mm.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import logging
  4. from typing import Any, Dict, List, Optional
  5. import torch
  6. from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate
  7. from torch._inductor.virtualized import V
  8. from .. import config as inductor_config
  9. from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate
  10. from ..codegen.wrapper import WrapperCodeGen
  11. from ..ir import FlexibleLayout
  12. from ..lowering import register_lowering
  13. from ..select_algorithm import (
  14. autotune_select_algorithm,
  15. ExternKernelChoice,
  16. NoValidChoicesError,
  17. TritonTemplate,
  18. )
  19. from ..utils import (
  20. use_aten_gemm_kernels,
  21. use_cpp_packed_gemm_template,
  22. use_cutlass_template,
  23. use_max_autotune,
  24. use_triton_template,
  25. )
  26. from .mm_common import (
  27. addmm_epilogue,
  28. int8_mm_configs,
  29. mixed_mm_configs,
  30. mm_args,
  31. mm_configs,
  32. mm_grid,
  33. mm_options,
  34. )
  35. log = logging.getLogger(__name__)
  36. aten = torch.ops.aten
  37. mm_template = TritonTemplate(
  38. name="mm",
  39. grid=mm_grid,
  40. source=r"""
  41. {{def_kernel("A", "B")}}
  42. M = {{size("A", 0)}}
  43. N = {{size("B", 1)}}
  44. K = {{size("A", 1)}}
  45. if M * N == 0:
  46. # early exit due to zero-size input(s)
  47. return
  48. stride_am = {{stride("A", 0)}}
  49. stride_ak = {{stride("A", 1)}}
  50. stride_bk = {{stride("B", 0)}}
  51. stride_bn = {{stride("B", 1)}}
  52. # based on triton.ops.matmul
  53. pid = tl.program_id(0)
  54. grid_m = (M + BLOCK_M - 1) // BLOCK_M
  55. grid_n = (N + BLOCK_N - 1) // BLOCK_N
  56. # re-order program ID for better L2 performance
  57. width = GROUP_M * grid_n
  58. group_id = pid // width
  59. group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
  60. pid_m = group_id * GROUP_M + (pid % group_size)
  61. pid_n = (pid % width) // (group_size)
  62. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  63. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  64. if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
  65. ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
  66. else:
  67. ram = rm % M
  68. if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
  69. rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
  70. else:
  71. rbn = rn % N
  72. rk = tl.arange(0, BLOCK_K)
  73. A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
  74. B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
  75. acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
  76. for k in range(K, 0, -BLOCK_K):
  77. if EVEN_K:
  78. a = tl.load(A)
  79. b = tl.load(B)
  80. else:
  81. a = tl.load(A, mask=rk[None, :] < k, other=0.)
  82. b = tl.load(B, mask=rk[:, None] < k, other=0.)
  83. if B_PROLOGUE_CAST_TYPE is not None:
  84. b = b.to(B_PROLOGUE_CAST_TYPE)
  85. acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
  86. A += BLOCK_K * stride_ak
  87. B += BLOCK_K * stride_bk
  88. # rematerialize rm and rn to save registers
  89. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  90. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  91. idx_m = rm[:, None]
  92. idx_n = rn[None, :]
  93. mask = (idx_m < M) & (idx_n < N)
  94. # inductor generates a suffix
  95. {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
  96. """,
  97. )
  98. aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
  99. aten_addmm = ExternKernelChoice(
  100. torch.addmm, "at::addmm_out", op_overload=aten.addmm.default
  101. )
  102. aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm")
  103. def _is_int8_mat(mat):
  104. return mat.get_dtype() in (torch.int8, torch.uint8)
  105. def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
  106. """
  107. Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
  108. kernel under the hood. There are a few shapes where this is slower,
  109. but they are rare.
  110. """
  111. if inp.stride(0) == 0 or inp.size(0) == 1:
  112. return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
  113. return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
  114. aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
  115. @register_lowering(aten.mm, type_promotion_kind=None)
  116. def tuned_mm(mat1, mat2, *, layout=None):
  117. m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
  118. aten_layout = layout
  119. if not use_max_autotune():
  120. aten_layout = FlexibleLayout(
  121. device=layout.device, dtype=layout.dtype, size=layout.size
  122. )
  123. # options to tune from
  124. choices = (
  125. [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
  126. )
  127. static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
  128. if is_nonzero and use_triton_template(layout):
  129. for config in mm_configs(m, n, k):
  130. mm_template.maybe_append_choice(
  131. choices,
  132. input_nodes=(mat1, mat2),
  133. layout=layout,
  134. **mm_options(config, m, n, k, layout),
  135. )
  136. if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
  137. CUTLASSGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
  138. if use_cpp_packed_gemm_template(layout, mat1, mat2):
  139. CppPackedGemmTemplate.add_choices(
  140. choices,
  141. layout,
  142. [mat1, mat2],
  143. )
  144. if (
  145. len(choices) == 0
  146. and not use_aten_gemm_kernels()
  147. and inductor_config.autotune_fallback_to_aten
  148. ):
  149. log.warning("No choices for GEMM, using ATen backend as fallback")
  150. return aten_mm.bind((mat1, mat2), aten_layout).output_node()
  151. try:
  152. return autotune_select_algorithm("mm", choices, [mat1, mat2], layout)
  153. except NoValidChoicesError:
  154. if not inductor_config.autotune_fallback_to_aten:
  155. raise
  156. log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
  157. return aten_mm.bind((mat1, mat2), aten_layout).output_node()
  158. def _is_static_problem(inputs_tensors, layout):
  159. # checks whether all input tensors and the output layout
  160. # have a static shape by attempting to convert the dimensions
  161. # to int
  162. static_shape = True
  163. static_size = WrapperCodeGen.statically_known_list_of_ints_or_none(layout.size)
  164. if static_size is None:
  165. nonzero = True
  166. for s in layout.size:
  167. sz = WrapperCodeGen.statically_known_int_or_none(s)
  168. if sz is not None and sz == 0:
  169. nonzero = False
  170. break
  171. return False, nonzero
  172. numel = 1
  173. for dim in static_size:
  174. numel *= dim
  175. nonzero = numel > 0
  176. return static_shape, nonzero
  177. @register_lowering(aten._int_mm, type_promotion_kind=None)
  178. def tuned_int_mm(mat1, mat2, *, layout=None):
  179. m, n, k, layout, mat1, mat2 = mm_args(
  180. mat1, mat2, layout=layout, out_dtype=torch.int32
  181. )
  182. static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
  183. use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
  184. choices = (
  185. [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
  186. )
  187. # TODO: Re-enable eager mode implementation once cuBLAS is fixed
  188. if use_cutlass or use_triton_template(layout, enable_int32=True):
  189. choices = []
  190. if use_cutlass:
  191. CUTLASSGemmTemplate.add_cutlass_gemm_choices(
  192. choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
  193. )
  194. if is_nonzero and use_triton_template(layout, enable_int32=True):
  195. for config in int8_mm_configs(m, n, k):
  196. mm_template.maybe_append_choice(
  197. choices,
  198. input_nodes=(mat1, mat2),
  199. layout=layout,
  200. **mm_options(config, m, n, k, layout),
  201. )
  202. if len(choices) == 0:
  203. log.warning(
  204. "No choices for integer GEMM avaialbe using configured backends, using ATen backend as fallback"
  205. )
  206. choices = [aten__int_mm.bind((mat1, mat2), layout)]
  207. try:
  208. return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
  209. except NoValidChoicesError:
  210. if not inductor_config.autotune_fallback_to_aten:
  211. raise
  212. log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
  213. choices = [aten__int_mm.bind((mat1, mat2), layout)]
  214. return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
  215. @register_lowering(aten.addmm, type_promotion_kind=None)
  216. def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
  217. ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
  218. m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
  219. static_shape, is_nonzero = _is_static_problem([inp, mat1, mat2], layout)
  220. if (not is_nonzero) or (not use_max_autotune()):
  221. # Use a FlexibleLayout if we are not autotuning.
  222. # This allows padding strides for the output.
  223. from torch._inductor.ir import FixedLayout, FlexibleLayout
  224. if isinstance(layout, FixedLayout):
  225. layout = FlexibleLayout(
  226. device=layout.device, dtype=layout.dtype, size=layout.size
  227. )
  228. choices = (
  229. [
  230. aten_addmm.bind(
  231. (inp, mat1, mat2),
  232. layout,
  233. alpha=alpha,
  234. beta=beta,
  235. )
  236. ]
  237. if use_aten_gemm_kernels()
  238. else []
  239. )
  240. return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout)
  241. choices = (
  242. [
  243. aten_addmm.bind(
  244. (inp_expanded, mat1, mat2),
  245. layout,
  246. alpha=alpha,
  247. beta=beta,
  248. )
  249. ]
  250. if use_aten_gemm_kernels()
  251. else []
  252. )
  253. if (
  254. use_aten_gemm_kernels()
  255. and inp_expanded.get_stride()[0] == 0
  256. and inp_expanded.get_device().type == "cuda"
  257. and inductor_config.triton.autotune_cublasLt
  258. ):
  259. # unexpand inp to make sure fused addmm from cublasLt is used
  260. choices.insert(
  261. 0,
  262. aten_bias_addmm.bind(
  263. (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
  264. ),
  265. )
  266. if is_nonzero and use_triton_template(layout):
  267. for config in mm_configs(m, n, k):
  268. mm_template.maybe_append_choice(
  269. choices,
  270. input_nodes=(inp_expanded, mat1, mat2),
  271. layout=layout,
  272. **mm_options(config, m, n, k, layout),
  273. prefix_args=1,
  274. epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
  275. )
  276. if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
  277. # Filter out a known cause of CUDA illegal memory access errors
  278. # broadcasting on the last dim of the bias term seems not to be working
  279. # in the linear GEMM epilogue used by addmm.
  280. if (
  281. WrapperCodeGen.statically_known_int_or_none(inp_expanded.layout.stride[-1])
  282. != 0
  283. ):
  284. CUTLASSGemmTemplate.add_cutlass_gemm_choices(
  285. choices,
  286. layout,
  287. [mat1, mat2, inp_expanded],
  288. alpha=alpha,
  289. beta=beta,
  290. )
  291. if use_cpp_packed_gemm_template(layout, mat1, mat2):
  292. CppPackedGemmTemplate.add_choices(
  293. choices,
  294. layout,
  295. [inp_expanded, mat1, mat2],
  296. alpha=alpha,
  297. beta=beta,
  298. )
  299. add_aten_fallback = False
  300. if len(choices) == 0:
  301. log.warning("No choices for GEMM, using ATen backend as fallback")
  302. add_aten_fallback = True
  303. if add_aten_fallback:
  304. choices.append(
  305. aten_addmm.bind(
  306. (inp_expanded, mat1, mat2),
  307. layout,
  308. ordered_kwargs_for_cpp_kernel,
  309. alpha=alpha,
  310. beta=beta,
  311. )
  312. )
  313. if (
  314. inp_expanded.get_stride()[0] == 0
  315. and inp_expanded.get_device().type == "cuda"
  316. and inductor_config.triton.autotune_cublasLt
  317. ):
  318. # unexpand inp to make sure fused addmm from cublasLt is used
  319. choices.insert(
  320. 0,
  321. aten_bias_addmm.bind(
  322. (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
  323. ),
  324. )
  325. try:
  326. return autotune_select_algorithm(
  327. "addmm", choices, [inp_expanded, mat1, mat2], layout
  328. )
  329. except NoValidChoicesError:
  330. if not inductor_config.autotune_fallback_to_aten:
  331. raise
  332. log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
  333. fallback_choice = aten_addmm.bind(
  334. (inp, mat1, mat2),
  335. layout,
  336. ordered_kwargs_for_cpp_kernel,
  337. alpha=alpha,
  338. beta=beta,
  339. )
  340. return fallback_choice.output_node()
  341. def fallback_mixed_mm(mat1, mat2, *, out):
  342. return torch.mm(mat1, mat2.to(mat1.dtype), out=out)
  343. aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None)
  344. @functools.lru_cache(None)
  345. def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
  346. props = torch.cuda.get_device_properties(index or 0)
  347. return props.major <= 7
  348. def tuned_mixed_mm(mat1, mat2, mat2_dtype):
  349. m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None)
  350. static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
  351. fallback = aten_fallback_mixed_mm.bind((mat1, mat2), layout)
  352. choices = [fallback]
  353. # can't use triton kernel unless one of these is true or if running on v100 (numerical issues)
  354. skip_triton = (
  355. mat1.layout.dtype != torch.float32
  356. and not (mat2.layout.is_contiguous() or mat2.layout.is_transposed())
  357. ) or _is_sm7x_or_older_gpu(layout.device.index)
  358. if inductor_config.force_mixed_mm:
  359. choices = []
  360. if not skip_triton:
  361. b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
  362. has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2)
  363. for config in mixed_mm_configs(m, n, k, has_int8_tensor=has_int8_tensor):
  364. mm_template.maybe_append_choice(
  365. choices,
  366. input_nodes=(mat1, mat2),
  367. layout=layout,
  368. **mm_options(config, m, n, k, layout, b_prologue_cast_type),
  369. )
  370. if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
  371. CUTLASSGemmTemplate.add_cutlass_gemm_choices(
  372. choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
  373. )
  374. if skip_triton and not choices:
  375. choices = [fallback]
  376. return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout)
  377. # This op is a special case of the int_mm op which we use based on the pattern
  378. # _int_mm -> mul (defined in ../fx_passes/post_grad.py) in order to prevent
  379. # realization of the int32 _int_mm output by forcing fusion with the mul op.
  380. # This is only used when config.force_fuse_int_mm_with_mul = True
  381. def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
  382. out_dtype = (
  383. torch.promote_types(mat3.get_dtype(), torch.int32)
  384. if out_dtype is None
  385. else out_dtype
  386. )
  387. m, n, k, layout, mat1, mat2, mat3 = mm_args(
  388. mat1, mat2, mat3, layout=layout, out_dtype=out_dtype
  389. )
  390. choices: List[Dict[Any, Any]] = []
  391. for config in int8_mm_configs(m, n, k):
  392. mm_template.maybe_append_choice(
  393. choices,
  394. input_nodes=(mat1, mat2, mat3),
  395. layout=layout,
  396. **dict(mm_options(config, m, n, k, layout), ACC_TYPE="tl.int32"),
  397. suffix_args=1,
  398. epilogue_fn=V.ops.mul,
  399. )
  400. return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout)