_common_utils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # mypy: allow-untyped-defs
  2. """
  3. This file includes private common utilities for FSDP.
  4. """
  5. import logging
  6. import traceback
  7. import warnings
  8. import weakref
  9. from enum import auto, Enum
  10. from functools import partial
  11. from typing import (
  12. Any,
  13. Callable,
  14. cast,
  15. Dict,
  16. Generator,
  17. Iterable,
  18. List,
  19. no_type_check,
  20. Optional,
  21. Set,
  22. Tuple,
  23. Type,
  24. TYPE_CHECKING,
  25. )
  26. import torch
  27. import torch.distributed as dist
  28. import torch.distributed.fsdp._flat_param as flat_param_file
  29. import torch.nn as nn
  30. from torch.distributed._composable_state import _get_module_state, _State
  31. from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
  32. _CHECKPOINT_PREFIX,
  33. )
  34. from torch.distributed.utils import _apply_to_tensors
  35. from torch.utils._mode_utils import no_dispatch
  36. from .api import (
  37. FullOptimStateDictConfig,
  38. FullStateDictConfig,
  39. OptimStateDictConfig,
  40. ShardingStrategy,
  41. StateDictConfig,
  42. StateDictType,
  43. )
  44. if TYPE_CHECKING:
  45. from torch.distributed.device_mesh import DeviceMesh
  46. from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
  47. from ._flat_param import FlatParamHandle
  48. FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
  49. FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
  50. FSDP_FLATTENED = "_fsdp_flattened"
  51. # Save a global mapping from module to its input tensor dtype to be populated
  52. # during the forward pre-hook and consumed in the forward post-hook when
  53. # overriding a module's mixed precision
  54. # NOTE: We currently take the last input tensor's dtype in the case of multiple
  55. # floating-point input tensors, which may be incorrect. However, since there is
  56. # not a 1:1 correspondence between input and output tensors, we must use *some*
  57. # heuristic like this to predict the desired output dtype.
  58. _MODULE_TO_INP_DTYPE: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
  59. class _FSDPDeviceHandle:
  60. """
  61. This is a simple abstraction for FSDP computing devices,
  62. which enables custom backends that implement CUDA-like
  63. semantics to be integrated with FSDP.
  64. """
  65. def __init__(self, device: torch.device, backend: Any = None):
  66. if backend is None:
  67. try:
  68. self.__backend = getattr(torch, device.type)
  69. self.__device = device
  70. except AttributeError as exc:
  71. raise AttributeError(
  72. f"Device '{device}' does not have a corresponding backend registered as 'torch.{device.type}'."
  73. ) from exc
  74. else:
  75. self.__backend = backend
  76. @classmethod
  77. def from_device(cls, device: torch.device) -> "_FSDPDeviceHandle":
  78. """
  79. Return an device handle corresponding to the device, and through this handle,
  80. operations with the same semantics as CUDA can be performed on the device.
  81. Just return torch.cuda if the device is cuda to make attribute-access faster.
  82. Custom backend must first register a module with the same name with {device.type} on torch.
  83. """
  84. if device.type == "cuda":
  85. return cast(_FSDPDeviceHandle, torch.cuda)
  86. return cls(device)
  87. def __getattr__(self, __name: str) -> Any:
  88. try:
  89. return getattr(self.__backend, __name)
  90. except AttributeError as exc:
  91. raise AttributeError(
  92. f"Custom backend '{self.__device.type}' not implement 'torch.{self.__device.type}.{__name}'"
  93. ) from exc
  94. class _UninitializedDeviceHandle(_FSDPDeviceHandle):
  95. def __init__(self):
  96. pass
  97. def __getattribute__(self, __name: str) -> Any:
  98. raise RuntimeError("Trying to use an uninitialized device handle.")
  99. class _FSDPState(_State):
  100. def __init__(self) -> None:
  101. # TODO: Move all the attributes to this class to enable typing for
  102. # FSDP/fully_shard.
  103. self._ignored_modules: Set[nn.Module] = set()
  104. self._ignored_params: Set[nn.Parameter] = set()
  105. # Buffer names are cleaned (without wrapper prefixes)
  106. self._ignored_buffer_names: Set[str] = set()
  107. self.process_group: Optional[dist.ProcessGroup] = None
  108. self.rank: int = -1
  109. self.world_size: int = -1
  110. self._device_mesh: Optional[DeviceMesh] = None
  111. self.sharding_strategy = ShardingStrategy.FULL_SHARD
  112. self._use_orig_params: bool = False
  113. self.training_state = TrainingState.IDLE
  114. self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
  115. self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT
  116. self._state_dict_config: StateDictConfig = FullStateDictConfig()
  117. self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig()
  118. self._is_root: Optional[bool] = None
  119. self._handle: Optional[flat_param_file.FlatParamHandle] = None
  120. self._fully_sharded_module_to_handle: Dict[
  121. nn.Module, Optional[flat_param_file.FlatParamHandle]
  122. ] = {}
  123. self.compute_device: Optional[torch.device] = None
  124. self._gradient_predivide_factor: int = 0
  125. self._gradient_postdivide_factor: int = 0
  126. self._comm_hook: Optional[Callable] = None
  127. self._comm_hook_state: Optional[Any] = None
  128. self._unshard_event: Optional[torch.cuda.Event] = None
  129. # Abstract device handle for fsdp compute device. For now,
  130. # the compute device must implement cuda semantics used by fsdp
  131. self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle()
  132. # All following attributes should only be used for root states:
  133. # Save these static lists to avoid the repeated tree traversals
  134. self._all_fsdp_states: List[_FSDPState] = []
  135. self._all_handles: List[flat_param_file.FlatParamHandle] = []
  136. self._fsdp_extension: Optional[FSDPExtensions] = None
  137. def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]:
  138. state = _get_module_state(module)
  139. if state is None or not isinstance(state, _FSDPState):
  140. return None
  141. return state
  142. def _get_module_fsdp_state_if_fully_sharded_module(
  143. module: nn.Module,
  144. ) -> Optional[_FSDPState]:
  145. state = _get_module_fsdp_state(module)
  146. if state is None:
  147. return None
  148. if state == module: # FullyShardedDataParallel module case.
  149. return state
  150. if module in state._fully_sharded_module_to_handle: # fully_shard case.
  151. return state
  152. return None
  153. class TrainingState(Enum):
  154. """
  155. An enum that indicates the state of a ``FullyShardedDataParallel` instance.
  156. """
  157. IDLE = auto()
  158. FORWARD_BACKWARD = auto()
  159. SUMMON_FULL_PARAMS = auto()
  160. class HandleTrainingState(Enum):
  161. """
  162. An enum that indicates the state of a ``FlatParamHandle`.
  163. """
  164. IDLE = auto()
  165. FORWARD = auto()
  166. BACKWARD_PRE = auto()
  167. BACKWARD_POST = auto()
  168. SUMMON_FULL_PARAMS = auto()
  169. def _is_composable(state: _FSDPState):
  170. # TODO: This is a temporary hack for differentiate between code paths.
  171. return not isinstance(state, nn.Module)
  172. @no_type_check
  173. def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamHandle"]:
  174. """
  175. Returns the ``FlatParamHandle`` s corresponding to ``module``. This is
  176. the handle that contains some parameter in ``module``.
  177. """
  178. if _is_composable(state):
  179. # A valid FSDP state may have no managed parameters and hence no
  180. # handles, meaning no entry in `_fully_sharded_module_to_handles`
  181. if state._handle is None:
  182. return None
  183. assert (
  184. module in state._fully_sharded_module_to_handle
  185. ), f"Expects a fully sharded module but got {module} on rank {state.rank}"
  186. return state._fully_sharded_module_to_handle[module]
  187. else:
  188. # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
  189. return module._handle
  190. @no_type_check
  191. def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool:
  192. """Returns if ``module`` has parameters managed by FSDP."""
  193. return _module_handle(state, module) is not None
  194. def _get_sharding_strategy(handle):
  195. """
  196. Returns the sharding strategy of the handle.
  197. """
  198. return handle._sharding_strategy if handle else None
  199. def clean_tensor_name(tensor_name: str) -> str:
  200. """
  201. Cleans the parameter or buffer name by removing any module wrapper
  202. prefixes.
  203. """
  204. tensor_name = tensor_name.replace(FSDP_PREFIX, "")
  205. # TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as
  206. # it couples `CheckpointWrapper` and FSDP and also does not scale for more
  207. # module wrappers.
  208. tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "")
  209. return tensor_name
  210. def _set_fsdp_flattened(tensor: torch.Tensor) -> None:
  211. """
  212. Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to
  213. avoid re-flattening it during nested construction.
  214. """
  215. setattr(tensor, FSDP_FLATTENED, True)
  216. def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
  217. """Returns if ``tensor`` has been marked as flattened by FSDP."""
  218. return getattr(tensor, FSDP_FLATTENED, False)
  219. def _named_parameters_with_duplicates(
  220. module: nn.Module, **kwargs: Any
  221. ) -> List[Tuple[str, nn.Parameter]]:
  222. """
  223. This API is required as some modules overwrite `named_parameters()` but do not support
  224. `remove_duplicate`.
  225. """
  226. assert (
  227. "remove_duplicate" not in kwargs
  228. ), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
  229. kwargs["remove_duplicate"] = False
  230. try:
  231. ret = list(module.named_parameters(**kwargs))
  232. except AssertionError as e:
  233. kwargs.pop("remove_duplicate")
  234. ret = list(module.named_parameters(**kwargs))
  235. return ret
  236. def _get_param_to_fqns(
  237. model: torch.nn.Module,
  238. dedup_shared_params: bool = True,
  239. ) -> Dict[nn.Parameter, List[str]]:
  240. """
  241. Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here,
  242. we use canonical to mean the fully-qualified name assigned to the parameter
  243. based on its position in the original nn.Module hierarchy before any wrapper
  244. or parallelism has been applied to it. This is in contrast to FQNs that may be
  245. generated after parallelisms or wrappers have been applied to the model.
  246. Each normal parameter maps to a singleton list containing its FQN, while each
  247. ``FlatParameter`` maps to a list of its original parameter FQNs, which may
  248. have length greater than one. All FQNs are prefixed starting from ``model``.
  249. In the case where FSDP was applied with ``use_orig_params=True``, there should be no
  250. ``FlatParameter`` s registered to the model's modules and this mapping will only
  251. contain mappings from ``nn.Parameter`` s to singleton FQN lists.
  252. It is only in the case where FSDP was applied with ``use_orig_params=False`` where
  253. a ``FlatParameter`` will be registered in place of the original parameters and there
  254. will be mappings from each ``FlatParameter`` to lists of FQNs corresponding to the
  255. original parameters.
  256. Args:
  257. model (torch.nn.Module): Root module (which may or may not be a
  258. :class:`FullyShardedDataParallel` instance).
  259. dedup_shared_params (bool): For shared parameters, if ``True``, only
  260. includes the FQNs corresponding to the first encounter of the
  261. shared parameter in the module traversal; if ``False``, then
  262. includes the FQNs across all encounters. (Default: ``True``)
  263. """
  264. def module_fn(module, prefix, tree_level, param_to_fqns):
  265. for param_name, param in _named_parameters_with_duplicates(
  266. module, recurse=False
  267. ):
  268. local_fqns = (
  269. param._fqns
  270. if isinstance(param, flat_param_file.FlatParameter)
  271. else [param_name]
  272. ) # prefixed from `module`
  273. global_fqns = [
  274. clean_tensor_name(prefix + name) for name in local_fqns
  275. ] # prefixed from the top level `model` (i.e. including `prefix`)
  276. is_shared_param = param in param_to_fqns
  277. if not is_shared_param:
  278. param_to_fqns[param] = global_fqns
  279. else:
  280. if isinstance(param, flat_param_file.FlatParameter):
  281. # DMP overwrites `named_parameters` and skip (advance to
  282. # the next child module) the wrapped_module (e.g.,
  283. # _dmp_wrapped_module and _fsdp_wrapped_module). When a user
  284. # calls `named_child` to traverse the module recursively and
  285. # calls `named_parameters` with `recurse=False`, parameters
  286. # will be traversed more than once.
  287. # This hack is specified designed for DMP + FSDP. We
  288. # overwrite the flat_parameters traversal result to only obtain
  289. # the last one, which happens to be the correct one.
  290. #
  291. # TODO: Remove this hack once DMP + FSDP is not supported.
  292. warnings.warn(
  293. "FlatParameter is being traversed more than once. "
  294. "This case should only happen when using "
  295. "DistributedModelParallel with FullyShardedDataParallel."
  296. )
  297. param_to_fqns[param] = global_fqns
  298. elif not dedup_shared_params:
  299. param_to_fqns[param].extend(global_fqns)
  300. def return_fn(param_to_fqns):
  301. return param_to_fqns
  302. param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
  303. return _apply_to_modules(
  304. model,
  305. module_fn,
  306. return_fn,
  307. [key for key, _ in _named_parameters_with_duplicates(model)],
  308. param_to_unflat_param_names,
  309. )
  310. @no_type_check
  311. def _log_post_backward_hook(
  312. state: _FSDPState, handle: "FlatParamHandle", logger: logging.Logger
  313. ) -> None:
  314. # Under TORCH_DISTRIBUTED_DEBUG=INFO, log the module names this hook fires for.
  315. # Below logging of module names this post-bwd hook fires for can help debug certain
  316. # cases where hooks don't fire, such as under certain activation checkpoint configs.
  317. if state._use_orig_params and handle._debug_level == dist.DebugLevel.INFO:
  318. param_fqns = _get_handle_fqns_from_root(state, handle)
  319. logger.warning("FSDP firing post-backward hooks for parameters %s", param_fqns)
  320. @no_type_check
  321. def _get_handle_fqns_from_root(
  322. state: _FSDPState, handle: "FlatParamHandle"
  323. ) -> Optional[List[str]]:
  324. if handle is None:
  325. return None
  326. param_to_fqn = state._exec_order_data.param_to_fqn
  327. handle_params = handle.flat_param._params # only populated for use_orig_params
  328. param_fqns = [
  329. fqn for fqn_list in [param_to_fqn[p] for p in handle_params] for fqn in fqn_list
  330. ]
  331. return param_fqns
  332. def _apply_to_modules(
  333. root_module: torch.nn.Module,
  334. module_fn: Callable,
  335. return_fn: Callable,
  336. filter_fqns: Optional[List[str]] = None,
  337. *args,
  338. **kwargs,
  339. ):
  340. """
  341. Performs a pre-order traversal of the modules in the hierarchy rooted at
  342. ``root_module``, applying ``module_fn`` at each module and finally
  343. returning a value using ``return_fn``. The traversal constructs the full
  344. module prefix name (e.g. "module.submodule." just like in model state dict)
  345. and makes that available to ``module_fn``.
  346. ``filter_fqns`` is used because some module may have its own prefix similar
  347. to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten
  348. to remove the prefix.
  349. """
  350. def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs):
  351. # Call the module function before recursing over children (pre-order)
  352. module_fn(module, prefix, tree_level, *args, **kwargs)
  353. for submodule_name, submodule in module.named_children():
  354. if submodule is None:
  355. continue
  356. new_prefix = prefix + submodule_name + "."
  357. new_tree_level = tree_level + 1
  358. if filter_fqns is not None:
  359. for fqn in filter_fqns:
  360. if fqn.startswith(new_prefix):
  361. break
  362. else:
  363. # DMP's named_parameter() will mess up the traversal with
  364. # ``named_children`` + `named_parameter(recurse=False)``.
  365. # This hack is a must to make the traversal work.
  366. # TODO: Remove this hack once DMP + FSDP is not supported.
  367. # It turns out that recursive wrapping may trigger this as
  368. # well.
  369. if (
  370. submodule_name == "_fsdp_wrapped_module"
  371. or submodule_name == "_dmp_wrapped_module"
  372. ):
  373. new_prefix = prefix
  374. elif submodule_name == "module":
  375. new_prefix = prefix
  376. f(submodule, new_prefix, new_tree_level, *args, **kwargs)
  377. f(root_module, "", 0, *args, **kwargs)
  378. return return_fn(*args, **kwargs)
  379. @no_type_check
  380. def _assert_in_training_states(
  381. state: _FSDPState,
  382. training_states: List[TrainingState],
  383. ) -> None:
  384. """Asserts that FSDP is in the states ``_training_states``."""
  385. # Raise a `ValueError` instead of using `assert` to ensure that these
  386. # logical assertions run even if `assert`s are disabled
  387. if state.training_state not in training_states:
  388. msg = (
  389. f"expected to be in states {training_states} but current state is "
  390. f"{state.training_state}"
  391. )
  392. # Print the error on rank 0 in case this is called in the backward pass
  393. if state.rank == 0:
  394. if isinstance(state, nn.Module):
  395. print(f"Asserting FSDP instance is: {state}")
  396. print(f"ERROR: {msg}")
  397. traceback.print_stack()
  398. raise ValueError(msg)
  399. def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
  400. """
  401. Returns:
  402. Set[nn.Module]: The subset of ``modules`` that are root modules (i.e.
  403. parent-less) with respect to the modules in the set itself. In other
  404. words, these are the modules in ``modules`` that are not the child of
  405. any other module in ``modules``.
  406. """
  407. root_modules: Set[nn.Module] = set()
  408. module_to_submodules = {module: set(module.modules()) for module in modules}
  409. for candidate_module in modules:
  410. is_root_module = True
  411. for module, submodules in module_to_submodules.items():
  412. is_child_module = (
  413. candidate_module is not module and candidate_module in submodules
  414. )
  415. if is_child_module:
  416. is_root_module = False
  417. break
  418. if is_root_module:
  419. root_modules.add(candidate_module)
  420. return root_modules
  421. def _override_module_mixed_precision(
  422. root: torch.nn.Module,
  423. module_classes_to_override: Iterable[Type[nn.Module]],
  424. wrap_override_dict: Dict[str, Any] = {"mixed_precision": None}, # noqa: B006
  425. ) -> Set[Type[nn.Module]]:
  426. module_classes_to_override = tuple(set(module_classes_to_override))
  427. # Return a set of the actually overridden module classes
  428. overridden_module_classes: Set[Type[nn.Module]] = set()
  429. for mod in root.modules():
  430. if isinstance(mod, module_classes_to_override):
  431. overridden_module_classes.add(type(mod))
  432. mod._wrap_overrides = wrap_override_dict # type: ignore[assignment]
  433. # TODO: We need to run this mixed precision ignored module in fp32,
  434. # but ensure subsequent modules, that may possibly be running with
  435. # mixed precision, still receive the appropriate precision inputs
  436. # without user having to adjust mixed precision config too much.
  437. # As a result, we attach pre and post forward hooks to up / down
  438. # cast. We should revisit this design.
  439. def cast_fn(
  440. dtype: torch.dtype, module: nn.Module, x: torch.Tensor
  441. ) -> torch.Tensor:
  442. if not torch.is_floating_point(x) or x.dtype == dtype:
  443. return x
  444. _MODULE_TO_INP_DTYPE[module] = x.dtype
  445. return x.to(dtype)
  446. def forward_pre_hook(module, args):
  447. return _apply_to_tensors(partial(cast_fn, torch.float32, module), args)
  448. def forward_post_hook(module, args, output):
  449. # NOTE: If the forward did not have any floating-point tensors,
  450. # then the dtype will not be set for this module, and we do not
  451. # upcast the dtype.
  452. if module in _MODULE_TO_INP_DTYPE:
  453. old_dtype = _MODULE_TO_INP_DTYPE[module]
  454. return _apply_to_tensors(
  455. partial(cast_fn, old_dtype, module), output
  456. )
  457. # We intentionally append both of these hooks so that they run after
  458. # all other hooks.
  459. mod.register_forward_pre_hook(forward_pre_hook, prepend=False)
  460. mod.register_forward_hook(forward_post_hook, prepend=False)
  461. return overridden_module_classes
  462. def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None:
  463. # FIXME record_stream doesn't work with non-cuda tensors
  464. if tensor.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]:
  465. return
  466. if torch.distributed._functional_collectives.is_torchdynamo_compiling():
  467. return
  468. # from @ezyang:
  469. # The no_dispatch was added in https://github.com/pytorch/pytorch/pull/88014 cc @fegin
  470. # Looking over the PR, it looks like this is because we don't actually support Stream arguments
  471. # in torch dispatch, so it just chokes.
  472. # If Dynamo is able to answer "are there any torch dispatch modes" active (it should answer False),
  473. # a better version of this would just be to check if there are any modes before disabling dispatch.
  474. # TODO(voz): Extend a dynamo util to answer the above, unify the codepaths here.
  475. tensor.record_stream(stream)
  476. else:
  477. with no_dispatch():
  478. tensor.record_stream(stream)