extractParams.js 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tf = require("@tensorflow/tfjs-core");
  4. var common_1 = require("../common");
  5. function extractorsFactory(extractWeights, paramMappings) {
  6. function extractDepthwiseConvParams(numChannels, mappedPrefix) {
  7. var filters = tf.tensor4d(extractWeights(3 * 3 * numChannels), [3, 3, numChannels, 1]);
  8. var batch_norm_scale = tf.tensor1d(extractWeights(numChannels));
  9. var batch_norm_offset = tf.tensor1d(extractWeights(numChannels));
  10. var batch_norm_mean = tf.tensor1d(extractWeights(numChannels));
  11. var batch_norm_variance = tf.tensor1d(extractWeights(numChannels));
  12. paramMappings.push({ paramPath: mappedPrefix + "/filters" }, { paramPath: mappedPrefix + "/batch_norm_scale" }, { paramPath: mappedPrefix + "/batch_norm_offset" }, { paramPath: mappedPrefix + "/batch_norm_mean" }, { paramPath: mappedPrefix + "/batch_norm_variance" });
  13. return {
  14. filters: filters,
  15. batch_norm_scale: batch_norm_scale,
  16. batch_norm_offset: batch_norm_offset,
  17. batch_norm_mean: batch_norm_mean,
  18. batch_norm_variance: batch_norm_variance
  19. };
  20. }
  21. function extractConvParams(channelsIn, channelsOut, filterSize, mappedPrefix, isPointwiseConv) {
  22. var filters = tf.tensor4d(extractWeights(channelsIn * channelsOut * filterSize * filterSize), [filterSize, filterSize, channelsIn, channelsOut]);
  23. var bias = tf.tensor1d(extractWeights(channelsOut));
  24. paramMappings.push({ paramPath: mappedPrefix + "/filters" }, { paramPath: mappedPrefix + "/" + (isPointwiseConv ? 'batch_norm_offset' : 'bias') });
  25. return { filters: filters, bias: bias };
  26. }
  27. function extractPointwiseConvParams(channelsIn, channelsOut, filterSize, mappedPrefix) {
  28. var _a = extractConvParams(channelsIn, channelsOut, filterSize, mappedPrefix, true), filters = _a.filters, bias = _a.bias;
  29. return {
  30. filters: filters,
  31. batch_norm_offset: bias
  32. };
  33. }
  34. function extractConvPairParams(channelsIn, channelsOut, mappedPrefix) {
  35. var depthwise_conv = extractDepthwiseConvParams(channelsIn, mappedPrefix + "/depthwise_conv");
  36. var pointwise_conv = extractPointwiseConvParams(channelsIn, channelsOut, 1, mappedPrefix + "/pointwise_conv");
  37. return { depthwise_conv: depthwise_conv, pointwise_conv: pointwise_conv };
  38. }
  39. function extractMobilenetV1Params() {
  40. var conv_0 = extractPointwiseConvParams(3, 32, 3, 'mobilenetv1/conv_0');
  41. var conv_1 = extractConvPairParams(32, 64, 'mobilenetv1/conv_1');
  42. var conv_2 = extractConvPairParams(64, 128, 'mobilenetv1/conv_2');
  43. var conv_3 = extractConvPairParams(128, 128, 'mobilenetv1/conv_3');
  44. var conv_4 = extractConvPairParams(128, 256, 'mobilenetv1/conv_4');
  45. var conv_5 = extractConvPairParams(256, 256, 'mobilenetv1/conv_5');
  46. var conv_6 = extractConvPairParams(256, 512, 'mobilenetv1/conv_6');
  47. var conv_7 = extractConvPairParams(512, 512, 'mobilenetv1/conv_7');
  48. var conv_8 = extractConvPairParams(512, 512, 'mobilenetv1/conv_8');
  49. var conv_9 = extractConvPairParams(512, 512, 'mobilenetv1/conv_9');
  50. var conv_10 = extractConvPairParams(512, 512, 'mobilenetv1/conv_10');
  51. var conv_11 = extractConvPairParams(512, 512, 'mobilenetv1/conv_11');
  52. var conv_12 = extractConvPairParams(512, 1024, 'mobilenetv1/conv_12');
  53. var conv_13 = extractConvPairParams(1024, 1024, 'mobilenetv1/conv_13');
  54. return {
  55. conv_0: conv_0,
  56. conv_1: conv_1,
  57. conv_2: conv_2,
  58. conv_3: conv_3,
  59. conv_4: conv_4,
  60. conv_5: conv_5,
  61. conv_6: conv_6,
  62. conv_7: conv_7,
  63. conv_8: conv_8,
  64. conv_9: conv_9,
  65. conv_10: conv_10,
  66. conv_11: conv_11,
  67. conv_12: conv_12,
  68. conv_13: conv_13
  69. };
  70. }
  71. function extractPredictionLayerParams() {
  72. var conv_0 = extractPointwiseConvParams(1024, 256, 1, 'prediction_layer/conv_0');
  73. var conv_1 = extractPointwiseConvParams(256, 512, 3, 'prediction_layer/conv_1');
  74. var conv_2 = extractPointwiseConvParams(512, 128, 1, 'prediction_layer/conv_2');
  75. var conv_3 = extractPointwiseConvParams(128, 256, 3, 'prediction_layer/conv_3');
  76. var conv_4 = extractPointwiseConvParams(256, 128, 1, 'prediction_layer/conv_4');
  77. var conv_5 = extractPointwiseConvParams(128, 256, 3, 'prediction_layer/conv_5');
  78. var conv_6 = extractPointwiseConvParams(256, 64, 1, 'prediction_layer/conv_6');
  79. var conv_7 = extractPointwiseConvParams(64, 128, 3, 'prediction_layer/conv_7');
  80. var box_encoding_0_predictor = extractConvParams(512, 12, 1, 'prediction_layer/box_predictor_0/box_encoding_predictor');
  81. var class_predictor_0 = extractConvParams(512, 9, 1, 'prediction_layer/box_predictor_0/class_predictor');
  82. var box_encoding_1_predictor = extractConvParams(1024, 24, 1, 'prediction_layer/box_predictor_1/box_encoding_predictor');
  83. var class_predictor_1 = extractConvParams(1024, 18, 1, 'prediction_layer/box_predictor_1/class_predictor');
  84. var box_encoding_2_predictor = extractConvParams(512, 24, 1, 'prediction_layer/box_predictor_2/box_encoding_predictor');
  85. var class_predictor_2 = extractConvParams(512, 18, 1, 'prediction_layer/box_predictor_2/class_predictor');
  86. var box_encoding_3_predictor = extractConvParams(256, 24, 1, 'prediction_layer/box_predictor_3/box_encoding_predictor');
  87. var class_predictor_3 = extractConvParams(256, 18, 1, 'prediction_layer/box_predictor_3/class_predictor');
  88. var box_encoding_4_predictor = extractConvParams(256, 24, 1, 'prediction_layer/box_predictor_4/box_encoding_predictor');
  89. var class_predictor_4 = extractConvParams(256, 18, 1, 'prediction_layer/box_predictor_4/class_predictor');
  90. var box_encoding_5_predictor = extractConvParams(128, 24, 1, 'prediction_layer/box_predictor_5/box_encoding_predictor');
  91. var class_predictor_5 = extractConvParams(128, 18, 1, 'prediction_layer/box_predictor_5/class_predictor');
  92. var box_predictor_0 = {
  93. box_encoding_predictor: box_encoding_0_predictor,
  94. class_predictor: class_predictor_0
  95. };
  96. var box_predictor_1 = {
  97. box_encoding_predictor: box_encoding_1_predictor,
  98. class_predictor: class_predictor_1
  99. };
  100. var box_predictor_2 = {
  101. box_encoding_predictor: box_encoding_2_predictor,
  102. class_predictor: class_predictor_2
  103. };
  104. var box_predictor_3 = {
  105. box_encoding_predictor: box_encoding_3_predictor,
  106. class_predictor: class_predictor_3
  107. };
  108. var box_predictor_4 = {
  109. box_encoding_predictor: box_encoding_4_predictor,
  110. class_predictor: class_predictor_4
  111. };
  112. var box_predictor_5 = {
  113. box_encoding_predictor: box_encoding_5_predictor,
  114. class_predictor: class_predictor_5
  115. };
  116. return {
  117. conv_0: conv_0,
  118. conv_1: conv_1,
  119. conv_2: conv_2,
  120. conv_3: conv_3,
  121. conv_4: conv_4,
  122. conv_5: conv_5,
  123. conv_6: conv_6,
  124. conv_7: conv_7,
  125. box_predictor_0: box_predictor_0,
  126. box_predictor_1: box_predictor_1,
  127. box_predictor_2: box_predictor_2,
  128. box_predictor_3: box_predictor_3,
  129. box_predictor_4: box_predictor_4,
  130. box_predictor_5: box_predictor_5
  131. };
  132. }
  133. return {
  134. extractMobilenetV1Params: extractMobilenetV1Params,
  135. extractPredictionLayerParams: extractPredictionLayerParams
  136. };
  137. }
  138. function extractParams(weights) {
  139. var paramMappings = [];
  140. var _a = common_1.extractWeightsFactory(weights), extractWeights = _a.extractWeights, getRemainingWeights = _a.getRemainingWeights;
  141. var _b = extractorsFactory(extractWeights, paramMappings), extractMobilenetV1Params = _b.extractMobilenetV1Params, extractPredictionLayerParams = _b.extractPredictionLayerParams;
  142. var mobilenetv1 = extractMobilenetV1Params();
  143. var prediction_layer = extractPredictionLayerParams();
  144. var extra_dim = tf.tensor3d(extractWeights(5118 * 4), [1, 5118, 4]);
  145. var output_layer = {
  146. extra_dim: extra_dim
  147. };
  148. paramMappings.push({ paramPath: 'output_layer/extra_dim' });
  149. if (getRemainingWeights().length !== 0) {
  150. throw new Error("weights remaing after extract: " + getRemainingWeights().length);
  151. }
  152. return {
  153. params: {
  154. mobilenetv1: mobilenetv1,
  155. prediction_layer: prediction_layer,
  156. output_layer: output_layer
  157. },
  158. paramMappings: paramMappings
  159. };
  160. }
  161. exports.extractParams = extractParams;
  162. //# sourceMappingURL=extractParams.js.map