torch_extension.cpp 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. #include <torch/extension.h>
  2. #include <ATen/ATen.h>
  3. #include "cuda_launch.h"
  4. #include <vector>
  5. std::vector<at::Tensor> index_max(
  6. at::Tensor index_vals,
  7. at::Tensor indices,
  8. int A_num_block,
  9. int B_num_block
  10. ) {
  11. return index_max_kernel(
  12. index_vals,
  13. indices,
  14. A_num_block,
  15. B_num_block
  16. );
  17. }
  18. at::Tensor mm_to_sparse(
  19. at::Tensor dense_A,
  20. at::Tensor dense_B,
  21. at::Tensor indices
  22. ) {
  23. return mm_to_sparse_kernel(
  24. dense_A,
  25. dense_B,
  26. indices
  27. );
  28. }
  29. at::Tensor sparse_dense_mm(
  30. at::Tensor sparse_A,
  31. at::Tensor indices,
  32. at::Tensor dense_B,
  33. int A_num_block
  34. ) {
  35. return sparse_dense_mm_kernel(
  36. sparse_A,
  37. indices,
  38. dense_B,
  39. A_num_block
  40. );
  41. }
  42. at::Tensor reduce_sum(
  43. at::Tensor sparse_A,
  44. at::Tensor indices,
  45. int A_num_block,
  46. int B_num_block
  47. ) {
  48. return reduce_sum_kernel(
  49. sparse_A,
  50. indices,
  51. A_num_block,
  52. B_num_block
  53. );
  54. }
  55. at::Tensor scatter(
  56. at::Tensor dense_A,
  57. at::Tensor indices,
  58. int B_num_block
  59. ) {
  60. return scatter_kernel(
  61. dense_A,
  62. indices,
  63. B_num_block
  64. );
  65. }
  66. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  67. m.def("index_max", &index_max, "index_max (CUDA)");
  68. m.def("mm_to_sparse", &mm_to_sparse, "mm_to_sparse (CUDA)");
  69. m.def("sparse_dense_mm", &sparse_dense_mm, "sparse_dense_mm (CUDA)");
  70. m.def("reduce_sum", &reduce_sum, "reduce_sum (CUDA)");
  71. m.def("scatter", &scatter, "scatter (CUDA)");
  72. }