TinyXception.js 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tslib_1 = require("tslib");
  4. var tf = require("@tensorflow/tfjs-core");
  5. var common_1 = require("../common");
  6. var dom_1 = require("../dom");
  7. var NeuralNetwork_1 = require("../NeuralNetwork");
  8. var ops_1 = require("../ops");
  9. var utils_1 = require("../utils");
  10. var extractParams_1 = require("./extractParams");
  11. var extractParamsFromWeigthMap_1 = require("./extractParamsFromWeigthMap");
  12. function conv(x, params, stride) {
  13. return tf.add(tf.conv2d(x, params.filters, stride, 'same'), params.bias);
  14. }
  15. function reductionBlock(x, params, isActivateInput) {
  16. if (isActivateInput === void 0) { isActivateInput = true; }
  17. var out = isActivateInput ? tf.relu(x) : x;
  18. out = common_1.depthwiseSeparableConv(out, params.separable_conv0, [1, 1]);
  19. out = common_1.depthwiseSeparableConv(tf.relu(out), params.separable_conv1, [1, 1]);
  20. out = tf.maxPool(out, [3, 3], [2, 2], 'same');
  21. out = tf.add(out, conv(x, params.expansion_conv, [2, 2]));
  22. return out;
  23. }
  24. function mainBlock(x, params) {
  25. var out = common_1.depthwiseSeparableConv(tf.relu(x), params.separable_conv0, [1, 1]);
  26. out = common_1.depthwiseSeparableConv(tf.relu(out), params.separable_conv1, [1, 1]);
  27. out = common_1.depthwiseSeparableConv(tf.relu(out), params.separable_conv2, [1, 1]);
  28. out = tf.add(out, x);
  29. return out;
  30. }
  31. var TinyXception = /** @class */ (function (_super) {
  32. tslib_1.__extends(TinyXception, _super);
  33. function TinyXception(numMainBlocks) {
  34. var _this = _super.call(this, 'TinyXception') || this;
  35. _this._numMainBlocks = numMainBlocks;
  36. return _this;
  37. }
  38. TinyXception.prototype.forwardInput = function (input) {
  39. var _this = this;
  40. var params = this.params;
  41. if (!params) {
  42. throw new Error('TinyXception - load model before inference');
  43. }
  44. return tf.tidy(function () {
  45. var batchTensor = input.toBatchTensor(112, true);
  46. var meanRgb = [122.782, 117.001, 104.298];
  47. var normalized = ops_1.normalize(batchTensor, meanRgb).div(tf.scalar(256));
  48. var out = tf.relu(conv(normalized, params.entry_flow.conv_in, [2, 2]));
  49. out = reductionBlock(out, params.entry_flow.reduction_block_0, false);
  50. out = reductionBlock(out, params.entry_flow.reduction_block_1);
  51. utils_1.range(_this._numMainBlocks, 0, 1).forEach(function (idx) {
  52. out = mainBlock(out, params.middle_flow["main_block_" + idx]);
  53. });
  54. out = reductionBlock(out, params.exit_flow.reduction_block);
  55. out = tf.relu(common_1.depthwiseSeparableConv(out, params.exit_flow.separable_conv, [1, 1]));
  56. return out;
  57. });
  58. };
  59. TinyXception.prototype.forward = function (input) {
  60. return tslib_1.__awaiter(this, void 0, void 0, function () {
  61. var _a;
  62. return tslib_1.__generator(this, function (_b) {
  63. switch (_b.label) {
  64. case 0:
  65. _a = this.forwardInput;
  66. return [4 /*yield*/, dom_1.toNetInput(input)];
  67. case 1: return [2 /*return*/, _a.apply(this, [_b.sent()])];
  68. }
  69. });
  70. });
  71. };
  72. TinyXception.prototype.getDefaultModelName = function () {
  73. return 'tiny_xception_model';
  74. };
  75. TinyXception.prototype.extractParamsFromWeigthMap = function (weightMap) {
  76. return extractParamsFromWeigthMap_1.extractParamsFromWeigthMap(weightMap, this._numMainBlocks);
  77. };
  78. TinyXception.prototype.extractParams = function (weights) {
  79. return extractParams_1.extractParams(weights, this._numMainBlocks);
  80. };
  81. return TinyXception;
  82. }(NeuralNetwork_1.NeuralNetwork));
  83. exports.TinyXception = TinyXception;
  84. //# sourceMappingURL=TinyXception.js.map