ConvUtils.h 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. #pragma once
  2. #include <ATen/core/List.h>
  3. #include <ATen/native/ConvUtils.h>
  4. namespace at::native::quantized {
  5. namespace {
  6. // MakeConvOutputShape used from both CPU and CUDA libraries
  7. // and exporting symbol from torch_cpu would probably take more storage
  8. // than duplicating implementation which likely be inlined away
  9. template <int kSpatialDim>
  10. at::SmallVector<int64_t, kSpatialDim + 2> MakeConvOutputShape(
  11. int N, // mini-batch
  12. int M, // output channels
  13. const std::array<int64_t, kSpatialDim>& input_image_shape,
  14. const std::vector<int64_t>& kernel,
  15. const torch::List<int64_t>& stride,
  16. const torch::List<int64_t>& padding,
  17. const torch::List<int64_t>& dilation);
  18. #if defined(USE_CUDA) || defined(USE_PYTORCH_QNNPACK)
  19. template <>
  20. at::SmallVector<int64_t, 4> MakeConvOutputShape<2>(
  21. int N, // mini-batch
  22. int M, // output channels
  23. const std::array<int64_t, 2>& input_image_shape,
  24. const std::vector<int64_t>& kernel,
  25. const at::List<int64_t>& stride,
  26. const at::List<int64_t>& padding,
  27. const at::List<int64_t>& dilation) {
  28. const int H = input_image_shape[0];
  29. const int W = input_image_shape[1];
  30. const int64_t Y_H =
  31. (H + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1;
  32. const int64_t Y_W =
  33. (W + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1;
  34. return {N, M, Y_H, Y_W};
  35. }
  36. template <>
  37. at::SmallVector<int64_t, 5> MakeConvOutputShape<3>(
  38. int N, // mini-batch
  39. int M, // output channels
  40. const std::array<int64_t, 3>& input_image_shape,
  41. const std::vector<int64_t>& kernel,
  42. const at::List<int64_t>& stride,
  43. const at::List<int64_t>& padding,
  44. const torch::List<int64_t>& dilation) {
  45. const int D = input_image_shape[0];
  46. const int H = input_image_shape[1];
  47. const int W = input_image_shape[2];
  48. const int64_t Y_D =
  49. (D + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1;
  50. const int64_t Y_H =
  51. (H + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1;
  52. const int64_t Y_W =
  53. (W + 2 * padding[2] - dilation[2] * (kernel[2] - 1) - 1) / stride[2] + 1;
  54. return {N, M, Y_D, Y_H, Y_W};
  55. }
  56. #endif
  57. } // anonymous namespace
  58. } // namespace at::native::quantized