| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- """
- This file includes public APIs for FSDP such as the classes used for the
- constructor arguments.
- """
- from dataclasses import dataclass
- from enum import auto, Enum
- from typing import Optional, Sequence, Type
- import torch
- from torch.nn.modules.batchnorm import _BatchNorm
- __all__ = [
- "ShardingStrategy",
- "BackwardPrefetch",
- "MixedPrecision",
- "CPUOffload",
- "StateDictType",
- "StateDictConfig",
- "FullStateDictConfig",
- "LocalStateDictConfig",
- "ShardedStateDictConfig",
- "OptimStateDictConfig",
- "FullOptimStateDictConfig",
- "LocalOptimStateDictConfig",
- "ShardedOptimStateDictConfig",
- "StateDictSettings",
- ]
- class ShardingStrategy(Enum):
- """
- This specifies the sharding strategy to be used for distributed training by
- :class:`FullyShardedDataParallel`.
- - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded.
- For the parameters, this strategy unshards (via all-gather) before the
- forward, reshards after the forward, unshards before the backward
- computation, and reshards after the backward computation. For gradients,
- it synchronizes and shards them (via reduce-scatter) after the backward
- computation. The sharded optimizer states are updated locally per rank.
- - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during
- computation, and additionally, parameters are sharded outside
- computation. For the parameters, this strategy unshards before the
- forward, does not reshard them after the forward, and only reshards them
- after the backward computation. The sharded optimizer states are updated
- locally per rank. Inside ``no_sync()``, the parameters are not resharded
- after the backward computation.
- - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded
- but instead replicated across ranks similar to PyTorch's
- :class:`DistributedDataParallel` API. For gradients, this strategy
- synchronizes them (via all-reduce) after the backward computation. The
- unsharded optimizer states are updated locally per rank.
- - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across
- nodes. This results in reduced communication volume as expensive all-gathers and
- reduce-scatters are only done within a node, which can be more performant for medium
- -sized models.
- - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across
- nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput
- since the unsharded parameters are not freed after the forward pass, saving the
- all-gathers in the pre-backward.
- """
- FULL_SHARD = auto()
- SHARD_GRAD_OP = auto()
- NO_SHARD = auto()
- HYBRID_SHARD = auto()
- _HYBRID_SHARD_ZERO2 = auto()
- class BackwardPrefetch(Enum):
- """
- This configures explicit backward prefetching, which improves throughput by
- enabling communication and computation overlap in the backward pass at the
- cost of slightly increased memory usage.
- - ``BACKWARD_PRE``: This enables the most overlap but increases memory
- usage the most. This prefetches the next set of parameters *before* the
- current set of parameters' gradient computation. This overlaps the *next
- all-gather* and the *current gradient computation*, and at the peak, it
- holds the current set of parameters, next set of parameters, and current
- set of gradients in memory.
- - ``BACKWARD_POST``: This enables less overlap but requires less memory
- usage. This prefetches the next set of parameters *after* the current
- set of parameters' gradient computation. This overlaps the *current
- reduce-scatter* and the *next gradient computation*, and it frees the
- current set of parameters before allocating memory for the next set of
- parameters, only holding the next set of parameters and current set of
- gradients in memory at the peak.
- - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables
- the backward prefetching altogether. This has no overlap and does not
- increase memory usage. In general, we do not recommend this setting since
- it may degrade throughput significantly.
- For more technical context: For a single process group using NCCL backend,
- any collectives, even if issued from different streams, contend for the
- same per-device NCCL stream, which implies that the relative order in which
- the collectives are issued matters for overlapping. The two backward
- prefetching values correspond to different issue orders.
- """
- # NOTE: For both modes, the ordering that defines "current" and "next" is
- # not always exact in the current implementation. A mistargeted prefetch
- # simply means that the parameter memory is allocated earlier than needed,
- # possibly increasing peak memory usage, but does not affect correctness.
- BACKWARD_PRE = auto()
- BACKWARD_POST = auto()
- @dataclass
- class MixedPrecision:
- """
- This configures FSDP-native mixed precision training.
- Attributes:
- param_dtype (Optional[torch.dtype]): This specifies the dtype for model
- parameters during forward and backward and thus the dtype for
- forward and backward computation. Outside forward and backward, the
- *sharded* parameters are kept in full precision (e.g. for the
- optimizer step), and for model checkpointing, the parameters are
- always saved in full precision. (Default: ``None``)
- reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
- gradient reduction (i.e. reduce-scatter or all-reduce). If this is
- ``None`` but ``param_dtype`` is not ``None``, then this takes on
- the ``param_dtype`` value, still running gradient reduction in low
- precision. This is permitted to differ from ``param_dtype``, e.g.
- to force gradient reduction to run in full precision. (Default:
- ``None``)
- buffer_dtype (Optional[torch.dtype]): This specifies the dtype for
- buffers. FSDP does not shard buffers. Rather, FSDP casts them to
- ``buffer_dtype`` in the first forward pass and keeps them in that
- dtype thereafter. For model checkpointing, the buffers are saved
- in full precision except for ``LOCAL_STATE_DICT``. (Default:
- ``None``)
- keep_low_precision_grads (bool): If ``False``, then FSDP upcasts
- gradients to full precision after the backward pass in preparation
- for the optimizer step. If ``True``, then FSDP keeps the gradients
- in the dtype used for gradient reduction, which can save memory if
- using a custom optimizer that supports running in low precision.
- (Default: ``False``)
- cast_forward_inputs (bool): If ``True``, then this FSDP module casts
- its forward args and kwargs to ``param_dtype``. This is to ensure
- that parameter and input dtypes match for forward computation, as
- required by many ops. This may need to be set to ``True`` when only
- applying mixed precision to some but not all FSDP modules, in which
- case a mixed-precision FSDP submodule needs to recast its inputs.
- (Default: ``False``)
- cast_root_forward_inputs (bool): If ``True``, then the root FSDP module
- casts its forward args and kwargs to ``param_dtype``, overriding
- the value of ``cast_forward_inputs``. For non-root FSDP modules,
- this does not do anything. (Default: ``True``)
- _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies
- module classes to ignore for mixed precision when using an
- ``auto_wrap_policy``: Modules of these classes will have FSDP
- applied to them separately with mixed precision disabled (meaning
- that the final FSDP construction would deviate from the specified
- policy). If ``auto_wrap_policy`` is not specified, then this does
- not do anything. This API is experimental and subject to change.
- (Default: ``(_BatchNorm,)``)
- .. note:: This API is experimental and subject to change.
- .. note:: Only floating point tensors are cast to their specified dtypes.
- .. note:: In ``summon_full_params``, parameters are forced to full
- precision, but buffers are not.
- .. note:: Layer norm and batch norm accumulate in ``float32`` even when
- their inputs are in a low precision like ``float16`` or ``bfloat16``.
- Disabling FSDP's mixed precision for those norm modules only means that
- the affine parameters are kept in ``float32``. However, this incurs
- separate all-gathers and reduce-scatters for those norm modules, which
- may be inefficient, so if the workload permits, the user should prefer
- to still apply mixed precision to those modules.
- .. note:: By default, if the user passes a model with any ``_BatchNorm``
- modules and specifies an ``auto_wrap_policy``, then the batch norm
- modules will have FSDP applied to them separately with mixed precision
- disabled. See the ``_module_classes_to_ignore`` argument.
- .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and
- ``cast_forward_inputs=False`` by default. For the root FSDP instance,
- its ``cast_root_forward_inputs`` takes precedence over its
- ``cast_forward_inputs``. For non-root FSDP instances, their
- ``cast_root_forward_inputs`` values are ignored. The default setting is
- sufficient for the typical case where each FSDP instance has the same
- ``MixedPrecision`` configuration and only needs to cast inputs to the
- ``param_dtype`` at the beginning of the model's forward pass.
- .. note:: For nested FSDP instances with different ``MixedPrecision``
- configurations, we recommend setting individual ``cast_forward_inputs``
- values to configure casting inputs or not before each instance's
- forward. In such a case, since the casts happen before each FSDP
- instance's forward, a parent FSDP instance should have its non-FSDP
- submodules run before its FSDP submodules to avoid the activation dtype
- being changed due to a different ``MixedPrecision`` configuration.
- Example::
- >>> # xdoctest: +SKIP("undefined variables")
- >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
- >>> model[1] = FSDP(
- >>> model[1],
- >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
- >>> )
- >>> model = FSDP(
- >>> model,
- >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
- >>> )
- The above shows a working example. On the other hand, if ``model[1]``
- were replaced with ``model[0]``, meaning that the submodule using
- different ``MixedPrecision`` ran its forward first, then ``model[1]``
- would incorrectly see ``float16`` activations instead of ``bfloat16``
- ones.
- """
- param_dtype: Optional[torch.dtype] = None
- reduce_dtype: Optional[torch.dtype] = None
- buffer_dtype: Optional[torch.dtype] = None
- keep_low_precision_grads: bool = False
- cast_forward_inputs: bool = False
- cast_root_forward_inputs: bool = True
- _module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,)
- @dataclass
- class CPUOffload:
- """
- This configures CPU offloading.
- Attributes:
- offload_params (bool): This specifies whether to offload parameters to
- CPU when not involved in computation. If ``True``, then this
- offloads gradients to CPU as well, meaning that the optimizer step
- runs on CPU.
- """
- offload_params: bool = False
- class StateDictType(Enum):
- """
- This enum indicates that which type of ``state_dict`` the FSDP module is
- currently processing (returning or loading).
- The default value is FULL_STATE_DICT to comply the PyTorch convention.
- ..note::
- FSDP currently supports three types of ``state_dict``:
- 1. ``state_dict/load_state_dict`: this pair of APIs return and load
- the non-sharded, unflattened parameters. The semantics is the
- same as using DDP.
- 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return
- and load local sharded, flattened parameters. The values returned
- by ``_local_state_dict`` can be directly used by FSDP and is only
- meaningful to FSDP (because parameters are flattened). Note that
- these APIs are meant for use via the :func:`state_dict_type`
- context manager as follows:
- >>> # xdoctest: +SKIP("undefined variables")
- >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT):
- ... state = fsdp.state_dict() # loads local state dict
- 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs
- return and load sharded, unflattened parameters. The ``state_dict``
- return by ``sharded_state_dict`` can be used by all other parallel
- schemes (resharding may be required).
- """
- FULL_STATE_DICT = auto()
- LOCAL_STATE_DICT = auto()
- SHARDED_STATE_DICT = auto()
- @dataclass
- class StateDictConfig:
- """
- ``StateDictConfig`` is the base class for all ``state_dict`` configuration
- classes. Users should instantiate a child class (e.g.
- ``FullStateDictConfig``) in order to configure settings for the
- corresponding ``state_dict`` type supported by FSDP.
- Attributes:
- offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict
- values to CPU, and if ``False``, then FSDP keeps them on GPU.
- (Default: ``False``)
- """
- offload_to_cpu: bool = False
- @dataclass
- class FullStateDictConfig(StateDictConfig):
- """
- ``FullStateDictConfig`` is a config class meant to be used with
- ``StateDictType.FULL_STATE_DICT``. We recommend enabling both
- ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state
- dicts to save GPU memory and CPU memory, respectively. This config class
- is meant to be used via the :func:`state_dict_type` context manager as
- follows:
- >>> # xdoctest: +SKIP("undefined variables")
- >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
- >>> fsdp = FSDP(model, auto_wrap_policy=...)
- >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
- >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
- >>> state = fsdp.state_dict()
- >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
- >>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
- >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
- >>> if dist.get_rank() == 0:
- >>> # Load checkpoint only on rank 0 to avoid memory redundancy
- >>> state_dict = torch.load("my_checkpoint.pt")
- >>> model.load_state_dict(state_dict)
- >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
- >>> # communicates loaded checkpoint states from rank 0 to rest of the world.
- >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
- >>> # After this point, all ranks have FSDP model with loaded checkpoint.
- Attributes:
- rank0_only (bool): If ``True``, then only rank 0 saves the full state
- dict, and nonzero ranks save an empty dict. If ``False``, then all
- ranks save the full state dict. (Default: ``False``)
- """
- rank0_only: bool = False
- @dataclass
- class LocalStateDictConfig(StateDictConfig):
- pass
- @dataclass
- class ShardedStateDictConfig(StateDictConfig):
- """
- ``ShardedStateDictConfig`` is a config class meant to be used with
- ``StateDictType.SHARDED_STATE_DICT``.
- Attributes:
- _use_dtensor (bool): If ``True``, then FSDP saves the state dict values
- as ``DTensor``, and if ``False``, then FSDP saves them as
- ``ShardedTensor``. (Default: ``False``)
- .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig`
- and it is used by FSDP to determine the type of state dict values. Users should not
- manually modify ``_use_dtensor``.
- """
- _use_dtensor: bool = False
- @dataclass
- class OptimStateDictConfig:
- """
- ``OptimStateDictConfig`` is the base class for all ``optim_state_dict``
- configuration classes. Users should instantiate a child class (e.g.
- ``FullOptimStateDictConfig``) in order to configure settings for the
- corresponding ``optim_state_dict`` type supported by FSDP.
- Attributes:
- offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's
- tensor values to CPU, and if ``False``, then FSDP keeps them on the
- original device (which is GPU unless parameter CPU offloading is
- enabled). (Default: ``True``)
- """
- offload_to_cpu: bool = True
- @dataclass
- class FullOptimStateDictConfig(OptimStateDictConfig):
- """
- Attributes:
- rank0_only (bool): If ``True``, then only rank 0 saves the full state
- dict, and nonzero ranks save an empty dict. If ``False``, then all
- ranks save the full state dict. (Default: ``False``)
- """
- rank0_only: bool = False
- @dataclass
- class LocalOptimStateDictConfig(OptimStateDictConfig):
- offload_to_cpu: bool = False
- @dataclass
- class ShardedOptimStateDictConfig(OptimStateDictConfig):
- """
- ``ShardedOptimStateDictConfig`` is a config class meant to be used with
- ``StateDictType.SHARDED_STATE_DICT``.
- Attributes:
- _use_dtensor (bool): If ``True``, then FSDP saves the state dict values
- as ``DTensor``, and if ``False``, then FSDP saves them as
- ``ShardedTensor``. (Default: ``False``)
- .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig`
- and it is used by FSDP to determine the type of state dict values. Users should not
- manually modify ``_use_dtensor``.
- """
- _use_dtensor: bool = False
- @dataclass
- class StateDictSettings:
- state_dict_type: StateDictType
- state_dict_config: StateDictConfig
- optim_state_dict_config: OptimStateDictConfig
|