| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- # mypy: allow-untyped-defs
- import contextlib
- import torch
- __all__ = [
- "fallback_dispatcher",
- "semi_sparse_values",
- "semi_sparse_indices",
- "semi_sparse_t",
- "semi_sparse_view",
- "semi_sparse_detach",
- "semi_sparse_mm",
- "semi_sparse_addmm",
- "semi_sparse_linear",
- ]
- @contextlib.contextmanager
- def no_dispatch():
- guard = torch._C._DisableTorchDispatch()
- try:
- yield
- finally:
- del guard
- def fallback_dispatcher(func, types, args, kwargs):
- with no_dispatch():
- return func(*args)
- def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor:
- assert len(args) == 1
- A = args[0]
- assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
- assert A.packed is not None
- if A.meta is None:
- m, k = A.shape
- num_kept_elements = m * k // 2
- return A.packed[:num_kept_elements:].view(m, -1)
- else:
- return A.packed.detach()
- def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor:
- assert len(args) == 1
- A = args[0]
- assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
- assert A.packed is not None
- if A.meta is None:
- m, k = A.shape
- num_kept_elements = m * k // 2
- metadata = A.packed[num_kept_elements:].view(m, -1)
- return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16)
- else:
- return A.meta
- def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
- assert len(args) == 1
- self = args[0]
- assert isinstance(self, torch.sparse.SparseSemiStructuredTensor)
- assert len(self.shape) == 2
- # Because we cannot go from the compressed representation back to the dense representation currently,
- # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
- # is the first or second argument, we expect an even / odd number of calls to transpose respectively.
- return self.__class__(
- torch.Size([self.shape[-1], self.shape[0]]),
- packed=self.packed_t,
- meta=self.meta_t,
- packed_t=self.packed,
- meta_t=self.meta,
- compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1)
- if self.compressed_swizzled_bitmask is not None
- else None,
- fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
- alg_id_cusparselt=args[0].alg_id_cusparselt,
- )
- def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor:
- assert len(args) == 2
- self, shape = args
- if tuple(shape) != self.shape:
- raise NotImplementedError(
- f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})"
- )
- return self
- def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
- assert len(args) == 1
- self = args[0]
- return self.__class__(
- shape=self.shape,
- packed=self.packed,
- meta=self.meta,
- packed_t=self.packed_t,
- meta_t=self.meta_t,
- compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
- requires_grad=False,
- )
- def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
- assert len(args) == 2
- A, B = args
- if A.ndim != 2 or B.ndim != 2:
- raise NotImplementedError(
- "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
- )
- if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
- row, col = B.shape
- B_padded = A._pad_dense_input(B)
- res = A._mm(B_padded)
- return res[:, :col]
- else:
- B_t = B.t()
- assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
- row, col = A.shape
- A_padded = B._pad_dense_input(A)
- res = B_t._mm(A_padded.t()).t()
- return res[:row, :]
- def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor:
- assert len(args) == 3
- bias, A, B = args
- if A.ndim != 2 or B.ndim != 2:
- raise NotImplementedError(
- "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
- )
- if bias.ndim != 1:
- raise NotImplementedError(
- f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}"
- )
- if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
- raise NotImplementedError(
- "`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse"
- )
- B_t = B.t()
- assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
- row, col = A.shape
- A_padded = B_t._pad_dense_input(A)
- result = B_t._mm(A_padded.t(), bias=bias).t()
- return result[:row, :]
- def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
- assert len(args) in [2, 3]
- A, B = args[:2]
- bias = args[2] if len(args) == 3 else None
- shape = A.shape
- A_2d = A.view(-1, shape[-1])
- if bias is None:
- res = A_2d @ B.t()
- else:
- res = semi_sparse_addmm(
- func=None,
- types=None,
- args=[bias, A_2d, B.t()],
- )
- return res.view(*shape[:-1], -1)
|