join.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from abc import ABC, abstractmethod
  4. from types import TracebackType
  5. from typing import Any, List, NamedTuple, Optional, Type
  6. import torch
  7. import torch.distributed as dist
  8. __all__ = ['JoinHook', 'Joinable', 'Join']
  9. class JoinHook:
  10. r"""
  11. This defines a join hook, which provides two entry points in the join context manager.
  12. Entry points : a main hook, which is called repeatedly while there exists a non-joined
  13. process, and a post-hook, which is called once all processes have joined.
  14. To implement a join hook for the generic join context manager, define a
  15. class that inherits from :class:`JoinHook` and override ``main_hook()`` and
  16. ``post_hook()`` as appropriate.
  17. """
  18. def main_hook(self) -> None:
  19. r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration.
  20. Training iteration i.e., in one forward pass, backward pass, and optimizer step.
  21. """
  22. ...
  23. def post_hook(self, is_last_joiner: bool) -> None:
  24. r"""
  25. Call hook after all processes have joined.
  26. It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join.
  27. Arguments:
  28. is_last_joiner (bool): ``True`` if the rank is one of the last to
  29. join; ``False`` otherwise.
  30. """
  31. ...
  32. class Joinable(ABC):
  33. r"""
  34. This defines an abstract base class for joinable classes.
  35. A joinable class
  36. (inheriting from :class:`Joinable`) should implement :meth:`join_hook`,
  37. which returns a :class:`JoinHook` instance, in addition to
  38. :meth:`join_device` and :meth:`join_process_group` that return device and
  39. process group information, respectively.
  40. """
  41. @abstractmethod
  42. def __init__(self):
  43. super().__init__()
  44. self._join_config = _JoinConfig.construct_disabled_join_config()
  45. @abstractmethod
  46. def join_hook(self, **kwargs) -> JoinHook:
  47. r"""
  48. Return a :class:`JoinHook` instance for the given :class:`Joinable`.
  49. Arguments:
  50. kwargs (dict): a :class:`dict` containing any keyword arguments
  51. to modify the behavior of the join hook at run time; all
  52. :class:`Joinable` instances sharing the same join context
  53. manager are forwarded the same value for ``kwargs``.
  54. """
  55. ...
  56. @property
  57. @abstractmethod
  58. def join_device(self) -> torch.device:
  59. r"""Return the device from which to perform collective communications needed by the join context manager."""
  60. ...
  61. @property
  62. @abstractmethod
  63. def join_process_group(self) -> Any:
  64. r"""Returns the process group for the collective communications needed by the join context manager itself."""
  65. ...
  66. class _JoinConfig(NamedTuple):
  67. r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side."""
  68. enable: bool
  69. throw_on_early_termination: bool
  70. is_first_joinable: bool
  71. @staticmethod
  72. def construct_disabled_join_config():
  73. r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled.
  74. e.g. if the caller is not in a join context manager.
  75. """
  76. return _JoinConfig(
  77. enable=False,
  78. throw_on_early_termination=False,
  79. is_first_joinable=False
  80. )
  81. class Join:
  82. r"""
  83. This class defines the generic join context manager, which allows custom hooks to be called after a process joins.
  84. These hooks should shadow the
  85. collective communications of non-joined processes to prevent hanging and
  86. erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook`
  87. for details about the hook definition.
  88. .. warning::
  89. The context manager requires each participating :class:`Joinable` to
  90. call the method :meth:`notify_join_context()` before its own per-
  91. iteration collective communications to ensure correctness.
  92. .. warning::
  93. The context manager requires that all ``process_group`` attributes in
  94. the :class:`JoinHook` objects are the same. If there are multiple
  95. :class:`JoinHook` objects, then the ``device`` of the first is used.
  96. The process group and device information is used for checking for non-
  97. joined processes and for notifying processes to throw an exception if
  98. ``throw_on_early_termination`` is enabled, both of which using an all-
  99. reduce.
  100. Arguments:
  101. joinables (List[Joinable]): a list of the participating
  102. :class:`Joinable` s; their hooks are iterated over in the given
  103. order.
  104. enable (bool): a flag enabling uneven input detection; setting to
  105. ``False`` disables the context manager's functionality and should
  106. only be set when the user knows the inputs will not be uneven
  107. (default: ``True``).
  108. throw_on_early_termination (bool): a flag controlling whether to throw an
  109. exception upon detecting uneven inputs (default: ``False``).
  110. Example::
  111. >>> import os
  112. >>> import torch
  113. >>> import torch.distributed as dist
  114. >>> import torch.multiprocessing as mp
  115. >>> # xdoctest: +SKIP
  116. >>> import torch.nn.parallel.DistributedDataParallel as DDP
  117. >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
  118. >>> from torch.distributed.algorithms.join import Join
  119. >>>
  120. >>> # On each spawned worker
  121. >>> def worker(rank):
  122. >>> dist.init_process_group("nccl", rank=rank, world_size=2)
  123. >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
  124. >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
  125. >>> # Rank 1 gets one more input than rank 0
  126. >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
  127. >>> with Join([model, optim]):
  128. >>> for input in inputs:
  129. >>> loss = model(input).sum()
  130. >>> loss.backward()
  131. >>> optim.step()
  132. >>> # All ranks reach here without hanging/erroring
  133. """
  134. def __init__(
  135. self,
  136. joinables: List[Joinable],
  137. enable: bool = True,
  138. throw_on_early_termination: bool = False,
  139. **kwargs,
  140. ):
  141. if len(joinables) == 0:
  142. raise ValueError("The join context manager requires at least one joinable")
  143. self._joinables = joinables
  144. self._join_hooks = [joinable.join_hook(**kwargs) for joinable in self._joinables]
  145. self._enable = enable
  146. self._throw_on_early_termination = throw_on_early_termination
  147. self._set_joinable_configs()
  148. self._extract_dist_info()
  149. def _set_joinable_configs(self) -> None:
  150. r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`."""
  151. assert len(self._joinables) > 0
  152. is_first_joinable = True
  153. for joinable in self._joinables:
  154. joinable._join_config = _JoinConfig(
  155. enable=self._enable,
  156. throw_on_early_termination=self._throw_on_early_termination,
  157. is_first_joinable=is_first_joinable
  158. )
  159. is_first_joinable = False
  160. def _extract_dist_info(self) -> None:
  161. r"""
  162. Extract the process group and device information from the joinables.
  163. If there are multiple joinables, then the context manager uses the
  164. first specified device.
  165. Preconditions:
  166. ``self._joinables`` is not ``None`` and is non-empty.
  167. Raises:
  168. ValueError
  169. If there are multiple conflicting ``process_group`` attributes
  170. among the ``Joinable`` objects.
  171. """
  172. process_group = None
  173. device = None
  174. for joinable in self._joinables:
  175. if process_group is None:
  176. process_group = joinable.join_process_group
  177. elif process_group != joinable.join_process_group:
  178. raise ValueError("Using join context manager with multiple process groups")
  179. if device is None:
  180. device = joinable.join_device
  181. self._process_group = process_group
  182. self._rank = dist.get_rank(self._process_group)
  183. self._device = device
  184. def __enter__(self):
  185. ...
  186. def __exit__(
  187. self,
  188. type: Optional[Type[BaseException]],
  189. value: Optional[BaseException],
  190. traceback: Optional[TracebackType]
  191. ):
  192. r"""
  193. Repeatedly runs the main hooks until all processes join; then, runs the post-hooks.
  194. Raises:
  195. RuntimeError
  196. If ``throw_on_early_termination=True``.
  197. """
  198. if not self._enable or type:
  199. return # propagate the exception directly if one was raised
  200. all_procs_joined = False
  201. is_last_joiner = True
  202. i = 0
  203. WARN_THRESHOLD = 1000
  204. warnings.simplefilter("once")
  205. while not all_procs_joined:
  206. if i > WARN_THRESHOLD:
  207. warnings.warn(
  208. "Detected uneven input skew of greater than "
  209. f"{WARN_THRESHOLD}. This means that rank "
  210. f"{self._rank} has at least {WARN_THRESHOLD} "
  211. f"fewer inputs than other currently-active ranks. "
  212. "This level of skew could lead to performance "
  213. "degradation during training."
  214. )
  215. # Shadow the all-reduce in non-joined processes
  216. num_nonjoined_procs = self._get_num_nonjoined_procs()
  217. if num_nonjoined_procs == 0:
  218. all_procs_joined = True
  219. else:
  220. if self._throw_on_early_termination:
  221. self._notify_procs_to_terminate()
  222. # Run main hooks
  223. for join_hook in self._join_hooks:
  224. join_hook.main_hook()
  225. is_last_joiner = False
  226. i += 1
  227. # Run post-hooks
  228. for join_hook in self._join_hooks:
  229. join_hook.post_hook(is_last_joiner)
  230. def _get_num_nonjoined_procs(self):
  231. r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes."""
  232. num_nonjoined_procs = torch.zeros(1, device=self._device)
  233. dist.all_reduce(num_nonjoined_procs, group=self._process_group)
  234. return num_nonjoined_procs.item()
  235. def _notify_procs_to_terminate(self):
  236. r"""Schedule an all-reduce to notify non-joined processes to terminate.
  237. Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs.
  238. """
  239. ones = torch.ones(1, device=self._device)
  240. dist.all_reduce(ones, group=self._process_group)
  241. raise RuntimeError(f"Rank {self._rank} exhausted all inputs.")
  242. @staticmethod
  243. def notify_join_context(joinable: Joinable):
  244. r"""
  245. Notifies the join context manager that the calling process has not yet joined.
  246. Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected
  247. (i.e. if one process has already joined) and throws an exception if so.
  248. This method should be called from a :class:`Joinable` object before
  249. its per-iteration collective communications. For example, this should
  250. be called at the beginning of the forward pass in
  251. :class:`DistributedDataParallel`.
  252. Only the first :class:`Joinable` object passed into the context
  253. manager performs the collective communications in this method, and
  254. for the others, this method is vacuous.
  255. Arguments:
  256. joinable (Joinable): the :class:`Joinable` object calling this
  257. method.
  258. Returns:
  259. An async work handle for the all-reduce meant to notify the context
  260. manager that the process has not yet joined if ``joinable`` is the
  261. first one passed into the context manager; ``None`` otherwise.
  262. """
  263. assert hasattr(joinable, "_join_config"), \
  264. f"Check that the {type(joinable)} constructor calls the " \
  265. "``Joinable`` constructor"
  266. join_config = joinable._join_config
  267. # First joinable is responsible for the collective communications
  268. if not join_config.is_first_joinable or not join_config.enable:
  269. return None
  270. device = joinable.join_device
  271. process_group = joinable.join_process_group
  272. # Schedule an all-reduce to indicate that the caller has not yet joined
  273. ones = torch.ones(1, device=device)
  274. work = dist.all_reduce(ones, group=process_group, async_op=True)
  275. if join_config.throw_on_early_termination:
  276. # Check if uneven inputs have been detected
  277. zeros = torch.zeros(1, device=device)
  278. dist.all_reduce(zeros, group=process_group)
  279. should_throw = zeros.item()
  280. if should_throw:
  281. raise RuntimeError(
  282. "Detected at least one rank that exhausted inputs. "
  283. "Throwing across all ranks."
  284. )
  285. return work