extractParams.js 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tf = require("@tensorflow/tfjs-core");
  4. var common_1 = require("../common");
  5. var extractSeparableConvParamsFactory_1 = require("../common/extractSeparableConvParamsFactory");
  6. var extractWeightsFactory_1 = require("../common/extractWeightsFactory");
  7. function extractorsFactory(extractWeights, paramMappings) {
  8. var extractConvParams = common_1.extractConvParamsFactory(extractWeights, paramMappings);
  9. function extractBatchNormParams(size, mappedPrefix) {
  10. var sub = tf.tensor1d(extractWeights(size));
  11. var truediv = tf.tensor1d(extractWeights(size));
  12. paramMappings.push({ paramPath: mappedPrefix + "/sub" }, { paramPath: mappedPrefix + "/truediv" });
  13. return { sub: sub, truediv: truediv };
  14. }
  15. function extractConvWithBatchNormParams(channelsIn, channelsOut, mappedPrefix) {
  16. var conv = extractConvParams(channelsIn, channelsOut, 3, mappedPrefix + "/conv");
  17. var bn = extractBatchNormParams(channelsOut, mappedPrefix + "/bn");
  18. return { conv: conv, bn: bn };
  19. }
  20. var extractSeparableConvParams = extractSeparableConvParamsFactory_1.extractSeparableConvParamsFactory(extractWeights, paramMappings);
  21. return {
  22. extractConvParams: extractConvParams,
  23. extractConvWithBatchNormParams: extractConvWithBatchNormParams,
  24. extractSeparableConvParams: extractSeparableConvParams
  25. };
  26. }
  27. function extractParams(weights, config, boxEncodingSize, filterSizes) {
  28. var _a = extractWeightsFactory_1.extractWeightsFactory(weights), extractWeights = _a.extractWeights, getRemainingWeights = _a.getRemainingWeights;
  29. var paramMappings = [];
  30. var _b = extractorsFactory(extractWeights, paramMappings), extractConvParams = _b.extractConvParams, extractConvWithBatchNormParams = _b.extractConvWithBatchNormParams, extractSeparableConvParams = _b.extractSeparableConvParams;
  31. var params;
  32. if (config.withSeparableConvs) {
  33. var s0 = filterSizes[0], s1 = filterSizes[1], s2 = filterSizes[2], s3 = filterSizes[3], s4 = filterSizes[4], s5 = filterSizes[5], s6 = filterSizes[6], s7 = filterSizes[7], s8 = filterSizes[8];
  34. var conv0 = config.isFirstLayerConv2d
  35. ? extractConvParams(s0, s1, 3, 'conv0')
  36. : extractSeparableConvParams(s0, s1, 'conv0');
  37. var conv1 = extractSeparableConvParams(s1, s2, 'conv1');
  38. var conv2 = extractSeparableConvParams(s2, s3, 'conv2');
  39. var conv3 = extractSeparableConvParams(s3, s4, 'conv3');
  40. var conv4 = extractSeparableConvParams(s4, s5, 'conv4');
  41. var conv5 = extractSeparableConvParams(s5, s6, 'conv5');
  42. var conv6 = s7 ? extractSeparableConvParams(s6, s7, 'conv6') : undefined;
  43. var conv7 = s8 ? extractSeparableConvParams(s7, s8, 'conv7') : undefined;
  44. var conv8 = extractConvParams(s8 || s7 || s6, 5 * boxEncodingSize, 1, 'conv8');
  45. params = { conv0: conv0, conv1: conv1, conv2: conv2, conv3: conv3, conv4: conv4, conv5: conv5, conv6: conv6, conv7: conv7, conv8: conv8 };
  46. }
  47. else {
  48. var s0 = filterSizes[0], s1 = filterSizes[1], s2 = filterSizes[2], s3 = filterSizes[3], s4 = filterSizes[4], s5 = filterSizes[5], s6 = filterSizes[6], s7 = filterSizes[7], s8 = filterSizes[8];
  49. var conv0 = extractConvWithBatchNormParams(s0, s1, 'conv0');
  50. var conv1 = extractConvWithBatchNormParams(s1, s2, 'conv1');
  51. var conv2 = extractConvWithBatchNormParams(s2, s3, 'conv2');
  52. var conv3 = extractConvWithBatchNormParams(s3, s4, 'conv3');
  53. var conv4 = extractConvWithBatchNormParams(s4, s5, 'conv4');
  54. var conv5 = extractConvWithBatchNormParams(s5, s6, 'conv5');
  55. var conv6 = extractConvWithBatchNormParams(s6, s7, 'conv6');
  56. var conv7 = extractConvWithBatchNormParams(s7, s8, 'conv7');
  57. var conv8 = extractConvParams(s8, 5 * boxEncodingSize, 1, 'conv8');
  58. params = { conv0: conv0, conv1: conv1, conv2: conv2, conv3: conv3, conv4: conv4, conv5: conv5, conv6: conv6, conv7: conv7, conv8: conv8 };
  59. }
  60. if (getRemainingWeights().length !== 0) {
  61. throw new Error("weights remaing after extract: " + getRemainingWeights().length);
  62. }
  63. return { params: params, paramMappings: paramMappings };
  64. }
  65. exports.extractParams = extractParams;
  66. //# sourceMappingURL=extractParams.js.map