123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- import { __awaiter, __generator } from "tslib";
- import * as tf from '@tensorflow/tfjs-core';
- import { getModelUris } from './common/getModelUris';
- import { loadWeightMap } from './dom';
- import { env } from './env';
- var NeuralNetwork = /** @class */ (function () {
- function NeuralNetwork(_name) {
- this._name = _name;
- this._params = undefined;
- this._paramMappings = [];
- }
- Object.defineProperty(NeuralNetwork.prototype, "params", {
- get: function () { return this._params; },
- enumerable: true,
- configurable: true
- });
- Object.defineProperty(NeuralNetwork.prototype, "paramMappings", {
- get: function () { return this._paramMappings; },
- enumerable: true,
- configurable: true
- });
- Object.defineProperty(NeuralNetwork.prototype, "isLoaded", {
- get: function () { return !!this.params; },
- enumerable: true,
- configurable: true
- });
- NeuralNetwork.prototype.getParamFromPath = function (paramPath) {
- var _a = this.traversePropertyPath(paramPath), obj = _a.obj, objProp = _a.objProp;
- return obj[objProp];
- };
- NeuralNetwork.prototype.reassignParamFromPath = function (paramPath, tensor) {
- var _a = this.traversePropertyPath(paramPath), obj = _a.obj, objProp = _a.objProp;
- obj[objProp].dispose();
- obj[objProp] = tensor;
- };
- NeuralNetwork.prototype.getParamList = function () {
- var _this = this;
- return this._paramMappings.map(function (_a) {
- var paramPath = _a.paramPath;
- return ({
- path: paramPath,
- tensor: _this.getParamFromPath(paramPath)
- });
- });
- };
- NeuralNetwork.prototype.getTrainableParams = function () {
- return this.getParamList().filter(function (param) { return param.tensor instanceof tf.Variable; });
- };
- NeuralNetwork.prototype.getFrozenParams = function () {
- return this.getParamList().filter(function (param) { return !(param.tensor instanceof tf.Variable); });
- };
- NeuralNetwork.prototype.variable = function () {
- var _this = this;
- this.getFrozenParams().forEach(function (_a) {
- var path = _a.path, tensor = _a.tensor;
- _this.reassignParamFromPath(path, tensor.variable());
- });
- };
- NeuralNetwork.prototype.freeze = function () {
- var _this = this;
- this.getTrainableParams().forEach(function (_a) {
- var path = _a.path, variable = _a.tensor;
- var tensor = tf.tensor(variable.dataSync());
- variable.dispose();
- _this.reassignParamFromPath(path, tensor);
- });
- };
- NeuralNetwork.prototype.dispose = function (throwOnRedispose) {
- if (throwOnRedispose === void 0) { throwOnRedispose = true; }
- this.getParamList().forEach(function (param) {
- if (throwOnRedispose && param.tensor.isDisposed) {
- throw new Error("param tensor has already been disposed for path " + param.path);
- }
- param.tensor.dispose();
- });
- this._params = undefined;
- };
- NeuralNetwork.prototype.serializeParams = function () {
- return new Float32Array(this.getParamList()
- .map(function (_a) {
- var tensor = _a.tensor;
- return Array.from(tensor.dataSync());
- })
- .reduce(function (flat, arr) { return flat.concat(arr); }));
- };
- NeuralNetwork.prototype.load = function (weightsOrUrl) {
- return __awaiter(this, void 0, void 0, function () {
- return __generator(this, function (_a) {
- switch (_a.label) {
- case 0:
- if (weightsOrUrl instanceof Float32Array) {
- this.extractWeights(weightsOrUrl);
- return [2 /*return*/];
- }
- return [4 /*yield*/, this.loadFromUri(weightsOrUrl)];
- case 1:
- _a.sent();
- return [2 /*return*/];
- }
- });
- });
- };
- NeuralNetwork.prototype.loadFromUri = function (uri) {
- return __awaiter(this, void 0, void 0, function () {
- var weightMap;
- return __generator(this, function (_a) {
- switch (_a.label) {
- case 0:
- if (uri && typeof uri !== 'string') {
- throw new Error(this._name + ".loadFromUri - expected model uri");
- }
- return [4 /*yield*/, loadWeightMap(uri, this.getDefaultModelName())];
- case 1:
- weightMap = _a.sent();
- this.loadFromWeightMap(weightMap);
- return [2 /*return*/];
- }
- });
- });
- };
- NeuralNetwork.prototype.loadFromDisk = function (filePath) {
- return __awaiter(this, void 0, void 0, function () {
- var readFile, _a, manifestUri, modelBaseUri, fetchWeightsFromDisk, loadWeights, manifest, _b, _c, weightMap;
- return __generator(this, function (_d) {
- switch (_d.label) {
- case 0:
- if (filePath && typeof filePath !== 'string') {
- throw new Error(this._name + ".loadFromDisk - expected model file path");
- }
- readFile = env.getEnv().readFile;
- _a = getModelUris(filePath, this.getDefaultModelName()), manifestUri = _a.manifestUri, modelBaseUri = _a.modelBaseUri;
- fetchWeightsFromDisk = function (filePaths) { return Promise.all(filePaths.map(function (filePath) { return readFile(filePath).then(function (buf) { return buf.buffer; }); })); };
- loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk);
- _c = (_b = JSON).parse;
- return [4 /*yield*/, readFile(manifestUri)];
- case 1:
- manifest = _c.apply(_b, [(_d.sent()).toString()]);
- return [4 /*yield*/, loadWeights(manifest, modelBaseUri)];
- case 2:
- weightMap = _d.sent();
- this.loadFromWeightMap(weightMap);
- return [2 /*return*/];
- }
- });
- });
- };
- NeuralNetwork.prototype.loadFromWeightMap = function (weightMap) {
- var _a = this.extractParamsFromWeigthMap(weightMap), paramMappings = _a.paramMappings, params = _a.params;
- this._paramMappings = paramMappings;
- this._params = params;
- };
- NeuralNetwork.prototype.extractWeights = function (weights) {
- var _a = this.extractParams(weights), paramMappings = _a.paramMappings, params = _a.params;
- this._paramMappings = paramMappings;
- this._params = params;
- };
- NeuralNetwork.prototype.traversePropertyPath = function (paramPath) {
- if (!this.params) {
- throw new Error("traversePropertyPath - model has no loaded params");
- }
- var result = paramPath.split('/').reduce(function (res, objProp) {
- if (!res.nextObj.hasOwnProperty(objProp)) {
- throw new Error("traversePropertyPath - object does not have property " + objProp + ", for path " + paramPath);
- }
- return { obj: res.nextObj, objProp: objProp, nextObj: res.nextObj[objProp] };
- }, { nextObj: this.params });
- var obj = result.obj, objProp = result.objProp;
- if (!obj || !objProp || !(obj[objProp] instanceof tf.Tensor)) {
- throw new Error("traversePropertyPath - parameter is not a tensor, for path " + paramPath);
- }
- return { obj: obj, objProp: objProp };
- };
- return NeuralNetwork;
- }());
- export { NeuralNetwork };
- //# sourceMappingURL=NeuralNetwork.js.map
|