Mtcnn.js 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import { __awaiter, __extends, __generator } from "tslib";
  2. import * as tf from '@tensorflow/tfjs-core';
  3. import { Point, Rect } from '../classes';
  4. import { FaceDetection } from '../classes/FaceDetection';
  5. import { FaceLandmarks5 } from '../classes/FaceLandmarks5';
  6. import { toNetInput } from '../dom';
  7. import { extendWithFaceDetection, extendWithFaceLandmarks } from '../factories';
  8. import { NeuralNetwork } from '../NeuralNetwork';
  9. import { bgrToRgbTensor } from './bgrToRgbTensor';
  10. import { CELL_SIZE } from './config';
  11. import { extractParams } from './extractParams';
  12. import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
  13. import { getSizesForScale } from './getSizesForScale';
  14. import { MtcnnOptions } from './MtcnnOptions';
  15. import { pyramidDown } from './pyramidDown';
  16. import { stage1 } from './stage1';
  17. import { stage2 } from './stage2';
  18. import { stage3 } from './stage3';
  19. var Mtcnn = /** @class */ (function (_super) {
  20. __extends(Mtcnn, _super);
  21. function Mtcnn() {
  22. return _super.call(this, 'Mtcnn') || this;
  23. }
  24. Mtcnn.prototype.load = function (weightsOrUrl) {
  25. return __awaiter(this, void 0, void 0, function () {
  26. return __generator(this, function (_a) {
  27. console.warn('mtcnn is deprecated and will be removed soon');
  28. return [2 /*return*/, _super.prototype.load.call(this, weightsOrUrl)];
  29. });
  30. });
  31. };
  32. Mtcnn.prototype.loadFromDisk = function (filePath) {
  33. return __awaiter(this, void 0, void 0, function () {
  34. return __generator(this, function (_a) {
  35. console.warn('mtcnn is deprecated and will be removed soon');
  36. return [2 /*return*/, _super.prototype.loadFromDisk.call(this, filePath)];
  37. });
  38. });
  39. };
  40. Mtcnn.prototype.forwardInput = function (input, forwardParams) {
  41. if (forwardParams === void 0) { forwardParams = {}; }
  42. return __awaiter(this, void 0, void 0, function () {
  43. var params, inputCanvas, stats, tsTotal, imgTensor, onReturn, _a, height, width, _b, minFaceSize, scaleFactor, maxNumScales, scoreThresholds, scaleSteps, scales, ts, out1, out2, out3, results;
  44. return __generator(this, function (_c) {
  45. switch (_c.label) {
  46. case 0:
  47. params = this.params;
  48. if (!params) {
  49. throw new Error('Mtcnn - load model before inference');
  50. }
  51. inputCanvas = input.canvases[0];
  52. if (!inputCanvas) {
  53. throw new Error('Mtcnn - inputCanvas is not defined, note that passing tensors into Mtcnn.forwardInput is not supported yet.');
  54. }
  55. stats = {};
  56. tsTotal = Date.now();
  57. imgTensor = tf.tidy(function () {
  58. return bgrToRgbTensor(tf.expandDims(tf.browser.fromPixels(inputCanvas)).toFloat());
  59. });
  60. onReturn = function (results) {
  61. // dispose tensors on return
  62. imgTensor.dispose();
  63. stats.total = Date.now() - tsTotal;
  64. return results;
  65. };
  66. _a = imgTensor.shape.slice(1), height = _a[0], width = _a[1];
  67. _b = new MtcnnOptions(forwardParams), minFaceSize = _b.minFaceSize, scaleFactor = _b.scaleFactor, maxNumScales = _b.maxNumScales, scoreThresholds = _b.scoreThresholds, scaleSteps = _b.scaleSteps;
  68. scales = (scaleSteps || pyramidDown(minFaceSize, scaleFactor, [height, width]))
  69. .filter(function (scale) {
  70. var sizes = getSizesForScale(scale, [height, width]);
  71. return Math.min(sizes.width, sizes.height) > CELL_SIZE;
  72. })
  73. .slice(0, maxNumScales);
  74. stats.scales = scales;
  75. stats.pyramid = scales.map(function (scale) { return getSizesForScale(scale, [height, width]); });
  76. ts = Date.now();
  77. return [4 /*yield*/, stage1(imgTensor, scales, scoreThresholds[0], params.pnet, stats)];
  78. case 1:
  79. out1 = _c.sent();
  80. stats.total_stage1 = Date.now() - ts;
  81. if (!out1.boxes.length) {
  82. return [2 /*return*/, onReturn({ results: [], stats: stats })];
  83. }
  84. stats.stage2_numInputBoxes = out1.boxes.length;
  85. // using the inputCanvas to extract and resize the image patches, since it is faster
  86. // than doing this on the gpu
  87. ts = Date.now();
  88. return [4 /*yield*/, stage2(inputCanvas, out1.boxes, scoreThresholds[1], params.rnet, stats)];
  89. case 2:
  90. out2 = _c.sent();
  91. stats.total_stage2 = Date.now() - ts;
  92. if (!out2.boxes.length) {
  93. return [2 /*return*/, onReturn({ results: [], stats: stats })];
  94. }
  95. stats.stage3_numInputBoxes = out2.boxes.length;
  96. ts = Date.now();
  97. return [4 /*yield*/, stage3(inputCanvas, out2.boxes, scoreThresholds[2], params.onet, stats)];
  98. case 3:
  99. out3 = _c.sent();
  100. stats.total_stage3 = Date.now() - ts;
  101. results = out3.boxes.map(function (box, idx) { return extendWithFaceLandmarks(extendWithFaceDetection({}, new FaceDetection(out3.scores[idx], new Rect(box.left / width, box.top / height, box.width / width, box.height / height), {
  102. height: height,
  103. width: width
  104. })), new FaceLandmarks5(out3.points[idx].map(function (pt) { return pt.sub(new Point(box.left, box.top)).div(new Point(box.width, box.height)); }), { width: box.width, height: box.height })); });
  105. return [2 /*return*/, onReturn({ results: results, stats: stats })];
  106. }
  107. });
  108. });
  109. };
  110. Mtcnn.prototype.forward = function (input, forwardParams) {
  111. if (forwardParams === void 0) { forwardParams = {}; }
  112. return __awaiter(this, void 0, void 0, function () {
  113. var _a;
  114. return __generator(this, function (_b) {
  115. switch (_b.label) {
  116. case 0:
  117. _a = this.forwardInput;
  118. return [4 /*yield*/, toNetInput(input)];
  119. case 1: return [4 /*yield*/, _a.apply(this, [_b.sent(),
  120. forwardParams])];
  121. case 2: return [2 /*return*/, (_b.sent()).results];
  122. }
  123. });
  124. });
  125. };
  126. Mtcnn.prototype.forwardWithStats = function (input, forwardParams) {
  127. if (forwardParams === void 0) { forwardParams = {}; }
  128. return __awaiter(this, void 0, void 0, function () {
  129. var _a;
  130. return __generator(this, function (_b) {
  131. switch (_b.label) {
  132. case 0:
  133. _a = this.forwardInput;
  134. return [4 /*yield*/, toNetInput(input)];
  135. case 1: return [2 /*return*/, _a.apply(this, [_b.sent(),
  136. forwardParams])];
  137. }
  138. });
  139. });
  140. };
  141. Mtcnn.prototype.getDefaultModelName = function () {
  142. return 'mtcnn_model';
  143. };
  144. Mtcnn.prototype.extractParamsFromWeigthMap = function (weightMap) {
  145. return extractParamsFromWeigthMap(weightMap);
  146. };
  147. Mtcnn.prototype.extractParams = function (weights) {
  148. return extractParams(weights);
  149. };
  150. return Mtcnn;
  151. }(NeuralNetwork));
  152. export { Mtcnn };
  153. //# sourceMappingURL=Mtcnn.js.map