Mtcnn.js 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tslib_1 = require("tslib");
  4. var tf = require("@tensorflow/tfjs-core");
  5. var classes_1 = require("../classes");
  6. var FaceDetection_1 = require("../classes/FaceDetection");
  7. var FaceLandmarks5_1 = require("../classes/FaceLandmarks5");
  8. var dom_1 = require("../dom");
  9. var factories_1 = require("../factories");
  10. var NeuralNetwork_1 = require("../NeuralNetwork");
  11. var bgrToRgbTensor_1 = require("./bgrToRgbTensor");
  12. var config_1 = require("./config");
  13. var extractParams_1 = require("./extractParams");
  14. var extractParamsFromWeigthMap_1 = require("./extractParamsFromWeigthMap");
  15. var getSizesForScale_1 = require("./getSizesForScale");
  16. var MtcnnOptions_1 = require("./MtcnnOptions");
  17. var pyramidDown_1 = require("./pyramidDown");
  18. var stage1_1 = require("./stage1");
  19. var stage2_1 = require("./stage2");
  20. var stage3_1 = require("./stage3");
  21. var Mtcnn = /** @class */ (function (_super) {
  22. tslib_1.__extends(Mtcnn, _super);
  23. function Mtcnn() {
  24. return _super.call(this, 'Mtcnn') || this;
  25. }
  26. Mtcnn.prototype.load = function (weightsOrUrl) {
  27. return tslib_1.__awaiter(this, void 0, void 0, function () {
  28. return tslib_1.__generator(this, function (_a) {
  29. console.warn('mtcnn is deprecated and will be removed soon');
  30. return [2 /*return*/, _super.prototype.load.call(this, weightsOrUrl)];
  31. });
  32. });
  33. };
  34. Mtcnn.prototype.loadFromDisk = function (filePath) {
  35. return tslib_1.__awaiter(this, void 0, void 0, function () {
  36. return tslib_1.__generator(this, function (_a) {
  37. console.warn('mtcnn is deprecated and will be removed soon');
  38. return [2 /*return*/, _super.prototype.loadFromDisk.call(this, filePath)];
  39. });
  40. });
  41. };
  42. Mtcnn.prototype.forwardInput = function (input, forwardParams) {
  43. if (forwardParams === void 0) { forwardParams = {}; }
  44. return tslib_1.__awaiter(this, void 0, void 0, function () {
  45. var params, inputCanvas, stats, tsTotal, imgTensor, onReturn, _a, height, width, _b, minFaceSize, scaleFactor, maxNumScales, scoreThresholds, scaleSteps, scales, ts, out1, out2, out3, results;
  46. return tslib_1.__generator(this, function (_c) {
  47. switch (_c.label) {
  48. case 0:
  49. params = this.params;
  50. if (!params) {
  51. throw new Error('Mtcnn - load model before inference');
  52. }
  53. inputCanvas = input.canvases[0];
  54. if (!inputCanvas) {
  55. throw new Error('Mtcnn - inputCanvas is not defined, note that passing tensors into Mtcnn.forwardInput is not supported yet.');
  56. }
  57. stats = {};
  58. tsTotal = Date.now();
  59. imgTensor = tf.tidy(function () {
  60. return bgrToRgbTensor_1.bgrToRgbTensor(tf.expandDims(tf.browser.fromPixels(inputCanvas)).toFloat());
  61. });
  62. onReturn = function (results) {
  63. // dispose tensors on return
  64. imgTensor.dispose();
  65. stats.total = Date.now() - tsTotal;
  66. return results;
  67. };
  68. _a = imgTensor.shape.slice(1), height = _a[0], width = _a[1];
  69. _b = new MtcnnOptions_1.MtcnnOptions(forwardParams), minFaceSize = _b.minFaceSize, scaleFactor = _b.scaleFactor, maxNumScales = _b.maxNumScales, scoreThresholds = _b.scoreThresholds, scaleSteps = _b.scaleSteps;
  70. scales = (scaleSteps || pyramidDown_1.pyramidDown(minFaceSize, scaleFactor, [height, width]))
  71. .filter(function (scale) {
  72. var sizes = getSizesForScale_1.getSizesForScale(scale, [height, width]);
  73. return Math.min(sizes.width, sizes.height) > config_1.CELL_SIZE;
  74. })
  75. .slice(0, maxNumScales);
  76. stats.scales = scales;
  77. stats.pyramid = scales.map(function (scale) { return getSizesForScale_1.getSizesForScale(scale, [height, width]); });
  78. ts = Date.now();
  79. return [4 /*yield*/, stage1_1.stage1(imgTensor, scales, scoreThresholds[0], params.pnet, stats)];
  80. case 1:
  81. out1 = _c.sent();
  82. stats.total_stage1 = Date.now() - ts;
  83. if (!out1.boxes.length) {
  84. return [2 /*return*/, onReturn({ results: [], stats: stats })];
  85. }
  86. stats.stage2_numInputBoxes = out1.boxes.length;
  87. // using the inputCanvas to extract and resize the image patches, since it is faster
  88. // than doing this on the gpu
  89. ts = Date.now();
  90. return [4 /*yield*/, stage2_1.stage2(inputCanvas, out1.boxes, scoreThresholds[1], params.rnet, stats)];
  91. case 2:
  92. out2 = _c.sent();
  93. stats.total_stage2 = Date.now() - ts;
  94. if (!out2.boxes.length) {
  95. return [2 /*return*/, onReturn({ results: [], stats: stats })];
  96. }
  97. stats.stage3_numInputBoxes = out2.boxes.length;
  98. ts = Date.now();
  99. return [4 /*yield*/, stage3_1.stage3(inputCanvas, out2.boxes, scoreThresholds[2], params.onet, stats)];
  100. case 3:
  101. out3 = _c.sent();
  102. stats.total_stage3 = Date.now() - ts;
  103. results = out3.boxes.map(function (box, idx) { return factories_1.extendWithFaceLandmarks(factories_1.extendWithFaceDetection({}, new FaceDetection_1.FaceDetection(out3.scores[idx], new classes_1.Rect(box.left / width, box.top / height, box.width / width, box.height / height), {
  104. height: height,
  105. width: width
  106. })), new FaceLandmarks5_1.FaceLandmarks5(out3.points[idx].map(function (pt) { return pt.sub(new classes_1.Point(box.left, box.top)).div(new classes_1.Point(box.width, box.height)); }), { width: box.width, height: box.height })); });
  107. return [2 /*return*/, onReturn({ results: results, stats: stats })];
  108. }
  109. });
  110. });
  111. };
  112. Mtcnn.prototype.forward = function (input, forwardParams) {
  113. if (forwardParams === void 0) { forwardParams = {}; }
  114. return tslib_1.__awaiter(this, void 0, void 0, function () {
  115. var _a;
  116. return tslib_1.__generator(this, function (_b) {
  117. switch (_b.label) {
  118. case 0:
  119. _a = this.forwardInput;
  120. return [4 /*yield*/, dom_1.toNetInput(input)];
  121. case 1: return [4 /*yield*/, _a.apply(this, [_b.sent(),
  122. forwardParams])];
  123. case 2: return [2 /*return*/, (_b.sent()).results];
  124. }
  125. });
  126. });
  127. };
  128. Mtcnn.prototype.forwardWithStats = function (input, forwardParams) {
  129. if (forwardParams === void 0) { forwardParams = {}; }
  130. return tslib_1.__awaiter(this, void 0, void 0, function () {
  131. var _a;
  132. return tslib_1.__generator(this, function (_b) {
  133. switch (_b.label) {
  134. case 0:
  135. _a = this.forwardInput;
  136. return [4 /*yield*/, dom_1.toNetInput(input)];
  137. case 1: return [2 /*return*/, _a.apply(this, [_b.sent(),
  138. forwardParams])];
  139. }
  140. });
  141. });
  142. };
  143. Mtcnn.prototype.getDefaultModelName = function () {
  144. return 'mtcnn_model';
  145. };
  146. Mtcnn.prototype.extractParamsFromWeigthMap = function (weightMap) {
  147. return extractParamsFromWeigthMap_1.extractParamsFromWeigthMap(weightMap);
  148. };
  149. Mtcnn.prototype.extractParams = function (weights) {
  150. return extractParams_1.extractParams(weights);
  151. };
  152. return Mtcnn;
  153. }(NeuralNetwork_1.NeuralNetwork));
  154. exports.Mtcnn = Mtcnn;
  155. //# sourceMappingURL=Mtcnn.js.map