SparseCsrTensorUtils.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. #pragma once
  2. #include <ATen/SparseCsrTensorImpl.h>
  3. #include <ATen/SparseTensorImpl.h>
  4. #include <ATen/core/Tensor.h>
  5. #ifndef AT_PER_OPERATOR_HEADERS
  6. #include <ATen/Functions.h>
  7. #include <ATen/NativeFunctions.h>
  8. #include <ATen/Operators.h>
  9. #else
  10. #include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
  11. #include <ATen/ops/resize_as_sparse_native.h>
  12. #endif
  13. #define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
  14. [&] { \
  15. const auto& the_layout = LAYOUT; \
  16. switch (the_layout) { \
  17. case kSparseCsr: \
  18. case kSparseCsc: \
  19. case kSparseBsr: \
  20. case kSparseBsc: \
  21. return __VA_ARGS__(); \
  22. default: \
  23. AT_ERROR( \
  24. NAME, \
  25. " expected sparse compressed tensor layout but got ", \
  26. the_layout); \
  27. } \
  28. }()
  29. #define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
  30. LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
  31. [&]() { \
  32. const auto& the_layout = LAYOUT; \
  33. switch (the_layout) { \
  34. case kSparseCsr: \
  35. case kSparseBsr: \
  36. return (ROW_DIM_ACTION)(); \
  37. case kSparseCsc: \
  38. case kSparseBsc: \
  39. return (COLUMN_DIM_ACTION)(); \
  40. default: \
  41. AT_ERROR( \
  42. NAME, \
  43. " expected sparse compressed tensor layout but got ", \
  44. the_layout); \
  45. } \
  46. }()
  47. #define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
  48. LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
  49. [&]() { \
  50. const auto& the_layout = LAYOUT; \
  51. switch (the_layout) { \
  52. case kSparseCsr: \
  53. case kSparseCsc: \
  54. return (NO_BLOCK_ACTION)(); \
  55. case kSparseBsr: \
  56. case kSparseBsc: \
  57. return (BLOCK_ACTION)(); \
  58. default: \
  59. AT_ERROR( \
  60. NAME, \
  61. " expected sparse compressed tensor layout but got ", \
  62. the_layout); \
  63. } \
  64. }()
  65. #define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
  66. LAYOUT, NAME, ROW_DIM_ACTION) \
  67. [&]() { \
  68. const auto& the_layout = LAYOUT; \
  69. switch (the_layout) { \
  70. case kSparseCsr: \
  71. case kSparseBsr: \
  72. return (ROW_DIM_ACTION)(); \
  73. default: \
  74. AT_ERROR( \
  75. NAME, \
  76. " expected sparse row compressed tensor layout but got ", \
  77. the_layout); \
  78. } \
  79. }()
  80. #define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
  81. LAYOUT, NAME, COL_DIM_ACTION) \
  82. [&]() { \
  83. const auto& the_layout = LAYOUT; \
  84. switch (the_layout) { \
  85. case kSparseCsc: \
  86. case kSparseBsc: \
  87. return (COL_DIM_ACTION)(); \
  88. default: \
  89. AT_ERROR( \
  90. NAME, \
  91. " expected sparse column compressed tensor layout but got ", \
  92. the_layout); \
  93. } \
  94. }()
  95. #define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
  96. [&]() { \
  97. const auto& the_layout = LAYOUT; \
  98. switch (the_layout) { \
  99. case kSparseCsr: \
  100. case kSparseCsc: \
  101. return (ACTION)(); \
  102. default: \
  103. AT_ERROR( \
  104. NAME, \
  105. " expected sparse compressed (non-block) tensor layout but got ", \
  106. the_layout); \
  107. } \
  108. }()
  109. #define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
  110. [&]() { \
  111. const auto& the_layout = LAYOUT; \
  112. switch (the_layout) { \
  113. case kSparseBsr: \
  114. case kSparseBsc: \
  115. return (ACTION)(); \
  116. default: \
  117. AT_ERROR( \
  118. NAME, \
  119. " expected sparse compressed block tensor layout but got ", \
  120. the_layout); \
  121. } \
  122. }()
  123. #define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
  124. AT_DISPATCH_SWITCH( \
  125. TYPE, \
  126. NAME, \
  127. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
  128. kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
  129. namespace at::sparse_csr {
  130. using SparseCsrTensor = Tensor;
  131. inline bool is_sparse_compressed(const Layout& layout) {
  132. switch (layout) {
  133. case kSparseCsr:
  134. case kSparseCsc:
  135. case kSparseBsr:
  136. case kSparseBsc:
  137. return true;
  138. default:;
  139. }
  140. return false;
  141. }
  142. inline bool is_sparse_compressed(const Tensor& self) {
  143. return is_sparse_compressed(self.layout());
  144. }
  145. inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
  146. AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
  147. self.layout(), "get_sparse_csr_impl", [&] {});
  148. return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
  149. }
  150. inline std::string layoutToString(
  151. Layout layout,
  152. bool upper = false,
  153. bool lower = false) {
  154. switch (layout) {
  155. case kSparseCsr:
  156. return (upper ? "CSR" : (lower ? "csr" : "Csr"));
  157. case kSparseCsc:
  158. return (upper ? "CSC" : (lower ? "csc" : "Csc"));
  159. case kSparseBsr:
  160. return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
  161. case kSparseBsc:
  162. return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
  163. default:
  164. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  165. return "";
  166. }
  167. }
  168. inline bool isCompressedRow(Layout layout) {
  169. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  170. layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
  171. }
  172. inline bool isCompressedColumn(Layout layout) {
  173. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  174. layout,
  175. "isCompressedColumn",
  176. [&] { return false; },
  177. [&] { return true; });
  178. }
  179. inline std::string compressedIndicesName(Layout layout) {
  180. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  181. layout,
  182. "compressedIndicesName",
  183. [&] { return "crow_indices"; },
  184. [&] { return "ccol_indices"; });
  185. }
  186. inline std::string plainIndicesName(Layout layout) {
  187. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  188. layout,
  189. "plainIndicesName",
  190. [&] { return "col_indices"; },
  191. [&] { return "row_indices"; });
  192. }
  193. inline std::string compressedDimName(Layout layout) {
  194. switch (layout) {
  195. case kSparseCsr:
  196. return "row";
  197. case kSparseCsc:
  198. return "column";
  199. case kSparseBsr:
  200. return "row block";
  201. case kSparseBsc:
  202. return "column block";
  203. default:
  204. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  205. return "";
  206. }
  207. }
  208. inline std::string plainDimName(Layout layout) {
  209. switch (layout) {
  210. case kSparseCsr:
  211. return "column";
  212. case kSparseCsc:
  213. return "row";
  214. case kSparseBsr:
  215. return "column block";
  216. case kSparseBsc:
  217. return "row block";
  218. default:
  219. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  220. return "";
  221. }
  222. }
  223. inline size_t rowDimension(Layout layout, IntArrayRef size) {
  224. return size.size() - (isCompressedRow(layout) ? 2 : 1);
  225. }
  226. inline size_t columnDimension(Layout layout, IntArrayRef size) {
  227. return size.size() - (isCompressedColumn(layout) ? 2 : 1);
  228. }
  229. inline size_t compressedDimension(
  230. Layout layout,
  231. IntArrayRef size,
  232. size_t dense_ndim = 0) {
  233. return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
  234. }
  235. inline size_t plainDimension(
  236. Layout layout,
  237. IntArrayRef size,
  238. size_t dense_ndim = 0) {
  239. return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
  240. }
  241. inline int64_t numBatchDimensions(Tensor const& self) {
  242. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  243. self.layout(),
  244. "numBatchDimensions",
  245. [&self] { return self.crow_indices().dim() - 1; },
  246. [&self] { return self.ccol_indices().dim() - 1; });
  247. }
  248. inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
  249. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  250. self.layout(),
  251. "getCompressedPlainIndices",
  252. [&self] {
  253. return std::make_pair(self.crow_indices(), self.col_indices());
  254. },
  255. [&self] {
  256. return std::make_pair(self.ccol_indices(), self.row_indices());
  257. });
  258. }
  259. inline ScalarType getIndexDtype(Tensor const& self) {
  260. switch (self.layout()) {
  261. case kSparseCsr:
  262. case kSparseBsr:
  263. return self.crow_indices().scalar_type();
  264. case kSparseCsc:
  265. case kSparseBsc:
  266. return self.ccol_indices().scalar_type();
  267. case kSparse:
  268. return self._indices().scalar_type();
  269. default:
  270. return ScalarType::Long;
  271. }
  272. }
  273. inline Layout flip_compressed_layout(Layout layout) {
  274. switch (layout) {
  275. case kSparseCsr:
  276. return kSparseCsc;
  277. case kSparseCsc:
  278. return kSparseCsr;
  279. case kSparseBsr:
  280. return kSparseBsc;
  281. case kSparseBsc:
  282. return kSparseBsr;
  283. default:
  284. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  285. return kSparseCsr;
  286. }
  287. }
  288. inline DimVector getBlockSize(Tensor const& self) {
  289. int64_t n_batch = numBatchDimensions(self);
  290. return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
  291. }
  292. inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
  293. if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
  294. int64_t n_batch = numBatchDimensions(self);
  295. return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
  296. } else {
  297. return {};
  298. }
  299. }
  300. template <typename binary_op_t, typename binary_op_out_t>
  301. inline bool only_sparse_compressed_binary_op_trivial_cases(
  302. const Tensor& self,
  303. const Tensor& other,
  304. const Scalar& alpha,
  305. Tensor& out,
  306. const binary_op_t& binary_op,
  307. const binary_op_out_t& binary_op_out) {
  308. // Only sparse compressed! Just like the name says :)
  309. TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self));
  310. TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other));
  311. TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out));
  312. // Bypass BLAS if there are matches in (self, other, out)
  313. if (self.is_same(out) && self.is_same(other)) {
  314. binary_op_out(self.values(), other.values(), alpha);
  315. return true;
  316. }
  317. if (self.is_same(other)) {
  318. auto [compressed_indices, plain_indices] =
  319. at::sparse_csr::getCompressedPlainIndices(self);
  320. static_cast<SparseCsrTensorImpl*>(out.unsafeGetTensorImpl())
  321. ->set_member_tensors(
  322. compressed_indices,
  323. plain_indices,
  324. binary_op(self.values(), other.values(), alpha),
  325. self.sizes());
  326. return true;
  327. }
  328. return false;
  329. }
  330. inline bool only_sparse_compressed_add_trivial_cases(
  331. const Tensor& self,
  332. const Tensor& other,
  333. const Scalar& alpha,
  334. Tensor& out) {
  335. return only_sparse_compressed_binary_op_trivial_cases(
  336. self,
  337. other,
  338. alpha,
  339. out,
  340. [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
  341. return v1.add(v2, alpha);
  342. },
  343. [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
  344. return v1.add_(v2, alpha);
  345. });
  346. }
  347. inline Tensor to_type(const Tensor& input, ScalarType dtype) {
  348. auto [compressed_indices, plain_indices] =
  349. at::sparse_csr::getCompressedPlainIndices(input);
  350. return at::_sparse_compressed_tensor_unsafe(
  351. compressed_indices,
  352. plain_indices,
  353. std::move(input.values()).to(dtype),
  354. input.sizes(),
  355. dtype,
  356. input.layout(),
  357. input.device(),
  358. input.options().pinned_memory_opt());
  359. }
  360. template <typename acc_t, typename scalar_t>
  361. inline std::tuple<Tensor, Tensor> create_acc_buffer(
  362. TensorOptions option,
  363. ScalarType type,
  364. int64_t nnz = -1) {
  365. Tensor new_values, new_values_acc;
  366. constexpr bool need_acc = !std::is_same_v<scalar_t, acc_t>;
  367. bool is_integral = at::isIntegralType(type, /*includeBool=*/true);
  368. if constexpr (need_acc) {
  369. auto acc_dtype = CppTypeToScalarType<acc_t>::value;
  370. new_values_acc = at::empty({}, option.dtype(acc_dtype));
  371. new_values = is_integral ? new_values_acc : at::empty({}, option);
  372. } else {
  373. new_values = new_values_acc = at::empty({}, option);
  374. }
  375. if (nnz != -1) {
  376. return std::make_tuple(
  377. new_values.resize_(nnz), new_values_acc.resize_(nnz));
  378. } else {
  379. return std::make_tuple(new_values, new_values_acc);
  380. }
  381. }
  382. inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) {
  383. if (!new_values_acc.is_same(new_values)) {
  384. new_values.copy_(new_values_acc);
  385. }
  386. }
  387. } // namespace at::sparse_csr