BatchLinearAlgebra.h 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. #pragma once
  2. #include <c10/util/Optional.h>
  3. #include <c10/util/string_view.h>
  4. #include <ATen/Config.h>
  5. #include <ATen/native/DispatchStub.h>
  6. // Forward declare TI
  7. namespace at {
  8. class Tensor;
  9. struct TensorIterator;
  10. namespace native {
  11. enum class TransposeType;
  12. }
  13. }
  14. namespace at::native {
  15. enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss};
  16. #if AT_BUILD_WITH_LAPACK()
  17. // Define per-batch functions to be used in the implementation of batched
  18. // linear algebra operations
  19. template <class scalar_t>
  20. void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info);
  21. template <class scalar_t>
  22. void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
  23. template <class scalar_t, class value_t=scalar_t>
  24. void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
  25. template <class scalar_t>
  26. void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
  27. template <class scalar_t>
  28. void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
  29. template <class scalar_t>
  30. void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info);
  31. template <class scalar_t, class value_t = scalar_t>
  32. void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
  33. template <class scalar_t>
  34. void lapackGels(char trans, int m, int n, int nrhs,
  35. scalar_t *a, int lda, scalar_t *b, int ldb,
  36. scalar_t *work, int lwork, int *info);
  37. template <class scalar_t, class value_t = scalar_t>
  38. void lapackGelsd(int m, int n, int nrhs,
  39. scalar_t *a, int lda, scalar_t *b, int ldb,
  40. value_t *s, value_t rcond, int *rank,
  41. scalar_t* work, int lwork,
  42. value_t *rwork, int* iwork, int *info);
  43. template <class scalar_t, class value_t = scalar_t>
  44. void lapackGelsy(int m, int n, int nrhs,
  45. scalar_t *a, int lda, scalar_t *b, int ldb,
  46. int *jpvt, value_t rcond, int *rank,
  47. scalar_t *work, int lwork, value_t* rwork, int *info);
  48. template <class scalar_t, class value_t = scalar_t>
  49. void lapackGelss(int m, int n, int nrhs,
  50. scalar_t *a, int lda, scalar_t *b, int ldb,
  51. value_t *s, value_t rcond, int *rank,
  52. scalar_t *work, int lwork,
  53. value_t *rwork, int *info);
  54. template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t>
  55. struct lapackLstsq_impl;
  56. template <class scalar_t, class value_t>
  57. struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> {
  58. static void call(
  59. char trans, int m, int n, int nrhs,
  60. scalar_t *a, int lda, scalar_t *b, int ldb,
  61. scalar_t *work, int lwork, int *info, // Gels flavor
  62. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  63. value_t *s, // Gelss flavor
  64. int *iwork // Gelsd flavor
  65. ) {
  66. lapackGels<scalar_t>(
  67. trans, m, n, nrhs,
  68. a, lda, b, ldb,
  69. work, lwork, info);
  70. }
  71. };
  72. template <class scalar_t, class value_t>
  73. struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> {
  74. static void call(
  75. char trans, int m, int n, int nrhs,
  76. scalar_t *a, int lda, scalar_t *b, int ldb,
  77. scalar_t *work, int lwork, int *info, // Gels flavor
  78. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  79. value_t *s, // Gelss flavor
  80. int *iwork // Gelsd flavor
  81. ) {
  82. lapackGelsy<scalar_t, value_t>(
  83. m, n, nrhs,
  84. a, lda, b, ldb,
  85. jpvt, rcond, rank,
  86. work, lwork, rwork, info);
  87. }
  88. };
  89. template <class scalar_t, class value_t>
  90. struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> {
  91. static void call(
  92. char trans, int m, int n, int nrhs,
  93. scalar_t *a, int lda, scalar_t *b, int ldb,
  94. scalar_t *work, int lwork, int *info, // Gels flavor
  95. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  96. value_t *s, // Gelss flavor
  97. int *iwork // Gelsd flavor
  98. ) {
  99. lapackGelsd<scalar_t, value_t>(
  100. m, n, nrhs,
  101. a, lda, b, ldb,
  102. s, rcond, rank,
  103. work, lwork,
  104. rwork, iwork, info);
  105. }
  106. };
  107. template <class scalar_t, class value_t>
  108. struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> {
  109. static void call(
  110. char trans, int m, int n, int nrhs,
  111. scalar_t *a, int lda, scalar_t *b, int ldb,
  112. scalar_t *work, int lwork, int *info, // Gels flavor
  113. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  114. value_t *s, // Gelss flavor
  115. int *iwork // Gelsd flavor
  116. ) {
  117. lapackGelss<scalar_t, value_t>(
  118. m, n, nrhs,
  119. a, lda, b, ldb,
  120. s, rcond, rank,
  121. work, lwork,
  122. rwork, info);
  123. }
  124. };
  125. template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t>
  126. void lapackLstsq(
  127. char trans, int m, int n, int nrhs,
  128. scalar_t *a, int lda, scalar_t *b, int ldb,
  129. scalar_t *work, int lwork, int *info, // Gels flavor
  130. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  131. value_t *s, // Gelss flavor
  132. int *iwork // Gelsd flavor
  133. ) {
  134. lapackLstsq_impl<driver_type, scalar_t, value_t>::call(
  135. trans, m, n, nrhs,
  136. a, lda, b, ldb,
  137. work, lwork, info,
  138. jpvt, rcond, rank, rwork,
  139. s,
  140. iwork);
  141. }
  142. template <class scalar_t>
  143. void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
  144. template <class scalar_t>
  145. void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
  146. template <class scalar_t>
  147. void lapackLdlHermitian(
  148. char uplo,
  149. int n,
  150. scalar_t* a,
  151. int lda,
  152. int* ipiv,
  153. scalar_t* work,
  154. int lwork,
  155. int* info);
  156. template <class scalar_t>
  157. void lapackLdlSymmetric(
  158. char uplo,
  159. int n,
  160. scalar_t* a,
  161. int lda,
  162. int* ipiv,
  163. scalar_t* work,
  164. int lwork,
  165. int* info);
  166. template <class scalar_t>
  167. void lapackLdlSolveHermitian(
  168. char uplo,
  169. int n,
  170. int nrhs,
  171. scalar_t* a,
  172. int lda,
  173. int* ipiv,
  174. scalar_t* b,
  175. int ldb,
  176. int* info);
  177. template <class scalar_t>
  178. void lapackLdlSolveSymmetric(
  179. char uplo,
  180. int n,
  181. int nrhs,
  182. scalar_t* a,
  183. int lda,
  184. int* ipiv,
  185. scalar_t* b,
  186. int ldb,
  187. int* info);
  188. template<class scalar_t, class value_t=scalar_t>
  189. void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
  190. #endif
  191. #if AT_BUILD_WITH_BLAS()
  192. template <class scalar_t>
  193. void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
  194. #endif
  195. using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
  196. DECLARE_DISPATCH(cholesky_fn, cholesky_stub);
  197. using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);
  198. DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
  199. using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/);
  200. DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub);
  201. using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/);
  202. DECLARE_DISPATCH(geqrf_fn, geqrf_stub);
  203. using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/);
  204. DECLARE_DISPATCH(orgqr_fn, orgqr_stub);
  205. using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/);
  206. DECLARE_DISPATCH(ormqr_fn, ormqr_stub);
  207. using linalg_eigh_fn = void (*)(
  208. const Tensor& /*eigenvalues*/,
  209. const Tensor& /*eigenvectors*/,
  210. const Tensor& /*infos*/,
  211. bool /*upper*/,
  212. bool /*compute_eigenvectors*/);
  213. DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub);
  214. using lstsq_fn = void (*)(
  215. const Tensor& /*a*/,
  216. Tensor& /*b*/,
  217. Tensor& /*rank*/,
  218. Tensor& /*singular_values*/,
  219. Tensor& /*infos*/,
  220. double /*rcond*/,
  221. std::string /*driver_name*/);
  222. DECLARE_DISPATCH(lstsq_fn, lstsq_stub);
  223. using triangular_solve_fn = void (*)(
  224. const Tensor& /*A*/,
  225. const Tensor& /*B*/,
  226. bool /*left*/,
  227. bool /*upper*/,
  228. TransposeType /*transpose*/,
  229. bool /*unitriangular*/);
  230. DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
  231. using lu_factor_fn = void (*)(
  232. const Tensor& /*input*/,
  233. const Tensor& /*pivots*/,
  234. const Tensor& /*infos*/,
  235. bool /*compute_pivots*/);
  236. DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);
  237. using unpack_pivots_fn = void(*)(
  238. TensorIterator& iter,
  239. const int64_t dim_size,
  240. const int64_t max_pivot);
  241. DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
  242. using lu_solve_fn = void (*)(
  243. const Tensor& /*LU*/,
  244. const Tensor& /*pivots*/,
  245. const Tensor& /*B*/,
  246. TransposeType /*trans*/);
  247. DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub);
  248. using ldl_factor_fn = void (*)(
  249. const Tensor& /*LD*/,
  250. const Tensor& /*pivots*/,
  251. const Tensor& /*info*/,
  252. bool /*upper*/,
  253. bool /*hermitian*/);
  254. DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub);
  255. using svd_fn = void (*)(
  256. const Tensor& /*A*/,
  257. const bool /*full_matrices*/,
  258. const bool /*compute_uv*/,
  259. const std::optional<c10::string_view>& /*driver*/,
  260. const Tensor& /*U*/,
  261. const Tensor& /*S*/,
  262. const Tensor& /*Vh*/,
  263. const Tensor& /*info*/);
  264. DECLARE_DISPATCH(svd_fn, svd_stub);
  265. using ldl_solve_fn = void (*)(
  266. const Tensor& /*LD*/,
  267. const Tensor& /*pivots*/,
  268. const Tensor& /*result*/,
  269. bool /*upper*/,
  270. bool /*hermitian*/);
  271. DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub);
  272. } // namespace at::native