extractFCParamsFactory.ts 711 B

123456789101112131415161718192021222324252627
  1. import * as tf from '../../dist/tfjs.esm';
  2. import { ExtractWeightsFunction, FCParams, ParamMapping } from './types';
  3. export function extractFCParamsFactory(
  4. extractWeights: ExtractWeightsFunction,
  5. paramMappings: ParamMapping[],
  6. ) {
  7. return (
  8. channelsIn: number,
  9. channelsOut: number,
  10. mappedPrefix: string,
  11. ): FCParams => {
  12. const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut]);
  13. const fc_bias = tf.tensor1d(extractWeights(channelsOut));
  14. paramMappings.push(
  15. { paramPath: `${mappedPrefix}/weights` },
  16. { paramPath: `${mappedPrefix}/bias` },
  17. );
  18. return {
  19. weights: fc_weights,
  20. bias: fc_bias,
  21. };
  22. };
  23. }