| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426 |
- #pragma once
- #include <ATen/SparseCsrTensorImpl.h>
- #include <ATen/SparseTensorImpl.h>
- #include <ATen/core/Tensor.h>
- #ifndef AT_PER_OPERATOR_HEADERS
- #include <ATen/Functions.h>
- #include <ATen/NativeFunctions.h>
- #include <ATen/Operators.h>
- #else
- #include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
- #include <ATen/ops/resize_as_sparse_native.h>
- #endif
- #define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
- [&] { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsr: \
- case kSparseCsc: \
- case kSparseBsr: \
- case kSparseBsc: \
- return __VA_ARGS__(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse compressed tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
- LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsr: \
- case kSparseBsr: \
- return (ROW_DIM_ACTION)(); \
- case kSparseCsc: \
- case kSparseBsc: \
- return (COLUMN_DIM_ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse compressed tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
- LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsr: \
- case kSparseCsc: \
- return (NO_BLOCK_ACTION)(); \
- case kSparseBsr: \
- case kSparseBsc: \
- return (BLOCK_ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse compressed tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
- LAYOUT, NAME, ROW_DIM_ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsr: \
- case kSparseBsr: \
- return (ROW_DIM_ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse row compressed tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
- LAYOUT, NAME, COL_DIM_ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsc: \
- case kSparseBsc: \
- return (COL_DIM_ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse column compressed tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseCsr: \
- case kSparseCsc: \
- return (ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse compressed (non-block) tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
- [&]() { \
- const auto& the_layout = LAYOUT; \
- switch (the_layout) { \
- case kSparseBsr: \
- case kSparseBsc: \
- return (ACTION)(); \
- default: \
- AT_ERROR( \
- NAME, \
- " expected sparse compressed block tensor layout but got ", \
- the_layout); \
- } \
- }()
- #define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
- AT_DISPATCH_SWITCH( \
- TYPE, \
- NAME, \
- AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
- kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
- namespace at::sparse_csr {
- using SparseCsrTensor = Tensor;
- inline bool is_sparse_compressed(const Layout& layout) {
- switch (layout) {
- case kSparseCsr:
- case kSparseCsc:
- case kSparseBsr:
- case kSparseBsc:
- return true;
- default:;
- }
- return false;
- }
- inline bool is_sparse_compressed(const Tensor& self) {
- return is_sparse_compressed(self.layout());
- }
- inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
- AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
- self.layout(), "get_sparse_csr_impl", [&] {});
- return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
- }
- inline std::string layoutToString(
- Layout layout,
- bool upper = false,
- bool lower = false) {
- switch (layout) {
- case kSparseCsr:
- return (upper ? "CSR" : (lower ? "csr" : "Csr"));
- case kSparseCsc:
- return (upper ? "CSC" : (lower ? "csc" : "Csc"));
- case kSparseBsr:
- return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
- case kSparseBsc:
- return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
- default:
- TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
- return "";
- }
- }
- inline bool isCompressedRow(Layout layout) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
- }
- inline bool isCompressedColumn(Layout layout) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- layout,
- "isCompressedColumn",
- [&] { return false; },
- [&] { return true; });
- }
- inline std::string compressedIndicesName(Layout layout) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- layout,
- "compressedIndicesName",
- [&] { return "crow_indices"; },
- [&] { return "ccol_indices"; });
- }
- inline std::string plainIndicesName(Layout layout) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- layout,
- "plainIndicesName",
- [&] { return "col_indices"; },
- [&] { return "row_indices"; });
- }
- inline std::string compressedDimName(Layout layout) {
- switch (layout) {
- case kSparseCsr:
- return "row";
- case kSparseCsc:
- return "column";
- case kSparseBsr:
- return "row block";
- case kSparseBsc:
- return "column block";
- default:
- TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
- return "";
- }
- }
- inline std::string plainDimName(Layout layout) {
- switch (layout) {
- case kSparseCsr:
- return "column";
- case kSparseCsc:
- return "row";
- case kSparseBsr:
- return "column block";
- case kSparseBsc:
- return "row block";
- default:
- TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
- return "";
- }
- }
- inline size_t rowDimension(Layout layout, IntArrayRef size) {
- return size.size() - (isCompressedRow(layout) ? 2 : 1);
- }
- inline size_t columnDimension(Layout layout, IntArrayRef size) {
- return size.size() - (isCompressedColumn(layout) ? 2 : 1);
- }
- inline size_t compressedDimension(
- Layout layout,
- IntArrayRef size,
- size_t dense_ndim = 0) {
- return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
- }
- inline size_t plainDimension(
- Layout layout,
- IntArrayRef size,
- size_t dense_ndim = 0) {
- return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
- }
- inline int64_t numBatchDimensions(Tensor const& self) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- self.layout(),
- "numBatchDimensions",
- [&self] { return self.crow_indices().dim() - 1; },
- [&self] { return self.ccol_indices().dim() - 1; });
- }
- inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
- return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
- self.layout(),
- "getCompressedPlainIndices",
- [&self] {
- return std::make_pair(self.crow_indices(), self.col_indices());
- },
- [&self] {
- return std::make_pair(self.ccol_indices(), self.row_indices());
- });
- }
- inline ScalarType getIndexDtype(Tensor const& self) {
- switch (self.layout()) {
- case kSparseCsr:
- case kSparseBsr:
- return self.crow_indices().scalar_type();
- case kSparseCsc:
- case kSparseBsc:
- return self.ccol_indices().scalar_type();
- case kSparse:
- return self._indices().scalar_type();
- default:
- return ScalarType::Long;
- }
- }
- inline Layout flip_compressed_layout(Layout layout) {
- switch (layout) {
- case kSparseCsr:
- return kSparseCsc;
- case kSparseCsc:
- return kSparseCsr;
- case kSparseBsr:
- return kSparseBsc;
- case kSparseBsc:
- return kSparseBsr;
- default:
- TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
- return kSparseCsr;
- }
- }
- inline DimVector getBlockSize(Tensor const& self) {
- int64_t n_batch = numBatchDimensions(self);
- return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
- }
- inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
- if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
- int64_t n_batch = numBatchDimensions(self);
- return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
- } else {
- return {};
- }
- }
- template <typename binary_op_t, typename binary_op_out_t>
- inline bool only_sparse_compressed_binary_op_trivial_cases(
- const Tensor& self,
- const Tensor& other,
- const Scalar& alpha,
- Tensor& out,
- const binary_op_t& binary_op,
- const binary_op_out_t& binary_op_out) {
- // Only sparse compressed! Just like the name says :)
- TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self));
- TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other));
- TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out));
- // Bypass BLAS if there are matches in (self, other, out)
- if (self.is_same(out) && self.is_same(other)) {
- binary_op_out(self.values(), other.values(), alpha);
- return true;
- }
- if (self.is_same(other)) {
- auto [compressed_indices, plain_indices] =
- at::sparse_csr::getCompressedPlainIndices(self);
- static_cast<SparseCsrTensorImpl*>(out.unsafeGetTensorImpl())
- ->set_member_tensors(
- compressed_indices,
- plain_indices,
- binary_op(self.values(), other.values(), alpha),
- self.sizes());
- return true;
- }
- return false;
- }
- inline bool only_sparse_compressed_add_trivial_cases(
- const Tensor& self,
- const Tensor& other,
- const Scalar& alpha,
- Tensor& out) {
- return only_sparse_compressed_binary_op_trivial_cases(
- self,
- other,
- alpha,
- out,
- [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
- return v1.add(v2, alpha);
- },
- [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
- return v1.add_(v2, alpha);
- });
- }
- inline Tensor to_type(const Tensor& input, ScalarType dtype) {
- auto [compressed_indices, plain_indices] =
- at::sparse_csr::getCompressedPlainIndices(input);
- return at::_sparse_compressed_tensor_unsafe(
- compressed_indices,
- plain_indices,
- std::move(input.values()).to(dtype),
- input.sizes(),
- dtype,
- input.layout(),
- input.device(),
- input.options().pinned_memory_opt());
- }
- template <typename acc_t, typename scalar_t>
- inline std::tuple<Tensor, Tensor> create_acc_buffer(
- TensorOptions option,
- ScalarType type,
- int64_t nnz = -1) {
- Tensor new_values, new_values_acc;
- constexpr bool need_acc = !std::is_same_v<scalar_t, acc_t>;
- bool is_integral = at::isIntegralType(type, /*includeBool=*/true);
- if constexpr (need_acc) {
- auto acc_dtype = CppTypeToScalarType<acc_t>::value;
- new_values_acc = at::empty({}, option.dtype(acc_dtype));
- new_values = is_integral ? new_values_acc : at::empty({}, option);
- } else {
- new_values = new_values_acc = at::empty({}, option);
- }
- if (nnz != -1) {
- return std::make_tuple(
- new_values.resize_(nnz), new_values_acc.resize_(nnz));
- } else {
- return std::make_tuple(new_values, new_values_acc);
- }
- }
- inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) {
- if (!new_values_acc.is_same(new_values)) {
- new_values.copy_(new_values_acc);
- }
- }
- } // namespace at::sparse_csr
|