AgeGenderNet.js 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tslib_1 = require("tslib");
  4. var tf = require("@tensorflow/tfjs-core");
  5. var fullyConnectedLayer_1 = require("../common/fullyConnectedLayer");
  6. var util_1 = require("../faceProcessor/util");
  7. var TinyXception_1 = require("../xception/TinyXception");
  8. var extractParams_1 = require("./extractParams");
  9. var extractParamsFromWeigthMap_1 = require("./extractParamsFromWeigthMap");
  10. var types_1 = require("./types");
  11. var NeuralNetwork_1 = require("../NeuralNetwork");
  12. var dom_1 = require("../dom");
  13. var AgeGenderNet = /** @class */ (function (_super) {
  14. tslib_1.__extends(AgeGenderNet, _super);
  15. function AgeGenderNet(faceFeatureExtractor) {
  16. if (faceFeatureExtractor === void 0) { faceFeatureExtractor = new TinyXception_1.TinyXception(2); }
  17. var _this = _super.call(this, 'AgeGenderNet') || this;
  18. _this._faceFeatureExtractor = faceFeatureExtractor;
  19. return _this;
  20. }
  21. Object.defineProperty(AgeGenderNet.prototype, "faceFeatureExtractor", {
  22. get: function () {
  23. return this._faceFeatureExtractor;
  24. },
  25. enumerable: true,
  26. configurable: true
  27. });
  28. AgeGenderNet.prototype.runNet = function (input) {
  29. var _this = this;
  30. var params = this.params;
  31. if (!params) {
  32. throw new Error(this._name + " - load model before inference");
  33. }
  34. return tf.tidy(function () {
  35. var bottleneckFeatures = input instanceof dom_1.NetInput
  36. ? _this.faceFeatureExtractor.forwardInput(input)
  37. : input;
  38. var pooled = tf.avgPool(bottleneckFeatures, [7, 7], [2, 2], 'valid').as2D(bottleneckFeatures.shape[0], -1);
  39. var age = fullyConnectedLayer_1.fullyConnectedLayer(pooled, params.fc.age).as1D();
  40. var gender = fullyConnectedLayer_1.fullyConnectedLayer(pooled, params.fc.gender);
  41. return { age: age, gender: gender };
  42. });
  43. };
  44. AgeGenderNet.prototype.forwardInput = function (input) {
  45. var _this = this;
  46. return tf.tidy(function () {
  47. var _a = _this.runNet(input), age = _a.age, gender = _a.gender;
  48. return { age: age, gender: tf.softmax(gender) };
  49. });
  50. };
  51. AgeGenderNet.prototype.forward = function (input) {
  52. return tslib_1.__awaiter(this, void 0, void 0, function () {
  53. var _a;
  54. return tslib_1.__generator(this, function (_b) {
  55. switch (_b.label) {
  56. case 0:
  57. _a = this.forwardInput;
  58. return [4 /*yield*/, dom_1.toNetInput(input)];
  59. case 1: return [2 /*return*/, _a.apply(this, [_b.sent()])];
  60. }
  61. });
  62. });
  63. };
  64. AgeGenderNet.prototype.predictAgeAndGender = function (input) {
  65. return tslib_1.__awaiter(this, void 0, void 0, function () {
  66. var netInput, out, ages, genders, ageAndGenderTensors, predictionsByBatch;
  67. var _this = this;
  68. return tslib_1.__generator(this, function (_a) {
  69. switch (_a.label) {
  70. case 0: return [4 /*yield*/, dom_1.toNetInput(input)];
  71. case 1:
  72. netInput = _a.sent();
  73. return [4 /*yield*/, this.forwardInput(netInput)];
  74. case 2:
  75. out = _a.sent();
  76. ages = tf.unstack(out.age);
  77. genders = tf.unstack(out.gender);
  78. ageAndGenderTensors = ages.map(function (ageTensor, i) { return ({
  79. ageTensor: ageTensor,
  80. genderTensor: genders[i]
  81. }); });
  82. return [4 /*yield*/, Promise.all(ageAndGenderTensors.map(function (_a) {
  83. var ageTensor = _a.ageTensor, genderTensor = _a.genderTensor;
  84. return tslib_1.__awaiter(_this, void 0, void 0, function () {
  85. var age, probMale, isMale, gender, genderProbability;
  86. return tslib_1.__generator(this, function (_b) {
  87. switch (_b.label) {
  88. case 0: return [4 /*yield*/, ageTensor.data()];
  89. case 1:
  90. age = (_b.sent())[0];
  91. return [4 /*yield*/, genderTensor.data()];
  92. case 2:
  93. probMale = (_b.sent())[0];
  94. isMale = probMale > 0.5;
  95. gender = isMale ? types_1.Gender.MALE : types_1.Gender.FEMALE;
  96. genderProbability = isMale ? probMale : (1 - probMale);
  97. ageTensor.dispose();
  98. genderTensor.dispose();
  99. return [2 /*return*/, { age: age, gender: gender, genderProbability: genderProbability }];
  100. }
  101. });
  102. });
  103. }))];
  104. case 3:
  105. predictionsByBatch = _a.sent();
  106. out.age.dispose();
  107. out.gender.dispose();
  108. return [2 /*return*/, netInput.isBatchInput
  109. ? predictionsByBatch
  110. : predictionsByBatch[0]];
  111. }
  112. });
  113. });
  114. };
  115. AgeGenderNet.prototype.getDefaultModelName = function () {
  116. return 'age_gender_model';
  117. };
  118. AgeGenderNet.prototype.dispose = function (throwOnRedispose) {
  119. if (throwOnRedispose === void 0) { throwOnRedispose = true; }
  120. this.faceFeatureExtractor.dispose(throwOnRedispose);
  121. _super.prototype.dispose.call(this, throwOnRedispose);
  122. };
  123. AgeGenderNet.prototype.loadClassifierParams = function (weights) {
  124. var _a = this.extractClassifierParams(weights), params = _a.params, paramMappings = _a.paramMappings;
  125. this._params = params;
  126. this._paramMappings = paramMappings;
  127. };
  128. AgeGenderNet.prototype.extractClassifierParams = function (weights) {
  129. return extractParams_1.extractParams(weights);
  130. };
  131. AgeGenderNet.prototype.extractParamsFromWeigthMap = function (weightMap) {
  132. var _a = util_1.seperateWeightMaps(weightMap), featureExtractorMap = _a.featureExtractorMap, classifierMap = _a.classifierMap;
  133. this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap);
  134. return extractParamsFromWeigthMap_1.extractParamsFromWeigthMap(classifierMap);
  135. };
  136. AgeGenderNet.prototype.extractParams = function (weights) {
  137. var classifierWeightSize = (512 * 1 + 1) + (512 * 2 + 2);
  138. var featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize);
  139. var classifierWeights = weights.slice(weights.length - classifierWeightSize);
  140. this.faceFeatureExtractor.extractWeights(featureExtractorWeights);
  141. return this.extractClassifierParams(classifierWeights);
  142. };
  143. return AgeGenderNet;
  144. }(NeuralNetwork_1.NeuralNetwork));
  145. exports.AgeGenderNet = AgeGenderNet;
  146. //# sourceMappingURL=AgeGenderNet.js.map