extractParamsFromWeigthMap.js 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import { disposeUnusedWeightTensors, extractWeightEntryFactory, loadSeparableConvParamsFactory, } from '../common';
  2. import { loadConvParamsFactory } from '../common/loadConvParamsFactory';
  3. import { range } from '../utils';
  4. function loadParamsFactory(weightMap, paramMappings) {
  5. var extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
  6. var extractConvParams = loadConvParamsFactory(extractWeightEntry);
  7. var extractSeparableConvParams = loadSeparableConvParamsFactory(extractWeightEntry);
  8. function extractReductionBlockParams(mappedPrefix) {
  9. var separable_conv0 = extractSeparableConvParams(mappedPrefix + "/separable_conv0");
  10. var separable_conv1 = extractSeparableConvParams(mappedPrefix + "/separable_conv1");
  11. var expansion_conv = extractConvParams(mappedPrefix + "/expansion_conv");
  12. return { separable_conv0: separable_conv0, separable_conv1: separable_conv1, expansion_conv: expansion_conv };
  13. }
  14. function extractMainBlockParams(mappedPrefix) {
  15. var separable_conv0 = extractSeparableConvParams(mappedPrefix + "/separable_conv0");
  16. var separable_conv1 = extractSeparableConvParams(mappedPrefix + "/separable_conv1");
  17. var separable_conv2 = extractSeparableConvParams(mappedPrefix + "/separable_conv2");
  18. return { separable_conv0: separable_conv0, separable_conv1: separable_conv1, separable_conv2: separable_conv2 };
  19. }
  20. return {
  21. extractConvParams: extractConvParams,
  22. extractSeparableConvParams: extractSeparableConvParams,
  23. extractReductionBlockParams: extractReductionBlockParams,
  24. extractMainBlockParams: extractMainBlockParams
  25. };
  26. }
  27. export function extractParamsFromWeigthMap(weightMap, numMainBlocks) {
  28. var paramMappings = [];
  29. var _a = loadParamsFactory(weightMap, paramMappings), extractConvParams = _a.extractConvParams, extractSeparableConvParams = _a.extractSeparableConvParams, extractReductionBlockParams = _a.extractReductionBlockParams, extractMainBlockParams = _a.extractMainBlockParams;
  30. var entry_flow_conv_in = extractConvParams('entry_flow/conv_in');
  31. var entry_flow_reduction_block_0 = extractReductionBlockParams('entry_flow/reduction_block_0');
  32. var entry_flow_reduction_block_1 = extractReductionBlockParams('entry_flow/reduction_block_1');
  33. var entry_flow = {
  34. conv_in: entry_flow_conv_in,
  35. reduction_block_0: entry_flow_reduction_block_0,
  36. reduction_block_1: entry_flow_reduction_block_1
  37. };
  38. var middle_flow = {};
  39. range(numMainBlocks, 0, 1).forEach(function (idx) {
  40. middle_flow["main_block_" + idx] = extractMainBlockParams("middle_flow/main_block_" + idx);
  41. });
  42. var exit_flow_reduction_block = extractReductionBlockParams('exit_flow/reduction_block');
  43. var exit_flow_separable_conv = extractSeparableConvParams('exit_flow/separable_conv');
  44. var exit_flow = {
  45. reduction_block: exit_flow_reduction_block,
  46. separable_conv: exit_flow_separable_conv
  47. };
  48. disposeUnusedWeightTensors(weightMap, paramMappings);
  49. return { params: { entry_flow: entry_flow, middle_flow: middle_flow, exit_flow: exit_flow }, paramMappings: paramMappings };
  50. }
  51. //# sourceMappingURL=extractParamsFromWeigthMap.js.map