ScanUtils.cuh 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. #pragma once
  2. #include <ATen/NumericUtils.h>
  3. #include <ATen/core/TensorBase.h>
  4. #include <ATen/cuda/cub.cuh>
  5. #include <ATen/cuda/CUDAContext.h>
  6. #include <c10/util/Load.h>
  7. #include <limits>
  8. #include <cmath>
  9. namespace at {
  10. namespace native {
  11. template <typename integer>
  12. constexpr inline integer ceil_div(integer n, integer m) {
  13. return (n + m - 1) / m;
  14. }
  15. template <typename integer>
  16. constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) {
  17. integer log_num_threads_x = 0;
  18. integer log_num_threads_y = 0;
  19. while (((integer)1 << log_num_threads_x) < row_size) {
  20. ++log_num_threads_x;
  21. }
  22. while (((integer)1 << log_num_threads_y) < num_rows) {
  23. ++log_num_threads_y;
  24. }
  25. // we want to keep the ratio between the x-threads and y-threads about the same as
  26. // the ratio between the row_size and num_rows, but the total number of threads in
  27. // a block should be about 512
  28. integer diff = log_num_threads_x - log_num_threads_y;
  29. // 9 is from log2(512)
  30. log_num_threads_x = ((integer)9 + diff) / (integer)2;
  31. // I found that in having larger log_num_threads_x can give significant speed up in some cases,
  32. // but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it
  33. // similar to the previous implementation
  34. // Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block.
  35. log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9);
  36. return log_num_threads_x;
  37. }
  38. template<typename scalar_t, typename idx_t, typename BinaryOperation>
  39. __device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) {
  40. if(!at::_isnan(rhs) && (at::_isnan(lhs) || !binary_op(rhs, lhs))) {
  41. rhs = lhs;
  42. rhs_idx = lhs_idx;
  43. }
  44. }
  45. /* Perform an inclusive scan along the innermost dimension of a tensor.
  46. *
  47. * - num_rows is the size of the flattened outer dimensions;
  48. * - row_size is the size of the innermost dimension;
  49. *
  50. * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
  51. * considered as having 'num_rows' rows of size 'row_size'.
  52. * Each thread block processes one or more sets of contiguous rows (processing multiple rows
  53. * per thread block is quicker than processing a single row, especially for short rows).
  54. */
  55. template<typename scalar_t, class BinaryFunction>
  56. __global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
  57. int num_rows, int row_size,
  58. const uint32_t num_threads, const uint32_t log_num_threads_x,
  59. scalar_t init, BinaryFunction binary_op) {
  60. // dynamic memory allocation for vbuf and ibuf
  61. alignas(sizeof(double)) extern __shared__ char buf[];
  62. scalar_t* vbuf = reinterpret_cast<scalar_t*>(buf); // the size is num_threads * 2
  63. int64_t* ibuf = reinterpret_cast<int64_t*>(vbuf + num_threads * 2);
  64. const uint32_t num_threads_x = 1 << log_num_threads_x;
  65. scalar_t* row_buf = vbuf + 2 * num_threads_x * threadIdx.y;
  66. int64_t* row_idx_buf = ibuf + 2 * num_threads_x * threadIdx.y;
  67. for (int block_row = blockIdx.x * blockDim.y;
  68. block_row < num_rows;
  69. block_row += blockDim.y * gridDim.x) {
  70. int row = block_row + threadIdx.y;
  71. const scalar_t *row_self = self_ + row * row_size;
  72. scalar_t *row_values = values_ + row * row_size;
  73. int64_t *row_indices = indices_ + row * row_size;
  74. scalar_t block_total = init;
  75. int64_t block_idx_final = 0;
  76. const bool row_exists = row < num_rows;
  77. // Perform scan on one block at a time, keeping track of the total value of
  78. // all blocks processed so far.
  79. for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
  80. // Load data into shared memory (two values per thread).
  81. int col1 = block_col + threadIdx.x;
  82. int col2 = block_col + num_threads_x + threadIdx.x;
  83. if (row_exists) {
  84. if (col1 < row_size) {
  85. row_buf[threadIdx.x] = c10::load(&row_self[col1]);
  86. row_idx_buf[threadIdx.x] = col1;
  87. } else {
  88. row_buf[threadIdx.x] = init;
  89. // No need to set the index here as the value in init will never be selected
  90. }
  91. if (col2 < row_size) {
  92. row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]);
  93. row_idx_buf[num_threads_x + threadIdx.x] = col2;
  94. } else {
  95. row_buf[num_threads_x + threadIdx.x] = init;
  96. // No need to set the index here as the value in init will never be selected
  97. }
  98. // Add the total value of all previous blocks to the first value of this block.
  99. if (threadIdx.x == 0) {
  100. binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op);
  101. }
  102. }
  103. __syncthreads();
  104. // Parallel reduction with Sklansky method. The diagram can be seen on this paper:
  105. // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
  106. for (uint32_t s = 1; s <= num_threads_x; s <<= 1) {
  107. if (row_exists) {
  108. uint32_t a = (threadIdx.x / s) * (2 * s) + s;
  109. uint32_t ti = a + (threadIdx.x % s);
  110. uint32_t si = a - 1;
  111. binary_op_update(row_buf[si], row_buf[ti], row_idx_buf[si], row_idx_buf[ti], binary_op);
  112. }
  113. __syncthreads();
  114. }
  115. // Write back to output.
  116. if (row_exists) {
  117. if (col1 < row_size){
  118. row_values[col1] = row_buf[threadIdx.x];
  119. row_indices[col1] = row_idx_buf[threadIdx.x];
  120. }
  121. if (col2 < row_size) {
  122. row_values[col2] = row_buf[num_threads_x + threadIdx.x];
  123. row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x];
  124. }
  125. }
  126. block_total = row_buf[2 * num_threads_x - 1];
  127. block_idx_final = row_idx_buf[2 * num_threads_x - 1];
  128. __syncthreads();
  129. }
  130. }
  131. }
  132. /* Perform an inclusive scan along an outer dimension of a tensor.
  133. *
  134. * - num_orows is the size of the flattened outer dimensions;
  135. * - num_irows is the size of the flattened inner dimensions;
  136. * - row_size is the size of the dimension along which to compute the variance;
  137. *
  138. * The dimensions to the outside and inside of the specified dimension are considered as flattened.
  139. * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
  140. * outer dimensions, which contains several "inner rows").
  141. * Each thread processes a single inner row at a time.
  142. */
  143. template<typename scalar_t, class BinaryFunction>
  144. __global__ void tensor_kernel_scan_outer_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
  145. const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) {
  146. for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
  147. for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
  148. const scalar_t *self = self_ + orow * row_size * num_irows + irow;
  149. scalar_t *values = values_ + orow * row_size * num_irows + irow;
  150. int64_t *indices = indices_ + orow * row_size * num_irows + irow;
  151. scalar_t out = init;
  152. int64_t out_idx = 0;
  153. for (auto col = decltype(row_size){0}; col < row_size; ++col) {
  154. const auto val = c10::load(self);
  155. if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) {
  156. out = val;
  157. out_idx = col;
  158. }
  159. *values = out;
  160. *indices = out_idx;
  161. self += num_irows;
  162. values += num_irows;
  163. indices += num_irows;
  164. }
  165. }
  166. }
  167. }
  168. inline void check_fits_in_unsigned(int64_t val, const char* name) {
  169. constexpr auto umax = std::numeric_limits<uint32_t>::max();
  170. TORCH_CHECK(
  171. val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value");
  172. }
  173. template<typename scalar_t, class BinaryFunction>
  174. __host__ void scan_outer_dim_with_indices(
  175. const TensorBase& self, const TensorBase& values, const TensorBase& indices,
  176. int dim, scalar_t init, BinaryFunction binary_op) {
  177. int64_t row_size = self.size(dim);
  178. auto sizes = self.sizes();
  179. // Treat all outer dimensions (i.e. dim_ < dim) as one.
  180. const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
  181. // Treat all inner dimensions (i.e. dim > dimension) as one.
  182. const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
  183. //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row,
  184. //make sure that input is not bigger than supported by uint32_t
  185. check_fits_in_unsigned(num_irows, "num_irows");
  186. check_fits_in_unsigned(num_orows, "num_orows");
  187. check_fits_in_unsigned(row_size, "row_size");
  188. dim3 threads(std::min(512, int(num_irows)));
  189. int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  190. dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
  191. tensor_kernel_scan_outer_dim_with_indices<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
  192. self.const_data_ptr<scalar_t>(), values.mutable_data_ptr<scalar_t>(), indices.mutable_data_ptr<int64_t>(),
  193. num_orows, num_irows, row_size, init, binary_op);
  194. C10_CUDA_KERNEL_LAUNCH_CHECK();
  195. }
  196. template <typename scalar_t, class BinaryFunction>
  197. __host__ void scan_innermost_dim_with_indices(
  198. const TensorBase& self, const TensorBase& values, const TensorBase& indices,
  199. scalar_t init, BinaryFunction binary_op) {
  200. int ndim = self.dim();
  201. // Treat all outer dimensions as a single dimension.
  202. int row_size = self.size(ndim - 1);
  203. int num_rows = self.numel() / row_size;
  204. // assuming max_num_threads per block is 512
  205. const uint32_t num_threads = 512;
  206. const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);
  207. const uint32_t num_threads_x = (1 << log_num_threads_x);
  208. const uint32_t num_threads_y = num_threads / num_threads_x;
  209. dim3 threads(num_threads_x, num_threads_y);
  210. dim3 grid(std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y))));
  211. const uint32_t mem_size = 2 * num_threads * (sizeof(scalar_t) + sizeof(int64_t));
  212. tensor_kernel_scan_innermost_dim_with_indices<scalar_t><<<grid, threads, mem_size,
  213. at::cuda::getCurrentCUDAStream()>>>(
  214. self.const_data_ptr<scalar_t>(), values.mutable_data_ptr<scalar_t>(), indices.mutable_data_ptr<int64_t>(),
  215. num_rows, row_size, num_threads, log_num_threads_x, init, binary_op);
  216. C10_CUDA_KERNEL_LAUNCH_CHECK();
  217. }
  218. template<typename scalar_t, typename BinaryFunction>
  219. void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, const TensorBase& indices, //int64_t dim) {
  220. int64_t dim, scalar_t init, BinaryFunction binary_op) {
  221. int ndim = self.dim();
  222. auto self_ = self.expect_contiguous();
  223. TORCH_INTERNAL_ASSERT(values.is_contiguous() && indices.is_contiguous());
  224. if (dim == ndim - 1) {
  225. scan_innermost_dim_with_indices<scalar_t>(*self_, values, indices, init, binary_op);
  226. } else {
  227. scan_outer_dim_with_indices<scalar_t>(*self_, values, indices, dim, init, binary_op);
  228. }
  229. }
  230. // TODO: The implementation of `tensor_kernel_scan_outer_dim` and
  231. // `tensor_kernel_scan_innermost_dim` is similar to
  232. // `tensor_kernel_scan_outer_dim_with_indices`
  233. // `tensor_kernel_scan_outer_dim_with_indices` and should be refactored to
  234. // remove the duplication.
  235. /* Perform an inclusive scan along an outer dimension of a tensor.
  236. *
  237. * - num_orows is the size of the flattened outer dimensions;
  238. * - num_irows is the size of the flattened inner dimensions;
  239. * - row_size is the size of the dimension along which to scan;
  240. *
  241. * The dimensions to the outside and inside of the specified dimension are considered as flattened.
  242. * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
  243. * outer dimensions, which contains several "inner rows").
  244. * Each thread processes a single inner row at a time.
  245. */
  246. template<typename scalar_t, class BinaryOp>
  247. __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
  248. const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
  249. const scalar_t init, BinaryOp binary_op)
  250. {
  251. for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
  252. for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
  253. const scalar_t *src = src_ + orow * row_size * num_irows + irow;
  254. scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
  255. scalar_t acc = init;
  256. for (uint32_t col = 0; col < row_size; ++col) {
  257. acc = binary_op(acc, c10::load(src));
  258. *tgt = acc;
  259. src += num_irows;
  260. tgt += num_irows;
  261. }
  262. }
  263. }
  264. }
  265. /* Perform an inclusive scan along the innermost dimension of a tensor.
  266. *
  267. * - num_rows is the size of the flattened outer dimensions;
  268. * - row_size is the size of the innermost dimension;
  269. *
  270. * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
  271. * considered as having 'num_rows' rows of size 'row_size'.
  272. * Each thread block processes one or more sets of contiguous rows (processing multiple rows
  273. * per thread block is quicker than processing a single row, especially for short rows).
  274. */
  275. template<typename T, class BinaryFunction>
  276. __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, const T *src_,
  277. const uint32_t num_rows, const uint32_t row_size,
  278. const uint32_t log_num_threads_x,
  279. T init, BinaryFunction binary_op){
  280. const uint32_t num_threads_x = 1 << log_num_threads_x;
  281. for (uint32_t block_row = blockIdx.x * blockDim.y;
  282. block_row < num_rows;
  283. block_row += blockDim.y * gridDim.x) {
  284. uint32_t row = block_row + threadIdx.y;
  285. T block_total = init;
  286. const T *row_src = src_ + row * row_size;
  287. T *row_tgt = tgt_ + row * row_size;
  288. const bool row_exists = row < num_rows;
  289. // Perform scan on one block at a time, keeping track of the total value of
  290. // all blocks processed so far.
  291. for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
  292. // Load data into shared memory (two values per thread).
  293. uint32_t col1 = block_col + threadIdx.x;
  294. uint32_t col2 = block_col + num_threads_x + threadIdx.x;
  295. if (row_exists) {
  296. if (col1 < row_size) {
  297. row_buf[threadIdx.x] = row_src[col1];
  298. } else {
  299. row_buf[threadIdx.x] = init;
  300. }
  301. if (col2 < row_size) {
  302. row_buf[num_threads_x + threadIdx.x] = row_src[col2];
  303. } else {
  304. row_buf[num_threads_x + threadIdx.x] = init;
  305. }
  306. // Add the total value of all previous blocks to the first value of this block.
  307. if (threadIdx.x == 0) {
  308. row_buf[0] = binary_op(row_buf[0], block_total);
  309. }
  310. }
  311. __syncthreads();
  312. // Parallel reduction with Sklansky method. The diagram can be seen on this paper:
  313. // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
  314. for (uint32_t m = 0; m <= log_num_threads_x; ++m) {
  315. if (row_exists) {
  316. uint32_t s = 1 << m; // s = 2 ^ m
  317. uint32_t a = ((threadIdx.x >> m) << (m + 1)) | s; // a = (threadIdx.x / s) * (2 * s) + s
  318. uint32_t ti = a + (threadIdx.x % s);
  319. uint32_t si = a - 1;
  320. row_buf[ti] = binary_op(row_buf[ti], row_buf[si]);
  321. }
  322. __syncthreads();
  323. }
  324. // Write back to output.
  325. if (row_exists) {
  326. if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];
  327. if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];
  328. }
  329. block_total = row_buf[2 * num_threads_x - 1];
  330. __syncthreads();
  331. }
  332. }
  333. }
  334. template <
  335. typename T,
  336. class BinaryFunction>
  337. __global__ void tensor_kernel_scan_innermost_dim(
  338. T* tgt_,
  339. const T* src_,
  340. const uint32_t num_rows,
  341. const uint32_t row_size,
  342. const uint32_t log_num_threads_x,
  343. T init,
  344. BinaryFunction binary_op) {
  345. alignas(sizeof(double)) extern __shared__ char sbuf[];
  346. T* sbuf2 = reinterpret_cast<T*>(sbuf);
  347. const uint32_t num_threads_x = 1 << log_num_threads_x;
  348. T* row_buf = reinterpret_cast<T*>(sbuf2 + num_threads_x * 2 * threadIdx.y);
  349. tensor_kernel_scan_innermost_dim_impl<T>(
  350. row_buf, tgt_, src_, num_rows, row_size, log_num_threads_x, init, binary_op);
  351. }
  352. template<typename scalar_t, class BinaryFunction>
  353. __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
  354. int dim, scalar_t init, BinaryFunction binary_op) {
  355. const int64_t row_size = self.size(dim);
  356. auto sizes = self.sizes();
  357. // Treat all outer dimensions (i.e. dim_ < dim) as one.
  358. const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
  359. // Treat all inner dimensions (i.e. dim > dimension) as one.
  360. const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
  361. dim3 threads(std::min(512, int(num_irows)));
  362. int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  363. dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
  364. check_fits_in_unsigned(num_irows, "num_irows");
  365. check_fits_in_unsigned(num_orows, "num_orows");
  366. check_fits_in_unsigned(row_size, "row_size");
  367. tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
  368. result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
  369. num_orows, num_irows, row_size, init, binary_op);
  370. C10_CUDA_KERNEL_LAUNCH_CHECK();
  371. }
  372. template <typename scalar_t, class BinaryFunction>
  373. void scan_innermost_dim(const TensorBase& self, const TensorBase& result,
  374. scalar_t init, BinaryFunction binary_op) {
  375. int64_t ndim = self.dim();
  376. // Treat all outer dimensions as a single dimension.
  377. int64_t row_size = self.size(ndim - 1);
  378. int64_t num_rows = self.numel() / row_size;
  379. // assuming max_num_threads per block is 512
  380. const uint32_t num_threads = 512;
  381. const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);
  382. const uint32_t num_threads_x = (1 << log_num_threads_x);
  383. const uint32_t num_threads_y = num_threads / num_threads_x;
  384. dim3 threads(num_threads_x, num_threads_y);
  385. int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
  386. dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y})));
  387. check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))");
  388. check_fits_in_unsigned(row_size, "row_size");
  389. tensor_kernel_scan_innermost_dim<scalar_t><<<grid, threads, num_threads * 2 * sizeof(scalar_t),
  390. at::cuda::getCurrentCUDAStream()>>>(
  391. result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
  392. num_rows, row_size, log_num_threads_x, init, binary_op);
  393. C10_CUDA_KERNEL_LAUNCH_CHECK();
  394. }
  395. template<typename scalar_t, typename BinaryFunction>
  396. void scan_dim(const TensorBase& self, const TensorBase& result,
  397. int64_t dim, scalar_t init, BinaryFunction binary_op) {
  398. int ndim = self.dim();
  399. auto self_ = self.expect_contiguous();
  400. TORCH_INTERNAL_ASSERT(result.is_contiguous());
  401. if (self.numel() == self.size(dim)) {
  402. cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
  403. } else if (dim == ndim - 1) {
  404. scan_innermost_dim<scalar_t>(*self_, result, init, binary_op);
  405. } else {
  406. scan_outer_dim<scalar_t>(*self_, result, dim, init, binary_op);
  407. }
  408. }
  409. }} // namespace at::native