| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993 |
- # mypy: allow-untyped-defs
- import contextlib
- import dataclasses
- import functools
- import itertools
- import logging
- import math
- import operator
- import re
- from itertools import chain
- from typing import (
- Any,
- Callable,
- ClassVar,
- Dict,
- List,
- NamedTuple,
- Optional,
- Set,
- Tuple,
- Union,
- )
- import sympy
- from sympy.printing.printer import Printer
- import torch
- import torch.fx
- from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
- from torch.utils import _pytree as pytree
- from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
- from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
- from .. import config, metrics
- from ..utils import (
- DeferredLineBase,
- generate_assert,
- IndentedBuffer,
- sympy_dot,
- sympy_subs,
- unique,
- )
- from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
- schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
- def data_type_logger(msg):
- if schedule_log.isEnabledFor(logging.DEBUG):
- schedule_log.debug("Data type propagation: %s", msg)
- @dataclasses.dataclass
- class WorkspaceArg:
- """A temporary buffer used for a single kernel, then discarded.
- Not registered as a traditional buffer since there are no users,
- so it would be dead code eliminated.
- """
- nbytes: sympy.Expr
- zero_fill: bool
- @dataclasses.dataclass
- class TensorArg:
- name: str
- buffer: str
- dtype: torch.dtype
- offset: sympy.Expr = sympy.Integer(0)
- @dataclasses.dataclass
- class SizeArg:
- name: str
- expr: sympy.Expr
- @dataclasses.dataclass
- class DeviceCodegen:
- scheduling: Any
- wrapper_codegen: type
- cpp_wrapper_codegen: type = type(None)
- KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
- device_codegens: Dict[str, DeviceCodegen] = {}
- class DeviceOpOverrides:
- def import_get_raw_stream_as(self, name):
- raise NotImplementedError
- def set_device(self, device_idx):
- raise NotImplementedError
- def synchronize(self):
- raise NotImplementedError
- def device_guard(self, device_idx):
- raise NotImplementedError
- device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
- # The code generated by Inductor consists of two main parts: kernel code and wrapper code.
- # For any new backend looking to integrate with Inductor, customization of these two main
- # parts are necessary to generate its specific code.
- #
- # Kernel code generation is determined by different Scheduling. Consequently, a new
- # backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
- # CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
- #
- # For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
- # that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
- # and override specific member functions to create backend-specific Python wrapper code.
- #
- # Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
- # of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
- # provide flexibility to the backend. A backend can choose to implement these classes from scratch,
- # or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
- # register_backend_for_device, to equip a new backend at runtime.
- #
- # Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
- # This backend can be used as a reference:
- # https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
- def register_backend_for_device(
- device: str,
- device_scheduling: Any,
- device_wrapper_codegen: type,
- device_cpp_wrapper_codegen: type = type(None),
- ):
- device_codegens[device] = DeviceCodegen(
- device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
- )
- def get_scheduling_for_device(device: str):
- return device_codegens[device].scheduling if device in device_codegens else None
- def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
- if device in device_codegens:
- wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
- return (
- wrapper_codegen_obj.cpp_wrapper_codegen
- if cpp_wrapper
- else wrapper_codegen_obj.wrapper_codegen
- )
- else:
- return None
- def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
- from ..ir import FlexibleLayout
- # added contiguous index prevents reordering
- return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
- def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
- device_op_overrides_dict[device] = device_op_overrides
- def get_device_op_overrides(device: str):
- assert isinstance(device, str)
- if not device_op_overrides_dict.keys():
- from .cuda import device_op_overrides # noqa: F401
- from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
- if device in device_op_overrides_dict.keys():
- return device_op_overrides_dict[device]
- @functools.lru_cache(None)
- def boolean_ops():
- return (
- "is_inf",
- "is_nan",
- "bitwise_xor",
- "logical_not",
- "signbit",
- "le",
- "lt",
- "ge",
- "gt",
- "eq",
- "ne",
- )
- DTYPE_TO_COMPUTATION_DTYPE = {
- torch.bfloat16: torch.float,
- torch.float16: torch.float,
- **{
- dtype: dtype
- for dtype in [
- torch.bool,
- torch.float32,
- torch.float64,
- torch.int8,
- torch.int16,
- torch.int32,
- torch.int64,
- torch.uint8,
- torch.uint16,
- torch.uint32,
- torch.uint64,
- ]
- },
- }
- class DataTypePropagation:
- def __init__(self, body) -> None:
- self.body = body
- self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
- "root": body.root_block.graph
- }
- for k, v in body.subblocks.items():
- self.graphs[k] = v.graph
- def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
- inputs = node.all_input_nodes
- input_nodes = [
- n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
- ]
- if len(input_nodes) == 0:
- return None
- all_input_nodes_propagated = all(
- OptimizationContext.key in n.meta
- and n.meta[OptimizationContext.key].dtype is not None
- for n in input_nodes
- )
- if not all_input_nodes_propagated:
- return None
- return functools.reduce(
- torch.promote_types,
- [n.meta[OptimizationContext.key].dtype for n in input_nodes],
- )
- def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
- sub_graph = self.graphs[node.target]
- dtype = self.propagate_graph(sub_graph)
- assert dtype
- return dtype
- def deduce_node_dtype(self, node: torch.fx.Node):
- if node.target in boolean_ops():
- return torch.bool
- if node.op == "placeholder":
- return None
- if node.target == "output":
- # we can infer output node if it only have 1 arg
- if len(node.args) != 1:
- return None
- if node.target in (
- "to_dtype",
- "index_expr",
- ):
- return node.args[-1]
- if node.target in (
- "rand",
- "randn",
- ):
- return torch.float
- if node.target in (
- "get_index",
- "index_expr",
- "randint64",
- ):
- return torch.int64
- if node.target in (
- "load",
- "store",
- "store_reduction",
- ):
- buf_name = node.args[1]
- return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
- if node.target == operator.getitem:
- return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
- assert isinstance(node.target, str)
- if node.target == "reduction":
- return node.args[1]
- if node.target == "constant":
- return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index]
- if node.target.startswith("masked_subblock"):
- return self.deduce_node_dtype_by_subgraph(node)
- return self.deduce_node_dtype_by_inputs(node)
- def propagate_graph(self, graph: torch.fx.Graph):
- assert graph.nodes
- graph_dtype = None
- # For masked_subblock, we use output's dtype to represent
- # the dtype of this subgraph. For other cases, graph_dtype
- # might be None
- for node in graph.nodes:
- if OptimizationContext.key in node.meta:
- opt_ctx = node.meta[OptimizationContext.key]
- else:
- opt_ctx = OptimizationContext()
- opt_ctx.dtype = self.deduce_node_dtype(node)
- node.meta[OptimizationContext.key] = opt_ctx
- if node.target == "output":
- graph_dtype = opt_ctx.dtype
- return graph_dtype
- def propagate(self):
- self.propagate_graph(self.graphs["root"])
- @classmethod
- def propagate_loopbody(cls, body):
- return cls(body).propagate()
- @classmethod
- def propagate_scheduler_node(cls, node):
- from ..ir import LoopBody
- from ..scheduler import SchedulerNode
- assert isinstance(node, SchedulerNode)
- assert isinstance(node._body, LoopBody)
- DataTypePropagation.propagate_loopbody(node._body)
- # This printer contains rules that are supposed to be generic for both C/C++ and
- # Python
- class ExprPrinter(Printer):
- @staticmethod
- def paren(string):
- def all_in_parens(string):
- if string[0] != "(" or len(string) < 2:
- return False
- count = 1
- for i, char in enumerate(string[1:]):
- if char == "(":
- count += 1
- elif char == ")":
- count -= 1
- if count == 0 and i != len(string) - 2:
- return False
- assert count == 0
- return True
- if (
- isinstance(string, CSEVariable)
- or re.match(r"^[a-z0-9_.]+$", string, re.I)
- or re.match(r"^\([^)]*\)$", string, re.I)
- or string == ""
- ):
- return string
- # don't put extra parens for strings that are already wrapped in parens
- if all_in_parens(string):
- return string
- return f"({string})"
- def _print_Relational(self, expr):
- return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
- def _print_Mul(self, expr):
- return "*".join(map(self.paren, map(self._print, expr.args)))
- def _print_Add(self, expr):
- return " + ".join(map(self.paren, map(self._print, expr.args)))
- # NB: this is OK to put here, because Mod is only defined for positive
- # numbers, and so across C/Python its behavior is consistent
- def _print_Mod(self, expr):
- return " % ".join(map(self.paren, map(self._print, expr.args)))
- def _print_FloatTrueDiv(self, expr):
- lhs, rhs = expr.args
- return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
- def _print_CleanDiv(self, expr):
- return self._print_FloorDiv(expr)
- def _print_GreaterThan(self, expr):
- # GreaterThan: >=
- # StrictlyGreaterThan: >
- # Go figure...
- return " >= ".join(map(self.paren, map(self._print, expr.args)))
- # NB: The C implementation is injected into codegen at
- # torch/_inductor/codegen/wrapper.py
- def _print_align(self, expr):
- assert len(expr.args) == 1
- return f"align({self._print(expr.args[0])})"
- # This must be implemented because sympy will collect x * x into Pow(x, 2), without
- # any explicit intervention. We print it just like x * x, notably, we
- # never generate sympy.Pow with floats.
- #
- # NB: this pow by natural, you should never have used builtin sympy.pow
- # for FloatPow, and a symbolic exponent should be PowByNatural. These
- # means exp is guaranteed to be integer.
- def _print_Pow(self, expr):
- base, exp = expr.args
- base = self._print(base)
- assert exp == int(exp), exp
- exp = int(exp)
- assert exp >= 0
- if exp > 0:
- return "*".join([self.paren(base)] * exp)
- else: # exp == 0
- return "1"
- # Explicit NotImplemented functions are to prevent default sympy printing
- # behavior, which will just barf out ToFloat(...) to your IR. The error
- # message is better here because it tells you which printer class it needs
- # to go in.
- def _print_ToFloat(self, expr):
- raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
- def _print_Infinity(self, expr):
- raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
- def _print_NegativeInfinity(self, expr):
- raise NotImplementedError(
- f"_print_NegativeInfinity not implemented for {type(self)}"
- )
- def _print_FloorDiv(self, expr):
- raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
- def _print_PythonMod(self, expr):
- raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
- def _print_IntTrueDiv(self, expr):
- raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
- def _print_PowByNatural(self, expr):
- raise NotImplementedError(
- f"_print_PowByNatural not implemented for {type(self)}"
- )
- def _print_FloatPow(self, expr):
- raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
- def _print_TruncToInt(self, expr):
- raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
- def _print_RoundToInt(self, expr):
- raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
- def _print_RoundDecimal(self, expr):
- raise NotImplementedError(
- f"_print_RoundDecimal not implemented for {type(self)}"
- )
- # NB: Some float operations are INTENTIONALLY not implemented for
- # printers. You can implement them as a quick unblock, but it is better
- # to ask yourself why we haven't done this computation in the Tensor
- # universe instead
- def _print_TruncToFloat(self, expr):
- raise NotImplementedError(
- f"_print_TruncToFloat not implemented for {type(self)}"
- )
- def doprint(self, expr, *, simplify: bool = True):
- # TODO: why are people passing strings to the printer here :think:
- if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
- expr = V.graph.sizevars.simplify(expr)
- return super().doprint(expr)
- class PythonPrinter(ExprPrinter):
- def _print_ToFloat(self, expr):
- assert len(expr.args) == 1
- return f"float({self._print(expr.args[0])})"
- def _print_ModularIndexing(self, expr):
- x, div, mod = expr.args
- x = self.paren(self.doprint(x))
- div = self.paren(self.doprint(div))
- mod = self.paren(self.doprint(mod))
- if div != "1":
- x = f"({x} // {div})"
- return f"{x} % {mod}"
- def _print_Infinity(self, expr):
- return "math.inf"
- def _print_NegativeInfinity(self, expr):
- return "-math.inf"
- # WARNING: this is dangerous for Triton, which has C-style modulus
- def _print_PythonMod(self, expr):
- return " % ".join(map(self.paren, map(self._print, expr.args)))
- # WARNING: this is dangerous for Triton, which has C-style modulus
- def _print_FloorDiv(self, expr):
- x, div = expr.args
- x = self.paren(self.doprint(x))
- div = self.paren(self.doprint(div))
- return f"({x} // {div})"
- # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
- # does a special algorithm
- def _print_IntTrueDiv(self, expr):
- lhs, rhs = expr.args
- return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
- def _helper_sqrt(self, expr):
- return f"math.sqrt({self._print(expr)})"
- def _print_OpaqueUnaryFn_sqrt(self, expr):
- return self._helper_sqrt(expr.args[0])
- def _print_FloatPow(self, expr):
- base, exp = expr.args
- return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
- # TODO: Not sure this works with Triton, even when base/exp are integral
- def _print_PowByNatural(self, expr):
- base, exp = expr.args
- return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
- def _print_floor(self, expr):
- assert len(expr.args) == 1
- return f"math.floor({self._print(expr.args[0])})"
- def _print_FloorToInt(self, expr):
- assert len(expr.args) == 1
- return f"math.floor({self._print(expr.args[0])})"
- def _print_TruncToInt(self, expr):
- assert len(expr.args) == 1
- # This also could have been int(), they'll do the same thing for float
- return f"math.trunc({self._print(expr.args[0])})"
- def _print_ceiling(self, expr):
- assert len(expr.args) == 1
- return f"math.ceil({self._print(expr.args[0])})"
- def _print_CeilToInt(self, expr):
- assert len(expr.args) == 1
- return f"math.ceil({self._print(expr.args[0])})"
- def _print_Abs(self, expr):
- assert len(expr.args) == 1
- return f"abs({self._print(expr.args[0])})"
- # NB: It's expected that we've made explicit any promotion in the sympy
- # expression, so it doesn't matter that Python max/min doesn't perform
- # promotion
- def _print_Max(self, expr):
- assert len(expr.args) >= 2
- return f"max({', '.join(map(self._print, expr.args))})"
- def _print_Min(self, expr):
- assert len(expr.args) >= 2
- return f"min({', '.join(map(self._print, expr.args))})"
- def _print_OpaqueUnaryFn_cos(self, expr):
- assert len(expr.args) == 1
- return f"math.cos({self._print(expr.args[0])})"
- def _print_OpaqueUnaryFn_cosh(self, expr):
- assert len(expr.args) == 1
- return f"math.cosh({self._print(expr.args[0])})"
- def _print_OpaqueUnaryFn_acos(self, expr):
- assert len(expr.args) == 1
- return f"math.acos({self._print(expr.args[0])})"
- def _print_OpaqueUnaryFn_sin(self, expr):
- assert len(expr.args) == 1
- return f"math.sin({self._print(expr.args[0])})"
- def _print_OpaqueUnaryFn_sinh(self, expr):
- assert len(expr.args) == 1
- return f"math.sinh({self._print(expr.args[0])})"
- def _print_OpaqueUnaryFn_asin(self, expr):
- assert len(expr.args) == 1
- return f"math.asin({self._print(expr.args[0])})"
- def _print_OpaqueUnaryFn_tan(self, expr):
- assert len(expr.args) == 1
- return f"math.tan({self._print(expr.args[0])})"
- def _print_OpaqueUnaryFn_tanh(self, expr):
- assert len(expr.args) == 1
- return f"math.tanh({self._print(expr.args[0])})"
- def _print_OpaqueUnaryFn_atan(self, expr):
- assert len(expr.args) == 1
- return f"math.atan({self._print(expr.args[0])})"
- def _print_RoundToInt(self, expr):
- assert len(expr.args) == 1
- return f"round({self._print(expr.args[0])})"
- def _print_RoundDecimal(self, expr):
- assert len(expr.args) == 2
- number, ndigits = expr.args
- assert isinstance(ndigits, sympy.Integer)
- return f"round({self._print(number)}, {ndigits})"
- class OpOverrides:
- def __init__(self, parent):
- super().__init__()
- self._parent = parent
- def __getattr__(self, item):
- return getattr(self._parent, item)
- @staticmethod
- def identity(value):
- # used to trigger cse
- return value
- @staticmethod
- def constant(value, dtype):
- return repr(value)
- @staticmethod
- def reciprocal(x):
- return ops.truediv(ops.constant(1, torch.int32), x)
- @staticmethod
- def square(x):
- return ops.mul(x, x)
- @staticmethod
- def erfc(x):
- return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
- @staticmethod
- def erfcx(x):
- return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
- @staticmethod
- def expm1(x):
- return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
- @staticmethod
- def log10(x):
- return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
- @staticmethod
- def log2(x):
- return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
- @staticmethod
- def exp2(x):
- return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
- @staticmethod
- def log1p(x):
- return ops.log(ops.add(x, ops.constant(1, torch.int32)))
- @staticmethod
- def sigmoid(x):
- one = ops.constant(1, torch.int32)
- return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
- @staticmethod
- def libdevice_sigmoid(x):
- one = ops.constant(1, torch.int32)
- return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
- @staticmethod
- def relu(x):
- return ops.maximum(x, ops.constant(0, torch.int32))
- @staticmethod
- def libdevice_abs(x):
- return ops.abs(x)
- @staticmethod
- def libdevice_sqrt(x):
- return ops.sqrt(x)
- @staticmethod
- def libdevice_cos(x):
- return ops.cos(x)
- @staticmethod
- def libdevice_sin(x):
- return ops.sin(x)
- @staticmethod
- def libdevice_log(x):
- return ops.log(x)
- @staticmethod
- def libdevice_exp(x):
- return ops.exp(x)
- @staticmethod
- def bitwise_not(x):
- return f"~{ExprPrinter.paren(x)}"
- @staticmethod
- def logical_not(a):
- return f"{ExprPrinter.paren(a)} == 0"
- @staticmethod
- def bitwise_and(x, y):
- return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
- @staticmethod
- def bitwise_or(x, y):
- return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
- @staticmethod
- def bitwise_xor(x, y):
- return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
- @staticmethod
- def bitwise_left_shift(x, y):
- return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
- @staticmethod
- def bitwise_right_shift(x, y):
- return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
- @staticmethod
- def remainder(a, b):
- r = ops.mod(a, b)
- cond = ops.and_(
- ops.ne(r, ops.constant(0, torch.int32)),
- ops.ne(ops.signbit(r), ops.signbit(b)),
- )
- return ops.where(cond, ops.add(r, b), r)
- @staticmethod
- def trunc_to_int(a, dtype):
- return ops.to_dtype(ops.trunc(a), dtype)
- @staticmethod
- def floor_to_int(a, dtype):
- return ops.to_dtype(ops.floor(a), dtype)
- @staticmethod
- def ceil_to_int(a, dtype):
- return ops.to_dtype(ops.ceil(a), dtype)
- @staticmethod
- def round_to_int(a, dtype):
- return ops.to_dtype(ops.round(a), dtype)
- @staticmethod
- def int_truediv(a, b):
- # TODO: this is wrong
- # TODO: an easy bandaid is to generate runtime asserts that it's
- # <= 2**53, which is when this equation is correct
- return ops.truediv(a, b)
- @staticmethod
- def load_seed(name, offset):
- return ops.load(name, sympy.Integer(offset))
- @classmethod
- def _initialize_pointwise_overrides(cls, target):
- assert target in {"triton", "cpp", "cppvec"}, target
- for funcname, data in pointwise_overrides_data.items():
- impl = getattr(data, target)
- if impl is None:
- continue
- setattr(cls, funcname, staticmethod(impl))
- @dataclasses.dataclass
- class OverridesData:
- name: str
- cpp: Callable[..., str]
- # None when not impl in libdevice/triton
- triton: Optional[Callable[..., str]] = None
- # None when not impl in aten/.../vec
- cppvec: Optional[Callable[..., str]] = None
- type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
- ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- # NB: if you add a new special function, don't forget to update
- # torch._inductor.ops_handler too
- pointwise_overrides_data: Dict[str, OverridesData] = dict(
- airy_ai=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"airy_ai_forward({x})",
- name="special_airy_ai",
- ),
- bessel_j0=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"bessel_j0_forward({x})",
- triton=lambda x: f"libdevice.j0({x})",
- name="special_bessel_j0",
- ),
- bessel_j1=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"bessel_j1_forward({x})",
- triton=lambda x: f"libdevice.j1({x})",
- name="special_bessel_j1",
- ),
- bessel_y0=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"bessel_y0_forward({x})",
- triton=lambda x: f"libdevice.y0({x})",
- name="special_bessel_y0",
- ),
- bessel_y1=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"bessel_y1_forward({x})",
- triton=lambda x: f"libdevice.y1({x})",
- name="special_bessel_y1",
- ),
- digamma=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"calc_digamma({x})",
- cppvec=lambda x: f"{x}.digamma()",
- name="digamma",
- ),
- # no cpp nor triton implementation for entr, it is defined as decomposition
- # erf, erfc
- erfcx=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"calc_erfcx({x})",
- triton=lambda x: f"libdevice.erfcx({x})",
- name="special_erfcx",
- ),
- fma=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
- cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
- triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
- name="fma",
- ),
- # erfinv, exp2, expit, gammaln
- igamma=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"calc_igamma({x}, {y})",
- name="igamma",
- ),
- igammac=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"calc_igammac({x}, {y})",
- name="igammac",
- ),
- gammainc=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"calc_igamma({x}, {y})",
- name="special_gammainc",
- ),
- gammaincc=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"calc_igammac({x}, {y})",
- name="special_gammaincc",
- ),
- i0=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"calc_i0({x})",
- triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
- cppvec=lambda x: f"{x}.i0()",
- name="i0",
- ),
- i0e=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"calc_i0e({x})",
- cppvec=lambda x: f"{x}.i0e()",
- name="special_i0e",
- ),
- i1=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"calc_i1({x})",
- triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
- name="special_i1",
- ),
- i1e=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"calc_i1e({x})",
- name="special_i1e",
- ),
- log_ndtr=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"calc_log_ndtr({x})",
- name="special_log_ndtr",
- ),
- # logit
- modified_bessel_i0=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"modified_bessel_i0_forward({x})",
- triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
- name="special_modified_bessel_i0",
- ),
- modified_bessel_i1=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"modified_bessel_i1_forward({x})",
- triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
- name="special_modified_bessel_i1",
- ),
- modified_bessel_k0=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"modified_bessel_k0_forward({x})",
- name="special_modified_bessel_k0",
- ),
- modified_bessel_k1=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"modified_bessel_k1_forward({x})",
- name="special_modified_bessel_k1",
- ),
- # multigamma
- ndtr=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"calc_ndtr({x})",
- name="special_ndtr",
- ),
- ndtri=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"calc_ndtri({x})",
- name="special_ndtri",
- ),
- polygamma=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"calc_polygamma({y}, {x})",
- name="polygamma",
- ),
- # psi - alias to digamma
- # round
- scaled_modified_bessel_k0=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
- name="special_scaled_modified_bessel_k0",
- ),
- scaled_modified_bessel_k1=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
- name="special_scaled_modified_bessel_k1",
- ),
- # sinc
- spherical_bessel_j0=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x: f"spherical_bessel_j0_forward({x})",
- name="special_spherical_bessel_j0",
- ),
- zeta=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"zeta({x}, {y})",
- name="special_zeta",
- ),
- chebyshev_polynomial_t=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
- name="special_chebyshev_polynomial_t",
- ),
- chebyshev_polynomial_u=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
- name="special_chebyshev_polynomial_u",
- ),
- chebyshev_polynomial_v=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
- name="special_chebyshev_polynomial_v",
- ),
- chebyshev_polynomial_w=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
- name="special_chebyshev_polynomial_w",
- ),
- legendre_polynomial_p=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
- name="special_legendre_polynomial_p",
- ),
- shifted_chebyshev_polynomial_t=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
- name="special_shifted_chebyshev_polynomial_t",
- ),
- shifted_chebyshev_polynomial_u=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
- name="special_shifted_chebyshev_polynomial_u",
- ),
- shifted_chebyshev_polynomial_v=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
- name="special_shifted_chebyshev_polynomial_v",
- ),
- shifted_chebyshev_polynomial_w=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
- name="special_shifted_chebyshev_polynomial_w",
- ),
- hermite_polynomial_h=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
- name="special_hermite_polynomial_h",
- ),
- hermite_polynomial_he=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
- name="special_hermite_polynomial_he",
- ),
- laguerre_polynomial_l=OverridesData(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
- name="special_laguerre_polynomial_l",
- ),
- )
- # Use mypy to check protocol implemented correctly
- def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
- return h
- class DeferredLine(DeferredLineBase):
- """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
- def __init__(self, name, line):
- super().__init__(line)
- self.name = name
- assert not isinstance(line, DeferredLineBase)
- def __call__(self):
- if all(
- self.name not in x
- for x in (
- V.graph.removed_buffers,
- V.kernel.removed_buffers,
- V.graph.inplaced_to_remove,
- V.kernel.inplaced_to_remove,
- )
- ):
- return self.line
- return None
- def _new_line(self, line):
- return DeferredLine(self.name, line)
- class BracesBuffer(IndentedBuffer):
- def indent(self, offset=1):
- @contextlib.contextmanager
- def ctx():
- for _ in range(offset):
- self.writeline("{")
- self._indent += 1
- for _ in range(-offset):
- self._indent -= 1
- self.writeline("}")
- yield
- for _ in range(-offset):
- self.writeline("{")
- self._indent += 1
- for _ in range(offset):
- self._indent -= 1
- self.writeline("}")
- return ctx()
- class InplacedBuffer(NamedTuple):
- inner_name: str
- other_names: List[str]
- class KernelArgs:
- @staticmethod
- def _lookup(prefix, odict, name):
- assert isinstance(name, (str, sympy.Symbol))
- if name not in odict:
- odict[name] = f"{prefix}{len(odict)}"
- return odict[name]
- def __init__(self, sizevars=None):
- self.input_buffers = dict()
- self.output_buffers = dict()
- self.inplace_buffers = dict()
- self.sizevars = sizevars or dict()
- self.workspace_arg = None
- def __repr__(self):
- return "KernelArgs({})".format(
- ", ".join(
- map(
- repr,
- [
- self.input_buffers,
- self.output_buffers,
- self.inplace_buffers,
- self.sizevars,
- ],
- )
- )
- )
- def _buffer_is_marked_removed(self, name):
- return isinstance(name, str) and name.startswith("REMOVED")
- def input(self, name):
- if V.graph.scheduler:
- name = V.graph.scheduler.mutation_real_name.get(name, name)
- assert name not in V.graph.removed_buffers, name
- if name in self.output_buffers:
- return self.output_buffers[name]
- if name in self.inplace_buffers:
- return self.inplace_buffers[name].inner_name
- if name.startswith("seed"):
- return self._lookup("seed", self.input_buffers, name)
- return self._lookup("in_ptr", self.input_buffers, name)
- def output(self, name):
- if V.graph.scheduler:
- name = V.graph.scheduler.mutation_real_name.get(name, name)
- assert name not in V.graph.removed_buffers, name
- if name in self.inplace_buffers:
- return self.inplace_buffers[name].inner_name
- return self._lookup("out_ptr", self.output_buffers, name)
- def make_inplace(self, input_name, output_name):
- assert output_name not in self.inplace_buffers
- if input_name in self.inplace_buffers:
- buf = self.inplace_buffers[input_name]
- buf.other_names.append(output_name)
- self.inplace_buffers[output_name] = buf
- else:
- buf = InplacedBuffer(
- f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
- [input_name, output_name],
- )
- self.inplace_buffers[input_name] = buf
- self.inplace_buffers[output_name] = buf
- def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
- if self.workspace_arg is None:
- self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
- return "ws_ptr", 0
- offset = self.workspace_arg.nbytes
- zero_fill = zero_fill or self.workspace_arg.zero_fill
- self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
- return "ws_ptr", offset
- def seed_offset(self, name, value):
- if value in self.sizevars:
- return self.sizevars[value]
- if name in self.sizevars.values():
- name = (
- f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
- )
- self.sizevars[value] = name
- return name
- def size(self, name):
- if str(name) == "seed":
- self.sizevars["seed"] = "seed"
- return "seed"
- return self._lookup("ks", self.sizevars, name)
- def call_names(self):
- return chain(
- self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
- )
- def wrap_ptr_arg(self, buf, dtype):
- return buf
- def wrap_size_arg(self, size):
- return str(size)
- def cpp_argdefs(self):
- from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE
- call_args = []
- arg_defs = []
- arg_types = []
- for inplaced in unique(self.inplace_buffers.values()):
- if self._buffer_is_marked_removed(inplaced):
- continue
- outer = inplaced.other_names[-1]
- inner = inplaced.inner_name
- dtype = V.graph.get_dtype(outer)
- cpp_dtype = DTYPE_TO_CPP[dtype]
- arg_defs.append(f"{cpp_dtype}* {inner}")
- call_args.append(self.wrap_ptr_arg(outer, dtype))
- arg_types.append(f"{cpp_dtype}*")
- for outer, inner in self.input_buffers.items():
- if outer in self.inplace_buffers:
- continue
- dtype = V.graph.get_dtype(outer)
- cpp_dtype = DTYPE_TO_CPP[dtype]
- arg_defs.append(f"const {cpp_dtype}* {inner}")
- call_args.append(self.wrap_ptr_arg(outer, dtype))
- arg_types.append(f"const {cpp_dtype}*")
- for outer, inner in self.output_buffers.items():
- if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
- continue
- dtype = V.graph.get_dtype(outer)
- cpp_dtype = DTYPE_TO_CPP[dtype]
- arg_defs.append(f"{cpp_dtype}* {inner}")
- call_args.append(self.wrap_ptr_arg(outer, dtype))
- arg_types.append(f"{cpp_dtype}*")
- for outer, inner in self.sizevars.items():
- arg_defs.append(f"const {INDEX_TYPE} {inner}")
- call_args.append(self.wrap_size_arg(outer))
- arg_types.append(f"const {INDEX_TYPE}")
- if V.graph.wrapper_code:
- V.graph.wrapper_code.ensure_size_computed(outer)
- assert self.workspace_arg is None, "Workspace not supported on CPU "
- return arg_defs, call_args, arg_types
- def python_argdefs(self):
- arg_defs = []
- call_args = []
- arg_types = []
- precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
- for inplaced in unique(self.inplace_buffers.values()):
- if self._buffer_is_marked_removed(inplaced):
- continue
- arg_defs.append(inplaced.inner_name)
- call_args.append(inplaced.other_names[-1])
- arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
- precompile_args.append(
- TensorArg(
- name=inplaced.inner_name,
- buffer=inplaced.other_names[-1],
- dtype=V.graph.get_dtype(inplaced.other_names[-1]),
- )
- )
- for outer, inner in chain(
- self.input_buffers.items(), self.output_buffers.items()
- ):
- if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
- continue
- arg_defs.append(inner)
- call_args.append(outer)
- arg_types.append(V.graph.get_dtype(outer))
- precompile_args.append(
- TensorArg(
- name=inner,
- buffer=outer,
- dtype=V.graph.get_dtype(outer),
- )
- )
- for outer, inner in self.sizevars.items():
- arg_defs.append(inner)
- call_args.append(outer)
- arg_types.append(type(outer))
- precompile_args.append(SizeArg(inner, outer))
- if V.graph.wrapper_code:
- V.graph.wrapper_code.ensure_size_computed(outer)
- if self.workspace_arg is not None:
- arg_defs.append("ws_ptr")
- call_args.append("workspace")
- precompile_args.append(self.workspace_arg)
- return arg_defs, call_args, precompile_args, arg_types
- def aliases(self):
- for inplaced in unique(self.inplace_buffers.values()):
- if self._buffer_is_marked_removed(inplaced):
- continue
- for other in inplaced.other_names:
- if (
- other in V.graph.inplaced_to_remove
- or other in V.kernel.inplaced_to_remove
- ):
- continue
- if other in self.input_buffers:
- yield self.input_buffers[other], inplaced.inner_name
- if other in self.output_buffers:
- yield self.output_buffers[other], inplaced.inner_name
- def is_removed(self, name):
- def _is_removed(name, buffers):
- return name not in buffers or self._buffer_is_marked_removed(buffers[name])
- return _is_removed(name, self.output_buffers) and _is_removed(
- name, self.inplace_buffers
- )
- # Includes inplace buffers, excludes removed buffers. Essentially,
- # after you do a call into this kernel, which buffers actually contain
- # updated data? Modeled off of python_argdefs.
- def live_output_buffers(self):
- live_outs = set()
- for inplaced in unique(self.inplace_buffers.values()):
- if self._buffer_is_marked_removed(inplaced):
- continue
- live_outs.add(inplaced.other_names[-1])
- for outer, inner in self.output_buffers.items():
- if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
- continue
- live_outs.add(outer)
- return live_outs
- class CSEVariable:
- """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
- To do so, the backends can simply overload `Kernel.create_cse_var`
- The "CSEVariable.update_on_args" method gives you a hook for annotations
- See example of TritonCSEVariable in triton.py
- """
- def __init__(self, name, bounds: ValueRanges[Any]):
- assert isinstance(bounds, ValueRanges)
- self.name = name
- self.bounds = bounds
- self.use_count = 1 # track how many tims this expression is used
- def __str__(self):
- return self.name
- def __hash__(self) -> int:
- return hash(self.name)
- def __eq__(self, other) -> bool:
- return type(other) == type(self) and other.name == self.name
- def update_on_args(self, name, args, kwargs):
- pass
- def __repr__(self):
- return f"{self.__class__.__name__}({self.name!r})"
- class CppWrapperKernelArgs(KernelArgs):
- def wrap_ptr_arg(self, buf, dtype):
- from .cpp_utils import DTYPE_TO_CPP
- if config.abi_compatible:
- # In the abi_compatible model, we just return the buf here.
- # We will form correct call args later in wrapper.generate_kernel_all.
- return buf
- else:
- return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
- def wrap_size_arg(self, size):
- return f"{size}"
- class CSE:
- """Common subexpression elimination"""
- def __init__(
- self,
- prefix="",
- suffix="",
- name_prefix="tmp",
- iter_buffers=None,
- store_cache=None,
- reduction_cache=None,
- varname_map=None,
- ):
- self.prefix = prefix
- self.suffix = suffix
- self.cache = {}
- self.name_prefix = name_prefix
- self.store_cache = store_cache or {}
- self.reduction_cache = reduction_cache or {}
- self.iter_buffer_ids = iter_buffers or itertools.count()
- self.invalidated_stores = set()
- self.varname_map = varname_map or {}
- def invalidate(self, keep_vars: Set[str]):
- for name, tmp in list(self.store_cache.items()):
- if tmp not in keep_vars:
- del self.store_cache[name]
- self.invalidated_stores.add(name)
- self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
- def clone(self):
- # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
- return CSE(
- prefix=self.prefix,
- suffix=self.suffix,
- name_prefix=self.name_prefix,
- iter_buffers=self.iter_buffer_ids,
- store_cache=self.store_cache,
- varname_map=self.varname_map,
- )
- def generate(
- self,
- buffer: IndentedBuffer,
- expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
- *,
- bounds: ValueRanges[Any] = ValueRanges.unknown(),
- write=True,
- assignment=True,
- ) -> CSEVariable:
- if isinstance(expr, OpsValue):
- expr = expr.value
- assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
- assert write or assignment
- if isinstance(expr, CSEVariable):
- # If the expressions were always created with all the information, we could
- # assert expr.bounds == bounds, but sometimes the expression is created
- # with the loose ValueRanges.unknown(), so we need to tighten the bounds
- expr.bounds = expr.bounds.tighten(bounds)
- expr.use_count += 1
- return expr
- cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
- var = self.cache.get(cache_key, None)
- if not var:
- var = self.newvar(bounds)
- self.cache[cache_key] = var
- if write:
- if V.kernel.current_node:
- V.kernel.current_node.codegen_originating_info(
- buffer, only_once=True
- )
- if isinstance(expr, IndentedBuffer):
- if assignment:
- buffer.writeline(f"{self.prefix}{var} =")
- buffer.splice(expr)
- buffer.writeline(self.suffix)
- else:
- if assignment:
- line = f"{self.prefix}{var} = {expr}{self.suffix}"
- else:
- line = f"{expr}{self.suffix}"
- buffer.writeline(line)
- else:
- var.bounds = var.bounds.tighten(bounds)
- var.use_count += 1
- return var
- def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
- var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
- var = V.kernel.create_cse_var(var_name, bounds)
- self.varname_map[var_name] = var
- return var
- class CodeGen:
- def __init__(self):
- super().__init__()
- self.exit_stack = contextlib.ExitStack()
- def __enter__(self):
- self.exit_stack.__enter__()
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
- class ScopedDict:
- def __init__(self, original_dict):
- self.original_dict = original_dict
- self.new_items = {}
- def __getitem__(self, key):
- if key in self.new_items:
- return self.new_items[key]
- return self.original_dict[key]
- def __setitem__(self, key, value):
- self.new_items[key] = value
- def __contains__(self, key):
- return key in self.new_items or key in self.original_dict
- def get(self, key, default=None):
- if key in self.new_items:
- return self.new_items[key]
- return self.original_dict.get(key, default)
- class Kernel(CodeGen):
- newvar_prefix = ""
- suffix = ""
- overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
- # TODO: these look dead, but with all the getattr it's hard to tell...
- load_format: None = None
- store_format: None = None
- def __init__(self, args=None, increase_kernel_count=True):
- super().__init__()
- if increase_kernel_count:
- metrics.generated_kernel_count += 1
- self.args = args or KernelArgs()
- self.loads = IndentedBuffer()
- self.compute = IndentedBuffer()
- self.stores = IndentedBuffer()
- self.num_load = 0
- self.num_reduction = 0
- self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
- self.must_keep_buffers = set()
- self.store_buffer_names = set()
- self._load_mask = None
- # set in set_current_node
- self.current_node = None
- self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
- self.removed_buffers = set()
- self.inplaced_to_remove = set()
- # key: the buffer to write
- # value: the buffer to read and whose memory can be reused for
- # the buffer specified by key
- self.inplace_update_buffers = dict()
- # Set minimum number of elements processed per thread.
- self.min_elem_per_thread = 1
- self.kernel_name = None
- @contextlib.contextmanager
- def set_current_node(self, node):
- prior = self.current_node
- self.current_node = node
- self.node_to_bounds = node._body.bounds().get_bounds()
- try:
- yield
- finally:
- self.current_node = prior
- @contextlib.contextmanager
- def swap_buffers(self, lb, cb=None, sb=None):
- def scope_cse(cse):
- new_cse = cse.clone()
- new_cse.cache = ScopedDict(cse.cache)
- new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
- new_cse.store_cache = ScopedDict(cse.store_cache)
- return new_cse
- if cb is None:
- cb = lb
- loads = self.loads
- compute = self.compute
- stores = self.stores
- cse = self.cse
- self.loads = lb
- self.compute = cb
- self.stores = sb
- self.cse = scope_cse(cse)
- try:
- yield
- finally:
- self.loads = loads
- self.compute = compute
- self.stores = stores
- self.cse = cse
- def load(self, name: str, index: sympy.Expr) -> CSEVariable:
- raise NotImplementedError
- def indirect_load(self, name: str, index: sympy.Expr):
- """A load the depends on an index we have read"""
- prior = self.loads
- try:
- # put the load in the compute section as it might have deps
- self.loads = self.compute
- return self.load(name, index)
- finally:
- self.loads = prior
- def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
- raise NotImplementedError
- def store(
- self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
- ) -> None:
- raise NotImplementedError
- def reduction(
- self,
- dtype: torch.dtype,
- src_dtype: torch.dtype,
- reduction_type: ReductionType,
- value: Union[CSEVariable, Tuple[CSEVariable, ...]],
- ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
- raise NotImplementedError
- def scan(
- self,
- dtypes: Tuple[torch.dtype, ...],
- combine_fn: Callable[
- [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
- ],
- values: Tuple[CSEVariable, ...],
- ) -> Tuple[CSEVariable, ...]:
- raise NotImplementedError
- def var_ranges(self):
- raise NotImplementedError
- def bucketize(
- self,
- values: CSEVariable,
- offsets_name: str,
- offsets_size: sympy.Expr,
- indexing_dtype: torch.dtype,
- right: bool,
- ) -> CSEVariable:
- """
- See [Note: Inductor bucketize op]
- """
- raise NotImplementedError
- @property
- def assert_function(self) -> str:
- raise NotImplementedError
- def indirect_assert(
- self,
- var: Union[CSEVariable, str],
- lower: Optional[str],
- upper: Optional[str],
- mask: Optional[str] = None,
- ) -> str:
- if isinstance(var, CSEVariable):
- var = str(var)
- assert isinstance(var, str)
- assert lower is None or isinstance(lower, str)
- assert upper is None or isinstance(upper, str)
- if lower and upper:
- # The conditions need to be in parens because of Python's operator precedence.
- # It'd be less error-prone to use and/or/not, which is suported by triton
- cond = f"({lower} <= {var}) & ({var} < {upper})"
- cond_print = f"{lower} <= {var} < {upper}"
- elif lower:
- cond = f"{lower} <= {var}"
- cond_print = cond
- else:
- assert upper
- cond = f"{var} < {upper}"
- cond_print = cond
- if mask:
- cond = f"({cond}) | ~({mask})"
- return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
- def check_bounds(
- self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
- ):
- raise NotImplementedError
- def index_to_str(self, index: sympy.Expr) -> str:
- raise NotImplementedError
- def __enter__(self):
- # TODO: hoist this to top level
- class CSEProxy:
- self.name = "CSEProxy"
- vr_analysis = ValueRangeAnalysis()
- @staticmethod
- def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
- def inner(*args, **kwargs):
- bounds = CSEProxy._bound_variable(name, *args, **kwargs)
- value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
- def do_cse(v):
- csevar = self.cse.generate(self.compute, v, bounds=bounds)
- csevar.update_on_args(name, args, kwargs)
- return csevar
- return pytree.tree_map(do_cse, value)
- return inner
- @staticmethod
- def _bound_variable(name, *args, **kwargs):
- """
- If the variable comes from an FX node, we forward the bound we have already computed
- Else, if the variable when codegen'ing another op, we try to compute its bounds
- """
- from ..select_algorithm import TritonTemplateKernel
- if isinstance(V.kernel, TritonTemplateKernel):
- return ValueRanges.unknown()
- fx_node = V.interpreter.current_node
- if fx_node.target == name and self.node_to_bounds is not None:
- assert isinstance(self.node_to_bounds, dict)
- return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
- elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
- # These create lots of inner strings. We would need to compute the bounds at the ops
- # We will also likely not get much from computing VRs on these nodes
- if any(
- s in fx_node.target
- for s in ("set_indirect", "reduction", "scan")
- ):
- return ValueRanges.unknown()
- # We assume that the inputs come from `ops.` and are not strings. If you want to generate
- # intermediary strings, wrap them in CSE variables with properly initialised bounds.
- # If there is no FX bound but we know how to compute one we do so
- assert not kwargs
- def arg_to_bound(x):
- if isinstance(x, CSEVariable):
- return x.bounds
- elif isinstance(x, sympy.Expr):
- return bound_sympy(x)
- else:
- return x
- arg_bounds = list(map(arg_to_bound, args))
- return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
- else:
- return ValueRanges.unknown()
- @staticmethod
- def indirect_indexing(
- var: CSEVariable, size: Union[sympy.Expr, int], check: bool = True
- ):
- if isinstance(size, int):
- size = sympy.Integer(size)
- assert isinstance(size, sympy.Expr), size
- # Skip CSE since this doesn't return an expression
- if var.bounds.lower < 0: # type: ignore[operator]
- stm = ops.add(var, ops.index_expr(size, torch.long))
- # Mixed negative and non-negative
- if var.bounds.upper >= 0: # type: ignore[operator]
- lt = ops.lt(var, 0)
- stm = ops.where(lt, stm, var)
- # Propagate bounds as we know how to compute them properly
- new_bounds = ValueRanges.unknown()
- if var.bounds != ValueRanges.unknown() and isinstance(
- size, sympy.Number
- ):
- # Take the negative part of the bound and add size to it
- # Then take union of that and the positive part
- # This is a tighter bound than that of a generic ops.where, as we have info on the cond
- neg_bounds = var.bounds & ValueRanges(-sympy.oo, -1)
- new_bounds = ValueRanges(
- neg_bounds.lower + size, neg_bounds.upper + size
- )
- # We don't have a good way of representing the empty range
- if var.bounds.upper >= 0: # type: ignore[operator]
- pos = var.bounds & ValueRanges(0, sympy.oo)
- new_bounds = new_bounds | pos
- var = self.cse.generate(self.compute, stm, bounds=new_bounds)
- sympy_var = parent_handler.indirect_indexing(var, size, check)
- if generate_assert(check):
- assert_lower = not (var.bounds.lower >= 0)
- # value ranges cannot x < s when x and s are symbols
- assert_upper = not isinstance(size, sympy.Number) or not (
- var.bounds.upper < size
- )
- self.check_bounds(sympy_var, size, assert_lower, assert_upper)
- return sympy_var
- @staticmethod
- def check_bounds(
- expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
- ):
- return self.check_bounds(expr, size, lower, upper)
- @staticmethod
- def load(name: str, index: sympy.Expr) -> CSEVariable:
- if name in self.cse.invalidated_stores:
- # A load from an invalidated store requires us to
- # keep the actual buffer around
- V.kernel.must_keep_buffers.add(name)
- if free_symbol_is_type(index, SymT.TMP):
- return self.indirect_load(name, index)
- store_cache = self.cse.store_cache
- if name in store_cache:
- return store_cache[name]
- out = self.load(name, index)
- # count load that is not in the store_cache, and also not in the
- # cse cache.
- if out.use_count == 1:
- self.num_load += 1
- return out
- @staticmethod
- def store(
- name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
- ) -> None:
- self.store_buffer_names.add(name)
- if mode is None:
- self.cse.store_cache[name] = value
- if self.current_node:
- for other_name in self.current_node.get_mutations():
- self.cse.store_cache[other_name] = value
- if name not in V.graph.removed_buffers:
- return self.store(name, index, value, mode=mode)
- else:
- return None # type: ignore[return-value]
- @staticmethod
- def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
- self.store_buffer_names.add(name)
- self.cse.store_cache[name] = value
- if self.current_node:
- for other_name in self.current_node.get_mutations():
- self.cse.store_cache[other_name] = value
- if name not in V.graph.removed_buffers:
- return self.store_reduction(name, index, value)
- @staticmethod
- def reduction(
- dtype: torch.dtype,
- src_dtype: torch.dtype,
- reduction_type: ReductionType,
- value: Union[CSEVariable, Tuple[CSEVariable, ...]],
- ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
- self.num_reduction += 1
- return self.reduction(dtype, src_dtype, reduction_type, value)
- @staticmethod
- def scan(
- dtypes: Tuple[torch.dtype, ...],
- combine_fn: Callable[
- [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
- Tuple[CSEVariable, ...],
- ],
- values: Tuple[CSEVariable, ...],
- ) -> Tuple[CSEVariable, ...]:
- return self.scan(dtypes, combine_fn, values)
- @staticmethod
- def bucketize(
- values: CSEVariable,
- offsets_name: str,
- offsets_size: sympy.Expr,
- indexing_dtype: torch.dtype,
- right: bool,
- ) -> CSEVariable:
- """
- [Note: Inductor bucketize op]
- Given values (tensor) and offsets_name (reference to the name of a 1D
- tensor), calculate the bucket that each value belongs to.
- e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
- return = [ 0, 1, 1, 1, 1, 3, 3, 4].
- When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
- When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
- Offsets must be non-decreasing or the result is undefined.
- """
- return self.bucketize(
- values, offsets_name, offsets_size, indexing_dtype, right
- )
- # Use mypy to check protocol implemented correctly
- def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
- return h
- super().__enter__()
- assert self.overrides
- parent_handler = self.overrides(V.get_ops_handler())
- self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
- self.exit_stack.enter_context(V.set_kernel_handler(self))
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- """
- Note that V.graph.scheduler can be None when codegening triton template
- kernels.
- """
- if V.graph.scheduler:
- V.graph.scheduler.remove_kernel_local_buffers()
- super().__exit__(exc_type, exc_val, exc_tb)
- def rename_indexing(self, index) -> sympy.Expr:
- # adds the necessary kernel args for index expressions
- # and renames variables in index expressions to kernel arg names
- if isinstance(index, (list, tuple)):
- return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
- index = V.graph.sizevars.simplify(index)
- sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
- replacements = {
- x: self.args.size(x)
- for x in sorted_symbols
- if symbol_is_type(
- x,
- (
- SymT.UNBACKED_INT,
- SymT.SIZE,
- SymT.PRECOMPUTED_SIZE,
- ),
- )
- }
- return sympy_subs(index, replacements)
- def create_cse_var(self, *args, **kwargs):
- return CSEVariable(*args, **kwargs)
- @dataclasses.dataclass
- class OptimizationContext:
- key: ClassVar[str] = "opt_ctx"
- dtype: Optional[torch.dtype] = None
- ops_name: str = ""
- @functools.lru_cache(None)
- def jinja2_env():
- try:
- import jinja2
- return jinja2.Environment(
- undefined=jinja2.StrictUndefined,
- )
- except ImportError:
- return None
- class KernelTemplate:
- """
- Base class for defining kernel templates.
- Children classes: TritonTemplate, CUDATemplate
- """
- @staticmethod
- def indent_except_first(source: str, num_indents: int, indents_spacing=4):
- lines = source.splitlines(True)
- if len(lines) > 1:
- lines[1:] = [
- (" " * indents_spacing * num_indents) + line for line in lines[1:]
- ]
- return "".join(lines)
- @staticmethod
- def _template_from_string(source):
- env = jinja2_env()
- if env is not None:
- env.filters["indent_except_first"] = KernelTemplate.indent_except_first
- return env.from_string(source)
- return None
- @staticmethod
- def _fake_get_dtype(fake_out):
- _get_dtype_real = V.graph.get_dtype
- def get_dtype(name):
- if name == fake_out.get_name():
- return fake_out.get_dtype()
- return _get_dtype_real(name)
- return get_dtype
- def __init__(self, name: str):
- self.name = name
- def maybe_append_choice(self, choices, **kwargs):
- """
- Maybe generates a new ChoiceCaller and appends it into existing choices.
- choices: A list of ChoiceCallers.
- kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
- """
- try:
- choices.append(self.generate(**kwargs))
- except NotImplementedError:
- pass
- def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
- """
- Generates a ChoiceCaller instance from the given arguments.
- """
- raise NotImplementedError
|