wrapper.py 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import contextlib
  4. import dataclasses
  5. import dis
  6. import functools
  7. import inspect
  8. import operator
  9. import re
  10. from itertools import count
  11. from typing import (
  12. Any,
  13. Callable,
  14. Dict,
  15. Iterator,
  16. List,
  17. Optional,
  18. Set,
  19. Tuple,
  20. TYPE_CHECKING,
  21. Union,
  22. )
  23. import sympy
  24. from sympy import Expr
  25. import torch
  26. import torch._ops
  27. from torch._dynamo.utils import counters, dynamo_timed
  28. from torch._inductor.codegen.multi_kernel import MultiKernelState
  29. from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes
  30. from torch.fx.node import _get_qualified_name
  31. from torch.utils._sympy.singleton_int import SingletonInt
  32. from torch.utils._sympy.symbol import symbol_is_type, SymT
  33. from .. import async_compile, config, ir
  34. from ..ir import ReinterpretView
  35. from ..runtime import triton_heuristics
  36. from ..runtime.hints import DeviceProperties
  37. from ..utils import (
  38. cache_on_self,
  39. get_benchmark_name,
  40. LineContext,
  41. sympy_product,
  42. sympy_str,
  43. )
  44. from ..virtualized import V
  45. from .aoti_hipify_utils import maybe_hipify_code_wrapper
  46. from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter
  47. from .triton_utils import config_of, signature_to_meta
  48. if TYPE_CHECKING:
  49. import triton
  50. from ..graph import GraphLowering
  51. pexpr = PythonPrinter().doprint
  52. ReuseKey = Tuple[torch.device, torch.dtype, str]
  53. def buffer_reuse_key(node: ir.Buffer) -> ReuseKey:
  54. return (
  55. node.get_device(),
  56. node.get_dtype(),
  57. # NB: this is symbolic so that we don't try to reuse a buffer
  58. # for s0 for s1, just because they happen to share the same
  59. # size hint
  60. sympy_str(V.graph.sizevars.simplify(node.layout.storage_size())),
  61. )
  62. def convert_arg_type(arg: torch.Argument) -> str:
  63. from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP
  64. # use x.real_type instead of x.type so that we get ScalarType instead of int
  65. python_type = repr(arg.real_type) # type: ignore[attr-defined]
  66. if python_type == "Tensor":
  67. # Conversions rules follow https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#func
  68. if arg.alias_info is not None and arg.alias_info.is_write:
  69. return f"at::{python_type}&"
  70. else:
  71. return f"at::{python_type} const&"
  72. if python_type in PYTHON_TO_CPP:
  73. cpp_type = PYTHON_TO_CPP[python_type]
  74. return cpp_type
  75. # Convert args of container types e.g. Optional[*]
  76. for py_container, cpp_container in CONTAINER_PYTHON_TO_CPP.items():
  77. container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type)
  78. if len(container_match) == 1:
  79. contained_type = container_match[0]
  80. assert (
  81. contained_type in PYTHON_TO_CPP
  82. ), f"unsupported {py_container} type in convert_arg_type: {contained_type}"
  83. cpp_contained_type = PYTHON_TO_CPP[contained_type]
  84. return f"{cpp_container}<{cpp_contained_type}>"
  85. raise AssertionError(f"unsupport python_type: {python_type}")
  86. def convert_return_type(ret: torch.Argument) -> str:
  87. # use x.real_type instead of x.type so that we get ScalarType instead of int
  88. python_type = repr(ret.real_type) # type: ignore[attr-defined]
  89. python_to_cpp = {
  90. "Tensor": "at::Tensor",
  91. "List[Tensor]": "std::vector<at::Tensor>",
  92. }
  93. cpp_type = python_to_cpp.get(python_type, None)
  94. assert cpp_type is not None, f"NYI return type: {python_type}"
  95. # An output aliasing an input is returned by reference only when it's a
  96. # Tensor, not when it's a Tensor[]. For example, aten.split.Tensor's output
  97. # aliases the input tensor, but the op returns a vector by value.
  98. if python_type == "Tensor" and ret.alias_info is not None:
  99. cpp_type += "&"
  100. return cpp_type
  101. def get_cpp_op_schema(kernel: torch._ops.OpOverload) -> str:
  102. args = kernel._schema.arguments
  103. returns = kernel._schema.returns
  104. num_returns = len(returns)
  105. assert num_returns > 0, "must have at least one return value"
  106. if num_returns == 1:
  107. cpp_return_value = convert_return_type(returns[0])
  108. elif num_returns > 1:
  109. tuple_returns = ", ".join([convert_return_type(r) for r in returns])
  110. cpp_return_value = f"std::tuple<{tuple_returns}>"
  111. cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args]
  112. return f"{cpp_return_value}({', '.join(cpp_arg_type)})" # type: ignore[possibly-undefined]
  113. # TODO: Move to a well known place
  114. TritonMetaParams = Dict[str, int]
  115. TritonGrid = Union[
  116. Tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], Tuple[int, ...]]
  117. ]
  118. def user_defined_kernel_grid_fn_code(
  119. name: str,
  120. configs: List["triton.Config"],
  121. grids: List[TritonGrid],
  122. wrapper: Optional["WrapperCodeGen"] = None,
  123. ) -> Tuple[str, str]:
  124. output = IndentedBuffer()
  125. def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr:
  126. return item if isinstance(item, sympy.Expr) else sympy.Integer(item)
  127. def determine_grid(grid: TritonGrid):
  128. if wrapper is None or callable(grid):
  129. # return as-is when used in eager mode or when grid is callable
  130. return grid
  131. # Grid contains ints/Expr, so utilize wrapper's expr printer for codegen
  132. sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid)
  133. return wrapper.codegen_shape_tuple(sympy_grid)
  134. fn_name = f"grid_wrapper_for_{name}"
  135. output.writeline(f"def {fn_name}(meta):")
  136. with output.indent():
  137. if len(grids) == 1:
  138. grid = determine_grid(grids[0])
  139. output.writeline(f"return {grid}")
  140. else:
  141. assert len(grids) > 1
  142. assert len(grids) == len(configs)
  143. seen = set()
  144. for grid, c in zip(grids, configs):
  145. guards = [f"meta['{name}'] == {val}" for name, val in c.kwargs.items()]
  146. guards = " and ".join(guards)
  147. grid = determine_grid(grid)
  148. statement = f"if {guards}: return {grid}"
  149. if statement in seen:
  150. continue
  151. seen.add(statement)
  152. output.writeline(statement)
  153. return fn_name, output.getvalue()
  154. @dataclasses.dataclass
  155. class SymbolicCallArg:
  156. inner: str
  157. # the original symbolic expression represented by inner
  158. inner_expr: sympy.Expr
  159. def __str__(self):
  160. return str(self.inner)
  161. # Default thread stack sizes vary by platform:
  162. # - Linux: 8 MB
  163. # - macOS: 512 KB
  164. # - Windows: 1 MB
  165. # Just pick something comfortably smaller than the smallest for now.
  166. MAX_STACK_ALLOCATION_SIZE = 1024 * 100
  167. class MemoryPlanningState:
  168. def __init__(self):
  169. super().__init__()
  170. self.reuse_pool: Dict[
  171. ReuseKey, List[FreeIfNotReusedLine]
  172. ] = collections.defaultdict(list)
  173. self.total_allocated_buffer_size: int = 0
  174. def __contains__(self, key: ReuseKey) -> bool:
  175. return bool(self.reuse_pool.get(key, None))
  176. def pop(self, key: ReuseKey) -> "FreeIfNotReusedLine":
  177. item = self.reuse_pool[key].pop()
  178. assert not item.is_reused
  179. return item
  180. def push(self, key: ReuseKey, item: "FreeIfNotReusedLine") -> None:
  181. assert not item.is_reused
  182. self.reuse_pool[key].append(item)
  183. class WrapperLine:
  184. pass
  185. @dataclasses.dataclass
  186. class EnterSubgraphLine(WrapperLine):
  187. wrapper: "WrapperCodeGen"
  188. graph: "GraphLowering"
  189. def codegen(self, code: IndentedBuffer) -> None:
  190. self.wrapper.push_codegened_graph(self.graph)
  191. code.do_indent()
  192. @dataclasses.dataclass
  193. class ExitSubgraphLine(WrapperLine):
  194. wrapper: "WrapperCodeGen"
  195. def codegen(self, code: IndentedBuffer) -> None:
  196. self.wrapper.pop_codegened_graph()
  197. code.do_unindent()
  198. @dataclasses.dataclass
  199. class EnterDeviceContextManagerLine(WrapperLine):
  200. device_idx: int
  201. last_seen_device_guard_index: Optional[int]
  202. def codegen(self, code: IndentedBuffer) -> None:
  203. if V.graph.cpp_wrapper:
  204. code.writeline("\n")
  205. if V.graph.aot_mode:
  206. # In AOT mode, we have a stream provided as a param. A stream is
  207. # associated with a device, so we never expect the device to change.
  208. # CUDAStreamGuard sets the stream and the device.
  209. if self.last_seen_device_guard_index is None:
  210. if config.abi_compatible:
  211. code.writeline(
  212. "AOTICudaStreamGuard stream_guard(stream, this->device_idx_);"
  213. )
  214. else:
  215. code.writeline(
  216. maybe_hipify_code_wrapper(
  217. "at::cuda::CUDAStreamGuard stream_guard("
  218. + "at::cuda::getStreamFromExternal(stream, this->device_idx_));"
  219. )
  220. )
  221. else:
  222. assert (
  223. self.last_seen_device_guard_index == self.device_idx
  224. ), "AOTInductor only supports running on one CUDA device"
  225. else:
  226. if self.last_seen_device_guard_index is None:
  227. code.writeline(
  228. f"AOTICudaGuard device_guard({self.device_idx});"
  229. if config.abi_compatible
  230. else maybe_hipify_code_wrapper(
  231. f"at::cuda::CUDAGuard device_guard({self.device_idx});"
  232. )
  233. )
  234. else:
  235. code.writeline(f"device_guard.set_index({self.device_idx});")
  236. else:
  237. # Note _DeviceGuard has less overhead than device, but only accepts
  238. # integers
  239. code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:")
  240. code.do_indent()
  241. code.writeline(V.graph.device_ops.set_device(self.device_idx))
  242. class ExitDeviceContextManagerLine(WrapperLine):
  243. def codegen(self, code: IndentedBuffer) -> None:
  244. if not V.graph.cpp_wrapper:
  245. code.do_unindent()
  246. @dataclasses.dataclass
  247. class MemoryPlanningLine(WrapperLine):
  248. wrapper: "WrapperCodeGen"
  249. def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
  250. """First pass to find reuse"""
  251. return self
  252. def codegen(self, code: IndentedBuffer) -> None:
  253. """Second pass to output code"""
  254. pass
  255. def __str__(self) -> str:
  256. """
  257. Emits a string representation that fits on one line.
  258. """
  259. args: List[str] = []
  260. for field in dataclasses.fields(self):
  261. if field.name == "wrapper":
  262. continue
  263. val = getattr(self, field.name)
  264. args.append(
  265. f"{field.name}={val.get_name() if field.type is ir.Buffer else val}"
  266. )
  267. return f"{type(self).__name__}({', '.join(args)})"
  268. @dataclasses.dataclass
  269. class AllocateLine(MemoryPlanningLine):
  270. node: ir.Buffer
  271. def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
  272. if self.node.get_name() in V.graph.removed_buffers:
  273. return NullLine(self.wrapper)
  274. # try to reuse a recently freed buffer
  275. key = buffer_reuse_key(self.node)
  276. if config.allow_buffer_reuse and key in state:
  277. free_line = state.pop(key)
  278. free_line.is_reused = True
  279. return ReuseLine(self.wrapper, free_line.node, self.node)
  280. if self.node.get_device().type == "cpu":
  281. static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node)
  282. if static_shape is not None:
  283. state.total_allocated_buffer_size += int(
  284. functools.reduce(operator.mul, static_shape, 1)
  285. )
  286. return self
  287. def codegen(self, code: IndentedBuffer) -> None:
  288. assert self.node.get_name() not in V.graph.removed_buffers
  289. line = self.wrapper.make_buffer_allocation(self.node)
  290. code.writeline(line)
  291. @dataclasses.dataclass
  292. class FreeIfNotReusedLine(MemoryPlanningLine):
  293. node: ir.Buffer
  294. is_reused: bool = False
  295. def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
  296. if len(self.node.get_inputs_that_alias_output()) > 0:
  297. return self
  298. if isinstance(self.node.layout, ir.MultiOutputLayout):
  299. return self
  300. assert not self.is_reused
  301. if self.node.get_name() in V.graph.removed_buffers:
  302. return NullLine(self.wrapper)
  303. if config.allow_buffer_reuse:
  304. state.push(buffer_reuse_key(self.node), self)
  305. return self
  306. def codegen(self, code: IndentedBuffer) -> None:
  307. assert self.node.get_name() not in V.graph.removed_buffers
  308. if not self.is_reused:
  309. code.writeline(self.wrapper.make_buffer_free(self.node))
  310. @dataclasses.dataclass
  311. class ReuseLine(MemoryPlanningLine):
  312. node: ir.Buffer
  313. reused_as: ir.Buffer
  314. delete_old: bool = True
  315. def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
  316. if self.node.get_name() in V.graph.removed_buffers:
  317. assert self.reused_as.get_name() in V.graph.removed_buffers
  318. return NullLine(self.wrapper)
  319. assert self.reused_as.get_name() not in V.graph.removed_buffers
  320. return self
  321. def codegen(self, code: IndentedBuffer) -> None:
  322. assert self.node.get_name() not in V.graph.removed_buffers
  323. assert self.reused_as.get_name() not in V.graph.removed_buffers
  324. code.writeline(
  325. self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old)
  326. )
  327. class NullLine(MemoryPlanningLine):
  328. pass
  329. BufferName = str
  330. class WrapperCodeGen(CodeGen):
  331. """
  332. Generate outer wrapper in Python that calls the kernels.
  333. """
  334. def __init__(self):
  335. super().__init__()
  336. self._names_iter: Iterator[int] = count()
  337. self.header = IndentedBuffer()
  338. self.prefix = IndentedBuffer()
  339. self.suffix = IndentedBuffer()
  340. self.wrapper_call = IndentedBuffer()
  341. # If the generated source code is exactly the same, reuse the
  342. # pre-existing kernel for it
  343. self.src_to_kernel: Dict[str, str] = {}
  344. self.kernel_numel_expr: Set[Tuple[str, GraphLowering]] = set()
  345. self.lines: List[Union[MemoryPlanningLine, LineContext]] = []
  346. self.declare = ""
  347. self.declare_maybe_reference = ""
  348. self.ending = ""
  349. self.open_bracket = "["
  350. self.closed_bracket = "]"
  351. self.comment = "#"
  352. self.namespace = ""
  353. self.none_str = "None"
  354. self.size = "size()"
  355. self.stride = "stride()"
  356. self.last_seen_device_guard_index: Optional[int] = None
  357. self.supports_intermediate_hooks = True
  358. self.expr_printer: Callable[[Any], str] = pexpr
  359. self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {}
  360. self.unbacked_symbol_decls: Set[str] = set() # str of sympy.Symbol
  361. self.allow_stack_allocation: Optional[bool] = None
  362. self.stack_allocated_buffers: Dict[BufferName, ir.Buffer] = {}
  363. self.computed_sizes: Set[sympy.Symbol] = set()
  364. # this is used for tracking which GraphLowering instance---parent graph
  365. # or (nested) subgraph---is currently codegened; the primary use case is
  366. # including the graph instance into a cache key to avoid cross-graph
  367. # caching during lowering of nested subgraphs
  368. self.codegened_graph_stack = []
  369. self.write_header()
  370. self.write_prefix()
  371. if not V.graph.aot_mode:
  372. for name, hashed in V.graph.constant_reprs.items():
  373. # include a hash so our code cache puts different constants into different files
  374. self.write_constant(name, hashed)
  375. self.allocated: Set[BufferName] = set()
  376. self.freed: Set[BufferName] = set()
  377. # maps from reusing buffer to reused buffer
  378. self.reuses: Dict[BufferName, BufferName] = dict()
  379. self.write_get_raw_stream = functools.lru_cache(None)( # type: ignore[assignment]
  380. self.write_get_raw_stream
  381. )
  382. @functools.lru_cache(None)
  383. def add_import_once(line: str) -> None:
  384. self.header.writeline(line)
  385. self.add_import_once = add_import_once
  386. self._metas: Dict[str, str] = {}
  387. self.multi_kernel_state = MultiKernelState()
  388. def write_constant(self, name: str, hashed: str) -> None:
  389. self.header.writeline(f"{name} = None # {hashed}")
  390. def write_header(self) -> None:
  391. context = torch._guards.TracingContext.try_get()
  392. aot_config_comment = ""
  393. if context is not None and context.aot_graph_name is not None:
  394. aot_config_comment = f"# AOT ID: {context.aot_graph_name}"
  395. self.header.splice(
  396. f"""
  397. {aot_config_comment}
  398. from ctypes import c_void_p, c_long
  399. import torch
  400. import math
  401. import random
  402. import os
  403. import tempfile
  404. from math import inf, nan
  405. from torch._inductor.hooks import run_intermediate_hooks
  406. from torch._inductor.utils import maybe_profile
  407. from torch._inductor.codegen.memory_planning import _align as align
  408. from torch import device, empty_strided
  409. from {async_compile.__name__} import AsyncCompile
  410. from torch._inductor.select_algorithm import extern_kernels
  411. from torch._inductor.codegen.multi_kernel import MultiKernelCall
  412. aten = torch.ops.aten
  413. inductor_ops = torch.ops.inductor
  414. _quantized = torch.ops._quantized
  415. assert_size_stride = torch._C._dynamo.guards.assert_size_stride
  416. empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
  417. empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
  418. reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
  419. alloc_from_pool = torch.ops.inductor._alloc_from_pool
  420. async_compile = AsyncCompile()
  421. """
  422. )
  423. @cache_on_self
  424. def write_triton_header_once(self) -> None:
  425. self.header.splice(
  426. """
  427. import triton
  428. import triton.language as tl
  429. from {} import grid, split_scan_grid, start_graph, end_graph
  430. {}
  431. """.format(
  432. triton_heuristics.__name__,
  433. V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"),
  434. )
  435. )
  436. def add_meta_once(self, meta: TritonMetaParams) -> str:
  437. meta = repr(meta)
  438. if meta not in self._metas:
  439. var = f"meta{len(self._metas)}"
  440. self._metas[meta] = var
  441. self.header.writeline(f"{var} = {meta}")
  442. return self._metas[meta]
  443. @cache_on_self
  444. def get_output_refs(self) -> List[str]:
  445. return [x.codegen_reference(self.wrapper_call) for x in V.graph.graph_outputs]
  446. def mark_output_type(self) -> None:
  447. return
  448. def codegen_input_size_asserts(self) -> None:
  449. for name, buf in V.graph.graph_inputs.items():
  450. if isinstance(buf, sympy.Expr):
  451. continue
  452. # comparing strides for 0 size tensor is tricky. Ignore them for now.
  453. if sympy_product(buf.get_size()) == 0:
  454. continue
  455. size = self.codegen_shape_tuple(buf.get_size())
  456. stride = self.codegen_shape_tuple(buf.get_stride())
  457. self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})")
  458. def codegen_input_nan_asserts(self) -> None:
  459. self.prefix.writeline("# make sure graph inputs are not nan/inf")
  460. for name, buf in V.graph.graph_inputs.items():
  461. if isinstance(buf, sympy.Expr):
  462. continue
  463. line = f"assert not {name}.isnan().any().item()"
  464. self.prefix.writeline(line)
  465. line = f"assert not {name}.isinf().any().item()"
  466. self.prefix.writeline(line)
  467. def write_prefix(self) -> None:
  468. self.prefix.splice(
  469. """
  470. async_compile.wait(globals())
  471. del async_compile
  472. def call(args):
  473. """
  474. )
  475. with self.prefix.indent():
  476. if config.triton.debug_sync_graph:
  477. self.prefix.writeline(V.graph.device_ops.synchronize())
  478. if V.graph.graph_inputs:
  479. lhs = ", ".join(V.graph.graph_input_names)
  480. if len(V.graph.graph_input_names) == 1:
  481. lhs += ","
  482. self.prefix.writeline(f"{lhs} = args")
  483. self.prefix.writeline("args.clear()")
  484. self.codegen_inputs(self.prefix, V.graph.graph_inputs)
  485. if config.size_asserts:
  486. self.codegen_input_size_asserts()
  487. if config.nan_asserts:
  488. self.codegen_input_nan_asserts()
  489. # this function (and below) takes a graph as input so
  490. # that stream caching happens per graph instance. this
  491. # is important for nested subgraph codegening.
  492. def write_get_raw_stream(self, device_idx: int, graph=None) -> str:
  493. self.write_triton_header_once()
  494. name = f"stream{device_idx}"
  495. self.writeline(f"{name} = get_raw_stream({device_idx})")
  496. return name
  497. def get_codegened_graph(self):
  498. return self.codegened_graph_stack[-1]
  499. def push_codegened_graph(self, graph):
  500. self.codegened_graph_stack.append(graph)
  501. def pop_codegened_graph(self):
  502. return self.codegened_graph_stack.pop()
  503. def next_kernel_suffix(self) -> str:
  504. return f"{next(self._names_iter)}"
  505. def codegen_device_guard_enter(self, device_idx: int) -> None:
  506. self.writeline(
  507. EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index)
  508. )
  509. self.last_seen_device_guard_index = device_idx
  510. def codegen_device_guard_exit(self) -> None:
  511. self.writeline(ExitDeviceContextManagerLine())
  512. def generate_return(self, output_refs: List[str]) -> None:
  513. if output_refs:
  514. self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
  515. else:
  516. self.wrapper_call.writeline("return ()")
  517. def generate_before_suffix(self, result: IndentedBuffer) -> None:
  518. return
  519. def generate_end(self, result: IndentedBuffer) -> None:
  520. return
  521. def generate_fallback_kernel(self, fallback_kernel, args):
  522. self.generate_extern_kernel_alloc(fallback_kernel, args)
  523. def generate_extern_kernel_alloc(self, extern_kernel, args):
  524. output_name = extern_kernel.get_name()
  525. origin_node = extern_kernel.get_origin_node()
  526. kernel_name = extern_kernel.get_kernel_name()
  527. ending = self.ending
  528. if config.memory_planning and "view_as_complex" in kernel_name:
  529. # view operation fallbacks cause issues since inductor
  530. # doesn't know the memory is still needed and might reuse it.
  531. ending = f".clone(){ending}"
  532. self.writeline(
  533. f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}"
  534. )
  535. if (
  536. self.supports_intermediate_hooks
  537. and config.generate_intermediate_hooks
  538. and origin_node is not None
  539. ):
  540. counters["inductor"]["intermediate_hooks"] += 1
  541. self.writeline(
  542. f"run_intermediate_hooks({origin_node.name!r}, {output_name})"
  543. )
  544. def generate_extern_kernel_out(
  545. self, kernel: str, out: str, out_view: Optional[str], args: List[str]
  546. ):
  547. args.append(f"out={out_view if out_view else out}")
  548. self.writeline(f"{kernel}({', '.join(args)})")
  549. def generate_user_defined_triton_kernel(
  550. self, kernel_name, grid, configs, args, triton_meta, arg_types=None
  551. ):
  552. grid, code = user_defined_kernel_grid_fn_code(
  553. kernel_name, configs, grid, wrapper=self
  554. )
  555. # Must happen after free symbols are already codegened
  556. # Emit the grid wrapper function right before the call
  557. for line in code.split("\n"):
  558. self.writeline(line)
  559. current_device = V.graph.scheduler.get_current_device_or_throw()
  560. stream_name = self.write_get_raw_stream(current_device.index, V.graph)
  561. self.writeline(
  562. f"{kernel_name}.run({', '.join(args)}, grid={grid}, stream={stream_name})"
  563. )
  564. def generate_scatter_fallback(
  565. self,
  566. output,
  567. inputs,
  568. cpp_kernel_name,
  569. python_kernel_name,
  570. src_is_tensor,
  571. reduce,
  572. kwargs,
  573. ):
  574. line = f"{python_kernel_name}({','.join(map(str, inputs))}"
  575. if python_kernel_name.startswith("aten.scatter_reduce"):
  576. line += ", ".join([""] + kwargs)
  577. else:
  578. if reduce:
  579. line += f", reduce={repr(reduce)}"
  580. line += ")"
  581. self.writeline(line)
  582. def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
  583. indices_str = f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}"
  584. args = [x, indices_str, values, accumulate]
  585. self.writeline(self.wrap_kernel_call(kernel, args))
  586. def generate_extern_kernel_alloc_and_find_schema_if_needed(
  587. self,
  588. buf_name: str,
  589. python_kernel_name: str,
  590. cpp_kernel_name: str,
  591. codegen_args: List[str],
  592. cpp_op_schema: str,
  593. cpp_kernel_key: str,
  594. cpp_kernel_overload_name: str = "",
  595. op_overload: Optional[torch._ops.OpOverload] = None,
  596. raw_args=None,
  597. outputs=None,
  598. ):
  599. self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(codegen_args)})")
  600. @dynamo_timed
  601. def generate(self, is_inference):
  602. if config.profile_bandwidth:
  603. self.write_triton_header_once()
  604. result = IndentedBuffer()
  605. result.splice(self.header)
  606. # We do not want the cpp header for intermediate const graph. Headers would be
  607. # rendered by the main module instead.
  608. if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph:
  609. result = IndentedBuffer()
  610. with contextlib.ExitStack() as stack:
  611. stack.enter_context(self.wrapper_call.indent())
  612. if config.profiler_mark_wrapper_call:
  613. self.generate_profiler_mark_wrapper_call(stack)
  614. if config.profile_bandwidth:
  615. self.generate_start_graph()
  616. # We disable planning during training because it presently increases peak memory consumption.
  617. if is_inference and config.memory_planning:
  618. self.memory_plan()
  619. # TODO: integrate memory planning & stack allocation?
  620. self.allow_stack_allocation = False
  621. else:
  622. self.memory_plan_reuse()
  623. if config.triton.store_cubin:
  624. self.generate_reset_kernel_saved_flags()
  625. for line in self.lines:
  626. if isinstance(line, WrapperLine):
  627. line.codegen(self.wrapper_call)
  628. else:
  629. self.wrapper_call.writeline(line)
  630. output_refs = self.get_output_refs()
  631. self.mark_output_type()
  632. if config.triton.debug_sync_graph:
  633. self.wrapper_call.writeline(V.graph.device_ops.synchronize())
  634. if config.profile_bandwidth:
  635. self.generate_end_graph()
  636. if config.triton.store_cubin:
  637. self.generate_save_uncompiled_kernels()
  638. self.generate_return(output_refs)
  639. self.finalize_prefix()
  640. result.splice(self.prefix)
  641. with result.indent():
  642. result.splice(self.wrapper_call)
  643. self.generate_before_suffix(result)
  644. result.splice(self.suffix)
  645. self.generate_end(result)
  646. self.add_benchmark_harness(result)
  647. return result.getvaluewithlinemap()
  648. def memory_plan(self):
  649. from .memory_planning import MemoryPlanner
  650. self.lines = MemoryPlanner(self).plan(self.lines)
  651. def memory_plan_reuse(self):
  652. out_names = V.graph.get_output_names()
  653. while (
  654. self.lines
  655. and isinstance(self.lines[-1], MemoryPlanningLine)
  656. # TODO: this seems legit, NullLine has no node
  657. and self.lines[-1].node.name not in out_names # type: ignore[attr-defined]
  658. ):
  659. # these lines will be pointless
  660. self.lines.pop()
  661. # codegen allocations in two passes
  662. planning_states = [MemoryPlanningState()]
  663. past_planning_states = []
  664. for i in range(len(self.lines)):
  665. line = self.lines[i]
  666. if isinstance(line, MemoryPlanningLine):
  667. self.lines[i] = line.plan(planning_states[-1])
  668. elif isinstance(line, EnterSubgraphLine):
  669. planning_states.append(MemoryPlanningState())
  670. elif isinstance(line, ExitSubgraphLine):
  671. past_planning_states.append(planning_states.pop())
  672. past_planning_states.append(planning_states.pop())
  673. assert len(planning_states) == 0
  674. # conservatively use the sum of all allocated buffer sizes
  675. # in potentially nested scopes as the total allocated size
  676. total_allocated_buffer_size = sum(
  677. s.total_allocated_buffer_size for s in past_planning_states
  678. )
  679. self.allow_stack_allocation = (
  680. self.allow_stack_allocation is not False
  681. and config.allow_stack_allocation
  682. and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE
  683. )
  684. def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
  685. code.writeline(f"{self.declare}{name}_size = {name}.{self.size}{self.ending}")
  686. def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
  687. code.writeline(
  688. f"{self.declare}{name}_stride = {name}.{self.stride}{self.ending}"
  689. )
  690. def codegen_inputs(
  691. self, code: IndentedBuffer, graph_inputs: Dict[str, ir.TensorBox]
  692. ):
  693. """Assign all symbolic shapes to locals"""
  694. @functools.lru_cache(None)
  695. def sizeof(name):
  696. self.codegen_input_size_var_decl(code, name)
  697. return f"{name}_size"
  698. @functools.lru_cache(None)
  699. def strideof(name):
  700. self.codegen_input_stride_var_decl(code, name)
  701. return f"{name}_stride"
  702. # Assign all symbolic shapes needed to local variables
  703. bound_vars: Set[sympy.Symbol] = set()
  704. def is_expr(x):
  705. return isinstance(x[1], sympy.Expr)
  706. graph_inputs_expr = list(filter(is_expr, graph_inputs.items()))
  707. graph_inputs_tensors = list(
  708. filter(lambda x: not is_expr(x), graph_inputs.items())
  709. )
  710. for name, shape in graph_inputs_expr:
  711. if isinstance(shape, sympy.Symbol) and shape not in bound_vars:
  712. code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
  713. bound_vars.add(shape)
  714. for name, value in graph_inputs_tensors:
  715. shapes = value.get_size()
  716. for dim, shape in enumerate(shapes):
  717. if isinstance(shape, sympy.Symbol) and shape not in bound_vars:
  718. code.writeline(
  719. f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
  720. )
  721. bound_vars.add(shape)
  722. for name, value in graph_inputs_tensors:
  723. shapes = value.get_stride()
  724. for dim, shape in enumerate(shapes):
  725. if isinstance(shape, sympy.Symbol) and shape not in bound_vars:
  726. code.writeline(
  727. f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
  728. )
  729. bound_vars.add(shape)
  730. def ensure_size_computed(self, sym: sympy.Symbol):
  731. if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE):
  732. if sym in self.computed_sizes:
  733. return
  734. self.computed_sizes.add(sym)
  735. expr = V.graph.sizevars.inv_precomputed_replacements[sym]
  736. self.writeline(
  737. f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}"
  738. )
  739. def finalize_prefix(self):
  740. pass
  741. def codegen_python_sizevar(self, x: Expr, *, simplify: bool = True) -> str:
  742. return pexpr(x, simplify=simplify)
  743. def codegen_sizevar(self, x: Expr) -> str:
  744. return self.codegen_python_sizevar(x)
  745. def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
  746. return f"{basename}[{index}]"
  747. def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
  748. parts = list(map(self.codegen_python_sizevar, shape))
  749. if len(parts) == 0:
  750. return "()"
  751. if len(parts) == 1:
  752. return f"({parts[0]}, )"
  753. return f"({', '.join(parts)})"
  754. def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
  755. return self.codegen_python_shape_tuple(shape)
  756. def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
  757. return "alloc_from_pool({})".format(
  758. ", ".join(
  759. [
  760. name,
  761. pexpr(offset), # bytes not numel
  762. str(dtype),
  763. self.codegen_shape_tuple(shape),
  764. self.codegen_shape_tuple(stride),
  765. ]
  766. )
  767. )
  768. def codegen_reinterpret_view(self, data, size, stride, offset, writer) -> str:
  769. size = self.codegen_shape_tuple(size)
  770. stride = self.codegen_shape_tuple(stride)
  771. offset = self.codegen_sizevar(offset)
  772. return f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})"
  773. def codegen_device_copy(self, src, dst):
  774. self.writeline(f"{dst}.copy_({src})")
  775. def codegen_multi_output(self, name, value):
  776. self.writeline(f"{self.declare}{name} = {value}{self.ending}")
  777. def codegen_dynamic_scalar(self, node):
  778. (data,) = (t.codegen_reference() for t in node.inputs)
  779. if len(node.keypath) == 0:
  780. self.writeline(f"{node.sym} = {data}.item()")
  781. elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey):
  782. self.writeline(f"{node.sym} = 1 if {data}.item() else 0")
  783. elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey):
  784. self.writeline(f"{node.sym}_undivided = {data}.item()")
  785. self.writeline(
  786. f"assert {node.sym}_undivided % {node.keypath[0].divisor} == 0, "
  787. f"f'{{{node.sym}_undivided}} not divisible by {node.keypath[0].divisor}'"
  788. )
  789. self.writeline(
  790. f"{node.sym} = {node.sym}_undivided // {node.keypath[0].divisor}"
  791. )
  792. else:
  793. raise AssertionError(f"unrecognized keypath {node.keypath}")
  794. # No one should ever use this buffer, but for uniformity
  795. # define the variable and assign it None
  796. self.writeline(f"{node.get_name()} = None")
  797. def benchmark_compiled_module(self, output):
  798. def add_fake_input(name, shape, stride, device, dtype):
  799. output.writeline(
  800. f"{name} = rand_strided("
  801. f"{self.codegen_python_shape_tuple(shape)}, "
  802. f"{self.codegen_python_shape_tuple(stride)}, "
  803. f"device='{device}', dtype={dtype})"
  804. )
  805. def add_expr_input(name, val):
  806. output.writeline(f"{name} = {val}")
  807. def add_torchbind_input(name, value):
  808. import pickle
  809. output.writeline(f"{name} = pickle.loads({pickle.dumps(value)!r})")
  810. output.writelines(
  811. ["", "", "def benchmark_compiled_module(times=10, repeat=10):"]
  812. )
  813. with output.indent():
  814. output.splice(
  815. """
  816. from torch._dynamo.testing import rand_strided
  817. from torch._inductor.utils import print_performance
  818. """,
  819. strip=True,
  820. )
  821. for name, value in V.graph.constants.items():
  822. # all the constants are global variables, that's why we need
  823. # these 'global var_name' lines
  824. output.writeline(f"global {name}")
  825. add_fake_input(
  826. name, value.size(), value.stride(), value.device, value.dtype
  827. )
  828. if len(V.graph.torchbind_constants) > 0:
  829. output.writeline("import pickle")
  830. for name, torchbind_obj in V.graph.torchbind_constants.items():
  831. # all the constants are global variables, that's why we need
  832. # these 'global var_name' lines
  833. output.writeline(f"global {name}")
  834. add_torchbind_input(name, torchbind_obj)
  835. for name, value in V.graph.graph_inputs.items():
  836. if isinstance(value, sympy.Symbol) and isinstance(
  837. V.graph.sizevars.var_to_val.get(value, None), SingletonInt
  838. ):
  839. # Inductor should only work with dense -> dense graph, and
  840. # SingletonInts belong to metadata that should only live on
  841. # the subclass.
  842. continue
  843. if isinstance(value, sympy.Expr): # Don't need to add symbolic
  844. # TODO: this fallback and those below actually will generate possibly
  845. # invalid benchmark code, because it's not guaranteed 42
  846. # is actually a valid value for the kernel in question.
  847. # See https://github.com/pytorch/pytorch/issues/124686
  848. add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42))
  849. else:
  850. shape = [
  851. V.graph.sizevars.size_hint(x, fallback=42)
  852. for x in value.get_size()
  853. ]
  854. stride = [
  855. V.graph.sizevars.size_hint(x, fallback=42)
  856. for x in value.get_stride()
  857. ]
  858. add_fake_input(
  859. name,
  860. shape,
  861. stride,
  862. value.get_device(),
  863. value.get_dtype(),
  864. )
  865. call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])"
  866. output.writeline(f"fn = lambda: {call_str}")
  867. output.writeline("return print_performance(fn, times=times, repeat=repeat)")
  868. def add_benchmark_harness(self, output):
  869. """
  870. Append a benchmark harness to generated code for debugging
  871. """
  872. if not config.benchmark_harness:
  873. return
  874. self.benchmark_compiled_module(output)
  875. output.writelines(["", "", 'if __name__ == "__main__":'])
  876. with output.indent():
  877. output.writelines(
  878. [
  879. "from torch._inductor.wrapper_benchmark import compiled_module_main",
  880. f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)",
  881. ]
  882. )
  883. def define_kernel(
  884. self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
  885. ):
  886. metadata_comment = f"{metadata}\n" if metadata else ""
  887. self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")
  888. def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
  889. from torch.utils._triton import patch_triton_dtype_repr
  890. patch_triton_dtype_repr()
  891. original_name = kernel.__name__
  892. from .common import KernelArgType, SizeArg, TensorArg
  893. signature: List[KernelArgType] = []
  894. constants: Dict[int, Any] = {}
  895. non_constant_indices = []
  896. equal_to_1_arg_idx: List[int] = []
  897. for idx, key in enumerate(kernel.arg_names):
  898. if key not in kwargs:
  899. continue
  900. arg = kwargs[key]
  901. if idx in kernel.constexprs:
  902. constants[idx] = arg
  903. else:
  904. non_constant_indices.append(idx)
  905. if isinstance(arg, ir.Buffer):
  906. signature.append(
  907. TensorArg(
  908. name=key,
  909. buffer=arg.get_name(),
  910. dtype=arg.get_dtype(),
  911. )
  912. )
  913. elif isinstance(arg, ir.ReinterpretView):
  914. # for ReinterpretView we use the underlying
  915. # buffer name and note the (possibly non-zero)
  916. # offset relative to the underlying buffer
  917. signature.append(
  918. TensorArg(
  919. name=key,
  920. buffer=arg.data.get_name(),
  921. dtype=arg.get_dtype(),
  922. offset=arg.layout.offset,
  923. )
  924. )
  925. else:
  926. signature.append(SizeArg(key, arg))
  927. if isinstance(
  928. arg, (int, sympy.Integer)
  929. ) and V.graph.sizevars.statically_known_equals(
  930. arg, 1 # type: ignore[arg-type]
  931. ):
  932. equal_to_1_arg_idx.append(idx)
  933. index_dtype = "tl.int32"
  934. triton_meta = {
  935. "signature": signature_to_meta(
  936. signature,
  937. size_dtype=index_dtype,
  938. indices=non_constant_indices,
  939. ),
  940. "device": DeviceProperties.create(
  941. V.graph.scheduler.get_current_device_or_throw()
  942. ),
  943. # Triton compiler includes equal_to_1 args into constants even
  944. # when they are not constexpr. otherwise there may be a segfault
  945. # during launching the Inductor-compiled Triton kernel.
  946. # TODO(aakhundov): add None args to constants, too. currently, this
  947. # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input.
  948. # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
  949. # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
  950. "constants": {
  951. **constants,
  952. **dict.fromkeys(equal_to_1_arg_idx, 1),
  953. },
  954. "configs": [
  955. config_of(
  956. signature,
  957. indices=non_constant_indices,
  958. )
  959. ],
  960. }
  961. # Distinguish between different functions using function id
  962. cache_key: List[Any] = [id(kernel.fn)]
  963. if len(configs) > 0:
  964. for arg in kwargs.values():
  965. # We need to key on non tensor arg only in autotune mode
  966. if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
  967. cache_key.append(arg)
  968. cache_key.append(str(triton_meta))
  969. cache_key = tuple(cache_key)
  970. if cache_key in self.user_defined_kernel_cache:
  971. return self.user_defined_kernel_cache[cache_key]
  972. name = f"{original_name}_{len(self.user_defined_kernel_cache)}"
  973. # Add to the cache for the next use
  974. self.user_defined_kernel_cache[cache_key] = (name, triton_meta)
  975. compile_wrapper = IndentedBuffer()
  976. compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''")
  977. from .triton import gen_common_triton_imports, TritonKernel
  978. compile_wrapper.splice(gen_common_triton_imports())
  979. inductor_meta = {
  980. "kernel_name": name,
  981. **TritonKernel.inductor_meta_common(),
  982. }
  983. configs = [
  984. {
  985. "kwargs": config.kwargs,
  986. "num_warps": config.num_warps,
  987. "num_stages": config.num_stages,
  988. }
  989. for config in configs
  990. ]
  991. compile_wrapper.splice(
  992. f"""
  993. @triton_heuristics.user_autotune(
  994. configs={configs!r},
  995. inductor_meta={inductor_meta!r},
  996. triton_meta={triton_meta!r},
  997. filename=__file__,
  998. custom_kernel=True,
  999. )
  1000. @triton.jit
  1001. """
  1002. )
  1003. compile_wrapper.splice(kernel.src, strip=True)
  1004. # Also include any possible kernel being called indirectly
  1005. from triton import JITFunction
  1006. from triton.language import constexpr
  1007. # global constexpr vars handled above
  1008. symbols_included = {original_name}
  1009. def traverse(cur_kernel):
  1010. # here we extract the unqualified names (i.e., not attributes and
  1011. # without prepended module name) loaded in the kernel code, which
  1012. # are matched with the co_names and __globals__ below to codegen
  1013. # the respective imports necessary for the kernel compilation
  1014. unqualified_loads = {
  1015. inst.argval
  1016. for inst in dis.Bytecode(cur_kernel.fn)
  1017. if inst.opname == "LOAD_GLOBAL"
  1018. }
  1019. global_annotations = cur_kernel.fn.__globals__.get("__annotations__", {})
  1020. for symbol_name in cur_kernel.fn.__code__.co_names:
  1021. if symbol_name in symbols_included:
  1022. continue
  1023. if symbol_name in cur_kernel.fn.__globals__:
  1024. symbol = cur_kernel.fn.__globals__[symbol_name]
  1025. if isinstance(symbol, JITFunction):
  1026. compile_wrapper.newline()
  1027. compile_wrapper.writeline("@triton.jit")
  1028. compile_wrapper.splice(symbol.src, strip=True)
  1029. symbols_included.add(symbol_name)
  1030. traverse(symbol)
  1031. elif isinstance(symbol, (int, str, bool, constexpr)):
  1032. compile_wrapper.newline()
  1033. if isinstance(symbol, constexpr):
  1034. symbol_str = f"tl.constexpr({symbol.value!r})"
  1035. else:
  1036. symbol_str = f"{symbol!r}"
  1037. if annotation := global_annotations.get(symbol_name):
  1038. annotion_code = ""
  1039. if isinstance(annotation, type):
  1040. annotation_code = (
  1041. f": {annotation.__module__}.{annotation.__name__}"
  1042. )
  1043. else:
  1044. annotation_code = f": {annotation!r}"
  1045. compile_wrapper.writeline(
  1046. f"{symbol_name}{annotation_code} = {symbol_str}"
  1047. )
  1048. else:
  1049. compile_wrapper.writeline(f"{symbol_name} = {symbol!r}")
  1050. symbols_included.add(symbol_name)
  1051. elif (
  1052. symbol_name in unqualified_loads
  1053. and symbol_name != "tl" # already imported
  1054. and hasattr(symbol, "__module__")
  1055. # only codegen imports from triton; JITFunctions
  1056. # imported from other modules will be codegened
  1057. # in the separate branch above
  1058. and symbol.__module__.startswith("triton")
  1059. ):
  1060. # a global symbol imported from triton is referenced
  1061. # without module qualification (i.e., `store` instead
  1062. # of `tl.store`): need to codegen an import
  1063. compile_wrapper.writeline(
  1064. f"from {symbol.__module__} import {symbol.__name__} as {symbol_name}"
  1065. )
  1066. symbols_included.add(symbol_name)
  1067. traverse(kernel)
  1068. current_device = V.graph.scheduler.get_current_device_or_throw()
  1069. compile_wrapper.writeline(f"''', device_str='{current_device.type}')")
  1070. _, lineno = inspect.getsourcelines(kernel.fn)
  1071. srcfile = inspect.getsourcefile(kernel.fn)
  1072. metadata = f"# Original path: {srcfile}:{lineno}"
  1073. self.define_kernel(
  1074. name,
  1075. compile_wrapper.getvalue(),
  1076. metadata,
  1077. )
  1078. return name, triton_meta
  1079. def generate_numel_expr(self, kernel_name: str, tree):
  1080. expr = f"{kernel_name}_{tree.prefix}numel"
  1081. if (expr, V.graph) not in self.kernel_numel_expr:
  1082. # declare expr once in each graph (scope)
  1083. self.kernel_numel_expr.add((expr, V.graph))
  1084. self.writeline(
  1085. f"{self.declare}{expr} = {self.expr_printer(tree.numel)}{self.ending}"
  1086. )
  1087. else:
  1088. self.writeline(f"{expr} = {self.expr_printer(tree.numel)}{self.ending}")
  1089. # We can get symbolic expressions here, like s0*64
  1090. # It is fine to have them here, but we need to handle them correctly as their own type
  1091. # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy*
  1092. # scalars as well.
  1093. # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for
  1094. # constant now, need type info. I agree, this needs type info, and while this is not true type info
  1095. # it suffices as a type hint for the purposes of producing the correct code for this type.
  1096. return SymbolicCallArg(expr, tree.numel)
  1097. def generate_workspace_allocation(self, nbytes, device, zero_fill):
  1098. line = self.make_allocation(
  1099. "workspace", device, torch.uint8, shape=(nbytes,), stride=(1,)
  1100. )
  1101. self.writeline(line)
  1102. if zero_fill:
  1103. self.writeline(f"workspace.zero_(){self.ending}")
  1104. def wrap_kernel_call(self, name, call_args):
  1105. return f"{name}({', '.join(call_args)}){self.ending}"
  1106. def generate_profiler_mark_wrapper_call(self, stack):
  1107. self.wrapper_call.writeline("from torch.profiler import record_function")
  1108. self.wrapper_call.writeline(
  1109. f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):"
  1110. )
  1111. stack.enter_context(self.wrapper_call.indent())
  1112. def generate_start_graph(self):
  1113. self.wrapper_call.writeline("start_graph()")
  1114. def generate_end_graph(self):
  1115. self.wrapper_call.writeline(f"end_graph({config.profile_bandwidth_output!r})")
  1116. def generate_reset_kernel_saved_flags(self):
  1117. self.wrapper_call.splice(
  1118. f"""
  1119. for kernel in globals().values():
  1120. if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner):
  1121. kernel.cuda_kernel_saved = False
  1122. """
  1123. )
  1124. def generate_save_uncompiled_kernels(self):
  1125. """
  1126. Precompile and save the CUBINs of the Triton kernels that haven't
  1127. been precompiled and saved as a side effect of running the generated
  1128. JIT model (Python wrapper). This can happen when the model contains
  1129. control flow: only one pass through the control flow operators covers
  1130. the kernels that are saved, the remaining kernels are not launched,
  1131. hence not saved. The main purpose of this codegen is to compile and
  1132. save the Triton kernels outside the active control flow path for
  1133. subsequent AOTInductor code generation and compilation.
  1134. """
  1135. self.wrapper_call.splice(
  1136. f"""
  1137. for kernel in globals().values():
  1138. if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner):
  1139. if not kernel.cuda_kernel_saved:
  1140. if len(kernel.launchers) == 0:
  1141. kernel.precompile()
  1142. kernel.save_cuda_kernel(
  1143. grid=(0, 0, 0), # use dummy grid
  1144. stream="stream", # use dummy stream
  1145. launcher=kernel.launchers[0],
  1146. )
  1147. """
  1148. )
  1149. def generate_default_grid(self, name: str, grid_args: List[Any]):
  1150. return grid_args
  1151. def generate_kernel_call(
  1152. self,
  1153. name,
  1154. call_args,
  1155. grid=None,
  1156. device_index=None,
  1157. cuda=True,
  1158. triton=True,
  1159. arg_types=None,
  1160. grid_fn: str = "grid",
  1161. triton_meta=None,
  1162. ):
  1163. """
  1164. Generates kernel call code.
  1165. cuda: Defines whether the backend is GPU. Otherwise the backend is CPU.
  1166. triton: Defines whether the GPU backend uses Triton for codegen.
  1167. Otherwise it uses the CUDA language for codegen.
  1168. Only valid when cuda == True.
  1169. """
  1170. if cuda:
  1171. call_args_str = ", ".join(pexpr(item) for item in call_args)
  1172. current_device = V.graph.scheduler.get_current_device_or_throw()
  1173. stream_name = self.write_get_raw_stream(current_device.index, V.graph)
  1174. if triton:
  1175. grid_str = ", ".join(pexpr(item) for item in grid)
  1176. grid_str = f"{grid_fn}({grid_str})"
  1177. self.writeline(
  1178. f"{name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
  1179. )
  1180. else:
  1181. stream_ptr = f"c_void_p({stream_name})"
  1182. self.writeline(f"{name}.{name}({call_args_str}, {stream_ptr})")
  1183. else:
  1184. self.writeline(self.wrap_kernel_call(name, call_args))
  1185. def writeline(self, line):
  1186. self.lines.append(line)
  1187. def writelines(self, lines):
  1188. for line in lines:
  1189. self.writeline(line)
  1190. def enter_context(self, ctx):
  1191. self.lines.append(LineContext(ctx))
  1192. def val_to_arg_str(self, s, type_=None):
  1193. from torch.utils._triton import dtype_to_string, has_triton_package
  1194. if has_triton_package():
  1195. import triton
  1196. if isinstance(s, SymTypes):
  1197. return pexpr(s.node.expr)
  1198. elif isinstance(s, sympy.Expr):
  1199. return pexpr(s)
  1200. elif isinstance(s, (tuple, list)):
  1201. @dataclasses.dataclass
  1202. class Shim:
  1203. ref: Any
  1204. def __repr__(self):
  1205. return self.ref
  1206. return repr(type(s)(Shim(self.val_to_arg_str(a)) for a in s))
  1207. elif isinstance(s, torch._ops.OpOverload):
  1208. return _get_qualified_name(s)
  1209. elif isinstance(s, (ir.Buffer, ReinterpretView)):
  1210. return s.codegen_reference()
  1211. elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined]
  1212. return dtype_to_string(s)
  1213. else:
  1214. return repr(s)
  1215. # The following methods are for memory management
  1216. def make_buffer_allocation(self, buffer):
  1217. device = buffer.get_device()
  1218. dtype = buffer.get_dtype()
  1219. shape = tuple(buffer.get_size())
  1220. stride = tuple(buffer.get_stride())
  1221. return self.make_allocation(buffer.get_name(), device, dtype, shape, stride)
  1222. def make_allocation(self, name, device, dtype, shape, stride):
  1223. if device.type in ("cpu", "cuda"):
  1224. # optimized path for faster allocations, saving ~2us versus the stuff below
  1225. return (
  1226. f"{name} = empty_strided_{device.type}("
  1227. f"{self.codegen_shape_tuple(shape)}, "
  1228. f"{self.codegen_shape_tuple(stride)}, "
  1229. f"{dtype})"
  1230. )
  1231. # all other devices:
  1232. return (
  1233. f"{name} = empty_strided("
  1234. f"{self.codegen_shape_tuple(shape)}, "
  1235. f"{self.codegen_shape_tuple(stride)}, "
  1236. f"device='{device.type}', dtype={dtype})"
  1237. )
  1238. def make_tensor_alias(self, new_name, old_name, comment=""):
  1239. return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}"
  1240. def make_buffer_free(self, buffer):
  1241. return f"del {buffer.get_name()}"
  1242. def make_free_by_names(self, names_to_del: List[str]):
  1243. return f"del {', '.join(name for name in names_to_del)}"
  1244. def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
  1245. return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse"
  1246. def make_buffer_reuse(self, old, new, delete_old: bool):
  1247. assert old.get_dtype() == new.get_dtype()
  1248. old_name = old.get_name()
  1249. new_name = new.get_name()
  1250. del_line = ";"
  1251. if old_name not in V.graph.get_output_names() and delete_old:
  1252. del_line = f"; {self.make_buffer_free(old)}"
  1253. if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
  1254. if old_name in self.stack_allocated_buffers:
  1255. self.stack_allocated_buffers[new_name] = new
  1256. return self.codegen_exact_buffer_reuse(old_name, new_name, del_line)
  1257. reinterpret_view = self.codegen_reinterpret_view(
  1258. old, new.get_size(), new.get_stride(), 0, self.wrapper_call
  1259. )
  1260. if reinterpret_view in self.stack_allocated_buffers:
  1261. self.stack_allocated_buffers[new_name] = new
  1262. return f"{self.declare_maybe_reference}{new_name} = {reinterpret_view}{del_line} {self.comment} reuse"
  1263. def codegen_deferred_allocation(self, name, layout):
  1264. self.writeline(
  1265. DeferredLine(
  1266. name,
  1267. f"{self.declare_maybe_reference}{name} = {layout.view.codegen_reference()}{self.ending} "
  1268. f"{self.comment} alias",
  1269. )
  1270. )
  1271. def codegen_allocation(self, buffer):
  1272. name = buffer.get_name()
  1273. if name in V.graph.removed_buffers or name in self.allocated:
  1274. return
  1275. self.allocated.add(name)
  1276. if isinstance(
  1277. buffer,
  1278. (ir.ExternKernelAlloc, ir.MultiOutput),
  1279. ):
  1280. return
  1281. layout = buffer.get_layout()
  1282. if isinstance(layout, ir.MutationLayoutSHOULDREMOVE):
  1283. return
  1284. if isinstance(layout, ir.NonOwningLayout):
  1285. assert isinstance(
  1286. layout.view, ir.ReinterpretView
  1287. ), f"unexpected {type(layout.view)}: {layout.view}"
  1288. self.codegen_allocation(layout.view.data)
  1289. self.codegen_deferred_allocation(name, layout)
  1290. return
  1291. self.writeline(AllocateLine(self, buffer))
  1292. def codegen_free(self, buffer):
  1293. assert (
  1294. buffer.get_workspace_size() == 0
  1295. ), "Only support zero workspace size for now!"
  1296. name = buffer.get_name()
  1297. # can be freed but not reused
  1298. if isinstance(buffer, ir.InputBuffer):
  1299. self.writeline(self.make_buffer_free(buffer))
  1300. return
  1301. if not self.can_reuse(buffer):
  1302. return
  1303. self.freed.add(name)
  1304. self.writeline(FreeIfNotReusedLine(self, buffer))
  1305. def can_reuse(self, input_buffer, output_buffer=None):
  1306. name = input_buffer.get_name()
  1307. if (
  1308. name in V.graph.removed_buffers
  1309. or name in V.graph.graph_inputs
  1310. or name in V.graph.constants
  1311. or name in V.graph.torchbind_constants
  1312. or name in V.graph.never_reuse_buffers
  1313. or name in self.freed
  1314. ):
  1315. return False
  1316. return True
  1317. def did_reuse(self, buffer, reused_buffer):
  1318. # Check whether a given buffer was reused by a possible reuser in the wrapper codegen
  1319. # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed
  1320. return (
  1321. buffer.get_name() in self.reuses
  1322. and self.reuses[buffer.get_name()] == reused_buffer.get_name()
  1323. )
  1324. def codegen_inplace_reuse(self, input_buffer, output_buffer):
  1325. assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
  1326. self.codegen_allocation(input_buffer)
  1327. self.freed.add(input_buffer.get_name())
  1328. self.allocated.add(output_buffer.get_name())
  1329. self.reuses[output_buffer.get_name()] = input_buffer.get_name()
  1330. self.writeline(ReuseLine(self, input_buffer, output_buffer))
  1331. def codegen_unbacked_symbol_decl(self, symbol):
  1332. name = str(symbol)
  1333. if name in self.unbacked_symbol_decls:
  1334. return name
  1335. else:
  1336. # When in CppWrapperCpu, we should only generate the declaration once
  1337. self.unbacked_symbol_decls.add(name)
  1338. return self.declare + name
  1339. def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
  1340. for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs):
  1341. self.writeline(f"{self.declare}{inner_input} = {outer_input}{self.ending}")
  1342. def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
  1343. for inner_output, outer_output in zip(
  1344. subgraph.graph.graph_outputs, outer_outputs
  1345. ):
  1346. self.writeline(
  1347. f"{outer_output} = {inner_output.codegen_reference()}{self.ending}"
  1348. )
  1349. def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
  1350. try:
  1351. self.push_codegened_graph(subgraph.graph)
  1352. self.writeline(f"{self.comment} subgraph: {subgraph.name}")
  1353. self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs)
  1354. parent_graph = V.graph
  1355. with V.set_graph_handler(subgraph.graph):
  1356. subgraph.graph.codegen_subgraph(
  1357. parent_graph=parent_graph,
  1358. )
  1359. self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs)
  1360. finally:
  1361. self.pop_codegened_graph()
  1362. def codegen_conditional(self, conditional):
  1363. name = conditional.get_name()
  1364. self.writeline(f"{name} = [None] * {len(conditional.outputs)}")
  1365. outer_inputs = [buf.codegen_reference() for buf in conditional.operands]
  1366. outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
  1367. predicate = conditional.predicate.codegen_reference()
  1368. if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
  1369. # move the Tensor predicate to host
  1370. predicate = f"{predicate}.item()"
  1371. self.writeline(f"{name} = [None] * {len(conditional.outputs)}")
  1372. self.writeline(f"if {predicate}:")
  1373. self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
  1374. self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
  1375. self.writeline(ExitSubgraphLine(self))
  1376. self.writeline("else:")
  1377. self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
  1378. self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
  1379. self.writeline(ExitSubgraphLine(self))
  1380. def codegen_while_loop(self, while_loop):
  1381. name = while_loop.get_name()
  1382. outer_carried_inputs = [
  1383. buf.codegen_reference() for buf in while_loop.carried_inputs
  1384. ]
  1385. outer_additional_inputs = [
  1386. buf.codegen_reference() for buf in while_loop.additional_inputs
  1387. ]
  1388. self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}")
  1389. for i, inp in enumerate(outer_carried_inputs):
  1390. # set the initial state before the loop
  1391. self.writeline(f"{name}[{i}] = {inp}")
  1392. cond_outer_inputs = [
  1393. *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))],
  1394. *outer_additional_inputs,
  1395. ]
  1396. cond_outer_outputs = [f"{name}_cond_result"]
  1397. body_outer_inputs = list(
  1398. cond_outer_inputs
  1399. ) # same inputs for cond_fn and body_fn
  1400. # Carry over the state from body_fn. Note: We only carry over
  1401. # the carried_inputs part of the inputs, the additional ones
  1402. # are passed in as they're before.
  1403. body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)]
  1404. self.writeline("while True:")
  1405. self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
  1406. self.codegen_subgraph(
  1407. while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
  1408. )
  1409. self.writeline(
  1410. f"if not {cond_outer_outputs[0]}.item(): break"
  1411. ) # condition doesn't hold
  1412. self.writeline(ExitSubgraphLine(self))
  1413. self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
  1414. self.codegen_subgraph(
  1415. while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
  1416. )
  1417. self.writeline(ExitSubgraphLine(self))
  1418. @staticmethod
  1419. def statically_known_int_or_none(x):
  1420. try:
  1421. if getattr(x, "free_symbols", None):
  1422. # _maybe_evaluate_static will return (s0 // (2 // s0)) as 2, but
  1423. # the actual codegen will still generate the full expression here.
  1424. return None
  1425. val = V.graph._shape_env._maybe_evaluate_static(x)
  1426. return int(val)
  1427. except Exception:
  1428. return None
  1429. @staticmethod
  1430. def statically_known_list_of_ints_or_none(lst):
  1431. result = []
  1432. for x in lst:
  1433. num = WrapperCodeGen.statically_known_int_or_none(x)
  1434. if num is None:
  1435. return None
  1436. result.append(num)
  1437. return result
  1438. @staticmethod
  1439. def is_statically_known_list_of_ints(lst):
  1440. return WrapperCodeGen.statically_known_list_of_ints_or_none(lst) is not None
  1441. @staticmethod
  1442. def static_shape_for_buffer_or_none(buffer):
  1443. return WrapperCodeGen.statically_known_list_of_ints_or_none(buffer.get_size())
  1444. @staticmethod
  1445. def can_prove_buffer_has_static_shape(buffer):
  1446. return WrapperCodeGen.static_shape_for_buffer_or_none(buffer) is not None