output_graph.py 85 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import contextlib
  4. import copy
  5. import functools
  6. import itertools
  7. import logging
  8. import operator
  9. import re
  10. import sys
  11. import traceback
  12. import weakref
  13. from dataclasses import dataclass
  14. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
  15. import sympy
  16. import torch._guards
  17. import torch._logging
  18. import torch.nn
  19. import torch.utils._pytree as pytree
  20. from torch import fx
  21. from torch._guards import GlobalContextCheckpointState, Source, TracingContext
  22. from torch._utils_internal import signpost_event
  23. from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
  24. from torch.fx.experimental._backward_state import BackwardState
  25. from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv
  26. from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
  27. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  28. from . import config, logging as torchdynamo_logging, variables
  29. from .backends.registry import CompiledFn, CompilerFn
  30. from .bytecode_transformation import (
  31. create_call_function,
  32. create_instruction,
  33. Instruction,
  34. unique_id,
  35. )
  36. from .code_context import code_context
  37. from .codegen import PyCodegen
  38. from .current_scope_id import enter_new_scope
  39. from .exc import (
  40. BackendCompilerFailed,
  41. exceptions_allowed_to_be_fallback,
  42. SkipFrame,
  43. unimplemented,
  44. unimplemented_with_warning,
  45. )
  46. from .guards import GuardBuilder, install_guard
  47. from .mutation_guard import is_dynamic_nn_module
  48. from .side_effects import AttributeMutationExisting, SideEffects
  49. from .source import (
  50. AttrSource,
  51. BackwardStateSource,
  52. ConstantSource,
  53. GetItemSource,
  54. GlobalStateSource,
  55. is_constant_source,
  56. is_from_local_source,
  57. LocalSource,
  58. ParamBufferSource,
  59. ShapeEnvSource,
  60. SyntheticLocalSource,
  61. TensorProperty,
  62. TensorPropertySource,
  63. )
  64. from .utils import (
  65. checkpoint_params,
  66. CleanupHook,
  67. clone_inputs,
  68. count_calls,
  69. counters,
  70. dynamo_timed,
  71. get_instruction_source_311,
  72. get_locals_to_steal,
  73. get_static_address_type,
  74. graph_break_reasons,
  75. increment_op_count,
  76. lazy_format_graph_code,
  77. LazyString,
  78. nn_module_proxy,
  79. same,
  80. set_example_value,
  81. )
  82. from .variables.base import VariableTracker
  83. from .variables.builder import (
  84. BackwardStateGraphArg,
  85. GraphArg,
  86. TrackedFake,
  87. VariableBuilder,
  88. wrap_fx_proxy,
  89. )
  90. from .variables.lists import BaseListVariable
  91. from .variables.misc import NullVariable
  92. from .variables.nn_module import NNModuleVariable
  93. from .variables.tensor import (
  94. NumpyNdarrayVariable,
  95. SymNodeVariable,
  96. TensorVariable,
  97. UnspecializedPythonVariable,
  98. )
  99. from .variables.torch_function import TensorWithTFOverrideVariable
  100. if TYPE_CHECKING:
  101. from torch._dynamo.symbolic_convert import InstructionTranslatorBase
  102. log = logging.getLogger(__name__)
  103. graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
  104. graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
  105. graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
  106. trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
  107. @dataclass(frozen=True)
  108. class VariableTrackerCacheKey:
  109. vt_id: int
  110. # Two different source can point to the same object. However, Dynamo handles
  111. # globals and local source differently when it comes to guards and possibly
  112. # some other parts as well. So, cache also relies on the source.
  113. source: Source
  114. class VariableTrackerCache:
  115. def __init__(self):
  116. self.cache = {}
  117. def lookup(self, value, source):
  118. key = VariableTrackerCacheKey(id(value), source)
  119. if key not in self.cache:
  120. return None
  121. return self.cache[key]
  122. def add(self, value, source, vt):
  123. key = VariableTrackerCacheKey(id(value), source)
  124. self.cache[key] = vt
  125. def clone(self):
  126. # Needed for copy and restore graph state
  127. new_cache = VariableTrackerCache()
  128. new_cache.cache.update(self.cache)
  129. return new_cache
  130. def clear(self):
  131. self.cache.clear()
  132. @functools.lru_cache(None)
  133. def _step_logger():
  134. return torchdynamo_logging.get_step_logger(log)
  135. @dataclass
  136. class GraphCompileReason:
  137. """Stores why a given output graph was compiled; i.e. what caused the graph break."""
  138. reason: str
  139. user_stack: List[traceback.FrameSummary]
  140. # Indicates if this was a graph compile reason due to graph break.
  141. graph_break: bool = True
  142. def __post_init__(self):
  143. if self.graph_break:
  144. graph_break_reasons.append(self)
  145. def _get_gen_rand_values_fn(random_calls):
  146. def _gen_rand_values():
  147. return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
  148. return _gen_rand_values
  149. class FakeRootModule(torch.nn.Module):
  150. """Trick the constructor of fx.GraphModule"""
  151. def __init__(self, nn_modules: Dict[str, torch.nn.Module]):
  152. super().__init__()
  153. for k, v in nn_modules.items():
  154. setattr(self, k, v)
  155. def __repr__(self):
  156. return "FakeRootModule(...)"
  157. class WrapperBackend:
  158. def __init__(self, backend: CompilerFn):
  159. self.backend: CompilerFn = backend
  160. def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
  161. self.restore = checkpoint_params(gm)
  162. self.gm = gm
  163. copy_gm = copy.deepcopy(self.gm)
  164. self.candidate = self.backend(copy_gm, example_inputs)
  165. if self.candidate is None or self.candidate is self.gm.forward:
  166. return self.gm.forward
  167. if not config.verify_correctness:
  168. return self.candidate
  169. # if verify_correctness=True
  170. try:
  171. correct = self.gm.forward(*clone_inputs(example_inputs))
  172. result = self.candidate(*clone_inputs(example_inputs))
  173. # TODO: replace `same` function with the one in testing
  174. if same(correct, result):
  175. return self.candidate
  176. raise RuntimeError(f"incorrect results of backend {self}")
  177. return self.gm.forward
  178. except Exception:
  179. log.exception("error in verify_correctness")
  180. raise
  181. finally:
  182. self.restore()
  183. Scope = Dict[str, object]
  184. class OutputGraph:
  185. """
  186. Wrapper class to hold outputs of InstructionTranslator. Mainly the
  187. generated fx.Graph.
  188. OutputGraph is 1:1 with a frame being processed. Each frame is associated
  189. with some root InstructionTranslator. When user code calls a function,
  190. we construct a InliningInstructionTranslator that continues to write into
  191. the root InstructionTranslator's OutputGraph.
  192. """
  193. def __init__(
  194. self,
  195. code_options: Dict[str, Any],
  196. compiler_fn: Optional[CompilerFn],
  197. root_tx,
  198. export: bool,
  199. export_constraints,
  200. frame_state,
  201. local_scope: Scope,
  202. global_scope: Scope,
  203. f_code,
  204. ):
  205. super().__init__()
  206. self.tracers = [SubgraphTracer(self, export_root=export)]
  207. # Map from graph input's `Source` to its `VariableTracker` to
  208. # de-duplicate graph inputs by source and reuse the tracker
  209. self.input_source_to_var: Dict[Source, VariableTracker] = {}
  210. self.export = export
  211. self.export_constraints = export_constraints
  212. self.frame_state = frame_state
  213. # Map from graph input's `Source` to sizes / strides metadata
  214. self.input_source_to_sizes_strides: Dict[Source, Dict[str, Any]] = {}
  215. self.cleanup_hooks: List[Callable[[], Any]] = []
  216. # compile_id is an id number for the current torch.compile
  217. self.compile_id: int = next(_compile_id_counter)
  218. # Set of globals installed via install_global* APIs
  219. self.installed_globals: Set[str] = set()
  220. # TODO: maybe should just pass the entire f_code in here? Not
  221. # sure...
  222. self.co_fields = {
  223. "co_name": f_code.co_name,
  224. "co_filename": f_code.co_filename,
  225. "co_firstlineno": f_code.co_firstlineno,
  226. }
  227. # tracked_fakes says where any tensor that was wrapped to fake came
  228. # from. It is similar to GraphArg, in that all GraphArgs will get
  229. # will get added to TrackedFakes, but TrackedFakes also contains
  230. # GraphArgs that got pruned, and things like Tensor attributes which
  231. # aren't explicit graph inputs. Used by shape guard
  232. self.tracked_fakes: List[TrackedFake] = []
  233. # List of symbols for which we have exact bindings in the arguments
  234. # already
  235. self.bound_symbols: Set[sympy.Symbol] = set()
  236. shape_env = ShapeEnv(
  237. # Reference Cycle!
  238. # Share a reference to the list of TrackedFake.
  239. #
  240. # ShapeEnv needs this in order to be able to reproduce the call
  241. # to produce_guards at an arbitrary time point. That is because
  242. # TrackedFake instances may have its metadata changed throughout
  243. # the program execution.
  244. tracked_fakes=self.tracked_fakes,
  245. allow_scalar_outputs=config.capture_scalar_outputs,
  246. allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
  247. prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
  248. _allow_complex_guards_as_runtime_asserts=config._allow_complex_guards_as_runtime_asserts,
  249. co_fields=self.co_fields,
  250. )
  251. # In export mode, we force the shape_env to strictly disallow any constraining
  252. # of the user marked dynamic dims
  253. import torch._functorch.config as _config
  254. with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
  255. fake_mode = torch._subclasses.FakeTensorMode(
  256. shape_env=shape_env,
  257. # TODO (tmanlaibaatar) Remove this once we always lift params and buffers
  258. allow_non_fake_inputs=True if self.export else False,
  259. export=self.export,
  260. )
  261. self.tracing_context: TracingContext = TracingContext(fake_mode)
  262. self.init_ambient_guards()
  263. # Map each tensor id to a list of sources. This is necessary because
  264. # tensor ids cannot be recovered from tracked fakes (in general).
  265. # We use this map to interpret (i.e., check for violations of) constraints,
  266. # specifically equality constraints, which have shared tensor ids in them.
  267. # This map should also be generally useful, e.g., for (de)serialization.
  268. self.tracked_fakes_id_to_source: Dict[
  269. int, List[Source]
  270. ] = collections.defaultdict(list)
  271. # Stores the full fqn of a param or buffer to the relevant source.
  272. self.param_name_to_source: Optional[Dict[str, Source]] = dict()
  273. self.side_effects = SideEffects()
  274. # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL
  275. # and LOAD_ATTR for same python objects free.
  276. self.variable_tracker_cache = VariableTrackerCache()
  277. self.unique_var_id = itertools.count()
  278. self.code_options = dict(code_options)
  279. self.output_instructions: List[Instruction] = []
  280. # used to track nodes that are added between calls of copy_graphstate
  281. # and restore_graphstate
  282. self.timestamp = 0
  283. # A list of register_finalizer_fns to apply to the output graph module
  284. self.register_finalizer_fns: List[Callable[[fx.GraphModule], None]] = []
  285. # Not checkpointed
  286. self.compiler_fn: Optional[CompilerFn] = compiler_fn
  287. self.global_scope = global_scope
  288. self.local_scope = local_scope
  289. self.root_tx = root_tx
  290. # Given a source, what are the user stacks of all locations that
  291. # accessed it?
  292. #
  293. # For efficiency, we only populate this:
  294. # - During export, and
  295. # - If the source could potentially lead to a spurious export input
  296. #
  297. # Feel free to populate this more frequently if other use-cases arise,
  298. # but be aware that we have to generate full stacks for each
  299. # recording!
  300. self.source_to_user_stacks: Dict[Source, List[traceback.StackSummary]] = {}
  301. self._current_tx: List[InstructionTranslatorBase] = []
  302. self.cleanups: List[CleanupHook] = []
  303. self.should_exit = False
  304. self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {}
  305. self.torch_function_enabled = torch._C._is_torch_function_enabled()
  306. # Tracks if the output graph has a user defined allowed function in the
  307. # graph. This is used later to determine if we should fallback to eager
  308. # for certain exceptions. THe idea is that if the user has applied
  309. # allow_in_graph, they would like to see the error instead of falling
  310. # back for backend errors.
  311. self.has_user_defined_allowed_in_graph = False
  312. # Tracks a list of called ops that were not tagged with "pt2_compliant_tag".
  313. # This information is useful for logging.
  314. self.non_compliant_ops: Set[torch._ops.OpOverload] = set({})
  315. # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag".
  316. # This information is useful for logging.
  317. self.compliant_custom_ops: Set[torch._ops.OpOverload] = set({})
  318. # We save the global torch state here to be restored in case of graph
  319. # breaks. The relevant issue is seen here
  320. # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086
  321. # where inlining of a function changes the global state (because of the
  322. # presence of torch.no_grad) and there is a graph break.
  323. self.save_global_state()
  324. # Tracks the original FQNs of the constant tensors from the original graph,
  325. # i.e. buffers and parameters.
  326. self.dynamo_flat_name_to_original_fqn: Dict[str, str] = {}
  327. # All calls to random() are replaced with a single call to __gen_rand_values
  328. # functions that returns a tuple of random values for each original call.
  329. # random_calls tracks calls to random() and random_values_var stores the name of
  330. # the variable that stores __gen_rand_values results.
  331. self.random_calls: List[
  332. Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
  333. ] = []
  334. self.random_values_var = None
  335. # Bytecode to insert right before we call the graph
  336. self.pregraph_bytecode: List[Instruction] = []
  337. # Use to pass values to backward hooks when using compiled autograd
  338. self.backward_state: Dict[str, VariableTracker] = {}
  339. self.backward_state_proxy: Optional[torch.fx.Proxy] = None
  340. self.backward_state_var: Optional[str] = None
  341. self.name_of_builtins_dict_key_in_fglobals: str = (
  342. self.install_builtins_dict_in_fglobals()
  343. )
  344. self.guard_on_key_order: Set[str] = set()
  345. def install_builtins_dict_in_fglobals(self):
  346. # f_globals["__builtins__"] can be a dict or a module. This is an
  347. # implemenation detail -
  348. # https://docs.python.org/3/library/builtins.html.
  349. # This makes guarding on any builtin messy because the guard check_fn
  350. # has to check if the __builtins__ is a module or dict, and then access
  351. # by either using getattr or getitem respectively.
  352. # To solve this problem, we insert a new entry in f_globals which points
  353. # to the builtins __dict__ and then we guard any builtin on this dict.
  354. # To avoid any collision with the pre-existing keys, we use the
  355. # install_global to give us a unique dict key.
  356. f_builtins = self.global_scope["__builtins__"]
  357. if not isinstance(f_builtins, dict):
  358. f_builtins = f_builtins.__dict__
  359. return self.install_global("__builtins_dict__", f_builtins)
  360. def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"):
  361. name = f"{prefix}{len(self.backward_state)}"
  362. assert name not in self.backward_state
  363. self.backward_state[name] = hook
  364. return name, self.get_backward_state_proxy()
  365. def get_backward_state_proxy(self):
  366. if self.backward_state_proxy is None:
  367. if self.export:
  368. unimplemented("backward_state does not support export")
  369. self.backward_state_proxy = self.root_tracer.create_graph_input(
  370. "dynamo_backward_state", BackwardState, source=BackwardStateSource()
  371. )
  372. self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg()
  373. set_example_value(self.backward_state_proxy.node, BackwardState())
  374. self.backward_state_var = self.new_var()
  375. return self.backward_state_proxy
  376. # This gets its own helper function so guards DEBUG logs are more informative
  377. def init_ambient_guards(self):
  378. # Register a SHAPE_ENV guard to make sure we setup shape guards
  379. # that show up in ShapeEnv
  380. self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
  381. self.guards.add(
  382. GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS)
  383. )
  384. self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE))
  385. self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
  386. self.guards.add(
  387. GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
  388. )
  389. ci = torch._C._functorch.peek_interpreter_stack()
  390. if ci is not None:
  391. self.guards.add(
  392. GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
  393. )
  394. def synthetic_graph_input(self, fn, args):
  395. """
  396. call fn(*args) before the graph runs and turn the result into a fake input.
  397. """
  398. example_value = fn(*args)
  399. varname = self.new_var()
  400. cg = PyCodegen(self.root_tx)
  401. cg.load_import_from(
  402. fn.__module__,
  403. fn.__name__,
  404. )
  405. cg.foreach(map(variables.ConstantVariable.create, args))
  406. cg.call_function(len(args), True)
  407. cg.store(varname)
  408. self.pregraph_bytecode.extend(cg.get_instructions())
  409. source = SyntheticLocalSource(varname)
  410. result = VariableBuilder(self.root_tx, source)(example_value)
  411. TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
  412. source
  413. )
  414. return result
  415. def add_cleanup_hook(self, fn: Callable[[], Any]):
  416. self.cleanup_hooks.append(fn)
  417. def call_cleanup_hooks(self):
  418. for hook in reversed(self.cleanup_hooks):
  419. hook()
  420. self.cleanup_hooks.clear()
  421. @property
  422. def root_tracer(self):
  423. return self.tracers[0]
  424. @property
  425. def current_tracer(self):
  426. return self.tracers[-1]
  427. def is_root_tracer(self):
  428. # Helper to tell if we are inside the higher order operator tracing.
  429. return len(self.tracers) == 1
  430. @property
  431. def graph(self):
  432. return self.current_tracer.graph
  433. # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer.
  434. @graph.setter
  435. def graph(self, value):
  436. self.current_tracer.graph = value
  437. @property
  438. def input_name_to_proxy(self):
  439. return self.current_tracer.input_name_to_proxy
  440. @property
  441. def real_value_cache(self):
  442. return self.current_tracer.real_value_cache
  443. # If you are here, and you're looking for create_graph_input,
  444. # to avoid ambiguity, please call one of the following:
  445. # - self.current_tracer.create_graph_input
  446. # - self.root_tracer.create_graph_input
  447. # See NOTE [HigherOrderOperator tracing design] for more context.
  448. def create_proxy(self, *args, **kwargs):
  449. return self.current_tracer.create_proxy(*args, **kwargs)
  450. def create_node(self, *args, **kwargs):
  451. return self.current_tracer.create_node(*args, **kwargs)
  452. def remove_node(self, *args, **kwargs):
  453. return self.current_tracer.remove_node(*args, **kwargs)
  454. @contextlib.contextmanager
  455. def subtracer(self, source_target, prior_tracer):
  456. new_scope_ctx = enter_new_scope()
  457. try:
  458. if prior_tracer:
  459. # Lineage MUST stay preserved
  460. assert prior_tracer.parent is self.current_tracer
  461. new_scope_ctx.__enter__()
  462. tracer = (
  463. prior_tracer
  464. if prior_tracer
  465. else SubgraphTracer(
  466. self, parent=self.current_tracer, source_target=source_target
  467. )
  468. )
  469. self.tracers.append(tracer)
  470. yield tracer
  471. finally:
  472. new_scope_ctx.__exit__(None, None, None)
  473. self.tracers.pop()
  474. @property
  475. def output(self):
  476. return self
  477. @property
  478. def fake_mode(self):
  479. return self.tracing_context.fake_mode
  480. @property
  481. def shape_env(self):
  482. return self.tracing_context.fake_mode.shape_env
  483. @property
  484. def guards(self) -> torch._guards.GuardsSet:
  485. return self.tracing_context.guards_context.dynamo_guards
  486. @property
  487. def nn_modules(self) -> Dict[str, Any]:
  488. return self.tracing_context.module_context.nn_modules
  489. def save_global_state(self, out=None):
  490. """
  491. Saves to out if it is provided. Else saves to the tracing context's global_state.
  492. """
  493. global_state = (
  494. out if out is not None else self.tracing_context.global_context.global_state
  495. )
  496. # TODO - Consider having a torch level API for torch_function_state. As
  497. # of now, we create a ref cycle by passing the
  498. # output.set_torch_function_state to
  499. # output.tracing_context.global_context.global_state. In the interim,
  500. # the problem can be solved by manually set
  501. # output.tracing_context.global_context.global_state to None at cleanup.
  502. global_state["torch_function_enabled"] = (
  503. self.set_torch_function_state,
  504. self.torch_function_enabled,
  505. )
  506. global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())
  507. global_state["autocast_enabled"] = (
  508. functools.partial(torch.set_autocast_enabled, "cuda"),
  509. torch.is_autocast_enabled("cuda"),
  510. )
  511. global_state["autocast_cpu_enabled"] = (
  512. functools.partial(torch.set_autocast_enabled, "cpu"),
  513. torch.is_autocast_enabled("cpu"),
  514. )
  515. global_state["autocast_gpu_dtype"] = (
  516. functools.partial(torch.set_autocast_dtype, "cuda"),
  517. torch.get_autocast_dtype("cuda"),
  518. )
  519. global_state["autocast_cpu_dtype"] = (
  520. functools.partial(torch.set_autocast_dtype, "cpu"),
  521. torch.get_autocast_dtype("cpu"),
  522. )
  523. global_state["autocast_cache_enabled"] = (
  524. torch.set_autocast_cache_enabled,
  525. torch.is_autocast_cache_enabled(),
  526. )
  527. def push_tx(self, tx):
  528. self._current_tx.append(tx)
  529. def pop_tx(self):
  530. return self._current_tx.pop()
  531. @property
  532. def current_tx(self):
  533. return self.root_tx if not self._current_tx else self._current_tx[-1]
  534. def add_symbol_bindings(self, arg: GraphArg):
  535. # Insert implicit size vars as necessary. With dynamic shapes, we
  536. # maintain the invariant that every sizevar gets a direct SymInt input
  537. # into the graph. This means downstream graph transforms can assume
  538. # every size variable is explicitly bound and accessible, instead of
  539. # having to pull it out implicitly from tensors.
  540. if self.export:
  541. return
  542. assert arg.fake_tensor is not None
  543. def bind_symint(s, prop):
  544. if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)):
  545. return
  546. s0 = s.node.expr
  547. if s0 in self.bound_symbols:
  548. return
  549. self.bound_symbols.add(s0)
  550. log.debug("bind_symint %s %s", s, prop.name())
  551. # TODO: don't readd symint if we already have it in graph
  552. # (this is harmless because we do remove the unused ones later)
  553. proxy = self.root_tracer.create_graph_input(
  554. str(s0),
  555. torch.SymInt,
  556. before=True,
  557. source=prop,
  558. )
  559. set_example_value(proxy.node, s)
  560. proxy.node.meta["grapharg"] = GraphArg(
  561. prop,
  562. s,
  563. pass_arg_as_tensor=False,
  564. fake_tensor=None,
  565. is_tensor=False,
  566. )
  567. def handle_tensor(t, src):
  568. for i, s in enumerate(t.size()):
  569. bind_symint(s, TensorPropertySource(src, TensorProperty.SIZE, i))
  570. if t.layout is torch.strided:
  571. for i, s in enumerate(t.stride()):
  572. bind_symint(s, TensorPropertySource(src, TensorProperty.STRIDE, i))
  573. bind_symint(
  574. t.storage_offset(),
  575. TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
  576. )
  577. elif t.layout is torch.sparse_coo:
  578. handle_tensor(t._indices(), src)
  579. handle_tensor(t._values(), src)
  580. elif t.layout in {torch.sparse_csr, torch.sparse_bsr}:
  581. handle_tensor(t.crow_indices(), src)
  582. handle_tensor(t.col_indices(), src)
  583. elif t.layout in {torch.sparse_csc, torch.sparse_bsc}:
  584. handle_tensor(t.ccol_indices(), src)
  585. handle_tensor(t.row_indices(), src)
  586. if is_traceable_wrapper_subclass(t):
  587. attrs, ctx = t.__tensor_flatten__()
  588. for attr in attrs:
  589. inner_t = getattr(t, attr)
  590. handle_tensor(inner_t, AttrSource(src, attr))
  591. handle_tensor(arg.fake_tensor, arg.source)
  592. def count_calls(self):
  593. return count_calls(self.graph)
  594. def is_empty_graph(self):
  595. return len(list(self.graph.nodes)) == 0
  596. def get_submodule(self, keys):
  597. assert keys
  598. obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules
  599. for k in keys.split("."):
  600. if isinstance(obj, dict):
  601. obj = obj[k]
  602. else:
  603. obj = getattr(obj, k)
  604. return obj
  605. def new_var(self, name="tmp"):
  606. existing = set(self.code_options["co_varnames"])
  607. # In common case, this will be O(1)
  608. while True:
  609. var = f"{name}_{next(self.unique_var_id)}"
  610. if var not in existing:
  611. self.code_options["co_varnames"] += (var,)
  612. return var
  613. def update_co_names(self, name):
  614. """Ensure self.code_options.co_names contains name"""
  615. if name not in self.code_options["co_names"]:
  616. self.code_options["co_names"] += (name,)
  617. @staticmethod
  618. def module_key_name(*names):
  619. # create a new unique name
  620. name = "_".join(map(str, names))
  621. # Strip the guard lookup L/G access
  622. name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name)
  623. # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv
  624. name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
  625. # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
  626. name = re.sub(r"[^a-zA-Z0-9]", "_", name)
  627. if not name or not name[0].isalpha():
  628. name = "sub" + name
  629. return name
  630. def register_attr_or_module(
  631. self,
  632. target: Union[torch.nn.Module, torch.Tensor, Any],
  633. *names,
  634. **options,
  635. ):
  636. if is_dynamic_nn_module(target, self.root_tx.export):
  637. return variables.UnspecializedNNModuleVariable(target, **options)
  638. options = dict(options)
  639. assert "source" in options
  640. source = options["source"]
  641. assert not isinstance(source, ParamBufferSource)
  642. if isinstance(target, torch.Tensor):
  643. tracer = self.current_tracer
  644. if not self.is_root_tracer():
  645. # For higher order ops, we don't want to insert the get_attr in
  646. # innermost graph. Instead, we want to raise the params/buffers
  647. # as inputs to the higher-order graph, and register them as
  648. # get_attrs in the root tracer.
  649. # Note that Dynamo will still call lift_tracked_freevar_to_input
  650. # when these inputs are encountered for the inner graph. The
  651. # only difference is what happens at the root tracer for
  652. # nn.Parameters vs free inputs. The free inputs are registered
  653. # as placeholders in the root graph, whereas the nn.Parameters
  654. # are registered as get_attr nodes in the root graph.
  655. tracer = self.root_tracer
  656. def wrap_name(module_key):
  657. assert self.param_name_to_source is not None
  658. self.param_name_to_source[module_key] = source
  659. # Check if the attr has already been registered. This can happen
  660. # when two different sources point to the same tensor.
  661. if target in self.root_tx.output.side_effects:
  662. return self.root_tx.output.side_effects[target]
  663. if get_static_address_type(target) == "guarded":
  664. install_guard(source.make_guard(GuardBuilder.ID_MATCH))
  665. elif not is_constant_source(source):
  666. install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
  667. vt = wrap_fx_proxy(
  668. self.root_tx,
  669. tracer.create_proxy("get_attr", module_key, tuple(), {}),
  670. example_value=target,
  671. **options,
  672. )
  673. # Track the object so to avoid duplicate registration in case of
  674. # different sources pointing to the same tensor object.
  675. vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
  676. return vt
  677. elif isinstance(target, torch.nn.Module):
  678. assert isinstance(target, torch.nn.Module)
  679. if source:
  680. install_guard(source.make_guard(GuardBuilder.NN_MODULE))
  681. def wrap_name(module_key):
  682. return NNModuleVariable(type(target), module_key, target, **options)
  683. else:
  684. # This is Dynamo created graph module, e.g., graph module coming
  685. # from higher order ops. NNModuleVariable tracker can't be
  686. # sourceless, so let's return a unspecializedNNModule variable
  687. # tracker.
  688. def wrap_name(module_key):
  689. return variables.UnspecializedNNModuleVariable(target, **options)
  690. elif isinstance(target, (torch.SymInt, torch.SymFloat)):
  691. # HACKY CODE REGION BEGIN
  692. # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
  693. # This ultimately gets written to self.nn_modules, which is unfortunate
  694. # Attrs that are tenors and symints and such need to be migrated to have their
  695. # own storage
  696. # alas, this is like this for now
  697. def wrap_name(module_key):
  698. return SymNodeVariable.create(
  699. self,
  700. self.create_proxy("get_attr", module_key, tuple(), {}),
  701. sym_num=target,
  702. **options,
  703. )
  704. # HACKY CODE REGION END
  705. else:
  706. def wrap_name(module_key):
  707. self.output.update_co_names(module_key)
  708. self.global_scope[module_key] = target
  709. return VariableBuilder(self, ConstantSource(source_name=module_key))(
  710. target
  711. )
  712. for k, v in self.nn_modules.items():
  713. if v is target:
  714. # it already exists
  715. return wrap_name(k)
  716. name = OutputGraph.module_key_name(*names)
  717. base = name
  718. for i in itertools.count():
  719. if name not in self.nn_modules:
  720. self.nn_modules[name] = target
  721. if isinstance(target, torch.nn.Module):
  722. def register_leaf_name(leaf_name):
  723. assert self.param_name_to_source is not None
  724. new_source = ParamBufferSource(source, leaf_name)
  725. new_name = f"{name}.{leaf_name}"
  726. self.param_name_to_source[new_name] = new_source
  727. if isinstance(source, LocalSource):
  728. self.dynamo_flat_name_to_original_fqn[
  729. OutputGraph.module_key_name(new_source.name())
  730. ] = leaf_name
  731. # annoying, but there are cases when we do not have parameters
  732. # see test_nn_moduledict_contains
  733. if hasattr(target, "_parameters"):
  734. for leaf_name, _ in target.named_parameters():
  735. register_leaf_name(leaf_name)
  736. if hasattr(target, "_buffers"):
  737. for leaf_name, _ in target.named_buffers():
  738. register_leaf_name(leaf_name)
  739. return wrap_name(name)
  740. name = f"{base}_{i}"
  741. raise AssertionError("unreachable")
  742. def handle_aliases_for_stolen_lists(self, tx):
  743. # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive
  744. maybe_gm = self.local_scope.get("self")
  745. stolen_list_names = get_locals_to_steal(maybe_gm)
  746. if not stolen_list_names:
  747. return []
  748. alias_insts = []
  749. needs_alias: Dict[
  750. str, List[Union[VariableTracker, AttributeMutationExisting]]
  751. ] = {}
  752. queue = [
  753. *tx.stack,
  754. *tx.symbolic_locals.values(),
  755. *self.side_effects.store_attr_mutations.keys(),
  756. ]
  757. while queue:
  758. x = queue.pop()
  759. if isinstance(x, BaseListVariable):
  760. assert isinstance(x.items, List)
  761. queue += x.items
  762. continue
  763. if not (
  764. isinstance(x, (VariableTracker, AttributeMutationExisting))
  765. and isinstance(x.source, GetItemSource)
  766. and isinstance(x.source.base, LocalSource)
  767. and x.source.base.local_name in stolen_list_names
  768. ):
  769. continue
  770. stolen_name = x.source.base.local_name
  771. if stolen_name not in needs_alias:
  772. needs_alias[stolen_name] = []
  773. needs_alias[stolen_name].append(x)
  774. visited = {}
  775. for arg in self.graphargs:
  776. if not (
  777. isinstance(arg._example, list)
  778. and isinstance(arg.source, LocalSource)
  779. and arg.source.local_name in needs_alias
  780. ):
  781. continue
  782. # arg is a list that will be cleared by the compiled function
  783. list_name = arg.source.local_name
  784. assert list_name in self.code_options["co_varnames"]
  785. for x in needs_alias[list_name]:
  786. list_idx = x.source.index
  787. if list_idx not in visited:
  788. alias_name = self.new_var(
  789. f"{list_name}_ref"
  790. ) # self.new_var already adds unique id suffix
  791. visited[list_idx] = alias_name
  792. # bytecode of `alias_name = list_name[list_idx]`
  793. alias_insts.extend(
  794. [
  795. create_instruction("LOAD_FAST", argval=list_name),
  796. create_instruction("LOAD_CONST", argval=list_idx),
  797. create_instruction("BINARY_SUBSCR"),
  798. create_instruction("STORE_FAST", argval=alias_name),
  799. ]
  800. )
  801. # operate on alias, handled by suffix codegen
  802. x.source = LocalSource(visited[list_idx])
  803. return alias_insts
  804. def compile_subgraph(
  805. self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
  806. ):
  807. """
  808. Generate a subgraph to continue execution on user code.
  809. Automatically restore live variables.
  810. """
  811. assert reason is not None
  812. from .decorators import disable
  813. self.partial_convert = partial_convert
  814. self.compile_subgraph_reason = reason
  815. self.should_exit = True
  816. log.debug("COMPILING GRAPH due to %s", reason)
  817. if not all(block.can_restore() for block in tx.block_stack):
  818. unimplemented("compile_subgraph with block_depth != 0")
  819. prefix_insts: List[Instruction] = []
  820. if sys.version_info >= (3, 11):
  821. # prefix instructions (Python 3.11+)
  822. for inst in tx.prefix_insts:
  823. if inst.opname == "MAKE_CELL":
  824. prefix_insts.append(
  825. create_instruction("MAKE_CELL", argval=inst.argval)
  826. )
  827. elif inst.opname == "COPY_FREE_VARS":
  828. prefix_insts.append(
  829. create_instruction(
  830. "COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"])
  831. )
  832. )
  833. else:
  834. prefix_insts.append(copy.copy(inst))
  835. assert not (
  836. self.pregraph_bytecode and self.export
  837. ), "export does not support pregraph_bytecode"
  838. prefix_insts.extend(self.pregraph_bytecode)
  839. prefix_insts.extend(self.handle_aliases_for_stolen_lists(tx))
  840. def append_prefix_insts():
  841. self.add_output_instructions(prefix_insts)
  842. prefix_insts.clear()
  843. for block in reversed(tx.block_stack):
  844. block.exit(tx)
  845. self.cleanup_graph()
  846. tx.prune_dead_locals()
  847. stack_values = list(tx.stack)
  848. # realize any unrealized tensor VTs in case they
  849. # need to be added to self.nn_modules as attributes
  850. for value in stack_values:
  851. value.realize()
  852. # Use nn.Module "proxies" in the constructed GraphModule so that
  853. # the resulting GM does not hold additional strong references to the original modules.
  854. # This prevents a strong ref cycle where Dynamo created code holds on to references
  855. # to modules that also have Dynamo code cache invalidation checks.
  856. # When cache invalidation runs, the generated GM will be invalidated, which also deletes
  857. # the proxies.
  858. nn_modules_proxies = {
  859. name: nn_module_proxy(mod) for name, mod in self.nn_modules.items()
  860. }
  861. root = FakeRootModule(nn_modules_proxies)
  862. # Add all the local vars to the "stack" so restore at the end
  863. restore_vars = []
  864. val_to_names: Dict[VariableTracker, List[str]] = {}
  865. if stack_values:
  866. val_to_names[stack_values[-1]] = list()
  867. # NB: Typically (i.e., for graph compile from RETURN_VALUE),
  868. # symbolic_locals will be empty at this point, as prune_dead_locals
  869. # will clear out all of symbolic_locals because RETURN_VALUE is the
  870. # last instruction and no more locals are used. The fanciness here
  871. # is only needed for partial graphs.
  872. for k, v in tx.symbolic_locals.items():
  873. # Note! this explicitly uses .local_name for matching
  874. # Failure to do so will cause spurious registrations in val_to_names.
  875. # This will in turn result in spurious variables showing up in the graph.
  876. # This was very tricky to debug. For an example, dump the graph at call_user_compiler
  877. # while running test_subgraphs.py
  878. if isinstance(v.source, LocalSource) and v.source.local_name == k:
  879. continue # no need to restore initial state
  880. # Do not load variable if it is NULL.
  881. if sys.version_info >= (3, 12):
  882. # Continuation function will load the NULL for v.
  883. if type.__instancecheck__(NullVariable, v):
  884. continue
  885. else:
  886. # A variable should never be NULL in < 3.12
  887. assert not type.__instancecheck__(NullVariable, v)
  888. if v not in val_to_names:
  889. val_to_names[v] = list()
  890. val_to_names[v].append(k)
  891. for v in val_to_names.keys():
  892. restore_vars.extend(val_to_names[v])
  893. stack_values.extend([v] * len(val_to_names[v]))
  894. # to handle random calls
  895. if len(self.random_calls) > 0:
  896. append_prefix_insts()
  897. random_calls_instructions = []
  898. self.random_values_var = self.new_var("random_values")
  899. rand_fn = disable(_get_gen_rand_values_fn(self.random_calls))
  900. rand_fn_name = self.install_global("__gen_rand_values", rand_fn)
  901. codegen = PyCodegen(tx, root)
  902. random_calls_instructions.extend(
  903. codegen.load_function_name(rand_fn_name, True)
  904. )
  905. random_calls_instructions.extend(create_call_function(0, False))
  906. random_calls_instructions.append(
  907. codegen.create_store(tx.output.random_values_var),
  908. )
  909. self.add_output_instructions(random_calls_instructions)
  910. if (
  911. stack_values
  912. and all(
  913. not isinstance(
  914. v,
  915. (
  916. UnspecializedPythonVariable,
  917. NumpyNdarrayVariable,
  918. TensorWithTFOverrideVariable,
  919. ),
  920. )
  921. and not (isinstance(v, SymNodeVariable) and v.python_type() is float)
  922. for v in stack_values
  923. )
  924. and all(isinstance(x, TensorVariable) for x in stack_values)
  925. and len(set(stack_values)) == len(stack_values)
  926. and self.side_effects.is_empty()
  927. and not len(tx.debug_locals) != 0
  928. and not self.backward_state
  929. ):
  930. append_prefix_insts()
  931. # optimization to generate better code in a common case
  932. self.add_output_instructions(
  933. self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  934. + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))]
  935. )
  936. # restore all the live local vars
  937. self.add_output_instructions(
  938. [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
  939. )
  940. else:
  941. graph_output_var = self.new_var("graph_out")
  942. pass1 = PyCodegen(tx, root, graph_output_var)
  943. self.codegen_suffix(tx, stack_values, pass1)
  944. # one more time now that we have established tempvars
  945. pass2 = PyCodegen(
  946. tx,
  947. root,
  948. graph_output_var,
  949. tempvars={val: None for val, count in pass1.uses.items() if count > 1},
  950. )
  951. self.codegen_suffix(tx, stack_values, pass2)
  952. stored_graph_output_var = False
  953. output = []
  954. if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
  955. output.extend(
  956. self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  957. )
  958. if len(pass2.graph_outputs) != 0:
  959. output.append(pass2.create_store(graph_output_var))
  960. stored_graph_output_var = True
  961. else:
  962. output.append(create_instruction("POP_TOP"))
  963. append_prefix_insts()
  964. self.add_output_instructions(output + pass2.get_instructions())
  965. # restore all the live local vars
  966. self.add_output_instructions(
  967. [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
  968. )
  969. if stored_graph_output_var:
  970. self.add_output_instructions(
  971. [PyCodegen(tx).create_delete(graph_output_var)]
  972. )
  973. def codegen_suffix(self, tx, stack_values, cg):
  974. if self.backward_state:
  975. assert not self.export
  976. for name, val in self.backward_state.items():
  977. cg(val)
  978. cg.append_output(cg.create_load(self.backward_state_var))
  979. cg.store_attr(name)
  980. self.side_effects.codegen_hooks(cg)
  981. self.side_effects.codegen_save_tempvars(cg)
  982. # Return variables used for logging at the end
  983. for debug_var, args in tx.debug_locals:
  984. cg(debug_var)
  985. for arg in args:
  986. cg(arg)
  987. cg.extend_output(create_call_function(len(args), True))
  988. cg.extend_output([create_instruction("POP_TOP")])
  989. cg.restore_stack(stack_values, value_from_source=not tx.export)
  990. self.side_effects.codegen_update_mutated(cg)
  991. def cleanup_graph(self):
  992. """
  993. Remove "creation_timestamp" from node meta
  994. Remove this pattern from the graph:
  995. torch._C._set_grad_enabled(False)
  996. torch._C._set_grad_enabled(True)
  997. """
  998. assert self.should_exit
  999. nodes = list(self.graph.nodes)
  1000. for node in nodes:
  1001. node.meta.pop("creation_timestamp", None)
  1002. grad_enabled = torch.is_grad_enabled()
  1003. for node1, node2 in zip(nodes, nodes[1:]):
  1004. if (
  1005. node1.target is torch._C._set_grad_enabled
  1006. and tuple(node1.args) == (not grad_enabled,)
  1007. and not node1._erased
  1008. ):
  1009. grad_enabled = node1.args[0]
  1010. if (
  1011. node2.target is torch._C._set_grad_enabled
  1012. and tuple(node2.args) == (not grad_enabled,)
  1013. and not node2._erased
  1014. ):
  1015. grad_enabled = node2.args[0]
  1016. self.graph.erase_node(node1)
  1017. self.graph.erase_node(node2)
  1018. def get_graph_sizes_structured(self):
  1019. ret = {}
  1020. for node in self.graph.nodes:
  1021. example_value = node.meta.get("example_value", None)
  1022. if isinstance(example_value, torch._subclasses.FakeTensor):
  1023. size = example_value.size()
  1024. ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size]
  1025. return ret
  1026. def get_graph_sizes(self, name: str):
  1027. graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n"
  1028. graph_sizes_str += f"===== {name} =====\n"
  1029. for node in self.graph.nodes:
  1030. example_value = node.meta.get("example_value", None)
  1031. if isinstance(example_value, torch._subclasses.FakeTensor):
  1032. size = example_value.size()
  1033. graph_sizes_str += f"{node.name}: {tuple(size)}\n"
  1034. concrete_size = []
  1035. has_symint = False
  1036. for sz in size:
  1037. if isinstance(sz, int):
  1038. concrete_size.append(sz)
  1039. elif isinstance(sz, torch.SymInt):
  1040. has_symint = True
  1041. concrete_size.append(sz.node.hint)
  1042. else:
  1043. break
  1044. else:
  1045. if has_symint:
  1046. graph_sizes_str += (
  1047. f"{node.name} (concrete): {tuple(concrete_size)}\n"
  1048. )
  1049. return graph_sizes_str
  1050. @contextlib.contextmanager
  1051. def restore_global_state(self):
  1052. """
  1053. Momentarily restores the global state to what it was prior to tracing the current output
  1054. """
  1055. prior_global_state = self.tracing_context.global_context.copy_graphstate()
  1056. current_global_state: Dict[str, Tuple[Any, bool]] = {}
  1057. self.save_global_state(out=current_global_state)
  1058. try:
  1059. # Set to state prior to tracing the graph
  1060. self.tracing_context.global_context.restore_graphstate(prior_global_state)
  1061. yield
  1062. finally:
  1063. # Reset to state at the current time (e.g. before calling the user compiler)
  1064. self.tracing_context.global_context.restore_graphstate(
  1065. GlobalContextCheckpointState(current_global_state)
  1066. )
  1067. @torch._guards.TracingContext.clear_frame()
  1068. def compile_and_call_fx_graph(self, tx, rv, root):
  1069. """
  1070. Generate code from self.graph and return the Instruction()s to
  1071. call that generated code.
  1072. """
  1073. from .decorators import disable
  1074. assert self.should_exit
  1075. name = unique_id("__compiled_fn")
  1076. assert isinstance(rv, list)
  1077. assert isinstance(root, FakeRootModule)
  1078. self.create_node(
  1079. "output",
  1080. "output",
  1081. (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
  1082. {},
  1083. )
  1084. if not config.do_not_emit_runtime_asserts:
  1085. insert_deferred_runtime_asserts(
  1086. fx.GraphModule(root, self.graph),
  1087. self.shape_env,
  1088. name,
  1089. )
  1090. # NB: deferred runtime asserts can keep graphargs live, so make sure
  1091. # those are inserted before pruning
  1092. self.remove_unused_graphargs()
  1093. ncalls = count_calls(self.graph)
  1094. counters["stats"]["calls_captured"] += ncalls
  1095. # free a bit of memory
  1096. self.real_value_cache.clear()
  1097. gm = _make_graph_module(root, self.graph)
  1098. for register_finalizer in self.register_finalizer_fns:
  1099. register_finalizer(gm)
  1100. gm.compile_subgraph_reason = self.compile_subgraph_reason
  1101. gm.meta[
  1102. "dynamo_flat_name_to_original_fqn"
  1103. ] = self.dynamo_flat_name_to_original_fqn.copy()
  1104. graph_code_log.debug(
  1105. "%s",
  1106. lazy_format_graph_code(name, gm, include_stride=True, include_device=True),
  1107. )
  1108. torch._logging.trace_structured(
  1109. "dynamo_output_graph",
  1110. lambda: {"sizes": self.get_graph_sizes_structured()},
  1111. payload_fn=lambda: gm.print_readable(
  1112. print_output=False, include_stride=True, include_device=True
  1113. ),
  1114. )
  1115. self.call_cleanup_hooks()
  1116. old_fake_mode = self.tracing_context.fake_mode
  1117. if not self.export:
  1118. import torch._functorch.config as _config
  1119. with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
  1120. # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
  1121. backend_fake_mode = torch._subclasses.FakeTensorMode(
  1122. shape_env=old_fake_mode.shape_env,
  1123. )
  1124. # TODO(voz): Ostensibily, this should be scoped and
  1125. # restore back to old_fake_mode, but doing so currently violates
  1126. # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
  1127. self.tracing_context.fake_mode = backend_fake_mode
  1128. with self.restore_global_state():
  1129. compiled_fn = self.call_user_compiler(gm)
  1130. from torch.fx._lazy_graph_module import _LazyGraphModule
  1131. if isinstance(compiled_fn, _LazyGraphModule) or (
  1132. isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
  1133. and compiled_fn.__name__ == "_lazy_forward"
  1134. ):
  1135. # Since dynamo will run the forward method for the GraphModule shortly
  1136. # anyways, it does not hurt to do the real recompilation here if
  1137. # this is a _LazyGraphModule. This makes it easier for dynamo to
  1138. # optimize a _LazyGraphModule.
  1139. lazy_gm = (
  1140. compiled_fn
  1141. if isinstance(compiled_fn, _LazyGraphModule)
  1142. else compiled_fn.__self__
  1143. )
  1144. _LazyGraphModule.force_recompile(lazy_gm)
  1145. if not isinstance(compiled_fn, _LazyGraphModule):
  1146. # replace compiled_fn with the real forward method
  1147. compiled_fn = lazy_gm.forward
  1148. compiled_fn = disable(compiled_fn)
  1149. counters["stats"]["unique_graphs"] += 1
  1150. # This is safe because we pre-process name to be unique
  1151. self.install_global_unsafe(name, compiled_fn)
  1152. cg = PyCodegen(tx)
  1153. cg.make_call_generated_code(name)
  1154. return cg.get_instructions()
  1155. @property
  1156. def placeholders(self) -> List[fx.Node]:
  1157. return self.graph.find_nodes(op="placeholder")
  1158. @property
  1159. def graphargs(self) -> List[GraphArg]:
  1160. return [node.meta["grapharg"] for node in self.placeholders]
  1161. @dynamo_timed(phase_name="backend_compile")
  1162. def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
  1163. assert self.compiler_fn is not None
  1164. tot = 0
  1165. placeholders = []
  1166. for node in gm.graph.nodes:
  1167. if node.op in ("call_function", "call_method", "call_module"):
  1168. tot += 1
  1169. if node.op == "placeholder":
  1170. placeholders.append(node)
  1171. increment_op_count(tot)
  1172. for pl in placeholders:
  1173. arg = pl.meta["grapharg"]
  1174. # TODO: Why isn't this stored in meta :think:
  1175. pl._dynamo_source = arg.source
  1176. gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
  1177. gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment]
  1178. try:
  1179. name = (
  1180. self.compiler_fn.__name__
  1181. if hasattr(self.compiler_fn, "__name__")
  1182. else ""
  1183. )
  1184. _step_logger()(logging.INFO, f"calling compiler function {name}")
  1185. compiler_fn = self.compiler_fn
  1186. if config.verify_correctness:
  1187. compiler_fn = WrapperBackend(compiler_fn)
  1188. compiled_fn = compiler_fn(gm, self.example_inputs())
  1189. _step_logger()(logging.INFO, f"done compiler function {name}")
  1190. assert callable(compiled_fn), "compiler_fn did not return callable"
  1191. except exceptions_allowed_to_be_fallback as e:
  1192. if self.has_user_defined_allowed_in_graph:
  1193. raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  1194. e.__traceback__
  1195. ) from None
  1196. msg = (
  1197. "Backend compiler failed with a fake tensor exception at \n"
  1198. f"{self.root_tx.format_frame_summary()}"
  1199. "Adding a graph break."
  1200. )
  1201. unimplemented_with_warning(e, self.root_tx.f_code, msg)
  1202. except SkipFrame as e:
  1203. # The backend compiler has requested that we skip the frame, instead of
  1204. # aborting execution.
  1205. raise e
  1206. except Exception as e:
  1207. raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  1208. e.__traceback__
  1209. ) from None
  1210. signpost_event(
  1211. "dynamo",
  1212. "OutputGraph.call_user_compiler",
  1213. {
  1214. **self.co_fields,
  1215. "op_count": tot,
  1216. "node_count": len(gm.graph.nodes),
  1217. "input_count": len(placeholders),
  1218. },
  1219. )
  1220. return compiled_fn
  1221. def example_inputs(self) -> List[torch.Tensor]:
  1222. result = []
  1223. for arg in self.graphargs:
  1224. result.append(arg.example)
  1225. return result
  1226. def remove_unused_graphargs(self) -> None:
  1227. # NB: It's always OK to drop GraphArg for symbols that ended up being
  1228. # specialized. You don't even have to make a guard for it, because
  1229. # ShapeEnv produce_guards operates on tracked_fakes, which never gets
  1230. # pruned. That being said, you'll get marginally better generated
  1231. # guard code if you promote the guard into a Dynamo guard (since that
  1232. # allows for the guard to be done using C++ guards.) If we get
  1233. # ShapeEnv guards to go into C++ guards, this will stop being a thing
  1234. # though!
  1235. assert self.should_exit
  1236. # Miniature DCE pass, but only for obviously trivial operations
  1237. def is_static_true(b_node: fx.node.Argument):
  1238. if b_node is True:
  1239. return True
  1240. if not isinstance(b_node, fx.Node):
  1241. return False
  1242. b = b_node.meta.get("example_value")
  1243. if b is None:
  1244. return False
  1245. if b is True:
  1246. return True
  1247. if (
  1248. isinstance(b, torch.SymBool)
  1249. and (r := b.node.maybe_as_bool()) is not None
  1250. ):
  1251. return r
  1252. # TODO: We can also technically remove all cases when the input
  1253. # doesn't have unbacked inputs, since it's all in the ShapeEnv
  1254. return False
  1255. def is_symnode_arg(a: fx.node.Argument):
  1256. from torch.fx.experimental.sym_node import SymTypes
  1257. if isinstance(a, (int, float, bool)):
  1258. return True
  1259. if isinstance(a, fx.Node):
  1260. return isinstance(a.meta.get("example_value"), SymTypes)
  1261. return False
  1262. # NB: We assume that you cannot do mutations on int/float/bool,
  1263. # because they are immutable types, and therefore is always safe to
  1264. # DCE.
  1265. def is_symnode_compute_node(node):
  1266. from torch.fx.experimental.sym_node import SymTypes
  1267. if node.op != "call_function":
  1268. return False
  1269. # TODO: I don't think it's possible to have a bare int/float here?
  1270. if not isinstance(node.meta.get("example_value"), SymTypes):
  1271. return False
  1272. # TODO: This will bail here if you ever end up with a more complicated
  1273. # computation function, like sum(list_of_ints), even though it
  1274. # should be DCE'able
  1275. if not all(is_symnode_arg(a) for a in node.args):
  1276. return False
  1277. if not all(is_symnode_arg(a) for a in node.kwargs.values()):
  1278. return False
  1279. return True
  1280. # NB: You could try to expand this to cover more cases by simply
  1281. # detecting whenever you have an int output, but this is a bit
  1282. # dangerous in case someone adds a function that returns an int but is
  1283. # mutating. So manually whitelist for now.
  1284. def is_accessor_node(node):
  1285. if (
  1286. node.op == "call_method"
  1287. and isinstance(node.args[0].meta.get("example_value"), torch.Tensor)
  1288. and node.target in ["size", "stride", "storage_offset", "item"]
  1289. ):
  1290. return True
  1291. if node.op == "call_function" and node.target in [
  1292. torch.ops.aten.sym_size,
  1293. torch.ops.aten.sym_size.default,
  1294. torch.ops.aten.sym_size.int,
  1295. torch.ops.aten.sym_stride,
  1296. torch.ops.aten.sym_stride.default,
  1297. torch.ops.aten.sym_stride.int,
  1298. torch.ops.aten.sym_storage_offset,
  1299. torch.ops.aten.sym_storage_offset.default,
  1300. ]:
  1301. return True
  1302. return False
  1303. for node in reversed(list(self.graph.nodes)):
  1304. if len(list(node.users)) == 0:
  1305. if (
  1306. node.op == "get_attr"
  1307. or (node.op == "call_function" and node.target is operator.getitem)
  1308. or (
  1309. node.op == "call_function"
  1310. and node.target is torch._check
  1311. and is_static_true(node.args[0])
  1312. )
  1313. or is_symnode_compute_node(node)
  1314. or is_accessor_node(node)
  1315. ):
  1316. self.remove_node(node)
  1317. def placeholder_binds_symbol(node):
  1318. arg = node.meta["grapharg"]
  1319. example = arg.example
  1320. if isinstance(example, torch.SymInt) and isinstance(
  1321. example.node.expr, sympy.Symbol
  1322. ):
  1323. return example.node.expr
  1324. return None
  1325. def remove_unused(node):
  1326. log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name())
  1327. # I'm not really sure why you need to delete these from the
  1328. # node since the node is going to get removed
  1329. del node.meta["grapharg"]
  1330. self.remove_node(node)
  1331. self.real_value_cache.pop(node, None)
  1332. used_symbols: Set[sympy.Symbol] = set()
  1333. def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
  1334. used_symbols |= free_symbols(fake)
  1335. recheck_placeholders = []
  1336. for node in self.placeholders:
  1337. binds_symbol = placeholder_binds_symbol(node) is not None
  1338. # Don't delete symbol bindings yet
  1339. if binds_symbol:
  1340. if not node.users:
  1341. recheck_placeholders.append(node)
  1342. else:
  1343. if not node.users and not isinstance(
  1344. node.meta["grapharg"], BackwardStateGraphArg
  1345. ):
  1346. remove_unused(node)
  1347. else:
  1348. # Register the free symbols as uses
  1349. arg = node.meta["grapharg"]
  1350. if isinstance(arg, BackwardStateGraphArg):
  1351. continue
  1352. if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
  1353. real_script_obj = node.meta["grapharg"].example
  1354. fake_script_obj = node.meta["grapharg"].example_strong_ref
  1355. flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined]
  1356. for attr in flat_dict.keys():
  1357. fake_attr_val = getattr(fake_script_obj.wrapped_obj, attr)
  1358. pytree.tree_map_only(
  1359. (torch.SymInt, torch.Tensor),
  1360. lambda t: update_used_symbols(used_symbols, t),
  1361. fake_attr_val,
  1362. )
  1363. continue
  1364. fake = (
  1365. arg.fake_tensor if arg.fake_tensor is not None else arg.example
  1366. )
  1367. update_used_symbols(used_symbols, fake)
  1368. # After removing unused graphargs, prune unused binds_symbol
  1369. for node in recheck_placeholders:
  1370. symbol = placeholder_binds_symbol(node)
  1371. if symbol is not None:
  1372. if symbol not in used_symbols:
  1373. remove_unused(node)
  1374. else:
  1375. # Make sure we delete later occurrences of the same symbol
  1376. used_symbols.remove(symbol)
  1377. def add_output_instructions(self, prefix: List[Instruction]) -> None:
  1378. """
  1379. We call this on the creation of a new compiled subgraph that is inserted
  1380. before user code.
  1381. """
  1382. self.output_instructions.extend(prefix)
  1383. self.should_exit = True
  1384. def install_global_unsafe(self, name, value) -> None:
  1385. """
  1386. WARNING: prefer the safer `install_global_by_id/install_global`.
  1387. torch.compile instances should be independent of each other;
  1388. one footgun is to have one instance depend on the existence of
  1389. a global installed by another instance. This can happen if we mangle
  1390. a global the same way across both instances.
  1391. """
  1392. assert name not in self.installed_globals
  1393. self.installed_globals.add(name)
  1394. self.cleanups.append(CleanupHook.create(self.global_scope, name, value))
  1395. def install_global_by_id(self, prefix, value) -> str:
  1396. """
  1397. Installs a global if it hasn't been installed already.
  1398. This is determined by (prefix, id(value)) pair.
  1399. Returns the name of the newly installed global.
  1400. """
  1401. # NB: need self.compile_id to distinguish this global
  1402. # from another global created in a different torch.compile instance
  1403. name = f"{prefix}_{id(value)}_c{self.compile_id}"
  1404. if name in self.installed_globals:
  1405. return name
  1406. self.install_global_unsafe(name, value)
  1407. return name
  1408. def install_global(self, prefix, value) -> str:
  1409. """
  1410. Installs a global, generating a unique name for it.
  1411. Returns the name of the newly installed global.
  1412. """
  1413. # NB: unique_id is unique, even across torch.compile instances
  1414. name = unique_id(prefix)
  1415. self.install_global_unsafe(name, value)
  1416. return name
  1417. def cleanup(self) -> None:
  1418. # There is a reference cycle between tracer and OutputGraph, causing
  1419. # some of the tensor objects to be held alive for longer than necessary.
  1420. self.root_tx = None
  1421. self.nn_modules.clear()
  1422. self.param_name_to_source = None
  1423. for node in self.graph.nodes:
  1424. if "grapharg" in node.meta:
  1425. del node.meta["grapharg"]
  1426. self.real_value_cache.clear()
  1427. self.input_name_to_proxy.clear()
  1428. self.side_effects.clear()
  1429. self.variable_tracker_cache.clear()
  1430. self.register_finalizer_fns.clear()
  1431. self.dynamo_flat_name_to_original_fqn.clear()
  1432. self.tracing_context.clear()
  1433. def set_torch_function_state(self, enabled: bool) -> None:
  1434. self.torch_function_enabled = enabled
  1435. def add_graph_finalizer(
  1436. self, register_finalizer: Callable[[fx.GraphModule], None]
  1437. ) -> None:
  1438. self.register_finalizer_fns.append(register_finalizer)
  1439. def example_value_from_input_node(self, node: torch.fx.Node):
  1440. """Extract the non-fake example tensor"""
  1441. if node.op == "placeholder":
  1442. return node.meta["grapharg"].example
  1443. assert node.op == "get_attr"
  1444. return self.nn_modules[node.target] # type: ignore[index]
  1445. err_epilogue = (
  1446. "With the current config, we will graph break "
  1447. "(and fall back to eager-mode PyTorch) on all ops "
  1448. "that have do not have the 'pt2_compliant_tag'. "
  1449. "Please see the following doc for how to mark this op as PT2 compliant "
  1450. "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html"
  1451. )
  1452. def check_pt2_compliant_op(output_graph, kind, target, args, kwargs):
  1453. if kind != "call_function":
  1454. return
  1455. def encountered_compliant_op(target):
  1456. if target.namespace in {"prim", "prims", "aten"}:
  1457. return
  1458. output_graph.compliant_custom_ops.add(target)
  1459. def encountered_non_compliant_op(target, msg):
  1460. output_graph.non_compliant_ops.add(target)
  1461. if config.only_allow_pt2_compliant_ops:
  1462. unimplemented(msg + " " + err_epilogue)
  1463. if isinstance(target, torch._ops.OpOverload):
  1464. if torch.Tag.pt2_compliant_tag in target.tags:
  1465. encountered_compliant_op(target)
  1466. return
  1467. encountered_non_compliant_op(
  1468. target,
  1469. f"Encountered the torch.ops.OpOverload {target} "
  1470. f"that is not PT2 compliant.",
  1471. )
  1472. return
  1473. if isinstance(target, torch._ops.OpOverloadPacket):
  1474. overloads = tuple(target.overloads())
  1475. # Optimization: Overload resolution is expensive.
  1476. # If there's only one overload, we know what it will resolve to.
  1477. if len(overloads) == 1:
  1478. op = getattr(target, overloads[0])
  1479. if torch.Tag.pt2_compliant_tag in op.tags:
  1480. encountered_compliant_op(op)
  1481. return
  1482. encountered_non_compliant_op(
  1483. op,
  1484. f"Encountered the non-overloaded "
  1485. f"torch.ops.OpOverloadPacket {target} "
  1486. f"that is not PT2 compliant. ",
  1487. )
  1488. return
  1489. args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes(
  1490. output_graph.current_tx, (args, kwargs), False
  1491. )
  1492. try:
  1493. overload = torch._C._jit_resolve_packet(
  1494. target._qualified_op_name, *args, **kwargs
  1495. )
  1496. except RuntimeError as e:
  1497. unimplemented(str(e))
  1498. op = getattr(target, overload)
  1499. if torch.Tag.pt2_compliant_tag in op.tags:
  1500. encountered_compliant_op(op)
  1501. else:
  1502. encountered_non_compliant_op(
  1503. op,
  1504. f"Encountered the torch.ops.OpOverloadPacket {target} "
  1505. f"which resolves to the overload ({overload}) that is "
  1506. f"not PT2 compliant.",
  1507. )
  1508. _compile_id_counter = itertools.count()
  1509. class SubgraphTracer(fx.Tracer):
  1510. """
  1511. Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer
  1512. and the separation of responsibilities is that SubgraphTracer is
  1513. responsible for building the graph while OutputGraph is responsible for
  1514. compiling and executing the graph.
  1515. """
  1516. def __init__(
  1517. self, output_graph, parent=None, export_root=False, source_target=None
  1518. ):
  1519. super().__init__()
  1520. self.output_graph = weakref.proxy(output_graph)
  1521. self.graph = torch.fx.Graph()
  1522. # The export is only ever set for the ROOT tracer. It controls
  1523. # whether or not certain inputs are allowed to be added or not.
  1524. # Look at call sites of create_graph_input to see how it is used.
  1525. if export_root:
  1526. assert parent is None
  1527. self.export_root = export_root
  1528. # Map from graph input name to its placeholder proxy object, where the
  1529. # map's keys give all current placeholder node names and can be used to
  1530. # create unique node names
  1531. self.input_name_to_proxy: Dict[str, fx.Proxy] = {}
  1532. # Node => computed real value (see utils.get_real_value)
  1533. self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}
  1534. # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design]
  1535. self.parent = parent
  1536. # A dict mapping previously free variables (Proxy objects)
  1537. # to new Proxy objects that wrap inputs to this subgraph.
  1538. #
  1539. # This dict serves two purposes:
  1540. # - Proxies are associated with VariableTrackers. If we see
  1541. # the same VariableTracker twice (and it is a free variable),
  1542. # then we want to use the same Proxy in the current subgraph to
  1543. # record the tracing.
  1544. # - If we are tracing a HigherOrderOperator's body_fn, then we
  1545. # need to keep track of what free variables were lifted so we can
  1546. # rewrite the HigherOrderOperator call using the traced body_fn.
  1547. # Dicts maintain the order of args for the HigherOrderOperator call.
  1548. self.lifted_freevars = {}
  1549. self.prev_inst = None
  1550. self._cur_code = None
  1551. self._orig_gm_meta = None
  1552. self._orig_gm_lineno_map = None
  1553. self._orig_gm_firstlineno = None
  1554. # Each SubgraphTracer is associated with a source target, which indicates
  1555. # which operator this subgraph is attached to. We compute a source_fn_stack
  1556. # based on the source target. For the root tracer, it's set to [].
  1557. # This is useful for debugging and transforming the exported graph.
  1558. if self.parent is None:
  1559. self.source_fn_stack = []
  1560. else:
  1561. self.source_fn_stack = self.parent.source_fn_stack + [
  1562. (self.graph._target_to_str(source_target), source_target)
  1563. ]
  1564. def create_proxy(
  1565. self,
  1566. kind,
  1567. target,
  1568. args,
  1569. kwargs,
  1570. name=None,
  1571. type_expr=None,
  1572. proxy_factory_fn=None,
  1573. ):
  1574. # NOTE: [Nested SubgraphTracer and free_variable handling]
  1575. # --------------------------------------------------------
  1576. # Read NOTE [HigherOrderOperator tracing design] first.
  1577. #
  1578. # Let's say we're in the middle of introspecting the body of a possibly
  1579. # nested HigherOrderOperator, and we see a free variable.
  1580. #
  1581. # There are two cases:
  1582. # 1. We see a free variable that is already tracked by Dynamo.
  1583. # 2. We see a free variable that has not been tracked by Dynamo
  1584. #
  1585. # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below)
  1586. # which will lift the freevar to be an input of this subgraph
  1587. # and also recursively lift it to be an input on the parent(s).
  1588. #
  1589. # In case 2, before the call to `create_proxy`, the InstructionTranslator
  1590. # will see the freevar when it gets loaded by Python bytecode.
  1591. # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or
  1592. # LOAD_GLOBAL.
  1593. # There, the InstructionTranslator asks Dynamo to begin tracking the
  1594. # freevar by building a new Variable.
  1595. # Building a new Variable automatically lifts the freevar to be an
  1596. # input of the root SubgraphTracer.
  1597. #
  1598. # The implications for the code below are:
  1599. # - We will always be in Case 1 when we get to this code.
  1600. # - Any "free variable" we encounter here is guaranteed to already be
  1601. # bound, that is, it is either a graph input of the root graph, or
  1602. # some local variable of the root graph or a subgraph.
  1603. # - The additional work we need to do here is *only* that we need to
  1604. # lift this free variable into inputs (recursively) of each nested
  1605. # higher-order-op subgraph until we hit the subgraph where the free
  1606. # variable is bound
  1607. if self.parent is not None:
  1608. flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
  1609. new_flat_args = []
  1610. for arg in flat_args:
  1611. maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg)
  1612. new_flat_args.append(maybe_new_arg)
  1613. args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec)
  1614. rv = super().create_proxy(
  1615. kind, target, args, kwargs, name, type_expr, proxy_factory_fn
  1616. )
  1617. # append stack trace to fx node
  1618. tx = self.output_graph.current_tx
  1619. # log detailed location of line of code in 3.11
  1620. if sys.version_info >= (3, 11) and kind in (
  1621. "call_function",
  1622. "call_method",
  1623. "call_module",
  1624. ):
  1625. cur_inst = tx.current_instruction
  1626. if (
  1627. cur_inst is not self.prev_inst
  1628. and cur_inst.positions is not None
  1629. and cur_inst.positions.lineno is not None
  1630. ):
  1631. tx_code = tx.f_code
  1632. header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno)
  1633. def get_trace_call_log_str():
  1634. line = get_instruction_source_311(tx_code, cur_inst).rstrip()
  1635. return f"TRACE FX call {rv.node.name} from {header}\n{line}"
  1636. trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
  1637. self.prev_inst = cur_inst
  1638. # update reference to original meta if we're tracing a new code object
  1639. is_retracing = False
  1640. if tx.f_code is not self._cur_code:
  1641. orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
  1642. "orig_graphmodule", lambda: None
  1643. )()
  1644. if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
  1645. is_retracing = True
  1646. self._orig_gm_meta = [
  1647. nd.meta for nd in orig_graphmodule_maybe.graph.nodes
  1648. ]
  1649. self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map
  1650. self._orig_gm_firstlineno = (
  1651. orig_graphmodule_maybe.forward.__code__.co_firstlineno
  1652. )
  1653. else:
  1654. self._orig_gm_meta = None
  1655. self._orig_gm_lineno_map = None
  1656. self._orig_gm_firstlineno = None
  1657. nn_module_stack = tx.nn_module_stack
  1658. if nn_module_stack:
  1659. rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
  1660. if kind in {"call_function", "call_method"}:
  1661. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
  1662. (rv.node.name, target)
  1663. ]
  1664. elif kind == "call_module":
  1665. if self.parent is not None:
  1666. unimplemented("Invoking an nn.Module inside HigherOrderOperator")
  1667. # For modules we store the class
  1668. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
  1669. (
  1670. rv.node.name,
  1671. rv.node.meta["nn_module_stack"][target][1],
  1672. )
  1673. ]
  1674. # preserve original meta if it is available
  1675. if (
  1676. self._orig_gm_meta
  1677. and self._orig_gm_lineno_map
  1678. and self._orig_gm_firstlineno
  1679. ):
  1680. lineno = tx.current_instruction.starts_line
  1681. node_idx = None
  1682. if lineno is not None:
  1683. node_idx = self._orig_gm_lineno_map.get(
  1684. lineno - self._orig_gm_firstlineno, None
  1685. )
  1686. if node_idx is not None:
  1687. meta = self._orig_gm_meta[node_idx]
  1688. for field in fx.proxy._COPY_META_FIELDS:
  1689. if field in meta:
  1690. rv.node.meta[field] = meta[field]
  1691. if "stack_trace" in meta:
  1692. rv.node.meta["stack_trace"] = meta["stack_trace"]
  1693. if not is_retracing:
  1694. if "nn_module_stack" not in rv.node.meta:
  1695. nn_module_stack = tx.nn_module_stack
  1696. if nn_module_stack:
  1697. rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
  1698. if "source_fn_stack" not in rv.node.meta:
  1699. if kind in {"call_function", "call_method"}:
  1700. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
  1701. (rv.node.name, target)
  1702. ]
  1703. elif kind == "call_module":
  1704. if self.parent is not None:
  1705. unimplemented(
  1706. "Invoking an nn.Module inside HigherOrderOperator"
  1707. )
  1708. # For modules we store the class
  1709. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
  1710. (
  1711. rv.node.name,
  1712. rv.node.meta["nn_module_stack"][target][1],
  1713. )
  1714. ]
  1715. if "stack_trace" not in rv.node.meta:
  1716. frame_summaries: List[traceback.FrameSummary] = []
  1717. while tx:
  1718. frame_summaries.append(tx.frame_summary())
  1719. tx = getattr(tx, "parent", None)
  1720. # Reverse the frame_summaries, such that the innermost frame is at the last
  1721. frame_summaries.reverse()
  1722. # official from_list stub doesn't have new-style type
  1723. msgs = traceback.StackSummary.from_list(frame_summaries).format()
  1724. rv.node.stack_trace = "".join(msgs)
  1725. return rv
  1726. def create_node(
  1727. self, op, target, args=None, kwargs=None, name=None, type_expr=None
  1728. ):
  1729. check_pt2_compliant_op(self.output_graph, op, target, args, kwargs)
  1730. if self.parent is not None:
  1731. flat_args = pytree.arg_tree_leaves(*args, **kwargs)
  1732. for arg in flat_args:
  1733. if not isinstance(arg, torch.fx.Node):
  1734. continue
  1735. assert (
  1736. arg.graph == self.graph
  1737. ), "create_node using arg not from this SubgraphTracer"
  1738. node = super().create_node(op, target, args, kwargs, name, type_expr)
  1739. node.meta["creation_timestamp"] = self.output_graph.timestamp
  1740. return node
  1741. # Note: we did not override erase_node since
  1742. # we call self.graph.erase_node elsewhere
  1743. def remove_node(self, node):
  1744. if len(node.users) > 0:
  1745. user_graph_nodes: List[torch.fx.Node] = []
  1746. for user in node.users.keys():
  1747. # For the case where user.graph == self.graph, that is a real bug and will raise
  1748. # properly.
  1749. if user.graph != self.graph:
  1750. # This is a nested graph, which needs to be deleted.
  1751. # If we do not do this, we will raise on attempting to remove this.
  1752. # As we only get here during restoration cleanup, this is sound.
  1753. user_graph_nodes.extend(reversed(list(user.graph.nodes)))
  1754. for other_graph_node in user_graph_nodes:
  1755. other_graph_node.graph.erase_node(other_graph_node)
  1756. self.graph.erase_node(node)
  1757. self.input_name_to_proxy.pop(node.name, None)
  1758. # when before=True, we will insert this input before the most recent
  1759. # inserted proxy. This is a hack to get around an ordering problem,
  1760. # where we first insert a tensor argument, and then insert bindings
  1761. # for SymInts that may occur in the tensor argument.
  1762. # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets
  1763. # fixed.
  1764. def create_graph_input(self, name, type_expr=None, before=False, source=None):
  1765. log.debug(
  1766. "create_graph_input %s %s",
  1767. name,
  1768. source.name() if source is not None else "(none)",
  1769. )
  1770. if source is None:
  1771. assert (
  1772. self.parent is not None
  1773. ), "you are required to provide a source for inputs on the root tracer"
  1774. # In eager, we are generally OK with adding graph inputs whenever we
  1775. # want, because we take care of writing the bytecode that knows how
  1776. # to source all the inputs.
  1777. #
  1778. # In export, this is bad, because you want a self-contained export
  1779. # object which only depends on the inputs you explicitly passed to it.
  1780. # So we are a bit more strict about what sources can become inputs
  1781. # in export
  1782. if self.export_root:
  1783. if not is_from_local_source(source, allow_cell_or_freevar=False):
  1784. self.output_graph.source_to_user_stacks.setdefault(source, []).append(
  1785. TracingContext.extract_stack()
  1786. )
  1787. # unique
  1788. if name in self.input_name_to_proxy:
  1789. for i in itertools.count():
  1790. candidate_name = f"{name}_{i}"
  1791. if candidate_name not in self.input_name_to_proxy:
  1792. name = candidate_name
  1793. break
  1794. if self.input_name_to_proxy:
  1795. prev_name = next(reversed(self.input_name_to_proxy))
  1796. node = self.input_name_to_proxy[prev_name].node
  1797. if before:
  1798. ctx = self.graph.inserting_before(node)
  1799. else:
  1800. ctx = self.graph.inserting_after(node)
  1801. else:
  1802. ctx = self.graph.inserting_before(None)
  1803. with ctx:
  1804. proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
  1805. if self.input_name_to_proxy and before:
  1806. k, v = self.input_name_to_proxy.popitem()
  1807. self.input_name_to_proxy[name] = proxy
  1808. self.input_name_to_proxy[k] = v
  1809. else:
  1810. self.input_name_to_proxy[name] = proxy
  1811. return proxy
  1812. # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details
  1813. def lift_tracked_freevar_to_input(self, proxy):
  1814. # You're doing something wrong if we are the root SubgraphTracer because
  1815. # Dynamo adds tensors to graph inputs before creating a proxy for them.
  1816. assert (
  1817. self.parent is not None
  1818. ), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer"
  1819. # Proxys are associated with VariableTracker.
  1820. # It is possible that we've already lifted the Proxy to be an input.
  1821. # If that is the case, just return the already lifted Proxy.
  1822. if proxy in self.lifted_freevars:
  1823. return self.lifted_freevars[proxy]
  1824. new_proxy = self.create_graph_input(proxy.node.name)
  1825. set_example_value(new_proxy.node, proxy.node.meta["example_value"])
  1826. self.lifted_freevars[proxy] = new_proxy
  1827. if self.parent is not None and proxy.tracer != self.parent:
  1828. self.parent.lift_tracked_freevar_to_input(proxy)
  1829. return new_proxy
  1830. def maybe_lift_tracked_freevar_to_input(self, arg):
  1831. """
  1832. If arg is a free variable, then lift it to be an input.
  1833. Returns the new lifted arg (if arg was a freevar), else the
  1834. original arg.
  1835. """
  1836. if not isinstance(arg, torch.fx.Proxy):
  1837. return arg
  1838. elif arg.tracer == self:
  1839. return arg
  1840. return self.lift_tracked_freevar_to_input(arg)
  1841. # NOTE: [HigherOrderOperator tracing design]
  1842. # Ignoring HigherOrderOperators for a moment,
  1843. # OutputGraph represents the graph being built by Dynamo that may be compiled
  1844. # and executed. It holds a root SubgraphTracer where the FX graph is built.
  1845. #
  1846. # HigherOrderOperators are operators that take functions as their arguments.
  1847. # When Dynamo encounters a HigherOrderOperator, then it attempts to introspect
  1848. # the function passed to it (call this the "body function"), capture it into a
  1849. # GraphModule, and rewrite the call to the HigherOrderOperator to use the
  1850. # GraphModule.
  1851. #
  1852. # The way we handle the capture of body functions is through having
  1853. # (possibly nested) SubgraphTracers, one per body function.
  1854. #
  1855. # Mechanically, we do the introspection by:
  1856. # - Creating a new SubgraphTracer via OutputGraph.subtracer
  1857. # - Executing the body function.
  1858. # This constructs the graph of the body function in the new SubgraphTracer
  1859. # while modifying the state of the OutputGraph. For example:
  1860. # - the OutputGraph can receive new GraphArgs (if we discover any new
  1861. # untracked Tensors)
  1862. # - side effects from the body function get accumulated into
  1863. # OutputGraph.side_effects
  1864. # - guards produced by the body function get accumulated into OutputGraph.guards
  1865. #
  1866. # The traced function has some special properties that make it easier for us
  1867. # to transform later down the line:
  1868. # - we lift all free variables to being inputs.
  1869. #
  1870. # If the introspection fails (due to the existence of graph breaks), then
  1871. # we roll back the current OutputGraph state and graph break on the
  1872. # HigherOrderOperator.