DilatedConvolutionUtils.h 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. #pragma once
  2. #include <algorithm>
  3. #include <vector>
  4. #include <ATen/div_rtn.h>
  5. #include <ATen/core/Tensor.h>
  6. #include <c10/util/irange.h>
  7. #define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
  8. TORCH_CHECK( \
  9. T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \
  10. "Need " #T " of dimension ", \
  11. DIM, \
  12. " and " #T ".size[", \
  13. DIM_SIZE, \
  14. "] == ", \
  15. SIZE, \
  16. " but got input to be of shape ", \
  17. T.sizes())
  18. namespace at::native::internal {
  19. namespace {
  20. inline bool all_positive(IntArrayRef& arr) {
  21. return std::all_of(
  22. arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
  23. }
  24. inline bool all_nonnegative(std::vector<int64_t>& arr) {
  25. return std::all_of(
  26. arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
  27. }
  28. } // namespace
  29. // calculate the rear part of output tensor sizes
  30. template <int64_t dim>
  31. std::vector<int64_t> get_output_size(
  32. const Tensor& input,
  33. IntArrayRef kernel_size,
  34. IntArrayRef stride_size,
  35. IntArrayRef pad_size,
  36. IntArrayRef dilation_size) {
  37. std::vector<int64_t> sizes;
  38. for (const auto index : c10::irange(dim)) {
  39. sizes.push_back(
  40. div_rtn<int64_t>(
  41. input.size(index + input.dim() - dim) + 2 * pad_size[index] -
  42. (dilation_size[index] * (kernel_size[index] - 1) + 1),
  43. stride_size[index]) +
  44. 1);
  45. }
  46. return sizes;
  47. }
  48. // calculate the sizes of output tensor
  49. template <int64_t dim>
  50. std::vector<int64_t> get_output_size(
  51. const Tensor& input,
  52. const Tensor& weight,
  53. IntArrayRef kernel_size,
  54. IntArrayRef stride_size,
  55. IntArrayRef pad_size,
  56. IntArrayRef dilation_size) {
  57. auto output_size = get_output_size<dim>(
  58. input, kernel_size, stride_size, pad_size, dilation_size);
  59. output_size.insert(output_size.begin(), weight.size(0));
  60. if (input.dim() == dim + 2) {
  61. output_size.insert(output_size.begin(), input.size(0));
  62. }
  63. return output_size;
  64. }
  65. /*
  66. slow_conv_dilated_shape_check - check user-input to dilated convolution
  67. forward and backward functions.
  68. */
  69. template <int64_t dim>
  70. void slow_conv_dilated_shape_check(
  71. const Tensor& input,
  72. const Tensor& weight,
  73. const Tensor& bias,
  74. const Tensor& grad_output,
  75. IntArrayRef kernel_size,
  76. IntArrayRef stride_size,
  77. IntArrayRef pad_size,
  78. IntArrayRef dilation_size) {
  79. /*
  80. When the following tensors are defined:
  81. bias, grad_weight, grad_output
  82. then these are assumed to be contiguous without checking
  83. because of these tensors are made contiguous by calling
  84. .contiguous() method or by resizing of zero-sized tensors in
  85. forward/backward functions.
  86. When grad_weight is defined then it is assumed without
  87. checking to have the same shape as weight, see backward
  88. functions.
  89. */
  90. // Check size arguments
  91. TORCH_CHECK(
  92. kernel_size.size() == dim,
  93. "kernel sizes length should be ",
  94. dim,
  95. ", but got ",
  96. kernel_size.size());
  97. TORCH_CHECK(
  98. stride_size.size() == dim,
  99. "strides length should be ",
  100. dim,
  101. ", but got ",
  102. stride_size.size());
  103. TORCH_CHECK(
  104. dilation_size.size() == dim,
  105. "dilations length should be ",
  106. dim,
  107. ", but got ",
  108. dilation_size.size());
  109. TORCH_CHECK(
  110. pad_size.size() == dim,
  111. "pads length should be ",
  112. dim,
  113. ", but got ",
  114. pad_size.size());
  115. TORCH_CHECK(
  116. all_positive(kernel_size),
  117. "kernel size should be greater than zero, but got ",
  118. kernel_size);
  119. TORCH_CHECK(
  120. all_positive(stride_size),
  121. "stride should be greater than zero, but got ",
  122. stride_size);
  123. TORCH_CHECK(
  124. all_positive(dilation_size),
  125. "dilation should be greater than zero, but got ",
  126. dilation_size);
  127. // check input
  128. TORCH_CHECK(input.defined(), "input must be defined");
  129. bool is_batch = input.dim() == dim + 2;
  130. int64_t n = (is_batch ? 2 : 1);
  131. int64_t ndim = n + dim;
  132. if (!is_batch) {
  133. // input dim has to be dim + 1 if not batched
  134. TORCH_CHECK(
  135. input.dim() == dim + 1,
  136. "input must be 4D or 5D tensor but got ",
  137. input.dim(),
  138. "D tensor");
  139. }
  140. // check output sizes
  141. auto output_size = get_output_size<dim>(
  142. input, kernel_size, stride_size, pad_size, dilation_size);
  143. TORCH_CHECK(
  144. all_nonnegative(output_size),
  145. "calculated output size ",
  146. output_size,
  147. " is too small (all sizes must be non-negative)");
  148. // check weight
  149. TORCH_CHECK(weight.defined(), "weight must be defined");
  150. TORCH_CHECK(
  151. weight.dim() == dim + 2,
  152. "weight must be ",
  153. dim + 2,
  154. "D tensor but got ",
  155. weight.dim(),
  156. "D tensor dim=",
  157. dim);
  158. TORCH_CHECK(
  159. weight.sizes().slice(2) == kernel_size,
  160. "weight[2:] shape ",
  161. weight.sizes().slice(2),
  162. " must be equal to kernel_size ",
  163. kernel_size);
  164. TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
  165. // check bias when present
  166. if (bias.defined()) {
  167. TORCH_CHECK(
  168. bias.dim() == 1,
  169. "bias must be 1D tensor but got ",
  170. bias.dim(),
  171. "D tensor");
  172. TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
  173. }
  174. // check grad_output when present
  175. if (grad_output.defined()) {
  176. TORCH_CHECK(
  177. grad_output.dim() == ndim,
  178. "grad_output must be ",
  179. ndim,
  180. "D tensor but got ",
  181. grad_output.dim(),
  182. "D tensor");
  183. if (is_batch) {
  184. TORCH_CHECK(
  185. grad_output.size(0) == input.size(0),
  186. "grad_output.size(0)=",
  187. grad_output.size(0),
  188. " must be input.size(0)=",
  189. input.size(0));
  190. }
  191. TORCH_CHECK(
  192. grad_output.size(n - 1) == weight.size(0),
  193. "grad_output.size(",
  194. n - 1,
  195. ")=",
  196. grad_output.size(n - 1),
  197. " must be weight.size(0)=",
  198. weight.size(0));
  199. TORCH_CHECK(
  200. grad_output.sizes().slice(n) == output_size,
  201. "grad_output[",
  202. n,
  203. ":] shape",
  204. grad_output.sizes().slice(n),
  205. " must be equal to output size ",
  206. output_size);
  207. }
  208. }
  209. } // namespace at::native::internal