NeuralNetwork.js 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tslib_1 = require("tslib");
  4. var tf = require("@tensorflow/tfjs-core");
  5. var getModelUris_1 = require("./common/getModelUris");
  6. var dom_1 = require("./dom");
  7. var env_1 = require("./env");
  8. var NeuralNetwork = /** @class */ (function () {
  9. function NeuralNetwork(_name) {
  10. this._name = _name;
  11. this._params = undefined;
  12. this._paramMappings = [];
  13. }
  14. Object.defineProperty(NeuralNetwork.prototype, "params", {
  15. get: function () { return this._params; },
  16. enumerable: true,
  17. configurable: true
  18. });
  19. Object.defineProperty(NeuralNetwork.prototype, "paramMappings", {
  20. get: function () { return this._paramMappings; },
  21. enumerable: true,
  22. configurable: true
  23. });
  24. Object.defineProperty(NeuralNetwork.prototype, "isLoaded", {
  25. get: function () { return !!this.params; },
  26. enumerable: true,
  27. configurable: true
  28. });
  29. NeuralNetwork.prototype.getParamFromPath = function (paramPath) {
  30. var _a = this.traversePropertyPath(paramPath), obj = _a.obj, objProp = _a.objProp;
  31. return obj[objProp];
  32. };
  33. NeuralNetwork.prototype.reassignParamFromPath = function (paramPath, tensor) {
  34. var _a = this.traversePropertyPath(paramPath), obj = _a.obj, objProp = _a.objProp;
  35. obj[objProp].dispose();
  36. obj[objProp] = tensor;
  37. };
  38. NeuralNetwork.prototype.getParamList = function () {
  39. var _this = this;
  40. return this._paramMappings.map(function (_a) {
  41. var paramPath = _a.paramPath;
  42. return ({
  43. path: paramPath,
  44. tensor: _this.getParamFromPath(paramPath)
  45. });
  46. });
  47. };
  48. NeuralNetwork.prototype.getTrainableParams = function () {
  49. return this.getParamList().filter(function (param) { return param.tensor instanceof tf.Variable; });
  50. };
  51. NeuralNetwork.prototype.getFrozenParams = function () {
  52. return this.getParamList().filter(function (param) { return !(param.tensor instanceof tf.Variable); });
  53. };
  54. NeuralNetwork.prototype.variable = function () {
  55. var _this = this;
  56. this.getFrozenParams().forEach(function (_a) {
  57. var path = _a.path, tensor = _a.tensor;
  58. _this.reassignParamFromPath(path, tensor.variable());
  59. });
  60. };
  61. NeuralNetwork.prototype.freeze = function () {
  62. var _this = this;
  63. this.getTrainableParams().forEach(function (_a) {
  64. var path = _a.path, variable = _a.tensor;
  65. var tensor = tf.tensor(variable.dataSync());
  66. variable.dispose();
  67. _this.reassignParamFromPath(path, tensor);
  68. });
  69. };
  70. NeuralNetwork.prototype.dispose = function (throwOnRedispose) {
  71. if (throwOnRedispose === void 0) { throwOnRedispose = true; }
  72. this.getParamList().forEach(function (param) {
  73. if (throwOnRedispose && param.tensor.isDisposed) {
  74. throw new Error("param tensor has already been disposed for path " + param.path);
  75. }
  76. param.tensor.dispose();
  77. });
  78. this._params = undefined;
  79. };
  80. NeuralNetwork.prototype.serializeParams = function () {
  81. return new Float32Array(this.getParamList()
  82. .map(function (_a) {
  83. var tensor = _a.tensor;
  84. return Array.from(tensor.dataSync());
  85. })
  86. .reduce(function (flat, arr) { return flat.concat(arr); }));
  87. };
  88. NeuralNetwork.prototype.load = function (weightsOrUrl) {
  89. return tslib_1.__awaiter(this, void 0, void 0, function () {
  90. return tslib_1.__generator(this, function (_a) {
  91. switch (_a.label) {
  92. case 0:
  93. if (weightsOrUrl instanceof Float32Array) {
  94. this.extractWeights(weightsOrUrl);
  95. return [2 /*return*/];
  96. }
  97. return [4 /*yield*/, this.loadFromUri(weightsOrUrl)];
  98. case 1:
  99. _a.sent();
  100. return [2 /*return*/];
  101. }
  102. });
  103. });
  104. };
  105. NeuralNetwork.prototype.loadFromUri = function (uri) {
  106. return tslib_1.__awaiter(this, void 0, void 0, function () {
  107. var weightMap;
  108. return tslib_1.__generator(this, function (_a) {
  109. switch (_a.label) {
  110. case 0:
  111. if (uri && typeof uri !== 'string') {
  112. throw new Error(this._name + ".loadFromUri - expected model uri");
  113. }
  114. return [4 /*yield*/, dom_1.loadWeightMap(uri, this.getDefaultModelName())];
  115. case 1:
  116. weightMap = _a.sent();
  117. this.loadFromWeightMap(weightMap);
  118. return [2 /*return*/];
  119. }
  120. });
  121. });
  122. };
  123. NeuralNetwork.prototype.loadFromDisk = function (filePath) {
  124. return tslib_1.__awaiter(this, void 0, void 0, function () {
  125. var readFile, _a, manifestUri, modelBaseUri, fetchWeightsFromDisk, loadWeights, manifest, _b, _c, weightMap;
  126. return tslib_1.__generator(this, function (_d) {
  127. switch (_d.label) {
  128. case 0:
  129. if (filePath && typeof filePath !== 'string') {
  130. throw new Error(this._name + ".loadFromDisk - expected model file path");
  131. }
  132. readFile = env_1.env.getEnv().readFile;
  133. _a = getModelUris_1.getModelUris(filePath, this.getDefaultModelName()), manifestUri = _a.manifestUri, modelBaseUri = _a.modelBaseUri;
  134. fetchWeightsFromDisk = function (filePaths) { return Promise.all(filePaths.map(function (filePath) { return readFile(filePath).then(function (buf) { return buf.buffer; }); })); };
  135. loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk);
  136. _c = (_b = JSON).parse;
  137. return [4 /*yield*/, readFile(manifestUri)];
  138. case 1:
  139. manifest = _c.apply(_b, [(_d.sent()).toString()]);
  140. return [4 /*yield*/, loadWeights(manifest, modelBaseUri)];
  141. case 2:
  142. weightMap = _d.sent();
  143. this.loadFromWeightMap(weightMap);
  144. return [2 /*return*/];
  145. }
  146. });
  147. });
  148. };
  149. NeuralNetwork.prototype.loadFromWeightMap = function (weightMap) {
  150. var _a = this.extractParamsFromWeigthMap(weightMap), paramMappings = _a.paramMappings, params = _a.params;
  151. this._paramMappings = paramMappings;
  152. this._params = params;
  153. };
  154. NeuralNetwork.prototype.extractWeights = function (weights) {
  155. var _a = this.extractParams(weights), paramMappings = _a.paramMappings, params = _a.params;
  156. this._paramMappings = paramMappings;
  157. this._params = params;
  158. };
  159. NeuralNetwork.prototype.traversePropertyPath = function (paramPath) {
  160. if (!this.params) {
  161. throw new Error("traversePropertyPath - model has no loaded params");
  162. }
  163. var result = paramPath.split('/').reduce(function (res, objProp) {
  164. if (!res.nextObj.hasOwnProperty(objProp)) {
  165. throw new Error("traversePropertyPath - object does not have property " + objProp + ", for path " + paramPath);
  166. }
  167. return { obj: res.nextObj, objProp: objProp, nextObj: res.nextObj[objProp] };
  168. }, { nextObj: this.params });
  169. var obj = result.obj, objProp = result.objProp;
  170. if (!obj || !objProp || !(obj[objProp] instanceof tf.Tensor)) {
  171. throw new Error("traversePropertyPath - parameter is not a tensor, for path " + paramPath);
  172. }
  173. return { obj: obj, objProp: objProp };
  174. };
  175. return NeuralNetwork;
  176. }());
  177. exports.NeuralNetwork = NeuralNetwork;
  178. //# sourceMappingURL=NeuralNetwork.js.map