utils.py 56 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import collections
  4. import contextlib
  5. import dataclasses
  6. import enum
  7. import functools
  8. import inspect
  9. import io
  10. import itertools
  11. import json
  12. import logging
  13. import math
  14. import operator
  15. import os
  16. import platform
  17. import shutil
  18. import sys
  19. import tempfile
  20. import textwrap
  21. import time
  22. import unittest
  23. from datetime import datetime
  24. from io import StringIO
  25. from pathlib import Path
  26. from typing import (
  27. Any,
  28. Callable,
  29. Dict,
  30. Generic,
  31. Iterable,
  32. List,
  33. NamedTuple,
  34. Optional,
  35. Protocol,
  36. Set,
  37. Tuple,
  38. TypeVar,
  39. Union,
  40. ValuesView,
  41. )
  42. from typing_extensions import Concatenate, ParamSpec
  43. from unittest import mock
  44. import sympy
  45. import torch
  46. import torch._export
  47. import torch.utils._pytree as pytree
  48. from torch._dynamo.device_interface import get_interface_for_device
  49. from torch._dynamo.utils import detect_fake_mode
  50. from torch.autograd import DeviceType
  51. from torch.autograd.profiler_util import EventList
  52. from torch.fx.passes.shape_prop import ShapeProp
  53. from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
  54. from torch.utils._sympy.symbol import make_symbol, SymT
  55. from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
  56. from . import config
  57. from .runtime.runtime_utils import cache_dir, ceildiv as runtime_ceildiv
  58. log = logging.getLogger(__name__)
  59. _T = TypeVar("_T")
  60. VarRanges = Dict[sympy.Expr, sympy.Expr]
  61. GPU_ALIGN_BYTES = 16
  62. ALIGN_BYTES = 64
  63. assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2"
  64. def _align(nbytes):
  65. """Round up to the nearest multiple of ALIGN_BYTES"""
  66. return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES
  67. def _is_aligned(v: sympy.Expr):
  68. """v can be statically proven to be a multiple of ALIGN_BYTES"""
  69. if isinstance(v, (sympy.Add, sympy.Max)):
  70. return all(map(_is_aligned, v.args))
  71. return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES
  72. class align(sympy.Function):
  73. """Symbolically round up to the nearest multiple of ALIGN_BYTES"""
  74. nargs = (1,)
  75. is_integer = True
  76. @classmethod
  77. def eval(cls, value):
  78. if isinstance(value, (int, sympy.Integer)):
  79. return _align(int(value))
  80. if _is_aligned(value):
  81. return value
  82. def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float:
  83. """
  84. Returns benchmark results by examining torch profiler events.
  85. This could be more accurate as it doesn't count CPU side overhead.
  86. However, this also requires manually excluding irrelevant event, e.g.
  87. vectorized_elementwise_kernel which is used to fill L2 cache,
  88. various CUDA events, etc, so could also be fragile.
  89. """
  90. fn()
  91. torch.cuda.synchronize()
  92. cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
  93. # Estimate the runtime of the function
  94. start_event = torch.cuda.Event(enable_timing=True)
  95. end_event = torch.cuda.Event(enable_timing=True)
  96. start_event.record()
  97. for _ in range(5):
  98. cache.zero_()
  99. fn()
  100. end_event.record()
  101. torch.cuda.synchronize()
  102. estimate_ms = start_event.elapsed_time(end_event) / 5
  103. # compute number of warmup and repeat
  104. n_warmup = max(1, int(warmup / estimate_ms))
  105. n_repeat = max(1, int(rep / estimate_ms))
  106. # Warm-up
  107. for _ in range(n_warmup):
  108. fn()
  109. with torch.profiler.profile(
  110. activities=[
  111. torch.profiler.ProfilerActivity.CUDA,
  112. ]
  113. ) as p:
  114. # Benchmark
  115. for i in range(n_repeat):
  116. # we clear the L2 cache before each run
  117. cache.zero_()
  118. # record time of `fn`
  119. fn()
  120. # Record clocks
  121. torch.cuda.synchronize()
  122. log.debug("raw events")
  123. log.debug(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
  124. filtered_events = EventList(
  125. [
  126. event
  127. for event in p.events()
  128. if event.device_type == DeviceType.CUDA and event.name != "Context Sync"
  129. ]
  130. )
  131. if len(filtered_events) % n_repeat != 0:
  132. raise RuntimeError(
  133. "Failed to divide all profiling events into #repeat groups. "
  134. "#CUDA events: %d, #repeats: %s",
  135. len(filtered_events),
  136. n_repeat,
  137. )
  138. num_event_per_group = len(filtered_events) / n_repeat
  139. actual_events = EventList(
  140. [
  141. event
  142. for i, event in enumerate(filtered_events)
  143. if i % num_event_per_group != 0
  144. ]
  145. )
  146. actual_events._build_tree()
  147. actual_events = actual_events.key_averages()
  148. log.debug("profiling time breakdown")
  149. log.debug(actual_events.table(row_limit=-1))
  150. res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat
  151. log.debug("profiling results: %s ms", res)
  152. return res
  153. @functools.lru_cache(None)
  154. def has_torchvision_roi_align() -> bool:
  155. try:
  156. from torchvision.ops import roi_align # noqa: F401
  157. torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta")
  158. return roi_align is not None and hasattr(
  159. getattr(torch.ops, "torchvision", None), "roi_align"
  160. )
  161. except ImportError:
  162. return False
  163. except RuntimeError as e:
  164. assert "torchvision::nms does not exist" in str(e)
  165. return False
  166. def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
  167. if device is None:
  168. return torch.tensor(0.0).device # default device
  169. if isinstance(device, str):
  170. device = torch.device(device)
  171. if device.type not in ("cpu", "meta") and device.index is None:
  172. device_interface = get_interface_for_device(device.type)
  173. return torch.device(device.type, index=device_interface.Worker.current_device())
  174. return device
  175. def sympy_product(it):
  176. return functools.reduce(operator.mul, it, sympy.Integer(1))
  177. def sympy_dot(seq1, seq2):
  178. assert len(seq1) == len(seq2)
  179. return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
  180. def unique(it: Iterable[_T]) -> ValuesView[_T]:
  181. return {id(x): x for x in it}.values()
  182. def ceildiv(
  183. numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
  184. ) -> Union[int, sympy.Expr]:
  185. if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr):
  186. return CeilDiv(sympy.sympify(numer), sympy.sympify(denom))
  187. # TODO: There is a bug in a call to this function, to repro:
  188. # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
  189. # --amp --only YituTechConvBert --dynamic-shapes
  190. assert isinstance(numer, int) and isinstance(
  191. denom, int
  192. ), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
  193. return runtime_ceildiv(numer, denom)
  194. def _type_of(key):
  195. # Use the function here to get rid of dependencies on the Triton during the codegen.
  196. # Refer to Triton implementation here:
  197. # https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238
  198. # `None` is nullptr. Implicitly convert to *i8.
  199. if key is None:
  200. return "*i8"
  201. dtype_str = str(key).split(".")[-1]
  202. tys = {
  203. "bool": "i1",
  204. "float8e4nv": "fp8e4nv",
  205. "float8e5": "fp8e5",
  206. "float8e4b15": "fp8e4b15",
  207. "float8e4b15x4": "fp8e4b15x4",
  208. "float8_e4m3fn": "fp8e4nv",
  209. "float8_e5m2": "fp8e5",
  210. "float16": "fp16",
  211. "bfloat16": "bf16",
  212. "float32": "fp32",
  213. "float64": "fp64",
  214. "int8": "i8",
  215. "int16": "i16",
  216. "int32": "i32",
  217. "int64": "i64",
  218. "uint8": "u8",
  219. "uint16": "u16",
  220. "uint32": "u32",
  221. "uint64": "u64",
  222. }
  223. # reinterpret can create triton type
  224. for v in list(tys.values()):
  225. tys[v] = v
  226. return key if isinstance(key, str) else f"*{tys[dtype_str]}"
  227. def convert_shape_to_inductor(
  228. lst: Iterable[Union[int, torch.SymInt]]
  229. ) -> List[sympy.Expr]:
  230. """
  231. Gets the shape and stride of a tensor. For non-symbolic tensors, this is
  232. trivial. But for symbolic tensors, we need to map from SymIntNode into
  233. sympy.Expr.
  234. """
  235. return [
  236. i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst
  237. ]
  238. def convert_shape_to_symint(
  239. lst: Iterable[Union[int, sympy.Expr]]
  240. ) -> List[Union[int, torch.SymInt]]:
  241. """
  242. Takes a list of shapes from Inductor and converts them into symints (or just
  243. ints if all shapes are static).
  244. """
  245. from .virtualized import V
  246. return [
  247. i
  248. if isinstance(i, int)
  249. else int(i)
  250. if isinstance(i, sympy.Integer)
  251. else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
  252. for i in lst
  253. ]
  254. def is_view(op: torch._ops.OpOverload):
  255. """
  256. Does this op overload have aliasing
  257. """
  258. assert isinstance(op, torch._ops.OpOverload)
  259. return any(a.alias_info is not None for a in op._schema.arguments)
  260. def is_pointwise_use(use):
  261. if not use.op == "call_function":
  262. return False
  263. if not (
  264. isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
  265. ):
  266. return False
  267. if use.target is operator.getitem or is_view(use.target):
  268. return all(is_pointwise_use(u) for u in use.users)
  269. return torch.Tag.pointwise in use.target.tags
  270. def gen_gm_and_inputs(target, args, kwargs):
  271. g = torch.fx.Graph()
  272. g_args = []
  273. a_args = []
  274. for n, arg in enumerate(args):
  275. if isinstance(arg, torch.Tensor):
  276. g_args.append(g.placeholder(f"arg{n}"))
  277. a_args.append(arg)
  278. else:
  279. g_args.append(arg)
  280. assert all(not isinstance(x, torch.Tensor) for x in kwargs.values())
  281. node = g.call_function(target, tuple(g_args), kwargs)
  282. if (
  283. len(target._schema.returns) == 1
  284. and str(target._schema.returns[0].type) == "Tensor"
  285. ):
  286. node = (node,)
  287. g.output(node)
  288. gm = torch.fx.GraphModule({}, g)
  289. return gm, a_args
  290. def synchronize(device: str = "cuda"):
  291. if device == "cpu":
  292. return
  293. device_interface = get_interface_for_device(device)
  294. if device_interface.is_available():
  295. device_interface.synchronize()
  296. def timed(
  297. model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda"
  298. ) -> float:
  299. synchronize(device)
  300. torch.manual_seed(1337)
  301. t0 = time.perf_counter()
  302. for _ in range(times):
  303. result = model(*example_inputs)
  304. synchronize(device)
  305. t1 = time.perf_counter()
  306. # GC the result after timing
  307. assert result is not None # type: ignore[possibly-undefined]
  308. return t1 - t0
  309. def print_performance(
  310. fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda"
  311. ):
  312. timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)])
  313. took = torch.median(timings) / times
  314. print(f"{took / baseline:.6f}")
  315. return took
  316. def precompute_method(obj: Any, method: str):
  317. """Replace obj.method() with a new method that returns a precomputed constant."""
  318. result = getattr(obj, method)()
  319. setattr(obj, method, lambda: result)
  320. def precompute_methods(obj: Any, methods: List[str]):
  321. """Replace methods with new methods that returns a precomputed constants."""
  322. for method in methods:
  323. precompute_method(obj, method)
  324. def cmp(a, b) -> int:
  325. return int(a > b) - int(a < b)
  326. def pad_listlike(x, size):
  327. if len(x) == 1:
  328. return type(x)([x[0]]) * size
  329. else:
  330. return x
  331. # Used to ensure that iterating over a set is deterministic
  332. def tuple_sorted(x):
  333. if len(x) == 0:
  334. return []
  335. def sort_func(elem):
  336. if isinstance(elem, str):
  337. return elem
  338. else:
  339. # We expect `elem` to be `scheduler.BaseSchedulerNode` type here,
  340. # but we are not able to do isinstance assert because of circular dependency
  341. return elem.get_name()
  342. return sorted(x, key=sort_func)
  343. P = ParamSpec("P")
  344. RV = TypeVar("RV", covariant=True)
  345. class CachedMethod(Protocol, Generic[P, RV]):
  346. @staticmethod
  347. def clear_cache(self) -> None:
  348. ...
  349. def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV:
  350. ...
  351. # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
  352. def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
  353. key = f"__{fn.__name__}_cache"
  354. @functools.wraps(fn)
  355. def wrapper(self):
  356. if not hasattr(self, key):
  357. setattr(self, key, fn(self))
  358. return getattr(self, key)
  359. def clear_cache(self):
  360. if hasattr(self, key):
  361. delattr(self, key)
  362. wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
  363. return wrapper # type: ignore[return-value]
  364. def aggregate_origins(node_schedule):
  365. from . import ir
  366. if isinstance(node_schedule, list):
  367. return functools.reduce(
  368. operator.or_,
  369. [
  370. node.node.origins
  371. for node in node_schedule
  372. if hasattr(node, "node") and node.node
  373. ],
  374. set(),
  375. )
  376. elif isinstance(node_schedule, ir.ExternKernel):
  377. return node_schedule.origins
  378. else:
  379. return set()
  380. def get_fused_kernel_name(node_schedule, descriptive_names):
  381. all_origins = aggregate_origins(node_schedule)
  382. if descriptive_names == "original_aten":
  383. # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
  384. sources = [
  385. origin.meta["original_aten"]._overloadpacket.__name__
  386. for origin in all_origins
  387. if origin.op == "call_function"
  388. and "original_aten" in origin.meta
  389. and origin.meta["original_aten"] is not None
  390. ]
  391. sources = sorted(set(sources))
  392. elif descriptive_names == "torch":
  393. # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
  394. sources = []
  395. for origin in all_origins:
  396. if origin.op == "call_function" and "source_fn_stack" in origin.meta:
  397. source_fn = origin.meta["source_fn_stack"][-1]
  398. if isinstance(source_fn[1], str):
  399. sources.append(source_fn[1])
  400. else:
  401. sources.append(source_fn[1].__name__)
  402. sources = sorted(set(sources))
  403. elif descriptive_names == "inductor_node":
  404. sources = [
  405. origin.name for origin in all_origins if origin.op == "call_function"
  406. ]
  407. else:
  408. raise NotImplementedError
  409. sources = sources
  410. return "_".join(["fused"] + sources)
  411. def get_kernel_metadata(node_schedule, wrapper):
  412. all_origins = aggregate_origins(node_schedule)
  413. inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
  414. from_node_dict = collections.defaultdict(list)
  415. original_aten_dict = collections.defaultdict(list)
  416. for node in inductor_nodes:
  417. if "original_aten" in node.meta and node.meta["original_aten"] is not None:
  418. key = str(node.meta["original_aten"]._overloadpacket)
  419. original_aten_dict[key].append(node.name)
  420. if "from_node" in node.meta:
  421. key = node.meta["from_node"][0][0]
  422. from_node_dict[key].append(node.name)
  423. metadata = (
  424. f"{wrapper.comment} Source Nodes: [{', '.join(sorted(from_node_dict.keys()))}], "
  425. f"Original ATen: [{', '.join(sorted(original_aten_dict.keys()))}]"
  426. )
  427. # trace back to original node here
  428. detailed_metadata = []
  429. for original_node, nodes in sorted(from_node_dict.items()):
  430. detailed_metadata.append(
  431. f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}"
  432. )
  433. return metadata, "\n".join(detailed_metadata)
  434. def dominated_nodes(
  435. initial_queue: Iterable[torch.fx.Node], skip_filter=None
  436. ) -> Set[torch.fx.Node]:
  437. """Returns the set of nodes whose values depend on those within initial_queue"""
  438. initial_queue = list(initial_queue)
  439. dominated_set = set(initial_queue)
  440. while initial_queue:
  441. node = initial_queue.pop()
  442. for user in node.users:
  443. if skip_filter and skip_filter(user):
  444. continue
  445. if user not in dominated_set:
  446. dominated_set.add(user)
  447. initial_queue.append(user)
  448. return dominated_set
  449. def gather_origins(args, kwargs):
  450. import itertools
  451. from . import ir
  452. def is_unrealized_node(n):
  453. if isinstance(n, ir.TensorBox):
  454. return is_unrealized_node(n.data)
  455. if isinstance(n, ir.StorageBox):
  456. return is_unrealized_node(n.data)
  457. return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise)
  458. kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)]
  459. arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)]
  460. return set(itertools.chain(*arg_origins, *kwarg_origins))
  461. def sympy_str(expr: sympy.Expr) -> str:
  462. """
  463. Normal sympy str is very slow, this is a lot faster. The result are
  464. somewhat worse, as it doesn't do as much simplification. So don't
  465. use this for final codegen.
  466. """
  467. if isinstance(expr, sympy.Symbol):
  468. return expr.name
  469. if isinstance(expr, sympy.Add):
  470. return " + ".join(map(sympy_str, expr.args))
  471. if isinstance(expr, sympy.Mul):
  472. return " * ".join(map(sympy_str, expr.args))
  473. if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)):
  474. return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
  475. return str(expr)
  476. def get_bounds_index_expr(index):
  477. from .virtualized import V
  478. # If this expression does not come from an FX node, we compute its bounds
  479. if (
  480. config.compute_all_bounds
  481. and (fx_node := getattr(V.interpreter, "current_node", None))
  482. and fx_node.target != "index_expr"
  483. ):
  484. return bound_sympy(index)
  485. else:
  486. return ValueRanges.unknown()
  487. def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
  488. """
  489. Used to generate an integer-nonnegative symbol.
  490. """
  491. # This should never be used for creating shape/stride symbols, as those
  492. # should all be allocated before Inductor.
  493. assert prefix != SymT.SIZE
  494. # NOTE: shape symbols are positive (> 0), but index variables are only
  495. # non-negative (>= 0).
  496. return make_symbol(prefix, idx, integer=True, nonnegative=True)
  497. def generate_assert(check):
  498. return (check or config.debug_index_asserts) and config.assert_indirect_indexing
  499. def sympy_index_symbol(name: str) -> sympy.Symbol:
  500. """
  501. Used to generate an integer-nonnegative symbol.
  502. """
  503. # This should never be used for creating shape/stride symbols, as those
  504. # should all be allocated before Inductor.
  505. assert name[0] != "s"
  506. # NOTE: shape symbols are positive (> 0), but index variables are only
  507. # non-negative (>= 0).
  508. return sympy.Symbol(name, integer=True, nonnegative=True)
  509. def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr:
  510. """
  511. When the passed replacement symbol v is a string, it is converted to a symbol with name v that
  512. have the same replaced expression integer and nonnegative properties.
  513. """
  514. def to_symbol(replaced, replacement):
  515. assert isinstance(replaced, sympy.Expr)
  516. if isinstance(replacement, str):
  517. return sympy.Symbol(
  518. replacement,
  519. integer=replaced.is_integer, # type: ignore[attr-defined]
  520. nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
  521. )
  522. else:
  523. return replacement
  524. # xreplace is faster than subs, but is way more picky
  525. return sympy.sympify(expr).xreplace(
  526. {k: to_symbol(k, v) for k, v in replacements.items()}
  527. )
  528. def is_symbolic(a: Any) -> bool:
  529. return isinstance(a, torch.SymInt) or (
  530. isinstance(a, torch.Tensor)
  531. and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride()))
  532. )
  533. def any_is_symbolic(*args: Any) -> bool:
  534. return any(is_symbolic(a) for a in args)
  535. def get_first_incompatible_cudagraph_node(gm):
  536. from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
  537. forbidden_set = {
  538. "aten._fused_moving_avg_obs_fq_helper.default",
  539. "aten._fused_moving_avg_obs_fq_helper_functional.default",
  540. "aten.multinomial.default",
  541. "fbgemm.dense_to_jagged.default",
  542. "fbgemm.jagged_to_padded_dense.default",
  543. "run_and_save_rng_state",
  544. "run_with_rng_state",
  545. "aten._local_scalar_dense",
  546. # Technically, it's not necessary to ban this, because an
  547. # assert_scalar with constant arguments can be validly run
  548. # with CUDA graphs, but the operator is also pointless with
  549. # constant arguments, so might as well ban
  550. "aten._assert_scalar",
  551. }
  552. if torch.are_deterministic_algorithms_enabled():
  553. forbidden_set.update(
  554. {
  555. "aten._unsafe_index_put.default",
  556. "aten.index_put.default",
  557. "aten.index_put_.default",
  558. "aten.scatter.src",
  559. "aten.scatter.reduce",
  560. "aten.scatter.value_reduce",
  561. "aten.scatter_add_",
  562. "aten.scatter_add.default",
  563. "aten.scatter_reduce.two",
  564. "aten.scatter_reduce_.two",
  565. "aten.scatter_reduce.two_out",
  566. }
  567. )
  568. for node in gm.graph.nodes:
  569. if str(node.target) in forbidden_set:
  570. return node
  571. if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
  572. return node
  573. return None
  574. def has_incompatible_cudagraph_ops(gm):
  575. return get_first_incompatible_cudagraph_node(gm) is not None
  576. def output_node(gm: torch.fx.GraphModule):
  577. """Get the output node from an FX graph"""
  578. last_node = next(iter(reversed(gm.graph.nodes)))
  579. assert last_node.op == "output"
  580. return last_node
  581. _registered_caches: List[Any] = []
  582. def clear_on_fresh_inductor_cache(obj: Any):
  583. """
  584. Use this decorator to register any caches that should be cache_clear'd
  585. with fresh_inductor_cache().
  586. """
  587. if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
  588. raise AttributeError(f"{obj} does not have a cache_clear method")
  589. _registered_caches.append(obj)
  590. return obj
  591. def clear_inductor_caches():
  592. """
  593. Clear all registered caches.
  594. """
  595. for obj in _registered_caches:
  596. obj.cache_clear()
  597. @contextlib.contextmanager
  598. def fresh_inductor_cache(cache_entries=None):
  599. """
  600. Contextmanager that provides a clean tmp cachedir for inductor.
  601. Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
  602. generated with this cache instance.
  603. """
  604. clear_inductor_caches()
  605. inductor_cache_dir = tempfile.mkdtemp()
  606. try:
  607. with mock.patch.dict(
  608. os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
  609. ):
  610. triton_cache_dir = os.path.join(inductor_cache_dir, "triton")
  611. with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
  612. yield
  613. if isinstance(cache_entries, dict):
  614. assert len(cache_entries) == 0, "expected empty cache_entries dict"
  615. if os.path.exists(triton_cache_dir):
  616. files = os.listdir(triton_cache_dir)
  617. cache_entries.update(
  618. {
  619. f: os.path.getsize(os.path.join(triton_cache_dir, f))
  620. for f in files
  621. if ".lock" not in f
  622. }
  623. )
  624. shutil.rmtree(inductor_cache_dir)
  625. except Exception:
  626. log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
  627. raise
  628. finally:
  629. clear_inductor_caches()
  630. def argsort(seq) -> List[int]:
  631. # preserve original order for equal strides
  632. getter = seq.__getitem__
  633. a_r = range(len(seq))
  634. return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
  635. @functools.lru_cache(8)
  636. def get_dtype_size(dtype):
  637. return torch.empty((), dtype=dtype).element_size()
  638. class LineContext(NamedTuple):
  639. context: Any
  640. class IndentedBuffer:
  641. tabwidth = 4
  642. def __init__(self, initial_indent=0):
  643. self._lines = []
  644. self._indent = initial_indent
  645. def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]:
  646. buf = StringIO()
  647. p = 1
  648. linemap = []
  649. for line in self._lines:
  650. if isinstance(line, DeferredLineBase):
  651. line = line()
  652. if line is None:
  653. continue
  654. elif isinstance(line, LineContext):
  655. linemap.append((p, line.context))
  656. continue
  657. assert isinstance(line, str)
  658. buf.write(line)
  659. buf.write("\n")
  660. p += 1 + line.count("\n")
  661. return buf.getvalue(), linemap
  662. def getvalue(self) -> str:
  663. v, _ = self.getvaluewithlinemap()
  664. return v
  665. def getrawvalue(self) -> str:
  666. buf = StringIO()
  667. for line in self._lines:
  668. if isinstance(line, DeferredLineBase):
  669. line = line()
  670. if line is None:
  671. continue
  672. elif isinstance(line, LineContext):
  673. continue
  674. assert isinstance(line, str)
  675. # backslash implies line continuation
  676. if line.endswith("\\"):
  677. buf.write(line[:-1])
  678. else:
  679. buf.write(line)
  680. buf.write("\n")
  681. return buf.getvalue()
  682. def clear(self):
  683. self._lines.clear()
  684. def __bool__(self):
  685. return bool(self._lines)
  686. def prefix(self):
  687. return " " * (self._indent * self.tabwidth)
  688. def newline(self):
  689. self.writeline("\n")
  690. def writeline(self, line):
  691. if isinstance(line, LineContext):
  692. self._lines.append(line)
  693. elif isinstance(line, DeferredLineBase):
  694. self._lines.append(line.with_prefix(self.prefix()))
  695. elif line.strip():
  696. self._lines.append(f"{self.prefix()}{line}")
  697. else:
  698. self._lines.append("")
  699. def writelines(self, lines):
  700. for line in lines:
  701. self.writeline(line)
  702. def indent(self, offset=1):
  703. @contextlib.contextmanager
  704. def ctx():
  705. self._indent += offset
  706. try:
  707. yield
  708. finally:
  709. self._indent -= offset
  710. return ctx()
  711. def do_indent(self, offset=1):
  712. self._indent += offset
  713. def do_unindent(self, offset=1):
  714. self._indent -= offset
  715. def splice(self, other_code, strip=False):
  716. if isinstance(other_code, IndentedBuffer):
  717. dedent = float("inf")
  718. for line in other_code._lines:
  719. if not isinstance(line, LineContext) and line:
  720. dedent = min(dedent, len(line) - len(line.lstrip()))
  721. if math.isinf(dedent):
  722. dedent = 0
  723. for line in other_code._lines:
  724. if isinstance(line, LineContext):
  725. self._lines.append(line)
  726. else:
  727. IndentedBuffer.writeline(self, line[int(dedent) :])
  728. else:
  729. other_code = textwrap.dedent(other_code)
  730. if strip:
  731. other_code = other_code.lstrip()
  732. if not other_code:
  733. return
  734. other_code = other_code.rstrip()
  735. for line in other_code.split("\n"):
  736. self.writeline(line)
  737. def map(self, func: Callable[[Any], Any]) -> IndentedBuffer:
  738. res = IndentedBuffer(initial_indent=self._indent)
  739. res._lines = [func(line) for line in self._lines]
  740. return res
  741. def __repr__(self):
  742. return f"{type(self)}({self.getvalue()})"
  743. def __add__(self, other):
  744. assert self._indent == other._indent
  745. res = IndentedBuffer(initial_indent=self._indent)
  746. res.writelines(self._lines)
  747. res.writelines(other._lines)
  748. return res
  749. class FakeIndentedBuffer(IndentedBuffer):
  750. def __init__(self):
  751. super().__init__()
  752. def __getattribute__(self, name):
  753. if name == "__class__": # Allow access to the class attribute
  754. return object.__getattribute__(self, name)
  755. raise RuntimeError(
  756. f"Tried to call self.{name} on FakeIndentedBuffer. This buffer"
  757. "is currently used on TritonTemplateKernel to prevent actual"
  758. "writes to the body without explicitly specifying the body with"
  759. "`TritonTemplateKernel.set_subgraph_body(name)`"
  760. )
  761. @contextlib.contextmanager
  762. def restore_stdout_stderr(initial_stdout, initial_stderr):
  763. try:
  764. yield
  765. finally:
  766. sys.stdout = initial_stdout
  767. sys.stderr = initial_stderr
  768. class DeferredLineBase:
  769. """A line that can be 'unwritten' at a later time"""
  770. def __init__(self, line):
  771. if not line.strip():
  772. line = ""
  773. self.line = line
  774. def __call__(self) -> Optional[str]:
  775. """Returns either self.line or None to indicate the line has been 'unwritten'"""
  776. raise NotImplementedError
  777. def _new_line(self, line: str) -> DeferredLineBase:
  778. """Returns a new deferred line with the same condition"""
  779. raise NotImplementedError
  780. def with_prefix(self, prefix):
  781. return self._new_line(f"{prefix}{self.line}")
  782. def lstrip(self):
  783. return self._new_line(self.line.lstrip())
  784. def __getitem__(self, index):
  785. return self._new_line(self.line[index])
  786. def __bool__(self):
  787. return bool(self.line)
  788. def __len__(self):
  789. return len(self.line)
  790. @functools.lru_cache(None)
  791. def is_big_gpu(index) -> bool:
  792. min_sms = 68 # 3080
  793. avail_sms = torch.cuda.get_device_properties(index).multi_processor_count
  794. if avail_sms < min_sms:
  795. log.warning(
  796. "Not enough SMs to use max_autotune_gemm mode",
  797. extra={"min_sms": min_sms, "avail_sms": avail_sms},
  798. )
  799. return False
  800. return True
  801. def use_max_autotune() -> bool:
  802. return (
  803. config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache
  804. )
  805. def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool:
  806. return (
  807. use_max_autotune()
  808. and layout.device.type == "cuda"
  809. and layout.dtype in allowed_layout_dtypes
  810. and is_big_gpu(layout.device.index or 0)
  811. )
  812. def _use_autotune_backend(backend: str) -> bool:
  813. return backend.upper() in [
  814. x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
  815. ]
  816. def use_triton_template(layout, *, enable_int32=False):
  817. layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
  818. if enable_int32:
  819. layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
  820. return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
  821. "TRITON"
  822. )
  823. def use_cutlass_template(layout, m, n, k):
  824. from .virtualized import V
  825. gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
  826. if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size:
  827. return False
  828. from .codegen.cuda.cutlass_utils import try_import_cutlass
  829. # Do not use cutlass template on ROCm
  830. if torch.version.hip:
  831. return False
  832. layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
  833. res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
  834. "CUTLASS"
  835. )
  836. if res:
  837. if not try_import_cutlass():
  838. log.warning(
  839. "Failed to import CUTLASS lib. Please check whether "
  840. "_inductor.config.cuda.cutlass_dir is set correctly. "
  841. "Skipping CUTLASS backend for now."
  842. )
  843. return False
  844. return res
  845. def _use_template_for_cpu(layout):
  846. return use_max_autotune() and layout.device.type == "cpu"
  847. def use_cpp_packed_gemm_template(layout, mat1, mat2):
  848. from . import ir
  849. from .codegen.cpp_micro_gemm import create_micro_gemm
  850. from .kernel.mm_common import mm_args
  851. if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"):
  852. return False
  853. if not config.cpp.weight_prepack:
  854. return False
  855. layout_dtypes = [torch.float32]
  856. m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2)
  857. # TODO(jgong5): support dynamic shapes for n or k
  858. if has_free_symbols((n, k)):
  859. return False
  860. if isinstance(mat2, ir.BaseView):
  861. mat2 = mat2.unwrap_view()
  862. micro_gemm = create_micro_gemm(
  863. "micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads()
  864. )
  865. # TODO(jgong5): support n % n_block_size != 0
  866. return (
  867. layout.dtype in layout_dtypes
  868. and micro_gemm is not None
  869. and n % micro_gemm.register_blocking[1] == 0
  870. and mat1.get_stride()[-1] == 1 # TODO(jgong5): support transposed input
  871. and isinstance(mat2, ir.StorageBox)
  872. and mat2.is_module_buffer()
  873. )
  874. def use_aten_gemm_kernels():
  875. return not use_max_autotune() or _use_autotune_backend("ATEN")
  876. class DebugDirManager:
  877. counter = itertools.count(0)
  878. prev_debug_name: str
  879. def __init__(self):
  880. self.id = next(DebugDirManager.counter)
  881. def __enter__(self):
  882. self.prev_debug_name = torch._dynamo.config.debug_dir_root
  883. self.new_name = f"{self.prev_debug_name}_tmp_{self.id}"
  884. torch._dynamo.config.debug_dir_root = self.new_name
  885. def __exit__(self, *args):
  886. shutil.rmtree(self.new_name)
  887. torch._dynamo.config.debug_dir_root = self.prev_debug_name
  888. def run_and_get_code(fn, *args, **kwargs):
  889. from .graph import GraphLowering
  890. compile_to_module = GraphLowering.compile_to_module
  891. source_codes: List[str] = []
  892. def patched_compile_to_module(self):
  893. mod = compile_to_module(self)
  894. with open(mod.__file__) as f:
  895. source_codes.append(f.read())
  896. return mod
  897. # If FX code caching is enabled, a hit prevents getting the code.
  898. with config.patch({"fx_graph_cache": False}):
  899. with mock.patch.object(
  900. GraphLowering, "compile_to_module", patched_compile_to_module
  901. ):
  902. torch._dynamo.reset()
  903. result = fn(*args, **kwargs)
  904. return result, source_codes
  905. def get_code(fn, *args, **kwargs):
  906. """Get the inductor-generated code, but skip any actual compilation or running."""
  907. from .graph import GraphLowering
  908. source_codes: List[str] = []
  909. def patched_compile_to_module(self: GraphLowering):
  910. class DummyModule:
  911. """This is empty to replace the generated triton module"""
  912. def __init__(self):
  913. pass
  914. def call(self, *args, **kwargs):
  915. # Don't do anything when called
  916. pass
  917. code, _ = (
  918. self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  919. )
  920. # Skip all the actual compiling.
  921. source_codes.append(code)
  922. return DummyModule()
  923. # If FX code caching is enabled, a hit prevents getting the code.
  924. with config.patch({"fx_graph_cache": False}):
  925. with mock.patch.object(
  926. GraphLowering, "compile_to_module", patched_compile_to_module
  927. ):
  928. torch._dynamo.reset()
  929. # Note the return here is None
  930. _ = fn(*args, **kwargs)
  931. return source_codes
  932. def get_triton_code(fn, *args, **kwargs):
  933. source_codes = get_code(fn, *args, **kwargs)
  934. # Can have two outputs if backwards was eagerly compiled
  935. assert (
  936. 1 <= len(source_codes) <= 2
  937. ), f"expected one or two code outputs got {len(source_codes)}"
  938. return source_codes[0]
  939. def run_and_get_triton_code(fn, *args, **kwargs):
  940. _, source_codes = run_and_get_code(fn, *args, **kwargs)
  941. # Can have two outputs if backwards was eagerly compiled
  942. assert (
  943. 1 <= len(source_codes) <= 2
  944. ), f"expected one or two code outputs got {len(source_codes)}"
  945. return source_codes[0]
  946. @contextlib.contextmanager
  947. def override_lowering(aten_op, override_fn):
  948. """
  949. Override the lowering of aten_op with override_fn.
  950. The first argument of override_fn is the original lowering fn.
  951. """
  952. from torch._inductor import lowering
  953. orig_fn = lowering.lowerings[aten_op]
  954. try:
  955. lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn)
  956. yield
  957. finally:
  958. lowering.lowerings[aten_op] = orig_fn
  959. def add_scheduler_init_hook(pre_fn, post_fn=None):
  960. """
  961. Add hook functions to be called at the beginning and end of Scheduler.__init__.
  962. Used for unit tests.
  963. """
  964. from torch._inductor.scheduler import Scheduler
  965. orig_fn = Scheduler.__init__
  966. def wrapper(scheduler, nodes):
  967. pre_fn(scheduler, nodes)
  968. out = orig_fn(scheduler, nodes)
  969. if post_fn:
  970. post_fn(scheduler, nodes)
  971. return out
  972. return unittest.mock.patch.object(Scheduler, "__init__", wrapper)
  973. def developer_warning(msg):
  974. """
  975. Warnings that will be actionable for PyTorch developers, but not
  976. end users. Allows us to easily disable them in stable releases but
  977. keep them on for nightly builds.
  978. """
  979. if config.developer_warnings:
  980. log.warning(msg)
  981. else:
  982. log.info(msg)
  983. def get_benchmark_name():
  984. """
  985. An experimental API used only when config.benchmark_kernel is true.
  986. The benchmark name is only available at codegen time. So we can not
  987. directly call it in benchmark_all_kernels which is run after codegen.
  988. The function assumes the argument after --only is the benchmark name.
  989. It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc
  990. scripts, this function may return None.
  991. There are 2 flavors of --only argument we need handle:
  992. 1. --only model_name
  993. 2. --only=model_name
  994. """
  995. try:
  996. idx = sys.argv.index("--only")
  997. if (
  998. idx + 1 < len(sys.argv)
  999. and len(sys.argv[idx + 1]) > 0
  1000. and sys.argv[idx + 1][0] != "-"
  1001. ):
  1002. return sys.argv[idx + 1]
  1003. except ValueError:
  1004. pass
  1005. for arg in sys.argv:
  1006. if arg.startswith("--only="):
  1007. return arg[len("--only=") :]
  1008. def is_ones(items):
  1009. return all(x == 1 for x in items)
  1010. def is_zeros(items):
  1011. return all(x == 0 for x in items)
  1012. def is_cpu_device(inputs):
  1013. return all(
  1014. item.device == torch.device("cpu")
  1015. for item in inputs
  1016. if isinstance(item, torch.Tensor)
  1017. )
  1018. def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
  1019. assert isinstance(
  1020. val, sympy.Expr
  1021. ), "only support sympy.Expr as input to get_sympy_Expr_dtype"
  1022. if val.is_integer: # type: ignore[attr-defined]
  1023. return torch.int64
  1024. else:
  1025. return torch.float64
  1026. @contextlib.contextmanager
  1027. def maybe_profile(should_profile, *args, **kwargs):
  1028. if should_profile:
  1029. with torch.profiler.profile(*args, **kwargs) as p:
  1030. yield p
  1031. else:
  1032. yield
  1033. def parallel_num_threads():
  1034. threads = config.cpp.threads
  1035. if threads < 1:
  1036. threads = torch.get_num_threads()
  1037. return threads
  1038. @functools.lru_cache(None)
  1039. def get_device_tflops(dtype):
  1040. from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
  1041. assert dtype in (torch.float16, torch.bfloat16, torch.float32)
  1042. if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
  1043. # Triton API change in https://github.com/openai/triton/pull/2293
  1044. from torch._utils_internal import max_clock_rate
  1045. sm_clock = max_clock_rate()
  1046. if dtype in (torch.float16, torch.bfloat16):
  1047. return get_max_tensorcore_tflops(dtype, sm_clock)
  1048. if torch.backends.cuda.matmul.allow_tf32:
  1049. return get_max_tensorcore_tflops(torch.float32, sm_clock)
  1050. else:
  1051. return get_max_simd_tflops(torch.float32, sm_clock)
  1052. else:
  1053. if dtype in (torch.float16, torch.bfloat16):
  1054. return get_max_tensorcore_tflops(dtype)
  1055. if torch.backends.cuda.matmul.allow_tf32:
  1056. return get_max_tensorcore_tflops(torch.float32)
  1057. else:
  1058. return get_max_simd_tflops(torch.float32)
  1059. @functools.lru_cache(None)
  1060. def get_gpu_dram_gbps():
  1061. from triton.testing import get_dram_gbps
  1062. return get_dram_gbps()
  1063. def get_gpu_shared_memory():
  1064. from triton.runtime import driver
  1065. return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)
  1066. def is_welford_reduction(reduction_type):
  1067. return reduction_type.startswith("welford")
  1068. def reduction_num_outputs(reduction_type):
  1069. return 3 if is_welford_reduction(reduction_type) else 1
  1070. def is_linux() -> bool:
  1071. return platform.system() == "Linux"
  1072. def has_free_symbols(itr: Iterable[Any]):
  1073. return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr)
  1074. def is_dynamic(*args):
  1075. from . import ir
  1076. for t in args:
  1077. if isinstance(t, ir.TensorBox):
  1078. if has_free_symbols(t.data.get_size()) or (
  1079. hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride())
  1080. ):
  1081. return True
  1082. elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)):
  1083. assert hasattr(t, "get_size") and hasattr(t, "get_stride")
  1084. if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()):
  1085. return True
  1086. elif not isinstance(t, ir.IRNode):
  1087. continue
  1088. else:
  1089. raise TypeError(f"unexpected type for is_dynamic {type(t)}")
  1090. return False
  1091. # Placeholder strings used in triton codegen.
  1092. class Placeholder(enum.Enum):
  1093. # The placeholder for the actual name of a triton kernel.
  1094. # e.g. for "def triton_" it would be "triton_"
  1095. KERNEL_NAME = "KERNEL_NAME"
  1096. # The descriptive name of the triton kernel; when unique_kernel_names = False, this
  1097. # placeholder will be replaced with a string with more information.
  1098. DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
  1099. def pass_execution_and_save(func, gm, inp, msg):
  1100. from .pattern_matcher import stable_topological_sort
  1101. with tempfile.NamedTemporaryFile(
  1102. mode="w",
  1103. encoding="utf-8",
  1104. delete=False,
  1105. ) as f:
  1106. before_io = io.StringIO()
  1107. after_io = io.StringIO()
  1108. ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp)
  1109. print(f"Before:\n{gm.graph}", file=f)
  1110. print(gm.graph, file=before_io)
  1111. start_time = datetime.now()
  1112. func(gm.graph)
  1113. time_elapsed = datetime.now() - start_time
  1114. # recompile graph
  1115. stable_topological_sort(gm.graph)
  1116. gm.graph.lint()
  1117. gm.recompile()
  1118. print(f"After:\n{gm.graph}", file=f)
  1119. print(gm.graph, file=after_io)
  1120. t = before_io.getvalue() == after_io.getvalue()
  1121. log.info(
  1122. "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s",
  1123. msg,
  1124. f.name,
  1125. t,
  1126. time_elapsed,
  1127. )
  1128. def is_collective(node):
  1129. from . import ir
  1130. return type(node) == ir._CollectiveKernel
  1131. def is_wait(node):
  1132. from . import ir
  1133. return type(node) == ir._WaitKernel
  1134. def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int):
  1135. "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)"
  1136. num_rng_seed_offset_inputs = (
  1137. 2 if torch._functorch.config.functionalize_rng_ops else 0
  1138. )
  1139. return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs
  1140. def count_tangents(fx_g: torch.fx.GraphModule):
  1141. """
  1142. Infers which inputs are static for a backwards graph
  1143. """
  1144. def is_saved_tensor(x):
  1145. return (
  1146. "tangents" not in x.name
  1147. and "bwd_seed" not in x.name
  1148. and "bwd_base_offset" not in x.name
  1149. )
  1150. arg_count = 0
  1151. static_arg_idxs = []
  1152. for n in fx_g.graph.nodes:
  1153. if n.op == "placeholder":
  1154. if is_saved_tensor(n):
  1155. static_arg_idxs.append(arg_count)
  1156. arg_count += 1
  1157. assert static_arg_idxs == list(range(len(static_arg_idxs)))
  1158. return len(static_arg_idxs)
  1159. @dataclasses.dataclass
  1160. class BoxedBool:
  1161. value: bool
  1162. def __bool__(self):
  1163. return self.value
  1164. @staticmethod
  1165. def disable(obj):
  1166. if isinstance(obj, BoxedBool):
  1167. obj.value = False
  1168. return obj
  1169. return False
  1170. @contextlib.contextmanager
  1171. def collect_defined_kernels(kernel_list):
  1172. from .codegen.wrapper import WrapperCodeGen
  1173. orig_define_kernel = WrapperCodeGen.define_kernel
  1174. def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs):
  1175. nonlocal kernel_list
  1176. kernel_list.append(kernel_code)
  1177. return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs)
  1178. with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel):
  1179. yield
  1180. def get_cloned_parameter_buffer_name(name: str):
  1181. return name + "__original__"
  1182. def is_gpu(device: str):
  1183. return device in ["cuda", "xpu"]
  1184. def device_need_guard(device: str):
  1185. assert isinstance(device, str)
  1186. return is_gpu(device)
  1187. def needs_fallback_due_to_atomic_add_limitations(dtype):
  1188. # tl.atomic_add does NOT support the following types
  1189. return dtype in {torch.int64, torch.bool, torch.bfloat16}
  1190. def use_scatter_fallback(
  1191. op_overload: torch._ops.OpOverload,
  1192. reduction_type,
  1193. self_dtype,
  1194. src_dtype,
  1195. src_device_type,
  1196. src_is_tensor,
  1197. ):
  1198. reduce_ty = (
  1199. "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum"
  1200. )
  1201. return (
  1202. reduction_type not in {None, reduce_ty}
  1203. or (
  1204. src_is_tensor
  1205. and is_gpu(src_device_type)
  1206. and needs_fallback_due_to_atomic_add_limitations(src_dtype)
  1207. )
  1208. or (
  1209. op_overload.overloadpacket == torch.ops.aten.scatter_reduce_
  1210. and reduction_type == "sum"
  1211. and src_is_tensor
  1212. and src_device_type == "cpu"
  1213. and config.cpp.fallback_scatter_reduce_sum
  1214. and (config.cpp.dynamic_threads or parallel_num_threads() != 1)
  1215. )
  1216. or (reduction_type == reduce_ty and self_dtype in {torch.bool, torch.int64})
  1217. or torch.are_deterministic_algorithms_enabled()
  1218. )
  1219. def dump_node_schedule(node_schedule):
  1220. """
  1221. An API that can be used in pdb to dump a node_schedule.
  1222. Right mainly dump the read/write dependencies but can add more as needed.
  1223. """
  1224. from torch._inductor.codegen.simd import DisableReduction, EnableReduction
  1225. from torch._inductor.scheduler import SchedulerNode
  1226. print(f"Node schedule with {len(node_schedule)} nodes")
  1227. for idx, node in enumerate(node_schedule):
  1228. print(f" {idx:3}:")
  1229. if node is EnableReduction:
  1230. print("enable reduction")
  1231. elif node is DisableReduction:
  1232. print("disable reduction")
  1233. elif isinstance(node, SchedulerNode):
  1234. is_red = node.is_reduction()
  1235. print(f"{'red' if is_red else 'pw'} scheduler node")
  1236. if is_red:
  1237. assert node.node is not None
  1238. print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined]
  1239. print("ReadDep:")
  1240. for dep in node.read_writes.reads:
  1241. print(dep)
  1242. print("WriteDep:")
  1243. for dep in node.read_writes.writes:
  1244. print(dep)
  1245. else:
  1246. raise RuntimeError(f"Unrecognized node type: {type(node)}")
  1247. def tensor_is_aligned(tensor: torch.Tensor):
  1248. # See Note: [Input Alignment handling in Inductor]
  1249. # Right now, we don't try to guard on the alignment of the storage offset.
  1250. # When this comment was written, non-symbolic storage_offsets are not guarded on
  1251. # but symbolic storage_offsets are. For consistency, we suppress guard creation
  1252. # upon performing this check: that ensures that we don't add recompiles when we
  1253. # add this logic.
  1254. return (
  1255. tensor.storage_offset() * get_dtype_size(tensor.dtype)
  1256. ) % GPU_ALIGN_BYTES == 0
  1257. def should_assume_input_aligned(example_input: torch.Tensor):
  1258. # See Note: [Input Alignment handling in Inductor]
  1259. # right now, we only care about alignment for cuda tensors.
  1260. if not is_gpu(example_input.device.type):
  1261. return False
  1262. return config.assume_aligned_inputs or tensor_is_aligned(example_input)
  1263. def maybe_get_suppress_shape_guards_ctx():
  1264. # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards()
  1265. # If it's not available, return a nullcontext.
  1266. # If we're dealing with cudagraphs, we might not have a tracing_context
  1267. tracing_context = torch._guards.TracingContext.try_get()
  1268. if not tracing_context:
  1269. return contextlib.nullcontext()
  1270. # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
  1271. shape_env = tracing_context.fake_mode.shape_env
  1272. if not shape_env:
  1273. return contextlib.nullcontext()
  1274. return shape_env.suppress_guards()
  1275. def aoti_eager_cache_dir(namespace: str, device: str):
  1276. return Path(cache_dir()) / "aoti_eager" / namespace / device
  1277. def aoti_eager_op_conf_lock(op_func_name_with_overload: str):
  1278. from filelock import FileLock
  1279. # Avoid circular import
  1280. from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT
  1281. op_conf_lock_file = f"{op_func_name_with_overload}.lock"
  1282. lock_dir = get_lock_dir()
  1283. return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT)
  1284. def load_aoti_eager_cache(ns: str, op_func_name_with_overload: str, device_type: str):
  1285. device_kernel_cache = aoti_eager_cache_dir(ns, device_type)
  1286. op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json"
  1287. if not op_conf.exists():
  1288. return []
  1289. with aoti_eager_op_conf_lock(op_func_name_with_overload):
  1290. with open(op_conf) as f:
  1291. json_data = json.load(f)
  1292. for item in json_data:
  1293. # Get absolution path for kernel library
  1294. kernel_lib_abs_path = device_kernel_cache / item["kernel_path"]
  1295. item["kernel_path"] = kernel_lib_abs_path.as_posix()
  1296. # Check if the kernel library exists
  1297. if not kernel_lib_abs_path.exists():
  1298. return []
  1299. for metadata in item["meta_info"]:
  1300. assert not metadata[
  1301. "is_dynamic"
  1302. ], "Only support static shape for now"
  1303. if metadata["device_type"] == "cpu":
  1304. metadata["device_index"] = -1
  1305. metadata["dtype"] = getattr(torch, metadata["dtype"].split(".")[-1])
  1306. return json_data
  1307. def aoti_compile_with_persistent_cache(
  1308. ns: str,
  1309. op_func_name_with_overload: str,
  1310. device_type: str,
  1311. dynamic: bool,
  1312. f: Callable[..., Any],
  1313. args: Tuple[Any],
  1314. kwargs: Dict[str, Any],
  1315. *,
  1316. dynamic_shapes: Optional[Dict[str, Any]] = None,
  1317. options: Optional[Dict[str, Any]] = None,
  1318. remove_runtime_assertions: bool = False,
  1319. disable_constraint_solver: bool = False,
  1320. ):
  1321. """
  1322. Compile the given function with persistent cache for AOTI eager mode.
  1323. """
  1324. assert not dynamic, "Only support static shape for now"
  1325. type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool}
  1326. supported_scalar_types = tuple(type_to_torch_dtype.keys())
  1327. flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
  1328. if not all(
  1329. isinstance(input, (supported_scalar_types, torch.Tensor))
  1330. for input in flattened_inputs
  1331. ):
  1332. raise NotImplementedError("Only support tensor, int, float, bool for now")
  1333. persistent_cache = aoti_eager_cache_dir(ns, device_type)
  1334. if not persistent_cache.exists():
  1335. persistent_cache.mkdir(parents=True)
  1336. persistent_cache_lib = persistent_cache / "lib"
  1337. if not persistent_cache_lib.exists():
  1338. persistent_cache_lib.mkdir()
  1339. with mock.patch.dict(
  1340. os.environ,
  1341. {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()},
  1342. ):
  1343. try:
  1344. kernel_lib_path = torch._export.aot_compile(
  1345. f,
  1346. args,
  1347. kwargs,
  1348. dynamic_shapes=dynamic_shapes,
  1349. options=options,
  1350. remove_runtime_assertions=remove_runtime_assertions,
  1351. disable_constraint_solver=disable_constraint_solver,
  1352. # Some operations may have non-Tensor parameters like int, float, bool. These
  1353. # non-Tensor parameters will not be the input of the graph. Therefore, we do
  1354. # need to keep the same signature.
  1355. same_signature=False,
  1356. )
  1357. kernel_metadata_items = []
  1358. for input in flattened_inputs:
  1359. # TODO(Eikan): To add dynamic support
  1360. metadata: Dict[str, Any] = {}
  1361. metadata["is_dynamic"] = dynamic
  1362. if isinstance(input, torch.Tensor):
  1363. metadata["device_type"] = f"{input.device.type}"
  1364. if is_cpu_device([input]):
  1365. metadata["device_index"] = -1
  1366. else:
  1367. metadata["device_index"] = input.device.index
  1368. metadata["dtype"] = f"{input.dtype}"
  1369. metadata["sizes"] = list(input.size())
  1370. metadata["strides"] = list(input.stride())
  1371. else:
  1372. assert isinstance(input, supported_scalar_types)
  1373. # Scalar tensor
  1374. metadata["device_type"] = device_type
  1375. metadata["device_index"] = -1 if device_type == "cpu" else 0
  1376. metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}"
  1377. metadata["sizes"] = []
  1378. metadata["strides"] = []
  1379. metadata["scalar_value"] = input
  1380. kernel_metadata_items.append(metadata)
  1381. kernel_meta_info: Dict[str, Any] = {}
  1382. kernel_meta_info["meta_info"] = kernel_metadata_items
  1383. kernel_meta_info["kernel_path"] = (
  1384. Path(kernel_lib_path).relative_to(persistent_cache).as_posix()
  1385. )
  1386. json_data = []
  1387. update_json = True
  1388. op_conf = persistent_cache / f"{op_func_name_with_overload}.json"
  1389. mode = "r" if op_conf.exists() else "w"
  1390. with aoti_eager_op_conf_lock(op_func_name_with_overload):
  1391. with open(op_conf, mode) as op_conf_file:
  1392. try:
  1393. json_data = json.load(op_conf_file)
  1394. except Exception as e:
  1395. json_data = []
  1396. assert isinstance(json_data, list)
  1397. for item in json_data:
  1398. assert isinstance(item, dict)
  1399. # Same kernel meta info already exists in the json file
  1400. if item["meta_info"] == kernel_metadata_items:
  1401. update_json = False
  1402. break
  1403. if update_json:
  1404. json_data.append(kernel_meta_info)
  1405. with open(op_conf, "w") as op_conf_file:
  1406. json.dump(json_data, op_conf_file, indent=4)
  1407. return kernel_lib_path
  1408. except Exception as e:
  1409. return ""
  1410. def run_and_get_cpp_code(fn, *args, **kwargs):
  1411. # We use the patch context manager instead of using it as a decorator.
  1412. # In this way, we can ensure that the attribute is patched and unpatched correctly
  1413. # even if this run_and_get_cpp_code function is called multiple times.
  1414. with unittest.mock.patch.object(config, "debug", True):
  1415. torch._dynamo.reset()
  1416. import io
  1417. import logging
  1418. log_capture_string = io.StringIO()
  1419. ch = logging.StreamHandler(log_capture_string)
  1420. from torch._inductor.graph import output_code_log
  1421. output_code_log.addHandler(ch)
  1422. prev_level = output_code_log.level
  1423. output_code_log.setLevel(logging.DEBUG)
  1424. result = fn(*args, **kwargs)
  1425. s = log_capture_string.getvalue()
  1426. output_code_log.setLevel(prev_level)
  1427. output_code_log.removeHandler(ch)
  1428. return result, s