NetInput.js 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import * as tf from '@tensorflow/tfjs-core';
  2. import { env } from '../env';
  3. import { padToSquare } from '../ops/padToSquare';
  4. import { computeReshapedDimensions, isTensor3D, isTensor4D, range } from '../utils';
  5. import { createCanvasFromMedia } from './createCanvas';
  6. import { imageToSquare } from './imageToSquare';
  7. var NetInput = /** @class */ (function () {
  8. function NetInput(inputs, treatAsBatchInput) {
  9. var _this = this;
  10. if (treatAsBatchInput === void 0) { treatAsBatchInput = false; }
  11. this._imageTensors = [];
  12. this._canvases = [];
  13. this._treatAsBatchInput = false;
  14. this._inputDimensions = [];
  15. if (!Array.isArray(inputs)) {
  16. throw new Error("NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have " + inputs);
  17. }
  18. this._treatAsBatchInput = treatAsBatchInput;
  19. this._batchSize = inputs.length;
  20. inputs.forEach(function (input, idx) {
  21. if (isTensor3D(input)) {
  22. _this._imageTensors[idx] = input;
  23. _this._inputDimensions[idx] = input.shape;
  24. return;
  25. }
  26. if (isTensor4D(input)) {
  27. var batchSize = input.shape[0];
  28. if (batchSize !== 1) {
  29. throw new Error("NetInput - tf.Tensor4D with batchSize " + batchSize + " passed, but not supported in input array");
  30. }
  31. _this._imageTensors[idx] = input;
  32. _this._inputDimensions[idx] = input.shape.slice(1);
  33. return;
  34. }
  35. var canvas = input instanceof env.getEnv().Canvas ? input : createCanvasFromMedia(input);
  36. _this._canvases[idx] = canvas;
  37. _this._inputDimensions[idx] = [canvas.height, canvas.width, 3];
  38. });
  39. }
  40. Object.defineProperty(NetInput.prototype, "imageTensors", {
  41. get: function () {
  42. return this._imageTensors;
  43. },
  44. enumerable: true,
  45. configurable: true
  46. });
  47. Object.defineProperty(NetInput.prototype, "canvases", {
  48. get: function () {
  49. return this._canvases;
  50. },
  51. enumerable: true,
  52. configurable: true
  53. });
  54. Object.defineProperty(NetInput.prototype, "isBatchInput", {
  55. get: function () {
  56. return this.batchSize > 1 || this._treatAsBatchInput;
  57. },
  58. enumerable: true,
  59. configurable: true
  60. });
  61. Object.defineProperty(NetInput.prototype, "batchSize", {
  62. get: function () {
  63. return this._batchSize;
  64. },
  65. enumerable: true,
  66. configurable: true
  67. });
  68. Object.defineProperty(NetInput.prototype, "inputDimensions", {
  69. get: function () {
  70. return this._inputDimensions;
  71. },
  72. enumerable: true,
  73. configurable: true
  74. });
  75. Object.defineProperty(NetInput.prototype, "inputSize", {
  76. get: function () {
  77. return this._inputSize;
  78. },
  79. enumerable: true,
  80. configurable: true
  81. });
  82. Object.defineProperty(NetInput.prototype, "reshapedInputDimensions", {
  83. get: function () {
  84. var _this = this;
  85. return range(this.batchSize, 0, 1).map(function (_, batchIdx) { return _this.getReshapedInputDimensions(batchIdx); });
  86. },
  87. enumerable: true,
  88. configurable: true
  89. });
  90. NetInput.prototype.getInput = function (batchIdx) {
  91. return this.canvases[batchIdx] || this.imageTensors[batchIdx];
  92. };
  93. NetInput.prototype.getInputDimensions = function (batchIdx) {
  94. return this._inputDimensions[batchIdx];
  95. };
  96. NetInput.prototype.getInputHeight = function (batchIdx) {
  97. return this._inputDimensions[batchIdx][0];
  98. };
  99. NetInput.prototype.getInputWidth = function (batchIdx) {
  100. return this._inputDimensions[batchIdx][1];
  101. };
  102. NetInput.prototype.getReshapedInputDimensions = function (batchIdx) {
  103. if (typeof this.inputSize !== 'number') {
  104. throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet');
  105. }
  106. var width = this.getInputWidth(batchIdx);
  107. var height = this.getInputHeight(batchIdx);
  108. return computeReshapedDimensions({ width: width, height: height }, this.inputSize);
  109. };
  110. /**
  111. * Create a batch tensor from all input canvases and tensors
  112. * with size [batchSize, inputSize, inputSize, 3].
  113. *
  114. * @param inputSize Height and width of the tensor.
  115. * @param isCenterImage (optional, default: false) If true, add an equal amount of padding on
  116. * both sides of the minor dimension oof the image.
  117. * @returns The batch tensor.
  118. */
  119. NetInput.prototype.toBatchTensor = function (inputSize, isCenterInputs) {
  120. var _this = this;
  121. if (isCenterInputs === void 0) { isCenterInputs = true; }
  122. this._inputSize = inputSize;
  123. return tf.tidy(function () {
  124. var inputTensors = range(_this.batchSize, 0, 1).map(function (batchIdx) {
  125. var input = _this.getInput(batchIdx);
  126. if (input instanceof tf.Tensor) {
  127. var imgTensor = isTensor4D(input) ? input : input.expandDims();
  128. imgTensor = padToSquare(imgTensor, isCenterInputs);
  129. if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
  130. imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize]);
  131. }
  132. return imgTensor.as3D(inputSize, inputSize, 3);
  133. }
  134. if (input instanceof env.getEnv().Canvas) {
  135. return tf.browser.fromPixels(imageToSquare(input, inputSize, isCenterInputs));
  136. }
  137. throw new Error("toBatchTensor - at batchIdx " + batchIdx + ", expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have " + input);
  138. });
  139. var batchTensor = tf.stack(inputTensors.map(function (t) { return t.toFloat(); })).as4D(_this.batchSize, inputSize, inputSize, 3);
  140. return batchTensor;
  141. });
  142. };
  143. return NetInput;
  144. }());
  145. export { NetInput };
  146. //# sourceMappingURL=NetInput.js.map