wkv_op.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. #include <torch/extension.h>
  2. #include "ATen/ATen.h"
  3. typedef at::BFloat16 bf16;
  4. void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
  5. void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
  6. void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s);
  7. void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s);
  8. void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
  9. void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
  10. void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
  11. const int B = k.size(0);
  12. const int T = k.size(1);
  13. const int C = k.size(2);
  14. cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
  15. }
  16. void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
  17. const int B = k.size(0);
  18. const int T = k.size(1);
  19. const int C = k.size(2);
  20. cuda_forward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
  21. }
  22. void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
  23. const int B = k.size(0);
  24. const int T = k.size(1);
  25. const int C = k.size(2);
  26. cuda_forward_with_state(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), s.data_ptr<float>());
  27. }
  28. void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
  29. const int B = k.size(0);
  30. const int T = k.size(1);
  31. const int C = k.size(2);
  32. cuda_forward_with_state_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), s.data_ptr<float>());
  33. }
  34. void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
  35. const int B = k.size(0);
  36. const int T = k.size(1);
  37. const int C = k.size(2);
  38. cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
  39. }
  40. void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
  41. const int B = k.size(0);
  42. const int T = k.size(1);
  43. const int C = k.size(2);
  44. cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
  45. gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
  46. }
  47. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  48. m.def("forward", &forward, "wkv forward");
  49. m.def("forward_bf16", &forward_bf16, "wkv forward bf16");
  50. m.def("forward_with_state", &forward_with_state, "wkv forward with state");
  51. m.def("forward_with_state_bf16", &forward_with_state_bf16, "wkv forward with state bf16");
  52. m.def("backward", &backward, "wkv backward");
  53. m.def("backward_bf16", &backward_bf16, "wkv backward bf16");
  54. }
  55. TORCH_LIBRARY(wkv, m) {
  56. m.def("forward", forward);
  57. m.def("forward_bf16", forward_bf16);
  58. m.def("forward_with_state", forward_with_state);
  59. m.def("forward_with_state_bf16", forward_with_state_bf16);
  60. m.def("backward", backward);
  61. m.def("backward_bf16", backward_bf16);
  62. }