extractParamsFromWeigthMap.js 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var disposeUnusedWeightTensors_1 = require("../common/disposeUnusedWeightTensors");
  4. var extractSeparableConvParamsFactory_1 = require("../common/extractSeparableConvParamsFactory");
  5. var extractWeightEntryFactory_1 = require("../common/extractWeightEntryFactory");
  6. function extractorsFactory(weightMap, paramMappings) {
  7. var extractWeightEntry = extractWeightEntryFactory_1.extractWeightEntryFactory(weightMap, paramMappings);
  8. function extractBatchNormParams(prefix) {
  9. var sub = extractWeightEntry(prefix + "/sub", 1);
  10. var truediv = extractWeightEntry(prefix + "/truediv", 1);
  11. return { sub: sub, truediv: truediv };
  12. }
  13. function extractConvParams(prefix) {
  14. var filters = extractWeightEntry(prefix + "/filters", 4);
  15. var bias = extractWeightEntry(prefix + "/bias", 1);
  16. return { filters: filters, bias: bias };
  17. }
  18. function extractConvWithBatchNormParams(prefix) {
  19. var conv = extractConvParams(prefix + "/conv");
  20. var bn = extractBatchNormParams(prefix + "/bn");
  21. return { conv: conv, bn: bn };
  22. }
  23. var extractSeparableConvParams = extractSeparableConvParamsFactory_1.loadSeparableConvParamsFactory(extractWeightEntry);
  24. return {
  25. extractConvParams: extractConvParams,
  26. extractConvWithBatchNormParams: extractConvWithBatchNormParams,
  27. extractSeparableConvParams: extractSeparableConvParams
  28. };
  29. }
  30. function extractParamsFromWeigthMap(weightMap, config) {
  31. var paramMappings = [];
  32. var _a = extractorsFactory(weightMap, paramMappings), extractConvParams = _a.extractConvParams, extractConvWithBatchNormParams = _a.extractConvWithBatchNormParams, extractSeparableConvParams = _a.extractSeparableConvParams;
  33. var params;
  34. if (config.withSeparableConvs) {
  35. var numFilters = (config.filterSizes && config.filterSizes.length || 9);
  36. params = {
  37. conv0: config.isFirstLayerConv2d ? extractConvParams('conv0') : extractSeparableConvParams('conv0'),
  38. conv1: extractSeparableConvParams('conv1'),
  39. conv2: extractSeparableConvParams('conv2'),
  40. conv3: extractSeparableConvParams('conv3'),
  41. conv4: extractSeparableConvParams('conv4'),
  42. conv5: extractSeparableConvParams('conv5'),
  43. conv6: numFilters > 7 ? extractSeparableConvParams('conv6') : undefined,
  44. conv7: numFilters > 8 ? extractSeparableConvParams('conv7') : undefined,
  45. conv8: extractConvParams('conv8')
  46. };
  47. }
  48. else {
  49. params = {
  50. conv0: extractConvWithBatchNormParams('conv0'),
  51. conv1: extractConvWithBatchNormParams('conv1'),
  52. conv2: extractConvWithBatchNormParams('conv2'),
  53. conv3: extractConvWithBatchNormParams('conv3'),
  54. conv4: extractConvWithBatchNormParams('conv4'),
  55. conv5: extractConvWithBatchNormParams('conv5'),
  56. conv6: extractConvWithBatchNormParams('conv6'),
  57. conv7: extractConvWithBatchNormParams('conv7'),
  58. conv8: extractConvParams('conv8')
  59. };
  60. }
  61. disposeUnusedWeightTensors_1.disposeUnusedWeightTensors(weightMap, paramMappings);
  62. return { params: params, paramMappings: paramMappings };
  63. }
  64. exports.extractParamsFromWeigthMap = extractParamsFromWeigthMap;
  65. //# sourceMappingURL=extractParamsFromWeigthMap.js.map