_flat_param.py 118 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. import logging
  5. import os
  6. import warnings
  7. from enum import auto, Enum
  8. from itertools import accumulate, chain
  9. from typing import (
  10. Any,
  11. Callable,
  12. cast,
  13. Dict,
  14. Generator,
  15. Iterator,
  16. List,
  17. NamedTuple,
  18. no_type_check,
  19. Optional,
  20. Sequence,
  21. Set,
  22. Tuple,
  23. Union,
  24. )
  25. import torch
  26. import torch.distributed as dist
  27. import torch.nn as nn
  28. import torch.nn.functional as F
  29. from torch import Tensor
  30. from torch.distributed.fsdp._common_utils import (
  31. _FSDPDeviceHandle,
  32. _named_parameters_with_duplicates,
  33. _no_dispatch_record_stream,
  34. _set_fsdp_flattened,
  35. HandleTrainingState,
  36. )
  37. from torch.distributed.utils import (
  38. _alloc_storage,
  39. _data_ptr_allocated,
  40. _free_storage,
  41. _p_assert,
  42. )
  43. from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined]
  44. from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
  45. from ._fsdp_extensions import (
  46. _ext_post_unflatten_transform,
  47. _ext_pre_flatten_transform,
  48. FSDPExtensions,
  49. )
  50. __all__ = [
  51. "FlatParameter",
  52. "FlatParamHandle",
  53. "FlatParamShardMetadata",
  54. "ParamInfo",
  55. "SharedParamInfo",
  56. "HandleShardingStrategy",
  57. ]
  58. logger = logging.getLogger(__name__)
  59. """
  60. [Note: Fully Sharded Module]
  61. We define the "fully sharded module" to be the original ``nn.Module`` that owns
  62. a ``FlatParamHandle``. It is the *single* module logically responsible for the
  63. *single* unshard/reshard pair for the handle's ``FlatParameter`` for a given
  64. forward or backward pass. The fully sharded module should be passed to the
  65. ``FlatParamHandle`` constructor.
  66. For the wrapper code path:
  67. - The ``FullyShardedDataParallel`` module wrapping the fully sharded module
  68. runs the unshard/reshard on behalf of the fully sharded module by overriding
  69. ``nn.Module.forward``.
  70. - The fully sharded module is exactly the module passed to the
  71. ``FullyShardedDataParallel`` constructor's ``module`` argument.
  72. For the non-wrapper code path:
  73. - Hooks registered on the fully sharded module run the unshard/reshard.
  74. - The fully sharded module may either be the direct argument to ``fully_shard``
  75. or a submodule chosen by the provided wrapping policy.
  76. """
  77. # Environment variable toggling whether to use unsafe `setattr()` for view
  78. # setting in `_use_sharded_views()` and `_use_unsharded_views()`
  79. # We should use 'safe' by default since it respects method overrides, but for
  80. # special cases such as for high CPU overhead or for intentionally bypassing
  81. # checks in the overrides, we may use 'unsafe'.
  82. _FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR"
  83. # Environment variable toggling whether to check for parameter/gradient
  84. # writeback in case their storages change after FSDP initialization
  85. # We should check by default since it prevents silent correctness errors, but
  86. # since such changes are atypical, we may want to skip the check to save CPU
  87. # overhead, especially since the check happens in the pre-forward and
  88. # pre-backward each iteration.
  89. _FSDP_SKIP_WRITEBACK_CHECK = "FSDP_SKIP_WRITEBACK_CHECK"
  90. # Env var toggling whether when model is in .eval() mode, should we run in fp32
  91. # or the reduced precision.
  92. _FSDP_USE_FULL_PREC_IN_EVAL = "FSDP_USE_FULL_PREC_IN_EVAL"
  93. # Some value to set padding in tensors to for debuggability
  94. _FLAT_PARAM_PADDING_VALUE = 42
  95. # Environment variables for disabling the all-gather and reduce-scatter
  96. # communication ops for ablation studies. Note that without these communication
  97. # ops the training won't converge, and you probably need to disable correctness
  98. # checks in your model.
  99. _FSDP_USE_FAKE_ALL_GATHER = "FSDP_USE_FAKE_ALL_GATHER"
  100. _FSDP_USE_FAKE_REDUCE = "FSDP_USE_FAKE_REDUCE"
  101. # TODO: Define this for now to avoid circular imports. See if we can remove.
  102. class HandleShardingStrategy(Enum):
  103. FULL_SHARD = auto()
  104. SHARD_GRAD_OP = auto()
  105. NO_SHARD = auto()
  106. HYBRID_SHARD = auto()
  107. _HYBRID_SHARD_ZERO2 = auto()
  108. RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
  109. HandleShardingStrategy.FULL_SHARD,
  110. HandleShardingStrategy.HYBRID_SHARD,
  111. )
  112. NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
  113. HandleShardingStrategy.SHARD_GRAD_OP,
  114. HandleShardingStrategy._HYBRID_SHARD_ZERO2,
  115. )
  116. class ParamInfo(NamedTuple):
  117. """Information for an original parameter."""
  118. param_name: str # unprefixed
  119. module: nn.Module
  120. module_name: str
  121. class SharedParamInfo(NamedTuple):
  122. """
  123. Additional information for a shared parameter.
  124. For each shared parameter, we designate one module and its parameter
  125. variable to be the primary owner, determined as the first one encountered
  126. in the parameter walk. These are prefixed with "prim". The primary module
  127. and parameter do not have their own :class:`SharedParamInfo` instance.
  128. """
  129. param_name: str # unprefixed
  130. module: nn.Module
  131. module_name: str
  132. prim_param_name: str # unprefixed
  133. prim_module: nn.Module
  134. prim_module_name: str
  135. class _ShardParamInfo(NamedTuple):
  136. """Shard-related information for an original parameter."""
  137. in_shard: bool
  138. # Use to index into the sharded flat parameter, e.g.
  139. # `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]`
  140. offset_in_shard: Optional[int]
  141. numel_in_shard: Optional[int]
  142. # Use to get part of the parameter in the local shard from a flattened
  143. # version of the unsharded parameter, e.g.
  144. # `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]`
  145. intra_param_start_idx: Optional[int]
  146. intra_param_end_idx: Optional[int] # inclusive
  147. class FlatParamShardMetadata(NamedTuple):
  148. """
  149. This holds metadata specific to this rank's shard of the flat parameter.
  150. Attributes:
  151. param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
  152. shard of the parameters; see :class:`FlatParameter`.
  153. param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
  154. shard of the parameters; see :class:`FlatParameter`.
  155. param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
  156. of the parameters; see :class:`FlatParameter`.
  157. param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
  158. units of numels) giving this rank's part of each flattened
  159. original parameter.
  160. """
  161. param_names: Tuple[str, ...]
  162. param_shapes: Tuple[torch.Size, ...]
  163. param_numels: Tuple[int, ...]
  164. param_offsets: Tuple[Tuple[int, int], ...]
  165. class _FlatParameterMeta(_ParameterMeta):
  166. # Make `isinstance(t, FlatParameter)` return True for custom tensor
  167. # instances that have the _is_flat_param flag for BC
  168. def __instancecheck__(self, instance):
  169. # NB: do NOT test the super implementation
  170. return isinstance(instance, torch.Tensor) and getattr(
  171. instance, "_is_flat_param", False
  172. )
  173. class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
  174. """
  175. This is the flat parameter used by :class:`FullyShardedDataParallel`.
  176. It is comprised of one or more original parameters, which are flattened and
  177. concatenated to construct the flat parameter.
  178. Under the current design, this parameter logically represents both the
  179. unsharded and sharded flat parameter, and its data changes storages
  180. dynamically.
  181. - In the :class:`FullyShardedDataParallel` constructor, the parameter
  182. is initialized as unsharded and then sharded in-place.
  183. - At runtime, the parameter is lazily (re)-initialized. The sharded
  184. parameter data is saved in ``self._local_shard``, and a new ``Tensor``
  185. ``self._full_param_padded`` is created, which is the all-gather
  186. destination and owns the unsharded parameter storage thereafter. (See
  187. :meth:`FlatParamHandle.init_flat_param_attributes`.)
  188. - Throughout runtime, the parameter data changes storages as needed,
  189. e.g. to the sharded flat parameter, low precision sharded flat
  190. parameter, or the unsharded flat parameter.
  191. NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter``
  192. padding, we have two versions of the per-parameter numels, one that
  193. includes the padding (``_numels_with_padding``) and one that does not
  194. (``_numels``). The former may have length longer than the other data
  195. structures, while the latter has the same length as the number of actual
  196. original parameters like the other per-parameter data structures.
  197. NOTE: This is not a real class; instead, you will always get a Parameter
  198. back out if you try to create one of these. This is similar to the trick
  199. we implemented for Parameter to get it to work with subclasses; this
  200. is primarily so that FlatParameter supports combination with FakeTensor.
  201. Attributes:
  202. _unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size
  203. without right-hand-side padding for divisibility by the world size.
  204. For ``use_orig_params=True``, this includes alignment padding.
  205. _padded_unsharded_size (torch.Size): Unsharded flat parameter's size
  206. with right-hand-side padding for divisibility by the world size.
  207. For ``use_orig_params=True``, this includes alignment padding. This
  208. is only set for sharded strategies since they require padding for
  209. the all-gather.
  210. _sharded_size (torch.Size): Sharded flat parameter's size with padding.
  211. This is also set for ``NO_SHARD``, in which case it is the same as
  212. the unsharded sizes. (We omit "padded" because there is no
  213. analogous unpadded one.)
  214. _num_params (int): Number of original parameters flattened into this
  215. flat parameter. This is the length of the per-parameter data
  216. structures.
  217. _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info
  218. entry; see :class:`ParamInfo` for details.
  219. _shapes (Tuple[torch.Size, ...]): Each parameter's original shape.
  220. _fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN)
  221. prefixed from the ``_fully_sharded_module``. The names are
  222. guaranteed to be unique in the subtree rooted at that module.
  223. _param_extensions (Tuple[Optional[Any], ...]): Each parameter's
  224. extension (i.e. some per-parameter state) used to customize
  225. pre-flatten and post-unflatten behavior or ``None``. This is
  226. experimental, and users should not depend on its existence in the
  227. future.
  228. _numels_with_padding (Tuple[int, ...]): Each parameter's numel
  229. including entries for the padding. This is used to construct views
  230. into the flat parameter via ``torch.split()``. This may have length
  231. longer than ``_num_params``.
  232. _numels (Tuple[int, ...]): Each parameter's numel excluding entries for
  233. padding. This has length equal to ``_num_params``.
  234. _shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's
  235. shard parameter info; see :class:`_ShardParamInfo` for details.
  236. _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter
  237. info entries; see :class:`SharedParamInfo` for details.
  238. _modules (Set[nn.Module]): Modules that contain some original parameter
  239. that is flattened into the flat parameter.
  240. _shard_numel_padded (int): Numel padded for this rank's sharded flat
  241. parameter.
  242. _local_shard (Tensor): Sharded flat parameter with padding if using a
  243. sharded strategy. If using ``NO_SHARD``, then this is the unpadded
  244. unsharded flat parameter, and there is no notion of a sharded flat
  245. parameter or padded unsharded flat parameter.
  246. _full_param_padded (Tensor): Unsharded flat parameter with padding.
  247. This is not defined for ``NO_SHARD``. When using mixed precision
  248. for parameters, this has the low precision.
  249. _full_prec_full_param_padded (Tensor): Full precision unsharded flat
  250. parameter with padding. This is used for unsharding outside of
  251. computation when using mixed precision for parameters. This is
  252. never defined for ``NO_SHARD``.
  253. _post_backward_hook_handle (RemovableHandle):
  254. Flat parameter's post-backward hook handle. (Compile only)
  255. _post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]):
  256. Flat parameter's :class:`AccumulateGrad` object and post-backward
  257. hook handle. (Eager only)
  258. _mp_shard (Tensor): Low precision sharded flat parameter with padding.
  259. This is only defined when parameter mixed precision is enabled. For
  260. ``NO_SHARD``, this is used for computation.
  261. _cpu_grad (Tensor): Sharded gradient with padding stored on CPU.
  262. This is only defined when offloading parameters is enabled.
  263. _saved_grad_shard (Tensor): Sharded gradient with padding from previous
  264. iterations for gradient accumulation without :meth:`no_sync`.
  265. _params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``,
  266. then each original parameter variable; otherwise, ``None``. This
  267. does not include any padding tensors.
  268. _shared_params (Optional[List[nn.Parameter]]): The original shared
  269. parameter variables if ``use_orig_params=True`` and ``None``
  270. otherwise.
  271. _tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor``
  272. views created in the forward and tracked by autograd when
  273. ``use_orig_params=True`` and is ``None`` otherwise. This is to
  274. preserve those ``Tensor`` variables for the backward to ensure that
  275. the ``FlatParameter`` 's ``AccumulateGrad`` object does not change
  276. in which case the post-backward hook does not run. This is relevant
  277. for cases like reentrant activation checkpointing.
  278. _is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``,
  279. a mask over the original parameters' gradients indicating if it is
  280. logically ``None`` or not; otherwise, ``None``. This does not
  281. include entries for padding. This mask is needed because only some
  282. of the parameters may have ``None`` gradient, in which case the
  283. flat gradient must be non-``None`` and must use zeros to
  284. approximate those original ``None`` gradients. This mask informs
  285. FSDP to set the original parameter gradients to ``None`` (instead
  286. of zeros) as needed.
  287. """
  288. _unpadded_unsharded_size: torch.Size
  289. _padded_unsharded_size: torch.Size
  290. _sharded_size: torch.Size
  291. _num_params: int
  292. _param_infos: Tuple[ParamInfo, ...]
  293. _shapes: Tuple[torch.Size, ...]
  294. _fqns: Tuple[str, ...]
  295. _param_extensions: Tuple[Optional[Any], ...]
  296. _numels_with_padding: Tuple[int, ...]
  297. _numels: Tuple[int, ...]
  298. _shard_param_infos: Tuple[_ShardParamInfo, ...]
  299. _shared_param_infos: Tuple[SharedParamInfo, ...]
  300. _modules: Set[nn.Module]
  301. _shard_numel_padded: int
  302. _local_shard: Tensor
  303. _full_param_padded: Tensor
  304. _full_prec_full_param_padded: Tensor
  305. # Eager only
  306. _post_backward_hook_state: Tuple[Any, Any]
  307. # Compile only
  308. _post_backward_hook_handle: Any
  309. _mp_shard: Tensor
  310. _cpu_grad: Tensor
  311. _saved_grad_shard: Tensor
  312. _params: Optional[List[nn.Parameter]]
  313. _shared_params: Optional[List[nn.Parameter]]
  314. _tensors: Optional[List[Optional[Tensor]]]
  315. _is_grad_none_mask: Optional[List[bool]]
  316. _is_padding_mask: List[bool]
  317. def __new__(cls, data=None, requires_grad=True):
  318. assert cls is FlatParameter, "subclasses FlatParameter not supported"
  319. r = nn.Parameter.__new__(nn.Parameter, data, requires_grad) # type: ignore[call-arg]
  320. r._is_flat_param = True # type: ignore[attr-defined]
  321. return r
  322. # NB: This is not a regular method, because FlatParameters are not actually
  323. # instances of this class (see __new__ above). So you must indirectly
  324. # call this directly through the classmethod.
  325. @classmethod
  326. def _init_metadata(
  327. cls,
  328. self,
  329. param_infos: List[ParamInfo],
  330. numels: List[int],
  331. shapes: List[torch.Size],
  332. fqns: List[str],
  333. shared_param_infos: List[SharedParamInfo],
  334. param_extensions: List[Optional[Any]],
  335. params: Optional[List[nn.Parameter]],
  336. shared_params: Optional[List[nn.Parameter]],
  337. is_padding_mask: List[bool],
  338. ) -> None:
  339. """
  340. Initialize attributes holding metadata about the original parameters comprising the flat parameter.
  341. We expose this method separate from the constructor to keep the
  342. constructor only responsible for the flat parameter's tensor data. This
  343. method should only be called once per model, while the constructor may
  344. be called multiple times, e.g. when reloading from a checkpoint, in
  345. which case only the tensor data needs to be passed to the constructor.
  346. Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the
  347. metadata is correctly assumed to be unchanged.
  348. Args:
  349. See the Attributes in the class docstring.
  350. """
  351. assert len(param_infos) == len(shapes)
  352. assert len(param_infos) == len(fqns)
  353. assert len(param_infos) == len(param_extensions)
  354. self._num_params = len(param_infos)
  355. self._param_infos = param_infos
  356. self._shapes = shapes
  357. self._fqns = fqns
  358. self._param_extensions = param_extensions
  359. self._is_padding_mask = is_padding_mask
  360. numels_without_padding: List[int] = []
  361. for numel, is_padding in zip(numels, is_padding_mask):
  362. if not is_padding:
  363. numels_without_padding.append(numel)
  364. self._numels = tuple(numels_without_padding)
  365. self._numels_with_padding = tuple(numels)
  366. assert len(self._numels) == self._num_params
  367. self._shared_param_infos = tuple(shared_param_infos)
  368. self._modules = {pi.module for pi in self._param_infos}.union(
  369. {spi.module for spi in self._shared_param_infos}
  370. )
  371. assert (params is None) == (shared_params is None)
  372. if params is not None:
  373. assert shared_params is not None and len(shared_params) == len(
  374. shared_param_infos
  375. )
  376. self._params = []
  377. for param, is_padding in zip(params, is_padding_mask):
  378. if not is_padding:
  379. self._params.append(param)
  380. self._shared_params = shared_params
  381. # Mark the original parameters to avoid flattening them into
  382. # another `FlatParameter` during recursive construction
  383. for param in chain(self._params, self._shared_params):
  384. _set_fsdp_flattened(param)
  385. self._is_grad_none_mask = [False for _ in range(self._num_params)]
  386. self._tensors = [None for _ in range(self._num_params)]
  387. else:
  388. self._params = None
  389. self._shared_params = None
  390. self._is_grad_none_mask = None
  391. self._tensors = None
  392. self._unpadded_unsharded_size = self.size()
  393. _set_fsdp_flattened(self)
  394. # Tracks whether the `FlatParameter`'s post-backward hook has been
  395. # called to modify the behavior of the post-backward callback
  396. self._post_backward_called = False
  397. class FlatParamHandle:
  398. """
  399. A handle that manages a flat parameter (:class:`FlatParameter`).
  400. This includes sharding and view management.
  401. Args:
  402. params (Sequence[nn.Parameter]): The parameters to flatten into the
  403. flat parameter.
  404. fully_sharded_module (nn.Module): See [Note: Fully Sharded Module].
  405. device (torch.device): The compute and communication device, which
  406. should be a non-CPU device. We refer to it as the compute device.
  407. sharding_strategy (ShardingStrategy): Sharding strategy to apply to
  408. this handle's ``FlatParameter``.
  409. offload_params (bool): Whether to offload the handle's
  410. ``FlatParameter`` to CPU.
  411. mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision
  412. setting passed to the FSDP constructor.
  413. mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed
  414. precision setting passed to the FSDP constructor.
  415. keep_low_precision_grads (bool): Whether to keep gradients in low
  416. precision.
  417. use_orig_params (bool): If ``True``, then FSDP preserves the original
  418. parameter variables and returns them from ``named_parameters()``
  419. (e.g. to support different optimizer hyperparameters within one
  420. :class:`FlatParameter`). If ``False``, then FSDP reconstructs the
  421. parameters every iteration and returns the :class:`FlatParameter` s
  422. from ``named_parameters()``.
  423. """
  424. ##################
  425. # INITIALIZATION #
  426. ##################
  427. def __init__(
  428. self,
  429. params: Sequence[Union[nn.Parameter, Tensor]],
  430. fully_sharded_module: nn.Module,
  431. device: torch.device,
  432. sharding_strategy: HandleShardingStrategy,
  433. offload_params: bool,
  434. mp_param_dtype: Optional[torch.dtype],
  435. mp_reduce_dtype: Optional[torch.dtype],
  436. keep_low_precision_grads: bool,
  437. process_group: dist.ProcessGroup,
  438. use_orig_params: bool,
  439. *,
  440. fsdp_extension: Optional[FSDPExtensions] = None,
  441. ):
  442. super().__init__()
  443. params = list(params)
  444. if len(params) == 0:
  445. raise ValueError(
  446. f"Cannot construct a {self.__class__.__name__} with an empty parameter list"
  447. )
  448. self._init_setattr_fns()
  449. self._skip_writeback_check = (
  450. os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1"
  451. )
  452. self._use_full_prec_in_eval = (
  453. os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
  454. )
  455. self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1"
  456. self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1"
  457. if self._skip_writeback_check:
  458. _warn_skip_writeback_check(
  459. logger,
  460. f"Since {_FSDP_SKIP_WRITEBACK_CHECK}=1, FSDP will not check "
  461. "for parameter or gradient writeback. Changing parameter or "
  462. "gradient storages may lead to silent correctness errors.",
  463. )
  464. if self._use_fake_all_gather:
  465. _warn_use_fake_all_gather(
  466. logger,
  467. f"Since {_FSDP_USE_FAKE_ALL_GATHER}=1, FSDP will not execute "
  468. "all-gather ops. Your training will be incorrect, but "
  469. "can reveal how much time spent on all-gather ops.",
  470. )
  471. if self._use_fake_reduce:
  472. _warn_use_fake_reduce(
  473. logger,
  474. f"Since {_FSDP_USE_FAKE_REDUCE}=1, FSDP will not execute "
  475. "reduce-scatter ops. Your training will be incorrect, but "
  476. "can reveal how much time spent on reduce-scatter ops.",
  477. )
  478. # Only align addresses for `use_orig_params=True` (for now)
  479. align_addresses = use_orig_params
  480. self._init_get_unflat_views_fn(align_addresses)
  481. self.device = device
  482. self._device_handle = _FSDPDeviceHandle.from_device(self.device)
  483. self.process_group = process_group
  484. if self._use_fake_all_gather or self._use_fake_reduce:
  485. self._fake_process_group = FakeProcessGroup(
  486. rank=process_group.rank(), world_size=process_group.size()
  487. )
  488. self.rank = process_group.rank()
  489. self.world_size = process_group.size()
  490. self._sharding_strategy = sharding_strategy
  491. self._offload_params = offload_params
  492. self._use_orig_params = use_orig_params
  493. self._keep_low_precision_grads = keep_low_precision_grads
  494. self._training_state = HandleTrainingState.IDLE
  495. self._debug_level = dist.get_debug_level()
  496. self._fully_sharded_module = fully_sharded_module
  497. # For strategies that do not free after forward, we skip using sharded
  498. # views after forward since the unsharded data exists. We still switch
  499. # `self.flat_param` to point to the sharded flat parameter since what
  500. # it points to parameterizes behavior. We use the following attribute
  501. # to track which tensor data the parameters are unsharded views into.
  502. self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None
  503. # The index in the state's `all_handles`, which must be the
  504. # same across ranks for the execution order validation to work
  505. self._handle_index: Optional[int] = None
  506. # Index in handles_to_pre_forward_order
  507. self._pre_forward_order_index: Optional[int] = None
  508. # Index in `handles_post_forward_order`
  509. self._post_forward_index: Optional[int] = None
  510. # Used for guarding against mistargeted forward prefetches
  511. self._needs_pre_forward_unshard = False
  512. # Used for guarding against mistargeted backward prefetches
  513. self._needs_pre_backward_unshard = False
  514. # Was the handle prefetched? Set on successful _prefetch_handle and unshard
  515. self._prefetched = False
  516. # Optimistically assume a valid input `params` and set dtype attributes
  517. # before `_init_flat_param()`, which performs the actual validation
  518. self._orig_param_dtype = params[0].dtype
  519. self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
  520. assert self._fwd_bwd_param_dtype is not None # mypy
  521. self._aligned_numel = (
  522. _get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
  523. if align_addresses
  524. else 0
  525. )
  526. self._fsdp_extension = fsdp_extension
  527. self._init_flat_param_and_metadata(
  528. params, fully_sharded_module, self._aligned_numel, use_orig_params # type: ignore[arg-type]
  529. )
  530. self._use_unsharded_views(as_params=False)
  531. def _init_setattr_fns(self):
  532. use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
  533. self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
  534. self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
  535. if use_unsafe_setattr:
  536. self._setattr_tensor = _unsafe_setattr_tensor
  537. self._setattr_param = _unsafe_setattr_param
  538. else:
  539. self._setattr_tensor = _safe_setattr_tensor_or_param
  540. self._setattr_param = _safe_setattr_tensor_or_param
  541. def _init_get_unflat_views_fn(self, align_addresses: bool):
  542. self._get_unflat_views = (
  543. self._get_unflat_views_aligned
  544. if align_addresses
  545. else self._get_unflat_views_unaligned
  546. )
  547. def _init_flat_param_and_metadata(
  548. self,
  549. params: List[Union[Tensor, nn.Parameter]],
  550. module: nn.Module,
  551. aligned_numel: int,
  552. use_orig_params: bool,
  553. ) -> None:
  554. """
  555. Initialize the ``FlatParameter`` and its metadata.
  556. NOTE: This should only be called once at construction time, after which
  557. the ``FlatParameter`` metadata is assumed to be static.
  558. NOTE: The elements of ``params`` should only be ``Tensor`` s when
  559. composing with ``DTensor`` -based tensor parallelism, in which case the
  560. elements may be ``DTensor`` local shards.
  561. """
  562. if len(params) == 0:
  563. raise ValueError("Expects non-empty `params`")
  564. if aligned_numel < 0:
  565. raise ValueError(
  566. f"Expects non-negative `aligned_numel` but got {aligned_numel}"
  567. )
  568. (
  569. dtype,
  570. flat_param_requires_grad,
  571. device,
  572. ) = self._validate_tensors_to_flatten(params)
  573. params_set = set(params)
  574. # For alignment padding, only `numels` gets strictly non-`None`
  575. # elements, and all other lists get `None` elements for padding.
  576. param_infos: List[ParamInfo] = []
  577. numels: List[int] = []
  578. shapes: List[torch.Size] = []
  579. fqns: List[str] = []
  580. shared_param_infos: List[SharedParamInfo] = []
  581. shared_param_memo: Dict[
  582. Union[Tensor, nn.Parameter], Tuple[nn.Module, str, str]
  583. ] = {}
  584. params_to_flatten: List[Union[Tensor, nn.Parameter]] = []
  585. shared_params: List[Union[Tensor, nn.Parameter]] = []
  586. param_extensions: List[Any] = []
  587. is_padding_mask: List[bool] = []
  588. total_numel = total_numel_without_padding = 0
  589. for submodule_name, submodule in module.named_modules(remove_duplicate=False):
  590. for param_name, param in _named_parameters_with_duplicates(
  591. submodule, recurse=False
  592. ):
  593. if param not in params_set:
  594. continue
  595. if param in shared_param_memo: # shared reference
  596. prim_module, prim_module_name, prim_param_name = shared_param_memo[
  597. param
  598. ]
  599. shared_params.append(param)
  600. shared_param_infos.append(
  601. SharedParamInfo(
  602. param_name,
  603. submodule,
  604. submodule_name,
  605. prim_param_name,
  606. prim_module,
  607. prim_module_name,
  608. )
  609. )
  610. else:
  611. if aligned_numel > 0:
  612. numel_to_pad = aligned_numel - (total_numel % aligned_numel)
  613. if numel_to_pad > 0 and numel_to_pad < aligned_numel:
  614. padding_tensor = _construct_padding_tensor(
  615. numel_to_pad, dtype, False, device
  616. )
  617. params_to_flatten.append(padding_tensor)
  618. is_padding_mask.append(True)
  619. numels.append(numel_to_pad)
  620. total_numel += numel_to_pad
  621. transform_t, extension = _ext_pre_flatten_transform(
  622. param,
  623. self._fsdp_extension,
  624. )
  625. param = cast(nn.Parameter, transform_t)
  626. param_extensions.append(extension)
  627. shared_param_memo[param] = (submodule, submodule_name, param_name)
  628. params_to_flatten.append(param)
  629. is_padding_mask.append(False)
  630. param_infos.append(ParamInfo(param_name, submodule, submodule_name))
  631. numels.append(param.numel())
  632. shapes.append(param.shape)
  633. fqn = (
  634. submodule_name + "." + param_name
  635. if submodule_name
  636. else param_name
  637. )
  638. fqns.append(fqn)
  639. total_numel += param.numel()
  640. total_numel_without_padding += param.numel()
  641. if len(params_to_flatten) == 0:
  642. raise ValueError(
  643. f"`params` were not found in `module`'s tree"
  644. f"params: {params}\nmodule: {module}"
  645. )
  646. if (
  647. self.rank == 0
  648. and aligned_numel > 0
  649. and total_numel != total_numel_without_padding
  650. ):
  651. logger.debug(
  652. "FSDP FlatParameter address alignment created "
  653. "%s numel of padding (%s vs. %s)",
  654. total_numel - total_numel_without_padding,
  655. total_numel,
  656. total_numel_without_padding,
  657. )
  658. if aligned_numel > 0:
  659. # Pad to be divisible by world size to avoid a copy for the
  660. # post-backward reduce-scatter
  661. numel_to_pad = self.world_size - (total_numel % self.world_size)
  662. if numel_to_pad > 0 and numel_to_pad < self.world_size:
  663. if self.rank == 0:
  664. logger.info(
  665. "FSDP FlatParameter world size divisibility created "
  666. "%s numel of padding",
  667. numel_to_pad,
  668. )
  669. padding_tensor = _construct_padding_tensor(
  670. numel_to_pad, dtype, False, device
  671. )
  672. params_to_flatten.append(padding_tensor)
  673. is_padding_mask.append(True)
  674. numels.append(numel_to_pad)
  675. total_numel += numel_to_pad
  676. # Pass `aligned_numel=0` since we already included padding tensors
  677. self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
  678. params_to_flatten,
  679. aligned_numel=0,
  680. requires_grad=flat_param_requires_grad,
  681. )
  682. FlatParameter._init_metadata(
  683. self.flat_param,
  684. param_infos,
  685. numels,
  686. shapes,
  687. fqns,
  688. shared_param_infos,
  689. param_extensions,
  690. _convert_to_params(params_to_flatten) if use_orig_params else None,
  691. _convert_to_params(shared_params) if use_orig_params else None,
  692. is_padding_mask,
  693. )
  694. def _validate_tensors_to_flatten(
  695. self, tensors: List[Union[Tensor, nn.Parameter]]
  696. ) -> Tuple:
  697. """Validate the tensors to flatten and returns any necessary metadata."""
  698. dtype: Optional[torch.dtype] = None
  699. # Return as the logical OR over each tensor's value
  700. flat_param_requires_grad: Optional[bool] = None
  701. device: Optional[torch.device] = None
  702. # For `use_orig_params=True`, permit non-uniform `requires_grad`
  703. for tensor in tensors:
  704. if isinstance(tensor, FlatParameter):
  705. raise ValueError("Cannot flatten a `FlatParameter`")
  706. if dtype is None and not tensor.is_floating_point():
  707. raise ValueError("Cannot flatten integer dtype tensors")
  708. if dtype is not None and tensor.dtype != dtype:
  709. raise ValueError(
  710. f"Must flatten tensors with uniform dtype but got {dtype} "
  711. f"and {tensor.dtype}"
  712. )
  713. if (
  714. not self._use_orig_params
  715. and flat_param_requires_grad is not None
  716. and tensor.requires_grad != flat_param_requires_grad
  717. ):
  718. raise ValueError(
  719. "Must flatten tensors with uniform `requires_grad` when "
  720. "`use_orig_params=False`"
  721. )
  722. if device is not None and tensor.device != device:
  723. raise ValueError(
  724. "Must flatten tensors on the same device but got both "
  725. f"{device} and {tensor.device}"
  726. )
  727. dtype = tensor.dtype
  728. flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad
  729. device = tensor.device
  730. assert flat_param_requires_grad is not None, "Requires non-empty `tensors` list"
  731. return dtype, flat_param_requires_grad, device
  732. def flatten_tensors(
  733. self,
  734. tensors: List[Tensor],
  735. aligned_numel: int,
  736. ) -> Tensor:
  737. """
  738. Flatten ``tensors`` into a single flat tensor.
  739. The flattening optionally includes
  740. padding if ``aligned_numel`` is greater than 0, where ``aligned_numel``
  741. gives the numel required to have address alignment.
  742. NOTE: The padding alignment algorithm must be kept in sync with
  743. :meth:`_init_flat_param_metadata`. We separate the two methods because
  744. the initialization happens once, whereas this method may be called
  745. multiple times throughout training (e.g. for checkpointing).
  746. """
  747. if len(tensors) == 0:
  748. raise ValueError("Expects non-empty `tensors`")
  749. if aligned_numel < 0:
  750. raise ValueError(
  751. f"Expects non-negative `aligned_numel` but got {aligned_numel}"
  752. )
  753. dtype, _, device = self._validate_tensors_to_flatten(tensors)
  754. flat_tensors: List[Tensor] = []
  755. if aligned_numel > 0:
  756. total_numel = 0
  757. for tensor in tensors:
  758. numel_to_pad = aligned_numel - (total_numel % aligned_numel)
  759. if numel_to_pad > 0 and numel_to_pad < aligned_numel:
  760. padding_tensor = _construct_padding_tensor(
  761. numel_to_pad, dtype, False, device
  762. )
  763. flat_tensors.append(padding_tensor)
  764. total_numel += numel_to_pad
  765. flat_tensors.append(torch.flatten(_detach_if_needed(tensor)))
  766. total_numel += tensor.numel()
  767. numel_to_pad = self.world_size - (total_numel % self.world_size)
  768. if numel_to_pad > 0 and numel_to_pad < self.world_size:
  769. padding_tensor = _construct_padding_tensor(
  770. numel_to_pad, dtype, False, device
  771. )
  772. flat_tensors.append(padding_tensor)
  773. total_numel += numel_to_pad
  774. else:
  775. flat_tensors = [
  776. torch.flatten(_detach_if_needed(tensor)) for tensor in tensors
  777. ]
  778. return torch.cat(flat_tensors, dim=0)
  779. def flatten_tensors_into_flat_param(
  780. self,
  781. tensors: List[Tensor],
  782. aligned_numel: int,
  783. requires_grad: bool,
  784. ) -> FlatParameter:
  785. flat_param_data = self.flatten_tensors(tensors, aligned_numel)
  786. return FlatParameter(flat_param_data, requires_grad=requires_grad)
  787. def _init_param_reduce_dtypes(
  788. self,
  789. mp_param_dtype: Optional[torch.dtype],
  790. mp_reduce_dtype: Optional[torch.dtype],
  791. ) -> None:
  792. """
  793. Initialize param and reduce dtypes.
  794. Precondition: ``self.flat_param`` is set. This ensures that this
  795. handle's parameters have a single dtype.
  796. Postcondition: This sets ``self._fwd_bwd_param_dtype`` and
  797. ``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype``
  798. is ``None``, then we assume the original parameter dtype. One special
  799. case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype``
  800. is ``None``, in which case we assume the gradient reduction dtype
  801. matches the forward/backward parameter dtype.
  802. """
  803. # Save whether these dtypes were specified so that we permit the
  804. # parameter dtype to change up until the lazy initialization
  805. self._low_prec_param_dtype_specified = mp_param_dtype is not None
  806. self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
  807. if (
  808. self._low_prec_param_dtype_specified
  809. and not self._low_prec_reduce_dtype_specified
  810. ):
  811. # Special case: infer gradient reduction mixed precision
  812. self._fwd_bwd_param_dtype = mp_param_dtype
  813. self._reduce_dtype = self._fwd_bwd_param_dtype
  814. else:
  815. self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
  816. self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
  817. assert self._fwd_bwd_param_dtype is not None
  818. assert self._reduce_dtype is not None
  819. ###################################
  820. # SHARD INITIALIZATION & METADATA #
  821. ###################################
  822. @torch.no_grad()
  823. def shard(self):
  824. """
  825. Shard the handle's ``FlatParameter``.
  826. This allocates new memory for
  827. the sharded flat parameter and frees the unsharded flat parameter's
  828. storage.
  829. Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard
  830. metadata attributes are set for all sharding strategies.
  831. """
  832. flat_param = self.flat_param
  833. if not self.uses_sharded_strategy:
  834. self._init_shard_metadata(0, 0, flat_param.numel() - 1)
  835. else:
  836. _p_assert(
  837. flat_param.storage_offset() == 0,
  838. "The `FlatParameter` is not the sole occupant of its storage",
  839. )
  840. sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
  841. flat_param, self.rank, self.world_size
  842. )
  843. if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
  844. allocated = flat_param._typed_storage()._size() > 0
  845. if allocated:
  846. flat_param._typed_storage()._resize_(0)
  847. flat_param.set_(sharded_flat_param) # type: ignore[call-overload]
  848. start_idx = sharded_flat_param.numel() * self.rank
  849. end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive
  850. self._init_shard_metadata(numel_padded, start_idx, end_idx)
  851. if self._use_orig_params:
  852. self._use_sharded_views()
  853. def _init_shard_metadata(
  854. self,
  855. numel_padded: int,
  856. unsharded_start_idx: int,
  857. unsharded_end_idx: int,
  858. ) -> None:
  859. """
  860. Initialize shard-related metadata for this rank's shard of the flat parameter.
  861. This includes ``_sharded_size``, ``_shard_param_infos``, and ``_shard_numel_padded``.
  862. Args:
  863. numel_padded (int): Numel padded for this rank's sharded flat
  864. parameter.
  865. unsharded_start_idx (int): Start index in the unsharded flat
  866. parameter assigned to this rank.
  867. unsharded_end_idx (int): End index (inclusive) in the unsharded
  868. flat parameter assigned to this rank.
  869. Precondition: ``self.flat_param`` 's data is the sharded flat
  870. parameter.
  871. """
  872. flat_param = self.flat_param
  873. flat_param._sharded_size = flat_param.size() # type: ignore[attr-defined]
  874. sharded_flat_param_numel = flat_param.numel() # includes `numel_padded`
  875. _p_assert(
  876. unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx,
  877. f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}",
  878. )
  879. _p_assert(
  880. numel_padded <= sharded_flat_param_numel,
  881. f"numel_padded: {numel_padded} "
  882. f"sharded_flat_param_numel: {sharded_flat_param_numel}",
  883. )
  884. shard_param_infos = self._get_shard_metadata(
  885. unsharded_start_idx, unsharded_end_idx
  886. )
  887. assert (
  888. len(shard_param_infos) == flat_param._num_params
  889. ), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
  890. flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
  891. flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
  892. def _get_shard_metadata(
  893. self,
  894. unsharded_start_idx: int,
  895. unsharded_end_idx: int,
  896. ) -> Tuple[_ShardParamInfo, ...]:
  897. """
  898. Compute the shard metadata based on ``unsharded_start_idx`` and ``unsharded_end_idx`` (inclusive).
  899. ``unsharded_start_idx`` and ``unsharded_end_idx`` give the interval of the
  900. unsharded flat parameter specifying the shard.
  901. """
  902. flat_param_offsets = self._get_flat_param_offsets()
  903. assert len(flat_param_offsets) == len(
  904. self.flat_param._numels_with_padding
  905. ), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
  906. shard_param_infos: List[_ShardParamInfo] = []
  907. sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
  908. # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
  909. # into the unsharded flat parameter (inclusive) of the given parameter
  910. for i, (
  911. (unsharded_param_start_idx, unsharded_param_end_idx),
  912. is_padding,
  913. ) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)):
  914. if is_padding:
  915. continue
  916. in_sharded_flat_param = (
  917. unsharded_start_idx <= unsharded_param_end_idx
  918. and unsharded_end_idx >= unsharded_param_start_idx
  919. )
  920. if not in_sharded_flat_param:
  921. shard_param_info = _ShardParamInfo(False, None, None, None, None)
  922. else:
  923. if unsharded_start_idx <= unsharded_param_start_idx:
  924. # This branch can only happen once since the rank's
  925. # unsharded start index can only intersect one parameter
  926. intra_param_start_idx = 0
  927. offset_in_shard = unsharded_param_start_idx - unsharded_start_idx
  928. else:
  929. intra_param_start_idx = (
  930. unsharded_start_idx - unsharded_param_start_idx
  931. )
  932. offset_in_shard = 0
  933. assert (
  934. offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
  935. ), (
  936. f"Invalid `offset_in_shard` of {offset_in_shard} for "
  937. f"sharded flat parameter with {sharded_flat_param_numel} numel"
  938. )
  939. intra_param_end_idx = (
  940. min(unsharded_param_end_idx, unsharded_end_idx)
  941. - unsharded_param_start_idx
  942. )
  943. numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1
  944. shard_param_info = _ShardParamInfo(
  945. True,
  946. offset_in_shard,
  947. numel_in_shard,
  948. intra_param_start_idx,
  949. intra_param_end_idx,
  950. )
  951. shard_param_infos.append(shard_param_info)
  952. return tuple(shard_param_infos)
  953. @staticmethod
  954. def _get_unpadded_shard(
  955. tensor: Tensor,
  956. rank: int,
  957. world_size: int,
  958. ) -> Tuple[Tensor, int]:
  959. """
  960. Return the unpadded shard of ``tensor`` for the given ``rank`` and ``world_size``.
  961. The returned value is a tuple of the shard of ``tensor`` without any
  962. padding and the numel to pad for that shard.
  963. If ``tensor`` is already flattened or may be viewed in the flattened
  964. shape (which is true in the expected usage), then this method does not
  965. allocate any new tensor memory.
  966. """
  967. chunks = torch.flatten(tensor).chunk(world_size)
  968. if len(chunks) < (rank + 1):
  969. # This rank gets an empty chunk fully padded with zeros since there
  970. # are not enough chunks across ranks
  971. chunk = chunks[0].new_empty(0)
  972. else:
  973. chunk = chunks[rank]
  974. numel_to_pad = chunks[0].numel() - chunk.numel()
  975. assert (
  976. numel_to_pad >= 0
  977. ), "Chunk's size should be at most the first chunk's size"
  978. return chunk, numel_to_pad
  979. @staticmethod
  980. def _get_shard(
  981. tensor: Tensor,
  982. rank: int,
  983. world_size: int,
  984. ) -> Tuple[Tensor, int]:
  985. """
  986. Return the shard of ``tensor`` with padding for the given ``rank`` and ``world_size`` and the numel padded for that shard.
  987. This method allocates new memory (via :meth:`clone`) since the
  988. unsharded ``tensor`` may be deallocated after this method returns.
  989. """
  990. chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
  991. tensor, rank, world_size
  992. )
  993. shard = chunk.clone()
  994. if numel_to_pad > 0:
  995. shard = F.pad(shard, [0, numel_to_pad])
  996. return shard, numel_to_pad
  997. @staticmethod
  998. def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
  999. """
  1000. Return the shape of ``tensor`` after sharding including padding.
  1001. This requires ``tensor`` to have 1D shape and ensures that the returned
  1002. shape is 1D.
  1003. """
  1004. assert len(tensor.shape) == 1, f"{tensor.shape}"
  1005. unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
  1006. tensor, rank, world_size
  1007. )
  1008. unpadded_sharded_size = unpadded_sharded_tensor.size()
  1009. assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
  1010. return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
  1011. def _get_flat_param_offsets(self) -> List[Tuple[int, int]]:
  1012. """
  1013. Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding).
  1014. NOTE: The returned list includes elements for alignment padding.
  1015. """
  1016. cumulative_sum = list(accumulate(self.flat_param._numels_with_padding))
  1017. starts = [0] + cumulative_sum[:-1]
  1018. ends = [end - 1 for end in cumulative_sum] # inclusive
  1019. param_offsets = list(zip(starts, ends))
  1020. return param_offsets
  1021. @no_type_check
  1022. def shard_metadata(
  1023. self,
  1024. ) -> FlatParamShardMetadata:
  1025. """
  1026. Return the shard-related metadata specific to this rank's shard of the flat parameter.
  1027. NOTE: The returned tuple does not include elements for alignment
  1028. padding but does account for the padding.
  1029. """
  1030. fqns_list = []
  1031. shapes_list = []
  1032. numels_list = []
  1033. shard_param_offsets = []
  1034. for fqn, shape, numel, shard_param_info in zip(
  1035. self.flat_param._fqns,
  1036. self.flat_param._shapes,
  1037. self.flat_param._numels,
  1038. self.flat_param._shard_param_infos,
  1039. ):
  1040. if not shard_param_info.in_shard:
  1041. continue
  1042. fqns_list.append(fqn)
  1043. shapes_list.append(shape)
  1044. numels_list.append(numel)
  1045. shard_param_offsets.append(
  1046. (
  1047. shard_param_info.intra_param_start_idx,
  1048. shard_param_info.intra_param_end_idx,
  1049. )
  1050. )
  1051. return FlatParamShardMetadata(
  1052. tuple(fqns_list),
  1053. tuple(shapes_list),
  1054. tuple(numels_list),
  1055. shard_param_offsets,
  1056. )
  1057. @no_type_check
  1058. @torch.no_grad()
  1059. def init_flat_param_attributes(self) -> None:
  1060. """
  1061. This initializes some attributes on the handle's ``FlatParameter``.
  1062. This should be called during lazy initialization since it requires the
  1063. parameter to be on the compute device if not offloading to CPU and we
  1064. want to give users the chance to move the parameter appropriately after
  1065. the FSDP constructor.
  1066. For each tensor attribute on the ``FlatParameter``, see the unshard and
  1067. reshard methods in this class for the allocation and free pattern.
  1068. """
  1069. flat_param = self.flat_param
  1070. if flat_param.dtype != self._orig_param_dtype:
  1071. # Entering this branch means that the user changed the parameter
  1072. # dtype after FSDP initialization, in which case we may need to
  1073. # refresh some saved dtype attributes (dtypes specified as a part
  1074. # of mixed precision take precedence).
  1075. if not self._low_prec_param_dtype_specified:
  1076. self._fwd_bwd_param_dtype = flat_param.dtype
  1077. # For `reduce_dtype`, require `param_dtype` was not specified since
  1078. # then we infer the `reduce_dtype` from the specified `param_dtype`
  1079. if (
  1080. not self._low_prec_reduce_dtype_specified
  1081. and not self._low_prec_param_dtype_specified
  1082. ):
  1083. self._reduce_dtype = flat_param.dtype
  1084. self._orig_param_dtype = flat_param.dtype
  1085. cpu_device = torch.device("cpu")
  1086. if self._offload_params:
  1087. _p_assert(
  1088. flat_param.device == cpu_device,
  1089. f"Expects the `FlatParameter` to be on CPU when parameter CPU "
  1090. f"offloading is enabled, not {flat_param.device}",
  1091. )
  1092. else:
  1093. self._check_on_compute_device(self.flat_param)
  1094. flat_param._local_shard = flat_param.data
  1095. if self._offload_params:
  1096. # Pin the memory for faster H2D transfer
  1097. flat_param._local_shard = flat_param._local_shard.pin_memory(
  1098. device=self.device
  1099. )
  1100. # Pre-allocate the sharded gradient on CPU to enable non-blocking
  1101. # D2H transfer during the backward pass
  1102. flat_param._cpu_grad = torch.zeros_like(
  1103. flat_param._local_shard, device=cpu_device
  1104. ).pin_memory(device=self.device)
  1105. if self._uses_param_mixed_precision:
  1106. # For parameter mixed precision, we maintain a low precision
  1107. # sharded tensor on the compute device to be all-gathered (for
  1108. # sharded strategies) or directly used (for `NO_SHARD`) for
  1109. # computation.
  1110. flat_param._mp_shard = torch.empty_like(
  1111. flat_param._local_shard,
  1112. device=self.device,
  1113. dtype=self._fwd_bwd_param_dtype,
  1114. )
  1115. _free_storage(flat_param._mp_shard)
  1116. if self.uses_sharded_strategy:
  1117. # We maintain a padded unsharded tensor that serves as the
  1118. # all-gather destination and owns the original parameter storages.
  1119. unsharded_param_dtype = (
  1120. self._fwd_bwd_param_dtype
  1121. if self._uses_param_mixed_precision
  1122. else flat_param.dtype
  1123. ) # use low precision if parameter mixed precision is enabled
  1124. padded_unsharded_numel = flat_param.numel() * self.world_size
  1125. flat_param._full_param_padded = torch.empty(
  1126. padded_unsharded_numel,
  1127. device=self.device,
  1128. dtype=unsharded_param_dtype,
  1129. )
  1130. flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
  1131. _free_storage(flat_param._full_param_padded)
  1132. if self._uses_param_mixed_precision:
  1133. # For parameter mixed precision, we maintain a full precision
  1134. # padded unsharded tensor for when we force full precision.
  1135. flat_param._full_prec_full_param_padded = torch.empty(
  1136. padded_unsharded_numel,
  1137. device=self.device,
  1138. dtype=flat_param.dtype, # full precision
  1139. )
  1140. _free_storage(flat_param._full_prec_full_param_padded)
  1141. ###################
  1142. # UNSHARD/RESHARD #
  1143. ###################
  1144. def pre_unshard(self) -> bool:
  1145. """
  1146. Return ``False`` if this is a no-op and ``True`` otherwise.
  1147. Postcondition: ``self.flat_param`` 's data is on the device for
  1148. communication and is what should be all-gathered. This means that it
  1149. matches the dtype of the expected unsharded parameter.
  1150. """
  1151. if (
  1152. self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
  1153. and self._skipped_use_sharded_views
  1154. ):
  1155. # Since this path imposes special semantics for the unsharded flat
  1156. # parameter (e.g. forcing full precision), use sharded views to
  1157. # reuse the existing logic for that special handling
  1158. self._use_sharded_views()
  1159. ret = False
  1160. if self._use_orig_params and not self._skip_writeback_check:
  1161. ret = self._writeback_orig_params()
  1162. if (
  1163. self.uses_sharded_strategy
  1164. and not self._offload_params
  1165. and not self.needs_unshard()
  1166. ):
  1167. pass # no-op
  1168. elif self._uses_param_mixed_precision and not self._force_full_precision:
  1169. self._use_low_precision_shard()
  1170. ret = True
  1171. elif self._offload_params and self.flat_param.device != self.device:
  1172. # NOTE: This creates a new tensor distinct from any attributes.
  1173. self.flat_param_to(self.device, non_blocking=True)
  1174. ret = True
  1175. self._check_on_compute_device(self.flat_param)
  1176. return ret
  1177. def _use_low_precision_shard(self):
  1178. """Allocate on the compute device and switch to using the low precision sharded flat parameter."""
  1179. self._check_low_precision_shard()
  1180. flat_param = self.flat_param
  1181. _alloc_storage(
  1182. flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined]
  1183. )
  1184. # `copy_()` implicitly casts to the low precision
  1185. flat_param._mp_shard.copy_( # type: ignore[attr-defined]
  1186. flat_param._local_shard.to( # type: ignore[attr-defined]
  1187. self.device, non_blocking=True
  1188. )
  1189. )
  1190. # Invariant: `_mp_shard` is always on the compute device.
  1191. flat_param.data = flat_param._mp_shard # type: ignore[attr-defined]
  1192. def unshard(self):
  1193. """
  1194. Run the unshard logic.
  1195. This includes all-gathering the flat parameter
  1196. and switching to using the unsharded flat parameter. If the handle does
  1197. not need unsharding, then this only switches to using the unsharded
  1198. flat parameter. For ``NO_SHARD``, this is a no-op.
  1199. If FSDP is in :meth:`summon_full_params` and the handle uses parameter
  1200. mixed precision, then the parameter is forced to full precision.
  1201. """
  1202. if not self.needs_unshard():
  1203. # Even when not needing an unshard, we should switch to using
  1204. # the unsharded flat parameter
  1205. unsharded_flat_param = (
  1206. self._get_padded_unsharded_flat_param()
  1207. if self.uses_sharded_strategy
  1208. else self.flat_param
  1209. )
  1210. self._use_unsharded_flat_param(unsharded_flat_param)
  1211. return
  1212. unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
  1213. padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
  1214. self._use_unsharded_flat_param(padded_unsharded_flat_param)
  1215. def needs_unshard(self) -> bool:
  1216. """Return if the handle's flat parameter needs to be unsharded."""
  1217. if not self.uses_sharded_strategy:
  1218. return False
  1219. unsharded_flat_param = self._get_padded_unsharded_flat_param()
  1220. already_unsharded = _same_storage_size(
  1221. unsharded_flat_param, unsharded_flat_param.numel()
  1222. )
  1223. return not already_unsharded
  1224. def _alloc_padded_unsharded_flat_param(self):
  1225. """
  1226. Allocate the *padded* unsharded flat parameter.
  1227. The unpadded unsharded
  1228. flat parameter is always a view into the padded one. This padded
  1229. parameter is saved to a different attribute on the ``FlatParameter``
  1230. depending on if we force full precision.
  1231. """
  1232. self._check_sharded_strategy()
  1233. flat_param = self.flat_param
  1234. unsharded_flat_param = self._get_padded_unsharded_flat_param()
  1235. self._check_storage_freed(unsharded_flat_param)
  1236. _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined]
  1237. return unsharded_flat_param
  1238. def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
  1239. """
  1240. Return a reference to the padded unsharded flat parameter depending on the calling context.
  1241. This should only be called if using a sharded strategy.
  1242. """
  1243. self._check_sharded_strategy()
  1244. flat_param = self.flat_param
  1245. if self._force_full_precision and self._uses_param_mixed_precision:
  1246. # When parameter mixed precision is enabled, we use a different
  1247. # tensor as the all-gather destination to preserve the invariant
  1248. # that `_full_param_padded` is in the low precision
  1249. unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
  1250. _p_assert(
  1251. unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
  1252. f"Expects full precision but got {self._fwd_bwd_param_dtype}",
  1253. )
  1254. # For no-reshard-after-forward strategies, `_full_param_padded` may
  1255. # still be allocated from a previous forward. As we are forcing
  1256. # full precision here, the full-precision unsharded copy may be
  1257. # modified, invalidating the existing low-precision unsharded copy,
  1258. # so we should free it here to ensure a new all-gather for the next
  1259. # forward/backward computation to persist the modifications.
  1260. if flat_param._full_param_padded.untyped_storage().size() > 0:
  1261. _free_storage(flat_param._full_param_padded)
  1262. else:
  1263. unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined]
  1264. return unsharded_flat_param
  1265. def _all_gather_flat_param(
  1266. self,
  1267. padded_unsharded_flat_param: Tensor,
  1268. ) -> Tensor:
  1269. """
  1270. All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
  1271. Then switch to use the all-gathered tensor.
  1272. """
  1273. _p_assert(
  1274. hasattr(self, "process_group") and hasattr(self, "world_size"),
  1275. "Expects a process group and world size to have been set via `shard()`",
  1276. )
  1277. sharded_flat_param = self.flat_param.data
  1278. expected_numel = sharded_flat_param.numel() * self.world_size
  1279. _p_assert(
  1280. padded_unsharded_flat_param.numel() == expected_numel,
  1281. f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
  1282. )
  1283. pg = (
  1284. self._fake_process_group
  1285. if self._use_fake_all_gather
  1286. else self.process_group
  1287. )
  1288. # HACK this should be handled by C10D
  1289. if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
  1290. tensor_list = list(
  1291. torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg))
  1292. )
  1293. dist.all_gather(tensor_list, sharded_flat_param, group=pg)
  1294. else:
  1295. dist.all_gather_into_tensor(
  1296. padded_unsharded_flat_param,
  1297. sharded_flat_param,
  1298. pg,
  1299. )
  1300. if self._offload_params:
  1301. # In case of offloading, `flat_param.data` (i.e. sharded param) is
  1302. # created on the pre-unshard stream. We need to hand it over to the
  1303. # unshard stream for all-gather
  1304. _no_dispatch_record_stream(
  1305. sharded_flat_param,
  1306. self._device_handle.current_stream(), # unshard_stream
  1307. )
  1308. return padded_unsharded_flat_param
  1309. def _use_unsharded_flat_param(
  1310. self,
  1311. padded_unsharded_flat_param: torch.Tensor,
  1312. ) -> None:
  1313. """
  1314. Switch to use the *unpadded* unsharded flat parameter.
  1315. This is a view into the *padded* unsharded flat parameter.
  1316. """
  1317. unsharded_size = self.flat_param._unpadded_unsharded_size
  1318. flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()]
  1319. # slicing [:] is not visible to autograd because of .data
  1320. self.flat_param.data = flat_param_part
  1321. in_forward = self._training_state == HandleTrainingState.FORWARD
  1322. in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
  1323. if self._use_orig_params:
  1324. if self._skipped_use_sharded_views and in_pre_backward:
  1325. # This call corresponds to the complementary pre-backward
  1326. # `_use_unsharded_views()` to the skipped pre-forward
  1327. # `_use_sharded_views()`, so we should skip this one too.
  1328. return
  1329. # We use `Tensor` views in the forward so that they are tracked by
  1330. # autograd. We use them in the pre-backward as well to support
  1331. # reentrant activation checkpointing, which needs the views to be
  1332. # tracked by autograd in the backward pass's recomputed forward.
  1333. self._use_unsharded_views(
  1334. as_params=(not in_forward and not in_pre_backward)
  1335. )
  1336. elif in_forward:
  1337. self._use_unsharded_views(as_params=False)
  1338. def post_unshard(self):
  1339. """
  1340. Run the post-unshard logic.
  1341. This includes freeing the low precision shard if needed.
  1342. """
  1343. if self._uses_param_mixed_precision and self.uses_sharded_strategy:
  1344. self._free_low_precision_sharded_param()
  1345. self._check_on_compute_device(self.flat_param)
  1346. def _free_low_precision_sharded_param(self):
  1347. """Frees the low precision sharded flat parameter."""
  1348. self._check_low_precision_shard()
  1349. # `_mp_shard` is allocated in the pre-unshard stream, consumed in the
  1350. # unshard stream for sharded strategies, and consumed in both the
  1351. # unshard and default streams for `NO_SHARD`. For sharded strategies,
  1352. # the current stream here is the unshard stream, and for `NO_SHARD`,
  1353. # it is the default stream. For `NO_SHARD`, only recording for the
  1354. # default stream suffices since the default stream waits for the
  1355. # unshard stream.
  1356. _no_dispatch_record_stream(
  1357. self.flat_param._mp_shard, self._device_handle.current_stream() # type: ignore[attr-defined]
  1358. )
  1359. _free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
  1360. @torch.no_grad()
  1361. def unshard_grad(self):
  1362. """
  1363. Unshard the handle's ``FlatParameter``'s gradient.
  1364. If all ranks have
  1365. ``None`` gradient, then all original parameters will as well. This
  1366. method performs an all-reduce and an all-gather. The additional
  1367. all-reduce is tolerable since this method is not meant to be used on
  1368. the computation critical path.
  1369. Postcondition: ``_saved_grad_shard`` is defined and contains the value
  1370. to set ``flat_param.grad`` after gradients are resharded.
  1371. """
  1372. if not self.uses_sharded_strategy:
  1373. self._use_unsharded_grad_views()
  1374. return
  1375. flat_param = self.flat_param
  1376. self._check_unsharded(flat_param)
  1377. # Check if all ranks have a `None` gradient
  1378. num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device)
  1379. num_grad_none[0] = flat_param.grad is None
  1380. dist.all_reduce(num_grad_none, group=self.process_group)
  1381. if num_grad_none[0] == self.world_size:
  1382. flat_param._saved_grad_shard = None # type: ignore[assignment]
  1383. self._use_unsharded_grad_views()
  1384. return
  1385. if flat_param.grad is None:
  1386. # In the case that only some ranks have `None` gradient, we use
  1387. # zeros to approximate as a best effort attempt
  1388. if self._debug_level == dist.DebugLevel.INFO:
  1389. warnings.warn(
  1390. f"[Rank {self.rank}] Only some but not all ranks have a "
  1391. "`None` `FlatParameter` gradient, so FSDP is using zeros to "
  1392. "approximate those ranks' sharded gradients being `None`"
  1393. )
  1394. flat_param._saved_grad_shard = None # type: ignore[assignment]
  1395. sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined]
  1396. else:
  1397. self._check_sharded(flat_param.grad)
  1398. flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
  1399. sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
  1400. padded_unsharded_grad = torch.empty(
  1401. flat_param._padded_unsharded_size, # type: ignore[attr-defined]
  1402. device=self.device,
  1403. dtype=sharded_grad.dtype,
  1404. )
  1405. dist.all_gather_into_tensor(
  1406. padded_unsharded_grad, sharded_grad, self.process_group
  1407. )
  1408. unsharded_size = self.flat_param._unpadded_unsharded_size
  1409. flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
  1410. unsharded_size
  1411. )
  1412. self._use_unsharded_grad_views()
  1413. def reshard_grad(self):
  1414. if self._use_orig_params:
  1415. self._use_sharded_grad_views()
  1416. if not self.uses_sharded_strategy:
  1417. return
  1418. self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined]
  1419. delattr(self.flat_param, "_saved_grad_shard")
  1420. def prepare_gradient_for_backward(self):
  1421. """
  1422. Prepare the gradient for the backward computation.
  1423. This is done by saving and clearing any existing sharded gradient
  1424. in ``.grad`` to enable computing a new unsharded gradient.
  1425. """
  1426. _p_assert(
  1427. self._training_state
  1428. in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
  1429. "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
  1430. )
  1431. flat_param = self.flat_param
  1432. if flat_param.grad is not None and (
  1433. flat_param.grad.size() != flat_param._unpadded_unsharded_size
  1434. or flat_param.grad.device != flat_param.device # grad on CPU
  1435. ):
  1436. self._check_on_compute_device(self.flat_param)
  1437. grad_offloaded = flat_param.grad.device != self.device
  1438. _p_assert(
  1439. not grad_offloaded or self._offload_params,
  1440. f"Expects the sharded gradient to be on {self.device} "
  1441. f"but got {flat_param.grad.device}",
  1442. )
  1443. prev_iter_synced_gradients = (
  1444. flat_param.grad.size()
  1445. == flat_param._local_shard.size() # type: ignore[attr-defined]
  1446. )
  1447. if prev_iter_synced_gradients:
  1448. # TODO (awgu): Gradient accumulation outside `no_sync()`
  1449. # does not work with CPU offloading. The issue should be
  1450. # that, in the post-backward hook, we cannot do an addition
  1451. # between a CPU tensor (the existing sharded gradient) and
  1452. # a GPU tensor (the new sharded gradient).
  1453. if not grad_offloaded:
  1454. flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
  1455. sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
  1456. else:
  1457. _p_assert(
  1458. hasattr(flat_param, "_cpu_grad"),
  1459. "`_cpu_grad` should be defined if the gradient is on CPU",
  1460. )
  1461. sharded_grad = flat_param._cpu_grad # type: ignore[attr-defined]
  1462. # If user specified to keep the gradient in low precision, then
  1463. # the gradient may still be of the low precision dtype if the
  1464. # user did not set the gradient to `None` after the previous
  1465. # backward, in which case FSDP should cast back to the full
  1466. # precision dtype so that FSDP can accumulate in that dtype in
  1467. # the post-backward hook and assign to `.grad` in that dtype in
  1468. # the post-backward callback.
  1469. local_shard_dtype = flat_param._local_shard.dtype # type: ignore[attr-defined]
  1470. if (
  1471. self._keep_low_precision_grads
  1472. and sharded_grad.dtype != local_shard_dtype
  1473. ):
  1474. sharded_grad.data = sharded_grad.to(local_shard_dtype)
  1475. else:
  1476. padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
  1477. _p_assert(
  1478. flat_param.grad.size() == padded_unsharded_size,
  1479. "Expects `.grad` to be the unsharded gradient in "
  1480. f"`no_sync()` with size {padded_unsharded_size} "
  1481. f"but got size {flat_param.grad.size()}",
  1482. )
  1483. flat_param.grad = None
  1484. def prepare_gradient_for_optim(self):
  1485. """Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute."""
  1486. def cast_grad_to_param_dtype_if_needed(flat_param):
  1487. # TODO (rohan-varma): test for full precision with keep_low_precision_grads
  1488. if not self._force_full_precision and self._keep_low_precision_grads:
  1489. _p_assert(flat_param.grad is not None, "Unexpected None grad!")
  1490. if flat_param.grad.dtype != self._fwd_bwd_param_dtype:
  1491. flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype)
  1492. if self._use_orig_params:
  1493. self._use_sharded_grad_views()
  1494. flat_param = self.flat_param
  1495. # TODO (awgu): We should replace these conditional checks to encode
  1496. # the logical intention more directly.
  1497. if hasattr(flat_param, "_cpu_grad"):
  1498. # NOTE: This branch includes `NO_SHARD`.
  1499. self._check_sharded(flat_param)
  1500. self._check_on_cpu(flat_param)
  1501. flat_param.grad = flat_param._cpu_grad # type: ignore[attr-defined]
  1502. cast_grad_to_param_dtype_if_needed(flat_param)
  1503. elif hasattr(flat_param, "_saved_grad_shard"):
  1504. self._check_sharded(flat_param)
  1505. self._check_on_compute_device(flat_param)
  1506. if flat_param._saved_grad_shard is not None:
  1507. self._check_on_compute_device(flat_param._saved_grad_shard) # type: ignore[attr-defined]
  1508. # If no sharded gradient was computed this iteration, then there is
  1509. # no need to forward `_saved_grad_shard` to `grad`
  1510. if flat_param._post_backward_called: # type: ignore[attr-defined]
  1511. flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
  1512. if flat_param.grad is not None:
  1513. cast_grad_to_param_dtype_if_needed(flat_param)
  1514. else:
  1515. _p_assert(
  1516. not self.uses_sharded_strategy
  1517. or not flat_param._post_backward_called, # type: ignore[attr-defined]
  1518. "All sharded parameters that received a gradient in the "
  1519. "post-backward should use `_saved_grad_shard`",
  1520. )
  1521. # Delete `_saved_grad_shard` since its existence indicates a previous
  1522. # gradient to accumulate with in the post-backward hook
  1523. if hasattr(flat_param, "_saved_grad_shard"):
  1524. delattr(flat_param, "_saved_grad_shard")
  1525. @contextlib.contextmanager
  1526. def to_cpu(self):
  1527. """
  1528. Move the unpadded unsharded flat parameter to CPU while in the context and moves it back to the previous device upon exit.
  1529. For now, this assumes the ``FlatParameter`` is the unpadded unsharded flat parameter
  1530. since (1) there is no reason to include the padding in the copy and (2)
  1531. there is no use case for the sharded flat parameter.
  1532. Precondition: ``self.flat_param`` 's data is the unpadded unsharded
  1533. flat parameter on the compute device, and the handle uses a sharded
  1534. strategy.
  1535. Postcondition: Same as the precondition.
  1536. """
  1537. self._check_sharded_strategy()
  1538. _p_assert(
  1539. self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
  1540. f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
  1541. )
  1542. self._check_on_compute_device(self.flat_param)
  1543. # Check that the unpadded unsharded flat parameter is a view into the
  1544. # padded unsharded flat parameter as expected
  1545. # NOTE: This check is not strictly needed for correctness but is a
  1546. # useful sanity check since the tensor should only be used internally.
  1547. _p_assert(
  1548. _same_storage(self.flat_param, self._get_padded_unsharded_flat_param()),
  1549. "Expects the unpadded parameter to be a view into the padded parameter",
  1550. )
  1551. self.flat_param_to(torch.device("cpu"))
  1552. self._free_unsharded_flat_param()
  1553. try:
  1554. yield
  1555. finally:
  1556. _p_assert(
  1557. self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
  1558. f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
  1559. )
  1560. padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
  1561. # Copy from CPU to the compute device
  1562. padded_unsharded_flat_param[: self.flat_param.numel()].copy_(
  1563. self.flat_param
  1564. )
  1565. self._use_unsharded_flat_param(padded_unsharded_flat_param)
  1566. def reshard(self, free_unsharded_flat_param: bool):
  1567. """
  1568. Run the reshard logic.
  1569. This includes freeing the unsharded flat
  1570. parameter if ``free_unsharded_flat_param`` and switching to using the
  1571. sharded flat parameter. Note that this also implicitly offloads
  1572. the sharded flat parameter (if CPU offload is enabled) by pointing
  1573. it to the ``_local_shard`` attribute which resides on CPU.
  1574. """
  1575. # Switch to the sharded `FlatParameter` before freeing to prevent
  1576. # "use-after-free"-type bugs with external profiling tools, where for
  1577. # `use_orig_params=True`, the `param` does not point to valid memory
  1578. # when setting `param.data = ...` in `_use_sharded_views()`.
  1579. self._use_sharded_flat_param()
  1580. if free_unsharded_flat_param:
  1581. self._free_unsharded_flat_param()
  1582. def post_reshard(self):
  1583. """
  1584. Run the post-reshard logic.
  1585. This includes freeing any memory that
  1586. can now be freed given that the ``FlatParameter`` points to the full
  1587. precision sharded flat parameter.
  1588. Precondition: ``self.flat_param`` 's data points to the full precision
  1589. sharded flat parameter.
  1590. """
  1591. # For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it
  1592. # is also the low precision *unsharded* flat parameter. Hence, we delay
  1593. # the free until the reshard.
  1594. if (
  1595. self._uses_param_mixed_precision
  1596. and not self.uses_sharded_strategy
  1597. and not self._force_full_precision # did not use the low precision shard
  1598. ):
  1599. self._free_low_precision_sharded_param()
  1600. def _free_unsharded_flat_param(self):
  1601. """
  1602. Free the padded unsharded flat parameter. We allow this
  1603. function to be called even when storage is not allocated
  1604. The tensor to free depends
  1605. on the calling context since the unshard may have forced full
  1606. precision, in which case a different tensor is used.
  1607. """
  1608. self._check_sharded_strategy()
  1609. unsharded_flat_param = self._get_padded_unsharded_flat_param()
  1610. self._check_on_compute_device(unsharded_flat_param)
  1611. # Do not free the memory until all ops in the current stream finish
  1612. _no_dispatch_record_stream(
  1613. unsharded_flat_param, self._device_handle.current_stream()
  1614. )
  1615. _free_storage(unsharded_flat_param)
  1616. def _use_sharded_flat_param(self) -> None:
  1617. """Switches to using the sharded flat parameter."""
  1618. flat_param = self.flat_param
  1619. if self._use_orig_params:
  1620. in_forward = self._training_state == HandleTrainingState.FORWARD
  1621. skip_use_sharded_views = (
  1622. torch.is_grad_enabled()
  1623. and in_forward
  1624. and self._sharding_strategy
  1625. in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
  1626. )
  1627. # Only incur the extra `.data` call if needed
  1628. if skip_use_sharded_views:
  1629. unsharded_flat_param = flat_param.data
  1630. if self._offload_params:
  1631. device = flat_param._local_shard.device # type: ignore[attr-defined]
  1632. _p_assert(
  1633. device == torch.device("cpu"),
  1634. f"Expects the local shard to be on CPU but got {device}",
  1635. )
  1636. flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
  1637. if self._use_orig_params:
  1638. if skip_use_sharded_views: # type: ignore[possibly-undefined]
  1639. self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined]
  1640. else:
  1641. self._use_sharded_views()
  1642. # For the post-forward reshard, we may try to use sharded gradient
  1643. # views (or unsharded gradient views if a gradient was accumulated
  1644. # in `no_sync()`), but for the post-backward reshard, we delay the
  1645. # call to after the reduce-scatter.
  1646. if (
  1647. in_forward # type: ignore[possibly-undefined]
  1648. # Skip using gradient views if skipped using sharded views
  1649. # since exposing unsharded parameters with sharded gradients
  1650. # may be confusing to the user
  1651. and not self._skipped_use_sharded_views
  1652. ):
  1653. # TODO: Change `_unpadded_unsharded_size` if we change the
  1654. # gradient to be computed directly with padding.
  1655. accumulated_grad_in_no_sync = (
  1656. flat_param.grad is not None
  1657. and self.uses_sharded_strategy
  1658. and flat_param.grad.shape == flat_param._unpadded_unsharded_size
  1659. )
  1660. if accumulated_grad_in_no_sync:
  1661. self._use_unsharded_grad_views()
  1662. else:
  1663. self._use_sharded_grad_views()
  1664. #########
  1665. # VIEWS #
  1666. #########
  1667. @no_type_check
  1668. def _get_unflat_views_unaligned(
  1669. self,
  1670. tensor: Optional[torch.Tensor] = None,
  1671. ) -> Iterator[Tensor]:
  1672. """
  1673. Return unflattened ``Tensor`` views into ``tensor``.
  1674. If `tensor`` is ``None``, ``flat_param`` is used. The unflattening is based
  1675. on ``flat_param`` 's metadata.
  1676. Examples for ``tensor`` include ``flat_param.grad`` or unsharded
  1677. tensor optimizer state.
  1678. """
  1679. flat_param = self.flat_param
  1680. if tensor is None:
  1681. tensor = flat_param
  1682. views = (
  1683. _ext_post_unflatten_transform(
  1684. subtensor.view(shape),
  1685. param_extension,
  1686. self._fsdp_extension,
  1687. )
  1688. for (subtensor, shape, param_extension) in zip(
  1689. torch.split(tensor, flat_param._numels, dim=0),
  1690. flat_param._shapes,
  1691. flat_param._param_extensions,
  1692. )
  1693. )
  1694. return views
  1695. @no_type_check
  1696. def _get_unflat_views_aligned(
  1697. self,
  1698. tensor: Optional[Tensor] = None,
  1699. ) -> List[Tensor]:
  1700. """
  1701. Return unflattened ``Tensor`` views into ``tensor`` with handling for padding.
  1702. This method has the same contract as :meth:`_get_unflat_views_unaligned`
  1703. except it checks for ``None`` placeholders representing padding for
  1704. alignment, which may incur slightly more CPU overhead.
  1705. """
  1706. flat_param = self.flat_param
  1707. if tensor is None:
  1708. tensor = flat_param
  1709. splits: List[Tensor] = torch.split(
  1710. tensor, flat_param._numels_with_padding, dim=0
  1711. )
  1712. idx = 0
  1713. views: List[Tensor] = []
  1714. for split, is_padding in zip(splits, flat_param._is_padding_mask):
  1715. if is_padding:
  1716. continue
  1717. views.append(
  1718. _ext_post_unflatten_transform(
  1719. split.view(flat_param._shapes[idx]),
  1720. flat_param._param_extensions[idx],
  1721. self._fsdp_extension,
  1722. )
  1723. )
  1724. idx += 1
  1725. return views
  1726. @no_type_check
  1727. @torch.enable_grad()
  1728. def _use_unsharded_views(self, as_params: bool) -> None:
  1729. """
  1730. Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it.
  1731. Args:
  1732. as_params (bool): If ``True``, then registers the original
  1733. parameters as ``nn.Parameter`` s; if ``False``, then registers
  1734. the original parameters only as ``Tensor`` s. ``False`` should
  1735. be used during forward/backward computation and when hiding the
  1736. original parameters from :meth:`nn.Module.named_parameters`.
  1737. Note:
  1738. when prefetching for next forward, current forward may be
  1739. annotated with `@torch.no_grad()`
  1740. `@torch.enable_grad()` ensures non-empty `view.grad_fn`
  1741. otherwise `_post_backward_hook` will not get called
  1742. """
  1743. flat_param = self.flat_param
  1744. self._check_unsharded(flat_param)
  1745. views = self._get_unflat_views()
  1746. from torch.distributed._tensor import DTensor
  1747. for i, (view, (param_name, module, _)) in enumerate(
  1748. zip(views, flat_param._param_infos)
  1749. ):
  1750. if self._use_orig_params and as_params:
  1751. if type(view) is DTensor:
  1752. # A `DTensor` `view` is not compatible with assigning
  1753. # `param.data = view`, so we cannot preserve the parameter
  1754. # variable.
  1755. self._setattr_param(
  1756. module,
  1757. param_name,
  1758. nn.Parameter(view, requires_grad=flat_param.requires_grad),
  1759. )
  1760. continue
  1761. param = self.flat_param._params[i]
  1762. self._setattr_param(module, param_name, param)
  1763. param.data = view
  1764. elif as_params:
  1765. self._setattr_param(
  1766. module,
  1767. param_name,
  1768. nn.Parameter(view, requires_grad=flat_param.requires_grad),
  1769. )
  1770. else: # `as_params=False`
  1771. param_var: Tensor = view
  1772. if self._use_orig_params:
  1773. if self._training_state == HandleTrainingState.FORWARD:
  1774. # Save the `Tensor` for the pre-backward
  1775. self.flat_param._tensors[i] = view # save for pre-backward
  1776. elif self._training_state == HandleTrainingState.BACKWARD_PRE:
  1777. # Use the saved `Tensor` variable from the forward to
  1778. # preserve the autograd graph so that the post-backward
  1779. # hook fires (e.g. for reentrant AC)
  1780. tensor = self.flat_param._tensors[i]
  1781. tensor.data = view
  1782. param_var = tensor
  1783. self._setattr_tensor(module, param_name, param_var)
  1784. if (
  1785. self._use_orig_params
  1786. and self._training_state == HandleTrainingState.FORWARD
  1787. ):
  1788. module._parameters[param_name] = param_var
  1789. for i, (
  1790. param_name,
  1791. module,
  1792. _,
  1793. prim_param_name,
  1794. prim_module,
  1795. _,
  1796. ) in enumerate(self.flat_param._shared_param_infos):
  1797. prim_param: Union[Tensor, nn.Parameter] = getattr(
  1798. prim_module, prim_param_name
  1799. )
  1800. _p_assert(
  1801. not as_params or isinstance(prim_param, nn.Parameter),
  1802. f"as_params={as_params} type(prim_param)={type(prim_param)}",
  1803. )
  1804. if self._use_orig_params and as_params:
  1805. shared_param = self.flat_param._shared_params[i]
  1806. self._setattr_param(module, param_name, shared_param)
  1807. shared_param.data = prim_param
  1808. elif as_params:
  1809. self._setattr_param(module, param_name, prim_param)
  1810. else:
  1811. self._setattr_tensor(module, param_name, prim_param)
  1812. if (
  1813. self._use_orig_params
  1814. and self._training_state == HandleTrainingState.FORWARD
  1815. ):
  1816. module._parameters[param_name] = prim_param
  1817. @no_type_check
  1818. def _use_unsharded_grad_views(self) -> None:
  1819. """
  1820. Unflatten the unsharded flat parameter's gradient.
  1821. The original parameter variables' gradients are set to be views into
  1822. the unsharded flat parameter's gradient.
  1823. """
  1824. # Expects the gradient to be in `flat_param.grad`
  1825. if self.flat_param.grad is None:
  1826. for param in chain(self.flat_param._params, self.flat_param._shared_params):
  1827. param.grad = None
  1828. return
  1829. self._check_unsharded(self.flat_param.grad)
  1830. views = self._get_unflat_views(self.flat_param.grad)
  1831. for i, (view, (param_name, module, _)) in enumerate(
  1832. zip(views, self.flat_param._param_infos)
  1833. ):
  1834. _p_assert(
  1835. hasattr(module, param_name),
  1836. f"{self.flat_param._fqns[i]} is missing",
  1837. )
  1838. param = getattr(module, param_name)
  1839. if (
  1840. param.shape != view.shape
  1841. or param.dtype != view.dtype
  1842. or param.device != view.device
  1843. ):
  1844. # NOTE: This is a hack using `.data` to side step the check
  1845. # that parameter/gradient sizes/dtypes/devices match. From
  1846. # calling `reshard()`, `param` has the sharded size, has the
  1847. # full precision dtype, and if CPU offloading is enabled, is on
  1848. # CPU. Thus, one or more of the following cases can hold when
  1849. # in `no_sync()`, where `view` is the original parameter's
  1850. # gradient:
  1851. # 1. `view` can have the unsharded size.
  1852. # 2. `view` can have the parameter low precision dtype.
  1853. # 3. `view` can be on GPU.
  1854. if param.grad is None:
  1855. param.grad = torch.empty_like(param)
  1856. param.grad.data = view
  1857. else:
  1858. param.grad = view
  1859. for i, (
  1860. param_name,
  1861. module,
  1862. module_name,
  1863. prim_param_name,
  1864. prim_module,
  1865. _,
  1866. ) in enumerate(self.flat_param._shared_param_infos):
  1867. _p_assert(
  1868. hasattr(module, param_name),
  1869. f"{module_name + '.' + param_name if module_name else param_name} is missing",
  1870. ) # did not save FQN info in `_shared_param_infos`
  1871. param = getattr(module, param_name)
  1872. prim_param = getattr(prim_module, prim_param_name)
  1873. if (
  1874. param.shape != prim_param.grad.shape
  1875. or param.dtype != prim_param.grad.dtype
  1876. or param.device != prim_param.grad.device
  1877. ):
  1878. # NOTE: This is the same hack to use `.data` to side step the
  1879. # size check.
  1880. if param.grad is None:
  1881. param.grad = torch.empty_like(param)
  1882. param.grad.data = prim_param.grad
  1883. else:
  1884. param.grad = prim_param.grad
  1885. @contextlib.contextmanager
  1886. def unflatten_as_params(self) -> Generator:
  1887. """
  1888. Unflatten the original parameters.
  1889. The function assumes that the flat parameter is unsharded. When in the context,
  1890. unflattens the original parameters as ``nn.Parameter`` views into the
  1891. flat parameter, and after the context, restores the original parameters
  1892. as ``Tensor`` views into the flat parameter.
  1893. """
  1894. self._use_unsharded_views(as_params=True)
  1895. try:
  1896. yield
  1897. finally:
  1898. self._use_unsharded_views(as_params=False)
  1899. @no_type_check
  1900. @torch.no_grad()
  1901. def _use_sharded_views(self) -> None:
  1902. """
  1903. Set the original parameter variables' data to be flattened views into the sharded flat parameter.
  1904. The views are kept as flattened to simplify the case where a parameter
  1905. is sharded across ranks. Parameters whose data is not present in the
  1906. sharded flat parameter have their data set to a size-0 empty tensor. We
  1907. do not delete them to ensure to preserve expected behaviors like model
  1908. printability. Parameters whose data is present must preserve their
  1909. variables to be passable to an optimizer.
  1910. """
  1911. self._unsharded_flat_param_for_skipped_views = None
  1912. if not self.uses_sharded_strategy:
  1913. # For `NO_SHARD`, use the *unflattened* unsharded views since we
  1914. # have the unsharded parameter
  1915. self._use_unsharded_views(as_params=True)
  1916. return
  1917. flat_param = self.flat_param
  1918. self._check_sharded(flat_param)
  1919. # Construct once and reuse for all parameters not in the local shard
  1920. size_0_empty_tensor = torch.empty(
  1921. 0,
  1922. dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
  1923. device=self.flat_param.device,
  1924. requires_grad=False,
  1925. )
  1926. for param, shard_param_info, (param_name, module, _) in zip(
  1927. flat_param._params, flat_param._shard_param_infos, flat_param._param_infos
  1928. ):
  1929. self._setattr_param(module, param_name, param)
  1930. if not shard_param_info.in_shard:
  1931. # Allow the original data to be freed via garbage collection
  1932. param.data = size_0_empty_tensor
  1933. else:
  1934. offset = shard_param_info.offset_in_shard
  1935. numel_in_shard = shard_param_info.numel_in_shard
  1936. param.data = flat_param[offset : offset + numel_in_shard]
  1937. assert self.flat_param._shared_params is not None
  1938. for i, (
  1939. param,
  1940. (param_name, module, _, prim_param_name, prim_module, _),
  1941. ) in enumerate(
  1942. zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
  1943. ):
  1944. self._setattr_param(module, param_name, param)
  1945. prim_param = getattr(prim_module, prim_param_name)
  1946. param.data = prim_param # could be both empty and non-empty
  1947. if self._training_state == HandleTrainingState.BACKWARD_POST:
  1948. # Clear the saved `Tensor`s since they are unneeded now
  1949. for i in range(len(self.flat_param._tensors)):
  1950. self.flat_param._tensors[i] = None
  1951. @no_type_check
  1952. @torch.no_grad()
  1953. def _use_sharded_grad_views(self) -> None:
  1954. """
  1955. Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient.
  1956. This is a no-op if there is no gradient.
  1957. Parameters whose data is not present in the sharded flat parameter and
  1958. parameters with ``requires_grad=False`` have their gradients set to
  1959. ``None``. Since the gradient variables do not need to be preserved,
  1960. this method does not manipulate existing ``Tensor`` data directly and
  1961. creates new ``Tensor`` variables instead.
  1962. """
  1963. flat_param = self.flat_param
  1964. self._check_sharded(flat_param)
  1965. grad = self.sharded_grad
  1966. if grad is None:
  1967. for param in chain(flat_param._params, flat_param._shared_params):
  1968. param.grad = None
  1969. return
  1970. self._check_sharded(grad)
  1971. for param, shard_param_info, is_grad_none in zip(
  1972. flat_param._params,
  1973. flat_param._shard_param_infos,
  1974. flat_param._is_grad_none_mask,
  1975. ):
  1976. if not shard_param_info.in_shard:
  1977. param.grad = None
  1978. else:
  1979. numel_in_shard = shard_param_info.numel_in_shard
  1980. if param.requires_grad and not is_grad_none:
  1981. offset = shard_param_info.offset_in_shard
  1982. if self._keep_low_precision_grads or param.dtype != grad.dtype:
  1983. # NOTE: This is a hack using `.data` to side step the
  1984. # check that parameter/gradient dtypes match. Here,
  1985. # `param` has full precision; `grad` has low precision.
  1986. if param.grad is None:
  1987. # `.grad` must have the same shape as `param`
  1988. param.grad = torch.empty_like(param)
  1989. param.grad.data = grad[
  1990. offset : offset + numel_in_shard
  1991. ].reshape(param.shape)
  1992. else:
  1993. param.grad = grad[offset : offset + numel_in_shard].reshape(
  1994. param.shape
  1995. )
  1996. else:
  1997. param.grad = None
  1998. assert flat_param._shared_params is not None
  1999. for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
  2000. zip(flat_param._shared_params, flat_param._shared_param_infos)
  2001. ):
  2002. in_sharded_flat_param = hasattr(prim_module, prim_param_name)
  2003. if in_sharded_flat_param and param.requires_grad:
  2004. prim_param = getattr(prim_module, prim_param_name)
  2005. param.grad = prim_param.grad # share the same reference
  2006. else:
  2007. param.grad = None
  2008. @no_type_check
  2009. @torch.no_grad()
  2010. def _writeback_orig_params(self) -> bool:
  2011. """
  2012. Write back any parameters that changed storage to the handle's ``FlatParameter``.
  2013. Iterates over the original parameters and writes back any parameters
  2014. that changed storages (due to a non-inplace operator) to the handle's
  2015. ``FlatParameter``. This method preserves the ``FlatParameter` 's
  2016. device even if an original parameter's device changes.
  2017. Raises:
  2018. RuntimeError: If an original parameter or gradient changes storages
  2019. but no longer has the expected flattened shape.
  2020. Returns: ``True`` if some writeback happened, and ``False`` otherwise.
  2021. """
  2022. if (
  2023. self.uses_sharded_strategy
  2024. and not self.is_sharded(self.flat_param)
  2025. and not self._skipped_use_sharded_views
  2026. ):
  2027. # For `NO_SHARD`, we may still need to writeback
  2028. return False
  2029. flat_param = self.flat_param
  2030. wroteback = False
  2031. if self._skipped_use_sharded_views and self.uses_sharded_strategy:
  2032. # NOTE: We must use the unsharded flat parameter from which the
  2033. # unsharded views were computed, not the one from the current
  2034. # calling context (`_get_padded_unsharded_flat_param()`) since that
  2035. # may be different (e.g. the model changed from train to eval).
  2036. flat_param_tensor = self._unsharded_flat_param_for_skipped_views
  2037. _p_assert(
  2038. _data_ptr_allocated(flat_param_tensor),
  2039. "If skipped using sharded views, the unsharded flat parameter "
  2040. "should be allocated",
  2041. )
  2042. else:
  2043. flat_param_tensor = flat_param
  2044. # NOTE: Since this method is called in the pre-unshard, which is only
  2045. # called during computation in the pre-forward or pre-backward, the
  2046. # sharded gradient should be guaranteed to be in `.grad`, not in
  2047. # `._saved_grad_shard`.
  2048. flat_param_grad = (
  2049. flat_param.grad
  2050. if self.uses_sharded_strategy or not self._offload_params
  2051. else flat_param._cpu_grad
  2052. )
  2053. for i, (
  2054. param,
  2055. (in_shard, offset_in_shard, numel_in_shard, _, _),
  2056. (param_name, module, _),
  2057. ) in enumerate(
  2058. zip(
  2059. flat_param._params,
  2060. flat_param._shard_param_infos,
  2061. flat_param._param_infos,
  2062. )
  2063. ):
  2064. if not in_shard:
  2065. continue
  2066. if not hasattr(module, param_name):
  2067. # Do not writeback if original parameters are deregistered
  2068. # (e.g. during model checkpointing)
  2069. continue
  2070. # Check for parameter writeback
  2071. if self._skipped_use_sharded_views:
  2072. param = flat_param._tensors[i]
  2073. _p_assert(
  2074. param is not None,
  2075. f"Expects to have saved tensor for {flat_param._fqns[i]}",
  2076. )
  2077. param_changed = getattr(module, param_name) is not param
  2078. needs_param_writeback = (
  2079. param_changed # changed parameter variable itself
  2080. or not _same_storage(param, flat_param_tensor)
  2081. )
  2082. if self._skipped_use_sharded_views and (
  2083. param_changed or needs_param_writeback
  2084. ):
  2085. raise AssertionError(
  2086. "FSDP does not support changing the parameters between "
  2087. f"forward and backward for {self._sharding_strategy}"
  2088. )
  2089. if param_changed:
  2090. # NOTE: The gradient is not preserved after a parameter change.
  2091. param = getattr(module, param_name)
  2092. flat_param._params[i] = param
  2093. if needs_param_writeback:
  2094. expected_shape = torch.Size([numel_in_shard])
  2095. self._writeback_tensor(
  2096. param, flat_param, i, expected_shape, offset_in_shard, True
  2097. )
  2098. wroteback = True
  2099. # Check for gradient writeback
  2100. if self._skipped_use_sharded_views:
  2101. # Skip the writeback check because we do not expose gradients
  2102. # when we skipped using sharded views
  2103. continue
  2104. if param.grad is None and flat_param.grad is not None:
  2105. expected_shape = torch.Size([numel_in_shard])
  2106. self._writeback_tensor(
  2107. None, flat_param.grad, i, expected_shape, offset_in_shard, False
  2108. )
  2109. elif param.grad is not None:
  2110. # For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
  2111. # memory and owns the gradient storage, so it will never
  2112. # require gradient writeback.
  2113. if not self.uses_sharded_strategy and self._offload_params:
  2114. # Explicitly continue to handle the case of `no_sync()`,
  2115. # where `param.grad` is a view into the GPU gradient
  2116. # referenced by `flat_param.grad`, while `flat_param_grad`
  2117. # is `flat_param._cpu_grad`, which is on CPU
  2118. continue
  2119. needs_grad_writeback = flat_param_grad is None or not _same_storage(
  2120. param.grad, flat_param_grad
  2121. )
  2122. if needs_grad_writeback:
  2123. if flat_param_grad is None:
  2124. flat_param_grad = torch.zeros_like(flat_param)
  2125. expected_shape = torch.Size([numel_in_shard])
  2126. self._writeback_tensor(
  2127. param.grad,
  2128. flat_param_grad,
  2129. i,
  2130. expected_shape,
  2131. offset_in_shard,
  2132. False,
  2133. )
  2134. flat_param.grad = flat_param_grad
  2135. flat_param_grad = flat_param.grad
  2136. # TODO: If we want to handle shared parameters, we need to re-generate
  2137. # the shared parameter data structures in case sharedness changed.
  2138. for i, (
  2139. param_name,
  2140. module,
  2141. _,
  2142. prim_param_name,
  2143. prim_module,
  2144. _,
  2145. ) in enumerate(flat_param._shared_param_infos):
  2146. if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
  2147. raise NotImplementedError(
  2148. "Changing shared parameters is not supported yet"
  2149. )
  2150. return wroteback
  2151. def _writeback_tensor(
  2152. self,
  2153. src_tensor: Optional[Tensor],
  2154. dst_tensor: Tensor,
  2155. tensor_index: int,
  2156. expected_shape: torch.Size,
  2157. offset: int,
  2158. is_param: bool, # else gradient
  2159. ) -> None:
  2160. """
  2161. Write back ``src_tensor`` to ``dst_tensor`` at offset ``offset``, where ``src_tensor`` should have shape ``expected_shape``.
  2162. ``is_param`` indicates if the tensor is the parameter (if ``True``) or gradient (if
  2163. ``False``). If ``src_tensor`` is ``None``, then the effect is zeroing
  2164. instead of copying. ``tensor_index`` gives the index of ``src_tensor``
  2165. in the metadata structures.
  2166. Raises:
  2167. RuntimeError: If the ``src_tensor`` does not have the expected
  2168. shape.
  2169. """
  2170. _p_assert(
  2171. len(expected_shape) == 1,
  2172. f"Expects a 1D expected shape but got {expected_shape}",
  2173. )
  2174. if self._debug_level == dist.DebugLevel.INFO:
  2175. rank = self.rank if hasattr(self, "rank") else dist.get_rank()
  2176. src_shape = src_tensor.shape if src_tensor is not None else None
  2177. src_device = src_tensor.device if src_tensor is not None else None
  2178. warnings.warn(
  2179. f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs "
  2180. f"writeback in {self._training_state}\n"
  2181. f"expected shape={expected_shape} shape={src_shape} "
  2182. f"expected device={dst_tensor.device} device={src_device}"
  2183. )
  2184. if src_tensor is not None and src_tensor.shape != expected_shape:
  2185. # NOTE: Gradient shape mismatch is not possible in practice since
  2186. # the gradient shape is enforced to match that of the parameter and
  2187. # we already check for parameter shape mismatch.
  2188. raise RuntimeError(
  2189. f"Cannot writeback when the {'parameter' if is_param else 'gradient'} "
  2190. f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}"
  2191. )
  2192. if src_tensor is not None:
  2193. dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
  2194. else:
  2195. dst_tensor[offset : offset + expected_shape.numel()].zero_()
  2196. assert self.flat_param._is_grad_none_mask is not None
  2197. self.flat_param._is_grad_none_mask[tensor_index] = True
  2198. def _reset_flat_param_grad_info_if_needed(self):
  2199. """
  2200. Reset ``flat_param.grad`` if needed.
  2201. When ``use_orig_params=True``:
  2202. (1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the
  2203. original parameters' ``.grad`` are ``None``, and
  2204. (2) sets ``flat_param.requires_grad=False`` if *none* of the original
  2205. parameters require gradient.
  2206. For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in
  2207. which case we want to free the gradients as soon after the
  2208. ``zero_grad()`` call as possible.
  2209. """
  2210. if not self._use_orig_params:
  2211. return
  2212. flat_param = self.flat_param
  2213. assert flat_param._params is not None # mypy
  2214. all_grad_none = True
  2215. requires_grad = False
  2216. for param in flat_param._params:
  2217. all_grad_none &= param.grad is None
  2218. requires_grad |= param.requires_grad
  2219. if all_grad_none:
  2220. flat_param.grad = None
  2221. # As long as one parameter requires gradient, then the flat parameter
  2222. # must require gradient
  2223. flat_param.requires_grad = requires_grad
  2224. def _deregister_orig_params(self):
  2225. for param_info in self.flat_param._param_infos:
  2226. param_name, module, _ = param_info
  2227. if hasattr(module, param_name):
  2228. delattr(module, param_name)
  2229. for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
  2230. if hasattr(module, param_name):
  2231. delattr(module, param_name)
  2232. ###########
  2233. # HELPERS #
  2234. ###########
  2235. def flat_param_to(self, *args, **kwargs):
  2236. """Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
  2237. self.flat_param.data = self.flat_param.to(*args, **kwargs)
  2238. if self._use_orig_params:
  2239. # Refresh the views because their storage may have changed
  2240. if self.is_sharded(self.flat_param):
  2241. self._use_sharded_views()
  2242. else:
  2243. self._use_unsharded_views(as_params=True)
  2244. def _get_modules(self) -> Set[nn.Module]:
  2245. """Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter."""
  2246. return {pi.module for pi in self.flat_param._param_infos}.union(
  2247. {spi.module for spi in self.flat_param._shared_param_infos}
  2248. )
  2249. def is_sharded(self, tensor: Tensor) -> bool:
  2250. """
  2251. Return whether ``tensor`` is *currently* sharded.
  2252. For ``NO_SHARD``, we choose to have this always return ``False`` for clarity.
  2253. """
  2254. if (
  2255. not hasattr(self.flat_param, "_sharded_size")
  2256. or not self.uses_sharded_strategy
  2257. ):
  2258. # `_sharded_size` is defined iff `handle.shard()` has been called
  2259. return False
  2260. sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
  2261. return tensor.size() == sharded_size
  2262. def param_module_names(self) -> Iterator[Tuple[str, str]]:
  2263. shared_param_infos = [
  2264. ParamInfo(param_name, module, module_name)
  2265. for (
  2266. param_name,
  2267. module,
  2268. module_name,
  2269. _,
  2270. _,
  2271. _,
  2272. ) in self.flat_param._shared_param_infos
  2273. ]
  2274. for param_info in chain(self.flat_param._param_infos, shared_param_infos):
  2275. param_name, _, module_name = param_info # type: ignore[misc]
  2276. yield (param_name, module_name)
  2277. def shared_param_module_names(self) -> Iterator[Tuple[str, str]]:
  2278. for param_name, _, module_name in [
  2279. ParamInfo(param_name, module, module_name)
  2280. for (
  2281. param_name,
  2282. module,
  2283. module_name,
  2284. _,
  2285. _,
  2286. _,
  2287. ) in self.flat_param._shared_param_infos
  2288. ]:
  2289. yield (param_name, module_name)
  2290. @property
  2291. def _fqns_in_shard(self) -> List[str]:
  2292. """Return the FQNs of the parameters present in this rank's shard."""
  2293. fqns_in_shard: List[str] = []
  2294. for fqn, shard_param_info in zip(
  2295. self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined]
  2296. ):
  2297. if shard_param_info.in_shard:
  2298. fqns_in_shard.append(fqn)
  2299. return fqns_in_shard
  2300. @property
  2301. def sharded_grad(self) -> Optional[Tensor]:
  2302. """Return the handle's sharded gradient."""
  2303. flat_param = self.flat_param
  2304. # Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
  2305. # - CPU offloading: `_cpu_grad`
  2306. # - No CPU offloading + sharded strategies: `_saved_grad_shard`
  2307. # - No CPU offloading + `NO_SHARD`: `grad`
  2308. grad: Optional[Tensor]
  2309. if hasattr(flat_param, "_cpu_grad"):
  2310. grad = flat_param._cpu_grad # type: ignore[attr-defined]
  2311. elif hasattr(flat_param, "_saved_grad_shard"):
  2312. # In the post-backward hook, the sharded gradient is still in
  2313. # `_saved_grad_shard`.
  2314. grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
  2315. else:
  2316. # If in IDLE or in FORWARD states, then there may be an
  2317. # (accumulated) gradient. If accessed in IDLE, then this should
  2318. # be due to re-registering the original parameters (e.g. in state
  2319. # dict load).
  2320. _p_assert(
  2321. flat_param.grad is None
  2322. or not self.uses_sharded_strategy
  2323. or self._training_state
  2324. in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE),
  2325. "Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` "
  2326. "unless in IDLE or FORWARD",
  2327. )
  2328. grad = flat_param.grad
  2329. return grad
  2330. def _reset_is_grad_none(self) -> None:
  2331. """
  2332. Reset ``_is_grad_none_mask`` as needed.
  2333. This method should only be
  2334. called in the post-backward after gradient computation, in which case
  2335. if a parameter requires gradient, then it will surely receive a
  2336. gradient and we may reset its mask entry to ``False``.
  2337. """
  2338. if not self._use_orig_params:
  2339. return
  2340. _p_assert(
  2341. self._training_state == HandleTrainingState.BACKWARD_POST,
  2342. "Expects to only be called in the post-backward after gradient computation",
  2343. )
  2344. flat_param = self.flat_param
  2345. assert flat_param._params is not None # mypy
  2346. for i, param in enumerate(flat_param._params): # type: ignore[arg-type]
  2347. # As long as the parameter requires gradient, it should receive a
  2348. # meaningful gradient (even if the gradient happens to be zeros)
  2349. if param.requires_grad:
  2350. assert flat_param._is_grad_none_mask is not None # mypy
  2351. flat_param._is_grad_none_mask[i] = False
  2352. #######################
  2353. # CHECKS & INVARIANTS #
  2354. #######################
  2355. def _check_sharded_strategy(self):
  2356. _p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
  2357. def _check_on_compute_device(self, tensor: Tensor):
  2358. _p_assert(
  2359. tensor.device == self.device,
  2360. f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}",
  2361. )
  2362. def _check_on_cpu(self, tensor: Tensor):
  2363. _p_assert(
  2364. tensor.device == torch.device("cpu"),
  2365. f"Expects tensor to be on CPU but got {tensor.device}",
  2366. )
  2367. @staticmethod
  2368. def _check_storage_freed(tensor: Tensor):
  2369. # Compile does not resize during trace
  2370. if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
  2371. _p_assert(
  2372. _same_storage_size(tensor, 0),
  2373. "Expects storage to be freed but got storage with size > 0",
  2374. )
  2375. @staticmethod
  2376. def _check_storage_allocated(tensor: Tensor):
  2377. _p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated")
  2378. def _check_low_precision_shard(self):
  2379. _p_assert(
  2380. self._uses_param_mixed_precision,
  2381. "Not using low precision for parameters",
  2382. )
  2383. _p_assert(
  2384. getattr(self.flat_param, "_mp_shard", None) is not None,
  2385. "Expects `_mp_shard` to exist",
  2386. )
  2387. device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
  2388. _p_assert(
  2389. device == self.device,
  2390. f"Expects the low precision shard to be on {self.device} but got {device}",
  2391. )
  2392. def _check_unsharded(self, tensor: Tensor):
  2393. msg_prefix = "Expects tensor to be unsharded "
  2394. _p_assert(tensor is not None, msg_prefix + "but got `None`")
  2395. unsharded_size = self.flat_param._unpadded_unsharded_size
  2396. _p_assert(
  2397. tensor.size() == unsharded_size,
  2398. msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
  2399. )
  2400. def _check_sharded(self, tensor: Tensor):
  2401. msg_prefix = "Expects tensor to be sharded "
  2402. _p_assert(tensor is not None, msg_prefix + "but got `None`")
  2403. sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
  2404. _p_assert(
  2405. tensor.size() == sharded_size,
  2406. msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
  2407. )
  2408. ##############
  2409. # PROPERTIES #
  2410. ##############
  2411. @property
  2412. def uses_sharded_strategy(self) -> bool:
  2413. return self._sharding_strategy != HandleShardingStrategy.NO_SHARD
  2414. @property
  2415. def _uses_param_mixed_precision(self) -> bool:
  2416. return self._fwd_bwd_param_dtype != self._orig_param_dtype
  2417. @property
  2418. def _uses_reduce_mixed_precision(self) -> bool:
  2419. return self._reduce_dtype != self._orig_param_dtype
  2420. @property
  2421. def _force_full_precision(self) -> bool:
  2422. return (
  2423. self._uses_param_mixed_precision or self._uses_reduce_mixed_precision
  2424. ) and (
  2425. self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
  2426. or
  2427. # Also disable mixed precision in model eval mode, if configured
  2428. (not self._fully_sharded_module.training and self._use_full_prec_in_eval)
  2429. )
  2430. @property
  2431. def _skipped_use_sharded_views(self) -> bool:
  2432. """
  2433. This property is used for sharding strategies that do not free after forward with ``use_orig_params=True``.
  2434. This returns if this handle is
  2435. currently in a state where it has skipped using sharded views, in which
  2436. case it can restore view invariants via ``_use_sharded_views()``.
  2437. """
  2438. return self._unsharded_flat_param_for_skipped_views is not None
  2439. # NOTE: These are hacks to bypass `nn.Module.__setattr__` checks.
  2440. def _unsafe_setattr_param(
  2441. module: nn.Module, param_name: str, param: nn.Parameter
  2442. ) -> None:
  2443. module._parameters[param_name] = param
  2444. # This bypasses any overrides in case `module` is an instance of an
  2445. # `nn.Module` subclass
  2446. super(nn.Module, module).__setattr__(param_name, param)
  2447. def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None:
  2448. module._parameters.pop(param_name, None)
  2449. # This bypasses any overrides in case `module` is an instance of an
  2450. # `nn.Module` subclass
  2451. super(nn.Module, module).__setattr__(param_name, tensor)
  2452. def _safe_setattr_tensor_or_param(
  2453. module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
  2454. ):
  2455. # Call `delattr()` and `setattr()` to go through `nn.Module` checks
  2456. if hasattr(module, param_name):
  2457. delattr(module, param_name)
  2458. setattr(module, param_name, tensor_or_param)
  2459. def _convert_to_params(
  2460. tensors: List[Union[torch.Tensor, nn.Parameter]]
  2461. ) -> List[nn.Parameter]:
  2462. return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
  2463. def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor:
  2464. return (
  2465. param_or_tensor.detach()
  2466. if isinstance(param_or_tensor, nn.Parameter)
  2467. else param_or_tensor
  2468. )
  2469. def _get_aligned_numel(unsharded_dtype: torch.dtype):
  2470. # NOTE: This alignment constraint comes from TorchInductor.
  2471. ALIGNMENT = 16 # bytes
  2472. unsharded_dtype_size = _get_dtype_size(unsharded_dtype)
  2473. aligned_numel = ALIGNMENT // unsharded_dtype_size
  2474. return aligned_numel
  2475. @functools.lru_cache(8)
  2476. def _get_dtype_size(dtype):
  2477. return torch.empty((), dtype=dtype).element_size()
  2478. def _construct_padding_tensor(
  2479. padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device
  2480. ):
  2481. # NOTE: Set the padding value as a magic number for debuggability. The
  2482. # value itself should never be used in any user-facing computation.
  2483. return (
  2484. torch.ones(
  2485. (padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device
  2486. )
  2487. * _FLAT_PARAM_PADDING_VALUE
  2488. )
  2489. # Use `lru_cache(1)` to only log the warning once (assuming the fixed warning
  2490. # messasge is passed in)
  2491. @functools.lru_cache(1)
  2492. def _warn_skip_writeback_check(log: logging.Logger, warning: str):
  2493. logger.warning(warning)
  2494. # Use `lru_cache(1)` to only log the warning once
  2495. @functools.lru_cache(1)
  2496. def _warn_use_fake_all_gather(log: logging.Logger, warning: str):
  2497. logger.warning(warning)
  2498. # Use `lru_cache(1)` to only log the warning once
  2499. @functools.lru_cache(1)
  2500. def _warn_use_fake_reduce(log: logging.Logger, warning: str):
  2501. logger.warning(warning)
  2502. def _same_storage(a, b):
  2503. # Params are DTensors in backward
  2504. # with SHARD_GRAD_OP + TP
  2505. from torch.distributed._tensor import DTensor
  2506. if isinstance(a, DTensor):
  2507. a = a._local_tensor
  2508. if isinstance(b, DTensor):
  2509. b = b._local_tensor
  2510. return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
  2511. def _same_storage_size(a: torch.Tensor, b: int):
  2512. return a.untyped_storage().size() // a.element_size() == b
  2513. def _storage_size_allocated(tensor: Tensor):
  2514. storage_size: int = tensor.untyped_storage().size()
  2515. return storage_size > 0