__init__.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from contextlib import contextmanager
  4. from typing import Any, Iterator
  5. import torch._C
  6. # These are imported so users can access them from the `torch.jit` module
  7. from torch._jit_internal import (
  8. _Await,
  9. _drop,
  10. _IgnoreContextManager,
  11. _isinstance,
  12. _overload,
  13. _overload_method,
  14. export,
  15. Final,
  16. Future,
  17. ignore,
  18. is_scripting,
  19. unused,
  20. )
  21. from torch.jit._async import fork, wait
  22. from torch.jit._await import _awaitable, _awaitable_nowait, _awaitable_wait
  23. from torch.jit._decomposition_utils import _register_decomposition
  24. from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
  25. from torch.jit._fuser import (
  26. fuser,
  27. last_executed_optimized_graph,
  28. optimized_execution,
  29. set_fusion_strategy,
  30. )
  31. from torch.jit._ir_utils import _InsertPoint
  32. from torch.jit._script import (
  33. _ScriptProfile,
  34. _unwrap_optional,
  35. Attribute,
  36. CompilationUnit,
  37. interface,
  38. RecursiveScriptClass,
  39. RecursiveScriptModule,
  40. script,
  41. script_method,
  42. ScriptFunction,
  43. ScriptModule,
  44. ScriptWarning,
  45. )
  46. from torch.jit._serialization import (
  47. jit_module_from_flatbuffer,
  48. load,
  49. save,
  50. save_jit_module_to_flatbuffer,
  51. )
  52. from torch.jit._trace import (
  53. _flatten,
  54. _get_trace_graph,
  55. _script_if_tracing,
  56. _unique_state_dict,
  57. is_tracing,
  58. ONNXTracedModule,
  59. TopLevelTracedModule,
  60. trace,
  61. trace_module,
  62. TracedModule,
  63. TracerWarning,
  64. TracingCheckError,
  65. )
  66. from torch.utils import set_module
  67. __all__ = [
  68. "Attribute",
  69. "CompilationUnit",
  70. "Error",
  71. "Future",
  72. "ScriptFunction",
  73. "ScriptModule",
  74. "annotate",
  75. "enable_onednn_fusion",
  76. "export",
  77. "export_opnames",
  78. "fork",
  79. "freeze",
  80. "interface",
  81. "ignore",
  82. "isinstance",
  83. "load",
  84. "onednn_fusion_enabled",
  85. "optimize_for_inference",
  86. "save",
  87. "script",
  88. "script_if_tracing",
  89. "set_fusion_strategy",
  90. "strict_fusion",
  91. "trace",
  92. "trace_module",
  93. "unused",
  94. "wait",
  95. ]
  96. # For backwards compatibility
  97. _fork = fork
  98. _wait = wait
  99. _set_fusion_strategy = set_fusion_strategy
  100. def export_opnames(m):
  101. r"""
  102. Generate new bytecode for a Script module.
  103. Returns what the op list would be for a Script Module based off the current code base.
  104. If you have a LiteScriptModule and want to get the currently present
  105. list of ops call _export_operator_list instead.
  106. """
  107. return torch._C._export_opnames(m._c)
  108. # torch.jit.Error
  109. Error = torch._C.JITException
  110. set_module(Error, "torch.jit")
  111. # This is not perfect but works in common cases
  112. Error.__name__ = "Error"
  113. Error.__qualname__ = "Error"
  114. # for use in python if using annotate
  115. def annotate(the_type, the_value):
  116. """Use to give type of `the_value` in TorchScript compiler.
  117. This method is a pass-through function that returns `the_value`, used to hint TorchScript
  118. compiler the type of `the_value`. It is a no-op when running outside of TorchScript.
  119. Though TorchScript can infer correct type for most Python expressions, there are some cases where
  120. type inference can be wrong, including:
  121. - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor`
  122. - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume
  123. it is type `T` rather than `Optional[T]`
  124. Note that `annotate()` does not help in `__init__` method of `torch.nn.Module` subclasses because it
  125. is executed in eager mode. To annotate types of `torch.nn.Module` attributes,
  126. use :meth:`~torch.jit.Attribute` instead.
  127. Example:
  128. .. testcode::
  129. import torch
  130. from typing import Dict
  131. @torch.jit.script
  132. def fn():
  133. # Telling TorchScript that this empty dictionary is a (str -> int) dictionary
  134. # instead of default dictionary type of (str -> Tensor).
  135. d = torch.jit.annotate(Dict[str, int], {})
  136. # Without `torch.jit.annotate` above, following statement would fail because of
  137. # type mismatch.
  138. d["name"] = 20
  139. .. testcleanup::
  140. del fn
  141. Args:
  142. the_type: Python type that should be passed to TorchScript compiler as type hint for `the_value`
  143. the_value: Value or expression to hint type for.
  144. Returns:
  145. `the_value` is passed back as return value.
  146. """
  147. return the_value
  148. def script_if_tracing(fn):
  149. """
  150. Compiles ``fn`` when it is first called during tracing.
  151. ``torch.jit.script`` has a non-negligible start up time when it is first called due to
  152. lazy-initializations of many compiler builtins. Therefore you should not use
  153. it in library code. However, you may want to have parts of your library work
  154. in tracing even if they use control flow. In these cases, you should use
  155. ``@torch.jit.script_if_tracing`` to substitute for
  156. ``torch.jit.script``.
  157. Args:
  158. fn: A function to compile.
  159. Returns:
  160. If called during tracing, a :class:`ScriptFunction` created by `torch.jit.script` is returned.
  161. Otherwise, the original function `fn` is returned.
  162. """
  163. return _script_if_tracing(fn)
  164. # for torch.jit.isinstance
  165. def isinstance(obj, target_type):
  166. """
  167. Provide container type refinement in TorchScript.
  168. It can refine parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``,
  169. ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also
  170. refine basic types such as bools and ints that are available in TorchScript.
  171. Args:
  172. obj: object to refine the type of
  173. target_type: type to try to refine obj to
  174. Returns:
  175. ``bool``: True if obj was successfully refined to the type of target_type,
  176. False otherwise with no new type refinement
  177. Example (using ``torch.jit.isinstance`` for type refinement):
  178. .. testcode::
  179. import torch
  180. from typing import Any, Dict, List
  181. class MyModule(torch.nn.Module):
  182. def __init__(self):
  183. super().__init__()
  184. def forward(self, input: Any): # note the Any type
  185. if torch.jit.isinstance(input, List[torch.Tensor]):
  186. for t in input:
  187. y = t.clamp(0, 0.5)
  188. elif torch.jit.isinstance(input, Dict[str, str]):
  189. for val in input.values():
  190. print(val)
  191. m = torch.jit.script(MyModule())
  192. x = [torch.rand(3,3), torch.rand(4,3)]
  193. m(x)
  194. y = {"key1":"val1","key2":"val2"}
  195. m(y)
  196. """
  197. return _isinstance(obj, target_type)
  198. class strict_fusion:
  199. """
  200. Give errors if not all nodes have been fused in inference, or symbolically differentiated in training.
  201. Example:
  202. Forcing fusion of additions.
  203. .. code-block:: python
  204. @torch.jit.script
  205. def foo(x):
  206. with torch.jit.strict_fusion():
  207. return x + x + x
  208. """
  209. def __init__(self):
  210. if not torch._jit_internal.is_scripting():
  211. warnings.warn("Only works in script mode")
  212. pass
  213. def __enter__(self):
  214. pass
  215. def __exit__(self, type: Any, value: Any, tb: Any) -> None:
  216. pass
  217. # Context manager for globally hiding source ranges when printing graphs.
  218. # Note that these functions are exposed to Python as static members of the
  219. # Graph class, so mypy checks need to be skipped.
  220. @contextmanager
  221. def _hide_source_ranges() -> Iterator[None]:
  222. old_enable_source_ranges = torch._C.Graph.global_print_source_ranges # type: ignore[attr-defined]
  223. try:
  224. torch._C.Graph.set_global_print_source_ranges(False) # type: ignore[attr-defined]
  225. yield
  226. finally:
  227. torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined]
  228. def enable_onednn_fusion(enabled: bool):
  229. """Enable or disables onednn JIT fusion based on the parameter `enabled`."""
  230. torch._C._jit_set_llga_enabled(enabled)
  231. def onednn_fusion_enabled():
  232. """Return whether onednn JIT fusion is enabled."""
  233. return torch._C._jit_llga_enabled()
  234. del Any
  235. if not torch._C._jit_init():
  236. raise RuntimeError("JIT initialization failed")