eval_frame.py 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634
  1. # mypy: allow-untyped-defs
  2. # mypy: disable-error-code="method-assign"
  3. """
  4. Functions in this file are responsible for modifying the eval frame
  5. handler at RUNTIME. Therefore, all functions in this file are hot.
  6. Functions that only execute at compile time should be placed
  7. in torch._dynamo.convert_frame.
  8. """
  9. from __future__ import annotations
  10. import contextlib
  11. import functools
  12. import inspect
  13. import logging
  14. import os
  15. import sys
  16. import textwrap
  17. import traceback
  18. import types
  19. import warnings
  20. import weakref
  21. from enum import Enum
  22. from os.path import dirname, join
  23. from typing import (
  24. Any,
  25. Callable,
  26. Dict,
  27. List,
  28. NamedTuple,
  29. Optional,
  30. Set,
  31. Tuple,
  32. TYPE_CHECKING,
  33. Union,
  34. )
  35. from unittest.mock import patch
  36. import torch
  37. import torch.fx
  38. import torch.utils._pytree as pytree
  39. import torch.utils.checkpoint
  40. from torch import _guards
  41. from torch._utils_internal import log_export_usage
  42. from torch.export.dynamic_shapes import _process_dynamic_shapes
  43. from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
  44. from torch.fx.experimental.symbolic_shapes import (
  45. ConstraintViolationError,
  46. DimDynamic,
  47. StatelessSymbolicContext,
  48. )
  49. from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
  50. from ..fx import GraphModule
  51. from .backends.registry import CompilerFn, lookup_backend
  52. from .hooks import Hooks
  53. # see discussion at https://github.com/pytorch/pytorch/issues/120699
  54. reset_code = torch._C._dynamo.eval_frame.reset_code # noqa: F401
  55. set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame # noqa: F401
  56. set_guard_error_hook = torch._C._dynamo.eval_frame.set_guard_error_hook # noqa: F401
  57. skip_code = torch._C._dynamo.eval_frame.skip_code # noqa: F401
  58. unsupported = torch._C._dynamo.eval_frame.unsupported # noqa: F401
  59. from . import config, convert_frame, external_utils, trace_rules, utils
  60. from .code_context import code_context
  61. from .exc import CondOpArgsMismatchError, UserError, UserErrorType
  62. from .mutation_guard import install_generation_tagging_init
  63. from .utils import common_constant_types, compile_times
  64. log = logging.getLogger(__name__)
  65. from torch._dispatch.python import enable_python_dispatcher
  66. always_optimize_code_objects = utils.ExactWeakKeyDictionary()
  67. null_context = contextlib.nullcontext
  68. import sympy
  69. if TYPE_CHECKING:
  70. from torch._subclasses import fake_tensor
  71. from .types import CacheEntry, DynamoCallback
  72. # See https://github.com/python/typing/pull/240
  73. class Unset(Enum):
  74. token = 0
  75. cached_backends: Dict[int, CompilerFn] = {}
  76. unset = Unset.token
  77. def _reset_guarded_backend_cache():
  78. global cached_backends
  79. for backend in cached_backends.values():
  80. if hasattr(backend, "reset"):
  81. backend.reset()
  82. cached_backends.clear()
  83. DONT_WRAP_FILES = {
  84. # For tracing into fx modules
  85. inspect.getsourcefile(GraphModule),
  86. join(dirname(dirname(__file__)), "onnx/_internal/fx/dynamo_graph_extractor.py"),
  87. }
  88. def _debug_get_cache_entry_list(
  89. code: Union[types.CodeType, Callable[..., Any]]
  90. ) -> List[CacheEntry]:
  91. """
  92. Given a code object or a callable object, retrieve the cache entries
  93. stored in this code.
  94. """
  95. if callable(code):
  96. code = code.__code__
  97. return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
  98. class OptimizedModule(torch.nn.Module):
  99. """
  100. Wraps the original nn.Module object and later patches its
  101. forward method to optimized self.forward method.
  102. """
  103. _torchdynamo_orig_callable: Callable[..., Any]
  104. get_compiler_config: Callable[[], Any]
  105. _opt_mod_attributes = {
  106. "_orig_mod",
  107. "dynamo_ctx",
  108. "_torchdynamo_orig_callable",
  109. "get_compiler_config",
  110. "forward",
  111. "_forward",
  112. "__dict__",
  113. "named_children_walk",
  114. }
  115. def __init__(self, mod: torch.nn.Module, dynamo_ctx):
  116. super().__init__()
  117. # Installs the params/buffer
  118. self._orig_mod = mod
  119. self.dynamo_ctx = dynamo_ctx
  120. self._initialize()
  121. def _initialize(self):
  122. # Do this stuff in constructor to lower overhead slightly
  123. if isinstance(self.dynamo_ctx, DisableContext):
  124. # No need to check trace rules
  125. self.forward = self.dynamo_ctx(self._orig_mod.__call__)
  126. elif isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check(
  127. self._orig_mod.forward
  128. ):
  129. # This may be a torch.nn.* instance in trace_rules.py which
  130. # won't trigger a frame evaluation workaround to add an extra
  131. # frame we can capture
  132. self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod))
  133. else:
  134. # Invoke hooks outside of dynamo then pickup the inner frame
  135. self.forward = self.dynamo_ctx(self._orig_mod.__call__)
  136. if hasattr(self._orig_mod, "_initialize_hook"):
  137. self._forward = self.forward
  138. self.forward = self._call_lazy_check
  139. def __reduce__(self):
  140. return (self.__class__, (self._orig_mod, self.dynamo_ctx))
  141. def __getstate__(self):
  142. state = dict(self.__dict__)
  143. state.pop("forward", None)
  144. state.pop("__call__", None)
  145. return state
  146. def __setstate__(self, state):
  147. self.__dict__ = state
  148. self._initialize()
  149. def __getattr__(self, name):
  150. if name == "_orig_mod":
  151. return self._modules["_orig_mod"]
  152. return getattr(self._orig_mod, name)
  153. def __setattr__(self, name, val):
  154. # Allow patching over class attributes
  155. if hasattr(type(self), name):
  156. return super().__setattr__(name, val)
  157. if name in OptimizedModule._opt_mod_attributes:
  158. return super().__setattr__(name, val)
  159. return setattr(self._orig_mod, name, val)
  160. def _call_lazy_check(self, *args, **kwargs):
  161. if hasattr(self._orig_mod, "_initialize_hook"):
  162. # In the case of a lazy module, we want to run
  163. # the pre-hooks which initialize it.
  164. # Afterwards, lazy module deletes its pre-hooks
  165. # to avoid treating it as lazy on subsequent recompile.
  166. self._orig_mod._infer_parameters(self._orig_mod, args, kwargs)
  167. return self._forward(*args, **kwargs)
  168. def __dir__(self):
  169. orig_mod_attrs = self._orig_mod.__dir__()
  170. return orig_mod_attrs + [
  171. attr for attr in super().__dir__() if attr not in orig_mod_attrs
  172. ]
  173. def remove_from_cache(f):
  174. """
  175. Make sure f.__code__ is not cached to force a recompile
  176. """
  177. if isinstance(f, types.CodeType):
  178. reset_code(f)
  179. elif hasattr(f, "__code__"):
  180. reset_code(f.__code__)
  181. elif hasattr(getattr(f, "forward", None), "__code__"):
  182. reset_code(f.forward.__code__)
  183. else:
  184. from . import reset # type: ignore[attr-defined]
  185. reset()
  186. log.warning("could not determine __code__ for %s", f)
  187. def nothing():
  188. pass
  189. def always_false():
  190. return False
  191. def innermost_fn(fn):
  192. """
  193. In case of nesting of _TorchDynamoContext calls, find the innermost
  194. function. TorchDynamo caches on fn.__code__ object, so its necessary to find
  195. the innermost function to pass on the optimize, run, disable etc.
  196. """
  197. unaltered_fn = fn
  198. while hasattr(unaltered_fn, "_torchdynamo_orig_callable"):
  199. unaltered_fn = unaltered_fn._torchdynamo_orig_callable
  200. assert callable(unaltered_fn)
  201. return unaltered_fn
  202. def make_set_enable_dynamic(enable: bool):
  203. assert isinstance(enable, bool)
  204. if enable:
  205. # Assume everything is dynamic by default
  206. return config._make_closure_patcher(assume_static_by_default=False)
  207. else:
  208. return config._make_closure_patcher(
  209. automatic_dynamic_shapes=False, assume_static_by_default=True
  210. )
  211. class _TorchDynamoContext:
  212. def __init__(
  213. self,
  214. callback: DynamoCallback,
  215. on_enter=nothing,
  216. backend_ctx_ctor=null_context,
  217. patch_fn=nothing,
  218. first_ctx=False,
  219. *,
  220. export=False,
  221. dynamic=None,
  222. compiler_config=None,
  223. ):
  224. super().__init__()
  225. assert callable(callback) or callback is False or callback is None
  226. self.callback: DynamoCallback = callback
  227. self._backend_ctx_ctor = backend_ctx_ctor
  228. self.prior: Union[Unset, DynamoCallback] = unset
  229. self.first_ctx = first_ctx
  230. self.export = export
  231. self._dynamic = dynamic
  232. self.compiler_config = compiler_config
  233. self.cleanup_fns: List[Callable[[], Any]] = []
  234. self.enter_exit_hooks = []
  235. patch_fn()
  236. # Save the backends so that we can reset them during torch._dynamo.reset
  237. backend = innermost_fn(callback)
  238. cached_backends.setdefault(id(backend), backend)
  239. if dynamic is not None:
  240. self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic))
  241. if on_enter is not nothing:
  242. # this case is not common
  243. def call_on_enter():
  244. on_enter()
  245. return nothing
  246. self.enter_exit_hooks.append(call_on_enter)
  247. if backend_ctx_ctor is not contextlib.nullcontext:
  248. # this case is not common
  249. def call_backend_ctx():
  250. ctx = backend_ctx_ctor()
  251. ctx.__enter__()
  252. return functools.partial(ctx.__exit__, None, None, None)
  253. self.enter_exit_hooks.append(call_backend_ctx)
  254. def __enter__(self):
  255. if config.raise_on_ctx_manager_usage:
  256. raise RuntimeError(
  257. "torch._dynamo.optimize(...) is used with a context manager. "
  258. "Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html "
  259. "to use torch._dynamo.optimize(...) as an annotation/decorator. "
  260. )
  261. self.cleanup_fns = [enter() for enter in self.enter_exit_hooks]
  262. self.prior = set_eval_frame(self.callback)
  263. def __exit__(self, exc_type, exc_val, exc_tb):
  264. assert self.prior is not unset
  265. set_eval_frame(self.prior)
  266. self.prior = unset
  267. for cleanup in self.cleanup_fns:
  268. cleanup()
  269. self.cleanup_fns.clear()
  270. def __call__(self, fn):
  271. # public api for compiler config/options
  272. def get_compiler_config():
  273. return self.compiler_config
  274. fn = innermost_fn(fn)
  275. # add context containing GraphModule to any GraphModule forward functions
  276. if isinstance(fn, GraphModule):
  277. # add context containing GraphModule to any GraphModule forward functions
  278. code_context.get_context(fn.forward.__code__)[
  279. "orig_graphmodule"
  280. ] = weakref.ref(fn)
  281. # Optimize the forward method of torch.nn.Module object
  282. if isinstance(fn, torch.nn.Module):
  283. mod = fn
  284. new_mod = OptimizedModule(mod, self)
  285. # Save the function pointer to find the original callable while nesting
  286. # of decorators.
  287. new_mod._torchdynamo_orig_callable = mod.forward
  288. # when compiling torch.nn.Module,
  289. # provide public api OptimizedModule.get_compiler_config()
  290. assert not hasattr(new_mod, "get_compiler_config")
  291. new_mod.get_compiler_config = get_compiler_config
  292. return new_mod
  293. if inspect.isclass(fn):
  294. # User has wrapped the class with compile/disable decorator. Apply
  295. # disable to init/call method.
  296. cls_obj = fn
  297. cls_obj.__call__ = self(cls_obj.__call__)
  298. if issubclass(cls_obj, torch.nn.Module):
  299. # NN module variable tracker directly inlines the _call_impl.
  300. cls_obj._call_impl = self(cls_obj._call_impl)
  301. return cls_obj
  302. assert callable(fn)
  303. try:
  304. filename = inspect.getsourcefile(fn)
  305. except TypeError:
  306. filename = None
  307. if (
  308. (filename is None or trace_rules.check(fn))
  309. and (
  310. getattr(fn, "__name__", "")
  311. not in ["_call_impl", "_wrapped_call_impl", "_lazy_forward"]
  312. )
  313. and filename not in DONT_WRAP_FILES
  314. ):
  315. # call to a builtin without a frame for us to capture
  316. fn = external_utils.wrap_inline(fn)
  317. def do_nothing(*arg, **kwargs):
  318. pass
  319. if hasattr(self, "callback"):
  320. callback = self.callback
  321. else:
  322. callback = do_nothing
  323. is_jit_tracing = torch._C._is_tracing
  324. is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing
  325. @functools.wraps(fn)
  326. def _fn(*args, **kwargs):
  327. if is_fx_tracing():
  328. if config.error_on_nested_fx_trace:
  329. raise RuntimeError(
  330. "Detected that you are using FX to symbolically trace "
  331. "a dynamo-optimized function. This is not supported at the moment."
  332. )
  333. else:
  334. return fn(*args, **kwargs)
  335. if is_jit_tracing():
  336. if config.error_on_nested_jit_trace:
  337. raise RuntimeError(
  338. "Detected that you are using FX to torch.jit.trace "
  339. "a dynamo-optimized function. This is not supported at the moment."
  340. )
  341. else:
  342. return fn(*args, **kwargs)
  343. cleanups = [enter() for enter in self.enter_exit_hooks]
  344. prior = set_eval_frame(callback)
  345. # Ensure that if an assertion occurs after graph pushes
  346. # something onto the DynamicLayerStack then we pop it off (the
  347. # constructed graph code isn't guarded with try/finally).
  348. #
  349. # This used to be a context but putting a `with` here is a noticible
  350. # perf regression (#126293)
  351. saved_dynamic_layer_stack_depth = (
  352. torch._C._functorch.get_dynamic_layer_stack_depth()
  353. )
  354. try:
  355. return fn(*args, **kwargs)
  356. finally:
  357. # Restore the dynamic layer stack depth if necessary.
  358. torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
  359. saved_dynamic_layer_stack_depth
  360. )
  361. set_eval_frame(prior)
  362. for cleanup in cleanups:
  363. cleanup()
  364. # hooks to properly handle inlining
  365. _fn._torchdynamo_inline = fn # type: ignore[attr-defined]
  366. # Save the function pointer to find the original callable while nesting
  367. # of decorators.
  368. _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
  369. # when compiling user function instead of nn.Module
  370. # provide public api _fn.get_compiler_config()
  371. assert not hasattr(_fn, "get_compiler_config")
  372. _fn.get_compiler_config = get_compiler_config # type: ignore[attr-defined]
  373. # If the function is called using torch._dynamo.optimize decorator, we
  374. # should prevent any type of skipping.
  375. if callback not in (None, False):
  376. if not hasattr(fn, "__code__"):
  377. raise RuntimeError(
  378. textwrap.dedent(
  379. """
  380. torch._dynamo.optimize is called on a non function object.
  381. If this is a callable class, please wrap the relevant code into a function and optimize the
  382. wrapper function.
  383. >> class CallableClass:
  384. >> def __init__(self):
  385. >> super().__init__()
  386. >> self.relu = torch.nn.ReLU()
  387. >>
  388. >> def __call__(self, x):
  389. >> return self.relu(torch.sin(x))
  390. >>
  391. >> def print_hello(self):
  392. >> print("Hello world")
  393. >>
  394. >> mod = CallableClass()
  395. If you want to optimize the __call__ function and other code, wrap that up in a function
  396. >> def wrapper_fn(x):
  397. >> y = mod(x)
  398. >> return y.sum()
  399. and then optimize the wrapper_fn
  400. >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn)
  401. """
  402. )
  403. )
  404. always_optimize_code_objects[fn.__code__] = True
  405. return _fn
  406. class OptimizeContext(_TorchDynamoContext):
  407. def __init__(
  408. self,
  409. callback,
  410. backend_ctx_ctor,
  411. first_ctx=False,
  412. *,
  413. export=False,
  414. dynamic=None,
  415. compiler_config=None,
  416. rebuild_ctx: Optional[
  417. Callable[[], Union[OptimizeContext, _NullDecorator]]
  418. ] = None,
  419. ):
  420. def on_enter():
  421. install_generation_tagging_init()
  422. super().__init__(
  423. callback=callback,
  424. on_enter=on_enter,
  425. backend_ctx_ctor=backend_ctx_ctor,
  426. patch_fn=TorchPatcher.patch,
  427. first_ctx=first_ctx,
  428. export=export,
  429. dynamic=dynamic,
  430. compiler_config=compiler_config,
  431. )
  432. if config.compiled_autograd:
  433. def call_compiled_autograd():
  434. assert rebuild_ctx is not None
  435. compiler_fn = rebuild_ctx()
  436. ctx = torch._dynamo.compiled_autograd.enable(compiler_fn)
  437. ctx.__enter__()
  438. return functools.partial(ctx.__exit__, None, None, None)
  439. self.enter_exit_hooks.append(call_compiled_autograd)
  440. def __reduce__(self):
  441. return (
  442. self.__class__,
  443. (self.callback, self._backend_ctx_ctor, self.first_ctx),
  444. {
  445. "export": self.export,
  446. "dynamic": self._dynamic,
  447. "compiler_config": self.compiler_config,
  448. },
  449. )
  450. class RunOnlyContext(_TorchDynamoContext):
  451. def __init__(self):
  452. # cudagraph trees relies on generation increment
  453. def on_enter():
  454. torch._dynamo.mutation_guard.GenerationTracker.generation += 1
  455. super().__init__(callback=False, on_enter=on_enter)
  456. def __reduce__(self):
  457. return (self.__class__, ())
  458. class DisableContext(_TorchDynamoContext):
  459. def __init__(self):
  460. super().__init__(callback=None)
  461. def __call__(self, fn):
  462. # Earlier this code was in the base class _TorchDynamoContext. But we
  463. # moved it here to have better code organization. For disable, we just
  464. # want the callback to be None. We don't have to check trace_rules or
  465. # create any wrapper.
  466. fn = innermost_fn(fn)
  467. if isinstance(fn, torch.nn.Module):
  468. mod = fn
  469. new_mod = OptimizedModule(mod, self)
  470. new_mod._torchdynamo_orig_callable = mod.forward
  471. return new_mod
  472. if inspect.isclass(fn):
  473. # User has wrapped the class with compile/disable decorator. Apply
  474. # disable to init/call method.
  475. cls_obj = fn
  476. # Disable on init is useful for reconstruction of bytecodes where we
  477. # want to prevent Dynamo from tracing into the init function. Check
  478. # test_reconstruction in test_model_output.py.
  479. cls_obj.__init__ = self(cls_obj.__init__)
  480. cls_obj.__call__ = self(cls_obj.__call__)
  481. if issubclass(cls_obj, torch.nn.Module):
  482. # NN module variable tracker directly inlines the _call_impl. Disable it.
  483. cls_obj._call_impl = self(cls_obj._call_impl)
  484. return cls_obj
  485. assert callable(fn)
  486. callback = self.callback
  487. @functools.wraps(fn)
  488. def _fn(*args, **kwargs):
  489. prior = set_eval_frame(callback)
  490. try:
  491. return fn(*args, **kwargs)
  492. finally:
  493. set_eval_frame(prior)
  494. _fn._torchdynamo_disable = True # type: ignore[attr-defined]
  495. # Save the function pointer to find the original callable while nesting
  496. # of decorators.
  497. _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
  498. return _fn
  499. def __reduce__(self):
  500. return (self.__class__, ())
  501. def _optimize_catch_errors(
  502. compile_fn,
  503. hooks: Hooks,
  504. backend_ctx_ctor=null_context,
  505. export=False,
  506. dynamic=None,
  507. compiler_config=None,
  508. rebuild_ctx=None,
  509. ):
  510. return OptimizeContext(
  511. convert_frame.catch_errors_wrapper(compile_fn, hooks),
  512. backend_ctx_ctor=backend_ctx_ctor,
  513. first_ctx=True,
  514. export=export,
  515. dynamic=dynamic,
  516. compiler_config=compiler_config,
  517. rebuild_ctx=rebuild_ctx,
  518. )
  519. def get_compiler_fn(compiler_fn):
  520. from .repro.after_dynamo import wrap_backend_debug
  521. if hasattr(compiler_fn, "compiler_name"):
  522. compiler_str = compiler_fn.compiler_name
  523. elif isinstance(compiler_fn, str):
  524. compiler_str = compiler_fn
  525. else:
  526. compiler_str = None
  527. compiler_fn = lookup_backend(compiler_fn)
  528. return wrap_backend_debug(compiler_fn, compiler_str)
  529. class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
  530. def __call__(self, fn):
  531. assert callable(fn)
  532. return fn
  533. def check_if_dynamo_supported():
  534. if sys.version_info >= (3, 13):
  535. raise RuntimeError("Python 3.13+ not yet supported for torch.compile")
  536. def is_dynamo_supported():
  537. try:
  538. check_if_dynamo_supported()
  539. return True
  540. except Exception:
  541. return False
  542. def check_if_inductor_supported():
  543. check_if_dynamo_supported()
  544. if sys.platform == "win32":
  545. raise RuntimeError("Windows not yet supported for inductor")
  546. def is_inductor_supported():
  547. try:
  548. check_if_inductor_supported()
  549. return True
  550. except Exception:
  551. return False
  552. def optimize(*args, **kwargs):
  553. def rebuild_ctx():
  554. return optimize(*args, **kwargs)
  555. return _optimize(rebuild_ctx, *args, **kwargs)
  556. def _optimize(
  557. rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]],
  558. backend="inductor",
  559. *,
  560. nopython=False,
  561. guard_export_fn=None,
  562. guard_fail_fn=None,
  563. disable=False,
  564. dynamic=None,
  565. ) -> Union[OptimizeContext, _NullDecorator]:
  566. """
  567. The main entrypoint of TorchDynamo. Do graph capture and call
  568. backend() to optimize extracted graphs.
  569. Args:
  570. backend: One of the two things:
  571. - Either, a function/callable taking a torch.fx.GraphModule and
  572. example_inputs and returning a python callable that runs the
  573. graph faster.
  574. One can also provide additional context for the backend, like
  575. torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute.
  576. See AOTAutogradMemoryEfficientFusionWithContext for the usage.
  577. - Or, a string backend name in `torch._dynamo.list_backends()`
  578. nopython: If True, graph breaks will be errors and there will
  579. be a single whole-program graph.
  580. disable: If True, turn this decorator into a no-op
  581. dynamic: If True, upfront compile as dynamic a kernel as possible. If False,
  582. disable all dynamic shapes support (always specialize). If None, automatically
  583. detect when sizes vary and generate dynamic kernels upon recompile.
  584. Example Usage::
  585. @torch._dynamo.optimize()
  586. def toy_example(a, b):
  587. ...
  588. """
  589. check_if_dynamo_supported()
  590. # Note: The hooks object could be global instead of passed around, *however* that would make
  591. # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
  592. # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
  593. # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an
  594. # easier to understand UX at the cost of a little more plumbing on our end.
  595. hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn)
  596. torch._C._log_api_usage_once("torch._dynamo.optimize")
  597. if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1":
  598. return _NullDecorator()
  599. backend = get_compiler_fn(backend)
  600. # Find if backend has any extra context manager
  601. backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
  602. if nopython:
  603. return optimize_assert(
  604. backend,
  605. dynamic=dynamic,
  606. hooks=hooks,
  607. rebuild_ctx=rebuild_ctx,
  608. )
  609. # The backend function is stashed in the callable returned by
  610. # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
  611. # be used by eval_frame.c to insert a guard on the backend.
  612. return _optimize_catch_errors(
  613. convert_frame.convert_frame(backend, hooks=hooks),
  614. hooks,
  615. backend_ctx_ctor,
  616. dynamic=dynamic,
  617. compiler_config=backend.get_compiler_config()
  618. if hasattr(backend, "get_compiler_config")
  619. else None,
  620. rebuild_ctx=rebuild_ctx,
  621. )
  622. # TODO(voz): Consider making "explain" output alongside a run / part of a run
  623. @patch("torch._dynamo.symbolic_convert.explain", True)
  624. def explain(f, *extra_args, **extra_kwargs):
  625. def inner(*args, **kwargs):
  626. # TODO(voz): Do we want a decorator for this?
  627. from . import reset # type: ignore[attr-defined]
  628. reset()
  629. graphs: List[torch.fx.GraphModule] = []
  630. break_reasons: List[Any] = []
  631. op_count: int = 0
  632. ops_per_graph: List[torch.fx.Node] = []
  633. out_guards: List[_guards.Guard] = []
  634. def dynamo_graph_accumulating_compiler(
  635. gm: torch.fx.GraphModule, example_inputs
  636. ):
  637. from .backends.debugging import _explain_graph_detail
  638. nonlocal graphs
  639. nonlocal op_count
  640. nonlocal ops_per_graph
  641. nonlocal break_reasons
  642. gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail(
  643. gm, graphs, op_count, ops_per_graph, break_reasons
  644. )
  645. return gm.forward
  646. def guard_export_print(guards):
  647. nonlocal out_guards
  648. out_guards.extend(guards)
  649. opt_f = optimize(
  650. dynamo_graph_accumulating_compiler,
  651. nopython=False,
  652. guard_export_fn=guard_export_print,
  653. )(f)
  654. # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
  655. opt_f(*args, **kwargs)
  656. graph_count = len(graphs)
  657. graph_break_count = graph_count - 1
  658. compile_time = compile_times(repr="str")
  659. # TODO(voz): Do we want a decorator for this?
  660. reset()
  661. from .backends.debugging import ExplainOutput
  662. return ExplainOutput(
  663. graphs,
  664. graph_count,
  665. graph_break_count,
  666. break_reasons,
  667. op_count,
  668. ops_per_graph,
  669. out_guards,
  670. compile_time,
  671. )
  672. if extra_args or extra_kwargs:
  673. warnings.warn(
  674. "explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. "
  675. "If you don't migrate, we may break your explain call in the future if your user defined kwargs "
  676. "conflict with future kwargs added to explain(f).",
  677. FutureWarning,
  678. stacklevel=2,
  679. )
  680. return inner(*extra_args, **extra_kwargs)
  681. else:
  682. return inner
  683. class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
  684. def __init__(
  685. self,
  686. m: torch.fx.GraphModule,
  687. flat_args: Tuple[Any],
  688. matched_input_elements_positions: List[int],
  689. flat_results: List[Any],
  690. matched_output_elements_positions: List[int],
  691. example_fake_inputs: List[torch.Tensor],
  692. flat_args_dynamic_dims: List[Set[int]],
  693. fake_mode: Optional[fake_tensor.FakeTensorMode] = None,
  694. ):
  695. super().__init__(m)
  696. assert len(flat_args_dynamic_dims) == len(flat_args)
  697. matched_input_elements_to_fake = {
  698. val: example_fake_inputs[ix]
  699. for ix, val in enumerate(matched_input_elements_positions)
  700. }
  701. self.new_args = []
  702. for i in range(0, len(flat_args)):
  703. arg = super().placeholder(f"arg{i}", (), {})
  704. if i in matched_input_elements_to_fake:
  705. arg.node.meta["val"] = matched_input_elements_to_fake[i]
  706. else:
  707. # Fill node.mata["val"] with faketensor from the input,
  708. # if it's not found in matched_input_elements_positions
  709. if fake_mode is not None and isinstance(flat_args[i], torch.Tensor):
  710. # TODO(zhxchen17) Also preserve all the user constraints here.
  711. arg.node.meta["val"] = fake_mode.from_tensor(
  712. flat_args[i],
  713. symbolic_context=StatelessSymbolicContext(
  714. dynamic_sizes=[
  715. DimDynamic.DYNAMIC
  716. if d in flat_args_dynamic_dims[i]
  717. else DimDynamic.STATIC
  718. for d in range(len(flat_args[i].shape))
  719. ],
  720. constraint_sizes=[None] * len(flat_args[i].shape),
  721. ),
  722. )
  723. self.new_args.append(arg)
  724. self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions)
  725. self.matched_output_elements_positions = matched_output_elements_positions
  726. self.flat_results = flat_results
  727. def placeholder(self, target, args, kwargs):
  728. arg = next(self.old_args_gen)
  729. if "val" in self.current_node.meta:
  730. arg.node.meta["val"] = self.current_node.meta["val"]
  731. if "tensor_dict" in self.current_node.meta:
  732. arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
  733. if "example_value" in self.current_node.meta:
  734. # NB: intentionally do not use set_example_value
  735. arg.node.meta["example_value"] = self.current_node.meta["example_value"]
  736. if "unbacked_bindings" in self.current_node.meta:
  737. arg.node.meta["unbacked_bindings"] = self.current_node.meta[
  738. "unbacked_bindings"
  739. ]
  740. return arg
  741. def output(self, target, args, kwargs):
  742. dynamo_result_flat = args[0]
  743. lookup = [*dynamo_result_flat, *self.new_args]
  744. new_results_flat = []
  745. for i in range(len(self.flat_results)):
  746. if self.matched_output_elements_positions[i] is not None:
  747. new_results_flat.append(
  748. lookup[self.matched_output_elements_positions[i]]
  749. )
  750. else:
  751. const_val = self.flat_results[i]
  752. assert isinstance(const_val, tuple(common_constant_types))
  753. new_results_flat.append(const_val)
  754. return super().output(target, (new_results_flat,), {})
  755. def run_node(self, n):
  756. self.current_node = n
  757. result_proxy = super().run_node(n)
  758. if "val" in self.current_node.meta:
  759. result_proxy.node.meta["val"] = self.current_node.meta["val"]
  760. if "example_value" in self.current_node.meta:
  761. # NB: intentionally do not use set_example_value
  762. result_proxy.node.meta["example_value"] = self.current_node.meta[
  763. "example_value"
  764. ]
  765. if "unbacked_bindings" in self.current_node.meta:
  766. result_proxy.node.meta["unbacked_bindings"] = self.current_node.meta[
  767. "unbacked_bindings"
  768. ]
  769. if self.current_node.op != "output":
  770. result_proxy.node._rename(
  771. getattr(self.current_node, "name", result_proxy.node.name)
  772. )
  773. return result_proxy
  774. def transform(self):
  775. result_gm = super().transform()
  776. if "dynamo_flat_name_to_original_fqn" in self.module.meta:
  777. result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
  778. "dynamo_flat_name_to_original_fqn"
  779. ]
  780. return result_gm
  781. class ExportResult(NamedTuple):
  782. graph_module: torch.fx.GraphModule
  783. guards: _guards.GuardsSet
  784. # NB: Do not add new fields without overriding __iter__; people are
  785. # destructuring so it is BC-breaking
  786. def check_signature_rewritable(graph):
  787. input_errors = []
  788. for node in graph.graph.find_nodes(op="placeholder"):
  789. assert hasattr(node, "_dynamo_source")
  790. source = node._dynamo_source
  791. user_stacks = graph._source_to_user_stacks.get(source)
  792. if user_stacks is None:
  793. continue
  794. assert len(user_stacks) > 0
  795. # In some cases we may not have a useful stack. Look for a
  796. # useful stack
  797. stack = None
  798. for s in user_stacks:
  799. if len(s) == 0:
  800. continue
  801. stack = s
  802. break
  803. if stack is None:
  804. msg = f"{source.name()}, a closed over free variable"
  805. else:
  806. tb = "".join(traceback.format_list(stack))
  807. extra = ""
  808. if len(user_stacks) > 1:
  809. extra = f"(elided {len(user_stacks) - 1} more accesses)"
  810. msg = f"{source.name()}, accessed at:\n{tb}{extra}"
  811. # TODO: option to print ALL of the stack traces at once
  812. input_errors.append(msg)
  813. if input_errors:
  814. raise UserError(
  815. UserErrorType.INVALID_INPUT,
  816. "Cannot export model which references tensors that are neither "
  817. "buffers/parameters/constants nor are direct inputs. For each tensor, if you'd "
  818. "like this tensor to be an explicit input, add it as a dummy argument "
  819. "to the top-level model definition you are exporting; if you would "
  820. "like its value to be embedded as an exported constant, wrap its access "
  821. "in a function marked with @assume_constant_result.\n\n"
  822. + "\n\n".join(input_errors),
  823. )
  824. def rewrite_signature(
  825. f_sig,
  826. graph,
  827. fake_mode,
  828. flat_args,
  829. in_spec,
  830. example_fake_inputs,
  831. graph_captured_input,
  832. graph_captured_output,
  833. dynamo_traced_result,
  834. flat_args_dynamic_dims,
  835. ):
  836. orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec)
  837. def check_user_input_output(flat_values, error_type):
  838. supported_types = [
  839. torch.Tensor,
  840. torch.SymInt,
  841. torch.SymFloat,
  842. torch.SymBool,
  843. torch._C.ScriptObject,
  844. ] + list(common_constant_types)
  845. def is_supported_type(val):
  846. return isinstance(val, tuple(supported_types))
  847. value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output"
  848. # We only check that the outputs are not None. Inputs can be None.
  849. for v in flat_values:
  850. if not is_supported_type(v):
  851. if error_type == UserErrorType.INVALID_INPUT and v is None:
  852. continue
  853. raise UserError(
  854. error_type,
  855. f"It looks like one of the {value_type}s with type `{type(v)}` "
  856. "is not supported or pytree-flattenable. \n"
  857. f"Exported graphs {value_type}s can only contain the "
  858. f"following supported types: {supported_types}. \n"
  859. "If you are using a custom class object, "
  860. "please register a pytree_flatten/unflatten function "
  861. "using `torch.utils._pytree.register_pytree_node` or "
  862. "`torch.export.register_dataclass`.",
  863. )
  864. check_user_input_output(flat_args, UserErrorType.INVALID_INPUT)
  865. flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result)
  866. check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT)
  867. def produce_matching(debug_type, sources, candidates):
  868. matched_elements_positions: List[Optional[int]] = []
  869. dict_of_source_vals = {}
  870. for i, val in enumerate(sources):
  871. dict_of_source_vals[id(val)] = i
  872. for i, val in enumerate(candidates):
  873. if isinstance(val, tuple(common_constant_types)):
  874. matched_elements_positions.append(None)
  875. elif id(val) not in dict_of_source_vals:
  876. raise AssertionError(
  877. f"Unexpectedly found a {type(val)} in the {debug_type}.\n"
  878. 'Please file an issue along with a paste of the logs from TORCH_LOGS="+export"'
  879. )
  880. else:
  881. matched_elements_positions.append(dict_of_source_vals[id(val)])
  882. return matched_elements_positions
  883. matched_input_elements_positions = produce_matching(
  884. "inputs", flat_args, graph_captured_input
  885. )
  886. assert graph_captured_output is not None
  887. matched_output_elements_positions = produce_matching(
  888. "outputs", list(graph_captured_output) + flat_args, flat_results_traced
  889. )
  890. new_graph = FlattenInputOutputSignature(
  891. graph,
  892. flat_args,
  893. matched_input_elements_positions,
  894. flat_results_traced,
  895. matched_output_elements_positions,
  896. example_fake_inputs,
  897. flat_args_dynamic_dims,
  898. fake_mode,
  899. ).transform()
  900. # Make dynamo graph to have same input/output spec as user code
  901. def argument_names(f_sig, args, kwargs) -> List[str]:
  902. def signature_to_fullargspec(sig: inspect.Signature):
  903. # Get a list of Parameter objects from the Signature object
  904. params = list(sig.parameters.values())
  905. # Separate positional arguments, keyword-only arguments and varargs/varkw
  906. args = [
  907. p.name
  908. for p in params
  909. if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
  910. ]
  911. kwonlyargs = [
  912. p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
  913. ]
  914. varargs = next(
  915. (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL),
  916. None,
  917. )
  918. varkw = next(
  919. (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD),
  920. None,
  921. )
  922. # Get default values for positional arguments and keyword-only arguments
  923. defaults = tuple(
  924. p.default
  925. for p in params
  926. if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
  927. and p.default is not inspect.Parameter.empty
  928. )
  929. kwonlydefaults = {
  930. p.name: p.default
  931. for p in params
  932. if p.kind == inspect.Parameter.KEYWORD_ONLY
  933. and p.default is not inspect.Parameter.empty
  934. }
  935. # Get annotations for parameters and return value
  936. annotations = {}
  937. if sig.return_annotation:
  938. annotations = {"return": sig.return_annotation}
  939. for parameter in params:
  940. annotations[parameter.name] = parameter.annotation
  941. # Return a FullArgSpec object with the extracted attributes
  942. return inspect.FullArgSpec(
  943. args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
  944. )
  945. fullargspec = signature_to_fullargspec(f_sig)
  946. # 1. Map `args` 1-to-1 to positional arguments in original signature.
  947. input_strs = fullargspec.args[: len(args)]
  948. if len(args) > len(fullargspec.args):
  949. # 2. If there are more arguments left in `args`, they map to varargs in original
  950. # signature. Assign names as {varargs}_0, {varargs}_1, ...
  951. assert fullargspec.varargs is not None, "More arguments than expected"
  952. input_strs += [
  953. f"{fullargspec.varargs}_{i}"
  954. for i in range(0, len(args) - len(input_strs))
  955. ]
  956. elif len(args) < len(fullargspec.args):
  957. # 3. If there are fewer arguments in `args` than `fullargspec.args`,
  958. # it implies these are arguments either with default values, or provided in
  959. # `kwargs`. The former can be safely ignored. Because Dynamo.export does not
  960. # export them as part of the function signature. The latter will be handled
  961. # in the next step.
  962. for unprovided_arg in fullargspec.args[
  963. len(args) : -len(fullargspec.defaults or [])
  964. ]:
  965. assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}"
  966. # 4. Keyword arguments provided in `kwargs`.
  967. input_strs += list(kwargs.keys())
  968. # 5. Keyword-only arguments with default values if not provided are not exported
  969. # as part of the function signature.
  970. for kwonly_arg in fullargspec.kwonlyargs:
  971. kwonlydefaults = fullargspec.kwonlydefaults or {}
  972. assert (
  973. kwonly_arg in kwargs or kwonly_arg in kwonlydefaults
  974. ), f"Missing keyword only argument {kwonly_arg}"
  975. return input_strs
  976. new_graph.graph._codegen = _PyTreeCodeGen(
  977. _PyTreeInfo(
  978. argument_names(f_sig, orig_args, orig_kwargs),
  979. in_spec,
  980. out_spec_traced,
  981. )
  982. )
  983. new_graph.recompile()
  984. return new_graph
  985. def export(
  986. f: Callable[..., Any],
  987. *extra_args,
  988. aten_graph: bool = False,
  989. pre_dispatch: bool = False,
  990. decomposition_table: Optional[
  991. Dict[torch._ops.OpOverload, Callable[..., Any]]
  992. ] = None,
  993. tracing_mode: str = "symbolic",
  994. dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
  995. assume_static_by_default: bool = False,
  996. same_signature: bool = True,
  997. disable_constraint_solver: bool = False,
  998. prefer_deferred_runtime_asserts_over_guards: bool = False,
  999. _allow_complex_guards_as_runtime_asserts: bool = False,
  1000. _log_export_usage: bool = True,
  1001. **extra_kwargs,
  1002. ) -> Callable[..., ExportResult]:
  1003. """
  1004. Export an input function f to a format that can be executed outside of PyTorch using the FX graph.
  1005. Args:
  1006. f (callable): A PyTorch function to be exported.
  1007. aten_graph (bool): If True, exports a graph with ATen operators.
  1008. If False, exports a graph with Python operators. Default is False.
  1009. pre_dispatch (bool): If True, exports a graph with ATen operators,
  1010. but before any logic in the PyTorch dispatcher has run.
  1011. This can be useful if you want to apply further transformations on a graph before running it
  1012. through autograd, autocast, or any other functionalities that are integrated into the dispatcher.
  1013. This flag is only valid if aten_graph=True is set.
  1014. Default is False.
  1015. decomposition_table (dict): A dictionary that maps operators to their decomposition functions.
  1016. Required if aten_graph or tracing_mode is specified. Default is None.
  1017. tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic".
  1018. dynamic_shapes:
  1019. An optional argument where the type should either be:
  1020. 1) a dict from argument names of ``f`` to their dynamic shape specifications,
  1021. 2) a tuple that specifies dynamic shape specifications for each input in original order.
  1022. If you are specifying dynamism on keyword args, you will need to pass them in the order that
  1023. is defined in the original function signature.
  1024. The dynamic shape of a tensor argument can be specified as either
  1025. (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
  1026. not required to include static dimension indices in this dict, but when they are,
  1027. they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
  1028. where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
  1029. are denoted by None. Arguments that are dicts or tuples / lists of tensors are
  1030. recursively specified by using mappings or sequences of contained specifications.
  1031. same_signature (bool): If True, rewrite the returned graph's signature to be the same as f.
  1032. disable_constraint_solver (bool): Whether the dim constraint solver must be disabled.
  1033. Returns:
  1034. A function that given args and kwargs, returns a tuple of (graph, guards)
  1035. Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options.
  1036. Guards: The guards we accumulated during tracing f above
  1037. Raises:
  1038. AssertionError: If decomposition_table is specified without setting aten_graph=True,
  1039. or if graph breaks during tracing in export.
  1040. AssertionError: If Dynamo input and output is not consistent with traced input/output.
  1041. Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
  1042. """
  1043. if _log_export_usage:
  1044. log_export_usage(event="export.private_api", flags={"_dynamo"})
  1045. # Deal with "local variable referenced before assignment"
  1046. _f = f
  1047. _assume_static_by_default = assume_static_by_default
  1048. def inner(*args, **kwargs):
  1049. constraints = _process_dynamic_shapes(_f, args, kwargs, dynamic_shapes)
  1050. f = _f
  1051. assume_static_by_default = _assume_static_by_default
  1052. check_if_dynamo_supported()
  1053. torch._C._log_api_usage_once("torch._dynamo.export")
  1054. if decomposition_table is not None:
  1055. assert (
  1056. aten_graph
  1057. ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
  1058. if pre_dispatch:
  1059. assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
  1060. f = innermost_fn(f)
  1061. call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
  1062. original_signature = inspect.signature(call_to_inspect)
  1063. graph = None
  1064. out_guards = None
  1065. graph_captured_input = None
  1066. graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
  1067. fake_mode = None
  1068. def guard_export_print(guards: _guards.GuardsSet):
  1069. nonlocal out_guards
  1070. assert (
  1071. out_guards is None
  1072. ), "whole graph export entails exactly one guard export"
  1073. out_guards = guards
  1074. example_inputs = []
  1075. def dynamo_normalization_capturing_compiler(
  1076. gm: torch.fx.GraphModule, inner_example_inputs
  1077. ):
  1078. nonlocal graph
  1079. assert (
  1080. graph is None
  1081. ), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
  1082. graph = gm
  1083. nonlocal fake_mode, example_inputs
  1084. # NB: do NOT pass inner_example_inputs here, we are detecting the
  1085. # Dynamo allocated fake mode, which should be DISTINCT from a
  1086. # potential outer ambient fake mode which the user provided.
  1087. # example_inputs is always the user specified inputs, so they
  1088. # would have the wrong fake mode attached to them
  1089. fake_mode = _guards.detect_fake_mode()
  1090. example_inputs = inner_example_inputs
  1091. def result_capturing_wrapper(*graph_inputs):
  1092. nonlocal graph_captured_result
  1093. nonlocal graph_captured_input
  1094. graph_captured_input = graph_inputs
  1095. assert graph is not None
  1096. named_parameters = dict(graph.named_parameters(remove_duplicate=False))
  1097. named_buffers = dict(graph.named_buffers(remove_duplicate=False))
  1098. ambient_fake_mode = (
  1099. _guards.detect_fake_mode(graph_inputs)
  1100. if _guards.detect_fake_mode(graph_inputs) is not None
  1101. else fake_mode
  1102. )
  1103. # We reran fake tensor propagation, but we didn't do
  1104. # anything with the resulting unbacked SymInts. Drop them
  1105. # from the pending list.
  1106. # NB: this is wrong if graph_captured_result has
  1107. # data-dependent output size!
  1108. ignore_fresh_unbacked = null_context()
  1109. if shape_env := ambient_fake_mode.shape_env:
  1110. ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols()
  1111. with (
  1112. ambient_fake_mode
  1113. ), enable_python_dispatcher(), ignore_fresh_unbacked:
  1114. params_and_buffers = {
  1115. **named_parameters,
  1116. **named_buffers,
  1117. }
  1118. fake_params_buffers = dict()
  1119. for name, value in params_and_buffers.items():
  1120. fake_params_buffers[name] = ambient_fake_mode.from_tensor(
  1121. value, static_shapes=True
  1122. )
  1123. fake_graph_inputs = pytree.tree_map(
  1124. ambient_fake_mode.from_tensor, graph_inputs
  1125. )
  1126. graph_captured_result = torch.func.functional_call(
  1127. graph, fake_params_buffers, fake_graph_inputs
  1128. )
  1129. return graph_captured_result
  1130. return result_capturing_wrapper
  1131. # Note: This is needed by rewrite_signature. We need to put it before
  1132. # optimize_assert since user program may mutate the inputs.
  1133. flat_args, in_spec = pytree.tree_flatten((args, kwargs))
  1134. remove_from_cache(f)
  1135. constraint_violation_error = None
  1136. if tracing_mode != "symbolic":
  1137. assume_static_by_default = True
  1138. with config.patch(
  1139. specialize_int=True,
  1140. assume_static_by_default=assume_static_by_default,
  1141. automatic_dynamic_shapes=False,
  1142. capture_dynamic_output_shape_ops=True,
  1143. capture_scalar_outputs=True,
  1144. prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
  1145. _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts,
  1146. ):
  1147. opt_f = optimize_assert(
  1148. dynamo_normalization_capturing_compiler,
  1149. hooks=Hooks(
  1150. guard_export_fn=guard_export_print,
  1151. guard_fail_fn=None,
  1152. ),
  1153. export=True,
  1154. export_constraints=constraints,
  1155. )(f)
  1156. # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
  1157. try:
  1158. result_traced = opt_f(*args, **kwargs)
  1159. except ConstraintViolationError as e:
  1160. constraint_violation_error = e
  1161. remove_from_cache(f)
  1162. if (
  1163. not disable_constraint_solver
  1164. and (shape_env := getattr(fake_mode, "shape_env", None)) is not None
  1165. and (dim_constraints := shape_env.dim_constraints) is not None
  1166. and not isinstance(
  1167. call_to_inspect, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
  1168. )
  1169. and not trace_rules.check(call_to_inspect)
  1170. ):
  1171. dim_constraints.solve()
  1172. dim_constraints.remove_redundant_dynamic_results()
  1173. forced_specializations = dim_constraints.forced_specializations()
  1174. msg = dim_constraints.prettify_results(
  1175. original_signature,
  1176. dynamic_shapes,
  1177. constraint_violation_error,
  1178. forced_specializations,
  1179. )
  1180. if constraint_violation_error:
  1181. constraint_violation_error.args = (
  1182. constraint_violation_error.args[0] + msg,
  1183. )
  1184. else:
  1185. if forced_specializations:
  1186. constraint_violation_error = ConstraintViolationError(msg)
  1187. else:
  1188. log.info(
  1189. "Summary of dimension constraints:%s",
  1190. msg,
  1191. )
  1192. # Error if we have any constraints on static values
  1193. for k in shape_env.var_to_range.keys():
  1194. if isinstance(k, sympy.Integer):
  1195. constraint_violation_error = ConstraintViolationError(
  1196. f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
  1197. "It appears that you're trying to set a constraint on a "
  1198. f"value which we evaluated to have a static value of {k}. "
  1199. 'Set TORCH_LOGS="+export" for more information.'
  1200. )
  1201. if constraint_violation_error:
  1202. raise constraint_violation_error
  1203. assert (
  1204. graph is not None
  1205. ), "Failed to produce a graph during tracing as no tensor operations were found."
  1206. assert hasattr(graph, "_source_to_user_stacks")
  1207. assert out_guards is not None, "Failed to produce guards during tracing"
  1208. assert fake_mode is not None
  1209. log.info(
  1210. "Dynamo captured graph:\n\n%s", graph.print_readable(print_output=False)
  1211. )
  1212. # This check need to happened before aten_graph
  1213. # because placeholder's _source_node attribute is not preserved by make_fx
  1214. if same_signature:
  1215. check_signature_rewritable(graph)
  1216. # NB: This is mostly hitting the cache; Dynamo already converted these
  1217. example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs]
  1218. if aten_graph:
  1219. # Running graph with interpreter is needed for propagating the stack_trace
  1220. def graph_with_interpreter(*args):
  1221. with torch.fx.traceback.preserve_node_meta():
  1222. return torch.fx.Interpreter(graph).run(*args)
  1223. with maybe_disable_fake_tensor_mode(), enable_python_dispatcher(), (
  1224. fake_mode
  1225. ):
  1226. try:
  1227. graph = make_fx(
  1228. graph_with_interpreter,
  1229. decomposition_table=decomposition_table,
  1230. tracing_mode="real",
  1231. _allow_non_fake_inputs=True,
  1232. pre_dispatch=pre_dispatch,
  1233. _allow_fake_constant=False,
  1234. )(*example_fake_inputs)
  1235. except CondOpArgsMismatchError as e:
  1236. # Wrap the internal error to the user-facing error
  1237. raise UserError( # noqa: B904
  1238. UserErrorType.DYNAMIC_CONTROL_FLOW,
  1239. str(e),
  1240. case_name="cond_operands",
  1241. )
  1242. assert graph is not None
  1243. for node in graph.graph.find_nodes(op="get_attr"):
  1244. if isinstance(getattr(graph, node.target), torch.Tensor):
  1245. node.meta["val"] = fake_mode.from_tensor(
  1246. getattr(graph, node.target), static_shapes=True
  1247. )
  1248. if same_signature:
  1249. flat_args_dynamic_dims = [
  1250. {c.dim for c in (constraints or ()) if c.w_tensor() is x}
  1251. for x in flat_args
  1252. ]
  1253. graph = rewrite_signature(
  1254. original_signature,
  1255. graph,
  1256. fake_mode,
  1257. flat_args,
  1258. in_spec,
  1259. example_fake_inputs,
  1260. graph_captured_input,
  1261. graph_captured_result,
  1262. result_traced, # type: ignore[possibly-undefined]
  1263. flat_args_dynamic_dims,
  1264. )
  1265. # Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check
  1266. assert graph is not None
  1267. graph.meta["input_shape_constraints"] = (
  1268. [constraint.serializable_spec for constraint in constraints]
  1269. if constraints
  1270. else []
  1271. )
  1272. return ExportResult(graph, out_guards)
  1273. if extra_args or extra_kwargs:
  1274. warnings.warn(
  1275. "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. "
  1276. "If you don't migrate, we may break your export call in the future if your user defined kwargs "
  1277. "conflict with future kwargs added to export(f).",
  1278. FutureWarning,
  1279. stacklevel=2,
  1280. )
  1281. return inner(*extra_args, **extra_kwargs)
  1282. else:
  1283. return inner
  1284. def optimize_assert(
  1285. backend,
  1286. *,
  1287. hooks=Hooks(None, None),
  1288. export=False,
  1289. export_constraints=None,
  1290. dynamic=None,
  1291. rebuild_ctx=None,
  1292. ):
  1293. """
  1294. The same as `torch._dynamo.optimize(backend, nopython=True)`
  1295. """
  1296. backend = get_compiler_fn(backend)
  1297. # Find if backend has any extra context manager
  1298. backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
  1299. return _optimize_catch_errors(
  1300. convert_frame.convert_frame_assert(
  1301. backend, export=export, export_constraints=export_constraints
  1302. ),
  1303. hooks,
  1304. backend_ctx_ctor,
  1305. export=export,
  1306. dynamic=dynamic,
  1307. rebuild_ctx=rebuild_ctx,
  1308. )
  1309. class TorchPatcher:
  1310. @staticmethod
  1311. @functools.lru_cache(None)
  1312. def patch():
  1313. # A better way to disable the following would be decorate the source
  1314. # functions with @torch._disable_dynamo. However, this causes issues
  1315. # with torch.deploy internally.
  1316. from .decorators import disable
  1317. torch.jit.trace = disable(torch.jit.trace)
  1318. torch.jit.trace_module = disable(torch.jit.trace_module)
  1319. torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph)
  1320. torch.fx._symbolic_trace.Tracer.trace = disable(
  1321. torch.fx._symbolic_trace.Tracer.trace
  1322. )
  1323. torch.distributions.Distribution.set_default_validate_args(False)
  1324. from ..optim import (
  1325. adadelta,
  1326. adagrad,
  1327. adam,
  1328. adamax,
  1329. adamw,
  1330. asgd,
  1331. lbfgs,
  1332. nadam,
  1333. radam,
  1334. rmsprop,
  1335. rprop,
  1336. sgd,
  1337. sparse_adam,
  1338. )
  1339. optimizer_modules = {
  1340. adadelta,
  1341. adagrad,
  1342. adam,
  1343. adamax,
  1344. adamw,
  1345. asgd,
  1346. lbfgs,
  1347. nadam,
  1348. radam,
  1349. rmsprop,
  1350. rprop,
  1351. sgd,
  1352. sparse_adam,
  1353. }
  1354. for opt_mod in optimizer_modules:
  1355. opt_name = opt_mod.__name__.split(".")[-1]
  1356. fused_fn_name = f"_fused_{opt_name}"
  1357. single_tensor_fn_name = f"_single_tensor_{opt_name}"
  1358. if hasattr(opt_mod, fused_fn_name):
  1359. setattr(
  1360. opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name))
  1361. )
  1362. optimizer_classes = [
  1363. opt
  1364. for opt in torch.optim.__dict__.values()
  1365. if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
  1366. ]
  1367. # Note: we don't support sparsity or tracing through backwards
  1368. excluded_optimizer_classes = {
  1369. torch.optim.SparseAdam,
  1370. torch.optim.LBFGS,
  1371. }
  1372. for opt in optimizer_classes:
  1373. if opt in excluded_optimizer_classes:
  1374. opt.step = disable(opt.step)
  1375. if hasattr(opt, "_init_group"):
  1376. opt._init_group = disable(opt._init_group)
  1377. @staticmethod
  1378. def suppress_torch_distributed_warnings(fn):
  1379. def inner_fn(*args, **kwargs):
  1380. warnings.filterwarnings(
  1381. "ignore", category=UserWarning, module="torch.distributed"
  1382. )
  1383. return fn(*args, **kwargs)
  1384. return inner_fn