common.py 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import dataclasses
  4. import functools
  5. import itertools
  6. import logging
  7. import math
  8. import operator
  9. import re
  10. from itertools import chain
  11. from typing import (
  12. Any,
  13. Callable,
  14. ClassVar,
  15. Dict,
  16. List,
  17. NamedTuple,
  18. Optional,
  19. Set,
  20. Tuple,
  21. Union,
  22. )
  23. import sympy
  24. from sympy.printing.printer import Printer
  25. import torch
  26. import torch.fx
  27. from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
  28. from torch.utils import _pytree as pytree
  29. from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
  30. from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
  31. from .. import config, metrics
  32. from ..utils import (
  33. DeferredLineBase,
  34. generate_assert,
  35. IndentedBuffer,
  36. sympy_dot,
  37. sympy_subs,
  38. unique,
  39. )
  40. from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
  41. schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
  42. def data_type_logger(msg):
  43. if schedule_log.isEnabledFor(logging.DEBUG):
  44. schedule_log.debug("Data type propagation: %s", msg)
  45. @dataclasses.dataclass
  46. class WorkspaceArg:
  47. """A temporary buffer used for a single kernel, then discarded.
  48. Not registered as a traditional buffer since there are no users,
  49. so it would be dead code eliminated.
  50. """
  51. nbytes: sympy.Expr
  52. zero_fill: bool
  53. @dataclasses.dataclass
  54. class TensorArg:
  55. name: str
  56. buffer: str
  57. dtype: torch.dtype
  58. offset: sympy.Expr = sympy.Integer(0)
  59. @dataclasses.dataclass
  60. class SizeArg:
  61. name: str
  62. expr: sympy.Expr
  63. @dataclasses.dataclass
  64. class DeviceCodegen:
  65. scheduling: Any
  66. wrapper_codegen: type
  67. cpp_wrapper_codegen: type = type(None)
  68. KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
  69. device_codegens: Dict[str, DeviceCodegen] = {}
  70. class DeviceOpOverrides:
  71. def import_get_raw_stream_as(self, name):
  72. raise NotImplementedError
  73. def set_device(self, device_idx):
  74. raise NotImplementedError
  75. def synchronize(self):
  76. raise NotImplementedError
  77. def device_guard(self, device_idx):
  78. raise NotImplementedError
  79. device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
  80. # The code generated by Inductor consists of two main parts: kernel code and wrapper code.
  81. # For any new backend looking to integrate with Inductor, customization of these two main
  82. # parts are necessary to generate its specific code.
  83. #
  84. # Kernel code generation is determined by different Scheduling. Consequently, a new
  85. # backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
  86. # CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
  87. #
  88. # For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
  89. # that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
  90. # and override specific member functions to create backend-specific Python wrapper code.
  91. #
  92. # Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
  93. # of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
  94. # provide flexibility to the backend. A backend can choose to implement these classes from scratch,
  95. # or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
  96. # register_backend_for_device, to equip a new backend at runtime.
  97. #
  98. # Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
  99. # This backend can be used as a reference:
  100. # https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
  101. def register_backend_for_device(
  102. device: str,
  103. device_scheduling: Any,
  104. device_wrapper_codegen: type,
  105. device_cpp_wrapper_codegen: type = type(None),
  106. ):
  107. device_codegens[device] = DeviceCodegen(
  108. device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
  109. )
  110. def get_scheduling_for_device(device: str):
  111. return device_codegens[device].scheduling if device in device_codegens else None
  112. def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
  113. if device in device_codegens:
  114. wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
  115. return (
  116. wrapper_codegen_obj.cpp_wrapper_codegen
  117. if cpp_wrapper
  118. else wrapper_codegen_obj.wrapper_codegen
  119. )
  120. else:
  121. return None
  122. def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
  123. from ..ir import FlexibleLayout
  124. # added contiguous index prevents reordering
  125. return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
  126. def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
  127. device_op_overrides_dict[device] = device_op_overrides
  128. def get_device_op_overrides(device: str):
  129. assert isinstance(device, str)
  130. if not device_op_overrides_dict.keys():
  131. from .cuda import device_op_overrides # noqa: F401
  132. from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
  133. if device in device_op_overrides_dict.keys():
  134. return device_op_overrides_dict[device]
  135. @functools.lru_cache(None)
  136. def boolean_ops():
  137. return (
  138. "is_inf",
  139. "is_nan",
  140. "bitwise_xor",
  141. "logical_not",
  142. "signbit",
  143. "le",
  144. "lt",
  145. "ge",
  146. "gt",
  147. "eq",
  148. "ne",
  149. )
  150. DTYPE_TO_COMPUTATION_DTYPE = {
  151. torch.bfloat16: torch.float,
  152. torch.float16: torch.float,
  153. **{
  154. dtype: dtype
  155. for dtype in [
  156. torch.bool,
  157. torch.float32,
  158. torch.float64,
  159. torch.int8,
  160. torch.int16,
  161. torch.int32,
  162. torch.int64,
  163. torch.uint8,
  164. torch.uint16,
  165. torch.uint32,
  166. torch.uint64,
  167. ]
  168. },
  169. }
  170. class DataTypePropagation:
  171. def __init__(self, body) -> None:
  172. self.body = body
  173. self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
  174. "root": body.root_block.graph
  175. }
  176. for k, v in body.subblocks.items():
  177. self.graphs[k] = v.graph
  178. def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
  179. inputs = node.all_input_nodes
  180. input_nodes = [
  181. n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
  182. ]
  183. if len(input_nodes) == 0:
  184. return None
  185. all_input_nodes_propagated = all(
  186. OptimizationContext.key in n.meta
  187. and n.meta[OptimizationContext.key].dtype is not None
  188. for n in input_nodes
  189. )
  190. if not all_input_nodes_propagated:
  191. return None
  192. return functools.reduce(
  193. torch.promote_types,
  194. [n.meta[OptimizationContext.key].dtype for n in input_nodes],
  195. )
  196. def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
  197. sub_graph = self.graphs[node.target]
  198. dtype = self.propagate_graph(sub_graph)
  199. assert dtype
  200. return dtype
  201. def deduce_node_dtype(self, node: torch.fx.Node):
  202. if node.target in boolean_ops():
  203. return torch.bool
  204. if node.op == "placeholder":
  205. return None
  206. if node.target == "output":
  207. # we can infer output node if it only have 1 arg
  208. if len(node.args) != 1:
  209. return None
  210. if node.target in (
  211. "to_dtype",
  212. "index_expr",
  213. ):
  214. return node.args[-1]
  215. if node.target in (
  216. "rand",
  217. "randn",
  218. ):
  219. return torch.float
  220. if node.target in (
  221. "get_index",
  222. "index_expr",
  223. "randint64",
  224. ):
  225. return torch.int64
  226. if node.target in (
  227. "load",
  228. "store",
  229. "store_reduction",
  230. ):
  231. buf_name = node.args[1]
  232. return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
  233. if node.target == operator.getitem:
  234. return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
  235. assert isinstance(node.target, str)
  236. if node.target == "reduction":
  237. return node.args[1]
  238. if node.target == "constant":
  239. return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index]
  240. if node.target.startswith("masked_subblock"):
  241. return self.deduce_node_dtype_by_subgraph(node)
  242. return self.deduce_node_dtype_by_inputs(node)
  243. def propagate_graph(self, graph: torch.fx.Graph):
  244. assert graph.nodes
  245. graph_dtype = None
  246. # For masked_subblock, we use output's dtype to represent
  247. # the dtype of this subgraph. For other cases, graph_dtype
  248. # might be None
  249. for node in graph.nodes:
  250. if OptimizationContext.key in node.meta:
  251. opt_ctx = node.meta[OptimizationContext.key]
  252. else:
  253. opt_ctx = OptimizationContext()
  254. opt_ctx.dtype = self.deduce_node_dtype(node)
  255. node.meta[OptimizationContext.key] = opt_ctx
  256. if node.target == "output":
  257. graph_dtype = opt_ctx.dtype
  258. return graph_dtype
  259. def propagate(self):
  260. self.propagate_graph(self.graphs["root"])
  261. @classmethod
  262. def propagate_loopbody(cls, body):
  263. return cls(body).propagate()
  264. @classmethod
  265. def propagate_scheduler_node(cls, node):
  266. from ..ir import LoopBody
  267. from ..scheduler import SchedulerNode
  268. assert isinstance(node, SchedulerNode)
  269. assert isinstance(node._body, LoopBody)
  270. DataTypePropagation.propagate_loopbody(node._body)
  271. # This printer contains rules that are supposed to be generic for both C/C++ and
  272. # Python
  273. class ExprPrinter(Printer):
  274. @staticmethod
  275. def paren(string):
  276. def all_in_parens(string):
  277. if string[0] != "(" or len(string) < 2:
  278. return False
  279. count = 1
  280. for i, char in enumerate(string[1:]):
  281. if char == "(":
  282. count += 1
  283. elif char == ")":
  284. count -= 1
  285. if count == 0 and i != len(string) - 2:
  286. return False
  287. assert count == 0
  288. return True
  289. if (
  290. isinstance(string, CSEVariable)
  291. or re.match(r"^[a-z0-9_.]+$", string, re.I)
  292. or re.match(r"^\([^)]*\)$", string, re.I)
  293. or string == ""
  294. ):
  295. return string
  296. # don't put extra parens for strings that are already wrapped in parens
  297. if all_in_parens(string):
  298. return string
  299. return f"({string})"
  300. def _print_Relational(self, expr):
  301. return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
  302. def _print_Mul(self, expr):
  303. return "*".join(map(self.paren, map(self._print, expr.args)))
  304. def _print_Add(self, expr):
  305. return " + ".join(map(self.paren, map(self._print, expr.args)))
  306. # NB: this is OK to put here, because Mod is only defined for positive
  307. # numbers, and so across C/Python its behavior is consistent
  308. def _print_Mod(self, expr):
  309. return " % ".join(map(self.paren, map(self._print, expr.args)))
  310. def _print_FloatTrueDiv(self, expr):
  311. lhs, rhs = expr.args
  312. return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
  313. def _print_CleanDiv(self, expr):
  314. return self._print_FloorDiv(expr)
  315. def _print_GreaterThan(self, expr):
  316. # GreaterThan: >=
  317. # StrictlyGreaterThan: >
  318. # Go figure...
  319. return " >= ".join(map(self.paren, map(self._print, expr.args)))
  320. # NB: The C implementation is injected into codegen at
  321. # torch/_inductor/codegen/wrapper.py
  322. def _print_align(self, expr):
  323. assert len(expr.args) == 1
  324. return f"align({self._print(expr.args[0])})"
  325. # This must be implemented because sympy will collect x * x into Pow(x, 2), without
  326. # any explicit intervention. We print it just like x * x, notably, we
  327. # never generate sympy.Pow with floats.
  328. #
  329. # NB: this pow by natural, you should never have used builtin sympy.pow
  330. # for FloatPow, and a symbolic exponent should be PowByNatural. These
  331. # means exp is guaranteed to be integer.
  332. def _print_Pow(self, expr):
  333. base, exp = expr.args
  334. base = self._print(base)
  335. assert exp == int(exp), exp
  336. exp = int(exp)
  337. assert exp >= 0
  338. if exp > 0:
  339. return "*".join([self.paren(base)] * exp)
  340. else: # exp == 0
  341. return "1"
  342. # Explicit NotImplemented functions are to prevent default sympy printing
  343. # behavior, which will just barf out ToFloat(...) to your IR. The error
  344. # message is better here because it tells you which printer class it needs
  345. # to go in.
  346. def _print_ToFloat(self, expr):
  347. raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
  348. def _print_Infinity(self, expr):
  349. raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
  350. def _print_NegativeInfinity(self, expr):
  351. raise NotImplementedError(
  352. f"_print_NegativeInfinity not implemented for {type(self)}"
  353. )
  354. def _print_FloorDiv(self, expr):
  355. raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
  356. def _print_PythonMod(self, expr):
  357. raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
  358. def _print_IntTrueDiv(self, expr):
  359. raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
  360. def _print_PowByNatural(self, expr):
  361. raise NotImplementedError(
  362. f"_print_PowByNatural not implemented for {type(self)}"
  363. )
  364. def _print_FloatPow(self, expr):
  365. raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
  366. def _print_TruncToInt(self, expr):
  367. raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
  368. def _print_RoundToInt(self, expr):
  369. raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
  370. def _print_RoundDecimal(self, expr):
  371. raise NotImplementedError(
  372. f"_print_RoundDecimal not implemented for {type(self)}"
  373. )
  374. # NB: Some float operations are INTENTIONALLY not implemented for
  375. # printers. You can implement them as a quick unblock, but it is better
  376. # to ask yourself why we haven't done this computation in the Tensor
  377. # universe instead
  378. def _print_TruncToFloat(self, expr):
  379. raise NotImplementedError(
  380. f"_print_TruncToFloat not implemented for {type(self)}"
  381. )
  382. def doprint(self, expr, *, simplify: bool = True):
  383. # TODO: why are people passing strings to the printer here :think:
  384. if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
  385. expr = V.graph.sizevars.simplify(expr)
  386. return super().doprint(expr)
  387. class PythonPrinter(ExprPrinter):
  388. def _print_ToFloat(self, expr):
  389. assert len(expr.args) == 1
  390. return f"float({self._print(expr.args[0])})"
  391. def _print_ModularIndexing(self, expr):
  392. x, div, mod = expr.args
  393. x = self.paren(self.doprint(x))
  394. div = self.paren(self.doprint(div))
  395. mod = self.paren(self.doprint(mod))
  396. if div != "1":
  397. x = f"({x} // {div})"
  398. return f"{x} % {mod}"
  399. def _print_Infinity(self, expr):
  400. return "math.inf"
  401. def _print_NegativeInfinity(self, expr):
  402. return "-math.inf"
  403. # WARNING: this is dangerous for Triton, which has C-style modulus
  404. def _print_PythonMod(self, expr):
  405. return " % ".join(map(self.paren, map(self._print, expr.args)))
  406. # WARNING: this is dangerous for Triton, which has C-style modulus
  407. def _print_FloorDiv(self, expr):
  408. x, div = expr.args
  409. x = self.paren(self.doprint(x))
  410. div = self.paren(self.doprint(div))
  411. return f"({x} // {div})"
  412. # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
  413. # does a special algorithm
  414. def _print_IntTrueDiv(self, expr):
  415. lhs, rhs = expr.args
  416. return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
  417. def _helper_sqrt(self, expr):
  418. return f"math.sqrt({self._print(expr)})"
  419. def _print_OpaqueUnaryFn_sqrt(self, expr):
  420. return self._helper_sqrt(expr.args[0])
  421. def _print_FloatPow(self, expr):
  422. base, exp = expr.args
  423. return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
  424. # TODO: Not sure this works with Triton, even when base/exp are integral
  425. def _print_PowByNatural(self, expr):
  426. base, exp = expr.args
  427. return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
  428. def _print_floor(self, expr):
  429. assert len(expr.args) == 1
  430. return f"math.floor({self._print(expr.args[0])})"
  431. def _print_FloorToInt(self, expr):
  432. assert len(expr.args) == 1
  433. return f"math.floor({self._print(expr.args[0])})"
  434. def _print_TruncToInt(self, expr):
  435. assert len(expr.args) == 1
  436. # This also could have been int(), they'll do the same thing for float
  437. return f"math.trunc({self._print(expr.args[0])})"
  438. def _print_ceiling(self, expr):
  439. assert len(expr.args) == 1
  440. return f"math.ceil({self._print(expr.args[0])})"
  441. def _print_CeilToInt(self, expr):
  442. assert len(expr.args) == 1
  443. return f"math.ceil({self._print(expr.args[0])})"
  444. def _print_Abs(self, expr):
  445. assert len(expr.args) == 1
  446. return f"abs({self._print(expr.args[0])})"
  447. # NB: It's expected that we've made explicit any promotion in the sympy
  448. # expression, so it doesn't matter that Python max/min doesn't perform
  449. # promotion
  450. def _print_Max(self, expr):
  451. assert len(expr.args) >= 2
  452. return f"max({', '.join(map(self._print, expr.args))})"
  453. def _print_Min(self, expr):
  454. assert len(expr.args) >= 2
  455. return f"min({', '.join(map(self._print, expr.args))})"
  456. def _print_OpaqueUnaryFn_cos(self, expr):
  457. assert len(expr.args) == 1
  458. return f"math.cos({self._print(expr.args[0])})"
  459. def _print_OpaqueUnaryFn_cosh(self, expr):
  460. assert len(expr.args) == 1
  461. return f"math.cosh({self._print(expr.args[0])})"
  462. def _print_OpaqueUnaryFn_acos(self, expr):
  463. assert len(expr.args) == 1
  464. return f"math.acos({self._print(expr.args[0])})"
  465. def _print_OpaqueUnaryFn_sin(self, expr):
  466. assert len(expr.args) == 1
  467. return f"math.sin({self._print(expr.args[0])})"
  468. def _print_OpaqueUnaryFn_sinh(self, expr):
  469. assert len(expr.args) == 1
  470. return f"math.sinh({self._print(expr.args[0])})"
  471. def _print_OpaqueUnaryFn_asin(self, expr):
  472. assert len(expr.args) == 1
  473. return f"math.asin({self._print(expr.args[0])})"
  474. def _print_OpaqueUnaryFn_tan(self, expr):
  475. assert len(expr.args) == 1
  476. return f"math.tan({self._print(expr.args[0])})"
  477. def _print_OpaqueUnaryFn_tanh(self, expr):
  478. assert len(expr.args) == 1
  479. return f"math.tanh({self._print(expr.args[0])})"
  480. def _print_OpaqueUnaryFn_atan(self, expr):
  481. assert len(expr.args) == 1
  482. return f"math.atan({self._print(expr.args[0])})"
  483. def _print_RoundToInt(self, expr):
  484. assert len(expr.args) == 1
  485. return f"round({self._print(expr.args[0])})"
  486. def _print_RoundDecimal(self, expr):
  487. assert len(expr.args) == 2
  488. number, ndigits = expr.args
  489. assert isinstance(ndigits, sympy.Integer)
  490. return f"round({self._print(number)}, {ndigits})"
  491. class OpOverrides:
  492. def __init__(self, parent):
  493. super().__init__()
  494. self._parent = parent
  495. def __getattr__(self, item):
  496. return getattr(self._parent, item)
  497. @staticmethod
  498. def identity(value):
  499. # used to trigger cse
  500. return value
  501. @staticmethod
  502. def constant(value, dtype):
  503. return repr(value)
  504. @staticmethod
  505. def reciprocal(x):
  506. return ops.truediv(ops.constant(1, torch.int32), x)
  507. @staticmethod
  508. def square(x):
  509. return ops.mul(x, x)
  510. @staticmethod
  511. def erfc(x):
  512. return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
  513. @staticmethod
  514. def erfcx(x):
  515. return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
  516. @staticmethod
  517. def expm1(x):
  518. return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
  519. @staticmethod
  520. def log10(x):
  521. return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
  522. @staticmethod
  523. def log2(x):
  524. return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
  525. @staticmethod
  526. def exp2(x):
  527. return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
  528. @staticmethod
  529. def log1p(x):
  530. return ops.log(ops.add(x, ops.constant(1, torch.int32)))
  531. @staticmethod
  532. def sigmoid(x):
  533. one = ops.constant(1, torch.int32)
  534. return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
  535. @staticmethod
  536. def libdevice_sigmoid(x):
  537. one = ops.constant(1, torch.int32)
  538. return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
  539. @staticmethod
  540. def relu(x):
  541. return ops.maximum(x, ops.constant(0, torch.int32))
  542. @staticmethod
  543. def libdevice_abs(x):
  544. return ops.abs(x)
  545. @staticmethod
  546. def libdevice_sqrt(x):
  547. return ops.sqrt(x)
  548. @staticmethod
  549. def libdevice_cos(x):
  550. return ops.cos(x)
  551. @staticmethod
  552. def libdevice_sin(x):
  553. return ops.sin(x)
  554. @staticmethod
  555. def libdevice_log(x):
  556. return ops.log(x)
  557. @staticmethod
  558. def libdevice_exp(x):
  559. return ops.exp(x)
  560. @staticmethod
  561. def bitwise_not(x):
  562. return f"~{ExprPrinter.paren(x)}"
  563. @staticmethod
  564. def logical_not(a):
  565. return f"{ExprPrinter.paren(a)} == 0"
  566. @staticmethod
  567. def bitwise_and(x, y):
  568. return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
  569. @staticmethod
  570. def bitwise_or(x, y):
  571. return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
  572. @staticmethod
  573. def bitwise_xor(x, y):
  574. return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
  575. @staticmethod
  576. def bitwise_left_shift(x, y):
  577. return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
  578. @staticmethod
  579. def bitwise_right_shift(x, y):
  580. return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
  581. @staticmethod
  582. def remainder(a, b):
  583. r = ops.mod(a, b)
  584. cond = ops.and_(
  585. ops.ne(r, ops.constant(0, torch.int32)),
  586. ops.ne(ops.signbit(r), ops.signbit(b)),
  587. )
  588. return ops.where(cond, ops.add(r, b), r)
  589. @staticmethod
  590. def trunc_to_int(a, dtype):
  591. return ops.to_dtype(ops.trunc(a), dtype)
  592. @staticmethod
  593. def floor_to_int(a, dtype):
  594. return ops.to_dtype(ops.floor(a), dtype)
  595. @staticmethod
  596. def ceil_to_int(a, dtype):
  597. return ops.to_dtype(ops.ceil(a), dtype)
  598. @staticmethod
  599. def round_to_int(a, dtype):
  600. return ops.to_dtype(ops.round(a), dtype)
  601. @staticmethod
  602. def int_truediv(a, b):
  603. # TODO: this is wrong
  604. # TODO: an easy bandaid is to generate runtime asserts that it's
  605. # <= 2**53, which is when this equation is correct
  606. return ops.truediv(a, b)
  607. @staticmethod
  608. def load_seed(name, offset):
  609. return ops.load(name, sympy.Integer(offset))
  610. @classmethod
  611. def _initialize_pointwise_overrides(cls, target):
  612. assert target in {"triton", "cpp", "cppvec"}, target
  613. for funcname, data in pointwise_overrides_data.items():
  614. impl = getattr(data, target)
  615. if impl is None:
  616. continue
  617. setattr(cls, funcname, staticmethod(impl))
  618. @dataclasses.dataclass
  619. class OverridesData:
  620. name: str
  621. cpp: Callable[..., str]
  622. # None when not impl in libdevice/triton
  623. triton: Optional[Callable[..., str]] = None
  624. # None when not impl in aten/.../vec
  625. cppvec: Optional[Callable[..., str]] = None
  626. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
  627. ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  628. )
  629. # NB: if you add a new special function, don't forget to update
  630. # torch._inductor.ops_handler too
  631. pointwise_overrides_data: Dict[str, OverridesData] = dict(
  632. airy_ai=OverridesData(
  633. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  634. cpp=lambda x: f"airy_ai_forward({x})",
  635. name="special_airy_ai",
  636. ),
  637. bessel_j0=OverridesData(
  638. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  639. cpp=lambda x: f"bessel_j0_forward({x})",
  640. triton=lambda x: f"libdevice.j0({x})",
  641. name="special_bessel_j0",
  642. ),
  643. bessel_j1=OverridesData(
  644. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  645. cpp=lambda x: f"bessel_j1_forward({x})",
  646. triton=lambda x: f"libdevice.j1({x})",
  647. name="special_bessel_j1",
  648. ),
  649. bessel_y0=OverridesData(
  650. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  651. cpp=lambda x: f"bessel_y0_forward({x})",
  652. triton=lambda x: f"libdevice.y0({x})",
  653. name="special_bessel_y0",
  654. ),
  655. bessel_y1=OverridesData(
  656. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  657. cpp=lambda x: f"bessel_y1_forward({x})",
  658. triton=lambda x: f"libdevice.y1({x})",
  659. name="special_bessel_y1",
  660. ),
  661. digamma=OverridesData(
  662. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  663. cpp=lambda x: f"calc_digamma({x})",
  664. cppvec=lambda x: f"{x}.digamma()",
  665. name="digamma",
  666. ),
  667. # no cpp nor triton implementation for entr, it is defined as decomposition
  668. # erf, erfc
  669. erfcx=OverridesData(
  670. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  671. cpp=lambda x: f"calc_erfcx({x})",
  672. triton=lambda x: f"libdevice.erfcx({x})",
  673. name="special_erfcx",
  674. ),
  675. fma=OverridesData(
  676. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  677. cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
  678. cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
  679. triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
  680. name="fma",
  681. ),
  682. # erfinv, exp2, expit, gammaln
  683. igamma=OverridesData(
  684. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  685. cpp=lambda x, y: f"calc_igamma({x}, {y})",
  686. name="igamma",
  687. ),
  688. igammac=OverridesData(
  689. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  690. cpp=lambda x, y: f"calc_igammac({x}, {y})",
  691. name="igammac",
  692. ),
  693. gammainc=OverridesData(
  694. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  695. cpp=lambda x, y: f"calc_igamma({x}, {y})",
  696. name="special_gammainc",
  697. ),
  698. gammaincc=OverridesData(
  699. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  700. cpp=lambda x, y: f"calc_igammac({x}, {y})",
  701. name="special_gammaincc",
  702. ),
  703. i0=OverridesData(
  704. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  705. cpp=lambda x: f"calc_i0({x})",
  706. triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
  707. cppvec=lambda x: f"{x}.i0()",
  708. name="i0",
  709. ),
  710. i0e=OverridesData(
  711. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  712. cpp=lambda x: f"calc_i0e({x})",
  713. cppvec=lambda x: f"{x}.i0e()",
  714. name="special_i0e",
  715. ),
  716. i1=OverridesData(
  717. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  718. cpp=lambda x: f"calc_i1({x})",
  719. triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
  720. name="special_i1",
  721. ),
  722. i1e=OverridesData(
  723. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  724. cpp=lambda x: f"calc_i1e({x})",
  725. name="special_i1e",
  726. ),
  727. log_ndtr=OverridesData(
  728. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  729. cpp=lambda x: f"calc_log_ndtr({x})",
  730. name="special_log_ndtr",
  731. ),
  732. # logit
  733. modified_bessel_i0=OverridesData(
  734. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  735. cpp=lambda x: f"modified_bessel_i0_forward({x})",
  736. triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
  737. name="special_modified_bessel_i0",
  738. ),
  739. modified_bessel_i1=OverridesData(
  740. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  741. cpp=lambda x: f"modified_bessel_i1_forward({x})",
  742. triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
  743. name="special_modified_bessel_i1",
  744. ),
  745. modified_bessel_k0=OverridesData(
  746. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  747. cpp=lambda x: f"modified_bessel_k0_forward({x})",
  748. name="special_modified_bessel_k0",
  749. ),
  750. modified_bessel_k1=OverridesData(
  751. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  752. cpp=lambda x: f"modified_bessel_k1_forward({x})",
  753. name="special_modified_bessel_k1",
  754. ),
  755. # multigamma
  756. ndtr=OverridesData(
  757. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  758. cpp=lambda x: f"calc_ndtr({x})",
  759. name="special_ndtr",
  760. ),
  761. ndtri=OverridesData(
  762. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  763. cpp=lambda x: f"calc_ndtri({x})",
  764. name="special_ndtri",
  765. ),
  766. polygamma=OverridesData(
  767. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  768. cpp=lambda x, y: f"calc_polygamma({y}, {x})",
  769. name="polygamma",
  770. ),
  771. # psi - alias to digamma
  772. # round
  773. scaled_modified_bessel_k0=OverridesData(
  774. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  775. cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
  776. name="special_scaled_modified_bessel_k0",
  777. ),
  778. scaled_modified_bessel_k1=OverridesData(
  779. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  780. cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
  781. name="special_scaled_modified_bessel_k1",
  782. ),
  783. # sinc
  784. spherical_bessel_j0=OverridesData(
  785. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  786. cpp=lambda x: f"spherical_bessel_j0_forward({x})",
  787. name="special_spherical_bessel_j0",
  788. ),
  789. zeta=OverridesData(
  790. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  791. cpp=lambda x, y: f"zeta({x}, {y})",
  792. name="special_zeta",
  793. ),
  794. chebyshev_polynomial_t=OverridesData(
  795. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  796. cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
  797. name="special_chebyshev_polynomial_t",
  798. ),
  799. chebyshev_polynomial_u=OverridesData(
  800. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  801. cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
  802. name="special_chebyshev_polynomial_u",
  803. ),
  804. chebyshev_polynomial_v=OverridesData(
  805. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  806. cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
  807. name="special_chebyshev_polynomial_v",
  808. ),
  809. chebyshev_polynomial_w=OverridesData(
  810. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  811. cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
  812. name="special_chebyshev_polynomial_w",
  813. ),
  814. legendre_polynomial_p=OverridesData(
  815. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  816. cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
  817. name="special_legendre_polynomial_p",
  818. ),
  819. shifted_chebyshev_polynomial_t=OverridesData(
  820. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  821. cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
  822. name="special_shifted_chebyshev_polynomial_t",
  823. ),
  824. shifted_chebyshev_polynomial_u=OverridesData(
  825. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  826. cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
  827. name="special_shifted_chebyshev_polynomial_u",
  828. ),
  829. shifted_chebyshev_polynomial_v=OverridesData(
  830. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  831. cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
  832. name="special_shifted_chebyshev_polynomial_v",
  833. ),
  834. shifted_chebyshev_polynomial_w=OverridesData(
  835. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  836. cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
  837. name="special_shifted_chebyshev_polynomial_w",
  838. ),
  839. hermite_polynomial_h=OverridesData(
  840. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  841. cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
  842. name="special_hermite_polynomial_h",
  843. ),
  844. hermite_polynomial_he=OverridesData(
  845. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  846. cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
  847. name="special_hermite_polynomial_he",
  848. ),
  849. laguerre_polynomial_l=OverridesData(
  850. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  851. cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
  852. name="special_laguerre_polynomial_l",
  853. ),
  854. )
  855. # Use mypy to check protocol implemented correctly
  856. def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
  857. return h
  858. class DeferredLine(DeferredLineBase):
  859. """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
  860. def __init__(self, name, line):
  861. super().__init__(line)
  862. self.name = name
  863. assert not isinstance(line, DeferredLineBase)
  864. def __call__(self):
  865. if all(
  866. self.name not in x
  867. for x in (
  868. V.graph.removed_buffers,
  869. V.kernel.removed_buffers,
  870. V.graph.inplaced_to_remove,
  871. V.kernel.inplaced_to_remove,
  872. )
  873. ):
  874. return self.line
  875. return None
  876. def _new_line(self, line):
  877. return DeferredLine(self.name, line)
  878. class BracesBuffer(IndentedBuffer):
  879. def indent(self, offset=1):
  880. @contextlib.contextmanager
  881. def ctx():
  882. for _ in range(offset):
  883. self.writeline("{")
  884. self._indent += 1
  885. for _ in range(-offset):
  886. self._indent -= 1
  887. self.writeline("}")
  888. yield
  889. for _ in range(-offset):
  890. self.writeline("{")
  891. self._indent += 1
  892. for _ in range(offset):
  893. self._indent -= 1
  894. self.writeline("}")
  895. return ctx()
  896. class InplacedBuffer(NamedTuple):
  897. inner_name: str
  898. other_names: List[str]
  899. class KernelArgs:
  900. @staticmethod
  901. def _lookup(prefix, odict, name):
  902. assert isinstance(name, (str, sympy.Symbol))
  903. if name not in odict:
  904. odict[name] = f"{prefix}{len(odict)}"
  905. return odict[name]
  906. def __init__(self, sizevars=None):
  907. self.input_buffers = dict()
  908. self.output_buffers = dict()
  909. self.inplace_buffers = dict()
  910. self.sizevars = sizevars or dict()
  911. self.workspace_arg = None
  912. def __repr__(self):
  913. return "KernelArgs({})".format(
  914. ", ".join(
  915. map(
  916. repr,
  917. [
  918. self.input_buffers,
  919. self.output_buffers,
  920. self.inplace_buffers,
  921. self.sizevars,
  922. ],
  923. )
  924. )
  925. )
  926. def _buffer_is_marked_removed(self, name):
  927. return isinstance(name, str) and name.startswith("REMOVED")
  928. def input(self, name):
  929. if V.graph.scheduler:
  930. name = V.graph.scheduler.mutation_real_name.get(name, name)
  931. assert name not in V.graph.removed_buffers, name
  932. if name in self.output_buffers:
  933. return self.output_buffers[name]
  934. if name in self.inplace_buffers:
  935. return self.inplace_buffers[name].inner_name
  936. if name.startswith("seed"):
  937. return self._lookup("seed", self.input_buffers, name)
  938. return self._lookup("in_ptr", self.input_buffers, name)
  939. def output(self, name):
  940. if V.graph.scheduler:
  941. name = V.graph.scheduler.mutation_real_name.get(name, name)
  942. assert name not in V.graph.removed_buffers, name
  943. if name in self.inplace_buffers:
  944. return self.inplace_buffers[name].inner_name
  945. return self._lookup("out_ptr", self.output_buffers, name)
  946. def make_inplace(self, input_name, output_name):
  947. assert output_name not in self.inplace_buffers
  948. if input_name in self.inplace_buffers:
  949. buf = self.inplace_buffers[input_name]
  950. buf.other_names.append(output_name)
  951. self.inplace_buffers[output_name] = buf
  952. else:
  953. buf = InplacedBuffer(
  954. f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
  955. [input_name, output_name],
  956. )
  957. self.inplace_buffers[input_name] = buf
  958. self.inplace_buffers[output_name] = buf
  959. def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
  960. if self.workspace_arg is None:
  961. self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
  962. return "ws_ptr", 0
  963. offset = self.workspace_arg.nbytes
  964. zero_fill = zero_fill or self.workspace_arg.zero_fill
  965. self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
  966. return "ws_ptr", offset
  967. def seed_offset(self, name, value):
  968. if value in self.sizevars:
  969. return self.sizevars[value]
  970. if name in self.sizevars.values():
  971. name = (
  972. f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
  973. )
  974. self.sizevars[value] = name
  975. return name
  976. def size(self, name):
  977. if str(name) == "seed":
  978. self.sizevars["seed"] = "seed"
  979. return "seed"
  980. return self._lookup("ks", self.sizevars, name)
  981. def call_names(self):
  982. return chain(
  983. self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
  984. )
  985. def wrap_ptr_arg(self, buf, dtype):
  986. return buf
  987. def wrap_size_arg(self, size):
  988. return str(size)
  989. def cpp_argdefs(self):
  990. from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE
  991. call_args = []
  992. arg_defs = []
  993. arg_types = []
  994. for inplaced in unique(self.inplace_buffers.values()):
  995. if self._buffer_is_marked_removed(inplaced):
  996. continue
  997. outer = inplaced.other_names[-1]
  998. inner = inplaced.inner_name
  999. dtype = V.graph.get_dtype(outer)
  1000. cpp_dtype = DTYPE_TO_CPP[dtype]
  1001. arg_defs.append(f"{cpp_dtype}* {inner}")
  1002. call_args.append(self.wrap_ptr_arg(outer, dtype))
  1003. arg_types.append(f"{cpp_dtype}*")
  1004. for outer, inner in self.input_buffers.items():
  1005. if outer in self.inplace_buffers:
  1006. continue
  1007. dtype = V.graph.get_dtype(outer)
  1008. cpp_dtype = DTYPE_TO_CPP[dtype]
  1009. arg_defs.append(f"const {cpp_dtype}* {inner}")
  1010. call_args.append(self.wrap_ptr_arg(outer, dtype))
  1011. arg_types.append(f"const {cpp_dtype}*")
  1012. for outer, inner in self.output_buffers.items():
  1013. if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
  1014. continue
  1015. dtype = V.graph.get_dtype(outer)
  1016. cpp_dtype = DTYPE_TO_CPP[dtype]
  1017. arg_defs.append(f"{cpp_dtype}* {inner}")
  1018. call_args.append(self.wrap_ptr_arg(outer, dtype))
  1019. arg_types.append(f"{cpp_dtype}*")
  1020. for outer, inner in self.sizevars.items():
  1021. arg_defs.append(f"const {INDEX_TYPE} {inner}")
  1022. call_args.append(self.wrap_size_arg(outer))
  1023. arg_types.append(f"const {INDEX_TYPE}")
  1024. if V.graph.wrapper_code:
  1025. V.graph.wrapper_code.ensure_size_computed(outer)
  1026. assert self.workspace_arg is None, "Workspace not supported on CPU "
  1027. return arg_defs, call_args, arg_types
  1028. def python_argdefs(self):
  1029. arg_defs = []
  1030. call_args = []
  1031. arg_types = []
  1032. precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
  1033. for inplaced in unique(self.inplace_buffers.values()):
  1034. if self._buffer_is_marked_removed(inplaced):
  1035. continue
  1036. arg_defs.append(inplaced.inner_name)
  1037. call_args.append(inplaced.other_names[-1])
  1038. arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
  1039. precompile_args.append(
  1040. TensorArg(
  1041. name=inplaced.inner_name,
  1042. buffer=inplaced.other_names[-1],
  1043. dtype=V.graph.get_dtype(inplaced.other_names[-1]),
  1044. )
  1045. )
  1046. for outer, inner in chain(
  1047. self.input_buffers.items(), self.output_buffers.items()
  1048. ):
  1049. if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
  1050. continue
  1051. arg_defs.append(inner)
  1052. call_args.append(outer)
  1053. arg_types.append(V.graph.get_dtype(outer))
  1054. precompile_args.append(
  1055. TensorArg(
  1056. name=inner,
  1057. buffer=outer,
  1058. dtype=V.graph.get_dtype(outer),
  1059. )
  1060. )
  1061. for outer, inner in self.sizevars.items():
  1062. arg_defs.append(inner)
  1063. call_args.append(outer)
  1064. arg_types.append(type(outer))
  1065. precompile_args.append(SizeArg(inner, outer))
  1066. if V.graph.wrapper_code:
  1067. V.graph.wrapper_code.ensure_size_computed(outer)
  1068. if self.workspace_arg is not None:
  1069. arg_defs.append("ws_ptr")
  1070. call_args.append("workspace")
  1071. precompile_args.append(self.workspace_arg)
  1072. return arg_defs, call_args, precompile_args, arg_types
  1073. def aliases(self):
  1074. for inplaced in unique(self.inplace_buffers.values()):
  1075. if self._buffer_is_marked_removed(inplaced):
  1076. continue
  1077. for other in inplaced.other_names:
  1078. if (
  1079. other in V.graph.inplaced_to_remove
  1080. or other in V.kernel.inplaced_to_remove
  1081. ):
  1082. continue
  1083. if other in self.input_buffers:
  1084. yield self.input_buffers[other], inplaced.inner_name
  1085. if other in self.output_buffers:
  1086. yield self.output_buffers[other], inplaced.inner_name
  1087. def is_removed(self, name):
  1088. def _is_removed(name, buffers):
  1089. return name not in buffers or self._buffer_is_marked_removed(buffers[name])
  1090. return _is_removed(name, self.output_buffers) and _is_removed(
  1091. name, self.inplace_buffers
  1092. )
  1093. # Includes inplace buffers, excludes removed buffers. Essentially,
  1094. # after you do a call into this kernel, which buffers actually contain
  1095. # updated data? Modeled off of python_argdefs.
  1096. def live_output_buffers(self):
  1097. live_outs = set()
  1098. for inplaced in unique(self.inplace_buffers.values()):
  1099. if self._buffer_is_marked_removed(inplaced):
  1100. continue
  1101. live_outs.add(inplaced.other_names[-1])
  1102. for outer, inner in self.output_buffers.items():
  1103. if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
  1104. continue
  1105. live_outs.add(outer)
  1106. return live_outs
  1107. class CSEVariable:
  1108. """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
  1109. To do so, the backends can simply overload `Kernel.create_cse_var`
  1110. The "CSEVariable.update_on_args" method gives you a hook for annotations
  1111. See example of TritonCSEVariable in triton.py
  1112. """
  1113. def __init__(self, name, bounds: ValueRanges[Any]):
  1114. assert isinstance(bounds, ValueRanges)
  1115. self.name = name
  1116. self.bounds = bounds
  1117. self.use_count = 1 # track how many tims this expression is used
  1118. def __str__(self):
  1119. return self.name
  1120. def __hash__(self) -> int:
  1121. return hash(self.name)
  1122. def __eq__(self, other) -> bool:
  1123. return type(other) == type(self) and other.name == self.name
  1124. def update_on_args(self, name, args, kwargs):
  1125. pass
  1126. def __repr__(self):
  1127. return f"{self.__class__.__name__}({self.name!r})"
  1128. class CppWrapperKernelArgs(KernelArgs):
  1129. def wrap_ptr_arg(self, buf, dtype):
  1130. from .cpp_utils import DTYPE_TO_CPP
  1131. if config.abi_compatible:
  1132. # In the abi_compatible model, we just return the buf here.
  1133. # We will form correct call args later in wrapper.generate_kernel_all.
  1134. return buf
  1135. else:
  1136. return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
  1137. def wrap_size_arg(self, size):
  1138. return f"{size}"
  1139. class CSE:
  1140. """Common subexpression elimination"""
  1141. def __init__(
  1142. self,
  1143. prefix="",
  1144. suffix="",
  1145. name_prefix="tmp",
  1146. iter_buffers=None,
  1147. store_cache=None,
  1148. reduction_cache=None,
  1149. varname_map=None,
  1150. ):
  1151. self.prefix = prefix
  1152. self.suffix = suffix
  1153. self.cache = {}
  1154. self.name_prefix = name_prefix
  1155. self.store_cache = store_cache or {}
  1156. self.reduction_cache = reduction_cache or {}
  1157. self.iter_buffer_ids = iter_buffers or itertools.count()
  1158. self.invalidated_stores = set()
  1159. self.varname_map = varname_map or {}
  1160. def invalidate(self, keep_vars: Set[str]):
  1161. for name, tmp in list(self.store_cache.items()):
  1162. if tmp not in keep_vars:
  1163. del self.store_cache[name]
  1164. self.invalidated_stores.add(name)
  1165. self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
  1166. def clone(self):
  1167. # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
  1168. return CSE(
  1169. prefix=self.prefix,
  1170. suffix=self.suffix,
  1171. name_prefix=self.name_prefix,
  1172. iter_buffers=self.iter_buffer_ids,
  1173. store_cache=self.store_cache,
  1174. varname_map=self.varname_map,
  1175. )
  1176. def generate(
  1177. self,
  1178. buffer: IndentedBuffer,
  1179. expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
  1180. *,
  1181. bounds: ValueRanges[Any] = ValueRanges.unknown(),
  1182. write=True,
  1183. assignment=True,
  1184. ) -> CSEVariable:
  1185. if isinstance(expr, OpsValue):
  1186. expr = expr.value
  1187. assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
  1188. assert write or assignment
  1189. if isinstance(expr, CSEVariable):
  1190. # If the expressions were always created with all the information, we could
  1191. # assert expr.bounds == bounds, but sometimes the expression is created
  1192. # with the loose ValueRanges.unknown(), so we need to tighten the bounds
  1193. expr.bounds = expr.bounds.tighten(bounds)
  1194. expr.use_count += 1
  1195. return expr
  1196. cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
  1197. var = self.cache.get(cache_key, None)
  1198. if not var:
  1199. var = self.newvar(bounds)
  1200. self.cache[cache_key] = var
  1201. if write:
  1202. if V.kernel.current_node:
  1203. V.kernel.current_node.codegen_originating_info(
  1204. buffer, only_once=True
  1205. )
  1206. if isinstance(expr, IndentedBuffer):
  1207. if assignment:
  1208. buffer.writeline(f"{self.prefix}{var} =")
  1209. buffer.splice(expr)
  1210. buffer.writeline(self.suffix)
  1211. else:
  1212. if assignment:
  1213. line = f"{self.prefix}{var} = {expr}{self.suffix}"
  1214. else:
  1215. line = f"{expr}{self.suffix}"
  1216. buffer.writeline(line)
  1217. else:
  1218. var.bounds = var.bounds.tighten(bounds)
  1219. var.use_count += 1
  1220. return var
  1221. def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
  1222. var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
  1223. var = V.kernel.create_cse_var(var_name, bounds)
  1224. self.varname_map[var_name] = var
  1225. return var
  1226. class CodeGen:
  1227. def __init__(self):
  1228. super().__init__()
  1229. self.exit_stack = contextlib.ExitStack()
  1230. def __enter__(self):
  1231. self.exit_stack.__enter__()
  1232. return self
  1233. def __exit__(self, exc_type, exc_val, exc_tb):
  1234. self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
  1235. class ScopedDict:
  1236. def __init__(self, original_dict):
  1237. self.original_dict = original_dict
  1238. self.new_items = {}
  1239. def __getitem__(self, key):
  1240. if key in self.new_items:
  1241. return self.new_items[key]
  1242. return self.original_dict[key]
  1243. def __setitem__(self, key, value):
  1244. self.new_items[key] = value
  1245. def __contains__(self, key):
  1246. return key in self.new_items or key in self.original_dict
  1247. def get(self, key, default=None):
  1248. if key in self.new_items:
  1249. return self.new_items[key]
  1250. return self.original_dict.get(key, default)
  1251. class Kernel(CodeGen):
  1252. newvar_prefix = ""
  1253. suffix = ""
  1254. overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
  1255. # TODO: these look dead, but with all the getattr it's hard to tell...
  1256. load_format: None = None
  1257. store_format: None = None
  1258. def __init__(self, args=None, increase_kernel_count=True):
  1259. super().__init__()
  1260. if increase_kernel_count:
  1261. metrics.generated_kernel_count += 1
  1262. self.args = args or KernelArgs()
  1263. self.loads = IndentedBuffer()
  1264. self.compute = IndentedBuffer()
  1265. self.stores = IndentedBuffer()
  1266. self.num_load = 0
  1267. self.num_reduction = 0
  1268. self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
  1269. self.must_keep_buffers = set()
  1270. self.store_buffer_names = set()
  1271. self._load_mask = None
  1272. # set in set_current_node
  1273. self.current_node = None
  1274. self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
  1275. self.removed_buffers = set()
  1276. self.inplaced_to_remove = set()
  1277. # key: the buffer to write
  1278. # value: the buffer to read and whose memory can be reused for
  1279. # the buffer specified by key
  1280. self.inplace_update_buffers = dict()
  1281. # Set minimum number of elements processed per thread.
  1282. self.min_elem_per_thread = 1
  1283. self.kernel_name = None
  1284. @contextlib.contextmanager
  1285. def set_current_node(self, node):
  1286. prior = self.current_node
  1287. self.current_node = node
  1288. self.node_to_bounds = node._body.bounds().get_bounds()
  1289. try:
  1290. yield
  1291. finally:
  1292. self.current_node = prior
  1293. @contextlib.contextmanager
  1294. def swap_buffers(self, lb, cb=None, sb=None):
  1295. def scope_cse(cse):
  1296. new_cse = cse.clone()
  1297. new_cse.cache = ScopedDict(cse.cache)
  1298. new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
  1299. new_cse.store_cache = ScopedDict(cse.store_cache)
  1300. return new_cse
  1301. if cb is None:
  1302. cb = lb
  1303. loads = self.loads
  1304. compute = self.compute
  1305. stores = self.stores
  1306. cse = self.cse
  1307. self.loads = lb
  1308. self.compute = cb
  1309. self.stores = sb
  1310. self.cse = scope_cse(cse)
  1311. try:
  1312. yield
  1313. finally:
  1314. self.loads = loads
  1315. self.compute = compute
  1316. self.stores = stores
  1317. self.cse = cse
  1318. def load(self, name: str, index: sympy.Expr) -> CSEVariable:
  1319. raise NotImplementedError
  1320. def indirect_load(self, name: str, index: sympy.Expr):
  1321. """A load the depends on an index we have read"""
  1322. prior = self.loads
  1323. try:
  1324. # put the load in the compute section as it might have deps
  1325. self.loads = self.compute
  1326. return self.load(name, index)
  1327. finally:
  1328. self.loads = prior
  1329. def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
  1330. raise NotImplementedError
  1331. def store(
  1332. self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
  1333. ) -> None:
  1334. raise NotImplementedError
  1335. def reduction(
  1336. self,
  1337. dtype: torch.dtype,
  1338. src_dtype: torch.dtype,
  1339. reduction_type: ReductionType,
  1340. value: Union[CSEVariable, Tuple[CSEVariable, ...]],
  1341. ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
  1342. raise NotImplementedError
  1343. def scan(
  1344. self,
  1345. dtypes: Tuple[torch.dtype, ...],
  1346. combine_fn: Callable[
  1347. [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
  1348. ],
  1349. values: Tuple[CSEVariable, ...],
  1350. ) -> Tuple[CSEVariable, ...]:
  1351. raise NotImplementedError
  1352. def var_ranges(self):
  1353. raise NotImplementedError
  1354. def bucketize(
  1355. self,
  1356. values: CSEVariable,
  1357. offsets_name: str,
  1358. offsets_size: sympy.Expr,
  1359. indexing_dtype: torch.dtype,
  1360. right: bool,
  1361. ) -> CSEVariable:
  1362. """
  1363. See [Note: Inductor bucketize op]
  1364. """
  1365. raise NotImplementedError
  1366. @property
  1367. def assert_function(self) -> str:
  1368. raise NotImplementedError
  1369. def indirect_assert(
  1370. self,
  1371. var: Union[CSEVariable, str],
  1372. lower: Optional[str],
  1373. upper: Optional[str],
  1374. mask: Optional[str] = None,
  1375. ) -> str:
  1376. if isinstance(var, CSEVariable):
  1377. var = str(var)
  1378. assert isinstance(var, str)
  1379. assert lower is None or isinstance(lower, str)
  1380. assert upper is None or isinstance(upper, str)
  1381. if lower and upper:
  1382. # The conditions need to be in parens because of Python's operator precedence.
  1383. # It'd be less error-prone to use and/or/not, which is suported by triton
  1384. cond = f"({lower} <= {var}) & ({var} < {upper})"
  1385. cond_print = f"{lower} <= {var} < {upper}"
  1386. elif lower:
  1387. cond = f"{lower} <= {var}"
  1388. cond_print = cond
  1389. else:
  1390. assert upper
  1391. cond = f"{var} < {upper}"
  1392. cond_print = cond
  1393. if mask:
  1394. cond = f"({cond}) | ~({mask})"
  1395. return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
  1396. def check_bounds(
  1397. self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
  1398. ):
  1399. raise NotImplementedError
  1400. def index_to_str(self, index: sympy.Expr) -> str:
  1401. raise NotImplementedError
  1402. def __enter__(self):
  1403. # TODO: hoist this to top level
  1404. class CSEProxy:
  1405. self.name = "CSEProxy"
  1406. vr_analysis = ValueRangeAnalysis()
  1407. @staticmethod
  1408. def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
  1409. def inner(*args, **kwargs):
  1410. bounds = CSEProxy._bound_variable(name, *args, **kwargs)
  1411. value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
  1412. def do_cse(v):
  1413. csevar = self.cse.generate(self.compute, v, bounds=bounds)
  1414. csevar.update_on_args(name, args, kwargs)
  1415. return csevar
  1416. return pytree.tree_map(do_cse, value)
  1417. return inner
  1418. @staticmethod
  1419. def _bound_variable(name, *args, **kwargs):
  1420. """
  1421. If the variable comes from an FX node, we forward the bound we have already computed
  1422. Else, if the variable when codegen'ing another op, we try to compute its bounds
  1423. """
  1424. from ..select_algorithm import TritonTemplateKernel
  1425. if isinstance(V.kernel, TritonTemplateKernel):
  1426. return ValueRanges.unknown()
  1427. fx_node = V.interpreter.current_node
  1428. if fx_node.target == name and self.node_to_bounds is not None:
  1429. assert isinstance(self.node_to_bounds, dict)
  1430. return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
  1431. elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
  1432. # These create lots of inner strings. We would need to compute the bounds at the ops
  1433. # We will also likely not get much from computing VRs on these nodes
  1434. if any(
  1435. s in fx_node.target
  1436. for s in ("set_indirect", "reduction", "scan")
  1437. ):
  1438. return ValueRanges.unknown()
  1439. # We assume that the inputs come from `ops.` and are not strings. If you want to generate
  1440. # intermediary strings, wrap them in CSE variables with properly initialised bounds.
  1441. # If there is no FX bound but we know how to compute one we do so
  1442. assert not kwargs
  1443. def arg_to_bound(x):
  1444. if isinstance(x, CSEVariable):
  1445. return x.bounds
  1446. elif isinstance(x, sympy.Expr):
  1447. return bound_sympy(x)
  1448. else:
  1449. return x
  1450. arg_bounds = list(map(arg_to_bound, args))
  1451. return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
  1452. else:
  1453. return ValueRanges.unknown()
  1454. @staticmethod
  1455. def indirect_indexing(
  1456. var: CSEVariable, size: Union[sympy.Expr, int], check: bool = True
  1457. ):
  1458. if isinstance(size, int):
  1459. size = sympy.Integer(size)
  1460. assert isinstance(size, sympy.Expr), size
  1461. # Skip CSE since this doesn't return an expression
  1462. if var.bounds.lower < 0: # type: ignore[operator]
  1463. stm = ops.add(var, ops.index_expr(size, torch.long))
  1464. # Mixed negative and non-negative
  1465. if var.bounds.upper >= 0: # type: ignore[operator]
  1466. lt = ops.lt(var, 0)
  1467. stm = ops.where(lt, stm, var)
  1468. # Propagate bounds as we know how to compute them properly
  1469. new_bounds = ValueRanges.unknown()
  1470. if var.bounds != ValueRanges.unknown() and isinstance(
  1471. size, sympy.Number
  1472. ):
  1473. # Take the negative part of the bound and add size to it
  1474. # Then take union of that and the positive part
  1475. # This is a tighter bound than that of a generic ops.where, as we have info on the cond
  1476. neg_bounds = var.bounds & ValueRanges(-sympy.oo, -1)
  1477. new_bounds = ValueRanges(
  1478. neg_bounds.lower + size, neg_bounds.upper + size
  1479. )
  1480. # We don't have a good way of representing the empty range
  1481. if var.bounds.upper >= 0: # type: ignore[operator]
  1482. pos = var.bounds & ValueRanges(0, sympy.oo)
  1483. new_bounds = new_bounds | pos
  1484. var = self.cse.generate(self.compute, stm, bounds=new_bounds)
  1485. sympy_var = parent_handler.indirect_indexing(var, size, check)
  1486. if generate_assert(check):
  1487. assert_lower = not (var.bounds.lower >= 0)
  1488. # value ranges cannot x < s when x and s are symbols
  1489. assert_upper = not isinstance(size, sympy.Number) or not (
  1490. var.bounds.upper < size
  1491. )
  1492. self.check_bounds(sympy_var, size, assert_lower, assert_upper)
  1493. return sympy_var
  1494. @staticmethod
  1495. def check_bounds(
  1496. expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
  1497. ):
  1498. return self.check_bounds(expr, size, lower, upper)
  1499. @staticmethod
  1500. def load(name: str, index: sympy.Expr) -> CSEVariable:
  1501. if name in self.cse.invalidated_stores:
  1502. # A load from an invalidated store requires us to
  1503. # keep the actual buffer around
  1504. V.kernel.must_keep_buffers.add(name)
  1505. if free_symbol_is_type(index, SymT.TMP):
  1506. return self.indirect_load(name, index)
  1507. store_cache = self.cse.store_cache
  1508. if name in store_cache:
  1509. return store_cache[name]
  1510. out = self.load(name, index)
  1511. # count load that is not in the store_cache, and also not in the
  1512. # cse cache.
  1513. if out.use_count == 1:
  1514. self.num_load += 1
  1515. return out
  1516. @staticmethod
  1517. def store(
  1518. name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
  1519. ) -> None:
  1520. self.store_buffer_names.add(name)
  1521. if mode is None:
  1522. self.cse.store_cache[name] = value
  1523. if self.current_node:
  1524. for other_name in self.current_node.get_mutations():
  1525. self.cse.store_cache[other_name] = value
  1526. if name not in V.graph.removed_buffers:
  1527. return self.store(name, index, value, mode=mode)
  1528. else:
  1529. return None # type: ignore[return-value]
  1530. @staticmethod
  1531. def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
  1532. self.store_buffer_names.add(name)
  1533. self.cse.store_cache[name] = value
  1534. if self.current_node:
  1535. for other_name in self.current_node.get_mutations():
  1536. self.cse.store_cache[other_name] = value
  1537. if name not in V.graph.removed_buffers:
  1538. return self.store_reduction(name, index, value)
  1539. @staticmethod
  1540. def reduction(
  1541. dtype: torch.dtype,
  1542. src_dtype: torch.dtype,
  1543. reduction_type: ReductionType,
  1544. value: Union[CSEVariable, Tuple[CSEVariable, ...]],
  1545. ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
  1546. self.num_reduction += 1
  1547. return self.reduction(dtype, src_dtype, reduction_type, value)
  1548. @staticmethod
  1549. def scan(
  1550. dtypes: Tuple[torch.dtype, ...],
  1551. combine_fn: Callable[
  1552. [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
  1553. Tuple[CSEVariable, ...],
  1554. ],
  1555. values: Tuple[CSEVariable, ...],
  1556. ) -> Tuple[CSEVariable, ...]:
  1557. return self.scan(dtypes, combine_fn, values)
  1558. @staticmethod
  1559. def bucketize(
  1560. values: CSEVariable,
  1561. offsets_name: str,
  1562. offsets_size: sympy.Expr,
  1563. indexing_dtype: torch.dtype,
  1564. right: bool,
  1565. ) -> CSEVariable:
  1566. """
  1567. [Note: Inductor bucketize op]
  1568. Given values (tensor) and offsets_name (reference to the name of a 1D
  1569. tensor), calculate the bucket that each value belongs to.
  1570. e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
  1571. return = [ 0, 1, 1, 1, 1, 3, 3, 4].
  1572. When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
  1573. When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
  1574. Offsets must be non-decreasing or the result is undefined.
  1575. """
  1576. return self.bucketize(
  1577. values, offsets_name, offsets_size, indexing_dtype, right
  1578. )
  1579. # Use mypy to check protocol implemented correctly
  1580. def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
  1581. return h
  1582. super().__enter__()
  1583. assert self.overrides
  1584. parent_handler = self.overrides(V.get_ops_handler())
  1585. self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
  1586. self.exit_stack.enter_context(V.set_kernel_handler(self))
  1587. return self
  1588. def __exit__(self, exc_type, exc_val, exc_tb):
  1589. """
  1590. Note that V.graph.scheduler can be None when codegening triton template
  1591. kernels.
  1592. """
  1593. if V.graph.scheduler:
  1594. V.graph.scheduler.remove_kernel_local_buffers()
  1595. super().__exit__(exc_type, exc_val, exc_tb)
  1596. def rename_indexing(self, index) -> sympy.Expr:
  1597. # adds the necessary kernel args for index expressions
  1598. # and renames variables in index expressions to kernel arg names
  1599. if isinstance(index, (list, tuple)):
  1600. return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
  1601. index = V.graph.sizevars.simplify(index)
  1602. sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
  1603. replacements = {
  1604. x: self.args.size(x)
  1605. for x in sorted_symbols
  1606. if symbol_is_type(
  1607. x,
  1608. (
  1609. SymT.UNBACKED_INT,
  1610. SymT.SIZE,
  1611. SymT.PRECOMPUTED_SIZE,
  1612. ),
  1613. )
  1614. }
  1615. return sympy_subs(index, replacements)
  1616. def create_cse_var(self, *args, **kwargs):
  1617. return CSEVariable(*args, **kwargs)
  1618. @dataclasses.dataclass
  1619. class OptimizationContext:
  1620. key: ClassVar[str] = "opt_ctx"
  1621. dtype: Optional[torch.dtype] = None
  1622. ops_name: str = ""
  1623. @functools.lru_cache(None)
  1624. def jinja2_env():
  1625. try:
  1626. import jinja2
  1627. return jinja2.Environment(
  1628. undefined=jinja2.StrictUndefined,
  1629. )
  1630. except ImportError:
  1631. return None
  1632. class KernelTemplate:
  1633. """
  1634. Base class for defining kernel templates.
  1635. Children classes: TritonTemplate, CUDATemplate
  1636. """
  1637. @staticmethod
  1638. def indent_except_first(source: str, num_indents: int, indents_spacing=4):
  1639. lines = source.splitlines(True)
  1640. if len(lines) > 1:
  1641. lines[1:] = [
  1642. (" " * indents_spacing * num_indents) + line for line in lines[1:]
  1643. ]
  1644. return "".join(lines)
  1645. @staticmethod
  1646. def _template_from_string(source):
  1647. env = jinja2_env()
  1648. if env is not None:
  1649. env.filters["indent_except_first"] = KernelTemplate.indent_except_first
  1650. return env.from_string(source)
  1651. return None
  1652. @staticmethod
  1653. def _fake_get_dtype(fake_out):
  1654. _get_dtype_real = V.graph.get_dtype
  1655. def get_dtype(name):
  1656. if name == fake_out.get_name():
  1657. return fake_out.get_dtype()
  1658. return _get_dtype_real(name)
  1659. return get_dtype
  1660. def __init__(self, name: str):
  1661. self.name = name
  1662. def maybe_append_choice(self, choices, **kwargs):
  1663. """
  1664. Maybe generates a new ChoiceCaller and appends it into existing choices.
  1665. choices: A list of ChoiceCallers.
  1666. kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
  1667. """
  1668. try:
  1669. choices.append(self.generate(**kwargs))
  1670. except NotImplementedError:
  1671. pass
  1672. def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
  1673. """
  1674. Generates a ChoiceCaller instance from the given arguments.
  1675. """
  1676. raise NotImplementedError