PixelShuffle.h 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. #include <ATen/core/Tensor.h>
  2. #include <c10/util/Exception.h>
  3. namespace at {
  4. namespace native {
  5. inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
  6. TORCH_CHECK(self.dim() >= 3,
  7. "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
  8. self.dim(), " dimension(s)");
  9. TORCH_CHECK(upscale_factor > 0,
  10. "pixel_shuffle expects a positive upscale_factor, but got ",
  11. upscale_factor);
  12. int64_t c = self.size(-3);
  13. int64_t upscale_factor_squared = upscale_factor * upscale_factor;
  14. TORCH_CHECK(c % upscale_factor_squared == 0,
  15. "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
  16. "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
  17. }
  18. inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
  19. TORCH_CHECK(
  20. self.dim() >= 3,
  21. "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
  22. self.dim(),
  23. " dimension(s)");
  24. TORCH_CHECK(
  25. downscale_factor > 0,
  26. "pixel_unshuffle expects a positive downscale_factor, but got ",
  27. downscale_factor);
  28. int64_t h = self.size(-2);
  29. int64_t w = self.size(-1);
  30. TORCH_CHECK(
  31. h % downscale_factor == 0,
  32. "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
  33. h,
  34. " is not divisible by ",
  35. downscale_factor);
  36. TORCH_CHECK(
  37. w % downscale_factor == 0,
  38. "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
  39. w,
  40. " is not divisible by ",
  41. downscale_factor);
  42. }
  43. }} // namespace at::native