NeuralNetwork.js 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import { __awaiter, __generator } from "tslib";
  2. import * as tf from '@tensorflow/tfjs-core';
  3. import { getModelUris } from './common/getModelUris';
  4. import { loadWeightMap } from './dom';
  5. import { env } from './env';
  6. var NeuralNetwork = /** @class */ (function () {
  7. function NeuralNetwork(_name) {
  8. this._name = _name;
  9. this._params = undefined;
  10. this._paramMappings = [];
  11. }
  12. Object.defineProperty(NeuralNetwork.prototype, "params", {
  13. get: function () { return this._params; },
  14. enumerable: true,
  15. configurable: true
  16. });
  17. Object.defineProperty(NeuralNetwork.prototype, "paramMappings", {
  18. get: function () { return this._paramMappings; },
  19. enumerable: true,
  20. configurable: true
  21. });
  22. Object.defineProperty(NeuralNetwork.prototype, "isLoaded", {
  23. get: function () { return !!this.params; },
  24. enumerable: true,
  25. configurable: true
  26. });
  27. NeuralNetwork.prototype.getParamFromPath = function (paramPath) {
  28. var _a = this.traversePropertyPath(paramPath), obj = _a.obj, objProp = _a.objProp;
  29. return obj[objProp];
  30. };
  31. NeuralNetwork.prototype.reassignParamFromPath = function (paramPath, tensor) {
  32. var _a = this.traversePropertyPath(paramPath), obj = _a.obj, objProp = _a.objProp;
  33. obj[objProp].dispose();
  34. obj[objProp] = tensor;
  35. };
  36. NeuralNetwork.prototype.getParamList = function () {
  37. var _this = this;
  38. return this._paramMappings.map(function (_a) {
  39. var paramPath = _a.paramPath;
  40. return ({
  41. path: paramPath,
  42. tensor: _this.getParamFromPath(paramPath)
  43. });
  44. });
  45. };
  46. NeuralNetwork.prototype.getTrainableParams = function () {
  47. return this.getParamList().filter(function (param) { return param.tensor instanceof tf.Variable; });
  48. };
  49. NeuralNetwork.prototype.getFrozenParams = function () {
  50. return this.getParamList().filter(function (param) { return !(param.tensor instanceof tf.Variable); });
  51. };
  52. NeuralNetwork.prototype.variable = function () {
  53. var _this = this;
  54. this.getFrozenParams().forEach(function (_a) {
  55. var path = _a.path, tensor = _a.tensor;
  56. _this.reassignParamFromPath(path, tensor.variable());
  57. });
  58. };
  59. NeuralNetwork.prototype.freeze = function () {
  60. var _this = this;
  61. this.getTrainableParams().forEach(function (_a) {
  62. var path = _a.path, variable = _a.tensor;
  63. var tensor = tf.tensor(variable.dataSync());
  64. variable.dispose();
  65. _this.reassignParamFromPath(path, tensor);
  66. });
  67. };
  68. NeuralNetwork.prototype.dispose = function (throwOnRedispose) {
  69. if (throwOnRedispose === void 0) { throwOnRedispose = true; }
  70. this.getParamList().forEach(function (param) {
  71. if (throwOnRedispose && param.tensor.isDisposed) {
  72. throw new Error("param tensor has already been disposed for path " + param.path);
  73. }
  74. param.tensor.dispose();
  75. });
  76. this._params = undefined;
  77. };
  78. NeuralNetwork.prototype.serializeParams = function () {
  79. return new Float32Array(this.getParamList()
  80. .map(function (_a) {
  81. var tensor = _a.tensor;
  82. return Array.from(tensor.dataSync());
  83. })
  84. .reduce(function (flat, arr) { return flat.concat(arr); }));
  85. };
  86. NeuralNetwork.prototype.load = function (weightsOrUrl) {
  87. return __awaiter(this, void 0, void 0, function () {
  88. return __generator(this, function (_a) {
  89. switch (_a.label) {
  90. case 0:
  91. if (weightsOrUrl instanceof Float32Array) {
  92. this.extractWeights(weightsOrUrl);
  93. return [2 /*return*/];
  94. }
  95. return [4 /*yield*/, this.loadFromUri(weightsOrUrl)];
  96. case 1:
  97. _a.sent();
  98. return [2 /*return*/];
  99. }
  100. });
  101. });
  102. };
  103. NeuralNetwork.prototype.loadFromUri = function (uri) {
  104. return __awaiter(this, void 0, void 0, function () {
  105. var weightMap;
  106. return __generator(this, function (_a) {
  107. switch (_a.label) {
  108. case 0:
  109. if (uri && typeof uri !== 'string') {
  110. throw new Error(this._name + ".loadFromUri - expected model uri");
  111. }
  112. return [4 /*yield*/, loadWeightMap(uri, this.getDefaultModelName())];
  113. case 1:
  114. weightMap = _a.sent();
  115. this.loadFromWeightMap(weightMap);
  116. return [2 /*return*/];
  117. }
  118. });
  119. });
  120. };
  121. NeuralNetwork.prototype.loadFromDisk = function (filePath) {
  122. return __awaiter(this, void 0, void 0, function () {
  123. var readFile, _a, manifestUri, modelBaseUri, fetchWeightsFromDisk, loadWeights, manifest, _b, _c, weightMap;
  124. return __generator(this, function (_d) {
  125. switch (_d.label) {
  126. case 0:
  127. if (filePath && typeof filePath !== 'string') {
  128. throw new Error(this._name + ".loadFromDisk - expected model file path");
  129. }
  130. readFile = env.getEnv().readFile;
  131. _a = getModelUris(filePath, this.getDefaultModelName()), manifestUri = _a.manifestUri, modelBaseUri = _a.modelBaseUri;
  132. fetchWeightsFromDisk = function (filePaths) { return Promise.all(filePaths.map(function (filePath) { return readFile(filePath).then(function (buf) { return buf.buffer; }); })); };
  133. loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk);
  134. _c = (_b = JSON).parse;
  135. return [4 /*yield*/, readFile(manifestUri)];
  136. case 1:
  137. manifest = _c.apply(_b, [(_d.sent()).toString()]);
  138. return [4 /*yield*/, loadWeights(manifest, modelBaseUri)];
  139. case 2:
  140. weightMap = _d.sent();
  141. this.loadFromWeightMap(weightMap);
  142. return [2 /*return*/];
  143. }
  144. });
  145. });
  146. };
  147. NeuralNetwork.prototype.loadFromWeightMap = function (weightMap) {
  148. var _a = this.extractParamsFromWeigthMap(weightMap), paramMappings = _a.paramMappings, params = _a.params;
  149. this._paramMappings = paramMappings;
  150. this._params = params;
  151. };
  152. NeuralNetwork.prototype.extractWeights = function (weights) {
  153. var _a = this.extractParams(weights), paramMappings = _a.paramMappings, params = _a.params;
  154. this._paramMappings = paramMappings;
  155. this._params = params;
  156. };
  157. NeuralNetwork.prototype.traversePropertyPath = function (paramPath) {
  158. if (!this.params) {
  159. throw new Error("traversePropertyPath - model has no loaded params");
  160. }
  161. var result = paramPath.split('/').reduce(function (res, objProp) {
  162. if (!res.nextObj.hasOwnProperty(objProp)) {
  163. throw new Error("traversePropertyPath - object does not have property " + objProp + ", for path " + paramPath);
  164. }
  165. return { obj: res.nextObj, objProp: objProp, nextObj: res.nextObj[objProp] };
  166. }, { nextObj: this.params });
  167. var obj = result.obj, objProp = result.objProp;
  168. if (!obj || !objProp || !(obj[objProp] instanceof tf.Tensor)) {
  169. throw new Error("traversePropertyPath - parameter is not a tensor, for path " + paramPath);
  170. }
  171. return { obj: obj, objProp: objProp };
  172. };
  173. return NeuralNetwork;
  174. }());
  175. export { NeuralNetwork };
  176. //# sourceMappingURL=NeuralNetwork.js.map