MaxPooling.h 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/Parallel.h>
  4. #include <ATen/native/DispatchStub.h>
  5. #include <ATen/native/Pool.h>
  6. namespace at::native {
  7. static void check_max_pool1d(
  8. const Tensor& self,
  9. IntArrayRef kernel_size,
  10. IntArrayRef stride,
  11. IntArrayRef padding,
  12. IntArrayRef dilation,
  13. bool ceil_mode) {
  14. TORCH_CHECK(
  15. self.dim() == 2 || self.dim() == 3,
  16. "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
  17. TORCH_CHECK(
  18. kernel_size.size() == 1,
  19. "max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
  20. kernel_size.size());
  21. TORCH_CHECK(
  22. stride.empty() || stride.size() == 1,
  23. "max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
  24. stride.size());
  25. TORCH_CHECK(
  26. padding.size() == 1,
  27. "max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
  28. padding.size());
  29. TORCH_CHECK(
  30. dilation.size() == 1,
  31. "max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
  32. dilation.size());
  33. // If stride=None then set it to kernel_size
  34. if (stride.empty()) {
  35. stride = kernel_size;
  36. }
  37. TORCH_CHECK(
  38. kernel_size[0] > 0,
  39. "max_pool1d() kernel_size must be greater than zero, but got ",
  40. kernel_size[0]);
  41. TORCH_CHECK(
  42. stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
  43. TORCH_CHECK(
  44. padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
  45. TORCH_CHECK(
  46. padding[0] <= kernel_size[0] / 2,
  47. "max_pool1d() padding should be at most half of kernel size, but got padding=",
  48. padding[0],
  49. " and kernel_size=",
  50. kernel_size[0]);
  51. TORCH_CHECK(
  52. dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
  53. const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
  54. TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
  55. }
  56. // TODO(Heitor) Template by dimension
  57. struct PoolingParams1D {
  58. int64_t NB; // Number of batches
  59. int64_t NC; // Number of channels
  60. int64_t IW; // Input width
  61. int64_t OW; // Output width
  62. int64_t KW; // Kernel width
  63. int64_t SJ; // Column stride
  64. int64_t PJ; // Column padding
  65. int64_t DJ; // Column dilation
  66. // Return index of input element for the given kernel and output index
  67. inline int64_t index(int64_t kj, int64_t oj) const {
  68. return oj * SJ + kj * DJ - PJ;
  69. }
  70. // Return index of first output within bounds for this kernel index
  71. inline int64_t valid_output_start(int64_t kj) const {
  72. int64_t ij = index(kj, 0);;
  73. return ij < 0 ? at::divup(-ij, SJ) : 0;
  74. }
  75. // Return index one past last output within bounds for this kernel index
  76. inline int64_t valid_output_end(int64_t kj) const {
  77. int64_t ij = index(kj, OW - 1);
  78. return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
  79. }
  80. };
  81. using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
  82. DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);
  83. } // namespace at::native