pattern_matcher.py 66 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969
  1. """
  2. # Inductor Pattern Matcher
  3. The pattern matcher enables search/replace within an FX graph.
  4. The main entrypoint to the pattern matcher is register_replacement(). Given a
  5. search function and a replacement function this will register a replacement with
  6. a pass (such as torch._inductor.fx_passes.joint_graph.patterns).
  7. Internally the pattern matcher represents patterns as a graph (a DAG). Creating
  8. new patterns manually as a graph is cumbersome and error-prone so the standard
  9. way to create patterns (using register_replacement()) is to provide a search
  10. function and a replacement function which is traced and converted into a graph.
  11. Because the search functions are built somewhat generic (they tend to ignore
  12. tensor sizes, for example) register_replacement() allows you to specify an
  13. `extra_check` function which performs additional checks to verify that the
  14. matched pattern fully matches before returning it.
  15. ## Precompiled Patterns
  16. New patterns are added using register_replacement(). Patterns added in this way
  17. can have a compile-time overhead because they need to be traced before
  18. use. Patterns can be precompiled and added using gen_register_replacement()
  19. instead. To do this you call gen_register_replacement() instead of
  20. register_replacement(). The arguments are the same except for an additional
  21. unique name which is used as a lookup key.
  22. ## Internals
  23. The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr
  24. implements a `_match` method which returns either a `Match` object for a
  25. successful match or a `FailedMatch` object for a failure to match.
  26. """
  27. # mypy: disallow-untyped-defs
  28. from __future__ import annotations
  29. import contextlib
  30. import dataclasses
  31. import functools
  32. import importlib
  33. import inspect
  34. import itertools
  35. import logging
  36. import operator
  37. import os
  38. import re
  39. import textwrap
  40. import typing
  41. from abc import ABC, abstractmethod
  42. from collections import defaultdict
  43. from pathlib import Path
  44. from typing import (
  45. Any,
  46. Callable,
  47. DefaultDict,
  48. Dict,
  49. Generator,
  50. Iterable,
  51. List,
  52. Mapping,
  53. NoReturn,
  54. Optional,
  55. Protocol,
  56. Sequence,
  57. Set,
  58. Tuple,
  59. Type,
  60. TypeVar,
  61. Union,
  62. )
  63. from typing_extensions import Self, TypeGuard
  64. import torch
  65. import torch._guards
  66. import torch.fx
  67. import torch.utils._pytree as pytree
  68. from torch._dispatch.python import enable_python_dispatcher
  69. from torch._dynamo.utils import counters
  70. from torch._prims_common import is_integer_dtype
  71. from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
  72. from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
  73. from torch.fx.immutable_collections import immutable_dict, immutable_list
  74. from .._functorch import config as functorch_config
  75. from .._functorch.aot_autograd import aot_function, make_boxed_func
  76. from .._functorch.partitioners import default_partition
  77. from .._subclasses import FakeTensorMode
  78. from ..fx import Transformer
  79. from . import config
  80. from .decomposition import select_decomp_table
  81. from .lowering import fallback_node_due_to_unsupported_type
  82. log = logging.getLogger(__name__)
  83. aten = torch.ops.aten
  84. prims = torch.ops.prims
  85. Constant = Any
  86. NodeOrConstant = Union[Constant, torch.fx.Node]
  87. class SearchFn(Protocol):
  88. __name__: str
  89. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  90. ...
  91. class ReplaceFn(Protocol):
  92. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  93. ...
  94. class TraceFn(Protocol):
  95. def __call__(
  96. self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any
  97. ) -> torch.fx.GraphModule:
  98. ...
  99. T = TypeVar("T")
  100. # What's a better name for this?
  101. FnsType = Union[torch.fx.node.Target, str]
  102. class Multiple:
  103. def __init__(self) -> None:
  104. # Ensure we're really a singleton.
  105. assert "MULTIPLE" not in globals() or self is MULTIPLE
  106. # Sentinel indicating multiple quantities can be matched
  107. MULTIPLE = Multiple()
  108. class Match:
  109. """
  110. Represents a successfully matched pattern.
  111. The `Match` object is returned to represent a successfully matched
  112. pattern. Included in the Match are the pattern that was matched, the graph
  113. nodes matched, and any args that were used during the matching.
  114. The args and kwargs are specific to the type of pattern that was matched and
  115. provide hints about what was matched.
  116. """
  117. pattern: PatternExpr
  118. args: List[Any]
  119. kwargs: Dict[str, Any]
  120. nodes: List[torch.fx.Node]
  121. targets: Dict[_TargetExpr, torch.fx.node.Target]
  122. ctx: MatchContext
  123. replacement_graph: Optional[torch.fx.Graph]
  124. def __init__(
  125. self,
  126. ctx: MatchContext,
  127. pattern: PatternExpr,
  128. args: Optional[Sequence[Any]] = None,
  129. kwargs: Optional[Dict[str, Any]] = None,
  130. ) -> None:
  131. super().__init__()
  132. self.pattern = pattern
  133. # The input nodes that must be passed in to the result
  134. self.args = list(args or [])
  135. self.kwargs = kwargs or {}
  136. # The nodes matched in this expression
  137. self.nodes = []
  138. # Mapping CallFunction to the node.target
  139. self.targets = {}
  140. self.ctx = ctx
  141. self.replacement_graph = None
  142. @property
  143. def graph(self) -> torch.fx.Graph:
  144. return self.ctx.graph
  145. def extend(self, other: Match) -> None:
  146. if self.kwargs:
  147. for key in set(self.kwargs.keys()) & set(other.kwargs.keys()):
  148. if self.kwargs[key] != other.kwargs[key]:
  149. raise FailedMatch("kwarg mismatch: {}", key)
  150. self.args.extend(other.args)
  151. self.nodes.extend(other.nodes)
  152. self.kwargs.update(other.kwargs)
  153. self.targets.update(other.targets)
  154. def bundle(self) -> Match:
  155. # Wrap args in an extra list
  156. self.args = [tuple(self.args)] if self.args else []
  157. return self
  158. def __repr__(self) -> str:
  159. return f"Match(..., {self.args}, {self.kwargs})"
  160. def erase_nodes(self, graph: torch.fx.Graph) -> None:
  161. for n in reversed(self.nodes):
  162. if not n._erased:
  163. graph.erase_node(n)
  164. def output_nodes(self) -> List[Optional[torch.fx.Node]]:
  165. return [
  166. (self.ctx.pattern_to_node[p] if p is not None else None)
  167. for p in self.ctx.outputs
  168. ]
  169. def output_node(self) -> torch.fx.Node:
  170. return next(p for p in self.output_nodes() if p)
  171. def replace_with_graph(
  172. self, replacement_graph: torch.fx.Graph, args: Sequence[Any]
  173. ) -> None:
  174. ReplacementPatternEntry.replace_with_graph(
  175. self, self.ctx.graph, replacement_graph, args
  176. )
  177. def replace_by_example(
  178. self,
  179. replacement_fn: ReplaceFn,
  180. args: Sequence[Any],
  181. trace_fn: Optional[TraceFn] = None,
  182. run_dce: bool = True,
  183. ) -> None:
  184. from torch._inductor.virtualized import V
  185. context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext
  186. with context:
  187. if trace_fn is None:
  188. trace_fn = functools.partial(fwd_only, run_dce=run_dce)
  189. replacement = trace_fn(
  190. replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
  191. )
  192. ReplacementPatternEntry.replace_with_graph(
  193. self,
  194. self.ctx.graph,
  195. replacement,
  196. args,
  197. )
  198. class FailedMatch(RuntimeError):
  199. """
  200. Represents a unsuccessful match.
  201. The `FailedMatch` object is returned to represent a failure to match a
  202. pattern.
  203. """
  204. format_string: str
  205. def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None:
  206. self.format_string = format_string
  207. # We want to construct error messages lazily instead of eagerly, as
  208. # constructing them eagerly can significantly worsen compile times.
  209. if len(format_string) > 200:
  210. raise RuntimeError(
  211. f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}"
  212. )
  213. self.args = args
  214. self.kwargs = kwargs
  215. def __str__(self) -> str:
  216. return self.format_string.format(*self.args, **self.kwargs)
  217. def __bool__(self) -> bool:
  218. return False
  219. MatchResult = Union[Match, FailedMatch]
  220. def is_match(m: MatchResult) -> TypeGuard[Match]:
  221. """
  222. TypeGuards cannot act on `self`. Thus this function exists to let mypy
  223. recognize FailedMatch.__bool__ as a TypeGuard.
  224. """
  225. return bool(m)
  226. class MatchContext:
  227. """
  228. Internal state needed while running PatternExpr._match().
  229. """
  230. outputs: List[Optional[PatternExpr]]
  231. pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]]
  232. graph: torch.fx.Graph
  233. exclusive_node_set: List[NodeOrConstant]
  234. def __init__(
  235. self,
  236. outputs: List[Optional[PatternExpr]],
  237. pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None,
  238. *,
  239. graph: torch.fx.Graph,
  240. ) -> None:
  241. self.outputs = outputs
  242. self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node)
  243. self.graph = graph
  244. self.exclusive_node_set = []
  245. def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult:
  246. """wrapper to check reused nodes in patterns"""
  247. if pattern in self.pattern_to_node:
  248. if self.pattern_to_node[pattern] == node:
  249. return Match(self, pattern) # already checked this node
  250. else:
  251. return FailedMatch("repeated pattern differs")
  252. m = pattern._match(node, self)
  253. assert pattern not in self.pattern_to_node
  254. self.pattern_to_node[pattern] = node if m else None
  255. return m
  256. def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]:
  257. return {
  258. pattern: node
  259. for pattern, node in self.pattern_to_node.items()
  260. if pattern.has_multiple_users() and node is not None
  261. }
  262. class PatternExpr(ABC):
  263. """
  264. Base class for types of patterns.
  265. """
  266. @abstractmethod
  267. def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
  268. ...
  269. def match(self, node: torch.fx.Node) -> MatchResult:
  270. try:
  271. return MatchContext([self], graph=node.graph).match(self, node)
  272. except FailedMatch as e:
  273. return e
  274. def has_multiple_users(self) -> bool:
  275. return False
  276. def __repr__(self) -> str:
  277. return self.__class__.__name__ + "()"
  278. def find_anchor_nodes(
  279. self, ctx: MatchContext, searched: Set[torch.fx.Node]
  280. ) -> Generator[Optional[torch.fx.Node], None, None]:
  281. if self in ctx.pattern_to_node:
  282. yield ctx.pattern_to_node[self]
  283. def pattern_eq(self, other: Any) -> bool:
  284. """
  285. Compare two `PatternExpr`s and return true if they are the
  286. same. Note this is NOT matching a pattern - it is comparing the pattern
  287. structures (for debugging).
  288. """
  289. return isinstance(other, self.__class__)
  290. class Arg(PatternExpr):
  291. """
  292. Capture an arg which will become an input to the handler. Args are
  293. passed in depth first order.
  294. """
  295. def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
  296. return Match(ctx, self, args=[node]) # matches anything
  297. class Ignored(PatternExpr):
  298. """
  299. Match an arg, but don't pass it to handler
  300. """
  301. def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
  302. return Match(ctx, self) # matches anything
  303. def __repr__(self) -> str:
  304. return "*"
  305. def pretty_print(self, pp: PatternPrettyPrinter) -> str:
  306. return "Ignored()"
  307. class KeywordArg(PatternExpr):
  308. """
  309. Capture a kwarg which will become an input to the handler.
  310. """
  311. def __init__(self, name: str) -> None:
  312. super().__init__()
  313. self.name = name
  314. def __repr__(self) -> str:
  315. return f"KeywordArg({self.name!r})"
  316. def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
  317. return Match(ctx, self, kwargs={self.name: node}) # matches anything
  318. def pattern_eq(self, other: Any) -> bool:
  319. other = typing.cast(Self, other) # super makes sure this is true
  320. return super().pattern_eq(other) and self.name == other.name
  321. class ExclusiveKeywordArg(PatternExpr):
  322. """
  323. Capture a kwarg which will become an input to the handler.
  324. """
  325. name: str
  326. def __init__(self, name: str) -> None:
  327. super().__init__()
  328. self.name = name
  329. def __repr__(self) -> str:
  330. return f"ExclusiveKeywordArg({self.name!r})"
  331. def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult:
  332. if node in ctx.exclusive_node_set:
  333. return FailedMatch("exclusive arg appears twice")
  334. ctx.exclusive_node_set.append(node)
  335. return Match(ctx, self, kwargs={self.name: node}) # matches anything
  336. def pattern_eq(self, other: Any) -> bool:
  337. other = typing.cast(Self, other) # super makes sure this is true
  338. return super().pattern_eq(other) and self.name == other.name
  339. class _TargetExpr(PatternExpr):
  340. """
  341. Base class for filtering match by node.target
  342. """
  343. fns: List[FnsType]
  344. fns_set: Set[FnsType]
  345. def __init__(
  346. self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1
  347. ) -> None:
  348. super().__init__()
  349. fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns)
  350. for fn in fns:
  351. if isinstance(fn, torch._ops.OpOverloadPacket):
  352. fns.extend(getattr(fn, overload) for overload in fn.overloads())
  353. self.fns = fns
  354. self.fns_set = set(fns)
  355. self.users = users
  356. @property
  357. @abstractmethod
  358. def op(self) -> str:
  359. ...
  360. def fns_repr(self) -> str:
  361. first_repr = self.fns[0]
  362. if not isinstance(first_repr, str):
  363. first_repr = first_repr.__name__
  364. if len(self.fns) > 1:
  365. return f"[{first_repr}, ...]"
  366. elif self.fns[0] is getattr(torch, first_repr, None):
  367. return f"torch.{first_repr}"
  368. elif isinstance(self.fns[0], torch._ops.OpOverload):
  369. return str(self.fns[0])
  370. else:
  371. return first_repr
  372. def __repr__(self) -> str:
  373. if self.users is MULTIPLE:
  374. comma_users = ", MULTIPLE"
  375. elif self.users != 1:
  376. comma_users = f", {self.users})"
  377. else:
  378. comma_users = ""
  379. return f"{self.__class__.__name__}({self.fns_repr()}{comma_users})"
  380. def has_multiple_users(self) -> bool:
  381. return isinstance(self.users, Multiple) or self.users > 1
  382. def find_anchor_nodes(
  383. self, ctx: MatchContext, searched: Set[torch.fx.Node]
  384. ) -> Generator[Optional[torch.fx.Node], None, None]:
  385. raise NotImplementedError
  386. def _match_fns(self, node: torch.fx.Node) -> bool:
  387. return (
  388. isinstance(node, torch.fx.Node)
  389. and node.op == self.op
  390. and extract_target(node) in self.fns_set
  391. )
  392. def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool:
  393. return (
  394. self in ctx.outputs
  395. or self.users is MULTIPLE
  396. or len(node.users) == self.users
  397. )
  398. def pattern_eq(self, other: Any) -> bool:
  399. other = typing.cast(Self, other) # super makes sure this is true
  400. return (
  401. super().pattern_eq(other)
  402. and self.op == other.op
  403. and self.fns == other.fns
  404. and self.users == other.users
  405. )
  406. _SimpleSpec = Tuple[Any, ...]
  407. class _TargetArgsExpr(_TargetExpr):
  408. """
  409. Base class for filtering match by node.{target,args,kwargs}
  410. """
  411. def __init__(
  412. self,
  413. fns: Union[torch.fx.node.Target, str, Sequence[Any]],
  414. *args: Any,
  415. _users: Union[int, Multiple] = 1,
  416. **kwargs: Any,
  417. ) -> None:
  418. super().__init__(fns, _users)
  419. self.args = tuple(args)
  420. self.kwargs = dict(kwargs)
  421. if any(
  422. isinstance(x, (dict, list, tuple))
  423. for x in itertools.chain(args, kwargs.values())
  424. ):
  425. self.flatten = self.pytree_flatten
  426. else:
  427. self.flatten = self.simple_flatten
  428. self.flat_args_kwargs = self.flatten(self.args, self.kwargs)
  429. @staticmethod
  430. def simple_flatten(
  431. args: Sequence[Any], kwargs: Mapping[Any, Any]
  432. ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
  433. values = (*args, *kwargs.values())
  434. spec = (len(args), *kwargs.keys())
  435. return values, spec
  436. @staticmethod
  437. def pytree_flatten(
  438. args: Sequence[Any], kwargs: Mapping[Any, Any]
  439. ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
  440. def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec:
  441. if s.type is None:
  442. return s
  443. mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
  444. return pytree.TreeSpec(
  445. mapping.get(s.type, s.type),
  446. s.context,
  447. list(map(norm_spec, s.children_specs)),
  448. )
  449. flat, spec = pytree.tree_flatten([args, kwargs])
  450. spec = norm_spec(spec)
  451. return flat, spec
  452. def __repr__(self) -> str:
  453. args = [
  454. self.fns_repr(),
  455. *map(repr, self.args),
  456. *[f"{k}={v}" for k, v in self.kwargs.items()],
  457. ]
  458. if self.users is MULTIPLE:
  459. args.append("_users=MULTIPLE")
  460. elif self.users != 1:
  461. args.append(f"_users={self.users}")
  462. return f"{self.__class__.__name__}({', '.join(args)})"
  463. def pretty_print(self, pp: PatternPrettyPrinter) -> str:
  464. args = [
  465. self.fns_repr(),
  466. *(pp.pretty_print(x) for x in self.args),
  467. *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()],
  468. ]
  469. if self.users is MULTIPLE:
  470. args.append("_users=MULTIPLE")
  471. elif self.users != 1:
  472. args.append(f"_users={self.users}")
  473. joiner_str = ", "
  474. return f"{self.__class__.__name__}({joiner_str.join(args)})"
  475. def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
  476. if not self._match_fns(node) or len(node.args) != len(self.args):
  477. return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
  478. if not self._match_users(node, ctx):
  479. return FailedMatch("multiple_users {}", self)
  480. _args = node.args
  481. _kwargs = node.kwargs
  482. if len(_kwargs) < len(self.kwargs):
  483. from torch.fx.operator_schemas import normalize_function
  484. normalized_args_and_kwargs = normalize_function(
  485. node.target, node.args, node.kwargs
  486. )
  487. if normalized_args_and_kwargs is None:
  488. return FailedMatch("function_mismatch: node={}, pattern={}", node, self)
  489. else:
  490. _args, _kwargs = normalized_args_and_kwargs
  491. if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs):
  492. _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
  493. else:
  494. return FailedMatch(
  495. "function_mismatch: node={}, pattern={}", node, self
  496. )
  497. else:
  498. _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs}
  499. node_items, node_spec = self.flatten(_args, _kwargs)
  500. self_items, self_spec = self.flat_args_kwargs
  501. if node_spec != self_spec:
  502. return FailedMatch("args_structure {} {}", node_spec, self_spec)
  503. assert len(node_items) == len(self_items)
  504. m = Match(ctx, self)
  505. for i, pattern, child_node in zip(itertools.count(), self_items, node_items):
  506. if isinstance(pattern, PatternExpr):
  507. child_match = ctx.match(pattern, child_node)
  508. if not is_match(child_match):
  509. return child_match
  510. m.extend(child_match)
  511. elif isinstance(child_node, torch.fx.Node) or child_node != pattern:
  512. return FailedMatch(
  513. "constant_args: {} {!r}!={pattern!r}", node, child_node
  514. )
  515. m.nodes.append(node)
  516. m.targets[self] = node.target
  517. return m
  518. def find_anchor_nodes(
  519. self, ctx: MatchContext, searched: Set[torch.fx.Node]
  520. ) -> Generator[Optional[torch.fx.Node], None, None]:
  521. """
  522. This is used when we are matching a pattern with multiple outputs.
  523. There is a partial match (stored in ctx) and we want to walk
  524. this pattern to find a connection to an already-matched node.
  525. Yields candidate nodes that `self._match` might like.
  526. """
  527. if self in ctx.pattern_to_node:
  528. yield ctx.pattern_to_node[self]
  529. return
  530. for pattern in self.flat_args_kwargs[0]:
  531. if isinstance(pattern, PatternExpr):
  532. for other_node in pattern.find_anchor_nodes(ctx, searched):
  533. if not isinstance(other_node, torch.fx.Node):
  534. continue
  535. for node in other_node.users:
  536. if node not in searched:
  537. if self._match_fns(node):
  538. yield node
  539. searched.add(node)
  540. def pattern_eq(self, other: Any) -> bool:
  541. other = typing.cast(Self, other) # super makes sure this is true
  542. return (
  543. super().pattern_eq(other)
  544. and self.flat_args_kwargs[1] == other.flat_args_kwargs[1]
  545. and all(
  546. a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b
  547. for a, b in zip(self.flat_args_kwargs[0], other.flat_args_kwargs[0])
  548. )
  549. )
  550. class CallFunction(_TargetArgsExpr):
  551. """
  552. Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)`
  553. """
  554. op = "call_function"
  555. class CallMethod(_TargetArgsExpr):
  556. """
  557. Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)`
  558. """
  559. op = "call_method"
  560. class CallModule(_TargetArgsExpr):
  561. """
  562. Matches a call_module node in the FX graphs: `module(*args, **kwargs)`
  563. """
  564. op = "call_module"
  565. class _TargetExprVarArgs(_TargetExpr):
  566. """
  567. Matches a call_function node with any arguments which are passed into the pattern
  568. """
  569. def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
  570. if not self._match_fns(node):
  571. return FailedMatch("function_mismatch")
  572. if not self._match_users(node, ctx):
  573. return FailedMatch("multiple_users")
  574. m = Match(ctx, self)
  575. m.nodes.append(node)
  576. m.targets[self] = node.target
  577. m.args.extend(node.args)
  578. m.kwargs.update(node.kwargs)
  579. return m
  580. class CallFunctionVarArgs(_TargetExprVarArgs):
  581. op = "call_function"
  582. class CallMethodVarArgs(_TargetExprVarArgs):
  583. op = "call_method"
  584. class CallModuleVarArgs(_TargetExprVarArgs):
  585. op = "call_module"
  586. class ListOf(PatternExpr):
  587. """
  588. Matches a repeated pattern
  589. """
  590. def __init__(self, pattern: PatternExpr, partial: bool = False) -> None:
  591. super().__init__()
  592. assert isinstance(pattern, PatternExpr)
  593. self.pattern = pattern
  594. self.partial = partial
  595. def __repr__(self) -> str:
  596. return f"{self.__class__.__name__}({self.pattern})"
  597. def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override]
  598. if not isinstance(node, (list, tuple)) or len(node) == 0:
  599. return FailedMatch("non_list")
  600. m = Match(ctx, self)
  601. # Propagating patterns with multiple users will ensure we don't revisit
  602. # the same nodes
  603. pattern_to_node = ctx.filter_multi_user_patterns()
  604. matched = False
  605. for i, child_node in enumerate(node):
  606. child_ctx = MatchContext(
  607. ctx.outputs, pattern_to_node, graph=child_node.graph
  608. )
  609. child_match = child_ctx.match(self.pattern, child_node)
  610. pattern_to_node = child_ctx.filter_multi_user_patterns()
  611. if not is_match(child_match):
  612. if not self.partial:
  613. return FailedMatch("list[{}]: {}", i, child_match)
  614. continue
  615. matched = True
  616. m.extend(child_match.bundle())
  617. if not matched:
  618. return FailedMatch("list: no_match")
  619. return m.bundle()
  620. def pattern_eq(self, other: Any) -> bool:
  621. other = typing.cast(Self, other) # super makes sure this is true
  622. return (
  623. super().pattern_eq(other)
  624. and self.pattern.pattern_eq(other.pattern)
  625. and self.partial == other.partial
  626. )
  627. class MultiOutputPattern(PatternExpr):
  628. outputs: List[Optional[PatternExpr]]
  629. def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None:
  630. super().__init__()
  631. assert isinstance(outputs[0], _TargetExpr)
  632. assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs
  633. self.outputs = list(outputs)
  634. self.op = outputs[0].op
  635. @property
  636. def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]:
  637. # This cast is checked above in __init__()
  638. output = typing.cast(_TargetExpr, self.outputs[0])
  639. return output.fns
  640. def __repr__(self) -> str:
  641. return f"{self.__class__.__name__}({self.outputs})"
  642. def pretty_print(self, pp: PatternPrettyPrinter) -> str:
  643. args = [pp.pretty_print(x) for x in self.outputs]
  644. joiner_str = f",\n{' '}"
  645. str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}"
  646. str_out = f"{str_out}\n])"
  647. return str_out
  648. def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
  649. output = typing.cast(_TargetExpr, self.outputs[0])
  650. m = ctx.match(output, node)
  651. if not is_match(m):
  652. return m
  653. for pattern in self.outputs[1:]:
  654. if pattern is None:
  655. continue
  656. child_match = self._match_from_anchors(pattern, ctx)
  657. if not is_match(child_match):
  658. return child_match
  659. m.extend(child_match)
  660. return m
  661. def _match_from_anchors(
  662. self, pattern: PatternExpr, ctx: MatchContext
  663. ) -> MatchResult:
  664. prior = dict(ctx.pattern_to_node)
  665. m: MatchResult = FailedMatch("no anchor found")
  666. for node in pattern.find_anchor_nodes(ctx, set()):
  667. m = ctx.match(pattern, node)
  668. if is_match(m):
  669. return m
  670. # revert any partial matches
  671. ctx.pattern_to_node = dict(prior)
  672. return m
  673. def match(self, node: torch.fx.Node) -> MatchResult:
  674. try:
  675. return MatchContext(self.outputs, graph=node.graph).match(self, node)
  676. except FailedMatch as e:
  677. return e
  678. def pattern_eq(self, other: Any) -> bool:
  679. other = typing.cast(Self, other) # super makes sure this is true
  680. return (
  681. super().pattern_eq(other)
  682. and len(self.outputs) == len(other.outputs)
  683. and all(
  684. a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b
  685. for a, b in zip(self.outputs, other.outputs)
  686. )
  687. )
  688. class RepeatedExpr(PatternExpr):
  689. """
  690. Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind`
  691. """
  692. def __init__(self, inner_pattern: _TargetExpr) -> None:
  693. super().__init__()
  694. self.inner_pattern = inner_pattern
  695. self.op = inner_pattern.op
  696. @property
  697. def fns(self) -> Sequence[FnsType]:
  698. return self.inner_pattern.fns
  699. def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
  700. m = ctx.match(self.inner_pattern, node)
  701. if not is_match(m):
  702. return m
  703. ctx.pattern_to_node.pop(
  704. self.inner_pattern,
  705. )
  706. # Check all anchor nodes match the pattern
  707. for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()):
  708. anchor_m = MatchContext([self], graph=node.graph).match(
  709. self.inner_pattern, anchor_node
  710. )
  711. if not is_match(anchor_m):
  712. return anchor_m
  713. m.extend(anchor_m)
  714. return m
  715. def pattern_eq(self, other: Any) -> bool:
  716. other = typing.cast(Self, other) # super makes sure this is true
  717. return super().pattern_eq(other) and self.inner_pattern.pattern_eq(
  718. other.inner_pattern
  719. )
  720. class PatternPrettyPrinter:
  721. """
  722. Serializes Patterns to executable python.
  723. XXX: currently only used and tested for fuse attention patterns. May not cover
  724. all patterns.
  725. """
  726. def __init__(self) -> None:
  727. self.namespace = torch.fx.graph._Namespace()
  728. self.memoized_objs_names: Dict[PatternExpr, str] = {}
  729. self.memoized_objs_pp: Dict[PatternExpr, str] = {}
  730. @staticmethod
  731. @functools.lru_cache(None)
  732. def run(obj: PatternExpr, output_name: str = "output") -> str:
  733. """
  734. Serializes obj to python code with obj written out to `output_name`
  735. """
  736. pp = PatternPrettyPrinter()
  737. assert hasattr(obj, "pretty_print")
  738. out_str = obj.pretty_print(pp=pp)
  739. output = []
  740. for key in pp.memoized_objs_names:
  741. output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}")
  742. output.append(f"{output_name} = {out_str}")
  743. return "\n".join(output)
  744. def pretty_print(self, obj: Any) -> str:
  745. if isinstance(obj, _TargetArgsExpr):
  746. if memoized_name := self.memoized_objs_names.get(obj):
  747. return memoized_name
  748. else:
  749. return self.memoize(obj)
  750. if hasattr(obj, "pretty_print"):
  751. return obj.pretty_print(self)
  752. return repr(obj)
  753. def memoize(self, obj: _TargetArgsExpr) -> str:
  754. obj_str = obj.pretty_print(self)
  755. obj_name = obj.fns_repr()
  756. for prefix in ("aten.", "torch.", "prims."):
  757. obj_name = obj_name.replace(prefix, "")
  758. tmp_name = self.namespace.create_name(obj_name, None)
  759. self.memoized_objs_names[obj] = tmp_name
  760. self.memoized_objs_pp[obj] = obj_str
  761. return tmp_name
  762. class _PassDictsType(Protocol):
  763. def __getitem__(self, k: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
  764. ...
  765. @dataclasses.dataclass
  766. class PatternEntry:
  767. pattern: PatternExpr
  768. extra_check: Callable[[Match], bool]
  769. def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
  770. raise NotImplementedError
  771. def register(
  772. self,
  773. pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
  774. target: Union[torch.fx.node.Target, None] = None,
  775. prepend: bool = False,
  776. ) -> None:
  777. if target is None:
  778. assert hasattr(self.pattern, "fns")
  779. for fn in self.pattern.fns:
  780. self.register(pass_dicts, fn, prepend=prepend)
  781. elif isinstance(pass_dicts, (dict, PatternMatcherPass)):
  782. assert hasattr(self.pattern, "op")
  783. if prepend:
  784. pass_dicts[(self.pattern.op, target)].insert(0, self)
  785. else:
  786. pass_dicts[(self.pattern.op, target)].append(self)
  787. else:
  788. pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts)
  789. for x in pass_dicts:
  790. self.register(x, target, prepend=prepend)
  791. @dataclasses.dataclass
  792. class LoweringPatternEntry(PatternEntry):
  793. handler: Callable[..., Any]
  794. def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
  795. handler = functools.wraps(self.handler)(functools.partial(self.handler, match))
  796. with graph.inserting_before(node):
  797. replacement = graph.call_function(handler, tuple(match.args), match.kwargs)
  798. replacement.meta.update(node.meta)
  799. node.replace_all_uses_with(replacement)
  800. assert match.nodes[-1] is node
  801. match.erase_nodes(graph)
  802. @dataclasses.dataclass
  803. class GraphPatternEntry(PatternEntry):
  804. """
  805. A pattern that runs a function on the FX graph
  806. """
  807. handler: Callable[..., Any]
  808. def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
  809. with graph.inserting_before(node):
  810. self.handler(match, *match.args, **match.kwargs)
  811. @dataclasses.dataclass
  812. class ReplacementPatternEntry(PatternEntry):
  813. normalize_args: Callable[..., List[Any]]
  814. @staticmethod
  815. def replace_with_graph(
  816. match: Match,
  817. graph: torch.fx.Graph,
  818. replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule],
  819. args: Sequence[torch.fx.Node],
  820. ) -> None:
  821. output_nodes = match.output_nodes()
  822. first_node = output_nodes[0]
  823. class Replacer(torch.fx.Interpreter):
  824. call_method = None # type: ignore[assignment]
  825. call_module = None # type: ignore[assignment]
  826. get_attr = None # type: ignore[assignment]
  827. def run_node(self, node: torch.fx.Node) -> Any:
  828. if node.op in ("placeholder", "output"):
  829. return super().run_node(node)
  830. if node.op == "call_function":
  831. target = node.target
  832. args, kwargs = self.fetch_args_kwargs_from_env(node)
  833. result = graph.call_function(target, args, kwargs)
  834. if "val" in node.meta and "val" not in result.meta:
  835. result.meta["val"] = node.meta["val"]
  836. if isinstance(node.meta["val"], torch.Tensor):
  837. assert "tensor_meta" in node.meta
  838. result.meta["tensor_meta"] = node.meta["tensor_meta"]
  839. return result
  840. raise NotImplementedError(f"unhandled {node}")
  841. output_nodes = match.output_nodes()
  842. if len(output_nodes) == 1:
  843. last_node = output_nodes[0]
  844. else:
  845. assert output_nodes[0]
  846. nodes = list(output_nodes[0].graph.nodes)
  847. indices = [
  848. (nodes.index(n), n)
  849. for n in output_nodes
  850. if isinstance(n, torch.fx.Node)
  851. ]
  852. last_node = min(indices, key=operator.itemgetter(0))[1]
  853. def percolate_tags(
  854. node: torch.fx.Node, recompute_tag: str, input_stops: Set[torch.fx.Node]
  855. ) -> None:
  856. queue = [node]
  857. visited = set()
  858. while queue:
  859. arg = queue.pop()
  860. if (
  861. arg not in visited
  862. and arg not in input_stops
  863. and hasattr(arg, "meta")
  864. ):
  865. visited.add(arg)
  866. arg.meta["recompute"] = recompute_tag
  867. queue.extend(arg.all_input_nodes)
  868. with graph.inserting_before(last_node):
  869. replacement = Replacer(replacement_graph).run(*args)
  870. if isinstance(replacement, torch.fx.Node):
  871. replacement = [replacement]
  872. def maybe_getitem(node: torch.fx.Node) -> Any:
  873. if node.op != "call_function":
  874. return None
  875. if node.target != operator.getitem:
  876. return None
  877. assert len(node.args) == 2
  878. return node.args[1]
  879. def replace(
  880. old: Union[torch.fx.Node, None],
  881. new: Union[torch.fx.Node, Sequence[torch.fx.Node], None],
  882. ) -> None:
  883. if old is None:
  884. assert new is None
  885. return
  886. assert isinstance(old, torch.fx.Node)
  887. if new is None:
  888. old.replace_all_uses_with(None)
  889. graph.erase_node(old)
  890. return
  891. if isinstance(new, torch.fx.Node):
  892. if "val" not in new.meta:
  893. new.meta.update(old.meta)
  894. # Preserve the recompute tags in the replacement graph. We
  895. # look at the recompute tags of the original output node to
  896. # propagate the tag from the output all the way to the input
  897. # args (named as args in the replace_with_graph).
  898. # Note that this is best effort. Since patterns are from
  899. # many to many, there is no easy way to correctly map the
  900. # recomputable tags. It is possible in some scenarios that we
  901. # incorrectly tag some nodes as recomputables.
  902. if "recompute" in old.meta:
  903. percolate_tags(new, old.meta["recompute"], set(args))
  904. old.replace_all_uses_with(new)
  905. graph.erase_node(old)
  906. return
  907. new = typing.cast(Sequence[torch.fx.Node], new)
  908. # `new` is not a node: it's a list of nodes.
  909. #
  910. # This happens when we want to replace a node that has a single
  911. # packed return with multiple unpacked returns. We need to do
  912. # some graph surgery here.
  913. #
  914. # Example:
  915. # def original_graph(x):
  916. # a = op(x)
  917. # b = a[0]
  918. # c = a[1]
  919. # ...
  920. #
  921. # Assume that we want to replace op(x) with the graph
  922. # def new_op(x):
  923. # w = x + 1
  924. # z = x + 2
  925. # return (w, z)
  926. #
  927. # We need to replace `op` with the contents of `new_op`,
  928. # and then rewrite a[0] to be w and a[1] to be z, as so:
  929. # def new_graph(x):
  930. # w = x + 1
  931. # z = x + 2
  932. # b = w
  933. # c = z
  934. # ...
  935. old_uses = list(old.users.keys())
  936. for user in old_uses:
  937. idx = maybe_getitem(user)
  938. if idx is None:
  939. raise AssertionError("can't handle")
  940. replace(user, new[idx])
  941. graph.erase_node(old)
  942. if len(output_nodes) == len(replacement):
  943. for old, new in zip(output_nodes, replacement):
  944. replace(old, new)
  945. else:
  946. assert len(output_nodes) == 1
  947. replace(output_nodes[0], replacement)
  948. match.erase_nodes(graph)
  949. def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None:
  950. assert match.replacement_graph is not None
  951. self.replace_with_graph(
  952. match,
  953. graph,
  954. match.replacement_graph,
  955. self.normalize_args(*match.args, **match.kwargs),
  956. )
  957. def _return_true(match: Match) -> bool:
  958. return True
  959. def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None:
  960. log.info(
  961. "Replacement pattern %s failed to apply due to shape mismatch: %s",
  962. search_fn.__name__,
  963. e,
  964. )
  965. def register_replacement(
  966. search_fn: SearchFn,
  967. replace_fn: ReplaceFn,
  968. example_inputs: Iterable[Any],
  969. trace_fn: TraceFn,
  970. pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
  971. extra_check: Callable[[Match], bool] = _return_true,
  972. scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
  973. exclusive_arg_names: Sequence[str] = (),
  974. search_fn_pattern: Union[PatternExpr, None] = None,
  975. ) -> bool:
  976. """
  977. Create a replacement rule based on example functions that get traced
  978. to create patterns. This supports both training and inference when
  979. run on a joint forward+backward graph.
  980. Args:
  981. search_fn: traced to give original pattern
  982. replace_fn: traced to give replacement graph
  983. example_inputs: example inputs for initial trace
  984. trace_fn: fwd_only or joint_fwd_bwd
  985. pass_dict: dict of passes to register to
  986. extra_check: additional check to run on match(using real shapes)
  987. """
  988. argnames_static = [*inspect.signature(search_fn).parameters.keys()]
  989. def check_fn(match: Match) -> bool:
  990. """
  991. Often shapes get burned into the pattern, so our initial match ran with
  992. `ignore_types=(int, ...)`.
  993. Recheck the match with the correct shapes.
  994. """
  995. argnames = list(argnames_static)
  996. for name in argnames:
  997. if name not in match.kwargs:
  998. raise RuntimeError(
  999. f"Not all inputs to pattern found in match.kwargs. Perhaps one "
  1000. f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}"
  1001. )
  1002. args = list(
  1003. torch.fx.map_arg(
  1004. [match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
  1005. )
  1006. )
  1007. sym_args: List[torch.SymInt] = []
  1008. with torch._dynamo.utils.detect_fake_mode(args):
  1009. for i, grad in enumerate(requires_grad):
  1010. if isinstance(args[i], torch.Tensor):
  1011. if grad and is_integer_dtype(args[i].dtype):
  1012. return False
  1013. args[i] = torch.empty_strided(
  1014. args[i].size(),
  1015. args[i].stride(),
  1016. dtype=args[i].dtype,
  1017. device=args[i].device,
  1018. requires_grad=grad,
  1019. )
  1020. for v in itertools.chain(args[i].shape, args[i].stride()):
  1021. if isinstance(v, torch.SymInt) and all(
  1022. guard_size_oblivious(v != a) for a in sym_args
  1023. ):
  1024. sym_args.append(v)
  1025. # If we were given a pre-traced pattern then use that instead of
  1026. # retracing. Note that this means the pattern has to be independent
  1027. # of its args.
  1028. specific_pattern = search_fn_pattern
  1029. if not specific_pattern:
  1030. if sym_args:
  1031. # AOT Autograd and make fx will dedupe symbolic shape size
  1032. # accesses of sym ints that appear as inputs
  1033. # We don't want the sym_size uses to interfere with pattern matching
  1034. # so we provide them as inputs.
  1035. # Later, when we actually do the replacement, the symbolic shape
  1036. # sizes will get re-traced and added to the graph.
  1037. def search_fn_new(*args_new: Any) -> Any:
  1038. return search_fn(*args_new[len(args_new) - len(args) :])
  1039. try:
  1040. specific_graph = trace_fn(search_fn_new, sym_args + args)
  1041. except RuntimeError as e:
  1042. log_trace_failure(search_fn, e)
  1043. return False
  1044. # correct argnames in the graph
  1045. sym_arg_names = []
  1046. for i, placeholder in zip(
  1047. range(len(sym_args) + len(args)),
  1048. specific_graph.graph.nodes,
  1049. ):
  1050. if i < len(sym_args):
  1051. sym_arg_names.append(placeholder.target)
  1052. continue
  1053. with specific_graph.graph.inserting_after(placeholder):
  1054. new_node = specific_graph.graph.placeholder(
  1055. argnames[i - len(sym_args)]
  1056. )
  1057. new_node.target = new_node.name
  1058. placeholder.replace_all_uses_with(new_node)
  1059. specific_graph.graph.erase_node(placeholder)
  1060. argnames = sym_arg_names + argnames
  1061. else:
  1062. try:
  1063. specific_graph = trace_fn(search_fn, args)
  1064. except RuntimeError as e:
  1065. log_trace_failure(search_fn, e)
  1066. return False
  1067. specific_pattern = fx_to_pattern(
  1068. specific_graph,
  1069. argnames=argnames,
  1070. exclusive_arg_names=exclusive_arg_names,
  1071. scalar_workaround=scalar_workaround,
  1072. )
  1073. node = match.output_nodes()[0]
  1074. assert node is not None
  1075. specific_pattern_match = specific_pattern.match(node)
  1076. if is_match(specific_pattern_match) and extra_check(specific_pattern_match):
  1077. # trace the pattern using the shapes from the user program
  1078. match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment]
  1079. return True
  1080. return False
  1081. def normalize_args(**kwargs: Any) -> List[Any]:
  1082. args = []
  1083. for name in argnames_static:
  1084. args.append(kwargs.pop(name))
  1085. for i in range(1, len(kwargs) + 1):
  1086. if f"tangents_{i}" not in kwargs:
  1087. break
  1088. args.append(kwargs.pop(f"tangents_{i}"))
  1089. assert not kwargs, f"leftover kwargs: {kwargs!r}"
  1090. return args
  1091. if trace_fn is joint_fwd_bwd:
  1092. # If inference mode is enabled during compilation, assume that we don't
  1093. # want to match on any training graph patterns
  1094. if torch.is_inference_mode_enabled():
  1095. return False
  1096. # TODO: Revisit the functionalize_rng_ops for lowmem dropout
  1097. with functorch_config.patch(functionalize_rng_ops=False):
  1098. requires_grad: List[bool] = [
  1099. isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
  1100. ]
  1101. if search_fn_pattern is None:
  1102. pattern = gen_pattern(
  1103. search_fn,
  1104. example_inputs,
  1105. trace_fn,
  1106. scalar_workaround,
  1107. exclusive_arg_names,
  1108. )
  1109. else:
  1110. pattern = search_fn_pattern
  1111. pattern_repr = PatternPrettyPrinter.run(pattern)
  1112. assert pattern_repr not in _seen_patterns
  1113. _seen_patterns.add(pattern_repr)
  1114. pattern = ReplacementPatternEntry(
  1115. pattern=pattern,
  1116. extra_check=check_fn,
  1117. normalize_args=normalize_args,
  1118. )
  1119. pattern.register(pass_dicts)
  1120. return pattern.pattern
  1121. _serialized_patterns: Set[str] = set()
  1122. def _serialize_pattern(
  1123. unique_name: str,
  1124. search_fn: SearchFn,
  1125. example_inputs: Iterable[Any],
  1126. trace_fn: TraceFn,
  1127. scalar_workaround: Union[Dict[str, Union[float, int]], None],
  1128. ) -> PatternExpr:
  1129. def get_file_template() -> str:
  1130. auto_generated_msg = textwrap.dedent(
  1131. """\
  1132. # This is an auto-generated file. Please do not modify it by hand.
  1133. # To re-generate, run:
  1134. # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
  1135. """
  1136. )
  1137. file_template = textwrap.dedent(
  1138. """\
  1139. # mypy: ignore-errors
  1140. # noqa: F401, E501
  1141. {msg}
  1142. import torch
  1143. import torch._inductor
  1144. aten = torch.ops.aten
  1145. prims = torch.ops.prims
  1146. """
  1147. ).format(msg=auto_generated_msg)
  1148. pattern_matcher_imports = []
  1149. for name in dir(torch._inductor.pattern_matcher):
  1150. attr = getattr(torch._inductor.pattern_matcher, name)
  1151. if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
  1152. pattern_matcher_imports.append(name)
  1153. formatted_imports = ",\n ".join(pattern_matcher_imports)
  1154. formatted_imports = f"from torch._inductor.pattern_matcher import (\n {formatted_imports},\n)\n"
  1155. return f"{file_template}{formatted_imports}"
  1156. if not SERIALIZED_PATTERN_PATH.is_dir():
  1157. raise RuntimeError(
  1158. f"Could not find serialized patterns directory at {SERIALIZED_PATTERN_PATH}"
  1159. )
  1160. pattern_name = search_fn.__name__
  1161. from torch._functorch import config as functorch_config
  1162. with functorch_config.patch(functionalize_rng_ops=False):
  1163. pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround)
  1164. serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=unique_name)
  1165. if pattern_name not in _serialized_patterns:
  1166. write_mode = "w"
  1167. _serialized_patterns.add(pattern_name)
  1168. else:
  1169. write_mode = "a"
  1170. file_template = get_file_template()
  1171. with open(SERIALIZED_PATTERN_PATH / f"{pattern_name}.py", write_mode) as f:
  1172. if write_mode == "w":
  1173. f.write(file_template)
  1174. else:
  1175. f.write("\n\n")
  1176. f.write(serialized_pattern)
  1177. f.write("\n")
  1178. return pattern
  1179. SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns"
  1180. # This is the set of serialized patterns that we've registered. Used by
  1181. # test_serialized_patterns_up_to_date() to ensure the patterns are up
  1182. # to date.
  1183. _known_precompiled_patterns: List[
  1184. Tuple[
  1185. Any,
  1186. Iterable[Any],
  1187. Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule],
  1188. Any,
  1189. PatternExpr,
  1190. ]
  1191. ] = []
  1192. def gen_register_replacement(
  1193. unique_name: str,
  1194. search_fn: SearchFn,
  1195. replace_fn: ReplaceFn,
  1196. example_inputs: Iterable[Any],
  1197. trace_fn: TraceFn,
  1198. pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
  1199. extra_check: Callable[[Match], bool] = _return_true,
  1200. scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
  1201. exclusive_arg_names: Sequence[str] = (),
  1202. skip_duplicates: bool = False,
  1203. ) -> None:
  1204. # Make sure the example_inputs is materialized.
  1205. example_inputs = tuple(example_inputs)
  1206. if "PYTORCH_GEN_PATTERNS" in os.environ:
  1207. pat = _serialize_pattern(
  1208. unique_name, search_fn, example_inputs, trace_fn, scalar_workaround
  1209. )
  1210. else:
  1211. pattern_name = search_fn.__name__
  1212. m = importlib.import_module(
  1213. f"torch._inductor.fx_passes.serialized_patterns.{pattern_name}"
  1214. )
  1215. if not m or not hasattr(m, unique_name):
  1216. log.warning(
  1217. "Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.",
  1218. unique_name,
  1219. )
  1220. pat = getattr(m, unique_name)
  1221. for arg in pytree.tree_iter(example_inputs):
  1222. if torch._subclasses.fake_tensor.is_fake(arg) and arg.constant is not None:
  1223. # This can be a problem - small fake tensors (e.g. `tensor(2)`) will
  1224. # hold onto their original constant value - and by stashing it here
  1225. # will cause a memory leak if the constant value is on GPU.
  1226. # Since this is just an optimization we can clear it out.
  1227. arg.constant = None
  1228. if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates:
  1229. return
  1230. _known_precompiled_patterns.append(
  1231. (search_fn, example_inputs, trace_fn, scalar_workaround, pat)
  1232. )
  1233. register_replacement(
  1234. search_fn,
  1235. replace_fn,
  1236. example_inputs,
  1237. trace_fn,
  1238. pass_dicts,
  1239. extra_check,
  1240. scalar_workaround,
  1241. exclusive_arg_names,
  1242. search_fn_pattern=pat,
  1243. )
  1244. @functorch_config.patch(functionalize_rng_ops=False)
  1245. def gen_pattern(
  1246. search_fn: SearchFn,
  1247. example_inputs: Sequence[Any],
  1248. trace_fn: TraceFn,
  1249. scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
  1250. exclusive_arg_names: Sequence[str] = (),
  1251. ) -> PatternExpr:
  1252. argnames = [*inspect.signature(search_fn).parameters.keys()]
  1253. if scalar_workaround is None:
  1254. scalar_workaround = {}
  1255. flat_inputs = []
  1256. input_idx = 0 # Positional arguments index
  1257. for argname in argnames:
  1258. if argname in scalar_workaround:
  1259. flat_inputs.append(scalar_workaround[argname])
  1260. else:
  1261. flat_inputs.append(example_inputs[input_idx])
  1262. input_idx += 1
  1263. search_gm = trace_fn(search_fn, flat_inputs)
  1264. return fx_to_pattern(
  1265. search_gm,
  1266. ignore_types=(int, float, list, torch.device, torch.dtype),
  1267. argnames=argnames,
  1268. scalar_workaround=scalar_workaround,
  1269. exclusive_arg_names=exclusive_arg_names,
  1270. )
  1271. def register_lowering_pattern(
  1272. pattern: PatternExpr,
  1273. extra_check: Callable[[Match], bool] = _return_true,
  1274. *,
  1275. pass_dict: _PassDictsType,
  1276. prepend: bool = False,
  1277. ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
  1278. """
  1279. Register an aten to inductor IR replacement pattern. The decorated
  1280. function is saved and then called a lowering time allowing direct
  1281. pattern to inductor IR conversion.
  1282. """
  1283. def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
  1284. assert callable(handler)
  1285. LoweringPatternEntry(
  1286. pattern=pattern, extra_check=extra_check, handler=handler
  1287. ).register(pass_dict, prepend=prepend)
  1288. handler._inductor_lowering_function = True # type: ignore[attr-defined]
  1289. return handler
  1290. return decorator
  1291. def register_graph_pattern(
  1292. pattern: PatternExpr,
  1293. extra_check: Callable[[Match], bool] = _return_true,
  1294. *,
  1295. pass_dict: _PassDictsType,
  1296. prepend: bool = False,
  1297. ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
  1298. """
  1299. Register a pattern that runs a function on the FX graph, allowing
  1300. custom transformation code.
  1301. """
  1302. def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
  1303. assert callable(handler)
  1304. GraphPatternEntry(
  1305. pattern=pattern, extra_check=extra_check, handler=handler
  1306. ).register(pass_dict, prepend=prepend)
  1307. return handler
  1308. return decorator
  1309. def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
  1310. # first node in the graph
  1311. return node is next(iter(graph.nodes))
  1312. # match: copy_, relu_, _set_grad_enabled, manual_seed, enter_functional_autocast, etc
  1313. _mutation_op_re = re.compile(r"_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_)")
  1314. def is_mutation_op(node: torch.fx.Node) -> bool:
  1315. if node.op == "call_function":
  1316. if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
  1317. return True
  1318. elif node.op == "call_method":
  1319. if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type]
  1320. return True
  1321. return node.kwargs.get("out") is not None
  1322. def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
  1323. n = node
  1324. while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
  1325. n = n.prev
  1326. mutation_region_id = n.meta.get("mutation_region_id", 0)
  1327. while n is not node:
  1328. n = n.next
  1329. if is_mutation_op(n):
  1330. mutation_region_id += 1
  1331. n.meta["mutation_region_id"] = mutation_region_id
  1332. return mutation_region_id
  1333. def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool:
  1334. return "mutation_region_id" not in next(iter(graph.nodes)).meta
  1335. def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None:
  1336. mutation_region_id = 0
  1337. for nd in graph.nodes:
  1338. if is_mutation_op(nd):
  1339. mutation_region_id += 1
  1340. nd.meta["mutation_region_id"] = mutation_region_id
  1341. class PatternMatcherPass:
  1342. def __init__(
  1343. self,
  1344. prevent_match_across_mutations: bool = False,
  1345. pass_name: Optional[str] = None,
  1346. ) -> None:
  1347. super().__init__()
  1348. self.patterns: DefaultDict[
  1349. Tuple[str, torch.fx.node.Target], List[PatternEntry]
  1350. ] = defaultdict(list)
  1351. self.prevent_match_across_mutations = prevent_match_across_mutations
  1352. self.pass_name = pass_name
  1353. def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
  1354. return self.patterns[item]
  1355. def apply(self, graph: torch.fx.GraphModule) -> int:
  1356. if not self.patterns:
  1357. return 0
  1358. if isinstance(graph, torch.fx.GraphModule):
  1359. graph = graph.graph
  1360. if self.prevent_match_across_mutations:
  1361. if should_compute_mutation_region_ids(graph):
  1362. compute_mutation_region_ids(graph)
  1363. get_mutation_region_id_partial = functools.partial(
  1364. get_mutation_region_id, graph
  1365. )
  1366. count = 0
  1367. nodes = []
  1368. has_call_module = False
  1369. for op, target in self.patterns:
  1370. if op == "call_module":
  1371. has_call_module = True
  1372. else:
  1373. nodes.append(graph.find_nodes(op=op, target=target, sort=False))
  1374. if has_call_module:
  1375. nodes.append(graph.find_nodes(op="call_module", sort=False))
  1376. for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
  1377. target = extract_target(node)
  1378. if node.op == "call_module":
  1379. if (node.op, target) not in self.patterns:
  1380. continue
  1381. # conservatively not applying pattern for cpu input,
  1382. # since some of the patterns induce codegen and split nodes.
  1383. # Note: we will only skip cpu compute if disable_cpp_codegen=True
  1384. if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False):
  1385. continue
  1386. for entry in self.patterns[(node.op, target)]:
  1387. if node._erased:
  1388. break
  1389. m = entry.pattern.match(node)
  1390. # pattern match crosses mutation barrier - discard
  1391. if (
  1392. self.prevent_match_across_mutations
  1393. and is_match(m)
  1394. and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
  1395. ):
  1396. continue
  1397. if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
  1398. log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
  1399. if is_match(m) and entry.extra_check(m):
  1400. count += 1
  1401. entry.apply(m, graph, node) # type: ignore[arg-type]
  1402. counters["inductor"]["pattern_matcher_count"] += 1
  1403. counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes)
  1404. return count
  1405. def clear(self) -> None:
  1406. self.patterns.clear()
  1407. def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn:
  1408. raise NotImplementedError
  1409. def fx_to_pattern(
  1410. gm: Union[torch.fx.GraphModule, torch.fx.Graph],
  1411. ignore_types: Sequence[Type[Any]] = (),
  1412. argnames: Sequence[str] = (),
  1413. scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
  1414. exclusive_arg_names: Sequence[str] = (),
  1415. ) -> PatternExpr:
  1416. """
  1417. Convert an FX graph into a PatternExpr. This is useful for simple
  1418. patterns that can only match single functions and fixed-length lists.
  1419. """
  1420. # scalar_workaround is a hack to capture dropout_p
  1421. # see https://github.com/pytorch/pytorch/issues/97894
  1422. scalar_workaround = scalar_workaround or {}
  1423. inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()}
  1424. assert len(inv_scalar_workaround) == len(scalar_workaround)
  1425. def process_arg(x: T) -> Union[T, KeywordArg, Ignored]:
  1426. if isinstance(x, (float, int)) and x in inv_scalar_workaround:
  1427. return KeywordArg(inv_scalar_workaround[x])
  1428. if type(x) in ignore_types:
  1429. return Ignored()
  1430. if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x:
  1431. return Ignored()
  1432. return x
  1433. argnum = itertools.count()
  1434. class Converter(torch.fx.Interpreter):
  1435. call_method = _not_implemented
  1436. call_module = _not_implemented
  1437. get_attr = _not_implemented
  1438. def placeholder(
  1439. self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
  1440. ) -> Union[ExclusiveKeywordArg, KeywordArg]:
  1441. n = next(argnum)
  1442. if n < len(argnames):
  1443. name = argnames[n]
  1444. elif argnames:
  1445. assert target.startswith("tangent")
  1446. name = target
  1447. else:
  1448. target = re.sub(r"_\d+$", "", target) # de-mangle arg name
  1449. name = target
  1450. if name in exclusive_arg_names:
  1451. return ExclusiveKeywordArg(name)
  1452. else:
  1453. return KeywordArg(name)
  1454. def call_function(
  1455. self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
  1456. ) -> PatternExpr:
  1457. args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
  1458. if list in ignore_types:
  1459. # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...]
  1460. args = [process_arg(a) for a in args]
  1461. kwargs = {k: process_arg(a) for k, a in kwargs.items()}
  1462. return CallFunction(target, *args, **kwargs)
  1463. def run_node(self, n: torch.fx.Node) -> Any:
  1464. rv = super().run_node(n)
  1465. if n.op == "output" and isinstance(rv, tuple):
  1466. assert len(rv) == len(n.args[0]) # type: ignore[arg-type]
  1467. for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type]
  1468. r.users = len(arg.users)
  1469. else:
  1470. rv.users = len(n.users)
  1471. return rv
  1472. pattern = Converter(gm).run()
  1473. if not isinstance(pattern, PatternExpr):
  1474. return MultiOutputPattern(pytree.tree_leaves(pattern))
  1475. return pattern
  1476. @torch.no_grad()
  1477. def fwd_only(
  1478. fn: Callable[..., Any], args: Sequence[Any], *, run_dce: bool = True
  1479. ) -> torch.fx.GraphModule:
  1480. """Build a normalized inference graph, for use with fx_to_pattern"""
  1481. # TODO - look into using aot autograd, asserting no mutating ops here
  1482. with enable_python_dispatcher():
  1483. gm = make_fx(fn, select_decomp_table(), tracing_mode="real")(*args)
  1484. from .fx_passes.post_grad import remove_noop_ops
  1485. remove_noop_ops(gm.graph)
  1486. if run_dce:
  1487. gm.graph.eliminate_dead_code()
  1488. gm.recompile()
  1489. return gm
  1490. @torch.enable_grad()
  1491. def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule:
  1492. """Build a normalized training graph, for use with fx_to_pattern"""
  1493. gm: Optional[torch.fx.GraphModule] = None
  1494. def record_joint_graph(
  1495. joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any
  1496. ) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
  1497. nonlocal gm
  1498. assert not gm
  1499. gm = clone_graph(joint_graph)
  1500. return default_partition(joint_graph, inputs, **kwargs)
  1501. with torch._guards.tracing(None):
  1502. aot_function(
  1503. fn,
  1504. lambda g, i: make_boxed_func(g),
  1505. partition_fn=record_joint_graph,
  1506. decompositions=select_decomp_table(),
  1507. keep_inference_input_mutations=True,
  1508. enable_log=False,
  1509. )(*args)
  1510. assert gm
  1511. from .fx_passes.post_grad import remove_noop_ops
  1512. remove_noop_ops(gm.graph)
  1513. from .fx_passes.joint_graph import pointless_view
  1514. matcher_pass = PatternMatcherPass()
  1515. pattern = CallFunction(
  1516. torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")
  1517. )
  1518. GraphPatternEntry(
  1519. pattern=pattern, handler=pointless_view, extra_check=_return_true
  1520. ).register(matcher_pass.patterns)
  1521. matcher_pass.apply(gm.graph) # type: ignore[arg-type]
  1522. # remove in/out specs
  1523. gm.graph._codegen = torch.fx.graph.CodeGen()
  1524. gm.graph.eliminate_dead_code()
  1525. gm.recompile()
  1526. return gm
  1527. def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
  1528. args: List[torch.fx.node.Argument] = list()
  1529. torch.fx.map_arg((n.args, n.kwargs), args.append)
  1530. return args
  1531. def stable_topological_sort(graph: torch.fx.Graph) -> None:
  1532. # Nodes are in exactly one of these three collections:
  1533. # - Nodes in `pending` are waiting to be processed (in reverse order):
  1534. pending = list(reversed(graph.nodes))
  1535. # - Nodes in `ready` have been processed and are already in the correct
  1536. # order.
  1537. ready = set()
  1538. # - `waiting` is a mapping from a dependency to nodes which depend on that
  1539. # dependency.
  1540. waiting = defaultdict(list)
  1541. # The cursor indicates the last processed node so we can add new nodes
  1542. # after it.
  1543. cursor = None
  1544. while pending:
  1545. node = pending.pop()
  1546. waiting_for = [x for x in _args(node) if x not in ready]
  1547. if waiting_for:
  1548. # We have unprocessed input nodes. Might as well wait for the last
  1549. # arg so an already sorted list will only recheck this node once.
  1550. waiting[waiting_for[-1]].append(node)
  1551. else:
  1552. ready.add(node)
  1553. if cursor and cursor.next is not node:
  1554. cursor.append(node)
  1555. cursor = node
  1556. # Mark the nodes that have been waiting for this node to finish as
  1557. # ready to check again.
  1558. pending.extend(reversed(waiting.pop(node, ())))
  1559. assert not waiting and len(ready) == len(graph.nodes)
  1560. def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]:
  1561. """Wrapper around lazy init functions in fx_passes/"""
  1562. @functools.lru_cache(None)
  1563. @functools.wraps(fn)
  1564. def lazy_init() -> Any:
  1565. counters_ref = counters["inductor"].copy()
  1566. with torch._guards.tracing(
  1567. None
  1568. ), maybe_disable_fake_tensor_mode(), FakeTensorMode():
  1569. result = fn()
  1570. # clear view matches encountered during tracing
  1571. counters["inductor"] = counters_ref
  1572. return result
  1573. return lazy_init
  1574. def config_flag(name: str) -> Callable[[Match], Any]:
  1575. """Function for extra_check to put pass behind a flag"""
  1576. def flag_check(match: Match) -> Any:
  1577. return getattr(config, name)
  1578. return flag_check
  1579. def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
  1580. class CopyGraph(Transformer):
  1581. def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node:
  1582. new_node = super().run_node(old_node)
  1583. if isinstance(new_node, torch.fx.Proxy):
  1584. new_node.node.meta.update(old_node.meta)
  1585. new_node.node.name = self.new_graph._graph_namespace.create_name(
  1586. old_node.name, None
  1587. )
  1588. return new_node
  1589. return CopyGraph(input_graph).transform()
  1590. _seen_patterns: Set[str] = set()
  1591. def get_arg_value(
  1592. node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None
  1593. ) -> Any:
  1594. return (
  1595. node.args[arg_number]
  1596. if len(node.args) > arg_number
  1597. else node.kwargs.get(kwarg_name) # type: ignore[arg-type]
  1598. )
  1599. def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]:
  1600. fns = [fn]
  1601. if isinstance(fn, torch._ops.OpOverloadPacket):
  1602. fns.extend([getattr(fn, overload) for overload in fn.overloads()])
  1603. return [node for node in nodes if node.target in fns]
  1604. def extract_target(node: torch.fx.Node) -> torch.fx.node.Target:
  1605. """For call_function and call_method, we directly use the target function;
  1606. For call_module, the target is string, and we treat the module class
  1607. as a function.
  1608. """
  1609. if node.op == "call_module":
  1610. return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type]
  1611. return node.target