CPUBlas.h 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. #pragma once
  2. #include <ATen/OpMathType.h>
  3. #include <ATen/native/DispatchStub.h>
  4. #include <ATen/native/TransposeType.h>
  5. #include <c10/util/complex.h>
  6. #include <c10/core/ScalarType.h>
  7. #include <c10/core/Scalar.h>
  8. namespace at::native::cpublas {
  9. namespace internal {
  10. void normalize_last_dims(
  11. TransposeType transa, TransposeType transb,
  12. int64_t m, int64_t n, int64_t k,
  13. int64_t *lda, int64_t *ldb, int64_t *ldc);
  14. } // namespace internal
  15. using gemm_fn = void(*)(
  16. at::ScalarType type,
  17. TransposeType transa, TransposeType transb,
  18. int64_t m, int64_t n, int64_t k,
  19. const Scalar& alpha,
  20. const void *a, int64_t lda,
  21. const void *b, int64_t ldb,
  22. const Scalar& beta,
  23. void *c, int64_t ldc);
  24. DECLARE_DISPATCH(gemm_fn, gemm_stub);
  25. template <typename scalar_t>
  26. void gemm(
  27. TransposeType transa, TransposeType transb,
  28. int64_t m, int64_t n, int64_t k,
  29. at::opmath_type<scalar_t> alpha,
  30. const scalar_t *a, int64_t lda,
  31. const scalar_t *b, int64_t ldb,
  32. at::opmath_type<scalar_t> beta,
  33. scalar_t *c, int64_t ldc) {
  34. internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
  35. gemm_stub(
  36. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  37. transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  38. }
  39. void gemm(
  40. TransposeType transa, TransposeType transb,
  41. int64_t m, int64_t n, int64_t k,
  42. double alpha,
  43. const double *a, int64_t lda,
  44. const double *b, int64_t ldb,
  45. double beta,
  46. double *c, int64_t ldc);
  47. void gemm(
  48. TransposeType transa, TransposeType transb,
  49. int64_t m, int64_t n, int64_t k,
  50. float alpha,
  51. const float *a, int64_t lda,
  52. const float *b, int64_t ldb,
  53. float beta,
  54. float *c, int64_t ldc);
  55. void gemm(
  56. TransposeType transa, TransposeType transb,
  57. int64_t m, int64_t n, int64_t k,
  58. float alpha,
  59. const at::BFloat16 *a, int64_t lda,
  60. const at::BFloat16 *b, int64_t ldb,
  61. float beta,
  62. at::BFloat16 *c, int64_t ldc);
  63. void gemm(
  64. TransposeType transa, TransposeType transb,
  65. int64_t m, int64_t n, int64_t k,
  66. const float alpha,
  67. const at::BFloat16 *a, int64_t lda,
  68. const at::BFloat16 *b, int64_t ldb,
  69. const float beta,
  70. float *c, int64_t ldc);
  71. void gemm(
  72. TransposeType transa, TransposeType transb,
  73. int64_t m, int64_t n, int64_t k,
  74. float alpha,
  75. const at::Half *a, int64_t lda,
  76. const at::Half *b, int64_t ldb,
  77. float beta,
  78. at::Half *c, int64_t ldc);
  79. void gemm(
  80. TransposeType transa, TransposeType transb,
  81. int64_t m, int64_t n, int64_t k,
  82. const float alpha,
  83. const at::Half *a, int64_t lda,
  84. const at::Half *b, int64_t ldb,
  85. const float beta,
  86. float *c, int64_t ldc);
  87. void gemm(
  88. TransposeType transa, TransposeType transb,
  89. int64_t m, int64_t n, int64_t k,
  90. c10::complex<double> alpha,
  91. const c10::complex<double> *a, int64_t lda,
  92. const c10::complex<double> *b, int64_t ldb,
  93. c10::complex<double> beta,
  94. c10::complex<double> *c, int64_t ldc);
  95. void gemm(
  96. TransposeType transa, TransposeType transb,
  97. int64_t m, int64_t n, int64_t k,
  98. c10::complex<float> alpha,
  99. const c10::complex<float> *a, int64_t lda,
  100. const c10::complex<float> *b, int64_t ldb,
  101. c10::complex<float> beta,
  102. c10::complex<float> *c, int64_t ldc);
  103. void gemm(
  104. TransposeType transa, TransposeType transb,
  105. int64_t m, int64_t n, int64_t k,
  106. int64_t alpha,
  107. const int64_t *a, int64_t lda,
  108. const int64_t *b, int64_t ldb,
  109. int64_t beta,
  110. int64_t *c, int64_t ldc);
  111. template <typename scalar_t>
  112. void gemm_batched(
  113. TransposeType transa, TransposeType transb,
  114. int64_t batch_size, int64_t m, int64_t n, int64_t k,
  115. scalar_t alpha,
  116. const scalar_t * const *a, int64_t lda,
  117. const scalar_t * const *b, int64_t ldb,
  118. const scalar_t beta,
  119. scalar_t * const *c, int64_t ldc);
  120. template <typename scalar_t>
  121. void gemm_batched_with_stride(
  122. TransposeType transa, TransposeType transb,
  123. int64_t batch_size, int64_t m, int64_t n, int64_t k,
  124. scalar_t alpha,
  125. const scalar_t *a, int64_t lda, int64_t batch_stride_a,
  126. const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
  127. scalar_t beta,
  128. scalar_t *c, int64_t ldc, int64_t batch_stride_c);
  129. using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
  130. DECLARE_DISPATCH(axpy_fn, axpy_stub);
  131. template<typename scalar_t>
  132. void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
  133. if(n == 1)
  134. {
  135. incx = 1;
  136. incy = 1;
  137. }
  138. axpy_stub(
  139. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  140. n, a, x, incx, y, incy);
  141. }
  142. void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
  143. void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
  144. void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
  145. void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
  146. using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
  147. DECLARE_DISPATCH(copy_fn, copy_stub);
  148. template<typename scalar_t>
  149. void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
  150. if(n == 1)
  151. {
  152. incx = 1;
  153. incy = 1;
  154. }
  155. copy_stub(
  156. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  157. n, x, incx, y, incy);
  158. }
  159. void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
  160. void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
  161. void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
  162. void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
  163. } // namespace at::native::cpublas