__init__.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # mypy: allow-untyped-defs
  2. from datetime import timedelta
  3. import logging
  4. import os
  5. import threading
  6. import warnings
  7. from typing import Generator, Tuple
  8. from urllib.parse import urlparse
  9. import torch
  10. import torch.distributed as dist
  11. logger = logging.getLogger(__name__)
  12. _init_counter = 0
  13. _init_counter_lock = threading.Lock()
  14. __all__ = ["is_available"]
  15. def is_available() -> bool:
  16. return hasattr(torch._C, "_rpc_init")
  17. if is_available() and not torch._C._rpc_init():
  18. raise RuntimeError("Failed to initialize torch.distributed.rpc")
  19. if is_available():
  20. from torch._C._distributed_c10d import Store
  21. from torch._C._distributed_rpc import (
  22. _disable_jit_rref_pickle,
  23. _enable_jit_rref_pickle,
  24. _disable_server_process_global_profiler,
  25. _enable_server_process_global_profiler,
  26. _set_and_start_rpc_agent,
  27. _reset_current_rpc_agent,
  28. _delete_all_user_and_unforked_owner_rrefs,
  29. _destroy_rref_context,
  30. _set_profiler_node_id,
  31. _is_current_rpc_agent_set,
  32. _rref_context_get_debug_info,
  33. _cleanup_python_rpc_handler,
  34. _invoke_rpc_builtin,
  35. _invoke_rpc_python_udf,
  36. _invoke_rpc_torchscript,
  37. _invoke_remote_builtin,
  38. _invoke_remote_python_udf,
  39. _invoke_remote_torchscript,
  40. _set_rpc_timeout,
  41. _get_current_rpc_agent,
  42. get_rpc_timeout,
  43. enable_gil_profiling,
  44. RpcBackendOptions,
  45. _TensorPipeRpcBackendOptionsBase,
  46. RpcAgent,
  47. PyRRef,
  48. TensorPipeAgent,
  49. RemoteProfilerManager,
  50. WorkerInfo,
  51. _DEFAULT_INIT_METHOD,
  52. _DEFAULT_NUM_WORKER_THREADS,
  53. _UNSET_RPC_TIMEOUT,
  54. _DEFAULT_RPC_TIMEOUT_SEC,
  55. ) # noqa: F401
  56. from . import api, backend_registry, functions
  57. from .api import * # noqa: F401,F403
  58. import numbers
  59. import torch.distributed.autograd as dist_autograd
  60. from .backend_registry import BackendType
  61. from .options import TensorPipeRpcBackendOptions # noqa: F401
  62. from .server_process_global_profiler import (
  63. _server_process_global_profile,
  64. )
  65. rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]
  66. __all__ += ["init_rpc", "BackendType", "TensorPipeRpcBackendOptions"]
  67. __all__ = __all__ + api.__all__ + backend_registry.__all__ # noqa: PLE0605
  68. def init_rpc(
  69. name,
  70. backend=None,
  71. rank=-1,
  72. world_size=None,
  73. rpc_backend_options=None,
  74. ):
  75. r"""
  76. Initializes RPC primitives such as the local RPC agent
  77. and distributed autograd, which immediately makes the current
  78. process ready to send and receive RPCs.
  79. Args:
  80. name (str): a globally unique name of this node. (e.g.,
  81. ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
  82. Name can only contain number, alphabet, underscore, colon,
  83. and/or dash, and must be shorter than 128 characters.
  84. backend (BackendType, optional): The type of RPC backend
  85. implementation. Supported values is
  86. ``BackendType.TENSORPIPE`` (the default).
  87. See :ref:`rpc-backends` for more information.
  88. rank (int): a globally unique id/rank of this node.
  89. world_size (int): The number of workers in the group.
  90. rpc_backend_options (RpcBackendOptions, optional): The options
  91. passed to the RpcAgent constructor. It must be an agent-specific
  92. subclass of :class:`~torch.distributed.rpc.RpcBackendOptions`
  93. and contains agent-specific initialization configurations. By
  94. default, for all agents, it sets the default timeout to 60
  95. seconds and performs the rendezvous with an underlying process
  96. group initialized using ``init_method = "env://"``,
  97. meaning that environment variables ``MASTER_ADDR`` and
  98. ``MASTER_PORT`` need to be set properly. See
  99. :ref:`rpc-backends` for more information and find which options
  100. are available.
  101. """
  102. torch._C._log_api_usage_once("torch.distributed.init_rpc")
  103. if backend is not None and not isinstance(
  104. backend, backend_registry.BackendType
  105. ):
  106. raise TypeError("Argument backend must be a member of BackendType")
  107. if rpc_backend_options is not None and not isinstance(
  108. rpc_backend_options, RpcBackendOptions
  109. ):
  110. raise TypeError(
  111. "Argument rpc_backend_options must be an instance of RpcBackendOptions"
  112. )
  113. # Try to detect the backend from the options
  114. if backend is None and rpc_backend_options is not None:
  115. for candidate_backend in BackendType:
  116. if isinstance(
  117. rpc_backend_options,
  118. type(
  119. backend_registry.construct_rpc_backend_options(
  120. candidate_backend
  121. )
  122. ),
  123. ):
  124. backend = candidate_backend
  125. break
  126. else:
  127. raise TypeError(
  128. f"Could not infer backend for options {rpc_backend_options}"
  129. )
  130. # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865)
  131. if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined]
  132. logger.warning(
  133. "RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined]
  134. "corresponding to %(backend)s, hence that backend will be used "
  135. "instead of the default BackendType.TENSORPIPE. To silence this "
  136. "warning pass `backend=%(backend)s` explicitly.",
  137. {'backend': backend}
  138. )
  139. if backend is None:
  140. backend = BackendType.TENSORPIPE # type: ignore[attr-defined]
  141. if rpc_backend_options is None:
  142. # default construct a set of RPC backend options.
  143. rpc_backend_options = backend_registry.construct_rpc_backend_options(
  144. backend
  145. )
  146. # Create store, performs rendezvous for static RPC group.
  147. if not world_size:
  148. # If world_size is not set in construction and also not set in environment variables
  149. # The store will be created for the dynamic group setting
  150. store = dist._create_store_from_options(rpc_backend_options, rank)
  151. else:
  152. # This rendezvous state sometimes is destroyed before all processes
  153. # finishing handshaking. To avoid that issue, we make it global to
  154. # keep it alive.
  155. global rendezvous_iterator
  156. rendezvous_iterator = dist.rendezvous(
  157. rpc_backend_options.init_method, rank=rank, world_size=world_size
  158. )
  159. store, _, _ = next(rendezvous_iterator)
  160. # Use same timeout as RPC.
  161. store.set_timeout(timedelta(seconds=rpc_backend_options.rpc_timeout))
  162. # Use a PrefixStore to distinguish multiple invocations.
  163. with _init_counter_lock:
  164. global _init_counter
  165. store = dist.PrefixStore(str(f"rpc_prefix_{_init_counter}"), store)
  166. _init_counter += 1
  167. # Initialize autograd before RPC since _init_rpc_backend guarantees all
  168. # processes sync via the store. If we initialize autograd after RPC,
  169. # there could be a race where some nodes might have initialized autograd
  170. # and others might not have. As a result, a node calling
  171. # torch.distributed.autograd.backward() would run into errors since
  172. # other nodes might not have been initialized.
  173. dist_autograd._init(rank)
  174. _set_profiler_node_id(rank)
  175. # Initialize RPC.
  176. _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
  177. def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options):
  178. type_mapping = {
  179. backend: backend_registry.BackendType,
  180. store: dist.Store,
  181. name: str,
  182. rank: numbers.Integral,
  183. # world_size can be None for a dynamic group
  184. world_size: (numbers.Integral, type(None)),
  185. rpc_backend_options: RpcBackendOptions,
  186. }
  187. for arg, arg_type in type_mapping.items():
  188. if not isinstance(arg, arg_type): # type: ignore[arg-type]
  189. raise RuntimeError(
  190. f"Argument {arg} must be of type {arg_type} but got type {type(arg)}"
  191. )
  192. def _init_rpc_backend(
  193. backend=BackendType.TENSORPIPE, # type: ignore[attr-defined]
  194. store=None,
  195. name=None,
  196. rank=-1,
  197. world_size=None,
  198. rpc_backend_options=None,
  199. ):
  200. _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options)
  201. if _is_current_rpc_agent_set():
  202. raise RuntimeError("RPC is already initialized")
  203. # Initialize RPC.
  204. rpc_agent = backend_registry.init_backend(
  205. backend,
  206. store=store,
  207. name=name,
  208. rank=rank,
  209. world_size=world_size,
  210. rpc_backend_options=rpc_backend_options,
  211. )
  212. api._init_rpc_states(rpc_agent)
  213. @api._require_initialized
  214. def _get_debug_info():
  215. info = _rref_context_get_debug_info()
  216. info.update(api._get_current_rpc_agent().get_debug_info())
  217. info.update(dist_autograd._get_debug_info())
  218. return info