| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993 |
- # mypy: ignore-errors
- import dataclasses
- import inspect
- import sys
- import warnings
- from typing import Callable, Dict, List, Optional
- import torch._C
- from torch._guards import Guard
- from .. import variables
- from ..bytecode_transformation import (
- create_call_function,
- create_instruction,
- create_setup_with,
- )
- from ..device_interface import get_interface_for_device
- from ..exc import unimplemented, Unsupported
- from ..guards import GuardBuilder, install_guard
- from ..source import AttrSource, GlobalStateSource
- from .base import VariableTracker
- from .functions import (
- NestedUserFunctionVariable,
- UserFunctionVariable,
- UserMethodVariable,
- WrappedUserFunctionVariable,
- WrappedUserMethodVariable,
- )
- @dataclasses.dataclass
- class ContextMangerState:
- """
- Mutating `self` in VariableTracker is not allowed because we copy
- them. This is a mutable container pointed to by context managers
- that won't get copied, so it is safe to mutate.
- """
- cleanup_fn: Optional[Callable] = None
- proxy: Optional[torch.fx.Proxy] = None
- def cleanup(self):
- if self.cleanup_fn is not None:
- self.cleanup_fn()
- self.cleanup_fn = None
- def cleanup_assert(self):
- assert self.cleanup_fn, "multiple exits?"
- self.cleanup()
- class ContextWrappingVariable(VariableTracker):
- _nonvar_fields = {
- "cm_obj",
- "target_values",
- "initial_values",
- "state",
- *VariableTracker._nonvar_fields,
- }
- def __init__(self, target_values, initial_values=None, *, state=None, **kwargs):
- super().__init__(**kwargs)
- self.target_values = target_values
- self.initial_values = initial_values
- self.state = ContextMangerState() if state is None else state
- def enter(self, tx):
- self._call_func(tx, self.target_values)
- self.set_cleanup_hook(tx)
- return variables.ConstantVariable.create(None)
- def set_cleanup_hook(self, tx, fn=None):
- if fn is None:
- def fn():
- self._call_func(tx, self.initial_values)
- self.state.cleanup_fn = fn
- tx.output.add_cleanup_hook(self.state.cleanup)
- def exit(self, tx, *args):
- self.state.cleanup_assert()
- return variables.ConstantVariable.create(None)
- def reconstruct_type(self, codegen):
- codegen(
- AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name())
- )
- def reconstruct(self, codegen):
- if sys.version_info >= (3, 11):
- codegen.append_output(create_instruction("PUSH_NULL"))
- self.reconstruct_type(codegen)
- target_values = self.target_values
- if not target_values:
- target_values = ()
- codegen.extend_output([codegen.create_load_const(val) for val in target_values])
- codegen.extend_output(create_call_function(len(target_values), False))
- def module_name(self):
- raise NotImplementedError("module_name called on base")
- def fn_name(self):
- raise NotImplementedError("fn_name called on base")
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- assert len(args) == 1
- if isinstance(args[0], NestedUserFunctionVariable):
- args[0] = UserFunctionVariable(args[0].get_function())
- assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable))
- if isinstance(args[0], UserMethodVariable):
- return WrappedUserMethodVariable(args[0], self)
- if isinstance(args[0], UserFunctionVariable):
- return WrappedUserFunctionVariable(args[0], self)
- class GenericContextWrappingVariable(ContextWrappingVariable):
- def __init__(self, target_values, initial_values=None, *, cm_obj=None, **kwargs):
- assert cm_obj is not None
- super().__init__(
- target_values=target_values, initial_values=initial_values, **kwargs
- )
- self.cm_obj = cm_obj
- def enter(self, tx):
- source = None if self.source is None else AttrSource(self.source, "__enter__")
- try:
- return variables.UserMethodVariable(
- self.cm_obj.__enter__.__func__,
- variables.UserDefinedObjectVariable(self.cm_obj),
- source=source,
- ).call_function(tx, [], {})
- except Unsupported as e:
- unimplemented(
- f"Unsupported context manager {self.cm_obj}'s __enter__ function",
- from_exc=e,
- )
- def exit(self, tx, *args):
- source = None if self.source is None else AttrSource(self.source, "__exit__")
- try:
- x = variables.UserMethodVariable(
- self.cm_obj.__exit__.__func__,
- variables.UserDefinedObjectVariable(self.cm_obj),
- source=source,
- ).call_function(
- tx,
- [
- variables.ConstantVariable.create(None),
- variables.ConstantVariable.create(None),
- variables.ConstantVariable.create(None),
- ],
- {},
- )
- except Unsupported as e:
- unimplemented(
- f"Unsupported context manager {self.cm_obj}'s __exit__ function",
- from_exc=e,
- )
- tx.generic_context_manager_depth -= 1
- return x
- class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
- """represents torch grad requries grad"""
- @staticmethod
- def create(tx, target_values, **kwargs):
- return GradInplaceRequiresGradCtxManagerVariable(
- target_values=target_values,
- initial_values=None,
- **kwargs,
- )
- def enter(self, tx):
- [enabled] = self.target_values
- self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed()
- torch._C._functorch.set_inplace_requires_grad_allowed(enabled)
- self.set_cleanup_hook(
- tx,
- lambda: torch._C._functorch.set_inplace_requires_grad_allowed(
- self.prev_state
- ),
- )
- self.state.proxy = tx.output.create_node(
- "call_function",
- torch._C._functorch.set_inplace_requires_grad_allowed,
- (enabled,),
- {},
- )
- return variables.ConstantVariable.create(None)
- def exit(self, tx, *args):
- self.state.cleanup()
- tx.output.create_node(
- "call_function",
- torch._C._functorch.set_inplace_requires_grad_allowed,
- (self.prev_state,),
- {},
- )
- return variables.ConstantVariable.create(None)
- class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable):
- """represents torch.func.jvp increment/decrement nesting"""
- # A guard is needed as the grad level is baked into the torch FX graph
- # This is fine if jvp is only called from within the function
- # being compiled. But the FX graph may be invalid in the case of a jvp
- # call from eager that calls the compiled function, as the jvp levels
- # may be different.
- _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
- @staticmethod
- def create(tx, **kwargs):
- var = JvpIncrementNestingCtxManagerVariable(
- target_values=None,
- initial_values=None,
- **kwargs,
- )
- return var
- def enter(self, tx):
- install_guard(self._guards_singleton)
- jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting()
- self.set_cleanup_hook(
- tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting()
- )
- self.state.proxy = tx.output.create_node(
- "call_function",
- torch._C._functorch._jvp_increment_nesting,
- (),
- {},
- )
- return variables.ConstantVariable.create(jvp_level)
- def exit(self, tx, *args):
- self.state.cleanup()
- tx.output.create_node(
- "call_function", torch._C._functorch._jvp_decrement_nesting, (), {}
- )
- return variables.ConstantVariable.create(None)
- class SetFwdGradEnabledContextManager(ContextWrappingVariable):
- """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad"""
- @staticmethod
- def create(tx, target_values, **kwargs):
- return SetFwdGradEnabledContextManager(
- target_values=target_values,
- initial_values=None,
- **kwargs,
- )
- def enter(self, tx):
- [mode] = self.target_values
- self.prev_state = torch._C._is_fwd_grad_enabled()
- torch._C._set_fwd_grad_enabled(mode)
- self.set_cleanup_hook(
- tx,
- lambda: torch._C._set_fwd_grad_enabled(self.prev_state),
- )
- self.state.proxy = tx.output.create_node(
- "call_function",
- torch._C._set_fwd_grad_enabled,
- (mode,),
- {},
- )
- return variables.ConstantVariable.create(None)
- def exit(self, tx, *args):
- self.state.cleanup()
- tx.output.create_node(
- "call_function",
- torch._C._set_fwd_grad_enabled,
- (self.prev_state,),
- {},
- )
- return variables.ConstantVariable.create(None)
- class DualLevelContextManager(ContextWrappingVariable):
- """Represents torch.autograd.forward_ad.dual_level ctx manager"""
- _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL)
- @staticmethod
- def create(tx, **kwargs):
- return DualLevelContextManager(
- target_values=None,
- initial_values=None,
- **kwargs,
- )
- def enter(self, tx):
- install_guard(self._guards_singleton)
- self.new_level = torch.autograd.forward_ad.enter_dual_level()
- self.set_cleanup_hook(
- tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level)
- )
- self.state.proxy = tx.output.create_node(
- "call_function",
- torch._C._enter_dual_level,
- (),
- {},
- )
- return variables.ConstantVariable.create(self.new_level)
- def exit(self, tx, *args):
- self.state.cleanup()
- tx.output.create_node(
- "call_function",
- torch._C._exit_dual_level,
- (self.new_level,),
- {},
- )
- return variables.ConstantVariable.create(None)
- class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
- """represents torch.func.grad increment/decrement nesting"""
- # A guard is needed as the grad level is baked into the torch FX graph
- # This is fine if grad is only called from within the function
- # being compiled. But the FX graph may be invalid in the case of a grad
- # call from eager that calls the compiled function, as the grad levels
- # may be different.
- _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
- @staticmethod
- def create(tx, **kwargs):
- var = GradIncrementNestingCtxManagerVariable(
- target_values=None,
- initial_values=None,
- **kwargs,
- )
- return var
- def enter(self, tx):
- install_guard(self._guards_singleton)
- grad_level = torch._C._functorch._grad_increment_nesting()
- self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting())
- self.state.proxy = tx.output.create_node(
- "call_function",
- torch._C._functorch._grad_increment_nesting,
- (),
- {},
- )
- return variables.ConstantVariable.create(grad_level)
- def exit(self, tx, *args):
- self.state.cleanup()
- tx.output.create_node(
- "call_function", torch._C._functorch._grad_decrement_nesting, (), {}
- )
- return variables.ConstantVariable.create(None)
- class CatchWarningsCtxManagerVariable(ContextWrappingVariable):
- """Delay a call to warnings.catch_warnings"""
- @staticmethod
- def create(tx, catch_warnings_args):
- return CatchWarningsCtxManagerVariable(
- catch_warnings_args=catch_warnings_args,
- target_values=None,
- initial_values=None,
- )
- def __init__(self, catch_warnings_args, **kwargs):
- assert isinstance(catch_warnings_args, dict), catch_warnings_args
- super().__init__(**kwargs)
- self.catch_warnings_args = catch_warnings_args
- def enter(self, tx):
- kwargs = {
- k: v.as_python_constant() for k, v in self.catch_warnings_args.items()
- }
- ctx_val = warnings.catch_warnings(**kwargs)
- self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None))
- return variables.ConstantVariable.create(ctx_val.__enter__())
- def reconstruct(self, cg):
- cg.load_import_from("warnings", "catch_warnings")
- cg.foreach(self.catch_warnings_args.values())
- keys = tuple(self.catch_warnings_args.keys())
- cg.extend_output(cg.create_call_function_kw(len(keys), keys, True))
- class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
- """represents torch VMap increment/decrement nesting"""
- # A guard is needed as the vmap level is baked into the torch FX graph
- # generated. This is fine if vmap is only called from within the function
- # being compiled. But the FX graph may be invalid in the case of a vmap
- # call from eager that calls the compiled function, as the vmap levels
- # may be different.
- _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
- @staticmethod
- def create(tx, target_values, **kwargs):
- var = VmapIncrementNestingCtxManagerVariable(
- target_values=target_values,
- initial_values=None,
- **kwargs,
- )
- return var
- def enter(self, tx):
- install_guard(self._guards_singleton)
- batch_size, randomness = self.target_values
- vmap_level = torch._C._functorch._vmap_increment_nesting(batch_size, randomness)
- self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
- self.state.proxy = tx.output.create_node(
- "call_function",
- torch._C._functorch._vmap_increment_nesting,
- (batch_size, randomness),
- {},
- )
- return variables.ConstantVariable.create(vmap_level)
- def exit(self, tx, *args):
- self.state.cleanup()
- tx.output.create_node(
- "call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
- )
- return variables.ConstantVariable.create(None)
- class GradModeVariable(ContextWrappingVariable):
- """represents torch.{no_grad,enable_grad,set_grad_mode}()"""
- _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)
- @staticmethod
- def create(tx, target_value, initialized=False, **kwargs):
- var = GradModeVariable(
- target_values=[target_value],
- initial_values=[torch.is_grad_enabled()],
- **kwargs,
- )
- if initialized:
- var._call_func(tx, var.target_values)
- return var
- def __init__(self, target_values, initial_values=None, initialized=True, **kwargs):
- super().__init__(
- target_values=target_values, initial_values=initial_values, **kwargs
- )
- install_guard(self._guards_singleton)
- def enter(self, tx):
- self._call_func(tx, self.target_values)
- return variables.ConstantVariable.create(None)
- def exit(self, tx, *args):
- self._call_func(tx, self.initial_values)
- return variables.ConstantVariable.create(None)
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ):
- self._call_func(tx, self.initial_values) # undo eager initialization
- return super().call_function(tx, args, kwargs)
- def _call_func(self, tx, values):
- assert len(values) == 1
- value = values[0]
- # Coalesce grad mode mutations
- if torch.is_grad_enabled() != value:
- tx.output.create_node(
- "call_function", torch._C._set_grad_enabled, (value,), {}
- )
- torch._C._set_grad_enabled(value)
- def module_name(self):
- return "torch"
- def fn_name(self):
- return "set_grad_enabled"
- class InferenceModeVariable(ContextWrappingVariable):
- @staticmethod
- def create(tx, target_value, **kwargs):
- var = InferenceModeVariable(
- [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs
- )
- return var
- def __init__(
- self,
- target_values,
- initial_values=None,
- **kwargs,
- ):
- if initial_values is None:
- # This must be called here since function defaults are evaluated at import time
- initial_values = torch.is_inference_mode_enabled()
- super().__init__(
- target_values=target_values, initial_values=initial_values, **kwargs
- )
- self.target_values = target_values
- def exit(self, tx, *args):
- self.state.cleanup_assert()
- tx.output.create_node(
- "call_function",
- torch.autograd.grad_mode._exit_inference_mode,
- (self.state.proxy,),
- {},
- )
- def enter(self, tx):
- ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values)
- self.set_cleanup_hook(
- tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx)
- )
- self.state.proxy = tx.output.create_node(
- "call_function",
- torch.autograd.grad_mode._enter_inference_mode,
- (*self.target_values,),
- {},
- )
- def module_name(self):
- return "torch"
- def fn_name(self):
- return "inference_mode"
- class TorchFunctionDisableVariable(ContextWrappingVariable):
- """represents whether torch function overrides are enabled or not"""
- _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)
- @staticmethod
- def create(tx, **kwargs):
- var = TorchFunctionDisableVariable(
- target_values=[False],
- initial_values=[tx.output.torch_function_enabled],
- **kwargs,
- )
- # mlazos: I think this is here to make sure we don't reinvoke on clone()
- var._call_func(tx, [False])
- var.set_cleanup_hook(tx)
- return var
- def __init__(self, target_values, initial_values=None, **kwargs):
- super().__init__(
- target_values=target_values, initial_values=initial_values, **kwargs
- )
- install_guard(self._guards_singleton)
- def enter(self, tx):
- return variables.ConstantVariable.create(None)
- def _call_func(self, tx, values):
- assert len(values) == 1
- tx.output.set_torch_function_state(values[0])
- class DeterministicAlgorithmsVariable(ContextWrappingVariable):
- """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""
- _guards_singleton = Guard(
- GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS
- )
- @staticmethod
- def create(tx, target_value, **kwargs):
- var = DeterministicAlgorithmsVariable(
- target_values=[target_value],
- initial_values=[torch.are_deterministic_algorithms_enabled()],
- **kwargs,
- )
- var._call_func(tx, [target_value])
- var.set_cleanup_hook(tx)
- return var
- def __init__(self, target_values, initial_values=None, **kwargs):
- super().__init__(
- target_values=target_values, initial_values=initial_values, **kwargs
- )
- install_guard(self._guards_singleton)
- def enter(self, tx):
- return variables.ConstantVariable.create(None)
- def _call_func(self, tx, values):
- assert len(values) == 1
- value = values[0]
- tx.output.create_node(
- "call_function", torch._C._set_deterministic_algorithms, (value,), {}
- ),
- torch._C._set_deterministic_algorithms(value)
- def module_name(self):
- return "torch"
- def fn_name(self):
- return "use_deterministic_algorithms"
- class DisabledSavedTensorsHooksVariable(ContextWrappingVariable):
- """represents torch.autograd.graph.disable_saved_tensors_hook."""
- @staticmethod
- def create(tx, target_value, **kwargs):
- var = DisabledSavedTensorsHooksVariable(
- target_values=[target_value],
- initial_values=[
- torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
- ],
- **kwargs,
- )
- var._call_func(tx, [target_value])
- var.set_cleanup_hook(tx)
- return var
- def __init__(self, target_values, initial_values=None, **kwargs):
- super().__init__(
- target_values=target_values, initial_values=initial_values, **kwargs
- )
- def enter(self, tx):
- return variables.ConstantVariable.create(None)
- def _call_func(self, tx, values):
- assert len(values) == 1
- value = values[0]
- if value is not None:
- # Disable `saved_tensors_hooks` with message (`value`)
- # OR
- # we are exiting this context and restoring the previous message.
- tx.output.create_node(
- "call_function",
- torch._C._autograd._saved_tensors_hooks_disable,
- (value,),
- {},
- )
- torch._C._autograd._saved_tensors_hooks_disable(value)
- else:
- # We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`.
- tx.output.create_node(
- "call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {}
- )
- torch._C._autograd._saved_tensors_hooks_enable()
- def module_name(self):
- return "torch.autograd.graph"
- def fn_name(self):
- return "disable_saved_tensors_hooks"
- class AutocastModeVariable(ContextWrappingVariable):
- @staticmethod
- def create(func, args, kwargs):
- assert func in [
- torch.amp.autocast_mode.autocast,
- torch.cuda.amp.autocast,
- torch.cpu.amp.autocast,
- ]
- # device_type : str,
- # dtype : Optional[_dtype] = None,
- # enabled : bool = True,
- # cache_enabled : Optional[bool] = None):cache_enabled
- bound_args = inspect.signature(func).bind(*args, **kwargs)
- bound_args.apply_defaults()
- target_values = []
- kwargs.clear()
- for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
- if key == "device_type" and func in [
- torch.cuda.amp.autocast,
- torch.cpu.amp.autocast,
- ]:
- arg = "cuda" if func is torch.cuda.amp.autocast else "cpu"
- else:
- arg = bound_args.arguments[key]
- if isinstance(arg, VariableTracker):
- target_values.append(arg.as_python_constant())
- else:
- target_values.append(arg)
- var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
- return var
- def __init__(self, target_values, initial_values=None, **kwargs):
- super().__init__(
- target_values=target_values, initial_values=initial_values, **kwargs
- )
- self.target_values = target_values
- def exit(self, tx, *args):
- self.state.cleanup_assert()
- tx.output.create_node(
- "call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
- )
- def enter(self, tx):
- ctx = torch.amp._enter_autocast(*self.target_values)
- self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
- self.state.proxy = tx.output.create_node(
- "call_function", torch.amp._enter_autocast, (*self.target_values,), {}
- )
- def module_name(self):
- return "torch.amp.autocast_mode"
- def fn_name(self):
- return "autocast"
- class NullContextVariable(ContextWrappingVariable):
- """
- This class represents Python contextlib.nullcontext.
- It's used as a placeholder for other context managers that Dynamo doesn't
- support yet, e.g, torch.autograd.profiler.record_function.
- """
- def __init__(self, target_values=None, **kwargs):
- super().__init__(target_values=target_values, **kwargs)
- def enter(self, tx):
- return variables.ConstantVariable.create(None)
- def exit(self, tx, *args):
- return variables.ConstantVariable.create(None)
- def module_name(self):
- return "contextlib"
- def fn_name(self):
- return "nullcontext"
- class StreamContextVariable(ContextWrappingVariable):
- @staticmethod
- def create(tx, target_value, **kwargs):
- from .builder import wrap_fx_proxy_cls
- current_stream_method = get_interface_for_device(
- target_value.device
- ).current_stream
- current_stream = wrap_fx_proxy_cls(
- StreamVariable,
- tx,
- tx.output.create_proxy(
- "call_function",
- current_stream_method,
- (None,),
- {},
- ),
- )
- return StreamContextVariable(
- target_values=[target_value],
- initial_values=[current_stream],
- device=target_value.device,
- **kwargs,
- )
- def __init__(self, target_values, device, initial_values=None, **kwargs):
- super().__init__(
- target_values=target_values, initial_values=initial_values, **kwargs
- )
- self.device = device
- self.set_stream = get_interface_for_device(self.device).set_stream
- self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
- def enter(self, tx):
- # stream generated inside the traced function
- if self.target_values[0].as_proxy() is not None:
- tx.output.create_proxy(
- "call_function",
- self.set_stream,
- (self.target_values[0].as_proxy(),),
- {},
- )
- # stream passed from outside the traced function
- else:
- stream = self.target_values[0].value
- tx.output.create_proxy(
- "call_function",
- self.set_stream_id,
- (stream.stream_id, stream.device_index, stream.device_type),
- {},
- )
- self.set_stream(self.target_values[0].value)
- self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
- def exit(self, tx, *args):
- tx.output.create_proxy(
- "call_function",
- self.set_stream,
- (self.initial_values[0].as_proxy(),),
- {},
- )
- self.state.cleanup_assert()
- class PreserveVersionContextVariable(ContextWrappingVariable):
- """
- Wraps torch.autograd._unsafe_preserve_version_counter
- """
- @staticmethod
- def constructor(tx):
- return variables.LambdaVariable(
- lambda tensor: PreserveVersionContextVariable(
- tensor,
- tensor.var_getattr(tx, "_version"),
- )
- )
- def __init__(self, tensor, prev_version, **kwargs):
- kwargs.setdefault("target_values", None)
- super().__init__(**kwargs)
- self.tensor = tensor
- self.prev_version = prev_version
- def enter(self, tx):
- pass
- def exit(self, tx, *args):
- from ..tensor_version_op import _unsafe_set_version_counter
- return variables.TorchInGraphFunctionVariable(
- _unsafe_set_version_counter
- ).call_function(tx, [self.tensor, self.prev_version], {})
- def reconstruct(self, codegen):
- unimplemented(
- "torch.autograd._unsafe_preserve_version_counter with graph break"
- )
- class StreamVariable(VariableTracker):
- def __init__(self, proxy, value, device, **kwargs):
- if proxy is not None and "example_value" in proxy.node.meta:
- assert proxy.node.meta["example_value"] == value
- assert (
- value.device.type == device.type
- ), "stream value is not equal to the passed device"
- super().__init__(**kwargs)
- self.proxy = proxy
- self.value = value
- self.device = device
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- assert hasattr(self.value, name), f"no stream method found named {name}"
- assert name in [
- "wait_stream",
- "synchronize",
- "query",
- "record_event",
- "wait_event",
- ], f" unsupported stream method {name}"
- from ..utils import proxy_args_kwargs
- from .builder import wrap_fx_proxy_cls
- if name in ("wait_stream", "synchronize", "wait_event"):
- tx.output.create_proxy(
- "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
- )
- return variables.ConstantVariable(None)
- elif name == "query":
- return wrap_fx_proxy_cls(
- target_cls=variables.ConstantVariable,
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
- ),
- )
- elif name == "record_event":
- return wrap_fx_proxy_cls(
- target_cls=EventVariable,
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
- ),
- )
- else:
- unimplemented(self.device + " stream method " + name + " unsupported")
- def as_proxy(self):
- return self.proxy
- def reconstruct(self, codegen):
- # If we got here, this stream is fully subsumed by the graph - this means it is
- # not an input or global
- assert not self.source
- # Since we just proved that - for other such structures, like lists and dicts, reconstruction
- # is fine and sound according to dynamo principles of treating collectives. However,
- # streams are special in that we want to preserve the identity of the stream as the same as in the graph
- # Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
- # yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
- # design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
- prefix = f"_stream_{self.device}"
- name = codegen.tx.output.install_global_by_id(prefix, self.value)
- codegen.append_output(
- codegen.create_load_global(name, push_null=False, add=True)
- )
- class EventVariable(VariableTracker):
- def __init__(self, proxy, value, **kwargs):
- if proxy is not None and "example_value" in proxy.node.meta:
- assert proxy.node.meta["example_value"] == value
- super().__init__(**kwargs)
- self.proxy = proxy
- self.value = value
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- from ..utils import proxy_args_kwargs
- from .builder import wrap_fx_proxy_cls
- if name in ("wait", "record", "synchronize"):
- tx.output.create_proxy(
- "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
- )
- return variables.ConstantVariable(None)
- elif name == "query":
- return wrap_fx_proxy_cls(
- target_cls=variables.ConstantVariable,
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
- ),
- )
- else:
- unimplemented(f"event method {name} unsupported")
- def as_proxy(self):
- return self.proxy
- class WithExitFunctionVariable(VariableTracker):
- _nonvar_fields = {
- "target",
- *VariableTracker._nonvar_fields,
- }
- def __init__(self, ctx: ContextWrappingVariable, target, **kwargs):
- super().__init__(**kwargs)
- assert isinstance(ctx, ContextWrappingVariable)
- self.ctx = ctx
- self.target = target
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- assert not kwargs
- return self.ctx.exit(tx, *args)
- def reconstruct(self, codegen):
- # Note here we reconstruct the context manager rather than the
- # exit function. The handler generated by BlockStackEntry
- # will re-enter the context in the resume function.
- self.ctx.reconstruct_type(codegen)
- if codegen.tx.output.partial_convert:
- if sys.version_info >= (3, 11):
- codegen.append_output(create_instruction("PUSH_NULL"))
- codegen.append_output(create_instruction("SWAP", arg=2))
- codegen.extend_output(
- [codegen.create_load_const(val) for val in self.ctx.target_values]
- )
- codegen.extend_output(
- create_call_function(len(self.ctx.target_values), False)
- )
- codegen.append_output(create_setup_with(self.target))
- codegen.append_output(create_instruction("POP_TOP"))
|