_semi_structured_ops.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import torch
  4. __all__ = [
  5. "fallback_dispatcher",
  6. "semi_sparse_values",
  7. "semi_sparse_indices",
  8. "semi_sparse_t",
  9. "semi_sparse_view",
  10. "semi_sparse_detach",
  11. "semi_sparse_mm",
  12. "semi_sparse_addmm",
  13. "semi_sparse_linear",
  14. ]
  15. @contextlib.contextmanager
  16. def no_dispatch():
  17. guard = torch._C._DisableTorchDispatch()
  18. try:
  19. yield
  20. finally:
  21. del guard
  22. def fallback_dispatcher(func, types, args, kwargs):
  23. with no_dispatch():
  24. return func(*args)
  25. def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor:
  26. assert len(args) == 1
  27. A = args[0]
  28. assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
  29. assert A.packed is not None
  30. if A.meta is None:
  31. m, k = A.shape
  32. num_kept_elements = m * k // 2
  33. return A.packed[:num_kept_elements:].view(m, -1)
  34. else:
  35. return A.packed.detach()
  36. def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor:
  37. assert len(args) == 1
  38. A = args[0]
  39. assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
  40. assert A.packed is not None
  41. if A.meta is None:
  42. m, k = A.shape
  43. num_kept_elements = m * k // 2
  44. metadata = A.packed[num_kept_elements:].view(m, -1)
  45. return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16)
  46. else:
  47. return A.meta
  48. def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
  49. assert len(args) == 1
  50. self = args[0]
  51. assert isinstance(self, torch.sparse.SparseSemiStructuredTensor)
  52. assert len(self.shape) == 2
  53. # Because we cannot go from the compressed representation back to the dense representation currently,
  54. # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
  55. # is the first or second argument, we expect an even / odd number of calls to transpose respectively.
  56. return self.__class__(
  57. torch.Size([self.shape[-1], self.shape[0]]),
  58. packed=self.packed_t,
  59. meta=self.meta_t,
  60. packed_t=self.packed,
  61. meta_t=self.meta,
  62. compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1)
  63. if self.compressed_swizzled_bitmask is not None
  64. else None,
  65. fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
  66. alg_id_cusparselt=args[0].alg_id_cusparselt,
  67. )
  68. def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor:
  69. assert len(args) == 2
  70. self, shape = args
  71. if tuple(shape) != self.shape:
  72. raise NotImplementedError(
  73. f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})"
  74. )
  75. return self
  76. def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
  77. assert len(args) == 1
  78. self = args[0]
  79. return self.__class__(
  80. shape=self.shape,
  81. packed=self.packed,
  82. meta=self.meta,
  83. packed_t=self.packed_t,
  84. meta_t=self.meta_t,
  85. compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
  86. requires_grad=False,
  87. )
  88. def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
  89. assert len(args) == 2
  90. A, B = args
  91. if A.ndim != 2 or B.ndim != 2:
  92. raise NotImplementedError(
  93. "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
  94. )
  95. if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
  96. row, col = B.shape
  97. B_padded = A._pad_dense_input(B)
  98. res = A._mm(B_padded)
  99. return res[:, :col]
  100. else:
  101. B_t = B.t()
  102. assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
  103. row, col = A.shape
  104. A_padded = B._pad_dense_input(A)
  105. res = B_t._mm(A_padded.t()).t()
  106. return res[:row, :]
  107. def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor:
  108. assert len(args) == 3
  109. bias, A, B = args
  110. if A.ndim != 2 or B.ndim != 2:
  111. raise NotImplementedError(
  112. "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
  113. )
  114. if bias.ndim != 1:
  115. raise NotImplementedError(
  116. f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}"
  117. )
  118. if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
  119. raise NotImplementedError(
  120. "`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse"
  121. )
  122. B_t = B.t()
  123. assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
  124. row, col = A.shape
  125. A_padded = B_t._pad_dense_input(A)
  126. result = B_t._mm(A_padded.t(), bias=bias).t()
  127. return result[:row, :]
  128. def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
  129. assert len(args) in [2, 3]
  130. A, B = args[:2]
  131. bias = args[2] if len(args) == 3 else None
  132. shape = A.shape
  133. A_2d = A.view(-1, shape[-1])
  134. if bias is None:
  135. res = A_2d @ B.t()
  136. else:
  137. res = semi_sparse_addmm(
  138. func=None,
  139. types=None,
  140. args=[bias, A_2d, B.t()],
  141. )
  142. return res.view(*shape[:-1], -1)