_semi_structured_conversions.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. # mypy: allow-untyped-defs
  2. import torch
  3. def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
  4. """
  5. This is PyTorch implementation of main part of reorder_meta()
  6. function, from tools/util/include/cutlass/util/host_reorder.h file
  7. of CUTLASS source tree. Furthermore, CUTLASS template for sparse
  8. GEMM decides upon layout of this matrix, and at the moment for the
  9. sparse GEMM executed on tensor cores, this is layout described by
  10. ColumnMajorInterleaved<2> data structure, in
  11. include/cutlass/layout/matrix.h of CUTLASS source tree. The
  12. reordering of meta matrix into meta_reordered matrix calculated
  13. according to these segments of CUTLASS code is re-implemented here.
  14. Note that this calculation produces offsets for scattering metadata
  15. matrix elements into reordered metadata matrix elements (or,
  16. equivalently, for gathering reordered metadata matrix element back
  17. into metadata matrix elements).
  18. """
  19. dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
  20. dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
  21. # Reorder the rows, then swizzle the 2x2 blocks.
  22. group = 32 if meta_dtype.itemsize == 2 else 16
  23. interweave = 4 if meta_dtype.itemsize == 2 else 2
  24. dst_rows = (
  25. dst_rows // group * group
  26. + (dst_rows % 8) * interweave
  27. + (dst_rows % group) // 8
  28. )
  29. topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
  30. bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
  31. dst_rows += topright - bottomleft
  32. dst_cols -= topright - bottomleft
  33. # Assumed that meta tensor is to be stored in CUTLASS
  34. # InterleavedColumnMajor layout, and reverse engineered
  35. # corresponding code to store values into this tensor.
  36. interleave = 2
  37. cols_maj = dst_cols // interleave
  38. cols_min = dst_cols % interleave
  39. return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
  40. def sparse_semi_structured_from_dense_cutlass(dense):
  41. """
  42. This function converts dense matrix into sparse semi-structured
  43. representation, producing "compressed" matrix, in the layout used by
  44. CUTLASS backend, and corresponding metadata matrix.
  45. """
  46. if dense.dim() != 2:
  47. raise RuntimeError(
  48. f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor"
  49. )
  50. m, k = dense.shape
  51. device = dense.device
  52. meta_dtype = torch.int8
  53. if dense.dtype == torch.int8:
  54. meta_dtype = torch.int32
  55. elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
  56. meta_dtype = torch.int16
  57. else:
  58. raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
  59. quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
  60. if quadbits_per_meta_elem not in (4, 8):
  61. raise RuntimeError("Invalid number of elements per meta element calculated")
  62. if meta_dtype == torch.int32:
  63. if m % 16 != 0:
  64. raise RuntimeError(
  65. f"Number of rows of dense matrix {m} must be divisible by 16"
  66. )
  67. else:
  68. if m % 32 != 0:
  69. raise RuntimeError(
  70. f"Number of rows of dense matrix {m} must be divisible by 32"
  71. )
  72. if k % (4 * quadbits_per_meta_elem) != 0:
  73. raise RuntimeError(
  74. f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
  75. )
  76. if dense.dtype != torch.float:
  77. ksparse = 4
  78. dense_4 = dense.view(-1, k // ksparse, ksparse)
  79. m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
  80. else:
  81. ksparse = 2
  82. dense_2 = dense.view(-1, k // ksparse, ksparse)
  83. m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
  84. meta_ncols = k // (ksparse * quadbits_per_meta_elem)
  85. # Encoding quadruples of True/False values as follows:
  86. # [True, True, False, False] -> 0b0100
  87. # [True, False, True, False] -> 0b1000
  88. # [False, True, True, False] -> 0b1001
  89. # [True, False, False, True ] -> 0b1100
  90. # [False, True, False, True ] -> 0b1101
  91. # [False, False, True, True ] -> 0b1110
  92. # Thus, lower two bits in the encoding are index of the True value
  93. # at the lowest index in the quadruple, and the higher two bits in
  94. # the encoding are index of the other True value in the quadruple.
  95. # In case there are less than two True values, than False value or
  96. # values at some index or indices are considered True for the
  97. # encoding. In case there are more than two True values, then the
  98. # excess True value(s) at some indices are considered False for
  99. # the encoding. The exact encodings used for these cases are as
  100. # follows:
  101. # [False, False, False, False] -> 0b1110
  102. # [False, False, False, True ] -> 0b1110
  103. # [False, False, True, False] -> 0b1110
  104. # [False, True, False, False] -> 0b1001
  105. # [False, True, True, True ] -> 0b1101
  106. # [True, False, False, False] -> 0b1000
  107. # [True, False, True, True ] -> 0b1100
  108. # [True, True, False, True ] -> 0b0100
  109. # [True, True, True, False] -> 0b0100
  110. # [True, True, True, True ] -> 0b0100
  111. # These particular encodings are chosen, with the help of Espresso
  112. # logic minimizer software, for the purpose of minimization of
  113. # corresponding Boolean functions, that translate non-zero flags
  114. # into encoding bits. Note also possible choices for the first
  115. # and last of these encodings were limited only to (0b0100,
  116. # 0b1110), in order to produce valid encodings for 1:2 sparsity
  117. # case.
  118. expr0 = m0 & m1
  119. expr1 = ~m0 & m1
  120. expr2 = ~m0 & ~m1
  121. bit0 = expr1
  122. bit1 = expr2
  123. bit2 = expr0 | expr2 | m3
  124. bit3 = expr1 | ~m1
  125. idxs0 = bit0 | (bit1.to(torch.int64) << 1)
  126. idxs1 = bit2 | (bit3.to(torch.int64) << 1)
  127. if dense.dtype != torch.float:
  128. sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
  129. sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
  130. sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
  131. else:
  132. sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
  133. meta_4 = idxs0 | (idxs1 << 2)
  134. meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
  135. if quadbits_per_meta_elem == 4:
  136. meta = (
  137. meta_n[:, :, 0]
  138. | (meta_n[:, :, 1] << 4)
  139. | (meta_n[:, :, 2] << 8)
  140. | (meta_n[:, :, 3] << 12)
  141. )
  142. elif quadbits_per_meta_elem == 8:
  143. meta = (
  144. meta_n[:, :, 0]
  145. | (meta_n[:, :, 1] << 4)
  146. | (meta_n[:, :, 2] << 8)
  147. | (meta_n[:, :, 3] << 12)
  148. | (meta_n[:, :, 4] << 16)
  149. | (meta_n[:, :, 5] << 20)
  150. | (meta_n[:, :, 6] << 24)
  151. | (meta_n[:, :, 7] << 28)
  152. )
  153. # Reorder meta tensor elements.
  154. meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined]
  155. meta_offsets = _calculate_meta_reordering_scatter_offsets(
  156. m, meta_ncols, meta_dtype, device
  157. )
  158. meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
  159. return (sparse, meta_reordered.view(m, meta_ncols))
  160. def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
  161. """
  162. This function performs reverse of the function above - it
  163. reconstructs dense matrix from a pair of "compressed" matrix, given
  164. in the layout used by CUTLASS backend, and accompanying metadata
  165. matrix.
  166. """
  167. if sparse.dim() != 2:
  168. raise RuntimeError(
  169. f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor"
  170. )
  171. m, k = sparse.shape
  172. device = sparse.device
  173. if meta_reordered.dim() != 2:
  174. raise RuntimeError(
  175. f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor"
  176. )
  177. if meta_reordered.device != device:
  178. raise RuntimeError(
  179. f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device"
  180. )
  181. meta_dtype = meta_reordered.dtype
  182. if meta_dtype not in (torch.int16, torch.int32):
  183. raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
  184. quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
  185. if sparse.dtype != torch.float:
  186. ksparse = 4
  187. else:
  188. ksparse = 2
  189. meta_nrows, meta_ncols = meta_reordered.shape
  190. if meta_nrows != m:
  191. raise RuntimeError(
  192. f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}"
  193. )
  194. if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
  195. raise RuntimeError(
  196. f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, "
  197. "expected according to the number of columns of meta matrix"
  198. )
  199. # Undo meta tensor elements reordering.
  200. meta_offsets = _calculate_meta_reordering_scatter_offsets(
  201. m, meta_ncols, meta_dtype, device
  202. )
  203. meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
  204. # Unpack sparse tensor back to original dense tensor, using
  205. # information provided by meta tensor. Note that torch.float
  206. # datatype is handled pretty much the same as
  207. # torch.half/torch.bfloat16, as metadata for a pair of torch.float
  208. # value is encoded as if underlying 8 bytes contain four
  209. # torch.half/torch.bfloat16 values, where either first two or last
  210. # two are zeros.
  211. meta_2 = torch.empty(
  212. (m, meta_ncols, 2 * quadbits_per_meta_elem),
  213. dtype=meta_dtype,
  214. device=device,
  215. )
  216. if quadbits_per_meta_elem == 4:
  217. meta_2[:, :, 0] = meta & 0b11
  218. meta_2[:, :, 1] = (meta >> 2) & 0b11
  219. meta_2[:, :, 2] = (meta >> 4) & 0b11
  220. meta_2[:, :, 3] = (meta >> 6) & 0b11
  221. meta_2[:, :, 4] = (meta >> 8) & 0b11
  222. meta_2[:, :, 5] = (meta >> 10) & 0b11
  223. meta_2[:, :, 6] = (meta >> 12) & 0b11
  224. meta_2[:, :, 7] = (meta >> 14) & 0b11
  225. elif quadbits_per_meta_elem == 8:
  226. meta_2[:, :, 0] = meta & 0b11
  227. meta_2[:, :, 1] = (meta >> 2) & 0b11
  228. meta_2[:, :, 2] = (meta >> 4) & 0b11
  229. meta_2[:, :, 3] = (meta >> 6) & 0b11
  230. meta_2[:, :, 4] = (meta >> 8) & 0b11
  231. meta_2[:, :, 5] = (meta >> 10) & 0b11
  232. meta_2[:, :, 6] = (meta >> 12) & 0b11
  233. meta_2[:, :, 7] = (meta >> 14) & 0b11
  234. meta_2[:, :, 8] = (meta >> 16) & 0b11
  235. meta_2[:, :, 9] = (meta >> 18) & 0b11
  236. meta_2[:, :, 10] = (meta >> 20) & 0b11
  237. meta_2[:, :, 11] = (meta >> 22) & 0b11
  238. meta_2[:, :, 12] = (meta >> 24) & 0b11
  239. meta_2[:, :, 13] = (meta >> 26) & 0b11
  240. meta_2[:, :, 14] = (meta >> 28) & 0b11
  241. meta_2[:, :, 15] = (meta >> 30) & 0b11
  242. dense_offsets = meta_2.view(-1) + (
  243. torch.arange(0, 2 * m * k // ksparse, device=device) * 4
  244. ).view(-1, 1).repeat(1, 2).view(-1)
  245. dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
  246. if sparse.dtype != torch.float:
  247. dense.scatter_(0, dense_offsets, sparse.view(-1))
  248. else:
  249. dense.view(torch.half).scatter_(
  250. 0, dense_offsets, sparse.view(torch.half).view(-1)
  251. )
  252. return dense.view(m, 2 * k)
  253. def _sparse_semi_structured_tile(dense):
  254. """
  255. This function computes a 2:4 sparse tile by greedily taking the largest values.
  256. Since we take the largest values greedily, how the sorting algorithm handles duplicates affects
  257. the ultimate sparsity pattern.
  258. Note that this function does not have the same sorting semantics as our CUDA backend,
  259. which is exposed via `torch._sparse_semi_structured_tile` and thus returns a different pattern.
  260. """
  261. def greedy_prune_tile(tile):
  262. num_kept_row = [0, 0, 0, 0]
  263. num_kept_col = [0, 0, 0, 0]
  264. for x in tile.flatten().sort(descending=True, stable=True).indices:
  265. r, c = x // 4, x % 4
  266. if num_kept_row[r] < 2 and num_kept_col[c] < 2:
  267. num_kept_row[r] += 1
  268. num_kept_col[c] += 1
  269. else:
  270. tile[r, c] = 0
  271. for batch in dense.unfold(0, 4, 4).unfold(1, 4, 4):
  272. for tile in batch:
  273. greedy_prune_tile(tile)
  274. return dense
  275. def _compute_compressed_swizzled_bitmask(dense):
  276. """
  277. Calculates the compressed swizzled bitmask from a dense tensor
  278. """
  279. # first we need to convert the dense tensor to a bitmask
  280. int_bitmask = dense.bool().to(torch.uint8)
  281. # Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
  282. # A, B, C and D, as displayed in the following schema:
  283. # +---+---+
  284. # | A | B |
  285. # +---+---+
  286. # | C | D |
  287. # +---+---+
  288. # we first need to split into the 8x8 tiles
  289. bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8)
  290. # then we unfold again to get our indivdual 4x4 tiles
  291. bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4)
  292. # Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern
  293. # of that tile. Note that the least siginificant bit is stored first.
  294. # [1 1 0 0]
  295. # [1 1 0 0] -> 0011 0011 -> 51
  296. # [0 0 1 1] 1100 1100 204
  297. # [0 0 1 1]
  298. # reshape tensor to expand tiles into 8-bit vectors
  299. bitmask_binary_representation = bitmask_4x4_chunks.reshape(*bitmask_4x4_chunks.shape[:2], 4, 2, 8)
  300. # to convert from binary representaiton, we can do a matmul with powers of two
  301. powers_of_two = 2**torch.arange(8, dtype=torch.float, device="cuda")
  302. # To run on GPU: cast to float to do matmul and then cast back
  303. compressed_swizzled_bitmask = (bitmask_binary_representation.to(torch.float) @ powers_of_two).to(torch.uint8)
  304. return compressed_swizzled_bitmask