replicate.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # mypy: allow-untyped-defs
  2. import weakref
  3. from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple
  4. import torch
  5. import torch.nn as nn
  6. from torch.distributed._composable_state import _State
  7. from torch.nn.parallel import DistributedDataParallel
  8. from .contract import _get_registry, contract
  9. _ROOT_MODULE_PREFIX = ""
  10. class _ReplicateState(_State):
  11. def __init__(self) -> None:
  12. super().__init__()
  13. self.module: nn.Module = nn.ParameterList()
  14. self.has_initialized: bool = False
  15. self._param_list: nn.ParameterList = nn.ParameterList()
  16. # TODO(@fegin): this variable is originally create for testing, we
  17. # should remove this if possible.
  18. self._orig_module = self.module
  19. self._param_names: List[str] = []
  20. self._no_sync: bool = False
  21. self._init_args: Optional[Tuple[Any, ...]] = None
  22. self._init_kwargs: Dict[str, Any] = {}
  23. self._comm_hook_args: List[Any] = []
  24. def _collect_params(
  25. self,
  26. module: nn.Module,
  27. ignored_modules: Set[nn.Module],
  28. ignored_params: Set[nn.Parameter],
  29. prefix: str = _ROOT_MODULE_PREFIX,
  30. ) -> None:
  31. # skip if managed by fully_sharded API
  32. if _is_fully_sharded(module):
  33. return
  34. # if a module is ignored, all descendants of the module are ignored.
  35. if module in ignored_modules:
  36. return
  37. recurse_prefix = (
  38. f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX
  39. )
  40. for n, p in module.named_parameters(recurse=False):
  41. if p not in ignored_params:
  42. self._param_list.append(p)
  43. self._param_names.append(f"{recurse_prefix}{n}")
  44. for name, child_module in module.named_children():
  45. self._collect_params(
  46. child_module,
  47. ignored_modules,
  48. ignored_params,
  49. prefix=f"{recurse_prefix}{name}",
  50. )
  51. def lazy_init(self) -> None:
  52. @torch._disable_dynamo(recursive=True)
  53. def _lazy_init():
  54. assert self._init_args is not None
  55. self.init(*self._init_args, **self._init_kwargs)
  56. self.register_comm_hook()
  57. self._init_args = tuple()
  58. self._init_kwargs = {}
  59. _lazy_init()
  60. def init(
  61. self,
  62. module: nn.Module,
  63. ignored_modules: Set[nn.Module],
  64. **kwargs,
  65. ) -> None:
  66. if self.has_initialized:
  67. return
  68. self.has_initialized = True
  69. device_mesh = kwargs.get("device_mesh", None)
  70. self.module = module
  71. ignored_params = {p for m in ignored_modules for p in m.parameters()}
  72. from torch.distributed.tensor.parallel.ddp import _localize_dtensor
  73. _localize_dtensor(module)
  74. self._collect_params(module, ignored_modules, ignored_params)
  75. if "device_id" in kwargs:
  76. # replicate() supports a small usability enhancement where
  77. # user can pass in device_id as a Union[int, torch.device] even for
  78. # CPU devices so users don't have to change code for CPU/GPU runs.
  79. # We derive the right device_ids to feed into DDP to support this.
  80. if kwargs["device_id"] is not None:
  81. device_id = kwargs["device_id"]
  82. # Convert to device_ids that DDP expects.
  83. if isinstance(device_id, torch.device) and device_id.type == "cpu":
  84. # CPU modules receive device_ids None
  85. kwargs["device_ids"] = None
  86. else:
  87. # GPU modules expect device_ids=[cuda_device]
  88. kwargs["device_ids"] = [device_id]
  89. else:
  90. kwargs["device_ids"] = None
  91. kwargs.pop("device_id")
  92. self._ddp = DistributedDataParallel(self._param_list, **kwargs)
  93. # Weakref to the DDP instance is currently only used for testing.
  94. replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp)
  95. def register_comm_hook(self) -> None:
  96. for comm_args, comm_kwargs in self._comm_hook_args:
  97. self._ddp.register_comm_hook(*comm_args, **comm_kwargs)
  98. self._comm_hook_args.clear()
  99. def record_init_args(self, *args, **kwargs) -> None:
  100. self._init_args = args
  101. self._init_kwargs = kwargs
  102. def forward_pre_hook(
  103. self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
  104. ) -> Any:
  105. if self._init_args or self._init_kwargs:
  106. self.lazy_init()
  107. self._ddp.require_backward_grad_sync = not self._no_sync
  108. return self._ddp._pre_forward(*args, **kwargs)
  109. def forward_post_hook(
  110. self,
  111. module: nn.Module,
  112. input: Tuple[torch.Tensor],
  113. output: torch.Tensor,
  114. ) -> torch.Tensor:
  115. return self._ddp._post_forward(output)
  116. def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
  117. raise AssertionError(
  118. "DDP does not support deepcopy. Please use state dict for serialization."
  119. )
  120. # Follow the same pattern as FSDP/fully_shard
  121. class DDP:
  122. def __new__(cls, *args, **kwargs):
  123. """
  124. Override ``__new__`` to remove the DDP class and directly construct
  125. the original class for cases like indexing into a container module.
  126. """
  127. # Use index 2 since 0 is the dynamically constructed `DDP<...>` class
  128. # and index 1 is the `DDP` class itself
  129. orig_cls = cls.__mro__[2]
  130. return orig_cls.__new__(orig_cls, *args, **kwargs)
  131. def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None:
  132. """
  133. Sets if the module should sync gradients. This can be used to implement
  134. gradient accumulation without communication.
  135. Args:
  136. requires_gradient_sync (bool): Whether to reduce gradients for the
  137. module's parameters.
  138. """
  139. replicate.state(self)._no_sync = not requires_gradient_sync
  140. def register_comm_hook(self, *args, **kwargs) -> None:
  141. replicate.state(self)._comm_hook_args.append((args, kwargs))
  142. @contract(state_cls=_ReplicateState)
  143. def replicate(
  144. module: nn.Module,
  145. ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
  146. **kwargs,
  147. ) -> nn.Module:
  148. r"""Replicates a module
  149. Args:
  150. module (torch.nn.Module): module to replicate
  151. Example::
  152. >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
  153. >>> module = nn.Linear(3, 3)
  154. >>> replicate(module)
  155. """
  156. torch._C._log_api_usage_once("torch.distributed.replicate")
  157. # TODO(fegin): using kwargs is not a good idea if we would like to make
  158. # replicate a formal API to replace DDP.
  159. if "device_id" in kwargs:
  160. if not isinstance(kwargs["device_id"], (int, torch.device)):
  161. raise RuntimeError(
  162. "Expected device_id to be int or torch.device, "
  163. f"but got {type(kwargs['device_id'])}"
  164. )
  165. if _is_fully_sharded(module):
  166. raise RuntimeError(
  167. "Cannot apply `replicate()` on a Module already managed by `fully_shard`"
  168. )
  169. if ignored_modules is None:
  170. ignored_modules = {}
  171. else:
  172. ignored_modules = set(ignored_modules)
  173. state = cast(_ReplicateState, replicate.state(module))
  174. module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
  175. device_mesh = kwargs.get("device_mesh", None)
  176. if device_mesh is not None:
  177. from torch.distributed.device_mesh import _mesh_resources
  178. if _mesh_resources.get_parent_mesh(device_mesh) is not None:
  179. # TODO: This is a temporary work around to enable DDP + TP.
  180. # We should do the logic in DDP so that the 2D implementation is
  181. # sound and the state_dict works out of the box.
  182. #
  183. # This won't conflict with what is done in DDP class as the module
  184. # replicate is going to pass is NOT the original module.
  185. from torch.distributed.tensor.parallel.ddp import (
  186. _localize_dtensor,
  187. _reconstruct_dtensor,
  188. )
  189. module.register_forward_pre_hook(_reconstruct_dtensor)
  190. module.register_forward_hook(_localize_dtensor)
  191. module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type]
  192. state.record_init_args(module, ignored_modules, **kwargs)
  193. # Place DDP leftmost for highest priority in the method resolution order
  194. cls = module.__class__
  195. dct = {"__deepcopy__": unimplemented_deepcopy}
  196. new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct)
  197. module.__class__ = new_cls
  198. return module
  199. def _is_fully_sharded(module: nn.Module) -> bool:
  200. r"""Check if module is marked with fully_shard."""
  201. registry = _get_registry(module)
  202. if registry is None:
  203. return False
  204. return "fully_shard" in registry