| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- # mypy: allow-untyped-defs
- from typing import Any, List, Tuple
- import torch.nn as nn
- from torch.distributed.tensor.parallel._data_parallel_utils import (
- _flatten_tensor,
- _unflatten_tensor,
- )
- __all__ = [] # type: ignore[var-annotated]
- def _get_submodule_n_params(module: nn.Module, path: str):
- """
- Get submodule and the direct path of parameter from the module
- """
- if "." in path:
- path_list = path.split(".")
- parent_module_path = ".".join(path_list[:-1])
- module = module.get_submodule(parent_module_path)
- path = path_list[-1]
- return module, path
- def _update_module_param(param_list: List[Tuple[nn.Module, str, nn.Parameter]]):
- """
- Update parameters within the module
- """
- for item in param_list:
- parent_module, module_path, t = item
- assert hasattr(parent_module, module_path)
- delattr(parent_module, module_path)
- setattr(parent_module, module_path, t)
- def _reconstruct_dtensor(module: nn.Module, _input: Any):
- """
- Recontruct DTensor parameters from local tensors
- """
- param_list = []
- # TODO: To add perf optimizations to this iterations
- for name, t in module.named_parameters():
- if hasattr(t, "_st_info"):
- dtensor = _unflatten_tensor(t, t._st_info)
- param_list.append((*_get_submodule_n_params(module, name), dtensor))
- _update_module_param(param_list) # type: ignore[arg-type]
- def _localize_dtensor(module: nn.Module, *_: Any):
- """
- Convert DTensor parameters to local tensors
- """
- param_list = []
- for name, param in module.named_parameters():
- t, sharding_info = _flatten_tensor(param)
- if sharding_info is not None:
- t = nn.Parameter(t)
- t._st_info = sharding_info # type: ignore[attr-defined]
- param_list.append((*_get_submodule_n_params(module, name), t))
- _update_module_param(param_list) # type: ignore[arg-type]
- def _pre_dp_module_transform(module: nn.Module):
- """
- Enable the composability between Tensor Parallelism (TP) and Data
- Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which
- are DTensors to local tensors before wrapping with data parallelism API.
- We then register two hooks, one for converting local tensors back to DTensor
- preforward and one to convert DTensors back to tensors after Forward. By
- integrating this way, we avoid any special handling of DTensor parameters by DDP
- and get DTensor's gradients propagated back to DP, e.g. gradient buckets of DDP.
- For now, this API only works with ``DistributedDataParallel``. It will later support
- other DP methods such as FSDP.
- Args:
- module (:class:`nn.Module`):
- Module which has been applied TP on.
- Example::
- >>> # xdoctest: +SKIP("distributed")
- >>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
- >>> from torch.nn.parallel import DistributedDataParallel as DDP
- >>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform
- >>>
- >>> # Define the module.
- >>> m = module(...)
- >>> parallelize_module(m, PairwiseParallel())
- >>> m = pre_dp_module_transform(m)
- >>> m = DDP(m)
- >>>
- """
- _localize_dtensor(module, None, None)
- # TODO: To add test cases and ensure that it works for nested modules
- module.register_forward_pre_hook(_reconstruct_dtensor)
- module.register_forward_hook(_localize_dtensor)
|