utils.py 975 B

123456789101112131415161718192021222324252627282930313233343536373839
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. from typing import Tuple, Union
  4. import torch
  5. from torch._C._functorch import (
  6. get_single_level_autograd_function_allowed,
  7. set_single_level_autograd_function_allowed,
  8. unwrap_if_dead,
  9. )
  10. from torch.utils._exposed_in import exposed_in
  11. __all__ = [
  12. "exposed_in",
  13. "argnums_t",
  14. "enable_single_level_autograd_function",
  15. "unwrap_dead_wrappers",
  16. ]
  17. @contextlib.contextmanager
  18. def enable_single_level_autograd_function():
  19. try:
  20. prev_state = get_single_level_autograd_function_allowed()
  21. set_single_level_autograd_function_allowed(True)
  22. yield
  23. finally:
  24. set_single_level_autograd_function_allowed(prev_state)
  25. def unwrap_dead_wrappers(args):
  26. # NB: doesn't use tree_map_only for performance reasons
  27. result = tuple(
  28. unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
  29. )
  30. return result
  31. argnums_t = Union[int, Tuple[int, ...]]