_guards.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import contextlib
  4. import dataclasses
  5. import enum
  6. import functools
  7. import logging
  8. import threading
  9. import traceback
  10. import unittest.mock
  11. import weakref
  12. from abc import abstractmethod
  13. from contextlib import contextmanager
  14. from typing import (
  15. Any,
  16. Callable,
  17. Dict,
  18. Generic,
  19. List,
  20. NamedTuple,
  21. Optional,
  22. Set,
  23. Tuple,
  24. TYPE_CHECKING,
  25. TypeVar,
  26. )
  27. from torch.utils import _pytree as pytree
  28. from torch.utils._traceback import CapturedTraceback
  29. from torch.utils.weak import WeakTensorKeyDictionary
  30. log = logging.getLogger(__name__)
  31. if TYPE_CHECKING:
  32. import sympy
  33. # Import the following modules during type checking to enable code intelligence features,
  34. # such as auto-completion in tools like pylance, even when these modules are not explicitly
  35. # imported in user code.
  36. import torch
  37. """
  38. torch._guards is the definitional source of truth for general purpose guard structures.
  39. An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
  40. and no guard installation notions here.
  41. """
  42. class CompileId(NamedTuple):
  43. frame_id: int
  44. # This id is per-frame, and counts how many times we've compiled this
  45. # frame. This could have been a global id but having this be per-frame
  46. # gives you a better intuitive sense for how many recompiles have occurred
  47. # so far.
  48. frame_compile_id: int
  49. # TODO: consider also tracking the recompilation count
  50. def __str__(self):
  51. return f"{self.frame_id}/{self.frame_compile_id}"
  52. class TraceId(NamedTuple):
  53. compile_id: CompileId
  54. # This starts off as 0, and every time we restart analysis it goes
  55. # up by one
  56. attempt: int
  57. def __str__(self):
  58. if self.attempt == 0:
  59. return str(self.compile_id)
  60. else:
  61. return f"{self.compile_id}_{self.attempt}"
  62. class GuardSource(enum.Enum):
  63. LOCAL = 0
  64. GLOBAL = 1
  65. LOCAL_NN_MODULE = 2
  66. GLOBAL_NN_MODULE = 3
  67. CONSTANT = 4
  68. RANDOM_VALUE = 5
  69. SHAPE_ENV = 6
  70. LOCAL_FSDP_MODULE = 7
  71. GLOBAL_FSDP_MODULE = 8
  72. BACKWARD_STATE = 9
  73. EPHEMERAL = 10
  74. SYNTHETIC_LOCAL = 11
  75. def is_fsdp_module(self) -> bool:
  76. return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
  77. def is_nn_module(self) -> bool:
  78. return (
  79. self
  80. in (
  81. GuardSource.GLOBAL_NN_MODULE,
  82. GuardSource.LOCAL_NN_MODULE,
  83. )
  84. or self.is_fsdp_module()
  85. )
  86. def is_local(self):
  87. return self in (
  88. GuardSource.LOCAL,
  89. GuardSource.LOCAL_NN_MODULE,
  90. GuardSource.LOCAL_FSDP_MODULE,
  91. )
  92. """
  93. Base class for a "GuardBuilder" role.
  94. The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
  95. confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
  96. to torchdynamo's GuardBuilder.
  97. Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
  98. on GuardSource's select function.
  99. There is value in keeping this GuardBuilderBase empty to keep layering clean.
  100. """
  101. class GuardBuilderBase:
  102. pass
  103. class ShapeGuard(NamedTuple):
  104. expr: sympy.Expr
  105. stack: CapturedTraceback
  106. @dataclasses.dataclass
  107. class Guard:
  108. # originating_source is the source that called the make_guard method to
  109. # construct this guard object. The property name specifies what exactly it
  110. # is the guard is guarding on. The meaning of the name is dependent on the
  111. # create_fn; you must look at the use-site inside create_fn to know what
  112. # name means.
  113. #
  114. # That being said, although you might think this is just a "name", name is
  115. # usually an arbitrary Python expression that will be evaluated with all
  116. # globals (and locals, if you create a LOCAL guard) to extract the Python
  117. # object that we want to perform guard tests on. This evaluation
  118. # typically happens in GuardBuilder.eval. In these cases, name is
  119. # typically produced by originating_source.name() (not to be confused with
  120. # GuardSource - the property source).
  121. #
  122. # Occasionally, name is not a valid Python expression; sometimes
  123. # it is meaningless. Example create_fns that are like this include
  124. # GRAD_MODE and SHAPE_ENV.
  125. originating_source: Source
  126. create_fn: Callable[[GuardBuilderBase, Guard], None]
  127. # Export only. These values are written to at time of guard check_fn creation.
  128. guard_types: Optional[List[str]] = None
  129. code_list: Optional[List[str]] = None
  130. obj_weakref: Optional[object] = None
  131. guarded_class_weakref: Optional[type] = None
  132. stack: Optional[CapturedTraceback] = None
  133. user_stack: Optional[traceback.StackSummary] = None
  134. _hash: Optional[int] = None
  135. def __hash__(self):
  136. if self._hash is None:
  137. self._hash = hash((self.name, self.source, id(self.create_fn)))
  138. return self._hash
  139. def sort_key(self):
  140. # Put the duplicate input guards at the end. The duplicate guards have
  141. # two sources while guard.name only considers one source.
  142. from ._dynamo.guards import GuardBuilder
  143. is_duplicate_input = (
  144. isinstance(self.create_fn, functools.partial)
  145. and self.create_fn.func is GuardBuilder.DUPLICATE_INPUT
  146. )
  147. return (
  148. is_duplicate_input,
  149. self.source.value if self.source else -1,
  150. len(self.name),
  151. self.name,
  152. self.inner_create_fn().__code__.co_firstlineno,
  153. )
  154. def __lt__(self, other):
  155. return self.sort_key() < other.sort_key()
  156. def inner_create_fn(self):
  157. if isinstance(self.create_fn, functools.partial):
  158. return self.create_fn.func
  159. else:
  160. return self.create_fn
  161. @property
  162. def name(self) -> str:
  163. return self.originating_source.name()
  164. @property
  165. def source(self) -> GuardSource:
  166. return self.originating_source.guard_source()
  167. @staticmethod
  168. def weakref_to_str(obj_weakref):
  169. """
  170. This is a workaround of a Python weakref bug.
  171. `obj_weakref` is instance returned by `weakref.ref`,
  172. `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
  173. class MyConfig(dict):
  174. def __getattr__(self, x):
  175. return self[x]
  176. obj = MyConfig(offset=5)
  177. obj_weakref = weakref.ref(obj)
  178. str(obj_weakref) # raise error: KeyError: '__name__'
  179. """
  180. if isinstance(obj_weakref, weakref.ReferenceType):
  181. obj = obj_weakref()
  182. if obj is not None:
  183. return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
  184. else:
  185. return f"<weakref at {hex(id(obj_weakref))}; dead>"
  186. else:
  187. return str(obj_weakref)
  188. def __repr__(self):
  189. s = f"""
  190. {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
  191. {{
  192. 'guard_types': {self.guard_types},
  193. 'code': {self.code_list},
  194. 'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
  195. 'guarded_class': {self.guarded_class_weakref}
  196. }}
  197. """
  198. return s
  199. def __str__(self):
  200. output = f"Name: {repr(self.name)}\n"
  201. source = self.source.name.lower() if self.source else ""
  202. output += f" Source: {source}\n"
  203. output += f" Create Function: {self.inner_create_fn().__name__}\n"
  204. output += f" Guard Types: {self.guard_types}\n"
  205. output += f" Code List: {self.code_list}\n"
  206. output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n"
  207. output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
  208. return output
  209. def create(self, builder: GuardBuilderBase):
  210. try:
  211. return self.create_fn(builder, self)
  212. except Exception:
  213. log.exception("Error while creating guard:\n%s", str(self).rstrip())
  214. if self.stack:
  215. log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
  216. raise
  217. def is_nn_module(self):
  218. return self.source.is_nn_module()
  219. def is_fsdp_module(self):
  220. return self.source.is_fsdp_module()
  221. def is_local(self):
  222. return self.source.is_local()
  223. def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
  224. if not self.guard_types:
  225. self.guard_types = list()
  226. self.guard_types.append(guard_type)
  227. assert self.guarded_class_weakref in (
  228. guarded_class,
  229. None,
  230. ), "Guarded class id must be identical, or None"
  231. self.guarded_class_weakref = guarded_class
  232. if not self.code_list:
  233. self.code_list = code_list
  234. else:
  235. self.code_list.extend(code_list)
  236. # Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have
  237. # multiple guards on the same object, the weakref can die between the
  238. # invocation of set_export_info calls. So a dead weakref is also
  239. # acceptable.
  240. assert (
  241. self.obj_weakref
  242. in (
  243. obj_weakref,
  244. None,
  245. )
  246. or callable(self.obj_weakref)
  247. and self.obj_weakref() is None
  248. ), "Guarded object must be identical, None or ephemeral (dead weakref)"
  249. self.obj_weakref = obj_weakref
  250. T = TypeVar("T")
  251. """
  252. Parent structure for guard env expressions.
  253. A GuardEnvExpr can have any subtype.
  254. Note: All subtypes must be handled exhaustively in
  255. torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
  256. """
  257. @dataclasses.dataclass
  258. class GuardEnvExpr:
  259. pass
  260. """
  261. A class representing a pair of duplicate inputs.
  262. input_pos_a and input_pos_b are input positions we have deduped.
  263. """
  264. @dataclasses.dataclass
  265. class DuplicateInputs(GuardEnvExpr):
  266. input_source_a: Source
  267. input_source_b: Source
  268. def __post_init__(self):
  269. assert self.input_source_a != self.input_source_b
  270. """
  271. Checkpointable is an interface for driving state snapshotting, left purposely vague for now.
  272. copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that
  273. can also be taken in at restore_graphstate(T) calls.
  274. When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable
  275. does not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet.
  276. In the future, it will have a closer coupling to a generic Checkpoint management system.
  277. """
  278. class Checkpointable(Generic[T]):
  279. @abstractmethod
  280. def copy_graphstate(self) -> T:
  281. ...
  282. @abstractmethod
  283. def restore_graphstate(self, state: T):
  284. ...
  285. class GuardsCheckpointState:
  286. """
  287. The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
  288. """
  289. dynamo_guards: Set[Guard] = set()
  290. def __init__(self, dynamo_guards):
  291. self.dynamo_guards = dynamo_guards
  292. def diff(self, other):
  293. """
  294. Produces a delta against another GuardsCheckpointState.
  295. Returns None if no delta is found, otherwise, return a set() of mismatched
  296. Guard type objects.
  297. """
  298. r = self.dynamo_guards.difference(other.dynamo_guards)
  299. if len(r) == 0:
  300. return None
  301. return r
  302. def __eq__(self, other):
  303. return self.diff(other) is None
  304. class ModuleContextCheckpointState:
  305. nn_modules: Dict[str, torch.nn.Module] = {}
  306. def __init__(self, nn_modules):
  307. self.nn_modules = nn_modules
  308. def diff(self, other):
  309. """
  310. Produces a delta against another ModuleContextCheckpointState.
  311. Returns None if no delta is found, otherwise, return a set() of mismatched
  312. module key names.
  313. """
  314. r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys()))
  315. if len(r) == 0:
  316. return None
  317. return r
  318. def __eq__(self, other):
  319. return self.diff(other) is None
  320. class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
  321. def __init__(self):
  322. self.nn_modules: Dict[str, Any] = {}
  323. def copy_graphstate(self):
  324. return ModuleContextCheckpointState(dict(self.nn_modules))
  325. def restore_graphstate(self, state):
  326. assert isinstance(state, ModuleContextCheckpointState)
  327. self.nn_modules = state.nn_modules
  328. class GlobalContextCheckpointState:
  329. global_state: Dict[str, Tuple[Callable, ...]] = {}
  330. def __init__(self, global_states):
  331. self.global_state = global_states
  332. def diff(self, other):
  333. """
  334. Produces a delta against another GlobalContextCheckpointState.
  335. Returns None if no delta is found, otherwise, return a set() of mismatched
  336. global key names.
  337. """
  338. r = set(self.global_state.keys()).difference(set(other.global_state.keys()))
  339. if len(r) == 0:
  340. return None
  341. return r
  342. def __eq__(self, other):
  343. return self.diff(other) is None
  344. class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
  345. """
  346. This keeps track of the global torch state during tracing of a function.
  347. For example, torch.is_grad_enabled.
  348. """
  349. _supported_global_states = {
  350. "grad_enabled",
  351. "torch_function_enabled",
  352. "autocast_enabled",
  353. "autocast_cpu_enabled",
  354. "autocast_gpu_dtype",
  355. "autocast_cpu_dtype",
  356. "autocast_cache_enabled",
  357. }
  358. def __init__(self):
  359. self.global_state: Dict[str, Tuple[Callable, ...]] = {}
  360. def copy_graphstate(self):
  361. return GlobalContextCheckpointState(dict(self.global_state))
  362. def restore_graphstate(self, state):
  363. assert isinstance(state, GlobalContextCheckpointState)
  364. self.global_state = state.global_state
  365. assert (
  366. len(self.global_state) == len(self._supported_global_states)
  367. and set(self.global_state.keys()) == self._supported_global_states
  368. ), "Global state mismatch"
  369. for func, args in self.global_state.values():
  370. func(args)
  371. """
  372. A GuardsContext is a checkpointable representation of all the guards in the current tracing
  373. context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated
  374. directly outside of it. For passing around internal state representations of this object,
  375. prefer to extract them with copy_graphstate to produce a GuardsCheckpointState.
  376. """
  377. # Like a Set[Guard] but will record the user stack on all guards at the
  378. # time they were installed at their destination
  379. class GuardsSet:
  380. def __init__(self, inner=None):
  381. if inner is None:
  382. inner = set()
  383. self.inner = inner
  384. def __iter__(self):
  385. return iter(self.inner)
  386. def __len__(self):
  387. return len(self.inner)
  388. # Subtraction along with bool is typically used to determine the delta of
  389. # added guards between checkpoints for higher order ops
  390. def __sub__(self, other):
  391. return GuardsSet(self.inner - other.inner)
  392. def __bool__(self):
  393. return bool(self.inner)
  394. def add(self, guard: Guard, *, collect_debug_stack=True, skip=0):
  395. if guard in self.inner:
  396. return
  397. if collect_debug_stack:
  398. if guard.stack is None:
  399. guard.stack = CapturedTraceback.extract(skip=1 + skip)
  400. if guard.user_stack is None:
  401. guard.user_stack = TracingContext.extract_stack()
  402. self.inner.add(guard)
  403. def update(self, *others: Set[Guard]):
  404. for o in others:
  405. for g in o:
  406. self.add(g, skip=1)
  407. def remove_guards_with_source(self, source):
  408. """Delete all guards with a given source"""
  409. self.inner = {g for g in self.inner if g.originating_source != source}
  410. class GuardsContext(Checkpointable[GuardsCheckpointState]):
  411. def __init__(self):
  412. self.dynamo_guards: GuardsSet = GuardsSet()
  413. self.aotautograd_guards: List[GuardEnvExpr] = []
  414. def copy_graphstate(self):
  415. return GuardsCheckpointState(set(self.dynamo_guards.inner))
  416. def restore_graphstate(self, state):
  417. # NB: "steals" the passed in state
  418. assert isinstance(state, GuardsCheckpointState)
  419. self.dynamo_guards = GuardsSet(state.dynamo_guards)
  420. _TLS = threading.local()
  421. """
  422. TracingContext is the source of truth for all currently accumulated information
  423. needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems
  424. are open to managing their own TracingContext with that in mind.
  425. The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid
  426. having to plumb complex subsystems across multiple verticals.
  427. Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor.
  428. Accessing the current tracing context via
  429. TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how
  430. to plumb objects back up to where frame interpretation happened.
  431. Note that you can end up with multiple TracingContext for a single compilation
  432. of a frame, as we reset the TracingContext whenever we restart analysis.
  433. CompileContext is a more overarching context that encompasses multiple restarts.
  434. """
  435. class CompileContext:
  436. @staticmethod
  437. def get() -> CompileContext:
  438. assert _TLS.compile_context is not None
  439. return _TLS.compile_context
  440. @staticmethod
  441. def try_get() -> Optional[CompileContext]:
  442. return getattr(_TLS, "compile_context", None)
  443. def __init__(self, compile_id):
  444. assert compile_id is None or isinstance(compile_id, CompileId)
  445. self.compile_id: Optional[CompileId] = compile_id
  446. self.attempt = 0
  447. @staticmethod
  448. def current_compile_id():
  449. self = CompileContext.try_get()
  450. if self is None:
  451. return None
  452. return self.compile_id
  453. @staticmethod
  454. def current_trace_id():
  455. self = CompileContext.try_get()
  456. if self is None:
  457. return None
  458. if self.compile_id is None:
  459. return None
  460. return TraceId(self.compile_id, self.attempt)
  461. class TracingContext:
  462. """
  463. Provides the currently installed TracingContext, or None.
  464. Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but
  465. will return None.
  466. """
  467. @staticmethod
  468. def try_get() -> Optional[TracingContext]:
  469. return getattr(_TLS, "tracing_context", None)
  470. @staticmethod
  471. def get() -> TracingContext:
  472. if ctx := TracingContext.try_get():
  473. return ctx
  474. raise RuntimeError(
  475. "TracingContext.get() must be called within an ongoing trace."
  476. )
  477. def __init__(self, fake_mode):
  478. self.guards_context = GuardsContext()
  479. self.module_context = ModuleContext()
  480. self.global_context = GlobalContext()
  481. self.fake_mode = fake_mode
  482. self.frame_summary_stack = []
  483. # This is morally part of frame_summary_stack, but it is kept separate
  484. # for clarity. As we process a frame, this variable gets updated
  485. # to keep track of what line we are in the function. We make a
  486. # function call, this gets cleared and the frame location is pushed
  487. # to frame_summary_stack (prepping this variable for the inner frame's
  488. # progress)
  489. self.loc_in_frame = None
  490. # this is only set after aot_autograd
  491. self.fw_metadata = None
  492. # this is only set after aot_autograd
  493. self.aot_graph_name = None
  494. self.params_flat = None
  495. # this is for extended return calling convention from backend
  496. # compiler to aot_autograd
  497. # Per output, what the compiler specified stride of the output is,
  498. # or None if no stride is known. This is always the HINT, it
  499. # is never a SymInt (it would be better if it was a SymInt, but
  500. # I can't conveniently get this from Inductor atm. Also, be
  501. # careful not to accidentally induce guards on the SymInt if
  502. # you ever do change this in aot_autograd.py; you should check
  503. # on permutations preferentially.)
  504. self.output_strides: Optional[List[Optional[List[int]]]] = None
  505. # When this is True, whenever we encounter an int in Dynamo tracing,
  506. # we will (1) force unspec it and (2) force it as a size-like unbacked
  507. # integer. This is currently used when processing certain lists of
  508. # ints that are known to be size-like and may have 0/1 entries that we
  509. # must not specialize on.
  510. self.force_unspec_int_unbacked_size_like = False
  511. # See note [Tensor Fakification and Symbol Caching]
  512. self.tensor_to_context = WeakTensorKeyDictionary()
  513. # If this true, Aot Autograd will return output Fake Tensors with appropiate
  514. # meta on the first invocation
  515. # see note: [Returning Fake Tensors on First AOT Autograd Call]
  516. self.fakify_first_call = False
  517. def clear(self):
  518. # Look at the note in output_graph.py in function `save_global_state`
  519. # for the context on clearing global context.
  520. self.global_context.global_state = {}
  521. @staticmethod
  522. @contextmanager
  523. def patch(**kwargs):
  524. prior = {}
  525. ctx = TracingContext.get()
  526. for key in kwargs.keys():
  527. # KeyError on invalid entry
  528. prior[key] = getattr(ctx, key)
  529. for key, val in kwargs.items():
  530. setattr(ctx, key, val)
  531. try:
  532. yield
  533. finally:
  534. for key, val in prior.items():
  535. setattr(ctx, key, val)
  536. @staticmethod
  537. def extract_stack():
  538. self = TracingContext.try_get()
  539. if self is None:
  540. return traceback.StackSummary()
  541. stack = self.frame_summary_stack
  542. if self.loc_in_frame is not None:
  543. stack = stack + [self.loc_in_frame]
  544. return traceback.StackSummary.from_list(stack)
  545. # Call this when you want to call into some code that isn't necessarily
  546. # associated with the current frame state
  547. @staticmethod
  548. @contextlib.contextmanager
  549. def clear_frame():
  550. tc = TracingContext.get()
  551. with unittest.mock.patch.object(
  552. tc, "frame_summary_stack", []
  553. ), unittest.mock.patch.object(tc, "loc_in_frame", None):
  554. try:
  555. yield
  556. except Exception as e:
  557. # Prevent real_stack from getting attached
  558. #
  559. # The invariant is that if an Exception as real_stack, we've
  560. # appropriately attached a user stack and we no longer need to
  561. # attach anything. Because we cannot conveniently interpose
  562. # when an exception is thrown, we instead interpose everywhere
  563. # we set what the user stack is set (using the context
  564. # manager). However, our compiler stack does "tail calls"
  565. # (when it calls into user compiler), at which point the
  566. # parent exception frames would incorrectly attach an
  567. # incorrect frame.
  568. #
  569. # However, if, somehow, someone raised an exception with this
  570. # scope that had a stack (for example, because they are
  571. # restoring the user stack state appropriately as they process
  572. # node by node), we should respect it. Thus, we cannot
  573. # unconditionally set None.
  574. if not hasattr(e, "real_stack"):
  575. e.real_stack = None # type: ignore[attr-defined]
  576. raise
  577. @staticmethod
  578. @contextlib.contextmanager
  579. def current_frame(frame_summary):
  580. # frame_summary can be None to solely take advantage of real_stack
  581. # attachment to thrown exceptions
  582. tc = TracingContext.get()
  583. if frame_summary is not None:
  584. tc.frame_summary_stack.append(frame_summary)
  585. old = tc.loc_in_frame
  586. tc.loc_in_frame = None
  587. try:
  588. yield
  589. except Exception as e:
  590. if not hasattr(e, "real_stack"):
  591. e.real_stack = tc.extract_stack() # type: ignore[attr-defined]
  592. raise
  593. finally:
  594. if frame_summary is not None:
  595. tc.frame_summary_stack.pop()
  596. tc.loc_in_frame = old
  597. @staticmethod
  598. @contextlib.contextmanager
  599. def report_output_strides():
  600. tc = TracingContext.try_get()
  601. if tc is None:
  602. yield None
  603. return
  604. old_output_strides = tc.output_strides
  605. tc.output_strides = []
  606. try:
  607. yield tc.output_strides
  608. finally:
  609. tc.output_strides = old_output_strides
  610. @staticmethod
  611. def set_current_loc(filename, lineno, frame_name):
  612. TracingContext.get().loc_in_frame = traceback.FrameSummary(
  613. filename, lineno, frame_name, lookup_line=False
  614. )
  615. @contextmanager
  616. def compile_context(context: Optional[CompileContext]):
  617. old_context = getattr(_TLS, "compile_context", None)
  618. _TLS.compile_context = context
  619. try:
  620. yield context
  621. finally:
  622. _TLS.compile_context = old_context
  623. @contextmanager
  624. def tracing(context: Optional[TracingContext]):
  625. """
  626. This function installs the passed in tracing context as a dynamic scoped
  627. global variable.
  628. Calls to TracingContext.get() while not under a `with tracing()` context
  629. will return None.
  630. """
  631. old_context = getattr(_TLS, "tracing_context", None)
  632. _TLS.tracing_context = context
  633. try:
  634. yield context
  635. except Exception as e:
  636. if not hasattr(e, "real_stack") and context is not None:
  637. e.real_stack = context.extract_stack() # type: ignore[attr-defined]
  638. raise
  639. finally:
  640. if (
  641. context is not None
  642. and context.fake_mode is not None
  643. and context.fake_mode.shape_env is not None
  644. ):
  645. context.fake_mode.shape_env.cleanup()
  646. _TLS.tracing_context = old_context
  647. # Subclasses can be found in torch/_dynamo/source.py
  648. # TODO(voz): Consider a toplevel torch/_source.py
  649. @dataclasses.dataclass(frozen=True)
  650. class Source:
  651. def is_dict_key(self):
  652. return False
  653. def is_ephemeral(self):
  654. return False
  655. def reconstruct(self, codegen):
  656. raise NotImplementedError
  657. def guard_source(self) -> GuardSource:
  658. raise NotImplementedError
  659. def name(self) -> str:
  660. raise NotImplementedError
  661. def make_guard(self, fn) -> Guard:
  662. if self.guard_source() is GuardSource.CONSTANT:
  663. raise NotImplementedError
  664. return Guard(self, fn)
  665. def is_nn_module(self) -> bool:
  666. return self.guard_source().is_nn_module()
  667. def subguards_allowed(self):
  668. """True if you can guard on attributes of this"""
  669. return self.guard_source() != GuardSource.SYNTHETIC_LOCAL
  670. # Subclasses can be found in torch/_dynamo/source.py
  671. @dataclasses.dataclass(frozen=True)
  672. class ChainedSource(Source):
  673. base: Source
  674. def is_dict_key(self):
  675. # Recurse until you either hit a ConstDictKey or a Source
  676. return self.base.is_dict_key()
  677. def is_ephemeral(self):
  678. return self.base.is_ephemeral()
  679. def detect_fake_mode(inputs: Any = None):
  680. """
  681. Attempts to "detect" what the current fake mode is. If there is one ambiently
  682. available from TracingContext, we preferentially use that. Otherwise, we
  683. heuristically detect the fake mode via the following sources, in order of
  684. priority:
  685. - Currently active fake mode on stack
  686. - Fake mode associated with passed in tensors (inputs does not
  687. have to be flattened)
  688. """
  689. from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
  690. fake_modes = []
  691. if context := TracingContext.try_get():
  692. fake_mode = context.fake_mode
  693. if fake_mode is not None:
  694. fake_modes.append((fake_mode, "tracing context", 0))
  695. from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
  696. for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
  697. if isinstance(m, FakeTensorMode):
  698. fake_modes.append((m, "active fake mode", i))
  699. flat_inputs = pytree.tree_leaves(inputs)
  700. for i, flat_input in enumerate(flat_inputs):
  701. if isinstance(flat_input, FakeTensor):
  702. fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
  703. if fake_modes:
  704. fake_mode, desc1, i1 = fake_modes[0]
  705. for m, desc2, i2 in fake_modes[1:]:
  706. assert fake_mode is m, (
  707. f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
  708. f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
  709. f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
  710. )
  711. return fake_mode
  712. else:
  713. return None
  714. def active_fake_mode():
  715. """
  716. Inspects the dispatch mode stack for an active fake mode and returns it.
  717. Returns None if no fake mode is active.
  718. """
  719. from torch._subclasses.fake_tensor import FakeTensorMode
  720. from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
  721. for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
  722. if isinstance(m, FakeTensorMode):
  723. return m
  724. return None