| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- # mypy: allow-untyped-defs
- import warnings
- from contextlib import contextmanager
- from typing import Any, Iterator
- import torch._C
- # These are imported so users can access them from the `torch.jit` module
- from torch._jit_internal import (
- _Await,
- _drop,
- _IgnoreContextManager,
- _isinstance,
- _overload,
- _overload_method,
- export,
- Final,
- Future,
- ignore,
- is_scripting,
- unused,
- )
- from torch.jit._async import fork, wait
- from torch.jit._await import _awaitable, _awaitable_nowait, _awaitable_wait
- from torch.jit._decomposition_utils import _register_decomposition
- from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
- from torch.jit._fuser import (
- fuser,
- last_executed_optimized_graph,
- optimized_execution,
- set_fusion_strategy,
- )
- from torch.jit._ir_utils import _InsertPoint
- from torch.jit._script import (
- _ScriptProfile,
- _unwrap_optional,
- Attribute,
- CompilationUnit,
- interface,
- RecursiveScriptClass,
- RecursiveScriptModule,
- script,
- script_method,
- ScriptFunction,
- ScriptModule,
- ScriptWarning,
- )
- from torch.jit._serialization import (
- jit_module_from_flatbuffer,
- load,
- save,
- save_jit_module_to_flatbuffer,
- )
- from torch.jit._trace import (
- _flatten,
- _get_trace_graph,
- _script_if_tracing,
- _unique_state_dict,
- is_tracing,
- ONNXTracedModule,
- TopLevelTracedModule,
- trace,
- trace_module,
- TracedModule,
- TracerWarning,
- TracingCheckError,
- )
- from torch.utils import set_module
- __all__ = [
- "Attribute",
- "CompilationUnit",
- "Error",
- "Future",
- "ScriptFunction",
- "ScriptModule",
- "annotate",
- "enable_onednn_fusion",
- "export",
- "export_opnames",
- "fork",
- "freeze",
- "interface",
- "ignore",
- "isinstance",
- "load",
- "onednn_fusion_enabled",
- "optimize_for_inference",
- "save",
- "script",
- "script_if_tracing",
- "set_fusion_strategy",
- "strict_fusion",
- "trace",
- "trace_module",
- "unused",
- "wait",
- ]
- # For backwards compatibility
- _fork = fork
- _wait = wait
- _set_fusion_strategy = set_fusion_strategy
- def export_opnames(m):
- r"""
- Generate new bytecode for a Script module.
- Returns what the op list would be for a Script Module based off the current code base.
- If you have a LiteScriptModule and want to get the currently present
- list of ops call _export_operator_list instead.
- """
- return torch._C._export_opnames(m._c)
- # torch.jit.Error
- Error = torch._C.JITException
- set_module(Error, "torch.jit")
- # This is not perfect but works in common cases
- Error.__name__ = "Error"
- Error.__qualname__ = "Error"
- # for use in python if using annotate
- def annotate(the_type, the_value):
- """Use to give type of `the_value` in TorchScript compiler.
- This method is a pass-through function that returns `the_value`, used to hint TorchScript
- compiler the type of `the_value`. It is a no-op when running outside of TorchScript.
- Though TorchScript can infer correct type for most Python expressions, there are some cases where
- type inference can be wrong, including:
- - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor`
- - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume
- it is type `T` rather than `Optional[T]`
- Note that `annotate()` does not help in `__init__` method of `torch.nn.Module` subclasses because it
- is executed in eager mode. To annotate types of `torch.nn.Module` attributes,
- use :meth:`~torch.jit.Attribute` instead.
- Example:
- .. testcode::
- import torch
- from typing import Dict
- @torch.jit.script
- def fn():
- # Telling TorchScript that this empty dictionary is a (str -> int) dictionary
- # instead of default dictionary type of (str -> Tensor).
- d = torch.jit.annotate(Dict[str, int], {})
- # Without `torch.jit.annotate` above, following statement would fail because of
- # type mismatch.
- d["name"] = 20
- .. testcleanup::
- del fn
- Args:
- the_type: Python type that should be passed to TorchScript compiler as type hint for `the_value`
- the_value: Value or expression to hint type for.
- Returns:
- `the_value` is passed back as return value.
- """
- return the_value
- def script_if_tracing(fn):
- """
- Compiles ``fn`` when it is first called during tracing.
- ``torch.jit.script`` has a non-negligible start up time when it is first called due to
- lazy-initializations of many compiler builtins. Therefore you should not use
- it in library code. However, you may want to have parts of your library work
- in tracing even if they use control flow. In these cases, you should use
- ``@torch.jit.script_if_tracing`` to substitute for
- ``torch.jit.script``.
- Args:
- fn: A function to compile.
- Returns:
- If called during tracing, a :class:`ScriptFunction` created by `torch.jit.script` is returned.
- Otherwise, the original function `fn` is returned.
- """
- return _script_if_tracing(fn)
- # for torch.jit.isinstance
- def isinstance(obj, target_type):
- """
- Provide container type refinement in TorchScript.
- It can refine parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``,
- ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also
- refine basic types such as bools and ints that are available in TorchScript.
- Args:
- obj: object to refine the type of
- target_type: type to try to refine obj to
- Returns:
- ``bool``: True if obj was successfully refined to the type of target_type,
- False otherwise with no new type refinement
- Example (using ``torch.jit.isinstance`` for type refinement):
- .. testcode::
- import torch
- from typing import Any, Dict, List
- class MyModule(torch.nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, input: Any): # note the Any type
- if torch.jit.isinstance(input, List[torch.Tensor]):
- for t in input:
- y = t.clamp(0, 0.5)
- elif torch.jit.isinstance(input, Dict[str, str]):
- for val in input.values():
- print(val)
- m = torch.jit.script(MyModule())
- x = [torch.rand(3,3), torch.rand(4,3)]
- m(x)
- y = {"key1":"val1","key2":"val2"}
- m(y)
- """
- return _isinstance(obj, target_type)
- class strict_fusion:
- """
- Give errors if not all nodes have been fused in inference, or symbolically differentiated in training.
- Example:
- Forcing fusion of additions.
- .. code-block:: python
- @torch.jit.script
- def foo(x):
- with torch.jit.strict_fusion():
- return x + x + x
- """
- def __init__(self):
- if not torch._jit_internal.is_scripting():
- warnings.warn("Only works in script mode")
- pass
- def __enter__(self):
- pass
- def __exit__(self, type: Any, value: Any, tb: Any) -> None:
- pass
- # Context manager for globally hiding source ranges when printing graphs.
- # Note that these functions are exposed to Python as static members of the
- # Graph class, so mypy checks need to be skipped.
- @contextmanager
- def _hide_source_ranges() -> Iterator[None]:
- old_enable_source_ranges = torch._C.Graph.global_print_source_ranges # type: ignore[attr-defined]
- try:
- torch._C.Graph.set_global_print_source_ranges(False) # type: ignore[attr-defined]
- yield
- finally:
- torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined]
- def enable_onednn_fusion(enabled: bool):
- """Enable or disables onednn JIT fusion based on the parameter `enabled`."""
- torch._C._jit_set_llga_enabled(enabled)
- def onednn_fusion_enabled():
- """Return whether onednn JIT fusion is enabled."""
- return torch._C._jit_llga_enabled()
- del Any
- if not torch._C._jit_init():
- raise RuntimeError("JIT initialization failed")
|