| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607 |
- # mypy: allow-untyped-defs
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the BSD license found in the
- # LICENSE file in the root directory of this source tree.
- import contextlib
- import copy
- from abc import ABC, abstractmethod
- from typing import (
- Any,
- Callable,
- cast,
- Dict,
- Generator,
- Iterable,
- Optional,
- Sequence,
- Set,
- Tuple,
- Type,
- Union,
- )
- import torch.nn as nn
- __all__ = [
- "always_wrap_policy",
- "lambda_auto_wrap_policy",
- "transformer_auto_wrap_policy",
- "size_based_auto_wrap_policy",
- "enable_wrap",
- "wrap",
- "CustomPolicy",
- "ModuleWrapPolicy",
- ]
- # NOTE: We intentionally keep this function simple and isolate the complexity
- # to `fn` to enable using this function generically. We may move this to a
- # non-FSDP-specific folder and/or make it public in the future.
- def _post_order_apply(
- root_module: nn.Module,
- fn: Callable[[nn.Module], Optional[nn.Module]],
- ):
- """
- This applies ``fn`` to every module in the module tree of ``root_module``
- following a post-order traversal. If ``fn`` returns an :class:`nn.Module`,
- then this replaces the original module with the newly returned one in the
- tree. Otherwise, ``fn`` should return ``None``, in which case the module is
- not changed.
- """
- # Track visited modules to avoid visiting shared modules multiple times
- visited_modules: Set[nn.Module] = {root_module}
- def _post_order_apply_inner(
- module: nn.Module,
- module_name: str,
- parent_module: Optional[nn.Module],
- ):
- for child_module_name, child_module in module.named_children():
- if child_module not in visited_modules:
- visited_modules.add(child_module)
- _post_order_apply_inner(child_module, child_module_name, module)
- optional_module = fn(module)
- if optional_module is not None:
- assert isinstance(parent_module, nn.Module), (
- "Non-root modules should have their parent module set but got "
- f"{parent_module} for {module}"
- )
- assert module_name, (
- "Non-root modules should have their module name set but got "
- f"an empty module name for {module}"
- )
- assert isinstance(
- optional_module, nn.Module
- ), f"fn should return None or an nn.Module but got {optional_module}"
- setattr(parent_module, module_name, optional_module)
- _post_order_apply_inner(root_module, "", None)
- def _construct_wrap_fn(
- root_module: nn.Module,
- target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
- fsdp_fn: Callable,
- ) -> Callable[[nn.Module], Optional[nn.Module]]:
- """
- This constructs the "wrap" function to pass to :func:`_post_order_apply`
- based on ``target_module_to_kwargs``, which should be constructed from the
- wrapping policy.
- """
- def fn(module: nn.Module) -> Optional[nn.Module]:
- # Explicitly avoid wrapping the root module since for FSDP, it is
- # handled by the caller
- if module in target_module_to_kwargs and module is not root_module:
- kwargs = target_module_to_kwargs[module]
- return fsdp_fn(module, **kwargs)
- return None
- return fn
- def _run_mixed_precision_override_policy(
- root_module: nn.Module,
- module_classes: Iterable[Type[nn.Module]],
- ignored_modules: Set[nn.Module],
- root_kwargs: Dict[str, Any],
- target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
- ):
- module_classes_tuple = tuple(set(module_classes))
- for module in root_module.modules():
- if module in ignored_modules:
- continue
- elif isinstance(module, module_classes_tuple):
- # This policy overrides any existing policy
- if module not in target_module_to_kwargs:
- # Only inherit from the root kwargs if not already specified
- target_module_to_kwargs[module] = root_kwargs
- target_module_to_kwargs[module]["mixed_precision"] = None
- return target_module_to_kwargs
- def always_wrap_policy(*args, **kwargs) -> bool:
- """
- A simple recursive wrap policy that always returns ``True``. This means
- that every submodule is wrapped by the wrapper class in
- :func:`_recursive_wrap`.
- """
- return True
- class _Policy(ABC):
- """
- This defines an abstract base class that represents a policy for applying
- a module-level API.
- """
- @abstractmethod
- def _run_policy(
- self,
- root_module: nn.Module,
- ignored_modules: Set[nn.Module],
- root_kwargs: Dict[str, Any],
- ) -> Dict[nn.Module, Dict[str, Any]]:
- """
- This should return a dict ``target_module_to_kwargs`` that maps from
- each target module to wrap to its kwargs.
- """
- ...
- def _module_wrap_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- module_classes: Set[Type[nn.Module]],
- ) -> bool:
- """
- This auto wrap policy wraps every module that is an instance of any type in
- ``module_classes`` as its own FSDP instance. The root module given by
- ``module`` is always wrapped as an FSDP instance regardless. Since the
- wrapping proceeds bottom up, each FSDP instance manages the parameters in
- its subtree excluding any already managed by a child FSDP instance.
- Args:
- module (nn.Module): Current module being considered.
- recurse (bool): If ``False``, then this function must decide whether
- ``module`` should be wrapped as an FSDP instance or not. If
- ``True``, then the function is still recursing down the module
- tree as a part of the DFS.
- nonwrapped_numel (int): Parameter numel not yet wrapped.
- module_classes (Set[Type[nn.Module]]): Set of module classes that are
- wrapped as FSDP instances.
- Returns:
- ``True`` if ``recurse=True``, and whether ``module`` should be wrapped
- if ``recurse=False``.
- """
- if recurse:
- return True # always recurse
- return isinstance(module, tuple(module_classes))
- class ModuleWrapPolicy(_Policy):
- """
- This policy applies to every module of the specified module classes,
- passing in the kwargs given to the root.
- """
- def __init__(self, module_classes: Iterable[Type[nn.Module]]):
- module_classes_set = set(module_classes)
- self._module_classes = module_classes_set
- self._module_classes_str = str(module_classes_set)
- def _run_policy(
- self,
- root_module: nn.Module,
- ignored_modules: Set[nn.Module],
- root_kwargs: Dict[str, Any],
- ) -> Dict[nn.Module, Dict[str, Any]]:
- module_classes = tuple(self._module_classes)
- target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
- for module in root_module.modules():
- if module in ignored_modules:
- continue
- elif isinstance(module, module_classes):
- # Shallow copy to avoid coupling changes across modules
- target_module_to_kwargs[module] = copy.copy(root_kwargs)
- return target_module_to_kwargs
- def __call__(self, module, recurse, *args, **kwargs):
- # nonwrapped_numel is not used.
- return _module_wrap_policy(
- module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes
- )
- def __repr__(self) -> str:
- return super().__repr__() + f"({self._module_classes_str})"
- class CustomPolicy(_Policy):
- """
- This policy takes in a lambda function that maps a given ``nn.Module`` to
- either ``False``, ``True``, or a kwarg dictionary.
- - If the function returns ``False`` or an empty dictionary, then the module
- does not have the API applied.
- - If the function returns ``True``, then the module has the API applied
- with the root's kwargs.
- - If the function returns a non-empty dictionary, then the module has the
- API applied, and the dictionary overrides the root's kwargs.
- Example::
- >>> # xdoctest: +SKIP("undefined variables")
- >>> model = init_transformer_model(...)
- >>> def lambda_fn(module: nn.Module):
- >>> if module is model.lm_head:
- >>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
- >>> elif isinstance(module, TransformerBlock):
- >>> return True
- >>> return False
- >>> policy = CustomPolicy(lambda_fn)
- >>> fsdp_model = FSDP(model, auto_wrap_policy=policy)
- """
- def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]):
- self._lambda_fn = lambda_fn
- def _run_policy(
- self,
- root_module: nn.Module,
- ignored_modules: Set[nn.Module],
- root_kwargs: Dict[str, Any],
- ) -> Dict[nn.Module, Dict[str, Any]]:
- target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
- for module in root_module.modules():
- if module in ignored_modules:
- continue
- res = self._lambda_fn(module)
- if not isinstance(res, (dict, bool)):
- raise ValueError(
- "The lambda_fn passed to CustomPolicy should return "
- f"False/True or a kwarg dict, but it returned {res}"
- )
- if not res:
- continue
- kwargs = copy.copy(root_kwargs)
- if isinstance(res, dict):
- # Override the root kwargs with the ones specified by the
- # lambda function
- kwargs.update(res)
- target_module_to_kwargs[module] = kwargs
- return target_module_to_kwargs
- def lambda_auto_wrap_policy(
- module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
- ) -> bool:
- """
- A convenient auto wrap policy to wrap submodules based on an arbitrary user
- function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
- a `wrapper_cls` unit.
- Return if a module should be wrapped during auto wrapping.
- The first three parameters are required by :func:`_recursive_wrap`.
- Args:
- module (nn.Module): Current module being considered.
- recurse (bool): If ``False``, then this function must decide whether
- ``module`` should be wrapped as an FSDP instance or not. If
- ``True``, then the function is still recursing down the module
- tree as a part of the DFS.
- nonwrapped_numel (int): Parameter numel not yet wrapped.
- lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
- this module will be wrapped.
- """
- if recurse:
- return True # always recurse
- return lambda_fn(module)
- def transformer_auto_wrap_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- transformer_layer_cls: Set[Type[nn.Module]],
- ) -> bool:
- """
- See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
- same as ``module_classes``. Note that shared parameters must be wrapped in
- the same FSDP instance, so this auto wrap policy can help wrap shared
- embeddings into the same FSDP instance for transformer models.
- """
- return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls)
- def _wrap_module_cls_individually(
- module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs
- ):
- if recurse:
- # always recurse
- return True
- else:
- # if not recursing, decide whether we should wrap based on whether the type of module
- # is in `module_classes`.
- return isinstance(module, tuple(module_classes))
- def _or_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- policies,
- ) -> bool:
- """
- A policy that wraps ``module`` if any policy in the passed in iterable of
- ``policies`` returns ``True``.
- """
- return any(
- policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel)
- for policy in policies
- )
- def size_based_auto_wrap_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- # Additional custom arguments
- min_num_params: int = int(1e8),
- force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
- exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
- ) -> bool:
- """
- A size-based auto wrap policy.
- Args:
- module (nn.Module): Current module being considered.
- recurse (bool): If ``False``, then this function must decide whether
- ``module`` should be wrapped as an FSDP instance or not. If
- ``True``, then the function is still recursing down the module
- tree as a part of the DFS.
- nonwrapped_numel (int): Parameter numel not yet wrapped.
- min_num_params (int): Customizable policy input that controls the size
- threshold over which a module is ready to be wrapped. This is in
- units of numel.
- force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep
- as leaves, i.e. their children will never be wrapped.
- exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be
- excluded in wrapping.
- Returns:
- Whether ``module`` should be wrapped.
- """
- force_leaf_modules = (
- size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined]
- if force_leaf_modules is None
- else force_leaf_modules
- )
- exclude_wrap_modules = (
- size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined]
- if exclude_wrap_modules is None
- else exclude_wrap_modules
- )
- # Keep the argument `min_num_params` for BC for now, but it represents the
- # minimum non-wrapped *numel* before triggering a wrapping
- min_nonwrapped_numel = min_num_params
- is_large = nonwrapped_numel >= min_nonwrapped_numel
- if recurse:
- # We should recurse if the module is big enough but not in force_leaf_modules list.
- return is_large and not isinstance(module, tuple(force_leaf_modules))
- else:
- # If we are not recursing, determine if we should wrap.
- return is_large and not isinstance(module, tuple(exclude_wrap_modules))
- # Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
- size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined]
- size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined]
- @contextlib.contextmanager
- def enable_wrap(
- *, wrapper_cls: Any, **wrapper_kwargs: Any
- ) -> Generator[None, None, None]:
- """
- Context manager to wrap modules using a wrapper.
- Useful for when you'd like to apply the same configuration arguments to all
- child modules that you wrap. A particularly important use case is wrapping
- large layers so that they get sharded (in-place) during initialization, to
- avoid running out of system memory. Large layers can indicate that they
- should be sharded via the ``wrap`` annotation and this context manager can
- provide the exact configuration for these nested instances.
- Usage::
- with enable_wrap(wrapper_cls, **params):
- # Wraps layer in FSDP by default if within context
- self.l1 = wrap(torch.nn.Linear(5, 5))
- Args:
- wrapper_cls:
- Class that `wrap` annotation will `wrap` modules with, such as
- `FullyShardedDataParallel`.
- **wrapper_kwargs:
- Configuration settings that will be passed to all ``wrap``
- instances inside the context
- """
- kwargs = {
- "wrapper_cls": wrapper_cls,
- **wrapper_kwargs,
- }
- with _ConfigAutoWrap(**kwargs):
- yield
- def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
- """
- Annotate that a module should be wrapped. Annotated modules will only be
- wrapped if inside of an :func:`enable_wrap` context manager. This allows
- a module to be initialized both with and without a wrapper without code
- change.
- The class that this function wraps the passed in ``nn.Module`` with is the
- passed in ``wrapper_cls`` argument into ``enable_wrap``. Both
- ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct
- the ``wrapper_cls`` instance. In the case of duplicate kwargs in
- ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be
- respected.
- Usage::
- with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
- # Wraps layer in FSDP by default if within context
- self.l1 = wrap(torch.nn.Linear(5, 5))
- Args:
- module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
- **wrap_overrides: configuration overrides that will take priority over
- the values provided by the :func:`enable_wrap` context
- """
- if _ConfigAutoWrap.in_autowrap_context:
- assert _ConfigAutoWrap.wrapper_cls is not None
- wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
- return _wrap(
- module,
- _ConfigAutoWrap.wrapper_cls,
- **wrap_overrides,
- )
- return module
- def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
- assert wrapper_cls is not None
- if hasattr(module, "_wrap_overrides"):
- # If module has a _wrap_overrides attribute, we force overriding the
- # FSDP config with these attributes for this module. Currently this
- # is only used to disable mixed precision for BatchNorm when
- # auto_wrapping.
- overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type]
- return wrapper_cls(module, **overrides)
- return wrapper_cls(module, **kwargs)
- def _recursive_wrap(
- module: nn.Module,
- auto_wrap_policy: Callable,
- wrapper_cls: Callable,
- ignored_modules: Set[nn.Module],
- ignored_params: Set[nn.Parameter],
- only_wrap_children: bool = False,
- **kwargs: Any,
- ) -> Tuple[nn.Module, int]:
- """
- Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns
- ``True`` with ``wrapper_cls``.
- Args:
- module (nn.Module): Module to recursively wrap.
- auto_wrap_policy (Callable): A callable representing a policy that
- determines which modules to recursively wrap with ``wrapper_cls``.
- ignored_modules (Set[torch.nn.Module]): Modules to ignore when
- wrapping.
- ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when
- wrapping; these should be the parameters contained in the modules
- in ``ignored_modules``.
- Returns:
- (nn.Module, int):
- ``module`` after wrapping and the numel recursively wrapped.
- """
- assert auto_wrap_policy is not None, "Must specify auto_wrap_policy."
- assert wrapper_cls is not None, "Must specify wrapper_cls"
- # Make sure no child is already wrapped.
- for _, child in module.named_modules():
- if child in ignored_modules:
- continue
- try:
- assert not isinstance(child, cast(type, wrapper_cls))
- except TypeError:
- # wrapper_cls is a function as opposed to a class type, just bypass above check.
- pass
- # We count all params, assuming none of them are already wrapped.
- nonwrapped_numel = sum(
- p.numel() for p in module.parameters() if p not in ignored_params
- )
- assert auto_wrap_policy is not None
- if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
- total_wrapped_numel = 0
- # Iterate through the children, recursively wrap if necessary
- for name, child in module.named_children():
- if child in ignored_modules:
- continue
- wrapped_child, num_wrapped_params = _recursive_wrap(
- module=child,
- auto_wrap_policy=auto_wrap_policy,
- wrapper_cls=wrapper_cls,
- ignored_modules=ignored_modules,
- ignored_params=ignored_params,
- **kwargs,
- )
- setattr(module, name, wrapped_child)
- # Keep track of how many parameters have been wrapped
- total_wrapped_numel += num_wrapped_params
- # decide if we need to wrap the current module,
- # since the left over parameters exceed the number of params to wrap
- remainder = nonwrapped_numel - total_wrapped_numel
- if not only_wrap_children and auto_wrap_policy(
- module=module, recurse=False, nonwrapped_numel=remainder
- ):
- # Leaf node or final wrapping of the remainder both happen here.
- return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
- else:
- return module, total_wrapped_numel
- return module, 0
- class _ConfigAutoWrap:
- """
- Helper class to wrap modules based on default config args via a context manager.
- See :func:`enable_wrap` for more information.
- """
- in_autowrap_context: bool = False # Context flag
- wrapper_cls: Optional[Callable] = None # The wrapper class
- kwargs: Dict[str, Any] = {} # Wrapper's args
- def __init__(self, **kwargs: Dict[str, Any]):
- self.kwargs = kwargs
- @staticmethod
- def enable_autowrap_context(kwargs: Any) -> None:
- if _ConfigAutoWrap.in_autowrap_context:
- raise NotImplementedError(
- "You are already within an autowrap context and we currently do not supported nested autowrap."
- )
- _ConfigAutoWrap.in_autowrap_context = True
- # Get and save the wrapper cls for the context.
- assert (
- "wrapper_cls" in kwargs.keys()
- ), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
- _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
- del kwargs["wrapper_cls"]
- # Save the rest.
- _ConfigAutoWrap.kwargs = kwargs
- @staticmethod
- def disable_autowrap_context() -> None:
- _ConfigAutoWrap.in_autowrap_context = False
- _ConfigAutoWrap.wrapper_cls = None
- _ConfigAutoWrap.kwargs = {}
- def __enter__(self) -> None:
- self.enable_autowrap_context(self.kwargs)
- def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
- self.disable_autowrap_context()
|