FaceRecognitionNet.js 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tslib_1 = require("tslib");
  4. var tf = require("@tensorflow/tfjs-core");
  5. var dom_1 = require("../dom");
  6. var NeuralNetwork_1 = require("../NeuralNetwork");
  7. var ops_1 = require("../ops");
  8. var convLayer_1 = require("./convLayer");
  9. var extractParams_1 = require("./extractParams");
  10. var extractParamsFromWeigthMap_1 = require("./extractParamsFromWeigthMap");
  11. var residualLayer_1 = require("./residualLayer");
  12. var FaceRecognitionNet = /** @class */ (function (_super) {
  13. tslib_1.__extends(FaceRecognitionNet, _super);
  14. function FaceRecognitionNet() {
  15. return _super.call(this, 'FaceRecognitionNet') || this;
  16. }
  17. FaceRecognitionNet.prototype.forwardInput = function (input) {
  18. var params = this.params;
  19. if (!params) {
  20. throw new Error('FaceRecognitionNet - load model before inference');
  21. }
  22. return tf.tidy(function () {
  23. var batchTensor = input.toBatchTensor(150, true).toFloat();
  24. var meanRgb = [122.782, 117.001, 104.298];
  25. var normalized = ops_1.normalize(batchTensor, meanRgb).div(tf.scalar(256));
  26. var out = convLayer_1.convDown(normalized, params.conv32_down);
  27. out = tf.maxPool(out, 3, 2, 'valid');
  28. out = residualLayer_1.residual(out, params.conv32_1);
  29. out = residualLayer_1.residual(out, params.conv32_2);
  30. out = residualLayer_1.residual(out, params.conv32_3);
  31. out = residualLayer_1.residualDown(out, params.conv64_down);
  32. out = residualLayer_1.residual(out, params.conv64_1);
  33. out = residualLayer_1.residual(out, params.conv64_2);
  34. out = residualLayer_1.residual(out, params.conv64_3);
  35. out = residualLayer_1.residualDown(out, params.conv128_down);
  36. out = residualLayer_1.residual(out, params.conv128_1);
  37. out = residualLayer_1.residual(out, params.conv128_2);
  38. out = residualLayer_1.residualDown(out, params.conv256_down);
  39. out = residualLayer_1.residual(out, params.conv256_1);
  40. out = residualLayer_1.residual(out, params.conv256_2);
  41. out = residualLayer_1.residualDown(out, params.conv256_down_out);
  42. var globalAvg = out.mean([1, 2]);
  43. var fullyConnected = tf.matMul(globalAvg, params.fc);
  44. return fullyConnected;
  45. });
  46. };
  47. FaceRecognitionNet.prototype.forward = function (input) {
  48. return tslib_1.__awaiter(this, void 0, void 0, function () {
  49. var _a;
  50. return tslib_1.__generator(this, function (_b) {
  51. switch (_b.label) {
  52. case 0:
  53. _a = this.forwardInput;
  54. return [4 /*yield*/, dom_1.toNetInput(input)];
  55. case 1: return [2 /*return*/, _a.apply(this, [_b.sent()])];
  56. }
  57. });
  58. });
  59. };
  60. FaceRecognitionNet.prototype.computeFaceDescriptor = function (input) {
  61. return tslib_1.__awaiter(this, void 0, void 0, function () {
  62. var netInput, faceDescriptorTensors, faceDescriptorsForBatch;
  63. var _this = this;
  64. return tslib_1.__generator(this, function (_a) {
  65. switch (_a.label) {
  66. case 0: return [4 /*yield*/, dom_1.toNetInput(input)];
  67. case 1:
  68. netInput = _a.sent();
  69. faceDescriptorTensors = tf.tidy(function () { return tf.unstack(_this.forwardInput(netInput)); });
  70. return [4 /*yield*/, Promise.all(faceDescriptorTensors.map(function (t) { return t.data(); }))];
  71. case 2:
  72. faceDescriptorsForBatch = _a.sent();
  73. faceDescriptorTensors.forEach(function (t) { return t.dispose(); });
  74. return [2 /*return*/, netInput.isBatchInput
  75. ? faceDescriptorsForBatch
  76. : faceDescriptorsForBatch[0]];
  77. }
  78. });
  79. });
  80. };
  81. FaceRecognitionNet.prototype.getDefaultModelName = function () {
  82. return 'face_recognition_model';
  83. };
  84. FaceRecognitionNet.prototype.extractParamsFromWeigthMap = function (weightMap) {
  85. return extractParamsFromWeigthMap_1.extractParamsFromWeigthMap(weightMap);
  86. };
  87. FaceRecognitionNet.prototype.extractParams = function (weights) {
  88. return extractParams_1.extractParams(weights);
  89. };
  90. return FaceRecognitionNet;
  91. }(NeuralNetwork_1.NeuralNetwork));
  92. exports.FaceRecognitionNet = FaceRecognitionNet;
  93. //# sourceMappingURL=FaceRecognitionNet.js.map