SsdMobilenetv1.js 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import { __awaiter, __extends, __generator } from "tslib";
  2. import * as tf from '@tensorflow/tfjs-core';
  3. import { Rect } from '../classes';
  4. import { FaceDetection } from '../classes/FaceDetection';
  5. import { toNetInput } from '../dom';
  6. import { NeuralNetwork } from '../NeuralNetwork';
  7. import { extractParams } from './extractParams';
  8. import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
  9. import { mobileNetV1 } from './mobileNetV1';
  10. import { nonMaxSuppression } from './nonMaxSuppression';
  11. import { outputLayer } from './outputLayer';
  12. import { predictionLayer } from './predictionLayer';
  13. import { SsdMobilenetv1Options } from './SsdMobilenetv1Options';
  14. var SsdMobilenetv1 = /** @class */ (function (_super) {
  15. __extends(SsdMobilenetv1, _super);
  16. function SsdMobilenetv1() {
  17. return _super.call(this, 'SsdMobilenetv1') || this;
  18. }
  19. SsdMobilenetv1.prototype.forwardInput = function (input) {
  20. var params = this.params;
  21. if (!params) {
  22. throw new Error('SsdMobilenetv1 - load model before inference');
  23. }
  24. return tf.tidy(function () {
  25. var batchTensor = input.toBatchTensor(512, false).toFloat();
  26. var x = tf.sub(tf.mul(batchTensor, tf.scalar(0.007843137718737125)), tf.scalar(1));
  27. var features = mobileNetV1(x, params.mobilenetv1);
  28. var _a = predictionLayer(features.out, features.conv11, params.prediction_layer), boxPredictions = _a.boxPredictions, classPredictions = _a.classPredictions;
  29. return outputLayer(boxPredictions, classPredictions, params.output_layer);
  30. });
  31. };
  32. SsdMobilenetv1.prototype.forward = function (input) {
  33. return __awaiter(this, void 0, void 0, function () {
  34. var _a;
  35. return __generator(this, function (_b) {
  36. switch (_b.label) {
  37. case 0:
  38. _a = this.forwardInput;
  39. return [4 /*yield*/, toNetInput(input)];
  40. case 1: return [2 /*return*/, _a.apply(this, [_b.sent()])];
  41. }
  42. });
  43. });
  44. };
  45. SsdMobilenetv1.prototype.locateFaces = function (input, options) {
  46. if (options === void 0) { options = {}; }
  47. return __awaiter(this, void 0, void 0, function () {
  48. var _a, maxResults, minConfidence, netInput, _b, _boxes, _scores, boxes, scores, i, scoresData, _c, _d, iouThreshold, indices, reshapedDims, inputSize, padX, padY, boxesData, results;
  49. return __generator(this, function (_e) {
  50. switch (_e.label) {
  51. case 0:
  52. _a = new SsdMobilenetv1Options(options), maxResults = _a.maxResults, minConfidence = _a.minConfidence;
  53. return [4 /*yield*/, toNetInput(input)];
  54. case 1:
  55. netInput = _e.sent();
  56. _b = this.forwardInput(netInput), _boxes = _b.boxes, _scores = _b.scores;
  57. boxes = _boxes[0];
  58. scores = _scores[0];
  59. for (i = 1; i < _boxes.length; i++) {
  60. _boxes[i].dispose();
  61. _scores[i].dispose();
  62. }
  63. _d = (_c = Array).from;
  64. return [4 /*yield*/, scores.data()];
  65. case 2:
  66. scoresData = _d.apply(_c, [_e.sent()]);
  67. iouThreshold = 0.5;
  68. indices = nonMaxSuppression(boxes, scoresData, maxResults, iouThreshold, minConfidence);
  69. reshapedDims = netInput.getReshapedInputDimensions(0);
  70. inputSize = netInput.inputSize;
  71. padX = inputSize / reshapedDims.width;
  72. padY = inputSize / reshapedDims.height;
  73. boxesData = boxes.arraySync();
  74. results = indices
  75. .map(function (idx) {
  76. var _a = [
  77. Math.max(0, boxesData[idx][0]),
  78. Math.min(1.0, boxesData[idx][2])
  79. ].map(function (val) { return val * padY; }), top = _a[0], bottom = _a[1];
  80. var _b = [
  81. Math.max(0, boxesData[idx][1]),
  82. Math.min(1.0, boxesData[idx][3])
  83. ].map(function (val) { return val * padX; }), left = _b[0], right = _b[1];
  84. return new FaceDetection(scoresData[idx], new Rect(left, top, right - left, bottom - top), {
  85. height: netInput.getInputHeight(0),
  86. width: netInput.getInputWidth(0)
  87. });
  88. });
  89. boxes.dispose();
  90. scores.dispose();
  91. return [2 /*return*/, results];
  92. }
  93. });
  94. });
  95. };
  96. SsdMobilenetv1.prototype.getDefaultModelName = function () {
  97. return 'ssd_mobilenetv1_model';
  98. };
  99. SsdMobilenetv1.prototype.extractParamsFromWeigthMap = function (weightMap) {
  100. return extractParamsFromWeigthMap(weightMap);
  101. };
  102. SsdMobilenetv1.prototype.extractParams = function (weights) {
  103. return extractParams(weights);
  104. };
  105. return SsdMobilenetv1;
  106. }(NeuralNetwork));
  107. export { SsdMobilenetv1 };
  108. //# sourceMappingURL=SsdMobilenetv1.js.map