extractSeparableConvParamsFactory.js 1.2 KB

1234567891011121314151617181920
  1. import * as tf from '@tensorflow/tfjs-core';
  2. import { SeparableConvParams } from './types';
  3. export function extractSeparableConvParamsFactory(extractWeights, paramMappings) {
  4. return function (channelsIn, channelsOut, mappedPrefix) {
  5. var depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1]);
  6. var pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]);
  7. var bias = tf.tensor1d(extractWeights(channelsOut));
  8. paramMappings.push({ paramPath: mappedPrefix + "/depthwise_filter" }, { paramPath: mappedPrefix + "/pointwise_filter" }, { paramPath: mappedPrefix + "/bias" });
  9. return new SeparableConvParams(depthwise_filter, pointwise_filter, bias);
  10. };
  11. }
  12. export function loadSeparableConvParamsFactory(extractWeightEntry) {
  13. return function (prefix) {
  14. var depthwise_filter = extractWeightEntry(prefix + "/depthwise_filter", 4);
  15. var pointwise_filter = extractWeightEntry(prefix + "/pointwise_filter", 4);
  16. var bias = extractWeightEntry(prefix + "/bias", 1);
  17. return new SeparableConvParams(depthwise_filter, pointwise_filter, bias);
  18. };
  19. }
  20. //# sourceMappingURL=extractSeparableConvParamsFactory.js.map