ddp.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # mypy: allow-untyped-defs
  2. from typing import Any, List, Tuple
  3. import torch.nn as nn
  4. from torch.distributed.tensor.parallel._data_parallel_utils import (
  5. _flatten_tensor,
  6. _unflatten_tensor,
  7. )
  8. __all__ = [] # type: ignore[var-annotated]
  9. def _get_submodule_n_params(module: nn.Module, path: str):
  10. """
  11. Get submodule and the direct path of parameter from the module
  12. """
  13. if "." in path:
  14. path_list = path.split(".")
  15. parent_module_path = ".".join(path_list[:-1])
  16. module = module.get_submodule(parent_module_path)
  17. path = path_list[-1]
  18. return module, path
  19. def _update_module_param(param_list: List[Tuple[nn.Module, str, nn.Parameter]]):
  20. """
  21. Update parameters within the module
  22. """
  23. for item in param_list:
  24. parent_module, module_path, t = item
  25. assert hasattr(parent_module, module_path)
  26. delattr(parent_module, module_path)
  27. setattr(parent_module, module_path, t)
  28. def _reconstruct_dtensor(module: nn.Module, _input: Any):
  29. """
  30. Recontruct DTensor parameters from local tensors
  31. """
  32. param_list = []
  33. # TODO: To add perf optimizations to this iterations
  34. for name, t in module.named_parameters():
  35. if hasattr(t, "_st_info"):
  36. dtensor = _unflatten_tensor(t, t._st_info)
  37. param_list.append((*_get_submodule_n_params(module, name), dtensor))
  38. _update_module_param(param_list) # type: ignore[arg-type]
  39. def _localize_dtensor(module: nn.Module, *_: Any):
  40. """
  41. Convert DTensor parameters to local tensors
  42. """
  43. param_list = []
  44. for name, param in module.named_parameters():
  45. t, sharding_info = _flatten_tensor(param)
  46. if sharding_info is not None:
  47. t = nn.Parameter(t)
  48. t._st_info = sharding_info # type: ignore[attr-defined]
  49. param_list.append((*_get_submodule_n_params(module, name), t))
  50. _update_module_param(param_list) # type: ignore[arg-type]
  51. def _pre_dp_module_transform(module: nn.Module):
  52. """
  53. Enable the composability between Tensor Parallelism (TP) and Data
  54. Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which
  55. are DTensors to local tensors before wrapping with data parallelism API.
  56. We then register two hooks, one for converting local tensors back to DTensor
  57. preforward and one to convert DTensors back to tensors after Forward. By
  58. integrating this way, we avoid any special handling of DTensor parameters by DDP
  59. and get DTensor's gradients propagated back to DP, e.g. gradient buckets of DDP.
  60. For now, this API only works with ``DistributedDataParallel``. It will later support
  61. other DP methods such as FSDP.
  62. Args:
  63. module (:class:`nn.Module`):
  64. Module which has been applied TP on.
  65. Example::
  66. >>> # xdoctest: +SKIP("distributed")
  67. >>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
  68. >>> from torch.nn.parallel import DistributedDataParallel as DDP
  69. >>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform
  70. >>>
  71. >>> # Define the module.
  72. >>> m = module(...)
  73. >>> parallelize_module(m, PairwiseParallel())
  74. >>> m = pre_dp_module_transform(m)
  75. >>> m = DDP(m)
  76. >>>
  77. """
  78. _localize_dtensor(module, None, None)
  79. # TODO: To add test cases and ensure that it works for nested modules
  80. module.register_forward_pre_hook(_reconstruct_dtensor)
  81. module.register_forward_hook(_localize_dtensor)