cuda_kernel.h 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #define WARP_SIZE 32
  2. #define FULL_MASK 0xffffffff
  3. #define OPTIMAL_THREADS 256
  4. __global__ void index_max_cuda_kernel(
  5. float *index_vals, // [batch_size, 32, num_block]
  6. int *indices, // [batch_size, num_block]
  7. float *max_vals, // [batch_size, A_num_block * 32]
  8. float *max_vals_scatter, // [batch_size, 32, num_block]
  9. long batch_size,
  10. long A_num_block,
  11. long B_num_block,
  12. long num_block
  13. );
  14. __global__ void mm_to_sparse_cuda_kernel(
  15. float *dense_A, // [batch_size, A_num_block, dim, 32]
  16. float *dense_B, // [batch_size, B_num_block, dim, 32]
  17. int *indices, // [batch_size, num_block]
  18. float *sparse_C, // [batch_size, num_block, 32, 32]
  19. long batch_size,
  20. long A_num_block,
  21. long B_num_block,
  22. long dim,
  23. long num_block
  24. );
  25. __global__ void sparse_dense_mm_cuda_kernel(
  26. float *sparse_A, // [batch_size, num_block, 32, 32]
  27. int *indices, // [batch_size, num_block]
  28. float *dense_B, // [batch_size, B_num_block, dim, 32]
  29. float *dense_C, // [batch_size, A_num_block, dim, 32]
  30. long batch_size,
  31. long A_num_block,
  32. long B_num_block,
  33. long dim,
  34. long num_block
  35. );
  36. __global__ void reduce_sum_cuda_kernel(
  37. float *sparse_A, // [batch_size, num_block, 32, 32]
  38. int *indices, // [batch_size, num_block]
  39. float *dense_C, // [batch_size, A_num_block, 32]
  40. long batch_size,
  41. long A_num_block,
  42. long B_num_block,
  43. long num_block
  44. );
  45. __global__ void scatter_cuda_kernel(
  46. float *dense_A, // [batch_size, A_num_block, 32]
  47. int *indices, // [batch_size, num_block]
  48. float *sparse_C, // [batch_size, num_block, 32, 32]
  49. long batch_size,
  50. long A_num_block,
  51. long B_num_block,
  52. long num_block
  53. );