SsdMobilenetv1.js 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tslib_1 = require("tslib");
  4. var tf = require("@tensorflow/tfjs-core");
  5. var classes_1 = require("../classes");
  6. var FaceDetection_1 = require("../classes/FaceDetection");
  7. var dom_1 = require("../dom");
  8. var NeuralNetwork_1 = require("../NeuralNetwork");
  9. var extractParams_1 = require("./extractParams");
  10. var extractParamsFromWeigthMap_1 = require("./extractParamsFromWeigthMap");
  11. var mobileNetV1_1 = require("./mobileNetV1");
  12. var nonMaxSuppression_1 = require("./nonMaxSuppression");
  13. var outputLayer_1 = require("./outputLayer");
  14. var predictionLayer_1 = require("./predictionLayer");
  15. var SsdMobilenetv1Options_1 = require("./SsdMobilenetv1Options");
  16. var SsdMobilenetv1 = /** @class */ (function (_super) {
  17. tslib_1.__extends(SsdMobilenetv1, _super);
  18. function SsdMobilenetv1() {
  19. return _super.call(this, 'SsdMobilenetv1') || this;
  20. }
  21. SsdMobilenetv1.prototype.forwardInput = function (input) {
  22. var params = this.params;
  23. if (!params) {
  24. throw new Error('SsdMobilenetv1 - load model before inference');
  25. }
  26. return tf.tidy(function () {
  27. var batchTensor = input.toBatchTensor(512, false).toFloat();
  28. var x = tf.sub(tf.mul(batchTensor, tf.scalar(0.007843137718737125)), tf.scalar(1));
  29. var features = mobileNetV1_1.mobileNetV1(x, params.mobilenetv1);
  30. var _a = predictionLayer_1.predictionLayer(features.out, features.conv11, params.prediction_layer), boxPredictions = _a.boxPredictions, classPredictions = _a.classPredictions;
  31. return outputLayer_1.outputLayer(boxPredictions, classPredictions, params.output_layer);
  32. });
  33. };
  34. SsdMobilenetv1.prototype.forward = function (input) {
  35. return tslib_1.__awaiter(this, void 0, void 0, function () {
  36. var _a;
  37. return tslib_1.__generator(this, function (_b) {
  38. switch (_b.label) {
  39. case 0:
  40. _a = this.forwardInput;
  41. return [4 /*yield*/, dom_1.toNetInput(input)];
  42. case 1: return [2 /*return*/, _a.apply(this, [_b.sent()])];
  43. }
  44. });
  45. });
  46. };
  47. SsdMobilenetv1.prototype.locateFaces = function (input, options) {
  48. if (options === void 0) { options = {}; }
  49. return tslib_1.__awaiter(this, void 0, void 0, function () {
  50. var _a, maxResults, minConfidence, netInput, _b, _boxes, _scores, boxes, scores, i, scoresData, _c, _d, iouThreshold, indices, reshapedDims, inputSize, padX, padY, boxesData, results;
  51. return tslib_1.__generator(this, function (_e) {
  52. switch (_e.label) {
  53. case 0:
  54. _a = new SsdMobilenetv1Options_1.SsdMobilenetv1Options(options), maxResults = _a.maxResults, minConfidence = _a.minConfidence;
  55. return [4 /*yield*/, dom_1.toNetInput(input)];
  56. case 1:
  57. netInput = _e.sent();
  58. _b = this.forwardInput(netInput), _boxes = _b.boxes, _scores = _b.scores;
  59. boxes = _boxes[0];
  60. scores = _scores[0];
  61. for (i = 1; i < _boxes.length; i++) {
  62. _boxes[i].dispose();
  63. _scores[i].dispose();
  64. }
  65. _d = (_c = Array).from;
  66. return [4 /*yield*/, scores.data()];
  67. case 2:
  68. scoresData = _d.apply(_c, [_e.sent()]);
  69. iouThreshold = 0.5;
  70. indices = nonMaxSuppression_1.nonMaxSuppression(boxes, scoresData, maxResults, iouThreshold, minConfidence);
  71. reshapedDims = netInput.getReshapedInputDimensions(0);
  72. inputSize = netInput.inputSize;
  73. padX = inputSize / reshapedDims.width;
  74. padY = inputSize / reshapedDims.height;
  75. boxesData = boxes.arraySync();
  76. results = indices
  77. .map(function (idx) {
  78. var _a = [
  79. Math.max(0, boxesData[idx][0]),
  80. Math.min(1.0, boxesData[idx][2])
  81. ].map(function (val) { return val * padY; }), top = _a[0], bottom = _a[1];
  82. var _b = [
  83. Math.max(0, boxesData[idx][1]),
  84. Math.min(1.0, boxesData[idx][3])
  85. ].map(function (val) { return val * padX; }), left = _b[0], right = _b[1];
  86. return new FaceDetection_1.FaceDetection(scoresData[idx], new classes_1.Rect(left, top, right - left, bottom - top), {
  87. height: netInput.getInputHeight(0),
  88. width: netInput.getInputWidth(0)
  89. });
  90. });
  91. boxes.dispose();
  92. scores.dispose();
  93. return [2 /*return*/, results];
  94. }
  95. });
  96. });
  97. };
  98. SsdMobilenetv1.prototype.getDefaultModelName = function () {
  99. return 'ssd_mobilenetv1_model';
  100. };
  101. SsdMobilenetv1.prototype.extractParamsFromWeigthMap = function (weightMap) {
  102. return extractParamsFromWeigthMap_1.extractParamsFromWeigthMap(weightMap);
  103. };
  104. SsdMobilenetv1.prototype.extractParams = function (weights) {
  105. return extractParams_1.extractParams(weights);
  106. };
  107. return SsdMobilenetv1;
  108. }(NeuralNetwork_1.NeuralNetwork));
  109. exports.SsdMobilenetv1 = SsdMobilenetv1;
  110. //# sourceMappingURL=SsdMobilenetv1.js.map