_optim_utils.py 85 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import functools
  4. import logging
  5. import warnings
  6. from contextlib import ExitStack
  7. from dataclasses import dataclass, field
  8. from typing import (
  9. Any,
  10. cast,
  11. Dict,
  12. Iterable,
  13. Iterator,
  14. List,
  15. NamedTuple,
  16. no_type_check,
  17. Optional,
  18. Sequence,
  19. Set,
  20. Tuple,
  21. TYPE_CHECKING,
  22. Union,
  23. )
  24. import torch
  25. import torch.distributed as dist
  26. import torch.distributed.fsdp._traversal_utils as traversal_utils
  27. import torch.nn as nn
  28. from torch.distributed._state_dict_utils import _gather_state_dict
  29. from torch.distributed._tensor import DTensor, Replicate
  30. from torch.distributed.distributed_c10d import _get_pg_default_device
  31. from torch.distributed.fsdp._common_utils import (
  32. _apply_to_modules,
  33. _FSDPState,
  34. _get_module_fsdp_state_if_fully_sharded_module,
  35. _get_param_to_fqns,
  36. _module_handle,
  37. _named_parameters_with_duplicates,
  38. clean_tensor_name,
  39. )
  40. from torch.distributed.fsdp._debug_utils import SimpleProfiler
  41. from torch.distributed.fsdp._flat_param import FlatParameter, FlatParamHandle
  42. from torch.distributed.fsdp._fsdp_extensions import (
  43. _ext_chunk_dtensor,
  44. _ext_chunk_tensor,
  45. )
  46. from torch.distributed.fsdp._runtime_utils import (
  47. _lazy_init,
  48. _reset_flat_param_grad_info_if_needed,
  49. )
  50. from torch.distributed.fsdp.api import (
  51. ShardingStrategy,
  52. StateDictSettings,
  53. StateDictType,
  54. )
  55. from torch.utils._pytree import tree_map_only
  56. if TYPE_CHECKING:
  57. from torch.distributed._shard.sharded_tensor import ShardedTensor
  58. logger = logging.getLogger(__name__)
  59. @dataclass
  60. class FSDPParamInfo:
  61. state: _FSDPState
  62. handle: FlatParamHandle
  63. param_indices: Dict[str, int]
  64. param_requires_grad: List[bool]
  65. def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
  66. keys = sorted(dictionary.keys())
  67. for k in keys:
  68. yield k, dictionary[k]
  69. @dataclass
  70. class _ConsolidatedOptimState:
  71. """
  72. This holds the consolidated optimizer state on the target rank. Positive-
  73. dimension tensor state is communicated across ranks, while zero-dimension
  74. tensor state and non-tensor state is taken directly from the target rank.
  75. PyTorch version 1.12 moved to using zero-dimension tensors for scalar
  76. values, but user implemented optimizers may still use float (i.e. a
  77. non-tensor). Thus, we support both and handle them identically.
  78. Attributes:
  79. tensor_state (Dict[str, torch.Tensor]): Mapping from positive-dimension
  80. tensor state name to the unsharded flat tensor representing the
  81. state.
  82. zero_dim_tensor_state (Dict[str, torch.Tensor]): Mapping from zero-
  83. dimension tensor state name to its value.
  84. non_tensor_state (Dict[str, Any]): Mapping from non-tensor state
  85. name to its value.
  86. """
  87. tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
  88. zero_dim_tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
  89. non_tensor_state: Dict[str, Any] = field(default_factory=dict)
  90. class _PosDimTensorInfo(NamedTuple):
  91. """
  92. Meatadata for positive-dimension tensors used internally for
  93. :meth:`scatter_full_optim_state_dict`.
  94. Attributes:
  95. shape (torch.Size): Sharded tensor shape (which is equal to the
  96. unsharded tensor shape if the tensor is optimizer state for a
  97. non-FSDP parameter and is hence not sharded).
  98. dtype (torch.dtype): Data type of the tensor.
  99. """
  100. shape: torch.Size
  101. dtype: torch.dtype
  102. class _OptimStateKey(NamedTuple):
  103. """
  104. This represents an optimizer state key that may be used commonly across
  105. ranks. It is based on the unflattened parameter names rather than parameter
  106. IDs to make it independent of each rank's own optimizer construction.
  107. """
  108. unflat_param_names: Tuple[str, ...]
  109. is_fsdp_managed: bool
  110. def _unflatten_optim_state(
  111. fsdp_param_info: FSDPParamInfo,
  112. flat_param_state: Dict[str, Any],
  113. to_save: bool,
  114. shard_state: bool,
  115. cpu_offload: bool,
  116. ) -> List[Dict[str, Any]]:
  117. """
  118. Unflattens the optimizer state, consisting of the "state" part and the
  119. "param_groups" part. Unflattening the "state" part involves consolidating
  120. the state on the target rank and remapping from flattened to unflattened
  121. parameter IDs, and the "param_groups" part only involves remapping from
  122. flattened to unflattened parameter IDs.
  123. Args:
  124. fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
  125. mapping from FQN to original parameter index.
  126. flat_param_state (Dict[str, Any]): Entry for the flat parameter in the
  127. "state" part of the optimizer state dict.
  128. to_save (bool): Whether to save the state on this rank.
  129. Returns:
  130. List[Dict[str, Any]]: A :class:`list` holding the entries in the
  131. "state" part of the optimizer state dict corresponding to the
  132. unflattened parameters comprising the flat parameter if on the target
  133. rank or an empty :class:`list` otherwise. The final optimizer state
  134. dict will need to map these entries using the proper unflattened
  135. parameter IDs.
  136. """
  137. assert (
  138. not shard_state or to_save
  139. ), "If ``shard_state`` is True, ``to_save`` has to be True."
  140. consolidated_state = _communicate_optim_state(
  141. fsdp_param_info,
  142. flat_param_state,
  143. )
  144. if to_save:
  145. unflat_param_state = _unflatten_communicated_optim_state(
  146. fsdp_param_info,
  147. consolidated_state,
  148. shard_state,
  149. )
  150. for optim_state in unflat_param_state:
  151. # We can't use .items() below cuz we'd run into a concurrent modification error
  152. if cpu_offload:
  153. for key in list(optim_state.keys()):
  154. state = optim_state[key]
  155. if not isinstance(state, torch.Tensor):
  156. continue
  157. optim_state[key] = state.cpu()
  158. return unflat_param_state
  159. else:
  160. return []
  161. def _is_zero_dim_tensor(x: Any) -> bool:
  162. return torch.is_tensor(x) and x.dim() == 0
  163. def _communicate_optim_state(
  164. fsdp_param_info: FSDPParamInfo,
  165. flat_param_state: Dict[str, Any],
  166. ) -> _ConsolidatedOptimState:
  167. """
  168. Communicates the optimizer state for a flat parameter across ranks. All
  169. ranks will hold the entire non-sharded optimizer state on GPU.
  170. If ``N`` is the number of tensor optimizer states in the optimizer state
  171. dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1``
  172. otherwise (where the plus 1 comes from all-gathering the padding per rank).
  173. Args:
  174. fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
  175. mapping from FQN to original parameter index.
  176. flat_param_state (Dict[str, Any]): The entry in the "state" part of the
  177. optimizer state dict corresponding to the flat parameter.
  178. Returns:
  179. ConsolidatedOptimState: Consolidated optimizer state for the target
  180. flat parameter.
  181. """
  182. fsdp_state = fsdp_param_info.state
  183. flat_param = fsdp_param_info.handle.flat_param
  184. state = _ConsolidatedOptimState()
  185. tensor_state, zero_dim_tensor_state, non_tensor_state = (
  186. state.tensor_state,
  187. state.zero_dim_tensor_state,
  188. state.non_tensor_state,
  189. )
  190. for state_name, value in sorted_items(flat_param_state):
  191. # Positive-dimension tensor state: communicate across ranks
  192. if torch.is_tensor(value) and value.dim() > 0:
  193. # If the parameter is not sharded, then neither is the
  194. # positive-dimension tensor state, so no need to communicate it --
  195. # we take the target rank's value
  196. if (
  197. fsdp_state.world_size == 1
  198. or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
  199. ):
  200. tensor_state[state_name] = value
  201. continue
  202. assert (
  203. fsdp_state.compute_device is not None
  204. ), "compute_device has not been initialized"
  205. if value.device.type != fsdp_state.compute_device.type:
  206. value = value.to(fsdp_state.compute_device)
  207. # Assume that positive-dimension tensor optimizer state
  208. # has the same shape as the sharded flat parameter
  209. buffer_size = flat_param._full_param_padded.size() # type: ignore[attr-defined]
  210. tensor_buffer = value.new_zeros(*buffer_size)
  211. dist.all_gather_into_tensor(
  212. tensor_buffer, value, group=fsdp_state.process_group
  213. )
  214. fsdp_state._device_handle.synchronize()
  215. unpadded_numel = cast(
  216. nn.Parameter, flat_param._unpadded_unsharded_size
  217. ).numel()
  218. tensor_state[state_name] = tensor_buffer[:unpadded_numel]
  219. # Zero-dimension tensor state and non-tensor state: take this rank's
  220. # value directly
  221. else:
  222. if _is_zero_dim_tensor(value):
  223. zero_dim_tensor_state[state_name] = value.detach().clone()
  224. else:
  225. non_tensor_state[state_name] = value
  226. return state
  227. def _unflatten_communicated_optim_state(
  228. fsdp_param_info: FSDPParamInfo,
  229. state: _ConsolidatedOptimState,
  230. shard_state: bool,
  231. ) -> List[Dict[str, Any]]:
  232. """
  233. Unflattens the communicated optimizer state (given by ``tensor_state``,
  234. ``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flat
  235. parameter. This should only be called on the target rank.
  236. Args:
  237. fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
  238. mapping from FQN to original parameter index.
  239. state (_ConsolidatedOptimState): Consolidated optimizer state.
  240. Returns:
  241. List[Dict[str, Any]]: A :class:`list` holding the entries in the
  242. "state" part of the optimizer state dict corresponding to the
  243. unflattened parameters comprising the flat parameter. The final
  244. optimizer state dict will need to map these entries using the proper
  245. unflattened parameter IDs.
  246. """
  247. fsdp_state = fsdp_param_info.state
  248. handle = fsdp_param_info.handle
  249. flat_param = handle.flat_param
  250. unflat_param_state: List[Dict[str, Any]] = []
  251. flat_param_views: Dict[str, Iterator] = {}
  252. num_unflat_params = flat_param._num_params
  253. tensor_state, zero_dim_tensor_state, non_tensor_state = (
  254. state.tensor_state,
  255. state.zero_dim_tensor_state,
  256. state.non_tensor_state,
  257. )
  258. for _ in range(num_unflat_params):
  259. unflat_state_param = {}
  260. # Add positive-dimension tensor state: unflatten with views
  261. for state_name, flat_tensor in sorted_items(tensor_state):
  262. views_generated = state_name in flat_param_views
  263. if not views_generated:
  264. views = handle._get_unflat_views(flat_tensor)
  265. flat_param_views[state_name] = views
  266. else:
  267. views = flat_param_views[state_name]
  268. optim_state: Union[torch.Tensor, ShardedTensor, DTensor] = next(views)
  269. if shard_state:
  270. osd_config = fsdp_state._optim_state_dict_config
  271. if getattr(osd_config, "_use_dtensor", False):
  272. assert fsdp_state._device_mesh is not None
  273. optim_state = _ext_chunk_dtensor(
  274. optim_state,
  275. fsdp_state.rank,
  276. fsdp_state._device_mesh,
  277. fsdp_state._fsdp_extension,
  278. )
  279. else:
  280. assert fsdp_state.process_group is not None
  281. optim_state = _ext_chunk_tensor(
  282. optim_state,
  283. fsdp_state.rank,
  284. fsdp_state.world_size,
  285. fsdp_state._device_handle.device_count(),
  286. fsdp_state.process_group,
  287. fsdp_state._fsdp_extension,
  288. )
  289. unflat_state_param[state_name] = optim_state
  290. # Add zero-dimension tensor state: take the target rank's value
  291. for state_name, zero_dim_tensor in sorted_items(zero_dim_tensor_state):
  292. unflat_state_param[state_name] = zero_dim_tensor
  293. # Add non-tensor state: take the target rank's value
  294. for state_name, non_tensor in sorted_items(non_tensor_state):
  295. unflat_state_param[state_name] = non_tensor
  296. unflat_param_state.append(unflat_state_param)
  297. return unflat_param_state
  298. def _broadcast_processed_state(
  299. fsdp_state: _FSDPState,
  300. optim_state: Dict[str, Any],
  301. group: Optional[dist.ProcessGroup],
  302. ) -> Dict[str, Any]:
  303. objects: List[Any] = [None]
  304. if dist.get_rank(group) == 0:
  305. objects[0] = tree_map_only(
  306. torch.Tensor,
  307. lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), # type: ignore[union-attr]
  308. optim_state,
  309. )
  310. dist.broadcast_object_list(objects, src=0, group=group)
  311. if dist.get_rank(group) == 0:
  312. return optim_state
  313. else:
  314. return objects[0]
  315. def _broadcast_state(
  316. fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup]
  317. ) -> Any:
  318. if dist.get_rank(group) == 0:
  319. if not isinstance(state, torch.Tensor) or state.dim() == 0:
  320. return state
  321. tensor = state.to(fsdp_state.compute_device)
  322. else:
  323. if isinstance(state, torch.Tensor):
  324. assert state.dim() == 0, (
  325. "For non-zero ranks, a tensor state should have zero dimension, "
  326. "but got the state with shape {state.shape()}."
  327. )
  328. return state
  329. elif not isinstance(state, _PosDimTensorInfo):
  330. return state
  331. tensor = torch.zeros(
  332. state.shape, dtype=state.dtype, device=fsdp_state.compute_device
  333. )
  334. dist.broadcast(tensor, src=0, group=group)
  335. return tensor
  336. def _shard_orig_param_state(
  337. fsdp_param_info: FSDPParamInfo,
  338. fqn: str,
  339. optim_state: Dict[str, Any],
  340. ) -> Dict[str, Any]:
  341. """
  342. Shard the optimizer state for the original parameter with the name ``fqn``.
  343. This API should only be used when ``use_orig_params`` is True.
  344. """
  345. if not optim_state:
  346. return {}
  347. fsdp_state = fsdp_param_info.state
  348. flat_param = fsdp_param_info.handle.flat_param
  349. param_idx = fsdp_param_info.param_indices[fqn]
  350. shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined]
  351. optim_state = _gather_state_dict(
  352. optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device
  353. )
  354. if not shard_param_info.in_shard:
  355. return {}
  356. # Flatten and shard the state.
  357. new_optim_state: Dict[str, Any] = {}
  358. intra_param_start_idx = shard_param_info.intra_param_start_idx
  359. intra_param_end_idx = shard_param_info.intra_param_end_idx
  360. for state_name, value in optim_state.items():
  361. if (
  362. torch.is_tensor(value)
  363. and value.dim() > 0
  364. and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
  365. ):
  366. value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator]
  367. new_optim_state[state_name] = value
  368. return new_optim_state
  369. def _flatten_optim_state_dict(
  370. optim_state_dict: Dict[str, Any],
  371. model: nn.Module,
  372. use_orig_params: bool = False,
  373. optim: Optional[torch.optim.Optimizer] = None,
  374. rank0_only: bool = False,
  375. group: Optional[dist.ProcessGroup] = None,
  376. ) -> Dict[str, Any]:
  377. """
  378. Flattens the full optimizer state dict, still keying by unflattened parameter
  379. names.
  380. If ``use_orig_params`` is True, each rank will have all FSDP-managed
  381. parameters but some of these parameters may be empty due to the sharding.
  382. For a regular optim.Optimizer, states for those empty parameters will
  383. not be initialized. So, when aggregating the FQNs across ranks, no assert
  384. will be raised on a rank even if it does not have all the states -- it is
  385. valid and FSDP know how to aggregate them. However, FSDP has to ignore
  386. handling those parameters that are not managed by FSDP and do not exist on
  387. the local rank -- it is managed by other parallelism and FSDP does not
  388. know ho to handle/aggregate them.
  389. Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to
  390. flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require
  391. all the states even if the corresponding parameters are empty. To this end,
  392. ``optim`` will be used to to get the initial state of the empty parameters.
  393. ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or
  394. NamedOptimizer.
  395. Returns:
  396. Dict[str, Any]: The flattened optimizer state dict.
  397. """
  398. SimpleProfiler.reset()
  399. unflat_osd = optim_state_dict
  400. if "state" not in unflat_osd and not rank0_only:
  401. raise ValueError(
  402. '`optim_state_dict` must have the keys "state"'
  403. "to be a valid optimizer state dict"
  404. )
  405. param_to_fqns = _get_param_to_fqns(model)
  406. fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model)
  407. fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state
  408. # Broadcast unflat_osd without non-scalar tensor if rank0_only is True.
  409. if rank0_only:
  410. unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group)
  411. # Construct the "state" part
  412. flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {}
  413. unflat_osd_state = unflat_osd["state"]
  414. all_state_keys = set(unflat_osd_state.keys())
  415. for param, fqns in param_to_fqns.items():
  416. fqn = fqns[0]
  417. if fqn not in unflat_osd_state:
  418. continue
  419. all_state_keys.difference_update(fqns)
  420. if rank0_only:
  421. for fqn in fqns:
  422. if not unflat_osd_state[fqn]:
  423. continue
  424. for state_name in unflat_osd_state[fqn].keys():
  425. unflat_osd_state[fqn][state_name] = _broadcast_state(
  426. fsdp_state, unflat_osd_state[fqn][state_name], group=group
  427. )
  428. fqn = fqns[0]
  429. if fqn in fqn_to_fsdp_param_info:
  430. fsdp_param_info = fqn_to_fsdp_param_info[fqn]
  431. if use_orig_params:
  432. with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING):
  433. flat_state = _shard_orig_param_state(
  434. fsdp_param_info,
  435. fqn,
  436. unflat_osd_state[fqn],
  437. )
  438. else:
  439. flat_state = _flatten_optim_state(
  440. fsdp_param_info,
  441. unflat_osd_state,
  442. fqns,
  443. )
  444. key = _OptimStateKey(tuple(fqns), True)
  445. # Only include non-empty states since as expected by
  446. # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer
  447. # or NamedOptimizer.
  448. if flat_state:
  449. flat_osd_state[key] = flat_state
  450. elif use_orig_params:
  451. assert (
  452. len(fqns) == 1
  453. ), f"use_orig_params is True but there are multiple FQNs, {fqns}."
  454. if optim is not None: # NamedOptimizer or KeyedOptimizer case.
  455. state = optim.state.get(param, None) # type: ignore[call-overload]
  456. if state is not None:
  457. flat_osd_state[key] = copy.deepcopy(state)
  458. else:
  459. warnings.warn(
  460. f"optim_state[{key}] is not on rank{fsdp_state.rank}."
  461. )
  462. else:
  463. raise RuntimeError(
  464. f"The state of {key} is empty. This should happen when "
  465. "use_orig_params=True."
  466. )
  467. else: # do not flatten non-FSDP parameters' states
  468. assert len(fqns) == 1
  469. key = _OptimStateKey(tuple(fqns), False)
  470. flat_osd_state[key] = copy.copy(unflat_osd_state[fqn])
  471. if rank0_only:
  472. for fqn in fqns:
  473. if not unflat_osd_state[fqn]:
  474. continue
  475. for state_name, param_state in list(unflat_osd_state[fqn].items()):
  476. if fsdp_state.rank > 0:
  477. # Deference the tensor so that PyTorch can collect the memory.
  478. del unflat_osd_state[fqn][state_name]
  479. else:
  480. # Move the tensor in the original osd back to CPU to make the
  481. # original osd unaffected.
  482. unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][
  483. state_name
  484. ].cpu()
  485. # Handle user-defined state, states that are not associated with parameters.
  486. for key in all_state_keys:
  487. user_state = unflat_osd_state[key]
  488. if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params:
  489. user_state = _broadcast_state(fsdp_state, user_state, group=group)
  490. flat_osd_state[key] = copy.copy(user_state)
  491. SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ")
  492. # Construct the "param_groups" part -- copy as is since it will be
  493. # rekeyed later according to the target rank's optimizer
  494. # Only copy param_groups if it exists in unflat_osd
  495. if "param_groups" in unflat_osd:
  496. flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"])
  497. return {"state": flat_osd_state, "param_groups": flat_osd_param_groups}
  498. else:
  499. return {"state": flat_osd_state}
  500. def _flatten_optim_state(
  501. fsdp_param_info: FSDPParamInfo,
  502. unflat_osd_state: Dict[str, Dict[str, Any]],
  503. unflat_param_names: List[str],
  504. ) -> Dict[str, Any]:
  505. """
  506. Flattens the optimizer state in ``full_optim_state_dict`` for a single
  507. flat parameter in ``fsdp_param_info`` corresponding to the unflattened
  508. parameter names in ``unflat_param_names``.
  509. Args:
  510. fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
  511. mapping from FQN to original parameter index.
  512. unflat_osd_state (Dict[str, Dict[str, Any]]): The "state" part of the
  513. optimizer state dict corresponding to the unflattened parameters.
  514. unflat_param_names (List[str]): A :class:`list` of unflattened
  515. parameter names corresponding to the flat parameter ``flat_param``.
  516. Returns:
  517. Dict[str, Any]: A :class:`dict` mapping state names to their values for
  518. a particular flat parameter. The sharded optimizer state dict's "state"
  519. part will map a key to this returned value.
  520. """
  521. fsdp_state = fsdp_param_info.state
  522. handle = fsdp_param_info.handle
  523. flat_param = handle.flat_param
  524. num_unflat_params = len(unflat_param_names)
  525. assert num_unflat_params > 0, (
  526. "Expects at least one unflattened parameter corresponding to the "
  527. "flat parameter"
  528. )
  529. unflat_param_shapes = flat_param._shapes
  530. num_unflat_param_shapes = len(unflat_param_shapes)
  531. assert (
  532. num_unflat_params == num_unflat_param_shapes
  533. ), f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
  534. # Check if these unflattened parameters have any optimizer state
  535. has_state = [
  536. bool(unflat_param_name in unflat_osd_state)
  537. for unflat_param_name in unflat_param_names
  538. ]
  539. # If none of the unflattened parameters comprising this flat parameter have
  540. # any state, then we do not want an entry in the optimizer state dict
  541. if not any(has_state):
  542. return {} # no need to flatten any state
  543. # There may still be some unflattened parameters with state and some
  544. # without
  545. unflat_param_states = [
  546. _gather_state_dict(
  547. unflat_osd_state[unflat_param_name],
  548. pg=fsdp_state.process_group,
  549. device=fsdp_state.compute_device,
  550. )
  551. if unflat_param_name in unflat_osd_state
  552. else None
  553. for unflat_param_name in unflat_param_names
  554. ]
  555. # Check that the unflattened parameters have the same state names
  556. state_names = None
  557. for unflat_param_state in unflat_param_states:
  558. if unflat_param_state is None:
  559. continue
  560. if state_names is None:
  561. state_names = set(unflat_param_state.keys())
  562. else:
  563. if state_names != set(unflat_param_state.keys()):
  564. raise ValueError(
  565. "Differing optimizer state names for the unflattened "
  566. f"parameters: {unflat_param_names}"
  567. )
  568. assert state_names is not None
  569. # Flatten the state
  570. flat_state: Dict[str, Any] = {}
  571. for state_name in state_names:
  572. state_values = [
  573. unflat_param_state[state_name] if unflat_param_state is not None else None
  574. for unflat_param_state in unflat_param_states
  575. ]
  576. non_none_state_values = [v for v in state_values if v is not None]
  577. # If all ranks have None, this is a None value
  578. if not non_none_state_values:
  579. flat_state[state_name] = None
  580. continue
  581. are_pos_dim_tensors = are_zero_dim_tensors = are_non_tensors = True
  582. for v in non_none_state_values:
  583. are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0
  584. are_zero_dim_tensors &= _is_zero_dim_tensor(v)
  585. are_non_tensors &= not torch.is_tensor(v)
  586. types = {type(v) for v in non_none_state_values}
  587. if len(types) != 1 or not (
  588. are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors
  589. ):
  590. raise ValueError(
  591. f"Differing optimizer state types for state {state_name}, "
  592. f"values {non_none_state_values}, and unflattened parameter "
  593. f"names {unflat_param_names}"
  594. )
  595. if are_pos_dim_tensors:
  596. flat_tensor = _flatten_tensor_optim_state(
  597. state_name,
  598. state_values,
  599. unflat_param_names,
  600. unflat_param_shapes,
  601. handle,
  602. )
  603. # Shard the flattened tensor immediately to minimize max memory
  604. # usage
  605. if (
  606. fsdp_state.world_size != 1
  607. and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
  608. ):
  609. sharded_flat_tensor, _ = FlatParamHandle._get_shard(
  610. flat_tensor,
  611. fsdp_state.rank,
  612. fsdp_state.world_size,
  613. )
  614. else:
  615. sharded_flat_tensor = flat_tensor
  616. flat_state[state_name] = sharded_flat_tensor
  617. elif are_zero_dim_tensors:
  618. flat_state[state_name] = _flatten_zero_dim_tensor_optim_state(
  619. state_name,
  620. state_values,
  621. unflat_param_names,
  622. )
  623. else:
  624. assert are_non_tensors
  625. flat_state[state_name] = _flatten_non_tensor_optim_state(
  626. state_name,
  627. state_values,
  628. unflat_param_names,
  629. )
  630. return flat_state
  631. def _flatten_tensor_optim_state(
  632. state_name: str,
  633. pos_dim_tensors: List[torch.Tensor],
  634. unflat_param_names: List[str],
  635. unflat_param_shapes: Sequence[torch.Size],
  636. handle: FlatParamHandle,
  637. ) -> torch.Tensor:
  638. """
  639. Flattens the positive-dimension tensor optimizer state given by the values
  640. ``tensors`` for the state ``state_name`` for a single flat parameter
  641. from ``handle`` corresponding to the unflattened parameter names
  642. ``unflat_param_names`` and unflatted parameter shapes
  643. ``unflat_param_shapes``. This flattens each unflattened parameter's tensor
  644. state into one tensor.
  645. NOTE: We use zero tensors for any unflattened parameters without state
  646. since some value is required to fill those entries. This assumes that the
  647. zero tensor is mathematically equivalent to having no state, which is true
  648. for Adam's "exp_avg" and "exp_avg_sq" but may not be true for all
  649. optimizers.
  650. Args:
  651. state_name (str): Optimizer state name.
  652. pos_dim_tensors (List[torch.Tensor]): Positive-dimension tensor
  653. optimizer state values for the unflattened parameters corresponding
  654. to the single flat parameter.
  655. unflat_param_names (List[str]): A :class:`list` of unflattened
  656. parameter names corresponding to the single flat parameter.
  657. unflat_param_shapes (List[torch.Size]): Unflattened parameter shapes
  658. corresponding to the single flat parameter.
  659. handle (FlatParamHandle): The flat parameter's handle.
  660. Returns:
  661. torch.Tensor: A flat tensor containing the optimizer state
  662. corresponding to ``state_name`` constructed by concatenating the
  663. unflattened parameter tensor states in ``pos_dim_tensors`` (using zero
  664. tensors for any unflattened parameters without the state).
  665. """
  666. flat_param = handle.flat_param
  667. non_none_tensors = [t for t in pos_dim_tensors if t is not None]
  668. # Check that all are tensors with the same dtype
  669. dtypes = {t.dtype for t in non_none_tensors}
  670. if len(dtypes) != 1:
  671. raise ValueError(
  672. "All unflattened parameters comprising a single flat "
  673. "parameter must have positive-dimension tensor state with the "
  674. f"same dtype but got dtypes {dtypes} for state {state_name} and "
  675. f"unflattened parameter names {unflat_param_names}"
  676. )
  677. dtype = next(iter(dtypes))
  678. # Check that each tensor state matches its parameter's shape
  679. for tensor, shape in zip(pos_dim_tensors, unflat_param_shapes):
  680. if tensor is None and len(shape) == 0:
  681. raise ValueError("Flattening a zero-dimension parameter is not supported")
  682. elif tensor is not None and tensor.shape != shape:
  683. raise ValueError(
  684. "Tensor optimizer state does not have same shape as its "
  685. f"parameter: {tensor.shape} {shape}"
  686. )
  687. # Flatten the tensor states: we do not need to add any right-hand-side
  688. # padding since the flat optimizer state tensor is sharded via
  689. # `_get_shard()`, which pads the shard as needed (just like for the flat
  690. # parameter)
  691. cpu_device = torch.device("cpu")
  692. tensors_to_flatten = [
  693. torch.flatten(state_value.to(cpu_device))
  694. if state_value is not None
  695. else torch.flatten(
  696. torch.zeros(
  697. size=shape,
  698. dtype=dtype,
  699. device=cpu_device,
  700. )
  701. )
  702. for state_value, shape in zip(pos_dim_tensors, unflat_param_shapes)
  703. ]
  704. flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel)
  705. flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined]
  706. assert flat_tensor.shape == flat_param_shape, (
  707. f"tensor optim state: {flat_tensor.shape} "
  708. f"flat parameter: {flat_param_shape}"
  709. )
  710. return flat_tensor
  711. def _flatten_zero_dim_tensor_optim_state(
  712. state_name: str,
  713. zero_dim_tensors: List[torch.Tensor],
  714. unflat_param_names: List[str],
  715. ) -> torch.Tensor:
  716. """
  717. Flattens the zero-dimension tensor optimizer state given by the values
  718. ``zero_dim_tensors`` for the state ``state_name`` for a single flat
  719. parameter corresponding to the unflattened parameter names
  720. ``unflat_param_names`` by enforcing that all tensors are the same and using
  721. that common value.
  722. NOTE: The requirement that the tensors are the same across all unflattened
  723. parameters comprising the flat parameter is needed to maintain the
  724. invariant that FSDP performs the same computation as its non-sharded
  725. equivalent. This means that none of the unflattened parameters can be
  726. missing this state since imposing a value may differ from having no value.
  727. For example, for Adam's "step", no value means maximum bias correction,
  728. while having some positive value means less bias correction.
  729. Args:
  730. state_name (str): Optimizer state name.
  731. zero_dim_tensors (List[torch.Tensor]): Zero-dimension optimizer state
  732. for the unflattened parameters corresponding to the single
  733. flat parameter.
  734. unflat_param_names (List[str]): A :class:`list` of unflattened
  735. parameter names corresponding to the single flat parameter.
  736. Returns:
  737. torch.Tensor: A zero-dimensional tensor giving the value of the state
  738. ``state_name`` for all unflattened parameters corresponding to the
  739. names ``unflat_param_names``.
  740. """
  741. non_none_tensors = [t for t in zero_dim_tensors if t is not None]
  742. # Enforce that all have the same value and dtype
  743. values_set = {t.item() if t is not None else None for t in zero_dim_tensors}
  744. dtypes = {t.dtype if t is not None else None for t in zero_dim_tensors}
  745. if (
  746. len(non_none_tensors) != len(zero_dim_tensors)
  747. or len(values_set) != 1
  748. or len(dtypes) != 1
  749. ):
  750. raise ValueError(
  751. "All unflattened parameters comprising a single flat "
  752. "parameter must have scalar state with the same value and dtype "
  753. f"but got values {values_set} and dtypes {dtypes} for state "
  754. f"{state_name} and unflattened parameter names "
  755. f"{unflat_param_names}"
  756. )
  757. value = next(iter(values_set))
  758. dtype = next(iter(dtypes))
  759. return torch.tensor(value, dtype=dtype, device=torch.device("cpu"))
  760. def _flatten_non_tensor_optim_state(
  761. state_name: str,
  762. non_tensors: List[Any],
  763. unflat_param_names: List[str],
  764. ) -> Any:
  765. """
  766. Flattens the non-tensor optimizer state given by the values ``non_tensors``
  767. for the state ``state_name`` for a single flat parameter corresponding
  768. to the unflattened parameter names ``unflat_param_names`` by enforcing that
  769. all values are the same and using that common value.
  770. See the note in :func:`_flatten_zero_dim_tensor_optim_state`.
  771. Args:
  772. state_name (str): Optimizer state name.
  773. non_tensors (List[Any]): Non-tensor optimizer state for the unflattened
  774. parameters corresponding to the single flat parameter.
  775. unflat_param_names (List[str]): A :class:`list` of unflattened
  776. parameter names corresponding to the single flat parameter.
  777. Returns:
  778. Any: A non-tensor giving the value of the state ``state_name`` for all
  779. unflattened parameters corresponding to the names
  780. ``unflat_param_names``.
  781. """
  782. non_none_non_tensors = [nt for nt in non_tensors if nt is not None]
  783. # Enforce that all have the same value (same type already checked)
  784. non_tensor_set = set(non_tensors)
  785. if len(non_none_non_tensors) != len(non_tensors) or len(non_tensor_set) != 1:
  786. raise ValueError(
  787. "All unflattened parameters comprising a single flat "
  788. "parameter must have scalar state with the same value and dtype "
  789. f"but got values {non_tensor_set} for state {state_name} and "
  790. f"unflattened parameter names {unflat_param_names}"
  791. )
  792. non_tensor = next(iter(non_tensor_set))
  793. return non_tensor
  794. def _rekey_sharded_optim_state_dict(
  795. sharded_osd: Dict[str, Any],
  796. model: nn.Module,
  797. optim: torch.optim.Optimizer,
  798. optim_input: Optional[
  799. Union[
  800. List[Dict[str, Any]],
  801. Iterable[nn.Parameter],
  802. ]
  803. ],
  804. using_optim_input: bool,
  805. is_named_optimizer: bool = False,
  806. ) -> Dict[str, Any]:
  807. """
  808. Rekeys the optimizer state dict from unflattened parameter names to flat
  809. parameter IDs according to the calling rank's ``optim``, which may be
  810. different across ranks. In particular, the unflattened parameter names are
  811. represented as :class:`_OptimStateKey` s.
  812. """
  813. param_to_fqns = _get_param_to_fqns(model)
  814. flat_param_to_fqn = _get_flat_param_to_fqn(model)
  815. param_to_param_key: Dict[nn.Parameter, Union[int, str]] = cast(
  816. Dict[nn.Parameter, Union[int, str]],
  817. (
  818. _get_param_to_param_id_from_optim_input(model, optim_input)
  819. if using_optim_input
  820. else _get_param_to_param_key(
  821. optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn
  822. )
  823. ),
  824. )
  825. # All parameter keys in `param_to_param_key` should be in
  826. # `param_to_fqns` -- strict inequality follows when not all parameters are
  827. # passed to the optimizer
  828. assert len(param_to_param_key) <= len(param_to_fqns)
  829. unflat_param_names_to_flat_param_key: Dict[
  830. Tuple[str, ...], Union[int, str]
  831. ] = {} # for "state"
  832. unflat_param_name_to_flat_param_key: Dict[
  833. str, Union[int, str]
  834. ] = {} # for "param_groups"
  835. for param, unflat_param_names in param_to_fqns.items():
  836. if param not in param_to_param_key:
  837. # This parameter was not passed to the optimizer
  838. continue
  839. flat_param_key = param_to_param_key[param]
  840. unflat_param_names_to_flat_param_key[tuple(unflat_param_names)] = flat_param_key
  841. for unflat_param_name in unflat_param_names:
  842. unflat_param_name_to_flat_param_key[unflat_param_name] = flat_param_key
  843. sharded_osd_state = sharded_osd["state"]
  844. rekeyed_osd_state: Dict[Union[str, int], Any] = {}
  845. for key, param_state in sharded_osd_state.items():
  846. if isinstance(key, str):
  847. rekeyed_osd_state[key] = param_state
  848. continue
  849. flat_param_key = unflat_param_names_to_flat_param_key.get(
  850. key.unflat_param_names, key.unflat_param_names
  851. )
  852. rekeyed_osd_state[flat_param_key] = param_state
  853. # Only process param_groups if it exists in sharded_osd
  854. if "param_groups" in sharded_osd:
  855. rekeyed_osd_param_groups: List[Dict[str, Any]] = []
  856. for unflat_param_group in sharded_osd["param_groups"]:
  857. flat_param_group = copy.deepcopy(unflat_param_group)
  858. flat_param_keys = sorted(
  859. {
  860. unflat_param_name_to_flat_param_key[unflat_param_name]
  861. for unflat_param_name in unflat_param_group["params"]
  862. }
  863. )
  864. flat_param_group["params"] = flat_param_keys
  865. rekeyed_osd_param_groups.append(flat_param_group)
  866. return {"state": rekeyed_osd_state, "param_groups": rekeyed_osd_param_groups}
  867. else:
  868. return {"state": rekeyed_osd_state}
  869. def _get_param_id_to_param_from_optim_input(
  870. model: nn.Module,
  871. optim_input: Optional[
  872. Union[
  873. List[Dict[str, Any]],
  874. Iterable[nn.Parameter],
  875. ]
  876. ] = None,
  877. ) -> Dict[int, nn.Parameter]:
  878. """
  879. Constructs a mapping from parameter IDs to parameters. This may be used
  880. both for models with ``FlatParameter`` s and without.
  881. NOTE: This method is only preserved for backward compatibility. The method
  882. :meth:`_get_param_key_to_param` is the preferred code path that does not
  883. rely on ``optim_input``.
  884. NOTE: We critically assume that, whether the optimizer input is a list of
  885. parameters or a list of parameter groups, :class:`torch.optim.Optimizer`
  886. enumerates the parameter IDs in order. In other words, for a parameter list
  887. input, the parameter IDs should be in that list order, and for a parameter
  888. groups input, the parameter IDs should be in order within each parameter
  889. group and in order across parameter groups.
  890. Args:
  891. model (nn.Module): Model whose parameters are passed into the
  892. optimizer.
  893. optim_input (Optional[Union[List[Dict[str, Any]],
  894. Iterable[nn.Parameter]]]): Input passed into the optimizer
  895. representing either a :class:`list` of parameter groups or an
  896. iterable of parameters; if ``None``, then this method assumes the
  897. input was ``model.parameters()``. (Default: ``None``)
  898. Returns:
  899. List[nn.Parameter]: Mapping from parameter IDs to parameters,
  900. where the parameter ID is implicitly the index in the :class:`list`.
  901. """
  902. # Assume the standard case of passing `model.parameters()` to the optimizer
  903. # if `optim_input` is not specified
  904. if optim_input is None:
  905. return dict(enumerate(model.parameters()))
  906. try:
  907. params = cast(List[nn.Parameter], list(optim_input))
  908. except TypeError as e:
  909. raise TypeError(
  910. "Optimizer input should be an iterable of Tensors or dicts, "
  911. f"but got {optim_input}"
  912. ) from e
  913. if len(params) == 0:
  914. raise ValueError("Optimizer input should not be empty")
  915. # Check if the optimizer input represents tensors or parameter groups
  916. all_tensors = True
  917. all_dicts = True
  918. for param in params:
  919. all_tensors &= isinstance(param, torch.Tensor)
  920. all_dicts &= isinstance(param, dict)
  921. if not all_tensors and not all_dicts:
  922. raise TypeError("Optimizer input should be an iterable of Tensors or dicts")
  923. if all_tensors:
  924. return dict(enumerate(params))
  925. assert all_dicts
  926. param_id_to_param: List[nn.Parameter] = []
  927. for param_group in params:
  928. has_params_key = "params" in param_group # type: ignore[operator]
  929. assert has_params_key, (
  930. 'A parameter group should map "params" to a list of the '
  931. "parameters in the group"
  932. )
  933. # Implicitly map `flat_param_id` (current length of the list) to
  934. # `param`
  935. param_id_to_param.extend(param_group["params"]) # type: ignore[index]
  936. return dict(enumerate(param_id_to_param))
  937. def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[FlatParameter, str]:
  938. """
  939. Constructs a mapping from ``FlatParameter`` to a cleaned (devoid of prefixes
  940. from wrappers) fully qualified name (FQN). Note that this FQN is "non-canonical"
  941. because ``FlatParameter`` s do not come from the original module but are
  942. registered only after FSDP has been applied. This function returns the FSDP-given
  943. name for the ``FlatParameter`` (usually module._flat_param) as opposed to the
  944. canonical FQNs returned for ``FlatParameter`` s in ``_common_utils._get_param_to_fqns(...)``).
  945. Consequently, this function will only return a non-empty mapping if FSDP was
  946. applied with ``use_orig_params=False`` as, otherwise, the original parameters
  947. are used within the module and there would be no ``FlatParameter`` s in the module.
  948. """
  949. def module_fn(module, prefix, tree_level, flat_param_to_fqn):
  950. for param_name, param in _named_parameters_with_duplicates(
  951. module, recurse=False
  952. ):
  953. if not isinstance(param, FlatParameter):
  954. continue
  955. fqn = clean_tensor_name(prefix + param_name)
  956. flat_param_to_fqn[param] = fqn
  957. def return_fn(flat_param_to_fqn):
  958. return flat_param_to_fqn
  959. flat_param_to_fqn_ret: Dict[FlatParameter, str] = {}
  960. return _apply_to_modules(
  961. model,
  962. module_fn,
  963. return_fn,
  964. [fqn for fqn, _ in _named_parameters_with_duplicates(model)],
  965. flat_param_to_fqn_ret,
  966. )
  967. def _get_param_key_to_param(
  968. optim: torch.optim.Optimizer,
  969. model: Optional[nn.Module] = None,
  970. is_named_optimizer: bool = False,
  971. param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None,
  972. flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None,
  973. ) -> Dict[Union[int, str], nn.Parameter]:
  974. """
  975. Constructs a mapping from parameter keys to parameters. For the regular
  976. optimizers, the keys are parameter IDs. For NamedOptimizer, the keys
  977. are FQNs. This API may be used both for models with ``FlatParameter`` s and
  978. without.
  979. """
  980. clean_fqn_to_curr_fqn: Dict[str, str] = {}
  981. if is_named_optimizer:
  982. assert (
  983. param_to_fqns is not None and flat_param_to_fqn is not None
  984. ), "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
  985. assert model is not None
  986. for key, _ in _named_parameters_with_duplicates(model):
  987. clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
  988. param_key_to_param: Dict[Union[str, int], nn.Parameter] = {}
  989. pid = 0
  990. for param_group in optim.param_groups:
  991. if is_named_optimizer:
  992. for param in param_group["params"]:
  993. assert flat_param_to_fqn is not None
  994. if param in flat_param_to_fqn:
  995. # FlatParameter case
  996. key = flat_param_to_fqn[param]
  997. else:
  998. assert param_to_fqns is not None
  999. # use_orig_params case
  1000. assert len(param_to_fqns[param]) == 1
  1001. key = param_to_fqns[param][0]
  1002. try:
  1003. key = clean_fqn_to_curr_fqn[key]
  1004. except KeyError as e:
  1005. raise KeyError(
  1006. f"Can't find {key} from {list(clean_fqn_to_curr_fqn.keys())}."
  1007. ) from e
  1008. param_key_to_param[key] = param
  1009. else:
  1010. for param in param_group["params"]:
  1011. param_key_to_param[pid] = param
  1012. pid += 1
  1013. return param_key_to_param
  1014. def _get_param_to_param_key(
  1015. optim: torch.optim.Optimizer,
  1016. model: Optional[nn.Module] = None,
  1017. is_named_optimizer: bool = False,
  1018. param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None,
  1019. flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None,
  1020. ) -> Dict[nn.Parameter, Union[int, str]]:
  1021. """
  1022. Constructs the inverse mapping of :func:`_get_param_key_to_param`. This API
  1023. only supports the case where `optim` is a regular optimizer, not NamedOptimizer.
  1024. So the parameter keys will be parameter ids.
  1025. """
  1026. param_id_to_param = _get_param_key_to_param(
  1027. optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn
  1028. )
  1029. return {param: param_id for param_id, param in param_id_to_param.items()}
  1030. def _get_param_to_param_id_from_optim_input(
  1031. model: nn.Module,
  1032. optim_input: Optional[
  1033. Union[
  1034. List[Dict[str, Any]],
  1035. Iterable[nn.Parameter],
  1036. ]
  1037. ] = None,
  1038. ) -> Dict[nn.Parameter, int]:
  1039. """Constructs the inverse mapping of :func:`_get_param_id_to_param_from_optim_input`."""
  1040. param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input)
  1041. return {param: param_id for param_id, param in param_id_to_param.items()}
  1042. def _check_missing_keys_on_rank(
  1043. r0_optim_state_keys: List[_OptimStateKey],
  1044. optim_state_key_to_param_key: Dict[_OptimStateKey, Union[str, int]],
  1045. param_key_to_param: Dict[Union[str, int], nn.Parameter],
  1046. group: Optional[dist.ProcessGroup],
  1047. ) -> None:
  1048. # Ensure that all ranks have at least the optimizer states needed by
  1049. # rank 0's optimizer
  1050. missing_keys: List[_OptimStateKey] = []
  1051. for r0_optim_state_key in r0_optim_state_keys:
  1052. if r0_optim_state_key not in optim_state_key_to_param_key:
  1053. # A parameter from rank 0's optimizer does not exist for this
  1054. # rank's optimizer
  1055. missing_keys.append(r0_optim_state_key)
  1056. continue
  1057. param_key = optim_state_key_to_param_key[r0_optim_state_key]
  1058. if isinstance(param_key, int):
  1059. assert param_key >= 0 and param_key < len(
  1060. param_key_to_param
  1061. ), "Check the `param_key_to_param` construction"
  1062. # We cannot use FSDPState.compute_device as this API is a global view.
  1063. device = _get_pg_default_device(group)
  1064. num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device)
  1065. dist.all_reduce(num_missing, group=group)
  1066. if num_missing.item() > 0:
  1067. obj_list = [None for _ in range(dist.get_world_size(group))]
  1068. dist.all_gather_object(obj_list, missing_keys, group=group)
  1069. error_msg = (
  1070. "FSDP currently requires each rank to have at least the "
  1071. "optimizer states needed by rank 0's optimizer but some ranks "
  1072. "are missing some of those states"
  1073. )
  1074. for rank, keys in enumerate(obj_list):
  1075. keys = cast(List[_OptimStateKey], keys)
  1076. if len(keys) > 0:
  1077. error_msg += (
  1078. f"\nRank {rank} is missing states for the parameters: "
  1079. f"{[key.unflat_param_names for key in keys]}"
  1080. )
  1081. raise RuntimeError(error_msg)
  1082. def _map_param_key_to_optim_keys(
  1083. optim_state_dict: Dict[str, Any],
  1084. group: Optional[dist.ProcessGroup],
  1085. param_key_to_param: Dict[Union[int, str], nn.Parameter],
  1086. param_to_fqns: Dict[nn.Parameter, List[str]],
  1087. fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
  1088. merge_keys: bool = False,
  1089. ) -> Tuple[List[_OptimStateKey], Dict[_OptimStateKey, Union[int, str]]]:
  1090. """
  1091. Construct the local mapping between the ``_OptimStateKey`` and parameter keys
  1092. and all the ``_OptimStateKey`` across ranks. If ``merge_keys`` is False, rank0
  1093. must contain all the ``_OptimStateKey``, an exception will be raised otherwise.
  1094. Note that ``merge_keys`` should equal to ``use_orig_params``.
  1095. """
  1096. rank = dist.get_rank(group)
  1097. optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]] = {} # local
  1098. all_optim_state_keys: List[_OptimStateKey] = []
  1099. for param_key, param in param_key_to_param.items():
  1100. # Do not include parameters without state to avoid empty mappings
  1101. # just like in normal `torch.optim.Optimizer.state_dict()`
  1102. if param_key not in optim_state_dict["state"]:
  1103. continue
  1104. fqns = param_to_fqns[param]
  1105. is_fsdp_managed = isinstance(param, FlatParameter)
  1106. if is_fsdp_managed:
  1107. assert fqns[0] in fqn_to_fsdp_param_info, (
  1108. fqns[0],
  1109. list(fqn_to_fsdp_param_info.keys()),
  1110. )
  1111. is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info
  1112. optim_state_key = _OptimStateKey(
  1113. unflat_param_names=tuple(fqns),
  1114. is_fsdp_managed=is_fsdp_managed,
  1115. )
  1116. if rank == 0 or merge_keys:
  1117. all_optim_state_keys.append(optim_state_key)
  1118. optim_state_key_to_param_key[optim_state_key] = param_key
  1119. if merge_keys:
  1120. all_keys: List[List[_OptimStateKey]] = [
  1121. [] for _ in range(dist.get_world_size(group))
  1122. ]
  1123. dist.all_gather_object(all_keys, all_optim_state_keys, group=group)
  1124. merge_all_optim_state_keys = [
  1125. key for local_keys in all_keys for key in local_keys
  1126. ]
  1127. all_optim_state_keys = sorted(set(merge_all_optim_state_keys))
  1128. else:
  1129. key_obj_list: List[Optional[List[_OptimStateKey]]] = (
  1130. [all_optim_state_keys] if rank == 0 else [None]
  1131. )
  1132. dist.broadcast_object_list(key_obj_list, src=0, group=group)
  1133. assert key_obj_list[0] is not None
  1134. all_optim_state_keys = key_obj_list[0]
  1135. _check_missing_keys_on_rank(
  1136. all_optim_state_keys,
  1137. optim_state_key_to_param_key,
  1138. param_key_to_param,
  1139. group,
  1140. )
  1141. return all_optim_state_keys, optim_state_key_to_param_key
  1142. def _unflatten_param_groups(
  1143. state_dict: Dict[str, Any],
  1144. param_key_to_param: Dict[Union[int, str], nn.Parameter],
  1145. param_to_fqns: Dict[nn.Parameter, List[str]],
  1146. ) -> List[Dict[str, Any]]:
  1147. param_groups: List[Dict[str, Any]] = []
  1148. for flat_param_group in state_dict["param_groups"]:
  1149. unflat_param_group = copy.deepcopy(flat_param_group)
  1150. param_group_params = [
  1151. param_key_to_param[flat_param_key]
  1152. for flat_param_key in flat_param_group["params"]
  1153. ]
  1154. nested_unflat_param_names = [
  1155. param_to_fqns[param] for param in param_group_params
  1156. ]
  1157. unflat_param_group["params"] = [
  1158. unflat_param_name
  1159. for unflat_param_names in nested_unflat_param_names
  1160. for unflat_param_name in unflat_param_names
  1161. ] # flatten the list of lists
  1162. param_groups.append(unflat_param_group)
  1163. return param_groups
  1164. def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool:
  1165. """
  1166. Returns whether the state_dict is from a NamedOptimizer.
  1167. This function checks that the keys in the state_dict['state'] are strings
  1168. (which usually are FQNs) versus integers (which usually refer to param_ids
  1169. from a vanilla torch.optim.Optimizer).
  1170. """
  1171. state = optim_state_dict.get("state", None)
  1172. if not state:
  1173. # If we cannot find a state, assume it is not NamedOptimizer as
  1174. # NamedOptimizer has eager initialization.
  1175. return False
  1176. try:
  1177. key = next(iter(state.keys()))
  1178. except Exception as e:
  1179. raise Exception(optim_state_dict) from e # noqa: TRY002
  1180. return isinstance(key, str)
  1181. @dataclass
  1182. class StateInfo:
  1183. # The key of these dictionaries are the state name, e.g., `exp_avg`.
  1184. tensors: Dict[str, _PosDimTensorInfo]
  1185. scalar_tensors: Dict[str, torch.Tensor]
  1186. non_tensors: Dict[str, Any]
  1187. def _allgather_state_info(
  1188. fsdp_state: _FSDPState,
  1189. input_states: Dict[str, Any],
  1190. ) -> List[Dict[str, StateInfo]]:
  1191. """
  1192. Given the ``input_states``, allgather StateInfo for each state. The function
  1193. uses all_gather_object to gather StateInfo so no GPU tensors are sent.
  1194. """
  1195. processed_state_dict: Dict[str, StateInfo] = {}
  1196. gathered_state_info: List[Dict[str, StateInfo]] = [
  1197. {} for _ in range(fsdp_state.world_size)
  1198. ]
  1199. for fqn, optim_state in input_states.items():
  1200. # Allgather the scalar tensor state, non-tensor states and tensors metadata.
  1201. processed_state = StateInfo({}, {}, {})
  1202. for state_name, value in sorted_items(optim_state):
  1203. if torch.is_tensor(value):
  1204. if value.dim() == 0:
  1205. # Ensure that `step` is on CPU.
  1206. processed_state.scalar_tensors[state_name] = value.cpu()
  1207. else:
  1208. processed_state.tensors[state_name] = _PosDimTensorInfo(
  1209. value.shape, value.dtype
  1210. )
  1211. else:
  1212. processed_state.non_tensors[state_name] = value
  1213. processed_state_dict[fqn] = processed_state
  1214. dist.all_gather_object(
  1215. gathered_state_info,
  1216. processed_state_dict,
  1217. group=fsdp_state.process_group,
  1218. )
  1219. return gathered_state_info
  1220. def _convert_all_state_info(
  1221. fsdp_param_info: FSDPParamInfo,
  1222. gathered_state_info: List[Dict[str, StateInfo]],
  1223. input_states: Dict[str, Any],
  1224. output_states: Dict[str, Dict[str, Any]],
  1225. ) -> Tuple[Optional[torch.dtype], Dict[str, List[Optional[torch.Tensor]]]]:
  1226. """
  1227. Given the ``gathered_state_info`` and ``input_states``, the API converted
  1228. the StateInfo into the original state if the state is not a non-scalar
  1229. tensor. For a multi-dimensional tensor, the local state will be stored in
  1230. ``state_buffer`` in a correct order for later allgather purpose.
  1231. """
  1232. state_buffers: Dict[str, List[Optional[torch.Tensor]]] = {}
  1233. for fqn, gathered_state in output_states.items():
  1234. state_info = [s[fqn] for s in gathered_state_info]
  1235. all_tensor_states = sorted(
  1236. {n for state in state_info for n in state.tensors.keys()}
  1237. )
  1238. empty_ranks: Set[int] = set()
  1239. dtype: Optional[torch.dtype] = None
  1240. # First check all the non-scalar states and get the information of
  1241. # states on each rank.
  1242. for state_name in all_tensor_states:
  1243. numels = []
  1244. _empty_ranks: Set[int] = set()
  1245. for rank, object_state in enumerate(state_info):
  1246. numels.append(0)
  1247. info = object_state.tensors.get(state_name, None)
  1248. if info is not None:
  1249. numels[-1] = info.shape.numel()
  1250. if not dtype:
  1251. dtype = info.dtype
  1252. else:
  1253. assert dtype == info.dtype
  1254. if numels[-1] == 0:
  1255. _empty_ranks.add(rank)
  1256. assert not empty_ranks or empty_ranks == _empty_ranks
  1257. empty_ranks = _empty_ranks
  1258. if state_name not in state_buffers:
  1259. state_buffers[state_name] = [
  1260. None for _ in fsdp_param_info.param_indices
  1261. ]
  1262. local_state = input_states[fqn].get(state_name, None)
  1263. # N.B. We need to move the state to compute_device. The reason is
  1264. # not yet clear and we need to figure out why the state may be on a
  1265. # different device.
  1266. if local_state is not None:
  1267. local_state = local_state.to(fsdp_param_info.state.compute_device)
  1268. state_buffers[state_name][fsdp_param_info.param_indices[fqn]] = local_state
  1269. # Restoring the scalar and non-tensor states. If the corresponding
  1270. # non-scalar states do not exist on the rank, we also skip the scalar
  1271. # non-tensor states on that rank.
  1272. for rank, object_state in enumerate(state_info):
  1273. if rank in empty_ranks:
  1274. continue
  1275. for name, non_tensor_value in object_state.non_tensors.items():
  1276. curr_non_tensor_value = gathered_state.get(name, None)
  1277. assert (
  1278. curr_non_tensor_value is None
  1279. or curr_non_tensor_value == non_tensor_value
  1280. ), (
  1281. f"Rank {rank} has different values for {name}: {non_tensor_value}."
  1282. + f" Other ranks: {curr_non_tensor_value}"
  1283. )
  1284. gathered_state[name] = non_tensor_value
  1285. for name, scalar_tensor_value in object_state.scalar_tensors.items():
  1286. curr_scalar_tensor_value = gathered_state.get(name, None)
  1287. assert curr_scalar_tensor_value is None or torch.equal(
  1288. scalar_tensor_value, curr_scalar_tensor_value
  1289. ), (
  1290. f"Rank {rank} has different values for {name}: {scalar_tensor_value}."
  1291. + f" Other ranks: {curr_scalar_tensor_value}"
  1292. )
  1293. gathered_state[name] = scalar_tensor_value
  1294. return dtype, state_buffers # type: ignore[possibly-undefined]
  1295. def _unflatten_orig_param_states(
  1296. fsdp_param_info: FSDPParamInfo,
  1297. output_states: Dict[str, Dict[str, Any]],
  1298. state_name: str,
  1299. shard_state: bool,
  1300. to_save: bool,
  1301. cpu_offload: bool,
  1302. ) -> None:
  1303. """
  1304. Given a output state dict, ``output_states``, which the keys are FQNs to the
  1305. original parameters (not FlatParameters nor parmeter ID), and the values
  1306. are gathered states, unflatten the states to the original dimensions.
  1307. This function performs the unflattening process in-place.
  1308. """
  1309. if not to_save:
  1310. return
  1311. flat_param = fsdp_param_info.handle.flat_param
  1312. fsdp_state = fsdp_param_info.state
  1313. for fqn, gathered_state in output_states.items():
  1314. value = gathered_state[state_name]
  1315. param_idx = fsdp_param_info.param_indices[fqn]
  1316. # TODO: This solution is not general and only apply to PTD TP solution.
  1317. if isinstance(value, DTensor):
  1318. placement = value.placements[0]
  1319. # If gathered state is a DTensor and its TP placement is not Replicate(), we need to
  1320. # gather the tensor on its TP dimension before chunking them into DTensor again.
  1321. if placement != Replicate():
  1322. placement_dim = placement.dim # type: ignore[attr-defined]
  1323. value_local = value.redistribute(placements=(Replicate(),))
  1324. reshape_size = list(flat_param._shapes[param_idx])
  1325. reshape_size[placement_dim] *= value.device_mesh.size(0)
  1326. reshape_size = torch.Size(reshape_size)
  1327. value = value.reshape(reshape_size)
  1328. # If gathered state is a replicate DTensor, we directly reshape it.
  1329. else:
  1330. value = value.reshape(flat_param._shapes[param_idx])
  1331. else:
  1332. # If gathered state is a tensor, we directly reshape it into unflatten state.
  1333. value = value.reshape(flat_param._shapes[param_idx])
  1334. if shard_state:
  1335. osd_config = fsdp_state._optim_state_dict_config
  1336. if getattr(osd_config, "_use_dtensor", False):
  1337. assert fsdp_state._device_mesh is not None
  1338. value = _ext_chunk_dtensor(
  1339. value,
  1340. fsdp_state.rank,
  1341. fsdp_state._device_mesh,
  1342. fsdp_state._fsdp_extension,
  1343. )
  1344. else:
  1345. assert fsdp_state.process_group is not None
  1346. value = _ext_chunk_tensor(
  1347. value,
  1348. fsdp_state.rank,
  1349. fsdp_state.world_size,
  1350. fsdp_state._device_handle.device_count(),
  1351. fsdp_state.process_group,
  1352. fsdp_state._fsdp_extension,
  1353. )
  1354. elif not cpu_offload:
  1355. with SimpleProfiler.profile("clone"):
  1356. value = value.detach().clone()
  1357. if cpu_offload:
  1358. with SimpleProfiler.profile(SimpleProfiler.Type.D2H):
  1359. value = value.cpu()
  1360. gathered_state[state_name] = value
  1361. def _allgather_orig_param_states(
  1362. fsdp_param_info: FSDPParamInfo,
  1363. gathered_state_info: List[Dict[str, StateInfo]],
  1364. input_states: Dict[str, Any],
  1365. shard_state: bool,
  1366. to_save: bool,
  1367. cpu_offload: bool,
  1368. ) -> Dict[str, Dict[str, Any]]:
  1369. """
  1370. Given the ``gathered_state_info`` and ``input_states``, the API allgathers
  1371. all tensor states and restore non-tensor states from ``gathered_state_info``.
  1372. """
  1373. fsdp_state = fsdp_param_info.state
  1374. if fsdp_state.rank == 0 and dist.get_debug_level() == dist.DebugLevel.DETAIL:
  1375. logger.info(
  1376. "Memory Summary before calling to _allgather_orig_param_states %s",
  1377. fsdp_state._device_handle.memory_summary(),
  1378. )
  1379. output_states: Dict[str, Dict[str, Any]] = {fqn: {} for fqn in input_states.keys()}
  1380. dtype, state_buffers = _convert_all_state_info(
  1381. fsdp_param_info, gathered_state_info, input_states, output_states
  1382. )
  1383. if len(state_buffers) == 0:
  1384. return output_states
  1385. has_state_params: List[bool] = [
  1386. True if fqn in output_states else False
  1387. for fqn, idx in fsdp_param_info.param_indices.items()
  1388. ]
  1389. # Loop through the ``state_buffers`` and construct the flattened, concatenated,
  1390. # sharded states. The size of the constructed state will be the same size as
  1391. # flat_param (also sharded).
  1392. # Then we perform an allgather_into_tensor to get the full flat_param state.
  1393. # The full flat_param state is the result of concatenation of multiple states
  1394. # the order of of flat_param._fqns.
  1395. # The final step is to split the flat_param state into original param states
  1396. # and return the result.
  1397. flat_param = fsdp_param_info.handle.flat_param
  1398. empty_func = functools.partial(
  1399. torch.empty, dtype=dtype, device=fsdp_state.compute_device
  1400. )
  1401. gathered_tensor = empty_func(flat_param._padded_unsharded_size)
  1402. # Synchronize can be slow but this will be easier for us to debug.
  1403. fsdp_state._device_handle.synchronize()
  1404. for state_name, buffers in state_buffers.items():
  1405. local_buffers: List[torch.Tensor] = []
  1406. begin = fsdp_state.rank * flat_param._sharded_size.numel()
  1407. # End is inclusive.
  1408. end = begin + flat_param._sharded_size.numel() - 1
  1409. # param_idx corresponds to the parameter index in the FlatParameter.
  1410. mem_offset, param_idx = 0, 0
  1411. for numel, is_padding in zip(
  1412. flat_param._numels_with_padding, flat_param._is_padding_mask
  1413. ):
  1414. frozen_and_no_state = not is_padding and (
  1415. not fsdp_param_info.param_requires_grad[param_idx]
  1416. and not has_state_params[param_idx]
  1417. )
  1418. if is_padding or frozen_and_no_state:
  1419. # This memory range is a padding or the param is frozen and does
  1420. # not require gradient. For the later case, we treat it as a
  1421. # padding and add empty values to the local_buffers.
  1422. padding_begin, padding_end = mem_offset, mem_offset + numel - 1
  1423. if padding_begin <= begin <= padding_end:
  1424. # The range is an align padding before the first parameter in
  1425. # the shard. The shard includes parts of this align padding.
  1426. padding_len = (
  1427. padding_end - begin + 1
  1428. if end >= padding_end
  1429. else end - begin + 1
  1430. )
  1431. elif padding_begin <= end <= padding_end:
  1432. # The range is an align padding after the last parameter in
  1433. # the shard. The shard includes parts of this align padding.
  1434. padding_len = (
  1435. end - padding_begin + 1
  1436. if begin <= padding_begin
  1437. else end - begin + 1
  1438. )
  1439. elif begin < padding_begin <= padding_end < end:
  1440. # The range is an align padding that is completely in the
  1441. # shard.
  1442. padding_len = numel
  1443. else:
  1444. padding_len = 0
  1445. if padding_len:
  1446. local_buffers.append(empty_func(padding_len))
  1447. if not is_padding:
  1448. # This memory range is a parameter in FlatParameter. So there
  1449. # should be an corresponding state in the optimizer unless the
  1450. # parameter is frozen, which we treat it as a padding above.
  1451. # We need to check if this rank owns the buffer. If this is None:
  1452. # 1.) the rank does not own any part of the original parameter.
  1453. # As a result, there is no corresponding optimizer state on
  1454. # the rank as well.
  1455. # 2.) the parameter is frozen AND no optimizer state for the
  1456. # parameter. If a parameter is frozen, there can still be
  1457. # optimizer state if the parameter is not frozen in the
  1458. # previous steps.
  1459. if buffers[param_idx] is not None:
  1460. local_buffers.append(cast(torch.Tensor, buffers[param_idx]))
  1461. param_idx += 1
  1462. mem_offset += numel
  1463. shard_numel_padded = flat_param._sharded_size.numel() - (
  1464. sum(t.numel() for t in local_buffers)
  1465. )
  1466. assert flat_param._shard_numel_padded == shard_numel_padded, (
  1467. "Manually calculated _sharded_numel_padded is incorrect. "
  1468. f"_shard_numel_padded={flat_param._shard_numel_padded}, "
  1469. f"shard_numel_padded={shard_numel_padded}, "
  1470. f"_sharded_size.numel={flat_param._sharded_size.numel()}, "
  1471. f"_numels_with_padding={flat_param._numels_with_padding}, "
  1472. f"begin={begin}, end={end},"
  1473. )
  1474. if shard_numel_padded > 0:
  1475. # Add right-handed padding.
  1476. local_buffers.append(empty_func(shard_numel_padded))
  1477. local_shard = torch.cat(local_buffers)
  1478. assert local_shard.numel() * fsdp_state.world_size == gathered_tensor.numel(), (
  1479. "The size of local shard times the world size should equal to the "
  1480. "gathered tensor size. The inconsistency may be from a bug of "
  1481. "FlatParameter's metadata or the reconstruction logic in optimizer "
  1482. "state dict."
  1483. )
  1484. fsdp_state._device_handle.synchronize()
  1485. with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER):
  1486. dist.all_gather_into_tensor(
  1487. gathered_tensor, local_shard, group=fsdp_state.process_group
  1488. )
  1489. # Synchronize can be slow but this will be easier for us to debug.
  1490. fsdp_state._device_handle.synchronize()
  1491. unpadded_tensor = gathered_tensor[: flat_param._unpadded_unsharded_size.numel()]
  1492. flat_param_handle = fsdp_param_info.handle
  1493. orig_states = flat_param_handle._get_unflat_views_aligned(unpadded_tensor)
  1494. assert len(orig_states) == len(fsdp_param_info.param_indices), (
  1495. "The number of parameters from FlatParameter is not consistent to "
  1496. "the number of states used by optimizer state dict reconstruction "
  1497. "logic."
  1498. )
  1499. for fqn, idx in fsdp_param_info.param_indices.items():
  1500. if fsdp_param_info.param_requires_grad[idx] or fqn in output_states:
  1501. output_states[fqn][state_name] = orig_states[idx]
  1502. _unflatten_orig_param_states(
  1503. fsdp_param_info,
  1504. output_states,
  1505. state_name,
  1506. shard_state,
  1507. to_save,
  1508. cpu_offload,
  1509. )
  1510. del gathered_tensor
  1511. return output_states
  1512. def _gather_all_orig_param_state(
  1513. fsdp_param_info: FSDPParamInfo,
  1514. input_states: Dict[str, Any],
  1515. shard_state: bool,
  1516. to_save: bool,
  1517. cpu_offload: bool,
  1518. ) -> Dict[str, Any]:
  1519. """
  1520. Given a optimizer state dict, ``input_states``, which the keys are FQNs to the
  1521. original parameters (not FlatParameters nor parmeter ID), gather all the
  1522. states and unflatten them to the original dimensions. Note that all the
  1523. params referred by the ``input_states`` must be managed by FSDP.
  1524. """
  1525. fsdp_state = fsdp_param_info.state
  1526. if (
  1527. fsdp_state.world_size == 1
  1528. or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
  1529. ):
  1530. return input_states if to_save else {}
  1531. with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING):
  1532. with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER_OBJ):
  1533. gathered_state_info = _allgather_state_info(fsdp_state, input_states)
  1534. output_states = _allgather_orig_param_states(
  1535. fsdp_param_info,
  1536. gathered_state_info,
  1537. input_states,
  1538. shard_state,
  1539. to_save,
  1540. cpu_offload,
  1541. )
  1542. if to_save:
  1543. for key, idx in fsdp_param_info.param_indices.items():
  1544. if key in output_states:
  1545. continue
  1546. if not fsdp_param_info.param_requires_grad[idx]:
  1547. continue
  1548. raise RuntimeError(
  1549. f"{key} is not in the output state. "
  1550. "The FSDPParamInfo has the param keys "
  1551. f"{sorted(fsdp_param_info.param_indices.keys())} while "
  1552. "the output_states has the param keys "
  1553. f"{sorted(output_states.keys())}."
  1554. )
  1555. return output_states
  1556. else:
  1557. return {}
  1558. def _convert_state_with_orig_params(
  1559. all_optim_state_keys: List[_OptimStateKey],
  1560. optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]],
  1561. fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
  1562. optim_state_dict: Dict[Union[str, int], Any],
  1563. to_save: bool,
  1564. shard_state: bool,
  1565. cpu_offload: bool = True,
  1566. ) -> Dict[str, Any]:
  1567. fsdp_osd_state: Dict[str, Any] = {}
  1568. # This variable is used to deduplicate the FSDPParamInfo as one FSDPParamInfo
  1569. # usually corresponds to multiple parameters. We could not use FSDPParamInfo
  1570. # as the key because FSDPParamInfo is not hashable. As a result, we fall back
  1571. # to `id(FSDPParamInfo)`, which the type is an integer.
  1572. all_states: Dict[int, Dict[str, Any]] = {}
  1573. # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
  1574. # across ranks
  1575. for optim_state_key in all_optim_state_keys:
  1576. param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
  1577. optim_state_key, None
  1578. )
  1579. if param_key is None and not optim_state_key.is_fsdp_managed:
  1580. continue
  1581. if optim_state_key.is_fsdp_managed:
  1582. fqn = optim_state_key.unflat_param_names[0]
  1583. fsdp_param_info = fqn_to_fsdp_param_info.get(fqn, None)
  1584. if fsdp_param_info is None:
  1585. # This can happen if the not all FSDP instances have all the
  1586. # parameters. This can happen with FSDP + some MPMD style
  1587. # parallelism.
  1588. # TODO: it is unclear if we need to do the same check with
  1589. # non-FSDP managed keys.
  1590. continue
  1591. state = {} if param_key is None else optim_state_dict[param_key]
  1592. if id(fsdp_param_info) not in all_states:
  1593. all_states[id(fsdp_param_info)] = {}
  1594. all_states[id(fsdp_param_info)][fqn] = state
  1595. elif to_save:
  1596. assert len(optim_state_key.unflat_param_names) == 1
  1597. unflat_param_name = optim_state_key.unflat_param_names[0]
  1598. with SimpleProfiler.profile("none_fsdp_managed_copy"):
  1599. param_key = cast(Union[str, int], param_key)
  1600. fsdp_osd_state[unflat_param_name] = copy.copy(
  1601. optim_state_dict[param_key]
  1602. )
  1603. if cpu_offload:
  1604. for state_name, value in sorted_items(
  1605. fsdp_osd_state[unflat_param_name]
  1606. ):
  1607. if not torch.is_tensor(value):
  1608. continue
  1609. fsdp_osd_state[unflat_param_name][state_name] = value.cpu()
  1610. # Instead of gathering the state of each parameter individually, we perform
  1611. # the gathering all at once to speed up the process.
  1612. for _all_states in all_states.values():
  1613. fqn = next(iter(_all_states.keys()))
  1614. fsdp_param_info = fqn_to_fsdp_param_info[fqn]
  1615. assert len(fsdp_param_info.param_requires_grad) > 0, (
  1616. "With use_orig_params, FSDPParamInfo should have requires_grad "
  1617. "information. However, the length is zero."
  1618. )
  1619. for key, idx in fsdp_param_info.param_indices.items():
  1620. if key in _all_states:
  1621. continue
  1622. if not fsdp_param_info.param_requires_grad[idx]:
  1623. continue
  1624. raise RuntimeError(
  1625. f"{key} is not in the optimizer state. "
  1626. "The FSDPParamInfo has the param keys "
  1627. f"{sorted(fsdp_param_info.param_indices.keys())} while "
  1628. "the optimizer has the param keys "
  1629. f"{sorted(_all_states.keys())}."
  1630. )
  1631. fsdp_osd_state.update(
  1632. _gather_all_orig_param_state(
  1633. fsdp_param_info,
  1634. _all_states,
  1635. shard_state,
  1636. to_save,
  1637. cpu_offload,
  1638. )
  1639. )
  1640. return fsdp_osd_state
  1641. def _convert_state_with_flat_params(
  1642. all_optim_state_keys: List[_OptimStateKey],
  1643. optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]],
  1644. fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
  1645. optim_state_dict: Dict[Union[str, int], Any],
  1646. to_save: bool,
  1647. shard_state: bool,
  1648. cpu_offload: bool = True,
  1649. ) -> Dict[str, Any]:
  1650. fsdp_osd_state: Dict[str, Any] = {}
  1651. # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
  1652. # across ranks
  1653. for optim_state_key in all_optim_state_keys:
  1654. param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
  1655. optim_state_key, None
  1656. )
  1657. assert param_key is not None, (
  1658. "If use_orig_params is False, we must be able to find the "
  1659. f"corresponding param id. {optim_state_key} {param_key}"
  1660. )
  1661. if optim_state_key.is_fsdp_managed:
  1662. # If there are multiple unflat_param_names (not use_orig_params),
  1663. # they share the same FSDPParamInfo. So the first unflat_param_name
  1664. # is sufficient to fetch the FSDPParamInfo.
  1665. fqn = optim_state_key.unflat_param_names[0]
  1666. fsdp_param_info = fqn_to_fsdp_param_info[fqn]
  1667. unflat_state = _unflatten_optim_state(
  1668. fsdp_param_info,
  1669. optim_state_dict[param_key],
  1670. to_save,
  1671. shard_state,
  1672. cpu_offload,
  1673. )
  1674. if to_save:
  1675. assert len(unflat_state) == len(optim_state_key.unflat_param_names)
  1676. for unflat_param_name, unflat_param_state in zip(
  1677. optim_state_key.unflat_param_names,
  1678. unflat_state,
  1679. ):
  1680. fsdp_osd_state[unflat_param_name] = unflat_param_state
  1681. elif to_save:
  1682. assert len(optim_state_key.unflat_param_names) == 1
  1683. unflat_param_name = optim_state_key.unflat_param_names[0]
  1684. fsdp_osd_state[unflat_param_name] = copy.copy(optim_state_dict[param_key])
  1685. if cpu_offload:
  1686. for state_name, value in sorted_items(
  1687. fsdp_osd_state[unflat_param_name]
  1688. ):
  1689. if not torch.is_tensor(value):
  1690. continue
  1691. fsdp_osd_state[unflat_param_name][state_name] = value.cpu()
  1692. return fsdp_osd_state
  1693. @torch.no_grad()
  1694. def _optim_state_dict(
  1695. model: nn.Module,
  1696. optim: torch.optim.Optimizer,
  1697. optim_state_dict: Dict[str, Any],
  1698. optim_input: Optional[
  1699. Union[
  1700. List[Dict[str, Any]],
  1701. Iterable[nn.Parameter],
  1702. ]
  1703. ],
  1704. rank0_only: bool,
  1705. shard_state: bool,
  1706. group: Optional[dist.ProcessGroup],
  1707. using_optim_input: bool,
  1708. use_orig_params: bool = False,
  1709. cpu_offload: bool = True,
  1710. ) -> Dict[str, Any]:
  1711. """
  1712. Consolidates the optimizer state and returns it as a :class:`dict`
  1713. following the convention of :meth:`torch.optim.Optimizer.state_dict`,
  1714. i.e. with keys ``"state"`` and ``"param_groups"``.
  1715. The flat parameters in ``FSDP`` modules contained in ``model`` are mapped
  1716. back to their unflattened parameters.
  1717. Parameter keys are not well-defined. For a regular optimizer, the optimizer
  1718. state_dict contains a mapping from parameter IDs to parameter states.
  1719. Parameter IDs are the order of parameters in ``optim.param_groups()`` across
  1720. all the groups. This API also allows user to pass ``optim_input`` for the
  1721. mapping between parameters and parameter IDs. Using ``optim_input`` is being
  1722. deprecated.
  1723. If the optimizer is a ``NamedOptimizer``, the optimizer state_dict does not
  1724. contain parameter IDs mapping but a mapping from parameter FQNs to parameter
  1725. states. This API finds the mapping from FQNs to parameters if the optimizer
  1726. is a ``NamedOptimizer``.
  1727. If ``use_orig_params`` is True, each rank will have all FSDP-managed
  1728. parameters but some of these parameters may be empty due to the sharding.
  1729. For a regular optim.Optimizer, states for those empty parameters will
  1730. not be initialized. So, when aggregating the FQNs across ranks, no assert
  1731. will be raised on a rank even if it does not have all the states -- it is
  1732. valid and FSDP knows how to aggregate them. However, FSDP has to ignore
  1733. handling those parameters that are not managed by FSDP and do not exist on
  1734. the local rank -- those are managed by other parallelisms and FSDP does not
  1735. know how to handle/aggregate them.
  1736. Args:
  1737. model (nn.Module): Root module (which may or may not be a
  1738. :class:`FullyShardedDataParallel` instance) whose parameters
  1739. were passed into the optimizer ``optim``.
  1740. optim (torch.optim.Optimizer): Optimizer for ``model`` 's
  1741. parameters.
  1742. rank0_only (bool): If ``True``, saves the populated :class:`dict`
  1743. only on rank 0; if ``False``, saves it on all ranks. (Default:
  1744. ``True``)
  1745. shard_state (bool): If ``True``, shard and distribute all
  1746. non-zero-dimension states.
  1747. Returns:
  1748. Dict[str, Any]: A :class:`dict` containing the optimizer state for
  1749. ``model`` 's original unflattened parameters and including keys
  1750. "state" and "param_groups" following the convention of
  1751. :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=False``,
  1752. then nonzero ranks return an empty :class:`dict`.
  1753. """
  1754. SimpleProfiler.reset()
  1755. cm = ExitStack()
  1756. cm.enter_context(SimpleProfiler.profile(SimpleProfiler.Type.ALL))
  1757. _reset_flat_param_grad_info_if_needed(traversal_utils._get_fsdp_handles(model))
  1758. to_save = not rank0_only or dist.get_rank(group) == 0 or shard_state
  1759. with SimpleProfiler.profile("preprocessing"):
  1760. param_to_fqns = _get_param_to_fqns(model)
  1761. flat_param_to_fqn = _get_flat_param_to_fqn(model)
  1762. is_named_optimizer = _is_named_optimizer(optim_state_dict)
  1763. param_key_to_param = cast(
  1764. Dict[Union[int, str], nn.Parameter],
  1765. (
  1766. _get_param_id_to_param_from_optim_input(model, optim_input)
  1767. if using_optim_input
  1768. else _get_param_key_to_param(
  1769. optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn
  1770. )
  1771. ),
  1772. )
  1773. fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model)
  1774. with SimpleProfiler.profile("preprocessing_with_comm"):
  1775. (
  1776. all_optim_state_keys,
  1777. optim_state_key_to_param_key,
  1778. ) = _map_param_key_to_optim_keys(
  1779. optim_state_dict,
  1780. group,
  1781. param_key_to_param,
  1782. param_to_fqns,
  1783. fqn_to_fsdp_param_info,
  1784. merge_keys=use_orig_params,
  1785. )
  1786. with SimpleProfiler.profile("state_converting"):
  1787. convert_fn = (
  1788. _convert_state_with_orig_params
  1789. if use_orig_params
  1790. else _convert_state_with_flat_params
  1791. )
  1792. fsdp_osd_state = convert_fn(
  1793. all_optim_state_keys,
  1794. optim_state_key_to_param_key,
  1795. fqn_to_fsdp_param_info,
  1796. optim_state_dict["state"],
  1797. to_save,
  1798. shard_state,
  1799. cpu_offload,
  1800. )
  1801. # At this point, communication is complete and ranks can return early if nothing
  1802. # will be saved on that rank.
  1803. if not to_save:
  1804. return {}
  1805. fsdp_osd: Dict[str, Any] = {"state": fsdp_osd_state}
  1806. flat_param_fqns = set(flat_param_to_fqn.values())
  1807. for key, value in optim_state_dict["state"].items():
  1808. if key in fsdp_osd_state:
  1809. continue
  1810. if key in flat_param_fqns:
  1811. continue
  1812. if key in param_key_to_param:
  1813. continue
  1814. # This key is not recognized by FSDP. It may be a user-defined state
  1815. # or some parameters state that FSDP is unable to map from
  1816. # ``optim.param_groups``.
  1817. warnings.warn(
  1818. f"Found a optim state, {key}, that FSDP cannot process. FSDP "
  1819. "will directly copy everything to the returned state_dict. In "
  1820. "most cases, this is a user-defined state that is not "
  1821. "associated with any particular parameter. Another possible "
  1822. "case is this state is managed by TorchRec. Otherwise, there may "
  1823. " be a mismatched assumption of optim_state_dict of this mode."
  1824. )
  1825. fsdp_osd_state[key] = value
  1826. if "param_groups" in optim_state_dict:
  1827. fsdp_osd["param_groups"] = _unflatten_param_groups(
  1828. optim_state_dict, param_key_to_param, param_to_fqns
  1829. )
  1830. cm.close()
  1831. SimpleProfiler.dump_and_reset("FSDP _optim_state_dict() profiling: ")
  1832. return fsdp_osd
  1833. def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
  1834. """
  1835. Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo``
  1836. if the param is managed by FSDP. Shared parameters, or original parameters that
  1837. are shared across multiple nn.Modules, are required to belong to one and only
  1838. one FSDP instance and thus correspond to one ``FlatParameter``. Within the one
  1839. ``FlatParameter``, ``FlatParameter._fqns`` only stores the first FQN of a shared
  1840. parameter. Thus, the keys in the mapping are guaranteed to map to unique parameters.
  1841. """
  1842. def module_fn(module, prefix, tree_level, fqn_to_param_info):
  1843. fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
  1844. if fsdp_state is None:
  1845. return
  1846. _lazy_init(fsdp_state, module)
  1847. handle = _module_handle(fsdp_state, module)
  1848. if not handle:
  1849. return
  1850. flat_param = handle.flat_param
  1851. fsdp_param_info = FSDPParamInfo(fsdp_state, handle, {}, [])
  1852. # NOTE: `idx` indexes into the data structures *without* padding
  1853. # elements
  1854. for idx, local_fqn in enumerate(flat_param._fqns):
  1855. fqn = clean_tensor_name(prefix + local_fqn)
  1856. if fqn in fqn_to_param_info:
  1857. assert fqn_to_param_info[fqn].handle.flat_param is flat_param, fqn
  1858. fqn_to_param_info[fqn] = fsdp_param_info
  1859. fsdp_param_info.param_indices[fqn] = idx
  1860. if flat_param._params is not None:
  1861. fsdp_param_info.param_requires_grad.append(
  1862. flat_param._params[idx].requires_grad
  1863. )
  1864. def return_fn(fqn_to_param_info):
  1865. return fqn_to_param_info
  1866. fqn_to_param_info: Dict[str, FSDPParamInfo] = {}
  1867. # FlatParameter._fqns stores the local fqn, starting from the root of the
  1868. # FSDP. Using _apply_to_modules() with model (may not be the FSDP root
  1869. # module) allows us to construct the global fqn.
  1870. return _apply_to_modules(
  1871. model,
  1872. module_fn,
  1873. return_fn,
  1874. [fqn for fqn, _ in _named_parameters_with_duplicates(model)],
  1875. fqn_to_param_info,
  1876. )
  1877. @no_type_check
  1878. def _set_optim_use_dtensor(
  1879. fsdp_state: _FSDPState,
  1880. state_dict_settings: StateDictSettings,
  1881. ) -> None:
  1882. # If device_mesh is passed in when initalizing FSDP, we automatically turn the
  1883. # _use_dtensor flag to be true for ShardedOptimStateDictConfig() if state_dict_type
  1884. # has to be set to SHARDED_STATE_DICT.
  1885. if getattr(fsdp_state, "_device_mesh", None):
  1886. state_dict_type = state_dict_settings.state_dict_type
  1887. if state_dict_type == StateDictType.LOCAL_STATE_DICT:
  1888. raise RuntimeError(
  1889. "Found state_dict_type LOCAL_STATE_DICT.",
  1890. "DeviceMesh is not compatible with LOCAL_STATE_DICT.",
  1891. "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
  1892. )
  1893. else:
  1894. state_dict_settings.optim_state_dict_config._use_dtensor = True