__init__.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from typing import List
  4. __all__ = [
  5. "compile",
  6. "assume_constant_result",
  7. "reset",
  8. "allow_in_graph",
  9. "list_backends",
  10. "disable",
  11. "cudagraph_mark_step_begin",
  12. "wrap_numpy",
  13. "is_compiling",
  14. "is_dynamo_compiling",
  15. ]
  16. def compile(*args, **kwargs):
  17. """
  18. See :func:`torch.compile` for details on the arguments for this function.
  19. """
  20. return torch.compile(*args, **kwargs)
  21. def reset() -> None:
  22. """
  23. This function clears all compilation caches and restores the system to its initial state.
  24. It is recommended to call this function, especially after using operations like `torch.compile(...)`
  25. to ensure a clean state before another unrelated compilation
  26. """
  27. import torch._dynamo
  28. torch._dynamo.reset()
  29. def allow_in_graph(fn):
  30. """
  31. Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
  32. and instead directly write it to the graph when encountered.
  33. If you are using :func:`torch.compile` (with backend="inductor" (the default)), or
  34. :func:`torch.export.export`, and trying to black-box a Python function throughout
  35. all tracing, do not use this API.
  36. Instead, please create a custom operator (see :ref:`custom-ops-landing-page`)
  37. .. warning::
  38. If you're a typical torch.compile user (e.g. you're applying torch.compile to
  39. a model to make it run faster), you probably don't want to use this function.
  40. :func:`allow_in_graph` is a footgun because it skips the compiler frontend
  41. (Dynamo) that is responsible for doing safety checks (graph breaks, handling
  42. closures, etc). Incorrect usage will lead to difficult-to-debug silent
  43. incorrectness issues.
  44. Given a Python function with no allow_in_graph decorator, regular execution
  45. of torch.compile traces through the function. :func:`allow_in_graph` changes
  46. it so that the frontend does not trace inside the function, but the compiler
  47. backend still traces through it. Compare this to custom operators, which
  48. treats a function as a black box throughout the torch.compile stack. The following
  49. table compares these mechanisms.
  50. +------------------------+-----------------------+--------------------------------+
  51. | Mechanism | Frontend (Dynamo) | Backend (AOTAutograd+Inductor) |
  52. +========================+=======================+================================+
  53. | no decorator | trace inside | trace inside |
  54. +------------------------+-----------------------+--------------------------------+
  55. | allow_in_graph | opaque callable | trace inside |
  56. +------------------------+-----------------------+--------------------------------+
  57. | custom op | opaque callable | opaque callable |
  58. +------------------------+-----------------------+--------------------------------+
  59. One common use case for :func:`allow_in_graph()` is as an escape hatch for the compiler
  60. frontend: if you know the function works w.r.t. to the downstream components of the
  61. compilation stack (AOTAutograd and Inductor) but there is a Dynamo bug that prevents it from
  62. symbolically introspecting the function properly (or if your code is in C/C++ and
  63. therefore cannot be introspected with Dynamo), then one can decorate said function
  64. with :func:`allow_in_graph` to bypass Dynamo.
  65. We require that ``fn`` adhere to the following restrictions. Failure to adhere
  66. results in undefined behavior:
  67. - The inputs to ``fn`` must be Proxy-able types in the FX graph. Valid types include:
  68. Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?]
  69. Tuple[Tensor?, ...]/Tuple[int?, ...]/Tuple[float?, ...]/torch.dtype/torch.device
  70. - The outputs to ``fn`` must be Proxy-able types in the FX graph (see previous bullet)
  71. - all Tensors used inside of ``fn`` must be passed directly as inputs to ``fn``
  72. (as opposed to being captured variables).
  73. Args:
  74. fn: A callable representing the function to be included in the graph.
  75. If ``fn`` is a list or tuple of callables it recursively applies
  76. :func:`allow_in_graph()` to each function and returns a new list or
  77. tuple containing the modified functions.
  78. Example::
  79. torch.compiler.allow_in_graph(my_custom_function)
  80. @torch.compile(...)
  81. def fn(a):
  82. x = torch.add(x, 1)
  83. x = my_custom_function(x)
  84. x = torch.add(x, 1)
  85. return x
  86. fn(...)
  87. Will capture a single graph containing ``my_custom_function()``.
  88. """
  89. import torch._dynamo
  90. return torch._dynamo.allow_in_graph(fn)
  91. def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
  92. """
  93. Return valid strings that can be passed to `torch.compile(..., backend="name")`.
  94. Args:
  95. exclude_tags(optional): A tuple of strings representing tags to exclude.
  96. """
  97. import torch._dynamo
  98. return torch._dynamo.list_backends(exclude_tags)
  99. def assume_constant_result(fn):
  100. """
  101. This function is used to mark a function `fn` as having a constant result.
  102. This allows the compiler to optimize away your function
  103. Returns The same function `fn`
  104. Args:
  105. fn: The function to be marked as having a constant result.
  106. .. warning::
  107. `assume_constant_result` can if invalid cause safety and soundness issues, :func:`torch.compile`
  108. will not attempt to validate whether the constant assumption is true or not
  109. """
  110. import torch._dynamo
  111. return torch._dynamo.assume_constant_result(fn)
  112. def disable(fn=None, recursive=True):
  113. """
  114. This function provides both a decorator and a context manager to disable compilation on a function
  115. It also provides the option of recursively disabling called functions
  116. Args:
  117. fn (optional): The function to disable
  118. recursive (optional): A boolean value indicating whether the disabling should be recursive.
  119. """
  120. import torch._dynamo
  121. return torch._dynamo.disable(fn, recursive)
  122. def cudagraph_mark_step_begin():
  123. """
  124. Indicates that a new iteration of inference or training is about to begin.
  125. CUDA Graphs will free tensors of a prior iteration. A new iteration is started on each invocation of
  126. torch.compile, so long as there is not a pending backward that has not been called.
  127. If that heuristic is wrong, such as in the following example, manually mark it with this api.
  128. .. code-block:: python
  129. @torch.compile(mode="reduce-overhead")
  130. def rand_foo():
  131. return torch.rand([4], device="cuda")
  132. for _ in range(5):
  133. torch.compiler.cudagraph_mark_step_begin()
  134. rand_foo() + rand_foo()
  135. For more details, see `torch.compiler_cudagraph_trees <https://pytorch.org/docs/main/torch.compiler_cudagraph_trees.html>`__
  136. """
  137. from torch._inductor import cudagraph_trees
  138. cudagraph_trees.mark_step_begin()
  139. def wrap_numpy(fn):
  140. r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
  141. from ``torch.Tensor``s to ``torch.Tensor``s.
  142. It is designed to be used with :func:`torch.compile` with ``fullgraph=True``. It allows to
  143. compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code
  144. on CUDA or compute its gradients.
  145. .. note::
  146. This decorator does not work without :func:`torch.compile`.
  147. Example::
  148. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  149. >>> # Compile a NumPy function as a Tensor -> Tensor function
  150. >>> @torch.compile(fullgraph=True)
  151. >>> @torch.compiler.wrap_numpy
  152. >>> def fn(a: np.ndarray):
  153. >>> return np.sum(a * a)
  154. >>> # Execute the NumPy function using Tensors on CUDA and compute the gradients
  155. >>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True)
  156. >>> out = fn(x)
  157. >>> out.backward()
  158. >>> print(x.grad)
  159. tensor([ 0., 2., 4., 6., 8., 10.], device='cuda:0')
  160. """
  161. from torch._dynamo.external_utils import wrap_numpy as wrap
  162. return wrap(fn)
  163. _is_compiling_flag: bool = False
  164. def is_compiling() -> bool:
  165. """
  166. Indicates whether a graph is executed/traced as part of torch.compile() or torch.export().
  167. Note that there are 2 other related flags that should deprecated eventually:
  168. * torch._dynamo.external_utils.is_compiling()
  169. * torch._utils.is_compiling()
  170. Example::
  171. >>> def forward(self, x):
  172. >>> if not torch.compiler.is_compiling():
  173. >>> pass # ...logic that is not needed in a compiled/traced graph...
  174. >>>
  175. >>> # ...rest of the function...
  176. """
  177. if torch.jit.is_scripting():
  178. return False
  179. else:
  180. return _is_compiling_flag
  181. def is_dynamo_compiling() -> bool:
  182. """
  183. Indicates whether a graph is traced via TorchDynamo.
  184. It's stricter than is_compiling() flag, as it would only be set to True when
  185. TorchDynamo is used.
  186. Example::
  187. >>> def forward(self, x):
  188. >>> if not torch.compiler.is_dynamo_compiling():
  189. >>> pass # ...logic that is not needed in a TorchDynamo-traced graph...
  190. >>>
  191. >>> # ...rest of the function...
  192. """
  193. return False