zero_redundancy_optimizer.py 70 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651
  1. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  2. #
  3. # This source code is licensed under the BSD license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. r"""Zero Redundancy Optimizer."""
  6. import collections
  7. import copy
  8. import enum
  9. import inspect
  10. import io
  11. import logging
  12. from itertools import chain
  13. from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
  14. import torch
  15. import torch.distributed as dist
  16. from torch.distributed.algorithms.join import Join, Joinable, JoinHook
  17. from torch.distributed.optim.utils import functional_optim_map
  18. from torch.optim import Optimizer
  19. logger = logging.getLogger(__name__)
  20. __all__ = ["ZeroRedundancyOptimizer"]
  21. # Credits: classy_vision/generic/distributed_util.py
  22. def _recursive_copy_to_device(
  23. value: Any,
  24. non_blocking: bool,
  25. device: torch.device,
  26. ) -> Any:
  27. r"""
  28. Recursively searches lists, tuples, dicts and copies tensors to device if possible.
  29. Non-tensor values are passed as-is in the result.
  30. .. note: These are all copies, so if there are two objects that reference
  31. the same object, then after this call, there will be two different objects
  32. referenced on the device.
  33. """
  34. if isinstance(value, torch.Tensor):
  35. return value.to(device, non_blocking=non_blocking)
  36. if isinstance(value, (list, tuple)):
  37. values = [
  38. _recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
  39. for val in value
  40. ]
  41. return values if isinstance(value, list) else tuple(values)
  42. if isinstance(value, collections.abc.Mapping):
  43. return {
  44. key: _recursive_copy_to_device(
  45. val, non_blocking=non_blocking, device=device
  46. )
  47. for key, val in value.items()
  48. }
  49. return value
  50. def _is_trainable(param: torch.Tensor) -> bool:
  51. r"""Return if a parameter is trainable, where trainability is equivalent to requiring a gradient."""
  52. return param.requires_grad
  53. def _broadcast_object(
  54. obj: Any,
  55. src_rank: int,
  56. group: object = dist.group.WORLD,
  57. device: torch.device = torch.device("cpu"),
  58. ) -> Any:
  59. r"""
  60. Broadcasts an object to the given group.
  61. It will be sending the object if called from the source rank and receiving
  62. the object otherwise.
  63. Arguments:
  64. obj: object to broadcast; only used if called on the source rank.
  65. src_rank (int): source rank.
  66. group (``ProcessGroup``, optional): group used for the broadcast
  67. (default: ``dist.group.WORLD``).
  68. device (``torch.device``, optional): device to send from or receive
  69. to (default: ``torch.device("cpu")``).
  70. Returns:
  71. The broadcasted object.
  72. """
  73. if dist.get_rank() == src_rank:
  74. # Send the object
  75. buffer = io.BytesIO()
  76. torch.save(obj, buffer)
  77. data = bytearray(buffer.getbuffer())
  78. length_tensor = torch.LongTensor([len(data)]).to(device)
  79. data_send_tensor = torch.ByteTensor(data).to(device)
  80. dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
  81. dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
  82. else:
  83. # Receive the object
  84. length_tensor = torch.LongTensor([0]).to(device)
  85. dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
  86. data_recv_tensor = torch.empty(
  87. [int(length_tensor.item())], dtype=torch.uint8, device=device
  88. )
  89. dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
  90. buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
  91. obj = torch.load(buffer, map_location=device, weights_only=False)
  92. return obj
  93. class _ZeROJoinHook(JoinHook):
  94. def __init__(self, zero):
  95. assert isinstance(zero, ZeroRedundancyOptimizer), (
  96. "ZeRO join hook requires passing in a ZeroRedundancyOptimizer "
  97. "instance as the state"
  98. )
  99. self.zero = zero
  100. super().__init__()
  101. def main_hook(self):
  102. """
  103. Perform an optimizer step.
  104. This step updates the joined process's shard of
  105. the parameters and broadcasts those parameters.
  106. """
  107. self.zero.step()
  108. class _DDPBucketAssignment:
  109. r"""
  110. Represent a :class:`DistributedDataParallel` bucket assignment.
  111. This means that a (possibly non-strict) subset of the parameters corresponding to
  112. a DDP bucket assigned to a rank to update.
  113. Attributes:
  114. bucket_index (int): index of the bucket determined by the DDP gradient
  115. bucket all-reduce order.
  116. parameters (List[torch.Tensor]): model parameters in the bucket
  117. assigned to this rank.
  118. offset (int): offset into the :class:`GradBucket` 's :meth:`parameters`
  119. giving the index of the first element in the passed-in
  120. ``parameters``; this equivalently indexes into the
  121. :class:`GradBucket` 's :meth:`gradients`.
  122. device (torch.device): device on which the parameters are stored.
  123. tensor (torch.Tensor): flattened tensor giving the data of the
  124. parameter subset assigned to the rank.
  125. """
  126. def __init__(
  127. self,
  128. bucket_index: int,
  129. parameters: List[torch.Tensor],
  130. offset: int,
  131. ):
  132. self.bucket_index = bucket_index
  133. self.parameters = parameters
  134. self.offset = offset
  135. if len(self.parameters) == 0:
  136. raise ValueError("Empty bucket assignment")
  137. # DDP guarantees all parameters in the bucket have the same device
  138. self.device: torch.device = self.parameters[0].device
  139. self.tensor: Optional[torch.Tensor] = None
  140. class _OverlapStatus(enum.IntEnum):
  141. r"""
  142. Define possible statuses that :class:`ZeroRedundancyOptimizer` can be in when overlapping with :class:`DistributedDataParallel`.
  143. Attributes:
  144. ``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and
  145. is waiting for DDP to finalize its bucketing.
  146. ``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that
  147. its bucketing is finalized. The ZeRO instance can now collect the
  148. necessary information about the DDP bucketing.
  149. ``INITIALIZED``: The ZeRO instance is fully initialized and can now
  150. optimize parameters.
  151. """
  152. UNINITIALIZED = 0
  153. DDP_HAS_REBUILT_BUCKETS = 1
  154. INITIALIZED = 2
  155. class _OverlapInfo:
  156. r"""
  157. Information needed by :class:`ZeroRedundancyOptimizer` to overlap with :class:`DistributedDataParallel`.
  158. Arguments:
  159. world_size (int): world size of the process group being used.
  160. Attributes:
  161. shard_buckets (bool): if ``True``, then the assignment of each
  162. :class:`DistributedDataParallel` bucket is partitioned across
  163. possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
  164. across possibly multiple ranks) to approximate uniformity following
  165. a threshold given by the total parameter size divided by the world
  166. size; if ``False``, then each bucket is wholly assigned to a single
  167. :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank);
  168. this should be set to the value passed into the hook constructor.
  169. status (_OverlapStatus): current status; see :class:`_OverlapStatus`
  170. for more information.
  171. params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]``
  172. gives the model parameters in the ``i``th bucket.
  173. params_per_rank (List[List[torch.Tensor]]): ``params_per_rank[i]``
  174. gives the model parameters assigned to the ``i``th rank, where the
  175. parameters are grouped by increasing bucket indices.
  176. offsets (Dict[int, int]): maps from bucket index to the offset in
  177. ``self.params_per_rank[rank]`` giving the index of the first
  178. parameter in that bucket, where ``rank`` is this process's own
  179. rank; the keys of this :class:`dict` are the bucket indices
  180. assigned to this rank.
  181. num_bucket_assignments (int): total number of bucket assignments across
  182. all ranks; this is equal to the number of
  183. :class:`DistributedDataParallel` gradient buckets if
  184. ``shard_buckets=False`` and possibly greater otherwise.
  185. total_size (int, optional): total size of all buckets (i.e. sum of
  186. ``param.numel()`` for all ``param`` across all buckets) if
  187. ``shard_buckets=True``; otherwise, ``None``.
  188. broadcast_handles (List[Work]): :class:`list` of async work handles for
  189. the parameter broadcasts.
  190. bucket_index_to_future (Dict[int, torch.futures.Future]):
  191. :class:`dict` mapping bucket index to the corresponding all-reduce
  192. future.
  193. bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict`
  194. mapping bucket index to the corresponding bucket.
  195. bucket_indices_seen (List[int]): :class:`list` of the bucket indices
  196. seen on this iteration.
  197. """
  198. def __init__(self, world_size) -> None:
  199. self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED
  200. self.shard_buckets: bool = False
  201. # Modified per bucket reconstruction
  202. self.params_per_bucket: List[List[torch.Tensor]] = []
  203. self.params_per_rank: List[List[torch.Tensor]] = [[] for _ in range(world_size)]
  204. self.offsets: Dict[int, int] = {}
  205. # Group Ranks
  206. self.assigned_ranks_per_bucket: List[Set[int]] = []
  207. self.num_bucket_assignments: int = 0
  208. self.total_size: Optional[int] = None
  209. # Modified per iteration
  210. self.broadcast_handles: List[Any] = []
  211. self.bucket_indices_seen: List[int] = []
  212. # Used by `hook_with_zero_step()`
  213. self.bucket_index_to_future: Dict[int, torch.futures.Future] = {}
  214. self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {}
  215. def wait_for_broadcasts(self) -> None:
  216. r"""
  217. Wait for all parameter broadcasts.
  218. This function should be called once all broadcasts have been scheduled,
  219. meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles``
  220. in preparation for the next iteration.
  221. """
  222. assert (
  223. len(self.broadcast_handles) == self.num_bucket_assignments
  224. ), f"Missing at least one broadcast handle on rank {dist.get_rank()}"
  225. _ = [x.wait() for x in self.broadcast_handles]
  226. self.broadcast_handles.clear()
  227. def clear_per_iter_info(self) -> None:
  228. r"""
  229. Clear the data structures that are modified per-iteration.
  230. This function should be called at the end of an iteration.
  231. """
  232. self.bucket_indices_seen.clear()
  233. self.bucket_index_to_future.clear()
  234. self.bucket_index_to_bucket.clear()
  235. class ZeroRedundancyOptimizer(Optimizer, Joinable):
  236. r"""
  237. Wrap an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` and shards its states across ranks in the group.
  238. The sharing is done as described by ZeRO_.
  239. The local optimizer instance in each rank is only
  240. responsible for updating approximately ``1 / world_size`` parameters and
  241. hence only needs to keep ``1 / world_size`` optimizer states. After
  242. parameters are updated locally, each rank will broadcast its parameters to
  243. all other peers to keep all model replicas in the same state.
  244. ``ZeroRedundancyOptimizer`` can be used in conjunction with
  245. :class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak
  246. memory consumption.
  247. ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
  248. of parameters at each rank. Each parameter belongs to a single rank and is
  249. not divided among ranks. The partition is arbitrary and might not match the
  250. the parameter registration or usage order.
  251. Arguments:
  252. params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
  253. or :class:`dict` s giving all parameters, which will be sharded
  254. across ranks.
  255. Keyword Args:
  256. optimizer_class (:class:`torch.nn.Optimizer`): the class of the local
  257. optimizer.
  258. process_group (``ProcessGroup``, optional): ``torch.distributed``
  259. ``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by
  260. :meth:`torch.distributed.init_process_group`).
  261. parameters_as_bucket_view (bool, optional): if ``True``, parameters are
  262. packed into buckets to speed up communication, and ``param.data``
  263. fields point to bucket views at different offsets; if ``False``,
  264. each individual parameter is communicated separately, and each
  265. ``params.data`` stays intact (default: ``False``).
  266. overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is
  267. overlapped with :class:`DistributedDataParallel` 's gradient
  268. synchronization; this requires (1) either a functional optimizer
  269. for the ``optimizer_class`` argument or one with a functional
  270. equivalent and (2) registering a DDP communication hook
  271. constructed from one of the functions in ``ddp_zero_hook.py``;
  272. parameters are packed into buckets matching those in
  273. :class:`DistributedDataParallel`, meaning that the
  274. ``parameters_as_bucket_view`` argument is ignored.
  275. If ``False``, :meth:`step` runs disjointly after the backward pass
  276. (per normal).
  277. (default: ``False``)
  278. **defaults: any trailing arguments, which are forwarded to the local
  279. optimizer.
  280. Example::
  281. >>> # xdoctest: +SKIP
  282. >>> import torch.nn as nn
  283. >>> from torch.distributed.optim import ZeroRedundancyOptimizer
  284. >>> from torch.nn.parallel import DistributedDataParallel as DDP
  285. >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
  286. >>> ddp = DDP(model, device_ids=[rank])
  287. >>> opt = ZeroRedundancyOptimizer(
  288. >>> ddp.parameters(),
  289. >>> optimizer_class=torch.optim.Adam,
  290. >>> lr=0.01
  291. >>> )
  292. >>> ddp(inputs).sum().backward()
  293. >>> opt.step()
  294. .. warning::
  295. Currently, ``ZeroRedundancyOptimizer`` requires that all of the
  296. passed-in parameters are the same dense type.
  297. .. warning::
  298. If you pass ``overlap_with_ddp=True``, be wary of the following: Given
  299. the way that overlapping :class:`DistributedDataParallel` with
  300. :class:`ZeroRedundancyOptimizer` is currently implemented, the first
  301. two or three training iterations do not perform parameter updates in
  302. the optimizer step, depending on if ``static_graph=False`` or
  303. ``static_graph=True``, respectively. This is because it needs
  304. information about the gradient bucketing strategy used by
  305. :class:`DistributedDataParallel`, which is not finalized until the
  306. second forward pass if ``static_graph=False`` or until the third
  307. forward pass if ``static_graph=True``. To adjust for this, one option
  308. is to prepend dummy inputs.
  309. .. warning:: ZeroRedundancyOptimizer is experimental and subject to change.
  310. .. _ZeRO: https://arxiv.org/abs/1910.02054
  311. """
  312. def __init__(
  313. self,
  314. params,
  315. optimizer_class: Type[Optimizer],
  316. process_group: Optional[Any] = None,
  317. parameters_as_bucket_view: bool = False,
  318. overlap_with_ddp: bool = False,
  319. **defaults: Any,
  320. ):
  321. r"""Init."""
  322. # Perform type and assumption checks on the input parameters
  323. params = self._verify_and_init_params(params)
  324. self._verify_same_dense_param_type()
  325. # NOTE: The parent constructor uses `add_param_group()` which is
  326. # partially overloaded in ZeroRedundancyOptimizer, so we use the
  327. # `initialized` flag to dissociate the behaviour of `add_param_group()`
  328. # between the parent and child.
  329. self.initialized = False
  330. Optimizer.__init__(self, params, defaults)
  331. Joinable.__init__(self)
  332. # Now, all parameters are held in both `self._all_params` and
  333. # `self.param_groups`
  334. # Internal data structures (`_cache` indicates lazily evaluated)
  335. self._param_to_rank_cache: Dict[torch.Tensor, int] = {}
  336. self._param_to_index_cache: Dict[torch.Tensor, int] = {}
  337. self._partition_parameters_cache: List[List[Dict]] = []
  338. self._index_to_param_cache: List[torch.Tensor] = []
  339. self._device_to_params_per_rank_cache: Dict[
  340. torch.device, List[List[torch.Tensor]]
  341. ] = {}
  342. self._bucket_assignments_per_rank_cache: List[
  343. Dict[int, _DDPBucketAssignment]
  344. ] = []
  345. self._is_trainable_mask = self._get_is_trainable_mask()
  346. # Default device for collective communication and buckets
  347. self._default_device = self._all_params[0].device
  348. self.process_group = (
  349. process_group if process_group is not None else dist.group.WORLD
  350. )
  351. self.world_size: int = dist.get_world_size(self.process_group)
  352. self.rank: int = dist.get_rank(self.process_group)
  353. self.global_rank: int = dist.distributed_c10d.get_global_rank(
  354. self.process_group, self.rank
  355. )
  356. self._overlap_with_ddp: bool = overlap_with_ddp
  357. self._optim_defaults = defaults
  358. self._optim_constructor = self._get_optimizer_constructor(optimizer_class)
  359. # If `overlap_with_ddp=True`, local optimizer initialization is delayed
  360. # to run time after the necessary information has been collected
  361. if not overlap_with_ddp:
  362. self._init_local_optimizer()
  363. else:
  364. self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size)
  365. if parameters_as_bucket_view:
  366. logger.warning(
  367. "`parameters_as_bucket_view=True` will be ignored since "
  368. "`overlap_with_ddp=True`; instead, a different bucketing "
  369. "strategy will be used"
  370. )
  371. # `self._buckets` is used if `parameters_as_bucket_view=True`, in
  372. # which case parameter data is flattened into contiguous bucket tensors
  373. self.parameters_as_bucket_view = parameters_as_bucket_view
  374. self._buckets: List[List[torch.Tensor]] = []
  375. self._build_param_buckets()
  376. # Optional consolidated optimizer state, only populated if this rank
  377. # is the target in `consolidate_state_dict()`
  378. self._all_state_dicts: List[Dict[str, Any]] = []
  379. self.initialized = True
  380. def _clear_cache(self) -> None:
  381. r"""Clear the cached data structures giving partition information."""
  382. self._partition_parameters_cache.clear()
  383. self._param_to_rank_cache.clear()
  384. self._index_to_param_cache.clear()
  385. self._param_to_index_cache.clear()
  386. self._device_to_params_per_rank_cache.clear()
  387. self._bucket_assignments_per_rank_cache.clear()
  388. def add_param_group(self, param_group: Dict[str, Any]) -> None:
  389. r"""
  390. Add a parameter group to the :class:`Optimizer` 's ``param_groups``.
  391. This can be useful when fine tuning a pre-trained network, as frozen
  392. layers can be made trainable and added to the :class:`Optimizer` as
  393. training progresses.
  394. Arguments:
  395. param_group (dict): specifies the parameters to be optimized and
  396. group-specific optimization options.
  397. .. warning:: This method handles updating the shards on all partitions
  398. but needs to be called on all ranks. Calling this on a subset of
  399. the ranks will cause the training to hang because communication
  400. primitives are called depending on the managed parameters and
  401. expect all the ranks to participate on the same set of parameters.
  402. """
  403. if self.initialized and self._overlap_with_ddp:
  404. raise RuntimeError(
  405. "ZeroRedundancyOptimizer with `overlap_with_ddp=True` only "
  406. "supports a single parameter group"
  407. )
  408. super().add_param_group(param_group)
  409. # NOTE: The rest of the method assumes that the call to the parent's
  410. # `add_param_group()` appends the new parameter group and preserves
  411. # the previous parameter-group ordering
  412. if self.initialized:
  413. # Force a re-partitioning of the parameters
  414. self._clear_cache()
  415. param_groups = self._partition_parameters()[self.rank]
  416. # NOTE: All parameters in the old parameter groups should be
  417. # assigned to the same ranks so that the local optimizers do not
  418. # need to be reinitialized
  419. # Add the parameters assigned to this rank from the new parameter
  420. # group to the local optimizer, if any
  421. if len(param_groups) == len(self.optim.param_groups) + 1:
  422. self.optim.add_param_group(param_groups[-1])
  423. # Update the bucketing strategy accordingly
  424. if self.parameters_as_bucket_view:
  425. self._build_param_buckets()
  426. def consolidate_state_dict(self, to: int = 0) -> None:
  427. r"""
  428. Consolidate a list of ``state_dict`` s (one per rank) on the target rank.
  429. Arguments:
  430. to (int): the rank that receives the optimizer states (default: 0).
  431. Raises:
  432. RuntimeError: if ``overlap_with_ddp=True`` and this method is
  433. called before this :class:`ZeroRedundancyOptimizer` instance
  434. has been fully initialized, which happens once
  435. :class:`DistributedDataParallel` gradient buckets have been
  436. rebuilt.
  437. .. warning:: This needs to be called on all ranks.
  438. """
  439. self._check_overlap_initialized()
  440. # Sync the exposed `param_groups` attributes to the local optimizer in
  441. # case they have been updated
  442. self._sync_param_groups(self.param_groups, self.optim.param_groups)
  443. # Pull the sharded state from all ranks and store them in rank order
  444. empty_messenger = torch.tensor(
  445. [0], dtype=torch.uint8, device=self._default_device
  446. )
  447. # NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`)
  448. # due to compatibility issues with NCCL backend; a possible follow-up
  449. # is to move all sharded state management to RPC RRef
  450. self._all_state_dicts = []
  451. for rank in range(self.world_size):
  452. global_rank = dist.distributed_c10d.get_global_rank(
  453. self.process_group, rank
  454. )
  455. if self.rank == to:
  456. # Consolidate all local `state_dict`s on this rank, storing on
  457. # CPU to save GPU memory
  458. if rank == self.rank:
  459. # Directly append own optimizer state
  460. self._all_state_dicts.append(
  461. _recursive_copy_to_device(
  462. self.optim.state_dict(),
  463. non_blocking=True,
  464. device=torch.device("cpu"),
  465. )
  466. )
  467. else:
  468. # Receive the optimizer state from the source rank
  469. local_state_dict = _broadcast_object(
  470. empty_messenger,
  471. src_rank=global_rank,
  472. group=self.process_group,
  473. device=self._default_device,
  474. )
  475. self._all_state_dicts.append(
  476. _recursive_copy_to_device(
  477. local_state_dict,
  478. non_blocking=True,
  479. device=torch.device("cpu"),
  480. )
  481. )
  482. else:
  483. if rank == self.rank:
  484. # Send the optimizer state to the target rank
  485. _ = _broadcast_object(
  486. self.optim.state_dict(),
  487. src_rank=self.global_rank,
  488. group=self.process_group,
  489. device=self._default_device,
  490. )
  491. elif rank != to:
  492. # Discard the received object; `broadcast()` is used for
  493. # compatibility reasons
  494. _ = _broadcast_object(
  495. empty_messenger,
  496. src_rank=global_rank,
  497. group=self.process_group,
  498. device=self._default_device,
  499. )
  500. def _verify_params_per_rank(
  501. self,
  502. params_per_rank: List[List[torch.Tensor]],
  503. ) -> None:
  504. r"""
  505. Verify ``params_per_rank`` for :meth:`_partition_parameters`.
  506. The verification is done by checking that ``params_per_rank`` has length equal
  507. to the world size and that it does not contain any parameters not passed into the
  508. :class:`ZeroRedundancyOptimizer` constructor.
  509. The parameters in ``params_per_rank`` being a strict subset of those
  510. passed into the constructor is valid since some parameters may be
  511. frozen.
  512. Raises:
  513. ValueError: if ``params_per_rank`` does not have length equal to
  514. the world size or if it contains a parameter that was not
  515. passed into the :class:`ZeroRedundancyOptimizer` constructor.
  516. """
  517. if len(params_per_rank) != self.world_size:
  518. raise ValueError(
  519. "`params_per_rank` must have length equal to the world size"
  520. )
  521. all_params_set = set(self._all_params)
  522. for params in params_per_rank:
  523. for param in params:
  524. if param not in all_params_set:
  525. raise ValueError(
  526. "Passing a new parameter in `params_per_rank` that "
  527. "was not passed into the ZeroRedundancyOptimizer "
  528. "constructor"
  529. )
  530. def _partition_param_group(
  531. self, param_group: Dict[str, Any], params_per_rank: List[List[torch.Tensor]]
  532. ) -> None:
  533. r"""
  534. Partition the parameter group ``param_group`` according to ``params_per_rank``.
  535. The partition will modify the ``self._partition_parameters_cache``. This method should
  536. only be used as a subroutine for :meth:`_partition_parameters`.
  537. Arguments:
  538. param_group (dict[str, Any]): a parameter group as normally defined
  539. in an optimizer state.
  540. params_per_rank (list[list[torch.Tensor]]): a :class:`list` of
  541. length world size containing :class:`list` s of parameters to
  542. assign to each rank.
  543. """
  544. for rank, params in enumerate(params_per_rank):
  545. rank_param_group = copy.copy(param_group)
  546. rank_param_group["params"] = params
  547. self._partition_parameters_cache[rank].append(rank_param_group)
  548. def _partition_parameters(
  549. self,
  550. params_per_rank: Optional[List[List[torch.Tensor]]] = None,
  551. ) -> List[List[Dict]]:
  552. r"""
  553. Partitions parameters across distributed data parallel ranks.
  554. Arguments:
  555. params_per_rank (list[list[torch.Tensor]], optional): a
  556. :class:`list` of length world size containing :class:`list` s
  557. of parameters to assign to each rank; this provides a way to
  558. specify a partition manually.
  559. If ``None``, the parameters are partitioned according to an
  560. internal algorithm.
  561. (default: ``None``)
  562. Returns:
  563. A :class:`list` where each element of the list contains the
  564. ``param_groups`` for a rank (which itself is a :class:`list` of
  565. :class:`dict`); element 0 corresponds to rank 0, etc.; each rank
  566. stores the ``param_groups`` for all ranks for the collective
  567. communication in :meth:`step`.
  568. Raises:
  569. ValueError: see :meth:`_validate_params_per_rank`.
  570. RuntimeError: if ``params_per_rank`` is not ``None`` and this
  571. :class:`ZeroRedundancyOptimizer` instance is using more than
  572. one parameter group.
  573. """
  574. if params_per_rank is None:
  575. # Partition the parameters optimizing for uniformity
  576. if len(self._partition_parameters_cache) == 0:
  577. self._partition_parameters_cache = [[] for _ in range(self.world_size)]
  578. sizes = [0] * self.world_size
  579. for param_group in self.param_groups:
  580. param_group_params_per_rank: List[List] = [
  581. [] for _ in range(self.world_size)
  582. ]
  583. # Sort the parameters by size (largest first)
  584. params_sorted = sorted(
  585. param_group["params"], key=lambda t: t.numel(), reverse=True
  586. )
  587. for param in params_sorted:
  588. # Greedily add the parameter to rank with smallest size so far
  589. rank = self._get_min_index(sizes)
  590. param_group_params_per_rank[rank].append(param)
  591. sizes[rank] += param.numel()
  592. # Apply the constructed partition of the parameter group
  593. self._partition_param_group(
  594. param_group, param_group_params_per_rank
  595. )
  596. return self._partition_parameters_cache
  597. # Partition the parameters according to `params_per_rank`
  598. assert len(self._partition_parameters_cache) == 0, (
  599. "Specifying `params_per_rank` should only be done when the "
  600. "parameters have not been partitioned yet"
  601. )
  602. if len(self.param_groups) != 1:
  603. raise RuntimeError(
  604. "Specifying `params_per_rank` only supports a single parameter group"
  605. )
  606. self._verify_params_per_rank(params_per_rank)
  607. self._partition_parameters_cache = [[] for _ in range(self.world_size)]
  608. # Apply the passed-in partition of the parameter group
  609. param_group = self.param_groups[0]
  610. self._partition_param_group(param_group, params_per_rank)
  611. return self._partition_parameters_cache
  612. @property
  613. def _param_to_rank(self) -> Dict[torch.Tensor, int]:
  614. r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition."""
  615. if len(self._param_to_rank_cache) == 0:
  616. for rank, param_groups in enumerate(self._partition_parameters()):
  617. for param_group in param_groups:
  618. for param in param_group["params"]:
  619. self._param_to_rank_cache[param] = rank
  620. return self._param_to_rank_cache
  621. @property
  622. def _param_to_index(self) -> Dict[torch.Tensor, int]:
  623. r"""
  624. :class:`dict` mapping parameters to their indices in the global optimizer state.
  625. NOTE: This assumes that the global optimizer state's indexing (in
  626. ``state_dict``) follows a linear ordering over the parameter groups.
  627. """
  628. if len(self._param_to_index_cache) == 0:
  629. self._param_to_index_cache = {
  630. p: i
  631. for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))
  632. }
  633. return self._param_to_index_cache
  634. @property
  635. def _index_to_param(self) -> List[torch.Tensor]:
  636. r"""List mapping parameter indices in the global optimizer scheme to the actual params."""
  637. if len(self._index_to_param_cache) == 0:
  638. self._index_to_param_cache = list(
  639. chain(*(g["params"] for g in self.param_groups))
  640. )
  641. return self._index_to_param_cache
  642. def _broadcast_params_from_rank(self, rank: int):
  643. r"""
  644. Broadcast the shard of parameters from a given rank to all other ranks asynchronously.
  645. Arguments:
  646. rank (int): the source rank.
  647. Returns:
  648. A :class:`list` of async work handles for the ``broadcast()`` s
  649. performed to synchronize the parameters.
  650. """
  651. assert not self._overlap_with_ddp, (
  652. "`_broadcast_params_from_rank()` should not be used if "
  653. "`overlap_with_ddp=True`; instead, the broadcasting should "
  654. "happen in the DDP communication hook"
  655. )
  656. handles = []
  657. if self.parameters_as_bucket_view:
  658. for dev_i_buckets in self._buckets:
  659. bucket = dev_i_buckets[rank]
  660. global_rank = dist.distributed_c10d.get_global_rank(
  661. self.process_group, rank
  662. )
  663. handles.append(
  664. dist.broadcast(
  665. tensor=bucket,
  666. src=global_rank,
  667. group=self.process_group,
  668. async_op=True,
  669. )
  670. )
  671. else:
  672. param_groups = self._partition_parameters()[rank]
  673. global_rank = dist.distributed_c10d.get_global_rank(
  674. self.process_group, rank
  675. )
  676. for param_group in param_groups:
  677. for param in param_group["params"]:
  678. handles.append(
  679. dist.broadcast(
  680. tensor=param.data,
  681. src=global_rank,
  682. group=self.process_group,
  683. async_op=True,
  684. )
  685. )
  686. return handles
  687. def _sync_params(self):
  688. r"""
  689. Sync all parameter shards across the ranks.
  690. This rank sends its shard of the parameters to all other ranks and
  691. receives a shard from each other rank. This is done using
  692. ``broadcast()``. Parameters are sent bucket-by-bucket if
  693. ``parameters_as_bucket_view=True``and sent parameter-by-parameter
  694. otherwise.
  695. """
  696. handles = []
  697. for rank in range(self.world_size):
  698. handles.extend(self._broadcast_params_from_rank(rank))
  699. _ = [x.wait() for x in handles]
  700. @property
  701. def _device_to_params_per_rank(
  702. self,
  703. ) -> Dict[torch.device, List[List[torch.Tensor]]]:
  704. r"""
  705. Return device parameters assigned per rank.
  706. :class:`dict` mapping each device to a :class:`list` of the per-rank parameter
  707. lists filtered to only include the parameters stored on that device.
  708. Each per-rank parameter list gives the parameters assigned to that rank
  709. to update.
  710. This is used for constructing the parameter buckets if
  711. ``parameters_as_bucket_view=True``.
  712. Let ``dev_i`` denote the ``i``th device for this rank. Then:
  713. ``dev_0`` maps to a list containing:
  714. rank 0's assigned parameters stored on ``dev_0``,
  715. rank 1's assigned parameters stored on ``dev_0``,
  716. ...
  717. ``dev_1`` maps to a list containing:
  718. rank 0's assigned parameters stored on ``dev_1``,
  719. rank 1's assigned parameters stored on ``dev_1``,
  720. ...
  721. ...
  722. """
  723. assert self.parameters_as_bucket_view, (
  724. "`_device_to_params_per_rank` should only be used if "
  725. "`parameters_as_bucket_view=True`"
  726. )
  727. if len(self._device_to_params_per_rank_cache) == 0:
  728. for rank, param_groups in enumerate(self._partition_parameters()):
  729. for param_group in param_groups:
  730. for param in param_group["params"]:
  731. device = param.device
  732. if device not in self._device_to_params_per_rank_cache:
  733. self._device_to_params_per_rank_cache[device] = [
  734. [] for _ in range(self.world_size)
  735. ]
  736. self._device_to_params_per_rank_cache[device][rank].append(
  737. param
  738. )
  739. return self._device_to_params_per_rank_cache
  740. def _get_min_index(
  741. self,
  742. values: List[int],
  743. disallowed_indices: Optional[Set[int]] = None,
  744. ) -> int:
  745. r"""
  746. Return ``values.index(min(values))``, except only uses one pass.
  747. It also excludes any indices in ``disallowed_indices`` if provided.
  748. Arguments:
  749. values: (List[int]): :class:`list` of values.
  750. disallowed_indices (Optional[Set[int]]): indices that are
  751. disallowed from being the returned min index.
  752. """
  753. min_index = -1
  754. min_value = float("inf")
  755. for i, value in enumerate(values):
  756. if disallowed_indices and i in disallowed_indices:
  757. continue
  758. if value < min_value:
  759. min_value = value
  760. min_index = i
  761. assert min_index >= 0, "All indices are disallowed"
  762. return min_index
  763. def _assign_bucket_subset_to_rank(
  764. self,
  765. bucket_index: int,
  766. bucket_params: List[torch.Tensor],
  767. bucket_offset: int,
  768. assigned_rank: int,
  769. assigned_ranks_per_bucket: List[Set[int]],
  770. ) -> None:
  771. r"""
  772. Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information.
  773. The model parameters given by ``bucket_params`` represents a (possibly non-strict)
  774. subset of the parameters corresponding to a :class:`DistributedDataParallel` bucket.
  775. Arguments:
  776. bucket_index (int): index of the :class:`DistributedDataParallel`
  777. gradient bucket.
  778. bucket_params (List[torch.Tensor]): subset of the parameters
  779. corresponding to the bucket to assign.
  780. bucket_offset (int): offset giving the index of the first element
  781. in ``bucket_params`` in the bucket's full parameter list.
  782. assigned_rank (int): group rank to assign to.
  783. assigned_ranks_per_bucket (List[Set[int]]): :class:`set` of group ranks
  784. assigned to each bucket.
  785. """
  786. overlap_info = self._overlap_info
  787. if len(bucket_params) == 0:
  788. raise ValueError("Empty bucket assignment")
  789. params_per_rank = overlap_info.params_per_rank
  790. offsets = overlap_info.offsets
  791. self._bucket_assignments_per_rank_cache[assigned_rank][
  792. bucket_index
  793. ] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
  794. if self.global_rank == assigned_rank:
  795. offsets[bucket_index] = len(params_per_rank[assigned_rank])
  796. params_per_rank[assigned_rank].extend(bucket_params)
  797. assigned_ranks_per_bucket[bucket_index].add(assigned_rank)
  798. self._overlap_info.num_bucket_assignments += 1
  799. @property
  800. def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]:
  801. r"""
  802. Return DDP bucket parameters assigned per rank.
  803. :class:`list` of length world size consisting of :class:`dict` s
  804. mapping bucket indices to :class:`_DDPBucketAssignment` s for each
  805. rank.
  806. """
  807. assert self._overlap_with_ddp, (
  808. "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`"
  809. )
  810. if len(self._bucket_assignments_per_rank_cache) > 0:
  811. return self._bucket_assignments_per_rank_cache
  812. overlap_info = self._overlap_info
  813. assert overlap_info.status == _OverlapStatus.INITIALIZED
  814. self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)]
  815. params_per_bucket = overlap_info.params_per_bucket
  816. if overlap_info.shard_buckets:
  817. # Define the assignment threshold to approximate uniformity
  818. assert overlap_info.total_size is not None, "`total_size` was not computed"
  819. threshold = overlap_info.total_size / self.world_size # type: ignore[operator]
  820. size_per_rank = [0 for _ in range(self.world_size)]
  821. num_buckets = len(params_per_bucket)
  822. overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)]
  823. assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket
  824. if not overlap_info.shard_buckets:
  825. # Assign each DDP bucket entirely to a single rank
  826. for bucket_index, bucket_params in enumerate(params_per_bucket):
  827. assert len(bucket_params) > 0, "Empty bucket"
  828. assigned_rank = self._get_assigned_rank(bucket_index)
  829. self._assign_bucket_subset_to_rank(
  830. bucket_index,
  831. bucket_params,
  832. 0,
  833. assigned_rank,
  834. assigned_ranks_per_bucket,
  835. )
  836. else:
  837. # Assign each DDP bucket to possibly multiple ranks
  838. # Specifically, sort the DDP buckets by increasing size, and for
  839. # each bucket, iteratively assign the maximal unassigned subset
  840. # with size less than `threshold` to the rank with the least total
  841. # size so far -- each such assignment is represented by a
  842. # `_DDPBucketAssignment` instance and only contains parameters from
  843. # a single DDP bucket
  844. params_per_bucket_enum = sorted(
  845. enumerate(params_per_bucket), key=lambda x: sum(p.numel() for p in x[1])
  846. )
  847. for bucket_index, bucket_params in params_per_bucket_enum:
  848. assert len(bucket_params) > 0, "Empty bucket"
  849. bucket_offset = 0
  850. assignment_size = 0
  851. for param_index, param in enumerate(bucket_params):
  852. param_numel = param.numel()
  853. if (
  854. assignment_size + param_numel >= threshold
  855. and param_index > bucket_offset
  856. ):
  857. assigned_rank = self._get_min_index(
  858. size_per_rank, assigned_ranks_per_bucket[bucket_index]
  859. )
  860. # Include up to but not including the parameter that
  861. # exceeded the threshold
  862. self._assign_bucket_subset_to_rank(
  863. bucket_index,
  864. bucket_params[bucket_offset:param_index],
  865. bucket_offset,
  866. assigned_rank,
  867. assigned_ranks_per_bucket,
  868. )
  869. size_per_rank[assigned_rank] += assignment_size
  870. bucket_offset = param_index
  871. assignment_size = 0
  872. assignment_size += param_numel
  873. # Assign the remainder of the bucket so that no assignment
  874. # spans across two buckets
  875. assigned_rank = self._get_min_index(
  876. size_per_rank, assigned_ranks_per_bucket[bucket_index]
  877. )
  878. self._assign_bucket_subset_to_rank(
  879. bucket_index,
  880. bucket_params[bucket_offset:],
  881. bucket_offset,
  882. assigned_rank,
  883. assigned_ranks_per_bucket,
  884. )
  885. size_per_rank[assigned_rank] += assignment_size
  886. return self._bucket_assignments_per_rank_cache
  887. def _local_step(
  888. self,
  889. gradients: Optional[List[Optional[torch.Tensor]]] = None,
  890. closure: Optional[Callable[[], float]] = None,
  891. **kwargs: Any,
  892. ) -> Optional[float]:
  893. r"""
  894. Perform a single optimizer step without syncing parameters across ranks.
  895. Arguments:
  896. gradients (list[Optional[torch.Tensor]], optional): a :class:`list`
  897. of length equal to the number of parameters assigned to this
  898. rank containing gradient tensors or ``None`` as its elements;
  899. a ``None`` in the :class:`list` indicates that the
  900. corresponding parameter should not be updated.
  901. If the argument itself is ``None``, then all parameters are
  902. updated, and the gradients are assumed to be already populated.
  903. (default: ``None``)
  904. closure (Callable): a closure that re-evaluates the model and
  905. returns the loss; optional for most optimizers and should be
  906. ``None`` if ``gradients`` is not ``None``; (default: ``None``)
  907. Returns:
  908. Optional loss depending on the underlying local optimizer.
  909. .. warning::
  910. The argument ``gradients`` should only be specified (i.e. not
  911. ``None``) if ``overlap_with_ddp=True``, in which case
  912. :class:`ZeroRedundancyOptimizer` wraps a functional optimizer.
  913. """
  914. Join.notify_join_context(self)
  915. # Check if the model trainability has changed
  916. is_trainable_mask = self._get_is_trainable_mask()
  917. if is_trainable_mask != self._is_trainable_mask:
  918. if self._overlap_with_ddp:
  919. raise RuntimeError(
  920. "ZeroRedundancyOptimizer with `overlap_with_ddp=True` "
  921. "does not support changing parameter trainability at run "
  922. "time"
  923. )
  924. logger.warning(
  925. "ZeroRedundancyOptimizer detected that the trainable "
  926. "parameters changed; rebuilding the parameter buckets if "
  927. "enabled"
  928. )
  929. self._build_param_buckets()
  930. self._is_trainable_mask = is_trainable_mask
  931. # Sync the exposed `param_groups` attributes to the local optimizer in
  932. # case they have been updated
  933. self._sync_param_groups(self.param_groups, self.optim.param_groups)
  934. # Run the optimizer step on this shard only
  935. if gradients is None:
  936. loss = (
  937. self.optim.step(**kwargs)
  938. if closure is None
  939. else self.optim.step(closure=closure, **kwargs)
  940. )
  941. else:
  942. assert self._overlap_with_ddp, (
  943. "Specifying `gradients` should not "
  944. "be used when `overlap_with_ddp=False`"
  945. )
  946. assert closure is None, (
  947. "`closure` is not supported when using a local functional optimizer"
  948. )
  949. loss = self.optim.step(gradients=gradients)
  950. # Sync any updated attributes in the local optimizer to the exposed
  951. # `param_groups`
  952. self._sync_param_groups(self.optim.param_groups, self.param_groups)
  953. return loss
  954. def step(
  955. self,
  956. closure: Optional[Callable[[], float]] = None,
  957. **kwargs: Any,
  958. ) -> Optional[float]:
  959. r"""
  960. Perform a single optimizer step and syncs parameters across all ranks.
  961. Arguments:
  962. closure (Callable): a closure that re-evaluates the model and
  963. returns the loss; optional for most optimizers.
  964. Returns:
  965. Optional loss depending on the underlying local optimizer.
  966. .. note: Any extra parameters are passed to the base optimizer as-is.
  967. """
  968. if self._overlap_with_ddp:
  969. logger.warning(
  970. "`step()` should not be included in the training loop when "
  971. "`overlap_with_ddp=True`"
  972. )
  973. return None
  974. # Perform the local optimizer step
  975. loss = self._local_step(closure=closure, **kwargs)
  976. # Sync all of the updated parameter shards across the ranks
  977. self._sync_params()
  978. return loss
  979. def join_hook(self, **kwargs):
  980. r"""
  981. Return the ZeRO join hook.
  982. It enables training on uneven inputs by
  983. shadowing the collective communications in the optimizer step.
  984. Gradients must be properly set before this hook is called.
  985. Arguments:
  986. kwargs (dict): a :class:`dict` containing any keyword arguments
  987. to modify the behavior of the join hook at run time; all
  988. :class:`Joinable` instances sharing the same join context
  989. manager are forwarded the same value for ``kwargs``.
  990. This hook does not support any keyword arguments; i.e. ``kwargs`` is
  991. unused.
  992. """
  993. return _ZeROJoinHook(self)
  994. @property
  995. def join_device(self) -> torch.device:
  996. r"""Return default device."""
  997. return self._default_device
  998. @property
  999. def join_process_group(self) -> Any:
  1000. r"""Return process group."""
  1001. return self.process_group
  1002. def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
  1003. r"""
  1004. Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed.
  1005. Arguments:
  1006. state_dict (dict): optimizer state; should be an object returned
  1007. from a call to :meth:`state_dict`.
  1008. Raises:
  1009. RuntimeError: if ``overlap_with_ddp=True`` and this method is
  1010. called before this :class:`ZeroRedundancyOptimizer` instance
  1011. has been fully initialized, which happens once
  1012. :class:`DistributedDataParallel` gradient buckets have been
  1013. rebuilt.
  1014. """
  1015. self._check_overlap_initialized()
  1016. for index, value in state_dict["state"].items():
  1017. param = self._index_to_param[index]
  1018. if self._param_to_rank[param] != self.rank:
  1019. # Clear any state irrelevant to this rank
  1020. state_dict["state"][index] = None
  1021. else:
  1022. # Load the parameter state to the local optimizer
  1023. self.optim.state[param] = _recursive_copy_to_device(
  1024. value, non_blocking=True, device=param.device
  1025. )
  1026. # Force zero-dimensional tensors (like Adam "step") on CPU
  1027. for state_name, state_value in self.optim.state[param].items():
  1028. if torch.is_tensor(state_value) and state_value.dim() == 0:
  1029. self.optim.state[param][state_name] = state_value.cpu()
  1030. super().load_state_dict(state_dict)
  1031. # Sync the input state with the exposed and local optimizer states
  1032. self._sync_param_groups(state_dict["param_groups"], self.param_groups)
  1033. self._sync_param_groups(self.param_groups, self.optim.param_groups)
  1034. def state_dict(self) -> Dict[str, Any]:
  1035. r"""
  1036. Return the last global optimizer state known to this rank.
  1037. .. warning:
  1038. If the state has not been consolidated to this rank, this raises a
  1039. runtime error, and even if it has, the state may not be up-to-date,
  1040. depending on when :meth:`consolidate_state_dict` was last called.
  1041. Raises:
  1042. RuntimeError: if ``overlap_with_ddp=True`` and this method is
  1043. called before this :class:`ZeroRedundancyOptimizer` instance
  1044. has been fully initialized, which happens once
  1045. :class:`DistributedDataParallel` gradient buckets have been
  1046. rebuilt; or if this method is called without a preceding call
  1047. to :meth:`consolidate_state_dict`.
  1048. """
  1049. self._check_overlap_initialized()
  1050. if len(self._all_state_dicts) == 0:
  1051. raise RuntimeError(
  1052. "Optimizer state has not been consolidated on this rank. "
  1053. f"Please call `consolidate_state_dict(to={self.rank})` on "
  1054. "all ranks beforehand if you meant to save the global state."
  1055. )
  1056. # Get the possibly-stale global optimizer state that uses global
  1057. # parameter indexing
  1058. state_dict = super().state_dict()
  1059. # Update the global optimizer state with local state information,
  1060. # factoring in the translation from local to global indexing
  1061. for rank, local_state_dict in enumerate(self._all_state_dicts):
  1062. local_param_groups = local_state_dict["param_groups"]
  1063. global_param_groups = self._partition_parameters()[rank]
  1064. assert len(local_param_groups) == len(
  1065. global_param_groups
  1066. ), "Mismatch between number of local and global parameter groups"
  1067. for local_param_group, global_param_group in zip(
  1068. local_param_groups, global_param_groups
  1069. ):
  1070. # `local_param_group` stores local indices, while
  1071. # `global_param_group` stores the tensors directly
  1072. local_param_indices = local_param_group["params"]
  1073. global_params = global_param_group["params"]
  1074. assert len(local_param_indices) == len(
  1075. global_params
  1076. ), "Mismatch between number of local and global parameters in parameter group"
  1077. for local_param_index, global_param in zip(
  1078. local_param_indices, global_params
  1079. ):
  1080. # Update the global parameter state, if any
  1081. if local_param_index in local_state_dict["state"]:
  1082. global_param_index = self._param_to_index[global_param]
  1083. state_dict["state"][global_param_index] = local_state_dict[
  1084. "state"
  1085. ][local_param_index]
  1086. # Sort the parameters in the state
  1087. state_dict["state"] = dict(sorted(state_dict["state"].items()))
  1088. return state_dict
  1089. @staticmethod
  1090. def _sync_param_groups(
  1091. src_param_groups: List[Dict[Any, Any]],
  1092. dst_param_groups: List[Dict[Any, Any]],
  1093. ) -> None:
  1094. r"""
  1095. Sync the attributes from the source parameter groups to the destination parameter groups.
  1096. Example attributes include learning rate or scheduler attributes. The
  1097. two parameter groups should have the same length (i.e. same number of
  1098. parameter groups).
  1099. Arguments:
  1100. src_param_groups (list[dict]): parameter groups giving the
  1101. attribute settings to copy.
  1102. dst_param_groups (list[dict]): parameter groups giving the
  1103. attribute settings to set.
  1104. """
  1105. assert len(src_param_groups) == len(
  1106. dst_param_groups
  1107. ), "Mismatch between number of source and destination parameter groups"
  1108. for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups):
  1109. # Sync all attributes except the parameters
  1110. for attr in filter(lambda x: x != "params", src_param_group.keys()):
  1111. dst_param_group[attr] = src_param_group[attr]
  1112. def _build_param_buckets(self) -> None:
  1113. r"""
  1114. Build parameter buckets if ``parameters_as_bucket_view=True``.
  1115. For each device that stores this rank's parameters, there is a
  1116. bucket (represented as a tensor) containing all of the parameters on
  1117. that device that are assigned to a given rank in the parameter update
  1118. partition.
  1119. This method is called in the constructor and any time parameter
  1120. trainability is changed.
  1121. .. warning::
  1122. The current implementation assumes that all of the parameters in a
  1123. bucket are of the same dense type when allocating the bucket's
  1124. tensor.
  1125. .. warning::
  1126. If the model parameters are stored across more than one device,
  1127. then the storage partitioning must be the same across all
  1128. processes in order for parameter synchronization to work.
  1129. """
  1130. if not self.parameters_as_bucket_view or self._overlap_with_ddp:
  1131. return
  1132. # `self._buckets[i][j]` are the parameters stored on device i and
  1133. # assigned to rank j
  1134. num_devices = len(self._device_to_params_per_rank)
  1135. self._buckets = [[] for _ in range(num_devices)] # type: ignore[assignment]
  1136. for dev_i, (device, params_per_rank) in enumerate(
  1137. self._device_to_params_per_rank.items()
  1138. ):
  1139. for params in params_per_rank:
  1140. bucket_size = 0
  1141. dtype = None
  1142. trainable_params = []
  1143. for param in params:
  1144. if not _is_trainable(param):
  1145. # Clone in case the parameter was previously part of
  1146. # a bucket to avoid the data from being destroyed
  1147. param.data = param.data.detach().clone()
  1148. else:
  1149. bucket_size += param.numel()
  1150. trainable_params.append(param)
  1151. dtype = param.dtype # assumes all same dtype
  1152. if bucket_size == 0:
  1153. # Create a dummy bucket if there are no parameters
  1154. bucket = torch.zeros(1, device=device)
  1155. else:
  1156. # Construct the bucket (assuming all dense and same dtype)
  1157. bucket = torch.empty(bucket_size, dtype=dtype, device=device)
  1158. offset = 0
  1159. for param in trainable_params:
  1160. offset_next = offset + param.numel()
  1161. bucket[offset:offset_next].copy_(param.data.flatten())
  1162. param.data = bucket[offset:offset_next].view_as(param.data)
  1163. offset = offset_next
  1164. self._buckets[dev_i].append(bucket) # type: ignore[arg-type]
  1165. def _build_ddp_param_buckets(self) -> None:
  1166. r"""
  1167. Build the DDP bucket with parameters assigned to this rank.
  1168. For each DDP bucket with parameters assigned to this rank, flattens the
  1169. data of those parameters into a single tensor and saves the tensor to
  1170. the ``tensor`` attribute in the corresponding
  1171. :class:`_DDPBucketAssignment` instance stored in
  1172. ``self._bucket_assignments_per_rank``.
  1173. :class:`DistributedDataParallel` guarantees that the parameters
  1174. corresponding to a gradient bucket have the same device and the same
  1175. dtype.
  1176. """
  1177. for bucket_assignments in self._bucket_assignments_per_rank:
  1178. for bucket_assignment in bucket_assignments.values():
  1179. params = bucket_assignment.parameters
  1180. bucket_size = 0
  1181. dtype = None
  1182. for param in params:
  1183. assert _is_trainable(param), (
  1184. "Model parameter "
  1185. "corresponding to a gradient in a DDP bucket should "
  1186. "require a gradient"
  1187. )
  1188. bucket_size += param.numel()
  1189. dtype = param.dtype # assumes all same dtype
  1190. assert bucket_size > 0, "Empty bucket"
  1191. # Construct the bucket tensor (assuming all dense and same dtype)
  1192. tensor = torch.empty(
  1193. bucket_size, dtype=dtype, device=bucket_assignment.device
  1194. )
  1195. offset = 0
  1196. for param in params:
  1197. offset_next = offset + param.numel()
  1198. tensor[offset:offset_next].copy_(param.data.flatten())
  1199. param.data = tensor[offset:offset_next].view_as(param.data)
  1200. offset = offset_next
  1201. bucket_assignment.tensor = tensor
  1202. def _verify_and_init_params(
  1203. self,
  1204. params: Any,
  1205. ) -> Union[List[torch.Tensor], List[dict]]:
  1206. r"""
  1207. Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters.
  1208. The initializagtion will first make sure that provided ``params`` is valid.
  1209. Arguments:
  1210. params (Any): Candidate parameter list or parameter groups to verify.
  1211. Raises:
  1212. TypeError: ``params`` has an invalid type.
  1213. ValueError: ``params`` is empty.
  1214. Returns:
  1215. The persistent form of ``params`` to be passed into the parent
  1216. :class:`Optimizer` constructor -- i.e. returns ``params`` as a
  1217. :class:`list` to ensure that it can be iterated over again.
  1218. """
  1219. if isinstance(params, torch.Tensor):
  1220. raise TypeError(
  1221. "`params` argument should be an iterable of "
  1222. f"Tensors, but got {torch.typename(params)}"
  1223. )
  1224. try:
  1225. all_params = list(params)
  1226. except TypeError as e:
  1227. raise TypeError(
  1228. "`params` argument should be an iterable of Tensors"
  1229. f" or dicts, but got {torch.typename(params)}"
  1230. ) from e
  1231. if len(all_params) == 0:
  1232. raise ValueError("ZeroRedundancyOptimizer got an empty parameter list")
  1233. all_tensors = True
  1234. all_dicts = True
  1235. for param in all_params:
  1236. all_tensors &= isinstance(param, torch.Tensor)
  1237. all_dicts &= isinstance(param, dict)
  1238. if not all_tensors and not all_dicts:
  1239. raise TypeError(
  1240. "`params` argument should be an iterable of Tensors or dicts"
  1241. )
  1242. # Ensure that `self._all_params` contains a list of all parameters
  1243. if all_tensors:
  1244. self._all_params = all_params
  1245. elif all_dicts:
  1246. self._all_params = []
  1247. # `all_params` contains parameter groups (not parameters)
  1248. for param_group in all_params:
  1249. if "params" not in param_group:
  1250. raise ValueError(
  1251. "Each parameter group passed-in via `params` must "
  1252. "have a 'params' key mapping to the parameters in "
  1253. "the group"
  1254. )
  1255. self._all_params.extend(param_group["params"])
  1256. return all_params
  1257. def _verify_same_dense_param_type(self) -> None:
  1258. r"""
  1259. Verify that all parameters are of the same dense type.
  1260. The method assumes that ``self._all_params`` has been initialized
  1261. and is non-empty.
  1262. Raises:
  1263. ValueError: ``params`` contains sparse parameters or parameters
  1264. of varying dense types.
  1265. NOTE: This method can be removed once support for sparse parameters
  1266. and varying parameter types is added.
  1267. """
  1268. typename = torch.typename(self._all_params[0])
  1269. if self._all_params[0].is_sparse:
  1270. raise ValueError(
  1271. "ZeroRedundancyOptimizer only supports using "
  1272. "the same dense type for all parameters but got "
  1273. f"{typename}"
  1274. )
  1275. for param in self._all_params[1:]:
  1276. other_typename = torch.typename(param)
  1277. if other_typename != typename:
  1278. raise ValueError(
  1279. "ZeroRedundancyOptimizer only supports "
  1280. "using the same dense type for all "
  1281. f"parameters but got both {typename} and "
  1282. f"{other_typename}"
  1283. )
  1284. def _get_is_trainable_mask(self) -> List[bool]:
  1285. r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not."""
  1286. return list(map(_is_trainable, self._all_params))
  1287. def _init_local_optimizer(self) -> None:
  1288. r"""
  1289. Initialize this rank's local optimizer, responsible for its subset of the parameters.
  1290. The local optimizer is saved in ``self.optim``.
  1291. """
  1292. assert (
  1293. self._optim_constructor is not None
  1294. ), "The local optimizer class has not been set"
  1295. param_groups = self._partition_parameters()[self.rank]
  1296. # `overlap_with_ddp=True` requires a local functional optimizer
  1297. if self._overlap_with_ddp:
  1298. # Functional optimizers only support a single parameter group and
  1299. # require passing in the parameters as a list
  1300. assert len(param_groups) == 1, (
  1301. "Initializing the local "
  1302. "functional optimizer with more than one parameter group"
  1303. )
  1304. params = param_groups[0]["params"]
  1305. # Try to pass `_allow_empty_param_list=True` to avoid erroring
  1306. if (
  1307. "_allow_empty_param_list"
  1308. in inspect.signature(self._optim_constructor).parameters
  1309. ):
  1310. self.optim: Any = self._optim_constructor(
  1311. params, **self._optim_defaults, _allow_empty_param_list=True
  1312. )
  1313. else:
  1314. logger.warning(
  1315. "%s does not support the argument "
  1316. "`_allow_empty_param_list`; ZeroRedundancyOptimizer may "
  1317. "error due to an empty parameter list",
  1318. self._optim_constructor
  1319. )
  1320. self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef]
  1321. # Log information about the DDP and ZeRO bucketing
  1322. if dist.get_debug_level() != dist.DebugLevel.OFF:
  1323. local_numel = sum(p.numel() for p in params)
  1324. num_assigned_buckets = len(
  1325. self._bucket_assignments_per_rank[self.global_rank]
  1326. )
  1327. logger.info(
  1328. "rank %s with %s parameters "
  1329. "across %s buckets",
  1330. self.global_rank, local_numel, num_assigned_buckets
  1331. )
  1332. if self.global_rank == 0:
  1333. logger.info(
  1334. "%s DDP "
  1335. "buckets and "
  1336. "%s bucket "
  1337. "assignments",
  1338. len(self._overlap_info.params_per_bucket), self._overlap_info.num_bucket_assignments
  1339. )
  1340. else:
  1341. # NOTE: Passing `param_groups` into the local optimizer constructor
  1342. # bypasses the empty parameter list check
  1343. self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults) # type: ignore[no-redef]
  1344. # TODO: Manually add `self.param_groups` if using a functional
  1345. # optimizer; remove this if/when the functional optimizers support
  1346. # multiple parameter groups
  1347. if self._overlap_with_ddp and not hasattr(self.optim, "param_groups"):
  1348. assert hasattr(self.optim, "param_group"), (
  1349. "The functional optimizer should set at least one of the "
  1350. "attributes `param_group` or `param_groups`"
  1351. )
  1352. self.optim.param_groups = [self.optim.param_group] # type: ignore[attr-defined]
  1353. self._sync_param_groups(self.optim.param_groups, self.param_groups)
  1354. def _init_zero_for_overlap(self) -> None:
  1355. r"""Perform a delayed initialization of the local optimizer and the supporting data structures."""
  1356. assert self._overlap_with_ddp, (
  1357. "`_init_zero_for_overlap()` should only be called when "
  1358. "`overlap_with_ddp=True`"
  1359. )
  1360. self._overlap_info.status = _OverlapStatus.INITIALIZED
  1361. self._clear_cache()
  1362. self._partition_parameters(self._overlap_info.params_per_rank)
  1363. self._build_ddp_param_buckets()
  1364. self._init_local_optimizer()
  1365. def _get_assigned_rank(self, bucket_index: int) -> int:
  1366. r"""
  1367. Return the single rank assigned to a :class:`DistributedDataParallel` gradient bucket.
  1368. Arguments:
  1369. bucket_index (int): index of the :class:`DistributedDataParallel`
  1370. bucket for which to get the assigned rank.
  1371. """
  1372. assert not self._overlap_info.shard_buckets, (
  1373. "The bucket assignment requires global bucket information and "
  1374. "will be computed later; there should be no need to use this "
  1375. "method"
  1376. )
  1377. return bucket_index % self.world_size
  1378. def _check_overlap_initialized(self):
  1379. r"""
  1380. Check the delayed initialization depending on the value of ``overlap_with_ddp``.
  1381. The delayed initialization has occurred (see
  1382. :meth:`_init_zero_for_overlap`) if ``overlap_with_ddp=True``, and
  1383. raises a ``RuntimeError`` if not. This should preface methods that
  1384. should not be run before that delayed initialization.
  1385. Raises:
  1386. RuntimeError: if ``overlap_with_ddp=True`` and
  1387. :meth:`_init_zero_for_overlap` has not been called.
  1388. """
  1389. if (
  1390. self._overlap_with_ddp
  1391. and self._overlap_info.status != _OverlapStatus.INITIALIZED
  1392. ):
  1393. raise RuntimeError(
  1394. "This method should not be called until this "
  1395. "ZeroRedundancyOptimizer instance has been fully "
  1396. "initialized"
  1397. )
  1398. def _get_optimizer_constructor(self, optimizer_class: Any) -> Any:
  1399. r"""
  1400. Return the optimizer constructor using validation and transformation depending on ``overlap_with_ddp``.
  1401. Returns:
  1402. - ``optimizer_class`` if ``overlap_with_ddp=False`` and
  1403. ``optimizer_class`` is not a functional optimizer.
  1404. - ``optimizer_class`` if ``overlap_with_ddp=True`` and
  1405. ``optimizer_class`` is already a functional optimizer.
  1406. - The functional equivalent of ``optimizer_class`` if
  1407. ``overlap_with_ddp=True`` and ``optimizer_class`` is not
  1408. already a functional optimizer (assuming the equivalent
  1409. exists).
  1410. Raises:
  1411. ValueError:
  1412. - if ``overlap_with_ddp=True`` but ``optimizer_class`` is
  1413. neither a functional optimizer nor translatable to a
  1414. functional optimizer.
  1415. - if ``overlap_with_ddp=False`` and ``optimizer_class`` is a
  1416. functional optimizer.
  1417. """
  1418. functional_optims = functional_optim_map.values()
  1419. if not self._overlap_with_ddp:
  1420. if optimizer_class in functional_optims:
  1421. # Using a functional optimizer is only supported when
  1422. # `overlap_with_ddp=True`
  1423. raise ValueError(
  1424. f"Passing in a functional optimizer {optimizer_class} "
  1425. "when `overlap_with_ddp=False`"
  1426. )
  1427. else:
  1428. return optimizer_class
  1429. else:
  1430. if optimizer_class in functional_optims:
  1431. # Already a functional optimizer
  1432. return optimizer_class
  1433. elif optimizer_class in functional_optim_map:
  1434. # Translate the passed-in optimizer class to its functional
  1435. # equivalent if `overlap_with_ddp=True`
  1436. optim_constructor = functional_optim_map[optimizer_class]
  1437. logger.info(
  1438. "Using the functional optimizer %s "
  1439. "instead of %s since "
  1440. "`overlap_with_ddp=True`",
  1441. optim_constructor, optimizer_class
  1442. )
  1443. return optim_constructor
  1444. else:
  1445. raise ValueError(
  1446. "Using `ddp_with_overlap=True` requires using a "
  1447. "functional optimizer, but there is no supported functional "
  1448. f"optimizer equivalent for {optimizer_class}"
  1449. )