PersistentSoftmax.cuh 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. #pragma once
  2. #include <cfloat>
  3. #include <limits>
  4. #include <stdint.h>
  5. #include <cuda_fp16.h>
  6. #include <c10/macros/Macros.h>
  7. #include <ATen/cuda/DeviceUtils.cuh>
  8. namespace {
  9. int log2_ceil(int value) {
  10. int log2_value = 0;
  11. while ((1 << log2_value) < value) ++log2_value;
  12. return log2_value;
  13. }
  14. template<typename T>
  15. struct Add {
  16. __device__ __forceinline__ T operator()(T a, T b) const {
  17. return a + b;
  18. }
  19. };
  20. template<typename T>
  21. struct Max {
  22. __device__ __forceinline__ T operator()(T a, T b) const {
  23. return a < b ? b : a;
  24. }
  25. };
  26. template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
  27. __device__ __forceinline__ void warp_reduce(acc_t* sum) {
  28. ReduceOp<acc_t> r;
  29. #pragma unroll
  30. for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
  31. #pragma unroll
  32. for (int i = 0; i < WARP_BATCH; ++i) {
  33. acc_t b = WARP_SHFL_XOR(sum[i], offset, WARP_SIZE);
  34. sum[i] = r(sum[i], b);
  35. }
  36. }
  37. }
  38. // The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension.
  39. // Each sample contains element_count scalar elements. element_count can be any integer value <= 1024.
  40. // The template arguments have the following meaning:
  41. // One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples.
  42. // WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small.
  43. // A "WARP" contains "C10_WARPS_SIZE" threads, these treads are guaranteed to belong to the same warp.
  44. // This is important because it means only __shfl_ instructions are required for reductions.
  45. // Note that this means WARP_SIZE must be a power of two and <= architecture warp size.
  46. // CUDA warp size is 32 for all existing GPU architectures, but there is no guarantee this will not change for future arch.
  47. // ROCm warp size is 64 for all currently ROCm-supported GPU architectures, but this may change for future archs.
  48. // is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed.
  49. // is_masked is a flag indicating whether SoftMax or MaskedSoftMax should be computed.
  50. // The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t.
  51. // This allows SoftMax to be fused with a cast immediately following the SoftMax.
  52. // The mask should have the same shape as input, with a boolean indicate if the value is masked.
  53. // The head_chunk_size is only used for transformer mask softmax, equals to H * D * D.
  54. // For instance:
  55. // input_t=half, acc_t=float, output_t=half => read half tensor, float accumulators, write half tensor.
  56. // input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor.
  57. // input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor.
  58. template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax, bool is_masked>
  59. __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count, const bool *mask = nullptr, const int head_chunk_size = -1, bool is_transformer_mask = false)
  60. {
  61. // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel.
  62. constexpr int next_power_of_two = 1 << log2_elements;
  63. constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  64. constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  65. constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  66. int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
  67. // batch_size might not be a multiple of WARP_BATCH. Check how
  68. // many batches have to computed within this WARP.
  69. int local_batches = batch_size - first_batch;
  70. if (local_batches > WARP_BATCH)
  71. local_batches = WARP_BATCH;
  72. // there might be multiple batches per warp. compute the index within the batch
  73. int local_idx = threadIdx.x;
  74. int idx_offset = first_batch * stride + local_idx;
  75. src += idx_offset;
  76. dst += idx_offset;
  77. if (is_transformer_mask) {
  78. mask += ((first_batch * stride) / head_chunk_size) * stride + local_idx;
  79. } else {
  80. mask += idx_offset;
  81. }
  82. // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
  83. // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
  84. // the nested loops.
  85. // This should have no impact on performance because the loops are unrolled anyway.
  86. // load data from global memory
  87. acc_t elements[WARP_BATCH][WARP_ITERATIONS];
  88. for (int i = 0; i < WARP_BATCH; ++i) {
  89. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  90. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  91. int element_index = local_idx + it * WARP_SIZE;
  92. if (element_index < batch_element_count) {
  93. elements[i][it] = src[i*element_count+it*WARP_SIZE];
  94. } else {
  95. elements[i][it] = -std::numeric_limits<acc_t>::infinity();
  96. }
  97. }
  98. }
  99. // compute max_value
  100. acc_t max_value[WARP_BATCH];
  101. #pragma unroll
  102. for (int i = 0; i < WARP_BATCH; ++i) {
  103. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  104. bool is_meaningful_max = false;
  105. max_value[i] = elements[i][0];
  106. #pragma unroll
  107. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  108. if (is_masked) {
  109. int idx = it*WARP_SIZE;
  110. if ((idx + local_idx) < batch_element_count) {
  111. if (!is_transformer_mask) {
  112. idx += i*element_count;
  113. }
  114. if (!mask[idx]) {
  115. max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
  116. is_meaningful_max = true;
  117. }
  118. }
  119. } else {
  120. max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
  121. }
  122. }
  123. if (is_masked) {
  124. if (!is_meaningful_max) {
  125. max_value[i] = -std::numeric_limits<acc_t>::infinity();
  126. }
  127. }
  128. }
  129. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
  130. acc_t sum[WARP_BATCH] { 0.0f };
  131. #pragma unroll
  132. for (int i = 0; i < WARP_BATCH; ++i) {
  133. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  134. #pragma unroll
  135. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  136. if (!is_masked) {
  137. if (is_log_softmax) {
  138. sum[i] += std::exp(elements[i][it] - max_value[i]);
  139. } else {
  140. elements[i][it] = std::exp(elements[i][it] - max_value[i]);
  141. sum[i] += elements[i][it];
  142. }
  143. } else {
  144. int idx = it*WARP_SIZE;
  145. bool valid = (idx + local_idx) < batch_element_count;
  146. if (!is_transformer_mask) {
  147. idx += i*element_count;
  148. }
  149. if (valid) {
  150. if (!mask[idx]) {
  151. if (is_log_softmax) {
  152. sum[i] += std::exp(elements[i][it] - max_value[i]);
  153. } else {
  154. elements[i][it] = std::exp(elements[i][it] - max_value[i]);
  155. sum[i] += elements[i][it];
  156. }
  157. } else {
  158. if (!is_log_softmax) {
  159. // Masked values are treated as -infinity, and std::exp(-infinity) is 0.
  160. elements[i][it] = 0;
  161. }
  162. }
  163. } else {
  164. if (!is_log_softmax) {
  165. elements[i][it] = 0.;
  166. }
  167. }
  168. }
  169. }
  170. }
  171. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
  172. // store result
  173. #pragma unroll
  174. for (int i = 0; i < WARP_BATCH; ++i) {
  175. if (i >= local_batches)
  176. break;
  177. if (is_log_softmax) sum[i] = std::log(sum[i]);
  178. #pragma unroll
  179. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  180. int element_index = local_idx + it * WARP_SIZE;
  181. if (element_index < element_count) {
  182. if (is_log_softmax) {
  183. dst[i*element_count+it*WARP_SIZE] = elements[i][it] - max_value[i] - sum[i];
  184. } else if (sum[i] == 0) {
  185. dst[i*element_count+it*WARP_SIZE] = std::numeric_limits<acc_t>::quiet_NaN();
  186. } else {
  187. dst[i*element_count+it*WARP_SIZE] = elements[i][it] / sum[i];
  188. }
  189. } else {
  190. break;
  191. }
  192. }
  193. }
  194. }
  195. template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax, bool is_masked>
  196. __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count, const bool *mask = nullptr)
  197. {
  198. // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
  199. constexpr int next_power_of_two = 1 << log2_elements;
  200. constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  201. constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  202. constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  203. int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
  204. // batch_size might not be a multiple of WARP_BATCH. Check how
  205. // many batches have to computed within this WARP.
  206. int local_batches = batch_size - first_batch;
  207. if (local_batches > WARP_BATCH)
  208. local_batches = WARP_BATCH;
  209. // there might be multiple batches per warp. compute the index within the batch
  210. int local_idx = threadIdx.x % WARP_SIZE;
  211. // the first element to process by the current thread
  212. int thread_offset = first_batch * stride + local_idx;
  213. grad += thread_offset;
  214. output += thread_offset;
  215. gradInput += thread_offset;
  216. if (is_masked) {
  217. mask += thread_offset;
  218. }
  219. // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
  220. // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
  221. // the nested loops.
  222. // This should have no impact on performance because the loops are unrolled anyway.
  223. // load data from global memory
  224. acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
  225. acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
  226. for (int i = 0; i < WARP_BATCH; ++i) {
  227. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  228. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  229. int element_index = local_idx + it * WARP_SIZE;
  230. if (element_index < batch_element_count) {
  231. grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE];
  232. output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
  233. } else {
  234. grad_reg[i][it] = acc_t(0);
  235. output_reg[i][it] = acc_t(0);
  236. }
  237. }
  238. }
  239. acc_t sum[WARP_BATCH] { 0.0f };
  240. #pragma unroll
  241. for (int i = 0; i < WARP_BATCH; ++i) {
  242. #pragma unroll
  243. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  244. if (!is_masked || !mask[i*element_count+it*WARP_SIZE]) {
  245. sum[i] += grad_reg[i][it];
  246. }
  247. }
  248. }
  249. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
  250. // store result
  251. #pragma unroll
  252. for (int i = 0; i < WARP_BATCH; ++i) {
  253. if (i >= local_batches)
  254. break;
  255. #pragma unroll
  256. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  257. int element_index = local_idx + it * WARP_SIZE;
  258. if (element_index < element_count) {
  259. if (is_masked && mask[i*element_count+it*WARP_SIZE]) {
  260. gradInput[i*element_count+it*WARP_SIZE] = 0;
  261. }
  262. // compute gradients
  263. else if (is_log_softmax) {
  264. gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
  265. } else {
  266. gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);
  267. }
  268. }
  269. }
  270. }
  271. }
  272. } // end of anonymous namespace
  273. template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
  274. void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr, int chunk_size = -1, bool is_transformer_mask = false)
  275. {
  276. TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
  277. if (softmax_elements == 0) {
  278. return;
  279. } else {
  280. int log2_elements = log2_ceil(softmax_elements);
  281. const int next_power_of_two = 1 << log2_elements;
  282. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
  283. int warp_size = at::cuda::warp_size();
  284. warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size;
  285. // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
  286. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  287. // use 128 threads per block to maximize gpu utilization
  288. constexpr int threads_per_block = 128;
  289. int warps_per_block = (threads_per_block / warp_size);
  290. int batches_per_block = warps_per_block * batches_per_warp;
  291. int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
  292. dim3 threads(warp_size, warps_per_block, 1);
  293. // Launch code would be more elegant if C++ supported FOR CONSTEXPR
  294. switch (log2_elements) {
  295. #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \
  296. softmax_warp_forward<input_t, output_t, acc_t, L2E, is_log_softmax, is_masked> \
  297. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, \
  298. src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \
  299. C10_CUDA_KERNEL_LAUNCH_CHECK(); \
  300. break;
  301. LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1
  302. LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2
  303. LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4
  304. LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8
  305. LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16
  306. LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32
  307. LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64
  308. LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128
  309. LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256
  310. LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512
  311. LAUNCH_SOFTMAX_WARP_FORWARD(10); ; // 1024
  312. default:
  313. break;
  314. }
  315. }
  316. }
  317. template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
  318. void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr)
  319. {
  320. TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
  321. if (softmax_elements == 0) {
  322. return;
  323. } else {
  324. int log2_elements = log2_ceil(softmax_elements);
  325. const int next_power_of_two = 1 << log2_elements;
  326. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
  327. int warp_size = at::cuda::warp_size();
  328. warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size;
  329. // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
  330. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  331. // use 128 threads per block to maximize gpu utilization
  332. constexpr int threads_per_block = 128;
  333. int warps_per_block = (threads_per_block / warp_size);
  334. int batches_per_block = warps_per_block * batches_per_warp;
  335. int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
  336. dim3 threads(warp_size, warps_per_block, 1);
  337. // Launch code would be more elegant if C++ supported FOR CONSTEXPR
  338. switch (log2_elements) {
  339. #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E: \
  340. softmax_warp_backward<input_t, output_t, acc_t, L2E, is_log_softmax, is_masked> \
  341. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> \
  342. (grad_input, grad, output, batch_count, softmax_elements_stride, \
  343. softmax_elements, mask); \
  344. C10_CUDA_KERNEL_LAUNCH_CHECK(); \
  345. break;
  346. LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1
  347. LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2
  348. LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4
  349. LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8
  350. LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16
  351. LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32
  352. LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64
  353. LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128
  354. LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256
  355. LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512
  356. LAUNCH_SOFTMAX_WARP_BACKWARD(10); // 1024
  357. default:
  358. break;
  359. }
  360. }
  361. }