| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- # mypy: allow-untyped-defs
- import weakref
- from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple
- import torch
- import torch.nn as nn
- from torch.distributed._composable_state import _State
- from torch.nn.parallel import DistributedDataParallel
- from .contract import _get_registry, contract
- _ROOT_MODULE_PREFIX = ""
- class _ReplicateState(_State):
- def __init__(self) -> None:
- super().__init__()
- self.module: nn.Module = nn.ParameterList()
- self.has_initialized: bool = False
- self._param_list: nn.ParameterList = nn.ParameterList()
- # TODO(@fegin): this variable is originally create for testing, we
- # should remove this if possible.
- self._orig_module = self.module
- self._param_names: List[str] = []
- self._no_sync: bool = False
- self._init_args: Optional[Tuple[Any, ...]] = None
- self._init_kwargs: Dict[str, Any] = {}
- self._comm_hook_args: List[Any] = []
- def _collect_params(
- self,
- module: nn.Module,
- ignored_modules: Set[nn.Module],
- ignored_params: Set[nn.Parameter],
- prefix: str = _ROOT_MODULE_PREFIX,
- ) -> None:
- # skip if managed by fully_sharded API
- if _is_fully_sharded(module):
- return
- # if a module is ignored, all descendants of the module are ignored.
- if module in ignored_modules:
- return
- recurse_prefix = (
- f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX
- )
- for n, p in module.named_parameters(recurse=False):
- if p not in ignored_params:
- self._param_list.append(p)
- self._param_names.append(f"{recurse_prefix}{n}")
- for name, child_module in module.named_children():
- self._collect_params(
- child_module,
- ignored_modules,
- ignored_params,
- prefix=f"{recurse_prefix}{name}",
- )
- def lazy_init(self) -> None:
- @torch._disable_dynamo(recursive=True)
- def _lazy_init():
- assert self._init_args is not None
- self.init(*self._init_args, **self._init_kwargs)
- self.register_comm_hook()
- self._init_args = tuple()
- self._init_kwargs = {}
- _lazy_init()
- def init(
- self,
- module: nn.Module,
- ignored_modules: Set[nn.Module],
- **kwargs,
- ) -> None:
- if self.has_initialized:
- return
- self.has_initialized = True
- device_mesh = kwargs.get("device_mesh", None)
- self.module = module
- ignored_params = {p for m in ignored_modules for p in m.parameters()}
- from torch.distributed.tensor.parallel.ddp import _localize_dtensor
- _localize_dtensor(module)
- self._collect_params(module, ignored_modules, ignored_params)
- if "device_id" in kwargs:
- # replicate() supports a small usability enhancement where
- # user can pass in device_id as a Union[int, torch.device] even for
- # CPU devices so users don't have to change code for CPU/GPU runs.
- # We derive the right device_ids to feed into DDP to support this.
- if kwargs["device_id"] is not None:
- device_id = kwargs["device_id"]
- # Convert to device_ids that DDP expects.
- if isinstance(device_id, torch.device) and device_id.type == "cpu":
- # CPU modules receive device_ids None
- kwargs["device_ids"] = None
- else:
- # GPU modules expect device_ids=[cuda_device]
- kwargs["device_ids"] = [device_id]
- else:
- kwargs["device_ids"] = None
- kwargs.pop("device_id")
- self._ddp = DistributedDataParallel(self._param_list, **kwargs)
- # Weakref to the DDP instance is currently only used for testing.
- replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp)
- def register_comm_hook(self) -> None:
- for comm_args, comm_kwargs in self._comm_hook_args:
- self._ddp.register_comm_hook(*comm_args, **comm_kwargs)
- self._comm_hook_args.clear()
- def record_init_args(self, *args, **kwargs) -> None:
- self._init_args = args
- self._init_kwargs = kwargs
- def forward_pre_hook(
- self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
- ) -> Any:
- if self._init_args or self._init_kwargs:
- self.lazy_init()
- self._ddp.require_backward_grad_sync = not self._no_sync
- return self._ddp._pre_forward(*args, **kwargs)
- def forward_post_hook(
- self,
- module: nn.Module,
- input: Tuple[torch.Tensor],
- output: torch.Tensor,
- ) -> torch.Tensor:
- return self._ddp._post_forward(output)
- def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
- raise AssertionError(
- "DDP does not support deepcopy. Please use state dict for serialization."
- )
- # Follow the same pattern as FSDP/fully_shard
- class DDP:
- def __new__(cls, *args, **kwargs):
- """
- Override ``__new__`` to remove the DDP class and directly construct
- the original class for cases like indexing into a container module.
- """
- # Use index 2 since 0 is the dynamically constructed `DDP<...>` class
- # and index 1 is the `DDP` class itself
- orig_cls = cls.__mro__[2]
- return orig_cls.__new__(orig_cls, *args, **kwargs)
- def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None:
- """
- Sets if the module should sync gradients. This can be used to implement
- gradient accumulation without communication.
- Args:
- requires_gradient_sync (bool): Whether to reduce gradients for the
- module's parameters.
- """
- replicate.state(self)._no_sync = not requires_gradient_sync
- def register_comm_hook(self, *args, **kwargs) -> None:
- replicate.state(self)._comm_hook_args.append((args, kwargs))
- @contract(state_cls=_ReplicateState)
- def replicate(
- module: nn.Module,
- ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
- **kwargs,
- ) -> nn.Module:
- r"""Replicates a module
- Args:
- module (torch.nn.Module): module to replicate
- Example::
- >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
- >>> module = nn.Linear(3, 3)
- >>> replicate(module)
- """
- torch._C._log_api_usage_once("torch.distributed.replicate")
- # TODO(fegin): using kwargs is not a good idea if we would like to make
- # replicate a formal API to replace DDP.
- if "device_id" in kwargs:
- if not isinstance(kwargs["device_id"], (int, torch.device)):
- raise RuntimeError(
- "Expected device_id to be int or torch.device, "
- f"but got {type(kwargs['device_id'])}"
- )
- if _is_fully_sharded(module):
- raise RuntimeError(
- "Cannot apply `replicate()` on a Module already managed by `fully_shard`"
- )
- if ignored_modules is None:
- ignored_modules = {}
- else:
- ignored_modules = set(ignored_modules)
- state = cast(_ReplicateState, replicate.state(module))
- module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
- device_mesh = kwargs.get("device_mesh", None)
- if device_mesh is not None:
- from torch.distributed.device_mesh import _mesh_resources
- if _mesh_resources.get_parent_mesh(device_mesh) is not None:
- # TODO: This is a temporary work around to enable DDP + TP.
- # We should do the logic in DDP so that the 2D implementation is
- # sound and the state_dict works out of the box.
- #
- # This won't conflict with what is done in DDP class as the module
- # replicate is going to pass is NOT the original module.
- from torch.distributed.tensor.parallel.ddp import (
- _localize_dtensor,
- _reconstruct_dtensor,
- )
- module.register_forward_pre_hook(_reconstruct_dtensor)
- module.register_forward_hook(_localize_dtensor)
- module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type]
- state.record_init_args(module, ignored_modules, **kwargs)
- # Place DDP leftmost for highest priority in the method resolution order
- cls = module.__class__
- dct = {"__deepcopy__": unimplemented_deepcopy}
- new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct)
- module.__class__ = new_cls
- return module
- def _is_fully_sharded(module: nn.Module) -> bool:
- r"""Check if module is marked with fully_shard."""
- registry = _get_registry(module)
- if registry is None:
- return False
- return "fully_shard" in registry
|