| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147 |
- # mypy: allow-untyped-defs
- import collections
- import contextlib
- import copy
- import functools
- import itertools
- import logging
- import operator
- import re
- import sys
- import traceback
- import weakref
- from dataclasses import dataclass
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
- import sympy
- import torch._guards
- import torch._logging
- import torch.nn
- import torch.utils._pytree as pytree
- from torch import fx
- from torch._guards import GlobalContextCheckpointState, Source, TracingContext
- from torch._utils_internal import signpost_event
- from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
- from torch.fx.experimental._backward_state import BackwardState
- from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv
- from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- from . import config, logging as torchdynamo_logging, variables
- from .backends.registry import CompiledFn, CompilerFn
- from .bytecode_transformation import (
- create_call_function,
- create_instruction,
- Instruction,
- unique_id,
- )
- from .code_context import code_context
- from .codegen import PyCodegen
- from .current_scope_id import enter_new_scope
- from .exc import (
- BackendCompilerFailed,
- exceptions_allowed_to_be_fallback,
- SkipFrame,
- unimplemented,
- unimplemented_with_warning,
- )
- from .guards import GuardBuilder, install_guard
- from .mutation_guard import is_dynamic_nn_module
- from .side_effects import AttributeMutationExisting, SideEffects
- from .source import (
- AttrSource,
- BackwardStateSource,
- ConstantSource,
- GetItemSource,
- GlobalStateSource,
- is_constant_source,
- is_from_local_source,
- LocalSource,
- ParamBufferSource,
- ShapeEnvSource,
- SyntheticLocalSource,
- TensorProperty,
- TensorPropertySource,
- )
- from .utils import (
- checkpoint_params,
- CleanupHook,
- clone_inputs,
- count_calls,
- counters,
- dynamo_timed,
- get_instruction_source_311,
- get_locals_to_steal,
- get_static_address_type,
- graph_break_reasons,
- increment_op_count,
- lazy_format_graph_code,
- LazyString,
- nn_module_proxy,
- same,
- set_example_value,
- )
- from .variables.base import VariableTracker
- from .variables.builder import (
- BackwardStateGraphArg,
- GraphArg,
- TrackedFake,
- VariableBuilder,
- wrap_fx_proxy,
- )
- from .variables.lists import BaseListVariable
- from .variables.misc import NullVariable
- from .variables.nn_module import NNModuleVariable
- from .variables.tensor import (
- NumpyNdarrayVariable,
- SymNodeVariable,
- TensorVariable,
- UnspecializedPythonVariable,
- )
- from .variables.torch_function import TensorWithTFOverrideVariable
- if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslatorBase
- log = logging.getLogger(__name__)
- graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
- graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
- graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
- trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
- @dataclass(frozen=True)
- class VariableTrackerCacheKey:
- vt_id: int
- # Two different source can point to the same object. However, Dynamo handles
- # globals and local source differently when it comes to guards and possibly
- # some other parts as well. So, cache also relies on the source.
- source: Source
- class VariableTrackerCache:
- def __init__(self):
- self.cache = {}
- def lookup(self, value, source):
- key = VariableTrackerCacheKey(id(value), source)
- if key not in self.cache:
- return None
- return self.cache[key]
- def add(self, value, source, vt):
- key = VariableTrackerCacheKey(id(value), source)
- self.cache[key] = vt
- def clone(self):
- # Needed for copy and restore graph state
- new_cache = VariableTrackerCache()
- new_cache.cache.update(self.cache)
- return new_cache
- def clear(self):
- self.cache.clear()
- @functools.lru_cache(None)
- def _step_logger():
- return torchdynamo_logging.get_step_logger(log)
- @dataclass
- class GraphCompileReason:
- """Stores why a given output graph was compiled; i.e. what caused the graph break."""
- reason: str
- user_stack: List[traceback.FrameSummary]
- # Indicates if this was a graph compile reason due to graph break.
- graph_break: bool = True
- def __post_init__(self):
- if self.graph_break:
- graph_break_reasons.append(self)
- def _get_gen_rand_values_fn(random_calls):
- def _gen_rand_values():
- return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
- return _gen_rand_values
- class FakeRootModule(torch.nn.Module):
- """Trick the constructor of fx.GraphModule"""
- def __init__(self, nn_modules: Dict[str, torch.nn.Module]):
- super().__init__()
- for k, v in nn_modules.items():
- setattr(self, k, v)
- def __repr__(self):
- return "FakeRootModule(...)"
- class WrapperBackend:
- def __init__(self, backend: CompilerFn):
- self.backend: CompilerFn = backend
- def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
- self.restore = checkpoint_params(gm)
- self.gm = gm
- copy_gm = copy.deepcopy(self.gm)
- self.candidate = self.backend(copy_gm, example_inputs)
- if self.candidate is None or self.candidate is self.gm.forward:
- return self.gm.forward
- if not config.verify_correctness:
- return self.candidate
- # if verify_correctness=True
- try:
- correct = self.gm.forward(*clone_inputs(example_inputs))
- result = self.candidate(*clone_inputs(example_inputs))
- # TODO: replace `same` function with the one in testing
- if same(correct, result):
- return self.candidate
- raise RuntimeError(f"incorrect results of backend {self}")
- return self.gm.forward
- except Exception:
- log.exception("error in verify_correctness")
- raise
- finally:
- self.restore()
- Scope = Dict[str, object]
- class OutputGraph:
- """
- Wrapper class to hold outputs of InstructionTranslator. Mainly the
- generated fx.Graph.
- OutputGraph is 1:1 with a frame being processed. Each frame is associated
- with some root InstructionTranslator. When user code calls a function,
- we construct a InliningInstructionTranslator that continues to write into
- the root InstructionTranslator's OutputGraph.
- """
- def __init__(
- self,
- code_options: Dict[str, Any],
- compiler_fn: Optional[CompilerFn],
- root_tx,
- export: bool,
- export_constraints,
- frame_state,
- local_scope: Scope,
- global_scope: Scope,
- f_code,
- ):
- super().__init__()
- self.tracers = [SubgraphTracer(self, export_root=export)]
- # Map from graph input's `Source` to its `VariableTracker` to
- # de-duplicate graph inputs by source and reuse the tracker
- self.input_source_to_var: Dict[Source, VariableTracker] = {}
- self.export = export
- self.export_constraints = export_constraints
- self.frame_state = frame_state
- # Map from graph input's `Source` to sizes / strides metadata
- self.input_source_to_sizes_strides: Dict[Source, Dict[str, Any]] = {}
- self.cleanup_hooks: List[Callable[[], Any]] = []
- # compile_id is an id number for the current torch.compile
- self.compile_id: int = next(_compile_id_counter)
- # Set of globals installed via install_global* APIs
- self.installed_globals: Set[str] = set()
- # TODO: maybe should just pass the entire f_code in here? Not
- # sure...
- self.co_fields = {
- "co_name": f_code.co_name,
- "co_filename": f_code.co_filename,
- "co_firstlineno": f_code.co_firstlineno,
- }
- # tracked_fakes says where any tensor that was wrapped to fake came
- # from. It is similar to GraphArg, in that all GraphArgs will get
- # will get added to TrackedFakes, but TrackedFakes also contains
- # GraphArgs that got pruned, and things like Tensor attributes which
- # aren't explicit graph inputs. Used by shape guard
- self.tracked_fakes: List[TrackedFake] = []
- # List of symbols for which we have exact bindings in the arguments
- # already
- self.bound_symbols: Set[sympy.Symbol] = set()
- shape_env = ShapeEnv(
- # Reference Cycle!
- # Share a reference to the list of TrackedFake.
- #
- # ShapeEnv needs this in order to be able to reproduce the call
- # to produce_guards at an arbitrary time point. That is because
- # TrackedFake instances may have its metadata changed throughout
- # the program execution.
- tracked_fakes=self.tracked_fakes,
- allow_scalar_outputs=config.capture_scalar_outputs,
- allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
- prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
- _allow_complex_guards_as_runtime_asserts=config._allow_complex_guards_as_runtime_asserts,
- co_fields=self.co_fields,
- )
- # In export mode, we force the shape_env to strictly disallow any constraining
- # of the user marked dynamic dims
- import torch._functorch.config as _config
- with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
- fake_mode = torch._subclasses.FakeTensorMode(
- shape_env=shape_env,
- # TODO (tmanlaibaatar) Remove this once we always lift params and buffers
- allow_non_fake_inputs=True if self.export else False,
- export=self.export,
- )
- self.tracing_context: TracingContext = TracingContext(fake_mode)
- self.init_ambient_guards()
- # Map each tensor id to a list of sources. This is necessary because
- # tensor ids cannot be recovered from tracked fakes (in general).
- # We use this map to interpret (i.e., check for violations of) constraints,
- # specifically equality constraints, which have shared tensor ids in them.
- # This map should also be generally useful, e.g., for (de)serialization.
- self.tracked_fakes_id_to_source: Dict[
- int, List[Source]
- ] = collections.defaultdict(list)
- # Stores the full fqn of a param or buffer to the relevant source.
- self.param_name_to_source: Optional[Dict[str, Source]] = dict()
- self.side_effects = SideEffects()
- # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL
- # and LOAD_ATTR for same python objects free.
- self.variable_tracker_cache = VariableTrackerCache()
- self.unique_var_id = itertools.count()
- self.code_options = dict(code_options)
- self.output_instructions: List[Instruction] = []
- # used to track nodes that are added between calls of copy_graphstate
- # and restore_graphstate
- self.timestamp = 0
- # A list of register_finalizer_fns to apply to the output graph module
- self.register_finalizer_fns: List[Callable[[fx.GraphModule], None]] = []
- # Not checkpointed
- self.compiler_fn: Optional[CompilerFn] = compiler_fn
- self.global_scope = global_scope
- self.local_scope = local_scope
- self.root_tx = root_tx
- # Given a source, what are the user stacks of all locations that
- # accessed it?
- #
- # For efficiency, we only populate this:
- # - During export, and
- # - If the source could potentially lead to a spurious export input
- #
- # Feel free to populate this more frequently if other use-cases arise,
- # but be aware that we have to generate full stacks for each
- # recording!
- self.source_to_user_stacks: Dict[Source, List[traceback.StackSummary]] = {}
- self._current_tx: List[InstructionTranslatorBase] = []
- self.cleanups: List[CleanupHook] = []
- self.should_exit = False
- self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {}
- self.torch_function_enabled = torch._C._is_torch_function_enabled()
- # Tracks if the output graph has a user defined allowed function in the
- # graph. This is used later to determine if we should fallback to eager
- # for certain exceptions. THe idea is that if the user has applied
- # allow_in_graph, they would like to see the error instead of falling
- # back for backend errors.
- self.has_user_defined_allowed_in_graph = False
- # Tracks a list of called ops that were not tagged with "pt2_compliant_tag".
- # This information is useful for logging.
- self.non_compliant_ops: Set[torch._ops.OpOverload] = set({})
- # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag".
- # This information is useful for logging.
- self.compliant_custom_ops: Set[torch._ops.OpOverload] = set({})
- # We save the global torch state here to be restored in case of graph
- # breaks. The relevant issue is seen here
- # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086
- # where inlining of a function changes the global state (because of the
- # presence of torch.no_grad) and there is a graph break.
- self.save_global_state()
- # Tracks the original FQNs of the constant tensors from the original graph,
- # i.e. buffers and parameters.
- self.dynamo_flat_name_to_original_fqn: Dict[str, str] = {}
- # All calls to random() are replaced with a single call to __gen_rand_values
- # functions that returns a tuple of random values for each original call.
- # random_calls tracks calls to random() and random_values_var stores the name of
- # the variable that stores __gen_rand_values results.
- self.random_calls: List[
- Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
- ] = []
- self.random_values_var = None
- # Bytecode to insert right before we call the graph
- self.pregraph_bytecode: List[Instruction] = []
- # Use to pass values to backward hooks when using compiled autograd
- self.backward_state: Dict[str, VariableTracker] = {}
- self.backward_state_proxy: Optional[torch.fx.Proxy] = None
- self.backward_state_var: Optional[str] = None
- self.name_of_builtins_dict_key_in_fglobals: str = (
- self.install_builtins_dict_in_fglobals()
- )
- self.guard_on_key_order: Set[str] = set()
- def install_builtins_dict_in_fglobals(self):
- # f_globals["__builtins__"] can be a dict or a module. This is an
- # implemenation detail -
- # https://docs.python.org/3/library/builtins.html.
- # This makes guarding on any builtin messy because the guard check_fn
- # has to check if the __builtins__ is a module or dict, and then access
- # by either using getattr or getitem respectively.
- # To solve this problem, we insert a new entry in f_globals which points
- # to the builtins __dict__ and then we guard any builtin on this dict.
- # To avoid any collision with the pre-existing keys, we use the
- # install_global to give us a unique dict key.
- f_builtins = self.global_scope["__builtins__"]
- if not isinstance(f_builtins, dict):
- f_builtins = f_builtins.__dict__
- return self.install_global("__builtins_dict__", f_builtins)
- def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"):
- name = f"{prefix}{len(self.backward_state)}"
- assert name not in self.backward_state
- self.backward_state[name] = hook
- return name, self.get_backward_state_proxy()
- def get_backward_state_proxy(self):
- if self.backward_state_proxy is None:
- if self.export:
- unimplemented("backward_state does not support export")
- self.backward_state_proxy = self.root_tracer.create_graph_input(
- "dynamo_backward_state", BackwardState, source=BackwardStateSource()
- )
- self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg()
- set_example_value(self.backward_state_proxy.node, BackwardState())
- self.backward_state_var = self.new_var()
- return self.backward_state_proxy
- # This gets its own helper function so guards DEBUG logs are more informative
- def init_ambient_guards(self):
- # Register a SHAPE_ENV guard to make sure we setup shape guards
- # that show up in ShapeEnv
- self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
- self.guards.add(
- GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS)
- )
- self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE))
- self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
- self.guards.add(
- GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
- )
- ci = torch._C._functorch.peek_interpreter_stack()
- if ci is not None:
- self.guards.add(
- GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
- )
- def synthetic_graph_input(self, fn, args):
- """
- call fn(*args) before the graph runs and turn the result into a fake input.
- """
- example_value = fn(*args)
- varname = self.new_var()
- cg = PyCodegen(self.root_tx)
- cg.load_import_from(
- fn.__module__,
- fn.__name__,
- )
- cg.foreach(map(variables.ConstantVariable.create, args))
- cg.call_function(len(args), True)
- cg.store(varname)
- self.pregraph_bytecode.extend(cg.get_instructions())
- source = SyntheticLocalSource(varname)
- result = VariableBuilder(self.root_tx, source)(example_value)
- TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
- source
- )
- return result
- def add_cleanup_hook(self, fn: Callable[[], Any]):
- self.cleanup_hooks.append(fn)
- def call_cleanup_hooks(self):
- for hook in reversed(self.cleanup_hooks):
- hook()
- self.cleanup_hooks.clear()
- @property
- def root_tracer(self):
- return self.tracers[0]
- @property
- def current_tracer(self):
- return self.tracers[-1]
- def is_root_tracer(self):
- # Helper to tell if we are inside the higher order operator tracing.
- return len(self.tracers) == 1
- @property
- def graph(self):
- return self.current_tracer.graph
- # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer.
- @graph.setter
- def graph(self, value):
- self.current_tracer.graph = value
- @property
- def input_name_to_proxy(self):
- return self.current_tracer.input_name_to_proxy
- @property
- def real_value_cache(self):
- return self.current_tracer.real_value_cache
- # If you are here, and you're looking for create_graph_input,
- # to avoid ambiguity, please call one of the following:
- # - self.current_tracer.create_graph_input
- # - self.root_tracer.create_graph_input
- # See NOTE [HigherOrderOperator tracing design] for more context.
- def create_proxy(self, *args, **kwargs):
- return self.current_tracer.create_proxy(*args, **kwargs)
- def create_node(self, *args, **kwargs):
- return self.current_tracer.create_node(*args, **kwargs)
- def remove_node(self, *args, **kwargs):
- return self.current_tracer.remove_node(*args, **kwargs)
- @contextlib.contextmanager
- def subtracer(self, source_target, prior_tracer):
- new_scope_ctx = enter_new_scope()
- try:
- if prior_tracer:
- # Lineage MUST stay preserved
- assert prior_tracer.parent is self.current_tracer
- new_scope_ctx.__enter__()
- tracer = (
- prior_tracer
- if prior_tracer
- else SubgraphTracer(
- self, parent=self.current_tracer, source_target=source_target
- )
- )
- self.tracers.append(tracer)
- yield tracer
- finally:
- new_scope_ctx.__exit__(None, None, None)
- self.tracers.pop()
- @property
- def output(self):
- return self
- @property
- def fake_mode(self):
- return self.tracing_context.fake_mode
- @property
- def shape_env(self):
- return self.tracing_context.fake_mode.shape_env
- @property
- def guards(self) -> torch._guards.GuardsSet:
- return self.tracing_context.guards_context.dynamo_guards
- @property
- def nn_modules(self) -> Dict[str, Any]:
- return self.tracing_context.module_context.nn_modules
- def save_global_state(self, out=None):
- """
- Saves to out if it is provided. Else saves to the tracing context's global_state.
- """
- global_state = (
- out if out is not None else self.tracing_context.global_context.global_state
- )
- # TODO - Consider having a torch level API for torch_function_state. As
- # of now, we create a ref cycle by passing the
- # output.set_torch_function_state to
- # output.tracing_context.global_context.global_state. In the interim,
- # the problem can be solved by manually set
- # output.tracing_context.global_context.global_state to None at cleanup.
- global_state["torch_function_enabled"] = (
- self.set_torch_function_state,
- self.torch_function_enabled,
- )
- global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())
- global_state["autocast_enabled"] = (
- functools.partial(torch.set_autocast_enabled, "cuda"),
- torch.is_autocast_enabled("cuda"),
- )
- global_state["autocast_cpu_enabled"] = (
- functools.partial(torch.set_autocast_enabled, "cpu"),
- torch.is_autocast_enabled("cpu"),
- )
- global_state["autocast_gpu_dtype"] = (
- functools.partial(torch.set_autocast_dtype, "cuda"),
- torch.get_autocast_dtype("cuda"),
- )
- global_state["autocast_cpu_dtype"] = (
- functools.partial(torch.set_autocast_dtype, "cpu"),
- torch.get_autocast_dtype("cpu"),
- )
- global_state["autocast_cache_enabled"] = (
- torch.set_autocast_cache_enabled,
- torch.is_autocast_cache_enabled(),
- )
- def push_tx(self, tx):
- self._current_tx.append(tx)
- def pop_tx(self):
- return self._current_tx.pop()
- @property
- def current_tx(self):
- return self.root_tx if not self._current_tx else self._current_tx[-1]
- def add_symbol_bindings(self, arg: GraphArg):
- # Insert implicit size vars as necessary. With dynamic shapes, we
- # maintain the invariant that every sizevar gets a direct SymInt input
- # into the graph. This means downstream graph transforms can assume
- # every size variable is explicitly bound and accessible, instead of
- # having to pull it out implicitly from tensors.
- if self.export:
- return
- assert arg.fake_tensor is not None
- def bind_symint(s, prop):
- if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)):
- return
- s0 = s.node.expr
- if s0 in self.bound_symbols:
- return
- self.bound_symbols.add(s0)
- log.debug("bind_symint %s %s", s, prop.name())
- # TODO: don't readd symint if we already have it in graph
- # (this is harmless because we do remove the unused ones later)
- proxy = self.root_tracer.create_graph_input(
- str(s0),
- torch.SymInt,
- before=True,
- source=prop,
- )
- set_example_value(proxy.node, s)
- proxy.node.meta["grapharg"] = GraphArg(
- prop,
- s,
- pass_arg_as_tensor=False,
- fake_tensor=None,
- is_tensor=False,
- )
- def handle_tensor(t, src):
- for i, s in enumerate(t.size()):
- bind_symint(s, TensorPropertySource(src, TensorProperty.SIZE, i))
- if t.layout is torch.strided:
- for i, s in enumerate(t.stride()):
- bind_symint(s, TensorPropertySource(src, TensorProperty.STRIDE, i))
- bind_symint(
- t.storage_offset(),
- TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
- )
- elif t.layout is torch.sparse_coo:
- handle_tensor(t._indices(), src)
- handle_tensor(t._values(), src)
- elif t.layout in {torch.sparse_csr, torch.sparse_bsr}:
- handle_tensor(t.crow_indices(), src)
- handle_tensor(t.col_indices(), src)
- elif t.layout in {torch.sparse_csc, torch.sparse_bsc}:
- handle_tensor(t.ccol_indices(), src)
- handle_tensor(t.row_indices(), src)
- if is_traceable_wrapper_subclass(t):
- attrs, ctx = t.__tensor_flatten__()
- for attr in attrs:
- inner_t = getattr(t, attr)
- handle_tensor(inner_t, AttrSource(src, attr))
- handle_tensor(arg.fake_tensor, arg.source)
- def count_calls(self):
- return count_calls(self.graph)
- def is_empty_graph(self):
- return len(list(self.graph.nodes)) == 0
- def get_submodule(self, keys):
- assert keys
- obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules
- for k in keys.split("."):
- if isinstance(obj, dict):
- obj = obj[k]
- else:
- obj = getattr(obj, k)
- return obj
- def new_var(self, name="tmp"):
- existing = set(self.code_options["co_varnames"])
- # In common case, this will be O(1)
- while True:
- var = f"{name}_{next(self.unique_var_id)}"
- if var not in existing:
- self.code_options["co_varnames"] += (var,)
- return var
- def update_co_names(self, name):
- """Ensure self.code_options.co_names contains name"""
- if name not in self.code_options["co_names"]:
- self.code_options["co_names"] += (name,)
- @staticmethod
- def module_key_name(*names):
- # create a new unique name
- name = "_".join(map(str, names))
- # Strip the guard lookup L/G access
- name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name)
- # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv
- name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
- # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
- name = re.sub(r"[^a-zA-Z0-9]", "_", name)
- if not name or not name[0].isalpha():
- name = "sub" + name
- return name
- def register_attr_or_module(
- self,
- target: Union[torch.nn.Module, torch.Tensor, Any],
- *names,
- **options,
- ):
- if is_dynamic_nn_module(target, self.root_tx.export):
- return variables.UnspecializedNNModuleVariable(target, **options)
- options = dict(options)
- assert "source" in options
- source = options["source"]
- assert not isinstance(source, ParamBufferSource)
- if isinstance(target, torch.Tensor):
- tracer = self.current_tracer
- if not self.is_root_tracer():
- # For higher order ops, we don't want to insert the get_attr in
- # innermost graph. Instead, we want to raise the params/buffers
- # as inputs to the higher-order graph, and register them as
- # get_attrs in the root tracer.
- # Note that Dynamo will still call lift_tracked_freevar_to_input
- # when these inputs are encountered for the inner graph. The
- # only difference is what happens at the root tracer for
- # nn.Parameters vs free inputs. The free inputs are registered
- # as placeholders in the root graph, whereas the nn.Parameters
- # are registered as get_attr nodes in the root graph.
- tracer = self.root_tracer
- def wrap_name(module_key):
- assert self.param_name_to_source is not None
- self.param_name_to_source[module_key] = source
- # Check if the attr has already been registered. This can happen
- # when two different sources point to the same tensor.
- if target in self.root_tx.output.side_effects:
- return self.root_tx.output.side_effects[target]
- if get_static_address_type(target) == "guarded":
- install_guard(source.make_guard(GuardBuilder.ID_MATCH))
- elif not is_constant_source(source):
- install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
- vt = wrap_fx_proxy(
- self.root_tx,
- tracer.create_proxy("get_attr", module_key, tuple(), {}),
- example_value=target,
- **options,
- )
- # Track the object so to avoid duplicate registration in case of
- # different sources pointing to the same tensor object.
- vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
- return vt
- elif isinstance(target, torch.nn.Module):
- assert isinstance(target, torch.nn.Module)
- if source:
- install_guard(source.make_guard(GuardBuilder.NN_MODULE))
- def wrap_name(module_key):
- return NNModuleVariable(type(target), module_key, target, **options)
- else:
- # This is Dynamo created graph module, e.g., graph module coming
- # from higher order ops. NNModuleVariable tracker can't be
- # sourceless, so let's return a unspecializedNNModule variable
- # tracker.
- def wrap_name(module_key):
- return variables.UnspecializedNNModuleVariable(target, **options)
- elif isinstance(target, (torch.SymInt, torch.SymFloat)):
- # HACKY CODE REGION BEGIN
- # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
- # This ultimately gets written to self.nn_modules, which is unfortunate
- # Attrs that are tenors and symints and such need to be migrated to have their
- # own storage
- # alas, this is like this for now
- def wrap_name(module_key):
- return SymNodeVariable.create(
- self,
- self.create_proxy("get_attr", module_key, tuple(), {}),
- sym_num=target,
- **options,
- )
- # HACKY CODE REGION END
- else:
- def wrap_name(module_key):
- self.output.update_co_names(module_key)
- self.global_scope[module_key] = target
- return VariableBuilder(self, ConstantSource(source_name=module_key))(
- target
- )
- for k, v in self.nn_modules.items():
- if v is target:
- # it already exists
- return wrap_name(k)
- name = OutputGraph.module_key_name(*names)
- base = name
- for i in itertools.count():
- if name not in self.nn_modules:
- self.nn_modules[name] = target
- if isinstance(target, torch.nn.Module):
- def register_leaf_name(leaf_name):
- assert self.param_name_to_source is not None
- new_source = ParamBufferSource(source, leaf_name)
- new_name = f"{name}.{leaf_name}"
- self.param_name_to_source[new_name] = new_source
- if isinstance(source, LocalSource):
- self.dynamo_flat_name_to_original_fqn[
- OutputGraph.module_key_name(new_source.name())
- ] = leaf_name
- # annoying, but there are cases when we do not have parameters
- # see test_nn_moduledict_contains
- if hasattr(target, "_parameters"):
- for leaf_name, _ in target.named_parameters():
- register_leaf_name(leaf_name)
- if hasattr(target, "_buffers"):
- for leaf_name, _ in target.named_buffers():
- register_leaf_name(leaf_name)
- return wrap_name(name)
- name = f"{base}_{i}"
- raise AssertionError("unreachable")
- def handle_aliases_for_stolen_lists(self, tx):
- # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive
- maybe_gm = self.local_scope.get("self")
- stolen_list_names = get_locals_to_steal(maybe_gm)
- if not stolen_list_names:
- return []
- alias_insts = []
- needs_alias: Dict[
- str, List[Union[VariableTracker, AttributeMutationExisting]]
- ] = {}
- queue = [
- *tx.stack,
- *tx.symbolic_locals.values(),
- *self.side_effects.store_attr_mutations.keys(),
- ]
- while queue:
- x = queue.pop()
- if isinstance(x, BaseListVariable):
- assert isinstance(x.items, List)
- queue += x.items
- continue
- if not (
- isinstance(x, (VariableTracker, AttributeMutationExisting))
- and isinstance(x.source, GetItemSource)
- and isinstance(x.source.base, LocalSource)
- and x.source.base.local_name in stolen_list_names
- ):
- continue
- stolen_name = x.source.base.local_name
- if stolen_name not in needs_alias:
- needs_alias[stolen_name] = []
- needs_alias[stolen_name].append(x)
- visited = {}
- for arg in self.graphargs:
- if not (
- isinstance(arg._example, list)
- and isinstance(arg.source, LocalSource)
- and arg.source.local_name in needs_alias
- ):
- continue
- # arg is a list that will be cleared by the compiled function
- list_name = arg.source.local_name
- assert list_name in self.code_options["co_varnames"]
- for x in needs_alias[list_name]:
- list_idx = x.source.index
- if list_idx not in visited:
- alias_name = self.new_var(
- f"{list_name}_ref"
- ) # self.new_var already adds unique id suffix
- visited[list_idx] = alias_name
- # bytecode of `alias_name = list_name[list_idx]`
- alias_insts.extend(
- [
- create_instruction("LOAD_FAST", argval=list_name),
- create_instruction("LOAD_CONST", argval=list_idx),
- create_instruction("BINARY_SUBSCR"),
- create_instruction("STORE_FAST", argval=alias_name),
- ]
- )
- # operate on alias, handled by suffix codegen
- x.source = LocalSource(visited[list_idx])
- return alias_insts
- def compile_subgraph(
- self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
- ):
- """
- Generate a subgraph to continue execution on user code.
- Automatically restore live variables.
- """
- assert reason is not None
- from .decorators import disable
- self.partial_convert = partial_convert
- self.compile_subgraph_reason = reason
- self.should_exit = True
- log.debug("COMPILING GRAPH due to %s", reason)
- if not all(block.can_restore() for block in tx.block_stack):
- unimplemented("compile_subgraph with block_depth != 0")
- prefix_insts: List[Instruction] = []
- if sys.version_info >= (3, 11):
- # prefix instructions (Python 3.11+)
- for inst in tx.prefix_insts:
- if inst.opname == "MAKE_CELL":
- prefix_insts.append(
- create_instruction("MAKE_CELL", argval=inst.argval)
- )
- elif inst.opname == "COPY_FREE_VARS":
- prefix_insts.append(
- create_instruction(
- "COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"])
- )
- )
- else:
- prefix_insts.append(copy.copy(inst))
- assert not (
- self.pregraph_bytecode and self.export
- ), "export does not support pregraph_bytecode"
- prefix_insts.extend(self.pregraph_bytecode)
- prefix_insts.extend(self.handle_aliases_for_stolen_lists(tx))
- def append_prefix_insts():
- self.add_output_instructions(prefix_insts)
- prefix_insts.clear()
- for block in reversed(tx.block_stack):
- block.exit(tx)
- self.cleanup_graph()
- tx.prune_dead_locals()
- stack_values = list(tx.stack)
- # realize any unrealized tensor VTs in case they
- # need to be added to self.nn_modules as attributes
- for value in stack_values:
- value.realize()
- # Use nn.Module "proxies" in the constructed GraphModule so that
- # the resulting GM does not hold additional strong references to the original modules.
- # This prevents a strong ref cycle where Dynamo created code holds on to references
- # to modules that also have Dynamo code cache invalidation checks.
- # When cache invalidation runs, the generated GM will be invalidated, which also deletes
- # the proxies.
- nn_modules_proxies = {
- name: nn_module_proxy(mod) for name, mod in self.nn_modules.items()
- }
- root = FakeRootModule(nn_modules_proxies)
- # Add all the local vars to the "stack" so restore at the end
- restore_vars = []
- val_to_names: Dict[VariableTracker, List[str]] = {}
- if stack_values:
- val_to_names[stack_values[-1]] = list()
- # NB: Typically (i.e., for graph compile from RETURN_VALUE),
- # symbolic_locals will be empty at this point, as prune_dead_locals
- # will clear out all of symbolic_locals because RETURN_VALUE is the
- # last instruction and no more locals are used. The fanciness here
- # is only needed for partial graphs.
- for k, v in tx.symbolic_locals.items():
- # Note! this explicitly uses .local_name for matching
- # Failure to do so will cause spurious registrations in val_to_names.
- # This will in turn result in spurious variables showing up in the graph.
- # This was very tricky to debug. For an example, dump the graph at call_user_compiler
- # while running test_subgraphs.py
- if isinstance(v.source, LocalSource) and v.source.local_name == k:
- continue # no need to restore initial state
- # Do not load variable if it is NULL.
- if sys.version_info >= (3, 12):
- # Continuation function will load the NULL for v.
- if type.__instancecheck__(NullVariable, v):
- continue
- else:
- # A variable should never be NULL in < 3.12
- assert not type.__instancecheck__(NullVariable, v)
- if v not in val_to_names:
- val_to_names[v] = list()
- val_to_names[v].append(k)
- for v in val_to_names.keys():
- restore_vars.extend(val_to_names[v])
- stack_values.extend([v] * len(val_to_names[v]))
- # to handle random calls
- if len(self.random_calls) > 0:
- append_prefix_insts()
- random_calls_instructions = []
- self.random_values_var = self.new_var("random_values")
- rand_fn = disable(_get_gen_rand_values_fn(self.random_calls))
- rand_fn_name = self.install_global("__gen_rand_values", rand_fn)
- codegen = PyCodegen(tx, root)
- random_calls_instructions.extend(
- codegen.load_function_name(rand_fn_name, True)
- )
- random_calls_instructions.extend(create_call_function(0, False))
- random_calls_instructions.append(
- codegen.create_store(tx.output.random_values_var),
- )
- self.add_output_instructions(random_calls_instructions)
- if (
- stack_values
- and all(
- not isinstance(
- v,
- (
- UnspecializedPythonVariable,
- NumpyNdarrayVariable,
- TensorWithTFOverrideVariable,
- ),
- )
- and not (isinstance(v, SymNodeVariable) and v.python_type() is float)
- for v in stack_values
- )
- and all(isinstance(x, TensorVariable) for x in stack_values)
- and len(set(stack_values)) == len(stack_values)
- and self.side_effects.is_empty()
- and not len(tx.debug_locals) != 0
- and not self.backward_state
- ):
- append_prefix_insts()
- # optimization to generate better code in a common case
- self.add_output_instructions(
- self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
- + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))]
- )
- # restore all the live local vars
- self.add_output_instructions(
- [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
- )
- else:
- graph_output_var = self.new_var("graph_out")
- pass1 = PyCodegen(tx, root, graph_output_var)
- self.codegen_suffix(tx, stack_values, pass1)
- # one more time now that we have established tempvars
- pass2 = PyCodegen(
- tx,
- root,
- graph_output_var,
- tempvars={val: None for val, count in pass1.uses.items() if count > 1},
- )
- self.codegen_suffix(tx, stack_values, pass2)
- stored_graph_output_var = False
- output = []
- if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
- output.extend(
- self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
- )
- if len(pass2.graph_outputs) != 0:
- output.append(pass2.create_store(graph_output_var))
- stored_graph_output_var = True
- else:
- output.append(create_instruction("POP_TOP"))
- append_prefix_insts()
- self.add_output_instructions(output + pass2.get_instructions())
- # restore all the live local vars
- self.add_output_instructions(
- [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
- )
- if stored_graph_output_var:
- self.add_output_instructions(
- [PyCodegen(tx).create_delete(graph_output_var)]
- )
- def codegen_suffix(self, tx, stack_values, cg):
- if self.backward_state:
- assert not self.export
- for name, val in self.backward_state.items():
- cg(val)
- cg.append_output(cg.create_load(self.backward_state_var))
- cg.store_attr(name)
- self.side_effects.codegen_hooks(cg)
- self.side_effects.codegen_save_tempvars(cg)
- # Return variables used for logging at the end
- for debug_var, args in tx.debug_locals:
- cg(debug_var)
- for arg in args:
- cg(arg)
- cg.extend_output(create_call_function(len(args), True))
- cg.extend_output([create_instruction("POP_TOP")])
- cg.restore_stack(stack_values, value_from_source=not tx.export)
- self.side_effects.codegen_update_mutated(cg)
- def cleanup_graph(self):
- """
- Remove "creation_timestamp" from node meta
- Remove this pattern from the graph:
- torch._C._set_grad_enabled(False)
- torch._C._set_grad_enabled(True)
- """
- assert self.should_exit
- nodes = list(self.graph.nodes)
- for node in nodes:
- node.meta.pop("creation_timestamp", None)
- grad_enabled = torch.is_grad_enabled()
- for node1, node2 in zip(nodes, nodes[1:]):
- if (
- node1.target is torch._C._set_grad_enabled
- and tuple(node1.args) == (not grad_enabled,)
- and not node1._erased
- ):
- grad_enabled = node1.args[0]
- if (
- node2.target is torch._C._set_grad_enabled
- and tuple(node2.args) == (not grad_enabled,)
- and not node2._erased
- ):
- grad_enabled = node2.args[0]
- self.graph.erase_node(node1)
- self.graph.erase_node(node2)
- def get_graph_sizes_structured(self):
- ret = {}
- for node in self.graph.nodes:
- example_value = node.meta.get("example_value", None)
- if isinstance(example_value, torch._subclasses.FakeTensor):
- size = example_value.size()
- ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size]
- return ret
- def get_graph_sizes(self, name: str):
- graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n"
- graph_sizes_str += f"===== {name} =====\n"
- for node in self.graph.nodes:
- example_value = node.meta.get("example_value", None)
- if isinstance(example_value, torch._subclasses.FakeTensor):
- size = example_value.size()
- graph_sizes_str += f"{node.name}: {tuple(size)}\n"
- concrete_size = []
- has_symint = False
- for sz in size:
- if isinstance(sz, int):
- concrete_size.append(sz)
- elif isinstance(sz, torch.SymInt):
- has_symint = True
- concrete_size.append(sz.node.hint)
- else:
- break
- else:
- if has_symint:
- graph_sizes_str += (
- f"{node.name} (concrete): {tuple(concrete_size)}\n"
- )
- return graph_sizes_str
- @contextlib.contextmanager
- def restore_global_state(self):
- """
- Momentarily restores the global state to what it was prior to tracing the current output
- """
- prior_global_state = self.tracing_context.global_context.copy_graphstate()
- current_global_state: Dict[str, Tuple[Any, bool]] = {}
- self.save_global_state(out=current_global_state)
- try:
- # Set to state prior to tracing the graph
- self.tracing_context.global_context.restore_graphstate(prior_global_state)
- yield
- finally:
- # Reset to state at the current time (e.g. before calling the user compiler)
- self.tracing_context.global_context.restore_graphstate(
- GlobalContextCheckpointState(current_global_state)
- )
- @torch._guards.TracingContext.clear_frame()
- def compile_and_call_fx_graph(self, tx, rv, root):
- """
- Generate code from self.graph and return the Instruction()s to
- call that generated code.
- """
- from .decorators import disable
- assert self.should_exit
- name = unique_id("__compiled_fn")
- assert isinstance(rv, list)
- assert isinstance(root, FakeRootModule)
- self.create_node(
- "output",
- "output",
- (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
- {},
- )
- if not config.do_not_emit_runtime_asserts:
- insert_deferred_runtime_asserts(
- fx.GraphModule(root, self.graph),
- self.shape_env,
- name,
- )
- # NB: deferred runtime asserts can keep graphargs live, so make sure
- # those are inserted before pruning
- self.remove_unused_graphargs()
- ncalls = count_calls(self.graph)
- counters["stats"]["calls_captured"] += ncalls
- # free a bit of memory
- self.real_value_cache.clear()
- gm = _make_graph_module(root, self.graph)
- for register_finalizer in self.register_finalizer_fns:
- register_finalizer(gm)
- gm.compile_subgraph_reason = self.compile_subgraph_reason
- gm.meta[
- "dynamo_flat_name_to_original_fqn"
- ] = self.dynamo_flat_name_to_original_fqn.copy()
- graph_code_log.debug(
- "%s",
- lazy_format_graph_code(name, gm, include_stride=True, include_device=True),
- )
- torch._logging.trace_structured(
- "dynamo_output_graph",
- lambda: {"sizes": self.get_graph_sizes_structured()},
- payload_fn=lambda: gm.print_readable(
- print_output=False, include_stride=True, include_device=True
- ),
- )
- self.call_cleanup_hooks()
- old_fake_mode = self.tracing_context.fake_mode
- if not self.export:
- import torch._functorch.config as _config
- with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
- # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
- backend_fake_mode = torch._subclasses.FakeTensorMode(
- shape_env=old_fake_mode.shape_env,
- )
- # TODO(voz): Ostensibily, this should be scoped and
- # restore back to old_fake_mode, but doing so currently violates
- # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
- self.tracing_context.fake_mode = backend_fake_mode
- with self.restore_global_state():
- compiled_fn = self.call_user_compiler(gm)
- from torch.fx._lazy_graph_module import _LazyGraphModule
- if isinstance(compiled_fn, _LazyGraphModule) or (
- isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
- and compiled_fn.__name__ == "_lazy_forward"
- ):
- # Since dynamo will run the forward method for the GraphModule shortly
- # anyways, it does not hurt to do the real recompilation here if
- # this is a _LazyGraphModule. This makes it easier for dynamo to
- # optimize a _LazyGraphModule.
- lazy_gm = (
- compiled_fn
- if isinstance(compiled_fn, _LazyGraphModule)
- else compiled_fn.__self__
- )
- _LazyGraphModule.force_recompile(lazy_gm)
- if not isinstance(compiled_fn, _LazyGraphModule):
- # replace compiled_fn with the real forward method
- compiled_fn = lazy_gm.forward
- compiled_fn = disable(compiled_fn)
- counters["stats"]["unique_graphs"] += 1
- # This is safe because we pre-process name to be unique
- self.install_global_unsafe(name, compiled_fn)
- cg = PyCodegen(tx)
- cg.make_call_generated_code(name)
- return cg.get_instructions()
- @property
- def placeholders(self) -> List[fx.Node]:
- return self.graph.find_nodes(op="placeholder")
- @property
- def graphargs(self) -> List[GraphArg]:
- return [node.meta["grapharg"] for node in self.placeholders]
- @dynamo_timed(phase_name="backend_compile")
- def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
- assert self.compiler_fn is not None
- tot = 0
- placeholders = []
- for node in gm.graph.nodes:
- if node.op in ("call_function", "call_method", "call_module"):
- tot += 1
- if node.op == "placeholder":
- placeholders.append(node)
- increment_op_count(tot)
- for pl in placeholders:
- arg = pl.meta["grapharg"]
- # TODO: Why isn't this stored in meta :think:
- pl._dynamo_source = arg.source
- gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
- gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment]
- try:
- name = (
- self.compiler_fn.__name__
- if hasattr(self.compiler_fn, "__name__")
- else ""
- )
- _step_logger()(logging.INFO, f"calling compiler function {name}")
- compiler_fn = self.compiler_fn
- if config.verify_correctness:
- compiler_fn = WrapperBackend(compiler_fn)
- compiled_fn = compiler_fn(gm, self.example_inputs())
- _step_logger()(logging.INFO, f"done compiler function {name}")
- assert callable(compiled_fn), "compiler_fn did not return callable"
- except exceptions_allowed_to_be_fallback as e:
- if self.has_user_defined_allowed_in_graph:
- raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
- e.__traceback__
- ) from None
- msg = (
- "Backend compiler failed with a fake tensor exception at \n"
- f"{self.root_tx.format_frame_summary()}"
- "Adding a graph break."
- )
- unimplemented_with_warning(e, self.root_tx.f_code, msg)
- except SkipFrame as e:
- # The backend compiler has requested that we skip the frame, instead of
- # aborting execution.
- raise e
- except Exception as e:
- raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
- e.__traceback__
- ) from None
- signpost_event(
- "dynamo",
- "OutputGraph.call_user_compiler",
- {
- **self.co_fields,
- "op_count": tot,
- "node_count": len(gm.graph.nodes),
- "input_count": len(placeholders),
- },
- )
- return compiled_fn
- def example_inputs(self) -> List[torch.Tensor]:
- result = []
- for arg in self.graphargs:
- result.append(arg.example)
- return result
- def remove_unused_graphargs(self) -> None:
- # NB: It's always OK to drop GraphArg for symbols that ended up being
- # specialized. You don't even have to make a guard for it, because
- # ShapeEnv produce_guards operates on tracked_fakes, which never gets
- # pruned. That being said, you'll get marginally better generated
- # guard code if you promote the guard into a Dynamo guard (since that
- # allows for the guard to be done using C++ guards.) If we get
- # ShapeEnv guards to go into C++ guards, this will stop being a thing
- # though!
- assert self.should_exit
- # Miniature DCE pass, but only for obviously trivial operations
- def is_static_true(b_node: fx.node.Argument):
- if b_node is True:
- return True
- if not isinstance(b_node, fx.Node):
- return False
- b = b_node.meta.get("example_value")
- if b is None:
- return False
- if b is True:
- return True
- if (
- isinstance(b, torch.SymBool)
- and (r := b.node.maybe_as_bool()) is not None
- ):
- return r
- # TODO: We can also technically remove all cases when the input
- # doesn't have unbacked inputs, since it's all in the ShapeEnv
- return False
- def is_symnode_arg(a: fx.node.Argument):
- from torch.fx.experimental.sym_node import SymTypes
- if isinstance(a, (int, float, bool)):
- return True
- if isinstance(a, fx.Node):
- return isinstance(a.meta.get("example_value"), SymTypes)
- return False
- # NB: We assume that you cannot do mutations on int/float/bool,
- # because they are immutable types, and therefore is always safe to
- # DCE.
- def is_symnode_compute_node(node):
- from torch.fx.experimental.sym_node import SymTypes
- if node.op != "call_function":
- return False
- # TODO: I don't think it's possible to have a bare int/float here?
- if not isinstance(node.meta.get("example_value"), SymTypes):
- return False
- # TODO: This will bail here if you ever end up with a more complicated
- # computation function, like sum(list_of_ints), even though it
- # should be DCE'able
- if not all(is_symnode_arg(a) for a in node.args):
- return False
- if not all(is_symnode_arg(a) for a in node.kwargs.values()):
- return False
- return True
- # NB: You could try to expand this to cover more cases by simply
- # detecting whenever you have an int output, but this is a bit
- # dangerous in case someone adds a function that returns an int but is
- # mutating. So manually whitelist for now.
- def is_accessor_node(node):
- if (
- node.op == "call_method"
- and isinstance(node.args[0].meta.get("example_value"), torch.Tensor)
- and node.target in ["size", "stride", "storage_offset", "item"]
- ):
- return True
- if node.op == "call_function" and node.target in [
- torch.ops.aten.sym_size,
- torch.ops.aten.sym_size.default,
- torch.ops.aten.sym_size.int,
- torch.ops.aten.sym_stride,
- torch.ops.aten.sym_stride.default,
- torch.ops.aten.sym_stride.int,
- torch.ops.aten.sym_storage_offset,
- torch.ops.aten.sym_storage_offset.default,
- ]:
- return True
- return False
- for node in reversed(list(self.graph.nodes)):
- if len(list(node.users)) == 0:
- if (
- node.op == "get_attr"
- or (node.op == "call_function" and node.target is operator.getitem)
- or (
- node.op == "call_function"
- and node.target is torch._check
- and is_static_true(node.args[0])
- )
- or is_symnode_compute_node(node)
- or is_accessor_node(node)
- ):
- self.remove_node(node)
- def placeholder_binds_symbol(node):
- arg = node.meta["grapharg"]
- example = arg.example
- if isinstance(example, torch.SymInt) and isinstance(
- example.node.expr, sympy.Symbol
- ):
- return example.node.expr
- return None
- def remove_unused(node):
- log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name())
- # I'm not really sure why you need to delete these from the
- # node since the node is going to get removed
- del node.meta["grapharg"]
- self.remove_node(node)
- self.real_value_cache.pop(node, None)
- used_symbols: Set[sympy.Symbol] = set()
- def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
- used_symbols |= free_symbols(fake)
- recheck_placeholders = []
- for node in self.placeholders:
- binds_symbol = placeholder_binds_symbol(node) is not None
- # Don't delete symbol bindings yet
- if binds_symbol:
- if not node.users:
- recheck_placeholders.append(node)
- else:
- if not node.users and not isinstance(
- node.meta["grapharg"], BackwardStateGraphArg
- ):
- remove_unused(node)
- else:
- # Register the free symbols as uses
- arg = node.meta["grapharg"]
- if isinstance(arg, BackwardStateGraphArg):
- continue
- if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
- real_script_obj = node.meta["grapharg"].example
- fake_script_obj = node.meta["grapharg"].example_strong_ref
- flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined]
- for attr in flat_dict.keys():
- fake_attr_val = getattr(fake_script_obj.wrapped_obj, attr)
- pytree.tree_map_only(
- (torch.SymInt, torch.Tensor),
- lambda t: update_used_symbols(used_symbols, t),
- fake_attr_val,
- )
- continue
- fake = (
- arg.fake_tensor if arg.fake_tensor is not None else arg.example
- )
- update_used_symbols(used_symbols, fake)
- # After removing unused graphargs, prune unused binds_symbol
- for node in recheck_placeholders:
- symbol = placeholder_binds_symbol(node)
- if symbol is not None:
- if symbol not in used_symbols:
- remove_unused(node)
- else:
- # Make sure we delete later occurrences of the same symbol
- used_symbols.remove(symbol)
- def add_output_instructions(self, prefix: List[Instruction]) -> None:
- """
- We call this on the creation of a new compiled subgraph that is inserted
- before user code.
- """
- self.output_instructions.extend(prefix)
- self.should_exit = True
- def install_global_unsafe(self, name, value) -> None:
- """
- WARNING: prefer the safer `install_global_by_id/install_global`.
- torch.compile instances should be independent of each other;
- one footgun is to have one instance depend on the existence of
- a global installed by another instance. This can happen if we mangle
- a global the same way across both instances.
- """
- assert name not in self.installed_globals
- self.installed_globals.add(name)
- self.cleanups.append(CleanupHook.create(self.global_scope, name, value))
- def install_global_by_id(self, prefix, value) -> str:
- """
- Installs a global if it hasn't been installed already.
- This is determined by (prefix, id(value)) pair.
- Returns the name of the newly installed global.
- """
- # NB: need self.compile_id to distinguish this global
- # from another global created in a different torch.compile instance
- name = f"{prefix}_{id(value)}_c{self.compile_id}"
- if name in self.installed_globals:
- return name
- self.install_global_unsafe(name, value)
- return name
- def install_global(self, prefix, value) -> str:
- """
- Installs a global, generating a unique name for it.
- Returns the name of the newly installed global.
- """
- # NB: unique_id is unique, even across torch.compile instances
- name = unique_id(prefix)
- self.install_global_unsafe(name, value)
- return name
- def cleanup(self) -> None:
- # There is a reference cycle between tracer and OutputGraph, causing
- # some of the tensor objects to be held alive for longer than necessary.
- self.root_tx = None
- self.nn_modules.clear()
- self.param_name_to_source = None
- for node in self.graph.nodes:
- if "grapharg" in node.meta:
- del node.meta["grapharg"]
- self.real_value_cache.clear()
- self.input_name_to_proxy.clear()
- self.side_effects.clear()
- self.variable_tracker_cache.clear()
- self.register_finalizer_fns.clear()
- self.dynamo_flat_name_to_original_fqn.clear()
- self.tracing_context.clear()
- def set_torch_function_state(self, enabled: bool) -> None:
- self.torch_function_enabled = enabled
- def add_graph_finalizer(
- self, register_finalizer: Callable[[fx.GraphModule], None]
- ) -> None:
- self.register_finalizer_fns.append(register_finalizer)
- def example_value_from_input_node(self, node: torch.fx.Node):
- """Extract the non-fake example tensor"""
- if node.op == "placeholder":
- return node.meta["grapharg"].example
- assert node.op == "get_attr"
- return self.nn_modules[node.target] # type: ignore[index]
- err_epilogue = (
- "With the current config, we will graph break "
- "(and fall back to eager-mode PyTorch) on all ops "
- "that have do not have the 'pt2_compliant_tag'. "
- "Please see the following doc for how to mark this op as PT2 compliant "
- "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html"
- )
- def check_pt2_compliant_op(output_graph, kind, target, args, kwargs):
- if kind != "call_function":
- return
- def encountered_compliant_op(target):
- if target.namespace in {"prim", "prims", "aten"}:
- return
- output_graph.compliant_custom_ops.add(target)
- def encountered_non_compliant_op(target, msg):
- output_graph.non_compliant_ops.add(target)
- if config.only_allow_pt2_compliant_ops:
- unimplemented(msg + " " + err_epilogue)
- if isinstance(target, torch._ops.OpOverload):
- if torch.Tag.pt2_compliant_tag in target.tags:
- encountered_compliant_op(target)
- return
- encountered_non_compliant_op(
- target,
- f"Encountered the torch.ops.OpOverload {target} "
- f"that is not PT2 compliant.",
- )
- return
- if isinstance(target, torch._ops.OpOverloadPacket):
- overloads = tuple(target.overloads())
- # Optimization: Overload resolution is expensive.
- # If there's only one overload, we know what it will resolve to.
- if len(overloads) == 1:
- op = getattr(target, overloads[0])
- if torch.Tag.pt2_compliant_tag in op.tags:
- encountered_compliant_op(op)
- return
- encountered_non_compliant_op(
- op,
- f"Encountered the non-overloaded "
- f"torch.ops.OpOverloadPacket {target} "
- f"that is not PT2 compliant. ",
- )
- return
- args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes(
- output_graph.current_tx, (args, kwargs), False
- )
- try:
- overload = torch._C._jit_resolve_packet(
- target._qualified_op_name, *args, **kwargs
- )
- except RuntimeError as e:
- unimplemented(str(e))
- op = getattr(target, overload)
- if torch.Tag.pt2_compliant_tag in op.tags:
- encountered_compliant_op(op)
- else:
- encountered_non_compliant_op(
- op,
- f"Encountered the torch.ops.OpOverloadPacket {target} "
- f"which resolves to the overload ({overload}) that is "
- f"not PT2 compliant.",
- )
- _compile_id_counter = itertools.count()
- class SubgraphTracer(fx.Tracer):
- """
- Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer
- and the separation of responsibilities is that SubgraphTracer is
- responsible for building the graph while OutputGraph is responsible for
- compiling and executing the graph.
- """
- def __init__(
- self, output_graph, parent=None, export_root=False, source_target=None
- ):
- super().__init__()
- self.output_graph = weakref.proxy(output_graph)
- self.graph = torch.fx.Graph()
- # The export is only ever set for the ROOT tracer. It controls
- # whether or not certain inputs are allowed to be added or not.
- # Look at call sites of create_graph_input to see how it is used.
- if export_root:
- assert parent is None
- self.export_root = export_root
- # Map from graph input name to its placeholder proxy object, where the
- # map's keys give all current placeholder node names and can be used to
- # create unique node names
- self.input_name_to_proxy: Dict[str, fx.Proxy] = {}
- # Node => computed real value (see utils.get_real_value)
- self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}
- # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design]
- self.parent = parent
- # A dict mapping previously free variables (Proxy objects)
- # to new Proxy objects that wrap inputs to this subgraph.
- #
- # This dict serves two purposes:
- # - Proxies are associated with VariableTrackers. If we see
- # the same VariableTracker twice (and it is a free variable),
- # then we want to use the same Proxy in the current subgraph to
- # record the tracing.
- # - If we are tracing a HigherOrderOperator's body_fn, then we
- # need to keep track of what free variables were lifted so we can
- # rewrite the HigherOrderOperator call using the traced body_fn.
- # Dicts maintain the order of args for the HigherOrderOperator call.
- self.lifted_freevars = {}
- self.prev_inst = None
- self._cur_code = None
- self._orig_gm_meta = None
- self._orig_gm_lineno_map = None
- self._orig_gm_firstlineno = None
- # Each SubgraphTracer is associated with a source target, which indicates
- # which operator this subgraph is attached to. We compute a source_fn_stack
- # based on the source target. For the root tracer, it's set to [].
- # This is useful for debugging and transforming the exported graph.
- if self.parent is None:
- self.source_fn_stack = []
- else:
- self.source_fn_stack = self.parent.source_fn_stack + [
- (self.graph._target_to_str(source_target), source_target)
- ]
- def create_proxy(
- self,
- kind,
- target,
- args,
- kwargs,
- name=None,
- type_expr=None,
- proxy_factory_fn=None,
- ):
- # NOTE: [Nested SubgraphTracer and free_variable handling]
- # --------------------------------------------------------
- # Read NOTE [HigherOrderOperator tracing design] first.
- #
- # Let's say we're in the middle of introspecting the body of a possibly
- # nested HigherOrderOperator, and we see a free variable.
- #
- # There are two cases:
- # 1. We see a free variable that is already tracked by Dynamo.
- # 2. We see a free variable that has not been tracked by Dynamo
- #
- # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below)
- # which will lift the freevar to be an input of this subgraph
- # and also recursively lift it to be an input on the parent(s).
- #
- # In case 2, before the call to `create_proxy`, the InstructionTranslator
- # will see the freevar when it gets loaded by Python bytecode.
- # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or
- # LOAD_GLOBAL.
- # There, the InstructionTranslator asks Dynamo to begin tracking the
- # freevar by building a new Variable.
- # Building a new Variable automatically lifts the freevar to be an
- # input of the root SubgraphTracer.
- #
- # The implications for the code below are:
- # - We will always be in Case 1 when we get to this code.
- # - Any "free variable" we encounter here is guaranteed to already be
- # bound, that is, it is either a graph input of the root graph, or
- # some local variable of the root graph or a subgraph.
- # - The additional work we need to do here is *only* that we need to
- # lift this free variable into inputs (recursively) of each nested
- # higher-order-op subgraph until we hit the subgraph where the free
- # variable is bound
- if self.parent is not None:
- flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
- new_flat_args = []
- for arg in flat_args:
- maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg)
- new_flat_args.append(maybe_new_arg)
- args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec)
- rv = super().create_proxy(
- kind, target, args, kwargs, name, type_expr, proxy_factory_fn
- )
- # append stack trace to fx node
- tx = self.output_graph.current_tx
- # log detailed location of line of code in 3.11
- if sys.version_info >= (3, 11) and kind in (
- "call_function",
- "call_method",
- "call_module",
- ):
- cur_inst = tx.current_instruction
- if (
- cur_inst is not self.prev_inst
- and cur_inst.positions is not None
- and cur_inst.positions.lineno is not None
- ):
- tx_code = tx.f_code
- header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno)
- def get_trace_call_log_str():
- line = get_instruction_source_311(tx_code, cur_inst).rstrip()
- return f"TRACE FX call {rv.node.name} from {header}\n{line}"
- trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
- self.prev_inst = cur_inst
- # update reference to original meta if we're tracing a new code object
- is_retracing = False
- if tx.f_code is not self._cur_code:
- orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
- "orig_graphmodule", lambda: None
- )()
- if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
- is_retracing = True
- self._orig_gm_meta = [
- nd.meta for nd in orig_graphmodule_maybe.graph.nodes
- ]
- self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map
- self._orig_gm_firstlineno = (
- orig_graphmodule_maybe.forward.__code__.co_firstlineno
- )
- else:
- self._orig_gm_meta = None
- self._orig_gm_lineno_map = None
- self._orig_gm_firstlineno = None
- nn_module_stack = tx.nn_module_stack
- if nn_module_stack:
- rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
- if kind in {"call_function", "call_method"}:
- rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
- (rv.node.name, target)
- ]
- elif kind == "call_module":
- if self.parent is not None:
- unimplemented("Invoking an nn.Module inside HigherOrderOperator")
- # For modules we store the class
- rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
- (
- rv.node.name,
- rv.node.meta["nn_module_stack"][target][1],
- )
- ]
- # preserve original meta if it is available
- if (
- self._orig_gm_meta
- and self._orig_gm_lineno_map
- and self._orig_gm_firstlineno
- ):
- lineno = tx.current_instruction.starts_line
- node_idx = None
- if lineno is not None:
- node_idx = self._orig_gm_lineno_map.get(
- lineno - self._orig_gm_firstlineno, None
- )
- if node_idx is not None:
- meta = self._orig_gm_meta[node_idx]
- for field in fx.proxy._COPY_META_FIELDS:
- if field in meta:
- rv.node.meta[field] = meta[field]
- if "stack_trace" in meta:
- rv.node.meta["stack_trace"] = meta["stack_trace"]
- if not is_retracing:
- if "nn_module_stack" not in rv.node.meta:
- nn_module_stack = tx.nn_module_stack
- if nn_module_stack:
- rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
- if "source_fn_stack" not in rv.node.meta:
- if kind in {"call_function", "call_method"}:
- rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
- (rv.node.name, target)
- ]
- elif kind == "call_module":
- if self.parent is not None:
- unimplemented(
- "Invoking an nn.Module inside HigherOrderOperator"
- )
- # For modules we store the class
- rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
- (
- rv.node.name,
- rv.node.meta["nn_module_stack"][target][1],
- )
- ]
- if "stack_trace" not in rv.node.meta:
- frame_summaries: List[traceback.FrameSummary] = []
- while tx:
- frame_summaries.append(tx.frame_summary())
- tx = getattr(tx, "parent", None)
- # Reverse the frame_summaries, such that the innermost frame is at the last
- frame_summaries.reverse()
- # official from_list stub doesn't have new-style type
- msgs = traceback.StackSummary.from_list(frame_summaries).format()
- rv.node.stack_trace = "".join(msgs)
- return rv
- def create_node(
- self, op, target, args=None, kwargs=None, name=None, type_expr=None
- ):
- check_pt2_compliant_op(self.output_graph, op, target, args, kwargs)
- if self.parent is not None:
- flat_args = pytree.arg_tree_leaves(*args, **kwargs)
- for arg in flat_args:
- if not isinstance(arg, torch.fx.Node):
- continue
- assert (
- arg.graph == self.graph
- ), "create_node using arg not from this SubgraphTracer"
- node = super().create_node(op, target, args, kwargs, name, type_expr)
- node.meta["creation_timestamp"] = self.output_graph.timestamp
- return node
- # Note: we did not override erase_node since
- # we call self.graph.erase_node elsewhere
- def remove_node(self, node):
- if len(node.users) > 0:
- user_graph_nodes: List[torch.fx.Node] = []
- for user in node.users.keys():
- # For the case where user.graph == self.graph, that is a real bug and will raise
- # properly.
- if user.graph != self.graph:
- # This is a nested graph, which needs to be deleted.
- # If we do not do this, we will raise on attempting to remove this.
- # As we only get here during restoration cleanup, this is sound.
- user_graph_nodes.extend(reversed(list(user.graph.nodes)))
- for other_graph_node in user_graph_nodes:
- other_graph_node.graph.erase_node(other_graph_node)
- self.graph.erase_node(node)
- self.input_name_to_proxy.pop(node.name, None)
- # when before=True, we will insert this input before the most recent
- # inserted proxy. This is a hack to get around an ordering problem,
- # where we first insert a tensor argument, and then insert bindings
- # for SymInts that may occur in the tensor argument.
- # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets
- # fixed.
- def create_graph_input(self, name, type_expr=None, before=False, source=None):
- log.debug(
- "create_graph_input %s %s",
- name,
- source.name() if source is not None else "(none)",
- )
- if source is None:
- assert (
- self.parent is not None
- ), "you are required to provide a source for inputs on the root tracer"
- # In eager, we are generally OK with adding graph inputs whenever we
- # want, because we take care of writing the bytecode that knows how
- # to source all the inputs.
- #
- # In export, this is bad, because you want a self-contained export
- # object which only depends on the inputs you explicitly passed to it.
- # So we are a bit more strict about what sources can become inputs
- # in export
- if self.export_root:
- if not is_from_local_source(source, allow_cell_or_freevar=False):
- self.output_graph.source_to_user_stacks.setdefault(source, []).append(
- TracingContext.extract_stack()
- )
- # unique
- if name in self.input_name_to_proxy:
- for i in itertools.count():
- candidate_name = f"{name}_{i}"
- if candidate_name not in self.input_name_to_proxy:
- name = candidate_name
- break
- if self.input_name_to_proxy:
- prev_name = next(reversed(self.input_name_to_proxy))
- node = self.input_name_to_proxy[prev_name].node
- if before:
- ctx = self.graph.inserting_before(node)
- else:
- ctx = self.graph.inserting_after(node)
- else:
- ctx = self.graph.inserting_before(None)
- with ctx:
- proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
- if self.input_name_to_proxy and before:
- k, v = self.input_name_to_proxy.popitem()
- self.input_name_to_proxy[name] = proxy
- self.input_name_to_proxy[k] = v
- else:
- self.input_name_to_proxy[name] = proxy
- return proxy
- # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details
- def lift_tracked_freevar_to_input(self, proxy):
- # You're doing something wrong if we are the root SubgraphTracer because
- # Dynamo adds tensors to graph inputs before creating a proxy for them.
- assert (
- self.parent is not None
- ), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer"
- # Proxys are associated with VariableTracker.
- # It is possible that we've already lifted the Proxy to be an input.
- # If that is the case, just return the already lifted Proxy.
- if proxy in self.lifted_freevars:
- return self.lifted_freevars[proxy]
- new_proxy = self.create_graph_input(proxy.node.name)
- set_example_value(new_proxy.node, proxy.node.meta["example_value"])
- self.lifted_freevars[proxy] = new_proxy
- if self.parent is not None and proxy.tracer != self.parent:
- self.parent.lift_tracked_freevar_to_input(proxy)
- return new_proxy
- def maybe_lift_tracked_freevar_to_input(self, arg):
- """
- If arg is a free variable, then lift it to be an input.
- Returns the new lifted arg (if arg was a freevar), else the
- original arg.
- """
- if not isinstance(arg, torch.fx.Proxy):
- return arg
- elif arg.tracer == self:
- return arg
- return self.lift_tracked_freevar_to_input(arg)
- # NOTE: [HigherOrderOperator tracing design]
- # Ignoring HigherOrderOperators for a moment,
- # OutputGraph represents the graph being built by Dynamo that may be compiled
- # and executed. It holds a root SubgraphTracer where the FX graph is built.
- #
- # HigherOrderOperators are operators that take functions as their arguments.
- # When Dynamo encounters a HigherOrderOperator, then it attempts to introspect
- # the function passed to it (call this the "body function"), capture it into a
- # GraphModule, and rewrite the call to the HigherOrderOperator to use the
- # GraphModule.
- #
- # The way we handle the capture of body functions is through having
- # (possibly nested) SubgraphTracers, one per body function.
- #
- # Mechanically, we do the introspection by:
- # - Creating a new SubgraphTracer via OutputGraph.subtracer
- # - Executing the body function.
- # This constructs the graph of the body function in the new SubgraphTracer
- # while modifying the state of the OutputGraph. For example:
- # - the OutputGraph can receive new GraphArgs (if we discover any new
- # untracked Tensors)
- # - side effects from the body function get accumulated into
- # OutputGraph.side_effects
- # - guards produced by the body function get accumulated into OutputGraph.guards
- #
- # The traced function has some special properties that make it easier for us
- # to transform later down the line:
- # - we lift all free variables to being inputs.
- #
- # If the introspection fails (due to the existence of graph breaks), then
- # we roll back the current OutputGraph state and graph break on the
- # HigherOrderOperator.
|