extractParams.js 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import * as tf from '@tensorflow/tfjs-core';
  2. import { extractConvParamsFactory } from '../common';
  3. import { extractSeparableConvParamsFactory } from '../common/extractSeparableConvParamsFactory';
  4. import { extractWeightsFactory } from '../common/extractWeightsFactory';
  5. function extractorsFactory(extractWeights, paramMappings) {
  6. var extractConvParams = extractConvParamsFactory(extractWeights, paramMappings);
  7. function extractBatchNormParams(size, mappedPrefix) {
  8. var sub = tf.tensor1d(extractWeights(size));
  9. var truediv = tf.tensor1d(extractWeights(size));
  10. paramMappings.push({ paramPath: mappedPrefix + "/sub" }, { paramPath: mappedPrefix + "/truediv" });
  11. return { sub: sub, truediv: truediv };
  12. }
  13. function extractConvWithBatchNormParams(channelsIn, channelsOut, mappedPrefix) {
  14. var conv = extractConvParams(channelsIn, channelsOut, 3, mappedPrefix + "/conv");
  15. var bn = extractBatchNormParams(channelsOut, mappedPrefix + "/bn");
  16. return { conv: conv, bn: bn };
  17. }
  18. var extractSeparableConvParams = extractSeparableConvParamsFactory(extractWeights, paramMappings);
  19. return {
  20. extractConvParams: extractConvParams,
  21. extractConvWithBatchNormParams: extractConvWithBatchNormParams,
  22. extractSeparableConvParams: extractSeparableConvParams
  23. };
  24. }
  25. export function extractParams(weights, config, boxEncodingSize, filterSizes) {
  26. var _a = extractWeightsFactory(weights), extractWeights = _a.extractWeights, getRemainingWeights = _a.getRemainingWeights;
  27. var paramMappings = [];
  28. var _b = extractorsFactory(extractWeights, paramMappings), extractConvParams = _b.extractConvParams, extractConvWithBatchNormParams = _b.extractConvWithBatchNormParams, extractSeparableConvParams = _b.extractSeparableConvParams;
  29. var params;
  30. if (config.withSeparableConvs) {
  31. var s0 = filterSizes[0], s1 = filterSizes[1], s2 = filterSizes[2], s3 = filterSizes[3], s4 = filterSizes[4], s5 = filterSizes[5], s6 = filterSizes[6], s7 = filterSizes[7], s8 = filterSizes[8];
  32. var conv0 = config.isFirstLayerConv2d
  33. ? extractConvParams(s0, s1, 3, 'conv0')
  34. : extractSeparableConvParams(s0, s1, 'conv0');
  35. var conv1 = extractSeparableConvParams(s1, s2, 'conv1');
  36. var conv2 = extractSeparableConvParams(s2, s3, 'conv2');
  37. var conv3 = extractSeparableConvParams(s3, s4, 'conv3');
  38. var conv4 = extractSeparableConvParams(s4, s5, 'conv4');
  39. var conv5 = extractSeparableConvParams(s5, s6, 'conv5');
  40. var conv6 = s7 ? extractSeparableConvParams(s6, s7, 'conv6') : undefined;
  41. var conv7 = s8 ? extractSeparableConvParams(s7, s8, 'conv7') : undefined;
  42. var conv8 = extractConvParams(s8 || s7 || s6, 5 * boxEncodingSize, 1, 'conv8');
  43. params = { conv0: conv0, conv1: conv1, conv2: conv2, conv3: conv3, conv4: conv4, conv5: conv5, conv6: conv6, conv7: conv7, conv8: conv8 };
  44. }
  45. else {
  46. var s0 = filterSizes[0], s1 = filterSizes[1], s2 = filterSizes[2], s3 = filterSizes[3], s4 = filterSizes[4], s5 = filterSizes[5], s6 = filterSizes[6], s7 = filterSizes[7], s8 = filterSizes[8];
  47. var conv0 = extractConvWithBatchNormParams(s0, s1, 'conv0');
  48. var conv1 = extractConvWithBatchNormParams(s1, s2, 'conv1');
  49. var conv2 = extractConvWithBatchNormParams(s2, s3, 'conv2');
  50. var conv3 = extractConvWithBatchNormParams(s3, s4, 'conv3');
  51. var conv4 = extractConvWithBatchNormParams(s4, s5, 'conv4');
  52. var conv5 = extractConvWithBatchNormParams(s5, s6, 'conv5');
  53. var conv6 = extractConvWithBatchNormParams(s6, s7, 'conv6');
  54. var conv7 = extractConvWithBatchNormParams(s7, s8, 'conv7');
  55. var conv8 = extractConvParams(s8, 5 * boxEncodingSize, 1, 'conv8');
  56. params = { conv0: conv0, conv1: conv1, conv2: conv2, conv3: conv3, conv4: conv4, conv5: conv5, conv6: conv6, conv7: conv7, conv8: conv8 };
  57. }
  58. if (getRemainingWeights().length !== 0) {
  59. throw new Error("weights remaing after extract: " + getRemainingWeights().length);
  60. }
  61. return { params: params, paramMappings: paramMappings };
  62. }
  63. //# sourceMappingURL=extractParams.js.map