deprecated.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # mypy: allow-untyped-defs
  2. """
  3. The APIs in this file are exposed as `functorch.*`. They are thin wrappers
  4. around the torch.func.* APIs that have deprecation warnings -- we're trying
  5. to move people to the torch.func.* equivalents.
  6. NB: We don't use *args, **kwargs in the signatures because that changes the
  7. documentation.
  8. """
  9. import textwrap
  10. import warnings
  11. from typing import Any, Callable, Optional, Tuple, Union
  12. import torch._functorch.apis as apis
  13. import torch._functorch.eager_transforms as _impl
  14. import torch._functorch.make_functional as _nn_impl
  15. import torch.nn as nn
  16. from torch._functorch.eager_transforms import argnums_t
  17. from torch._functorch.vmap import in_dims_t, out_dims_t
  18. def get_warning(api, new_api=None, replace_newlines=False):
  19. if new_api is None:
  20. new_api = f"torch.func.{api}"
  21. warning = (
  22. f"We've integrated functorch into PyTorch. As the final step of the \n"
  23. f"integration, `functorch.{api}` is deprecated as of PyTorch \n"
  24. f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n"
  25. f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n"
  26. f"and/or the `torch.func` migration guide for more details \n"
  27. f"https://pytorch.org/docs/main/func.migrating.html"
  28. )
  29. if replace_newlines:
  30. warning = warning.replace("\n", "")
  31. return warning
  32. def warn_deprecated(api, new_api=None):
  33. warning = get_warning(api, new_api, replace_newlines=True)
  34. warnings.warn(warning, FutureWarning, stacklevel=3)
  35. def setup_docs(functorch_api, torch_func_api=None, new_api_name=None):
  36. api_name = functorch_api.__name__
  37. if torch_func_api is None:
  38. torch_func_api = getattr(_impl, api_name)
  39. # See https://docs.python.org/3/using/cmdline.html#cmdoption-OO
  40. if torch_func_api.__doc__ is None:
  41. return
  42. warning = get_warning(api_name, new_api_name)
  43. warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, " ")
  44. warning_note = textwrap.indent(warning_note, " ")
  45. functorch_api.__doc__ = torch_func_api.__doc__ + warning_note
  46. def vmap(
  47. func: Callable,
  48. in_dims: in_dims_t = 0,
  49. out_dims: out_dims_t = 0,
  50. randomness: str = "error",
  51. *,
  52. chunk_size=None,
  53. ) -> Callable:
  54. warn_deprecated("vmap", "torch.vmap")
  55. return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size)
  56. def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
  57. warn_deprecated("grad")
  58. return apis.grad(func, argnums, has_aux)
  59. def grad_and_value(
  60. func: Callable, argnums: argnums_t = 0, has_aux: bool = False
  61. ) -> Callable:
  62. warn_deprecated("grad_and_value")
  63. return apis.grad_and_value(func, argnums, has_aux)
  64. def vjp(func: Callable, *primals, has_aux: bool = False):
  65. warn_deprecated("vjp")
  66. return _impl.vjp(func, *primals, has_aux=has_aux)
  67. def jvp(
  68. func: Callable,
  69. primals: Any,
  70. tangents: Any,
  71. *,
  72. strict: bool = False,
  73. has_aux: bool = False,
  74. ):
  75. warn_deprecated("jvp")
  76. return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux)
  77. def jacrev(
  78. func: Callable,
  79. argnums: Union[int, Tuple[int]] = 0,
  80. *,
  81. has_aux=False,
  82. chunk_size: Optional[int] = None,
  83. _preallocate_and_copy=False,
  84. ):
  85. warn_deprecated("jacrev")
  86. return _impl.jacrev(
  87. func,
  88. argnums,
  89. has_aux=has_aux,
  90. chunk_size=chunk_size,
  91. _preallocate_and_copy=_preallocate_and_copy,
  92. )
  93. def jacfwd(
  94. func: Callable,
  95. argnums: argnums_t = 0,
  96. has_aux: bool = False,
  97. *,
  98. randomness: str = "error",
  99. ):
  100. warn_deprecated("jacfwd")
  101. return _impl.jacfwd(func, argnums, has_aux, randomness=randomness)
  102. def hessian(func, argnums=0):
  103. warn_deprecated("hessian")
  104. return _impl.hessian(func, argnums=argnums)
  105. def functionalize(func: Callable, *, remove: str = "mutations") -> Callable:
  106. warn_deprecated("functionalize")
  107. return _impl.functionalize(func, remove=remove)
  108. def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
  109. warn_deprecated("make_functional", "torch.func.functional_call")
  110. return _nn_impl.make_functional(model, disable_autograd_tracking)
  111. def make_functional_with_buffers(
  112. model: nn.Module, disable_autograd_tracking: bool = False
  113. ):
  114. warn_deprecated("make_functional_with_buffers", "torch.func.functional_call")
  115. return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking)
  116. def combine_state_for_ensemble(models):
  117. warn_deprecated("combine_state_for_ensemble", "torch.func.stack_module_state")
  118. return _nn_impl.combine_state_for_ensemble(models)
  119. setup_docs(vmap, apis.vmap, "torch.vmap")
  120. setup_docs(grad, apis.grad)
  121. setup_docs(grad_and_value, apis.grad_and_value)
  122. setup_docs(vjp)
  123. setup_docs(jvp)
  124. setup_docs(jacrev)
  125. setup_docs(jacfwd)
  126. setup_docs(hessian)
  127. setup_docs(functionalize)
  128. setup_docs(make_functional, _nn_impl.make_functional, "torch.func.functional_call")
  129. setup_docs(
  130. make_functional_with_buffers, _nn_impl.make_functional, "torch.func.functional_call"
  131. )
  132. setup_docs(
  133. combine_state_for_ensemble,
  134. _nn_impl.combine_state_for_ensemble,
  135. "torch.func.stack_module_state",
  136. )