simd.py 65 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import collections
  4. import contextlib
  5. import dataclasses
  6. import functools
  7. import itertools
  8. import logging
  9. import math
  10. import operator
  11. from typing import (
  12. Any,
  13. Callable,
  14. Counter,
  15. DefaultDict,
  16. Dict,
  17. Iterable,
  18. List,
  19. Optional,
  20. Sequence,
  21. Set,
  22. Tuple,
  23. Union,
  24. )
  25. import sympy
  26. import torch
  27. import torch._logging
  28. from torch.utils._sympy.functions import FloorDiv, ModularIndexing
  29. from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
  30. from ..._dynamo.utils import counters
  31. from .. import config, ir, scheduler
  32. from ..codecache import code_hash
  33. from ..dependencies import Dep, MemoryDep, StarDep, WeakDep
  34. from ..ir import TritonTemplateBuffer
  35. from ..optimize_indexing import indexing_dtype_strength_reduction
  36. from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
  37. from ..runtime.runtime_utils import green_text, yellow_text
  38. from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
  39. from ..utils import (
  40. get_dtype_size,
  41. IndentedBuffer,
  42. Placeholder,
  43. sympy_index_symbol,
  44. sympy_product,
  45. sympy_subs,
  46. unique,
  47. )
  48. from ..virtualized import ops, OpsWrapper, V
  49. from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter
  50. from .multi_kernel import MultiKernel
  51. log = logging.getLogger(__name__)
  52. perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
  53. schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
  54. fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
  55. pexpr = PythonPrinter().doprint
  56. @dataclasses.dataclass
  57. class IterationRanges:
  58. """
  59. Each range tree represents multiple sets of iteration indexing
  60. in a single tiled dimension in the output kernel.
  61. If you have two loops ranges one (4, 3, 2) and another (4, 6),
  62. then the range tree will be:
  63. 4 (i0)
  64. 3 (i1) 6 (i3)
  65. 2 (i2)
  66. Where i0 is shared between both loops, but then the split into
  67. different indexing vars. All loop ranges must iterate over
  68. the same number of elements.
  69. """
  70. def __init__(
  71. self,
  72. name: str,
  73. var_list: List[sympy.Symbol],
  74. var_ranges: Dict[sympy.Symbol, sympy.Expr],
  75. numel: sympy.Expr,
  76. prefix: str,
  77. *,
  78. kernel: SIMDKernel,
  79. divisor=sympy.Integer(1),
  80. length=sympy.Integer(1),
  81. root: IterationRangesRoot,
  82. ):
  83. super().__init__()
  84. self.name = name
  85. self.var_list = var_list
  86. self.var_ranges = var_ranges
  87. self.numel = numel
  88. self.prefix = prefix
  89. self.divisor = divisor
  90. self.length = length
  91. self.kernel = kernel
  92. self.root = root
  93. def symbol(self):
  94. return sympy_index_symbol(self.name)
  95. class IterationRangesRoot(IterationRanges):
  96. def __init__(
  97. self,
  98. name: str,
  99. numel: sympy.Expr,
  100. # TODO: this is probably SymTy.INDEX and SymTy.RINDEX
  101. prefix: str,
  102. index: int,
  103. kernel: SIMDKernel,
  104. pid_cache=None,
  105. *,
  106. is_loop: bool,
  107. tensor_dim: Optional[int],
  108. grid_dim: Optional[int],
  109. has_zdim: bool,
  110. ):
  111. if pid_cache is None:
  112. pid_cache = {}
  113. super().__init__(
  114. name=name,
  115. var_list=[],
  116. var_ranges={},
  117. numel=numel,
  118. prefix=prefix,
  119. kernel=kernel,
  120. root=self,
  121. )
  122. self.index = index
  123. # Store all the nodes in one flat list
  124. self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {}
  125. # This is for re-ordering program ID in triton mm template
  126. # pid_cache["tl.program_id(0)"] = pid_m
  127. self.pid_cache: Dict[str, str] = pid_cache
  128. # True if the dimension is implemented as a single program looping over
  129. # the full dimension (currently only used for non-persistent reduction)
  130. assert not is_loop or (prefix == "r" and grid_dim is None)
  131. self.is_loop = is_loop
  132. # Index of corresponding dimension on triton tensors
  133. self.tensor_dim = tensor_dim
  134. # Index of corresponding dimension in the triton grid
  135. self.grid_dim = grid_dim
  136. self.has_zdim = has_zdim
  137. def __repr__(self):
  138. return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)"
  139. def cache_clear(self):
  140. for node in self.nodes.values():
  141. node.cache_clear()
  142. def lookup(self, divisor, length):
  143. """
  144. Lookup a given RangeTreeEntry, creating it if needed
  145. """
  146. if V.graph.sizevars.statically_known_equals(divisor * length, self.numel):
  147. expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor)
  148. else:
  149. expr = ModularIndexing(
  150. sympy_index_symbol(f"{self.prefix}index"), divisor, length
  151. )
  152. if expr not in self.nodes:
  153. node = IterationRangesEntry(
  154. f"{self.prefix}{next(V.kernel.iter_vars_count)}",
  155. divisor,
  156. length,
  157. expr,
  158. self,
  159. )
  160. V.kernel.range_tree_nodes[node.symbol()] = node
  161. self.var_list.append(node.symbol())
  162. self.var_ranges[node.symbol()] = length
  163. self.nodes[expr] = node
  164. return self.nodes[expr]
  165. def construct_entries(self, lengths: List[sympy.Expr]):
  166. divisor = sympy.Integer(1)
  167. itervars = []
  168. for length in reversed(lengths):
  169. itervars.append(self.lookup(divisor, length))
  170. divisor = divisor * length
  171. return list(reversed(itervars))
  172. def construct(self, lengths: List[sympy.Expr]):
  173. return [e.symbol() for e in self.construct_entries(lengths)]
  174. def vars_and_sizes(self, index: sympy.Expr):
  175. """Figure out vars from this tree used in index"""
  176. nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols]
  177. nodes = [n for n in nodes if n and n.prefix == self.prefix]
  178. nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor))
  179. divisor = sympy.Integer(1)
  180. index_vars = []
  181. sizes = []
  182. def add(node):
  183. nonlocal divisor
  184. index_vars.append(node.symbol())
  185. sizes.append(node.length)
  186. divisor = divisor * node.length
  187. for node in nodes:
  188. if not V.graph.sizevars.statically_known_equals(node.divisor, divisor):
  189. # fill in unused index var
  190. add(self.lookup(divisor, FloorDiv(node.divisor, divisor)))
  191. divisor = node.divisor
  192. add(node)
  193. if not V.graph.sizevars.statically_known_equals(self.numel, divisor):
  194. # fill in unused index var
  195. add(self.lookup(divisor, FloorDiv(self.numel, divisor)))
  196. return list(reversed(index_vars)), list(reversed(sizes))
  197. class IterationRangesEntry(IterationRanges):
  198. def __init__(
  199. self,
  200. name: str,
  201. divisor: sympy.Expr,
  202. length: sympy.Expr,
  203. expr: sympy.Expr,
  204. parent: IterationRanges,
  205. ):
  206. super().__init__(
  207. name=name,
  208. numel=parent.numel / length,
  209. var_list=parent.var_list,
  210. var_ranges=parent.var_ranges,
  211. prefix=parent.prefix,
  212. divisor=divisor,
  213. length=length,
  214. kernel=parent.kernel,
  215. root=parent.root,
  216. )
  217. self.parent = parent
  218. self.codegen = functools.lru_cache(None)(self._codegen)
  219. self.expr = expr
  220. def __repr__(self):
  221. return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})"
  222. def set_name(self, name):
  223. self.codegen = lambda: name # type: ignore[assignment]
  224. self.codegen.cache_clear = lambda: None # type: ignore[method-assign]
  225. self.name = name
  226. def cache_clear(self):
  227. self.codegen.cache_clear()
  228. def _codegen(self):
  229. V.kernel.codegen_iteration_ranges_entry(self)
  230. return self.name
  231. def precomputed_args(self):
  232. # for dynamic shapes, find parts of indexing expressions that have to be precomputed
  233. precomputed_args: List[sympy.Expr] = []
  234. if isinstance(self.expr, sympy.Symbol):
  235. return precomputed_args
  236. assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr)
  237. for arg in self.expr.args[1:]:
  238. if not isinstance(arg, (sympy.Integer, sympy.Symbol)):
  239. symbols = arg.free_symbols
  240. if len(symbols) > 0 and all(
  241. symbol_is_type(s, SymT.SIZE) for s in symbols
  242. ):
  243. precomputed_args.append(arg)
  244. return precomputed_args
  245. def __hash__(self):
  246. return hash(self.name)
  247. def __eq__(self, other):
  248. return self.name == other.name
  249. def constant_repr(value):
  250. if value == float("inf"):
  251. return 'float("inf")'
  252. elif value == float("-inf"):
  253. return 'float("-inf")'
  254. elif math.isnan(value):
  255. return 'float("nan")'
  256. return repr(value)
  257. class SIMDKernel(Kernel):
  258. """
  259. Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests.
  260. """
  261. sexpr = pexpr
  262. kexpr: Callable[[sympy.Expr], str]
  263. allow_block_ptr = False
  264. def __init__(
  265. self,
  266. *groups,
  267. index_dtype: str,
  268. mutations: Optional[Set[str]] = None,
  269. pid_cache=None,
  270. reduction_hint=ReductionHint.DEFAULT,
  271. disable_persistent_reduction=False,
  272. ):
  273. if pid_cache is None:
  274. pid_cache = {}
  275. super().__init__()
  276. self.body = IndentedBuffer()
  277. self.indexing_code = IndentedBuffer()
  278. self.numels = [V.graph.sizevars.simplify(s) for s in groups]
  279. self.mutations: Set[str] = mutations if mutations is not None else set()
  280. self.range_trees: List[IterationRangesRoot] = []
  281. self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {}
  282. self.iter_vars_count = itertools.count()
  283. self.inside_reduction = self.numels[-1] != 1
  284. self.reduction_hint = reduction_hint
  285. self.index_dtype: str = index_dtype
  286. self.last_usage: Set[str] = set()
  287. self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list)
  288. self.persistent_reduction: bool = (
  289. not disable_persistent_reduction
  290. ) and self.should_use_persistent_reduction()
  291. self.no_x_dim = self.want_no_x_dim()
  292. self.code_hash = None
  293. # define this in a closure to make cache local to object
  294. @functools.lru_cache(None)
  295. def simplify_indexing(index: sympy.Expr):
  296. index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges())
  297. for tree in self.range_trees:
  298. index = self.combine_contiguous_dims(index, tree)
  299. return self.combine_modular_indexing_pairs(index)
  300. self.simplify_indexing = simplify_indexing
  301. self.initialize_range_tree(pid_cache)
  302. def want_no_x_dim(self):
  303. return False
  304. def initialize_range_tree(self, pid_cache):
  305. no_r_dim = not self.inside_reduction or self.numels[-1] == 1
  306. prefixes = "zyxr"
  307. active_prefixes = prefixes[-len(self.numels) :]
  308. grid_dims = "xyz"
  309. if self.no_x_dim:
  310. tensor_dims = "r"
  311. elif no_r_dim:
  312. tensor_dims = "xyz"
  313. else:
  314. tensor_dims = "xyzr"
  315. tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes)
  316. for i, prefix in enumerate(active_prefixes):
  317. is_reduction = prefix == "r"
  318. tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None
  319. grid_dim = None if is_reduction else grid_dims.find(prefix)
  320. index = i if grid_dim is None else grid_dim
  321. self.range_trees.append(
  322. IterationRangesRoot(
  323. f"{prefix}index",
  324. self.numels[i],
  325. prefix,
  326. index,
  327. self,
  328. pid_cache=pid_cache,
  329. is_loop=is_reduction and not self.persistent_reduction,
  330. tensor_dim=tensor_dim,
  331. grid_dim=grid_dim,
  332. has_zdim="z" in active_prefixes,
  333. )
  334. )
  335. def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
  336. prior = self.inside_reduction
  337. self.inside_reduction = False
  338. try:
  339. return self.store(name, index, value)
  340. finally:
  341. self.inside_reduction = prior
  342. def should_use_persistent_reduction(self) -> bool:
  343. return False # defined in subclass
  344. def var_ranges(self):
  345. return dict(
  346. itertools.chain.from_iterable(
  347. tree.var_ranges.items() for tree in self.range_trees
  348. )
  349. )
  350. def triton_tensor_ndim(self):
  351. return sum(int(tree.tensor_dim is not None) for tree in self.range_trees)
  352. def indexing_size_str(self, i):
  353. sizes = ["None"] * self.triton_tensor_ndim()
  354. sizes[i] = ":"
  355. return f"[{', '.join(sizes)}]"
  356. def dense_size_list(self) -> List[str]:
  357. sizes = ["1"] * self.triton_tensor_ndim()
  358. for tree in self.range_trees:
  359. if tree.tensor_dim is None:
  360. continue
  361. if tree.prefix != "r" or self.inside_reduction:
  362. sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK"
  363. return sizes
  364. def dense_size_str(self):
  365. sizes = self.dense_size_list()
  366. return f"[{', '.join(sizes)}]"
  367. def combine_modular_indexing_pairs(self, index):
  368. if not isinstance(index, ModularIndexing):
  369. return index
  370. x = index.args[0]
  371. if (tree_node := self.range_tree_nodes.get(x)) is None:
  372. return index
  373. new_index = sympy_subs(index, {x: tree_node.expr})
  374. return V.graph.sizevars.combine_modular_indexing_pairs(new_index)
  375. def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
  376. if expand_res := V.graph.sizevars.expand_floor_div(index):
  377. new_index, denominator = expand_res # type: ignore[misc]
  378. return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator)
  379. else:
  380. return self._combine_contiguous_dims(index, tree)
  381. def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
  382. """
  383. More aggressive simplification to merge contiguous dims
  384. """
  385. if isinstance(index, (sympy.Integer, sympy.Symbol)):
  386. return index
  387. index_vars, sizes = tree.vars_and_sizes(index)
  388. if len(sizes) <= 1:
  389. return index
  390. new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
  391. index_vars, sizes, index_prevent_reordering([index], index_vars, sizes)
  392. )
  393. if new_sizes == sizes:
  394. return index
  395. new_index_vars = tree.construct(new_sizes)
  396. new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars))))
  397. return new_index
  398. def set_last_usage(self, nodes):
  399. if not self.inside_reduction or self.persistent_reduction:
  400. return
  401. self.last_usage = set(
  402. itertools.chain.from_iterable(
  403. n.last_usage for n in nodes if n is not EnableReduction
  404. )
  405. )
  406. def disable_reduction(self):
  407. should_flush = self.range_trees[-1].is_loop
  408. @contextlib.contextmanager
  409. def ctx():
  410. if self.numels[-1] == 1:
  411. assert not self.inside_reduction
  412. yield
  413. return
  414. if should_flush:
  415. # calling codegen_body() will flush all the pending buffers
  416. # and write out a reduction loop
  417. self.codegen_body()
  418. self.inside_reduction = False
  419. try:
  420. yield
  421. if should_flush:
  422. # flush out any code before opening the next loop
  423. self.codegen_body()
  424. finally:
  425. self.inside_reduction = True
  426. return ctx()
  427. def set_ranges(self, *lengths):
  428. assert len(lengths) == len(self.range_trees)
  429. return [
  430. ranges.construct(length)
  431. for length, ranges in zip(lengths, self.range_trees)
  432. ]
  433. @staticmethod
  434. def _split_iteration_ranges(
  435. groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]]
  436. ):
  437. sv = V.graph.sizevars
  438. new_ranges: List[List[sympy.Expr]] = [[] for _ in groups]
  439. remaining = [sv.simplify(g) for g in groups]
  440. var_count = itertools.count()
  441. def add_range(i, expr):
  442. expr = sv.simplify(expr)
  443. if not sv.statically_known_multiple_of(remaining[i], expr):
  444. raise CantSplit
  445. # guard on the last item out
  446. remaining[i] = FloorDiv(remaining[i], expr)
  447. new_ranges[i].append(expr)
  448. return next(var_count)
  449. def make_combined(size, idx1, idx2):
  450. def getter(flat_vars):
  451. return size * flat_vars[idx1] + flat_vars[idx2]
  452. return getter
  453. return_getters_groups = []
  454. current_group = 0
  455. for length_group in lengths:
  456. return_getters = []
  457. for size in length_group:
  458. if sv.statically_known_equals(size, 1): # type: ignore[arg-type]
  459. return_getters.append(lambda _: sympy.Integer(0))
  460. continue
  461. while current_group < len(remaining) and sv.statically_known_equals(
  462. remaining[current_group], 1 # type: ignore[arg-type]
  463. ):
  464. # scroll to next group with remaining elements
  465. current_group += 1
  466. if current_group + 1 < len(remaining) and sv.statically_known_gt(
  467. size, remaining[current_group]
  468. ):
  469. # need to break size in two
  470. if not sv.statically_known_multiple_of(
  471. size, remaining[current_group]
  472. ):
  473. raise CantSplit
  474. size1 = remaining[current_group]
  475. size2 = FloorDiv(size, remaining[current_group])
  476. return_getters.append(
  477. make_combined(
  478. size2,
  479. add_range(current_group, size1),
  480. add_range(current_group + 1, size2),
  481. )
  482. )
  483. else:
  484. return_getters.append(
  485. operator.itemgetter(add_range(current_group, size))
  486. )
  487. return_getters_groups.append(return_getters)
  488. assert all(
  489. V.graph.sizevars.size_hint(s) == 1 for s in remaining
  490. ), f"failed to set ranges {remaining} {lengths}"
  491. return new_ranges, return_getters_groups
  492. @classmethod
  493. def is_compatible(
  494. cls, groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]]
  495. ):
  496. try:
  497. cls._split_iteration_ranges(groups, lengths)
  498. return True
  499. except CantSplit:
  500. return False
  501. def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]):
  502. """
  503. We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1).
  504. To do this we need to split up the iteration space of i0 into something like:
  505. for i1 in s0:
  506. for i2 in s1:
  507. i0 = i1*s1 + i2
  508. ....
  509. This function matches and resplits lengths to the groups of
  510. this kernel to enable tiled + non-tiled fusions.
  511. """
  512. groups = [rt.numel for rt in self.range_trees]
  513. if not self.inside_reduction:
  514. groups[-1] = sympy.Integer(1)
  515. if len(lengths) == len(self.range_trees) and all(
  516. V.graph.sizevars.simplify(sympy_product(x) - g) == 0
  517. for x, g in zip(lengths, groups)
  518. ):
  519. return self.set_ranges(*lengths)
  520. new_ranges, return_getters_groups = self._split_iteration_ranges(
  521. groups, lengths
  522. )
  523. itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges)))
  524. return [[fn(itervars) for fn in fns] for fns in return_getters_groups]
  525. def is_indirect_indexing(self, index: sympy.Expr):
  526. # tmpX means indirect indexing
  527. return free_symbol_is_type(index, SymT.TMP)
  528. def is_broadcasted(self, index: sympy.Expr):
  529. # Note. This may not be correct when there is indirect indexing
  530. if self.is_indirect_indexing(index):
  531. return False
  532. index_numels = [1] * len(self.numels)
  533. for symbol in index.free_symbols:
  534. if symbol not in self.range_tree_nodes:
  535. # Non-iterated variables, e.g. strides
  536. continue
  537. entry = self.range_tree_nodes[symbol] # type: ignore[index]
  538. assert isinstance(entry.parent, IterationRangesRoot)
  539. index_numels[entry.parent.index] *= entry.length
  540. # If the index variables only iterate over a subset of the kernel
  541. # numels, then it must be broadcasted.
  542. simplify = V.graph.sizevars.simplify
  543. return any(
  544. simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type]
  545. for idx_range, iter_range in zip(index_numels, self.numels)
  546. )
  547. def index_to_str(self, index: sympy.Expr) -> str:
  548. """
  549. Convert an index expr to a string that can be used in output code.
  550. e.g. a sympy expression "s2" may actually appear as "ks1" in the generated kernel.
  551. Index expressions often need to be passed in as arguments to the triton kernel.
  552. Rename_indexing and codegen_indexing keep track of the needed indices and add
  553. new parameters to the function signature.
  554. """
  555. if isinstance(index, list):
  556. return f"[{', '.join(map(self.index_to_str, index))}]"
  557. return self.kexpr(self.rename_indexing(index)) # type: ignore[call-arg]
  558. def prepare_indexing(
  559. self,
  560. index: sympy.Expr,
  561. ):
  562. index = self.simplify_indexing(index)
  563. index = sympy_subs(index, V.graph.sizevars.precomputed_replacements)
  564. # if simple replacements didn't get rid of floor/ceil, try full subs
  565. if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)):
  566. index = index.subs(V.graph.sizevars.precomputed_replacements)
  567. # last resort, if no range vars are in the expr, hoist it
  568. # TODO instead of trying to blindly find complicated exprs, we should hoist the
  569. # inputs/outputs sizes and strides, but at the time indexing is generated
  570. # kernel inputs and outputs are not set yet, we'd need a deeper refactor
  571. # to do it this way
  572. if len(index.atoms(sympy.ceiling)):
  573. for a in index.atoms(sympy.ceiling):
  574. # for nested exprs, atoms yields top level first (?)
  575. # so if everything goes fine, lower level replacements will come up empty
  576. symbols = a.free_symbols
  577. if len(symbols) > 0 and all(
  578. symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE))
  579. for s in symbols
  580. ):
  581. replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)}
  582. index = sympy_subs(index, replacements)
  583. return self.codegen_indexing(self.simplify_indexing(index))
  584. def active_range_trees(self, reorder=False):
  585. trees = [
  586. t for t in self.range_trees if t.prefix != "r" or self.inside_reduction
  587. ]
  588. if reorder and len(trees) > 1:
  589. count = sum(t.prefix in "xyz" for t in trees)
  590. assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [
  591. t.prefix for t in trees[:count]
  592. ]
  593. trees[:count] = reversed(trees[:count])
  594. return trees
  595. def filter_masks(self, mask_vars):
  596. for tree in self.range_trees:
  597. # Masks are superfluous if we only have one element
  598. if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type]
  599. mask_vars.discard(f"{tree.prefix}mask")
  600. continue
  601. # Masks are superfluous if numel is a multiple of BLOCK
  602. # (We use the fact that BLOCK is required by triton to be a power of 2)
  603. if tree.prefix.upper() not in TRITON_MAX_BLOCK:
  604. continue
  605. max_block = TRITON_MAX_BLOCK[tree.prefix.upper()]
  606. # Optional optimization: if block divides numel exactly, we will
  607. # never need to do a masked load to handle stragglers at the end.
  608. # It's faster to avoid masking at all. But it is sound to always
  609. # mask.
  610. if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type]
  611. mask_vars.discard(f"{tree.prefix}mask")
  612. def codegen_indexing(self, expr: sympy.Expr):
  613. expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges())
  614. for sym in sorted(expr.free_symbols, key=str):
  615. if sym in self.range_tree_nodes:
  616. # if indexing expression is complicated, we precompute it on the host side
  617. # and send the result as a kernel argument
  618. replacements = {}
  619. for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index]
  620. replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
  621. if len(replacements) > 0:
  622. self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index]
  623. self.range_tree_nodes[sym].expr, replacements # type: ignore[index]
  624. )
  625. self.range_tree_nodes[sym].codegen() # type: ignore[index]
  626. return expr
  627. @contextlib.contextmanager
  628. def mask_loads(self, mask):
  629. """Context manager to add an additional mask to tl.load/store"""
  630. prior = self._load_mask
  631. if prior:
  632. mask = ops.logical_and(mask, prior)
  633. mask = OpsWrapper._unwrap(mask)
  634. self._load_mask = mask
  635. try:
  636. # TODO(jansel): do we need a reshape here?
  637. yield mask
  638. finally:
  639. self._load_mask = prior
  640. def get_strides_of_load(self, index: sympy.Expr):
  641. """
  642. This gets the stride of the index for each of the tiling variables
  643. (technically, it does it at index 0)
  644. For example, if
  645. xindex = x0 + 512*x1 + 1024*r0
  646. x0 = (xindex//512)
  647. x1 = (xindex % 512)
  648. r0 = rindex // 1024
  649. this function would return
  650. {xindex: 512, rindex: 1024}
  651. """
  652. index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()}
  653. index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type]
  654. strides = {}
  655. for range_tree in self.range_trees:
  656. s = sympy_index_symbol(range_tree.name)
  657. strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs(
  658. index_in_tile_vars, {s: 0}
  659. )
  660. return strides
  661. @staticmethod
  662. def _map_tuple_or_scalar(fn, value):
  663. if isinstance(value, tuple):
  664. return tuple(map(fn, value))
  665. return fn(value)
  666. def estimate_kernel_num_bytes(self):
  667. """
  668. Try the best to estimate the total size (in bytes) of the
  669. kernel's inputs and outputs, which is used for estimating the memory
  670. throughput of this kernel. This information is used for checking how
  671. far we are from the peak memory bandwidth. It's important that
  672. we want to avoid overestimating the sizes of the inputs and outputs,
  673. because it can wrongfully give us a very large memory traffic value,
  674. which may be even larger than the theoretical bandwidth and thus
  675. become very misleading. This is particularly problematic for cases
  676. where we slice some inputs. In those cases, we should only count
  677. the size of the "slices" instead of the original inputs, because
  678. only the slices contribute to the real memory traffic.
  679. """
  680. nbytes = []
  681. ninplace_args = len(unique(self.args.inplace_buffers.values()))
  682. _, call_args, _, _ = self.args.python_argdefs()
  683. # For pointwise and reduction kernels, this is the upper-bound numels
  684. # for the output buffer.
  685. # FIXME: This is not exactly right for cases like below:
  686. # def foo(tensor0, tensor1):
  687. # x0 = narrow(tensor0)
  688. # return cat(x0, tensor1)
  689. # For this example, we will end up overestimate the size for the
  690. # slice s0. Potentially, we could have precise inputs information
  691. # if we maintained the original inputs of the Pointwise kernel created
  692. # for the "cat". However, I think it might be a bit overwhelming that
  693. # we add such complexity only for handling some particular cases for
  694. # benchmarking.
  695. out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels))
  696. for i, arg in enumerate(call_args):
  697. # "buf" may be narrowed. In this case, the number of memory accesses
  698. # should be estimated based on the reinterpreted layout.
  699. # On the other hand, buf may be broadcasted. In this case,
  700. # counting the size of the underline storage would give us
  701. # a better estimation in terms of memory accesses.
  702. if arg not in self.buf_accesses:
  703. nbytes.append(0)
  704. continue
  705. arg_numel = V.graph.get_numel(arg)
  706. buf_size = V.graph.sizevars.size_hint(arg_numel)
  707. if buf_size > out_numel:
  708. # This arg points to a buf that has been sliced.
  709. # We need to count each individual slice to have
  710. # a better estimation.
  711. indices: Set[Any] = set()
  712. no_index_dep_count = 0
  713. for dep in self.buf_accesses[arg]:
  714. if isinstance(dep, (StarDep, WeakDep)):
  715. indices.add(f"no_index_dep_{no_index_dep_count}")
  716. no_index_dep_count += 1
  717. else:
  718. indices.add(dep.index)
  719. numel = len(indices) * out_numel
  720. else:
  721. numel = buf_size
  722. dtype = V.graph.get_dtype(arg)
  723. dtype_size = get_dtype_size(dtype)
  724. nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
  725. return sum(nbytes)
  726. def warn_mix_layout(self, kernel_name):
  727. """
  728. Print message if the kernel have mixed layout inputs.
  729. Only care about 4D tensor for now.
  730. """
  731. if (
  732. len(self.args.input_buffers) == 1
  733. and len(self.args.output_buffers) == 1
  734. and len(self.args.inplace_buffers) == 0
  735. ):
  736. # even if input buffer and output buffer have different layout,
  737. # this can be a layout conversion kernel. No need to warn for
  738. # the mix layouts.
  739. return
  740. argdefs, call_args, signature, _ = self.args.python_argdefs()
  741. uniform_stride_order = None
  742. for arg_name in call_args:
  743. buf = V.graph.get_buffer(arg_name)
  744. if buf and len(buf.layout.size) == 4:
  745. # ignore the tensor if only 1 dimension is non-zero
  746. if len([x for x in buf.layout.size if x == 1]) == 3:
  747. continue
  748. stride_order = ir.get_stride_order(buf.layout.stride)
  749. if uniform_stride_order is None:
  750. uniform_stride_order = stride_order
  751. elif uniform_stride_order != stride_order:
  752. msg = yellow_text(
  753. f"Expected stride order {uniform_stride_order}, but found stride order"
  754. + f" {stride_order} for kernel {kernel_name}"
  755. )
  756. log.warning(msg)
  757. stride_order_list = [
  758. ir.get_stride_order(V.graph.get_buffer(name).layout.stride)
  759. if V.graph.get_buffer(name)
  760. else None
  761. for name in call_args
  762. ]
  763. size_list = [
  764. V.graph.get_buffer(name).layout.size
  765. if V.graph.get_buffer(name)
  766. else None
  767. for name in call_args
  768. ]
  769. source_list = [
  770. "GraphInput"
  771. if name in V.graph.graph_inputs
  772. else "IntermediateBuffer"
  773. if name in V.graph.name_to_buffer
  774. else None
  775. for name in call_args
  776. ]
  777. msg = yellow_text(
  778. f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}"
  779. + f"\n sizes {size_list}\n sources {source_list}\n"
  780. )
  781. log.warning(msg)
  782. return
  783. msg = green_text(
  784. f"All the inputs for the triton kernel {kernel_name} have uniform layout"
  785. )
  786. log.warning(msg)
  787. def welford_reduce_fallback(self, dtype, value):
  788. sum_ = ops.reduction(dtype, dtype, "sum", value)
  789. self.inside_reduction = False
  790. rnumel = ops.index_expr(self.numels[-1], dtype)
  791. mean = ops.truediv(sum_, rnumel)
  792. self.inside_reduction = True
  793. dx = ops.sub(value, mean)
  794. dx2 = ops.mul(dx, dx)
  795. m2 = ops.reduction(dtype, dtype, "sum", dx2)
  796. return OpsWrapper._unwrap((mean, m2, rnumel))
  797. def codegen_kernel(self):
  798. raise NotImplementedError
  799. def codegen_body(self):
  800. pass
  801. def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
  802. raise NotImplementedError
  803. class SIMDScheduling(BaseScheduling):
  804. kernel_type = SIMDKernel # override in subclass
  805. int32_type = "torch.int32"
  806. int64_type = "torch.int64"
  807. def __init__(self, scheduler):
  808. self.scheduler = scheduler
  809. def group_fn(self, sizes):
  810. return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
  811. def can_fuse(self, node1, node2):
  812. """
  813. Hook called by Scheduler to determine if the Triton backend
  814. can fuse node1 and node2. These nodes might already be
  815. FusedSchedulerNodes.
  816. """
  817. if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance(
  818. node2, scheduler.ForeachKernelSchedulerNode
  819. ):
  820. return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2)
  821. _, (numel1, rnumel1) = node1.group
  822. _, (numel2, rnumel2) = node2.group
  823. why = WhyNoFuse(node1, node2)
  824. if node1.is_split_scan() and not node2.is_split_scan():
  825. if node2.is_reduction():
  826. why("Split scan cannot fuse with reductions")
  827. elif node2.is_split_scan() and not node1.is_split_scan():
  828. if node1.is_reduction():
  829. why("Split scan cannot fuse with reductions")
  830. if node1.is_reduction() and node2.is_reduction():
  831. reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2
  832. if not reduction_can_fuse:
  833. why(
  834. "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)",
  835. numel1,
  836. numel2,
  837. rnumel1,
  838. rnumel2,
  839. )
  840. return reduction_can_fuse
  841. if not node1.is_reduction() and not node2.is_reduction():
  842. if not (numel1 == numel2 and rnumel1 == rnumel2):
  843. why(
  844. "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)",
  845. numel1,
  846. numel2,
  847. rnumel1,
  848. rnumel2,
  849. )
  850. return False
  851. if node1.is_template():
  852. # Only allow fusion for TritonTemplates for now.
  853. # Fusion for CUDATemplates are not supported.
  854. is_triton_template = isinstance(node1.node, TritonTemplateBuffer)
  855. if not is_triton_template:
  856. why("node1 is not TritonTemplateBuffer")
  857. return is_triton_template
  858. # check for a bad combined tiling
  859. tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
  860. tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1)
  861. tiling3 = self.select_tiling(
  862. node1.get_nodes() + node2.get_nodes(), numel1, rnumel1
  863. )
  864. if config.triton.tiling_prevents_pointwise_fusion:
  865. cond = True
  866. if len(tiling1) > 2:
  867. if len(tiling2) > 2:
  868. cond = tiling1 == tiling2 == tiling3
  869. else:
  870. cond = tiling1 == tiling3
  871. elif len(tiling2) > 2:
  872. cond = tiling2 == tiling3
  873. if not cond:
  874. why(
  875. "tiling mismatch (%s, %s, %s)",
  876. tiling1,
  877. tiling2,
  878. tiling3,
  879. )
  880. return False
  881. return True
  882. if not node1.is_reduction() and node2.is_reduction():
  883. assert rnumel1 == 1 and rnumel2 != 1
  884. if numel1 == numel2 * rnumel2:
  885. if not all(
  886. SIMDKernel.is_compatible((numel2, rnumel2), n.get_ranges())
  887. for n in node1.get_nodes()
  888. ):
  889. why("nodes numel/rnumel incompatibility")
  890. return False
  891. if (
  892. config.triton.tiling_prevents_reduction_fusion
  893. and not node1.is_template()
  894. ):
  895. is_reduction_tiling_valid = self.select_tiling(
  896. node1.get_nodes(), numel1
  897. ) in (
  898. (numel1, 1),
  899. (numel2, rnumel2, 1),
  900. )
  901. if not is_reduction_tiling_valid:
  902. why("invalid tiling for reduction")
  903. return is_reduction_tiling_valid
  904. return True
  905. if numel1 != numel2:
  906. why("nodes numel incompatibility")
  907. return numel1 == numel2
  908. assert node1.is_reduction() and not node2.is_reduction()
  909. # swap args to hit the case above
  910. return self.can_fuse_horizontal(node2, node1)
  911. can_fuse_vertical = can_fuse
  912. can_fuse_horizontal = can_fuse
  913. def generate_node_schedule(self, nodes, numel, rnumel):
  914. node_schedule: List[Any] = []
  915. current_loop_writes: Set[str] = set()
  916. # Writes with a reduced shape, meaning they are only present once the
  917. # reduction loop has ended
  918. current_loop_reduced_writes = set()
  919. current_loop_has_writes = False
  920. done = set()
  921. def fits_in_main_body(n):
  922. _, (node_numel, node_rnumel) = n.group
  923. return (node_numel == numel and node_rnumel == rnumel) or (
  924. node_numel == numel * rnumel and node_rnumel == 1
  925. )
  926. def fits_outside_reduction(n):
  927. _, (node_numel, node_rnumel) = n.group
  928. return node_numel == numel and node_rnumel == 1 and rnumel != 1
  929. def schedule_node_in_loop(n):
  930. nonlocal current_loop_has_writes
  931. done.add(n)
  932. node_schedule.append(n)
  933. current_loop_has_writes = True
  934. # A scan is modelled as a reduction in the scheduler but has a
  935. # full sized output that can be used inside the loop body
  936. if (
  937. n.is_reduction()
  938. and isinstance(n, scheduler.SchedulerNode)
  939. and isinstance(n.node, ir.ComputedBuffer)
  940. and not isinstance(n.node.data, ir.Scan)
  941. ):
  942. current_loop_reduced_writes.add(n.get_name())
  943. @contextlib.contextmanager
  944. def end_current_reduction_loop():
  945. nonlocal current_loop_has_writes
  946. if current_loop_has_writes:
  947. # flush out any other runnable nodes to reduce number of loops
  948. for other_node in nodes[index + 1 :]:
  949. if (
  950. node not in done
  951. and fits_in_main_body(other_node)
  952. and not (current_loop_reduced_writes & other_node.ancestors)
  953. ):
  954. schedule_node_in_loop(node)
  955. if node_schedule and node_schedule[-1] is EnableReduction:
  956. node_schedule.pop()
  957. else:
  958. node_schedule.append(DisableReduction)
  959. yield
  960. node_schedule.append(EnableReduction)
  961. current_loop_reduced_writes.clear()
  962. current_loop_has_writes = False
  963. for index, node in enumerate(nodes):
  964. if node in done:
  965. continue
  966. done.add(node)
  967. def requires_closing_previous_reduction(node, node_schedule):
  968. if rnumel == 1:
  969. return False
  970. if not current_loop_reduced_writes & node.ancestors:
  971. return False
  972. assert node_schedule and not isinstance(
  973. node_schedule[-1], (EnableReduction, DisableReduction)
  974. )
  975. return bool(current_loop_reduced_writes)
  976. if fits_in_main_body(node):
  977. if requires_closing_previous_reduction(node, node_schedule):
  978. with end_current_reduction_loop():
  979. pass # need to start a new reduction loop
  980. schedule_node_in_loop(node)
  981. elif fits_outside_reduction(node):
  982. with end_current_reduction_loop():
  983. node_schedule.append(node)
  984. else:
  985. raise NotImplementedError(
  986. f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}"
  987. )
  988. return node_schedule
  989. def codegen_node(
  990. self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode]
  991. ):
  992. """
  993. Given a set of pre-fused nodes, generate a Triton kernel.
  994. """
  995. nodes: List[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment]
  996. _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
  997. node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
  998. buf_accesses = collections.defaultdict(list)
  999. for node in nodes:
  1000. for access in node.read_writes.reads | node.read_writes.writes:
  1001. buf_accesses[access.name].append(access)
  1002. schedule_log.debug("Schedule:\n %s", node_schedule)
  1003. return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel)
  1004. @staticmethod
  1005. def reduction_hint(node):
  1006. assert node.is_reduction()
  1007. if all(
  1008. dep.is_contiguous()
  1009. for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes)
  1010. ):
  1011. return ReductionHint.INNER
  1012. else:
  1013. return node.node.data.reduction_hint
  1014. @staticmethod
  1015. def can_use_32bit_indexing(
  1016. numel: sympy.Expr, buffers: Iterable[Union[ir.Buffer, ir.TensorBox]]
  1017. ) -> bool:
  1018. int_max = torch.iinfo(torch.int32).max
  1019. size_hint = V.graph.sizevars.size_hint
  1020. has_hint = V.graph.sizevars.shape_env.has_hint
  1021. def within_32bit(e):
  1022. # Allow for unhinted e as long as we can still statically prove
  1023. # (e.g., via ValueRanges) that it is still in bounds
  1024. if V.graph.sizevars.is_expr_static_and_true(e <= int_max):
  1025. return True
  1026. # Otherwise, the hint MUST exist and be in range
  1027. return has_hint(e) and size_hint(e) <= int_max
  1028. if not within_32bit(numel):
  1029. return False
  1030. # Any use of a MultiOutputLayout will create a buffer with a
  1031. # Layout whose sizes are accounted for
  1032. buf_sizes = [
  1033. buf.get_layout().storage_size()
  1034. for buf in buffers
  1035. if not isinstance(buf.get_layout(), ir.MultiOutputLayout)
  1036. ]
  1037. if not all(within_32bit(size) for size in buf_sizes):
  1038. return False
  1039. # Only install guards for 32-bit indexing as there is no correctness
  1040. # issue with using 64-bit for everything
  1041. V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type]
  1042. for size in buf_sizes:
  1043. V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type]
  1044. return True
  1045. @classmethod
  1046. def select_index_dtype(cls, node_schedule, numel, reduction_numel):
  1047. # Gather all used buffer names
  1048. buffer_names = set()
  1049. for node in node_schedule:
  1050. if not isinstance(node, scheduler.BaseSchedulerNode):
  1051. continue
  1052. buffer_names.update(node.get_names())
  1053. buffer_names.update(node.used_buffer_names())
  1054. # Get buffers objects
  1055. def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]:
  1056. buf = V.graph.get_buffer(name)
  1057. if buf is None:
  1058. raise RuntimeError(f"Failed to find buffer matching name {name}")
  1059. return buf
  1060. buffers = [V.graph.get_buffer(name) for name in buffer_names]
  1061. # In theory we can separately check xnumel and rnumel are <= int_max
  1062. # but some indexers do use the full linear index so we need to be
  1063. # conservative here.
  1064. total_numel = numel * reduction_numel
  1065. if SIMDScheduling.can_use_32bit_indexing(total_numel, buffers):
  1066. return cls.int32_type
  1067. return cls.int64_type
  1068. def has_non_contiguous_pw_in_reduction_kernel(self, node_schedule, numel, rnumel):
  1069. pointwise_nodes = list(
  1070. filter(
  1071. lambda n: n not in (EnableReduction, DisableReduction)
  1072. and not n.is_reduction()
  1073. and n.group[1][0] == numel * rnumel,
  1074. node_schedule,
  1075. )
  1076. )
  1077. for node in pointwise_nodes:
  1078. # An index can be an integer when loading a random seed.
  1079. if not all(
  1080. not isinstance(dep, MemoryDep)
  1081. or dep.is_contiguous()
  1082. or isinstance(dep.index, (sympy.Integer, int))
  1083. or dep.stride1_for_last_dim()
  1084. for dep in itertools.chain(
  1085. node.read_writes.reads, node.read_writes.writes
  1086. )
  1087. ):
  1088. return True
  1089. return False
  1090. def get_kernel_args(self, node_schedule, numel, reduction_numel):
  1091. reductions = list(
  1092. filter(
  1093. lambda n: n not in (EnableReduction, DisableReduction)
  1094. and n.is_reduction(),
  1095. node_schedule,
  1096. )
  1097. )
  1098. if len(reductions) > 0:
  1099. hints = [self.reduction_hint(n) for n in reductions]
  1100. if hints.count(hints[0]) == len(hints):
  1101. reduction_hint_val = hints[0]
  1102. else:
  1103. reduction_hint_val = ReductionHint.DEFAULT
  1104. if (
  1105. reduction_hint_val == ReductionHint.INNER
  1106. and self.has_non_contiguous_pw_in_reduction_kernel(
  1107. node_schedule, numel, reduction_numel
  1108. )
  1109. ):
  1110. reduction_hint_val = ReductionHint.DEFAULT
  1111. else:
  1112. reduction_hint_val = ReductionHint.DEFAULT
  1113. mutations = set()
  1114. for node in node_schedule:
  1115. if hasattr(node, "get_mutations"):
  1116. mutations.update(node.get_mutations())
  1117. index_dtype = self.select_index_dtype(node_schedule, numel, reduction_numel)
  1118. return reduction_hint_val, mutations, index_dtype
  1119. def codegen_node_schedule(
  1120. self, node_schedule, buf_accesses, numel, reduction_numel
  1121. ):
  1122. from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel
  1123. tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel)
  1124. (
  1125. reduction_hint_val,
  1126. mutations,
  1127. index_dtype,
  1128. ) = self.get_kernel_args(node_schedule, numel, reduction_numel)
  1129. is_split_scan = any(
  1130. isinstance(node, BaseSchedulerNode) and node.is_split_scan()
  1131. for node in node_schedule
  1132. )
  1133. kernel_type = TritonSplitScanKernel if is_split_scan else self.kernel_type
  1134. kernel_args = tiled_groups
  1135. kernel_kwargs = {
  1136. "reduction_hint": reduction_hint_val,
  1137. "mutations": mutations,
  1138. "index_dtype": index_dtype,
  1139. }
  1140. kernel = kernel_type(
  1141. *kernel_args,
  1142. **kernel_kwargs,
  1143. )
  1144. kernel.buf_accesses = buf_accesses
  1145. self.codegen_node_schedule_with_kernel(node_schedule, kernel)
  1146. with V.set_kernel_handler(kernel):
  1147. src_code = kernel.codegen_kernel()
  1148. kernel_name = self.define_kernel(src_code, node_schedule, kernel)
  1149. log.debug("Generating kernel code with kernel_name: %s", kernel_name)
  1150. kernel.kernel_name = kernel_name
  1151. kernel.code_hash = code_hash(src_code)
  1152. if kernel.persistent_reduction and config.triton.multi_kernel:
  1153. kernel2 = self.kernel_type(
  1154. *kernel_args,
  1155. **kernel_kwargs,
  1156. disable_persistent_reduction=True,
  1157. )
  1158. self.codegen_node_schedule_with_kernel(node_schedule, kernel2)
  1159. with V.set_kernel_handler(kernel2):
  1160. src_code2 = kernel2.codegen_kernel()
  1161. kernel_name2 = self.define_kernel(src_code2, node_schedule, kernel)
  1162. kernel2.kernel_name = kernel_name2
  1163. kernel2.code_hash = code_hash(src_code2)
  1164. final_kernel = MultiKernel([kernel, kernel2])
  1165. else:
  1166. final_kernel = kernel # type: ignore[assignment]
  1167. with V.set_kernel_handler(final_kernel):
  1168. for node in node_schedule:
  1169. if node not in (EnableReduction, DisableReduction):
  1170. node.mark_run()
  1171. self.codegen_comment(node_schedule)
  1172. final_kernel.call_kernel(final_kernel.kernel_name)
  1173. if config.nan_asserts:
  1174. final_kernel.codegen_nan_check()
  1175. if config.warn_mix_layout:
  1176. final_kernel.warn_mix_layout(kernel_name)
  1177. V.graph.removed_buffers |= final_kernel.removed_buffers
  1178. V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove
  1179. if (
  1180. V.graph.wrapper_code.supports_intermediate_hooks
  1181. and config.generate_intermediate_hooks
  1182. ):
  1183. # Not every node in the schedule will actually be live on output;
  1184. # we can't check dead buffers.
  1185. live_outs = kernel.args.live_output_buffers()
  1186. for node in node_schedule:
  1187. if not isinstance(node, scheduler.BaseSchedulerNode):
  1188. continue
  1189. name = node.get_name()
  1190. if name not in live_outs:
  1191. continue
  1192. assert node.node is not None
  1193. origin_node = node.node.get_origin_node()
  1194. if origin_node is not None:
  1195. counters["inductor"]["intermediate_hooks"] += 1
  1196. V.graph.wrapper_code.writeline(
  1197. f"run_intermediate_hooks({origin_node.name!r}, {name})"
  1198. )
  1199. self.scheduler.free_buffers()
  1200. def codegen_node_schedule_with_kernel(self, node_schedule, kernel):
  1201. def current_reduction_nodes(nodes):
  1202. return itertools.takewhile(lambda n: n is not DisableReduction, nodes)
  1203. with kernel:
  1204. stack = contextlib.ExitStack()
  1205. kernel.set_last_usage(current_reduction_nodes(node_schedule))
  1206. for node in node_schedule:
  1207. if node not in (EnableReduction, DisableReduction):
  1208. node.decide_inplace_update()
  1209. for i, node in enumerate(node_schedule):
  1210. if node is DisableReduction:
  1211. stack.enter_context(kernel.disable_reduction())
  1212. elif node is EnableReduction:
  1213. stack.close()
  1214. kernel.set_last_usage(current_reduction_nodes(node_schedule[i:]))
  1215. else:
  1216. # TODO - use split ranges ?
  1217. indexing_dtype_strength_reduction(node._body)
  1218. index_vars = kernel.split_and_set_ranges(node.get_ranges())
  1219. node.codegen(index_vars)
  1220. def codegen_template(
  1221. self, template_node, epilogue_nodes, only_gen_src_code=False
  1222. ) -> Optional[str]:
  1223. """
  1224. Codegen a triton template
  1225. If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper
  1226. """
  1227. _, (numel, rnumel) = template_node.group
  1228. assert rnumel == 1
  1229. kernel, render = template_node.node.make_kernel_render(template_node.node)
  1230. with kernel:
  1231. if not only_gen_src_code:
  1232. for node in [template_node, *epilogue_nodes]:
  1233. node.mark_run()
  1234. partial_code = render()
  1235. with kernel.set_subgraph_body("<STORE_OUTPUT>"):
  1236. for node in epilogue_nodes:
  1237. node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
  1238. if not isinstance(partial_code, str):
  1239. partial_code.finalize_hook("<DEF_KERNEL>")
  1240. # finalize must be called after adding epilogue above
  1241. with V.set_kernel_handler(kernel):
  1242. # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion.
  1243. with kernel.set_subgraph_body("<STORE_OUTPUT>"):
  1244. if isinstance(partial_code, str):
  1245. src_code = partial_code
  1246. else:
  1247. partial_code.finalize_hook("<STORE_OUTPUT>")
  1248. src_code = partial_code.code
  1249. node_schedule = [template_node, *epilogue_nodes]
  1250. if config.benchmark_kernel:
  1251. num_gb = kernel.estimate_kernel_num_bytes() / 1e9
  1252. grid_args = V.graph.sizevars.size_hints(kernel.call_sizes)
  1253. assert kernel.meta is not None, "meta is None"
  1254. grid = kernel.grid_fn(*grid_args, kernel.meta)
  1255. src_code = (
  1256. f"{kernel.imports_for_benchmark_kernel()}\n"
  1257. f"{src_code}\n"
  1258. f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}"
  1259. )
  1260. if only_gen_src_code:
  1261. return src_code
  1262. kernel_name = self.define_kernel(src_code, node_schedule, kernel)
  1263. self.codegen_comment(node_schedule)
  1264. kernel.call_kernel(kernel_name, template_node.node)
  1265. V.graph.removed_buffers |= kernel.removed_buffers
  1266. V.graph.inplaced_to_remove |= kernel.inplaced_to_remove
  1267. self.scheduler.free_buffers()
  1268. return None
  1269. def codegen_sync(self):
  1270. V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize())
  1271. def codegen_foreach(self, foreach_node):
  1272. from .triton_foreach import ForeachKernel
  1273. for partitions_with_metadata in ForeachKernel.horizontal_partition(
  1274. foreach_node.get_subkernel_nodes(), self
  1275. ):
  1276. kernel = ForeachKernel()
  1277. for nodes, tiled_groups, numel, rnumel in partitions_with_metadata:
  1278. node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
  1279. (
  1280. reduction_hint_val,
  1281. mutations,
  1282. index_dtype,
  1283. ) = self.get_kernel_args(node_schedule, numel, rnumel)
  1284. subkernel = kernel.create_sub_kernel(
  1285. *tiled_groups,
  1286. reduction_hint=reduction_hint_val,
  1287. mutations=mutations,
  1288. index_dtype=index_dtype,
  1289. )
  1290. self.codegen_node_schedule_with_kernel(
  1291. node_schedule,
  1292. subkernel,
  1293. )
  1294. with V.set_kernel_handler(subkernel):
  1295. for node in node_schedule:
  1296. if node not in (EnableReduction, DisableReduction):
  1297. node.mark_run()
  1298. V.graph.removed_buffers |= subkernel.removed_buffers
  1299. V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove
  1300. src_code = kernel.codegen_kernel()
  1301. kernel_name = self.define_kernel(src_code, [foreach_node], kernel)
  1302. self.codegen_comment([foreach_node])
  1303. kernel.call_kernel(V.graph.wrapper_code, kernel_name)
  1304. self.scheduler.free_buffers()
  1305. @staticmethod
  1306. @functools.lru_cache(32)
  1307. def candidate_tilings(node):
  1308. ranges, reduction_ranges = node.get_ranges()
  1309. if len(ranges) <= 1:
  1310. return ()
  1311. rw = node.pointwise_read_writes()
  1312. assert len(rw.range_vars) == len(ranges)
  1313. # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads
  1314. # that need to access the entire tensor; they don't contribute read indexing
  1315. # information (and practically, they don't have dep.index so they can't be used
  1316. # for stride_hints below
  1317. dep_sources = [rw.reads, rw.writes]
  1318. assert all(
  1319. isinstance(dep, (MemoryDep, StarDep))
  1320. for dep in itertools.chain.from_iterable(dep_sources)
  1321. )
  1322. deps = [
  1323. dep
  1324. for dep in itertools.chain.from_iterable(dep_sources)
  1325. if dep.name not in V.graph.removed_buffers and isinstance(dep, MemoryDep)
  1326. ]
  1327. write_names = {dep.name for dep in rw.writes}
  1328. tilings: List[CandidateTiling] = []
  1329. for dep in deps:
  1330. strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars)
  1331. assert len(strides) == len(ranges)
  1332. try:
  1333. split = strides.index(1) + 1
  1334. if split == len(ranges):
  1335. continue
  1336. if all(s == 0 for s in strides[split:]):
  1337. # if this is a broadcasted tensor and all dimensions after split are broadcast,
  1338. # this is not a real split
  1339. continue
  1340. except ValueError:
  1341. continue
  1342. tiled_groups = (
  1343. V.graph.sizevars.simplify(sympy_product(ranges[:split])),
  1344. V.graph.sizevars.simplify(sympy_product(ranges[split:])),
  1345. )
  1346. # score by number of elements
  1347. score = V.graph.sizevars.size_hint(
  1348. sympy_product(
  1349. size for size, stride in zip(ranges, strides) if stride != 0
  1350. )
  1351. )
  1352. if dep.name in write_names:
  1353. # ngimel said contiguous writes is more important than reads
  1354. score *= 2
  1355. if CandidateTiling.is_good_size(tiled_groups[0]):
  1356. score *= 2
  1357. if CandidateTiling.is_good_size(tiled_groups[1]):
  1358. score *= 2
  1359. if (
  1360. V.graph.sizevars.size_hint(
  1361. score - sympy_product(itertools.chain(ranges, reduction_ranges))
  1362. )
  1363. >= 0
  1364. ):
  1365. tilings.append(CandidateTiling(tiled_groups, score, dep.name))
  1366. return tilings
  1367. @classmethod
  1368. def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)):
  1369. """
  1370. Heuristics to decide how to tile kernels.
  1371. Currently, we tile based on stride-1 dimensions.
  1372. Returns:
  1373. `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel`
  1374. """
  1375. if reduction_numel != 1 or config.triton.max_tiles <= 1:
  1376. # TODO(jansel): should we tile reductions?
  1377. # do perf hint here if stride-1 dim is not being reduced
  1378. if perf_hint_log.level <= logging.WARNING:
  1379. for node in EnableReduction.filter(node_schedule):
  1380. if len(cls.candidate_tilings(node)) > 0:
  1381. perf_hint_log.info("reduction over non-contiguous dims")
  1382. break
  1383. return (numel, reduction_numel)
  1384. seen_names = set()
  1385. candidate_tiles: Counter[Any] = collections.Counter()
  1386. for node in EnableReduction.filter(node_schedule):
  1387. for tiling in cls.candidate_tilings(node):
  1388. if tiling.name in seen_names:
  1389. continue
  1390. seen_names.add(tiling.name)
  1391. candidate_tiles[tiling.tiling] += tiling.score
  1392. ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()]
  1393. if config.triton.max_tiles >= 3:
  1394. # Consider adding a third dimension of tiling, but only
  1395. # when a1 is a multiple of b1; otherwise, you have a lot
  1396. # of stragglers which is annoying to generate code for.
  1397. #
  1398. # NB: More than three max tiles is not enabled by default.
  1399. # Add one 3D tiling choice
  1400. for i in range(1, len(ranked_tilings)):
  1401. a0, a1 = ranked_tilings[0]
  1402. b0, b1 = ranked_tilings[i]
  1403. if V.graph.sizevars.size_hint(a1 - b1) == 0:
  1404. continue
  1405. if V.graph.sizevars.size_hint(a1 - b1) < 0:
  1406. # swap so a0 is bigger
  1407. a0, a1 = ranked_tilings[i]
  1408. b0, b1 = ranked_tilings[0]
  1409. assert V.graph.sizevars.size_hint(a1 - b1) > 0
  1410. if V.graph.sizevars.statically_known_multiple_of(a1, b1):
  1411. tiling = (a0, FloorDiv(a1, b1), b1)
  1412. ranked_tilings = [tiling] + ranked_tilings
  1413. break # only 1 choice for now
  1414. if len(ranked_tilings) > 1:
  1415. perf_hint_log.info("possibly bad tiling: %s", ranked_tilings)
  1416. for tiled_groups in ranked_tilings:
  1417. new_groups = (*tiled_groups, reduction_numel)
  1418. if all(
  1419. SIMDKernel.is_compatible(new_groups, node.get_ranges())
  1420. for node in node_schedule
  1421. if isinstance(node, scheduler.SchedulerNode)
  1422. ):
  1423. return new_groups
  1424. return (numel, reduction_numel)
  1425. def flush(self):
  1426. pass
  1427. def ready_to_flush(self) -> bool:
  1428. return False
  1429. def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
  1430. @dataclasses.dataclass
  1431. class LastUsageHolder:
  1432. n: Any
  1433. last_usage: Any
  1434. def __del__(self):
  1435. self.n.last_usage = self.last_usage
  1436. last_usage_holders = [LastUsageHolder(n, n.last_usage) for n in nodes]
  1437. # empty last_usage. May cause more aggressive 'evict_last'. Should be fine.
  1438. for n in nodes:
  1439. n.last_usage = set()
  1440. if not nodes[0].is_template():
  1441. _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
  1442. node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
  1443. tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
  1444. reduction_hint_val, mutations, index_dtype = self.get_kernel_args(
  1445. node_schedule, numel, rnumel
  1446. )
  1447. kernel = self.kernel_type(
  1448. *tiled_groups,
  1449. reduction_hint=reduction_hint_val,
  1450. mutations=mutations,
  1451. index_dtype=index_dtype,
  1452. )
  1453. self.codegen_node_schedule_with_kernel(node_schedule, kernel)
  1454. with config.patch(
  1455. "benchmark_kernel", benchmark_kernel
  1456. ), V.set_kernel_handler(kernel):
  1457. src_code = kernel.codegen_kernel()
  1458. else:
  1459. template_node = nodes[0]
  1460. epilogue_nodes = nodes[1:]
  1461. with config.patch("benchmark_kernel", benchmark_kernel):
  1462. src_code = self.codegen_template(
  1463. template_node, epilogue_nodes, only_gen_src_code=True
  1464. )
  1465. src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
  1466. return src_code
  1467. def codegen_comment(self, node_schedule):
  1468. pass
  1469. def define_kernel(self, src_code, node_schedule, kernel):
  1470. raise NotImplementedError
  1471. @dataclasses.dataclass
  1472. class CandidateTiling:
  1473. tiling: Tuple[sympy.Expr, sympy.Expr]
  1474. score: int # higher is better
  1475. name: Optional[str] = None
  1476. @staticmethod
  1477. def is_good_size(s):
  1478. """Somewhat arbitrary heuristic used to boost scores for some sizes"""
  1479. s = V.graph.sizevars.size_hint(s)
  1480. return s >= 32 and (s % 32 == 0)
  1481. class DisableReduction:
  1482. """
  1483. Marker to invoke `kernel.disable_reduction()`. This closes a
  1484. reduction loop and allows for pointwise ops to occur on the output
  1485. of a reduction.
  1486. """
  1487. class EnableReduction:
  1488. """
  1489. Marker to end a DisableReduction block.
  1490. """
  1491. @staticmethod
  1492. def filter(node_schedule):
  1493. """
  1494. Get the nodes from node_schedule skipping those in a
  1495. DisableReduction block.
  1496. """
  1497. disabled = False
  1498. for node in node_schedule:
  1499. if node in (EnableReduction, DisableReduction):
  1500. # Don't tile stuff outside the main reduction loop
  1501. disabled = node is DisableReduction
  1502. elif disabled:
  1503. pass
  1504. else:
  1505. yield node
  1506. class CantSplit(Exception):
  1507. pass