||
- # mypy: allow-untyped-defs
- import contextlib
- import platform
- import uuid
- import warnings
- import weakref
- from collections import defaultdict
- from itertools import count
- from typing import (
- Any,
- Callable,
- ContextManager,
- DefaultDict,
- Dict,
- Iterable,
- List,
- Optional,
- Tuple,
- )
- from weakref import ReferenceType
- import torch
- import torch.fx.traceback as fx_traceback
- from torch._functorch._aot_autograd.functional_utils import is_fun
- from torch.utils._pytree import tree_map
- from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
- from torch.utils._python_dispatch import TorchDispatchMode
- __all__ = [
- "checkpoint",
- "checkpoint_sequential",
- "CheckpointError",
- "CheckpointFunction",
- "check_backward_validity",
- "detach_variable",
- "get_device_states",
- "set_device_states",
- "noop_context_fn",
- "set_checkpoint_early_stop",
- "DefaultDeviceType",
- "set_checkpoint_debug_enabled",
- ]
- _DEFAULT_DETERMINISM_MODE = "default"
- _checkpoint_debug_enabled: Optional[bool] = None
- @contextlib.contextmanager
- def set_checkpoint_debug_enabled(enabled: Optional[bool]):
- """
- Context manager that sets whether checkpoint should print additional debug
- information when running. See the ``debug`` flag for
- :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that
- when set, this context manager overrides the value of ``debug`` passed to
- checkpoint. To defer to the local setting, pass ``None`` to this context.
- Args:
- enabled (bool): Whether checkpoint should print debug information.
- Default is 'None'.
- """
- global _checkpoint_debug_enabled
- try:
- prev = _checkpoint_debug_enabled
- _checkpoint_debug_enabled = enabled
- yield
- finally:
- _checkpoint_debug_enabled = prev
- def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
- if isinstance(inputs, tuple):
- out = []
- for inp in inputs:
- if not isinstance(inp, torch.Tensor):
- out.append(inp)
- continue
- x = inp.detach()
- x.requires_grad = inp.requires_grad
- out.append(x)
- return tuple(out)
- else:
- raise RuntimeError(
- "Only tuple of tensors is supported. Got Unsupported input type: ",
- type(inputs).__name__,
- )
- def check_backward_validity(inputs: Iterable[Any]) -> None:
- if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
- warnings.warn(
- "None of the inputs have requires_grad=True. Gradients will be None"
- )
- def _get_device_module(device="cuda"):
- device_module = getattr(torch, device)
- return device_module
- class DefaultDeviceType:
- r"""
- A class that manages the default device type for checkpointing.
- If no non-CPU tensors are present, the default device type will
- be used. The default value is 'cuda'. The device type is used in
- the checkpointing process when determining which device states
- to save and restore for recomputation.
- """
- _default_device_type = "cuda"
- @staticmethod
- def set_device_type(device: str = "cuda"):
- """
- Set the default device type for checkpointing.
- Args:
- device (str): The device type to be set as default. Default is 'cuda'.
- """
- DefaultDeviceType._default_device_type = device
- @staticmethod
- def get_device_type() -> str:
- """
- Get the current default device type for checkpointing.
- Returns:
- str: The current default device type.
- """
- return DefaultDeviceType._default_device_type
- def _infer_device_type(*args):
- device_types = []
- def add_device_types(arg):
- nonlocal device_types
- if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu":
- device_types.append(arg.device.type)
- tree_map(add_device_types, args)
- device_types_set = set(device_types)
- if len(device_types_set) > 1:
- warnings.warn(
- "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
- "Device state will only be saved for devices of a single device type, and the remaining "
- "devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
- "this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
- "detected, it will be prioritized; otherwise, the first device encountered will be selected.)"
- f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}"
- )
- if len(device_types) == 0:
- return DefaultDeviceType.get_device_type()
- elif "cuda" in device_types_set:
- return "cuda"
- else:
- return device_types[0]
- # We can't know if the run_fn will internally move some args to different devices,
- # which would require logic to preserve rng states for those devices as well.
- # We could paranoically stash and restore ALL the rng states for all visible devices,
- # but that seems very wasteful for most cases. Compromise: Stash the RNG state for
- # the device of all Tensor args.
- #
- # To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
- def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
- # This will not error out if "arg" is a CPU tensor or a non-tensor type because
- # the conditionals short-circuit.
- fwd_device_ids = []
- def add_device_ids(arg):
- nonlocal fwd_device_ids
- if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu":
- fwd_device_ids.append(arg.get_device())
- tree_map(add_device_ids, args)
- fwd_device_states = []
- device_module = _get_device_module(_infer_device_type(*args))
- for device_id in fwd_device_ids:
- with device_module.device(device_id):
- fwd_device_states.append(device_module.get_rng_state())
- return fwd_device_ids, fwd_device_states
- def set_device_states(devices, states) -> None:
- device_module = _get_device_module(_infer_device_type(*states))
- for device, state in zip(devices, states):
- with device_module.device(device):
- device_module.set_rng_state(state)
- def _get_autocast_kwargs(device="cuda"):
- if torch.amp.is_autocast_available(device):
- device_autocast_kwargs = {
- "enabled": torch.is_autocast_enabled(device),
- "dtype": torch.get_autocast_dtype(device),
- "cache_enabled": torch.is_autocast_cache_enabled(),
- }
- else:
- device_autocast_kwargs = None
- cpu_autocast_kwargs = {
- "enabled": torch.is_autocast_enabled('cpu'),
- "dtype": torch.get_autocast_dtype('cpu'),
- "cache_enabled": torch.is_autocast_cache_enabled(),
- }
- return device_autocast_kwargs, cpu_autocast_kwargs
- class CheckpointFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, run_function, preserve_rng_state, *args):
- check_backward_validity(args)
- ctx.run_function = run_function
- ctx.preserve_rng_state = preserve_rng_state
- # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
- ctx.device = _infer_device_type(*args)
- ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
- ctx.device
- )
- if preserve_rng_state:
- ctx.fwd_cpu_state = torch.get_rng_state()
- # Don't eagerly initialize the cuda context by accident.
- # (If the user intends that the context is initialized later, within their
- # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
- # we have no way to anticipate this will happen before we run the function.)
- ctx.had_device_in_fwd = False
- device_module = _get_device_module(ctx.device)
- if getattr(device_module, "_initialized", False):
- ctx.had_device_in_fwd = True
- ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
- # Save non-tensor inputs in ctx, keep a placeholder None for tensors
- # to be filled out during the backward.
- ctx.inputs = []
- ctx.tensor_indices = []
- tensor_inputs = []
- for i, arg in enumerate(args):
- if torch.is_tensor(arg):
- tensor_inputs.append(arg)
- ctx.tensor_indices.append(i)
- ctx.inputs.append(None)
- else:
- ctx.inputs.append(arg)
- ctx.save_for_backward(*tensor_inputs)
- with torch.no_grad():
- outputs = run_function(*args)
- return outputs
- @staticmethod
- def backward(ctx, *args):
- if not torch.autograd._is_checkpoint_valid():
- raise RuntimeError(
- "When use_reentrant=True, torch.utils.checkpoint is incompatible"
- " with .grad() or passing an `inputs` parameter to .backward()."
- " To resolve this error, you can either set use_reentrant=False,"
- " or call .backward() without passing the `inputs` argument."
- )
- # Copy the list to avoid modifying original list.
- inputs = list(ctx.inputs)
- tensor_indices = ctx.tensor_indices
- tensors = ctx.saved_tensors
- device_module = _get_device_module(ctx.device)
- # Fill in inputs with appropriate saved tensors.
- for i, idx in enumerate(tensor_indices):
- inputs[idx] = tensors[i]
- # Stash the surrounding rng state, and mimic the state that was
- # present at this time during forward. Restore the surrounding state
- # when we're done.
- rng_devices = []
- if ctx.preserve_rng_state and ctx.had_device_in_fwd:
- rng_devices = ctx.fwd_devices
- with torch.random.fork_rng(
- devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device
- ):
- if ctx.preserve_rng_state:
- torch.set_rng_state(ctx.fwd_cpu_state)
- if ctx.had_device_in_fwd:
- set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
- detached_inputs = detach_variable(tuple(inputs))
- device_autocast_ctx = torch.amp.autocast(
- device_type=ctx.device, **ctx.device_autocast_kwargs
- ) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext()
- with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
- outputs = ctx.run_function(*detached_inputs)
- if isinstance(outputs, torch.Tensor):
- outputs = (outputs,)
- # run backward() with only tensor that requires grad
- outputs_with_grad = []
- args_with_grad = []
- for i in range(len(outputs)):
- if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
- outputs_with_grad.append(outputs[i])
- args_with_grad.append(args[i])
- if len(outputs_with_grad) == 0:
- raise RuntimeError(
- "none of output has requires_grad=True,"
- " this checkpoint() is not necessary"
- )
- torch.autograd.backward(outputs_with_grad, args_with_grad)
- grads = tuple(
- inp.grad if isinstance(inp, torch.Tensor) else None
- for inp in detached_inputs
- )
- return (None, None) + grads
- def noop_context_fn():
- return contextlib.nullcontext(), contextlib.nullcontext()
- # TorchDynamo does not step inside utils.checkpoint function. The flow
- # looks likes this
- # 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
- # speculatively checking if the forward function is safe to trace.
- # 2) If yes, then Dynamo-generated Fx graph has the wrapped higher
- # order op. As a result, TorchDynamo does not look inside utils.checkpoint.
- # 3) If not, then TorchDynamo falls back to eager by performing a graph
- # break. And here, the following disable wrapper ensures that
- # TorchDynamo does not trigger again on the frames created by
- # utils.checkpoint innards.
- @torch._disable_dynamo
- def checkpoint(
- function,
- *args,
- use_reentrant: Optional[bool] = None,
- context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
- determinism_check: str = _DEFAULT_DETERMINISM_MODE,
- debug: bool = False,
- **kwargs
- ):
- r"""Checkpoint a model or part of the model.
- Activation checkpointing is a technique that trades compute for memory.
- Instead of keeping tensors needed for backward alive until they are used in
- gradient computation during backward, forward computation in checkpointed
- regions omits saving tensors for backward and recomputes them during the
- backward pass. Activation checkpointing can be applied to any part of a
- model.
- There are currently two checkpointing implementations available, determined
- by the :attr:`use_reentrant` parameter. It is recommended that you use
- ``use_reentrant=False``. Please refer the note below for a discussion of
- their differences.
- .. warning::
- If the :attr:`function` invocation during the backward pass differs
- from the forward pass, e.g., due to a global variable, the checkpointed
- version may not be equivalent, potentially causing an
- error being raised or leading to silently incorrect gradients.
- .. warning::
- The ``use_reentrant`` parameter should be passed explicitly. In version
- 2.4 we will raise an exception if ``use_reentrant`` is not passed.
- If you are using the ``use_reentrant=True`` variant, please refer to the
- note below for important considerations and potential limitations.
- .. note::
- The reentrant variant of checkpoint (``use_reentrant=True``) and
- the non-reentrant variant of checkpoint (``use_reentrant=False``)
- differ in the following ways:
- * Non-reentrant checkpoint stops recomputation as soon as all needed
- intermediate activations have been recomputed. This feature is enabled
- by default, but can be disabled with :func:`set_checkpoint_early_stop`.
- Reentrant checkpoint always recomputes :attr:`function` in its
- entirety during the backward pass.
- * The reentrant variant does not record the autograd graph during the
- forward pass, as it runs with the forward pass under
- :func:`torch.no_grad`. The non-reentrant version does record the
- autograd graph, allowing one to perform backward on the graph within
- checkpointed regions.
- * The reentrant checkpoint only supports the
- :func:`torch.autograd.backward` API for the backward pass without its
- `inputs` argument, while the non-reentrant version supports all ways
- of performing the backward pass.
- * At least one input and output must have ``requires_grad=True`` for the
- reentrant variant. If this condition is unmet, the checkpointed part
- of the model will not have gradients. The non-reentrant version does
- not have this requirement.
- * The reentrant version does not consider tensors in nested structures
- (e.g., custom objects, lists, dicts, etc) as participating in
- autograd, while the non-reentrant version does.
- * The reentrant checkpoint does not support checkpointed regions with
- detached tensors from the computational graph, whereas the
- non-reentrant version does. For the reentrant variant, if the
- checkpointed segment contains tensors detached using ``detach()`` or
- with :func:`torch.no_grad`, the backward pass will raise an error.
- This is because ``checkpoint`` makes all the outputs require gradients
- and this causes issues when a tensor is defined to have no gradient in
- the model. To avoid this, detach the tensors outside of the
- ``checkpoint`` function.
- Args:
- function: describes what to run in the forward pass of the model or
- part of the model. It should also know how to handle the inputs
- passed as the tuple. For example, in LSTM, if user passes
- ``(activation, hidden)``, :attr:`function` should correctly use the
- first input as ``activation`` and the second input as ``hidden``
- preserve_rng_state(bool, optional): Omit stashing and restoring
- the RNG state during each checkpoint. Note that under torch.compile,
- this flag doesn't take effect and we always preserve RNG state.
- Default: ``True``
- use_reentrant(bool):
- specify whether to use the activation checkpoint variant that
- requires reentrant autograd. This parameter should be passed
- explicitly. In version 2.4 we will raise an exception if
- ``use_reentrant`` is not passed. If ``use_reentrant=False``,
- ``checkpoint`` will use an implementation that does not require
- reentrant autograd. This allows ``checkpoint`` to support additional
- functionality, such as working as expected with
- ``torch.autograd.grad`` and support for keyword arguments input into
- the checkpointed function.
- context_fn(Callable, optional): A callable returning a tuple of two
- context managers. The function and its recomputation will be run
- under the first and second context managers respectively.
- This argument is only supported if ``use_reentrant=False``.
- determinism_check(str, optional): A string specifying the determinism
- check to perform. By default it is set to ``"default"`` which
- compares the shapes, dtypes, and devices of the recomputed tensors
- against those the saved tensors. To turn off this check, specify
- ``"none"``. Currently these are the only two supported values.
- Please open an issue if you would like to see more determinism
- checks. This argument is only supported if ``use_reentrant=False``,
- if ``use_reentrant=True``, the determinism check is always disabled.
- debug(bool, optional): If ``True``, error messages will also include
- a trace of the operators ran during the original forward computation
- as well as the recomputation. This argument is only supported if
- ``use_reentrant=False``.
- args: tuple containing inputs to the :attr:`function`
- Returns:
- Output of running :attr:`function` on :attr:`*args`
- """
- if use_reentrant is None:
- warnings.warn(
- "torch.utils.checkpoint: the use_reentrant parameter should be "
- "passed explicitly. In version 2.4 we will raise an exception "
- "if use_reentrant is not passed. use_reentrant=False is "
- "recommended, but if you need to preserve the current default "
- "behavior, you can pass use_reentrant=True. Refer to docs for more "
- "details on the differences between the two variants.",
- stacklevel=2
- )
- use_reentrant = True
- # Hack to mix *args with **kwargs in a python 2.7-compliant way
- preserve = kwargs.pop("preserve_rng_state", True)
- if kwargs and use_reentrant:
- raise ValueError(
- "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
- )
- if use_reentrant:
- if context_fn is not noop_context_fn or debug is not False:
- raise ValueError(
- "Passing `context_fn` or `debug` is only supported when "
- "use_reentrant=False."
- )
- return CheckpointFunction.apply(function, preserve, *args)
- else:
- gen = _checkpoint_without_reentrant_generator(
- function, preserve, context_fn, determinism_check, debug, *args, **kwargs
- )
- # Runs pre-forward logic
- next(gen)
- ret = function(*args, **kwargs)
- # Runs post-forward logic
- try:
- next(gen)
- except StopIteration:
- return ret
- def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs):
- r"""Checkpoint a sequential model to save memory.
- Sequential models execute a list of modules/functions in order
- (sequentially). Therefore, we can divide such a model in various segments
- and checkpoint each segment. All segments except the last will not store
- the intermediate activations. The inputs of each checkpointed segment will
- be saved for re-running the segment in the backward pass.
- .. warning::
- The ``use_reentrant`` parameter should be passed explicitly. In version
- 2.4 we will raise an exception if ``use_reentrant`` is not passed.
- If you are using the ``use_reentrant=True` variant, please see
- :func:`~torch.utils.checkpoint.checkpoint` for
- the important considerations and limitations of this variant. It is
- recommended that you use ``use_reentrant=False``.
- .. warning:
- Since PyTorch 1.4, it allows only one Tensor as the input and
- intermediate outputs, just like :class:`torch.nn.Sequential`.
- Args:
- functions: A :class:`torch.nn.Sequential` or the list of modules or
- functions (comprising the model) to run sequentially.
- segments: Number of chunks to create in the model
- input: A Tensor that is input to :attr:`functions`
- preserve_rng_state(bool, optional): Omit stashing and restoring
- the RNG state during each checkpoint.
- Default: ``True``
- use_reentrant(bool):
- specify whether to use the activation checkpoint variant that
- requires reentrant autograd. This parameter should be passed
- explicitly. In version 2.4 we will raise an exception if
- ``use_reentrant`` is not passed. If ``use_reentrant=False``,
- ``checkpoint`` will use an implementation that does not require
- reentrant autograd. This allows ``checkpoint`` to support additional
- functionality, such as working as expected with
- ``torch.autograd.grad`` and support for keyword arguments input into
- the checkpointed function.
- Returns:
- Output of running :attr:`functions` sequentially on :attr:`*inputs`
- Example:
- >>> # xdoctest: +SKIP("stub")
- >>> model = nn.Sequential(...)
- >>> input_var = checkpoint_sequential(model, chunks, input_var)
- """
- if use_reentrant is None:
- warnings.warn(
- "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant "
- "parameter should be passed explicitly. "
- "In version 2.4 we will raise an exception if use_reentrant "
- "is not passed. use_reentrant=False is "
- "recommended, but if you need to preserve the current default "
- "behavior, you can pass use_reentrant=True. Refer to docs for more "
- "details on the differences between the two variants."
- )
- use_reentrant = True
- # Hack for keyword-only parameter in a python 2.7-compliant way
- preserve = kwargs.pop("preserve_rng_state", True)
- if kwargs:
- raise ValueError(
- "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
- )
- def run_function(start, end, functions):
- def forward(input):
- for j in range(start, end + 1):
- input = functions[j](input)
- return input
- return forward
- if isinstance(functions, torch.nn.Sequential):
- functions = list(functions.children())
- segment_size = len(functions) // segments
- # the last chunk has to be non-volatile
- end = -1
- for start in range(0, segment_size * (segments - 1), segment_size):
- end = start + segment_size - 1
- input = checkpoint(
- run_function(start, end, functions),
- input,
- use_reentrant=use_reentrant,
- preserve_rng_state=preserve,
- )
- return run_function(end + 1, len(functions) - 1, functions)(input)
- def _internal_assert(cond):
- if not cond:
- raise AssertionError(
- "Something went unexpectedly wrong in activation checkpoint. "
- "Please report this bug by filing an issue to PyTorch."
- )
- # NOTE [ Nestable Checkpoint ]
- #
- # The semantics of nested checkpoint can be defined by two basic rules.
- # Following the two rules leads to an important implication that is central
- # to motivating the design.
- #
- # Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden
- # from any outer layers of checkpoint.
- #
- # Rule 2. The inputs of inner checkpoints are treated as tensors saved to its
- # parent checkpoint.
- #
- # Implication: To recompute any given saved tensor, we need to recompute all of
- # the checkpoints wrapping it.
- #
- # Why is this implied? To unpack a saved tensor X during backward we need to
- # recompute the inner-most checkpoint (#1), and in order to recompute that
- # checkpoint I need to have its inputs, which are managed by that checkpoint's
- # parent (#2), which thus also needs to be recomputed first. Continue this line
- # of reasoning and we realize that in order to unpack X, all checkpoints that
- # were active at the time X was saved need to be recomputed. (unless we have
- # already done so in that backward for some other saved tensor).
- #
- # In practice, we use a noop autograd Function to save inputs as saved tensors.
- # During unpack calling ctx.saved_tensor triggers the parent checkpoint to
- # recompute.
- #
- # Rule 3. We should start recomputation as if there are no checkpoints currently
- # active. Checkpoints encountered during recomputation are still
- # respected.
- #
- # When we start recomputation, we push the saved variable hook meant for
- # recomputation on the stack. See examples in Rule 6 for more context.
- #
- # * * * *
- #
- # Beyond the basic semantics specific to nested checkpoint, we impose several
- # more constraints that may apply to checkpointing in general.
- #
- # Rule 4. Lifetime of recomputed tensors
- #
- # Recomputed tensors are considered specific to particular invocations
- # of backward and are always cleared immediately as they are unpacked
- # Particularly, we require this to happen even if retain_graph=True.
- #
- # [ Implementation details of Rule 4 ]
- #
- # If we were okay with recomputed tensors staying alive after backward is run
- # with retain_graph=True, we would store recomputed variables as the values of a
- # WeakKeyDictionary and pack strong references to the keys, so that as we
- # backward, those packed keys would be cleared as long as retain_graph=False.
- # Clearing the packed key clears the corresponding entry in the WKD.
- #
- # If we wish recomputed variables to be immediately cleared as we unpack them in
- # the retain_graph=True case, we cannot rely on the packed keys to be cleared by
- # backward automatically. Instead of packing the strong reference to the key
- # directly, we pack a container object, which we manually clear as we unpack.
- #
- # An important detail is that if a second backward happens, the second
- # recomputation needs to reset the container with a newly created key.
- #
- # Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we
- # know we need.
- #
- # [ Implementation details of Rule 5 ]
- #
- # During recomputation, raise an exception if the number of recomputed tensors
- # matches the number of tensors that we expected to recompute. We wrap the
- # recomputation call with a try-catch to catch this specific exception. See
- # Rule #6 below for some examples.
- #
- # Rule 6. We support doing backward inside checkpoint context
- #
- # [ retain_graph is True]
- #
- # def fn(x):
- # y = x.sin()
- # z = y.cos()
- # gx, = torch.autograd.grad(z, x, retains_grad=True)
- # return gx, z
- #
- # out = checkpoint(fn)(inp)
- # out.backward()
- #
- # Because z is saved by cos while checkpoint is enabled, it would not be
- # actually saved, and so the .grad() call inside must trigger a recomputation.
- #
- # During recomputation the "inner pack hook" has two responsibilities:
- #
- # 1) As usual, populating the WeakKeyDictionary storing recomputed tensors
- # 2) Pack the actual tensor (detached) so that one may perform backward on the
- # recomputed graph. The tensors saved to this graph will live until the end
- # of recomputation, or die earlier if someone performs backward with
- # retain_graph=False.
- #
- # More generally performing backward on the recomputed graph occurs in the
- # following cases:
- # - If backward is performed inside forward,
- # - During the original forward IF early-stop is disabled
- # - During the original backward
- # - If there are multiple .grad()/.backward() calls, we would perform backward
- # on the recomputed graph even if early-stop is enabled (see the example below)
- #
- # [ retain_graph is False ]
- #
- # The example below shows what happens if during recomputation we find that some
- # of the tensors we are trying to recompute have already been cleared.
- #
- # Spoiler: we don't do anything special, we just skip over them!
- #
- # def fn(x):
- # y = x.sin() # (1)
- # z = y.cos() # (2)
- # gx, = torch.autograd.grad(z, x) # (3)
- # return x.cos() * gx # (4)
- #
- # out = checkpoint(fn)(inp)
- # out.backward() # (5)
- #
- # 1, 2. Don't save x and y since we are inside a checkpoint.
- # 3. Trigger a recompute of fn since x and y weren't saved.
- # And depending on whether early stop is enabled, either stop at (2) or
- # continue running the function.
- # Because we are running backward with retain_graph=False, we clear x and y's
- # holders.
- # 4. Don't save x since we are inside a checkpoint.
- # 5. Calling backward triggers another recompute of fn. During recompute, we see
- # that x and y have already been cleared in the original graph as indicated
- # by holder=None. We skip over them. We still save x at (4) (since its holder
- # is still alive.)
- _enable_checkpoint_early_stop = True
- @contextlib.contextmanager
- def set_checkpoint_early_stop(enable: bool):
- """Context manager that sets whether checkpoint should stop recomputation early.
- By default, non-reentrant checkpoint stops recomputation as soon as it
- has computed all needed Tensors. This context manager can be used to disable
- that feature if it is problematic for your specific application.
- This context manager only needs to be active when forward is run. It does
- not need to be active during backward.
- Example::
- >>> # xdoctest: +SKIP(failing)
- >>> message = "saved tensors default hooks are disabled"
- >>> with set_checkpoint_early_stop(False):
- ... # Any checkpoint under this context manager will respect this
- ... # context manager, even if its backward is performed outside.
- ... out = checkpoint(fn, inputs)
- ...
- >>> out.backward()
- """
- global _enable_checkpoint_early_stop
- try:
- prev = _enable_checkpoint_early_stop
- _enable_checkpoint_early_stop = enable
- yield
- finally:
- _enable_checkpoint_early_stop = prev
- class _Handle:
- pass
- class _Holder:
- def __init__(self):
- self.handles: Dict[int, Optional[_Handle]] = dict()
- class _NoopSaveInputs(torch.autograd.Function):
- @staticmethod
- def forward(*args):
- return torch.empty((0,))
- @staticmethod
- def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
- # Only tensors can be saved with ctx.save_for_backward, everything else
- # is captured by get_args, which is saved directly on ctx
- tensor_indices, tensors = zip(
- *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
- )
- idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
- # args but with tensors replaced with None as placeholders
- args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
- def get_args(saved_tensors):
- # restore the placeholders with the original tensors grabbed from
- # ctx.saved_tensors (which may be saved on a parent checkpoint if
- # this checkpoint is nested, and that would trigger a recursive
- # unpack!)
- ret = [
- saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
- for i, o in enumerate(args)
- ]
- # grab the tail since we also saved the dummy to avoid having to explicitly
- # handle the case where there are no tensor inputs
- return ret[1:]
- ctx.get_args = get_args
- ctx.save_for_backward(*tensors)
- @staticmethod
- def backward(ctx, *grad_outputs):
- raise AssertionError("Did not expect to backward on this graph")
- class _CheckpointFrame:
- def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn):
- self.recompute_fn = recompute_fn
- self.input_saver = None
- self.weak_holders: List[ReferenceType] = []
- # We store this as a weakkeydictionary so that in the case of a partial
- # backward, the entries in the dict are cleared alongside the Holder
- # which will be removed when the SavedVariable is cleared.
- self.recomputed: DefaultDict[
- int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]
- ] = defaultdict(weakref.WeakKeyDictionary)
- # We need both recomp_counter and recomputed since they can diverge
- # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885
- self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
- self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool)
- # See Rule 5
- self.early_stop = early_stop
- # Debugging
- self.metadata_fn = metadata_fn
- self.unpack_error_cb = unpack_error_cb
- self.x_metadatas = []
- self.forward_completed = False
- self.ignore_saved_mismatch = False
- def check_recomputed_tensors_match(self, gid):
- if self.ignore_saved_mismatch:
- # TODO: we can probably make this check stricter by checking that
- # the metadata of the first tensors still match.
- return
- # NOTE [ Error handling for checkpoint ]
- #
- # At a high level, we need to check that the tensors saved
- # during original forward matches tensors saved during recompute
- # This means handling 3 cases:
- #
- # 1. During recompute, more tensors were saved.
- #
- # Usually this is hidden due to the StopRecomputationError
- # but if early stop is not enabled, or we would have errored
- # anyway because there aren't enough weak_holders. But we
- # do want to have a nice error. See the _recomputation_hook
- # for details.
- if not len(self.weak_holders) == self.recomp_counter[gid]:
- # 2. During recompute, fewer tensors were saved
- #
- # We know that everytime we save something do original forward
- # we append to weak_holder, and every time we save a tensor
- # during recompute we increment recompute_counter.
- raise CheckpointError(
- "torch.utils.checkpoint: A different number of tensors was saved "
- "during the original forward and recomputation.\n"
- f"Number of tensors saved during forward: {len(self.weak_holders)}\n"
- f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}"
- )
- # 3. During recompute, the same tensors were saved, but they
- # have different metadata
- nb_meta_different = []
- for idx, weak_holder in enumerate(self.weak_holders):
- holder = weak_holder()
- if holder is None:
- continue
- # We've seen all holders since we iterate over them in order
- # For every holder that is still alive now, it must've been
- # alive when we saw it during recompute, therefore, the
- # gid must be set.
- _internal_assert(gid in holder.handles)
- # We know this is the first unpack, so it couldn't have been set
- # to None yet.
- _internal_assert(holder.handles[gid] is not None)
- # We always set these together in the recomputation hook
- _internal_assert(holder.handles[gid] in self.recomputed[gid])
- # see pack hook, x_metadata is 1:1 with weak_holders.
- x_meta = self.x_metadatas[idx]
- recomputed_x = self.recomputed[gid][holder.handles[gid]]
- if x_meta != self.metadata_fn(recomputed_x):
- nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x)))
- if len(nb_meta_different) > 0:
- mismatched_tensors = ""
- for idx, x_meta, recomputed_meta in nb_meta_different:
- mismatched_tensors += (
- f"tensor at position {idx}:\n"
- f"saved metadata: {x_meta}\n"
- f"recomputed metadata: {recomputed_meta}\n"
- )
- raise CheckpointError(
- "torch.utils.checkpoint: Recomputed values for the following tensors "
- "have different metadata than during the forward pass.\n"
- f"{mismatched_tensors}"
- )
- _checkpoint_error_template = """ \
- An error happened while unpacking tensors; dumping logs of latest computation
- because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`.
- Scroll all the way down for guidance on how to navigate these logs.
- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
- | 1. Stack traces of the operators that ran in the original forward |
- +------------------------------------------------------------------------------+
- {forward_traces}
- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
- | 2. Stack traces of the operators that ran during recomputation |
- +------------------------------------------------------------------------------+
- {recompute_traces}
- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
- | 3. Log of operators in the original forward and recomputation |
- +------------------------------------------------------------------------------+
- (Scroll up to correlate stack traces with each operation listed below. This
- helps identify their source in the code.)
- IMPORTANT: Differences in "detach" calls between the original forward and the
- recomputation are expected. They are introduced by the checkpointing
- mechanism and can be ignored.
- Operations executed during the original forward:
- {forward_ops}
- Operations executed during recomputation:
- {recompute_ops}
- +------------------------------------------------------------------------------+
- ERROR: Detected non-determinism while running activation checkpointing
- You are seeing this error because you passed `debug=True` to checkpoint and
- tensors to be saved during the original forward and differ between those saved
- during recomputation. This can happen if different operators were ran in the
- original forward and in the recomputation.
- To identify where the mismatch may be coming from, you can do the following:
- 1) Compare the operators ran during original forward and recomputation to
- see where they differ. These operators are printed above in the order they
- were executed.
- 2) Review the stack trace for each operator to locate its invocation source.
- Each operator's stack trace is printed in their execution order.
- Note that the logs can be quite long. Here's how they are structured:
- (Tip: you can Ctrl-f for these headers)
- 1. Stack traces of the operators that ran in the original forward
- 2. Stack traces of the operators that ran during recomputation
- 3. Log of operators in the original forward and recomputation
- 4. Error message <--- You are here
- --------------------------------------------------------------------------------
- """
- class CheckpointError(RuntimeError):
- pass
- def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]:
- # This function returns the context_fn and error_cb to be used by the
- # checkpointing mechanism. error_cb is invoked when an error is detected
- # during unpack.
- # record_context_cpp is not support on non-linux non-x86_64 platforms
- cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux'
- class CaptureLogs:
- def __init__(self):
- self.logs = None
- self.tbs = None
- def get_context_manager(self):
- @contextlib.contextmanager
- def logging_mode():
- with LoggingTensorMode(), \
- capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
- self.logs, self.tbs = logs_and_tb
- yield logs_and_tb
- return logging_mode()
- capture_logs_fwd = CaptureLogs()
- capture_logs_recompute = CaptureLogs()
- def unpack_error_cb(e: CheckpointError):
- def get_str_tb(label, capture_logs):
- out = ""
- total_len = len(capture_logs.logs)
- for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)):
- out += f"{log} ({i + 1} of {total_len} in {label})\n\n"
- found_torch_dispatch = False
- for line in tb:
- # Start printing stack trace only after __torch_dispatch__ is found
- is_torch_dispatch = line['name'] == '__torch_dispatch__'
- if not found_torch_dispatch and not is_torch_dispatch:
- continue
- elif is_torch_dispatch:
- found_torch_dispatch = True
- continue
- out += f"{line['filename']}:{line['line']}:{line['name']}\n"
- out += "\n\n"
- return out
- assert capture_logs_fwd.logs is not None
- assert capture_logs_recompute.logs is not None
- raise CheckpointError(
- _checkpoint_error_template.format(
- forward_traces=get_str_tb("original", capture_logs_fwd),
- recompute_traces=get_str_tb("recompute", capture_logs_recompute),
- forward_ops="\n".join(capture_logs_fwd.logs),
- recompute_ops="\n".join(capture_logs_recompute.logs)
- )
- ) from e
- def context_fn():
- return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager()
- return context_fn, unpack_error_cb
- def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]:
- # These properties are fast to check, easy to understand
- return {
- "shape": x.shape,
- "dtype": x.dtype,
- "device": x.device
- }
- _allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = {
- _DEFAULT_DETERMINISM_MODE: _default_meta_extractor,
- "none": lambda _: None,
- }
- # See Rule 5
- class _StopRecomputationError(Exception):
- pass
- class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
- def __init__(self, target_frame_ref: ReferenceType, gid: int):
- def pack_hook(x):
- target_frame = target_frame_ref()
- assert target_frame is not None # appease mypy
- recomp_idx = target_frame.recomp_counter[gid]
- target_frame.recomp_counter[gid] += 1
- if recomp_idx >= len(target_frame.weak_holders):
- assert not target_frame.early_stop
- if not target_frame.forward_completed:
- # We run into this case when early stop is not enabled and do
- # grad within checkpoint.
- # We need to set this flag, so we don't error out later when
- # we check if the number of tensors saved during forward and
- # recomputation match.
- target_frame.ignore_saved_mismatch = True
- return x.detach()
- raise CheckpointError(
- "torch.utils.checkpoint: trying to save more tensors during "
- "recomputation than during the original forward pass."
- )
- holder = target_frame.weak_holders[recomp_idx]()
- # This holder may have been cleared because someone may have called
- # backward within forward. If so, we don't need to save.
- if holder is not None:
- _internal_assert(holder.handles.get(gid, None) is None)
- holder.handles[gid] = _Handle()
- target_frame.recomputed[gid][holder.handles[gid]] = x.detach()
- if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
- target_frame.weak_holders
- ):
- raise _StopRecomputationError
- # See Rule 6: [ retain_graph is True ] above
- return x.detach()
- def unpack_hook(x):
- # See Rule 6: [ retain_graph is True ] above for an example of when
- # the graph created during recomputation could be backwarded.
- return x
- super().__init__(pack_hook, unpack_hook)
- class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
- def __init__(self, frame):
- def pack_hook(x):
- # See Rule 4 above
- holder = _Holder()
- frame.weak_holders.append(weakref.ref(holder))
- # Save metadata to detect non-determinism
- if frame.metadata_fn is not None:
- with torch.no_grad():
- frame.x_metadatas.append(frame.metadata_fn(x))
- return holder
- def unpack_hook(holder):
- gid = torch._C._current_graph_task_id()
- if gid == -1:
- # generate a temporary id if we trigger unpack outside of a backward call
- gid = int(uuid.uuid4())
- if not frame.is_recomputed[gid]:
- ctx = frame.input_saver.grad_fn
- args = ctx.get_args(ctx.saved_tensors)
- try:
- with _recomputation_hook(
- weakref.ref(frame), gid
- ), torch.autograd.enable_grad():
- frame.recompute_fn(*args)
- except _StopRecomputationError:
- pass
- frame.is_recomputed[gid] = True
- frame.check_recomputed_tensors_match(gid)
- _internal_assert(gid in holder.handles)
- if holder.handles[gid] is None:
- raise CheckpointError(
- "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
- "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
- "so only once. Otherwise please open an issue with details on your use case."
- )
- _internal_assert(holder.handles[gid] in frame.recomputed[gid])
- ret = frame.recomputed[gid][holder.handles[gid]]
- holder.handles[gid] = None
- return ret
- if frame.unpack_error_cb is not None:
- def unpack_hook_with_error_cb(holder):
- try:
- return unpack_hook(holder)
- except CheckpointError as e:
- frame.unpack_error_cb(e)
- super().__init__(pack_hook, unpack_hook_with_error_cb)
- else:
- super().__init__(pack_hook, unpack_hook)
- def _is_compiling(func, args, kwargs):
- # Check if we are under AOTAutograd tracing
- # There should probably be a better way to do this...
- # TODO: unify _is_compiling across all compile stacks
- for arg in args:
- if isinstance(arg, torch.Tensor) and is_fun(arg):
- return True
- return False
- def _detach(x):
- if isinstance(x, torch.Tensor):
- return x.detach()
- return x
- uid = count(1)
- # NOTE: torch.utils.checkpoint internal logic will call these two functions unknown number of times
- # (i.e. there could be _CachedTorchDispatchMode calls that doesn't map to a _CachingTorchDispatchMode call),
- # so we ignore these ops and just always recompute them.
- _ignored_ops = {
- torch.ops.prim.device.default,
- torch.ops.aten.detach.default,
- } | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
- class _CachingTorchDispatchMode(TorchDispatchMode):
- r"""
- A :class:`TorchDispatchMode` to implement selective activation checkpointing
- that's compatible with torch.compile. Used together with _CachedTorchDispatchMode.
- """
- def __init__(self, policy_fn, storage):
- self.policy_fn = policy_fn
- self.storage = storage
- def push_into_storage(self, out, func, args, kwargs):
- out_detached = tree_map(_detach, out)
- self.storage[func].append(out_detached)
- def _handle_compile_in_forward_ctx(self, should_not_recompute, func, args, kwargs):
- if should_not_recompute:
- fx_traceback.current_meta["recompute"] = 0
- # NOTE: Here we just store and reuse output of all ops, since in torch.compile mode
- # we decide and handle recomputation in the partitioner.
- out = func(*args, **kwargs)
- self.push_into_storage(out, func, args, kwargs)
- return out
- def __torch_dispatch__(self, func, types, args=(), kwargs=None):
- if kwargs is None:
- kwargs = {}
- if func in _ignored_ops:
- return func(*args, **kwargs)
- should_not_recompute = self.policy_fn("forward", func, *args, **kwargs)
- if _is_compiling(func, args, kwargs):
- return self._handle_compile_in_forward_ctx(should_not_recompute, func, args, kwargs)
- else:
- if should_not_recompute:
- out = func(*args, **kwargs)
- self.push_into_storage(out, func, args, kwargs)
- else:
- out = func(*args, **kwargs)
- return out
- class _CachedTorchDispatchMode(TorchDispatchMode):
- r"""
- A :class:`TorchDispatchMode` to implement selective activation checkpointing
- that's compatible with torch.compile. Used together with _CachingTorchDispatchMode.
- """
- def __init__(self, policy_fn, storage):
- self.policy_fn = policy_fn
- self.storage = storage
- def pop_from_storage(self, func, args, kwargs):
- assert func in self.storage
- out = self.storage[func].pop(0)
- return out
- def _handle_compile_in_recompute_ctx(self, should_not_recompute, func, args, kwargs):
- out = self.pop_from_storage(func, args, kwargs)
- return out
- def __torch_dispatch__(self, func, types, args=(), kwargs=None):
- if kwargs is None:
- kwargs = {}
- if func in _ignored_ops:
- return func(*args, **kwargs)
- should_not_recompute = self.policy_fn("recompute", func, *args, **kwargs)
- if _is_compiling(func, args, kwargs):
- return self._handle_compile_in_recompute_ctx(should_not_recompute, func, args, kwargs)
- else:
- if should_not_recompute:
- out = self.pop_from_storage(func, args, kwargs)
- else:
- out = func(*args, **kwargs)
- return out
- def _pt2_selective_checkpoint_context_fn_gen(policy_fn):
- """
- A helper function that generates a pair of contexts to be later passed into
- `torch.utils.checkpoint` API to implment selective checkpointing.
- .. warning::
- This is context_fn is intended for use with torch.compile only.
- Args:
- policy_fn (Callable[[Callable, List[Any], Dict[str, Any]], bool]): Policy function
- to decide whether a particular op should be recomputed in backward pass or not.
- In eager mode:
- If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed.
- If policy_fn(...) returns False, the op is guaranteed to be recomputed.
- In torch.compile mode:
- If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed.
- If policy_fn(...) returns False, the op may or may not be recomputed
- (it's up to the partitioner to decide).
- Returns:
- A pair of generated contexts.
- Example:
- >>> # xdoctest: +REQUIRES(LINUX)
- >>>
- >>> def get_custom_policy():
- >>> no_recompute_list = [
- >>> torch.ops.aten.mm.default,
- >>> ]
- >>> def custom_policy(mode, func, *args, **kwargs):
- >>> return func in no_recompute_list
- >>> return custom_policy
- >>>
- >>> def selective_checkpointing_context_fn():
- >>> return _pt2_selective_checkpoint_context_fn_gen(get_custom_policy())
- >>>
- >>> def gn(x, y):
- >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
- >>>
- >>> def fn(x, y):
- >>> return torch.utils.checkpoint.checkpoint(
- >>> gn, x, y,
- >>> use_reentrant=False,
- >>> context_fn=selective_checkpointing_context_fn,
- >>> )
- >>>
- >>> x = torch.randn(4, 4, requires_grad=True)
- >>> y = torch.randn(4, 4, requires_grad=True)
- >>>
- >>> compiled_fn = torch.compile(fn)
- """
- storage: Dict[Any, List[Any]] = defaultdict(list)
- return _CachingTorchDispatchMode(policy_fn, storage), _CachedTorchDispatchMode(policy_fn, storage)
- # NB: this helper wraps fn before calling checkpoint_impl. kwargs and
- # saving/restoring of global state is handled here.
- def _checkpoint_without_reentrant_generator(
- fn,
- preserve_rng_state=True,
- context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
- determinism_check: str = _DEFAULT_DETERMINISM_MODE,
- debug: bool = False,
- *args,
- **kwargs
- ):
- """Checkpointing without reentrant autograd.
- Args:
- function: describes what to run in the forward pass of the model or
- part of the model. It should also know how to handle the inputs
- passed as the tuple. For example, in LSTM, if user passes
- ``(activation, hidden)``, :attr:`function` should correctly use the
- first input as ``activation`` and the second input as ``hidden``
- preserve_rng_state(bool, optional): Omit stashing and restoring
- the RNG state during each checkpoint.
- Default: ``True``
- context_fn(Callable, optional): A callable returning a tuple of two
- context managers. The function and its recomputation will be run
- under the first and second context managers respectively.
- determinism_check(str, optional): A string specifying the determinism
- check to perform. By default it is set to ``"default"`` which
- compares the shapes, dtypes, and devices of the recomputed tensors
- against those the saved tensors. To turn off this check, specify
- ``"none"``. Currently these are the only two supported values.
- Please open an issue if you would like to see more determinism
- checks.
- debug(bool, optional): If ``True``, error messages will also include
- a trace of the operators ran during the original forward computation
- as well as the recomputation.
- *args: Arguments to pass in to the given ``function``.
- **kwargs: Keyword arguments to pass into the given ``function``.
- """
- unpack_error_cb = None
- if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug:
- if context_fn != noop_context_fn:
- raise ValueError(
- "debug=True is incompatible with non-default context_fn"
- )
- context_fn, unpack_error_cb = _get_debug_context_and_cb()
- if determinism_check in _allowed_determinism_checks_to_fns:
- metadata_fn = _allowed_determinism_checks_to_fns[determinism_check]
- else:
- raise ValueError(
- f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, "
- f"but got {determinism_check}"
- )
- device = _infer_device_type(*args)
- device_module = _get_device_module(device)
- forward_context, recompute_context = context_fn()
- if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn:
- assert (
- isinstance(forward_context, TorchDispatchMode) and
- isinstance(recompute_context, TorchDispatchMode)
- ), \
- "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \
- "must generate a tuple of two `TorchDispatchMode`s."
- # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
- device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device=device)
- if preserve_rng_state:
- fwd_cpu_state = torch.get_rng_state()
- # Don't eagerly initialize the cuda context by accident.
- # (If the user intends that the context is initialized later, within their
- # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
- # we have no way to anticipate this will happen before we run the function.
- # If they do so, we raise an error.)
- had_device_in_fwd = False
- if getattr(device_module, "_initialized", False):
- had_device_in_fwd = True
- fwd_devices, fwd_device_states = get_device_states(*args)
- def recompute_fn(*inputs):
- kwargs, *args = inputs
- # This will be called later during recomputation. This wrapping enables
- # the necessary global state to be captured.
- rng_devices = []
- if preserve_rng_state and had_device_in_fwd:
- rng_devices = fwd_devices
- with torch.random.fork_rng(
- devices=rng_devices, enabled=preserve_rng_state, device_type=device
- ):
- if preserve_rng_state:
- torch.set_rng_state(fwd_cpu_state)
- if had_device_in_fwd:
- set_device_states(fwd_devices, fwd_device_states)
- device_autocast_ctx = torch.amp.autocast(
- device_type=device, **device_autocast_kwargs
- ) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
- with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
- fn(*args, **kwargs)
- new_frame = _CheckpointFrame(
- recompute_fn,
- _enable_checkpoint_early_stop,
- unpack_error_cb,
- metadata_fn
- )
- dummy = torch.empty((0,), requires_grad=True)
- new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)
- # When ambient grad_mode is False
- if new_frame.input_saver.grad_fn is None:
- yield
- return
- with _checkpoint_hook(new_frame), forward_context:
- yield
- new_frame.forward_completed = True
- if getattr(device_module, "_initialized", False) and \
- preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined]
- # Device was not initialized before running the forward, so we didn't
- # stash the device state.
- raise RuntimeError(
- "PyTorch's device state was initialized in the forward pass "
- "of a Checkpoint, which is not allowed. Please open an issue "
- "if you need this feature."
- )
- return
|