NeuralNetwork.d.ts 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import * as tf from '@tensorflow/tfjs-core';
  2. import { ParamMapping } from './common';
  3. export declare abstract class NeuralNetwork<TNetParams> {
  4. protected _name: string;
  5. protected _params: TNetParams | undefined;
  6. protected _paramMappings: ParamMapping[];
  7. constructor(_name: string);
  8. get params(): TNetParams | undefined;
  9. get paramMappings(): ParamMapping[];
  10. get isLoaded(): boolean;
  11. getParamFromPath(paramPath: string): tf.Tensor;
  12. reassignParamFromPath(paramPath: string, tensor: tf.Tensor): void;
  13. getParamList(): {
  14. path: string;
  15. tensor: tf.Tensor<tf.Rank>;
  16. }[];
  17. getTrainableParams(): {
  18. path: string;
  19. tensor: tf.Tensor<tf.Rank>;
  20. }[];
  21. getFrozenParams(): {
  22. path: string;
  23. tensor: tf.Tensor<tf.Rank>;
  24. }[];
  25. variable(): void;
  26. freeze(): void;
  27. dispose(throwOnRedispose?: boolean): void;
  28. serializeParams(): Float32Array;
  29. load(weightsOrUrl: Float32Array | string | undefined): Promise<void>;
  30. loadFromUri(uri: string | undefined): Promise<void>;
  31. loadFromDisk(filePath: string | undefined): Promise<void>;
  32. loadFromWeightMap(weightMap: tf.NamedTensorMap): void;
  33. extractWeights(weights: Float32Array): void;
  34. private traversePropertyPath;
  35. protected abstract getDefaultModelName(): string;
  36. protected abstract extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap): {
  37. params: TNetParams;
  38. paramMappings: ParamMapping[];
  39. };
  40. protected abstract extractParams(weights: Float32Array): {
  41. params: TNetParams;
  42. paramMappings: ParamMapping[];
  43. };
  44. }