ctx_manager.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993
  1. # mypy: ignore-errors
  2. import dataclasses
  3. import inspect
  4. import sys
  5. import warnings
  6. from typing import Callable, Dict, List, Optional
  7. import torch._C
  8. from torch._guards import Guard
  9. from .. import variables
  10. from ..bytecode_transformation import (
  11. create_call_function,
  12. create_instruction,
  13. create_setup_with,
  14. )
  15. from ..device_interface import get_interface_for_device
  16. from ..exc import unimplemented, Unsupported
  17. from ..guards import GuardBuilder, install_guard
  18. from ..source import AttrSource, GlobalStateSource
  19. from .base import VariableTracker
  20. from .functions import (
  21. NestedUserFunctionVariable,
  22. UserFunctionVariable,
  23. UserMethodVariable,
  24. WrappedUserFunctionVariable,
  25. WrappedUserMethodVariable,
  26. )
  27. @dataclasses.dataclass
  28. class ContextMangerState:
  29. """
  30. Mutating `self` in VariableTracker is not allowed because we copy
  31. them. This is a mutable container pointed to by context managers
  32. that won't get copied, so it is safe to mutate.
  33. """
  34. cleanup_fn: Optional[Callable] = None
  35. proxy: Optional[torch.fx.Proxy] = None
  36. def cleanup(self):
  37. if self.cleanup_fn is not None:
  38. self.cleanup_fn()
  39. self.cleanup_fn = None
  40. def cleanup_assert(self):
  41. assert self.cleanup_fn, "multiple exits?"
  42. self.cleanup()
  43. class ContextWrappingVariable(VariableTracker):
  44. _nonvar_fields = {
  45. "cm_obj",
  46. "target_values",
  47. "initial_values",
  48. "state",
  49. *VariableTracker._nonvar_fields,
  50. }
  51. def __init__(self, target_values, initial_values=None, *, state=None, **kwargs):
  52. super().__init__(**kwargs)
  53. self.target_values = target_values
  54. self.initial_values = initial_values
  55. self.state = ContextMangerState() if state is None else state
  56. def enter(self, tx):
  57. self._call_func(tx, self.target_values)
  58. self.set_cleanup_hook(tx)
  59. return variables.ConstantVariable.create(None)
  60. def set_cleanup_hook(self, tx, fn=None):
  61. if fn is None:
  62. def fn():
  63. self._call_func(tx, self.initial_values)
  64. self.state.cleanup_fn = fn
  65. tx.output.add_cleanup_hook(self.state.cleanup)
  66. def exit(self, tx, *args):
  67. self.state.cleanup_assert()
  68. return variables.ConstantVariable.create(None)
  69. def reconstruct_type(self, codegen):
  70. codegen(
  71. AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name())
  72. )
  73. def reconstruct(self, codegen):
  74. if sys.version_info >= (3, 11):
  75. codegen.append_output(create_instruction("PUSH_NULL"))
  76. self.reconstruct_type(codegen)
  77. target_values = self.target_values
  78. if not target_values:
  79. target_values = ()
  80. codegen.extend_output([codegen.create_load_const(val) for val in target_values])
  81. codegen.extend_output(create_call_function(len(target_values), False))
  82. def module_name(self):
  83. raise NotImplementedError("module_name called on base")
  84. def fn_name(self):
  85. raise NotImplementedError("fn_name called on base")
  86. def call_function(
  87. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  88. ) -> "VariableTracker":
  89. assert len(args) == 1
  90. if isinstance(args[0], NestedUserFunctionVariable):
  91. args[0] = UserFunctionVariable(args[0].get_function())
  92. assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable))
  93. if isinstance(args[0], UserMethodVariable):
  94. return WrappedUserMethodVariable(args[0], self)
  95. if isinstance(args[0], UserFunctionVariable):
  96. return WrappedUserFunctionVariable(args[0], self)
  97. class GenericContextWrappingVariable(ContextWrappingVariable):
  98. def __init__(self, target_values, initial_values=None, *, cm_obj=None, **kwargs):
  99. assert cm_obj is not None
  100. super().__init__(
  101. target_values=target_values, initial_values=initial_values, **kwargs
  102. )
  103. self.cm_obj = cm_obj
  104. def enter(self, tx):
  105. source = None if self.source is None else AttrSource(self.source, "__enter__")
  106. try:
  107. return variables.UserMethodVariable(
  108. self.cm_obj.__enter__.__func__,
  109. variables.UserDefinedObjectVariable(self.cm_obj),
  110. source=source,
  111. ).call_function(tx, [], {})
  112. except Unsupported as e:
  113. unimplemented(
  114. f"Unsupported context manager {self.cm_obj}'s __enter__ function",
  115. from_exc=e,
  116. )
  117. def exit(self, tx, *args):
  118. source = None if self.source is None else AttrSource(self.source, "__exit__")
  119. try:
  120. x = variables.UserMethodVariable(
  121. self.cm_obj.__exit__.__func__,
  122. variables.UserDefinedObjectVariable(self.cm_obj),
  123. source=source,
  124. ).call_function(
  125. tx,
  126. [
  127. variables.ConstantVariable.create(None),
  128. variables.ConstantVariable.create(None),
  129. variables.ConstantVariable.create(None),
  130. ],
  131. {},
  132. )
  133. except Unsupported as e:
  134. unimplemented(
  135. f"Unsupported context manager {self.cm_obj}'s __exit__ function",
  136. from_exc=e,
  137. )
  138. tx.generic_context_manager_depth -= 1
  139. return x
  140. class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
  141. """represents torch grad requries grad"""
  142. @staticmethod
  143. def create(tx, target_values, **kwargs):
  144. return GradInplaceRequiresGradCtxManagerVariable(
  145. target_values=target_values,
  146. initial_values=None,
  147. **kwargs,
  148. )
  149. def enter(self, tx):
  150. [enabled] = self.target_values
  151. self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed()
  152. torch._C._functorch.set_inplace_requires_grad_allowed(enabled)
  153. self.set_cleanup_hook(
  154. tx,
  155. lambda: torch._C._functorch.set_inplace_requires_grad_allowed(
  156. self.prev_state
  157. ),
  158. )
  159. self.state.proxy = tx.output.create_node(
  160. "call_function",
  161. torch._C._functorch.set_inplace_requires_grad_allowed,
  162. (enabled,),
  163. {},
  164. )
  165. return variables.ConstantVariable.create(None)
  166. def exit(self, tx, *args):
  167. self.state.cleanup()
  168. tx.output.create_node(
  169. "call_function",
  170. torch._C._functorch.set_inplace_requires_grad_allowed,
  171. (self.prev_state,),
  172. {},
  173. )
  174. return variables.ConstantVariable.create(None)
  175. class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable):
  176. """represents torch.func.jvp increment/decrement nesting"""
  177. # A guard is needed as the grad level is baked into the torch FX graph
  178. # This is fine if jvp is only called from within the function
  179. # being compiled. But the FX graph may be invalid in the case of a jvp
  180. # call from eager that calls the compiled function, as the jvp levels
  181. # may be different.
  182. _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
  183. @staticmethod
  184. def create(tx, **kwargs):
  185. var = JvpIncrementNestingCtxManagerVariable(
  186. target_values=None,
  187. initial_values=None,
  188. **kwargs,
  189. )
  190. return var
  191. def enter(self, tx):
  192. install_guard(self._guards_singleton)
  193. jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting()
  194. self.set_cleanup_hook(
  195. tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting()
  196. )
  197. self.state.proxy = tx.output.create_node(
  198. "call_function",
  199. torch._C._functorch._jvp_increment_nesting,
  200. (),
  201. {},
  202. )
  203. return variables.ConstantVariable.create(jvp_level)
  204. def exit(self, tx, *args):
  205. self.state.cleanup()
  206. tx.output.create_node(
  207. "call_function", torch._C._functorch._jvp_decrement_nesting, (), {}
  208. )
  209. return variables.ConstantVariable.create(None)
  210. class SetFwdGradEnabledContextManager(ContextWrappingVariable):
  211. """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad"""
  212. @staticmethod
  213. def create(tx, target_values, **kwargs):
  214. return SetFwdGradEnabledContextManager(
  215. target_values=target_values,
  216. initial_values=None,
  217. **kwargs,
  218. )
  219. def enter(self, tx):
  220. [mode] = self.target_values
  221. self.prev_state = torch._C._is_fwd_grad_enabled()
  222. torch._C._set_fwd_grad_enabled(mode)
  223. self.set_cleanup_hook(
  224. tx,
  225. lambda: torch._C._set_fwd_grad_enabled(self.prev_state),
  226. )
  227. self.state.proxy = tx.output.create_node(
  228. "call_function",
  229. torch._C._set_fwd_grad_enabled,
  230. (mode,),
  231. {},
  232. )
  233. return variables.ConstantVariable.create(None)
  234. def exit(self, tx, *args):
  235. self.state.cleanup()
  236. tx.output.create_node(
  237. "call_function",
  238. torch._C._set_fwd_grad_enabled,
  239. (self.prev_state,),
  240. {},
  241. )
  242. return variables.ConstantVariable.create(None)
  243. class DualLevelContextManager(ContextWrappingVariable):
  244. """Represents torch.autograd.forward_ad.dual_level ctx manager"""
  245. _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL)
  246. @staticmethod
  247. def create(tx, **kwargs):
  248. return DualLevelContextManager(
  249. target_values=None,
  250. initial_values=None,
  251. **kwargs,
  252. )
  253. def enter(self, tx):
  254. install_guard(self._guards_singleton)
  255. self.new_level = torch.autograd.forward_ad.enter_dual_level()
  256. self.set_cleanup_hook(
  257. tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level)
  258. )
  259. self.state.proxy = tx.output.create_node(
  260. "call_function",
  261. torch._C._enter_dual_level,
  262. (),
  263. {},
  264. )
  265. return variables.ConstantVariable.create(self.new_level)
  266. def exit(self, tx, *args):
  267. self.state.cleanup()
  268. tx.output.create_node(
  269. "call_function",
  270. torch._C._exit_dual_level,
  271. (self.new_level,),
  272. {},
  273. )
  274. return variables.ConstantVariable.create(None)
  275. class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
  276. """represents torch.func.grad increment/decrement nesting"""
  277. # A guard is needed as the grad level is baked into the torch FX graph
  278. # This is fine if grad is only called from within the function
  279. # being compiled. But the FX graph may be invalid in the case of a grad
  280. # call from eager that calls the compiled function, as the grad levels
  281. # may be different.
  282. _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
  283. @staticmethod
  284. def create(tx, **kwargs):
  285. var = GradIncrementNestingCtxManagerVariable(
  286. target_values=None,
  287. initial_values=None,
  288. **kwargs,
  289. )
  290. return var
  291. def enter(self, tx):
  292. install_guard(self._guards_singleton)
  293. grad_level = torch._C._functorch._grad_increment_nesting()
  294. self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting())
  295. self.state.proxy = tx.output.create_node(
  296. "call_function",
  297. torch._C._functorch._grad_increment_nesting,
  298. (),
  299. {},
  300. )
  301. return variables.ConstantVariable.create(grad_level)
  302. def exit(self, tx, *args):
  303. self.state.cleanup()
  304. tx.output.create_node(
  305. "call_function", torch._C._functorch._grad_decrement_nesting, (), {}
  306. )
  307. return variables.ConstantVariable.create(None)
  308. class CatchWarningsCtxManagerVariable(ContextWrappingVariable):
  309. """Delay a call to warnings.catch_warnings"""
  310. @staticmethod
  311. def create(tx, catch_warnings_args):
  312. return CatchWarningsCtxManagerVariable(
  313. catch_warnings_args=catch_warnings_args,
  314. target_values=None,
  315. initial_values=None,
  316. )
  317. def __init__(self, catch_warnings_args, **kwargs):
  318. assert isinstance(catch_warnings_args, dict), catch_warnings_args
  319. super().__init__(**kwargs)
  320. self.catch_warnings_args = catch_warnings_args
  321. def enter(self, tx):
  322. kwargs = {
  323. k: v.as_python_constant() for k, v in self.catch_warnings_args.items()
  324. }
  325. ctx_val = warnings.catch_warnings(**kwargs)
  326. self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None))
  327. return variables.ConstantVariable.create(ctx_val.__enter__())
  328. def reconstruct(self, cg):
  329. cg.load_import_from("warnings", "catch_warnings")
  330. cg.foreach(self.catch_warnings_args.values())
  331. keys = tuple(self.catch_warnings_args.keys())
  332. cg.extend_output(cg.create_call_function_kw(len(keys), keys, True))
  333. class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
  334. """represents torch VMap increment/decrement nesting"""
  335. # A guard is needed as the vmap level is baked into the torch FX graph
  336. # generated. This is fine if vmap is only called from within the function
  337. # being compiled. But the FX graph may be invalid in the case of a vmap
  338. # call from eager that calls the compiled function, as the vmap levels
  339. # may be different.
  340. _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
  341. @staticmethod
  342. def create(tx, target_values, **kwargs):
  343. var = VmapIncrementNestingCtxManagerVariable(
  344. target_values=target_values,
  345. initial_values=None,
  346. **kwargs,
  347. )
  348. return var
  349. def enter(self, tx):
  350. install_guard(self._guards_singleton)
  351. batch_size, randomness = self.target_values
  352. vmap_level = torch._C._functorch._vmap_increment_nesting(batch_size, randomness)
  353. self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
  354. self.state.proxy = tx.output.create_node(
  355. "call_function",
  356. torch._C._functorch._vmap_increment_nesting,
  357. (batch_size, randomness),
  358. {},
  359. )
  360. return variables.ConstantVariable.create(vmap_level)
  361. def exit(self, tx, *args):
  362. self.state.cleanup()
  363. tx.output.create_node(
  364. "call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
  365. )
  366. return variables.ConstantVariable.create(None)
  367. class GradModeVariable(ContextWrappingVariable):
  368. """represents torch.{no_grad,enable_grad,set_grad_mode}()"""
  369. _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)
  370. @staticmethod
  371. def create(tx, target_value, initialized=False, **kwargs):
  372. var = GradModeVariable(
  373. target_values=[target_value],
  374. initial_values=[torch.is_grad_enabled()],
  375. **kwargs,
  376. )
  377. if initialized:
  378. var._call_func(tx, var.target_values)
  379. return var
  380. def __init__(self, target_values, initial_values=None, initialized=True, **kwargs):
  381. super().__init__(
  382. target_values=target_values, initial_values=initial_values, **kwargs
  383. )
  384. install_guard(self._guards_singleton)
  385. def enter(self, tx):
  386. self._call_func(tx, self.target_values)
  387. return variables.ConstantVariable.create(None)
  388. def exit(self, tx, *args):
  389. self._call_func(tx, self.initial_values)
  390. return variables.ConstantVariable.create(None)
  391. def call_function(
  392. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  393. ):
  394. self._call_func(tx, self.initial_values) # undo eager initialization
  395. return super().call_function(tx, args, kwargs)
  396. def _call_func(self, tx, values):
  397. assert len(values) == 1
  398. value = values[0]
  399. # Coalesce grad mode mutations
  400. if torch.is_grad_enabled() != value:
  401. tx.output.create_node(
  402. "call_function", torch._C._set_grad_enabled, (value,), {}
  403. )
  404. torch._C._set_grad_enabled(value)
  405. def module_name(self):
  406. return "torch"
  407. def fn_name(self):
  408. return "set_grad_enabled"
  409. class InferenceModeVariable(ContextWrappingVariable):
  410. @staticmethod
  411. def create(tx, target_value, **kwargs):
  412. var = InferenceModeVariable(
  413. [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs
  414. )
  415. return var
  416. def __init__(
  417. self,
  418. target_values,
  419. initial_values=None,
  420. **kwargs,
  421. ):
  422. if initial_values is None:
  423. # This must be called here since function defaults are evaluated at import time
  424. initial_values = torch.is_inference_mode_enabled()
  425. super().__init__(
  426. target_values=target_values, initial_values=initial_values, **kwargs
  427. )
  428. self.target_values = target_values
  429. def exit(self, tx, *args):
  430. self.state.cleanup_assert()
  431. tx.output.create_node(
  432. "call_function",
  433. torch.autograd.grad_mode._exit_inference_mode,
  434. (self.state.proxy,),
  435. {},
  436. )
  437. def enter(self, tx):
  438. ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values)
  439. self.set_cleanup_hook(
  440. tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx)
  441. )
  442. self.state.proxy = tx.output.create_node(
  443. "call_function",
  444. torch.autograd.grad_mode._enter_inference_mode,
  445. (*self.target_values,),
  446. {},
  447. )
  448. def module_name(self):
  449. return "torch"
  450. def fn_name(self):
  451. return "inference_mode"
  452. class TorchFunctionDisableVariable(ContextWrappingVariable):
  453. """represents whether torch function overrides are enabled or not"""
  454. _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)
  455. @staticmethod
  456. def create(tx, **kwargs):
  457. var = TorchFunctionDisableVariable(
  458. target_values=[False],
  459. initial_values=[tx.output.torch_function_enabled],
  460. **kwargs,
  461. )
  462. # mlazos: I think this is here to make sure we don't reinvoke on clone()
  463. var._call_func(tx, [False])
  464. var.set_cleanup_hook(tx)
  465. return var
  466. def __init__(self, target_values, initial_values=None, **kwargs):
  467. super().__init__(
  468. target_values=target_values, initial_values=initial_values, **kwargs
  469. )
  470. install_guard(self._guards_singleton)
  471. def enter(self, tx):
  472. return variables.ConstantVariable.create(None)
  473. def _call_func(self, tx, values):
  474. assert len(values) == 1
  475. tx.output.set_torch_function_state(values[0])
  476. class DeterministicAlgorithmsVariable(ContextWrappingVariable):
  477. """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""
  478. _guards_singleton = Guard(
  479. GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS
  480. )
  481. @staticmethod
  482. def create(tx, target_value, **kwargs):
  483. var = DeterministicAlgorithmsVariable(
  484. target_values=[target_value],
  485. initial_values=[torch.are_deterministic_algorithms_enabled()],
  486. **kwargs,
  487. )
  488. var._call_func(tx, [target_value])
  489. var.set_cleanup_hook(tx)
  490. return var
  491. def __init__(self, target_values, initial_values=None, **kwargs):
  492. super().__init__(
  493. target_values=target_values, initial_values=initial_values, **kwargs
  494. )
  495. install_guard(self._guards_singleton)
  496. def enter(self, tx):
  497. return variables.ConstantVariable.create(None)
  498. def _call_func(self, tx, values):
  499. assert len(values) == 1
  500. value = values[0]
  501. tx.output.create_node(
  502. "call_function", torch._C._set_deterministic_algorithms, (value,), {}
  503. ),
  504. torch._C._set_deterministic_algorithms(value)
  505. def module_name(self):
  506. return "torch"
  507. def fn_name(self):
  508. return "use_deterministic_algorithms"
  509. class DisabledSavedTensorsHooksVariable(ContextWrappingVariable):
  510. """represents torch.autograd.graph.disable_saved_tensors_hook."""
  511. @staticmethod
  512. def create(tx, target_value, **kwargs):
  513. var = DisabledSavedTensorsHooksVariable(
  514. target_values=[target_value],
  515. initial_values=[
  516. torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
  517. ],
  518. **kwargs,
  519. )
  520. var._call_func(tx, [target_value])
  521. var.set_cleanup_hook(tx)
  522. return var
  523. def __init__(self, target_values, initial_values=None, **kwargs):
  524. super().__init__(
  525. target_values=target_values, initial_values=initial_values, **kwargs
  526. )
  527. def enter(self, tx):
  528. return variables.ConstantVariable.create(None)
  529. def _call_func(self, tx, values):
  530. assert len(values) == 1
  531. value = values[0]
  532. if value is not None:
  533. # Disable `saved_tensors_hooks` with message (`value`)
  534. # OR
  535. # we are exiting this context and restoring the previous message.
  536. tx.output.create_node(
  537. "call_function",
  538. torch._C._autograd._saved_tensors_hooks_disable,
  539. (value,),
  540. {},
  541. )
  542. torch._C._autograd._saved_tensors_hooks_disable(value)
  543. else:
  544. # We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`.
  545. tx.output.create_node(
  546. "call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {}
  547. )
  548. torch._C._autograd._saved_tensors_hooks_enable()
  549. def module_name(self):
  550. return "torch.autograd.graph"
  551. def fn_name(self):
  552. return "disable_saved_tensors_hooks"
  553. class AutocastModeVariable(ContextWrappingVariable):
  554. @staticmethod
  555. def create(func, args, kwargs):
  556. assert func in [
  557. torch.amp.autocast_mode.autocast,
  558. torch.cuda.amp.autocast,
  559. torch.cpu.amp.autocast,
  560. ]
  561. # device_type : str,
  562. # dtype : Optional[_dtype] = None,
  563. # enabled : bool = True,
  564. # cache_enabled : Optional[bool] = None):cache_enabled
  565. bound_args = inspect.signature(func).bind(*args, **kwargs)
  566. bound_args.apply_defaults()
  567. target_values = []
  568. kwargs.clear()
  569. for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
  570. if key == "device_type" and func in [
  571. torch.cuda.amp.autocast,
  572. torch.cpu.amp.autocast,
  573. ]:
  574. arg = "cuda" if func is torch.cuda.amp.autocast else "cpu"
  575. else:
  576. arg = bound_args.arguments[key]
  577. if isinstance(arg, VariableTracker):
  578. target_values.append(arg.as_python_constant())
  579. else:
  580. target_values.append(arg)
  581. var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
  582. return var
  583. def __init__(self, target_values, initial_values=None, **kwargs):
  584. super().__init__(
  585. target_values=target_values, initial_values=initial_values, **kwargs
  586. )
  587. self.target_values = target_values
  588. def exit(self, tx, *args):
  589. self.state.cleanup_assert()
  590. tx.output.create_node(
  591. "call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
  592. )
  593. def enter(self, tx):
  594. ctx = torch.amp._enter_autocast(*self.target_values)
  595. self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
  596. self.state.proxy = tx.output.create_node(
  597. "call_function", torch.amp._enter_autocast, (*self.target_values,), {}
  598. )
  599. def module_name(self):
  600. return "torch.amp.autocast_mode"
  601. def fn_name(self):
  602. return "autocast"
  603. class NullContextVariable(ContextWrappingVariable):
  604. """
  605. This class represents Python contextlib.nullcontext.
  606. It's used as a placeholder for other context managers that Dynamo doesn't
  607. support yet, e.g, torch.autograd.profiler.record_function.
  608. """
  609. def __init__(self, target_values=None, **kwargs):
  610. super().__init__(target_values=target_values, **kwargs)
  611. def enter(self, tx):
  612. return variables.ConstantVariable.create(None)
  613. def exit(self, tx, *args):
  614. return variables.ConstantVariable.create(None)
  615. def module_name(self):
  616. return "contextlib"
  617. def fn_name(self):
  618. return "nullcontext"
  619. class StreamContextVariable(ContextWrappingVariable):
  620. @staticmethod
  621. def create(tx, target_value, **kwargs):
  622. from .builder import wrap_fx_proxy_cls
  623. current_stream_method = get_interface_for_device(
  624. target_value.device
  625. ).current_stream
  626. current_stream = wrap_fx_proxy_cls(
  627. StreamVariable,
  628. tx,
  629. tx.output.create_proxy(
  630. "call_function",
  631. current_stream_method,
  632. (None,),
  633. {},
  634. ),
  635. )
  636. return StreamContextVariable(
  637. target_values=[target_value],
  638. initial_values=[current_stream],
  639. device=target_value.device,
  640. **kwargs,
  641. )
  642. def __init__(self, target_values, device, initial_values=None, **kwargs):
  643. super().__init__(
  644. target_values=target_values, initial_values=initial_values, **kwargs
  645. )
  646. self.device = device
  647. self.set_stream = get_interface_for_device(self.device).set_stream
  648. self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
  649. def enter(self, tx):
  650. # stream generated inside the traced function
  651. if self.target_values[0].as_proxy() is not None:
  652. tx.output.create_proxy(
  653. "call_function",
  654. self.set_stream,
  655. (self.target_values[0].as_proxy(),),
  656. {},
  657. )
  658. # stream passed from outside the traced function
  659. else:
  660. stream = self.target_values[0].value
  661. tx.output.create_proxy(
  662. "call_function",
  663. self.set_stream_id,
  664. (stream.stream_id, stream.device_index, stream.device_type),
  665. {},
  666. )
  667. self.set_stream(self.target_values[0].value)
  668. self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
  669. def exit(self, tx, *args):
  670. tx.output.create_proxy(
  671. "call_function",
  672. self.set_stream,
  673. (self.initial_values[0].as_proxy(),),
  674. {},
  675. )
  676. self.state.cleanup_assert()
  677. class PreserveVersionContextVariable(ContextWrappingVariable):
  678. """
  679. Wraps torch.autograd._unsafe_preserve_version_counter
  680. """
  681. @staticmethod
  682. def constructor(tx):
  683. return variables.LambdaVariable(
  684. lambda tensor: PreserveVersionContextVariable(
  685. tensor,
  686. tensor.var_getattr(tx, "_version"),
  687. )
  688. )
  689. def __init__(self, tensor, prev_version, **kwargs):
  690. kwargs.setdefault("target_values", None)
  691. super().__init__(**kwargs)
  692. self.tensor = tensor
  693. self.prev_version = prev_version
  694. def enter(self, tx):
  695. pass
  696. def exit(self, tx, *args):
  697. from ..tensor_version_op import _unsafe_set_version_counter
  698. return variables.TorchInGraphFunctionVariable(
  699. _unsafe_set_version_counter
  700. ).call_function(tx, [self.tensor, self.prev_version], {})
  701. def reconstruct(self, codegen):
  702. unimplemented(
  703. "torch.autograd._unsafe_preserve_version_counter with graph break"
  704. )
  705. class StreamVariable(VariableTracker):
  706. def __init__(self, proxy, value, device, **kwargs):
  707. if proxy is not None and "example_value" in proxy.node.meta:
  708. assert proxy.node.meta["example_value"] == value
  709. assert (
  710. value.device.type == device.type
  711. ), "stream value is not equal to the passed device"
  712. super().__init__(**kwargs)
  713. self.proxy = proxy
  714. self.value = value
  715. self.device = device
  716. def call_method(
  717. self,
  718. tx,
  719. name,
  720. args: "List[VariableTracker]",
  721. kwargs: "Dict[str, VariableTracker]",
  722. ) -> "VariableTracker":
  723. assert hasattr(self.value, name), f"no stream method found named {name}"
  724. assert name in [
  725. "wait_stream",
  726. "synchronize",
  727. "query",
  728. "record_event",
  729. "wait_event",
  730. ], f" unsupported stream method {name}"
  731. from ..utils import proxy_args_kwargs
  732. from .builder import wrap_fx_proxy_cls
  733. if name in ("wait_stream", "synchronize", "wait_event"):
  734. tx.output.create_proxy(
  735. "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
  736. )
  737. return variables.ConstantVariable(None)
  738. elif name == "query":
  739. return wrap_fx_proxy_cls(
  740. target_cls=variables.ConstantVariable,
  741. tx=tx,
  742. proxy=tx.output.create_proxy(
  743. "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
  744. ),
  745. )
  746. elif name == "record_event":
  747. return wrap_fx_proxy_cls(
  748. target_cls=EventVariable,
  749. tx=tx,
  750. proxy=tx.output.create_proxy(
  751. "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
  752. ),
  753. )
  754. else:
  755. unimplemented(self.device + " stream method " + name + " unsupported")
  756. def as_proxy(self):
  757. return self.proxy
  758. def reconstruct(self, codegen):
  759. # If we got here, this stream is fully subsumed by the graph - this means it is
  760. # not an input or global
  761. assert not self.source
  762. # Since we just proved that - for other such structures, like lists and dicts, reconstruction
  763. # is fine and sound according to dynamo principles of treating collectives. However,
  764. # streams are special in that we want to preserve the identity of the stream as the same as in the graph
  765. # Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
  766. # yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
  767. # design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
  768. prefix = f"_stream_{self.device}"
  769. name = codegen.tx.output.install_global_by_id(prefix, self.value)
  770. codegen.append_output(
  771. codegen.create_load_global(name, push_null=False, add=True)
  772. )
  773. class EventVariable(VariableTracker):
  774. def __init__(self, proxy, value, **kwargs):
  775. if proxy is not None and "example_value" in proxy.node.meta:
  776. assert proxy.node.meta["example_value"] == value
  777. super().__init__(**kwargs)
  778. self.proxy = proxy
  779. self.value = value
  780. def call_method(
  781. self,
  782. tx,
  783. name,
  784. args: "List[VariableTracker]",
  785. kwargs: "Dict[str, VariableTracker]",
  786. ) -> "VariableTracker":
  787. from ..utils import proxy_args_kwargs
  788. from .builder import wrap_fx_proxy_cls
  789. if name in ("wait", "record", "synchronize"):
  790. tx.output.create_proxy(
  791. "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
  792. )
  793. return variables.ConstantVariable(None)
  794. elif name == "query":
  795. return wrap_fx_proxy_cls(
  796. target_cls=variables.ConstantVariable,
  797. tx=tx,
  798. proxy=tx.output.create_proxy(
  799. "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
  800. ),
  801. )
  802. else:
  803. unimplemented(f"event method {name} unsupported")
  804. def as_proxy(self):
  805. return self.proxy
  806. class WithExitFunctionVariable(VariableTracker):
  807. _nonvar_fields = {
  808. "target",
  809. *VariableTracker._nonvar_fields,
  810. }
  811. def __init__(self, ctx: ContextWrappingVariable, target, **kwargs):
  812. super().__init__(**kwargs)
  813. assert isinstance(ctx, ContextWrappingVariable)
  814. self.ctx = ctx
  815. self.target = target
  816. def call_function(
  817. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  818. ) -> "VariableTracker":
  819. assert not kwargs
  820. return self.ctx.exit(tx, *args)
  821. def reconstruct(self, codegen):
  822. # Note here we reconstruct the context manager rather than the
  823. # exit function. The handler generated by BlockStackEntry
  824. # will re-enter the context in the resume function.
  825. self.ctx.reconstruct_type(codegen)
  826. if codegen.tx.output.partial_convert:
  827. if sys.version_info >= (3, 11):
  828. codegen.append_output(create_instruction("PUSH_NULL"))
  829. codegen.append_output(create_instruction("SWAP", arg=2))
  830. codegen.extend_output(
  831. [codegen.create_load_const(val) for val in self.ctx.target_values]
  832. )
  833. codegen.extend_output(
  834. create_call_function(len(self.ctx.target_values), False)
  835. )
  836. codegen.append_output(create_setup_with(self.target))
  837. codegen.append_output(create_instruction("POP_TOP"))