_wrap_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import functools
  4. import inspect
  5. import warnings
  6. from functools import partial
  7. from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
  8. import torch.nn as nn
  9. from torch.distributed.fsdp._common_utils import (
  10. _get_module_fsdp_state,
  11. _override_module_mixed_precision,
  12. )
  13. from torch.distributed.fsdp.wrap import (
  14. _construct_wrap_fn,
  15. _or_policy,
  16. _Policy,
  17. _post_order_apply,
  18. _recursive_wrap,
  19. _run_mixed_precision_override_policy,
  20. _wrap_module_cls_individually,
  21. )
  22. def _auto_wrap(
  23. root_module: nn.Module,
  24. policy: Union[Callable, _Policy],
  25. ignored_modules: Set[nn.Module],
  26. ignored_params: Set[nn.Parameter],
  27. root_kwargs: Dict[str, Any],
  28. fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard`
  29. ):
  30. """
  31. Auto wraps modules in ``root_module`` 's tree according to ``policy``
  32. following a post-order traversal.
  33. Precondition: ``root_kwargs`` should contain all arguments except
  34. ``module``. This function accepts the kwargs dict directly since it gets
  35. forwarded into the post-order traversal function.
  36. """
  37. mixed_precision = root_kwargs["mixed_precision"]
  38. is_wrapper = inspect.isclass(fsdp_fn)
  39. # TODO: We may relax this no-nested-wrapping constraint to support manual
  40. # wrapping followed by auto wrapping.
  41. _check_nested_wrapping(root_module)
  42. if isinstance(policy, _Policy):
  43. root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
  44. target_module_to_kwargs = policy._run_policy(
  45. root_module, ignored_modules, root_kwargs
  46. )
  47. if mixed_precision is not None:
  48. target_module_to_kwargs = _run_mixed_precision_override_policy(
  49. root_module,
  50. mixed_precision._module_classes_to_ignore,
  51. ignored_modules,
  52. root_kwargs,
  53. target_module_to_kwargs,
  54. )
  55. overridden_module_classes = _override_module_mixed_precision(
  56. root_module, mixed_precision._module_classes_to_ignore
  57. )
  58. _warn_on_overridden_mixed_precision(overridden_module_classes)
  59. use_orig_params = root_kwargs.get("use_orig_params", False)
  60. _validate_frozen_params(
  61. root_module,
  62. set(target_module_to_kwargs.keys()),
  63. ignored_params,
  64. use_orig_params,
  65. )
  66. wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn)
  67. _post_order_apply(root_module, wrap_fn)
  68. return
  69. recursive_wrap_kwargs = {
  70. "module": root_module,
  71. "auto_wrap_policy": policy,
  72. "wrapper_cls": fsdp_fn,
  73. "ignored_modules": ignored_modules,
  74. "ignored_params": ignored_params,
  75. "only_wrap_children": True,
  76. }
  77. if mixed_precision is not None:
  78. # Wrap modules of the ignored types separately and register forward
  79. # hooks to cast to fp32 and back to the original dtype, respectively
  80. overridden_module_classes = _override_module_mixed_precision(
  81. root_module, mixed_precision._module_classes_to_ignore
  82. )
  83. policy = functools.partial(
  84. _or_policy,
  85. policies=[
  86. policy,
  87. partial(
  88. _wrap_module_cls_individually,
  89. module_classes=mixed_precision._module_classes_to_ignore,
  90. ),
  91. ],
  92. )
  93. recursive_wrap_kwargs["auto_wrap_policy"] = policy
  94. _warn_on_overridden_mixed_precision(overridden_module_classes)
  95. _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
  96. def _check_nested_wrapping(root_module: nn.Module):
  97. for module_name, module in root_module.named_modules():
  98. if _get_module_fsdp_state(module) is not None:
  99. raise ValueError(
  100. "FSDP auto wrapping requires modules to not already have "
  101. f"FSDP applied but found {module_name} in\n{root_module}"
  102. )
  103. def _warn_on_overridden_mixed_precision(
  104. overridden_module_classes: Set[Type[nn.Module]],
  105. ):
  106. if len(overridden_module_classes) == 0:
  107. return
  108. warnings.warn(
  109. "Both mixed precision and an auto_wrap_policy were specified to FSDP, "
  110. f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n"
  111. "These modules will be wrapped as separate FSDP instacnes with mixed "
  112. "precision disabled."
  113. )
  114. def _validate_frozen_params(
  115. root_module: nn.Module,
  116. modules_to_wrap: Set[nn.Module],
  117. ignored_params: Set[nn.Parameter],
  118. use_orig_params: bool,
  119. ):
  120. """
  121. This checks that, given ``modules_to_wrap``, each module would manage
  122. parameters that are uniformly frozen or non-frozen. This uniformity
  123. requirement is strict for ``use_orig_params=False`` (hard error) and highly
  124. recommended for ``use_orig_params=True`` (user warning).
  125. """
  126. post_order_named_modules = _get_post_order_named_modules(root_module)
  127. visited_modules: Set[nn.Module] = set()
  128. for module_name, module in post_order_named_modules:
  129. if module in modules_to_wrap:
  130. param_to_fqn = _get_managed_param_to_fqn(
  131. module, ignored_params, visited_modules, module_name
  132. )
  133. frozen_param_fqns: List[str] = []
  134. frozen_param_numel = 0
  135. nonfrozen_param_fqns: List[str] = []
  136. nonfrozen_param_numel = 0
  137. for param, fqn in param_to_fqn.items():
  138. if param.requires_grad:
  139. nonfrozen_param_fqns.append(fqn)
  140. nonfrozen_param_numel += param.numel()
  141. else:
  142. frozen_param_fqns.append(fqn)
  143. frozen_param_numel += param.numel()
  144. if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0:
  145. msg = f"{module_name} has both parameters with requires_grad=True and False."
  146. if use_orig_params:
  147. total_param_numel = frozen_param_numel + nonfrozen_param_numel
  148. msg += (
  149. " We do not recommend wrapping such modules since "
  150. "the gradient memory usage will be higher than expected "
  151. f"({total_param_numel} numel instead of {nonfrozen_param_numel} numel "
  152. "before sharding via reduce-scatter). "
  153. )
  154. else:
  155. msg += " FSDP does not support wrapping such modules when use_orig_params=False. "
  156. msg += "If possible, wrap the frozen parameters with FSDP separately.\n"
  157. msg += (
  158. f"The following parameters have requires_grad=True:\n{nonfrozen_param_fqns}\n"
  159. f"The following parameters have requires_grad=False:\n{frozen_param_fqns}"
  160. )
  161. if use_orig_params:
  162. warnings.warn(msg)
  163. else:
  164. raise ValueError(msg)
  165. def _get_post_order_named_modules(
  166. root_module: nn.Module,
  167. ) -> List[Tuple[str, nn.Module]]:
  168. """
  169. This returns the named modules following a post-order traversal, which is a
  170. valid reverse topological sort. We achieve this using the reverse of a
  171. stack-based DFS order instead of reversing ``root_module.named_modules()``
  172. since the former gives the modules in registration order at each level in
  173. the module tree (as opposed to the reverse), which allows us to error/warn
  174. on the first registered module that violates the condition.
  175. For example, consider the following module structure:
  176. M(
  177. S1(),
  178. S2(
  179. SS1(),
  180. SS2(),
  181. ),
  182. S3(),
  183. )
  184. The reverse DFS order is [S1, SS1, SS2, S2, S3, M], while the reverse
  185. ``named_modules()`` order is [S3, SS2, SS1, S2, S1, M].
  186. """
  187. visited_modules = {root_module}
  188. stack = [("", root_module)]
  189. # Append and reverse at the end for linear-time algorithm
  190. reverse_post_order_named_modules: List[Tuple[str, nn.Module]] = []
  191. while stack:
  192. module_name, module = stack.pop()
  193. reverse_post_order_named_modules.append((module_name, module))
  194. for child_module_name, child_module in module.named_children():
  195. if child_module is None: # only for overrides of `named_children()`
  196. continue
  197. if child_module not in visited_modules:
  198. visited_modules.add(child_module)
  199. if module_name != "":
  200. child_module_name = module_name + "." + child_module_name
  201. stack.append((child_module_name, child_module))
  202. post_order_named_modules = list(reversed(reverse_post_order_named_modules))
  203. return post_order_named_modules
  204. def _get_managed_param_to_fqn(
  205. module_to_wrap: nn.Module,
  206. ignored_params: Set[nn.Parameter],
  207. visited_modules: Set[nn.Module],
  208. root_prefix: str,
  209. ) -> Dict[nn.Parameter, str]:
  210. """
  211. This returns a dict that maps managed parameter to its FQN for the given
  212. ``module_to_wrap``. The dict's keys are exactly the parameters that would
  213. be managed by the module, where this is achieved by calling this function
  214. on the modules to wrap in reverse topological order, destructively updating
  215. ``visited_modules``, and not traversing into those modules. The FQNs are
  216. prefixed from the root (via ``root_prefix``) to be more informative.
  217. NOTE: This function is meant to be called pre-wrapping and iteratively in
  218. reverse topological order to cover the full module tree. This differs from
  219. the ``_get_param_to_fqn()`` function meant to be called post-wrapping and
  220. on the full module tree in one shot. Given those differences, we do not try
  221. to unify the two.
  222. """
  223. param_to_fqn: Dict[nn.Parameter, str] = {}
  224. # Run BFS (or any tree traversal works)
  225. queue = collections.deque([(module_to_wrap, root_prefix)])
  226. visited_modules.add(module_to_wrap)
  227. while queue:
  228. module, prefix = queue.popleft()
  229. for param_name, param in module.named_parameters(recurse=False):
  230. if param not in ignored_params:
  231. fqn = param_name if prefix == "" else prefix + "." + param_name
  232. param_to_fqn[param] = fqn
  233. for child_module_name, child_module in module.named_children():
  234. if child_module is None: # only for overrides of `named_children()`
  235. continue
  236. if child_module not in visited_modules:
  237. visited_modules.add(child_module)
  238. child_prefix = (
  239. child_module_name
  240. if prefix == ""
  241. else prefix + "." + child_module_name
  242. )
  243. queue.append((child_module, child_prefix))
  244. return param_to_fqn