semi_structured.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from collections import namedtuple
  4. from typing import Any, Optional, Tuple, List, Callable, Dict
  5. import torch
  6. from torch.sparse._semi_structured_conversions import (
  7. sparse_semi_structured_from_dense_cutlass,
  8. sparse_semi_structured_to_dense_cutlass
  9. )
  10. from torch.sparse._semi_structured_ops import (
  11. fallback_dispatcher,
  12. semi_sparse_values,
  13. semi_sparse_indices,
  14. semi_sparse_detach,
  15. semi_sparse_t,
  16. semi_sparse_view,
  17. semi_sparse_mm,
  18. semi_sparse_addmm,
  19. semi_sparse_linear,
  20. )
  21. __all__ = [
  22. "SparseSemiStructuredTensor",
  23. "SparseSemiStructuredTensorCUTLASS",
  24. "SparseSemiStructuredTensorCUSPARSELT",
  25. "to_sparse_semi_structured",
  26. ]
  27. _SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
  28. "_SEMI_STRUCTURED_SPARSE_CONFIG",
  29. "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
  30. )
  31. class SparseSemiStructuredTensor(torch.Tensor):
  32. """
  33. This class implementes semi-structured sparsity as a Tensor subclass.
  34. Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
  35. depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
  36. structured sparsity.
  37. There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
  38. This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
  39. and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
  40. Note that as such, this class cannot be insantiated directly.
  41. -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
  42. - `def from_dense()` - backend specific compression routines
  43. - `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
  44. """
  45. _DEFAULT_ALG_ID: int = 0
  46. _DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
  47. _FORCE_CUTLASS: bool = True
  48. _FUSE_TRANSPOSE: bool = False
  49. _PROTOTYPE_WARNING_SHOWN: bool = False
  50. BACKEND: str
  51. SPARSE_DISPATCH: Dict[Callable, Callable]
  52. packed: Optional[torch.Tensor]
  53. meta: Optional[torch.Tensor]
  54. packed_t: Optional[torch.Tensor]
  55. meta_t: Optional[torch.Tensor]
  56. compressed_swizzled_bitmask: Optional[torch.Tensor]
  57. fuse_transpose_cusparselt: bool
  58. alg_id_cusparselt: int
  59. __slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
  60. @staticmethod
  61. def __new__( # noqa: PYI034
  62. cls,
  63. shape: torch.Size,
  64. packed: Optional[torch.Tensor],
  65. meta: Optional[torch.Tensor],
  66. packed_t: Optional[torch.Tensor],
  67. meta_t: Optional[torch.Tensor],
  68. compressed_swizzled_bitmask: Optional[torch.Tensor],
  69. fuse_transpose_cusparselt: bool = False,
  70. alg_id_cusparselt: int = 0,
  71. requires_grad: bool = False,
  72. ):
  73. """
  74. Create a new instance of the tensor subclass from the compressed sparse representation.
  75. We have the option to create the subclass with the compressed representations of both X and X', for training.
  76. For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
  77. Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
  78. Args:
  79. shape: The shape of the original dense tensor
  80. packed: The compressed representation of the original dense tensor
  81. meta: The metadata of the original dense tensor, if it is stored separately
  82. packed_t: The compressed representation of the transposed original dense tensor
  83. meta_t: The metadata of the transposed original dense tensor, if it is stored separately
  84. compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
  85. participate in the computation. Used for pointwise ops.
  86. fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
  87. with a matmul, which is useful in the case of 2:4 sparse training.
  88. alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
  89. Returns:
  90. torch.Tensor: A torch.Tensor wrapper subclass.
  91. Raises:
  92. ValueError: If all of the tensor arguments are None.
  93. """
  94. if not cls._PROTOTYPE_WARNING_SHOWN:
  95. warnings.warn(
  96. (
  97. "The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
  98. "and will change in the near future. Please open a Github issue "
  99. "for features requests and see our documentation on the torch.sparse "
  100. "module for further information about the project."
  101. ),
  102. UserWarning,
  103. )
  104. cls._PROTOTYPE_WARNING_SHOWN = True
  105. # Because this only runs onces, we also load the dispatch table here as well.
  106. # We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
  107. # But this is useful since it allows users to overload the dispatch table for debugging / testing.
  108. cls._load_dispatch_table()
  109. # we can also register the classes with dynamo when the warning is shown.
  110. torch._dynamo.allow_in_graph(cls)
  111. if packed is not None:
  112. previous_tensor = packed
  113. elif packed_t is not None:
  114. previous_tensor = packed_t
  115. else:
  116. raise ValueError("At least one of packed or packed_t must be provided")
  117. kwargs = {
  118. "device": previous_tensor.device,
  119. "dtype": previous_tensor.dtype,
  120. "layout": previous_tensor.layout,
  121. "requires_grad": requires_grad,
  122. }
  123. tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
  124. tensor.packed = packed
  125. tensor.meta = meta
  126. tensor.packed_t = packed_t
  127. tensor.meta_t = meta_t
  128. tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
  129. tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
  130. tensor.alg_id_cusparselt = alg_id_cusparselt
  131. return tensor
  132. def __repr__(self) -> str: # type: ignore[override]
  133. assert hasattr(self, "shape")
  134. return f"{self.__class__.__name__}(shape={self.shape})"
  135. def __tensor_flatten__(
  136. self,
  137. ) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]:
  138. inner_tensors = list(
  139. filter(lambda x: getattr(self, x) is not None, self.__slots__)
  140. )
  141. tensor_meta = (
  142. self.shape,
  143. self.fuse_transpose_cusparselt,
  144. self.alg_id_cusparselt,
  145. self.requires_grad,
  146. )
  147. return inner_tensors, tensor_meta
  148. @classmethod
  149. def __tensor_unflatten__(
  150. cls,
  151. inner_tensors,
  152. tensor_meta : Tuple[torch.Size, bool, int, bool],
  153. outer_size,
  154. outer_stride,
  155. ) -> torch.Tensor:
  156. shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
  157. return cls(
  158. shape=shape,
  159. packed=inner_tensors.get("packed", None),
  160. meta=inner_tensors.get("meta", None),
  161. packed_t=inner_tensors.get("packed_t", None),
  162. meta_t=inner_tensors.get("meta_t", None),
  163. compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
  164. fuse_transpose_cusparselt=fuse_transpose_cusparselt,
  165. alg_id_cusparselt=alg_id_cusparselt,
  166. requires_grad=requires_grad,
  167. )
  168. __torch_function__ = torch._C._disabled_torch_function_impl
  169. @classmethod
  170. def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
  171. if func._overloadpacket not in cls.SPARSE_DISPATCH:
  172. raise NotImplementedError(
  173. f"{cls.__name__} only supports a specific set of operations, "
  174. f"can't perform requested op ({func.__name__})"
  175. )
  176. return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
  177. @classmethod
  178. def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
  179. """
  180. Loads the op overload sparse dispatch table for the current class.
  181. """
  182. if getattr(cls, "SPARSE_DISPATCH", None) is None:
  183. cls.SPARSE_DISPATCH = {
  184. torch.ops.aten.values: semi_sparse_values,
  185. torch.ops.aten.indices: semi_sparse_indices,
  186. torch.ops.aten.is_same_size: fallback_dispatcher,
  187. torch.ops.aten.detach_: fallback_dispatcher,
  188. torch.ops.aten.detach: semi_sparse_detach,
  189. torch.ops.aten.t: semi_sparse_t,
  190. torch.ops.aten.view: semi_sparse_view,
  191. torch.ops.aten.mm: semi_sparse_mm,
  192. torch.ops.aten.matmul: semi_sparse_mm,
  193. torch.ops.aten.addmm: semi_sparse_addmm,
  194. torch.ops.aten.linear: semi_sparse_linear,
  195. torch.ops.aten._to_copy: fallback_dispatcher,
  196. }
  197. if custom_dispatch_table is not None:
  198. cls.SPARSE_DISPATCH.update(custom_dispatch_table)
  199. @classmethod
  200. def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None:
  201. """
  202. Assert that the given tensor is valid for semi-structured sparse compression.
  203. """
  204. # check device
  205. if not original_tensor.is_cuda:
  206. raise RuntimeError(
  207. f"Error original_tensor.device= {original_tensor.device} is not supported! "
  208. "Only CUDA tensors are currently supported."
  209. )
  210. # check dim
  211. if original_tensor.dim() != 2:
  212. raise RuntimeError(
  213. f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
  214. "Only 2d tensors are currently supported."
  215. )
  216. # check contiguous
  217. if not original_tensor.is_contiguous():
  218. raise RuntimeError(
  219. "Error original_tensor is not contiguous!"
  220. "Only contiguous tensors are currently supported."
  221. )
  222. # check dtype
  223. if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
  224. raise RuntimeError(
  225. f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
  226. "dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
  227. )
  228. # check shape
  229. m, n = original_tensor.shape
  230. min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
  231. min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
  232. if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
  233. # TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
  234. raise RuntimeError(
  235. f"Error original_tensor.shape {original_tensor.shape} is not supported! "
  236. f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
  237. )
  238. @classmethod
  239. def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
  240. """
  241. Calculates padding for dense tensor and pads tensor if necessary.
  242. If padding is not required, this function returns the original tensor.
  243. """
  244. # only 2d matmul
  245. assert dense_input.dim() == 2
  246. # check shape
  247. m, n = dense_input.shape
  248. min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
  249. min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
  250. # calculate padding
  251. to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
  252. to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
  253. if to_pad_m or to_pad_n:
  254. return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
  255. else:
  256. return dense_input
  257. def to_dense(self):
  258. col = self.shape[-1]
  259. return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
  260. @classmethod
  261. def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor":
  262. raise NotImplementedError
  263. def _mm(
  264. self,
  265. B: torch.Tensor,
  266. *,
  267. bias: Optional[torch.Tensor] = None,
  268. **kwargs,
  269. ) -> torch.Tensor:
  270. raise NotImplementedError
  271. def to_sparse_semi_structured(
  272. original_tensor: torch.Tensor,
  273. transposed: bool = False,
  274. ) -> SparseSemiStructuredTensor:
  275. """
  276. This function converts a dense tensor into a sparse semi-structured tensor.
  277. It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
  278. This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
  279. We currently only support semi-structured sparse tensors for 2d CUDA tensors.
  280. Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in
  281. `_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
  282. Args:
  283. original_tensor (Tensor): the dense tensor to convert
  284. transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
  285. Returns:
  286. SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
  287. Raises:
  288. None
  289. Example:
  290. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  291. >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
  292. tensor([[0., 0., 1., ..., 0., 1., 1.],
  293. [0., 0., 1., ..., 0., 1., 1.],
  294. [0., 0., 1., ..., 0., 1., 1.],
  295. ...,
  296. [0., 0., 1., ..., 0., 1., 1.],
  297. [0., 0., 1., ..., 0., 1., 1.],
  298. [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
  299. >>> A_sparse = to_sparse_semi_structured(A)
  300. SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
  301. >>> A_sparse.values()
  302. tensor([[1., 1., 1., ..., 1., 1., 1.],
  303. [1., 1., 1., ..., 1., 1., 1.],
  304. [1., 1., 1., ..., 1., 1., 1.],
  305. ...,
  306. [1., 1., 1., ..., 1., 1., 1.],
  307. [1., 1., 1., ..., 1., 1., 1.],
  308. [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
  309. >>> A_sparse.indices()
  310. tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
  311. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  312. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  313. ...,
  314. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  315. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  316. [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
  317. """
  318. if transposed:
  319. warnings.warn(
  320. "Setting transpose from `to_sparse_semi_structured` is deprecated "
  321. "and will be removed in a future release. "
  322. "`SparseSemiStructuredTensor` only support contiguous input tensors.",
  323. FutureWarning,
  324. stacklevel=2,
  325. )
  326. # set from _FORCE_CUTLASS flag
  327. SPARSE_SUBCLASS = (
  328. torch.sparse.SparseSemiStructuredTensorCUTLASS
  329. if SparseSemiStructuredTensor._FORCE_CUTLASS
  330. else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
  331. )
  332. return SPARSE_SUBCLASS.from_dense(original_tensor)
  333. class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
  334. """
  335. This class implements semi-structured sparsity for the CUTLASS backend.
  336. In this implementation, the specified elements and metadata are stored seprately,
  337. in packed and meta respectively.
  338. When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
  339. sparse_semi_structured_from_dense for conversion to the compressed format.
  340. """
  341. BACKEND = "cutlass"
  342. _DTYPE_SHAPE_CONSTRAINTS = {
  343. torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
  344. torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
  345. torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
  346. torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
  347. }
  348. @classmethod
  349. def from_dense(
  350. cls, original_tensor: torch.Tensor
  351. ) -> "SparseSemiStructuredTensorCUTLASS":
  352. cls._validate_device_dim_dtype_shape(original_tensor)
  353. (
  354. sparse_tensor_cutlass,
  355. meta_tensor_cutlass,
  356. ) = sparse_semi_structured_from_dense_cutlass(original_tensor)
  357. return cls(
  358. original_tensor.shape,
  359. packed=sparse_tensor_cutlass,
  360. meta=meta_tensor_cutlass,
  361. packed_t=None,
  362. meta_t=None,
  363. compressed_swizzled_bitmask=None,
  364. requires_grad=original_tensor.requires_grad,
  365. )
  366. def to_dense(self):
  367. assert self.meta is not None and self.packed is not None
  368. return sparse_semi_structured_to_dense_cutlass(
  369. self.packed,
  370. self.meta,
  371. ) if self.meta.ndim == 2 else super().to_dense()
  372. @classmethod
  373. def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
  374. """
  375. This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
  376. It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
  377. The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
  378. Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
  379. It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
  380. pruned dense tensor.
  381. Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
  382. Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
  383. This can be used in the backward pass to mask the gradients.
  384. [9 1 7 4] [9 0 7 0]
  385. [1 2 3 0] [0 2 0 0]
  386. [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
  387. [1 2 6 2] [0 0 6 2] -> metadata
  388. -> pack to transposed CUTLASS -> packed_t
  389. semi-structured representation -> metadata_t
  390. -> compute swizzled bitmask -> compressed_swizzled_bitmask
  391. The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
  392. ```
  393. from torch.sparse import SparseSemiStructuredTensorCUTLASS
  394. from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
  395. pruned = _sparse_semi_structured_tile(dense)
  396. packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
  397. packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
  398. bitmask = _compute_compressed_swizzled_bitmask(pruned)
  399. SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
  400. ```
  401. """
  402. # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
  403. (packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
  404. original_tensor,
  405. algorithm=algorithm,
  406. use_cutlass=True)
  407. return cls(
  408. original_tensor.shape,
  409. packed=packed,
  410. meta=meta,
  411. packed_t=packed_t,
  412. meta_t=meta_t,
  413. compressed_swizzled_bitmask=compressed_swizzled_bitmask,
  414. requires_grad=False,
  415. )
  416. def _mm(
  417. self,
  418. B: torch.Tensor,
  419. *,
  420. bias: Optional[torch.Tensor] = None,
  421. **kwargs
  422. ) -> torch.Tensor:
  423. if isinstance(B, SparseSemiStructuredTensor):
  424. raise ValueError(
  425. "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
  426. )
  427. cls_name = self.__class__.__name__
  428. if self.ndim != 2 or B.ndim != 2:
  429. raise NotImplementedError(
  430. f"`{cls_name}` matmul: Broadcasting is not implemented"
  431. )
  432. if self.packed is None or self.meta is None:
  433. raise NotImplementedError(
  434. f"`{cls_name}` matmul: operation is not supported"
  435. )
  436. else:
  437. if bias is None:
  438. res = torch._sparse_semi_structured_mm(
  439. self.packed, self.meta, B
  440. )
  441. else:
  442. res = torch._sparse_semi_structured_addmm(
  443. bias, self.packed, self.meta, B
  444. )
  445. return res[: self.shape[0]]
  446. class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
  447. """
  448. The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
  449. packed = [ specified elements of original tensor | metadata ]
  450. For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
  451. The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
  452. attributes respectively.
  453. cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
  454. as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
  455. """
  456. BACKEND = "cusparselt"
  457. _DTYPE_SHAPE_CONSTRAINTS = {
  458. torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
  459. torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
  460. torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
  461. torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(8, 8, 4, 4),
  462. }
  463. @classmethod
  464. def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT":
  465. cls._validate_device_dim_dtype_shape(original_tensor)
  466. return cls(
  467. shape=original_tensor.shape,
  468. packed=torch._cslt_compress(original_tensor),
  469. meta=None,
  470. packed_t=None,
  471. meta_t=None,
  472. compressed_swizzled_bitmask=None,
  473. fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
  474. alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
  475. requires_grad=original_tensor.requires_grad,
  476. )
  477. @classmethod
  478. def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
  479. """
  480. This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
  481. layout and sparse matmul.
  482. The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
  483. [9 1 7 4] [9 0 7 0]
  484. [1 2 3 0] [0 2 0 0]
  485. [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
  486. [1 2 6 2] [0 0 6 2]
  487. -> pack to transposed cuSPARSELt -> packed_t
  488. semi-structured representation
  489. -> compute swizzled bitmask -> compressed_swizzled_bitmask
  490. The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
  491. ```
  492. from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
  493. from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
  494. pruned = _sparse_semi_structured_tile(dense)
  495. packed_cusparselt = torch._cslt_compress(pruned)
  496. packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
  497. bitmask = _compute_compressed_swizzled_bitmask(pruned)
  498. SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
  499. ```
  500. """
  501. (packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
  502. original_tensor,
  503. algorithm=algorithm,
  504. use_cutlass=False)
  505. return cls(
  506. original_tensor.shape,
  507. packed=packed,
  508. meta=meta,
  509. packed_t=packed_t,
  510. meta_t=meta_t,
  511. compressed_swizzled_bitmask=compressed_swizzled_bitmask,
  512. requires_grad=False,
  513. )
  514. def _mm(
  515. self,
  516. B: torch.Tensor,
  517. *,
  518. bias: Optional[torch.Tensor] = None,
  519. **kwargs
  520. ) -> torch.Tensor:
  521. if isinstance(B, SparseSemiStructuredTensor):
  522. raise ValueError(
  523. "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
  524. )
  525. if self.ndim != 2 or B.ndim != 2:
  526. raise NotImplementedError(
  527. f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
  528. )
  529. if B.dtype != self.dtype:
  530. raise NotImplementedError(
  531. f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
  532. f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
  533. "This operation is only supported when A and B have the same data type."
  534. )
  535. if bias is not None and bias.dtype != self.dtype:
  536. raise NotImplementedError(
  537. f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
  538. "with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
  539. "This operation is only supported when A, B and C have the same data type."
  540. )
  541. if self.packed is None:
  542. raise NotImplementedError(
  543. f"`{self.__class__.__name__}` matmul: operation is not supported"
  544. )
  545. else:
  546. res = torch._cslt_sparse_mm(
  547. self.packed,
  548. B,
  549. bias=bias,
  550. transpose_result=self.fuse_transpose_cusparselt,
  551. alg_id=self.alg_id_cusparselt,
  552. )
  553. return res.t() if self.fuse_transpose_cusparselt else res