extractParamsFromWeigthMap.js 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import { disposeUnusedWeightTensors, extractWeightEntryFactory } from '../common';
  2. import { isTensor3D } from '../utils';
  3. function extractorsFactory(weightMap, paramMappings) {
  4. var extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
  5. function extractPointwiseConvParams(prefix, idx, mappedPrefix) {
  6. var filters = extractWeightEntry(prefix + "/Conv2d_" + idx + "_pointwise/weights", 4, mappedPrefix + "/filters");
  7. var batch_norm_offset = extractWeightEntry(prefix + "/Conv2d_" + idx + "_pointwise/convolution_bn_offset", 1, mappedPrefix + "/batch_norm_offset");
  8. return { filters: filters, batch_norm_offset: batch_norm_offset };
  9. }
  10. function extractConvPairParams(idx) {
  11. var mappedPrefix = "mobilenetv1/conv_" + idx;
  12. var prefixDepthwiseConv = "MobilenetV1/Conv2d_" + idx + "_depthwise";
  13. var mappedPrefixDepthwiseConv = mappedPrefix + "/depthwise_conv";
  14. var mappedPrefixPointwiseConv = mappedPrefix + "/pointwise_conv";
  15. var filters = extractWeightEntry(prefixDepthwiseConv + "/depthwise_weights", 4, mappedPrefixDepthwiseConv + "/filters");
  16. var batch_norm_scale = extractWeightEntry(prefixDepthwiseConv + "/BatchNorm/gamma", 1, mappedPrefixDepthwiseConv + "/batch_norm_scale");
  17. var batch_norm_offset = extractWeightEntry(prefixDepthwiseConv + "/BatchNorm/beta", 1, mappedPrefixDepthwiseConv + "/batch_norm_offset");
  18. var batch_norm_mean = extractWeightEntry(prefixDepthwiseConv + "/BatchNorm/moving_mean", 1, mappedPrefixDepthwiseConv + "/batch_norm_mean");
  19. var batch_norm_variance = extractWeightEntry(prefixDepthwiseConv + "/BatchNorm/moving_variance", 1, mappedPrefixDepthwiseConv + "/batch_norm_variance");
  20. return {
  21. depthwise_conv: {
  22. filters: filters,
  23. batch_norm_scale: batch_norm_scale,
  24. batch_norm_offset: batch_norm_offset,
  25. batch_norm_mean: batch_norm_mean,
  26. batch_norm_variance: batch_norm_variance
  27. },
  28. pointwise_conv: extractPointwiseConvParams('MobilenetV1', idx, mappedPrefixPointwiseConv)
  29. };
  30. }
  31. function extractMobilenetV1Params() {
  32. return {
  33. conv_0: extractPointwiseConvParams('MobilenetV1', 0, 'mobilenetv1/conv_0'),
  34. conv_1: extractConvPairParams(1),
  35. conv_2: extractConvPairParams(2),
  36. conv_3: extractConvPairParams(3),
  37. conv_4: extractConvPairParams(4),
  38. conv_5: extractConvPairParams(5),
  39. conv_6: extractConvPairParams(6),
  40. conv_7: extractConvPairParams(7),
  41. conv_8: extractConvPairParams(8),
  42. conv_9: extractConvPairParams(9),
  43. conv_10: extractConvPairParams(10),
  44. conv_11: extractConvPairParams(11),
  45. conv_12: extractConvPairParams(12),
  46. conv_13: extractConvPairParams(13)
  47. };
  48. }
  49. function extractConvParams(prefix, mappedPrefix) {
  50. var filters = extractWeightEntry(prefix + "/weights", 4, mappedPrefix + "/filters");
  51. var bias = extractWeightEntry(prefix + "/biases", 1, mappedPrefix + "/bias");
  52. return { filters: filters, bias: bias };
  53. }
  54. function extractBoxPredictorParams(idx) {
  55. var box_encoding_predictor = extractConvParams("Prediction/BoxPredictor_" + idx + "/BoxEncodingPredictor", "prediction_layer/box_predictor_" + idx + "/box_encoding_predictor");
  56. var class_predictor = extractConvParams("Prediction/BoxPredictor_" + idx + "/ClassPredictor", "prediction_layer/box_predictor_" + idx + "/class_predictor");
  57. return { box_encoding_predictor: box_encoding_predictor, class_predictor: class_predictor };
  58. }
  59. function extractPredictionLayerParams() {
  60. return {
  61. conv_0: extractPointwiseConvParams('Prediction', 0, 'prediction_layer/conv_0'),
  62. conv_1: extractPointwiseConvParams('Prediction', 1, 'prediction_layer/conv_1'),
  63. conv_2: extractPointwiseConvParams('Prediction', 2, 'prediction_layer/conv_2'),
  64. conv_3: extractPointwiseConvParams('Prediction', 3, 'prediction_layer/conv_3'),
  65. conv_4: extractPointwiseConvParams('Prediction', 4, 'prediction_layer/conv_4'),
  66. conv_5: extractPointwiseConvParams('Prediction', 5, 'prediction_layer/conv_5'),
  67. conv_6: extractPointwiseConvParams('Prediction', 6, 'prediction_layer/conv_6'),
  68. conv_7: extractPointwiseConvParams('Prediction', 7, 'prediction_layer/conv_7'),
  69. box_predictor_0: extractBoxPredictorParams(0),
  70. box_predictor_1: extractBoxPredictorParams(1),
  71. box_predictor_2: extractBoxPredictorParams(2),
  72. box_predictor_3: extractBoxPredictorParams(3),
  73. box_predictor_4: extractBoxPredictorParams(4),
  74. box_predictor_5: extractBoxPredictorParams(5)
  75. };
  76. }
  77. return {
  78. extractMobilenetV1Params: extractMobilenetV1Params,
  79. extractPredictionLayerParams: extractPredictionLayerParams
  80. };
  81. }
  82. export function extractParamsFromWeigthMap(weightMap) {
  83. var paramMappings = [];
  84. var _a = extractorsFactory(weightMap, paramMappings), extractMobilenetV1Params = _a.extractMobilenetV1Params, extractPredictionLayerParams = _a.extractPredictionLayerParams;
  85. var extra_dim = weightMap['Output/extra_dim'];
  86. paramMappings.push({ originalPath: 'Output/extra_dim', paramPath: 'output_layer/extra_dim' });
  87. if (!isTensor3D(extra_dim)) {
  88. throw new Error("expected weightMap['Output/extra_dim'] to be a Tensor3D, instead have " + extra_dim);
  89. }
  90. var params = {
  91. mobilenetv1: extractMobilenetV1Params(),
  92. prediction_layer: extractPredictionLayerParams(),
  93. output_layer: {
  94. extra_dim: extra_dim
  95. }
  96. };
  97. disposeUnusedWeightTensors(weightMap, paramMappings);
  98. return { params: params, paramMappings: paramMappings };
  99. }
  100. //# sourceMappingURL=extractParamsFromWeigthMap.js.map