| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969 |
- """
- # Inductor Pattern Matcher
- The pattern matcher enables search/replace within an FX graph.
- The main entrypoint to the pattern matcher is register_replacement(). Given a
- search function and a replacement function this will register a replacement with
- a pass (such as torch._inductor.fx_passes.joint_graph.patterns).
- Internally the pattern matcher represents patterns as a graph (a DAG). Creating
- new patterns manually as a graph is cumbersome and error-prone so the standard
- way to create patterns (using register_replacement()) is to provide a search
- function and a replacement function which is traced and converted into a graph.
- Because the search functions are built somewhat generic (they tend to ignore
- tensor sizes, for example) register_replacement() allows you to specify an
- `extra_check` function which performs additional checks to verify that the
- matched pattern fully matches before returning it.
- ## Precompiled Patterns
- New patterns are added using register_replacement(). Patterns added in this way
- can have a compile-time overhead because they need to be traced before
- use. Patterns can be precompiled and added using gen_register_replacement()
- instead. To do this you call gen_register_replacement() instead of
- register_replacement(). The arguments are the same except for an additional
- unique name which is used as a lookup key.
- ## Internals
- The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr
- implements a `_match` method which returns either a `Match` object for a
- successful match or a `FailedMatch` object for a failure to match.
- """
- # mypy: disallow-untyped-defs
- from __future__ import annotations
- import contextlib
- import dataclasses
- import functools
- import importlib
- import inspect
- import itertools
- import logging
- import operator
- import os
- import re
- import textwrap
- import typing
- from abc import ABC, abstractmethod
- from collections import defaultdict
- from pathlib import Path
- from typing import (
- Any,
- Callable,
- DefaultDict,
- Dict,
- Generator,
- Iterable,
- List,
- Mapping,
- NoReturn,
- Optional,
- Protocol,
- Sequence,
- Set,
- Tuple,
- Type,
- TypeVar,
- Union,
- )
- from typing_extensions import Self, TypeGuard
- import torch
- import torch._guards
- import torch.fx
- import torch.utils._pytree as pytree
- from torch._dispatch.python import enable_python_dispatcher
- from torch._dynamo.utils import counters
- from torch._prims_common import is_integer_dtype
- from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
- from torch.fx.immutable_collections import immutable_dict, immutable_list
- from .._functorch import config as functorch_config
- from .._functorch.aot_autograd import aot_function, make_boxed_func
- from .._functorch.partitioners import default_partition
- from .._subclasses import FakeTensorMode
- from ..fx import Transformer
- from . import config
- from .decomposition import select_decomp_table
- from .lowering import fallback_node_due_to_unsupported_type
- log = logging.getLogger(__name__)
- aten = torch.ops.aten
- prims = torch.ops.prims
- Constant = Any
- NodeOrConstant = Union[Constant, torch.fx.Node]
- class SearchFn(Protocol):
- __name__: str
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- ...
- class ReplaceFn(Protocol):
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- ...
- class TraceFn(Protocol):
- def __call__(
- self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any
- ) -> torch.fx.GraphModule:
- ...
- T = TypeVar("T")
- # What's a better name for this?
- FnsType = Union[torch.fx.node.Target, str]
- class Multiple:
- def __init__(self) -> None:
- # Ensure we're really a singleton.
- assert "MULTIPLE" not in globals() or self is MULTIPLE
- # Sentinel indicating multiple quantities can be matched
- MULTIPLE = Multiple()
- class Match:
- """
- Represents a successfully matched pattern.
- The `Match` object is returned to represent a successfully matched
- pattern. Included in the Match are the pattern that was matched, the graph
- nodes matched, and any args that were used during the matching.
- The args and kwargs are specific to the type of pattern that was matched and
- provide hints about what was matched.
- """
- pattern: PatternExpr
- args: List[Any]
- kwargs: Dict[str, Any]
- nodes: List[torch.fx.Node]
- targets: Dict[_TargetExpr, torch.fx.node.Target]
- ctx: MatchContext
- replacement_graph: Optional[torch.fx.Graph]
- def __init__(
- self,
- ctx: MatchContext,
- pattern: PatternExpr,
- args: Optional[Sequence[Any]] = None,
- kwargs: Optional[Dict[str, Any]] = None,
- ) -> None:
- super().__init__()
- self.pattern = pattern
- # The input nodes that must be passed in to the result
- self.args = list(args or [])
- self.kwargs = kwargs or {}
- # The nodes matched in this expression
- self.nodes = []
- # Mapping CallFunction to the node.target
- self.targets = {}
- self.ctx = ctx
- self.replacement_graph = None
- @property
- def graph(self) -> torch.fx.Graph:
- return self.ctx.graph
- def extend(self, other: Match) -> None:
- if self.kwargs:
- for key in set(self.kwargs.keys()) & set(other.kwargs.keys()):
- if self.kwargs[key] != other.kwargs[key]:
- raise FailedMatch("kwarg mismatch: {}", key)
- self.args.extend(other.args)
- self.nodes.extend(other.nodes)
- self.kwargs.update(other.kwargs)
- self.targets.update(other.targets)
- def bundle(self) -> Match:
- # Wrap args in an extra list
- self.args = [tuple(self.args)] if self.args else []
- return self
- def __repr__(self) -> str:
- return f"Match(..., {self.args}, {self.kwargs})"
- def erase_nodes(self, graph: torch.fx.Graph) -> None:
- for n in reversed(self.nodes):
- if not n._erased:
- graph.erase_node(n)
- def output_nodes(self) -> List[Optional[torch.fx.Node]]:
- return [
- (self.ctx.pattern_to_node[p] if p is not None else None)
- for p in self.ctx.outputs
- ]
- def output_node(self) -> torch.fx.Node:
- return next(p for p in self.output_nodes() if p)
- def replace_with_graph(
- self, replacement_graph: torch.fx.Graph, args: Sequence[Any]
- ) -> None:
- ReplacementPatternEntry.replace_with_graph(
- self, self.ctx.graph, replacement_graph, args
- )
- def replace_by_example(
- self,
- replacement_fn: ReplaceFn,
- args: Sequence[Any],
- trace_fn: Optional[TraceFn] = None,
- run_dce: bool = True,
- ) -> None:
- from torch._inductor.virtualized import V
- context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext
- with context:
- if trace_fn is None:
- trace_fn = functools.partial(fwd_only, run_dce=run_dce)
- replacement = trace_fn(
- replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
- )
- ReplacementPatternEntry.replace_with_graph(
- self,
- self.ctx.graph,
- replacement,
- args,
- )
- class FailedMatch(RuntimeError):
- """
- Represents a unsuccessful match.
- The `FailedMatch` object is returned to represent a failure to match a
- pattern.
- """
- format_string: str
- def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None:
- self.format_string = format_string
- # We want to construct error messages lazily instead of eagerly, as
- # constructing them eagerly can significantly worsen compile times.
- if len(format_string) > 200:
- raise RuntimeError(
- f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}"
- )
- self.args = args
- self.kwargs = kwargs
- def __str__(self) -> str:
- return self.format_string.format(*self.args, **self.kwargs)
- def __bool__(self) -> bool:
- return False
- MatchResult = Union[Match, FailedMatch]
- def is_match(m: MatchResult) -> TypeGuard[Match]:
- """
- TypeGuards cannot act on `self`. Thus this function exists to let mypy
- recognize FailedMatch.__bool__ as a TypeGuard.
- """
- return bool(m)
- class MatchContext:
- """
- Internal state needed while running PatternExpr._match().
- """
- outputs: List[Optional[PatternExpr]]
- pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]]
- graph: torch.fx.Graph
- exclusive_node_set: List[NodeOrConstant]
- def __init__(
- self,
- outputs: List[Optional[PatternExpr]],
- pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None,
- *,
- graph: torch.fx.Graph,
- ) -> None:
- self.outputs = outputs
- self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node)
- self.graph = graph
- self.exclusive_node_set = []
- def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult:
- """wrapper to check reused nodes in patterns"""
- if pattern in self.pattern_to_node:
- if self.pattern_to_node[pattern] == node:
- return Match(self, pattern) # already checked this node
- else:
- return FailedMatch("repeated pattern differs")
- m = pattern._match(node, self)
- assert pattern not in self.pattern_to_node
- self.pattern_to_node[pattern] = node if m else None
- return m
- def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]:
- return {
- pattern: node
- for pattern, node in self.pattern_to_node.items()
- if pattern.has_multiple_users() and node is not None
- }
- class PatternExpr(ABC):
- """
- Base class for types of patterns.
- """
- @abstractmethod
- def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
- ...
- def match(self, node: torch.fx.Node) -> MatchResult:
- try:
- return MatchContext([self], graph=node.graph).match(self, node)
- except FailedMatch as e:
- return e
- def has_multiple_users(self) -> bool:
- return False
- def __repr__(self) -> str:
- return self.__class__.__name__ + "()"
- def find_anchor_nodes(
- self, ctx: MatchContext, searched: Set[torch.fx.Node]
- ) -> Generator[Optional[torch.fx.Node], None, None]:
- if self in ctx.pattern_to_node:
- yield ctx.pattern_to_node[self]
- def pattern_eq(self, other: Any) -> bool:
- """
- Compare two `PatternExpr`s and return true if they are the
- same. Note this is NOT matching a pattern - it is comparing the pattern
- structures (for debugging).
- """
- return isinstance(other, self.__class__)
- class Arg(PatternExpr):
- """
- Capture an arg which will become an input to the handler. Args are
- passed in depth first order.
- """
- def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
- return Match(ctx, self, args=[node]) # matches anything
- class Ignored(PatternExpr):
- """
- Match an arg, but don't pass it to handler
- """
- def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
- return Match(ctx, self) # matches anything
- def __repr__(self) -> str:
- return "*"
- def pretty_print(self, pp: PatternPrettyPrinter) -> str:
- return "Ignored()"
- class KeywordArg(PatternExpr):
- """
- Capture a kwarg which will become an input to the handler.
- """
- def __init__(self, name: str) -> None:
- super().__init__()
- self.name = name
- def __repr__(self) -> str:
- return f"KeywordArg({self.name!r})"
- def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
- return Match(ctx, self, kwargs={self.name: node}) # matches anything
- def pattern_eq(self, other: Any) -> bool:
- other = typing.cast(Self, other) # super makes sure this is true
- return super().pattern_eq(other) and self.name == other.name
- class ExclusiveKeywordArg(PatternExpr):
- """
- Capture a kwarg which will become an input to the handler.
- """
- name: str
- def __init__(self, name: str) -> None:
- super().__init__()
- self.name = name
- def __repr__(self) -> str:
- return f"ExclusiveKeywordArg({self.name!r})"
- def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
- if node in ctx.exclusive_node_set:
- return FailedMatch("exclusive arg appears twice")
- ctx.exclusive_node_set.append(node)
- return Match(ctx, self, kwargs={self.name: node}) # matches anything
- def pattern_eq(self, other: Any) -> bool:
- other = typing.cast(Self, other) # super makes sure this is true
- return super().pattern_eq(other) and self.name == other.name
- class _TargetExpr(PatternExpr):
- """
- Base class for filtering match by node.target
- """
- fns: List[FnsType]
- fns_set: Set[FnsType]
- def __init__(
- self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1
- ) -> None:
- super().__init__()
- fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns)
- for fn in fns:
- if isinstance(fn, torch._ops.OpOverloadPacket):
- fns.extend(getattr(fn, overload) for overload in fn.overloads())
- self.fns = fns
- self.fns_set = set(fns)
- self.users = users
- @property
- @abstractmethod
- def op(self) -> str:
- ...
- def fns_repr(self) -> str:
- first_repr = self.fns[0]
- if not isinstance(first_repr, str):
- first_repr = first_repr.__name__
- if len(self.fns) > 1:
- return f"[{first_repr}, ...]"
- elif self.fns[0] is getattr(torch, first_repr, None):
- return f"torch.{first_repr}"
- elif isinstance(self.fns[0], torch._ops.OpOverload):
- return str(self.fns[0])
- else:
- return first_repr
- def __repr__(self) -> str:
- if self.users is MULTIPLE:
- comma_users = ", MULTIPLE"
- elif self.users != 1:
- comma_users = f", {self.users})"
- else:
- comma_users = ""
- return f"{self.__class__.__name__}({self.fns_repr()}{comma_users})"
- def has_multiple_users(self) -> bool:
- return isinstance(self.users, Multiple) or self.users > 1
- def find_anchor_nodes(
- self, ctx: MatchContext, searched: Set[torch.fx.Node]
- ) -> Generator[Optional[torch.fx.Node], None, None]:
- raise NotImplementedError
- def _match_fns(self, node: torch.fx.Node) -> bool:
- return (
- isinstance(node, torch.fx.Node)
- and node.op == self.op
- and extract_target(node) in self.fns_set
- )
- def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool:
- return (
- self in ctx.outputs
- or self.users is MULTIPLE
- or len(node.users) == self.users
- )
- def pattern_eq(self, other: Any) -> bool:
- other = typing.cast(Self, other) # super makes sure this is true
- return (
- super().pattern_eq(other)
- and self.op == other.op
- and self.fns == other.fns
- and self.users == other.users
- )
- _SimpleSpec = Tuple[Any, ...]
- class _TargetArgsExpr(_TargetExpr):
- """
- Base class for filtering match by node.{target,args,kwargs}
- """
- def __init__(
- self,
- fns: Union[torch.fx.node.Target, str, Sequence[Any]],
- *args: Any,
- _users: Union[int, Multiple] = 1,
- **kwargs: Any,
- ) -> None:
- super().__init__(fns, _users)
- self.args = tuple(args)
- self.kwargs = dict(kwargs)
- if any(
- isinstance(x, (dict, list, tuple))
- for x in itertools.chain(args, kwargs.values())
- ):
- self.flatten = self.pytree_flatten
- else:
- self.flatten = self.simple_flatten
- self.flat_args_kwargs = self.flatten(self.args, self.kwargs)
- @staticmethod
- def simple_flatten(
- args: Sequence[Any], kwargs: Mapping[Any, Any]
- ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
- values = (*args, *kwargs.values())
- spec = (len(args), *kwargs.keys())
- return values, spec
- @staticmethod
- def pytree_flatten(
- args: Sequence[Any], kwargs: Mapping[Any, Any]
- ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
- def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec:
- if s.type is None:
- return s
- mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
- return pytree.TreeSpec(
- mapping.get(s.type, s.type),
- s.context,
- list(map(norm_spec, s.children_specs)),
- )
- flat, spec = pytree.tree_flatten([args, kwargs])
- spec = norm_spec(spec)
- return flat, spec
- def __repr__(self) -> str:
- args = [
- self.fns_repr(),
- *map(repr, self.args),
- *[f"{k}={v}" for k, v in self.kwargs.items()],
- ]
- if self.users is MULTIPLE:
- args.append("_users=MULTIPLE")
- elif self.users != 1:
- args.append(f"_users={self.users}")
- return f"{self.__class__.__name__}({', '.join(args)})"
- def pretty_print(self, pp: PatternPrettyPrinter) -> str:
- args = [
- self.fns_repr(),
- *(pp.pretty_print(x) for x in self.args),
- *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()],
- ]
- if self.users is MULTIPLE:
- args.append("_users=MULTIPLE")
- elif self.users != 1:
- args.append(f"_users={self.users}")
- joiner_str = ", "
- return f"{self.__class__.__name__}({joiner_str.join(args)})"
- def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
- if not self._match_fns(node) or len(node.args) != len(self.args):
- return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
- if not self._match_users(node, ctx):
- return FailedMatch("multiple_users {}", self)
- _args = node.args
- _kwargs = node.kwargs
- if len(_kwargs) < len(self.kwargs):
- from torch.fx.operator_schemas import normalize_function
- normalized_args_and_kwargs = normalize_function(
- node.target, node.args, node.kwargs
- )
- if normalized_args_and_kwargs is None:
- return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
- else:
- _args, _kwargs = normalized_args_and_kwargs
- if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs):
- _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
- else:
- return FailedMatch(
- "function_mismatch: node={}, pattern={}", node, self
- )
- else:
- _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
- node_items, node_spec = self.flatten(_args, _kwargs)
- self_items, self_spec = self.flat_args_kwargs
- if node_spec != self_spec:
- return FailedMatch("args_structure {} {}", node_spec, self_spec)
- assert len(node_items) == len(self_items)
- m = Match(ctx, self)
- for i, pattern, child_node in zip(itertools.count(), self_items, node_items):
- if isinstance(pattern, PatternExpr):
- child_match = ctx.match(pattern, child_node)
- if not is_match(child_match):
- return child_match
- m.extend(child_match)
- elif isinstance(child_node, torch.fx.Node) or child_node != pattern:
- return FailedMatch(
- "constant_args: {} {!r}!={pattern!r}", node, child_node
- )
- m.nodes.append(node)
- m.targets[self] = node.target
- return m
- def find_anchor_nodes(
- self, ctx: MatchContext, searched: Set[torch.fx.Node]
- ) -> Generator[Optional[torch.fx.Node], None, None]:
- """
- This is used when we are matching a pattern with multiple outputs.
- There is a partial match (stored in ctx) and we want to walk
- this pattern to find a connection to an already-matched node.
- Yields candidate nodes that `self._match` might like.
- """
- if self in ctx.pattern_to_node:
- yield ctx.pattern_to_node[self]
- return
- for pattern in self.flat_args_kwargs[0]:
- if isinstance(pattern, PatternExpr):
- for other_node in pattern.find_anchor_nodes(ctx, searched):
- if not isinstance(other_node, torch.fx.Node):
- continue
- for node in other_node.users:
- if node not in searched:
- if self._match_fns(node):
- yield node
- searched.add(node)
- def pattern_eq(self, other: Any) -> bool:
- other = typing.cast(Self, other) # super makes sure this is true
- return (
- super().pattern_eq(other)
- and self.flat_args_kwargs[1] == other.flat_args_kwargs[1]
- and all(
- a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b
- for a, b in zip(self.flat_args_kwargs[0], other.flat_args_kwargs[0])
- )
- )
- class CallFunction(_TargetArgsExpr):
- """
- Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)`
- """
- op = "call_function"
- class CallMethod(_TargetArgsExpr):
- """
- Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)`
- """
- op = "call_method"
- class CallModule(_TargetArgsExpr):
- """
- Matches a call_module node in the FX graphs: `module(*args, **kwargs)`
- """
- op = "call_module"
- class _TargetExprVarArgs(_TargetExpr):
- """
- Matches a call_function node with any arguments which are passed into the pattern
- """
- def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
- if not self._match_fns(node):
- return FailedMatch("function_mismatch")
- if not self._match_users(node, ctx):
- return FailedMatch("multiple_users")
- m = Match(ctx, self)
- m.nodes.append(node)
- m.targets[self] = node.target
- m.args.extend(node.args)
- m.kwargs.update(node.kwargs)
- return m
- class CallFunctionVarArgs(_TargetExprVarArgs):
- op = "call_function"
- class CallMethodVarArgs(_TargetExprVarArgs):
- op = "call_method"
- class CallModuleVarArgs(_TargetExprVarArgs):
- op = "call_module"
- class ListOf(PatternExpr):
- """
- Matches a repeated pattern
- """
- def __init__(self, pattern: PatternExpr, partial: bool = False) -> None:
- super().__init__()
- assert isinstance(pattern, PatternExpr)
- self.pattern = pattern
- self.partial = partial
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.pattern})"
- def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override]
- if not isinstance(node, (list, tuple)) or len(node) == 0:
- return FailedMatch("non_list")
- m = Match(ctx, self)
- # Propagating patterns with multiple users will ensure we don't revisit
- # the same nodes
- pattern_to_node = ctx.filter_multi_user_patterns()
- matched = False
- for i, child_node in enumerate(node):
- child_ctx = MatchContext(
- ctx.outputs, pattern_to_node, graph=child_node.graph
- )
- child_match = child_ctx.match(self.pattern, child_node)
- pattern_to_node = child_ctx.filter_multi_user_patterns()
- if not is_match(child_match):
- if not self.partial:
- return FailedMatch("list[{}]: {}", i, child_match)
- continue
- matched = True
- m.extend(child_match.bundle())
- if not matched:
- return FailedMatch("list: no_match")
- return m.bundle()
- def pattern_eq(self, other: Any) -> bool:
- other = typing.cast(Self, other) # super makes sure this is true
- return (
- super().pattern_eq(other)
- and self.pattern.pattern_eq(other.pattern)
- and self.partial == other.partial
- )
- class MultiOutputPattern(PatternExpr):
- outputs: List[Optional[PatternExpr]]
- def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None:
- super().__init__()
- assert isinstance(outputs[0], _TargetExpr)
- assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs
- self.outputs = list(outputs)
- self.op = outputs[0].op
- @property
- def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]:
- # This cast is checked above in __init__()
- output = typing.cast(_TargetExpr, self.outputs[0])
- return output.fns
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.outputs})"
- def pretty_print(self, pp: PatternPrettyPrinter) -> str:
- args = [pp.pretty_print(x) for x in self.outputs]
- joiner_str = f",\n{' '}"
- str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}"
- str_out = f"{str_out}\n])"
- return str_out
- def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
- output = typing.cast(_TargetExpr, self.outputs[0])
- m = ctx.match(output, node)
- if not is_match(m):
- return m
- for pattern in self.outputs[1:]:
- if pattern is None:
- continue
- child_match = self._match_from_anchors(pattern, ctx)
- if not is_match(child_match):
- return child_match
- m.extend(child_match)
- return m
- def _match_from_anchors(
- self, pattern: PatternExpr, ctx: MatchContext
- ) -> MatchResult:
- prior = dict(ctx.pattern_to_node)
- m: MatchResult = FailedMatch("no anchor found")
- for node in pattern.find_anchor_nodes(ctx, set()):
- m = ctx.match(pattern, node)
- if is_match(m):
- return m
- # revert any partial matches
- ctx.pattern_to_node = dict(prior)
- return m
- def match(self, node: torch.fx.Node) -> MatchResult:
- try:
- return MatchContext(self.outputs, graph=node.graph).match(self, node)
- except FailedMatch as e:
- return e
- def pattern_eq(self, other: Any) -> bool:
- other = typing.cast(Self, other) # super makes sure this is true
- return (
- super().pattern_eq(other)
- and len(self.outputs) == len(other.outputs)
- and all(
- a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b
- for a, b in zip(self.outputs, other.outputs)
- )
- )
- class RepeatedExpr(PatternExpr):
- """
- Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind`
- """
- def __init__(self, inner_pattern: _TargetExpr) -> None:
- super().__init__()
- self.inner_pattern = inner_pattern
- self.op = inner_pattern.op
- @property
- def fns(self) -> Sequence[FnsType]:
- return self.inner_pattern.fns
- def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
- m = ctx.match(self.inner_pattern, node)
- if not is_match(m):
- return m
- ctx.pattern_to_node.pop(
- self.inner_pattern,
- )
- # Check all anchor nodes match the pattern
- for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()):
- anchor_m = MatchContext([self], graph=node.graph).match(
- self.inner_pattern, anchor_node
- )
- if not is_match(anchor_m):
- return anchor_m
- m.extend(anchor_m)
- return m
- def pattern_eq(self, other: Any) -> bool:
- other = typing.cast(Self, other) # super makes sure this is true
- return super().pattern_eq(other) and self.inner_pattern.pattern_eq(
- other.inner_pattern
- )
- class PatternPrettyPrinter:
- """
- Serializes Patterns to executable python.
- XXX: currently only used and tested for fuse attention patterns. May not cover
- all patterns.
- """
- def __init__(self) -> None:
- self.namespace = torch.fx.graph._Namespace()
- self.memoized_objs_names: Dict[PatternExpr, str] = {}
- self.memoized_objs_pp: Dict[PatternExpr, str] = {}
- @staticmethod
- @functools.lru_cache(None)
- def run(obj: PatternExpr, output_name: str = "output") -> str:
- """
- Serializes obj to python code with obj written out to `output_name`
- """
- pp = PatternPrettyPrinter()
- assert hasattr(obj, "pretty_print")
- out_str = obj.pretty_print(pp=pp)
- output = []
- for key in pp.memoized_objs_names:
- output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}")
- output.append(f"{output_name} = {out_str}")
- return "\n".join(output)
- def pretty_print(self, obj: Any) -> str:
- if isinstance(obj, _TargetArgsExpr):
- if memoized_name := self.memoized_objs_names.get(obj):
- return memoized_name
- else:
- return self.memoize(obj)
- if hasattr(obj, "pretty_print"):
- return obj.pretty_print(self)
- return repr(obj)
- def memoize(self, obj: _TargetArgsExpr) -> str:
- obj_str = obj.pretty_print(self)
- obj_name = obj.fns_repr()
- for prefix in ("aten.", "torch.", "prims."):
- obj_name = obj_name.replace(prefix, "")
- tmp_name = self.namespace.create_name(obj_name, None)
- self.memoized_objs_names[obj] = tmp_name
- self.memoized_objs_pp[obj] = obj_str
- return tmp_name
- class _PassDictsType(Protocol):
- def __getitem__(self, k: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
- ...
- @dataclasses.dataclass
- class PatternEntry:
- pattern: PatternExpr
- extra_check: Callable[[Match], bool]
- def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
- raise NotImplementedError
- def register(
- self,
- pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
- target: Union[torch.fx.node.Target, None] = None,
- prepend: bool = False,
- ) -> None:
- if target is None:
- assert hasattr(self.pattern, "fns")
- for fn in self.pattern.fns:
- self.register(pass_dicts, fn, prepend=prepend)
- elif isinstance(pass_dicts, (dict, PatternMatcherPass)):
- assert hasattr(self.pattern, "op")
- if prepend:
- pass_dicts[(self.pattern.op, target)].insert(0, self)
- else:
- pass_dicts[(self.pattern.op, target)].append(self)
- else:
- pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts)
- for x in pass_dicts:
- self.register(x, target, prepend=prepend)
- @dataclasses.dataclass
- class LoweringPatternEntry(PatternEntry):
- handler: Callable[..., Any]
- def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
- handler = functools.wraps(self.handler)(functools.partial(self.handler, match))
- with graph.inserting_before(node):
- replacement = graph.call_function(handler, tuple(match.args), match.kwargs)
- replacement.meta.update(node.meta)
- node.replace_all_uses_with(replacement)
- assert match.nodes[-1] is node
- match.erase_nodes(graph)
- @dataclasses.dataclass
- class GraphPatternEntry(PatternEntry):
- """
- A pattern that runs a function on the FX graph
- """
- handler: Callable[..., Any]
- def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
- with graph.inserting_before(node):
- self.handler(match, *match.args, **match.kwargs)
- @dataclasses.dataclass
- class ReplacementPatternEntry(PatternEntry):
- normalize_args: Callable[..., List[Any]]
- @staticmethod
- def replace_with_graph(
- match: Match,
- graph: torch.fx.Graph,
- replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule],
- args: Sequence[torch.fx.Node],
- ) -> None:
- output_nodes = match.output_nodes()
- first_node = output_nodes[0]
- class Replacer(torch.fx.Interpreter):
- call_method = None # type: ignore[assignment]
- call_module = None # type: ignore[assignment]
- get_attr = None # type: ignore[assignment]
- def run_node(self, node: torch.fx.Node) -> Any:
- if node.op in ("placeholder", "output"):
- return super().run_node(node)
- if node.op == "call_function":
- target = node.target
- args, kwargs = self.fetch_args_kwargs_from_env(node)
- result = graph.call_function(target, args, kwargs)
- if "val" in node.meta and "val" not in result.meta:
- result.meta["val"] = node.meta["val"]
- if isinstance(node.meta["val"], torch.Tensor):
- assert "tensor_meta" in node.meta
- result.meta["tensor_meta"] = node.meta["tensor_meta"]
- return result
- raise NotImplementedError(f"unhandled {node}")
- output_nodes = match.output_nodes()
- if len(output_nodes) == 1:
- last_node = output_nodes[0]
- else:
- assert output_nodes[0]
- nodes = list(output_nodes[0].graph.nodes)
- indices = [
- (nodes.index(n), n)
- for n in output_nodes
- if isinstance(n, torch.fx.Node)
- ]
- last_node = min(indices, key=operator.itemgetter(0))[1]
- def percolate_tags(
- node: torch.fx.Node, recompute_tag: str, input_stops: Set[torch.fx.Node]
- ) -> None:
- queue = [node]
- visited = set()
- while queue:
- arg = queue.pop()
- if (
- arg not in visited
- and arg not in input_stops
- and hasattr(arg, "meta")
- ):
- visited.add(arg)
- arg.meta["recompute"] = recompute_tag
- queue.extend(arg.all_input_nodes)
- with graph.inserting_before(last_node):
- replacement = Replacer(replacement_graph).run(*args)
- if isinstance(replacement, torch.fx.Node):
- replacement = [replacement]
- def maybe_getitem(node: torch.fx.Node) -> Any:
- if node.op != "call_function":
- return None
- if node.target != operator.getitem:
- return None
- assert len(node.args) == 2
- return node.args[1]
- def replace(
- old: Union[torch.fx.Node, None],
- new: Union[torch.fx.Node, Sequence[torch.fx.Node], None],
- ) -> None:
- if old is None:
- assert new is None
- return
- assert isinstance(old, torch.fx.Node)
- if new is None:
- old.replace_all_uses_with(None)
- graph.erase_node(old)
- return
- if isinstance(new, torch.fx.Node):
- if "val" not in new.meta:
- new.meta.update(old.meta)
- # Preserve the recompute tags in the replacement graph. We
- # look at the recompute tags of the original output node to
- # propagate the tag from the output all the way to the input
- # args (named as args in the replace_with_graph).
- # Note that this is best effort. Since patterns are from
- # many to many, there is no easy way to correctly map the
- # recomputable tags. It is possible in some scenarios that we
- # incorrectly tag some nodes as recomputables.
- if "recompute" in old.meta:
- percolate_tags(new, old.meta["recompute"], set(args))
- old.replace_all_uses_with(new)
- graph.erase_node(old)
- return
- new = typing.cast(Sequence[torch.fx.Node], new)
- # `new` is not a node: it's a list of nodes.
- #
- # This happens when we want to replace a node that has a single
- # packed return with multiple unpacked returns. We need to do
- # some graph surgery here.
- #
- # Example:
- # def original_graph(x):
- # a = op(x)
- # b = a[0]
- # c = a[1]
- # ...
- #
- # Assume that we want to replace op(x) with the graph
- # def new_op(x):
- # w = x + 1
- # z = x + 2
- # return (w, z)
- #
- # We need to replace `op` with the contents of `new_op`,
- # and then rewrite a[0] to be w and a[1] to be z, as so:
- # def new_graph(x):
- # w = x + 1
- # z = x + 2
- # b = w
- # c = z
- # ...
- old_uses = list(old.users.keys())
- for user in old_uses:
- idx = maybe_getitem(user)
- if idx is None:
- raise AssertionError("can't handle")
- replace(user, new[idx])
- graph.erase_node(old)
- if len(output_nodes) == len(replacement):
- for old, new in zip(output_nodes, replacement):
- replace(old, new)
- else:
- assert len(output_nodes) == 1
- replace(output_nodes[0], replacement)
- match.erase_nodes(graph)
- def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
- assert match.replacement_graph is not None
- self.replace_with_graph(
- match,
- graph,
- match.replacement_graph,
- self.normalize_args(*match.args, **match.kwargs),
- )
- def _return_true(match: Match) -> bool:
- return True
- def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None:
- log.info(
- "Replacement pattern %s failed to apply due to shape mismatch: %s",
- search_fn.__name__,
- e,
- )
- def register_replacement(
- search_fn: SearchFn,
- replace_fn: ReplaceFn,
- example_inputs: Iterable[Any],
- trace_fn: TraceFn,
- pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
- extra_check: Callable[[Match], bool] = _return_true,
- scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
- exclusive_arg_names: Sequence[str] = (),
- search_fn_pattern: Union[PatternExpr, None] = None,
- ) -> bool:
- """
- Create a replacement rule based on example functions that get traced
- to create patterns. This supports both training and inference when
- run on a joint forward+backward graph.
- Args:
- search_fn: traced to give original pattern
- replace_fn: traced to give replacement graph
- example_inputs: example inputs for initial trace
- trace_fn: fwd_only or joint_fwd_bwd
- pass_dict: dict of passes to register to
- extra_check: additional check to run on match(using real shapes)
- """
- argnames_static = [*inspect.signature(search_fn).parameters.keys()]
- def check_fn(match: Match) -> bool:
- """
- Often shapes get burned into the pattern, so our initial match ran with
- `ignore_types=(int, ...)`.
- Recheck the match with the correct shapes.
- """
- argnames = list(argnames_static)
- for name in argnames:
- if name not in match.kwargs:
- raise RuntimeError(
- f"Not all inputs to pattern found in match.kwargs. Perhaps one "
- f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}"
- )
- args = list(
- torch.fx.map_arg(
- [match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
- )
- )
- sym_args: List[torch.SymInt] = []
- with torch._dynamo.utils.detect_fake_mode(args):
- for i, grad in enumerate(requires_grad):
- if isinstance(args[i], torch.Tensor):
- if grad and is_integer_dtype(args[i].dtype):
- return False
- args[i] = torch.empty_strided(
- args[i].size(),
- args[i].stride(),
- dtype=args[i].dtype,
- device=args[i].device,
- requires_grad=grad,
- )
- for v in itertools.chain(args[i].shape, args[i].stride()):
- if isinstance(v, torch.SymInt) and all(
- guard_size_oblivious(v != a) for a in sym_args
- ):
- sym_args.append(v)
- # If we were given a pre-traced pattern then use that instead of
- # retracing. Note that this means the pattern has to be independent
- # of its args.
- specific_pattern = search_fn_pattern
- if not specific_pattern:
- if sym_args:
- # AOT Autograd and make fx will dedupe symbolic shape size
- # accesses of sym ints that appear as inputs
- # We don't want the sym_size uses to interfere with pattern matching
- # so we provide them as inputs.
- # Later, when we actually do the replacement, the symbolic shape
- # sizes will get re-traced and added to the graph.
- def search_fn_new(*args_new: Any) -> Any:
- return search_fn(*args_new[len(args_new) - len(args) :])
- try:
- specific_graph = trace_fn(search_fn_new, sym_args + args)
- except RuntimeError as e:
- log_trace_failure(search_fn, e)
- return False
- # correct argnames in the graph
- sym_arg_names = []
- for i, placeholder in zip(
- range(len(sym_args) + len(args)),
- specific_graph.graph.nodes,
- ):
- if i < len(sym_args):
- sym_arg_names.append(placeholder.target)
- continue
- with specific_graph.graph.inserting_after(placeholder):
- new_node = specific_graph.graph.placeholder(
- argnames[i - len(sym_args)]
- )
- new_node.target = new_node.name
- placeholder.replace_all_uses_with(new_node)
- specific_graph.graph.erase_node(placeholder)
- argnames = sym_arg_names + argnames
- else:
- try:
- specific_graph = trace_fn(search_fn, args)
- except RuntimeError as e:
- log_trace_failure(search_fn, e)
- return False
- specific_pattern = fx_to_pattern(
- specific_graph,
- argnames=argnames,
- exclusive_arg_names=exclusive_arg_names,
- scalar_workaround=scalar_workaround,
- )
- node = match.output_nodes()[0]
- assert node is not None
- specific_pattern_match = specific_pattern.match(node)
- if is_match(specific_pattern_match) and extra_check(specific_pattern_match):
- # trace the pattern using the shapes from the user program
- match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment]
- return True
- return False
- def normalize_args(**kwargs: Any) -> List[Any]:
- args = []
- for name in argnames_static:
- args.append(kwargs.pop(name))
- for i in range(1, len(kwargs) + 1):
- if f"tangents_{i}" not in kwargs:
- break
- args.append(kwargs.pop(f"tangents_{i}"))
- assert not kwargs, f"leftover kwargs: {kwargs!r}"
- return args
- if trace_fn is joint_fwd_bwd:
- # If inference mode is enabled during compilation, assume that we don't
- # want to match on any training graph patterns
- if torch.is_inference_mode_enabled():
- return False
- # TODO: Revisit the functionalize_rng_ops for lowmem dropout
- with functorch_config.patch(functionalize_rng_ops=False):
- requires_grad: List[bool] = [
- isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
- ]
- if search_fn_pattern is None:
- pattern = gen_pattern(
- search_fn,
- example_inputs,
- trace_fn,
- scalar_workaround,
- exclusive_arg_names,
- )
- else:
- pattern = search_fn_pattern
- pattern_repr = PatternPrettyPrinter.run(pattern)
- assert pattern_repr not in _seen_patterns
- _seen_patterns.add(pattern_repr)
- pattern = ReplacementPatternEntry(
- pattern=pattern,
- extra_check=check_fn,
- normalize_args=normalize_args,
- )
- pattern.register(pass_dicts)
- return pattern.pattern
- _serialized_patterns: Set[str] = set()
- def _serialize_pattern(
- unique_name: str,
- search_fn: SearchFn,
- example_inputs: Iterable[Any],
- trace_fn: TraceFn,
- scalar_workaround: Union[Dict[str, Union[float, int]], None],
- ) -> PatternExpr:
- def get_file_template() -> str:
- auto_generated_msg = textwrap.dedent(
- """\
- # This is an auto-generated file. Please do not modify it by hand.
- # To re-generate, run:
- # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
- """
- )
- file_template = textwrap.dedent(
- """\
- # mypy: ignore-errors
- # noqa: F401, E501
- {msg}
- import torch
- import torch._inductor
- aten = torch.ops.aten
- prims = torch.ops.prims
- """
- ).format(msg=auto_generated_msg)
- pattern_matcher_imports = []
- for name in dir(torch._inductor.pattern_matcher):
- attr = getattr(torch._inductor.pattern_matcher, name)
- if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
- pattern_matcher_imports.append(name)
- formatted_imports = ",\n ".join(pattern_matcher_imports)
- formatted_imports = f"from torch._inductor.pattern_matcher import (\n {formatted_imports},\n)\n"
- return f"{file_template}{formatted_imports}"
- if not SERIALIZED_PATTERN_PATH.is_dir():
- raise RuntimeError(
- f"Could not find serialized patterns directory at {SERIALIZED_PATTERN_PATH}"
- )
- pattern_name = search_fn.__name__
- from torch._functorch import config as functorch_config
- with functorch_config.patch(functionalize_rng_ops=False):
- pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround)
- serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=unique_name)
- if pattern_name not in _serialized_patterns:
- write_mode = "w"
- _serialized_patterns.add(pattern_name)
- else:
- write_mode = "a"
- file_template = get_file_template()
- with open(SERIALIZED_PATTERN_PATH / f"{pattern_name}.py", write_mode) as f:
- if write_mode == "w":
- f.write(file_template)
- else:
- f.write("\n\n")
- f.write(serialized_pattern)
- f.write("\n")
- return pattern
- SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns"
- # This is the set of serialized patterns that we've registered. Used by
- # test_serialized_patterns_up_to_date() to ensure the patterns are up
- # to date.
- _known_precompiled_patterns: List[
- Tuple[
- Any,
- Iterable[Any],
- Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule],
- Any,
- PatternExpr,
- ]
- ] = []
- def gen_register_replacement(
- unique_name: str,
- search_fn: SearchFn,
- replace_fn: ReplaceFn,
- example_inputs: Iterable[Any],
- trace_fn: TraceFn,
- pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
- extra_check: Callable[[Match], bool] = _return_true,
- scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
- exclusive_arg_names: Sequence[str] = (),
- skip_duplicates: bool = False,
- ) -> None:
- # Make sure the example_inputs is materialized.
- example_inputs = tuple(example_inputs)
- if "PYTORCH_GEN_PATTERNS" in os.environ:
- pat = _serialize_pattern(
- unique_name, search_fn, example_inputs, trace_fn, scalar_workaround
- )
- else:
- pattern_name = search_fn.__name__
- m = importlib.import_module(
- f"torch._inductor.fx_passes.serialized_patterns.{pattern_name}"
- )
- if not m or not hasattr(m, unique_name):
- log.warning(
- "Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.",
- unique_name,
- )
- pat = getattr(m, unique_name)
- for arg in pytree.tree_iter(example_inputs):
- if torch._subclasses.fake_tensor.is_fake(arg) and arg.constant is not None:
- # This can be a problem - small fake tensors (e.g. `tensor(2)`) will
- # hold onto their original constant value - and by stashing it here
- # will cause a memory leak if the constant value is on GPU.
- # Since this is just an optimization we can clear it out.
- arg.constant = None
- if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates:
- return
- _known_precompiled_patterns.append(
- (search_fn, example_inputs, trace_fn, scalar_workaround, pat)
- )
- register_replacement(
- search_fn,
- replace_fn,
- example_inputs,
- trace_fn,
- pass_dicts,
- extra_check,
- scalar_workaround,
- exclusive_arg_names,
- search_fn_pattern=pat,
- )
- @functorch_config.patch(functionalize_rng_ops=False)
- def gen_pattern(
- search_fn: SearchFn,
- example_inputs: Sequence[Any],
- trace_fn: TraceFn,
- scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
- exclusive_arg_names: Sequence[str] = (),
- ) -> PatternExpr:
- argnames = [*inspect.signature(search_fn).parameters.keys()]
- if scalar_workaround is None:
- scalar_workaround = {}
- flat_inputs = []
- input_idx = 0 # Positional arguments index
- for argname in argnames:
- if argname in scalar_workaround:
- flat_inputs.append(scalar_workaround[argname])
- else:
- flat_inputs.append(example_inputs[input_idx])
- input_idx += 1
- search_gm = trace_fn(search_fn, flat_inputs)
- return fx_to_pattern(
- search_gm,
- ignore_types=(int, float, list, torch.device, torch.dtype),
- argnames=argnames,
- scalar_workaround=scalar_workaround,
- exclusive_arg_names=exclusive_arg_names,
- )
- def register_lowering_pattern(
- pattern: PatternExpr,
- extra_check: Callable[[Match], bool] = _return_true,
- *,
- pass_dict: _PassDictsType,
- prepend: bool = False,
- ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
- """
- Register an aten to inductor IR replacement pattern. The decorated
- function is saved and then called a lowering time allowing direct
- pattern to inductor IR conversion.
- """
- def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
- assert callable(handler)
- LoweringPatternEntry(
- pattern=pattern, extra_check=extra_check, handler=handler
- ).register(pass_dict, prepend=prepend)
- handler._inductor_lowering_function = True # type: ignore[attr-defined]
- return handler
- return decorator
- def register_graph_pattern(
- pattern: PatternExpr,
- extra_check: Callable[[Match], bool] = _return_true,
- *,
- pass_dict: _PassDictsType,
- prepend: bool = False,
- ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
- """
- Register a pattern that runs a function on the FX graph, allowing
- custom transformation code.
- """
- def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
- assert callable(handler)
- GraphPatternEntry(
- pattern=pattern, extra_check=extra_check, handler=handler
- ).register(pass_dict, prepend=prepend)
- return handler
- return decorator
- def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
- # first node in the graph
- return node is next(iter(graph.nodes))
- # match: copy_, relu_, _set_grad_enabled, manual_seed, enter_functional_autocast, etc
- _mutation_op_re = re.compile(r"_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_)")
- def is_mutation_op(node: torch.fx.Node) -> bool:
- if node.op == "call_function":
- if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
- return True
- elif node.op == "call_method":
- if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type]
- return True
- return node.kwargs.get("out") is not None
- def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
- n = node
- while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
- n = n.prev
- mutation_region_id = n.meta.get("mutation_region_id", 0)
- while n is not node:
- n = n.next
- if is_mutation_op(n):
- mutation_region_id += 1
- n.meta["mutation_region_id"] = mutation_region_id
- return mutation_region_id
- def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool:
- return "mutation_region_id" not in next(iter(graph.nodes)).meta
- def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None:
- mutation_region_id = 0
- for nd in graph.nodes:
- if is_mutation_op(nd):
- mutation_region_id += 1
- nd.meta["mutation_region_id"] = mutation_region_id
- class PatternMatcherPass:
- def __init__(
- self,
- prevent_match_across_mutations: bool = False,
- pass_name: Optional[str] = None,
- ) -> None:
- super().__init__()
- self.patterns: DefaultDict[
- Tuple[str, torch.fx.node.Target], List[PatternEntry]
- ] = defaultdict(list)
- self.prevent_match_across_mutations = prevent_match_across_mutations
- self.pass_name = pass_name
- def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
- return self.patterns[item]
- def apply(self, graph: torch.fx.GraphModule) -> int:
- if not self.patterns:
- return 0
- if isinstance(graph, torch.fx.GraphModule):
- graph = graph.graph
- if self.prevent_match_across_mutations:
- if should_compute_mutation_region_ids(graph):
- compute_mutation_region_ids(graph)
- get_mutation_region_id_partial = functools.partial(
- get_mutation_region_id, graph
- )
- count = 0
- nodes = []
- has_call_module = False
- for op, target in self.patterns:
- if op == "call_module":
- has_call_module = True
- else:
- nodes.append(graph.find_nodes(op=op, target=target, sort=False))
- if has_call_module:
- nodes.append(graph.find_nodes(op="call_module", sort=False))
- for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
- target = extract_target(node)
- if node.op == "call_module":
- if (node.op, target) not in self.patterns:
- continue
- # conservatively not applying pattern for cpu input,
- # since some of the patterns induce codegen and split nodes.
- # Note: we will only skip cpu compute if disable_cpp_codegen=True
- if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False):
- continue
- for entry in self.patterns[(node.op, target)]:
- if node._erased:
- break
- m = entry.pattern.match(node)
- # pattern match crosses mutation barrier - discard
- if (
- self.prevent_match_across_mutations
- and is_match(m)
- and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
- ):
- continue
- if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
- log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
- if is_match(m) and entry.extra_check(m):
- count += 1
- entry.apply(m, graph, node) # type: ignore[arg-type]
- counters["inductor"]["pattern_matcher_count"] += 1
- counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
- return count
- def clear(self) -> None:
- self.patterns.clear()
- def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn:
- raise NotImplementedError
- def fx_to_pattern(
- gm: Union[torch.fx.GraphModule, torch.fx.Graph],
- ignore_types: Sequence[Type[Any]] = (),
- argnames: Sequence[str] = (),
- scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
- exclusive_arg_names: Sequence[str] = (),
- ) -> PatternExpr:
- """
- Convert an FX graph into a PatternExpr. This is useful for simple
- patterns that can only match single functions and fixed-length lists.
- """
- # scalar_workaround is a hack to capture dropout_p
- # see https://github.com/pytorch/pytorch/issues/97894
- scalar_workaround = scalar_workaround or {}
- inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()}
- assert len(inv_scalar_workaround) == len(scalar_workaround)
- def process_arg(x: T) -> Union[T, KeywordArg, Ignored]:
- if isinstance(x, (float, int)) and x in inv_scalar_workaround:
- return KeywordArg(inv_scalar_workaround[x])
- if type(x) in ignore_types:
- return Ignored()
- if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x:
- return Ignored()
- return x
- argnum = itertools.count()
- class Converter(torch.fx.Interpreter):
- call_method = _not_implemented
- call_module = _not_implemented
- get_attr = _not_implemented
- def placeholder(
- self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
- ) -> Union[ExclusiveKeywordArg, KeywordArg]:
- n = next(argnum)
- if n < len(argnames):
- name = argnames[n]
- elif argnames:
- assert target.startswith("tangent")
- name = target
- else:
- target = re.sub(r"_\d+$", "", target) # de-mangle arg name
- name = target
- if name in exclusive_arg_names:
- return ExclusiveKeywordArg(name)
- else:
- return KeywordArg(name)
- def call_function(
- self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
- ) -> PatternExpr:
- args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
- if list in ignore_types:
- # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...]
- args = [process_arg(a) for a in args]
- kwargs = {k: process_arg(a) for k, a in kwargs.items()}
- return CallFunction(target, *args, **kwargs)
- def run_node(self, n: torch.fx.Node) -> Any:
- rv = super().run_node(n)
- if n.op == "output" and isinstance(rv, tuple):
- assert len(rv) == len(n.args[0]) # type: ignore[arg-type]
- for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type]
- r.users = len(arg.users)
- else:
- rv.users = len(n.users)
- return rv
- pattern = Converter(gm).run()
- if not isinstance(pattern, PatternExpr):
- return MultiOutputPattern(pytree.tree_leaves(pattern))
- return pattern
- @torch.no_grad()
- def fwd_only(
- fn: Callable[..., Any], args: Sequence[Any], *, run_dce: bool = True
- ) -> torch.fx.GraphModule:
- """Build a normalized inference graph, for use with fx_to_pattern"""
- # TODO - look into using aot autograd, asserting no mutating ops here
- with enable_python_dispatcher():
- gm = make_fx(fn, select_decomp_table(), tracing_mode="real")(*args)
- from .fx_passes.post_grad import remove_noop_ops
- remove_noop_ops(gm.graph)
- if run_dce:
- gm.graph.eliminate_dead_code()
- gm.recompile()
- return gm
- @torch.enable_grad()
- def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule:
- """Build a normalized training graph, for use with fx_to_pattern"""
- gm: Optional[torch.fx.GraphModule] = None
- def record_joint_graph(
- joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any
- ) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
- nonlocal gm
- assert not gm
- gm = clone_graph(joint_graph)
- return default_partition(joint_graph, inputs, **kwargs)
- with torch._guards.tracing(None):
- aot_function(
- fn,
- lambda g, i: make_boxed_func(g),
- partition_fn=record_joint_graph,
- decompositions=select_decomp_table(),
- keep_inference_input_mutations=True,
- enable_log=False,
- )(*args)
- assert gm
- from .fx_passes.post_grad import remove_noop_ops
- remove_noop_ops(gm.graph)
- from .fx_passes.joint_graph import pointless_view
- matcher_pass = PatternMatcherPass()
- pattern = CallFunction(
- torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")
- )
- GraphPatternEntry(
- pattern=pattern, handler=pointless_view, extra_check=_return_true
- ).register(matcher_pass.patterns)
- matcher_pass.apply(gm.graph) # type: ignore[arg-type]
- # remove in/out specs
- gm.graph._codegen = torch.fx.graph.CodeGen()
- gm.graph.eliminate_dead_code()
- gm.recompile()
- return gm
- def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
- args: List[torch.fx.node.Argument] = list()
- torch.fx.map_arg((n.args, n.kwargs), args.append)
- return args
- def stable_topological_sort(graph: torch.fx.Graph) -> None:
- # Nodes are in exactly one of these three collections:
- # - Nodes in `pending` are waiting to be processed (in reverse order):
- pending = list(reversed(graph.nodes))
- # - Nodes in `ready` have been processed and are already in the correct
- # order.
- ready = set()
- # - `waiting` is a mapping from a dependency to nodes which depend on that
- # dependency.
- waiting = defaultdict(list)
- # The cursor indicates the last processed node so we can add new nodes
- # after it.
- cursor = None
- while pending:
- node = pending.pop()
- waiting_for = [x for x in _args(node) if x not in ready]
- if waiting_for:
- # We have unprocessed input nodes. Might as well wait for the last
- # arg so an already sorted list will only recheck this node once.
- waiting[waiting_for[-1]].append(node)
- else:
- ready.add(node)
- if cursor and cursor.next is not node:
- cursor.append(node)
- cursor = node
- # Mark the nodes that have been waiting for this node to finish as
- # ready to check again.
- pending.extend(reversed(waiting.pop(node, ())))
- assert not waiting and len(ready) == len(graph.nodes)
- def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]:
- """Wrapper around lazy init functions in fx_passes/"""
- @functools.lru_cache(None)
- @functools.wraps(fn)
- def lazy_init() -> Any:
- counters_ref = counters["inductor"].copy()
- with torch._guards.tracing(
- None
- ), maybe_disable_fake_tensor_mode(), FakeTensorMode():
- result = fn()
- # clear view matches encountered during tracing
- counters["inductor"] = counters_ref
- return result
- return lazy_init
- def config_flag(name: str) -> Callable[[Match], Any]:
- """Function for extra_check to put pass behind a flag"""
- def flag_check(match: Match) -> Any:
- return getattr(config, name)
- return flag_check
- def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
- class CopyGraph(Transformer):
- def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node:
- new_node = super().run_node(old_node)
- if isinstance(new_node, torch.fx.Proxy):
- new_node.node.meta.update(old_node.meta)
- new_node.node.name = self.new_graph._graph_namespace.create_name(
- old_node.name, None
- )
- return new_node
- return CopyGraph(input_graph).transform()
- _seen_patterns: Set[str] = set()
- def get_arg_value(
- node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None
- ) -> Any:
- return (
- node.args[arg_number]
- if len(node.args) > arg_number
- else node.kwargs.get(kwarg_name) # type: ignore[arg-type]
- )
- def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]:
- fns = [fn]
- if isinstance(fn, torch._ops.OpOverloadPacket):
- fns.extend([getattr(fn, overload) for overload in fn.overloads()])
- return [node for node in nodes if node.target in fns]
- def extract_target(node: torch.fx.Node) -> torch.fx.node.Target:
- """For call_function and call_method, we directly use the target function;
- For call_module, the target is string, and we treat the module class
- as a function.
- """
- if node.op == "call_module":
- return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type]
- return node.target
|