_script.pyi 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # mypy: allow-untyped-defs
  2. # mypy: disable-error-code="type-arg"
  3. from typing import (
  4. Any,
  5. Callable,
  6. Dict,
  7. List,
  8. NamedTuple,
  9. Optional,
  10. overload,
  11. Tuple,
  12. Type,
  13. TypeVar,
  14. Union,
  15. )
  16. from typing_extensions import Never, TypeAlias
  17. from _typeshed import Incomplete
  18. import torch
  19. from torch._classes import classes as classes
  20. from torch._jit_internal import _qualified_name as _qualified_name
  21. from torch.jit._builtins import _register_builtin as _register_builtin
  22. from torch.jit._fuser import (
  23. _graph_for as _graph_for,
  24. _script_method_graph_for as _script_method_graph_for,
  25. )
  26. from torch.jit._monkeytype_config import (
  27. JitTypeTraceConfig as JitTypeTraceConfig,
  28. JitTypeTraceStore as JitTypeTraceStore,
  29. monkeytype_trace as monkeytype_trace,
  30. )
  31. from torch.jit._recursive import (
  32. _compile_and_register_class as _compile_and_register_class,
  33. infer_methods_to_compile as infer_methods_to_compile,
  34. ScriptMethodStub as ScriptMethodStub,
  35. wrap_cpp_module as wrap_cpp_module,
  36. )
  37. from torch.jit._state import (
  38. _enabled as _enabled,
  39. _set_jit_function_cache as _set_jit_function_cache,
  40. _set_jit_overload_cache as _set_jit_overload_cache,
  41. _try_get_jit_cached_function as _try_get_jit_cached_function,
  42. _try_get_jit_cached_overloads as _try_get_jit_cached_overloads,
  43. )
  44. from torch.jit.frontend import (
  45. get_default_args as get_default_args,
  46. get_jit_class_def as get_jit_class_def,
  47. get_jit_def as get_jit_def,
  48. )
  49. from torch.nn import Module as Module
  50. from torch.overrides import (
  51. has_torch_function as has_torch_function,
  52. has_torch_function_unary as has_torch_function_unary,
  53. has_torch_function_variadic as has_torch_function_variadic,
  54. )
  55. from torch.package import (
  56. PackageExporter as PackageExporter,
  57. PackageImporter as PackageImporter,
  58. )
  59. from torch.utils import set_module as set_module
  60. from ._serialization import validate_map_location as validate_map_location
  61. ScriptFunction = torch._C.ScriptFunction
  62. type_trace_db: JitTypeTraceStore
  63. # Defined in torch/csrc/jit/python/script_init.cpp
  64. ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]]
  65. _ClassVar = TypeVar("_ClassVar", bound=type)
  66. def _reduce(cls) -> None: ...
  67. class Attribute(NamedTuple):
  68. value: Incomplete
  69. type: Incomplete
  70. def _get_type_trace_db(): ...
  71. def _get_function_from_type(cls, name): ...
  72. def _is_new_style_class(cls): ...
  73. class OrderedDictWrapper:
  74. _c: Incomplete
  75. def __init__(self, _c) -> None: ...
  76. def keys(self): ...
  77. def values(self): ...
  78. def __len__(self) -> int: ...
  79. def __delitem__(self, k) -> None: ...
  80. def items(self): ...
  81. def __setitem__(self, k, v) -> None: ...
  82. def __contains__(self, k) -> bool: ...
  83. def __getitem__(self, k): ...
  84. class OrderedModuleDict(OrderedDictWrapper):
  85. _python_modules: Incomplete
  86. def __init__(self, module, python_dict) -> None: ...
  87. def items(self): ...
  88. def __contains__(self, k) -> bool: ...
  89. def __setitem__(self, k, v) -> None: ...
  90. def __getitem__(self, k): ...
  91. class ScriptMeta(type):
  92. def __init__(cls, name, bases, attrs) -> None: ...
  93. class _CachedForward:
  94. def __get__(self, obj, cls): ...
  95. class ScriptWarning(Warning): ...
  96. def script_method(fn): ...
  97. class ConstMap:
  98. const_mapping: Incomplete
  99. def __init__(self, const_mapping) -> None: ...
  100. def __getattr__(self, attr): ...
  101. def unpackage_script_module(
  102. importer: PackageImporter, script_module_id: str
  103. ) -> torch.nn.Module: ...
  104. _magic_methods: Incomplete
  105. class RecursiveScriptClass:
  106. _c: Incomplete
  107. _props: Incomplete
  108. def __init__(self, cpp_class) -> None: ...
  109. def __getattr__(self, attr): ...
  110. def __setattr__(self, attr, value): ...
  111. def forward_magic_method(self, method_name, *args, **kwargs): ...
  112. def __getstate__(self) -> None: ...
  113. def __iadd__(self, other): ...
  114. def method_template(self, *args, **kwargs): ...
  115. class ScriptModule(Module, metaclass=ScriptMeta):
  116. __jit_unused_properties__: Incomplete
  117. def __init__(self) -> None: ...
  118. forward: Callable[..., Any]
  119. def __getattr__(self, attr): ...
  120. def __setattr__(self, attr, value): ...
  121. def define(self, src): ...
  122. def _replicate_for_data_parallel(self): ...
  123. def __reduce_package__(self, exporter: PackageExporter): ...
  124. # add __jit_unused_properties__
  125. @property
  126. def code(self) -> str: ...
  127. @property
  128. def code_with_constants(self) -> Tuple[str, ConstMap]: ...
  129. @property
  130. def graph(self) -> torch.Graph: ...
  131. @property
  132. def inlined_graph(self) -> torch.Graph: ...
  133. @property
  134. def original_name(self) -> str: ...
  135. class RecursiveScriptModule(ScriptModule):
  136. _disable_script_meta: bool
  137. _c: Incomplete
  138. def __init__(self, cpp_module) -> None: ...
  139. @staticmethod
  140. def _construct(cpp_module, init_fn): ...
  141. @staticmethod
  142. def _finalize_scriptmodule(script_module) -> None: ...
  143. _concrete_type: Incomplete
  144. _modules: Incomplete
  145. _parameters: Incomplete
  146. _buffers: Incomplete
  147. __dict__: Incomplete
  148. def _reconstruct(self, cpp_module) -> None: ...
  149. def save(self, f, **kwargs): ...
  150. def _save_for_lite_interpreter(self, *args, **kwargs): ...
  151. def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): ...
  152. def save_to_buffer(self, *args, **kwargs): ...
  153. def get_debug_state(self, *args, **kwargs): ...
  154. def extra_repr(self): ...
  155. def graph_for(self, *args, **kwargs): ...
  156. def define(self, src) -> None: ...
  157. def __getattr__(self, attr): ...
  158. def __setattr__(self, attr, value): ...
  159. def __copy__(self): ...
  160. def __deepcopy__(self, memo): ...
  161. def forward_magic_method(self, method_name, *args, **kwargs): ...
  162. def __iter__(self): ...
  163. def __getitem__(self, idx): ...
  164. def __len__(self) -> int: ...
  165. def __contains__(self, key) -> bool: ...
  166. def __dir__(self): ...
  167. def __bool__(self) -> bool: ...
  168. def _replicate_for_data_parallel(self): ...
  169. def _get_methods(cls): ...
  170. _compiled_methods_allowlist: Incomplete
  171. def _make_fail(name): ...
  172. def call_prepare_scriptable_func_impl(obj, memo): ...
  173. def call_prepare_scriptable_func(obj): ...
  174. def create_script_dict(obj): ...
  175. def create_script_list(obj, type_hint: Incomplete | None = ...): ...
  176. @overload
  177. def script(
  178. obj: Type[Module],
  179. optimize: Optional[bool] = None,
  180. _frames_up: int = 0,
  181. _rcb: Optional[ResolutionCallback] = None,
  182. example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
  183. ) -> Never: ...
  184. @overload
  185. def script( # type: ignore[misc]
  186. obj: Dict,
  187. optimize: Optional[bool] = None,
  188. _frames_up: int = 0,
  189. _rcb: Optional[ResolutionCallback] = None,
  190. example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
  191. ) -> torch.ScriptDict: ...
  192. @overload
  193. def script( # type: ignore[misc]
  194. obj: List,
  195. optimize: Optional[bool] = None,
  196. _frames_up: int = 0,
  197. _rcb: Optional[ResolutionCallback] = None,
  198. example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
  199. ) -> torch.ScriptList: ...
  200. @overload
  201. def script( # type: ignore[misc]
  202. obj: Module,
  203. optimize: Optional[bool] = None,
  204. _frames_up: int = 0,
  205. _rcb: Optional[ResolutionCallback] = None,
  206. example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
  207. ) -> RecursiveScriptModule: ...
  208. @overload
  209. def script( # type: ignore[misc]
  210. obj: _ClassVar,
  211. optimize: Optional[bool] = None,
  212. _frames_up: int = 0,
  213. _rcb: Optional[ResolutionCallback] = None,
  214. example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
  215. ) -> _ClassVar: ...
  216. @overload
  217. def script( # type: ignore[misc]
  218. obj: Callable,
  219. optimize: Optional[bool] = None,
  220. _frames_up: int = 0,
  221. _rcb: Optional[ResolutionCallback] = None,
  222. example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
  223. ) -> ScriptFunction: ...
  224. @overload
  225. def script(
  226. obj: Any,
  227. optimize: Optional[bool] = None,
  228. _frames_up: int = 0,
  229. _rcb: Optional[ResolutionCallback] = None,
  230. example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
  231. ) -> RecursiveScriptClass: ...
  232. @overload
  233. def script(
  234. obj,
  235. optimize: Incomplete | None = ...,
  236. _frames_up: int = ...,
  237. _rcb: Incomplete | None = ...,
  238. example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = ...,
  239. ): ...
  240. def _check_overload_defaults(impl_defaults, overload_defaults, loc) -> None: ...
  241. def _compile_function_with_overload(overload_fn, qual_name, impl_fn): ...
  242. def _get_overloads(obj): ...
  243. def _check_directly_compile_overloaded(obj) -> None: ...
  244. def interface(obj): ...
  245. def _recursive_compile_class(obj, loc): ...
  246. CompilationUnit: Incomplete
  247. def pad(s: str, padding: int, offset: int = ..., char: str = ...): ...
  248. class _ScriptProfileColumn:
  249. header: Incomplete
  250. alignment: Incomplete
  251. offset: Incomplete
  252. rows: Incomplete
  253. def __init__(
  254. self, header: str, alignment: int = ..., offset: int = ...
  255. ) -> None: ...
  256. def add_row(self, lineno: int, value: Any): ...
  257. def materialize(self): ...
  258. class _ScriptProfileTable:
  259. cols: Incomplete
  260. source_range: Incomplete
  261. def __init__(
  262. self, cols: List[_ScriptProfileColumn], source_range: List[int]
  263. ) -> None: ...
  264. def dump_string(self): ...
  265. class _ScriptProfile:
  266. profile: Incomplete
  267. def __init__(self) -> None: ...
  268. def enable(self) -> None: ...
  269. def disable(self) -> None: ...
  270. def dump_string(self) -> str: ...
  271. def dump(self) -> None: ...
  272. def _unwrap_optional(x): ...