_runtime_utils.py 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import logging
  4. from enum import auto, Enum
  5. from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple
  6. import torch
  7. import torch.distributed as dist
  8. import torch.distributed.fsdp._traversal_utils as traversal_utils
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from torch.autograd import Variable
  12. from torch.autograd.graph import register_multi_grad_hook
  13. from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
  14. from torch.distributed.fsdp._common_utils import (
  15. _assert_in_training_states,
  16. _FSDPState,
  17. _get_module_fsdp_state,
  18. _is_composable,
  19. _log_post_backward_hook,
  20. _no_dispatch_record_stream,
  21. clean_tensor_name,
  22. TrainingState,
  23. )
  24. from torch.distributed.fsdp._flat_param import (
  25. FlatParameter,
  26. FlatParamHandle,
  27. HandleShardingStrategy,
  28. HandleTrainingState,
  29. RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES,
  30. )
  31. from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
  32. from torch.distributed.fsdp.api import BackwardPrefetch
  33. from torch.distributed.utils import (
  34. _apply_to_tensors,
  35. _cast_forward_inputs,
  36. _p_assert,
  37. _to_kwargs,
  38. )
  39. from torch.utils import _pytree as pytree
  40. logger = logging.getLogger(__name__)
  41. # Do not include "process_group" to enable hybrid shard and MoE cases
  42. HOMOGENEOUS_ATTR_NAMES = (
  43. "_use_orig_params",
  44. "limit_all_gathers",
  45. "_use_full_prec_in_eval",
  46. )
  47. class _PrefetchMode(Enum):
  48. BACKWARD = auto()
  49. FORWARD = auto()
  50. def _get_fsdp_root_states_with_modules(
  51. module: nn.Module,
  52. ) -> Tuple[List[_FSDPState], List[nn.Module]]:
  53. """
  54. Returns a tuple containing:
  55. 1. A list of the root ``_FSDPState`` instances in the module tree rooted at
  56. ``module`` without any duplicates and following the ``module.modules()``
  57. traversal order (which is assumed to be depth-first).
  58. 2. A corresponding list of the root modules owning the states in the first
  59. list.
  60. This is similar to :func:`_get_fsdp_states_with_modules` except that we
  61. must call :func:`_is_fsdp_root` to force a lazy initialization to determine
  62. the FSDP root in case lazy initialization has not yet happened.
  63. """
  64. fsdp_root_states: List[_FSDPState] = []
  65. fsdp_root_modules: List[nn.Module] = []
  66. visited_fsdp_states: Set[_FSDPState] = set()
  67. # NOTE: This function assumes that `module.modules()` proceeds top-down.
  68. for submodule in module.modules():
  69. optional_state = _get_module_fsdp_state(submodule)
  70. if (
  71. optional_state is not None
  72. and optional_state not in visited_fsdp_states
  73. and _is_fsdp_root(optional_state, submodule)
  74. ):
  75. visited_fsdp_states.add(optional_state)
  76. fsdp_root_states.append(optional_state)
  77. fsdp_root_modules.append(submodule)
  78. return fsdp_root_states, fsdp_root_modules
  79. def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]:
  80. """See :func:`_get_fsdp_root_states_with_modules`."""
  81. fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module)
  82. return fsdp_root_states
  83. def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool:
  84. """
  85. Returns if ``state`` corresponds to that of an FSDP root.
  86. For the wrapper code path, ``state`` and ``module`` should be the same. For
  87. the non-wrapper code path, ``state`` should be ``module`` 's state.
  88. """
  89. # Force a lazy initialization to determine the FSDP root
  90. _lazy_init(state, module)
  91. assert state._is_root is not None # mypy
  92. return state._is_root
  93. @no_type_check
  94. def _lazy_init(
  95. state: _FSDPState,
  96. root_module: nn.Module,
  97. ) -> _FSDPState:
  98. """
  99. Performs initialization lazily, typically right before the first forward
  100. pass. The laziness is needed to ensure that the parameter device/dtype and
  101. the FSDP hierarchy have finalized. This method's actual logic only runs on
  102. the root FSDP instance, which performs initialization for all non-root FSDP
  103. instances to avoid partial initialization.
  104. For the non-composable code path, ``state`` and ``root_module`` should be
  105. the same, namely the FSDP instance itself.
  106. """
  107. if state._is_root is not None:
  108. return # no-op: already lazily initialized
  109. if not state._device_handle.is_available():
  110. # Allow the FSDP constructor to run even without CUDA but check this
  111. # once we start real execution
  112. raise RuntimeError("FSDP does not support CPU only execution")
  113. # The following logic is only run on the root FSDP instance since it will
  114. # set `_is_root=False` for the non-root instances
  115. state._is_root = True
  116. _assert_in_training_states(state, [TrainingState.IDLE])
  117. _check_flat_params_on_expected_device(state, root_module)
  118. state._all_fsdp_states = traversal_utils._get_fsdp_states(root_module)
  119. _init_streams(state)
  120. buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module)
  121. _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device)
  122. state._exec_order_data.init(state, root_module, state.process_group)
  123. _share_state_and_init_handle_attrs(state, root_module)
  124. return state
  125. def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module):
  126. """
  127. Checks that all ``FlatParameter``s in ``module`` 's tree managed by
  128. ``state`` are on the expected device for *lazy initialization*.
  129. """
  130. cpu_device = torch.device("cpu")
  131. for handle in traversal_utils._get_fsdp_handles(module):
  132. if (
  133. not handle._offload_params
  134. and handle.flat_param.device != state.compute_device
  135. ):
  136. raise RuntimeError(
  137. "An FSDP-managed module unexpectedly has parameters on "
  138. f"{handle.flat_param.device}. Make sure to move the module to "
  139. f"{state.compute_device} before training."
  140. )
  141. elif handle._offload_params and handle.flat_param.device != cpu_device:
  142. raise RuntimeError(
  143. "An FSDP-managed module with parameter CPU offloading enabled "
  144. f"has parameters on {handle.flat_param.device}. Make sure to "
  145. f"not move the module from CPU when offloading parameters."
  146. )
  147. @no_type_check
  148. def _share_state_and_init_handle_attrs(
  149. root_state: _FSDPState,
  150. root_module: nn.Module,
  151. ) -> None:
  152. """
  153. Shares data structure state from the ``root_state`` to all FSDP states in
  154. ``root_module`` 's module tree, and initializes handle attributes. These
  155. are done together to require a single loop over the states.
  156. """
  157. handle = root_state._handle
  158. if handle:
  159. handle.init_flat_param_attributes()
  160. attr_name_to_values: Dict[str, Set[Any]] = {}
  161. for attr_name in HOMOGENEOUS_ATTR_NAMES:
  162. attr_name_to_values[attr_name] = set()
  163. root_state._all_handles = root_state._exec_order_data.all_handles # share reference
  164. # Update _has_optim_in_backward for each handle.
  165. for handle in root_state._all_handles:
  166. flat_param = handle.flat_param
  167. if hasattr(flat_param, "_in_backward_optimizers"):
  168. raise RuntimeError(
  169. "FSDP optimizer in backward only supported with use_orig_params=True!"
  170. )
  171. handle._has_optim_in_backward = flat_param._params is not None and any(
  172. hasattr(param, "_in_backward_optimizers") for param in flat_param._params
  173. )
  174. if handle._has_optim_in_backward:
  175. torch._C._log_api_usage_once("fsdp.optimizer_in_backward")
  176. for fsdp_state in root_state._all_fsdp_states:
  177. for attr_name in HOMOGENEOUS_ATTR_NAMES:
  178. _p_assert(
  179. hasattr(fsdp_state, attr_name),
  180. f"FSDP state missing attribute {attr_name}",
  181. )
  182. attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
  183. if fsdp_state is root_state:
  184. continue
  185. # Relax the assert for non-root FSDP instances in case the nested
  186. # initialized module is wrapped again in FSDP later (e.g. after
  187. # training to run inference)
  188. _p_assert(
  189. fsdp_state._is_root is None or not fsdp_state._is_root,
  190. "Non-root FSDP instance's `_is_root` should not have been "
  191. "set yet or should have been set to `False`",
  192. )
  193. fsdp_state._is_root = False
  194. fsdp_state._unshard_stream = root_state._unshard_stream
  195. fsdp_state._post_backward_stream = root_state._post_backward_stream
  196. fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream
  197. fsdp_state._all_reduce_stream = root_state._all_reduce_stream
  198. fsdp_state._default_stream = root_state._default_stream
  199. fsdp_state._exec_order_data = root_state._exec_order_data
  200. fsdp_state._free_event_queue = root_state._free_event_queue
  201. if fsdp_state._fsdp_extension is not None:
  202. fsdp_state._fsdp_extension.compute_stream = root_state._default_stream
  203. handle = fsdp_state._handle
  204. if handle:
  205. handle.init_flat_param_attributes()
  206. for attr_name, attr_values in attr_name_to_values.items():
  207. if len(attr_values) != 1:
  208. raise ValueError(
  209. f"Expects one homogeneous value for {attr_name} but got {attr_values}"
  210. )
  211. @no_type_check
  212. def _init_streams(
  213. state: _FSDPState,
  214. ) -> None:
  215. """
  216. Initializes CUDA streams for overlapping communication, computation, and
  217. data transfers. The streams should be shared across FSDP instances.
  218. """
  219. assert state._is_root
  220. assert state._device_handle.is_available()
  221. uses_hybrid_sharding = any(
  222. fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES
  223. for fsdp_state in state._all_fsdp_states
  224. )
  225. # Prioritize all-gathers/reduce-scatters over async all-reduce for HSDP and
  226. # preserve the default priority of 0 otherwise
  227. high_priority = -1 if state.limit_all_gathers and uses_hybrid_sharding else 0
  228. # Default stream for computation
  229. state._default_stream = state._device_handle.current_stream()
  230. if state._fsdp_extension is not None:
  231. # set the compute stream to the FSDP extension
  232. state._fsdp_extension.compute_stream = state._default_stream
  233. # Stream for unshard logic, including allocating the all-gather destination
  234. # tensors and the all-gathers themselves
  235. state._unshard_stream = state._device_handle.Stream(priority=high_priority)
  236. # Stream for overlapping gradient reduction with the backward pass gradient
  237. # computation
  238. state._post_backward_stream = state._device_handle.Stream(priority=high_priority)
  239. # Stream for pre-unshard logic, namely allocations and writes for CPU
  240. # offloading (H2D copy) and mixed precision (low precision cast)
  241. state._pre_unshard_stream = state._device_handle.Stream(priority=high_priority)
  242. # Stream to run HSDP's all-reduce as async (if using HSDP)
  243. state._all_reduce_stream = (
  244. state._device_handle.Stream() if uses_hybrid_sharding else state._default_stream
  245. )
  246. @no_type_check
  247. def _unshard(
  248. state: _FSDPState,
  249. handle: FlatParamHandle,
  250. unshard_stream: torch.Stream,
  251. pre_unshard_stream: torch.Stream,
  252. ) -> None:
  253. """
  254. Unshards the handles in ``handles``. If the handles are in
  255. :meth:`summon_full_params` and are using mixed precision, then they are
  256. forced to full precision.
  257. Postcondition: handle's ``FlatParameter`` 's data is the padded
  258. unsharded flat parameter on the compute device.
  259. """
  260. if not handle:
  261. return
  262. with state._device_handle.stream(pre_unshard_stream):
  263. ran_pre_unshard = handle.pre_unshard()
  264. if ran_pre_unshard:
  265. unshard_stream.wait_stream(pre_unshard_stream)
  266. if state.limit_all_gathers:
  267. event = state._free_event_queue.dequeue_if_needed()
  268. if event:
  269. with torch.profiler.record_function(
  270. "FullyShardedDataParallel.rate_limiter"
  271. ):
  272. event.synchronize()
  273. with state._device_handle.stream(unshard_stream):
  274. handle.unshard()
  275. handle.post_unshard()
  276. @no_type_check
  277. def _reshard(
  278. state: _FSDPState,
  279. handle: FlatParamHandle,
  280. free_unsharded_flat_param: bool,
  281. ):
  282. """
  283. Reshards the handle. ``free_unsharded_flat_param`` indicates whether to
  284. free the handle's padded unsharded flat parameter.
  285. """
  286. handle.reshard(free_unsharded_flat_param)
  287. if state.limit_all_gathers and free_unsharded_flat_param:
  288. if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
  289. # We don't run a even queue for freeing under torch compile atm
  290. # But maybe we need to? TODO(voz): Look into this
  291. free_event = state._device_handle.Event()
  292. free_event.record()
  293. state._free_event_queue.enqueue(free_event)
  294. handle.post_reshard()
  295. # Flat parameter freed or not, we always have to "unshard" the parameter
  296. # upon next access to get its shape correct.
  297. handle._prefetched = False
  298. def _unshard_grads(
  299. handle: Optional[FlatParamHandle],
  300. ) -> None:
  301. if handle:
  302. handle.unshard_grad()
  303. def _reshard_grads(
  304. handle: Optional[FlatParamHandle],
  305. ) -> None:
  306. if handle:
  307. handle.reshard_grad()
  308. @no_type_check
  309. def _pre_forward(
  310. state: _FSDPState,
  311. handle: Optional[FlatParamHandle],
  312. unshard_fn: Callable,
  313. module: nn.Module,
  314. args: Tuple[Any, ...],
  315. kwargs: Dict[str, Any],
  316. ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
  317. """
  318. Runs the pre-forward logic. This includes an opportunity to unshard
  319. currently sharded parameters such as those for the current forward and
  320. registering post-backward hooks for these current parameters. This function
  321. also converts forward ``args`` and ``kwargs`` to the given precision.
  322. Args:
  323. handles (List[FlatParamHandle]): Handles giving the parameters used in
  324. the current forward.
  325. unshard_fn (Optional[Callable]): A callable to unshard any currently
  326. sharded parameters or ``None`` to not do any unsharding.
  327. module (nn.Module): Module whose forward this method runs right before;
  328. expected by the hook signature.
  329. args (Tuple[Any, ...]): Module forward ``args``.
  330. kwargs (Dict[str, Any]): Module forward ``kwargs``.
  331. """
  332. with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"):
  333. # For `fully_shard` + `checkpoint`, skip pre-forward logic in the
  334. # recomputed forward
  335. if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
  336. # For both checkpoint implementations, we do not need to re-cast
  337. # inputs here since they will be checkpointed in the low precision
  338. # either by AC or normally by autograd as long as the AC region is
  339. # nested within FSDP
  340. return args, kwargs
  341. state.training_state = TrainingState.FORWARD_BACKWARD
  342. state._exec_order_data.record_pre_forward(handle, module.training)
  343. if handle:
  344. handle._training_state = HandleTrainingState.FORWARD
  345. if unshard_fn is not None:
  346. unshard_fn(state, handle)
  347. # Register post-backward hooks to reshard the parameters and reduce-scatter
  348. # their gradients. They must be re-registered every forward pass in case
  349. # the `grad_fn` is mutated.
  350. _register_post_backward_hook(state, handle)
  351. # We have to reallocate the _cpu_grad if optimizer overlap
  352. # set the grad to None in the backward pass.
  353. if handle and handle._offload_params and handle.flat_param._cpu_grad is None:
  354. handle.flat_param._cpu_grad = torch.zeros_like(
  355. handle.flat_param._local_shard, device=torch.device("cpu")
  356. ).pin_memory(device=state.compute_device)
  357. should_cast_forward_inputs = (
  358. state._handle and not state._handle._force_full_precision
  359. )
  360. if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs:
  361. # Recursively convert args and kwargs to specified precision.
  362. input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
  363. args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
  364. _register_post_backward_reshard_only_hook(state, handle, args, kwargs)
  365. return args, kwargs
  366. @no_type_check
  367. def _pre_forward_unshard(
  368. state: _FSDPState,
  369. handle: Optional[FlatParamHandle],
  370. ) -> None:
  371. """Unshards parameters in the pre-forward."""
  372. if not handle:
  373. return
  374. # If the handles have been prefetched, then there is no need to call
  375. # `_unshard()` again
  376. if not handle._prefetched:
  377. _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  378. handle._needs_pre_forward_unshard = False
  379. # Don't wait during trace
  380. if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
  381. current_stream = state._device_handle.current_stream()
  382. if state._unshard_event is not None:
  383. current_stream.wait_event(state._unshard_event)
  384. state._unshard_event = None
  385. else:
  386. current_stream.wait_stream(state._unshard_stream)
  387. with torch.profiler.record_function(
  388. "FullyShardedDataParallel._pre_forward_prefetch"
  389. ):
  390. _prefetch_handle(state, handle, _PrefetchMode.FORWARD)
  391. @no_type_check
  392. def _post_forward(
  393. state: _FSDPState,
  394. handle: Optional[FlatParamHandle],
  395. reshard_fn: Callable,
  396. module: nn.Module,
  397. input: Any,
  398. output: Any,
  399. ) -> Any:
  400. """
  401. Runs the post-forward logic. This includes an opportunity to reshard
  402. currently unsharded parameters such as those used in the current forward
  403. and registering pre-backward hooks on the forward outputs.
  404. Args:
  405. handles (List[FlatParamHandle]): Handles giving the parameters used in
  406. the current forward.
  407. reshard_fn (Optional[Callable]): A callable to reshard any currently
  408. unsharded parameters (e.g. from the current forward) or ``None`` to
  409. not do any resharding.
  410. module (nn.Module): Module whose forward just ran, which should be a
  411. fully sharded module (see [Note: Fully Sharded Module]); expected
  412. by the hook signature.
  413. input (Any): Unused; expected by the hook signature.
  414. output (Any): Forward pass output; pre-backward hooks are registered on
  415. the tensors that require gradients in this output.
  416. Postcondition: Each ``FlatParameter`` 's data points to the sharded flat
  417. parameter.
  418. """
  419. with torch.profiler.record_function("FullyShardedDataParallel._post_forward"):
  420. # For `fully_shard` + `checkpoint`, skip post-forward logic in the
  421. # recomputed forward
  422. if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
  423. return output
  424. state._exec_order_data.record_post_forward(handle)
  425. if reshard_fn is not None:
  426. reshard_fn(state, handle)
  427. # Register pre-backward hooks to unshard the flat parameters for the
  428. # gradient computation (if needed)
  429. output = _register_pre_backward_hooks(state, module, output, handle)
  430. state.training_state = TrainingState.IDLE
  431. if handle:
  432. handle._training_state = HandleTrainingState.IDLE
  433. return output
  434. @no_type_check
  435. def _post_forward_reshard(
  436. state: _FSDPState,
  437. handle: FlatParamHandle,
  438. ) -> None:
  439. """Reshards parameters in the post-forward."""
  440. if not handle:
  441. return
  442. # Do not free the root's parameters in the post-forward for `FULL_SHARD`
  443. # with the intention that they are immediately used for backward
  444. # computation (though this may not be true)
  445. free_unsharded_flat_param = (
  446. not state._is_root
  447. and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
  448. )
  449. _reshard(state, handle, free_unsharded_flat_param)
  450. @no_type_check
  451. def _root_pre_forward(
  452. state: _FSDPState,
  453. module: nn.Module,
  454. args,
  455. kwargs,
  456. ) -> None:
  457. """
  458. Runs pre-forward logic specific to the root FSDP instance, which should run
  459. before any individual module's pre-forward. This starts with an attempt at
  460. lazy initialization (which only runs non-vacuously once). Otherwise, if
  461. this is called on a non-root FSDP instance, then it returns directly.
  462. Args:
  463. module (nn.Module): Module for which this logic tries to run. It may or
  464. may not be the root. If not, then this method does not do anything.
  465. """
  466. with torch.profiler.record_function("FullyShardedDataParallel._root_pre_forward"):
  467. _lazy_init(state, module)
  468. _p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
  469. if not state._is_root:
  470. # Always cast forward inputs in the root of this local FSDP unit for mixed
  471. # precision, as this is where mixed precision could be configed.
  472. # This is more useful for auto wrapping that is recommended in composable path.
  473. # For manual wrapping, cast forward inputs on each local FSDP unit root will
  474. # increase some overhead, so not turned on for model wrapper path right now where
  475. # manual wrapping is more broadly used.
  476. if _is_composable(state):
  477. return _root_cast_forward_input(state, module, args, kwargs)
  478. return args, kwargs
  479. # We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers
  480. # are in full precision and if we should cast them back to lower precision, which happens when
  481. # exiting eval() mode.
  482. handle = state._handle
  483. if handle:
  484. should_cast_buffers_to_full_prec = handle._force_full_precision
  485. else:
  486. should_cast_buffers_to_full_prec = True
  487. if should_cast_buffers_to_full_prec:
  488. _cast_buffers_to_dtype_and_device(
  489. buffers=dict(module.named_buffers()).values(),
  490. buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()),
  491. device=state.compute_device,
  492. )
  493. # This flag is only set when we cast buffers to full precision, to avoid the
  494. # CPU overhead that can stem from retrieving all buffers and their types in the
  495. # following else branch.
  496. state._needs_buffer_dtype_restore_check = True
  497. elif getattr(state, "_needs_buffer_dtype_restore_check", False):
  498. # Check if buffers are in full precision and we need to cast them
  499. # back down.
  500. (
  501. buffers,
  502. buffer_dtypes_for_computation,
  503. ) = _get_buffers_and_dtypes_for_computation(state, module)
  504. if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0:
  505. if any(
  506. buffer.dtype != buffer_dtype_for_computation
  507. for buffer, buffer_dtype_for_computation in zip(
  508. buffers, buffer_dtypes_for_computation
  509. )
  510. ):
  511. # Assume we have to cast everything if there is one mismatch
  512. _cast_buffers_to_dtype_and_device(
  513. buffers, buffer_dtypes_for_computation, state.compute_device
  514. )
  515. # We don't have to check this again until we cast buffers to full precision again.
  516. state._needs_buffer_dtype_restore_check = False
  517. if state.forward_prefetch:
  518. handles = []
  519. for fsdp_state in state._all_fsdp_states:
  520. if fsdp_state._handle:
  521. handles.append(fsdp_state._handle)
  522. for handle in handles:
  523. handle._needs_pre_forward_unshard = True
  524. handle._prefetched = False
  525. _wait_for_computation_stream(
  526. state._device_handle.current_stream(),
  527. state._unshard_stream,
  528. state._pre_unshard_stream,
  529. )
  530. _reset_flat_param_grad_info_if_needed(state._all_handles)
  531. # Prepares the forward inputs by moving them to ``compute_device``
  532. # TODO: Do not use the side stream for tensor copies for now; investigate
  533. # the perf with/without it.
  534. with torch.profiler.record_function("FullyShardedDataParallel._to_kwargs"):
  535. args_tuple, kwargs_tuple = _to_kwargs(
  536. args, kwargs, state.compute_device, False
  537. )
  538. args = args_tuple[0]
  539. kwargs = kwargs_tuple[0]
  540. return _root_cast_forward_input(state, module, args, kwargs)
  541. @no_type_check
  542. def _root_cast_forward_input(
  543. state: _FSDPState, module: torch.nn.Module, args, kwargs
  544. ) -> Tuple[Any, Any]:
  545. if state._handle:
  546. force_full_precision = not state._handle._force_full_precision
  547. else:
  548. force_full_precision = True
  549. should_cast_forward_inputs = (
  550. (module.training or not state._use_full_prec_in_eval) and force_full_precision
  551. ) and state.mixed_precision.cast_root_forward_inputs
  552. if should_cast_forward_inputs:
  553. input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
  554. args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
  555. return args, kwargs
  556. @no_type_check
  557. def _pre_backward_hook(
  558. state: _FSDPState,
  559. module: nn.Module,
  560. handle: FlatParamHandle,
  561. grad,
  562. *unused: Any,
  563. ) -> Any:
  564. """
  565. Prepares ``_handle`` 's ``FlatParameter`` s for gradient computation.
  566. Args:
  567. module (nn.Module): Fully sharded module (see [Note: Fully Sharded
  568. Module]).
  569. """
  570. # Only run the pre-backward hook once per group of handles involved in the
  571. # same module forward computation
  572. if (
  573. handle
  574. and hasattr(handle, "_ran_pre_backward_hook")
  575. and handle._ran_pre_backward_hook
  576. ):
  577. logger.debug("%s %s", id(state), "Not Running pre backward! Already Ran!")
  578. return grad
  579. with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"):
  580. # Queue the post-backward callback once for the root FSDP instance to
  581. # attach it to the outermost backward graph task so that it is called
  582. # after all backward calls complete
  583. if state._is_root and not state._post_backward_callback_queued:
  584. _register_post_backward_final_callback(state, module)
  585. _reset_flat_param_grad_info_if_needed(state._all_handles)
  586. elif handle:
  587. allowed_states = [TrainingState.IDLE]
  588. if _is_composable(state):
  589. allowed_states.append(TrainingState.FORWARD_BACKWARD)
  590. _assert_in_training_states(state, allowed_states)
  591. state.training_state = TrainingState.FORWARD_BACKWARD
  592. # Queueing the post-backward callback is the only logic that is not
  593. # per-handle in the pre-backward hook, so we can return early here if
  594. # there are no handles.
  595. if not handle:
  596. return grad
  597. handle._training_state = HandleTrainingState.BACKWARD_PRE
  598. if handle._needs_pre_backward_unshard:
  599. # If the handles have been prefetched, then there is no need to
  600. # call `_unshard()` again
  601. if not handle._prefetched:
  602. _unshard(
  603. state,
  604. handle,
  605. state._unshard_stream,
  606. state._pre_unshard_stream,
  607. )
  608. # Don't wait during trace
  609. if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
  610. state._device_handle.current_stream().wait_stream(state._unshard_stream)
  611. # Set this to `False` to ensure that a mistargeted prefetch does not
  612. # actually unshard these handles
  613. handle._needs_pre_backward_unshard = False
  614. with torch.profiler.record_function(
  615. "FullyShardedDataParallel._pre_backward_prefetch"
  616. ):
  617. _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
  618. handle.prepare_gradient_for_backward()
  619. handle._ran_pre_backward_hook = True
  620. return grad
  621. @no_type_check
  622. @torch.no_grad()
  623. def _post_backward_hook(
  624. state: _FSDPState,
  625. handle: FlatParamHandle,
  626. flat_param,
  627. *unused: Any,
  628. ):
  629. """
  630. Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
  631. Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
  632. unsharded gradient for the local batch.
  633. Postcondition:
  634. - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
  635. unsharded gradient.
  636. - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
  637. gradient (accumulating with any existing gradient).
  638. """
  639. _log_post_backward_hook(state, handle, logger)
  640. flat_param = handle.flat_param
  641. flat_param._post_backward_called = True
  642. with torch.autograd.profiler.record_function(
  643. "FullyShardedDataParallel._post_backward_hook"
  644. ):
  645. _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
  646. # For multiple applications of reentrant AC across submodules sharing
  647. # the same `FlatParameter`, the post-backward hook may run multiple
  648. # times in one backward, in which case we permit the state to already
  649. # be in `BACKWARD_POST`.
  650. _p_assert(
  651. handle._training_state
  652. in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
  653. f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
  654. )
  655. handle._training_state = HandleTrainingState.BACKWARD_POST
  656. if flat_param.grad is None:
  657. return
  658. if flat_param.grad.requires_grad:
  659. raise RuntimeError("FSDP does not support gradients of gradients")
  660. _post_backward_reshard(state, handle)
  661. if not state._sync_gradients:
  662. if handle._use_orig_params:
  663. handle._use_unsharded_grad_views()
  664. return
  665. # Wait for all ops in the current stream (e.g. gradient computation) to
  666. # finish before reduce-scattering the gradient
  667. if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
  668. state._post_backward_stream.wait_stream(
  669. state._device_handle.current_stream()
  670. )
  671. with state._device_handle.stream(state._post_backward_stream):
  672. autograd_computed_grad = flat_param.grad.data
  673. if (
  674. not _low_precision_hook_enabled(state)
  675. and flat_param.grad.dtype != handle._reduce_dtype
  676. # If we are forcing full precision but communicating grads
  677. # (i.e. model.eval() + full precision in eval was configured), don't downcast gradient.
  678. and not handle._force_full_precision
  679. ):
  680. flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype)
  681. if handle.uses_sharded_strategy:
  682. _reduce_grad(state, handle)
  683. else:
  684. _reduce_grad_no_shard(state, handle)
  685. # Since the unsharded gradient is produced in the computation
  686. # stream and consumed in the post-backward stream, inform the
  687. # caching allocator (before it goes out of scope)
  688. _no_dispatch_record_stream(
  689. autograd_computed_grad, state._post_backward_stream
  690. )
  691. def _post_backward_reshard_only_hook(
  692. state: _FSDPState,
  693. handle: FlatParamHandle,
  694. *unused: Any,
  695. ) -> None:
  696. with torch.profiler.record_function(
  697. "FullyShardedDataParallel._post_backward_hook_reshard_only"
  698. ):
  699. # `_pre_backward_hook` may not get executed
  700. # if forward output does not require grad
  701. # overwrite IDLE state for post-backward prefetching
  702. state.training_state = TrainingState.FORWARD_BACKWARD
  703. handle._training_state = HandleTrainingState.BACKWARD_POST
  704. _post_backward_reshard(state, handle)
  705. def _post_backward_reshard(
  706. state: _FSDPState,
  707. handle: FlatParamHandle,
  708. *unused: Any,
  709. ) -> None:
  710. free_unsharded_flat_param = _should_free_in_backward(state, handle)
  711. _reshard(state, handle, free_unsharded_flat_param)
  712. # TODO: Post-backward prefetching does not support the multiple handles
  713. # per module case since the post-backward hook runs per handle, not per
  714. # group of handles.
  715. with torch.profiler.record_function(
  716. "FullyShardedDataParallel._post_backward_prefetch"
  717. ):
  718. _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
  719. @no_type_check
  720. def _should_free_in_backward(
  721. state: _FSDPState,
  722. handle: FlatParamHandle,
  723. ) -> bool:
  724. """
  725. Returns whether FSDP should free the unsharded flat parameter in the
  726. post-backward or not.
  727. """
  728. if not handle.uses_sharded_strategy:
  729. return False
  730. # If not syncing gradients, then we do not free for strategies that do not
  731. # reshard after forward as a *heuristic* to tradeoff higher memory for
  732. # higher throughput.
  733. return (
  734. state._sync_gradients
  735. or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
  736. )
  737. @no_type_check
  738. def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None:
  739. """
  740. For sharded strategies, this runs gradient reduction, sharded gradient
  741. accumulation if needed, and the post-reduction callback.
  742. """
  743. flat_param = handle.flat_param
  744. uses_hybrid_sharded_strategy = handle._sharding_strategy in (
  745. HandleShardingStrategy.HYBRID_SHARD,
  746. HandleShardingStrategy._HYBRID_SHARD_ZERO2,
  747. )
  748. # We clear `.grad` to permit multiple backwards. This avoids a race where
  749. # the second backward pass computation precedes ahead of the first backward
  750. # pass reduction, which is possible since the reduction is issued in a
  751. # separate stream and is async and would result in reducing the wrong
  752. # gradient.
  753. unsharded_grad = flat_param.grad.data
  754. flat_param.grad = None
  755. padded_unsharded_grad, new_sharded_grad = _get_reduce_scatter_tensors(
  756. state, unsharded_grad
  757. )
  758. if state._comm_hook is None: # default path
  759. _div_if_needed(padded_unsharded_grad, state._gradient_predivide_factor)
  760. pg = (
  761. handle._fake_process_group
  762. if handle._use_fake_reduce
  763. else state.process_group
  764. )
  765. dist.reduce_scatter_tensor(
  766. new_sharded_grad,
  767. padded_unsharded_grad,
  768. group=pg,
  769. )
  770. if uses_hybrid_sharded_strategy:
  771. # Don't wait during trace
  772. if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
  773. state._all_reduce_stream.wait_stream(state._post_backward_stream)
  774. with state._device_handle.stream(state._all_reduce_stream):
  775. # Since the new sharded gradient is produced in the post-
  776. # backward stream and consumed in the all-reduce stream,
  777. # inform the caching allocator
  778. _no_dispatch_record_stream(new_sharded_grad, state._all_reduce_stream)
  779. dist.all_reduce(new_sharded_grad, group=state._inter_node_pg)
  780. _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)
  781. grad_to_offload = _accumulate_sharded_grad(
  782. state, handle, new_sharded_grad
  783. )
  784. _post_reduce_grad_callback(state, handle, grad_to_offload)
  785. return
  786. _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)
  787. else:
  788. state._comm_hook(
  789. state._comm_hook_state, padded_unsharded_grad, new_sharded_grad
  790. )
  791. # NOTE: HSDP variants do not support communication hook.
  792. grad_to_offload = _accumulate_sharded_grad(state, handle, new_sharded_grad)
  793. _post_reduce_grad_callback(state, handle, grad_to_offload)
  794. @no_type_check
  795. def _get_reduce_scatter_tensors(
  796. state: _FSDPState, unsharded_grad: torch.Tensor
  797. ) -> Tuple[torch.Tensor, torch.Tensor]:
  798. """
  799. Returns the input and output tensors to reduce-scatter, respectively.
  800. """
  801. chunks = list(unsharded_grad.chunk(state.world_size))
  802. numel_to_pad = state.world_size * chunks[0].numel() - unsharded_grad.numel()
  803. padded_unsharded_grad = (
  804. F.pad(unsharded_grad, [0, numel_to_pad]) if numel_to_pad > 0 else unsharded_grad
  805. )
  806. new_sharded_grad = torch.empty_like(chunks[0]) # padded
  807. return padded_unsharded_grad, new_sharded_grad
  808. @no_type_check
  809. def _accumulate_sharded_grad(
  810. state: _FSDPState,
  811. handle: FlatParamHandle,
  812. sharded_grad: torch.Tensor,
  813. ) -> torch.Tensor:
  814. """
  815. Accumulates the reduce-scattered sharded gradient with any existing sharded
  816. gradient if needed, returning the gradient to offload (if CPU offloading is
  817. enabled).
  818. """
  819. flat_param = handle.flat_param
  820. _cast_grad_to_param_dtype(state, sharded_grad, flat_param)
  821. # Save the sharded gradient in `_saved_grad_shard` to support gradient
  822. # accumulation -- for multiple backwards, the gradient reductions may
  823. # happen in arbitrary order
  824. accumulate_grad = hasattr(flat_param, "_saved_grad_shard")
  825. if accumulate_grad:
  826. _check_grad_to_accumulate(sharded_grad, flat_param._saved_grad_shard)
  827. flat_param._saved_grad_shard += sharded_grad
  828. else:
  829. flat_param._saved_grad_shard = sharded_grad
  830. grad_to_offload = flat_param._saved_grad_shard
  831. return grad_to_offload
  832. @no_type_check
  833. def _reduce_grad_no_shard(state: _FSDPState, handle: FlatParamHandle) -> None:
  834. """
  835. For no-shard, this runs gradient reduction (which directly covers any
  836. gradient accumulation implicitly) and the post-reduction callback.
  837. """
  838. flat_param = handle.flat_param
  839. if state._comm_hook is None: # default path
  840. _div_if_needed(flat_param.grad, state._gradient_predivide_factor)
  841. dist.all_reduce(flat_param.grad, group=state.process_group)
  842. _div_if_needed(flat_param.grad, state._gradient_postdivide_factor)
  843. else:
  844. state._comm_hook(state._comm_hook_state, flat_param.grad)
  845. # For `NO_SHARD`, we can keep the low precision gradients by simply
  846. # omitting the cast altogether
  847. if not handle._keep_low_precision_grads:
  848. _cast_grad_to_param_dtype(state, flat_param.grad, flat_param)
  849. grad_to_offload = flat_param.grad.data
  850. _post_reduce_grad_callback(state, handle, grad_to_offload)
  851. @no_type_check
  852. def _post_reduce_grad_callback(
  853. state: _FSDPState,
  854. handle: FlatParamHandle,
  855. # Additional arguments needed for the callback logic
  856. grad_to_offload: torch.Tensor,
  857. ):
  858. """
  859. This callback captures any logic to run after the gradient reduction
  860. finishes. Currently, this offloads the gradient to CPU if CPU offloading is
  861. enabled and uses sharded gradient views if ``use_orig_params=True``.
  862. """
  863. _offload_grad(state, handle, grad_to_offload)
  864. _post_backward_use_sharded_grad_views(handle)
  865. @no_type_check
  866. def _offload_grad(
  867. state: _FSDPState,
  868. handle: FlatParamHandle,
  869. grad_to_offload: torch.Tensor,
  870. ):
  871. if not handle._offload_params:
  872. return
  873. # Offload the gradient to CPU to ensure parameters and gradients are on the
  874. # same device as required by the optimizer
  875. # TODO: Investigate why `NO_SHARD` breaks correctness when using
  876. # `non_blocking=True` here.
  877. # TODO (rohan-varma): When CPU offload and optimizer overlap,
  878. # non_blocking=True won't work since the copy may have not finished before
  879. # the optimizer step executes on CPU. If we want to use non-blocking=True
  880. # here, we'll have to synchronize before using result on CPU.
  881. non_blocking = handle.uses_sharded_strategy and not handle._has_optim_in_backward
  882. handle.flat_param._cpu_grad.copy_(
  883. grad_to_offload.detach(), non_blocking=non_blocking
  884. ) # synchronized in the post-backward callback
  885. # Since the gradient being offloaded may have been produced in the
  886. # computation stream and is being consumed here in the post-backward
  887. # stream, inform the caching allocator
  888. _no_dispatch_record_stream(grad_to_offload.data, state._post_backward_stream)
  889. @no_type_check
  890. def _post_backward_use_sharded_grad_views(handle: FlatParamHandle):
  891. if not handle._use_orig_params:
  892. return
  893. # Since the handle's `FlatParameter` completed its gradient computation, we
  894. # should reset the gradient noneness mask
  895. handle._reset_is_grad_none()
  896. # Delay using sharded gradient views until after the reduce-scatter instead
  897. # of immediately after resharding
  898. handle._use_sharded_grad_views()
  899. if handle._has_optim_in_backward:
  900. handle.prepare_gradient_for_optim()
  901. for orig_param in handle.flat_param._params:
  902. # Check for `None` gradient to filter parameters not in the rank
  903. if orig_param.grad is not None and hasattr(
  904. orig_param, "_in_backward_optimizers"
  905. ):
  906. # TODO (rohan-varma): For CPU offload, this unfortunately
  907. # operates on CPU because the parameters and gradients have
  908. # already been offloaded. We should run this on GPU after
  909. # refactoring.
  910. for optim in orig_param._in_backward_optimizers:
  911. optim.step()
  912. optim.zero_grad(set_to_none=True)
  913. handle._reset_flat_param_grad_info_if_needed()
  914. if handle._offload_params:
  915. handle.flat_param._cpu_grad = None
  916. def _div_if_needed(tensor: torch.Tensor, div_factor: float) -> None:
  917. if div_factor > 1:
  918. tensor.div_(div_factor)
  919. @no_type_check
  920. def _cast_grad_to_param_dtype(
  921. state: _FSDPState,
  922. sharded_grad: torch.Tensor,
  923. param: FlatParameter,
  924. ):
  925. """
  926. Casts ``sharded_grad`` back to the full parameter dtype so that the
  927. optimizer step runs with that dtype. This performs an actual cast if
  928. 1. parameters were in reduced precision during the forward since then
  929. gradients would be in that reduced precision, or
  930. 2. parameters were not in reduced precision but gradients were in
  931. reduced precision for communication.
  932. However, if a low precision communication hook is registered, then this
  933. dtype cast happens in the hook instead.
  934. """
  935. _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
  936. if not _low_precision_hook_enabled(state) and sharded_grad.dtype != param.dtype:
  937. low_prec_grad_data = sharded_grad.data
  938. sharded_grad.data = sharded_grad.data.to(dtype=param.dtype)
  939. # Since for `NO_SHARD`, the gradient is produced in the computation
  940. # stream and consumed here in the post-backward stream, inform the
  941. # caching allocator; for the sharded strategies, the gradient is
  942. # produced in the post-backward stream, so this `record_stream()`
  943. # should be a no-op
  944. _no_dispatch_record_stream(
  945. low_prec_grad_data, state._device_handle.current_stream()
  946. )
  947. def _check_grad_to_accumulate(
  948. new_sharded_grad: torch.Tensor,
  949. accumulated_grad: torch.Tensor,
  950. ) -> None:
  951. _p_assert(
  952. accumulated_grad.shape == new_sharded_grad.shape,
  953. "Shape mismatch when accumulating gradients: "
  954. f"existing gradient shape={accumulated_grad.shape} "
  955. f"new gradient shape={new_sharded_grad.shape}",
  956. )
  957. _p_assert(
  958. accumulated_grad.device == new_sharded_grad.device,
  959. "Device mismatch when accumulating gradients: "
  960. f"existing gradient device={accumulated_grad.device} "
  961. f"new gradient device={new_sharded_grad.device}",
  962. )
  963. @no_type_check
  964. def _low_precision_hook_enabled(state: _FSDPState) -> bool:
  965. return state._comm_hook in LOW_PRECISION_HOOKS
  966. @no_type_check
  967. @torch.no_grad()
  968. def _post_backward_final_callback(
  969. state: _FSDPState,
  970. module: nn.Module,
  971. ):
  972. """
  973. This waits for the post-backward to finish and performs some final cleanup.
  974. This runs at the end of the entire backward pass and should only be called
  975. on the root FSDP instance.
  976. """
  977. _p_assert(
  978. state._is_root,
  979. "The post-backward callback should only be called on the root FSDP instance",
  980. )
  981. root_state = state
  982. if root_state._sync_gradients:
  983. current_stream = state._device_handle.current_stream()
  984. # TODO (rohan-varma): this also waits for the overlapped optimizer step to finish
  985. # since it currently runs in the post-backward stream. That can be
  986. # pushed to the next forward if run in a different stream
  987. current_stream.wait_stream(root_state._post_backward_stream)
  988. if root_state._all_reduce_stream is not current_stream: # uses HSDP
  989. current_stream.wait_stream(root_state._all_reduce_stream)
  990. if root_state.cpu_offload.offload_params:
  991. # Wait for non-blocking GPU -> CPU sharded gradient copies from the
  992. # post-backward hooks to finish explicitly since CPU gradients do
  993. # not automatically synchronize with the GPU
  994. state._device_handle.current_stream().synchronize()
  995. root_state._exec_order_data.next_iter()
  996. for fsdp_state in state._all_fsdp_states:
  997. _catch_all_reshard(fsdp_state)
  998. _finalize_params(fsdp_state)
  999. fsdp_state.training_state = TrainingState.IDLE
  1000. handle = fsdp_state._handle
  1001. if handle:
  1002. handle._ran_pre_backward_hook = False
  1003. handle._needs_pre_backward_unshard = False
  1004. handle._post_forward_index = None
  1005. handle._training_state = HandleTrainingState.IDLE
  1006. handle._prefetched = False
  1007. # Reset for cases like one forward and multiple backwards
  1008. root_state._post_backward_callback_queued = False
  1009. @no_type_check
  1010. def _catch_all_reshard(
  1011. state: _FSDPState,
  1012. ) -> None:
  1013. """
  1014. Reshards the parameters that may not have been resharded in the
  1015. post-backward hook. This can happen when a module's output is used in the
  1016. forward pass, meaning that its pre-backward hook runs (unsharding the
  1017. parameter), but the post-backward hook does not run because the output was
  1018. not jused in the loss computation corresponding to this backward pass.
  1019. """
  1020. # Wrap with a try-except to provide a more informative traceback if an
  1021. # error is raised
  1022. try:
  1023. if state._handle:
  1024. # TODO: This already-resharded check is brittle:
  1025. # https://github.com/pytorch/pytorch/issues/83956
  1026. already_resharded = (
  1027. state._handle.flat_param.data_ptr()
  1028. == state._handle.flat_param._local_shard.data_ptr()
  1029. # If FSDP skipped using sharded views, then the flat parameter
  1030. # still points to the sharded data, so we need to reshard to
  1031. # use sharded views
  1032. and not state._handle._skipped_use_sharded_views
  1033. )
  1034. if already_resharded:
  1035. return
  1036. free_unsharded_flat_param = _should_free_in_backward(state, state._handle)
  1037. _reshard(state, state._handle, free_unsharded_flat_param)
  1038. except Exception as e:
  1039. _p_assert(
  1040. False,
  1041. f"Got exception in the catch-all reshard for {state}: {str(e)}",
  1042. raise_assertion_error=False,
  1043. )
  1044. raise e
  1045. @no_type_check
  1046. def _finalize_params(
  1047. state: _FSDPState,
  1048. ) -> None:
  1049. """Finalizes the parameters before the next iteration."""
  1050. handle = state._handle
  1051. if not handle:
  1052. return
  1053. flat_param = handle.flat_param
  1054. if torch.distributed._functional_collectives.is_torchdynamo_compiling():
  1055. if hasattr(flat_param, "_post_backward_hook_handle"):
  1056. pbhs_handle = flat_param._post_backward_hook_handle
  1057. pbhs_handle.remove()
  1058. del flat_param._post_backward_hook_handle
  1059. else:
  1060. if hasattr(flat_param, "_post_backward_hook_state"):
  1061. post_backward_hook_state_len = len(flat_param._post_backward_hook_state)
  1062. expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 1
  1063. _p_assert(
  1064. post_backward_hook_state_len == expected_post_backward_hook_state_len,
  1065. f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
  1066. )
  1067. flat_param._post_backward_hook_state[-1].remove()
  1068. delattr(flat_param, "_post_backward_hook_state")
  1069. if flat_param.requires_grad:
  1070. if not state._sync_gradients:
  1071. # Preserve the gradient accumulation state if not synchronizing
  1072. # gradients: `.grad` remains the unsharded gradient from prior
  1073. # `no_sync()` iterations, and `_saved_grad_shard` remains the
  1074. # sharded gradient from the last synchronized iteration
  1075. return
  1076. if not handle._has_optim_in_backward:
  1077. handle.prepare_gradient_for_optim()
  1078. _p_assert(
  1079. hasattr(flat_param, "_post_backward_called"),
  1080. "Expects `_post_backward_called` to be set on the `FlatParameter`",
  1081. )
  1082. flat_param._post_backward_called = False
  1083. @no_type_check
  1084. def _prefetch_handle(
  1085. state: _FSDPState,
  1086. current_handle: Optional[FlatParamHandle],
  1087. prefetch_mode: _PrefetchMode,
  1088. ) -> None:
  1089. """
  1090. Prefetches the next handles if needed (without synchronization). An empty
  1091. handles key cannot prefetch.
  1092. """
  1093. if not current_handle:
  1094. return
  1095. handle = _get_handle_to_prefetch(state, current_handle)
  1096. if not handle:
  1097. return
  1098. # Temporarily emulate the training state while calling `_unshard` to
  1099. # ensure the correct `as_params` for `_use_unsharded_views()`
  1100. prev_training_state = handle._training_state
  1101. if prefetch_mode == _PrefetchMode.BACKWARD:
  1102. handle._training_state = HandleTrainingState.BACKWARD_PRE
  1103. elif prefetch_mode == _PrefetchMode.FORWARD:
  1104. handle._training_state = HandleTrainingState.FORWARD
  1105. else:
  1106. raise ValueError(f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}")
  1107. # Prefetch the next set of handles without synchronizing to allow
  1108. # the sync to happen as late as possible to maximize overlap
  1109. _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  1110. handle._training_state = prev_training_state
  1111. handle._prefetched = True
  1112. @no_type_check
  1113. def _get_handle_to_prefetch(
  1114. state: _FSDPState,
  1115. current_handle: FlatParamHandle,
  1116. ) -> FlatParamHandle:
  1117. """
  1118. Returns a :class:`list` of the handles keys to prefetch for the next
  1119. module(s), where ``current_handle`` represents the current module.
  1120. "Prefetching" refers to running the unshard logic early (without
  1121. synchronization), and the "next" modules depend on the recorded execution
  1122. order and the current training state.
  1123. """
  1124. training_state = _get_training_state(current_handle)
  1125. valid_training_states = (
  1126. HandleTrainingState.BACKWARD_PRE,
  1127. HandleTrainingState.BACKWARD_POST,
  1128. HandleTrainingState.FORWARD,
  1129. )
  1130. _p_assert(
  1131. training_state in valid_training_states,
  1132. f"Prefetching is only supported in {valid_training_states} but "
  1133. f"currently in {training_state}",
  1134. )
  1135. eod = state._exec_order_data
  1136. target_handle: Optional[FlatParamHandle] = None
  1137. if (
  1138. training_state == HandleTrainingState.BACKWARD_PRE
  1139. and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
  1140. ) or (
  1141. training_state == HandleTrainingState.BACKWARD_POST
  1142. and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
  1143. ):
  1144. target_handle_candidate = eod.get_handle_to_backward_prefetch(current_handle)
  1145. if (
  1146. target_handle_candidate
  1147. and target_handle_candidate._needs_pre_backward_unshard
  1148. and not target_handle_candidate._prefetched
  1149. ):
  1150. target_handle = target_handle_candidate
  1151. else:
  1152. target_handle = None
  1153. elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch:
  1154. target_handle_candidate = eod.get_handle_to_forward_prefetch(current_handle)
  1155. if (
  1156. target_handle_candidate
  1157. and target_handle_candidate._needs_pre_forward_unshard
  1158. and not target_handle_candidate._prefetched
  1159. ):
  1160. target_handle = target_handle_candidate
  1161. else:
  1162. target_handle = None
  1163. return target_handle
  1164. def _get_training_state(
  1165. handle: FlatParamHandle,
  1166. ) -> HandleTrainingState:
  1167. """Returns the training state of the handles in ``handle``."""
  1168. _p_assert(handle, "Expects a non-empty handle")
  1169. return handle._training_state
  1170. @no_type_check
  1171. def _register_pre_forward_hook(
  1172. state: _FSDPState,
  1173. module: nn.Module,
  1174. ) -> None:
  1175. """
  1176. Registers a pre-forward hook on ``module``.
  1177. """
  1178. for forward_handle in state._pre_forward_handles:
  1179. forward_handle.remove()
  1180. state._pre_forward_handles.clear()
  1181. module_param_handle = state._fully_sharded_module_to_handle.get(module, None)
  1182. hook = functools.partial(
  1183. _pre_forward, state, module_param_handle, _pre_forward_unshard
  1184. )
  1185. state._pre_forward_handles.append(
  1186. module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
  1187. )
  1188. @no_type_check
  1189. def _register_post_forward_hook(
  1190. state: _FSDPState,
  1191. module: nn.Module,
  1192. ) -> None:
  1193. """
  1194. Registers a post-forward hook on ``module``. Even if the module has no
  1195. handles, we should register the hook since it will register the module's
  1196. pre-backward hook.
  1197. """
  1198. for forward_handle in state._post_forward_handles:
  1199. forward_handle.remove()
  1200. state._post_forward_handles.clear()
  1201. module_param_handle = state._fully_sharded_module_to_handle.get(module, None)
  1202. hook = functools.partial(
  1203. _post_forward,
  1204. state,
  1205. module_param_handle,
  1206. _post_forward_reshard,
  1207. )
  1208. state._post_forward_handles.append(module.register_forward_hook(hook))
  1209. @no_type_check
  1210. def _register_root_pre_forward_hook(
  1211. state: _FSDPState,
  1212. module: nn.Module,
  1213. ):
  1214. """
  1215. Registers root pre-forward hook on ``module``, which should be the local
  1216. FSDP root.
  1217. NOTE: For the current composable FSDP design, we have each application of
  1218. ``fully_shard()`` to a module to indicate that that module is the local
  1219. FSDP root. We may remove this assumption in the future, in which case we
  1220. will need to register this root pre-forward hook on any candidate module
  1221. that may be the local FSDP root.
  1222. """
  1223. for forward_handle in state._root_pre_forward_handles:
  1224. forward_handle.remove()
  1225. state._root_pre_forward_handles.clear()
  1226. hook = functools.partial(_root_pre_forward, state)
  1227. state._root_pre_forward_handles.append(
  1228. module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
  1229. )
  1230. @no_type_check
  1231. def _register_pre_backward_hooks(
  1232. state: _FSDPState,
  1233. module: nn.Module,
  1234. outputs: Any,
  1235. handle: FlatParamHandle,
  1236. ) -> None:
  1237. """
  1238. Registers pre-backward hooks on the tensors that require gradients in the
  1239. forward pass outputs ``outputs``, which were computed using the
  1240. ``FlatParameter`` s of ``handles``.
  1241. Args:
  1242. module (nn.Module): Fully sharded module (see [Note: Fully Sharded
  1243. Module]).
  1244. Returns:
  1245. Forward pass outputs with pre-backward hooks registered to tensors that
  1246. require gradients.
  1247. """
  1248. # If there is no gradient computation, then there is no need for
  1249. # pre-backward logic
  1250. if not torch.is_grad_enabled():
  1251. return outputs
  1252. if state._is_root:
  1253. state._post_backward_callback_queued = False # only defined on the root
  1254. if handle:
  1255. handle._needs_pre_backward_unshard = False
  1256. # Since these handles' `FlatParameter`s participated in a forward, we
  1257. # conservatively assume that they will be used in the backward
  1258. handle._ran_pre_backward_hook = False
  1259. def _register_hook(t: torch.Tensor) -> torch.Tensor:
  1260. if t.requires_grad:
  1261. t.register_hook(
  1262. torch.utils.hooks.unserializable_hook(
  1263. functools.partial(_pre_backward_hook, state, module, handle)
  1264. )
  1265. )
  1266. if handle:
  1267. handle._needs_pre_backward_unshard = True
  1268. return t
  1269. return _apply_to_tensors(_register_hook, outputs)
  1270. def _register_post_backward_hook(
  1271. state: _FSDPState,
  1272. handle: Optional[FlatParamHandle],
  1273. ) -> None:
  1274. """
  1275. Registers post-backward hooks on the ``FlatParameter`` s'
  1276. ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients.
  1277. The ``AccumulateGrad`` object represents the last function that finalizes
  1278. the ``FlatParameter`` 's gradient, so it only runs after its entire
  1279. gradient computation has finished.
  1280. We register the post-backward hook only once in the *first* forward that a
  1281. ``FlatParameter`` participates in. This relies on the ``AccumulateGrad``
  1282. object being preserved through multiple forwards.
  1283. NOTE: We follow this heuristic to prefer the *first* forward to target the
  1284. parameter mixed precision case, where there are *separate*
  1285. ``AccumulateGrad`` objects across the different forwards. (Without
  1286. parameter mixed precision, the ``AccumulateGrad`` objects are the same.) If
  1287. we instead prefer the *last* forward, then the hook runs early.
  1288. """
  1289. # If there is no gradient computation, then there is no need for
  1290. # post-backward logic
  1291. if not torch.is_grad_enabled():
  1292. return
  1293. if not handle:
  1294. return
  1295. flat_param = handle.flat_param
  1296. if torch.distributed._functional_collectives.is_torchdynamo_compiling():
  1297. already_registered = hasattr(flat_param, "_post_backward_hook_handle")
  1298. if already_registered or not flat_param.requires_grad:
  1299. return
  1300. hook = functools.partial(_post_backward_hook, state, handle)
  1301. hook_handle = flat_param.register_post_accumulate_grad_hook(hook)
  1302. flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined]
  1303. else:
  1304. already_registered = hasattr(flat_param, "_post_backward_hook_state")
  1305. if already_registered or not flat_param.requires_grad:
  1306. return
  1307. # Get the `AccumulateGrad` object
  1308. temp_flat_param = flat_param.expand_as(flat_param)
  1309. _p_assert(
  1310. temp_flat_param.grad_fn is not None,
  1311. "The `grad_fn` is needed to access the `AccumulateGrad` and "
  1312. "register the post-backward hook",
  1313. )
  1314. acc_grad = temp_flat_param.grad_fn.next_functions[0][0] # type: ignore[union-attr]
  1315. assert acc_grad is not None
  1316. hook_handle = acc_grad.register_hook(
  1317. functools.partial(_post_backward_hook, state, handle)
  1318. )
  1319. flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined]
  1320. def _register_post_backward_reshard_only_hook(
  1321. state: _FSDPState,
  1322. handle: Optional[FlatParamHandle],
  1323. args: Tuple[Any, ...],
  1324. kwargs: Dict[str, Any],
  1325. ) -> None:
  1326. """
  1327. Registers post-backward hooks to reshard flat parameters that do not
  1328. require gradient. We register these using multi-post-grad hooks on the
  1329. input activations to ensure that all gradients that may depend on the
  1330. parameters have been computed before resharding.
  1331. """
  1332. # If there is no gradient computation, then there is no need for
  1333. # post-backward logic
  1334. if not torch.is_grad_enabled():
  1335. return
  1336. # Construct `inp_tensors` lazily to avoid CPU overhead in typical case
  1337. # where each flat parameter requires gradient
  1338. inp_tensors: Optional[List[torch.Tensor]] = None
  1339. if not handle:
  1340. return
  1341. flat_param = handle.flat_param
  1342. if torch.distributed._functional_collectives.is_torchdynamo_compiling():
  1343. already_registered = hasattr(flat_param, "_post_backward_hook_handle")
  1344. else:
  1345. already_registered = hasattr(flat_param, "_post_backward_hook_state")
  1346. if already_registered or flat_param.requires_grad:
  1347. return
  1348. if inp_tensors is None:
  1349. args_flat = pytree.arg_tree_leaves(*args, **kwargs)
  1350. inp_tensors = [
  1351. obj for obj in args_flat if torch.is_tensor(obj) and obj.requires_grad
  1352. ]
  1353. assert inp_tensors is not None # mypy
  1354. hook_handle = register_multi_grad_hook(
  1355. inp_tensors, functools.partial(_post_backward_reshard_only_hook, state, handle)
  1356. )
  1357. if torch.distributed._functional_collectives.is_torchdynamo_compiling():
  1358. flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined, assignment]
  1359. else:
  1360. flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment]
  1361. @no_type_check
  1362. def _register_post_backward_final_callback(
  1363. state: _FSDPState, module: nn.Module
  1364. ) -> None:
  1365. """
  1366. Registers the post-backward final callback that runs at the end of the
  1367. backward pass. This should be called from the root FSDP instance at the
  1368. beginning of the pre-backward.
  1369. """
  1370. _p_assert(
  1371. state._is_root,
  1372. "Only the root FSDP instance should register the post-backward callback",
  1373. )
  1374. if state._post_backward_callback_queued:
  1375. return
  1376. _assert_in_training_states(state, [TrainingState.IDLE])
  1377. # Trace does not need this callback
  1378. if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
  1379. state._post_backward_callback_queued = True
  1380. Variable._execution_engine.queue_callback(
  1381. functools.partial(_post_backward_final_callback, state, module)
  1382. )
  1383. def _wait_for_computation_stream(
  1384. computation_stream: torch.Stream,
  1385. unshard_stream: torch.Stream,
  1386. pre_unshard_stream: torch.Stream,
  1387. ):
  1388. """
  1389. Has the unshard and pre-unshard streams wait for the computation stream.
  1390. For example, this should be called in the FSDP root's pre-forward to
  1391. respect optimizer step computation.
  1392. """
  1393. # Tracing does not need to wait
  1394. if torch.distributed._functional_collectives.is_torchdynamo_compiling():
  1395. return
  1396. unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
  1397. # Having the pre-all-gather stream wait for the current stream even if we
  1398. # do not leverage the pre-all-gather stream is tolerable since this only
  1399. # runs once per iteration
  1400. pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
  1401. def _reset_flat_param_grad_info_if_needed(
  1402. handles: List[FlatParamHandle],
  1403. ):
  1404. """
  1405. Clears the original parameters' gradients if needed. This method's CPU
  1406. overhead is minimal, so we may call it throughout FSDP methods, which serve
  1407. as callsites to free the gradient memory earlier.
  1408. """
  1409. if not isinstance(handles, list):
  1410. handles = [handles]
  1411. for handle in handles:
  1412. if handle._use_orig_params:
  1413. handle._reset_flat_param_grad_info_if_needed()
  1414. @no_type_check
  1415. def _get_buffers_and_dtypes_for_computation(
  1416. state: _FSDPState,
  1417. root_module: nn.Module,
  1418. ) -> Tuple[List[torch.Tensor], List[Optional[torch.dtype]]]:
  1419. """
  1420. Returns all buffers in the module tree rooted at ``root_module`` and a
  1421. corresponding list of the buffer dtypes for computation. Each buffer dtype
  1422. is either ``None`` if buffer mixed precision is not enabled or the buffer
  1423. low precision dtype otherwise.
  1424. """
  1425. _p_assert(state._is_root, "Expects the root to cast buffers")
  1426. buffers: List[torch.Tensor] = []
  1427. buffer_dtypes: List[Optional[torch.dtype]] = []
  1428. visited_buffers: Set[torch.Tensor] = set()
  1429. # Traverse the FSDP states bottom-up so that we prefer the owning FSDP
  1430. # instance's mixed precision setting for each buffer
  1431. fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules(
  1432. root_module
  1433. )
  1434. for fsdp_state, fsdp_module in zip(reversed(fsdp_states), reversed(fsdp_modules)):
  1435. for buffer_name, buffer in fsdp_module.named_buffers():
  1436. if buffer in visited_buffers:
  1437. continue
  1438. visited_buffers.add(buffer)
  1439. if clean_tensor_name(buffer_name) in fsdp_state._ignored_buffer_names:
  1440. continue
  1441. buffers.append(buffer)
  1442. buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype)
  1443. assert len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}"
  1444. return buffers, buffer_dtypes
  1445. @no_type_check
  1446. def _get_orig_buffer_dtypes(
  1447. state: _FSDPState,
  1448. buffer_names: List[str],
  1449. ) -> List[torch.dtype]:
  1450. """
  1451. Returns the original buffer types of the given buffer names.
  1452. """
  1453. buffer_dtypes: List[torch.dtype] = []
  1454. for buffer_name in buffer_names:
  1455. _p_assert(
  1456. buffer_name in state._buffer_name_to_orig_dtype,
  1457. f"{buffer_name} is missing from pre-computed dict on rank "
  1458. f"{state.rank}, which only has keys "
  1459. f"{state._buffer_name_to_orig_dtype.keys()}",
  1460. )
  1461. buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name])
  1462. return buffer_dtypes
  1463. def _cast_buffers_to_dtype_and_device(
  1464. buffers: List[torch.Tensor],
  1465. buffer_dtypes: List[Optional[torch.dtype]],
  1466. device: torch.device,
  1467. ) -> None:
  1468. """
  1469. Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them
  1470. to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
  1471. corresponding buffer is only moved to ``device``.
  1472. """
  1473. _p_assert(
  1474. buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
  1475. f"Expects `buffers` and `buffer_dtypes` to have the same length if "
  1476. f"`buffer_dtypes` is specified but got {len(buffers)} and "
  1477. f"{len(buffer_dtypes)}",
  1478. )
  1479. for buffer, buffer_dtype in zip(buffers, buffer_dtypes):
  1480. if not torch.is_floating_point(buffer) or buffer_dtype is None:
  1481. buffer.data = buffer.to(device=device)
  1482. else:
  1483. buffer.data = buffer.to(device=device, dtype=buffer_dtype)