padToSquare.js 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import * as tf from '@tensorflow/tfjs-core';
  2. /**
  3. * Pads the smaller dimension of an image tensor with zeros, such that width === height.
  4. *
  5. * @param imgTensor The image tensor.
  6. * @param isCenterImage (optional, default: false) If true, add an equal amount of padding on
  7. * both sides of the minor dimension oof the image.
  8. * @returns The padded tensor with width === height.
  9. */
  10. export function padToSquare(imgTensor, isCenterImage) {
  11. if (isCenterImage === void 0) { isCenterImage = false; }
  12. return tf.tidy(function () {
  13. var _a = imgTensor.shape.slice(1), height = _a[0], width = _a[1];
  14. if (height === width) {
  15. return imgTensor;
  16. }
  17. var dimDiff = Math.abs(height - width);
  18. var paddingAmount = Math.round(dimDiff * (isCenterImage ? 0.5 : 1));
  19. var paddingAxis = height > width ? 2 : 1;
  20. var createPaddingTensor = function (paddingAmount) {
  21. var paddingTensorShape = imgTensor.shape.slice();
  22. paddingTensorShape[paddingAxis] = paddingAmount;
  23. return tf.fill(paddingTensorShape, 0);
  24. };
  25. var paddingTensorAppend = createPaddingTensor(paddingAmount);
  26. var remainingPaddingAmount = dimDiff - paddingTensorAppend.shape[paddingAxis];
  27. var paddingTensorPrepend = isCenterImage && remainingPaddingAmount
  28. ? createPaddingTensor(remainingPaddingAmount)
  29. : null;
  30. var tensorsToStack = [
  31. paddingTensorPrepend,
  32. imgTensor,
  33. paddingTensorAppend
  34. ]
  35. .filter(function (t) { return !!t; })
  36. .map(function (t) { return t.toFloat(); });
  37. return tf.concat(tensorsToStack, paddingAxis);
  38. });
  39. }
  40. //# sourceMappingURL=padToSquare.js.map