extractSeparableConvParamsFactory.js 1.4 KB

123456789101112131415161718192021222324
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tf = require("@tensorflow/tfjs-core");
  4. var types_1 = require("./types");
  5. function extractSeparableConvParamsFactory(extractWeights, paramMappings) {
  6. return function (channelsIn, channelsOut, mappedPrefix) {
  7. var depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1]);
  8. var pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]);
  9. var bias = tf.tensor1d(extractWeights(channelsOut));
  10. paramMappings.push({ paramPath: mappedPrefix + "/depthwise_filter" }, { paramPath: mappedPrefix + "/pointwise_filter" }, { paramPath: mappedPrefix + "/bias" });
  11. return new types_1.SeparableConvParams(depthwise_filter, pointwise_filter, bias);
  12. };
  13. }
  14. exports.extractSeparableConvParamsFactory = extractSeparableConvParamsFactory;
  15. function loadSeparableConvParamsFactory(extractWeightEntry) {
  16. return function (prefix) {
  17. var depthwise_filter = extractWeightEntry(prefix + "/depthwise_filter", 4);
  18. var pointwise_filter = extractWeightEntry(prefix + "/pointwise_filter", 4);
  19. var bias = extractWeightEntry(prefix + "/bias", 1);
  20. return new types_1.SeparableConvParams(depthwise_filter, pointwise_filter, bias);
  21. };
  22. }
  23. exports.loadSeparableConvParamsFactory = loadSeparableConvParamsFactory;
  24. //# sourceMappingURL=extractSeparableConvParamsFactory.js.map