conv.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import functools
  4. import logging
  5. from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict
  6. import torch
  7. from .. import config, ir
  8. from ..lowering import (
  9. add_layout_constraint,
  10. constrain_to_fx_strides,
  11. lowerings as L,
  12. register_lowering,
  13. )
  14. from ..select_algorithm import (
  15. autotune_select_algorithm,
  16. ExternKernelChoice,
  17. TritonTemplate,
  18. )
  19. from ..utils import (
  20. ceildiv,
  21. is_ones,
  22. is_zeros,
  23. pad_listlike,
  24. sympy_product,
  25. use_triton_template,
  26. )
  27. from ..virtualized import V
  28. from .mm_common import filtered_configs
  29. if TYPE_CHECKING:
  30. from ..ir import TensorBox
  31. log = logging.getLogger(__name__)
  32. aten = torch.ops.aten
  33. def conv_grid(n, c, h, w, meta):
  34. return (
  35. ceildiv(n * h * w, meta["BLOCK_M"]),
  36. ceildiv(c, meta["BLOCK_N"]),
  37. meta["GROUPS"],
  38. )
  39. # List of dictionaries to store the kernel configs. Configs that evaluate to true
  40. # will be utilised on the target platform
  41. kernel_configs = [
  42. # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
  43. {"config": (64, 256, 16, 2, 4), "cond": True},
  44. {"config": (256, 64, 16, 2, 4), "cond": True},
  45. {"config": (1024, 16, 16, 1, 8), "cond": True},
  46. {"config": (128, 128, 32, 2, 8), "cond": True},
  47. {"config": (64, 64, 32, 2, 4), "cond": True},
  48. {"config": (64, 256, 32, 2, 8), "cond": True},
  49. {"config": (256, 64, 32, 2, 8), "cond": True},
  50. ]
  51. # Create filtered list of configs based on conv
  52. platform_configs = tuple(
  53. cast(Tuple[int, int, int, int, int], config["config"])
  54. for config in kernel_configs
  55. if config["cond"]
  56. )
  57. # On ROCm convert num_stages to 1 as pipelining provides no benefit
  58. if torch.version.hip:
  59. platform_configs = tuple(
  60. (config[0], config[1], config[2], 1, config[4]) for config in platform_configs
  61. )
  62. conv_configs = functools.partial(
  63. filtered_configs,
  64. configs=platform_configs,
  65. )
  66. LOOP_BODY = """
  67. idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
  68. idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
  69. idx_x_c = tl.arange(0, BLOCK_K) + k
  70. x_ptrs = x_base + (
  71. (idx_x_h * stride_xh)[:, None]
  72. + (idx_x_w * stride_xw)[:, None]
  73. + (idx_x_c * stride_xc)[None, :]
  74. )
  75. mask_x = (
  76. (idx_n < BATCH)[:, None]
  77. & (idx_x_h >= 0)[:, None]
  78. & (idx_x_h < IN_H)[:, None]
  79. & (idx_x_w >= 0)[:, None]
  80. & (idx_x_w < IN_W)[:, None]
  81. & (idx_x_c < GROUP_IN_C)[None, :]
  82. )
  83. matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
  84. w_ptrs = w_base + (
  85. (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
  86. )
  87. mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
  88. matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
  89. acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
  90. """
  91. """
  92. This is a relatively simple conv implementation that can likely be
  93. improved. Many alternate conv versions can be found here:
  94. https://github.com/pytorch/torchdynamo/pull/971
  95. """
  96. conv2d_template = TritonTemplate(
  97. name="convolution",
  98. grid=conv_grid,
  99. source=r"""
  100. {{def_kernel("X", "W")}}
  101. # Tensor dimensions
  102. BATCH = {{size("X", 0)}}
  103. IN_C = {{size("X", 1)}}
  104. IN_H = {{size("X", 2)}}
  105. IN_W = {{size("X", 3)}}
  106. OUT_C = {{size(None, 1)}}
  107. OUT_H = {{size(None, 2)}}
  108. OUT_W = {{size(None, 3)}}
  109. # Strides:
  110. stride_xn = {{stride("X", 0)}}
  111. stride_xc = {{stride("X", 1)}}
  112. stride_xh = {{stride("X", 2)}}
  113. stride_xw = {{stride("X", 3)}}
  114. stride_wc_out = {{stride("W", 0)}}
  115. stride_wc_in = {{stride("W", 1)}}
  116. stride_wh = {{stride("W", 2)}}
  117. stride_ww = {{stride("W", 3)}}
  118. nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
  119. idx_y_w = nhw % OUT_W
  120. nh = nhw // OUT_W
  121. idx_y_h = nh % OUT_H
  122. idx_n = nh // OUT_H
  123. idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
  124. {% if GROUPS == 1 %}
  125. group = 0
  126. GROUP_IN_C = IN_C
  127. GROUP_OUT_C = OUT_C
  128. {% else %}
  129. group = tl.program_id(2)
  130. GROUP_IN_C = IN_C // GROUPS
  131. GROUP_OUT_C = OUT_C // GROUPS
  132. {% endif %}
  133. x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
  134. w_base = (
  135. W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
  136. )
  137. acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
  138. {% if UNROLL %}
  139. {% for i in range(KERNEL_H) %}
  140. {% for j in range(KERNEL_W) %}
  141. i = {{i}}
  142. j = {{j}}
  143. for k in range(0, GROUP_IN_C, BLOCK_K):
  144. """
  145. + LOOP_BODY
  146. + """
  147. {% endfor %}
  148. {% endfor %}
  149. {% else %}
  150. # Could be simplified, but slightly slower:
  151. # for i in range(KERNEL_H):
  152. # for j in range(KERNEL_W):
  153. # for k in range(0, GROUP_IN_C, BLOCK_K):
  154. BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
  155. for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
  156. k = (ijk % BLOCK_K_COUNT) * BLOCK_K
  157. ij = ijk // BLOCK_K_COUNT
  158. i = ij // KERNEL_W
  159. j = ij % KERNEL_W
  160. """
  161. + LOOP_BODY
  162. + """
  163. {% endif %}
  164. mask = (
  165. (idx_n < BATCH)[:, None]
  166. & (idx_y_h < OUT_H)[:, None]
  167. & (idx_y_w < OUT_W)[:, None]
  168. & (idx_y_c < GROUP_OUT_C)[None, :]
  169. )
  170. idx_n = idx_n[:, None]
  171. idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
  172. idx_h = idx_y_h[:, None]
  173. idx_w = idx_y_w[:, None]
  174. # inductor generates a suffix
  175. {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}}
  176. """,
  177. )
  178. aten_convolution = ExternKernelChoice(
  179. torch.convolution,
  180. "at::convolution",
  181. has_out_variant=False,
  182. op_overload=aten.convolution.default,
  183. )
  184. def conv1x1_via_mm(x, w, *, out):
  185. w = torch.squeeze(torch.squeeze(w, -1), -1)
  186. return torch.matmul(
  187. x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1)
  188. )
  189. aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None)
  190. class ConvLayoutParams(TypedDict):
  191. stride: tuple[int, ...]
  192. padding: tuple[int, ...]
  193. dilation: tuple[int, ...]
  194. transposed: bool
  195. output_padding: tuple[int, ...]
  196. groups: int
  197. def conv_layout(
  198. x: TensorBox,
  199. weight: TensorBox,
  200. bias: Optional[TensorBox],
  201. stride: Sequence[int],
  202. padding: tuple[int, ...],
  203. dilation: tuple[int, ...],
  204. transposed: bool,
  205. output_padding: tuple[int, ...],
  206. groups: int,
  207. ) -> ir.Layout:
  208. """Determine output layout for a convolution"""
  209. with V.graph.fake_mode:
  210. output = torch.ops.aten.convolution(
  211. ir.ir_node_to_tensor(x, guard_shape=True),
  212. ir.ir_node_to_tensor(weight, guard_shape=True),
  213. ir.ir_node_to_tensor(bias, guard_shape=True),
  214. V.graph.sizevars.size_hints(stride), # type: ignore[arg-type]
  215. V.graph.sizevars.size_hints(padding), # type: ignore[arg-type]
  216. dilation,
  217. transposed,
  218. V.graph.sizevars.size_hints(output_padding), # type: ignore[arg-type]
  219. groups,
  220. )
  221. sizes = ir.convert_shape_to_inductor(output.size())
  222. stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
  223. return ir.FixedLayout(
  224. x.get_device(),
  225. x.get_dtype(),
  226. sizes,
  227. stride,
  228. )
  229. def channels_last_order(rank):
  230. order = list(reversed(range(rank)))
  231. order.insert(1, order.pop(-1))
  232. return order
  233. def convert_1x1_conv_to_mm(x, weight, bias):
  234. # special case for 1x1 convolution, which is actually just a matmul
  235. rank = len(weight.get_size())
  236. for _ in range(rank - 2):
  237. weight = L[aten.squeeze](weight, dim=-1)
  238. weight = L[aten.permute](weight, [1, 0])
  239. if x.get_size()[0] != 1:
  240. x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank))
  241. else:
  242. x.realize()
  243. x.freeze_layout()
  244. x_permute = list(range(rank))
  245. x_permute.append(x_permute.pop(1))
  246. x = L[aten.permute](x, x_permute)
  247. *sizes, in_chan = x.get_size()
  248. x = L[aten.reshape](x, [sympy_product(sizes), in_chan])
  249. if bias is None:
  250. result = L[aten.mm](x, weight)
  251. else:
  252. result = L[aten.addmm](bias, x, weight)
  253. result = L[aten.reshape](result, [*sizes, -1])
  254. result_permute = list(range(rank))
  255. result_permute.insert(1, result_permute.pop(-1))
  256. return L[aten.permute](result, result_permute)
  257. @register_lowering(aten.convolution)
  258. def convolution(
  259. x: TensorBox,
  260. weight: TensorBox,
  261. bias: TensorBox,
  262. stride: List[int],
  263. padding: List[int],
  264. dilation: List[int],
  265. transposed: bool,
  266. output_padding: List[int],
  267. groups: int,
  268. ):
  269. stride = tuple(stride)
  270. padding = tuple(padding)
  271. dilation = tuple(dilation)
  272. output_padding = tuple(output_padding)
  273. if not isinstance(groups, int):
  274. groups = V.graph.sizevars.evaluate_static_shape(groups)
  275. assert isinstance(groups, int)
  276. kwargs: ConvLayoutParams = {
  277. "stride": stride,
  278. "padding": padding,
  279. "dilation": dilation,
  280. "transposed": transposed,
  281. "output_padding": output_padding,
  282. "groups": groups,
  283. }
  284. if len(x.get_size()) == len(weight.get_size()) - 1:
  285. # add batch dimension to simplify rest of function
  286. return L[aten.squeeze](
  287. convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs),
  288. dim=0,
  289. )
  290. out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
  291. weight.get_size()
  292. )
  293. ndim = len(kernel_shape)
  294. stride = pad_listlike(stride, ndim)
  295. padding = pad_listlike(padding, ndim)
  296. dilation = pad_listlike(dilation, ndim)
  297. output_padding = pad_listlike(output_padding, ndim)
  298. def channels_last_conv():
  299. if V.graph.layout_opt and ndim == 2:
  300. return True
  301. layout = conv_layout(x, weight, None, **kwargs)
  302. req_stride_order = ir.get_stride_order(
  303. V.graph.sizevars.size_hints(layout.stride)
  304. )
  305. return req_stride_order == ir.NHWC_STRIDE_ORDER
  306. autotuning_gemm = config.max_autotune or config.max_autotune_gemm
  307. if (
  308. (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
  309. and is_ones(kernel_shape)
  310. and is_ones(stride)
  311. and is_zeros(padding)
  312. and is_ones(dilation)
  313. and not transposed
  314. and is_zeros(output_padding)
  315. and groups == 1
  316. and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0)
  317. ):
  318. return convert_1x1_conv_to_mm(x, weight, bias)
  319. if bias is not None and ir.get_device_type(x) != "cpu":
  320. # peel off the bias, cudnn is slower with it
  321. result = convolution(x, weight, None, **kwargs)
  322. return L[aten.add](
  323. result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1])
  324. )
  325. x.realize()
  326. weight.realize()
  327. # ndim can be 1 for convolution in models such as demucs
  328. # TODO: check if it's beneficial to convert Conv1d to Conv2d and then
  329. # apply channels last.
  330. if V.graph.layout_opt and ndim == 2:
  331. V.graph.num_channels_last_conv += 1
  332. x = ir.ExternKernel.require_channels_last(x)
  333. # TODO maybe we can convert weights to channels last just once before
  334. # running the model.
  335. weight = ir.ExternKernel.require_channels_last(weight)
  336. layout = conv_layout(x, weight, None, **kwargs)
  337. else:
  338. layout = conv_layout(x, weight, None, **kwargs)
  339. req_stride_order = ir.get_stride_order(
  340. V.graph.sizevars.size_hints(layout.stride)
  341. )
  342. x = ir.ExternKernel.require_stride_order(x, req_stride_order)
  343. weight = ir.ExternKernel.require_stride_order(weight, req_stride_order)
  344. ordered_kwargs_for_cpp_kernel = [
  345. "stride",
  346. "padding",
  347. "dilation",
  348. "transposed",
  349. "output_padding",
  350. "groups",
  351. ]
  352. if bias is None:
  353. args = [x, weight]
  354. kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
  355. ordered_kwargs_for_cpp_kernel.insert(0, "bias")
  356. else:
  357. args = [x, weight, bias]
  358. bias.realize()
  359. bias.freeze_layout()
  360. V.graph.sizevars.evaluate_static_shapes(bias.get_size())
  361. choices = [
  362. aten_convolution.bind(
  363. args,
  364. layout,
  365. ordered_kwargs_for_cpp_kernel,
  366. **kwargs,
  367. )
  368. ]
  369. if (
  370. use_triton_template(layout)
  371. # templates only support these:
  372. and ndim == 2
  373. and is_ones(dilation)
  374. and not transposed
  375. and is_zeros(output_padding)
  376. # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
  377. and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
  378. ):
  379. if (
  380. is_ones(kernel_shape)
  381. and is_ones(stride)
  382. and is_zeros(padding)
  383. and groups == 1
  384. ):
  385. choices.append(aten_conv1x1_via_mm.bind(args, layout))
  386. for cfg in conv_configs(
  387. sympy_product([x.get_size()[0], *x.get_size()[2:]]),
  388. out_chan,
  389. in_chan,
  390. ):
  391. conv2d_template.maybe_append_choice(
  392. choices,
  393. input_nodes=(x, weight),
  394. layout=layout,
  395. KERNEL_H=kernel_shape[0],
  396. KERNEL_W=kernel_shape[1],
  397. STRIDE_H=stride[0],
  398. STRIDE_W=stride[1],
  399. PADDING_H=padding[0],
  400. PADDING_W=padding[1],
  401. GROUPS=groups,
  402. # TODO(jansel): try unroll for bigger kernels once fixed:
  403. # https://github.com/openai/triton/issues/1254
  404. UNROLL=is_ones(kernel_shape),
  405. ALLOW_TF32=torch.backends.cudnn.allow_tf32,
  406. num_stages=cfg.num_stages,
  407. num_warps=cfg.num_warps,
  408. **cfg.kwargs,
  409. )
  410. return autotune_select_algorithm("convolution", choices, args, layout)
  411. @register_lowering(aten._convolution)
  412. def _convolution(
  413. x,
  414. weight,
  415. bias,
  416. stride,
  417. padding,
  418. dilation,
  419. transposed,
  420. output_padding,
  421. groups,
  422. benchmark,
  423. deterministic,
  424. cudnn_enabled,
  425. allow_tf32,
  426. ):
  427. return convolution(
  428. x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
  429. )
  430. def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
  431. assert fx_node.target == torch.ops.aten.convolution.default
  432. if V.graph.layout_opt:
  433. return args, kwargs
  434. else:
  435. return constrain_to_fx_strides(fx_node, *args, **kwargs)
  436. add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides)