NetInput.js 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tf = require("@tensorflow/tfjs-core");
  4. var env_1 = require("../env");
  5. var padToSquare_1 = require("../ops/padToSquare");
  6. var utils_1 = require("../utils");
  7. var createCanvas_1 = require("./createCanvas");
  8. var imageToSquare_1 = require("./imageToSquare");
  9. var NetInput = /** @class */ (function () {
  10. function NetInput(inputs, treatAsBatchInput) {
  11. var _this = this;
  12. if (treatAsBatchInput === void 0) { treatAsBatchInput = false; }
  13. this._imageTensors = [];
  14. this._canvases = [];
  15. this._treatAsBatchInput = false;
  16. this._inputDimensions = [];
  17. if (!Array.isArray(inputs)) {
  18. throw new Error("NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have " + inputs);
  19. }
  20. this._treatAsBatchInput = treatAsBatchInput;
  21. this._batchSize = inputs.length;
  22. inputs.forEach(function (input, idx) {
  23. if (utils_1.isTensor3D(input)) {
  24. _this._imageTensors[idx] = input;
  25. _this._inputDimensions[idx] = input.shape;
  26. return;
  27. }
  28. if (utils_1.isTensor4D(input)) {
  29. var batchSize = input.shape[0];
  30. if (batchSize !== 1) {
  31. throw new Error("NetInput - tf.Tensor4D with batchSize " + batchSize + " passed, but not supported in input array");
  32. }
  33. _this._imageTensors[idx] = input;
  34. _this._inputDimensions[idx] = input.shape.slice(1);
  35. return;
  36. }
  37. var canvas = input instanceof env_1.env.getEnv().Canvas ? input : createCanvas_1.createCanvasFromMedia(input);
  38. _this._canvases[idx] = canvas;
  39. _this._inputDimensions[idx] = [canvas.height, canvas.width, 3];
  40. });
  41. }
  42. Object.defineProperty(NetInput.prototype, "imageTensors", {
  43. get: function () {
  44. return this._imageTensors;
  45. },
  46. enumerable: true,
  47. configurable: true
  48. });
  49. Object.defineProperty(NetInput.prototype, "canvases", {
  50. get: function () {
  51. return this._canvases;
  52. },
  53. enumerable: true,
  54. configurable: true
  55. });
  56. Object.defineProperty(NetInput.prototype, "isBatchInput", {
  57. get: function () {
  58. return this.batchSize > 1 || this._treatAsBatchInput;
  59. },
  60. enumerable: true,
  61. configurable: true
  62. });
  63. Object.defineProperty(NetInput.prototype, "batchSize", {
  64. get: function () {
  65. return this._batchSize;
  66. },
  67. enumerable: true,
  68. configurable: true
  69. });
  70. Object.defineProperty(NetInput.prototype, "inputDimensions", {
  71. get: function () {
  72. return this._inputDimensions;
  73. },
  74. enumerable: true,
  75. configurable: true
  76. });
  77. Object.defineProperty(NetInput.prototype, "inputSize", {
  78. get: function () {
  79. return this._inputSize;
  80. },
  81. enumerable: true,
  82. configurable: true
  83. });
  84. Object.defineProperty(NetInput.prototype, "reshapedInputDimensions", {
  85. get: function () {
  86. var _this = this;
  87. return utils_1.range(this.batchSize, 0, 1).map(function (_, batchIdx) { return _this.getReshapedInputDimensions(batchIdx); });
  88. },
  89. enumerable: true,
  90. configurable: true
  91. });
  92. NetInput.prototype.getInput = function (batchIdx) {
  93. return this.canvases[batchIdx] || this.imageTensors[batchIdx];
  94. };
  95. NetInput.prototype.getInputDimensions = function (batchIdx) {
  96. return this._inputDimensions[batchIdx];
  97. };
  98. NetInput.prototype.getInputHeight = function (batchIdx) {
  99. return this._inputDimensions[batchIdx][0];
  100. };
  101. NetInput.prototype.getInputWidth = function (batchIdx) {
  102. return this._inputDimensions[batchIdx][1];
  103. };
  104. NetInput.prototype.getReshapedInputDimensions = function (batchIdx) {
  105. if (typeof this.inputSize !== 'number') {
  106. throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet');
  107. }
  108. var width = this.getInputWidth(batchIdx);
  109. var height = this.getInputHeight(batchIdx);
  110. return utils_1.computeReshapedDimensions({ width: width, height: height }, this.inputSize);
  111. };
  112. /**
  113. * Create a batch tensor from all input canvases and tensors
  114. * with size [batchSize, inputSize, inputSize, 3].
  115. *
  116. * @param inputSize Height and width of the tensor.
  117. * @param isCenterImage (optional, default: false) If true, add an equal amount of padding on
  118. * both sides of the minor dimension oof the image.
  119. * @returns The batch tensor.
  120. */
  121. NetInput.prototype.toBatchTensor = function (inputSize, isCenterInputs) {
  122. var _this = this;
  123. if (isCenterInputs === void 0) { isCenterInputs = true; }
  124. this._inputSize = inputSize;
  125. return tf.tidy(function () {
  126. var inputTensors = utils_1.range(_this.batchSize, 0, 1).map(function (batchIdx) {
  127. var input = _this.getInput(batchIdx);
  128. if (input instanceof tf.Tensor) {
  129. var imgTensor = utils_1.isTensor4D(input) ? input : input.expandDims();
  130. imgTensor = padToSquare_1.padToSquare(imgTensor, isCenterInputs);
  131. if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
  132. imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize]);
  133. }
  134. return imgTensor.as3D(inputSize, inputSize, 3);
  135. }
  136. if (input instanceof env_1.env.getEnv().Canvas) {
  137. return tf.browser.fromPixels(imageToSquare_1.imageToSquare(input, inputSize, isCenterInputs));
  138. }
  139. throw new Error("toBatchTensor - at batchIdx " + batchIdx + ", expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have " + input);
  140. });
  141. var batchTensor = tf.stack(inputTensors.map(function (t) { return t.toFloat(); })).as4D(_this.batchSize, inputSize, inputSize, 3);
  142. return batchTensor;
  143. });
  144. };
  145. return NetInput;
  146. }());
  147. exports.NetInput = NetInput;
  148. //# sourceMappingURL=NetInput.js.map