TinyYolov2Base.js 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tslib_1 = require("tslib");
  4. var tf = require("@tensorflow/tfjs-core");
  5. var BoundingBox_1 = require("../classes/BoundingBox");
  6. var ObjectDetection_1 = require("../classes/ObjectDetection");
  7. var common_1 = require("../common");
  8. var dom_1 = require("../dom");
  9. var NeuralNetwork_1 = require("../NeuralNetwork");
  10. var ops_1 = require("../ops");
  11. var nonMaxSuppression_1 = require("../ops/nonMaxSuppression");
  12. var normalize_1 = require("../ops/normalize");
  13. var config_1 = require("./config");
  14. var convWithBatchNorm_1 = require("./convWithBatchNorm");
  15. var depthwiseSeparableConv_1 = require("./depthwiseSeparableConv");
  16. var extractParams_1 = require("./extractParams");
  17. var extractParamsFromWeigthMap_1 = require("./extractParamsFromWeigthMap");
  18. var leaky_1 = require("./leaky");
  19. var TinyYolov2Options_1 = require("./TinyYolov2Options");
  20. var TinyYolov2Base = /** @class */ (function (_super) {
  21. tslib_1.__extends(TinyYolov2Base, _super);
  22. function TinyYolov2Base(config) {
  23. var _this = _super.call(this, 'TinyYolov2') || this;
  24. config_1.validateConfig(config);
  25. _this._config = config;
  26. return _this;
  27. }
  28. Object.defineProperty(TinyYolov2Base.prototype, "config", {
  29. get: function () {
  30. return this._config;
  31. },
  32. enumerable: true,
  33. configurable: true
  34. });
  35. Object.defineProperty(TinyYolov2Base.prototype, "withClassScores", {
  36. get: function () {
  37. return this.config.withClassScores || this.config.classes.length > 1;
  38. },
  39. enumerable: true,
  40. configurable: true
  41. });
  42. Object.defineProperty(TinyYolov2Base.prototype, "boxEncodingSize", {
  43. get: function () {
  44. return 5 + (this.withClassScores ? this.config.classes.length : 0);
  45. },
  46. enumerable: true,
  47. configurable: true
  48. });
  49. TinyYolov2Base.prototype.runTinyYolov2 = function (x, params) {
  50. var out = convWithBatchNorm_1.convWithBatchNorm(x, params.conv0);
  51. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  52. out = convWithBatchNorm_1.convWithBatchNorm(out, params.conv1);
  53. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  54. out = convWithBatchNorm_1.convWithBatchNorm(out, params.conv2);
  55. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  56. out = convWithBatchNorm_1.convWithBatchNorm(out, params.conv3);
  57. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  58. out = convWithBatchNorm_1.convWithBatchNorm(out, params.conv4);
  59. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  60. out = convWithBatchNorm_1.convWithBatchNorm(out, params.conv5);
  61. out = tf.maxPool(out, [2, 2], [1, 1], 'same');
  62. out = convWithBatchNorm_1.convWithBatchNorm(out, params.conv6);
  63. out = convWithBatchNorm_1.convWithBatchNorm(out, params.conv7);
  64. return common_1.convLayer(out, params.conv8, 'valid', false);
  65. };
  66. TinyYolov2Base.prototype.runMobilenet = function (x, params) {
  67. var out = this.config.isFirstLayerConv2d
  68. ? leaky_1.leaky(common_1.convLayer(x, params.conv0, 'valid', false))
  69. : depthwiseSeparableConv_1.depthwiseSeparableConv(x, params.conv0);
  70. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  71. out = depthwiseSeparableConv_1.depthwiseSeparableConv(out, params.conv1);
  72. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  73. out = depthwiseSeparableConv_1.depthwiseSeparableConv(out, params.conv2);
  74. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  75. out = depthwiseSeparableConv_1.depthwiseSeparableConv(out, params.conv3);
  76. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  77. out = depthwiseSeparableConv_1.depthwiseSeparableConv(out, params.conv4);
  78. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  79. out = depthwiseSeparableConv_1.depthwiseSeparableConv(out, params.conv5);
  80. out = tf.maxPool(out, [2, 2], [1, 1], 'same');
  81. out = params.conv6 ? depthwiseSeparableConv_1.depthwiseSeparableConv(out, params.conv6) : out;
  82. out = params.conv7 ? depthwiseSeparableConv_1.depthwiseSeparableConv(out, params.conv7) : out;
  83. return common_1.convLayer(out, params.conv8, 'valid', false);
  84. };
  85. TinyYolov2Base.prototype.forwardInput = function (input, inputSize) {
  86. var _this = this;
  87. var params = this.params;
  88. if (!params) {
  89. throw new Error('TinyYolov2 - load model before inference');
  90. }
  91. return tf.tidy(function () {
  92. var batchTensor = input.toBatchTensor(inputSize, false).toFloat();
  93. batchTensor = _this.config.meanRgb
  94. ? normalize_1.normalize(batchTensor, _this.config.meanRgb)
  95. : batchTensor;
  96. batchTensor = batchTensor.div(tf.scalar(256));
  97. return _this.config.withSeparableConvs
  98. ? _this.runMobilenet(batchTensor, params)
  99. : _this.runTinyYolov2(batchTensor, params);
  100. });
  101. };
  102. TinyYolov2Base.prototype.forward = function (input, inputSize) {
  103. return tslib_1.__awaiter(this, void 0, void 0, function () {
  104. var _a;
  105. return tslib_1.__generator(this, function (_b) {
  106. switch (_b.label) {
  107. case 0:
  108. _a = this.forwardInput;
  109. return [4 /*yield*/, dom_1.toNetInput(input)];
  110. case 1: return [4 /*yield*/, _a.apply(this, [_b.sent(), inputSize])];
  111. case 2: return [2 /*return*/, _b.sent()];
  112. }
  113. });
  114. });
  115. };
  116. TinyYolov2Base.prototype.detect = function (input, forwardParams) {
  117. if (forwardParams === void 0) { forwardParams = {}; }
  118. return tslib_1.__awaiter(this, void 0, void 0, function () {
  119. var _a, inputSize, scoreThreshold, netInput, out, out0, inputDimensions, results, boxes, scores, classScores, classNames, indices, detections;
  120. var _this = this;
  121. return tslib_1.__generator(this, function (_b) {
  122. switch (_b.label) {
  123. case 0:
  124. _a = new TinyYolov2Options_1.TinyYolov2Options(forwardParams), inputSize = _a.inputSize, scoreThreshold = _a.scoreThreshold;
  125. return [4 /*yield*/, dom_1.toNetInput(input)];
  126. case 1:
  127. netInput = _b.sent();
  128. return [4 /*yield*/, this.forwardInput(netInput, inputSize)];
  129. case 2:
  130. out = _b.sent();
  131. out0 = tf.tidy(function () { return tf.unstack(out)[0].expandDims(); });
  132. inputDimensions = {
  133. width: netInput.getInputWidth(0),
  134. height: netInput.getInputHeight(0)
  135. };
  136. return [4 /*yield*/, this.extractBoxes(out0, netInput.getReshapedInputDimensions(0), scoreThreshold)];
  137. case 3:
  138. results = _b.sent();
  139. out.dispose();
  140. out0.dispose();
  141. boxes = results.map(function (res) { return res.box; });
  142. scores = results.map(function (res) { return res.score; });
  143. classScores = results.map(function (res) { return res.classScore; });
  144. classNames = results.map(function (res) { return _this.config.classes[res.label]; });
  145. indices = nonMaxSuppression_1.nonMaxSuppression(boxes.map(function (box) { return box.rescale(inputSize); }), scores, this.config.iouThreshold, true);
  146. detections = indices.map(function (idx) {
  147. return new ObjectDetection_1.ObjectDetection(scores[idx], classScores[idx], classNames[idx], boxes[idx], inputDimensions);
  148. });
  149. return [2 /*return*/, detections];
  150. }
  151. });
  152. });
  153. };
  154. TinyYolov2Base.prototype.getDefaultModelName = function () {
  155. return '';
  156. };
  157. TinyYolov2Base.prototype.extractParamsFromWeigthMap = function (weightMap) {
  158. return extractParamsFromWeigthMap_1.extractParamsFromWeigthMap(weightMap, this.config);
  159. };
  160. TinyYolov2Base.prototype.extractParams = function (weights) {
  161. var filterSizes = this.config.filterSizes || TinyYolov2Base.DEFAULT_FILTER_SIZES;
  162. var numFilters = filterSizes ? filterSizes.length : undefined;
  163. if (numFilters !== 7 && numFilters !== 8 && numFilters !== 9) {
  164. throw new Error("TinyYolov2 - expected 7 | 8 | 9 convolutional filters, but found " + numFilters + " filterSizes in config");
  165. }
  166. return extractParams_1.extractParams(weights, this.config, this.boxEncodingSize, filterSizes);
  167. };
  168. TinyYolov2Base.prototype.extractBoxes = function (outputTensor, inputBlobDimensions, scoreThreshold) {
  169. return tslib_1.__awaiter(this, void 0, void 0, function () {
  170. var width, height, inputSize, correctionFactorX, correctionFactorY, numCells, numBoxes, _a, boxesTensor, scoresTensor, classScoresTensor, results, scoresData, boxesData, row, col, anchor, score, ctX, ctY, width_1, height_1, x, y, pos, _b, classScore, label, _c;
  171. var _this = this;
  172. return tslib_1.__generator(this, function (_d) {
  173. switch (_d.label) {
  174. case 0:
  175. width = inputBlobDimensions.width, height = inputBlobDimensions.height;
  176. inputSize = Math.max(width, height);
  177. correctionFactorX = inputSize / width;
  178. correctionFactorY = inputSize / height;
  179. numCells = outputTensor.shape[1];
  180. numBoxes = this.config.anchors.length;
  181. _a = tf.tidy(function () {
  182. var reshaped = outputTensor.reshape([numCells, numCells, numBoxes, _this.boxEncodingSize]);
  183. var boxes = reshaped.slice([0, 0, 0, 0], [numCells, numCells, numBoxes, 4]);
  184. var scores = reshaped.slice([0, 0, 0, 4], [numCells, numCells, numBoxes, 1]);
  185. var classScores = _this.withClassScores
  186. ? tf.softmax(reshaped.slice([0, 0, 0, 5], [numCells, numCells, numBoxes, _this.config.classes.length]), 3)
  187. : tf.scalar(0);
  188. return [boxes, scores, classScores];
  189. }), boxesTensor = _a[0], scoresTensor = _a[1], classScoresTensor = _a[2];
  190. results = [];
  191. return [4 /*yield*/, scoresTensor.array()];
  192. case 1:
  193. scoresData = _d.sent();
  194. return [4 /*yield*/, boxesTensor.array()];
  195. case 2:
  196. boxesData = _d.sent();
  197. row = 0;
  198. _d.label = 3;
  199. case 3:
  200. if (!(row < numCells)) return [3 /*break*/, 12];
  201. col = 0;
  202. _d.label = 4;
  203. case 4:
  204. if (!(col < numCells)) return [3 /*break*/, 11];
  205. anchor = 0;
  206. _d.label = 5;
  207. case 5:
  208. if (!(anchor < numBoxes)) return [3 /*break*/, 10];
  209. score = ops_1.sigmoid(scoresData[row][col][anchor][0]);
  210. if (!(!scoreThreshold || score > scoreThreshold)) return [3 /*break*/, 9];
  211. ctX = ((col + ops_1.sigmoid(boxesData[row][col][anchor][0])) / numCells) * correctionFactorX;
  212. ctY = ((row + ops_1.sigmoid(boxesData[row][col][anchor][1])) / numCells) * correctionFactorY;
  213. width_1 = ((Math.exp(boxesData[row][col][anchor][2]) * this.config.anchors[anchor].x) / numCells) * correctionFactorX;
  214. height_1 = ((Math.exp(boxesData[row][col][anchor][3]) * this.config.anchors[anchor].y) / numCells) * correctionFactorY;
  215. x = (ctX - (width_1 / 2));
  216. y = (ctY - (height_1 / 2));
  217. pos = { row: row, col: col, anchor: anchor };
  218. if (!this.withClassScores) return [3 /*break*/, 7];
  219. return [4 /*yield*/, this.extractPredictedClass(classScoresTensor, pos)];
  220. case 6:
  221. _c = _d.sent();
  222. return [3 /*break*/, 8];
  223. case 7:
  224. _c = { classScore: 1, label: 0 };
  225. _d.label = 8;
  226. case 8:
  227. _b = _c, classScore = _b.classScore, label = _b.label;
  228. results.push(tslib_1.__assign({ box: new BoundingBox_1.BoundingBox(x, y, x + width_1, y + height_1), score: score, classScore: score * classScore, label: label }, pos));
  229. _d.label = 9;
  230. case 9:
  231. anchor++;
  232. return [3 /*break*/, 5];
  233. case 10:
  234. col++;
  235. return [3 /*break*/, 4];
  236. case 11:
  237. row++;
  238. return [3 /*break*/, 3];
  239. case 12:
  240. boxesTensor.dispose();
  241. scoresTensor.dispose();
  242. classScoresTensor.dispose();
  243. return [2 /*return*/, results];
  244. }
  245. });
  246. });
  247. };
  248. TinyYolov2Base.prototype.extractPredictedClass = function (classesTensor, pos) {
  249. return tslib_1.__awaiter(this, void 0, void 0, function () {
  250. var row, col, anchor, classesData;
  251. return tslib_1.__generator(this, function (_a) {
  252. switch (_a.label) {
  253. case 0:
  254. row = pos.row, col = pos.col, anchor = pos.anchor;
  255. return [4 /*yield*/, classesTensor.array()];
  256. case 1:
  257. classesData = _a.sent();
  258. return [2 /*return*/, Array(this.config.classes.length).fill(0)
  259. .map(function (_, i) { return classesData[row][col][anchor][i]; })
  260. .map(function (classScore, label) { return ({
  261. classScore: classScore,
  262. label: label
  263. }); })
  264. .reduce(function (max, curr) { return max.classScore > curr.classScore ? max : curr; })];
  265. }
  266. });
  267. });
  268. };
  269. TinyYolov2Base.DEFAULT_FILTER_SIZES = [
  270. 3, 16, 32, 64, 128, 256, 512, 1024, 1024
  271. ];
  272. return TinyYolov2Base;
  273. }(NeuralNetwork_1.NeuralNetwork));
  274. exports.TinyYolov2Base = TinyYolov2Base;
  275. //# sourceMappingURL=TinyYolov2Base.js.map