extractParams.js 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var common_1 = require("../common");
  4. var utils_1 = require("../utils");
  5. function extractorsFactory(extractWeights, paramMappings) {
  6. var extractConvParams = common_1.extractConvParamsFactory(extractWeights, paramMappings);
  7. var extractSeparableConvParams = common_1.extractSeparableConvParamsFactory(extractWeights, paramMappings);
  8. function extractReductionBlockParams(channelsIn, channelsOut, mappedPrefix) {
  9. var separable_conv0 = extractSeparableConvParams(channelsIn, channelsOut, mappedPrefix + "/separable_conv0");
  10. var separable_conv1 = extractSeparableConvParams(channelsOut, channelsOut, mappedPrefix + "/separable_conv1");
  11. var expansion_conv = extractConvParams(channelsIn, channelsOut, 1, mappedPrefix + "/expansion_conv");
  12. return { separable_conv0: separable_conv0, separable_conv1: separable_conv1, expansion_conv: expansion_conv };
  13. }
  14. function extractMainBlockParams(channels, mappedPrefix) {
  15. var separable_conv0 = extractSeparableConvParams(channels, channels, mappedPrefix + "/separable_conv0");
  16. var separable_conv1 = extractSeparableConvParams(channels, channels, mappedPrefix + "/separable_conv1");
  17. var separable_conv2 = extractSeparableConvParams(channels, channels, mappedPrefix + "/separable_conv2");
  18. return { separable_conv0: separable_conv0, separable_conv1: separable_conv1, separable_conv2: separable_conv2 };
  19. }
  20. return {
  21. extractConvParams: extractConvParams,
  22. extractSeparableConvParams: extractSeparableConvParams,
  23. extractReductionBlockParams: extractReductionBlockParams,
  24. extractMainBlockParams: extractMainBlockParams
  25. };
  26. }
  27. function extractParams(weights, numMainBlocks) {
  28. var paramMappings = [];
  29. var _a = common_1.extractWeightsFactory(weights), extractWeights = _a.extractWeights, getRemainingWeights = _a.getRemainingWeights;
  30. var _b = extractorsFactory(extractWeights, paramMappings), extractConvParams = _b.extractConvParams, extractSeparableConvParams = _b.extractSeparableConvParams, extractReductionBlockParams = _b.extractReductionBlockParams, extractMainBlockParams = _b.extractMainBlockParams;
  31. var entry_flow_conv_in = extractConvParams(3, 32, 3, 'entry_flow/conv_in');
  32. var entry_flow_reduction_block_0 = extractReductionBlockParams(32, 64, 'entry_flow/reduction_block_0');
  33. var entry_flow_reduction_block_1 = extractReductionBlockParams(64, 128, 'entry_flow/reduction_block_1');
  34. var entry_flow = {
  35. conv_in: entry_flow_conv_in,
  36. reduction_block_0: entry_flow_reduction_block_0,
  37. reduction_block_1: entry_flow_reduction_block_1
  38. };
  39. var middle_flow = {};
  40. utils_1.range(numMainBlocks, 0, 1).forEach(function (idx) {
  41. middle_flow["main_block_" + idx] = extractMainBlockParams(128, "middle_flow/main_block_" + idx);
  42. });
  43. var exit_flow_reduction_block = extractReductionBlockParams(128, 256, 'exit_flow/reduction_block');
  44. var exit_flow_separable_conv = extractSeparableConvParams(256, 512, 'exit_flow/separable_conv');
  45. var exit_flow = {
  46. reduction_block: exit_flow_reduction_block,
  47. separable_conv: exit_flow_separable_conv
  48. };
  49. if (getRemainingWeights().length !== 0) {
  50. throw new Error("weights remaing after extract: " + getRemainingWeights().length);
  51. }
  52. return {
  53. paramMappings: paramMappings,
  54. params: { entry_flow: entry_flow, middle_flow: middle_flow, exit_flow: exit_flow }
  55. };
  56. }
  57. exports.extractParams = extractParams;
  58. //# sourceMappingURL=extractParams.js.map