| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633 |
- # mypy: allow-untyped-defs
- import warnings
- from collections import namedtuple
- from typing import Any, Optional, Tuple, List, Callable, Dict
- import torch
- from torch.sparse._semi_structured_conversions import (
- sparse_semi_structured_from_dense_cutlass,
- sparse_semi_structured_to_dense_cutlass
- )
- from torch.sparse._semi_structured_ops import (
- fallback_dispatcher,
- semi_sparse_values,
- semi_sparse_indices,
- semi_sparse_detach,
- semi_sparse_t,
- semi_sparse_view,
- semi_sparse_mm,
- semi_sparse_addmm,
- semi_sparse_linear,
- )
- __all__ = [
- "SparseSemiStructuredTensor",
- "SparseSemiStructuredTensorCUTLASS",
- "SparseSemiStructuredTensorCUSPARSELT",
- "to_sparse_semi_structured",
- ]
- _SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
- "_SEMI_STRUCTURED_SPARSE_CONFIG",
- "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
- )
- class SparseSemiStructuredTensor(torch.Tensor):
- """
- This class implementes semi-structured sparsity as a Tensor subclass.
- Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
- depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
- structured sparsity.
- There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
- This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
- and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
- Note that as such, this class cannot be insantiated directly.
- -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
- - `def from_dense()` - backend specific compression routines
- - `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
- """
- _DEFAULT_ALG_ID: int = 0
- _DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
- _FORCE_CUTLASS: bool = True
- _FUSE_TRANSPOSE: bool = False
- _PROTOTYPE_WARNING_SHOWN: bool = False
- BACKEND: str
- SPARSE_DISPATCH: Dict[Callable, Callable]
- packed: Optional[torch.Tensor]
- meta: Optional[torch.Tensor]
- packed_t: Optional[torch.Tensor]
- meta_t: Optional[torch.Tensor]
- compressed_swizzled_bitmask: Optional[torch.Tensor]
- fuse_transpose_cusparselt: bool
- alg_id_cusparselt: int
- __slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
- @staticmethod
- def __new__( # noqa: PYI034
- cls,
- shape: torch.Size,
- packed: Optional[torch.Tensor],
- meta: Optional[torch.Tensor],
- packed_t: Optional[torch.Tensor],
- meta_t: Optional[torch.Tensor],
- compressed_swizzled_bitmask: Optional[torch.Tensor],
- fuse_transpose_cusparselt: bool = False,
- alg_id_cusparselt: int = 0,
- requires_grad: bool = False,
- ):
- """
- Create a new instance of the tensor subclass from the compressed sparse representation.
- We have the option to create the subclass with the compressed representations of both X and X', for training.
- For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
- Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
- Args:
- shape: The shape of the original dense tensor
- packed: The compressed representation of the original dense tensor
- meta: The metadata of the original dense tensor, if it is stored separately
- packed_t: The compressed representation of the transposed original dense tensor
- meta_t: The metadata of the transposed original dense tensor, if it is stored separately
- compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
- participate in the computation. Used for pointwise ops.
- fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
- with a matmul, which is useful in the case of 2:4 sparse training.
- alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
- Returns:
- torch.Tensor: A torch.Tensor wrapper subclass.
- Raises:
- ValueError: If all of the tensor arguments are None.
- """
- if not cls._PROTOTYPE_WARNING_SHOWN:
- warnings.warn(
- (
- "The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
- "and will change in the near future. Please open a Github issue "
- "for features requests and see our documentation on the torch.sparse "
- "module for further information about the project."
- ),
- UserWarning,
- )
- cls._PROTOTYPE_WARNING_SHOWN = True
- # Because this only runs onces, we also load the dispatch table here as well.
- # We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
- # But this is useful since it allows users to overload the dispatch table for debugging / testing.
- cls._load_dispatch_table()
- # we can also register the classes with dynamo when the warning is shown.
- torch._dynamo.allow_in_graph(cls)
- if packed is not None:
- previous_tensor = packed
- elif packed_t is not None:
- previous_tensor = packed_t
- else:
- raise ValueError("At least one of packed or packed_t must be provided")
- kwargs = {
- "device": previous_tensor.device,
- "dtype": previous_tensor.dtype,
- "layout": previous_tensor.layout,
- "requires_grad": requires_grad,
- }
- tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
- tensor.packed = packed
- tensor.meta = meta
- tensor.packed_t = packed_t
- tensor.meta_t = meta_t
- tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
- tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
- tensor.alg_id_cusparselt = alg_id_cusparselt
- return tensor
- def __repr__(self) -> str: # type: ignore[override]
- assert hasattr(self, "shape")
- return f"{self.__class__.__name__}(shape={self.shape})"
- def __tensor_flatten__(
- self,
- ) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]:
- inner_tensors = list(
- filter(lambda x: getattr(self, x) is not None, self.__slots__)
- )
- tensor_meta = (
- self.shape,
- self.fuse_transpose_cusparselt,
- self.alg_id_cusparselt,
- self.requires_grad,
- )
- return inner_tensors, tensor_meta
- @classmethod
- def __tensor_unflatten__(
- cls,
- inner_tensors,
- tensor_meta : Tuple[torch.Size, bool, int, bool],
- outer_size,
- outer_stride,
- ) -> torch.Tensor:
- shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
- return cls(
- shape=shape,
- packed=inner_tensors.get("packed", None),
- meta=inner_tensors.get("meta", None),
- packed_t=inner_tensors.get("packed_t", None),
- meta_t=inner_tensors.get("meta_t", None),
- compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
- fuse_transpose_cusparselt=fuse_transpose_cusparselt,
- alg_id_cusparselt=alg_id_cusparselt,
- requires_grad=requires_grad,
- )
- __torch_function__ = torch._C._disabled_torch_function_impl
- @classmethod
- def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
- if func._overloadpacket not in cls.SPARSE_DISPATCH:
- raise NotImplementedError(
- f"{cls.__name__} only supports a specific set of operations, "
- f"can't perform requested op ({func.__name__})"
- )
- return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
- @classmethod
- def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
- """
- Loads the op overload sparse dispatch table for the current class.
- """
- if getattr(cls, "SPARSE_DISPATCH", None) is None:
- cls.SPARSE_DISPATCH = {
- torch.ops.aten.values: semi_sparse_values,
- torch.ops.aten.indices: semi_sparse_indices,
- torch.ops.aten.is_same_size: fallback_dispatcher,
- torch.ops.aten.detach_: fallback_dispatcher,
- torch.ops.aten.detach: semi_sparse_detach,
- torch.ops.aten.t: semi_sparse_t,
- torch.ops.aten.view: semi_sparse_view,
- torch.ops.aten.mm: semi_sparse_mm,
- torch.ops.aten.matmul: semi_sparse_mm,
- torch.ops.aten.addmm: semi_sparse_addmm,
- torch.ops.aten.linear: semi_sparse_linear,
- torch.ops.aten._to_copy: fallback_dispatcher,
- }
- if custom_dispatch_table is not None:
- cls.SPARSE_DISPATCH.update(custom_dispatch_table)
- @classmethod
- def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None:
- """
- Assert that the given tensor is valid for semi-structured sparse compression.
- """
- # check device
- if not original_tensor.is_cuda:
- raise RuntimeError(
- f"Error original_tensor.device= {original_tensor.device} is not supported! "
- "Only CUDA tensors are currently supported."
- )
- # check dim
- if original_tensor.dim() != 2:
- raise RuntimeError(
- f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
- "Only 2d tensors are currently supported."
- )
- # check contiguous
- if not original_tensor.is_contiguous():
- raise RuntimeError(
- "Error original_tensor is not contiguous!"
- "Only contiguous tensors are currently supported."
- )
- # check dtype
- if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
- raise RuntimeError(
- f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
- "dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
- )
- # check shape
- m, n = original_tensor.shape
- min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
- min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
- if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
- # TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
- raise RuntimeError(
- f"Error original_tensor.shape {original_tensor.shape} is not supported! "
- f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
- )
- @classmethod
- def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
- """
- Calculates padding for dense tensor and pads tensor if necessary.
- If padding is not required, this function returns the original tensor.
- """
- # only 2d matmul
- assert dense_input.dim() == 2
- # check shape
- m, n = dense_input.shape
- min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
- min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
- # calculate padding
- to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
- to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
- if to_pad_m or to_pad_n:
- return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
- else:
- return dense_input
- def to_dense(self):
- col = self.shape[-1]
- return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
- @classmethod
- def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor":
- raise NotImplementedError
- def _mm(
- self,
- B: torch.Tensor,
- *,
- bias: Optional[torch.Tensor] = None,
- **kwargs,
- ) -> torch.Tensor:
- raise NotImplementedError
- def to_sparse_semi_structured(
- original_tensor: torch.Tensor,
- transposed: bool = False,
- ) -> SparseSemiStructuredTensor:
- """
- This function converts a dense tensor into a sparse semi-structured tensor.
- It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
- This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
- We currently only support semi-structured sparse tensors for 2d CUDA tensors.
- Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in
- `_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
- Args:
- original_tensor (Tensor): the dense tensor to convert
- transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
- Returns:
- SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
- Raises:
- None
- Example:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
- tensor([[0., 0., 1., ..., 0., 1., 1.],
- [0., 0., 1., ..., 0., 1., 1.],
- [0., 0., 1., ..., 0., 1., 1.],
- ...,
- [0., 0., 1., ..., 0., 1., 1.],
- [0., 0., 1., ..., 0., 1., 1.],
- [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
- >>> A_sparse = to_sparse_semi_structured(A)
- SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
- >>> A_sparse.values()
- tensor([[1., 1., 1., ..., 1., 1., 1.],
- [1., 1., 1., ..., 1., 1., 1.],
- [1., 1., 1., ..., 1., 1., 1.],
- ...,
- [1., 1., 1., ..., 1., 1., 1.],
- [1., 1., 1., ..., 1., 1., 1.],
- [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
- >>> A_sparse.indices()
- tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
- [-4370, -4370, -4370, ..., -4370, -4370, -4370],
- [-4370, -4370, -4370, ..., -4370, -4370, -4370],
- ...,
- [-4370, -4370, -4370, ..., -4370, -4370, -4370],
- [-4370, -4370, -4370, ..., -4370, -4370, -4370],
- [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
- """
- if transposed:
- warnings.warn(
- "Setting transpose from `to_sparse_semi_structured` is deprecated "
- "and will be removed in a future release. "
- "`SparseSemiStructuredTensor` only support contiguous input tensors.",
- FutureWarning,
- stacklevel=2,
- )
- # set from _FORCE_CUTLASS flag
- SPARSE_SUBCLASS = (
- torch.sparse.SparseSemiStructuredTensorCUTLASS
- if SparseSemiStructuredTensor._FORCE_CUTLASS
- else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
- )
- return SPARSE_SUBCLASS.from_dense(original_tensor)
- class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
- """
- This class implements semi-structured sparsity for the CUTLASS backend.
- In this implementation, the specified elements and metadata are stored seprately,
- in packed and meta respectively.
- When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
- sparse_semi_structured_from_dense for conversion to the compressed format.
- """
- BACKEND = "cutlass"
- _DTYPE_SHAPE_CONSTRAINTS = {
- torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
- torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
- torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
- torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
- }
- @classmethod
- def from_dense(
- cls, original_tensor: torch.Tensor
- ) -> "SparseSemiStructuredTensorCUTLASS":
- cls._validate_device_dim_dtype_shape(original_tensor)
- (
- sparse_tensor_cutlass,
- meta_tensor_cutlass,
- ) = sparse_semi_structured_from_dense_cutlass(original_tensor)
- return cls(
- original_tensor.shape,
- packed=sparse_tensor_cutlass,
- meta=meta_tensor_cutlass,
- packed_t=None,
- meta_t=None,
- compressed_swizzled_bitmask=None,
- requires_grad=original_tensor.requires_grad,
- )
- def to_dense(self):
- assert self.meta is not None and self.packed is not None
- return sparse_semi_structured_to_dense_cutlass(
- self.packed,
- self.meta,
- ) if self.meta.ndim == 2 else super().to_dense()
- @classmethod
- def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
- """
- This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
- It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
- The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
- Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
- It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
- pruned dense tensor.
- Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
- Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
- This can be used in the backward pass to mask the gradients.
- [9 1 7 4] [9 0 7 0]
- [1 2 3 0] [0 2 0 0]
- [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
- [1 2 6 2] [0 0 6 2] -> metadata
- -> pack to transposed CUTLASS -> packed_t
- semi-structured representation -> metadata_t
- -> compute swizzled bitmask -> compressed_swizzled_bitmask
- The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
- ```
- from torch.sparse import SparseSemiStructuredTensorCUTLASS
- from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
- pruned = _sparse_semi_structured_tile(dense)
- packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
- packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
- bitmask = _compute_compressed_swizzled_bitmask(pruned)
- SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
- ```
- """
- # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
- (packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
- original_tensor,
- algorithm=algorithm,
- use_cutlass=True)
- return cls(
- original_tensor.shape,
- packed=packed,
- meta=meta,
- packed_t=packed_t,
- meta_t=meta_t,
- compressed_swizzled_bitmask=compressed_swizzled_bitmask,
- requires_grad=False,
- )
- def _mm(
- self,
- B: torch.Tensor,
- *,
- bias: Optional[torch.Tensor] = None,
- **kwargs
- ) -> torch.Tensor:
- if isinstance(B, SparseSemiStructuredTensor):
- raise ValueError(
- "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
- )
- cls_name = self.__class__.__name__
- if self.ndim != 2 or B.ndim != 2:
- raise NotImplementedError(
- f"`{cls_name}` matmul: Broadcasting is not implemented"
- )
- if self.packed is None or self.meta is None:
- raise NotImplementedError(
- f"`{cls_name}` matmul: operation is not supported"
- )
- else:
- if bias is None:
- res = torch._sparse_semi_structured_mm(
- self.packed, self.meta, B
- )
- else:
- res = torch._sparse_semi_structured_addmm(
- bias, self.packed, self.meta, B
- )
- return res[: self.shape[0]]
- class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
- """
- The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
- packed = [ specified elements of original tensor | metadata ]
- For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
- The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
- attributes respectively.
- cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
- as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
- """
- BACKEND = "cusparselt"
- _DTYPE_SHAPE_CONSTRAINTS = {
- torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
- torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
- torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
- torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(8, 8, 4, 4),
- }
- @classmethod
- def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT":
- cls._validate_device_dim_dtype_shape(original_tensor)
- return cls(
- shape=original_tensor.shape,
- packed=torch._cslt_compress(original_tensor),
- meta=None,
- packed_t=None,
- meta_t=None,
- compressed_swizzled_bitmask=None,
- fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
- alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
- requires_grad=original_tensor.requires_grad,
- )
- @classmethod
- def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
- """
- This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
- layout and sparse matmul.
- The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
- [9 1 7 4] [9 0 7 0]
- [1 2 3 0] [0 2 0 0]
- [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
- [1 2 6 2] [0 0 6 2]
- -> pack to transposed cuSPARSELt -> packed_t
- semi-structured representation
- -> compute swizzled bitmask -> compressed_swizzled_bitmask
- The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
- ```
- from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
- from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
- pruned = _sparse_semi_structured_tile(dense)
- packed_cusparselt = torch._cslt_compress(pruned)
- packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
- bitmask = _compute_compressed_swizzled_bitmask(pruned)
- SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
- ```
- """
- (packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
- original_tensor,
- algorithm=algorithm,
- use_cutlass=False)
- return cls(
- original_tensor.shape,
- packed=packed,
- meta=meta,
- packed_t=packed_t,
- meta_t=meta_t,
- compressed_swizzled_bitmask=compressed_swizzled_bitmask,
- requires_grad=False,
- )
- def _mm(
- self,
- B: torch.Tensor,
- *,
- bias: Optional[torch.Tensor] = None,
- **kwargs
- ) -> torch.Tensor:
- if isinstance(B, SparseSemiStructuredTensor):
- raise ValueError(
- "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
- )
- if self.ndim != 2 or B.ndim != 2:
- raise NotImplementedError(
- f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
- )
- if B.dtype != self.dtype:
- raise NotImplementedError(
- f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
- f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
- "This operation is only supported when A and B have the same data type."
- )
- if bias is not None and bias.dtype != self.dtype:
- raise NotImplementedError(
- f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
- "with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
- "This operation is only supported when A, B and C have the same data type."
- )
- if self.packed is None:
- raise NotImplementedError(
- f"`{self.__class__.__name__}` matmul: operation is not supported"
- )
- else:
- res = torch._cslt_sparse_mm(
- self.packed,
- B,
- bias=bias,
- transpose_result=self.fuse_transpose_cusparselt,
- alg_id=self.alg_id_cusparselt,
- )
- return res.t() if self.fuse_transpose_cusparselt else res
|