extractParamsFromWeigthMap.js 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import { disposeUnusedWeightTensors } from '../common/disposeUnusedWeightTensors';
  2. import { loadSeparableConvParamsFactory } from '../common/extractSeparableConvParamsFactory';
  3. import { extractWeightEntryFactory } from '../common/extractWeightEntryFactory';
  4. function extractorsFactory(weightMap, paramMappings) {
  5. var extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
  6. function extractBatchNormParams(prefix) {
  7. var sub = extractWeightEntry(prefix + "/sub", 1);
  8. var truediv = extractWeightEntry(prefix + "/truediv", 1);
  9. return { sub: sub, truediv: truediv };
  10. }
  11. function extractConvParams(prefix) {
  12. var filters = extractWeightEntry(prefix + "/filters", 4);
  13. var bias = extractWeightEntry(prefix + "/bias", 1);
  14. return { filters: filters, bias: bias };
  15. }
  16. function extractConvWithBatchNormParams(prefix) {
  17. var conv = extractConvParams(prefix + "/conv");
  18. var bn = extractBatchNormParams(prefix + "/bn");
  19. return { conv: conv, bn: bn };
  20. }
  21. var extractSeparableConvParams = loadSeparableConvParamsFactory(extractWeightEntry);
  22. return {
  23. extractConvParams: extractConvParams,
  24. extractConvWithBatchNormParams: extractConvWithBatchNormParams,
  25. extractSeparableConvParams: extractSeparableConvParams
  26. };
  27. }
  28. export function extractParamsFromWeigthMap(weightMap, config) {
  29. var paramMappings = [];
  30. var _a = extractorsFactory(weightMap, paramMappings), extractConvParams = _a.extractConvParams, extractConvWithBatchNormParams = _a.extractConvWithBatchNormParams, extractSeparableConvParams = _a.extractSeparableConvParams;
  31. var params;
  32. if (config.withSeparableConvs) {
  33. var numFilters = (config.filterSizes && config.filterSizes.length || 9);
  34. params = {
  35. conv0: config.isFirstLayerConv2d ? extractConvParams('conv0') : extractSeparableConvParams('conv0'),
  36. conv1: extractSeparableConvParams('conv1'),
  37. conv2: extractSeparableConvParams('conv2'),
  38. conv3: extractSeparableConvParams('conv3'),
  39. conv4: extractSeparableConvParams('conv4'),
  40. conv5: extractSeparableConvParams('conv5'),
  41. conv6: numFilters > 7 ? extractSeparableConvParams('conv6') : undefined,
  42. conv7: numFilters > 8 ? extractSeparableConvParams('conv7') : undefined,
  43. conv8: extractConvParams('conv8')
  44. };
  45. }
  46. else {
  47. params = {
  48. conv0: extractConvWithBatchNormParams('conv0'),
  49. conv1: extractConvWithBatchNormParams('conv1'),
  50. conv2: extractConvWithBatchNormParams('conv2'),
  51. conv3: extractConvWithBatchNormParams('conv3'),
  52. conv4: extractConvWithBatchNormParams('conv4'),
  53. conv5: extractConvWithBatchNormParams('conv5'),
  54. conv6: extractConvWithBatchNormParams('conv6'),
  55. conv7: extractConvWithBatchNormParams('conv7'),
  56. conv8: extractConvParams('conv8')
  57. };
  58. }
  59. disposeUnusedWeightTensors(weightMap, paramMappings);
  60. return { params: params, paramMappings: paramMappings };
  61. }
  62. //# sourceMappingURL=extractParamsFromWeigthMap.js.map