| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634 |
- # mypy: allow-untyped-defs
- # mypy: disable-error-code="method-assign"
- """
- Functions in this file are responsible for modifying the eval frame
- handler at RUNTIME. Therefore, all functions in this file are hot.
- Functions that only execute at compile time should be placed
- in torch._dynamo.convert_frame.
- """
- from __future__ import annotations
- import contextlib
- import functools
- import inspect
- import logging
- import os
- import sys
- import textwrap
- import traceback
- import types
- import warnings
- import weakref
- from enum import Enum
- from os.path import dirname, join
- from typing import (
- Any,
- Callable,
- Dict,
- List,
- NamedTuple,
- Optional,
- Set,
- Tuple,
- TYPE_CHECKING,
- Union,
- )
- from unittest.mock import patch
- import torch
- import torch.fx
- import torch.utils._pytree as pytree
- import torch.utils.checkpoint
- from torch import _guards
- from torch._utils_internal import log_export_usage
- from torch.export.dynamic_shapes import _process_dynamic_shapes
- from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
- from torch.fx.experimental.symbolic_shapes import (
- ConstraintViolationError,
- DimDynamic,
- StatelessSymbolicContext,
- )
- from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
- from ..fx import GraphModule
- from .backends.registry import CompilerFn, lookup_backend
- from .hooks import Hooks
- # see discussion at https://github.com/pytorch/pytorch/issues/120699
- reset_code = torch._C._dynamo.eval_frame.reset_code # noqa: F401
- set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame # noqa: F401
- set_guard_error_hook = torch._C._dynamo.eval_frame.set_guard_error_hook # noqa: F401
- skip_code = torch._C._dynamo.eval_frame.skip_code # noqa: F401
- unsupported = torch._C._dynamo.eval_frame.unsupported # noqa: F401
- from . import config, convert_frame, external_utils, trace_rules, utils
- from .code_context import code_context
- from .exc import CondOpArgsMismatchError, UserError, UserErrorType
- from .mutation_guard import install_generation_tagging_init
- from .utils import common_constant_types, compile_times
- log = logging.getLogger(__name__)
- from torch._dispatch.python import enable_python_dispatcher
- always_optimize_code_objects = utils.ExactWeakKeyDictionary()
- null_context = contextlib.nullcontext
- import sympy
- if TYPE_CHECKING:
- from torch._subclasses import fake_tensor
- from .types import CacheEntry, DynamoCallback
- # See https://github.com/python/typing/pull/240
- class Unset(Enum):
- token = 0
- cached_backends: Dict[int, CompilerFn] = {}
- unset = Unset.token
- def _reset_guarded_backend_cache():
- global cached_backends
- for backend in cached_backends.values():
- if hasattr(backend, "reset"):
- backend.reset()
- cached_backends.clear()
- DONT_WRAP_FILES = {
- # For tracing into fx modules
- inspect.getsourcefile(GraphModule),
- join(dirname(dirname(__file__)), "onnx/_internal/fx/dynamo_graph_extractor.py"),
- }
- def _debug_get_cache_entry_list(
- code: Union[types.CodeType, Callable[..., Any]]
- ) -> List[CacheEntry]:
- """
- Given a code object or a callable object, retrieve the cache entries
- stored in this code.
- """
- if callable(code):
- code = code.__code__
- return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
- class OptimizedModule(torch.nn.Module):
- """
- Wraps the original nn.Module object and later patches its
- forward method to optimized self.forward method.
- """
- _torchdynamo_orig_callable: Callable[..., Any]
- get_compiler_config: Callable[[], Any]
- _opt_mod_attributes = {
- "_orig_mod",
- "dynamo_ctx",
- "_torchdynamo_orig_callable",
- "get_compiler_config",
- "forward",
- "_forward",
- "__dict__",
- "named_children_walk",
- }
- def __init__(self, mod: torch.nn.Module, dynamo_ctx):
- super().__init__()
- # Installs the params/buffer
- self._orig_mod = mod
- self.dynamo_ctx = dynamo_ctx
- self._initialize()
- def _initialize(self):
- # Do this stuff in constructor to lower overhead slightly
- if isinstance(self.dynamo_ctx, DisableContext):
- # No need to check trace rules
- self.forward = self.dynamo_ctx(self._orig_mod.__call__)
- elif isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check(
- self._orig_mod.forward
- ):
- # This may be a torch.nn.* instance in trace_rules.py which
- # won't trigger a frame evaluation workaround to add an extra
- # frame we can capture
- self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod))
- else:
- # Invoke hooks outside of dynamo then pickup the inner frame
- self.forward = self.dynamo_ctx(self._orig_mod.__call__)
- if hasattr(self._orig_mod, "_initialize_hook"):
- self._forward = self.forward
- self.forward = self._call_lazy_check
- def __reduce__(self):
- return (self.__class__, (self._orig_mod, self.dynamo_ctx))
- def __getstate__(self):
- state = dict(self.__dict__)
- state.pop("forward", None)
- state.pop("__call__", None)
- return state
- def __setstate__(self, state):
- self.__dict__ = state
- self._initialize()
- def __getattr__(self, name):
- if name == "_orig_mod":
- return self._modules["_orig_mod"]
- return getattr(self._orig_mod, name)
- def __setattr__(self, name, val):
- # Allow patching over class attributes
- if hasattr(type(self), name):
- return super().__setattr__(name, val)
- if name in OptimizedModule._opt_mod_attributes:
- return super().__setattr__(name, val)
- return setattr(self._orig_mod, name, val)
- def _call_lazy_check(self, *args, **kwargs):
- if hasattr(self._orig_mod, "_initialize_hook"):
- # In the case of a lazy module, we want to run
- # the pre-hooks which initialize it.
- # Afterwards, lazy module deletes its pre-hooks
- # to avoid treating it as lazy on subsequent recompile.
- self._orig_mod._infer_parameters(self._orig_mod, args, kwargs)
- return self._forward(*args, **kwargs)
- def __dir__(self):
- orig_mod_attrs = self._orig_mod.__dir__()
- return orig_mod_attrs + [
- attr for attr in super().__dir__() if attr not in orig_mod_attrs
- ]
- def remove_from_cache(f):
- """
- Make sure f.__code__ is not cached to force a recompile
- """
- if isinstance(f, types.CodeType):
- reset_code(f)
- elif hasattr(f, "__code__"):
- reset_code(f.__code__)
- elif hasattr(getattr(f, "forward", None), "__code__"):
- reset_code(f.forward.__code__)
- else:
- from . import reset # type: ignore[attr-defined]
- reset()
- log.warning("could not determine __code__ for %s", f)
- def nothing():
- pass
- def always_false():
- return False
- def innermost_fn(fn):
- """
- In case of nesting of _TorchDynamoContext calls, find the innermost
- function. TorchDynamo caches on fn.__code__ object, so its necessary to find
- the innermost function to pass on the optimize, run, disable etc.
- """
- unaltered_fn = fn
- while hasattr(unaltered_fn, "_torchdynamo_orig_callable"):
- unaltered_fn = unaltered_fn._torchdynamo_orig_callable
- assert callable(unaltered_fn)
- return unaltered_fn
- def make_set_enable_dynamic(enable: bool):
- assert isinstance(enable, bool)
- if enable:
- # Assume everything is dynamic by default
- return config._make_closure_patcher(assume_static_by_default=False)
- else:
- return config._make_closure_patcher(
- automatic_dynamic_shapes=False, assume_static_by_default=True
- )
- class _TorchDynamoContext:
- def __init__(
- self,
- callback: DynamoCallback,
- on_enter=nothing,
- backend_ctx_ctor=null_context,
- patch_fn=nothing,
- first_ctx=False,
- *,
- export=False,
- dynamic=None,
- compiler_config=None,
- ):
- super().__init__()
- assert callable(callback) or callback is False or callback is None
- self.callback: DynamoCallback = callback
- self._backend_ctx_ctor = backend_ctx_ctor
- self.prior: Union[Unset, DynamoCallback] = unset
- self.first_ctx = first_ctx
- self.export = export
- self._dynamic = dynamic
- self.compiler_config = compiler_config
- self.cleanup_fns: List[Callable[[], Any]] = []
- self.enter_exit_hooks = []
- patch_fn()
- # Save the backends so that we can reset them during torch._dynamo.reset
- backend = innermost_fn(callback)
- cached_backends.setdefault(id(backend), backend)
- if dynamic is not None:
- self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic))
- if on_enter is not nothing:
- # this case is not common
- def call_on_enter():
- on_enter()
- return nothing
- self.enter_exit_hooks.append(call_on_enter)
- if backend_ctx_ctor is not contextlib.nullcontext:
- # this case is not common
- def call_backend_ctx():
- ctx = backend_ctx_ctor()
- ctx.__enter__()
- return functools.partial(ctx.__exit__, None, None, None)
- self.enter_exit_hooks.append(call_backend_ctx)
- def __enter__(self):
- if config.raise_on_ctx_manager_usage:
- raise RuntimeError(
- "torch._dynamo.optimize(...) is used with a context manager. "
- "Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html "
- "to use torch._dynamo.optimize(...) as an annotation/decorator. "
- )
- self.cleanup_fns = [enter() for enter in self.enter_exit_hooks]
- self.prior = set_eval_frame(self.callback)
- def __exit__(self, exc_type, exc_val, exc_tb):
- assert self.prior is not unset
- set_eval_frame(self.prior)
- self.prior = unset
- for cleanup in self.cleanup_fns:
- cleanup()
- self.cleanup_fns.clear()
- def __call__(self, fn):
- # public api for compiler config/options
- def get_compiler_config():
- return self.compiler_config
- fn = innermost_fn(fn)
- # add context containing GraphModule to any GraphModule forward functions
- if isinstance(fn, GraphModule):
- # add context containing GraphModule to any GraphModule forward functions
- code_context.get_context(fn.forward.__code__)[
- "orig_graphmodule"
- ] = weakref.ref(fn)
- # Optimize the forward method of torch.nn.Module object
- if isinstance(fn, torch.nn.Module):
- mod = fn
- new_mod = OptimizedModule(mod, self)
- # Save the function pointer to find the original callable while nesting
- # of decorators.
- new_mod._torchdynamo_orig_callable = mod.forward
- # when compiling torch.nn.Module,
- # provide public api OptimizedModule.get_compiler_config()
- assert not hasattr(new_mod, "get_compiler_config")
- new_mod.get_compiler_config = get_compiler_config
- return new_mod
- if inspect.isclass(fn):
- # User has wrapped the class with compile/disable decorator. Apply
- # disable to init/call method.
- cls_obj = fn
- cls_obj.__call__ = self(cls_obj.__call__)
- if issubclass(cls_obj, torch.nn.Module):
- # NN module variable tracker directly inlines the _call_impl.
- cls_obj._call_impl = self(cls_obj._call_impl)
- return cls_obj
- assert callable(fn)
- try:
- filename = inspect.getsourcefile(fn)
- except TypeError:
- filename = None
- if (
- (filename is None or trace_rules.check(fn))
- and (
- getattr(fn, "__name__", "")
- not in ["_call_impl", "_wrapped_call_impl", "_lazy_forward"]
- )
- and filename not in DONT_WRAP_FILES
- ):
- # call to a builtin without a frame for us to capture
- fn = external_utils.wrap_inline(fn)
- def do_nothing(*arg, **kwargs):
- pass
- if hasattr(self, "callback"):
- callback = self.callback
- else:
- callback = do_nothing
- is_jit_tracing = torch._C._is_tracing
- is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing
- @functools.wraps(fn)
- def _fn(*args, **kwargs):
- if is_fx_tracing():
- if config.error_on_nested_fx_trace:
- raise RuntimeError(
- "Detected that you are using FX to symbolically trace "
- "a dynamo-optimized function. This is not supported at the moment."
- )
- else:
- return fn(*args, **kwargs)
- if is_jit_tracing():
- if config.error_on_nested_jit_trace:
- raise RuntimeError(
- "Detected that you are using FX to torch.jit.trace "
- "a dynamo-optimized function. This is not supported at the moment."
- )
- else:
- return fn(*args, **kwargs)
- cleanups = [enter() for enter in self.enter_exit_hooks]
- prior = set_eval_frame(callback)
- # Ensure that if an assertion occurs after graph pushes
- # something onto the DynamicLayerStack then we pop it off (the
- # constructed graph code isn't guarded with try/finally).
- #
- # This used to be a context but putting a `with` here is a noticible
- # perf regression (#126293)
- saved_dynamic_layer_stack_depth = (
- torch._C._functorch.get_dynamic_layer_stack_depth()
- )
- try:
- return fn(*args, **kwargs)
- finally:
- # Restore the dynamic layer stack depth if necessary.
- torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
- saved_dynamic_layer_stack_depth
- )
- set_eval_frame(prior)
- for cleanup in cleanups:
- cleanup()
- # hooks to properly handle inlining
- _fn._torchdynamo_inline = fn # type: ignore[attr-defined]
- # Save the function pointer to find the original callable while nesting
- # of decorators.
- _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
- # when compiling user function instead of nn.Module
- # provide public api _fn.get_compiler_config()
- assert not hasattr(_fn, "get_compiler_config")
- _fn.get_compiler_config = get_compiler_config # type: ignore[attr-defined]
- # If the function is called using torch._dynamo.optimize decorator, we
- # should prevent any type of skipping.
- if callback not in (None, False):
- if not hasattr(fn, "__code__"):
- raise RuntimeError(
- textwrap.dedent(
- """
- torch._dynamo.optimize is called on a non function object.
- If this is a callable class, please wrap the relevant code into a function and optimize the
- wrapper function.
- >> class CallableClass:
- >> def __init__(self):
- >> super().__init__()
- >> self.relu = torch.nn.ReLU()
- >>
- >> def __call__(self, x):
- >> return self.relu(torch.sin(x))
- >>
- >> def print_hello(self):
- >> print("Hello world")
- >>
- >> mod = CallableClass()
- If you want to optimize the __call__ function and other code, wrap that up in a function
- >> def wrapper_fn(x):
- >> y = mod(x)
- >> return y.sum()
- and then optimize the wrapper_fn
- >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn)
- """
- )
- )
- always_optimize_code_objects[fn.__code__] = True
- return _fn
- class OptimizeContext(_TorchDynamoContext):
- def __init__(
- self,
- callback,
- backend_ctx_ctor,
- first_ctx=False,
- *,
- export=False,
- dynamic=None,
- compiler_config=None,
- rebuild_ctx: Optional[
- Callable[[], Union[OptimizeContext, _NullDecorator]]
- ] = None,
- ):
- def on_enter():
- install_generation_tagging_init()
- super().__init__(
- callback=callback,
- on_enter=on_enter,
- backend_ctx_ctor=backend_ctx_ctor,
- patch_fn=TorchPatcher.patch,
- first_ctx=first_ctx,
- export=export,
- dynamic=dynamic,
- compiler_config=compiler_config,
- )
- if config.compiled_autograd:
- def call_compiled_autograd():
- assert rebuild_ctx is not None
- compiler_fn = rebuild_ctx()
- ctx = torch._dynamo.compiled_autograd.enable(compiler_fn)
- ctx.__enter__()
- return functools.partial(ctx.__exit__, None, None, None)
- self.enter_exit_hooks.append(call_compiled_autograd)
- def __reduce__(self):
- return (
- self.__class__,
- (self.callback, self._backend_ctx_ctor, self.first_ctx),
- {
- "export": self.export,
- "dynamic": self._dynamic,
- "compiler_config": self.compiler_config,
- },
- )
- class RunOnlyContext(_TorchDynamoContext):
- def __init__(self):
- # cudagraph trees relies on generation increment
- def on_enter():
- torch._dynamo.mutation_guard.GenerationTracker.generation += 1
- super().__init__(callback=False, on_enter=on_enter)
- def __reduce__(self):
- return (self.__class__, ())
- class DisableContext(_TorchDynamoContext):
- def __init__(self):
- super().__init__(callback=None)
- def __call__(self, fn):
- # Earlier this code was in the base class _TorchDynamoContext. But we
- # moved it here to have better code organization. For disable, we just
- # want the callback to be None. We don't have to check trace_rules or
- # create any wrapper.
- fn = innermost_fn(fn)
- if isinstance(fn, torch.nn.Module):
- mod = fn
- new_mod = OptimizedModule(mod, self)
- new_mod._torchdynamo_orig_callable = mod.forward
- return new_mod
- if inspect.isclass(fn):
- # User has wrapped the class with compile/disable decorator. Apply
- # disable to init/call method.
- cls_obj = fn
- # Disable on init is useful for reconstruction of bytecodes where we
- # want to prevent Dynamo from tracing into the init function. Check
- # test_reconstruction in test_model_output.py.
- cls_obj.__init__ = self(cls_obj.__init__)
- cls_obj.__call__ = self(cls_obj.__call__)
- if issubclass(cls_obj, torch.nn.Module):
- # NN module variable tracker directly inlines the _call_impl. Disable it.
- cls_obj._call_impl = self(cls_obj._call_impl)
- return cls_obj
- assert callable(fn)
- callback = self.callback
- @functools.wraps(fn)
- def _fn(*args, **kwargs):
- prior = set_eval_frame(callback)
- try:
- return fn(*args, **kwargs)
- finally:
- set_eval_frame(prior)
- _fn._torchdynamo_disable = True # type: ignore[attr-defined]
- # Save the function pointer to find the original callable while nesting
- # of decorators.
- _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
- return _fn
- def __reduce__(self):
- return (self.__class__, ())
- def _optimize_catch_errors(
- compile_fn,
- hooks: Hooks,
- backend_ctx_ctor=null_context,
- export=False,
- dynamic=None,
- compiler_config=None,
- rebuild_ctx=None,
- ):
- return OptimizeContext(
- convert_frame.catch_errors_wrapper(compile_fn, hooks),
- backend_ctx_ctor=backend_ctx_ctor,
- first_ctx=True,
- export=export,
- dynamic=dynamic,
- compiler_config=compiler_config,
- rebuild_ctx=rebuild_ctx,
- )
- def get_compiler_fn(compiler_fn):
- from .repro.after_dynamo import wrap_backend_debug
- if hasattr(compiler_fn, "compiler_name"):
- compiler_str = compiler_fn.compiler_name
- elif isinstance(compiler_fn, str):
- compiler_str = compiler_fn
- else:
- compiler_str = None
- compiler_fn = lookup_backend(compiler_fn)
- return wrap_backend_debug(compiler_fn, compiler_str)
- class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
- def __call__(self, fn):
- assert callable(fn)
- return fn
- def check_if_dynamo_supported():
- if sys.version_info >= (3, 13):
- raise RuntimeError("Python 3.13+ not yet supported for torch.compile")
- def is_dynamo_supported():
- try:
- check_if_dynamo_supported()
- return True
- except Exception:
- return False
- def check_if_inductor_supported():
- check_if_dynamo_supported()
- if sys.platform == "win32":
- raise RuntimeError("Windows not yet supported for inductor")
- def is_inductor_supported():
- try:
- check_if_inductor_supported()
- return True
- except Exception:
- return False
- def optimize(*args, **kwargs):
- def rebuild_ctx():
- return optimize(*args, **kwargs)
- return _optimize(rebuild_ctx, *args, **kwargs)
- def _optimize(
- rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]],
- backend="inductor",
- *,
- nopython=False,
- guard_export_fn=None,
- guard_fail_fn=None,
- disable=False,
- dynamic=None,
- ) -> Union[OptimizeContext, _NullDecorator]:
- """
- The main entrypoint of TorchDynamo. Do graph capture and call
- backend() to optimize extracted graphs.
- Args:
- backend: One of the two things:
- - Either, a function/callable taking a torch.fx.GraphModule and
- example_inputs and returning a python callable that runs the
- graph faster.
- One can also provide additional context for the backend, like
- torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute.
- See AOTAutogradMemoryEfficientFusionWithContext for the usage.
- - Or, a string backend name in `torch._dynamo.list_backends()`
- nopython: If True, graph breaks will be errors and there will
- be a single whole-program graph.
- disable: If True, turn this decorator into a no-op
- dynamic: If True, upfront compile as dynamic a kernel as possible. If False,
- disable all dynamic shapes support (always specialize). If None, automatically
- detect when sizes vary and generate dynamic kernels upon recompile.
- Example Usage::
- @torch._dynamo.optimize()
- def toy_example(a, b):
- ...
- """
- check_if_dynamo_supported()
- # Note: The hooks object could be global instead of passed around, *however* that would make
- # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
- # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
- # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an
- # easier to understand UX at the cost of a little more plumbing on our end.
- hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn)
- torch._C._log_api_usage_once("torch._dynamo.optimize")
- if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1":
- return _NullDecorator()
- backend = get_compiler_fn(backend)
- # Find if backend has any extra context manager
- backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
- if nopython:
- return optimize_assert(
- backend,
- dynamic=dynamic,
- hooks=hooks,
- rebuild_ctx=rebuild_ctx,
- )
- # The backend function is stashed in the callable returned by
- # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
- # be used by eval_frame.c to insert a guard on the backend.
- return _optimize_catch_errors(
- convert_frame.convert_frame(backend, hooks=hooks),
- hooks,
- backend_ctx_ctor,
- dynamic=dynamic,
- compiler_config=backend.get_compiler_config()
- if hasattr(backend, "get_compiler_config")
- else None,
- rebuild_ctx=rebuild_ctx,
- )
- # TODO(voz): Consider making "explain" output alongside a run / part of a run
- @patch("torch._dynamo.symbolic_convert.explain", True)
- def explain(f, *extra_args, **extra_kwargs):
- def inner(*args, **kwargs):
- # TODO(voz): Do we want a decorator for this?
- from . import reset # type: ignore[attr-defined]
- reset()
- graphs: List[torch.fx.GraphModule] = []
- break_reasons: List[Any] = []
- op_count: int = 0
- ops_per_graph: List[torch.fx.Node] = []
- out_guards: List[_guards.Guard] = []
- def dynamo_graph_accumulating_compiler(
- gm: torch.fx.GraphModule, example_inputs
- ):
- from .backends.debugging import _explain_graph_detail
- nonlocal graphs
- nonlocal op_count
- nonlocal ops_per_graph
- nonlocal break_reasons
- gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail(
- gm, graphs, op_count, ops_per_graph, break_reasons
- )
- return gm.forward
- def guard_export_print(guards):
- nonlocal out_guards
- out_guards.extend(guards)
- opt_f = optimize(
- dynamo_graph_accumulating_compiler,
- nopython=False,
- guard_export_fn=guard_export_print,
- )(f)
- # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
- opt_f(*args, **kwargs)
- graph_count = len(graphs)
- graph_break_count = graph_count - 1
- compile_time = compile_times(repr="str")
- # TODO(voz): Do we want a decorator for this?
- reset()
- from .backends.debugging import ExplainOutput
- return ExplainOutput(
- graphs,
- graph_count,
- graph_break_count,
- break_reasons,
- op_count,
- ops_per_graph,
- out_guards,
- compile_time,
- )
- if extra_args or extra_kwargs:
- warnings.warn(
- "explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. "
- "If you don't migrate, we may break your explain call in the future if your user defined kwargs "
- "conflict with future kwargs added to explain(f).",
- FutureWarning,
- stacklevel=2,
- )
- return inner(*extra_args, **extra_kwargs)
- else:
- return inner
- class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
- def __init__(
- self,
- m: torch.fx.GraphModule,
- flat_args: Tuple[Any],
- matched_input_elements_positions: List[int],
- flat_results: List[Any],
- matched_output_elements_positions: List[int],
- example_fake_inputs: List[torch.Tensor],
- flat_args_dynamic_dims: List[Set[int]],
- fake_mode: Optional[fake_tensor.FakeTensorMode] = None,
- ):
- super().__init__(m)
- assert len(flat_args_dynamic_dims) == len(flat_args)
- matched_input_elements_to_fake = {
- val: example_fake_inputs[ix]
- for ix, val in enumerate(matched_input_elements_positions)
- }
- self.new_args = []
- for i in range(0, len(flat_args)):
- arg = super().placeholder(f"arg{i}", (), {})
- if i in matched_input_elements_to_fake:
- arg.node.meta["val"] = matched_input_elements_to_fake[i]
- else:
- # Fill node.mata["val"] with faketensor from the input,
- # if it's not found in matched_input_elements_positions
- if fake_mode is not None and isinstance(flat_args[i], torch.Tensor):
- # TODO(zhxchen17) Also preserve all the user constraints here.
- arg.node.meta["val"] = fake_mode.from_tensor(
- flat_args[i],
- symbolic_context=StatelessSymbolicContext(
- dynamic_sizes=[
- DimDynamic.DYNAMIC
- if d in flat_args_dynamic_dims[i]
- else DimDynamic.STATIC
- for d in range(len(flat_args[i].shape))
- ],
- constraint_sizes=[None] * len(flat_args[i].shape),
- ),
- )
- self.new_args.append(arg)
- self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions)
- self.matched_output_elements_positions = matched_output_elements_positions
- self.flat_results = flat_results
- def placeholder(self, target, args, kwargs):
- arg = next(self.old_args_gen)
- if "val" in self.current_node.meta:
- arg.node.meta["val"] = self.current_node.meta["val"]
- if "tensor_dict" in self.current_node.meta:
- arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
- if "example_value" in self.current_node.meta:
- # NB: intentionally do not use set_example_value
- arg.node.meta["example_value"] = self.current_node.meta["example_value"]
- if "unbacked_bindings" in self.current_node.meta:
- arg.node.meta["unbacked_bindings"] = self.current_node.meta[
- "unbacked_bindings"
- ]
- return arg
- def output(self, target, args, kwargs):
- dynamo_result_flat = args[0]
- lookup = [*dynamo_result_flat, *self.new_args]
- new_results_flat = []
- for i in range(len(self.flat_results)):
- if self.matched_output_elements_positions[i] is not None:
- new_results_flat.append(
- lookup[self.matched_output_elements_positions[i]]
- )
- else:
- const_val = self.flat_results[i]
- assert isinstance(const_val, tuple(common_constant_types))
- new_results_flat.append(const_val)
- return super().output(target, (new_results_flat,), {})
- def run_node(self, n):
- self.current_node = n
- result_proxy = super().run_node(n)
- if "val" in self.current_node.meta:
- result_proxy.node.meta["val"] = self.current_node.meta["val"]
- if "example_value" in self.current_node.meta:
- # NB: intentionally do not use set_example_value
- result_proxy.node.meta["example_value"] = self.current_node.meta[
- "example_value"
- ]
- if "unbacked_bindings" in self.current_node.meta:
- result_proxy.node.meta["unbacked_bindings"] = self.current_node.meta[
- "unbacked_bindings"
- ]
- if self.current_node.op != "output":
- result_proxy.node._rename(
- getattr(self.current_node, "name", result_proxy.node.name)
- )
- return result_proxy
- def transform(self):
- result_gm = super().transform()
- if "dynamo_flat_name_to_original_fqn" in self.module.meta:
- result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
- "dynamo_flat_name_to_original_fqn"
- ]
- return result_gm
- class ExportResult(NamedTuple):
- graph_module: torch.fx.GraphModule
- guards: _guards.GuardsSet
- # NB: Do not add new fields without overriding __iter__; people are
- # destructuring so it is BC-breaking
- def check_signature_rewritable(graph):
- input_errors = []
- for node in graph.graph.find_nodes(op="placeholder"):
- assert hasattr(node, "_dynamo_source")
- source = node._dynamo_source
- user_stacks = graph._source_to_user_stacks.get(source)
- if user_stacks is None:
- continue
- assert len(user_stacks) > 0
- # In some cases we may not have a useful stack. Look for a
- # useful stack
- stack = None
- for s in user_stacks:
- if len(s) == 0:
- continue
- stack = s
- break
- if stack is None:
- msg = f"{source.name()}, a closed over free variable"
- else:
- tb = "".join(traceback.format_list(stack))
- extra = ""
- if len(user_stacks) > 1:
- extra = f"(elided {len(user_stacks) - 1} more accesses)"
- msg = f"{source.name()}, accessed at:\n{tb}{extra}"
- # TODO: option to print ALL of the stack traces at once
- input_errors.append(msg)
- if input_errors:
- raise UserError(
- UserErrorType.INVALID_INPUT,
- "Cannot export model which references tensors that are neither "
- "buffers/parameters/constants nor are direct inputs. For each tensor, if you'd "
- "like this tensor to be an explicit input, add it as a dummy argument "
- "to the top-level model definition you are exporting; if you would "
- "like its value to be embedded as an exported constant, wrap its access "
- "in a function marked with @assume_constant_result.\n\n"
- + "\n\n".join(input_errors),
- )
- def rewrite_signature(
- f_sig,
- graph,
- fake_mode,
- flat_args,
- in_spec,
- example_fake_inputs,
- graph_captured_input,
- graph_captured_output,
- dynamo_traced_result,
- flat_args_dynamic_dims,
- ):
- orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec)
- def check_user_input_output(flat_values, error_type):
- supported_types = [
- torch.Tensor,
- torch.SymInt,
- torch.SymFloat,
- torch.SymBool,
- torch._C.ScriptObject,
- ] + list(common_constant_types)
- def is_supported_type(val):
- return isinstance(val, tuple(supported_types))
- value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output"
- # We only check that the outputs are not None. Inputs can be None.
- for v in flat_values:
- if not is_supported_type(v):
- if error_type == UserErrorType.INVALID_INPUT and v is None:
- continue
- raise UserError(
- error_type,
- f"It looks like one of the {value_type}s with type `{type(v)}` "
- "is not supported or pytree-flattenable. \n"
- f"Exported graphs {value_type}s can only contain the "
- f"following supported types: {supported_types}. \n"
- "If you are using a custom class object, "
- "please register a pytree_flatten/unflatten function "
- "using `torch.utils._pytree.register_pytree_node` or "
- "`torch.export.register_dataclass`.",
- )
- check_user_input_output(flat_args, UserErrorType.INVALID_INPUT)
- flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result)
- check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT)
- def produce_matching(debug_type, sources, candidates):
- matched_elements_positions: List[Optional[int]] = []
- dict_of_source_vals = {}
- for i, val in enumerate(sources):
- dict_of_source_vals[id(val)] = i
- for i, val in enumerate(candidates):
- if isinstance(val, tuple(common_constant_types)):
- matched_elements_positions.append(None)
- elif id(val) not in dict_of_source_vals:
- raise AssertionError(
- f"Unexpectedly found a {type(val)} in the {debug_type}.\n"
- 'Please file an issue along with a paste of the logs from TORCH_LOGS="+export"'
- )
- else:
- matched_elements_positions.append(dict_of_source_vals[id(val)])
- return matched_elements_positions
- matched_input_elements_positions = produce_matching(
- "inputs", flat_args, graph_captured_input
- )
- assert graph_captured_output is not None
- matched_output_elements_positions = produce_matching(
- "outputs", list(graph_captured_output) + flat_args, flat_results_traced
- )
- new_graph = FlattenInputOutputSignature(
- graph,
- flat_args,
- matched_input_elements_positions,
- flat_results_traced,
- matched_output_elements_positions,
- example_fake_inputs,
- flat_args_dynamic_dims,
- fake_mode,
- ).transform()
- # Make dynamo graph to have same input/output spec as user code
- def argument_names(f_sig, args, kwargs) -> List[str]:
- def signature_to_fullargspec(sig: inspect.Signature):
- # Get a list of Parameter objects from the Signature object
- params = list(sig.parameters.values())
- # Separate positional arguments, keyword-only arguments and varargs/varkw
- args = [
- p.name
- for p in params
- if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
- ]
- kwonlyargs = [
- p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
- ]
- varargs = next(
- (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL),
- None,
- )
- varkw = next(
- (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD),
- None,
- )
- # Get default values for positional arguments and keyword-only arguments
- defaults = tuple(
- p.default
- for p in params
- if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
- and p.default is not inspect.Parameter.empty
- )
- kwonlydefaults = {
- p.name: p.default
- for p in params
- if p.kind == inspect.Parameter.KEYWORD_ONLY
- and p.default is not inspect.Parameter.empty
- }
- # Get annotations for parameters and return value
- annotations = {}
- if sig.return_annotation:
- annotations = {"return": sig.return_annotation}
- for parameter in params:
- annotations[parameter.name] = parameter.annotation
- # Return a FullArgSpec object with the extracted attributes
- return inspect.FullArgSpec(
- args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
- )
- fullargspec = signature_to_fullargspec(f_sig)
- # 1. Map `args` 1-to-1 to positional arguments in original signature.
- input_strs = fullargspec.args[: len(args)]
- if len(args) > len(fullargspec.args):
- # 2. If there are more arguments left in `args`, they map to varargs in original
- # signature. Assign names as {varargs}_0, {varargs}_1, ...
- assert fullargspec.varargs is not None, "More arguments than expected"
- input_strs += [
- f"{fullargspec.varargs}_{i}"
- for i in range(0, len(args) - len(input_strs))
- ]
- elif len(args) < len(fullargspec.args):
- # 3. If there are fewer arguments in `args` than `fullargspec.args`,
- # it implies these are arguments either with default values, or provided in
- # `kwargs`. The former can be safely ignored. Because Dynamo.export does not
- # export them as part of the function signature. The latter will be handled
- # in the next step.
- for unprovided_arg in fullargspec.args[
- len(args) : -len(fullargspec.defaults or [])
- ]:
- assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}"
- # 4. Keyword arguments provided in `kwargs`.
- input_strs += list(kwargs.keys())
- # 5. Keyword-only arguments with default values if not provided are not exported
- # as part of the function signature.
- for kwonly_arg in fullargspec.kwonlyargs:
- kwonlydefaults = fullargspec.kwonlydefaults or {}
- assert (
- kwonly_arg in kwargs or kwonly_arg in kwonlydefaults
- ), f"Missing keyword only argument {kwonly_arg}"
- return input_strs
- new_graph.graph._codegen = _PyTreeCodeGen(
- _PyTreeInfo(
- argument_names(f_sig, orig_args, orig_kwargs),
- in_spec,
- out_spec_traced,
- )
- )
- new_graph.recompile()
- return new_graph
- def export(
- f: Callable[..., Any],
- *extra_args,
- aten_graph: bool = False,
- pre_dispatch: bool = False,
- decomposition_table: Optional[
- Dict[torch._ops.OpOverload, Callable[..., Any]]
- ] = None,
- tracing_mode: str = "symbolic",
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
- assume_static_by_default: bool = False,
- same_signature: bool = True,
- disable_constraint_solver: bool = False,
- prefer_deferred_runtime_asserts_over_guards: bool = False,
- _allow_complex_guards_as_runtime_asserts: bool = False,
- _log_export_usage: bool = True,
- **extra_kwargs,
- ) -> Callable[..., ExportResult]:
- """
- Export an input function f to a format that can be executed outside of PyTorch using the FX graph.
- Args:
- f (callable): A PyTorch function to be exported.
- aten_graph (bool): If True, exports a graph with ATen operators.
- If False, exports a graph with Python operators. Default is False.
- pre_dispatch (bool): If True, exports a graph with ATen operators,
- but before any logic in the PyTorch dispatcher has run.
- This can be useful if you want to apply further transformations on a graph before running it
- through autograd, autocast, or any other functionalities that are integrated into the dispatcher.
- This flag is only valid if aten_graph=True is set.
- Default is False.
- decomposition_table (dict): A dictionary that maps operators to their decomposition functions.
- Required if aten_graph or tracing_mode is specified. Default is None.
- tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic".
- dynamic_shapes:
- An optional argument where the type should either be:
- 1) a dict from argument names of ``f`` to their dynamic shape specifications,
- 2) a tuple that specifies dynamic shape specifications for each input in original order.
- If you are specifying dynamism on keyword args, you will need to pass them in the order that
- is defined in the original function signature.
- The dynamic shape of a tensor argument can be specified as either
- (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
- not required to include static dimension indices in this dict, but when they are,
- they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
- where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
- are denoted by None. Arguments that are dicts or tuples / lists of tensors are
- recursively specified by using mappings or sequences of contained specifications.
- same_signature (bool): If True, rewrite the returned graph's signature to be the same as f.
- disable_constraint_solver (bool): Whether the dim constraint solver must be disabled.
- Returns:
- A function that given args and kwargs, returns a tuple of (graph, guards)
- Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options.
- Guards: The guards we accumulated during tracing f above
- Raises:
- AssertionError: If decomposition_table is specified without setting aten_graph=True,
- or if graph breaks during tracing in export.
- AssertionError: If Dynamo input and output is not consistent with traced input/output.
- Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
- """
- if _log_export_usage:
- log_export_usage(event="export.private_api", flags={"_dynamo"})
- # Deal with "local variable referenced before assignment"
- _f = f
- _assume_static_by_default = assume_static_by_default
- def inner(*args, **kwargs):
- constraints = _process_dynamic_shapes(_f, args, kwargs, dynamic_shapes)
- f = _f
- assume_static_by_default = _assume_static_by_default
- check_if_dynamo_supported()
- torch._C._log_api_usage_once("torch._dynamo.export")
- if decomposition_table is not None:
- assert (
- aten_graph
- ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
- if pre_dispatch:
- assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
- f = innermost_fn(f)
- call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
- original_signature = inspect.signature(call_to_inspect)
- graph = None
- out_guards = None
- graph_captured_input = None
- graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
- fake_mode = None
- def guard_export_print(guards: _guards.GuardsSet):
- nonlocal out_guards
- assert (
- out_guards is None
- ), "whole graph export entails exactly one guard export"
- out_guards = guards
- example_inputs = []
- def dynamo_normalization_capturing_compiler(
- gm: torch.fx.GraphModule, inner_example_inputs
- ):
- nonlocal graph
- assert (
- graph is None
- ), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
- graph = gm
- nonlocal fake_mode, example_inputs
- # NB: do NOT pass inner_example_inputs here, we are detecting the
- # Dynamo allocated fake mode, which should be DISTINCT from a
- # potential outer ambient fake mode which the user provided.
- # example_inputs is always the user specified inputs, so they
- # would have the wrong fake mode attached to them
- fake_mode = _guards.detect_fake_mode()
- example_inputs = inner_example_inputs
- def result_capturing_wrapper(*graph_inputs):
- nonlocal graph_captured_result
- nonlocal graph_captured_input
- graph_captured_input = graph_inputs
- assert graph is not None
- named_parameters = dict(graph.named_parameters(remove_duplicate=False))
- named_buffers = dict(graph.named_buffers(remove_duplicate=False))
- ambient_fake_mode = (
- _guards.detect_fake_mode(graph_inputs)
- if _guards.detect_fake_mode(graph_inputs) is not None
- else fake_mode
- )
- # We reran fake tensor propagation, but we didn't do
- # anything with the resulting unbacked SymInts. Drop them
- # from the pending list.
- # NB: this is wrong if graph_captured_result has
- # data-dependent output size!
- ignore_fresh_unbacked = null_context()
- if shape_env := ambient_fake_mode.shape_env:
- ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols()
- with (
- ambient_fake_mode
- ), enable_python_dispatcher(), ignore_fresh_unbacked:
- params_and_buffers = {
- **named_parameters,
- **named_buffers,
- }
- fake_params_buffers = dict()
- for name, value in params_and_buffers.items():
- fake_params_buffers[name] = ambient_fake_mode.from_tensor(
- value, static_shapes=True
- )
- fake_graph_inputs = pytree.tree_map(
- ambient_fake_mode.from_tensor, graph_inputs
- )
- graph_captured_result = torch.func.functional_call(
- graph, fake_params_buffers, fake_graph_inputs
- )
- return graph_captured_result
- return result_capturing_wrapper
- # Note: This is needed by rewrite_signature. We need to put it before
- # optimize_assert since user program may mutate the inputs.
- flat_args, in_spec = pytree.tree_flatten((args, kwargs))
- remove_from_cache(f)
- constraint_violation_error = None
- if tracing_mode != "symbolic":
- assume_static_by_default = True
- with config.patch(
- specialize_int=True,
- assume_static_by_default=assume_static_by_default,
- automatic_dynamic_shapes=False,
- capture_dynamic_output_shape_ops=True,
- capture_scalar_outputs=True,
- prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
- _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts,
- ):
- opt_f = optimize_assert(
- dynamo_normalization_capturing_compiler,
- hooks=Hooks(
- guard_export_fn=guard_export_print,
- guard_fail_fn=None,
- ),
- export=True,
- export_constraints=constraints,
- )(f)
- # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
- try:
- result_traced = opt_f(*args, **kwargs)
- except ConstraintViolationError as e:
- constraint_violation_error = e
- remove_from_cache(f)
- if (
- not disable_constraint_solver
- and (shape_env := getattr(fake_mode, "shape_env", None)) is not None
- and (dim_constraints := shape_env.dim_constraints) is not None
- and not isinstance(
- call_to_inspect, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
- )
- and not trace_rules.check(call_to_inspect)
- ):
- dim_constraints.solve()
- dim_constraints.remove_redundant_dynamic_results()
- forced_specializations = dim_constraints.forced_specializations()
- msg = dim_constraints.prettify_results(
- original_signature,
- dynamic_shapes,
- constraint_violation_error,
- forced_specializations,
- )
- if constraint_violation_error:
- constraint_violation_error.args = (
- constraint_violation_error.args[0] + msg,
- )
- else:
- if forced_specializations:
- constraint_violation_error = ConstraintViolationError(msg)
- else:
- log.info(
- "Summary of dimension constraints:%s",
- msg,
- )
- # Error if we have any constraints on static values
- for k in shape_env.var_to_range.keys():
- if isinstance(k, sympy.Integer):
- constraint_violation_error = ConstraintViolationError(
- f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
- "It appears that you're trying to set a constraint on a "
- f"value which we evaluated to have a static value of {k}. "
- 'Set TORCH_LOGS="+export" for more information.'
- )
- if constraint_violation_error:
- raise constraint_violation_error
- assert (
- graph is not None
- ), "Failed to produce a graph during tracing as no tensor operations were found."
- assert hasattr(graph, "_source_to_user_stacks")
- assert out_guards is not None, "Failed to produce guards during tracing"
- assert fake_mode is not None
- log.info(
- "Dynamo captured graph:\n\n%s", graph.print_readable(print_output=False)
- )
- # This check need to happened before aten_graph
- # because placeholder's _source_node attribute is not preserved by make_fx
- if same_signature:
- check_signature_rewritable(graph)
- # NB: This is mostly hitting the cache; Dynamo already converted these
- example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs]
- if aten_graph:
- # Running graph with interpreter is needed for propagating the stack_trace
- def graph_with_interpreter(*args):
- with torch.fx.traceback.preserve_node_meta():
- return torch.fx.Interpreter(graph).run(*args)
- with maybe_disable_fake_tensor_mode(), enable_python_dispatcher(), (
- fake_mode
- ):
- try:
- graph = make_fx(
- graph_with_interpreter,
- decomposition_table=decomposition_table,
- tracing_mode="real",
- _allow_non_fake_inputs=True,
- pre_dispatch=pre_dispatch,
- _allow_fake_constant=False,
- )(*example_fake_inputs)
- except CondOpArgsMismatchError as e:
- # Wrap the internal error to the user-facing error
- raise UserError( # noqa: B904
- UserErrorType.DYNAMIC_CONTROL_FLOW,
- str(e),
- case_name="cond_operands",
- )
- assert graph is not None
- for node in graph.graph.find_nodes(op="get_attr"):
- if isinstance(getattr(graph, node.target), torch.Tensor):
- node.meta["val"] = fake_mode.from_tensor(
- getattr(graph, node.target), static_shapes=True
- )
- if same_signature:
- flat_args_dynamic_dims = [
- {c.dim for c in (constraints or ()) if c.w_tensor() is x}
- for x in flat_args
- ]
- graph = rewrite_signature(
- original_signature,
- graph,
- fake_mode,
- flat_args,
- in_spec,
- example_fake_inputs,
- graph_captured_input,
- graph_captured_result,
- result_traced, # type: ignore[possibly-undefined]
- flat_args_dynamic_dims,
- )
- # Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check
- assert graph is not None
- graph.meta["input_shape_constraints"] = (
- [constraint.serializable_spec for constraint in constraints]
- if constraints
- else []
- )
- return ExportResult(graph, out_guards)
- if extra_args or extra_kwargs:
- warnings.warn(
- "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. "
- "If you don't migrate, we may break your export call in the future if your user defined kwargs "
- "conflict with future kwargs added to export(f).",
- FutureWarning,
- stacklevel=2,
- )
- return inner(*extra_args, **extra_kwargs)
- else:
- return inner
- def optimize_assert(
- backend,
- *,
- hooks=Hooks(None, None),
- export=False,
- export_constraints=None,
- dynamic=None,
- rebuild_ctx=None,
- ):
- """
- The same as `torch._dynamo.optimize(backend, nopython=True)`
- """
- backend = get_compiler_fn(backend)
- # Find if backend has any extra context manager
- backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
- return _optimize_catch_errors(
- convert_frame.convert_frame_assert(
- backend, export=export, export_constraints=export_constraints
- ),
- hooks,
- backend_ctx_ctor,
- export=export,
- dynamic=dynamic,
- rebuild_ctx=rebuild_ctx,
- )
- class TorchPatcher:
- @staticmethod
- @functools.lru_cache(None)
- def patch():
- # A better way to disable the following would be decorate the source
- # functions with @torch._disable_dynamo. However, this causes issues
- # with torch.deploy internally.
- from .decorators import disable
- torch.jit.trace = disable(torch.jit.trace)
- torch.jit.trace_module = disable(torch.jit.trace_module)
- torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph)
- torch.fx._symbolic_trace.Tracer.trace = disable(
- torch.fx._symbolic_trace.Tracer.trace
- )
- torch.distributions.Distribution.set_default_validate_args(False)
- from ..optim import (
- adadelta,
- adagrad,
- adam,
- adamax,
- adamw,
- asgd,
- lbfgs,
- nadam,
- radam,
- rmsprop,
- rprop,
- sgd,
- sparse_adam,
- )
- optimizer_modules = {
- adadelta,
- adagrad,
- adam,
- adamax,
- adamw,
- asgd,
- lbfgs,
- nadam,
- radam,
- rmsprop,
- rprop,
- sgd,
- sparse_adam,
- }
- for opt_mod in optimizer_modules:
- opt_name = opt_mod.__name__.split(".")[-1]
- fused_fn_name = f"_fused_{opt_name}"
- single_tensor_fn_name = f"_single_tensor_{opt_name}"
- if hasattr(opt_mod, fused_fn_name):
- setattr(
- opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name))
- )
- optimizer_classes = [
- opt
- for opt in torch.optim.__dict__.values()
- if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
- ]
- # Note: we don't support sparsity or tracing through backwards
- excluded_optimizer_classes = {
- torch.optim.SparseAdam,
- torch.optim.LBFGS,
- }
- for opt in optimizer_classes:
- if opt in excluded_optimizer_classes:
- opt.step = disable(opt.step)
- if hasattr(opt, "_init_group"):
- opt._init_group = disable(opt._init_group)
- @staticmethod
- def suppress_torch_distributed_warnings(fn):
- def inner_fn(*args, **kwargs):
- warnings.filterwarnings(
- "ignore", category=UserWarning, module="torch.distributed"
- )
- return fn(*args, **kwargs)
- return inner_fn
|