fully_shard.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from typing import Callable, Iterable, Optional, Union
  2. from typing_extensions import deprecated
  3. import torch
  4. import torch.distributed as dist
  5. import torch.nn as nn
  6. from torch.distributed._composable.contract import contract
  7. from torch.distributed._composable_state import _get_module_state, _insert_module_state
  8. from torch.distributed.fsdp._common_utils import _FSDPState
  9. from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
  10. from torch.distributed.fsdp._init_utils import (
  11. _init_buffer_state,
  12. _init_core_state,
  13. _init_device_handle,
  14. _init_ignored_module_states,
  15. _init_param_handle_from_module,
  16. _init_prefetching_state,
  17. _init_process_group_state,
  18. _init_runtime_state,
  19. _init_state_dict_state,
  20. HYBRID_SHARDING_STRATEGIES,
  21. )
  22. from torch.distributed.fsdp._runtime_utils import (
  23. _register_post_forward_hook,
  24. _register_pre_forward_hook,
  25. _register_root_pre_forward_hook,
  26. )
  27. from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
  28. from torch.distributed.fsdp._wrap_utils import _auto_wrap
  29. from torch.distributed.fsdp.api import (
  30. BackwardPrefetch,
  31. CPUOffload,
  32. MixedPrecision,
  33. ShardingStrategy,
  34. )
  35. from torch.distributed.fsdp.wrap import _Policy
  36. @contract(state_cls=_FSDPState)
  37. @deprecated(
  38. "`torch.distributed._composable.fully_shard` is being deprecated. "
  39. "You can continue to use the wrapper based FSDP. "
  40. "See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py. "
  41. "`torch.distributed._composable.fully_shard` will be removed after PyTorch 2.5.",
  42. category=FutureWarning,
  43. )
  44. def fully_shard(
  45. module: nn.Module,
  46. *,
  47. process_group: Optional[dist.ProcessGroup] = None,
  48. policy: Optional[_Policy] = None,
  49. strategy: Optional[ShardingStrategy] = None,
  50. mixed_precision: Optional[MixedPrecision] = None,
  51. cpu_offload: Optional[CPUOffload] = None,
  52. ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
  53. device_id: Optional[Union[int, torch.device]] = None,
  54. param_init_fn: Optional[Callable[[nn.Module], None]] = None,
  55. sync_module_states: bool = False,
  56. forward_prefetch: bool = False,
  57. ignored_states: Union[
  58. Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
  59. ] = None,
  60. ) -> nn.Module:
  61. """Applies ``FullyShardedDataParallel`` (FSDP) semantics to ``module``."""
  62. torch._C._log_api_usage_once("torch.distributed.fully_shard")
  63. # Enforce the new auto wrap policy
  64. if policy is not None and not isinstance(policy, _Policy):
  65. raise ValueError(f"Expects a `_Policy` but got {policy}")
  66. state = fully_shard.state(module)
  67. state = _init_ignored_module_states(state, module, ignored_modules, ignored_states)
  68. state = _init_device_handle(state, module, state._ignored_params, device_id)
  69. _annotate_modules_for_dynamo(module, state._ignored_modules, True)
  70. state = _init_process_group_state(state, process_group, strategy, policy)
  71. if policy is not None:
  72. root_kwargs = {
  73. "process_group": process_group,
  74. "strategy": strategy,
  75. "mixed_precision": mixed_precision,
  76. "cpu_offload": cpu_offload,
  77. "ignored_modules": ignored_modules,
  78. "device_id": device_id,
  79. "param_init_fn": param_init_fn,
  80. "sync_module_states": sync_module_states,
  81. "forward_prefetch": forward_prefetch,
  82. "ignored_states": ignored_states,
  83. }
  84. if strategy in HYBRID_SHARDING_STRATEGIES:
  85. root_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
  86. _auto_wrap(
  87. module,
  88. policy,
  89. state._ignored_modules,
  90. state._ignored_params,
  91. root_kwargs,
  92. fully_shard,
  93. )
  94. state = _init_core_state(
  95. state,
  96. strategy or ShardingStrategy.FULL_SHARD,
  97. mixed_precision,
  98. cpu_offload,
  99. limit_all_gathers=True,
  100. use_orig_params=True,
  101. backward_prefetch_limit=1,
  102. forward_prefetch_limit=1,
  103. )
  104. state = _init_runtime_state(state)
  105. state = _init_prefetching_state(
  106. state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch
  107. )
  108. state = _init_buffer_state(state, module)
  109. state = _init_param_handle_from_module(
  110. state, module, device_id, param_init_fn, sync_module_states
  111. )
  112. state = _init_state_dict_state(state)
  113. _register_all_state_dict_hooks(state)
  114. _register_pre_forward_hook(state, module)
  115. _register_post_forward_hook(state, module)
  116. _register_root_pre_forward_hook(state, module) # prepend last
  117. # Always insert the state for the passed-in module even if it has no
  118. # managed parameters, in which case it has no handles and does not appear
  119. # in `_fully_sharded_module_to_handles`
  120. _insert_module_state(module, state)
  121. for submodule in module.modules():
  122. if (
  123. submodule in state._fully_sharded_module_to_handle
  124. and _get_module_state(submodule) is None
  125. ):
  126. _insert_module_state(submodule, state)
  127. return module