extractParamsFromWeigthMap.js 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var common_1 = require("../common");
  4. var utils_1 = require("../utils");
  5. function extractorsFactory(weightMap, paramMappings) {
  6. var extractWeightEntry = common_1.extractWeightEntryFactory(weightMap, paramMappings);
  7. function extractScaleLayerParams(prefix) {
  8. var weights = extractWeightEntry(prefix + "/scale/weights", 1);
  9. var biases = extractWeightEntry(prefix + "/scale/biases", 1);
  10. return { weights: weights, biases: biases };
  11. }
  12. function extractConvLayerParams(prefix) {
  13. var filters = extractWeightEntry(prefix + "/conv/filters", 4);
  14. var bias = extractWeightEntry(prefix + "/conv/bias", 1);
  15. var scale = extractScaleLayerParams(prefix);
  16. return { conv: { filters: filters, bias: bias }, scale: scale };
  17. }
  18. function extractResidualLayerParams(prefix) {
  19. return {
  20. conv1: extractConvLayerParams(prefix + "/conv1"),
  21. conv2: extractConvLayerParams(prefix + "/conv2")
  22. };
  23. }
  24. return {
  25. extractConvLayerParams: extractConvLayerParams,
  26. extractResidualLayerParams: extractResidualLayerParams
  27. };
  28. }
  29. function extractParamsFromWeigthMap(weightMap) {
  30. var paramMappings = [];
  31. var _a = extractorsFactory(weightMap, paramMappings), extractConvLayerParams = _a.extractConvLayerParams, extractResidualLayerParams = _a.extractResidualLayerParams;
  32. var conv32_down = extractConvLayerParams('conv32_down');
  33. var conv32_1 = extractResidualLayerParams('conv32_1');
  34. var conv32_2 = extractResidualLayerParams('conv32_2');
  35. var conv32_3 = extractResidualLayerParams('conv32_3');
  36. var conv64_down = extractResidualLayerParams('conv64_down');
  37. var conv64_1 = extractResidualLayerParams('conv64_1');
  38. var conv64_2 = extractResidualLayerParams('conv64_2');
  39. var conv64_3 = extractResidualLayerParams('conv64_3');
  40. var conv128_down = extractResidualLayerParams('conv128_down');
  41. var conv128_1 = extractResidualLayerParams('conv128_1');
  42. var conv128_2 = extractResidualLayerParams('conv128_2');
  43. var conv256_down = extractResidualLayerParams('conv256_down');
  44. var conv256_1 = extractResidualLayerParams('conv256_1');
  45. var conv256_2 = extractResidualLayerParams('conv256_2');
  46. var conv256_down_out = extractResidualLayerParams('conv256_down_out');
  47. var fc = weightMap['fc'];
  48. paramMappings.push({ originalPath: 'fc', paramPath: 'fc' });
  49. if (!utils_1.isTensor2D(fc)) {
  50. throw new Error("expected weightMap[fc] to be a Tensor2D, instead have " + fc);
  51. }
  52. var params = {
  53. conv32_down: conv32_down,
  54. conv32_1: conv32_1,
  55. conv32_2: conv32_2,
  56. conv32_3: conv32_3,
  57. conv64_down: conv64_down,
  58. conv64_1: conv64_1,
  59. conv64_2: conv64_2,
  60. conv64_3: conv64_3,
  61. conv128_down: conv128_down,
  62. conv128_1: conv128_1,
  63. conv128_2: conv128_2,
  64. conv256_down: conv256_down,
  65. conv256_1: conv256_1,
  66. conv256_2: conv256_2,
  67. conv256_down_out: conv256_down_out,
  68. fc: fc
  69. };
  70. common_1.disposeUnusedWeightTensors(weightMap, paramMappings);
  71. return { params: params, paramMappings: paramMappings };
  72. }
  73. exports.extractParamsFromWeigthMap = extractParamsFromWeigthMap;
  74. //# sourceMappingURL=extractParamsFromWeigthMap.js.map