cuda_kernel.cu 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. #include "cuda_kernel.h"
  2. //////////////////////////////////////////////////////////////////////////////////////////////////
  3. //////////////////////////////////////////////////////////////////////////////////////////////////
  4. __global__ void index_max_cuda_kernel(
  5. float *index_vals, // [batch_size, 32, num_block]
  6. int *indices, // [batch_size, num_block]
  7. float *max_vals, // [batch_size, A_num_block * 32]
  8. float *max_vals_scatter, // [batch_size, 32, num_block]
  9. long batch_size,
  10. long A_num_block,
  11. long B_num_block,
  12. long num_block
  13. ) {
  14. long batch_idx = blockIdx.x;
  15. long thread_idx = threadIdx.x;
  16. long num_thread = blockDim.x;
  17. extern __shared__ float buffer[];
  18. int *max_buffer = (int*)buffer;
  19. for (int i = 0; i < A_num_block * 32; i = i + num_thread) {
  20. int idx = i + thread_idx;
  21. if (idx < A_num_block * 32) {
  22. max_buffer[idx] = -1e8;
  23. }
  24. }
  25. __syncthreads();
  26. int *indices_pt = &indices[batch_idx * num_block];
  27. float *index_vals_pt = &index_vals[batch_idx * num_block * 32];
  28. for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) {
  29. int idx = idx_start + thread_idx;
  30. int A_block_idx = indices_pt[idx % num_block] / B_num_block;
  31. atomicMax(&max_buffer[A_block_idx * 32 + idx / num_block], (int)(index_vals_pt[idx] * 1000));
  32. }
  33. __syncthreads();
  34. float *max_vals_pt = &max_vals[batch_idx * A_num_block * 32];
  35. for (int i = 0; i < A_num_block * 32; i = i + num_thread) {
  36. int idx = i + thread_idx;
  37. if (idx < A_num_block * 32) {
  38. max_vals_pt[idx] = (float)max_buffer[idx] / 1000.;
  39. }
  40. }
  41. float *max_vals_scatter_pt = &max_vals_scatter[batch_idx * num_block * 32];
  42. for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) {
  43. int idx = idx_start + thread_idx;
  44. int A_block_idx = indices_pt[idx % num_block] / B_num_block;
  45. max_vals_scatter_pt[idx] = (float)max_buffer[A_block_idx * 32 + idx / num_block] / 1000.;
  46. }
  47. }
  48. __global__ void mm_to_sparse_cuda_kernel(
  49. float *dense_A, // [batch_size, A_num_block, dim, 32]
  50. float *dense_B, // [batch_size, B_num_block, dim, 32]
  51. int *indices, // [batch_size, num_block]
  52. float *sparse_C, // [batch_size, num_block, 32, 32]
  53. long batch_size,
  54. long A_num_block,
  55. long B_num_block,
  56. long dim,
  57. long num_block
  58. ) {
  59. long batch_idx = blockIdx.y;
  60. long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
  61. long thread_idx = threadIdx.x;
  62. __shared__ float buffer[4096];
  63. float *A_buffer = &buffer[threadIdx.y * 1024]; // [2, 8, 32]
  64. float *B_buffer = &buffer[threadIdx.y * 1024 + 512]; // [2, 8, 32]
  65. long batch_idx__block_idx = batch_idx * num_block + block_idx;
  66. long AB_block_idx = indices[batch_idx__block_idx];
  67. float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * dim * 32];
  68. float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * dim * 32];
  69. int reg_1_idx = thread_idx / 8; // [0000000011111111222222223333333344444444555555556666666677777777]
  70. int reg_2_idx = thread_idx % 8; // [0123456701234567012345670123456701234567012345670123456701234567]
  71. float reg_1[8];
  72. float reg_2[8];
  73. float reg_array[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
  74. #pragma unroll
  75. for (int i = 0; i < 4; i++) {
  76. A_buffer[i * 64 + thread_idx] = dense_A_pt[i * 64 + thread_idx];
  77. B_buffer[i * 64 + thread_idx] = dense_B_pt[i * 64 + thread_idx];
  78. }
  79. __syncthreads();
  80. #pragma unroll
  81. for (int i = 0; i < 4; i++) {
  82. reg_1[i] = A_buffer[reg_1_idx * 4 + i];
  83. reg_2[i] = B_buffer[reg_2_idx * 4 + i];
  84. }
  85. for (int dim_stride = 1; dim_stride < (dim / 8); dim_stride++) {
  86. #pragma unroll
  87. for (int i = 0; i < 4; i++) {
  88. A_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_A_pt[dim_stride * 256 + i * 64 + thread_idx];
  89. B_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_B_pt[dim_stride * 256 + i * 64 + thread_idx];
  90. }
  91. #pragma unroll
  92. for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) {
  93. #pragma unroll
  94. for (int i = 0; i < 4; i++) {
  95. reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_1_idx * 4 + i];
  96. reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_2_idx * 4 + i];
  97. }
  98. #pragma unroll
  99. for (int i = 0; i < 4; i++) {
  100. #pragma unroll
  101. for (int j = 0; j < 4; j++) {
  102. reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
  103. }
  104. }
  105. }
  106. __syncthreads();
  107. #pragma unroll
  108. for (int i = 0; i < 4; i++) {
  109. reg_1[i] = A_buffer[(dim_stride % 2) * 256 + reg_1_idx * 4 + i];
  110. reg_2[i] = B_buffer[(dim_stride % 2) * 256 + reg_2_idx * 4 + i];
  111. }
  112. #pragma unroll
  113. for (int i = 0; i < 4; i++) {
  114. #pragma unroll
  115. for (int j = 0; j < 4; j++) {
  116. reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
  117. }
  118. }
  119. }
  120. #pragma unroll
  121. for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) {
  122. #pragma unroll
  123. for (int i = 0; i < 4; i++) {
  124. reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[256 + mini_dim_idx * 32 + reg_1_idx * 4 + i];
  125. reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[256 + mini_dim_idx * 32 + reg_2_idx * 4 + i];
  126. }
  127. #pragma unroll
  128. for (int i = 0; i < 4; i++) {
  129. #pragma unroll
  130. for (int j = 0; j < 4; j++) {
  131. reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
  132. }
  133. }
  134. }
  135. #pragma unroll
  136. for (int i = 0; i < 4; i++) {
  137. #pragma unroll
  138. for (int j = 0; j < 4; j++) {
  139. reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
  140. }
  141. }
  142. __syncthreads();
  143. float *C_buffer = &buffer[threadIdx.y * 1024]; // [32, 32]
  144. #pragma unroll
  145. for (int i = 0; i < 4; i++) {
  146. #pragma unroll
  147. for (int j = 0; j < 4; j++) {
  148. C_buffer[(reg_2_idx * 4 + j) * 32 + reg_1_idx * 4 + i] = reg_array[i * 4 + j];
  149. }
  150. }
  151. __syncthreads();
  152. float *sparse_C_pt = &sparse_C[batch_idx__block_idx * 1024];
  153. #pragma unroll
  154. for (int i = 0; i < 16; i++) {
  155. sparse_C_pt[i * 64 + thread_idx] = C_buffer[i * 64 + thread_idx];
  156. }
  157. }
  158. __global__ void sparse_dense_mm_cuda_kernel(
  159. float *sparse_A, // [batch_size, num_block, 32, 32]
  160. int *indices, // [batch_size, num_block]
  161. float *dense_B, // [batch_size, B_num_block, dim, 32]
  162. float *dense_C, // [batch_size, A_num_block, dim, 32]
  163. long batch_size,
  164. long A_num_block,
  165. long B_num_block,
  166. long dim,
  167. long num_block
  168. ) {
  169. long batch_idx = blockIdx.y;
  170. long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
  171. long thread_idx = threadIdx.x;
  172. __shared__ float buffer[6144];
  173. float *A_buffer = &buffer[threadIdx.y * 3072]; // [32, 32]
  174. float *B_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [32, 64]
  175. long batch_idx__block_idx = batch_idx * num_block + block_idx;
  176. float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024];
  177. #pragma unroll
  178. for (int i = 0; i < 8; i++) {
  179. A_buffer[i * 128 + thread_idx] = sparse_A_pt[i * 128 + thread_idx];
  180. }
  181. long AB_block_idx = indices[batch_idx__block_idx];
  182. float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * 32 * dim];
  183. float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32 * dim];
  184. // [0000000011111111222222223333333344444444555555556666666677777777]
  185. // [0123456701234567012345670123456701234567012345670123456701234567]
  186. int reg_1_idx = thread_idx / 8;
  187. int reg_2_idx = thread_idx % 8;
  188. float reg_1[8];
  189. float reg_2[8];
  190. float reg_array[16];
  191. for (int dim_stride = 0; dim_stride < dim; dim_stride = dim_stride + 64) {
  192. #pragma unroll
  193. for (int i = 0; i < 16; i++) {
  194. B_buffer[i * 128 + thread_idx] = dense_B_pt[dim_stride * 32 + i * 128 + thread_idx];
  195. }
  196. #pragma unroll
  197. for (int i = 0; i < 16; i++) {
  198. reg_array[i] = 0;
  199. }
  200. __syncthreads();
  201. #pragma unroll
  202. for (int i = 0; i < 4; i++) {
  203. reg_1[i] = B_buffer[(reg_1_idx * 4 + i) * 32];
  204. reg_2[i] = A_buffer[reg_2_idx * 4 + i];
  205. }
  206. #pragma unroll
  207. for (int mini_dim_idx = 1; mini_dim_idx < 32; mini_dim_idx++) {
  208. #pragma unroll
  209. for (int i = 0; i < 4; i++) {
  210. reg_1[(mini_dim_idx % 2) * 4 + i] = B_buffer[(reg_1_idx * 4 + i) * 32 + mini_dim_idx];
  211. reg_2[(mini_dim_idx % 2) * 4 + i] = A_buffer[mini_dim_idx * 32 + reg_2_idx * 4 + i];
  212. }
  213. #pragma unroll
  214. for (int i = 0; i < 4; i++) {
  215. #pragma unroll
  216. for (int j = 0; j < 4; j++) {
  217. reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
  218. }
  219. }
  220. }
  221. #pragma unroll
  222. for (int i = 0; i < 4; i++) {
  223. #pragma unroll
  224. for (int j = 0; j < 4; j++) {
  225. reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
  226. }
  227. }
  228. __syncthreads();
  229. float *C_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [64, 32]
  230. #pragma unroll
  231. for (int i = 0; i < 4; i++) {
  232. #pragma unroll
  233. for (int j = 0; j < 4; j++) {
  234. C_buffer[(reg_1_idx * 4 + i) * 32 + reg_2_idx * 4 + j] = reg_array[i * 4 + j];
  235. }
  236. }
  237. __syncthreads();
  238. #pragma unroll
  239. for (int i = 0; i < 16; i++) {
  240. atomicAdd(&dense_C_pt[dim_stride * 32 + i * 128 + thread_idx], C_buffer[i * 128 + thread_idx]);
  241. }
  242. __syncthreads();
  243. }
  244. }
  245. __global__ void reduce_sum_cuda_kernel(
  246. float *sparse_A, // [batch_size, num_block, 32, 32]
  247. int *indices, // [batch_size, num_block]
  248. float *dense_C, // [batch_size, A_num_block, 32]
  249. long batch_size,
  250. long A_num_block,
  251. long B_num_block,
  252. long num_block
  253. ) {
  254. long batch_idx = blockIdx.y;
  255. long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
  256. long thread_idx = threadIdx.x;
  257. long batch_idx__block_idx = batch_idx * num_block + block_idx;
  258. long AB_block_idx = indices[batch_idx__block_idx];
  259. float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024];
  260. float reg_array[16];
  261. float value = 0;
  262. #pragma unroll
  263. for (int i = 0; i < 8; i++) {
  264. reg_array[i] = sparse_A_pt[i * 32 + thread_idx];
  265. }
  266. #pragma unroll
  267. for (int stride = 8; stride < 32; stride = stride + 8) {
  268. #pragma unroll
  269. for (int i = 0; i < 8; i++) {
  270. reg_array[(stride + i) % 16] = sparse_A_pt[(stride + i) * 32 + thread_idx];
  271. }
  272. #pragma unroll
  273. for (int i = 0; i < 8; i++) {
  274. value = value + reg_array[(stride - 8 + i) % 16];
  275. }
  276. }
  277. #pragma unroll
  278. for (int i = 0; i < 8; i++) {
  279. value = value + reg_array[8 + i];
  280. }
  281. float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32];
  282. atomicAdd(&dense_C_pt[thread_idx], value);
  283. }
  284. __global__ void scatter_cuda_kernel(
  285. float *dense_A, // [batch_size, A_num_block, 32]
  286. int *indices, // [batch_size, num_block]
  287. float *sparse_C, // [batch_size, num_block, 32, 32]
  288. long batch_size,
  289. long A_num_block,
  290. long B_num_block,
  291. long num_block
  292. ) {
  293. long batch_idx = blockIdx.y;
  294. long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
  295. long thread_idx = threadIdx.x;
  296. long batch_idx__block_idx = batch_idx * num_block + block_idx;
  297. long AB_block_idx = indices[batch_idx__block_idx];
  298. float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32];
  299. float *sparse_C_pt = &sparse_C[(batch_idx * num_block + block_idx) * 1024];
  300. float value = dense_A_pt[thread_idx];
  301. #pragma unroll
  302. for (int i = 0; i < 32; i++) {
  303. sparse_C_pt[i * 32 + thread_idx] = value;
  304. }
  305. }