| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792 |
- # mypy: allow-untyped-defs
- from __future__ import annotations
- import collections
- import contextlib
- import dataclasses
- import enum
- import functools
- import inspect
- import io
- import itertools
- import json
- import logging
- import math
- import operator
- import os
- import platform
- import shutil
- import sys
- import tempfile
- import textwrap
- import time
- import unittest
- from datetime import datetime
- from io import StringIO
- from pathlib import Path
- from typing import (
- Any,
- Callable,
- Dict,
- Generic,
- Iterable,
- List,
- NamedTuple,
- Optional,
- Protocol,
- Set,
- Tuple,
- TypeVar,
- Union,
- ValuesView,
- )
- from typing_extensions import Concatenate, ParamSpec
- from unittest import mock
- import sympy
- import torch
- import torch._export
- import torch.utils._pytree as pytree
- from torch._dynamo.device_interface import get_interface_for_device
- from torch._dynamo.utils import detect_fake_mode
- from torch.autograd import DeviceType
- from torch.autograd.profiler_util import EventList
- from torch.fx.passes.shape_prop import ShapeProp
- from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
- from torch.utils._sympy.symbol import make_symbol, SymT
- from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
- from . import config
- from .runtime.runtime_utils import cache_dir, ceildiv as runtime_ceildiv
- log = logging.getLogger(__name__)
- _T = TypeVar("_T")
- VarRanges = Dict[sympy.Expr, sympy.Expr]
- GPU_ALIGN_BYTES = 16
- ALIGN_BYTES = 64
- assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2"
- def _align(nbytes):
- """Round up to the nearest multiple of ALIGN_BYTES"""
- return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES
- def _is_aligned(v: sympy.Expr):
- """v can be statically proven to be a multiple of ALIGN_BYTES"""
- if isinstance(v, (sympy.Add, sympy.Max)):
- return all(map(_is_aligned, v.args))
- return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES
- class align(sympy.Function):
- """Symbolically round up to the nearest multiple of ALIGN_BYTES"""
- nargs = (1,)
- is_integer = True
- @classmethod
- def eval(cls, value):
- if isinstance(value, (int, sympy.Integer)):
- return _align(int(value))
- if _is_aligned(value):
- return value
- def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float:
- """
- Returns benchmark results by examining torch profiler events.
- This could be more accurate as it doesn't count CPU side overhead.
- However, this also requires manually excluding irrelevant event, e.g.
- vectorized_elementwise_kernel which is used to fill L2 cache,
- various CUDA events, etc, so could also be fragile.
- """
- fn()
- torch.cuda.synchronize()
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
- # Estimate the runtime of the function
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
- start_event.record()
- for _ in range(5):
- cache.zero_()
- fn()
- end_event.record()
- torch.cuda.synchronize()
- estimate_ms = start_event.elapsed_time(end_event) / 5
- # compute number of warmup and repeat
- n_warmup = max(1, int(warmup / estimate_ms))
- n_repeat = max(1, int(rep / estimate_ms))
- # Warm-up
- for _ in range(n_warmup):
- fn()
- with torch.profiler.profile(
- activities=[
- torch.profiler.ProfilerActivity.CUDA,
- ]
- ) as p:
- # Benchmark
- for i in range(n_repeat):
- # we clear the L2 cache before each run
- cache.zero_()
- # record time of `fn`
- fn()
- # Record clocks
- torch.cuda.synchronize()
- log.debug("raw events")
- log.debug(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
- filtered_events = EventList(
- [
- event
- for event in p.events()
- if event.device_type == DeviceType.CUDA and event.name != "Context Sync"
- ]
- )
- if len(filtered_events) % n_repeat != 0:
- raise RuntimeError(
- "Failed to divide all profiling events into #repeat groups. "
- "#CUDA events: %d, #repeats: %s",
- len(filtered_events),
- n_repeat,
- )
- num_event_per_group = len(filtered_events) / n_repeat
- actual_events = EventList(
- [
- event
- for i, event in enumerate(filtered_events)
- if i % num_event_per_group != 0
- ]
- )
- actual_events._build_tree()
- actual_events = actual_events.key_averages()
- log.debug("profiling time breakdown")
- log.debug(actual_events.table(row_limit=-1))
- res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat
- log.debug("profiling results: %s ms", res)
- return res
- @functools.lru_cache(None)
- def has_torchvision_roi_align() -> bool:
- try:
- from torchvision.ops import roi_align # noqa: F401
- torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta")
- return roi_align is not None and hasattr(
- getattr(torch.ops, "torchvision", None), "roi_align"
- )
- except ImportError:
- return False
- except RuntimeError as e:
- assert "torchvision::nms does not exist" in str(e)
- return False
- def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
- if device is None:
- return torch.tensor(0.0).device # default device
- if isinstance(device, str):
- device = torch.device(device)
- if device.type not in ("cpu", "meta") and device.index is None:
- device_interface = get_interface_for_device(device.type)
- return torch.device(device.type, index=device_interface.Worker.current_device())
- return device
- def sympy_product(it):
- return functools.reduce(operator.mul, it, sympy.Integer(1))
- def sympy_dot(seq1, seq2):
- assert len(seq1) == len(seq2)
- return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
- def unique(it: Iterable[_T]) -> ValuesView[_T]:
- return {id(x): x for x in it}.values()
- def ceildiv(
- numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
- ) -> Union[int, sympy.Expr]:
- if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr):
- return CeilDiv(sympy.sympify(numer), sympy.sympify(denom))
- # TODO: There is a bug in a call to this function, to repro:
- # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
- # --amp --only YituTechConvBert --dynamic-shapes
- assert isinstance(numer, int) and isinstance(
- denom, int
- ), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
- return runtime_ceildiv(numer, denom)
- def _type_of(key):
- # Use the function here to get rid of dependencies on the Triton during the codegen.
- # Refer to Triton implementation here:
- # https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238
- # `None` is nullptr. Implicitly convert to *i8.
- if key is None:
- return "*i8"
- dtype_str = str(key).split(".")[-1]
- tys = {
- "bool": "i1",
- "float8e4nv": "fp8e4nv",
- "float8e5": "fp8e5",
- "float8e4b15": "fp8e4b15",
- "float8e4b15x4": "fp8e4b15x4",
- "float8_e4m3fn": "fp8e4nv",
- "float8_e5m2": "fp8e5",
- "float16": "fp16",
- "bfloat16": "bf16",
- "float32": "fp32",
- "float64": "fp64",
- "int8": "i8",
- "int16": "i16",
- "int32": "i32",
- "int64": "i64",
- "uint8": "u8",
- "uint16": "u16",
- "uint32": "u32",
- "uint64": "u64",
- }
- # reinterpret can create triton type
- for v in list(tys.values()):
- tys[v] = v
- return key if isinstance(key, str) else f"*{tys[dtype_str]}"
- def convert_shape_to_inductor(
- lst: Iterable[Union[int, torch.SymInt]]
- ) -> List[sympy.Expr]:
- """
- Gets the shape and stride of a tensor. For non-symbolic tensors, this is
- trivial. But for symbolic tensors, we need to map from SymIntNode into
- sympy.Expr.
- """
- return [
- i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst
- ]
- def convert_shape_to_symint(
- lst: Iterable[Union[int, sympy.Expr]]
- ) -> List[Union[int, torch.SymInt]]:
- """
- Takes a list of shapes from Inductor and converts them into symints (or just
- ints if all shapes are static).
- """
- from .virtualized import V
- return [
- i
- if isinstance(i, int)
- else int(i)
- if isinstance(i, sympy.Integer)
- else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
- for i in lst
- ]
- def is_view(op: torch._ops.OpOverload):
- """
- Does this op overload have aliasing
- """
- assert isinstance(op, torch._ops.OpOverload)
- return any(a.alias_info is not None for a in op._schema.arguments)
- def is_pointwise_use(use):
- if not use.op == "call_function":
- return False
- if not (
- isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
- ):
- return False
- if use.target is operator.getitem or is_view(use.target):
- return all(is_pointwise_use(u) for u in use.users)
- return torch.Tag.pointwise in use.target.tags
- def gen_gm_and_inputs(target, args, kwargs):
- g = torch.fx.Graph()
- g_args = []
- a_args = []
- for n, arg in enumerate(args):
- if isinstance(arg, torch.Tensor):
- g_args.append(g.placeholder(f"arg{n}"))
- a_args.append(arg)
- else:
- g_args.append(arg)
- assert all(not isinstance(x, torch.Tensor) for x in kwargs.values())
- node = g.call_function(target, tuple(g_args), kwargs)
- if (
- len(target._schema.returns) == 1
- and str(target._schema.returns[0].type) == "Tensor"
- ):
- node = (node,)
- g.output(node)
- gm = torch.fx.GraphModule({}, g)
- return gm, a_args
- def synchronize(device: str = "cuda"):
- if device == "cpu":
- return
- device_interface = get_interface_for_device(device)
- if device_interface.is_available():
- device_interface.synchronize()
- def timed(
- model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda"
- ) -> float:
- synchronize(device)
- torch.manual_seed(1337)
- t0 = time.perf_counter()
- for _ in range(times):
- result = model(*example_inputs)
- synchronize(device)
- t1 = time.perf_counter()
- # GC the result after timing
- assert result is not None # type: ignore[possibly-undefined]
- return t1 - t0
- def print_performance(
- fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda"
- ):
- timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)])
- took = torch.median(timings) / times
- print(f"{took / baseline:.6f}")
- return took
- def precompute_method(obj: Any, method: str):
- """Replace obj.method() with a new method that returns a precomputed constant."""
- result = getattr(obj, method)()
- setattr(obj, method, lambda: result)
- def precompute_methods(obj: Any, methods: List[str]):
- """Replace methods with new methods that returns a precomputed constants."""
- for method in methods:
- precompute_method(obj, method)
- def cmp(a, b) -> int:
- return int(a > b) - int(a < b)
- def pad_listlike(x, size):
- if len(x) == 1:
- return type(x)([x[0]]) * size
- else:
- return x
- # Used to ensure that iterating over a set is deterministic
- def tuple_sorted(x):
- if len(x) == 0:
- return []
- def sort_func(elem):
- if isinstance(elem, str):
- return elem
- else:
- # We expect `elem` to be `scheduler.BaseSchedulerNode` type here,
- # but we are not able to do isinstance assert because of circular dependency
- return elem.get_name()
- return sorted(x, key=sort_func)
- P = ParamSpec("P")
- RV = TypeVar("RV", covariant=True)
- class CachedMethod(Protocol, Generic[P, RV]):
- @staticmethod
- def clear_cache(self) -> None:
- ...
- def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV:
- ...
- # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
- def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
- key = f"__{fn.__name__}_cache"
- @functools.wraps(fn)
- def wrapper(self):
- if not hasattr(self, key):
- setattr(self, key, fn(self))
- return getattr(self, key)
- def clear_cache(self):
- if hasattr(self, key):
- delattr(self, key)
- wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
- return wrapper # type: ignore[return-value]
- def aggregate_origins(node_schedule):
- from . import ir
- if isinstance(node_schedule, list):
- return functools.reduce(
- operator.or_,
- [
- node.node.origins
- for node in node_schedule
- if hasattr(node, "node") and node.node
- ],
- set(),
- )
- elif isinstance(node_schedule, ir.ExternKernel):
- return node_schedule.origins
- else:
- return set()
- def get_fused_kernel_name(node_schedule, descriptive_names):
- all_origins = aggregate_origins(node_schedule)
- if descriptive_names == "original_aten":
- # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
- sources = [
- origin.meta["original_aten"]._overloadpacket.__name__
- for origin in all_origins
- if origin.op == "call_function"
- and "original_aten" in origin.meta
- and origin.meta["original_aten"] is not None
- ]
- sources = sorted(set(sources))
- elif descriptive_names == "torch":
- # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
- sources = []
- for origin in all_origins:
- if origin.op == "call_function" and "source_fn_stack" in origin.meta:
- source_fn = origin.meta["source_fn_stack"][-1]
- if isinstance(source_fn[1], str):
- sources.append(source_fn[1])
- else:
- sources.append(source_fn[1].__name__)
- sources = sorted(set(sources))
- elif descriptive_names == "inductor_node":
- sources = [
- origin.name for origin in all_origins if origin.op == "call_function"
- ]
- else:
- raise NotImplementedError
- sources = sources
- return "_".join(["fused"] + sources)
- def get_kernel_metadata(node_schedule, wrapper):
- all_origins = aggregate_origins(node_schedule)
- inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
- from_node_dict = collections.defaultdict(list)
- original_aten_dict = collections.defaultdict(list)
- for node in inductor_nodes:
- if "original_aten" in node.meta and node.meta["original_aten"] is not None:
- key = str(node.meta["original_aten"]._overloadpacket)
- original_aten_dict[key].append(node.name)
- if "from_node" in node.meta:
- key = node.meta["from_node"][0][0]
- from_node_dict[key].append(node.name)
- metadata = (
- f"{wrapper.comment} Source Nodes: [{', '.join(sorted(from_node_dict.keys()))}], "
- f"Original ATen: [{', '.join(sorted(original_aten_dict.keys()))}]"
- )
- # trace back to original node here
- detailed_metadata = []
- for original_node, nodes in sorted(from_node_dict.items()):
- detailed_metadata.append(
- f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}"
- )
- return metadata, "\n".join(detailed_metadata)
- def dominated_nodes(
- initial_queue: Iterable[torch.fx.Node], skip_filter=None
- ) -> Set[torch.fx.Node]:
- """Returns the set of nodes whose values depend on those within initial_queue"""
- initial_queue = list(initial_queue)
- dominated_set = set(initial_queue)
- while initial_queue:
- node = initial_queue.pop()
- for user in node.users:
- if skip_filter and skip_filter(user):
- continue
- if user not in dominated_set:
- dominated_set.add(user)
- initial_queue.append(user)
- return dominated_set
- def gather_origins(args, kwargs):
- import itertools
- from . import ir
- def is_unrealized_node(n):
- if isinstance(n, ir.TensorBox):
- return is_unrealized_node(n.data)
- if isinstance(n, ir.StorageBox):
- return is_unrealized_node(n.data)
- return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise)
- kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)]
- arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)]
- return set(itertools.chain(*arg_origins, *kwarg_origins))
- def sympy_str(expr: sympy.Expr) -> str:
- """
- Normal sympy str is very slow, this is a lot faster. The result are
- somewhat worse, as it doesn't do as much simplification. So don't
- use this for final codegen.
- """
- if isinstance(expr, sympy.Symbol):
- return expr.name
- if isinstance(expr, sympy.Add):
- return " + ".join(map(sympy_str, expr.args))
- if isinstance(expr, sympy.Mul):
- return " * ".join(map(sympy_str, expr.args))
- if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)):
- return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
- return str(expr)
- def get_bounds_index_expr(index):
- from .virtualized import V
- # If this expression does not come from an FX node, we compute its bounds
- if (
- config.compute_all_bounds
- and (fx_node := getattr(V.interpreter, "current_node", None))
- and fx_node.target != "index_expr"
- ):
- return bound_sympy(index)
- else:
- return ValueRanges.unknown()
- def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
- """
- Used to generate an integer-nonnegative symbol.
- """
- # This should never be used for creating shape/stride symbols, as those
- # should all be allocated before Inductor.
- assert prefix != SymT.SIZE
- # NOTE: shape symbols are positive (> 0), but index variables are only
- # non-negative (>= 0).
- return make_symbol(prefix, idx, integer=True, nonnegative=True)
- def generate_assert(check):
- return (check or config.debug_index_asserts) and config.assert_indirect_indexing
- def sympy_index_symbol(name: str) -> sympy.Symbol:
- """
- Used to generate an integer-nonnegative symbol.
- """
- # This should never be used for creating shape/stride symbols, as those
- # should all be allocated before Inductor.
- assert name[0] != "s"
- # NOTE: shape symbols are positive (> 0), but index variables are only
- # non-negative (>= 0).
- return sympy.Symbol(name, integer=True, nonnegative=True)
- def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr:
- """
- When the passed replacement symbol v is a string, it is converted to a symbol with name v that
- have the same replaced expression integer and nonnegative properties.
- """
- def to_symbol(replaced, replacement):
- assert isinstance(replaced, sympy.Expr)
- if isinstance(replacement, str):
- return sympy.Symbol(
- replacement,
- integer=replaced.is_integer, # type: ignore[attr-defined]
- nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
- )
- else:
- return replacement
- # xreplace is faster than subs, but is way more picky
- return sympy.sympify(expr).xreplace(
- {k: to_symbol(k, v) for k, v in replacements.items()}
- )
- def is_symbolic(a: Any) -> bool:
- return isinstance(a, torch.SymInt) or (
- isinstance(a, torch.Tensor)
- and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride()))
- )
- def any_is_symbolic(*args: Any) -> bool:
- return any(is_symbolic(a) for a in args)
- def get_first_incompatible_cudagraph_node(gm):
- from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
- forbidden_set = {
- "aten._fused_moving_avg_obs_fq_helper.default",
- "aten._fused_moving_avg_obs_fq_helper_functional.default",
- "aten.multinomial.default",
- "fbgemm.dense_to_jagged.default",
- "fbgemm.jagged_to_padded_dense.default",
- "run_and_save_rng_state",
- "run_with_rng_state",
- "aten._local_scalar_dense",
- # Technically, it's not necessary to ban this, because an
- # assert_scalar with constant arguments can be validly run
- # with CUDA graphs, but the operator is also pointless with
- # constant arguments, so might as well ban
- "aten._assert_scalar",
- }
- if torch.are_deterministic_algorithms_enabled():
- forbidden_set.update(
- {
- "aten._unsafe_index_put.default",
- "aten.index_put.default",
- "aten.index_put_.default",
- "aten.scatter.src",
- "aten.scatter.reduce",
- "aten.scatter.value_reduce",
- "aten.scatter_add_",
- "aten.scatter_add.default",
- "aten.scatter_reduce.two",
- "aten.scatter_reduce_.two",
- "aten.scatter_reduce.two_out",
- }
- )
- for node in gm.graph.nodes:
- if str(node.target) in forbidden_set:
- return node
- if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
- return node
- return None
- def has_incompatible_cudagraph_ops(gm):
- return get_first_incompatible_cudagraph_node(gm) is not None
- def output_node(gm: torch.fx.GraphModule):
- """Get the output node from an FX graph"""
- last_node = next(iter(reversed(gm.graph.nodes)))
- assert last_node.op == "output"
- return last_node
- _registered_caches: List[Any] = []
- def clear_on_fresh_inductor_cache(obj: Any):
- """
- Use this decorator to register any caches that should be cache_clear'd
- with fresh_inductor_cache().
- """
- if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
- raise AttributeError(f"{obj} does not have a cache_clear method")
- _registered_caches.append(obj)
- return obj
- def clear_inductor_caches():
- """
- Clear all registered caches.
- """
- for obj in _registered_caches:
- obj.cache_clear()
- @contextlib.contextmanager
- def fresh_inductor_cache(cache_entries=None):
- """
- Contextmanager that provides a clean tmp cachedir for inductor.
- Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
- generated with this cache instance.
- """
- clear_inductor_caches()
- inductor_cache_dir = tempfile.mkdtemp()
- try:
- with mock.patch.dict(
- os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
- ):
- triton_cache_dir = os.path.join(inductor_cache_dir, "triton")
- with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
- yield
- if isinstance(cache_entries, dict):
- assert len(cache_entries) == 0, "expected empty cache_entries dict"
- if os.path.exists(triton_cache_dir):
- files = os.listdir(triton_cache_dir)
- cache_entries.update(
- {
- f: os.path.getsize(os.path.join(triton_cache_dir, f))
- for f in files
- if ".lock" not in f
- }
- )
- shutil.rmtree(inductor_cache_dir)
- except Exception:
- log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
- raise
- finally:
- clear_inductor_caches()
- def argsort(seq) -> List[int]:
- # preserve original order for equal strides
- getter = seq.__getitem__
- a_r = range(len(seq))
- return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
- @functools.lru_cache(8)
- def get_dtype_size(dtype):
- return torch.empty((), dtype=dtype).element_size()
- class LineContext(NamedTuple):
- context: Any
- class IndentedBuffer:
- tabwidth = 4
- def __init__(self, initial_indent=0):
- self._lines = []
- self._indent = initial_indent
- def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]:
- buf = StringIO()
- p = 1
- linemap = []
- for line in self._lines:
- if isinstance(line, DeferredLineBase):
- line = line()
- if line is None:
- continue
- elif isinstance(line, LineContext):
- linemap.append((p, line.context))
- continue
- assert isinstance(line, str)
- buf.write(line)
- buf.write("\n")
- p += 1 + line.count("\n")
- return buf.getvalue(), linemap
- def getvalue(self) -> str:
- v, _ = self.getvaluewithlinemap()
- return v
- def getrawvalue(self) -> str:
- buf = StringIO()
- for line in self._lines:
- if isinstance(line, DeferredLineBase):
- line = line()
- if line is None:
- continue
- elif isinstance(line, LineContext):
- continue
- assert isinstance(line, str)
- # backslash implies line continuation
- if line.endswith("\\"):
- buf.write(line[:-1])
- else:
- buf.write(line)
- buf.write("\n")
- return buf.getvalue()
- def clear(self):
- self._lines.clear()
- def __bool__(self):
- return bool(self._lines)
- def prefix(self):
- return " " * (self._indent * self.tabwidth)
- def newline(self):
- self.writeline("\n")
- def writeline(self, line):
- if isinstance(line, LineContext):
- self._lines.append(line)
- elif isinstance(line, DeferredLineBase):
- self._lines.append(line.with_prefix(self.prefix()))
- elif line.strip():
- self._lines.append(f"{self.prefix()}{line}")
- else:
- self._lines.append("")
- def writelines(self, lines):
- for line in lines:
- self.writeline(line)
- def indent(self, offset=1):
- @contextlib.contextmanager
- def ctx():
- self._indent += offset
- try:
- yield
- finally:
- self._indent -= offset
- return ctx()
- def do_indent(self, offset=1):
- self._indent += offset
- def do_unindent(self, offset=1):
- self._indent -= offset
- def splice(self, other_code, strip=False):
- if isinstance(other_code, IndentedBuffer):
- dedent = float("inf")
- for line in other_code._lines:
- if not isinstance(line, LineContext) and line:
- dedent = min(dedent, len(line) - len(line.lstrip()))
- if math.isinf(dedent):
- dedent = 0
- for line in other_code._lines:
- if isinstance(line, LineContext):
- self._lines.append(line)
- else:
- IndentedBuffer.writeline(self, line[int(dedent) :])
- else:
- other_code = textwrap.dedent(other_code)
- if strip:
- other_code = other_code.lstrip()
- if not other_code:
- return
- other_code = other_code.rstrip()
- for line in other_code.split("\n"):
- self.writeline(line)
- def map(self, func: Callable[[Any], Any]) -> IndentedBuffer:
- res = IndentedBuffer(initial_indent=self._indent)
- res._lines = [func(line) for line in self._lines]
- return res
- def __repr__(self):
- return f"{type(self)}({self.getvalue()})"
- def __add__(self, other):
- assert self._indent == other._indent
- res = IndentedBuffer(initial_indent=self._indent)
- res.writelines(self._lines)
- res.writelines(other._lines)
- return res
- class FakeIndentedBuffer(IndentedBuffer):
- def __init__(self):
- super().__init__()
- def __getattribute__(self, name):
- if name == "__class__": # Allow access to the class attribute
- return object.__getattribute__(self, name)
- raise RuntimeError(
- f"Tried to call self.{name} on FakeIndentedBuffer. This buffer"
- "is currently used on TritonTemplateKernel to prevent actual"
- "writes to the body without explicitly specifying the body with"
- "`TritonTemplateKernel.set_subgraph_body(name)`"
- )
- @contextlib.contextmanager
- def restore_stdout_stderr(initial_stdout, initial_stderr):
- try:
- yield
- finally:
- sys.stdout = initial_stdout
- sys.stderr = initial_stderr
- class DeferredLineBase:
- """A line that can be 'unwritten' at a later time"""
- def __init__(self, line):
- if not line.strip():
- line = ""
- self.line = line
- def __call__(self) -> Optional[str]:
- """Returns either self.line or None to indicate the line has been 'unwritten'"""
- raise NotImplementedError
- def _new_line(self, line: str) -> DeferredLineBase:
- """Returns a new deferred line with the same condition"""
- raise NotImplementedError
- def with_prefix(self, prefix):
- return self._new_line(f"{prefix}{self.line}")
- def lstrip(self):
- return self._new_line(self.line.lstrip())
- def __getitem__(self, index):
- return self._new_line(self.line[index])
- def __bool__(self):
- return bool(self.line)
- def __len__(self):
- return len(self.line)
- @functools.lru_cache(None)
- def is_big_gpu(index) -> bool:
- min_sms = 68 # 3080
- avail_sms = torch.cuda.get_device_properties(index).multi_processor_count
- if avail_sms < min_sms:
- log.warning(
- "Not enough SMs to use max_autotune_gemm mode",
- extra={"min_sms": min_sms, "avail_sms": avail_sms},
- )
- return False
- return True
- def use_max_autotune() -> bool:
- return (
- config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache
- )
- def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool:
- return (
- use_max_autotune()
- and layout.device.type == "cuda"
- and layout.dtype in allowed_layout_dtypes
- and is_big_gpu(layout.device.index or 0)
- )
- def _use_autotune_backend(backend: str) -> bool:
- return backend.upper() in [
- x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
- ]
- def use_triton_template(layout, *, enable_int32=False):
- layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
- if enable_int32:
- layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
- return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
- "TRITON"
- )
- def use_cutlass_template(layout, m, n, k):
- from .virtualized import V
- gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
- if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size:
- return False
- from .codegen.cuda.cutlass_utils import try_import_cutlass
- # Do not use cutlass template on ROCm
- if torch.version.hip:
- return False
- layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
- res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
- "CUTLASS"
- )
- if res:
- if not try_import_cutlass():
- log.warning(
- "Failed to import CUTLASS lib. Please check whether "
- "_inductor.config.cuda.cutlass_dir is set correctly. "
- "Skipping CUTLASS backend for now."
- )
- return False
- return res
- def _use_template_for_cpu(layout):
- return use_max_autotune() and layout.device.type == "cpu"
- def use_cpp_packed_gemm_template(layout, mat1, mat2):
- from . import ir
- from .codegen.cpp_micro_gemm import create_micro_gemm
- from .kernel.mm_common import mm_args
- if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"):
- return False
- if not config.cpp.weight_prepack:
- return False
- layout_dtypes = [torch.float32]
- m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2)
- # TODO(jgong5): support dynamic shapes for n or k
- if has_free_symbols((n, k)):
- return False
- if isinstance(mat2, ir.BaseView):
- mat2 = mat2.unwrap_view()
- micro_gemm = create_micro_gemm(
- "micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads()
- )
- # TODO(jgong5): support n % n_block_size != 0
- return (
- layout.dtype in layout_dtypes
- and micro_gemm is not None
- and n % micro_gemm.register_blocking[1] == 0
- and mat1.get_stride()[-1] == 1 # TODO(jgong5): support transposed input
- and isinstance(mat2, ir.StorageBox)
- and mat2.is_module_buffer()
- )
- def use_aten_gemm_kernels():
- return not use_max_autotune() or _use_autotune_backend("ATEN")
- class DebugDirManager:
- counter = itertools.count(0)
- prev_debug_name: str
- def __init__(self):
- self.id = next(DebugDirManager.counter)
- def __enter__(self):
- self.prev_debug_name = torch._dynamo.config.debug_dir_root
- self.new_name = f"{self.prev_debug_name}_tmp_{self.id}"
- torch._dynamo.config.debug_dir_root = self.new_name
- def __exit__(self, *args):
- shutil.rmtree(self.new_name)
- torch._dynamo.config.debug_dir_root = self.prev_debug_name
- def run_and_get_code(fn, *args, **kwargs):
- from .graph import GraphLowering
- compile_to_module = GraphLowering.compile_to_module
- source_codes: List[str] = []
- def patched_compile_to_module(self):
- mod = compile_to_module(self)
- with open(mod.__file__) as f:
- source_codes.append(f.read())
- return mod
- # If FX code caching is enabled, a hit prevents getting the code.
- with config.patch({"fx_graph_cache": False}):
- with mock.patch.object(
- GraphLowering, "compile_to_module", patched_compile_to_module
- ):
- torch._dynamo.reset()
- result = fn(*args, **kwargs)
- return result, source_codes
- def get_code(fn, *args, **kwargs):
- """Get the inductor-generated code, but skip any actual compilation or running."""
- from .graph import GraphLowering
- source_codes: List[str] = []
- def patched_compile_to_module(self: GraphLowering):
- class DummyModule:
- """This is empty to replace the generated triton module"""
- def __init__(self):
- pass
- def call(self, *args, **kwargs):
- # Don't do anything when called
- pass
- code, _ = (
- self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
- )
- # Skip all the actual compiling.
- source_codes.append(code)
- return DummyModule()
- # If FX code caching is enabled, a hit prevents getting the code.
- with config.patch({"fx_graph_cache": False}):
- with mock.patch.object(
- GraphLowering, "compile_to_module", patched_compile_to_module
- ):
- torch._dynamo.reset()
- # Note the return here is None
- _ = fn(*args, **kwargs)
- return source_codes
- def get_triton_code(fn, *args, **kwargs):
- source_codes = get_code(fn, *args, **kwargs)
- # Can have two outputs if backwards was eagerly compiled
- assert (
- 1 <= len(source_codes) <= 2
- ), f"expected one or two code outputs got {len(source_codes)}"
- return source_codes[0]
- def run_and_get_triton_code(fn, *args, **kwargs):
- _, source_codes = run_and_get_code(fn, *args, **kwargs)
- # Can have two outputs if backwards was eagerly compiled
- assert (
- 1 <= len(source_codes) <= 2
- ), f"expected one or two code outputs got {len(source_codes)}"
- return source_codes[0]
- @contextlib.contextmanager
- def override_lowering(aten_op, override_fn):
- """
- Override the lowering of aten_op with override_fn.
- The first argument of override_fn is the original lowering fn.
- """
- from torch._inductor import lowering
- orig_fn = lowering.lowerings[aten_op]
- try:
- lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn)
- yield
- finally:
- lowering.lowerings[aten_op] = orig_fn
- def add_scheduler_init_hook(pre_fn, post_fn=None):
- """
- Add hook functions to be called at the beginning and end of Scheduler.__init__.
- Used for unit tests.
- """
- from torch._inductor.scheduler import Scheduler
- orig_fn = Scheduler.__init__
- def wrapper(scheduler, nodes):
- pre_fn(scheduler, nodes)
- out = orig_fn(scheduler, nodes)
- if post_fn:
- post_fn(scheduler, nodes)
- return out
- return unittest.mock.patch.object(Scheduler, "__init__", wrapper)
- def developer_warning(msg):
- """
- Warnings that will be actionable for PyTorch developers, but not
- end users. Allows us to easily disable them in stable releases but
- keep them on for nightly builds.
- """
- if config.developer_warnings:
- log.warning(msg)
- else:
- log.info(msg)
- def get_benchmark_name():
- """
- An experimental API used only when config.benchmark_kernel is true.
- The benchmark name is only available at codegen time. So we can not
- directly call it in benchmark_all_kernels which is run after codegen.
- The function assumes the argument after --only is the benchmark name.
- It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc
- scripts, this function may return None.
- There are 2 flavors of --only argument we need handle:
- 1. --only model_name
- 2. --only=model_name
- """
- try:
- idx = sys.argv.index("--only")
- if (
- idx + 1 < len(sys.argv)
- and len(sys.argv[idx + 1]) > 0
- and sys.argv[idx + 1][0] != "-"
- ):
- return sys.argv[idx + 1]
- except ValueError:
- pass
- for arg in sys.argv:
- if arg.startswith("--only="):
- return arg[len("--only=") :]
- def is_ones(items):
- return all(x == 1 for x in items)
- def is_zeros(items):
- return all(x == 0 for x in items)
- def is_cpu_device(inputs):
- return all(
- item.device == torch.device("cpu")
- for item in inputs
- if isinstance(item, torch.Tensor)
- )
- def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
- assert isinstance(
- val, sympy.Expr
- ), "only support sympy.Expr as input to get_sympy_Expr_dtype"
- if val.is_integer: # type: ignore[attr-defined]
- return torch.int64
- else:
- return torch.float64
- @contextlib.contextmanager
- def maybe_profile(should_profile, *args, **kwargs):
- if should_profile:
- with torch.profiler.profile(*args, **kwargs) as p:
- yield p
- else:
- yield
- def parallel_num_threads():
- threads = config.cpp.threads
- if threads < 1:
- threads = torch.get_num_threads()
- return threads
- @functools.lru_cache(None)
- def get_device_tflops(dtype):
- from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
- assert dtype in (torch.float16, torch.bfloat16, torch.float32)
- if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
- # Triton API change in https://github.com/openai/triton/pull/2293
- from torch._utils_internal import max_clock_rate
- sm_clock = max_clock_rate()
- if dtype in (torch.float16, torch.bfloat16):
- return get_max_tensorcore_tflops(dtype, sm_clock)
- if torch.backends.cuda.matmul.allow_tf32:
- return get_max_tensorcore_tflops(torch.float32, sm_clock)
- else:
- return get_max_simd_tflops(torch.float32, sm_clock)
- else:
- if dtype in (torch.float16, torch.bfloat16):
- return get_max_tensorcore_tflops(dtype)
- if torch.backends.cuda.matmul.allow_tf32:
- return get_max_tensorcore_tflops(torch.float32)
- else:
- return get_max_simd_tflops(torch.float32)
- @functools.lru_cache(None)
- def get_gpu_dram_gbps():
- from triton.testing import get_dram_gbps
- return get_dram_gbps()
- def get_gpu_shared_memory():
- from triton.runtime import driver
- return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)
- def is_welford_reduction(reduction_type):
- return reduction_type.startswith("welford")
- def reduction_num_outputs(reduction_type):
- return 3 if is_welford_reduction(reduction_type) else 1
- def is_linux() -> bool:
- return platform.system() == "Linux"
- def has_free_symbols(itr: Iterable[Any]):
- return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr)
- def is_dynamic(*args):
- from . import ir
- for t in args:
- if isinstance(t, ir.TensorBox):
- if has_free_symbols(t.data.get_size()) or (
- hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride())
- ):
- return True
- elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)):
- assert hasattr(t, "get_size") and hasattr(t, "get_stride")
- if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()):
- return True
- elif not isinstance(t, ir.IRNode):
- continue
- else:
- raise TypeError(f"unexpected type for is_dynamic {type(t)}")
- return False
- # Placeholder strings used in triton codegen.
- class Placeholder(enum.Enum):
- # The placeholder for the actual name of a triton kernel.
- # e.g. for "def triton_" it would be "triton_"
- KERNEL_NAME = "KERNEL_NAME"
- # The descriptive name of the triton kernel; when unique_kernel_names = False, this
- # placeholder will be replaced with a string with more information.
- DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
- def pass_execution_and_save(func, gm, inp, msg):
- from .pattern_matcher import stable_topological_sort
- with tempfile.NamedTemporaryFile(
- mode="w",
- encoding="utf-8",
- delete=False,
- ) as f:
- before_io = io.StringIO()
- after_io = io.StringIO()
- ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp)
- print(f"Before:\n{gm.graph}", file=f)
- print(gm.graph, file=before_io)
- start_time = datetime.now()
- func(gm.graph)
- time_elapsed = datetime.now() - start_time
- # recompile graph
- stable_topological_sort(gm.graph)
- gm.graph.lint()
- gm.recompile()
- print(f"After:\n{gm.graph}", file=f)
- print(gm.graph, file=after_io)
- t = before_io.getvalue() == after_io.getvalue()
- log.info(
- "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s",
- msg,
- f.name,
- t,
- time_elapsed,
- )
- def is_collective(node):
- from . import ir
- return type(node) == ir._CollectiveKernel
- def is_wait(node):
- from . import ir
- return type(node) == ir._WaitKernel
- def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int):
- "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)"
- num_rng_seed_offset_inputs = (
- 2 if torch._functorch.config.functionalize_rng_ops else 0
- )
- return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs
- def count_tangents(fx_g: torch.fx.GraphModule):
- """
- Infers which inputs are static for a backwards graph
- """
- def is_saved_tensor(x):
- return (
- "tangents" not in x.name
- and "bwd_seed" not in x.name
- and "bwd_base_offset" not in x.name
- )
- arg_count = 0
- static_arg_idxs = []
- for n in fx_g.graph.nodes:
- if n.op == "placeholder":
- if is_saved_tensor(n):
- static_arg_idxs.append(arg_count)
- arg_count += 1
- assert static_arg_idxs == list(range(len(static_arg_idxs)))
- return len(static_arg_idxs)
- @dataclasses.dataclass
- class BoxedBool:
- value: bool
- def __bool__(self):
- return self.value
- @staticmethod
- def disable(obj):
- if isinstance(obj, BoxedBool):
- obj.value = False
- return obj
- return False
- @contextlib.contextmanager
- def collect_defined_kernels(kernel_list):
- from .codegen.wrapper import WrapperCodeGen
- orig_define_kernel = WrapperCodeGen.define_kernel
- def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs):
- nonlocal kernel_list
- kernel_list.append(kernel_code)
- return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs)
- with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel):
- yield
- def get_cloned_parameter_buffer_name(name: str):
- return name + "__original__"
- def is_gpu(device: str):
- return device in ["cuda", "xpu"]
- def device_need_guard(device: str):
- assert isinstance(device, str)
- return is_gpu(device)
- def needs_fallback_due_to_atomic_add_limitations(dtype):
- # tl.atomic_add does NOT support the following types
- return dtype in {torch.int64, torch.bool, torch.bfloat16}
- def use_scatter_fallback(
- op_overload: torch._ops.OpOverload,
- reduction_type,
- self_dtype,
- src_dtype,
- src_device_type,
- src_is_tensor,
- ):
- reduce_ty = (
- "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum"
- )
- return (
- reduction_type not in {None, reduce_ty}
- or (
- src_is_tensor
- and is_gpu(src_device_type)
- and needs_fallback_due_to_atomic_add_limitations(src_dtype)
- )
- or (
- op_overload.overloadpacket == torch.ops.aten.scatter_reduce_
- and reduction_type == "sum"
- and src_is_tensor
- and src_device_type == "cpu"
- and config.cpp.fallback_scatter_reduce_sum
- and (config.cpp.dynamic_threads or parallel_num_threads() != 1)
- )
- or (reduction_type == reduce_ty and self_dtype in {torch.bool, torch.int64})
- or torch.are_deterministic_algorithms_enabled()
- )
- def dump_node_schedule(node_schedule):
- """
- An API that can be used in pdb to dump a node_schedule.
- Right mainly dump the read/write dependencies but can add more as needed.
- """
- from torch._inductor.codegen.simd import DisableReduction, EnableReduction
- from torch._inductor.scheduler import SchedulerNode
- print(f"Node schedule with {len(node_schedule)} nodes")
- for idx, node in enumerate(node_schedule):
- print(f" {idx:3}:")
- if node is EnableReduction:
- print("enable reduction")
- elif node is DisableReduction:
- print("disable reduction")
- elif isinstance(node, SchedulerNode):
- is_red = node.is_reduction()
- print(f"{'red' if is_red else 'pw'} scheduler node")
- if is_red:
- assert node.node is not None
- print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined]
- print("ReadDep:")
- for dep in node.read_writes.reads:
- print(dep)
- print("WriteDep:")
- for dep in node.read_writes.writes:
- print(dep)
- else:
- raise RuntimeError(f"Unrecognized node type: {type(node)}")
- def tensor_is_aligned(tensor: torch.Tensor):
- # See Note: [Input Alignment handling in Inductor]
- # Right now, we don't try to guard on the alignment of the storage offset.
- # When this comment was written, non-symbolic storage_offsets are not guarded on
- # but symbolic storage_offsets are. For consistency, we suppress guard creation
- # upon performing this check: that ensures that we don't add recompiles when we
- # add this logic.
- return (
- tensor.storage_offset() * get_dtype_size(tensor.dtype)
- ) % GPU_ALIGN_BYTES == 0
- def should_assume_input_aligned(example_input: torch.Tensor):
- # See Note: [Input Alignment handling in Inductor]
- # right now, we only care about alignment for cuda tensors.
- if not is_gpu(example_input.device.type):
- return False
- return config.assume_aligned_inputs or tensor_is_aligned(example_input)
- def maybe_get_suppress_shape_guards_ctx():
- # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards()
- # If it's not available, return a nullcontext.
- # If we're dealing with cudagraphs, we might not have a tracing_context
- tracing_context = torch._guards.TracingContext.try_get()
- if not tracing_context:
- return contextlib.nullcontext()
- # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
- shape_env = tracing_context.fake_mode.shape_env
- if not shape_env:
- return contextlib.nullcontext()
- return shape_env.suppress_guards()
- def aoti_eager_cache_dir(namespace: str, device: str):
- return Path(cache_dir()) / "aoti_eager" / namespace / device
- def aoti_eager_op_conf_lock(op_func_name_with_overload: str):
- from filelock import FileLock
- # Avoid circular import
- from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT
- op_conf_lock_file = f"{op_func_name_with_overload}.lock"
- lock_dir = get_lock_dir()
- return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT)
- def load_aoti_eager_cache(ns: str, op_func_name_with_overload: str, device_type: str):
- device_kernel_cache = aoti_eager_cache_dir(ns, device_type)
- op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json"
- if not op_conf.exists():
- return []
- with aoti_eager_op_conf_lock(op_func_name_with_overload):
- with open(op_conf) as f:
- json_data = json.load(f)
- for item in json_data:
- # Get absolution path for kernel library
- kernel_lib_abs_path = device_kernel_cache / item["kernel_path"]
- item["kernel_path"] = kernel_lib_abs_path.as_posix()
- # Check if the kernel library exists
- if not kernel_lib_abs_path.exists():
- return []
- for metadata in item["meta_info"]:
- assert not metadata[
- "is_dynamic"
- ], "Only support static shape for now"
- if metadata["device_type"] == "cpu":
- metadata["device_index"] = -1
- metadata["dtype"] = getattr(torch, metadata["dtype"].split(".")[-1])
- return json_data
- def aoti_compile_with_persistent_cache(
- ns: str,
- op_func_name_with_overload: str,
- device_type: str,
- dynamic: bool,
- f: Callable[..., Any],
- args: Tuple[Any],
- kwargs: Dict[str, Any],
- *,
- dynamic_shapes: Optional[Dict[str, Any]] = None,
- options: Optional[Dict[str, Any]] = None,
- remove_runtime_assertions: bool = False,
- disable_constraint_solver: bool = False,
- ):
- """
- Compile the given function with persistent cache for AOTI eager mode.
- """
- assert not dynamic, "Only support static shape for now"
- type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool}
- supported_scalar_types = tuple(type_to_torch_dtype.keys())
- flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
- if not all(
- isinstance(input, (supported_scalar_types, torch.Tensor))
- for input in flattened_inputs
- ):
- raise NotImplementedError("Only support tensor, int, float, bool for now")
- persistent_cache = aoti_eager_cache_dir(ns, device_type)
- if not persistent_cache.exists():
- persistent_cache.mkdir(parents=True)
- persistent_cache_lib = persistent_cache / "lib"
- if not persistent_cache_lib.exists():
- persistent_cache_lib.mkdir()
- with mock.patch.dict(
- os.environ,
- {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()},
- ):
- try:
- kernel_lib_path = torch._export.aot_compile(
- f,
- args,
- kwargs,
- dynamic_shapes=dynamic_shapes,
- options=options,
- remove_runtime_assertions=remove_runtime_assertions,
- disable_constraint_solver=disable_constraint_solver,
- # Some operations may have non-Tensor parameters like int, float, bool. These
- # non-Tensor parameters will not be the input of the graph. Therefore, we do
- # need to keep the same signature.
- same_signature=False,
- )
- kernel_metadata_items = []
- for input in flattened_inputs:
- # TODO(Eikan): To add dynamic support
- metadata: Dict[str, Any] = {}
- metadata["is_dynamic"] = dynamic
- if isinstance(input, torch.Tensor):
- metadata["device_type"] = f"{input.device.type}"
- if is_cpu_device([input]):
- metadata["device_index"] = -1
- else:
- metadata["device_index"] = input.device.index
- metadata["dtype"] = f"{input.dtype}"
- metadata["sizes"] = list(input.size())
- metadata["strides"] = list(input.stride())
- else:
- assert isinstance(input, supported_scalar_types)
- # Scalar tensor
- metadata["device_type"] = device_type
- metadata["device_index"] = -1 if device_type == "cpu" else 0
- metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}"
- metadata["sizes"] = []
- metadata["strides"] = []
- metadata["scalar_value"] = input
- kernel_metadata_items.append(metadata)
- kernel_meta_info: Dict[str, Any] = {}
- kernel_meta_info["meta_info"] = kernel_metadata_items
- kernel_meta_info["kernel_path"] = (
- Path(kernel_lib_path).relative_to(persistent_cache).as_posix()
- )
- json_data = []
- update_json = True
- op_conf = persistent_cache / f"{op_func_name_with_overload}.json"
- mode = "r" if op_conf.exists() else "w"
- with aoti_eager_op_conf_lock(op_func_name_with_overload):
- with open(op_conf, mode) as op_conf_file:
- try:
- json_data = json.load(op_conf_file)
- except Exception as e:
- json_data = []
- assert isinstance(json_data, list)
- for item in json_data:
- assert isinstance(item, dict)
- # Same kernel meta info already exists in the json file
- if item["meta_info"] == kernel_metadata_items:
- update_json = False
- break
- if update_json:
- json_data.append(kernel_meta_info)
- with open(op_conf, "w") as op_conf_file:
- json.dump(json_data, op_conf_file, indent=4)
- return kernel_lib_path
- except Exception as e:
- return ""
- def run_and_get_cpp_code(fn, *args, **kwargs):
- # We use the patch context manager instead of using it as a decorator.
- # In this way, we can ensure that the attribute is patched and unpatched correctly
- # even if this run_and_get_cpp_code function is called multiple times.
- with unittest.mock.patch.object(config, "debug", True):
- torch._dynamo.reset()
- import io
- import logging
- log_capture_string = io.StringIO()
- ch = logging.StreamHandler(log_capture_string)
- from torch._inductor.graph import output_code_log
- output_code_log.addHandler(ch)
- prev_level = output_code_log.level
- output_code_log.setLevel(logging.DEBUG)
- result = fn(*args, **kwargs)
- s = log_capture_string.getvalue()
- output_code_log.setLevel(prev_level)
- output_code_log.removeHandler(ch)
- return result, s
|