extractFCParamsFactory.js 730 B

12345678910111213141516
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tf = require("@tensorflow/tfjs-core");
  4. function extractFCParamsFactory(extractWeights, paramMappings) {
  5. return function (channelsIn, channelsOut, mappedPrefix) {
  6. var fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut]);
  7. var fc_bias = tf.tensor1d(extractWeights(channelsOut));
  8. paramMappings.push({ paramPath: mappedPrefix + "/weights" }, { paramPath: mappedPrefix + "/bias" });
  9. return {
  10. weights: fc_weights,
  11. bias: fc_bias
  12. };
  13. };
  14. }
  15. exports.extractFCParamsFactory = extractFCParamsFactory;
  16. //# sourceMappingURL=extractFCParamsFactory.js.map