_distributed_rpc.pyi 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # mypy: allow-untyped-defs
  2. # mypy: disable-error-code="type-arg"
  3. from datetime import timedelta
  4. from typing import Any, Dict, Generic, List, Optional, overload, Tuple, Type, TypeVar
  5. import torch
  6. from . import Future
  7. from ._autograd import ProfilerEvent
  8. from ._distributed_c10d import Store
  9. from ._profiler import ProfilerConfig
  10. # This module is defined in torch/csrc/distributed/rpc/init.cpp
  11. _DEFAULT_INIT_METHOD: str
  12. _DEFAULT_NUM_WORKER_THREADS: int
  13. _UNSET_RPC_TIMEOUT: float
  14. _DEFAULT_RPC_TIMEOUT_SEC: float
  15. _T = TypeVar("_T")
  16. class RpcBackendOptions:
  17. rpc_timeout: float
  18. init_method: str
  19. def __init__(
  20. self,
  21. rpc_timeout: float = ...,
  22. init_method: str = ...,
  23. ): ...
  24. class WorkerInfo:
  25. def __init__(self, name: str, worker_id: int): ...
  26. @property
  27. def name(self) -> str: ...
  28. @property
  29. def id(self) -> int: ...
  30. def __eq__(self, other: object) -> bool: ...
  31. class RpcAgent:
  32. def join(self, shutdown: bool = False, timeout: float = 0): ...
  33. def sync(self): ...
  34. def shutdown(self): ...
  35. @overload
  36. def get_worker_info(self) -> WorkerInfo: ...
  37. @overload
  38. def get_worker_info(self, workerName: str) -> WorkerInfo: ...
  39. def get_worker_infos(self) -> List[WorkerInfo]: ...
  40. def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
  41. def get_debug_info(self) -> Dict[str, str]: ...
  42. def get_metrics(self) -> Dict[str, str]: ...
  43. class PyRRef(Generic[_T]):
  44. def __init__(self, value: _T, type_hint: Any = None) -> None: ...
  45. def is_owner(self) -> bool: ...
  46. def confirmed_by_owner(self) -> bool: ...
  47. def owner(self) -> WorkerInfo: ...
  48. def owner_name(self) -> str: ...
  49. def to_here(self, timeout: float = ...) -> _T: ...
  50. def local_value(self) -> Any: ...
  51. def rpc_sync(self, timeout: float = ...) -> Any: ...
  52. def rpc_async(self, timeout: float = ...) -> Any: ...
  53. def remote(self, timeout: float = ...) -> Any: ...
  54. def _serialize(self) -> Tuple: ...
  55. @staticmethod
  56. def _deserialize(tp: Tuple) -> PyRRef: ...
  57. def _get_type(self) -> Type[_T]: ...
  58. def _get_future(self) -> Future[_T]: ...
  59. def _get_profiling_future(self) -> Future[_T]: ...
  60. def _set_profiling_future(self, profilingFuture: Future[_T]): ...
  61. class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
  62. num_worker_threads: int
  63. device_maps: Dict[str, Dict[torch.device, torch.device]]
  64. devices: List[torch.device]
  65. def __init__(
  66. self,
  67. num_worker_threads: int,
  68. _transports: Optional[List],
  69. _channels: Optional[List],
  70. rpc_timeout: float = ...,
  71. init_method: str = ...,
  72. device_maps: Dict[str, Dict[torch.device, torch.device]] = {}, # noqa: B006
  73. devices: List[torch.device] = [], # noqa: B006
  74. ): ...
  75. def _set_device_map(
  76. self,
  77. to: str,
  78. device_map: Dict[torch.device, torch.device],
  79. ): ...
  80. class TensorPipeAgent(RpcAgent):
  81. def __init__(
  82. self,
  83. store: Store,
  84. name: str,
  85. worker_id: int,
  86. world_size: Optional[int],
  87. opts: _TensorPipeRpcBackendOptionsBase,
  88. reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
  89. devices: List[torch.device],
  90. ): ...
  91. def join(self, shutdown: bool = False, timeout: float = 0): ...
  92. def shutdown(self): ...
  93. @overload
  94. def get_worker_info(self) -> WorkerInfo: ...
  95. @overload
  96. def get_worker_info(self, workerName: str) -> WorkerInfo: ...
  97. @overload
  98. def get_worker_info(self, id: int) -> WorkerInfo: ...
  99. def get_worker_infos(self) -> List[WorkerInfo]: ...
  100. def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
  101. def _update_group_membership(
  102. self,
  103. worker_info: WorkerInfo,
  104. my_devices: List[torch.device],
  105. reverse_device_map: Dict[str, Dict[torch.device, torch.device]],
  106. is_join: bool,
  107. ): ...
  108. def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
  109. @property
  110. def is_static_group(self) -> bool: ...
  111. @property
  112. def store(self) -> Store: ...
  113. def _is_current_rpc_agent_set() -> bool: ...
  114. def _get_current_rpc_agent() -> RpcAgent: ...
  115. def _set_and_start_rpc_agent(agent: RpcAgent): ...
  116. def _reset_current_rpc_agent(): ...
  117. def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
  118. def _destroy_rref_context(ignoreRRefLeak: bool): ...
  119. def _rref_context_get_debug_info() -> Dict[str, str]: ...
  120. def _cleanup_python_rpc_handler(): ...
  121. def _invoke_rpc_builtin(
  122. dst: WorkerInfo,
  123. opName: str,
  124. rpcTimeoutSeconds: float,
  125. *args: Any,
  126. **kwargs: Any,
  127. ): ...
  128. def _invoke_rpc_python_udf(
  129. dst: WorkerInfo,
  130. pickledPythonUDF: str,
  131. tensors: List[torch.Tensor],
  132. rpcTimeoutSeconds: float,
  133. isAsyncExecution: bool,
  134. ): ...
  135. def _invoke_rpc_torchscript(
  136. dstWorkerName: str,
  137. qualifiedNameStr: str,
  138. argsTuple: Tuple,
  139. kwargsDict: Dict,
  140. rpcTimeoutSeconds: float,
  141. isAsyncExecution: bool,
  142. ): ...
  143. def _invoke_remote_builtin(
  144. dst: WorkerInfo,
  145. opName: str,
  146. rpcTimeoutSeconds: float,
  147. *args: Any,
  148. **kwargs: Any,
  149. ): ...
  150. def _invoke_remote_python_udf(
  151. dst: WorkerInfo,
  152. pickledPythonUDF: str,
  153. tensors: List[torch.Tensor],
  154. rpcTimeoutSeconds: float,
  155. isAsyncExecution: bool,
  156. ): ...
  157. def _invoke_remote_torchscript(
  158. dstWorkerName: WorkerInfo,
  159. qualifiedNameStr: str,
  160. rpcTimeoutSeconds: float,
  161. isAsyncExecution: bool,
  162. *args: Any,
  163. **kwargs: Any,
  164. ): ...
  165. def get_rpc_timeout() -> float: ...
  166. def enable_gil_profiling(flag: bool): ...
  167. def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
  168. class RemoteProfilerManager:
  169. @staticmethod
  170. def set_current_profiling_key(key: str): ...
  171. def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
  172. def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
  173. def _set_profiler_node_id(default_node_id: int): ...
  174. def _enable_jit_rref_pickle(): ...
  175. def _disable_jit_rref_pickle(): ...