AgeGenderNet.js 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import { __awaiter, __extends, __generator } from "tslib";
  2. import * as tf from '@tensorflow/tfjs-core';
  3. import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
  4. import { seperateWeightMaps } from '../faceProcessor/util';
  5. import { TinyXception } from '../xception/TinyXception';
  6. import { extractParams } from './extractParams';
  7. import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
  8. import { Gender } from './types';
  9. import { NeuralNetwork } from '../NeuralNetwork';
  10. import { NetInput, toNetInput } from '../dom';
  11. var AgeGenderNet = /** @class */ (function (_super) {
  12. __extends(AgeGenderNet, _super);
  13. function AgeGenderNet(faceFeatureExtractor) {
  14. if (faceFeatureExtractor === void 0) { faceFeatureExtractor = new TinyXception(2); }
  15. var _this = _super.call(this, 'AgeGenderNet') || this;
  16. _this._faceFeatureExtractor = faceFeatureExtractor;
  17. return _this;
  18. }
  19. Object.defineProperty(AgeGenderNet.prototype, "faceFeatureExtractor", {
  20. get: function () {
  21. return this._faceFeatureExtractor;
  22. },
  23. enumerable: true,
  24. configurable: true
  25. });
  26. AgeGenderNet.prototype.runNet = function (input) {
  27. var _this = this;
  28. var params = this.params;
  29. if (!params) {
  30. throw new Error(this._name + " - load model before inference");
  31. }
  32. return tf.tidy(function () {
  33. var bottleneckFeatures = input instanceof NetInput
  34. ? _this.faceFeatureExtractor.forwardInput(input)
  35. : input;
  36. var pooled = tf.avgPool(bottleneckFeatures, [7, 7], [2, 2], 'valid').as2D(bottleneckFeatures.shape[0], -1);
  37. var age = fullyConnectedLayer(pooled, params.fc.age).as1D();
  38. var gender = fullyConnectedLayer(pooled, params.fc.gender);
  39. return { age: age, gender: gender };
  40. });
  41. };
  42. AgeGenderNet.prototype.forwardInput = function (input) {
  43. var _this = this;
  44. return tf.tidy(function () {
  45. var _a = _this.runNet(input), age = _a.age, gender = _a.gender;
  46. return { age: age, gender: tf.softmax(gender) };
  47. });
  48. };
  49. AgeGenderNet.prototype.forward = function (input) {
  50. return __awaiter(this, void 0, void 0, function () {
  51. var _a;
  52. return __generator(this, function (_b) {
  53. switch (_b.label) {
  54. case 0:
  55. _a = this.forwardInput;
  56. return [4 /*yield*/, toNetInput(input)];
  57. case 1: return [2 /*return*/, _a.apply(this, [_b.sent()])];
  58. }
  59. });
  60. });
  61. };
  62. AgeGenderNet.prototype.predictAgeAndGender = function (input) {
  63. return __awaiter(this, void 0, void 0, function () {
  64. var netInput, out, ages, genders, ageAndGenderTensors, predictionsByBatch;
  65. var _this = this;
  66. return __generator(this, function (_a) {
  67. switch (_a.label) {
  68. case 0: return [4 /*yield*/, toNetInput(input)];
  69. case 1:
  70. netInput = _a.sent();
  71. return [4 /*yield*/, this.forwardInput(netInput)];
  72. case 2:
  73. out = _a.sent();
  74. ages = tf.unstack(out.age);
  75. genders = tf.unstack(out.gender);
  76. ageAndGenderTensors = ages.map(function (ageTensor, i) { return ({
  77. ageTensor: ageTensor,
  78. genderTensor: genders[i]
  79. }); });
  80. return [4 /*yield*/, Promise.all(ageAndGenderTensors.map(function (_a) {
  81. var ageTensor = _a.ageTensor, genderTensor = _a.genderTensor;
  82. return __awaiter(_this, void 0, void 0, function () {
  83. var age, probMale, isMale, gender, genderProbability;
  84. return __generator(this, function (_b) {
  85. switch (_b.label) {
  86. case 0: return [4 /*yield*/, ageTensor.data()];
  87. case 1:
  88. age = (_b.sent())[0];
  89. return [4 /*yield*/, genderTensor.data()];
  90. case 2:
  91. probMale = (_b.sent())[0];
  92. isMale = probMale > 0.5;
  93. gender = isMale ? Gender.MALE : Gender.FEMALE;
  94. genderProbability = isMale ? probMale : (1 - probMale);
  95. ageTensor.dispose();
  96. genderTensor.dispose();
  97. return [2 /*return*/, { age: age, gender: gender, genderProbability: genderProbability }];
  98. }
  99. });
  100. });
  101. }))];
  102. case 3:
  103. predictionsByBatch = _a.sent();
  104. out.age.dispose();
  105. out.gender.dispose();
  106. return [2 /*return*/, netInput.isBatchInput
  107. ? predictionsByBatch
  108. : predictionsByBatch[0]];
  109. }
  110. });
  111. });
  112. };
  113. AgeGenderNet.prototype.getDefaultModelName = function () {
  114. return 'age_gender_model';
  115. };
  116. AgeGenderNet.prototype.dispose = function (throwOnRedispose) {
  117. if (throwOnRedispose === void 0) { throwOnRedispose = true; }
  118. this.faceFeatureExtractor.dispose(throwOnRedispose);
  119. _super.prototype.dispose.call(this, throwOnRedispose);
  120. };
  121. AgeGenderNet.prototype.loadClassifierParams = function (weights) {
  122. var _a = this.extractClassifierParams(weights), params = _a.params, paramMappings = _a.paramMappings;
  123. this._params = params;
  124. this._paramMappings = paramMappings;
  125. };
  126. AgeGenderNet.prototype.extractClassifierParams = function (weights) {
  127. return extractParams(weights);
  128. };
  129. AgeGenderNet.prototype.extractParamsFromWeigthMap = function (weightMap) {
  130. var _a = seperateWeightMaps(weightMap), featureExtractorMap = _a.featureExtractorMap, classifierMap = _a.classifierMap;
  131. this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap);
  132. return extractParamsFromWeigthMap(classifierMap);
  133. };
  134. AgeGenderNet.prototype.extractParams = function (weights) {
  135. var classifierWeightSize = (512 * 1 + 1) + (512 * 2 + 2);
  136. var featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize);
  137. var classifierWeights = weights.slice(weights.length - classifierWeightSize);
  138. this.faceFeatureExtractor.extractWeights(featureExtractorWeights);
  139. return this.extractClassifierParams(classifierWeights);
  140. };
  141. return AgeGenderNet;
  142. }(NeuralNetwork));
  143. export { AgeGenderNet };
  144. //# sourceMappingURL=AgeGenderNet.js.map