UpSample.h 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. #pragma once
  2. #include <math.h>
  3. #include <ATen/OpMathType.h>
  4. #include <ATen/TensorUtils.h>
  5. #include <ATen/OpMathType.h>
  6. #include <ATen/core/Tensor.h>
  7. #include <ATen/cpu/vec/functional.h>
  8. #include <ATen/cpu/vec/vec.h>
  9. #include <ATen/native/DispatchStub.h>
  10. #include <ATen/native/cpu/utils.h>
  11. /**
  12. * Note [compute_scales_value]
  13. * Note [area_pixel_compute_scale]
  14. * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  15. * Interpolate with scale_factor can have different behaviors
  16. * depending on the value of recompute_scale_factor:
  17. *
  18. * - With recompute_scale_factor = True (current default behavior):
  19. * the scale_factor, when provided by the user, are used to calculate
  20. * the output size. The input size and the computed output_size
  21. * are then used to infer new values for the scales which are
  22. * used in the interpolation. Because floating-point math is not exact,
  23. * this may be a different value from the user-supplied scales.
  24. *
  25. * - With recompute_scale_factor = False (which will be the default
  26. * behavior starting 1.5.0):
  27. * the behavior follows opencv logic, and the scales provided by
  28. * the user are the ones used in the interpolation calculations.
  29. *
  30. * If the scales are not provided or if they are provided but
  31. * recompute_scale_factor is set to True (default behavior), the scales
  32. * are computed from the input and the output size;
  33. *
  34. *
  35. * When the scales are inferred from the input and output sizes,
  36. * we view each pixel as an area, idx + 0.5 as its center index.
  37. * Here is an example formula in 1D case.
  38. * if align_corners: center of two corner pixel areas are preserved,
  39. * (0.5, 0.5) -> (0.5, 0.5),
  40. * (input_size - 0.5, 0.5) -> (output_size - 0.5)
  41. * scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5)
  42. * src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
  43. * if not align_corners: the whole range is scaled accordingly
  44. * scale = input_size / output_size
  45. * src_idx + 0.5 = scale * (dst_index + 0.5)
  46. */
  47. namespace at::native {
  48. namespace upsample {
  49. TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
  50. c10::IntArrayRef input_size, // Full input tensor size.
  51. at::OptionalIntArrayRef output_size,
  52. std::optional<c10::ArrayRef<double>> scale_factors);
  53. inline std::optional<double> get_scale_value(std::optional<c10::ArrayRef<double>> scales, int idx) {
  54. if (!scales) {
  55. return c10::nullopt;
  56. }
  57. return scales->at(idx);
  58. }
  59. } // namespace upsample
  60. using scale_t = std::optional<double>;
  61. using upsampling_nearest1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
  62. using _upsampling_nearest_exact1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
  63. using upsampling_nearest2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
  64. using _upsampling_nearest_exact2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
  65. using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
  66. using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
  67. using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w);
  68. using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  69. using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  70. using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w);
  71. using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  72. using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  73. DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel);
  74. DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel);
  75. DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel);
  76. DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_kernel);
  77. DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_kernel);
  78. DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_kernel);
  79. DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_backward_kernel);
  80. DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_backward_kernel);
  81. DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_backward_kernel);
  82. DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_backward_kernel);
  83. DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel);
  84. DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel);
  85. DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel);
  86. DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel);
  87. DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel);
  88. DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel);
  89. DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel);
  90. DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel);
  91. DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel);
  92. DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel);
  93. DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel);
  94. DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel);
  95. DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel);
  96. static C10_UNUSED std::array<int64_t, 3> upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
  97. TORCH_CHECK(
  98. output_size.size() == 1,
  99. "It is expected output_size equals to 1, but got size ",
  100. output_size.size());
  101. TORCH_CHECK(
  102. input_size.size() == 3,
  103. "It is expected input_size equals to 3, but got size ",
  104. input_size.size());
  105. int64_t output_width = output_size[0];
  106. int64_t nbatch = input_size[0];
  107. int64_t channels = input_size[1];
  108. int64_t input_width = input_size[2];
  109. TORCH_CHECK(
  110. input_width > 0 && output_width > 0,
  111. "Input and output sizes should be greater than 0, but got input (W: ",
  112. input_width,
  113. ") and output (W: ",
  114. output_width,
  115. ")");
  116. return {nbatch, channels, output_width};
  117. }
  118. static C10_UNUSED std::array<int64_t, 4> upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
  119. TORCH_CHECK(
  120. output_size.size() == 2,
  121. "It is expected output_size equals to 2, but got size ",
  122. output_size.size());
  123. TORCH_CHECK(
  124. input_size.size() == 4,
  125. "It is expected input_size equals to 4, but got size ",
  126. input_size.size());
  127. int64_t output_height = output_size[0];
  128. int64_t output_width = output_size[1];
  129. int64_t nbatch = input_size[0];
  130. int64_t channels = input_size[1];
  131. int64_t input_height = input_size[2];
  132. int64_t input_width = input_size[3];
  133. TORCH_CHECK(
  134. input_height > 0 && input_width > 0 && output_height > 0 &&
  135. output_width > 0,
  136. "Input and output sizes should be greater than 0,"
  137. " but got input (H: ",
  138. input_height,
  139. ", W: ",
  140. input_width,
  141. ") output (H: ",
  142. output_height,
  143. ", W: ",
  144. output_width,
  145. ")");
  146. return {nbatch, channels, output_height, output_width};
  147. }
  148. static C10_UNUSED
  149. std::array<int64_t, 5> upsample_3d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
  150. TORCH_CHECK(
  151. output_size.size() == 3,
  152. "It is expected output_size equals to 3, but got size ",
  153. output_size.size());
  154. TORCH_CHECK(
  155. input_size.size() == 5,
  156. "It is expected input_size equals to 5, but got size ",
  157. input_size.size());
  158. int64_t output_depth = output_size[0];
  159. int64_t output_height = output_size[1];
  160. int64_t output_width = output_size[2];
  161. int64_t nbatch = input_size[0];
  162. int64_t channels = input_size[1];
  163. int64_t input_depth = input_size[2];
  164. int64_t input_height = input_size[3];
  165. int64_t input_width = input_size[4];
  166. TORCH_CHECK(
  167. input_depth > 0 && input_height > 0 && input_width > 0 &&
  168. output_depth > 0 && output_height > 0 && output_width > 0,
  169. "Input and output sizes should be greater than 0, but got input (D: ",
  170. input_depth,
  171. ", H: ",
  172. input_height,
  173. ", W: ",
  174. input_width,
  175. ") output (D: ",
  176. output_depth,
  177. ", H: ",
  178. output_height,
  179. ", W: ",
  180. output_width,
  181. ")");
  182. return {nbatch, channels, output_depth, output_height, output_width};
  183. }
  184. inline void upsample_2d_shape_check(
  185. const Tensor& input,
  186. const Tensor& grad_output,
  187. int64_t nbatch,
  188. int64_t nchannels,
  189. int64_t input_height,
  190. int64_t input_width,
  191. int64_t output_height,
  192. int64_t output_width) {
  193. TORCH_CHECK(
  194. input_height > 0 && input_width > 0 && output_height > 0 &&
  195. output_width > 0,
  196. "Input and output sizes should be greater than 0,"
  197. " but got input (H: ",
  198. input_height,
  199. ", W: ",
  200. input_width,
  201. ") output (H: ",
  202. output_height,
  203. ", W: ",
  204. output_width,
  205. ")");
  206. if (input.defined()) {
  207. // Allow for empty batch size but not other dimensions
  208. TORCH_CHECK(
  209. (input.numel() != 0 ||
  210. (input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0)
  211. ) &&
  212. input.dim() == 4,
  213. "Non-empty 4D data tensor expected but got a tensor with sizes ",
  214. input.sizes());
  215. } else if (grad_output.defined()) {
  216. check_dim_size(grad_output, 4, 0, nbatch);
  217. check_dim_size(grad_output, 4, 1, nchannels);
  218. check_dim_size(grad_output, 4, 2, output_height);
  219. check_dim_size(grad_output, 4, 3, output_width);
  220. }
  221. }
  222. template <typename scalar_t>
  223. inline scalar_t compute_scales_value(
  224. const std::optional<double> scale,
  225. int64_t input_size,
  226. int64_t output_size) {
  227. // see Note [compute_scales_value]
  228. // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
  229. return (scale.has_value() && scale.value() > 0.)
  230. ? static_cast<scalar_t>(1.0 / scale.value())
  231. : (static_cast<scalar_t>(input_size) / output_size);
  232. }
  233. template <typename scalar_t>
  234. inline scalar_t area_pixel_compute_scale(
  235. int64_t input_size,
  236. int64_t output_size,
  237. bool align_corners,
  238. const std::optional<double> scale) {
  239. // see Note [area_pixel_compute_scale]
  240. if(align_corners) {
  241. if(output_size > 1) {
  242. return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
  243. } else {
  244. return static_cast<scalar_t>(0);
  245. }
  246. } else {
  247. return compute_scales_value<scalar_t>(scale, input_size, output_size);
  248. }
  249. }
  250. template <typename scalar_t>
  251. inline scalar_t area_pixel_compute_source_index(
  252. scalar_t scale,
  253. int64_t dst_index,
  254. bool align_corners,
  255. bool cubic) {
  256. if (align_corners) {
  257. return scale * dst_index;
  258. } else {
  259. scalar_t src_idx = scale * (dst_index + static_cast<scalar_t>(0.5)) -
  260. static_cast<scalar_t>(0.5);
  261. // [Note] Follow Opencv resize logic:
  262. // We allow negative src_idx here and later will use
  263. // dx = src_idx - floorf(src_idx)
  264. // to compute the "distance"(which affects weights).
  265. // For linear modes, weight distribution doesn't matter
  266. // for negative indices as they use 2 pixels to interpolate.
  267. // For example, [-1, 0], they both use pixel 0 value so it
  268. // doesn't affect if we bound the src_idx to 0 or not.
  269. // TODO: Our current linear mode impls use unbound indices
  270. // where we should and then remove this cubic flag.
  271. // This matters in cubic mode, as we might need [-1, 0, 1, 2]
  272. // to interpolate and the weights can be affected.
  273. return (!cubic && src_idx < static_cast<scalar_t>(0)) ? scalar_t(0)
  274. : src_idx;
  275. }
  276. }
  277. inline int64_t nearest_neighbor_compute_source_index(
  278. const float scale,
  279. int64_t dst_index,
  280. int64_t input_size) {
  281. // Index computation matching OpenCV INTER_NEAREST
  282. // which is buggy and kept for BC
  283. const int64_t src_index =
  284. std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
  285. return src_index;
  286. }
  287. inline int64_t nearest_neighbor_exact_compute_source_index(
  288. const float scale,
  289. int64_t dst_index,
  290. int64_t input_size) {
  291. // index_f32 = (output_index + 0.5) * scale - 0.5
  292. // input_index = round(index_f32)
  293. // Same as Pillow and Scikit-Image/Scipy ndi.zoom
  294. const int64_t src_index =
  295. std::min(static_cast<int64_t>(floorf((dst_index + 0.5) * scale)), input_size - 1);
  296. return src_index;
  297. }
  298. inline int64_t nearest_idx(
  299. int64_t output_index,
  300. int64_t input_size,
  301. int64_t output_size,
  302. std::optional<double> scales) {
  303. // This method specificly treats cases: output_size == input_size or
  304. // output_size == 2 * input_size, that we would like to get rid of
  305. // We keep this method for BC and consider as deprecated.
  306. // See nearest_exact_idx as replacement
  307. if (output_size == input_size) {
  308. // scale_factor = 1, simply copy
  309. return output_index;
  310. } else if (output_size == 2 * input_size) {
  311. // scale_factor = 2, shift input index
  312. return output_index >> 1;
  313. } else {
  314. float scale = compute_scales_value<float>(scales, input_size, output_size);
  315. return nearest_neighbor_compute_source_index(scale, output_index, input_size);
  316. }
  317. }
  318. inline int64_t nearest_exact_idx(
  319. int64_t output_index,
  320. int64_t input_size,
  321. int64_t output_size,
  322. std::optional<double> scales) {
  323. float scale = compute_scales_value<float>(scales, input_size, output_size);
  324. return nearest_neighbor_exact_compute_source_index(scale, output_index, input_size);
  325. }
  326. // Define a typedef to dispatch to nearest_idx or nearest_exact_idx
  327. typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, std::optional<double>);
  328. template <typename scalar_t>
  329. static scalar_t upsample_get_value_bounded(
  330. scalar_t* data,
  331. int64_t width,
  332. int64_t height,
  333. int64_t x,
  334. int64_t y) {
  335. int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
  336. int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
  337. return data[access_y * width + access_x];
  338. }
  339. template <typename scalar_t>
  340. static void upsample_increment_value_bounded(
  341. scalar_t* data,
  342. int64_t width,
  343. int64_t height,
  344. int64_t x,
  345. int64_t y,
  346. scalar_t value) {
  347. int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
  348. int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
  349. data[access_y * width + access_x] += value;
  350. }
  351. // Based on
  352. // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
  353. template <typename scalar_t>
  354. inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
  355. return ((A + 2) * x - (A + 3)) * x * x + 1;
  356. }
  357. template <typename scalar_t>
  358. inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
  359. return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
  360. }
  361. template <typename scalar_t>
  362. inline void get_cubic_upsample_coefficients(
  363. scalar_t coeffs[4],
  364. scalar_t t) {
  365. scalar_t A = -0.75;
  366. scalar_t x1 = t;
  367. coeffs[0] = cubic_convolution2<scalar_t>(x1 + 1.0, A);
  368. coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
  369. // opposite coefficients
  370. scalar_t x2 = 1.0 - t;
  371. coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
  372. coeffs[3] = cubic_convolution2<scalar_t>(x2 + 1.0, A);
  373. }
  374. template <typename scalar_t>
  375. inline scalar_t cubic_interp1d(
  376. scalar_t x0,
  377. scalar_t x1,
  378. scalar_t x2,
  379. scalar_t x3,
  380. scalar_t t) {
  381. scalar_t coeffs[4];
  382. get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
  383. return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
  384. }
  385. // when `real_input_index` becomes larger than the range the floating point
  386. // type can accurately represent, the type casting to `int64_t` might exceed
  387. // `input_size`, causing overflow. So we guard it with `std::min` below.
  388. template<typename scalar_t, typename opmath_t>
  389. inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
  390. input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
  391. lambda = std::min(
  392. std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
  393. static_cast<opmath_t>(1)
  394. );
  395. }
  396. template<typename scalar_t, typename opmath_t>
  397. inline void compute_source_index_and_lambda(
  398. int64_t& input_index0,
  399. int64_t& input_index1,
  400. scalar_t& lambda0,
  401. scalar_t& lambda1,
  402. opmath_t ratio,
  403. int64_t output_index,
  404. int64_t input_size,
  405. int64_t output_size,
  406. bool align_corners) {
  407. if (output_size == input_size) {
  408. // scale_factor = 1, simply copy
  409. input_index0 = output_index;
  410. input_index1 = output_index;
  411. lambda0 = static_cast<scalar_t>(1);
  412. lambda1 = static_cast<scalar_t>(0);
  413. } else {
  414. const auto real_input_index =
  415. area_pixel_compute_source_index<opmath_t>(
  416. ratio, output_index, align_corners, /*cubic=*/false);
  417. guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
  418. int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
  419. input_index1 = input_index0 + offset;
  420. lambda0 = static_cast<scalar_t>(1.) - lambda1;
  421. }
  422. }
  423. // It will not be used by data types other than BFloat16 and Half.
  424. template <typename scalar_in, typename scalar_out,
  425. typename std::enable_if_t<!is_reduced_floating_point_v<scalar_out> || !std::is_same<scalar_in, float>::value, int> = 0>
  426. void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
  427. TORCH_CHECK((is_reduced_floating_point_v<scalar_out>),
  428. "Upsample backward only support BFloat16 and Half in the lower precision data types on CPU.")
  429. TORCH_CHECK((std::is_same<scalar_in, float>::value),
  430. "Upsample backward should use float as acc buffer for BFloat16 and Half grad input on CPU.")
  431. return;
  432. }
  433. template <typename scalar_in, typename scalar_out,
  434. typename std::enable_if_t<is_reduced_floating_point_v<scalar_out> && std::is_same<scalar_in, float>::value, int> = 0>
  435. void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
  436. using bVec = Vectorized<scalar_out>;
  437. using fVec = Vectorized<float>;
  438. int64_t d = 0;
  439. for (; d < size - (size % bVec::size()); d += bVec::size()) {
  440. bVec gin_bvec = bVec::loadu(gin + d);
  441. fVec gin_fvec0, gin_fvec1;
  442. std::tie(gin_fvec0, gin_fvec1) = convert_to_float<scalar_out>(gin_bvec);
  443. gin_fvec0 += fVec::loadu(buffer_ptr + d);
  444. gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
  445. fVec(0).store(buffer_ptr + d);
  446. fVec(0).store(buffer_ptr + d + fVec::size());
  447. convert_from_float<scalar_out>(gin_fvec0, gin_fvec1).store(gin + d);
  448. }
  449. for (; d < size; d++) {
  450. gin[d] += buffer_ptr[d];
  451. buffer_ptr[d] = 0;
  452. }
  453. }
  454. } // namespace at::native