bmm.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # mypy: allow-untyped-defs
  2. import logging
  3. import torch
  4. from .. import ir, lowering as L
  5. from ..select_algorithm import (
  6. autotune_select_algorithm,
  7. ExternKernelChoice,
  8. TritonTemplate,
  9. )
  10. from ..utils import (
  11. ceildiv as cdiv,
  12. use_aten_gemm_kernels,
  13. use_cutlass_template,
  14. use_triton_template,
  15. )
  16. from ..virtualized import V
  17. from .mm import _is_static_problem
  18. from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
  19. log = logging.getLogger(__name__)
  20. aten = torch.ops.aten
  21. def bmm_grid(b, m, n, meta):
  22. return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
  23. bmm_template = TritonTemplate(
  24. name="bmm",
  25. grid=bmm_grid,
  26. source=r"""
  27. {{def_kernel("A", "B")}}
  28. M = {{size("A", -2)}}
  29. N = {{size("B", -1)}}
  30. K = {{size("A", -1)}}
  31. stride_aq = {{stride("A", 0)}}
  32. stride_am = {{stride("A", 1)}}
  33. stride_ak = {{stride("A", 2)}}
  34. stride_bq = {{stride("B", 0)}}
  35. stride_bk = {{stride("B", 1)}}
  36. stride_bn = {{stride("B", 2)}}
  37. # based on triton.ops.matmul
  38. pid = tl.program_id(0)
  39. grid_m = (M + BLOCK_M - 1) // BLOCK_M
  40. grid_n = (N + BLOCK_N - 1) // BLOCK_N
  41. # re-order program ID for better L2 performance
  42. width = GROUP_M * grid_n
  43. group_id = pid // width
  44. group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
  45. pid_m = group_id * GROUP_M + (pid % group_size)
  46. pid_n = (pid % width) // (group_size)
  47. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  48. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  49. if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
  50. ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
  51. else:
  52. ram = rm % M
  53. if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
  54. rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
  55. else:
  56. rbn = rn % N
  57. rk = tl.arange(0, BLOCK_K)
  58. idx_q = tl.program_id(1) # batch dimension for BMM
  59. A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
  60. B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
  61. acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
  62. for k in range(K, 0, -BLOCK_K):
  63. if EVEN_K:
  64. a = tl.load(A)
  65. b = tl.load(B)
  66. else:
  67. a = tl.load(A, mask=rk[None, :] < k, other=0.)
  68. b = tl.load(B, mask=rk[:, None] < k, other=0.)
  69. acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
  70. A += BLOCK_K * stride_ak
  71. B += BLOCK_K * stride_bk
  72. # rematerialize rm and rn to save registers
  73. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  74. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  75. idx_q = tl.program_id(1) # batch dimension for BMM
  76. idx_m = rm[:, None]
  77. idx_n = rn[None, :]
  78. mask = (idx_m < M) & (idx_n < N)
  79. # inductor generates a suffix
  80. {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
  81. """,
  82. )
  83. aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
  84. aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
  85. @L.register_lowering(aten.bmm)
  86. def tuned_bmm(mat1, mat2, *, layout=None):
  87. if all(x.get_device().type == "cpu" for x in [mat1, mat2]):
  88. # decompose to small ops when memory bound
  89. if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1:
  90. mat1 = L.unsqueeze(mat1, -1)
  91. mat2 = L.unsqueeze(mat2, 1)
  92. return L.sum_(L.mul(mat1, mat2), axis=2)
  93. def is_valid_to_require_contiguous(t):
  94. if not ir.is_storage_and_layout(t):
  95. return True
  96. _, layout = ir.as_storage_and_layout(t, freeze=False)
  97. return isinstance(layout, ir.FlexibleLayout)
  98. def is_preferred_layout_as_bmm_input(sizes, strides):
  99. # contiguous on one of the last two dims
  100. return (
  101. strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1])
  102. ) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2]))
  103. # Make the input of bmm contiguous
  104. # if it is not contiguous on either of the last two dims,
  105. # because bmm cpu implementation would do contiguous() if not.
  106. # This is to avoid additional copies in bmm.
  107. def may_require_contiguous(t, meta_t):
  108. sizes = meta_t.meta["val"].size()
  109. strides = meta_t.meta["val"].stride()
  110. if not is_preferred_layout_as_bmm_input(sizes, strides):
  111. t = ir.ExternKernel.require_contiguous(t)
  112. return t
  113. if is_valid_to_require_contiguous(mat1):
  114. meta_mat1 = V.graph.current_node.args[0]
  115. mat1 = may_require_contiguous(mat1, meta_mat1)
  116. if is_valid_to_require_contiguous(mat2):
  117. meta_mat2 = V.graph.current_node.args[1]
  118. mat2 = may_require_contiguous(mat2, meta_mat2)
  119. m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
  120. # options to tune from
  121. choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
  122. if use_triton_template(layout):
  123. for config in mm_configs(m, n, k):
  124. bmm_template.maybe_append_choice(
  125. choices,
  126. input_nodes=(mat1, mat2),
  127. layout=layout,
  128. **mm_options(config, m, n, k, layout),
  129. )
  130. static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
  131. if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
  132. from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate
  133. CUTLASSGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
  134. if len(choices) == 0:
  135. log.warning("No choices for GEMM, using ATen backend as fallback")
  136. choices.append(aten_bmm.bind((mat1, mat2), layout))
  137. return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
  138. # Don't register this since it is slower than decomposing it
  139. # @L.register_lowering(aten.baddbmm)
  140. def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
  141. m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
  142. # options to tune from
  143. choices = (
  144. [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
  145. if use_aten_gemm_kernels()
  146. else []
  147. )
  148. if use_triton_template(layout):
  149. for config in mm_configs(m, n, k):
  150. bmm_template.maybe_append_choice(
  151. choices,
  152. input_nodes=(inp, mat1, mat2),
  153. layout=layout,
  154. **mm_options(config, m, n, k, layout),
  155. prefix_args=1,
  156. epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
  157. )
  158. return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)