TinyYolov2Base.js 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import { __assign, __awaiter, __extends, __generator } from "tslib";
  2. import * as tf from '@tensorflow/tfjs-core';
  3. import { BoundingBox } from '../classes/BoundingBox';
  4. import { ObjectDetection } from '../classes/ObjectDetection';
  5. import { convLayer } from '../common';
  6. import { toNetInput } from '../dom';
  7. import { NeuralNetwork } from '../NeuralNetwork';
  8. import { sigmoid } from '../ops';
  9. import { nonMaxSuppression } from '../ops/nonMaxSuppression';
  10. import { normalize } from '../ops/normalize';
  11. import { validateConfig } from './config';
  12. import { convWithBatchNorm } from './convWithBatchNorm';
  13. import { depthwiseSeparableConv } from './depthwiseSeparableConv';
  14. import { extractParams } from './extractParams';
  15. import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
  16. import { leaky } from './leaky';
  17. import { TinyYolov2Options } from './TinyYolov2Options';
  18. var TinyYolov2Base = /** @class */ (function (_super) {
  19. __extends(TinyYolov2Base, _super);
  20. function TinyYolov2Base(config) {
  21. var _this = _super.call(this, 'TinyYolov2') || this;
  22. validateConfig(config);
  23. _this._config = config;
  24. return _this;
  25. }
  26. Object.defineProperty(TinyYolov2Base.prototype, "config", {
  27. get: function () {
  28. return this._config;
  29. },
  30. enumerable: true,
  31. configurable: true
  32. });
  33. Object.defineProperty(TinyYolov2Base.prototype, "withClassScores", {
  34. get: function () {
  35. return this.config.withClassScores || this.config.classes.length > 1;
  36. },
  37. enumerable: true,
  38. configurable: true
  39. });
  40. Object.defineProperty(TinyYolov2Base.prototype, "boxEncodingSize", {
  41. get: function () {
  42. return 5 + (this.withClassScores ? this.config.classes.length : 0);
  43. },
  44. enumerable: true,
  45. configurable: true
  46. });
  47. TinyYolov2Base.prototype.runTinyYolov2 = function (x, params) {
  48. var out = convWithBatchNorm(x, params.conv0);
  49. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  50. out = convWithBatchNorm(out, params.conv1);
  51. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  52. out = convWithBatchNorm(out, params.conv2);
  53. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  54. out = convWithBatchNorm(out, params.conv3);
  55. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  56. out = convWithBatchNorm(out, params.conv4);
  57. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  58. out = convWithBatchNorm(out, params.conv5);
  59. out = tf.maxPool(out, [2, 2], [1, 1], 'same');
  60. out = convWithBatchNorm(out, params.conv6);
  61. out = convWithBatchNorm(out, params.conv7);
  62. return convLayer(out, params.conv8, 'valid', false);
  63. };
  64. TinyYolov2Base.prototype.runMobilenet = function (x, params) {
  65. var out = this.config.isFirstLayerConv2d
  66. ? leaky(convLayer(x, params.conv0, 'valid', false))
  67. : depthwiseSeparableConv(x, params.conv0);
  68. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  69. out = depthwiseSeparableConv(out, params.conv1);
  70. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  71. out = depthwiseSeparableConv(out, params.conv2);
  72. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  73. out = depthwiseSeparableConv(out, params.conv3);
  74. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  75. out = depthwiseSeparableConv(out, params.conv4);
  76. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  77. out = depthwiseSeparableConv(out, params.conv5);
  78. out = tf.maxPool(out, [2, 2], [1, 1], 'same');
  79. out = params.conv6 ? depthwiseSeparableConv(out, params.conv6) : out;
  80. out = params.conv7 ? depthwiseSeparableConv(out, params.conv7) : out;
  81. return convLayer(out, params.conv8, 'valid', false);
  82. };
  83. TinyYolov2Base.prototype.forwardInput = function (input, inputSize) {
  84. var _this = this;
  85. var params = this.params;
  86. if (!params) {
  87. throw new Error('TinyYolov2 - load model before inference');
  88. }
  89. return tf.tidy(function () {
  90. var batchTensor = input.toBatchTensor(inputSize, false).toFloat();
  91. batchTensor = _this.config.meanRgb
  92. ? normalize(batchTensor, _this.config.meanRgb)
  93. : batchTensor;
  94. batchTensor = batchTensor.div(tf.scalar(256));
  95. return _this.config.withSeparableConvs
  96. ? _this.runMobilenet(batchTensor, params)
  97. : _this.runTinyYolov2(batchTensor, params);
  98. });
  99. };
  100. TinyYolov2Base.prototype.forward = function (input, inputSize) {
  101. return __awaiter(this, void 0, void 0, function () {
  102. var _a;
  103. return __generator(this, function (_b) {
  104. switch (_b.label) {
  105. case 0:
  106. _a = this.forwardInput;
  107. return [4 /*yield*/, toNetInput(input)];
  108. case 1: return [4 /*yield*/, _a.apply(this, [_b.sent(), inputSize])];
  109. case 2: return [2 /*return*/, _b.sent()];
  110. }
  111. });
  112. });
  113. };
  114. TinyYolov2Base.prototype.detect = function (input, forwardParams) {
  115. if (forwardParams === void 0) { forwardParams = {}; }
  116. return __awaiter(this, void 0, void 0, function () {
  117. var _a, inputSize, scoreThreshold, netInput, out, out0, inputDimensions, results, boxes, scores, classScores, classNames, indices, detections;
  118. var _this = this;
  119. return __generator(this, function (_b) {
  120. switch (_b.label) {
  121. case 0:
  122. _a = new TinyYolov2Options(forwardParams), inputSize = _a.inputSize, scoreThreshold = _a.scoreThreshold;
  123. return [4 /*yield*/, toNetInput(input)];
  124. case 1:
  125. netInput = _b.sent();
  126. return [4 /*yield*/, this.forwardInput(netInput, inputSize)];
  127. case 2:
  128. out = _b.sent();
  129. out0 = tf.tidy(function () { return tf.unstack(out)[0].expandDims(); });
  130. inputDimensions = {
  131. width: netInput.getInputWidth(0),
  132. height: netInput.getInputHeight(0)
  133. };
  134. return [4 /*yield*/, this.extractBoxes(out0, netInput.getReshapedInputDimensions(0), scoreThreshold)];
  135. case 3:
  136. results = _b.sent();
  137. out.dispose();
  138. out0.dispose();
  139. boxes = results.map(function (res) { return res.box; });
  140. scores = results.map(function (res) { return res.score; });
  141. classScores = results.map(function (res) { return res.classScore; });
  142. classNames = results.map(function (res) { return _this.config.classes[res.label]; });
  143. indices = nonMaxSuppression(boxes.map(function (box) { return box.rescale(inputSize); }), scores, this.config.iouThreshold, true);
  144. detections = indices.map(function (idx) {
  145. return new ObjectDetection(scores[idx], classScores[idx], classNames[idx], boxes[idx], inputDimensions);
  146. });
  147. return [2 /*return*/, detections];
  148. }
  149. });
  150. });
  151. };
  152. TinyYolov2Base.prototype.getDefaultModelName = function () {
  153. return '';
  154. };
  155. TinyYolov2Base.prototype.extractParamsFromWeigthMap = function (weightMap) {
  156. return extractParamsFromWeigthMap(weightMap, this.config);
  157. };
  158. TinyYolov2Base.prototype.extractParams = function (weights) {
  159. var filterSizes = this.config.filterSizes || TinyYolov2Base.DEFAULT_FILTER_SIZES;
  160. var numFilters = filterSizes ? filterSizes.length : undefined;
  161. if (numFilters !== 7 && numFilters !== 8 && numFilters !== 9) {
  162. throw new Error("TinyYolov2 - expected 7 | 8 | 9 convolutional filters, but found " + numFilters + " filterSizes in config");
  163. }
  164. return extractParams(weights, this.config, this.boxEncodingSize, filterSizes);
  165. };
  166. TinyYolov2Base.prototype.extractBoxes = function (outputTensor, inputBlobDimensions, scoreThreshold) {
  167. return __awaiter(this, void 0, void 0, function () {
  168. var width, height, inputSize, correctionFactorX, correctionFactorY, numCells, numBoxes, _a, boxesTensor, scoresTensor, classScoresTensor, results, scoresData, boxesData, row, col, anchor, score, ctX, ctY, width_1, height_1, x, y, pos, _b, classScore, label, _c;
  169. var _this = this;
  170. return __generator(this, function (_d) {
  171. switch (_d.label) {
  172. case 0:
  173. width = inputBlobDimensions.width, height = inputBlobDimensions.height;
  174. inputSize = Math.max(width, height);
  175. correctionFactorX = inputSize / width;
  176. correctionFactorY = inputSize / height;
  177. numCells = outputTensor.shape[1];
  178. numBoxes = this.config.anchors.length;
  179. _a = tf.tidy(function () {
  180. var reshaped = outputTensor.reshape([numCells, numCells, numBoxes, _this.boxEncodingSize]);
  181. var boxes = reshaped.slice([0, 0, 0, 0], [numCells, numCells, numBoxes, 4]);
  182. var scores = reshaped.slice([0, 0, 0, 4], [numCells, numCells, numBoxes, 1]);
  183. var classScores = _this.withClassScores
  184. ? tf.softmax(reshaped.slice([0, 0, 0, 5], [numCells, numCells, numBoxes, _this.config.classes.length]), 3)
  185. : tf.scalar(0);
  186. return [boxes, scores, classScores];
  187. }), boxesTensor = _a[0], scoresTensor = _a[1], classScoresTensor = _a[2];
  188. results = [];
  189. return [4 /*yield*/, scoresTensor.array()];
  190. case 1:
  191. scoresData = _d.sent();
  192. return [4 /*yield*/, boxesTensor.array()];
  193. case 2:
  194. boxesData = _d.sent();
  195. row = 0;
  196. _d.label = 3;
  197. case 3:
  198. if (!(row < numCells)) return [3 /*break*/, 12];
  199. col = 0;
  200. _d.label = 4;
  201. case 4:
  202. if (!(col < numCells)) return [3 /*break*/, 11];
  203. anchor = 0;
  204. _d.label = 5;
  205. case 5:
  206. if (!(anchor < numBoxes)) return [3 /*break*/, 10];
  207. score = sigmoid(scoresData[row][col][anchor][0]);
  208. if (!(!scoreThreshold || score > scoreThreshold)) return [3 /*break*/, 9];
  209. ctX = ((col + sigmoid(boxesData[row][col][anchor][0])) / numCells) * correctionFactorX;
  210. ctY = ((row + sigmoid(boxesData[row][col][anchor][1])) / numCells) * correctionFactorY;
  211. width_1 = ((Math.exp(boxesData[row][col][anchor][2]) * this.config.anchors[anchor].x) / numCells) * correctionFactorX;
  212. height_1 = ((Math.exp(boxesData[row][col][anchor][3]) * this.config.anchors[anchor].y) / numCells) * correctionFactorY;
  213. x = (ctX - (width_1 / 2));
  214. y = (ctY - (height_1 / 2));
  215. pos = { row: row, col: col, anchor: anchor };
  216. if (!this.withClassScores) return [3 /*break*/, 7];
  217. return [4 /*yield*/, this.extractPredictedClass(classScoresTensor, pos)];
  218. case 6:
  219. _c = _d.sent();
  220. return [3 /*break*/, 8];
  221. case 7:
  222. _c = { classScore: 1, label: 0 };
  223. _d.label = 8;
  224. case 8:
  225. _b = _c, classScore = _b.classScore, label = _b.label;
  226. results.push(__assign({ box: new BoundingBox(x, y, x + width_1, y + height_1), score: score, classScore: score * classScore, label: label }, pos));
  227. _d.label = 9;
  228. case 9:
  229. anchor++;
  230. return [3 /*break*/, 5];
  231. case 10:
  232. col++;
  233. return [3 /*break*/, 4];
  234. case 11:
  235. row++;
  236. return [3 /*break*/, 3];
  237. case 12:
  238. boxesTensor.dispose();
  239. scoresTensor.dispose();
  240. classScoresTensor.dispose();
  241. return [2 /*return*/, results];
  242. }
  243. });
  244. });
  245. };
  246. TinyYolov2Base.prototype.extractPredictedClass = function (classesTensor, pos) {
  247. return __awaiter(this, void 0, void 0, function () {
  248. var row, col, anchor, classesData;
  249. return __generator(this, function (_a) {
  250. switch (_a.label) {
  251. case 0:
  252. row = pos.row, col = pos.col, anchor = pos.anchor;
  253. return [4 /*yield*/, classesTensor.array()];
  254. case 1:
  255. classesData = _a.sent();
  256. return [2 /*return*/, Array(this.config.classes.length).fill(0)
  257. .map(function (_, i) { return classesData[row][col][anchor][i]; })
  258. .map(function (classScore, label) { return ({
  259. classScore: classScore,
  260. label: label
  261. }); })
  262. .reduce(function (max, curr) { return max.classScore > curr.classScore ? max : curr; })];
  263. }
  264. });
  265. });
  266. };
  267. TinyYolov2Base.DEFAULT_FILTER_SIZES = [
  268. 3, 16, 32, 64, 128, 256, 512, 1024, 1024
  269. ];
  270. return TinyYolov2Base;
  271. }(NeuralNetwork));
  272. export { TinyYolov2Base };
  273. //# sourceMappingURL=TinyYolov2Base.js.map