external_utils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # mypy: allow-untyped-defs
  2. # This module contains functions that *will be allowed* by dynamo
  3. import functools
  4. from typing import List
  5. import torch
  6. import torch.utils._pytree as pytree
  7. try:
  8. import numpy as np
  9. except ModuleNotFoundError:
  10. np = None # type: ignore[assignment]
  11. def is_compiling() -> bool:
  12. """
  13. Indicates whether we are tracing/compiling with torch.compile() or torch.export().
  14. If need to check specifically that TorchDynamo is used, then use
  15. torch.compiler.is_dynamo_compiling().
  16. TODO(khabinov): we should deprecate this function and use one of these two:
  17. * torch.compiler.is_compiling(),
  18. * torch.compiler.is_dynamo_compiling().
  19. It will depend on the context where to use what.
  20. """
  21. return torch.compiler.is_compiling()
  22. def wrap_inline(fn):
  23. """
  24. Create an extra frame around fn that is not in skipfiles
  25. """
  26. @functools.wraps(fn)
  27. def inner(*args, **kwargs):
  28. return fn(*args, **kwargs)
  29. return inner
  30. def call_hook(hook, *args):
  31. """
  32. Used by compiled autograd to handle hook returning None
  33. """
  34. result = hook(*args)
  35. if result is None:
  36. return args[0]
  37. return result
  38. def wrap_numpy(f):
  39. r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
  40. from ``torch.Tensor``s to ``torch.Tensor``s.
  41. """
  42. if not np:
  43. return f
  44. @functools.wraps(f)
  45. def wrap(*args, **kwargs):
  46. args, kwargs = pytree.tree_map_only(
  47. torch.Tensor, lambda x: x.numpy(), (args, kwargs)
  48. )
  49. out = f(*args, **kwargs)
  50. return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
  51. return wrap
  52. class FakeBackwardCFunction:
  53. def __init__(
  54. self,
  55. real: torch.autograd.function.BackwardCFunction,
  56. saved_tensors: List[torch.Tensor],
  57. ):
  58. self.real = real
  59. self.saved_tensors = saved_tensors
  60. def __getattr__(self, name):
  61. # route any attribute that isn't defined on this obj
  62. return getattr(self.real, name)
  63. # This function corresponds to the "eager" implementation of a lifted autograd.Function.backward
  64. def call_backward(backward_c_function, saved_tensors, *args):
  65. fake = FakeBackwardCFunction(backward_c_function, saved_tensors)
  66. grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined]
  67. # in eager, we wrap in a tuple when there's only one grad output
  68. if type(grads) is not tuple:
  69. grads = (grads,)
  70. return grads
  71. def untyped_storage_size(x: torch.Tensor):
  72. return x.untyped_storage().size()
  73. def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs):
  74. return getattr(bw_state, hook_name)(*args, **kwargs)
  75. def call_module_hooks_from_backward_state(
  76. _, result, *args, bw_state, hooks_name: str, module_name: str
  77. ):
  78. module = getattr(bw_state, module_name)
  79. hooks = getattr(bw_state, hooks_name)
  80. for hook in hooks:
  81. new_result = hook(module, result, *args)
  82. if new_result is not None:
  83. result = new_result
  84. return result