triton.py 95 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import dataclasses
  4. import functools
  5. import itertools
  6. import logging
  7. import os
  8. import textwrap
  9. from functools import lru_cache
  10. from typing import (
  11. Any,
  12. Callable,
  13. cast,
  14. Dict,
  15. List,
  16. Optional,
  17. Set,
  18. Tuple,
  19. TYPE_CHECKING,
  20. Union,
  21. )
  22. import sympy
  23. import torch
  24. import torch._logging
  25. from torch._dynamo.utils import preserve_rng_state
  26. from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
  27. from torch._prims_common import is_integer_dtype
  28. from torch.utils._triton import has_triton_package
  29. from ...utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
  30. from ...utils._sympy.value_ranges import ValueRanges
  31. from .. import config, ir
  32. from ..codecache import code_hash, get_path, PyCodeCache
  33. from ..metrics import is_metric_table_enabled, log_kernel_metadata
  34. from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
  35. from ..runtime.runtime_utils import do_bench_gpu, get_max_y_grid, next_power_of_2
  36. from ..utils import (
  37. cache_on_self,
  38. get_bounds_index_expr,
  39. get_fused_kernel_name,
  40. get_kernel_metadata,
  41. is_welford_reduction,
  42. Placeholder,
  43. sympy_dot,
  44. sympy_subs,
  45. )
  46. from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
  47. from ..wrapper_benchmark import get_kernel_category_by_source_code
  48. from .common import (
  49. CSE,
  50. CSEVariable,
  51. DeferredLine,
  52. IndentedBuffer,
  53. OpOverrides,
  54. PythonPrinter,
  55. SizeArg,
  56. TensorArg,
  57. )
  58. from .simd import constant_repr, IterationRangesEntry, pexpr, SIMDKernel, SIMDScheduling
  59. from .triton_utils import config_of, signature_of, signature_to_meta
  60. if TYPE_CHECKING:
  61. from ..ir import IRNode
  62. log = logging.getLogger(__name__)
  63. perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
  64. schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
  65. fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
  66. @lru_cache(None)
  67. def gen_attr_descriptor_import():
  68. """
  69. import AttrsDescriptor if the triton version is new enough to have this
  70. class defined.
  71. """
  72. if not has_triton_package():
  73. return ""
  74. import triton.compiler.compiler
  75. if hasattr(triton.compiler.compiler, "AttrsDescriptor"):
  76. return "from triton.compiler.compiler import AttrsDescriptor"
  77. else:
  78. return ""
  79. @lru_cache(None)
  80. def gen_common_triton_imports():
  81. imports = IndentedBuffer()
  82. imports.splice(
  83. """
  84. import triton
  85. import triton.language as tl
  86. """
  87. )
  88. if attr_desc := gen_attr_descriptor_import():
  89. imports.writeline(attr_desc)
  90. imports.splice(
  91. """
  92. from torch._inductor.runtime import triton_helpers, triton_heuristics
  93. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  94. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  95. """
  96. )
  97. return imports.getvalue()
  98. @dataclasses.dataclass
  99. class IndexingOptions:
  100. index_str: str
  101. mask_vars: Set[sympy.Symbol]
  102. mask_str: str
  103. expand_str: Optional[str]
  104. _has_rindex: bool
  105. index: sympy.Expr
  106. def has_mask(self):
  107. return bool(self.mask_vars)
  108. def has_indirect(self):
  109. return free_symbol_is_type(self.index, SymT.TMP)
  110. def has_rindex(self):
  111. return self._has_rindex
  112. def has_tmpmask(self):
  113. return "tmp" in self.mask_str
  114. def has_rmask(self):
  115. return "rmask" in self.mask_str
  116. @dataclasses.dataclass
  117. class BlockPtrOptions:
  118. constant_offset: sympy.Expr
  119. shape: List[sympy.Expr]
  120. strides: List[sympy.Expr]
  121. block_shape: List[str]
  122. order: List[int]
  123. offsets: List[str]
  124. mask_vars: Set[sympy.Symbol]
  125. reshape_suffix: List[str]
  126. @staticmethod
  127. def create(
  128. strides: List[sympy.Expr],
  129. constant_offset: sympy.Expr,
  130. range_trees: List[IterationRangesEntry],
  131. mask_vars: Set[sympy.Symbol],
  132. ) -> BlockPtrOptions:
  133. """Helper to create a BlockPtrOptions instance"""
  134. block_shape = [f"{t.prefix.upper()}BLOCK" for t in range_trees]
  135. reshape_suffix = [*block_shape]
  136. broadcasting_dim = [s == 0 for s in strides]
  137. for i, is_broadcasting in enumerate(broadcasting_dim):
  138. if is_broadcasting:
  139. # drop any stride==0 dimensions for performance
  140. reshape_suffix[i] = "1"
  141. if V.kernel.no_x_dim:
  142. assert range_trees[0].prefix == "x"
  143. reshape_suffix.pop(0)
  144. if (
  145. not V.kernel.inside_reduction
  146. and len(strides) == len(V.kernel.numels) - 1
  147. and V.kernel.numels[-1] != 1
  148. ):
  149. # Need to expand rank by 1 to match rank when self.inside_reduction=True
  150. reshape_suffix.append("1")
  151. def filter(it):
  152. """Removes any broadcasting dims from a given sequence"""
  153. assert len(it) == len(broadcasting_dim)
  154. return [
  155. item
  156. for item, is_broadcasting in zip(it, broadcasting_dim)
  157. if not is_broadcasting
  158. ]
  159. return BlockPtrOptions(
  160. constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset),
  161. shape=[
  162. V.graph.sizevars.lookup_precomputed_size(t.numel)
  163. for t in filter(range_trees)
  164. ],
  165. strides=[*map(V.graph.sizevars.lookup_precomputed_size, filter(strides))],
  166. block_shape=filter(block_shape),
  167. order=V.graph.sizevars.guarded_order(filter(strides)),
  168. offsets=filter([f"{t.prefix}offset" for t in range_trees]),
  169. mask_vars=mask_vars,
  170. reshape_suffix=reshape_suffix,
  171. )
  172. def format(self, name: str, roffset=True) -> str:
  173. """
  174. Codegen a call to tl.make_block_ptr()
  175. Args:
  176. name: variable name for pointer
  177. roffset: should roffset be included in offsets=..., for use with tl.advance()
  178. Returns:
  179. "tl.make_block_ptr(...)"
  180. """
  181. f = V.kernel.index_to_str
  182. offsets = [*self.offsets]
  183. if not roffset:
  184. offsets[offsets.index("roffset")] = "0"
  185. args = [
  186. f"{name} + ({f(self.constant_offset)})"
  187. if self.constant_offset != 0
  188. else name,
  189. f"shape={f(self.shape)}",
  190. f"strides={f(self.strides)}",
  191. f"block_shape={f(self.block_shape)}",
  192. f"order={f(self.order)}",
  193. f"offsets={f(offsets)}",
  194. ]
  195. return f"tl.make_block_ptr({', '.join(args)})"
  196. @cache_on_self
  197. def boundary_check(self) -> List[int]:
  198. """List of indices to pass to tl.load(boundary_check=...)"""
  199. check = []
  200. for i in range(len(self.shape)):
  201. if (
  202. self.block_shape[i] != "1"
  203. and not V.graph.sizevars.statically_known_equals(self.strides[i], 0) # type: ignore[arg-type]
  204. and not V.graph.sizevars.statically_known_multiple_of(
  205. self.shape[i],
  206. TRITON_MAX_BLOCK[self.block_shape[i][0]], # type: ignore[arg-type]
  207. )
  208. and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK")
  209. ):
  210. check.append(i)
  211. return check
  212. def advance_roffset(self):
  213. """Codegen string to pass to tl.advance(name, ...)"""
  214. advance = ["0"] * len(self.shape)
  215. advance[self.offsets.index("roffset")] = "RBLOCK"
  216. return V.kernel.index_to_str(advance)
  217. def has_indirect(self):
  218. return False # block_ptr can't do indirect indexing
  219. def has_rindex(self):
  220. return "RBLOCK" in self.block_shape
  221. def has_rmask(self):
  222. return self.has_rindex()
  223. def has_tmpmask(self):
  224. return False # block_ptr can't do indirect indexing
  225. def has_mask(self):
  226. return bool(self.boundary_check())
  227. def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]):
  228. """Workaround https://github.com/openai/triton/issues/2836"""
  229. assert isinstance(old_shape, list) and isinstance(new_shape, list)
  230. if old_shape == new_shape:
  231. return value
  232. if [s for s in new_shape if s != "1"] != old_shape:
  233. return f"tl.reshape({value}, [{', '.join(new_shape)}])"
  234. # rewrite to [:, None] syntax, which is less buggy
  235. idx = 0
  236. expand = []
  237. for size in new_shape:
  238. if idx < len(old_shape) and size == old_shape[idx]:
  239. expand.append(":")
  240. idx += 1
  241. else:
  242. assert size == "1"
  243. expand.append("None")
  244. assert idx == len(old_shape)
  245. return f"{value}[{', '.join(expand)}]"
  246. # NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a
  247. # number of operators which Triton "implements", but in a way that is
  248. # inconsistent with Python semantics (and consistent with C semantics). We
  249. # must override all of these, or it is potential silent correctness problem
  250. class TritonPrinter(PythonPrinter):
  251. def _print_TruncToInt(self, expr):
  252. assert len(expr.args) == 1
  253. return (
  254. f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
  255. )
  256. def _print_ToFloat(self, expr):
  257. assert len(expr.args) == 1
  258. return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)"
  259. # TODO: This is wrong if one of the inputs is negative. This is hard to
  260. # tickle though, as the inputs are typically positive (and if we can prove
  261. # they are positive, we will have used Mod instead, for which this codegen
  262. # is right). If you are trying to hit this, maybe try something like
  263. # torch.arange(n, device="cuda") - 1 and then do a modulus on it
  264. def _print_PythonMod(self, expr):
  265. return " % ".join(map(self.paren, map(self._print, expr.args)))
  266. # TODO: This is wrong, see
  267. # https://github.com/triton-lang/triton/issues/955
  268. # But for Sympy expressions, things will /mostly/ work out because we
  269. # don't usually deal with negative numbers in the division
  270. def _print_FloorDiv(self, expr):
  271. assert expr.is_integer
  272. x, div = expr.args
  273. x = self.paren(self.doprint(x))
  274. div = self.paren(self.doprint(div))
  275. return f"({x} // {div})"
  276. # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
  277. # precision algorithm, which we would need to replicate here
  278. def _print_IntTrueDiv(self, expr):
  279. lhs, rhs = expr.args
  280. return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
  281. # NB: sympy.floor/ceiling produce integers, so we have to do the
  282. # conversion to index dtype
  283. def _print_floor(self, expr):
  284. assert len(expr.args) == 1
  285. return (
  286. f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
  287. )
  288. def _print_FloorToInt(self, expr):
  289. assert len(expr.args) == 1
  290. return (
  291. f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
  292. )
  293. def _print_ceiling(self, expr):
  294. assert len(expr.args) == 1
  295. return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
  296. def _print_CeilToInt(self, expr):
  297. assert len(expr.args) == 1
  298. return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
  299. def _helper_sqrt(self, expr):
  300. return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))"
  301. def _print_Where(self, expr):
  302. c = self.doprint(expr.args[0])
  303. p = self.doprint(expr.args[1])
  304. q = self.doprint(expr.args[2])
  305. return f"tl.where({c}, {p}, {q})"
  306. def _print_Min(self, expr):
  307. nargs = len(expr.args)
  308. if len(expr.args) == 1:
  309. return self._print(expr.args[0])
  310. mid = len(expr.args) // 2
  311. a = self._print(sympy.Min(*expr.args[:mid]))
  312. b = self._print(sympy.Min(*expr.args[mid:]))
  313. return f"tl.minimum({a}, {b})"
  314. def _print_Max(self, expr):
  315. nargs = len(expr.args)
  316. if len(expr.args) == 1:
  317. return self._print(expr.args[0])
  318. mid = len(expr.args) // 2
  319. a = self._print(sympy.Max(*expr.args[:mid]))
  320. b = self._print(sympy.Max(*expr.args[mid:]))
  321. return f"tl.maximum({a}, {b})"
  322. def _print_Abs(self, expr):
  323. assert len(expr.args) == 1
  324. return f"tl_math.abs({self._print(expr.args[0])})"
  325. def _print_OpaqueUnaryFn_cos(self, expr):
  326. assert len(expr.args) == 1
  327. return f"libdevice.cos(({self._print(expr.args[0])}).to(tl.float32))"
  328. def _print_OpaqueUnaryFn_cosh(self, expr):
  329. assert len(expr.args) == 1
  330. return f"libdevice.cosh(({self._print(expr.args[0])}).to(tl.float32))"
  331. def _print_OpaqueUnaryFn_acos(self, expr):
  332. assert len(expr.args) == 1
  333. return f"libdevice.acos(({self._print(expr.args[0])}).to(tl.float32))"
  334. def _print_OpaqueUnaryFn_sin(self, expr):
  335. assert len(expr.args) == 1
  336. return f"libdevice.sin(({self._print(expr.args[0])}).to(tl.float32))"
  337. def _print_OpaqueUnaryFn_sinh(self, expr):
  338. assert len(expr.args) == 1
  339. return f"libdevice.sinh(({self._print(expr.args[0])}).to(tl.float32))"
  340. def _print_OpaqueUnaryFn_asin(self, expr):
  341. assert len(expr.args) == 1
  342. return f"libdevice.asin(({self._print(expr.args[0])}).to(tl.float32))"
  343. def _print_OpaqueUnaryFn_tan(self, expr):
  344. assert len(expr.args) == 1
  345. return f"libdevice.tan(({self._print(expr.args[0])}).to(tl.float32))"
  346. def _print_OpaqueUnaryFn_tanh(self, expr):
  347. assert len(expr.args) == 1
  348. return f"libdevice.tanh(({self._print(expr.args[0])}).to(tl.float32))"
  349. def _print_OpaqueUnaryFn_atan(self, expr):
  350. assert len(expr.args) == 1
  351. return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))"
  352. def _print_RoundToInt(self, expr):
  353. assert len(expr.args) == 1
  354. return f"libdevice.llrint({self._print(expr.args[0])})"
  355. def _print_RoundDecimal(self, expr):
  356. assert len(expr.args) == 2
  357. number, ndigits = expr.args
  358. if number.is_integer:
  359. # ndigits < 0 should have been filtered by the sympy function
  360. assert ndigits < 0
  361. raise ValueError(
  362. f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
  363. )
  364. return f"libdevice.nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits}"
  365. texpr = TritonPrinter().doprint
  366. def triton_compute_type(dtype):
  367. triton_type_name = str(dtype).split(".")[-1]
  368. if triton_type_name == "bool":
  369. triton_type_name = "int1"
  370. elif triton_type_name in ("float16", "bfloat16"):
  371. # float16 math is done in float32 inside the kernel
  372. triton_type_name = "float32"
  373. elif triton_type_name == "float8_e4m3fn":
  374. triton_type_name = "float8e4nv"
  375. elif triton_type_name == "float8_e5m2":
  376. triton_type_name = "float8e5"
  377. elif triton_type_name == "float8_e4m3fnuz":
  378. triton_type_name = "float8e4b8"
  379. elif triton_type_name == "float8_e5m2":
  380. triton_type_name = "float8e5b16"
  381. return f"tl.{triton_type_name}"
  382. def triton_store_type(dtype):
  383. triton_type_name = str(dtype).split(".")[-1]
  384. if triton_type_name == "bool":
  385. triton_type_name = "int8"
  386. elif triton_type_name == "float8_e4m3fn":
  387. triton_type_name = "float8e4nv"
  388. elif triton_type_name == "float8_e5m2":
  389. triton_type_name = "float8e5"
  390. return f"tl.{triton_type_name}"
  391. def triton_acc_type(dtype):
  392. if is_integer_dtype(dtype) and dtype.is_signed:
  393. nbits = 64 if dtype == torch.int64 else 32
  394. return f"tl.int{nbits}"
  395. return triton_compute_type(dtype)
  396. class TritonCSEVariable(CSEVariable):
  397. def __init__(self, name, bounds: ValueRanges[Any]):
  398. super().__init__(name, bounds)
  399. # We'll use this to track which masks the variable needs when used for indirect indexing
  400. self.mask_vars: Set[str] = set()
  401. def update_on_args(self, name, args, kwargs):
  402. for arg in args:
  403. if isinstance(arg, TritonCSEVariable):
  404. self.mask_vars.update(arg.mask_vars)
  405. elif isinstance(arg, sympy.Symbol) and arg.name[0] in "xyr":
  406. # most of the time index vars don't need masks associated with them
  407. # however, when index vars are used to compute indices for indirect reads
  408. # those reads should subsequently be masked,
  409. self.mask_vars.update({f"{arg.name[0]}mask"})
  410. class TritonOverrides(OpOverrides):
  411. """Map element-wise ops to Triton"""
  412. @staticmethod
  413. def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None):
  414. def _get_min_elements_per_thread(
  415. src_dtype: torch.dtype, dst_dtype: torch.dtype
  416. ) -> int:
  417. if src_dtype == dst_dtype:
  418. # No data type conversion is needed. No requirements on min_elem_per_thread.
  419. return 0
  420. # fp8 data type conversions has min_elem_per_thread requirements.
  421. # Refer to Triton implementations here:
  422. # https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.
  423. fp8_dtypes = {
  424. torch.float8_e4m3fn,
  425. torch.float8_e5m2,
  426. }
  427. # Triton doesn't support type conversions between fp8_e4m3 and fp8_e5m2.
  428. assert not (
  429. src_dtype in fp8_dtypes
  430. and dst_dtype in fp8_dtypes
  431. and src_dtype != dst_dtype
  432. ), "Conversions between float8_e5m2 and float8_e4m3fn is not supported!"
  433. if src_dtype == torch.float8_e5m2 or dst_dtype == torch.float8_e5m2:
  434. return 4
  435. if src_dtype == torch.float8_e4m3fn or dst_dtype == torch.float8_e4m3fn:
  436. return 2
  437. # No requirements on min_elem_per_thread.
  438. return 0
  439. if src_dtype is not None:
  440. # Both dtype and src_dtype are set. This is used by torch to(dtype=dtype).
  441. # It takes the maximum min_elem_per_thread if there are multiple fp8 conversions
  442. # in the same kernel.
  443. V.kernel.min_elem_per_thread = max(
  444. _get_min_elements_per_thread(src_dtype, dtype),
  445. V.kernel.min_elem_per_thread,
  446. )
  447. if dtype == torch.bool:
  448. return f"({x} != 0)"
  449. elif dtype == torch.uint8:
  450. # to work around llvm uint conversion semantics
  451. # that produces 0's for negative values
  452. return f"{x}.to(tl.int8).to(tl.uint8)"
  453. return f"{x}.to({triton_compute_type(dtype)})"
  454. @staticmethod
  455. def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype):
  456. triton_dtype = triton_compute_type(dtype)
  457. # We may promote float16 or bfloat16 to float32 and cause the
  458. # bitwidth of dtype to be different from the input tensor (i.e. float32).
  459. # In such as case, we will have to convert the input tensor to
  460. # its src_type, perform bitcast, and then convert the bit-casted
  461. # tensor back to float to ensure we use values with the right precision.
  462. if src_dtype in (torch.float16, torch.bfloat16):
  463. triton_src_dtype = str(src_dtype).split(".")[-1]
  464. cast_x = f"{x}.to(tl.{triton_src_dtype})"
  465. cast_x = f"{cast_x}.to({triton_dtype}, bitcast=True)"
  466. return f"{cast_x}.to(tl.float32)"
  467. else:
  468. return f"{x}.to({triton_dtype}, bitcast=True)"
  469. @staticmethod
  470. def _shaped_constant(value, dtype, shape):
  471. type_ = torch._prims_common.dtype_to_type(dtype)
  472. triton_val = constant_repr(type_(value))
  473. triton_type = triton_compute_type(dtype)
  474. if triton_type == "tl.float32":
  475. # Float constants are always f32 in triton
  476. return triton_val
  477. # NOTE: We use a tensor here in order to get the expected type.
  478. # Otherwise, e.g. float64 constants would be trunctated to float32.
  479. return f"tl.full({shape}, {triton_val}, {triton_type})"
  480. @classmethod
  481. def constant(cls, value, dtype):
  482. return cls._shaped_constant(value, dtype, shape=[])
  483. @staticmethod
  484. def abs(x):
  485. return f"tl_math.abs({x})"
  486. @staticmethod
  487. def libdevice_abs(x):
  488. return f"libdevice.abs({x})"
  489. @staticmethod
  490. def exp(x):
  491. return f"tl_math.exp({x})"
  492. @staticmethod
  493. def libdevice_exp(x):
  494. return f"libdevice.exp({x})"
  495. @staticmethod
  496. def exp2(x):
  497. return f"libdevice.exp2({x})"
  498. @staticmethod
  499. def expm1(x):
  500. return f"libdevice.expm1({x})"
  501. @staticmethod
  502. def sqrt(x):
  503. return f"libdevice.sqrt({x})"
  504. @staticmethod
  505. def libdevice_sqrt(x):
  506. return f"libdevice.sqrt({x})"
  507. @staticmethod
  508. def relu(x):
  509. bug = config.triton.inject_relu_bug_TESTING_ONLY
  510. if bug == "compile_error":
  511. return "compile error!"
  512. elif bug == "runtime_error":
  513. # NB: this only triggers runtime error as long as input
  514. # is not all zero
  515. return f'triton_helpers.device_assert_then({x} == 0, "injected assert fail", {x})'
  516. elif bug == "accuracy":
  517. return f"{x} + 1"
  518. elif bug is None:
  519. return ops.maximum(ops.constant(0, torch.int32), x)
  520. else:
  521. raise AssertionError(
  522. f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}"
  523. )
  524. @staticmethod
  525. def minimum(a, b):
  526. return f"triton_helpers.minimum({a}, {b})"
  527. @staticmethod
  528. def maximum(a, b):
  529. return f"triton_helpers.maximum({a}, {b})"
  530. @staticmethod
  531. def where(a, b, c):
  532. return f"tl.where({a}, {b}, {c})"
  533. @staticmethod
  534. def cos(x):
  535. return f"tl_math.cos({x})"
  536. @staticmethod
  537. def libdevice_cos(x):
  538. return f"libdevice.cos({x})"
  539. @staticmethod
  540. def sin(x):
  541. return f"tl_math.sin({x})"
  542. @staticmethod
  543. def libdevice_sin(x):
  544. return f"libdevice.sin({x})"
  545. @classmethod
  546. def index_expr(cls, expr, dtype):
  547. raise NotImplementedError("ops.index_expr not implemented outside a kernel")
  548. @staticmethod
  549. def masked(mask, body, other):
  550. raise NotImplementedError("ops.masked not implemented outside a kernel")
  551. @staticmethod
  552. def lgamma(x):
  553. return f"libdevice.lgamma({x})"
  554. @staticmethod
  555. def erf(x):
  556. return f"libdevice.erf({x})"
  557. @staticmethod
  558. def cosh(x):
  559. return f"libdevice.cosh({x})"
  560. @staticmethod
  561. def sinh(x):
  562. return f"libdevice.sinh({x})"
  563. @staticmethod
  564. def acos(x):
  565. return f"libdevice.acos({x})"
  566. @staticmethod
  567. def acosh(x):
  568. return f"libdevice.acosh({x})"
  569. @staticmethod
  570. def asin(x):
  571. return f"libdevice.asin({x})"
  572. @staticmethod
  573. def asinh(x):
  574. return f"libdevice.asinh({x})"
  575. @staticmethod
  576. def atan2(x, y):
  577. return f"libdevice.atan2({x}, {y})"
  578. @staticmethod
  579. def atan(x):
  580. return f"libdevice.atan({x})"
  581. @staticmethod
  582. def atanh(x):
  583. return f"libdevice.atanh({x})"
  584. @staticmethod
  585. def copysign(x, y):
  586. return f"libdevice.copysign({x}, {y})"
  587. @staticmethod
  588. def erfc(x):
  589. return f"libdevice.erfc({x})"
  590. @staticmethod
  591. def erfinv(x):
  592. return f"libdevice.erfinv({x})"
  593. @staticmethod
  594. def hypot(x, y):
  595. return f"libdevice.hypot({x}, {y})"
  596. @staticmethod
  597. def log10(x):
  598. return f"libdevice.log10({x})"
  599. @staticmethod
  600. def log2(x):
  601. return f"libdevice.log2({x})"
  602. @staticmethod
  603. def nextafter(x, y):
  604. return f"libdevice.nextafter({x}, {y})"
  605. @staticmethod
  606. def logical_and(a, b):
  607. return f"{a} & {b}"
  608. @staticmethod
  609. def logical_not(a):
  610. return f"{a} == 0"
  611. @staticmethod
  612. def logical_or(a, b):
  613. return f"{a} | {b}"
  614. @staticmethod
  615. def logical_xor(a, b):
  616. return f"({a} ^ {b})"
  617. @staticmethod
  618. def bitwise_and(a, b):
  619. return f"{a} & {b}"
  620. @staticmethod
  621. def bitwise_not(a):
  622. return f"~{a}"
  623. @staticmethod
  624. def bitwise_or(a, b):
  625. return f"{a} | {b}"
  626. @staticmethod
  627. def bitwise_xor(a, b):
  628. return f"{a} ^ {b}"
  629. @staticmethod
  630. def bitwise_left_shift(a, b):
  631. return f"{a} << {b}"
  632. @staticmethod
  633. def bitwise_right_shift(a, b):
  634. return f"{a} >> {b}"
  635. @staticmethod
  636. def rand(seed, offset):
  637. offset = f"({offset}).to(tl.uint32)"
  638. return f"tl.rand({seed}, {offset})"
  639. @staticmethod
  640. def randn(seed, offset):
  641. offset = f"({offset}).to(tl.uint32)"
  642. return f"tl.randn({seed}, {offset})"
  643. @staticmethod
  644. def randint64(seed, offset, low, high):
  645. offset = f"({offset}).to(tl.uint32)"
  646. return f"triton_helpers.randint64({seed}, {offset}, {low}, {high})"
  647. @staticmethod
  648. def load_seed(name, offset):
  649. raise NotImplementedError("ops.load_seed not implemented outside a kernel")
  650. @staticmethod
  651. def rsqrt(x):
  652. return f"libdevice.rsqrt({x})"
  653. @staticmethod
  654. def log1p(x):
  655. return f"libdevice.log1p({x})"
  656. @staticmethod
  657. def tan(x):
  658. return f"libdevice.tan({x})"
  659. @staticmethod
  660. def tanh(x):
  661. return f"libdevice.tanh({x})"
  662. @staticmethod
  663. def sigmoid(x):
  664. return f"tl.sigmoid({x})"
  665. @staticmethod
  666. def signbit(x):
  667. # XX: This is wrong for the value -0.0 in floating point
  668. return f"libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0"
  669. @staticmethod
  670. def fmod(a, b):
  671. return f"libdevice.fmod({a}, {b})"
  672. @staticmethod
  673. def pow(a, b):
  674. return f"libdevice.pow({a}, {b})"
  675. @staticmethod
  676. def log(x):
  677. return f"tl_math.log({x})"
  678. @staticmethod
  679. def libdevice_log(x):
  680. return f"libdevice.log({x})"
  681. @staticmethod
  682. def isinf(x):
  683. return f"libdevice.isinf({x}).to(tl.int1)"
  684. @staticmethod
  685. def isnan(x):
  686. return f"libdevice.isnan({x}).to(tl.int1)"
  687. @staticmethod
  688. def round(x):
  689. return f"libdevice.nearbyint({x})"
  690. @staticmethod
  691. def floor(x):
  692. return f"libdevice.floor({x})"
  693. @staticmethod
  694. def floordiv(a, b):
  695. # See the comment in lowering.div_mode. a and b are integer type.
  696. # Similar to div_floor_kernel_cuda in pytorch core.
  697. # Notice that // in triton behaves as truncdiv instead of floordiv
  698. quot = f"{a} // {b}"
  699. rem = f"{a} % {b}"
  700. return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})"
  701. @staticmethod
  702. def sign(x):
  703. z = ops.constant(0, torch.int32)
  704. left = ops.to_dtype((ops.lt(z, x)), torch.int8)
  705. right = ops.to_dtype((ops.lt(x, z)), torch.int8)
  706. sub = ops.sub(left, right)
  707. return f"{sub}.to({x}.dtype)"
  708. @staticmethod
  709. def trunc(x):
  710. return f"libdevice.trunc({x})"
  711. @staticmethod
  712. def truncdiv(a, b):
  713. # See the comment in lowering.div_mode. a and b are integer type.
  714. # Notice that // in triton behaves as truncdiv instead of floordiv
  715. return f"{a} // {b}"
  716. @staticmethod
  717. def ceil(x):
  718. return f"libdevice.ceil({x})"
  719. TritonOverrides._initialize_pointwise_overrides("triton")
  720. # Use mypy to check protocol implemented correctly
  721. def _typecheck_TritonOverrides(h: TritonOverrides) -> OpsHandler[str]:
  722. return h
  723. class TritonKernelOverrides(TritonOverrides):
  724. """Map element-wise ops to Triton within a TritonKernel
  725. Unlike TritonOverrides, these assume the code is going to be inserted into
  726. the body of the main triton kernel and so it may use indexing and mask
  727. variables which are assumed to already be defined in the current scope.
  728. """
  729. @classmethod
  730. def constant(cls, value, dtype):
  731. # NOTE: Cannot use shape=[] as it's not supported by triton-rocm
  732. # We could use shape=[1] instead but starting with the correct
  733. # ndim avoids extra `tt.expand_dim` ops appearing in the triton IR.
  734. ndim = V.kernel.triton_tensor_ndim()
  735. shape = [1] * ndim
  736. return cls._shaped_constant(value, dtype, shape=shape)
  737. @classmethod
  738. def index_expr(cls, expr, dtype):
  739. indexing = V.kernel.indexing(expr, block_ptr=False)
  740. assert isinstance(indexing, IndexingOptions)
  741. var = V.kernel.cse.generate(
  742. V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr)
  743. )
  744. if dtype not in {torch.int32, torch.int64}:
  745. var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype))
  746. var.mask_vars = indexing.mask_vars
  747. return var
  748. @staticmethod
  749. def masked(mask, body, other):
  750. with V.kernel.mask_loads(mask) as new_mask:
  751. result = body()
  752. # Remove once CSEVariables track the dtype
  753. if result.bounds.is_bool:
  754. other = bool(other)
  755. # Take dtype from result to prevent accidental promotion
  756. other = V.kernel.cse.generate(
  757. V.kernel.compute,
  758. f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)",
  759. bounds=ValueRanges.wrap(other),
  760. )
  761. ret = ops.where(new_mask, result, other)
  762. ret.mask_vars.discard(new_mask)
  763. return ret
  764. @staticmethod
  765. def load_seed(name, offset):
  766. var = V.kernel.args.input(name)
  767. return (
  768. f"tl.load({var} + {V.kernel.args.seed_offset('load_seed_offset', offset)})"
  769. )
  770. @staticmethod
  771. def frexp(x):
  772. cache_key = f"frexp({x})"
  773. if cache_key in V.kernel.cse.cache:
  774. return V.kernel.cse.cache[cache_key]
  775. mantissa = V.kernel.cse.newvar()
  776. exponent = V.kernel.cse.newvar()
  777. V.kernel.compute.writeline(
  778. f"{mantissa}, {exponent} = triton_helpers.frexp({x})"
  779. )
  780. V.kernel.cse.cache[cache_key] = (mantissa, exponent)
  781. return (mantissa, exponent)
  782. # Use mypy to check protocol implemented correctly
  783. def _typecheck_TritonKernelOverrides(h: TritonKernelOverrides) -> OpsHandler[str]:
  784. return h
  785. class HelperFunctions:
  786. """An ordered set of helper functions."""
  787. _templates_seen: Dict[str, str] # Template code to function name
  788. finalized_helpers: List[str]
  789. def __init__(self):
  790. self._templates_seen = {}
  791. self.finalized_helpers = []
  792. def add(self, template_code: str, *, base_name="_triton_helper_fn") -> str:
  793. """This accepts a function definition with the function name
  794. left as a format specifier e.g.
  795. @triton.jit
  796. def {name}(arg0, arg1):
  797. return arg0 + arg1
  798. We add the templated code to the function set and return the name
  799. assigned to that function.
  800. """
  801. existing_name = self._templates_seen.get(template_code)
  802. if existing_name is not None:
  803. # Don't duplicate existing helpers
  804. return existing_name
  805. name = f"{base_name}{len(self.finalized_helpers)}"
  806. self._templates_seen[template_code] = name
  807. self.finalized_helpers.append(template_code.format(name=name))
  808. return name
  809. def __iter__(self):
  810. return iter(self.finalized_helpers)
  811. def __getitem__(self, idx):
  812. return self.finalized_helpers[idx]
  813. class TritonKernel(SIMDKernel):
  814. overrides = TritonKernelOverrides # type: ignore[assignment]
  815. helper_functions: HelperFunctions
  816. kexpr: Callable[[sympy.Expr], str] = texpr
  817. allow_block_ptr = True
  818. def __init__(
  819. self,
  820. *groups,
  821. index_dtype: str,
  822. mutations: Optional[Set[str]] = None,
  823. pid_cache=None,
  824. reduction_hint=ReductionHint.DEFAULT,
  825. min_elem_per_thread=0,
  826. disable_persistent_reduction=False,
  827. ):
  828. super().__init__(
  829. *groups,
  830. index_dtype=index_dtype,
  831. mutations=mutations,
  832. reduction_hint=reduction_hint,
  833. pid_cache=pid_cache,
  834. disable_persistent_reduction=disable_persistent_reduction,
  835. )
  836. self.suffix: IndentedBuffer = IndentedBuffer() # type: ignore[assignment]
  837. self.outside_loop_vars: Set[Any] = set()
  838. self.min_elem_per_thread = min_elem_per_thread
  839. self.block_ptr_id = itertools.count()
  840. self.helper_functions = HelperFunctions()
  841. # A set of autotuning hints to pass as part of triton_meta
  842. self.autotune_hints: Set[AutotuneHint] = set()
  843. self.triton_meta: Optional[Dict[str, object]] = None
  844. self.codegen_range_tree()
  845. def codegen_range_tree(self):
  846. for tree in self.range_trees:
  847. # reduction indexing goes inside a loop
  848. if not tree.is_loop:
  849. self.iteration_ranges_codegen_header(tree, self.body)
  850. if self.inside_reduction and self.range_trees[-1].is_loop:
  851. # workaround for this issue:
  852. # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7
  853. self.body.writeline(
  854. f"rbase = {self.iteration_ranges_ranges_code(self.range_trees[-1])}"
  855. )
  856. def need_numel_args(self):
  857. r"""
  858. Indicate whether we need provide numel as arguments for the generated
  859. kernel calls in the benchmark.
  860. Should be true for pointwise/reduction kernels but false for triton
  861. matmul kernels.
  862. """
  863. return True
  864. def should_use_persistent_reduction(self) -> bool:
  865. """
  866. Heuristic to set self.persistent_reduction and add guards
  867. if needed.
  868. """
  869. if not (self.inside_reduction and config.triton.persistent_reductions):
  870. return False
  871. threshold = {
  872. ReductionHint.INNER: 1024,
  873. }.get(self.reduction_hint, 64)
  874. # If multi_kernel is enabled, we do more aggressive persistent reduction.
  875. # This may result in some persistent reductions slower than the
  876. # corresponding non-persistent reductions. MultiKernel will do benchmarking
  877. # to pick the faster one.
  878. if config.triton.multi_kernel:
  879. threshold *= 16
  880. last_numel = self.numels[-1]
  881. return V.graph.sizevars.statically_known_leq(last_numel, threshold) # type: ignore[arg-types]
  882. def want_no_x_dim(self):
  883. return (
  884. self.reduction_hint == ReductionHint.INNER
  885. and self.persistent_reduction
  886. and len(self.numels) == 2
  887. and V.graph.sizevars.statically_known_geq(self.numels[-1], 256) # type: ignore[arg-types]
  888. )
  889. @property
  890. def assert_function(self) -> str:
  891. return "tl.device_assert"
  892. def indexing(
  893. self,
  894. index: sympy.Expr,
  895. *,
  896. copy_shape=None,
  897. dense_indexing=False,
  898. override_mask=None,
  899. block_ptr=False,
  900. ):
  901. """
  902. Compute the index and mask to pass to tl.load() or tl.store()
  903. """
  904. index = self.prepare_indexing(index)
  905. index_vars = index.free_symbols
  906. has_rindex = False
  907. mask_vars: Set[str] = set()
  908. for var in index_vars:
  909. assert isinstance(var, sympy.Symbol)
  910. has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX)
  911. if override_mask:
  912. pass
  913. elif symbol_is_type(var, SymT.TMP):
  914. # indirect indexing
  915. cse_var = self.cse.varname_map[var.name]
  916. mask_vars.update(cse_var.mask_vars)
  917. elif symbol_is_type(
  918. var,
  919. (
  920. SymT.UNBACKED_INT,
  921. SymT.SIZE,
  922. SymT.PRECOMPUTED_SIZE,
  923. SymT.INDEX,
  924. SymT.FLOAT,
  925. SymT.UNBACKED_FLOAT,
  926. ),
  927. ):
  928. pass
  929. else:
  930. # var is one of xN, yN or rN
  931. assert symbol_is_type(
  932. var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK)
  933. ), var.name
  934. mask_vars.add(f"{var.name[0]}mask")
  935. need_dense = (
  936. config.triton.dense_indexing
  937. or dense_indexing
  938. or self._load_mask is not None
  939. ) and index != 0
  940. have_dense = True
  941. have_loop_vars = False
  942. dense_mask_vars = set()
  943. for tree in self.active_range_trees():
  944. if index_vars.intersection(tree.var_list):
  945. have_loop_vars = True
  946. else:
  947. have_dense = False
  948. dense_mask_vars.add(f"{tree.prefix}mask")
  949. if (
  950. block_ptr
  951. and self.allow_block_ptr
  952. and config.triton.use_block_ptr
  953. and not override_mask
  954. and not self._load_mask
  955. and len(mask_vars - dense_mask_vars) == 0
  956. and not self.is_indirect_indexing(index)
  957. and have_loop_vars
  958. # workaround https://github.com/openai/triton/issues/2821
  959. and self.index_dtype == "tl.int32"
  960. ):
  961. index_relative_to_xyr_index = sympy_subs(
  962. index, {v: t.expr for v, t in self.range_tree_nodes.items()}
  963. )
  964. range_trees = self.active_range_trees(reorder=True)
  965. symbols = [t.symbol() for t in range_trees]
  966. strides = [sympy.Wild(f"stride_{s}", exclude=symbols) for s in symbols]
  967. offset = sympy.Wild("_offset", exclude=symbols)
  968. m = index_relative_to_xyr_index.match(sympy_dot(symbols, strides) + offset)
  969. # TODO(jansel): it is sometimes possible to do higher dimensional block_ptrs with
  970. # a tl.reshape the correct block. We will miss these cases today.
  971. if m:
  972. self.filter_masks(mask_vars)
  973. from .triton import BlockPtrOptions
  974. return BlockPtrOptions.create(
  975. [m[s] for s in strides],
  976. m[offset],
  977. range_trees,
  978. mask_vars, # type: ignore[arg-type]
  979. )
  980. expand_str = None
  981. index_str = self.index_to_str(index)
  982. if isinstance(index, sympy.Integer):
  983. expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
  984. index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
  985. return IndexingOptions(
  986. index_str, set(), "None", expand_str, has_rindex, index
  987. )
  988. if need_dense and not have_dense:
  989. expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
  990. index_str = f"tl.broadcast_to({index_str}, {expand_str})"
  991. mask_vars = dense_mask_vars
  992. elif not have_loop_vars and copy_shape:
  993. index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)"
  994. mask_vars = dense_mask_vars
  995. if override_mask:
  996. mask_vars = {override_mask}
  997. if self._load_mask:
  998. mask_vars.add(self._load_mask)
  999. self.filter_masks(mask_vars)
  1000. mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
  1001. return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex, index) # type: ignore[arg-type]
  1002. def codegen_block_ptr(
  1003. self, name: str, var: str, indexing: BlockPtrOptions, other=""
  1004. ) -> Tuple[str, Optional[DeferredLine], str]:
  1005. advance_block_ptr = None
  1006. check = indexing.boundary_check()
  1007. if not check:
  1008. # workaround https://github.com/openai/triton/issues/2813
  1009. other = ""
  1010. elif other:
  1011. assert other == ", other=0.0"
  1012. other = f", boundary_check={check!r}, padding_option='zero'"
  1013. else:
  1014. other = f", boundary_check={check!r}"
  1015. if (
  1016. self.inside_reduction
  1017. and self.range_trees[-1].is_loop
  1018. and indexing.has_rindex()
  1019. ):
  1020. block_ptr = f"block_ptr{next(self.block_ptr_id)}"
  1021. self.body.writeline(
  1022. DeferredLine(
  1023. name, f"{block_ptr} = {indexing.format(var, roffset=False)}"
  1024. )
  1025. )
  1026. advance_block_ptr = DeferredLine(
  1027. name,
  1028. f"{block_ptr} = tl.advance({block_ptr}, {indexing.advance_roffset()})",
  1029. )
  1030. else:
  1031. block_ptr = indexing.format(var)
  1032. return block_ptr, advance_block_ptr, other
  1033. def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
  1034. # broadcasting is not implicit for block_ptrs
  1035. value = (
  1036. f"tl.broadcast_to({value}, {self.index_to_str(indexing.reshape_suffix)})"
  1037. )
  1038. # drop any extra size=1 dimensions
  1039. value = triton_reshape(value, indexing.reshape_suffix, indexing.block_shape)
  1040. # workaround https://github.com/openai/triton/issues/2814
  1041. value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})"
  1042. return f"tl.store({block_ptr}, {value}{other})"
  1043. def check_bounds(
  1044. self,
  1045. expr: sympy.Expr,
  1046. size: sympy.Expr,
  1047. lower: bool,
  1048. upper: bool,
  1049. ):
  1050. if not (lower or upper):
  1051. return
  1052. assert isinstance(expr, sympy.Expr)
  1053. indexing = self.indexing(expr, block_ptr=False)
  1054. assert isinstance(indexing, IndexingOptions)
  1055. index_str = indexing.index_str
  1056. mask_str = indexing.mask_str if indexing.has_mask() else None
  1057. size_str = V.kernel.sexpr(self.rename_indexing(size)) if upper else None
  1058. # expr is already wrapped
  1059. line = self.indirect_assert(
  1060. index_str, "0" if lower else None, size_str, mask_str
  1061. )
  1062. indirect = self.is_indirect_indexing(expr) or any(
  1063. isinstance(m, TritonCSEVariable) for m in indexing.mask_vars
  1064. )
  1065. buffer = self.get_load_buffer(indexing)
  1066. self.cse.generate(buffer, line, assignment=False)
  1067. def get_load_buffer(self, indexing):
  1068. if indexing.has_indirect() or indexing.has_tmpmask():
  1069. # Masked loads must come after the mask is computed
  1070. return self.compute
  1071. elif (
  1072. self.inside_reduction
  1073. and self.range_trees[-1].is_loop
  1074. and not indexing.has_rindex()
  1075. ):
  1076. # can lift a common load outside of reduction loop
  1077. # One exception is when this is an indirect_load.
  1078. return self.body
  1079. else:
  1080. return self.loads
  1081. def load(self, name: str, index: sympy.Expr):
  1082. var = self.args.input(name)
  1083. indirect_indexing = self.is_indirect_indexing(index)
  1084. original_index = index
  1085. indexing = self.indexing(index, block_ptr=True)
  1086. has_rindex = indexing.has_rindex()
  1087. has_tmpmask = indexing.has_tmpmask()
  1088. # Keep the variable in cache if were going to reuse it. Equiv., if any of the following hold
  1089. # 1) We are doing broadcasting
  1090. # 2) It is a non-coalesced load. The intuition is that if it's
  1091. # non-coalesced, we will likely load each element multiple times in
  1092. # practice.
  1093. # 3) It will be used later and it won't be CSE'd. Equiv., if all the following hold
  1094. # 3.1) We are in a reduction loop
  1095. # 3.2) Its not its last use
  1096. # 3.3) This load will not be lifted to the body
  1097. #
  1098. is_coalesced = any(
  1099. i == 1 for i in self.get_strides_of_load(original_index).values()
  1100. )
  1101. if self.is_broadcasted(original_index):
  1102. ep = ", eviction_policy='evict_last'"
  1103. elif not is_coalesced:
  1104. ep = ", eviction_policy='evict_last'"
  1105. elif self.inside_reduction and self.range_trees[-1].is_loop:
  1106. if name in self.args.inplace_buffers:
  1107. names = set(self.args.inplace_buffers[name].other_names)
  1108. else:
  1109. names = {name}
  1110. last_use = len(names & self.last_usage) > 0
  1111. evict_last = not last_use and (has_rindex or indirect_indexing)
  1112. if evict_last:
  1113. ep = ", eviction_policy='evict_last'"
  1114. else:
  1115. ep = ", eviction_policy='evict_first'"
  1116. else:
  1117. ep = ""
  1118. if (has_tmpmask or has_rindex) and indexing.has_mask():
  1119. other = ", other=0.0"
  1120. else:
  1121. other = ""
  1122. advance_block_ptr = None
  1123. append_broadcast = None
  1124. if V.graph.is_unspec_arg(name):
  1125. line = var
  1126. else:
  1127. if isinstance(indexing, BlockPtrOptions):
  1128. block_ptr, advance_block_ptr, other = self.codegen_block_ptr(
  1129. name, var, indexing, other
  1130. )
  1131. line = f"tl.load({block_ptr}{other}{ep})"
  1132. # add needed size=1 dimensions
  1133. line = triton_reshape(
  1134. line, indexing.block_shape, indexing.reshape_suffix
  1135. )
  1136. elif isinstance(original_index, sympy.Integer):
  1137. line = f"tl.load({var} + ({original_index}))"
  1138. append_broadcast = indexing.expand_str
  1139. else:
  1140. line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})"
  1141. dtype = V.graph.get_dtype(name)
  1142. if dtype in (torch.float16, torch.bfloat16):
  1143. line += ".to(tl.float32)"
  1144. if dtype == torch.bool and torch.version.hip is None:
  1145. # Workaround for https://github.com/openai/triton/issues/2151
  1146. # tl.load returns int8 when loading from pointer to int1
  1147. # NOTE: Currently causes hangs on bool UTs for ROCm
  1148. line += ".to(tl.int1)"
  1149. load_buffer = self.get_load_buffer(indexing)
  1150. result_var = self.cse.generate(load_buffer, line)
  1151. assert isinstance(result_var, TritonCSEVariable)
  1152. result_var.mask_vars = indexing.mask_vars # type: ignore[assignment]
  1153. if append_broadcast:
  1154. line = f"tl.broadcast_to({result_var}, {append_broadcast})"
  1155. result_var = self.cse.generate(load_buffer, line)
  1156. if advance_block_ptr:
  1157. load_buffer.writeline(advance_block_ptr)
  1158. if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex):
  1159. self.outside_loop_vars.add(result_var)
  1160. return result_var
  1161. def store(
  1162. self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
  1163. ) -> None:
  1164. var = self.args.output(name)
  1165. original_index = index
  1166. indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None)
  1167. # Guard against write-after-read corruption in triton.
  1168. # See # https://github.com/openai/triton/issues/1615
  1169. # This triton bug means that a load which is broadcasted over multiple
  1170. # warps may see the result of a store that happens later in the triton
  1171. # program. The workaround is to add a barrier before storing, which
  1172. # enforces that all warps have already read the data.
  1173. is_inplace = name in self.args.inplace_buffers
  1174. is_broadcasted = self.is_broadcasted(original_index)
  1175. if is_inplace and is_broadcasted:
  1176. self.stores.writeline(DeferredLine(name, "tl.debug_barrier()"))
  1177. advance_block_ptr = None
  1178. if isinstance(indexing, BlockPtrOptions):
  1179. block_ptr, advance_block_ptr, other = self.codegen_block_ptr(
  1180. name, var, indexing
  1181. )
  1182. # block_ptr stores don't do implicit casting
  1183. line = self.codegen_block_ptr_store_line(
  1184. name, indexing, block_ptr, value, other
  1185. )
  1186. elif mode is None:
  1187. line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})"
  1188. elif mode == "atomic_add":
  1189. line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str})"
  1190. else:
  1191. raise NotImplementedError(f"store mode={mode}")
  1192. self.stores.writeline(DeferredLine(name, line))
  1193. if advance_block_ptr:
  1194. self.stores.writeline(advance_block_ptr)
  1195. if not self.inside_reduction:
  1196. self.outside_loop_vars.add(value)
  1197. def bucketize(
  1198. self,
  1199. values: CSEVariable,
  1200. offsets_name: str,
  1201. offsets_size: sympy.Expr,
  1202. indexing_dtype: torch.dtype,
  1203. right: bool,
  1204. ) -> CSEVariable:
  1205. """
  1206. See [Note: Inductor bucketize op]
  1207. """
  1208. # Triton performance for bucketize_binary_search is much better when the number
  1209. # of threads equals the number of elements.
  1210. # If we're trying to use a bucketize kernel, we should make sure that an
  1211. # autotuning config with num_elements_per_warp=32 exists.
  1212. self.autotune_hints.add(AutotuneHint.ELEMENTS_PER_WARP_32)
  1213. offsets_ptr = self.args.input(offsets_name)
  1214. block_size = self.dense_size_str()
  1215. offsets_size_str = self.index_to_str(offsets_size)
  1216. if indexing_dtype == torch.int32:
  1217. triton_dtype = "tl.int32"
  1218. elif indexing_dtype == torch.int64:
  1219. triton_dtype = "tl.int64"
  1220. else:
  1221. raise NotImplementedError(
  1222. "Bucketize only supports indexing with int32 and int64"
  1223. )
  1224. result = self.cse.generate(
  1225. self.compute,
  1226. f"triton_helpers.bucketize_binary_search({values}, {offsets_ptr}, {triton_dtype}, {right}, {offsets_size_str}, {block_size})", # noqa: B950 line too long
  1227. )
  1228. return result
  1229. def reduction_resize(self, value):
  1230. ndims = self.triton_tensor_ndim()
  1231. if ndims == 1:
  1232. return f"triton_helpers.promote_to_tensor({value})"
  1233. sizes = [":"] * ndims
  1234. sizes[-1] = "None"
  1235. return f"{value}[{', '.join(sizes)}]"
  1236. def reduction(
  1237. self,
  1238. dtype: torch.dtype,
  1239. src_dtype: torch.dtype,
  1240. reduction_type: ReductionType,
  1241. value: Union[CSEVariable, Tuple[CSEVariable, ...]],
  1242. ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
  1243. assert self.inside_reduction
  1244. masks = {f"{tree.prefix}mask" for tree in self.range_trees}
  1245. self.filter_masks(masks)
  1246. masks = sorted(masks)
  1247. if self._load_mask:
  1248. masks.append(self._load_mask)
  1249. reduction_range_prefix = self.range_trees[-1].prefix
  1250. # Say we have
  1251. # tmp0 = ops.constant(1, torch.int64)
  1252. # tmp1 = ops.reduction(torch.int64, torch.int64, "sum", tmp0)
  1253. # tmp0 in the triton code is either a scalar, or single-element tensor
  1254. # so if we emit tl.sum directly, it will only give 1 instead of RBLOCK * 1
  1255. # To avoid this, we broadcast to the expected shape first.
  1256. dense_size_str = self.dense_size_str()
  1257. value = self._map_tuple_or_scalar(
  1258. lambda v: self.cse.generate(
  1259. self.compute, f"tl.broadcast_to({v}, {dense_size_str})"
  1260. ),
  1261. value,
  1262. )
  1263. dim: int
  1264. root_op: str
  1265. def final_reduction(value):
  1266. use_helper = reduction_type in {"any", "max", "min", "prod"}
  1267. module = "triton_helpers" if use_helper else "tl"
  1268. if reduction_type in {"max", "min"}:
  1269. return self.reduction_resize(
  1270. f"{module}.{reduction_type}2({value}, {dim})"
  1271. )
  1272. return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})")
  1273. def final_argreduce(buffer, result_var, value, index):
  1274. buffer.splice(
  1275. f"""\
  1276. _, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
  1277. {result_var} = {self.reduction_resize(f'{result_var}_tmp')}
  1278. """
  1279. )
  1280. cache_key = (src_dtype, reduction_type, value)
  1281. if cache_key in self.cse.reduction_cache:
  1282. return self.cse.reduction_cache[cache_key]
  1283. dim = self.triton_tensor_ndim() - 1
  1284. acc_type = triton_acc_type(src_dtype)
  1285. result_var: Any = self.cse.newvar()
  1286. result_var.mask_vars = {var for var in masks if var[0] != "r"}
  1287. cond = " & ".join(masks)
  1288. def where_cond(tval, fval):
  1289. if not cond:
  1290. return tval
  1291. return TritonKernelOverrides.where(cond, tval, fval)
  1292. if self.persistent_reduction:
  1293. default = ir.Reduction.default_value(reduction_type, src_dtype)
  1294. default = self._map_tuple_or_scalar(constant_repr, default)
  1295. def _mask_value(value, default):
  1296. return self.cse.generate(self.compute, where_cond(value, default))
  1297. if isinstance(value, tuple):
  1298. masked_value = [_mask_value(v, d) for v, d in zip(value, default)]
  1299. else:
  1300. masked_value = _mask_value(value, default)
  1301. if reduction_type in {"argmax", "argmin"}:
  1302. accumulator_index = str(
  1303. self.cse.generate(
  1304. self.compute,
  1305. f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)",
  1306. )
  1307. )
  1308. root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
  1309. final_argreduce(
  1310. self.compute, result_var, masked_value, accumulator_index
  1311. )
  1312. elif reduction_type == "welford_reduce":
  1313. # For persistent reductions, don't bother with
  1314. # welford's algorithm since it uses more registers, and
  1315. # taking two reductions doesn't increase memory usage.
  1316. result_var = self.welford_reduce_fallback(dtype, value)
  1317. elif reduction_type == "welford_combine":
  1318. mean, m2, weight = masked_value
  1319. welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})"
  1320. mean, m2, weight = (self.cse.newvar() for _ in range(3))
  1321. self.compute.writeline(f"{mean}, {m2}, {weight} = {welford}")
  1322. result_var = tuple(
  1323. self.cse.generate(self.compute, self.reduction_resize(var_name))
  1324. for var_name in (mean, m2, weight)
  1325. )
  1326. else:
  1327. result_var = self.cse.generate(
  1328. self.compute, final_reduction(masked_value)
  1329. )
  1330. else:
  1331. accumulator = f"_{result_var}"
  1332. default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
  1333. default = self._map_tuple_or_scalar(constant_repr, default)
  1334. if not isinstance(default, tuple):
  1335. self.body.writeline(
  1336. f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})"
  1337. )
  1338. if reduction_type in {"argmax", "argmin"}:
  1339. accumulator_index = f"_{result_var}_index"
  1340. long_max = torch.iinfo(torch.int64).max
  1341. self.body.writeline(
  1342. f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)"
  1343. )
  1344. root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
  1345. self.compute.splice(
  1346. f"""\
  1347. {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
  1348. {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
  1349. )
  1350. {accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
  1351. {accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)}
  1352. """
  1353. )
  1354. final_argreduce(self.suffix, result_var, accumulator, accumulator_index)
  1355. elif is_welford_reduction(reduction_type):
  1356. accumulator = f"{result_var}_mean"
  1357. accumulator_m2 = f"{result_var}_m2"
  1358. accumulator_weight = f"{result_var}_weight"
  1359. self.body.writeline(
  1360. f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})"
  1361. )
  1362. self.body.writeline(
  1363. f"{accumulator_m2} = tl.zeros({self.dense_size_str()}, {acc_type})"
  1364. )
  1365. self.body.writeline(
  1366. f"{accumulator_weight} = tl.zeros({self.dense_size_str()}, {acc_type})"
  1367. )
  1368. if reduction_type == "welford_combine":
  1369. mean, m2, weight = value
  1370. self.compute.splice(
  1371. f"""\
  1372. {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_combine(
  1373. {accumulator}, {accumulator_m2}, {accumulator_weight},
  1374. {mean}, {m2}, {weight}
  1375. )
  1376. """
  1377. )
  1378. else:
  1379. assert reduction_type == "welford_reduce"
  1380. self.compute.splice(
  1381. f"""\
  1382. {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_reduce(
  1383. {value}, {accumulator}, {accumulator_m2}, {accumulator_weight}, roffset == 0
  1384. )
  1385. """
  1386. )
  1387. self.compute.splice(
  1388. f"""\
  1389. {accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
  1390. {accumulator_m2} = {where_cond(f'{accumulator_m2}_next', accumulator_m2)}
  1391. {accumulator_weight} = {where_cond(f'{accumulator_weight}_next', accumulator_weight)}
  1392. """
  1393. )
  1394. result_mean = result_var
  1395. result_m2 = self.cse.newvar()
  1396. result_weight = self.cse.newvar()
  1397. self.suffix.splice(
  1398. f"""\
  1399. {result_mean}_tmp, {result_m2}_tmp, {result_weight}_tmp = triton_helpers.welford(
  1400. {accumulator}, {accumulator_m2}, {accumulator_weight}, {dim}
  1401. )
  1402. {result_mean} = {self.reduction_resize(f'{result_mean}_tmp')}
  1403. {result_m2} = {self.reduction_resize(f'{result_m2}_tmp')}
  1404. {result_weight} = {self.reduction_resize(f'{result_weight}_tmp')}
  1405. """
  1406. )
  1407. result_var = result_mean, result_m2, result_weight
  1408. else:
  1409. combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype)
  1410. updated = combine_fn(accumulator, value)
  1411. self.compute.writeline(
  1412. f"{accumulator} = {where_cond(updated, accumulator)}"
  1413. )
  1414. if src_dtype == torch.bool:
  1415. # This is only really used for aten.any. It changes the
  1416. # final reduction of a non-persistent reduction from
  1417. # tmp5 = triton_helpers.max(_tmp5, 1)[:, None]
  1418. # to
  1419. # tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1)
  1420. # which is needed because tl.reduce doesn't support tl.int1
  1421. accumulator = f"{accumulator}.to(tl.int8)"
  1422. result_type = triton_compute_type(dtype)
  1423. self.suffix.writeline(
  1424. f"{result_var} = {final_reduction(accumulator)}.to({result_type})"
  1425. )
  1426. else:
  1427. self.suffix.writeline(
  1428. f"{result_var} = {final_reduction(accumulator)}"
  1429. )
  1430. self.cse.reduction_cache[cache_key] = result_var
  1431. if isinstance(result_var, tuple):
  1432. assert all(isinstance(x, TritonCSEVariable) for x in result_var)
  1433. self.outside_loop_vars |= set(result_var)
  1434. else:
  1435. assert isinstance(result_var, TritonCSEVariable)
  1436. self.outside_loop_vars.add(result_var)
  1437. return result_var
  1438. def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
  1439. assert self.inside_reduction
  1440. self.inside_reduction = False
  1441. indexing = self.indexing(index, block_ptr=True)
  1442. self.inside_reduction = True
  1443. var = self.args.output(name)
  1444. if isinstance(indexing, BlockPtrOptions):
  1445. self.suffix.writeline(
  1446. DeferredLine(
  1447. name,
  1448. self.codegen_block_ptr_store_line(
  1449. name,
  1450. indexing,
  1451. indexing.format(var),
  1452. value,
  1453. f", boundary_check={indexing.boundary_check()!r}",
  1454. ),
  1455. )
  1456. )
  1457. else:
  1458. assert isinstance(indexing, IndexingOptions)
  1459. self.suffix.writeline(
  1460. DeferredLine(
  1461. name,
  1462. f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})",
  1463. )
  1464. )
  1465. def _lift_helper(self, fn, num_args) -> str:
  1466. # Lift IR function for scan operations into a triton function
  1467. # in the global namespace
  1468. helper = IndentedBuffer()
  1469. helper.writeline("@triton.jit")
  1470. args = [tuple(f"arg{i}_{n}" for n in range(num_args)) for i in range(2)]
  1471. signature = ", ".join(itertools.chain.from_iterable(args))
  1472. helper.writeline(f"def {{name}}({signature}):")
  1473. cse = CSE(prefix="", suffix="")
  1474. overrides = TritonOverrides(V.MockHandler())
  1475. # Build a name that changes depending on fn to workaround a triton bug
  1476. # where the combine_fn to reduce and scan is not hashed, and so different
  1477. # scan ops may collide in the triton cache.
  1478. # This is fixed with the latest triton pin, but not the triton-rocm pin.
  1479. helper_name = "_triton_helper_fn"
  1480. class CSEProxy:
  1481. def __getattr__(self, name: str) -> Callable[..., CSEVariable]:
  1482. def inner(*args, **kwargs):
  1483. nonlocal helper_name
  1484. helper_name += f"_{name}"
  1485. return cse.generate(
  1486. helper,
  1487. getattr(overrides, name)(*args, **kwargs),
  1488. )
  1489. return inner
  1490. with helper.indent(), V.set_ops_handler(CSEProxy()):
  1491. outputs = fn(*args)
  1492. outputs = ", ".join(str(output) for output in outputs)
  1493. helper.writeline(f"return {outputs}")
  1494. return self.helper_functions.add(helper.getvalue(), base_name=helper_name)
  1495. def scan(
  1496. self,
  1497. dtypes: Tuple[torch.dtype, ...],
  1498. combine_fn: Callable[
  1499. [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
  1500. ],
  1501. values: Tuple[CSEVariable, ...],
  1502. ) -> Tuple[CSEVariable, ...]:
  1503. assert self.inside_reduction
  1504. masks = {f"{tree.prefix}mask" for tree in self.range_trees}
  1505. self.filter_masks(masks)
  1506. masks = sorted(masks)
  1507. assert not self._load_mask, "ops.scan not supported inside ops.masked"
  1508. reduction_range_prefix = self.range_trees[-1].prefix
  1509. broadcasted_values = []
  1510. accumulators = []
  1511. cse_compute = functools.partial(self.cse.generate, self.compute)
  1512. combine_helper_fn = self._lift_helper(combine_fn, len(values))
  1513. dim = self.triton_tensor_ndim() - 1
  1514. for value, dtype in zip(values, dtypes):
  1515. acc_type = triton_acc_type(dtype)
  1516. cond = " & ".join(masks)
  1517. value_dtype = self.cse.generate(
  1518. self.compute,
  1519. f"{value}.to({triton_compute_type(dtype)})",
  1520. )
  1521. value = self.cse.generate(
  1522. self.compute,
  1523. f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})",
  1524. )
  1525. broadcasted_values.append(value)
  1526. acc_type = triton_acc_type(dtype)
  1527. cond = " & ".join(masks)
  1528. if not self.persistent_reduction:
  1529. accumulator = self.cse.newvar()
  1530. reduced_size = self.dense_size_list()
  1531. reduced_size[-1] = "1"
  1532. reduced_size = f"[{', '.join(reduced_size)}]"
  1533. default = "float('nan')" if dtype.is_floating_point else "-1"
  1534. self.body.writeline(
  1535. f"{accumulator} = tl.full({reduced_size}, {default}, {acc_type})"
  1536. )
  1537. accumulators.append(accumulator)
  1538. def csv(values):
  1539. return " ".join(f"{value}," for value in values)
  1540. def cse_multiple(line, n, masks):
  1541. cache_keys = [f"{line}, {i}, {masks}" for i in range(n)]
  1542. if all(cache_key in self.cse.cache for cache_key in cache_keys):
  1543. return [self.cse.cache[cache_key] for cache_key in cache_keys]
  1544. result_vars = [self.cse.newvar() for _ in range(n)]
  1545. self.compute.writeline(
  1546. f"{csv(result_vars)} = {line}",
  1547. )
  1548. for result_var, cache_key in zip(result_vars, cache_keys):
  1549. if masks:
  1550. result_var.mask_vars = masks # type: ignore[attr-defined]
  1551. self.cse.cache[cache_key] = result_var
  1552. return tuple(result_vars)
  1553. partial_scan_vars = cse_multiple(
  1554. f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})",
  1555. len(values),
  1556. masks,
  1557. )
  1558. if not self.persistent_reduction:
  1559. def sum_fn(a, b):
  1560. return [ops.add(ai, bi) for ai, bi in zip(a, b)]
  1561. sum_helper_fn = self._lift_helper(sum_fn, len(values))
  1562. pre_reduce_vars = ", ".join(
  1563. f"{scan_var} * (rbase == (RBLOCK - 1))"
  1564. for scan_var in partial_scan_vars
  1565. )
  1566. # tl.reduce doesn't work for non-commutative operators, so instead
  1567. # of repeating the scan op as a reduction, we use sum to select the
  1568. # last scan value
  1569. partial_reduce_vars = cse_multiple(
  1570. f"tl.reduce(({pre_reduce_vars}), -1, {sum_helper_fn}, keep_dims=True)",
  1571. len(values),
  1572. masks,
  1573. )
  1574. accs_next = combine_fn(tuple(accumulators), partial_reduce_vars)
  1575. full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars)
  1576. result_vars = [
  1577. cse_compute(f"tl.where(roffset > 0, {full_scan}, {partial_scan})")
  1578. for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars)
  1579. ]
  1580. for acc_next, accumulator, partial_reduce in zip(
  1581. accs_next, accumulators, partial_reduce_vars
  1582. ):
  1583. self.compute.writeline(
  1584. f"{accumulator} = tl.where(roffset > 0, {acc_next}, {partial_reduce})"
  1585. )
  1586. else:
  1587. result_vars = partial_scan_vars
  1588. for result_var in result_vars:
  1589. result_var.mask_vars = masks # type: ignore[attr-defined]
  1590. return tuple(result_vars)
  1591. def codegen_body(self):
  1592. """
  1593. Concat output code from index_code, loads, compute, stores,
  1594. suffix into self.body.
  1595. For pointwise kernels, this is called just once at the end.
  1596. For reduction kernels, this generates a loop over the reduction
  1597. axis.
  1598. """
  1599. if not (
  1600. self.indexing_code
  1601. or self.loads
  1602. or self.stores
  1603. or self.compute
  1604. or self.suffix
  1605. ):
  1606. return
  1607. if self.inside_reduction and self.range_trees[-1].is_loop:
  1608. self.body.writeline("for roffset in range(0, rnumel, RBLOCK):")
  1609. with self.body.indent():
  1610. # last range tree is always reduction
  1611. self.iteration_ranges_codegen_header(self.range_trees[-1], self.body)
  1612. self.body.splice(self.indexing_code)
  1613. self.body.splice(self.loads)
  1614. self.body.splice(self.compute)
  1615. self.body.splice(self.stores)
  1616. # invalidate any caches that came from inside the reduction loop
  1617. self.cse.invalidate(self.outside_loop_vars)
  1618. self.range_trees[-1].cache_clear()
  1619. else:
  1620. self.body.splice(self.indexing_code)
  1621. self.body.splice(self.loads)
  1622. self.body.splice(self.compute)
  1623. self.body.splice(self.stores)
  1624. self.body.splice(self.suffix)
  1625. self.indexing_code.clear()
  1626. self.loads.clear()
  1627. self.compute.clear()
  1628. self.stores.clear()
  1629. self.suffix.clear()
  1630. def codegen_kernel_benchmark(self, num_gb, grid=None):
  1631. result = IndentedBuffer()
  1632. argdefs, call_args, signature, _ = self.args.python_argdefs()
  1633. result.writelines(["", "", "def get_args():"])
  1634. with result.indent():
  1635. name_cnt = itertools.count()
  1636. var_names = []
  1637. for arg_name, arg_sig in zip(call_args, signature):
  1638. var_name = f"arg_{next(name_cnt)}"
  1639. buf = V.graph.get_buffer(arg_name)
  1640. if buf:
  1641. result.writeline(
  1642. f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
  1643. )
  1644. elif arg_name in V.graph.constants:
  1645. # note that random seed is put in V.graph.constants
  1646. const_tensor = V.graph.constants[arg_name]
  1647. result.writeline(
  1648. f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
  1649. )
  1650. elif isinstance(arg_sig, SizeArg):
  1651. symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
  1652. # Force the seed_offset to be 0 so calls to the same kernel
  1653. # using different seed offset will have the same benchmark harness.
  1654. # We can dedup kernel definitions in this case.
  1655. if "seed_offset" in arg_sig.name:
  1656. symval_hint = 0
  1657. result.writeline(f"{var_name} = {symval_hint}")
  1658. else:
  1659. raise KeyError(
  1660. f"Don't find the buffer or const tensor for {arg_name}"
  1661. )
  1662. var_names.append(var_name)
  1663. result.writeline(f"return {', '.join(var_names)},")
  1664. result.writelines(["\n", "\n", "def call(args):"])
  1665. if grid is None:
  1666. grid = []
  1667. extra_args = []
  1668. extra_args_str = None
  1669. for tree in self.active_range_trees():
  1670. expr = pexpr(V.graph.sizevars.size_hint(tree.numel))
  1671. extra_args.append(expr)
  1672. if tree.prefix != "r":
  1673. grid.append(expr)
  1674. if self.need_numel_args():
  1675. extra_args_str = ", ".join(map(str, extra_args)) + ", "
  1676. else:
  1677. extra_args_str = ""
  1678. grid_arg = f"{extra_args_str}grid=grid({', '.join(grid)})"
  1679. else:
  1680. grid_arg = f"grid={grid}"
  1681. current_device = V.graph.scheduler.get_current_device_or_throw()
  1682. index = current_device.index
  1683. with result.indent():
  1684. result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
  1685. with result.indent():
  1686. result.writeline(
  1687. V.graph.device_ops.set_device(index)
  1688. ) # no-op to ensure context
  1689. stream_name = f"stream{index}"
  1690. result.writeline(f"{stream_name} = get_raw_stream({index})")
  1691. result.writeline(
  1692. f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})"
  1693. )
  1694. # benchmark all configs
  1695. result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
  1696. with result.indent():
  1697. result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
  1698. with result.indent():
  1699. result.writeline(
  1700. V.graph.device_ops.set_device(index)
  1701. ) # no-op to ensure context
  1702. result.writeline(
  1703. f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})"
  1704. )
  1705. result.writelines(["\n", "\n", "if __name__ == '__main__':"])
  1706. with result.indent():
  1707. result.writeline("from triton.testing import do_bench")
  1708. result.writeline("")
  1709. result.writeline("args = get_args()")
  1710. result.writeline(
  1711. "ms = do_bench(lambda: call(args), rep=40, fast_flush=True)"
  1712. )
  1713. result.writeline(f"num_gb = {num_gb}")
  1714. result.writeline("gb_per_s = num_gb / (ms / 1e3)")
  1715. result.writeline(
  1716. 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")'
  1717. )
  1718. return result
  1719. def imports_for_benchmark_kernel(self):
  1720. return textwrap.dedent(
  1721. """
  1722. from torch._dynamo.testing import rand_strided
  1723. {}
  1724. import torch
  1725. from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
  1726. """.format(
  1727. V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
  1728. )
  1729. )
  1730. def _get_heuristic(self):
  1731. if self.persistent_reduction:
  1732. assert self.inside_reduction
  1733. return "persistent_reduction"
  1734. elif self.inside_reduction:
  1735. return "reduction"
  1736. return "pointwise"
  1737. @staticmethod
  1738. def inductor_meta_common():
  1739. inductor_meta = {
  1740. "backend_hash": torch.utils._triton.triton_hash_with_backend(),
  1741. "are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(),
  1742. "assert_indirect_indexing": config.assert_indirect_indexing,
  1743. "autotune_local_cache": config.autotune_local_cache,
  1744. "autotune_pointwise": config.triton.autotune_pointwise,
  1745. "autotune_remote_cache": config.autotune_remote_cache,
  1746. "force_disable_caches": config.force_disable_caches,
  1747. "dynamic_scale_rblock": config.dynamic_scale_rblock,
  1748. "max_autotune": config.max_autotune,
  1749. "max_autotune_pointwise": config.max_autotune_pointwise,
  1750. "min_split_scan_rblock": config.triton.min_split_scan_rblock,
  1751. "spill_threshold": config.triton.spill_threshold,
  1752. "store_cubin": config.triton.store_cubin,
  1753. }
  1754. if torch.version.hip is not None:
  1755. inductor_meta["is_hip"] = True
  1756. if config.is_fbcode():
  1757. inductor_meta["is_fbcode"] = True
  1758. if config.profile_bandwidth:
  1759. inductor_meta["profile_bandwidth"] = config.profile_bandwidth
  1760. inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex
  1761. inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output
  1762. if config.coordinate_descent_tuning:
  1763. inductor_meta[
  1764. "coordinate_descent_tuning"
  1765. ] = config.coordinate_descent_tuning
  1766. inductor_meta[
  1767. "coordinate_descent_search_radius"
  1768. ] = config.coordinate_descent_search_radius
  1769. inductor_meta[
  1770. "coordinate_descent_check_all_directions"
  1771. ] = config.coordinate_descent_check_all_directions
  1772. return inductor_meta
  1773. def codegen_kernel(self, name=None):
  1774. code = IndentedBuffer()
  1775. size_hints = []
  1776. for numel in self.numels:
  1777. numel_hint = V.graph.sizevars.symbolic_hint(numel)
  1778. if not isinstance(numel_hint, (int, sympy.Integer)):
  1779. # This default heuristic hint was picked carefully: it is
  1780. # large, to ensure that we don't shrink the block size (since
  1781. # if you don't have many elements, it'd be wasteful to pick a
  1782. # large block size). Since we don't know how many elements we
  1783. # might have, we should be OK with some inefficiency to make
  1784. # sure we handle the large case well. 8192 is the largest
  1785. # block size we support, so we pick that.
  1786. #
  1787. # If we have a better hint for unbacked SymInts (e.g., because
  1788. # a user told us, or we are tracking upper bounds) we could
  1789. # use that here.
  1790. size_hint = 8192
  1791. else:
  1792. size_hint = next_power_of_2(int(numel_hint))
  1793. size_hints.append(size_hint)
  1794. if not self.inside_reduction:
  1795. size_hints.pop()
  1796. heuristics = self._get_heuristic()
  1797. if name is None:
  1798. code.splice(gen_common_triton_imports())
  1799. if config.benchmark_kernel:
  1800. code.splice(self.imports_for_benchmark_kernel())
  1801. argdefs, _, signature, _ = self.args.python_argdefs()
  1802. # maps actual expression to SizeArg if it is in sizevars replacements
  1803. for i, arg in enumerate(signature):
  1804. if isinstance(arg, SizeArg):
  1805. # mypy is unhappy about the sympy.Expr
  1806. # type for the key of the dict below
  1807. symbol = cast(sympy.Symbol, arg.expr)
  1808. if symbol in V.graph.sizevars.inv_precomputed_replacements:
  1809. signature[i] = SizeArg(
  1810. arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol]
  1811. )
  1812. mutated_args = set()
  1813. for mutation in self.mutations:
  1814. if mutation in self.args.input_buffers:
  1815. mutated_args.add(self.args.input_buffers[mutation])
  1816. if (
  1817. mutation in self.args.inplace_buffers
  1818. and mutation not in V.graph.removed_buffers
  1819. and mutation not in self.removed_buffers
  1820. ):
  1821. mutated_args.add(self.args.inplace_buffers[mutation].inner_name)
  1822. if mutation in self.args.output_buffers:
  1823. mutated_args.add(self.args.output_buffers[mutation])
  1824. mutated_args = sorted(mutated_args)
  1825. triton_meta_signature = signature_to_meta(
  1826. signature, size_dtype=self.index_dtype
  1827. )
  1828. triton_meta = {
  1829. "signature": triton_meta_signature,
  1830. "device": DeviceProperties.create(
  1831. V.graph.scheduler.get_current_device_or_throw()
  1832. ),
  1833. "constants": {},
  1834. }
  1835. inductor_meta = {
  1836. "autotune_hints": set(self.autotune_hints),
  1837. "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
  1838. "mutated_arg_names": mutated_args,
  1839. "no_x_dim": self.no_x_dim,
  1840. "num_load": self.num_load,
  1841. "num_reduction": self.num_reduction,
  1842. **self.inductor_meta_common(),
  1843. }
  1844. num_gb = None
  1845. if config.benchmark_kernel or config.profile_bandwidth:
  1846. num_gb = self.estimate_kernel_num_bytes() / 1e9
  1847. inductor_meta["kernel_num_gb"] = num_gb
  1848. for tree in self.active_range_trees():
  1849. sizearg = SizeArg(f"{tree.prefix}numel", tree.numel)
  1850. signature.append(sizearg)
  1851. triton_meta_signature[len(argdefs)] = signature_of(
  1852. sizearg, size_dtype=self.index_dtype
  1853. )
  1854. argdefs.append(f"{tree.prefix}numel")
  1855. # constexpr version causes issues, see
  1856. # https://github.com/pytorch/torchdynamo/pull/1362
  1857. # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
  1858. # tree.numel
  1859. # )
  1860. # argdefs.append(f"{tree.prefix}numel: tl.constexpr")
  1861. triton_meta["configs"] = [config_of(signature)]
  1862. # Triton compiler includes equal_to_1 args into constants even
  1863. # when they are not constexpr. otherwise there may be a segfault
  1864. # during launching the Inductor-compiled Triton kernel.
  1865. # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
  1866. # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
  1867. for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
  1868. triton_meta["constants"][arg_num] = 1 # type: ignore[index]
  1869. self.triton_meta = triton_meta
  1870. for tree in self.range_trees:
  1871. if tree.prefix == "r" and self.persistent_reduction:
  1872. # RBLOCK for persistent_reduction is defined in codegen_static_numels
  1873. continue
  1874. if tree.tensor_dim is None:
  1875. continue
  1876. argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr")
  1877. self.codegen_body()
  1878. for helper in self.helper_functions:
  1879. code.writeline("")
  1880. code.splice(helper)
  1881. if self.inside_reduction:
  1882. reduction_hint = self.reduction_hint
  1883. heuristics_line = f"""
  1884. @triton_heuristics.{heuristics}(
  1885. size_hints={size_hints!r},
  1886. reduction_hint={reduction_hint},
  1887. filename=__file__,
  1888. triton_meta={triton_meta!r},
  1889. inductor_meta={inductor_meta!r}
  1890. )
  1891. @triton.jit
  1892. """
  1893. else:
  1894. tile_hint = ""
  1895. if len(size_hints) == 2:
  1896. if len(signature) == 4: # input, output and 2 args
  1897. tile_hint = "tile_hint=TileHint.SQUARE,"
  1898. else:
  1899. tile_hint = "tile_hint=TileHint.DEFAULT,"
  1900. heuristics_line = f"""
  1901. @triton_heuristics.{heuristics}(
  1902. size_hints={size_hints!r}, {tile_hint}
  1903. filename=__file__,
  1904. triton_meta={triton_meta!r},
  1905. inductor_meta={inductor_meta!r},
  1906. min_elem_per_thread={self.min_elem_per_thread}
  1907. )
  1908. @triton.jit
  1909. """
  1910. code.splice(heuristics_line)
  1911. code.writeline(
  1912. f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
  1913. )
  1914. with code.indent():
  1915. self.codegen_static_numels(code)
  1916. for old, new in self.args.aliases():
  1917. code.writeline(f"{old} = {new}")
  1918. code.splice(self.body)
  1919. if config.benchmark_kernel:
  1920. code.splice(self.codegen_kernel_benchmark(num_gb))
  1921. return code.getvalue()
  1922. def codegen_static_numels(self, code):
  1923. """
  1924. We get a small speedup from hard coding numels if they are static.
  1925. This code stomps on the passed-in values by writing an constant to the top of the kernel.
  1926. In a kernel like:
  1927. def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
  1928. We would add
  1929. xnumel = 4096
  1930. rnumel = 768
  1931. After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
  1932. a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
  1933. knows that its a static numel, as that you just plop a constant into the kernel.
  1934. """
  1935. for tree in self.range_trees:
  1936. if tree.prefix != "r" or self.inside_reduction:
  1937. simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
  1938. if isinstance(simplified_tree_numel, (sympy.Integer, int)):
  1939. code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}")
  1940. if tree.prefix == "r" and self.persistent_reduction:
  1941. simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
  1942. if isinstance(simplified_tree_numel, (sympy.Integer, int)):
  1943. val = int(simplified_tree_numel)
  1944. val = next_power_of_2(val)
  1945. else:
  1946. val = 128
  1947. while not V.graph.sizevars.statically_known_leq(
  1948. simplified_tree_numel, val
  1949. ):
  1950. assert (
  1951. val <= 16 * 1024
  1952. ), f"Failed to find static RBLOCK for {simplified_tree_numel}"
  1953. val *= 2
  1954. code.writeline(f"RBLOCK: tl.constexpr = {val}")
  1955. if tree.prefix == "x" and self.no_x_dim:
  1956. code.writeline("XBLOCK: tl.constexpr = 1")
  1957. def _get_grid_fn(self):
  1958. return "grid"
  1959. def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid):
  1960. # TODO(jansel): if there are constants, we shouldn't bother passing them as args
  1961. for tree in self.range_trees:
  1962. if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
  1963. expr = tree.numel
  1964. else:
  1965. expr = V.graph.wrapper_code.generate_numel_expr(name, tree)
  1966. if tree.prefix != "r" or self.inside_reduction:
  1967. call_args.append(expr)
  1968. arg_types.append(type(expr))
  1969. if tree.grid_dim is not None:
  1970. grid.append(expr)
  1971. def get_call_args(self):
  1972. # arg_types is needed for cpp wrapper codegen
  1973. _, call_args, _, arg_types = self.args.python_argdefs()
  1974. # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
  1975. for i in range(len(call_args)):
  1976. if V.graph.is_unspec_arg(call_args[i]):
  1977. call_args[i] = call_args[i] + ".item()"
  1978. return call_args, arg_types
  1979. def call_kernel(self, name: str, node: Optional[IRNode] = None):
  1980. wrapper = V.graph.wrapper_code
  1981. call_args, arg_types = self.get_call_args()
  1982. grid: List[Any] = []
  1983. self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid)
  1984. current_device = V.graph.scheduler.get_current_device_or_throw()
  1985. if self.args.workspace_arg is not None:
  1986. ws = self.args.workspace_arg
  1987. wrapper.generate_workspace_allocation(
  1988. ws.nbytes, current_device, ws.zero_fill
  1989. )
  1990. grid = wrapper.generate_default_grid(name, grid)
  1991. wrapper.generate_kernel_call(
  1992. name,
  1993. call_args,
  1994. grid,
  1995. current_device.index,
  1996. cuda=True,
  1997. triton=True,
  1998. arg_types=arg_types,
  1999. grid_fn=self._get_grid_fn(),
  2000. triton_meta=self.triton_meta,
  2001. )
  2002. if self.args.workspace_arg is not None:
  2003. wrapper.writeline(wrapper.make_free_by_names(["workspace"]))
  2004. def codegen_nan_check(self):
  2005. wrapper = V.graph.wrapper_code
  2006. _, call_args, arg_types, _ = self.args.python_argdefs()
  2007. for arg, arg_type in zip(call_args, arg_types):
  2008. if isinstance(arg_type, TensorArg):
  2009. if V.graph.cpp_wrapper:
  2010. if config.abi_compatible:
  2011. wrapper.writeline(
  2012. f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));'
  2013. )
  2014. else:
  2015. wrapper.writeline(f'assert_inf_and_nan("{arg}", {arg});')
  2016. else:
  2017. line = f"assert not {arg}.isnan().any().item()"
  2018. wrapper.writeline(line)
  2019. line = f"assert not {arg}.isinf().any().item()"
  2020. wrapper.writeline(line)
  2021. def create_cse_var(self, *args, **kwargs):
  2022. return TritonCSEVariable(*args, **kwargs)
  2023. def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
  2024. line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}"
  2025. if entry.root.is_loop:
  2026. self.indexing_code.writeline(line)
  2027. else:
  2028. # lift non-reduction stores outside loop
  2029. self.body.writeline(line)
  2030. def iteration_ranges_ranges_code(self, entry):
  2031. assert entry.tensor_dim is not None
  2032. size = self.indexing_size_str(entry.tensor_dim)
  2033. index_dtype = self.index_dtype
  2034. convert = f".to({index_dtype})" if index_dtype != "tl.int32" else ""
  2035. return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{convert}"
  2036. def iteration_ranges_scalar_code(self, entry, value):
  2037. index_dtype = self.index_dtype
  2038. ndim = self.triton_tensor_ndim()
  2039. size = [1] * ndim
  2040. return f"tl.full({size}, {value}, {index_dtype})"
  2041. def iteration_ranges_get_pid(self, entry):
  2042. assert entry.grid_dim is not None
  2043. key = f"tl.program_id({entry.grid_dim})"
  2044. # y_grid has a limit, so express it in terms of y and z in case of overflow.
  2045. # z grid is only exercised when max_tiles == 3 (off by default).
  2046. if (
  2047. entry.grid_dim == 1
  2048. and not entry.has_zdim
  2049. and not (isinstance(entry.numel, int) and entry.numel <= get_max_y_grid())
  2050. ):
  2051. # For ynumel larger than max_ygrid, we need to use zdim.
  2052. # For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z).
  2053. # So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset.
  2054. key = f"({key} + tl.program_id({entry.grid_dim + 1}) * tl.num_programs({entry.grid_dim}))"
  2055. pid = entry.pid_cache.get(key, key)
  2056. if self.index_dtype != "tl.int32":
  2057. return f"{pid}.to({self.index_dtype})"
  2058. return pid
  2059. def iteration_ranges_codegen_header(self, entry, code):
  2060. x = entry.prefix
  2061. if entry.is_loop:
  2062. code.writeline(f"{entry.name} = {x}offset + {x}base")
  2063. elif entry.grid_dim is None:
  2064. # no need to "{x}offset = "
  2065. code.writeline(f"{entry.name} = {self.iteration_ranges_ranges_code(entry)}")
  2066. code.writeline(f"{x}offset = 0")
  2067. else:
  2068. if entry.tensor_dim is not None:
  2069. line = f"{x}offset + {self.iteration_ranges_ranges_code(entry)}"
  2070. else:
  2071. line = self.iteration_ranges_scalar_code(entry, f"{x}offset")
  2072. code.writelines(
  2073. [
  2074. f"{x}offset = {self.iteration_ranges_get_pid(entry)} * {x.upper()}BLOCK",
  2075. f"{entry.name} = {line}",
  2076. ]
  2077. )
  2078. code.writeline(f"{x}mask = {entry.name} < {x}numel")
  2079. class TritonScheduling(SIMDScheduling):
  2080. int32_type = "tl.int32"
  2081. int64_type = "tl.int64"
  2082. kernel_type = TritonKernel
  2083. def codegen_comment(self, node_schedule):
  2084. wrapper = V.graph.wrapper_code
  2085. origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
  2086. if origins:
  2087. wrapper.writeline(origins)
  2088. if config.debug_fusion:
  2089. from torch._inductor.scheduler import (
  2090. BaseSchedulerNode,
  2091. ForeachKernelSchedulerNode,
  2092. )
  2093. if not any(
  2094. isinstance(n, ForeachKernelSchedulerNode) for n in node_schedule
  2095. ):
  2096. # We probably should look what are the nodes inside a foreach
  2097. # schedule node
  2098. node_names = [
  2099. n.get_name()
  2100. for n in node_schedule
  2101. if isinstance(n, BaseSchedulerNode)
  2102. ]
  2103. wrapper.writeline(
  2104. f"{wrapper.comment} Fused node name list: {', '.join(node_names)}"
  2105. )
  2106. def define_kernel(self, src_code, node_schedule, kernel):
  2107. wrapper = V.graph.wrapper_code
  2108. if src_code in wrapper.src_to_kernel:
  2109. kernel_name = wrapper.src_to_kernel[src_code]
  2110. else:
  2111. fused_name = (
  2112. get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
  2113. if config.triton.descriptive_names
  2114. else ""
  2115. )
  2116. kernel_category = get_kernel_category_by_source_code(src_code)[:3]
  2117. kernel_name = "_".join(
  2118. ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()]
  2119. )
  2120. # use the original src_code as the key
  2121. wrapper.src_to_kernel[src_code] = kernel_name
  2122. subs_name = kernel_name if config.triton.unique_kernel_names else "triton_"
  2123. # DESCRIPTIVE_NAME is used for profiling purposes; it shows the full kernel name
  2124. # even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set
  2125. # to "triton_" to maximize caching opportunities (when unique_kernel_names = False).
  2126. src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
  2127. src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name)
  2128. # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
  2129. # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
  2130. src_code = src_code.replace("#pragma CMT", "#")
  2131. basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py")
  2132. compile_wrapper = IndentedBuffer()
  2133. compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''")
  2134. compile_wrapper.splice(src_code, strip=True)
  2135. current_device = V.graph.scheduler.get_current_device_or_throw()
  2136. compile_wrapper.writeline(f"''', device_str='{current_device.type}')")
  2137. metadata_comment = f"# kernel path: {kernel_path}"
  2138. origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
  2139. metadata_comment += "\n" + origins + "\n" + detailed_origins
  2140. wrapper.define_kernel(
  2141. kernel_name, compile_wrapper.getvalue(), metadata_comment
  2142. )
  2143. # log kernel metadata for offline analysis.
  2144. # E.g. one can find all unaligned inner reduction and check if
  2145. # padding helps with the perf kernel by kernel.
  2146. if is_metric_table_enabled("kernel_metadata"):
  2147. log_kernel_metadata(kernel_name, kernel_path, src_code)
  2148. return kernel_name
  2149. @preserve_rng_state()
  2150. def benchmark_fused_nodes(self, nodes):
  2151. src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True)
  2152. mod = PyCodeCache.load(src_code)
  2153. def cache_file_path():
  2154. assert mod.__file__ is not None
  2155. return os.path.splitext(mod.__file__)[0] + ".kernel_perf"
  2156. def load_cache():
  2157. path = cache_file_path()
  2158. if os.path.exists(path):
  2159. with open(path) as fd:
  2160. return float(fd.read())
  2161. return None
  2162. def store_cache():
  2163. path = cache_file_path()
  2164. with open(path, "w") as fd:
  2165. fd.write(str(ms))
  2166. log.debug(
  2167. "kernel src code for %s written to: %s",
  2168. {n.get_name() for n in nodes},
  2169. mod.__file__,
  2170. )
  2171. ms = load_cache()
  2172. if ms is not None:
  2173. return ms, mod.__file__
  2174. args = mod.get_args()
  2175. call = mod.call
  2176. wrapped_jit_function = mod.triton_
  2177. # call once to trigger the compilation
  2178. try:
  2179. call(wrapped_jit_function.clone_args(*args)[0])
  2180. except Exception as e:
  2181. log.debug(
  2182. "Exception (%s) in compiling fused nodes %s",
  2183. e,
  2184. {n.get_name() for n in nodes},
  2185. )
  2186. ms = float("inf")
  2187. store_cache()
  2188. return ms, mod.__file__
  2189. launchers = wrapped_jit_function.launchers
  2190. assert len(launchers) == 1
  2191. if launchers[0].n_spills > 0:
  2192. # skip benchmarking the kernel if there are register spills
  2193. ms = float("inf")
  2194. else:
  2195. # We have to clone the inplace updated arguments to avoid earlier calls
  2196. # generating out of range indices for later calls.
  2197. ms = do_bench_gpu(lambda: call(wrapped_jit_function.clone_args(*args)[0]))
  2198. # overhead of cloning args gives bias for fusing the kernel
  2199. # in the case of mutating/in-placeable second fusion
  2200. # TODO - would be better as a hook in triton do_bench that reset
  2201. # the input values between benchmarking
  2202. ms = ms - do_bench_gpu(lambda: wrapped_jit_function.clone_args(*args))
  2203. log.debug(
  2204. "The fused kernel for %s took %.3f ms to run",
  2205. {n.get_name() for n in nodes},
  2206. ms,
  2207. )
  2208. store_cache()
  2209. return ms, mod.__file__