distributed.py 106 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import functools
  4. import inspect
  5. import itertools
  6. import logging
  7. import os
  8. import sys
  9. import warnings
  10. import weakref
  11. from collections import defaultdict, deque
  12. from contextlib import contextmanager
  13. from dataclasses import dataclass, fields, is_dataclass
  14. from enum import auto, Enum
  15. from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING
  16. import torch
  17. import torch.distributed as dist
  18. from torch.autograd import Function, Variable
  19. from torch.distributed.algorithms.join import Join, Joinable, JoinHook
  20. from torch.utils._pytree import tree_flatten, tree_unflatten
  21. RPC_AVAILABLE = False
  22. if dist.is_available():
  23. from torch.distributed.distributed_c10d import (
  24. _get_default_group,
  25. _rank_not_in_group,
  26. ReduceOp,
  27. )
  28. from torch.distributed.utils import (
  29. _alloc_storage,
  30. _cast_forward_inputs,
  31. _free_storage,
  32. _sync_module_states,
  33. _to_kwargs,
  34. _verify_param_shape_across_processes,
  35. )
  36. if torch.distributed.rpc.is_available():
  37. RPC_AVAILABLE = True
  38. from torch.distributed.rpc import RRef
  39. from torch._utils import _get_device_index
  40. from ..modules import Module
  41. from .scatter_gather import gather, scatter_kwargs # noqa: F401
  42. if TYPE_CHECKING:
  43. from torch.utils.hooks import RemovableHandle
  44. __all__ = ["DistributedDataParallel"]
  45. logger = logging.getLogger(__name__)
  46. @dataclass
  47. class _MixedPrecision:
  48. """
  49. This configures DDP-native mixed precision training.
  50. Attributes:
  51. param_dtype (torch.dtype): This specifies the dtype for model
  52. parameters, inputs (when ``cast_forward_inputs`` is set to
  53. ``True``), and therefore the dtype for computation.
  54. However, outside the forward and backward passes, parameters are in
  55. full precision. Model checkpointing always happens in full
  56. precision.
  57. reduce_dtype (torch.dtype): This specifies the dtype for gradient
  58. reduction, which is permitted to differ from ``param_dtype``.
  59. buffer_dtype (torch.dtype): This specifies the dtype for buffers.
  60. .. note:: This API is experimental and subject to change.
  61. .. note:: Only floating point tensors are cast to their specified dtypes.
  62. .. note:: ``state_dict`` checkpoints parameters and buffers in full
  63. precision.
  64. .. note:: Each low precision dtype must be specified explicitly. For
  65. example, ``_MixedPrecision(reduce_dtype=torch.float16)`` only specifies
  66. the reduction dtype to be low precision, and DDP will not cast
  67. parameters or buffers.
  68. .. note:: If a ``reduce_dtype`` is not specified, then gradient reduction
  69. happens in ``param_dtype`` if specified or the original parameter dtype
  70. otherwise. For example, ``_MixedPrecision(param_dtype=torch.float16)``
  71. would result in communication occurring in fp16.
  72. """
  73. param_dtype: Optional[torch.dtype] = None
  74. reduce_dtype: Optional[torch.dtype] = None
  75. buffer_dtype: Optional[torch.dtype] = None
  76. # TODO (rohan-varma): keep_low_precision_grads: bool = False
  77. # TODO (rohan-varma): APIs to allow users to run batchnorm and layernorm
  78. # in full precision. For DDP, this can be implemented by not performing the
  79. # parameter cast for BN and LN units.
  80. def _cast_buffers(mixed_precision_config, root_module):
  81. """Casts buffers to the given ``buffer_dtype``."""
  82. for buf in root_module.buffers():
  83. if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored:
  84. continue
  85. buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype)
  86. def _setup_mixed_precision_params(mixed_precision_config, root_module):
  87. """Create and free storage for the mixed precision parameters."""
  88. for param in root_module.parameters():
  89. # Do not setup mixed precision for DDP ignored parameters.
  90. if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
  91. continue
  92. if not hasattr(param, "_mp_param"):
  93. param._mp_param = torch.zeros_like(
  94. param,
  95. device=param.device,
  96. dtype=mixed_precision_config.param_dtype,
  97. requires_grad=param.requires_grad,
  98. )
  99. _free_storage(param._mp_param)
  100. # _fp_param will point to the full precision param so it can be switched
  101. # back to at the end of forward / backward.
  102. param._fp_param = param.data
  103. def _tree_flatten_with_rref(output):
  104. output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
  105. if output_is_rref:
  106. output_tensor_list, treespec = tree_flatten(output.local_value())
  107. else:
  108. output_tensor_list, treespec = tree_flatten(output)
  109. # Need to return flattened tensors, spec to re-pack them, as well
  110. # as if the return type was actually an RRef to reconstruct.
  111. return output_tensor_list, treespec, output_is_rref
  112. def _tree_unflatten_with_rref(output, treespec, output_is_rref):
  113. output = tree_unflatten(output, treespec)
  114. if output_is_rref:
  115. output = RRef(output)
  116. return output
  117. def _find_tensors(obj):
  118. r"""Recursively find all tensors contained in the specified object."""
  119. if RPC_AVAILABLE and isinstance(obj, RRef):
  120. # If the current node is the owner of the RRef, unwrap it and try to
  121. # find Tensors.
  122. # TODO: Expand to remote RRefs.
  123. if obj.is_owner():
  124. return _find_tensors(obj.local_value())
  125. if isinstance(obj, torch.Tensor):
  126. return [obj]
  127. if isinstance(obj, (list, tuple)):
  128. return itertools.chain.from_iterable(map(_find_tensors, obj))
  129. if isinstance(obj, dict):
  130. return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
  131. if is_dataclass(obj):
  132. return itertools.chain.from_iterable(
  133. map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
  134. )
  135. return []
  136. def _dump_DDP_relevant_env_vars():
  137. relevant_env_vars = [
  138. "RANK",
  139. "LOCAL_RANK",
  140. "WORLD_SIZE",
  141. "MASTER_PORT",
  142. "MASTER_ADDR",
  143. "CUDA_VISIBLE_DEVICES",
  144. "GLOO_SOCKET_IFNAME",
  145. "GLOO_DEVICE_TRANSPORT",
  146. "NCCL_SOCKET_IFNAME",
  147. "TORCH_NCCL_BLOCKING_WAIT",
  148. "NCCL_DEBUG",
  149. "NCCL_DEBUG_SUBSYS",
  150. "NCCL_IB_DISABLE",
  151. # More NCCL env vars:
  152. "NCCL_P2P_DISABLE",
  153. "NCCL_P2P_LEVEL",
  154. "NCCL_SHM_DISABLE",
  155. "NCCL_SOCKET_NTHREADS",
  156. "NCCL_NSOCKS_PERTHREAD",
  157. "NCCL_BUFFSIZE",
  158. "NCCL_NTHREADS",
  159. "NCCL_RINGS",
  160. "NCCL_MAX_NCHANNELS",
  161. "NCCL_MIN_NCHANNELS",
  162. "NCCL_CHECKS_DISABLE",
  163. "NCCL_CHECK_POINTERS",
  164. "NCCL_LAUNCH_MODE",
  165. "NCCL_IB_HCA",
  166. "NCCL_IB_TIMEOUT",
  167. "NCCL_IB_RETRY_CNT",
  168. "NCCL_IB_GID_INDEX",
  169. "NCCL_IB_SL",
  170. "NCCL_IB_TC",
  171. "NCCL_IB_AR_THRESHOLD",
  172. "NCCL_IB_CUDA_SUPPORT",
  173. "NCCL_NET_GDR_LEVEL",
  174. "NCCL_NET_GDR_READ",
  175. "NCCL_SINGLE_RING_THRESHOLD",
  176. "NCCL_LL_THRESHOLD",
  177. "NCCL_TREE_THRESHOLD",
  178. "NCCL_ALGO",
  179. "NCCL_PROTO",
  180. "NCCL_IGNORE_CPU_AFFINITY",
  181. "NCCL_DEBUG_FILE",
  182. "NCCL_COLLNET_ENABLE",
  183. "NCCL_TOPO_FILE",
  184. "NCCL_TOPO_DUMP_FILE",
  185. "TORCH_NCCL_ASYNC_ERROR_HANDLING",
  186. ]
  187. formatted_output = ""
  188. for var in relevant_env_vars:
  189. value = os.environ[var] if var in os.environ else "N/A"
  190. formatted_output += f"env:{var}={value}\n"
  191. print(formatted_output)
  192. class _BufferCommHookLocation(Enum):
  193. PRE_FORWARD = auto()
  194. POST_FORWARD = auto()
  195. @dataclass
  196. class _BufferCommHook:
  197. buffer_comm_hook: Callable
  198. buffer_comm_hook_state: Any
  199. buffer_comm_hook_location: _BufferCommHookLocation
  200. # Add a DDPSink to run various functions when backwards starts, such as
  201. # queueing call back of out-most backward/graph task,
  202. # this helps call back is fired after all gradients' calculation
  203. # is completed.
  204. class _DDPSink(Function):
  205. @staticmethod
  206. def forward(ctx, ddp_weakref, *inputs):
  207. # set_materialize_grads(False) will ensure that None gradients stay as
  208. # None and are not filled with zeros.
  209. ctx.set_materialize_grads(False)
  210. ctx.ddp_weakref = ddp_weakref
  211. ret = inputs
  212. if ddp_weakref()._ddp_sink_clone:
  213. ret = tuple(
  214. inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
  215. )
  216. return ret
  217. @staticmethod
  218. def backward(ctx, *grad_outputs):
  219. # Enqueue delay allreduce for static graph training on the first
  220. # iteration.
  221. ddp_weakref = ctx.ddp_weakref()
  222. reducer = ddp_weakref.reducer
  223. static_graph = ddp_weakref.static_graph
  224. delay_ar_enqueued = (
  225. static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued
  226. )
  227. if static_graph and not delay_ar_enqueued:
  228. Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
  229. reducer._delay_all_reduce
  230. )
  231. ddp_weakref._static_graph_delay_allreduce_enqueued = True
  232. return (None, *grad_outputs)
  233. class _DDPJoinHook(JoinHook):
  234. def __init__(self, ddp, divide_by_initial_world_size):
  235. """Set config variables for internal usage."""
  236. assert isinstance(ddp, DistributedDataParallel), (
  237. "DDP join hook requires passing in a DistributedDataParallel "
  238. "instance as the state"
  239. )
  240. assert ddp.logger is not None
  241. ddp.logger._set_uneven_input_join()
  242. self.ddp = ddp
  243. self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
  244. super().__init__()
  245. def main_hook(self):
  246. """Shadow the DDP collective communication operations in the forward and backward passes."""
  247. ddp = self.ddp
  248. # Buckets are rebuilt only once during a training period
  249. ddp.reducer._rebuild_buckets()
  250. # Schedule a broadcast if we are syncing module buffers in the
  251. # forward pass
  252. # TODO: make DDP uneven inputs context manager support buffer
  253. # comm hook (https://github.com/pytorch/pytorch/issues/65436)
  254. ddp._check_and_sync_module_buffers()
  255. # Check if need to sync in the backward pass
  256. should_sync_backwards = ddp._check_global_requires_backward_grad_sync(
  257. is_joined_rank=True
  258. )
  259. # Forward parameter sync is disabled in the next iteration if we
  260. # are skipping gradient sync this iteration, so set
  261. # `require_forward_param_sync` accordingly
  262. ddp.require_forward_param_sync = should_sync_backwards
  263. if not should_sync_backwards:
  264. return
  265. # Schedule one allreduce per gradient bucket to match the backward
  266. # pass allreduce
  267. ddp._match_all_reduce_for_bwd_pass()
  268. # Check if we need to allreduce locally unused parameters
  269. if ddp.find_unused_parameters:
  270. ddp._match_unused_params_allreduce()
  271. # Rebuilt parameters are pushed only once during a training period
  272. ddp.reducer._push_all_rebuilt_params()
  273. def post_hook(self, is_last_joiner: bool):
  274. """Sync the final model to ensure that the model is the same across all processes."""
  275. self.ddp._sync_final_model(is_last_joiner)
  276. class DistributedDataParallel(Module, Joinable):
  277. r"""Implement distributed data parallelism based on ``torch.distributed`` at module level.
  278. This container provides data parallelism by synchronizing gradients
  279. across each model replica. The devices to synchronize across are
  280. specified by the input ``process_group``, which is the entire world
  281. by default. Note that ``DistributedDataParallel`` does not chunk or
  282. otherwise shard the input across participating GPUs; the user is
  283. responsible for defining how to do so, for example through the use
  284. of a :class:`DistributedSampler`.
  285. See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`.
  286. The same constraints on input as in :class:`torch.nn.DataParallel` apply.
  287. Creation of this class requires that ``torch.distributed`` to be already
  288. initialized, by calling :func:`torch.distributed.init_process_group`.
  289. ``DistributedDataParallel`` is proven to be significantly faster than
  290. :class:`torch.nn.DataParallel` for single-node multi-GPU data
  291. parallel training.
  292. To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn
  293. up ``N`` processes, ensuring that each process exclusively works on a single
  294. GPU from 0 to N-1. This can be done by either setting
  295. ``CUDA_VISIBLE_DEVICES`` for every process or by calling:
  296. >>> # xdoctest: +SKIP("undefined variables")
  297. >>> torch.cuda.set_device(i)
  298. where i is from 0 to N-1. In each process, you should refer the following
  299. to construct this module:
  300. >>> # xdoctest: +SKIP("undefined variables")
  301. >>> torch.distributed.init_process_group(
  302. >>> backend='nccl', world_size=N, init_method='...'
  303. >>> )
  304. >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
  305. In order to spawn up multiple processes per node, you can use either
  306. ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``.
  307. .. note::
  308. Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__
  309. for a brief introduction to all features related to distributed training.
  310. .. note::
  311. ``DistributedDataParallel`` can be used in conjunction with
  312. :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce
  313. per-rank optimizer states memory footprint. Please refer to
  314. `ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__
  315. for more details.
  316. .. note:: ``nccl`` backend is currently the fastest and highly recommended
  317. backend when using GPUs. This applies to both single-node and
  318. multi-node distributed training.
  319. .. note:: This module also supports mixed-precision distributed training.
  320. This means that your model can have different types of parameters such
  321. as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these
  322. mixed types of parameters will just work fine.
  323. .. note:: If you use ``torch.save`` on one process to checkpoint the module,
  324. and ``torch.load`` on some other processes to recover it, make sure that
  325. ``map_location`` is configured properly for every process. Without
  326. ``map_location``, ``torch.load`` would recover the module to devices
  327. where the module was saved from.
  328. .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
  329. gradient will be ``M`` times smaller when compared to the same model
  330. trained on a single node with ``batch=M*N`` if the loss is summed (NOT
  331. averaged as usual) across instances in a batch (because the gradients
  332. between different nodes are averaged). You should take this into
  333. consideration when you want to obtain a mathematically equivalent
  334. training process compared to the local training counterpart. But in most
  335. cases, you can just treat a DistributedDataParallel wrapped model, a
  336. DataParallel wrapped model and an ordinary model on a single GPU as the
  337. same (E.g. using the same learning rate for equivalent batch size).
  338. .. note::
  339. Parameters are never broadcast between processes. The module performs
  340. an all-reduce step on gradients and assumes that they will be modified
  341. by the optimizer in all processes in the same way. Buffers
  342. (e.g. BatchNorm stats) are broadcast from the module in process of rank
  343. 0, to all other replicas in the system in every iteration.
  344. .. note::
  345. If you are using DistributedDataParallel in conjunction with the
  346. :ref:`distributed-rpc-framework`, you should always use
  347. :meth:`torch.distributed.autograd.backward` to compute gradients and
  348. :class:`torch.distributed.optim.DistributedOptimizer` for optimizing
  349. parameters.
  350. Example::
  351. >>> # xdoctest: +SKIP("undefined variables")
  352. >>> import torch.distributed.autograd as dist_autograd
  353. >>> from torch.nn.parallel import DistributedDataParallel as DDP
  354. >>> import torch
  355. >>> from torch import optim
  356. >>> from torch.distributed.optim import DistributedOptimizer
  357. >>> import torch.distributed.rpc as rpc
  358. >>> from torch.distributed.rpc import RRef
  359. >>>
  360. >>> t1 = torch.rand((3, 3), requires_grad=True)
  361. >>> t2 = torch.rand((3, 3), requires_grad=True)
  362. >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2))
  363. >>> ddp_model = DDP(my_model)
  364. >>>
  365. >>> # Setup optimizer
  366. >>> optimizer_params = [rref]
  367. >>> for param in ddp_model.parameters():
  368. >>> optimizer_params.append(RRef(param))
  369. >>>
  370. >>> dist_optim = DistributedOptimizer(
  371. >>> optim.SGD,
  372. >>> optimizer_params,
  373. >>> lr=0.05,
  374. >>> )
  375. >>>
  376. >>> with dist_autograd.context() as context_id:
  377. >>> pred = ddp_model(rref.to_here())
  378. >>> loss = loss_func(pred, target)
  379. >>> dist_autograd.backward(context_id, [loss])
  380. >>> dist_optim.step(context_id)
  381. .. note::
  382. DistributedDataParallel currently offers limited support for gradient
  383. checkpointing with :meth:`torch.utils.checkpoint`.
  384. If the checkpoint is done with use_reentrant=False (recommended), DDP
  385. will work as expected without any limitations.
  386. If, however, the checkpoint is done with use_reentrant=True (the default),
  387. DDP will work as expected when there are no unused parameters in the model
  388. and each layer is checkpointed at most once (make sure you are not passing
  389. `find_unused_parameters=True` to DDP). We currently do not support the
  390. case where a layer is checkpointed multiple times, or when there unused
  391. parameters in the checkpointed model.
  392. .. note::
  393. To let a non-DDP model load a state dict from a DDP model,
  394. :meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present`
  395. needs to be applied to strip the prefix "module." in the DDP state dict before loading.
  396. .. warning::
  397. Constructor, forward method, and differentiation of the output (or a
  398. function of the output of this module) are distributed synchronization
  399. points. Take that into account in case different processes might be
  400. executing different code.
  401. .. warning::
  402. This module assumes all parameters are registered in the model by the
  403. time it is created. No parameters should be added nor removed later.
  404. Same applies to buffers.
  405. .. warning::
  406. This module assumes all parameters are registered in the model of each
  407. distributed processes are in the same order. The module itself will
  408. conduct gradient ``allreduce`` following the reverse order of the
  409. registered parameters of the model. In other words, it is users'
  410. responsibility to ensure that each distributed process has the exact
  411. same model and thus the exact same parameter registration order.
  412. .. warning::
  413. This module allows parameters with non-rowmajor-contiguous strides.
  414. For example, your model may contain some parameters whose
  415. :class:`torch.memory_format` is ``torch.contiguous_format``
  416. and others whose format is ``torch.channels_last``. However,
  417. corresponding parameters in different processes must have the
  418. same strides.
  419. .. warning::
  420. This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
  421. only work if gradients are to be accumulated in ``.grad`` attributes of
  422. parameters).
  423. .. warning::
  424. If you plan on using this module with a ``nccl`` backend or a ``gloo``
  425. backend (that uses Infiniband), together with a DataLoader that uses
  426. multiple workers, please change the multiprocessing start method to
  427. ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
  428. Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
  429. likely experience deadlocks if you don't change this setting.
  430. .. warning::
  431. You should never try to change your model's parameters after wrapping
  432. up your model with ``DistributedDataParallel``. Because, when
  433. wrapping up your model with ``DistributedDataParallel``, the constructor
  434. of ``DistributedDataParallel`` will register the additional gradient
  435. reduction functions on all the parameters of the model itself at the
  436. time of construction. If you change the model's parameters afterwards,
  437. gradient reduction functions no longer match the correct set of
  438. parameters.
  439. .. warning::
  440. Using ``DistributedDataParallel`` in conjunction with the
  441. :ref:`distributed-rpc-framework` is experimental and subject to change.
  442. Args:
  443. module (Module): module to be parallelized
  444. device_ids (list of int or torch.device): CUDA devices.
  445. 1) For single-device modules, ``device_ids`` can
  446. contain exactly one device id, which represents the only
  447. CUDA device where the input module corresponding to this process resides.
  448. Alternatively, ``device_ids`` can also be ``None``.
  449. 2) For multi-device modules and CPU modules,
  450. ``device_ids`` must be ``None``.
  451. When ``device_ids`` is ``None`` for both cases,
  452. both the input data for the forward pass and the actual module
  453. must be placed on the correct device.
  454. (default: ``None``)
  455. output_device (int or torch.device): Device location of output for
  456. single-device CUDA modules. For multi-device modules and
  457. CPU modules, it must be ``None``, and the module itself
  458. dictates the output location. (default: ``device_ids[0]``
  459. for single-device modules)
  460. broadcast_buffers (bool): Flag that enables syncing (broadcasting)
  461. buffers of the module at beginning of the ``forward``
  462. function. (default: ``True``)
  463. process_group: The process group to be used for distributed data
  464. all-reduction. If ``None``, the default process group, which
  465. is created by :func:`torch.distributed.init_process_group`,
  466. will be used. (default: ``None``)
  467. bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into
  468. multiple buckets so that gradient reduction of each
  469. bucket can potentially overlap with backward computation.
  470. :attr:`bucket_cap_mb` controls the bucket size in
  471. MebiBytes (MiB). If ``None``, a default size of 25 MiB
  472. will be used. (default: ``None``)
  473. find_unused_parameters (bool): Traverse the autograd graph from all
  474. tensors contained in the return value of the
  475. wrapped module's ``forward`` function. Parameters
  476. that don't receive gradients as part of this
  477. graph are preemptively marked as being ready to
  478. be reduced. In addition, parameters that may have
  479. been used in the wrapped module's ``forward``
  480. function but were not part of loss computation and
  481. thus would also not receive gradients are
  482. preemptively marked as ready to be reduced.
  483. (default: ``False``)
  484. check_reduction: This argument is deprecated.
  485. gradient_as_bucket_view (bool): When set to ``True``, gradients will be views
  486. pointing to different offsets of ``allreduce`` communication
  487. buckets. This can reduce peak memory usage, where the
  488. saved memory size will be equal to the total gradients
  489. size. Moreover, it avoids the overhead of copying between
  490. gradients and ``allreduce`` communication buckets. When
  491. gradients are views, ``detach_()`` cannot be called on the
  492. gradients. If hitting such errors, please fix it by
  493. referring to the :meth:`~torch.optim.Optimizer.zero_grad`
  494. function in ``torch/optim/optimizer.py`` as a solution.
  495. Note that gradients will be views after first iteration, so
  496. the peak memory saving should be checked after first iteration.
  497. static_graph (bool): When set to ``True``, DDP knows the trained graph is
  498. static. Static graph means 1) The set of used and unused
  499. parameters will not change during the whole training loop; in
  500. this case, it does not matter whether users set
  501. ``find_unused_parameters = True`` or not. 2) How the graph is trained
  502. will not change during the whole training loop (meaning there is
  503. no control flow depending on iterations).
  504. When static_graph is set to be ``True``, DDP will support cases that
  505. can not be supported in the past:
  506. 1) Reentrant backwards.
  507. 2) Activation checkpointing multiple times.
  508. 3) Activation checkpointing when model has unused parameters.
  509. 4) There are model parameters that are outside of forward function.
  510. 5) Potentially improve performance when there are unused parameters,
  511. as DDP will not search graph in each iteration to detect unused
  512. parameters when static_graph is set to be ``True``.
  513. To check whether you can set static_graph to be ``True``, one way is to
  514. check ddp logging data at the end of your previous model training,
  515. if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you
  516. can set ``static_graph = True`` as well.
  517. Example::
  518. >>> # xdoctest: +SKIP("undefined variables")
  519. >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
  520. >>> # Training loop
  521. >>> ...
  522. >>> ddp_logging_data = model_DDP._get_ddp_logging_data()
  523. >>> static_graph = ddp_logging_data.get("can_set_static_graph")
  524. delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter): a list
  525. of named parameters whose all reduce will be delayed when the gradient of
  526. the parameter specified in ``param_to_hook_all_reduce`` is ready. Other
  527. arguments of DDP do not apply to named params specified in this argument
  528. as these named params will be ignored by DDP reducer.
  529. param_to_hook_all_reduce (torch.nn.Parameter): a parameter to hook delayed all reduce
  530. of parameters specified in ``delay_all_reduce_named_params``.
  531. Attributes:
  532. module (Module): the module to be parallelized.
  533. Example::
  534. >>> # xdoctest: +SKIP("undefined variables")
  535. >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
  536. >>> net = torch.nn.parallel.DistributedDataParallel(model)
  537. """
  538. # used to track whether the given thread is inside ddp forward for torchdynamo purposes
  539. _active_ddp_module: Optional["DistributedDataParallel"] = None
  540. def __init__(
  541. self,
  542. module,
  543. device_ids=None,
  544. output_device=None,
  545. dim=0,
  546. broadcast_buffers=True,
  547. process_group=None,
  548. bucket_cap_mb=None,
  549. find_unused_parameters=False,
  550. check_reduction=False,
  551. gradient_as_bucket_view=False,
  552. static_graph=False,
  553. delay_all_reduce_named_params=None,
  554. param_to_hook_all_reduce=None,
  555. mixed_precision: Optional[_MixedPrecision] = None,
  556. device_mesh=None,
  557. ):
  558. super().__init__()
  559. Joinable.__init__(self)
  560. self.logger = None
  561. if bool(delay_all_reduce_named_params is not None) != bool(
  562. param_to_hook_all_reduce is not None
  563. ):
  564. self._log_and_throw(
  565. ValueError,
  566. "delay_all_reduce_named_params and param_to_hook_all_reduce "
  567. "need to be set at the same time.",
  568. )
  569. if process_group and device_mesh is not None:
  570. raise RuntimeError(
  571. "Cannot specify both process_group and device_mesh arguments."
  572. )
  573. elif process_group is None and device_mesh is None:
  574. self.process_group = _get_default_group()
  575. elif device_mesh is None:
  576. self.process_group = process_group
  577. else:
  578. if device_mesh.ndim != 1:
  579. raise RuntimeError(
  580. f"Only 1D device mesh is supported, but got {device_mesh}."
  581. )
  582. self.device_mesh = device_mesh
  583. self.process_group = device_mesh.get_group(mesh_dim=0)
  584. from torch.distributed.device_mesh import _mesh_resources
  585. if _mesh_resources.get_parent_mesh(device_mesh) is not None:
  586. # TODO: This is a temporary work around to enable DDP + TP.
  587. # We should do the logic in DDP so that the 2D implementation is
  588. # sound and the state_dict works out of the box.
  589. # This has to be done before check UninitializedParameter.
  590. from torch.distributed.tensor.parallel.ddp import (
  591. _pre_dp_module_transform,
  592. )
  593. _pre_dp_module_transform(module)
  594. self._delay_all_reduce_params = []
  595. if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
  596. self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore)
  597. else:
  598. self.parameters_to_ignore = set()
  599. if delay_all_reduce_named_params is not None:
  600. for name, param in delay_all_reduce_named_params:
  601. self.parameters_to_ignore.add(name)
  602. self._delay_all_reduce_params.append(param)
  603. self._module_parameters = [
  604. p
  605. for n, p in module.named_parameters()
  606. if n not in self.parameters_to_ignore
  607. ]
  608. if not any(p.requires_grad for p in self._module_parameters):
  609. if len(self._delay_all_reduce_params):
  610. logger.info("Delay the AllReduce of all parameters.")
  611. else:
  612. self._log_and_throw(
  613. RuntimeError,
  614. "DistributedDataParallel is not needed when a module "
  615. "doesn't have any parameter that requires a gradient.",
  616. )
  617. if device_ids is not None and len(device_ids) > 1:
  618. self._log_and_throw(
  619. ValueError,
  620. "device_ids can only be None or contain a single element.",
  621. )
  622. self.is_multi_device_module = (
  623. len({p.device for p in self._module_parameters}) > 1
  624. )
  625. distinct_device_types = {
  626. p.device.type for p in self._module_parameters if p.device is not None
  627. }
  628. if len(distinct_device_types) != 1:
  629. self._log_and_throw(
  630. ValueError,
  631. "DistributedDataParallel's input module must be on "
  632. f"the same type of devices, but input module parameters locate in {distinct_device_types}.",
  633. )
  634. self.device_type = next(iter(distinct_device_types))
  635. if (
  636. device_ids is None
  637. or len(device_ids) == 0 # For backward compatibility.
  638. or self.device_type == "cpu"
  639. or self.is_multi_device_module
  640. ):
  641. if device_ids or output_device:
  642. self._log_and_throw(
  643. ValueError,
  644. "DistributedDataParallel device_ids and output_device arguments "
  645. "only work with single-device/multiple-device GPU modules or CPU modules, "
  646. f"but got device_ids {device_ids}, output_device {output_device}, "
  647. f"and module parameters {({p.device for p in self._module_parameters})}.",
  648. )
  649. self.device_ids = None
  650. self.output_device = None
  651. else:
  652. self.device_ids = [_get_device_index(x, True) for x in device_ids]
  653. if output_device is None:
  654. output_device = device_ids[0]
  655. self.output_device = _get_device_index(output_device, True)
  656. self.static_graph = False
  657. self.dim = dim
  658. self.module = module
  659. self.device = next(iter(self._module_parameters)).device
  660. self.broadcast_buffers = broadcast_buffers
  661. self.find_unused_parameters = find_unused_parameters
  662. self.require_backward_grad_sync = True
  663. self.require_forward_param_sync = True
  664. self.gradient_as_bucket_view = gradient_as_bucket_view
  665. self.mixed_precision = mixed_precision
  666. if self.mixed_precision is not None:
  667. logger.warning("Received mixed precision config %s", self.mixed_precision)
  668. if check_reduction:
  669. # This argument is no longer used since the reducer
  670. # will ensure reduction completes even if some parameters
  671. # do not receive gradients.
  672. warnings.warn(
  673. "The `check_reduction` argument in `DistributedDataParallel` "
  674. "module is deprecated. Please avoid using it.",
  675. FutureWarning,
  676. stacklevel=2,
  677. )
  678. # Check that a module does not have Uninitialized parameters
  679. for param in self._module_parameters:
  680. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  681. self._log_and_throw(
  682. RuntimeError,
  683. "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. "
  684. "Run a dummy forward pass to correctly initialize the modules",
  685. )
  686. # used for intra-node param sync and inter-node sync as well
  687. self.broadcast_bucket_size = int(250 * 1024 * 1024)
  688. # reduction bucket size
  689. if bucket_cap_mb is None:
  690. # default case (bucket cap is 25 MiB)
  691. bucket_cap_mb = 25
  692. self.bucket_bytes_cap_default = True
  693. else:
  694. self.bucket_bytes_cap_default = False
  695. self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
  696. # Whether to perform input tensor CPU to GPU copies on a side-stream
  697. self.use_side_stream_for_tensor_copies = (
  698. os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1"
  699. )
  700. # Initialize gradient buffers and register all reduce hook
  701. self._delay_grad_buffer = None
  702. self._delay_grad_views: List[torch.Tensor] = []
  703. self._delay_all_reduce_all_params = False
  704. if len(self._delay_all_reduce_params) != 0:
  705. self._register_delay_all_reduce_hook(
  706. bucket_cap_mb=bucket_cap_mb,
  707. param_to_hook_all_reduce=param_to_hook_all_reduce,
  708. device_ids=device_ids,
  709. )
  710. if self._delay_all_reduce_all_params:
  711. return
  712. # Build parameters for reducer.
  713. parameters, expect_sparse_gradient = self._build_params_for_reducer()
  714. # Verify model equivalence.
  715. _verify_param_shape_across_processes(self.process_group, parameters)
  716. # Sync params and buffers. Ensures all DDP models start off at the same value.
  717. _sync_module_states(
  718. module=self.module,
  719. process_group=self.process_group,
  720. broadcast_bucket_size=self.broadcast_bucket_size,
  721. src=0,
  722. params_and_buffers_to_ignore=self.parameters_to_ignore,
  723. broadcast_buffers=self.broadcast_buffers,
  724. )
  725. # In debug mode, build a mapping of parameter index -> parameter.
  726. param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
  727. # Builds reducer.
  728. self._ddp_init_helper(
  729. parameters,
  730. expect_sparse_gradient,
  731. param_to_name_mapping,
  732. static_graph,
  733. )
  734. self._comm_hooks: List[Tuple[Callable, object]] = []
  735. if self.mixed_precision is not None:
  736. _setup_mixed_precision_params(self.mixed_precision, self.module)
  737. _cast_buffers(self.mixed_precision, self.module)
  738. # Stream used for async low precision copies.
  739. self._mp_stream = torch.cuda.Stream()
  740. self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated]
  741. # Add forward pre-hook to root module to kick off copies to lower
  742. # precision.
  743. self.module.register_forward_pre_hook(
  744. self._root_copy_hook, prepend=False, with_kwargs=True
  745. )
  746. # Add forward pre hook to all submodules to wait for copy events
  747. # before running computation.
  748. for module in self.module.modules():
  749. module.register_forward_pre_hook(
  750. self._module_wait_for_copy_hook,
  751. prepend=False,
  752. with_kwargs=True,
  753. )
  754. # Set up callbacks in backward to upcast and use full precision
  755. # params. TODO (rohan-varma): Make this compose with general
  756. # comm hooks and apply_optimizer_in_backward. Importing inline to
  757. # avoid circular import issue.
  758. from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import (
  759. _AllreduceUpcastHookState,
  760. _reducer_allreduce_and_upcast_hook,
  761. )
  762. upcast_hook_state = _AllreduceUpcastHookState(
  763. ddp_weakref=weakref.ref(self),
  764. upcast_stream=torch.cuda.Stream(),
  765. )
  766. self.register_comm_hook(
  767. upcast_hook_state,
  768. _reducer_allreduce_and_upcast_hook,
  769. )
  770. # Inform reducer of reduced precision param dtype for correctness
  771. # of type checks between gradient and bucket.
  772. self.reducer._set_mixed_precision_param_dtype( # type: ignore[attr-defined]
  773. self.mixed_precision.param_dtype
  774. )
  775. self._has_rebuilt_buckets = False
  776. if static_graph:
  777. self._set_static_graph()
  778. self._lazy_init_ran = False
  779. # Register the AccumulateGrad post hooks if optimize_ddp is
  780. # True. The hooks will be deregistered if compiled_autograd is not
  781. # enabled.
  782. self._accum_grad_hooks: List[RemovableHandle] = []
  783. optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode()
  784. self._use_python_reducer = optimize_ddp in (
  785. "python_reducer",
  786. "python_reducer_without_compiled_forward",
  787. )
  788. if self._use_python_reducer:
  789. torch._inductor.config._fuse_ddp_communication = True
  790. torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
  791. # Directly adding this to the trace rule will disturb the users
  792. # who are using DDPOptimizer.
  793. torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add(
  794. "torch.nn.parallel.distributed"
  795. )
  796. torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear()
  797. self._force_to_disable_cpp_reducer = (
  798. optimize_ddp == "python_reducer_without_compiled_forward"
  799. )
  800. if self._use_python_reducer:
  801. self._register_accum_grad_hook()
  802. # Whether or not DDPSink performs a clone.
  803. self._ddp_sink_clone = True
  804. def _register_accum_grad_hook(self):
  805. import torch.distributed._functional_collectives as fcol
  806. def compiled_accum_grad_hook(
  807. param,
  808. *,
  809. param_index: int,
  810. ):
  811. if not self.require_backward_grad_sync:
  812. return
  813. if param.grad is None:
  814. return
  815. if self._comm_hooks:
  816. for hook, state in self._comm_hooks:
  817. hook(state, (param.grad, param))
  818. else:
  819. gradient = param.grad / self.process_group.size()
  820. gradient = fcol.all_reduce(gradient, "sum", self.process_group)
  821. param.grad.copy_(gradient)
  822. for index, param in enumerate(self._module_parameters):
  823. if not param.requires_grad:
  824. continue
  825. self._accum_grad_hooks.append(
  826. param.register_post_accumulate_grad_hook(
  827. functools.partial(
  828. compiled_accum_grad_hook,
  829. param_index=index,
  830. )
  831. )
  832. )
  833. def _delayed_all_reduce_hook(self, grad):
  834. world_size = dist.get_world_size(self.process_group)
  835. self._delay_grad_buffer.div_(world_size) # type: ignore[union-attr]
  836. _ = dist.all_reduce(
  837. self._delay_grad_buffer, group=self.process_group, async_op=True
  838. )
  839. return grad
  840. def _register_delay_all_reduce_hook(
  841. self,
  842. bucket_cap_mb,
  843. param_to_hook_all_reduce,
  844. device_ids,
  845. ):
  846. # 1. Create gradient buffer
  847. device = torch.device("cpu") if device_ids is None else device_ids[0]
  848. self._delay_grad_buffer = torch.zeros(
  849. sum(p.numel() for p in self._delay_all_reduce_params),
  850. device=device,
  851. )
  852. # 2. Broadcast the parameters
  853. detached_params = [p.detach() for p in self._delay_all_reduce_params]
  854. dist._broadcast_coalesced(self.process_group, detached_params, bucket_cap_mb, 0)
  855. # 3. Hook all reduce to the specified parameter
  856. param_to_hook_all_reduce.register_hook(self._delayed_all_reduce_hook)
  857. # 4. Build tensor views for gradients
  858. offset = 0
  859. for param in self._delay_all_reduce_params:
  860. grad_view = self._delay_grad_buffer[offset : (offset + param.numel())].view(
  861. param.shape
  862. )
  863. self._delay_grad_views.append(grad_view)
  864. offset = offset + param.numel()
  865. # 5. Check whether the all reduce of all params requiring grad is delayed.
  866. for module_name, module in self.module.named_modules():
  867. for param_name, param in module.named_parameters(recurse=False):
  868. if param.requires_grad:
  869. full_name = f"{module_name}.{param_name}"
  870. if full_name not in self.parameters_to_ignore:
  871. # There is at least a param whose all reduce will not be delayed.
  872. # In this case, we should not set self._delay_all_reduce_all_params
  873. # to True.
  874. return
  875. self._delay_all_reduce_all_params = True
  876. def _setup_in_backward_optimizers(self):
  877. # Check if user has used apply_optim_in_backward to overlap optimizer
  878. # step + DDP backward. Current constraints:
  879. # 1. Only allreduce is supported at the moment, no custom communication.
  880. # 2. For DDP-managed parameters that have their optimizer run in
  881. # backward, their gradients are set to ``None``. If your use case
  882. # requires DDP parameters grad not to be set to ``None`` after their
  883. # in-backward optimizer runs, please ping
  884. # https://github.com/pytorch/pytorch/issues/90052.
  885. # NOTE: we use self._module_parameters instead of .parameters() since
  886. # the former excludes ignored (non-DDP managed) parameters.
  887. if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters):
  888. torch._C._log_api_usage_once("ddp.optimizer_in_backward")
  889. # Remove hooks that apply_optim_in_backward had registered because
  890. # DDP customizes how optimizer is overlapped with backward due to
  891. # the allreduce.
  892. param_to_handle_map = (
  893. dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
  894. )
  895. for p in self._module_parameters:
  896. for handle in param_to_handle_map.get(p, []):
  897. handle.remove()
  898. # Need a weakref to DDP instance to run all_reduce (from reducer)
  899. # and get managed DDP parameters.
  900. ddp_weakref = weakref.ref(self)
  901. # Note: importing in function, otherwise this will cause a circular
  902. # import.
  903. from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
  904. _apply_optim_in_backward_hook,
  905. )
  906. self.register_comm_hook(
  907. ddp_weakref,
  908. _apply_optim_in_backward_hook(
  909. gradient_is_bucket_view=self.gradient_as_bucket_view
  910. ),
  911. )
  912. self.reducer._set_optimizer_in_backward() # type: ignore[attr-defined]
  913. def _fire_reducer_autograd_hook(self, idx, *unused):
  914. """
  915. Fire the reducer's autograd hook to allreduce params in a Reducer bucket.
  916. Note that this is only used during mixed precision training as the
  917. Reducer's hooks installed during construction time would not be called
  918. as we're working in the low precision parameter setting.
  919. """
  920. self.reducer._autograd_hook(idx) # type: ignore[attr-defined]
  921. def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None:
  922. """
  923. For DDP mixed precision, put low precision copies on separate stream and create events to wait for them.
  924. When training with DDP mixed precision, this root pre-forward hook kicks
  925. off low precision copies on a separate stream and creates respective
  926. events to wait for them.
  927. """
  928. # Clear out previous iteration submodule to event. This is because we
  929. # may have populated some events for modules that didn't end up being
  930. # used.
  931. self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated]
  932. with torch.cuda.stream(self._mp_stream):
  933. for submodule in self.module.modules():
  934. for param in submodule.parameters(recurse=False):
  935. # Do not cast DDP ignored parameters.
  936. if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
  937. continue
  938. _alloc_storage(param._mp_param, param.size())
  939. # copy() implicitly casts to low precision
  940. with torch.no_grad():
  941. param._mp_param.copy_(param.data)
  942. # TODO: when zero_grad(set_to_none=False) or in grad
  943. # accumulation case, accumulated grads can be in fp32
  944. # which can cause errors when running DDP backwards due
  945. # to mismatched incoming and accumulated gradient types.
  946. # So we manually cast the accumulated grad down for now,
  947. # in the future we may shift to FSDP style gradient
  948. # accumulation management where the accumulated gradient
  949. # is saved and .grad field is set to None, bypassing
  950. # this issue.
  951. if param.grad is not None:
  952. param.grad.data = param.grad.to(
  953. self.mixed_precision.param_dtype # type: ignore[union-attr]
  954. )
  955. param.data = param._mp_param
  956. copy_event = torch.cuda.Event()
  957. copy_event.record()
  958. self._submodule_to_event[submodule].append(copy_event)
  959. def _module_wait_for_copy_hook(
  960. self,
  961. module,
  962. *args: Any,
  963. **kwargs: Any,
  964. ) -> None:
  965. """Before carrying out computation, wait on the appropriate event to ensure low precision copies have finished."""
  966. try:
  967. event = self._submodule_to_event[module].popleft()
  968. except IndexError:
  969. # copy event has already been waited on
  970. return
  971. event.wait(stream=torch.cuda.current_stream())
  972. for p in module.parameters(recurse=False):
  973. # Don't register hooks if param does not require grad
  974. if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored):
  975. continue
  976. # We need to register autograd hook here instead of DDP's ctor
  977. # since we're working with the low precision param. Register them
  978. # via obtaining the gradient accumulator.
  979. tmp = p.expand_as(p)
  980. grad_acc = tmp.grad_fn.next_functions[0][0]
  981. hook = grad_acc.register_hook(
  982. functools.partial(self._fire_reducer_autograd_hook, p._idx)
  983. )
  984. p._ddp_mp_hook_state = (grad_acc, hook)
  985. def _log_and_throw(self, err_type, err_msg):
  986. if self.logger is not None:
  987. self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}")
  988. raise err_type(err_msg)
  989. def _ddp_init_helper(
  990. self,
  991. parameters,
  992. expect_sparse_gradient,
  993. param_to_name_mapping,
  994. static_graph,
  995. ):
  996. """
  997. DDP init helper function to manage parameters, grad hooks, logging, and SyncBatchNorm.
  998. Initialization helper function that does the following:
  999. (1) bucketing the parameters for reductions
  1000. (2) resetting the bucketing states
  1001. (3) registering the grad hooks
  1002. (4) Logging construction-time DDP logging data
  1003. (5) passing a handle of DDP to SyncBatchNorm Layer
  1004. """
  1005. # Notice, the parameters order is not in the order in which they are used,
  1006. # especially in models with control flow.
  1007. #
  1008. # Alongside parameters are not presented in the real execution order,
  1009. # if a certain model happens to also
  1010. # 1) have other collectives comm ops in its backward graph.
  1011. # 2) have unused parameter in subset ranks of the whole world.
  1012. # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter,
  1013. # matching up with other collectives comm ops on other ranks unexpectedly.
  1014. #
  1015. # In order to handle this corner case, when the parameters are not in the real execution order,
  1016. # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients
  1017. # of the whole graph are computed.
  1018. #
  1019. # Notice, here we only disable bucketing for the first iteration.
  1020. # After the first iteration, it's OK to rebuild buckets,
  1021. # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph.
  1022. # Can remove this branching once #73732 is landed.
  1023. if static_graph is True or self.find_unused_parameters is False:
  1024. bucket_size_limits = [sys.maxsize]
  1025. else:
  1026. if self.bucket_bytes_cap_default:
  1027. bucket_size_limits = [
  1028. dist._DEFAULT_FIRST_BUCKET_BYTES,
  1029. self.bucket_bytes_cap,
  1030. ]
  1031. else:
  1032. bucket_size_limits = [self.bucket_bytes_cap]
  1033. (
  1034. bucket_indices,
  1035. per_bucket_size_limits,
  1036. ) = dist._compute_bucket_assignment_by_size(
  1037. parameters,
  1038. bucket_size_limits,
  1039. expect_sparse_gradient,
  1040. )
  1041. # Remember index for parameters if we are in mixed precision, as we
  1042. # need to pass in index to Reducer's autograd hook via python.
  1043. if self.mixed_precision is not None:
  1044. for i, p in enumerate(parameters):
  1045. p._idx = i
  1046. # Note: reverse list of buckets because we want to approximate the
  1047. # order in which their gradients are produced, and assume they
  1048. # are used in the forward pass in the order they are defined.
  1049. self.reducer = dist.Reducer(
  1050. parameters,
  1051. list(reversed(bucket_indices)),
  1052. list(reversed(per_bucket_size_limits)),
  1053. self.process_group,
  1054. expect_sparse_gradient,
  1055. # The bucket size limit is specified in the constructor.
  1056. # Additionally, we allow for a single small bucket for parameters
  1057. # that are defined first, such that their gradients don't spill into
  1058. # a much larger bucket, adding unnecessary latency after gradient
  1059. # computation finishes. Experiments showed 1MB is a reasonable value.
  1060. self.bucket_bytes_cap,
  1061. self.find_unused_parameters,
  1062. self.gradient_as_bucket_view,
  1063. param_to_name_mapping,
  1064. # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first
  1065. # bucket.
  1066. dist._DEFAULT_FIRST_BUCKET_BYTES
  1067. if self.bucket_bytes_cap_default
  1068. else self.bucket_bytes_cap,
  1069. )
  1070. self.logger = dist.Logger(self.reducer)
  1071. # Set as a weak reference to avoid reference cycle between
  1072. # logger and reducer.
  1073. self.reducer.set_logger(self.logger)
  1074. has_sync_bn = False
  1075. for submodule in self.module.modules():
  1076. if isinstance(submodule, torch.nn.SyncBatchNorm):
  1077. has_sync_bn = True
  1078. break
  1079. # Set logging data that can be got during construction time.
  1080. self.logger.set_construction_data_and_log(
  1081. self.module.__class__.__name__,
  1082. [] if self.device_ids is None else self.device_ids,
  1083. -1 if self.output_device is None else self.output_device,
  1084. self.broadcast_buffers,
  1085. has_sync_bn,
  1086. static_graph,
  1087. )
  1088. # passing a handle to torch.nn.SyncBatchNorm layer
  1089. self._passing_sync_batchnorm_handle(self.module)
  1090. def __getstate__(self):
  1091. self._check_default_group()
  1092. attrs = copy.copy(self.__dict__)
  1093. del attrs["process_group"]
  1094. del attrs["reducer"]
  1095. del attrs["logger"]
  1096. return attrs
  1097. def __setstate__(self, state):
  1098. # If serializable, then the process group should be the default one
  1099. self.process_group = _get_default_group()
  1100. super().__setstate__(state)
  1101. self.__dict__.setdefault("require_forward_param_sync", True)
  1102. self.__dict__.setdefault("require_backward_grad_sync", True)
  1103. parameters, expect_sparse_gradient = self._build_params_for_reducer()
  1104. # In debug mode, build a mapping of parameter index -> parameter.
  1105. param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
  1106. # Builds reducer.
  1107. self._ddp_init_helper(
  1108. parameters,
  1109. expect_sparse_gradient,
  1110. param_to_name_mapping,
  1111. self.static_graph,
  1112. )
  1113. if self.static_graph:
  1114. self.reducer._set_static_graph()
  1115. assert self.logger is not None
  1116. self.logger._set_static_graph()
  1117. def _build_params_for_reducer(self):
  1118. # Build tuple of (module, parameter) for all parameters that require grads.
  1119. modules_and_parameters = [
  1120. (module, parameter)
  1121. for module_name, module in self.module.named_modules()
  1122. for parameter in [
  1123. param
  1124. # Note that we access module.named_parameters instead of
  1125. # parameters(module). parameters(module) is only needed in the
  1126. # single-process multi device case, where it accesses replicated
  1127. # parameters through _former_parameters.
  1128. for param_name, param in module.named_parameters(recurse=False)
  1129. if param.requires_grad
  1130. and f"{module_name}.{param_name}" not in self.parameters_to_ignore
  1131. ]
  1132. ]
  1133. # Deduplicate any parameters that might be shared across child modules.
  1134. memo = set()
  1135. modules_and_parameters = [
  1136. # "p not in memo" is the deduplication check.
  1137. # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
  1138. (m, p)
  1139. for m, p in modules_and_parameters
  1140. if p not in memo and not memo.add(p) # type: ignore[func-returns-value]
  1141. ]
  1142. # Build list of parameters.
  1143. parameters = [parameter for _, parameter in modules_and_parameters]
  1144. # Checks if a module will produce a sparse gradient.
  1145. def produces_sparse_gradient(module):
  1146. if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
  1147. return module.sparse
  1148. return False
  1149. # Build list of booleans indicating whether or not to expect sparse
  1150. # gradients for the corresponding parameters.
  1151. expect_sparse_gradient = [
  1152. produces_sparse_gradient(module) for module, _ in modules_and_parameters
  1153. ]
  1154. self._assign_modules_buffers()
  1155. return parameters, expect_sparse_gradient
  1156. def _assign_modules_buffers(self):
  1157. """
  1158. Assign self.module.named_buffers to self.modules_buffers.
  1159. Assigns module buffers to self.modules_buffers which are then used to
  1160. broadcast across ranks when broadcast_buffers=True. Note that this
  1161. must be called every time buffers need to be synced because buffers can
  1162. be reassigned by user module,
  1163. see https://github.com/pytorch/pytorch/issues/63916.
  1164. """
  1165. # Collect buffers for modules, filtering out buffers that should be ignored.
  1166. named_module_buffers = [
  1167. (buffer, buffer_name)
  1168. for buffer_name, buffer in self.module.named_buffers()
  1169. if buffer_name not in self.parameters_to_ignore
  1170. ]
  1171. self.modules_buffers = [
  1172. buffer for (buffer, buffer_name) in named_module_buffers
  1173. ]
  1174. # Dict[str, tensor] representing module buffers not ignored by DDP.
  1175. self.named_module_buffers = {
  1176. buffer_name: buffer for (buffer, buffer_name) in named_module_buffers
  1177. }
  1178. def _build_debug_param_to_name_mapping(self, parameters):
  1179. param_to_param_index = {parameters[i]: i for i in range(len(parameters))}
  1180. param_set = set(parameters)
  1181. param_index_to_param_fqn = {}
  1182. for module_name, module in self.module.named_modules():
  1183. for param_name, param in module.named_parameters(recurse=False):
  1184. fqn = f"{module_name}.{param_name}"
  1185. # Bypass ignored parameters since those are not reduced by DDP
  1186. # to begin with.
  1187. if fqn not in self.parameters_to_ignore and param.requires_grad:
  1188. if param not in param_set:
  1189. self._log_and_throw(
  1190. ValueError,
  1191. f"Param with name {fqn} found in module parameters, but not DDP parameters."
  1192. " This indicates a bug in DDP, please report an issue to PyTorch.",
  1193. )
  1194. param_index = param_to_param_index[param]
  1195. param_index_to_param_fqn[param_index] = fqn
  1196. # Ensure we covered all parameters
  1197. if len(param_set) != len(param_index_to_param_fqn):
  1198. self._log_and_throw(
  1199. ValueError,
  1200. (
  1201. "Expected param to name mapping to cover all parameters, but"
  1202. f" got conflicting lengths: {len(param_set)} vs "
  1203. f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP"
  1204. ", please report an issue to PyTorch."
  1205. ),
  1206. )
  1207. return param_index_to_param_fqn
  1208. def _get_parameters(self, m, recurse=True):
  1209. """Return a generator of module parameters."""
  1210. def model_parameters(m):
  1211. ps = (
  1212. m._former_parameters.values()
  1213. if hasattr(m, "_former_parameters")
  1214. else m.parameters(recurse=False)
  1215. )
  1216. yield from ps
  1217. for mod in m.modules() if recurse else [m]:
  1218. yield from model_parameters(mod)
  1219. def _check_default_group(self):
  1220. pickle_not_supported = False
  1221. try:
  1222. if self.process_group != _get_default_group():
  1223. pickle_not_supported = True
  1224. except RuntimeError:
  1225. pickle_not_supported = True
  1226. if pickle_not_supported:
  1227. self._log_and_throw(
  1228. RuntimeError,
  1229. "DDP Pickling/Unpickling are only supported "
  1230. "when using DDP with the default process "
  1231. "group. That is, when you have called "
  1232. "init_process_group and have not passed "
  1233. "process_group argument to DDP constructor",
  1234. )
  1235. @contextmanager
  1236. def no_sync(self):
  1237. r"""
  1238. Context manager to disable gradient synchronizations across DDP processes.
  1239. Within this context, gradients will be accumulated on module
  1240. variables, which will later be synchronized in the first
  1241. forward-backward pass exiting the context.
  1242. Example::
  1243. >>> # xdoctest: +SKIP("undefined variables")
  1244. >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
  1245. >>> with ddp.no_sync():
  1246. >>> for input in inputs:
  1247. >>> ddp(input).backward() # no synchronization, accumulate grads
  1248. >>> ddp(another_input).backward() # synchronize grads
  1249. .. warning::
  1250. The forward pass should be included inside the context manager, or
  1251. else gradients will still be synchronized.
  1252. """
  1253. old_require_backward_grad_sync = self.require_backward_grad_sync
  1254. self.require_backward_grad_sync = False
  1255. try:
  1256. yield
  1257. finally:
  1258. self.require_backward_grad_sync = old_require_backward_grad_sync
  1259. @classmethod
  1260. def _get_active_ddp_module(cls):
  1261. """`TorchDynamo` requires DDP's status and module for cooperative optimization."""
  1262. return cls._active_ddp_module
  1263. # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in
  1264. # for the 'module_to_run' underneath
  1265. # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details
  1266. @contextmanager
  1267. @torch._disable_dynamo(recursive=False)
  1268. def _inside_ddp_forward(self):
  1269. DistributedDataParallel._active_ddp_module = self
  1270. try:
  1271. yield
  1272. finally:
  1273. DistributedDataParallel._active_ddp_module = None
  1274. def _run_ddp_forward(self, *inputs, **kwargs):
  1275. if self._use_python_reducer:
  1276. return self.module(*inputs, **kwargs) # type: ignore[index]
  1277. else:
  1278. with self._inside_ddp_forward():
  1279. return self.module(*inputs, **kwargs) # type: ignore[index]
  1280. def _clear_grad_buffer(self):
  1281. # Making param.grad points to the grad buffers before backward is based on the
  1282. # assumption that the grad accumulation is done in place in autograd engine,
  1283. # for some edge cases, if the grad accumulation in autograd engine is not in
  1284. # place, then the param.grad and grad buffers are detached.
  1285. if self._delay_grad_buffer is not None:
  1286. # We batch zero_grad for all params by resetting the whole grad
  1287. # buffer when the grad of all params is set to None.
  1288. all_param_grad_none = all(
  1289. param.grad is None for param in self._delay_all_reduce_params
  1290. )
  1291. for index, param in enumerate(self._delay_all_reduce_params):
  1292. if param.grad is None:
  1293. param.grad = self._delay_grad_views[index]
  1294. if not all_param_grad_none:
  1295. param.grad.zero_()
  1296. if all_param_grad_none:
  1297. self._delay_grad_buffer.zero_()
  1298. def _lazy_init(self):
  1299. # Initialization for DDP that occurs after construction, but lazily
  1300. # before the first forward pass.
  1301. self._setup_in_backward_optimizers()
  1302. self._lazy_init_ran = True
  1303. def _should_disable_cpp_reducer(self) -> bool:
  1304. return self._use_python_reducer and (
  1305. torch._utils.is_compiling() or self._force_to_disable_cpp_reducer
  1306. )
  1307. def _pre_forward(self, *inputs, **kwargs):
  1308. if self._should_disable_cpp_reducer():
  1309. return inputs, kwargs
  1310. # Disable the python reducer if compiled_autograd is not enabled.
  1311. if self._accum_grad_hooks:
  1312. for index, h in enumerate(self._accum_grad_hooks):
  1313. h.remove()
  1314. self._accum_grad_hooks.clear()
  1315. if not self._lazy_init_ran and not torch._utils.is_compiling():
  1316. self._lazy_init()
  1317. if self._delay_all_reduce_all_params:
  1318. return inputs, kwargs
  1319. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  1320. assert self.logger is not None
  1321. self.logger.set_runtime_stats_and_log()
  1322. self.reducer.prepare_for_forward()
  1323. # Notify the join context that this process has not joined, if
  1324. # needed
  1325. work = Join.notify_join_context(self)
  1326. if work:
  1327. self.reducer._set_forward_pass_work_handle(
  1328. work, self._divide_by_initial_world_size # type: ignore[arg-type]
  1329. )
  1330. # Calling _rebuild_buckets before forward computation,
  1331. # It may allocate new buckets before deallocating old buckets
  1332. # inside _rebuild_buckets. To save peak memory usage,
  1333. # call _rebuild_buckets before the peak memory usage increases
  1334. # during forward computation.
  1335. # This should be called only once during whole training period.
  1336. if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
  1337. logger.info("Reducer buckets have been rebuilt in this iteration.")
  1338. self._has_rebuilt_buckets = True
  1339. # sync params according to location (before/after forward) user
  1340. # specified as part of hook, if hook was specified.
  1341. if self._check_sync_bufs_pre_fwd():
  1342. self._sync_buffers()
  1343. if self._join_config.enable:
  1344. # Notify joined ranks whether they should sync in backwards pass or not.
  1345. self._check_global_requires_backward_grad_sync(is_joined_rank=False)
  1346. if self.device_ids:
  1347. moved_inputs, moved_kwargs = _to_kwargs(
  1348. inputs,
  1349. kwargs,
  1350. torch.device(self.device_type, self.device_ids[0]),
  1351. self.use_side_stream_for_tensor_copies,
  1352. )
  1353. args, kwargs = moved_inputs[0], moved_kwargs[0]
  1354. # Cast inputs to reduced precision if needed.
  1355. if self.mixed_precision is not None:
  1356. args, kwargs = _cast_forward_inputs(
  1357. self.mixed_precision.param_dtype,
  1358. *args,
  1359. **kwargs,
  1360. )
  1361. return args, kwargs
  1362. else:
  1363. # Cast inputs to reduced precision if needed.
  1364. # TODO (rohan-varma) test this codepath.
  1365. if self.mixed_precision is not None:
  1366. inputs, kwargs = _cast_forward_inputs(
  1367. self.mixed_precision.param_dtype,
  1368. *inputs,
  1369. **kwargs,
  1370. )
  1371. return inputs, kwargs
  1372. def _post_forward(self, output):
  1373. if self._should_disable_cpp_reducer():
  1374. return output
  1375. if self._delay_all_reduce_all_params:
  1376. self._clear_grad_buffer()
  1377. return output
  1378. # sync params according to location (before/after forward) user
  1379. # specified as part of hook, if hook was specified.
  1380. if self._check_sync_bufs_post_fwd():
  1381. self._sync_buffers()
  1382. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  1383. self.require_forward_param_sync = True
  1384. # We'll return the output object verbatim since it is a freeform
  1385. # object. We need to find any tensors in this object, though,
  1386. # because we need to figure out which parameters were used during
  1387. # this forward pass, to ensure we short circuit reduction for any
  1388. # unused parameters. Only if `find_unused_parameters` is set.
  1389. if self.find_unused_parameters and not self.static_graph:
  1390. # Do not need to populate this for static graph.
  1391. self.reducer.prepare_for_backward(list(_find_tensors(output)))
  1392. else:
  1393. self.reducer.prepare_for_backward([])
  1394. else:
  1395. self.require_forward_param_sync = False
  1396. # TODO: DDPSink is currently enabled for unused parameter detection and
  1397. # static graph training for first iteration.
  1398. if (self.find_unused_parameters and not self.static_graph) or (
  1399. self.static_graph and not self._static_graph_delay_allreduce_enqueued
  1400. ):
  1401. (
  1402. output_tensor_list,
  1403. treespec,
  1404. output_is_rref,
  1405. ) = _tree_flatten_with_rref(output)
  1406. output_placeholders = [None for _ in range(len(output_tensor_list))]
  1407. # Do not touch tensors that have no grad_fn, which can cause issues
  1408. # such as https://github.com/pytorch/pytorch/issues/60733
  1409. for i, output in enumerate(output_tensor_list):
  1410. if torch.is_tensor(output) and output.grad_fn is None:
  1411. output_placeholders[i] = output
  1412. # When find_unused_parameters=True, makes tensors which require grad
  1413. # run through the DDPSink backward pass. When not all outputs are
  1414. # used in loss, this makes those corresponding tensors receive
  1415. # undefined gradient which the reducer then handles to ensure
  1416. # param.grad field is not touched and we don't error out.
  1417. passthrough_tensor_list = _DDPSink.apply(
  1418. weakref.ref(self),
  1419. *output_tensor_list,
  1420. )
  1421. for i in range(len(output_placeholders)):
  1422. if output_placeholders[i] is None:
  1423. output_placeholders[i] = passthrough_tensor_list[i]
  1424. # Reconstruct output data structure.
  1425. output = _tree_unflatten_with_rref(
  1426. output_placeholders, treespec, output_is_rref
  1427. )
  1428. # At the end of the forward pass, reset the grad buffer and grad views
  1429. self._clear_grad_buffer()
  1430. return output
  1431. def forward(self, *inputs, **kwargs):
  1432. with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
  1433. inputs, kwargs = self._pre_forward(*inputs, **kwargs)
  1434. output = (
  1435. self.module.forward(*inputs, **kwargs)
  1436. if self._delay_all_reduce_all_params
  1437. else self._run_ddp_forward(*inputs, **kwargs)
  1438. )
  1439. return self._post_forward(output)
  1440. def scatter(self, inputs, kwargs, device_ids):
  1441. return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  1442. def to_kwargs(self, inputs, kwargs, device_id):
  1443. # Kept for BC
  1444. return _to_kwargs(
  1445. inputs,
  1446. kwargs,
  1447. torch.device(self.device_type, device_id),
  1448. self.use_side_stream_for_tensor_copies,
  1449. )
  1450. def gather(self, outputs, output_device):
  1451. return gather(outputs, output_device, dim=self.dim)
  1452. def train(self, mode=True):
  1453. super().train(mode)
  1454. return self
  1455. # When running in join mode, schedules an allreduce to notify joined ranks
  1456. # of whether backwards pass synchronization will run this iteration or not.
  1457. def _check_global_requires_backward_grad_sync(self, is_joined_rank):
  1458. if not is_joined_rank and self.require_backward_grad_sync:
  1459. requires_sync_tensor = torch.ones(1, device=self.device)
  1460. else:
  1461. requires_sync_tensor = torch.zeros(1, device=self.device)
  1462. work = dist.all_reduce(
  1463. requires_sync_tensor, group=self.process_group, async_op=True
  1464. )
  1465. # (kwen2501) This if condition is a plain translation of previous
  1466. # behavior, i.e. in the `is_joined_rank=False` case, `work.wait()`
  1467. # is not called and it doesn't care about the result. I am guessing
  1468. # that it just wants to fire a matching all-reduce and does not want
  1469. # the main stream to wait.
  1470. if is_joined_rank:
  1471. work.wait()
  1472. should_sync_backwards = requires_sync_tensor.item() != 0
  1473. return should_sync_backwards
  1474. else:
  1475. return None # Return value is not/should not be used.
  1476. # When running in join mode, checks and performs sync of module buffers if
  1477. # the models have buffers that should be synchronized in the forward pass.
  1478. def _check_and_sync_module_buffers(self):
  1479. if self._check_sync_bufs_pre_fwd():
  1480. authoritative_rank = self._find_common_rank(self._distributed_rank, False)
  1481. self._sync_module_buffers(authoritative_rank)
  1482. # When running in join model, agrees upon a common rank and broadcast model
  1483. # parameters to all other ranks.
  1484. def _sync_final_model(self, is_last_joiner):
  1485. # Agree upon the process that will be the authoritative model copy.
  1486. # The current rank is a candidate for being the authoritative copy if
  1487. # is_last_joiner=True. We break ties via picking the larger rank.
  1488. self._authoritative_rank = self._find_common_rank(
  1489. self._distributed_rank, is_last_joiner
  1490. )
  1491. _sync_module_states(
  1492. module=self.module,
  1493. process_group=self.process_group,
  1494. broadcast_bucket_size=self.broadcast_bucket_size,
  1495. src=self._authoritative_rank,
  1496. params_and_buffers_to_ignore=self.parameters_to_ignore,
  1497. broadcast_buffers=self.broadcast_buffers,
  1498. )
  1499. # Schedule comm ops to match those scheduled in the reducer's backward
  1500. # pass.
  1501. def _match_all_reduce_for_bwd_pass(self):
  1502. comm_work = []
  1503. # Schedule comm in the same order as Reducer schedules them, i.e.
  1504. # the order of the buckets. Retrieving the bucket order from the reducer
  1505. # ensures that we keep the same order in join mode, such as when bucket
  1506. # order is rebuilt dynamically.
  1507. # Returns grad_buckets in order, but real tensors are substituted with
  1508. # zero tensors of the same shape.
  1509. grad_buckets = self.reducer._get_zeros_like_grad_buckets()
  1510. for grad_bucket in grad_buckets:
  1511. # Joined processes contribute zero gradient. In the case that
  1512. # divide_by_initial_world_size=True, we divide grads by the static
  1513. # world size, if not, the dividing factor is reduced by the number
  1514. # of joined processes.
  1515. work = self.reducer._run_comm_hook(grad_bucket)
  1516. comm_work.append(work)
  1517. for work in comm_work:
  1518. work.wait()
  1519. # Allreduces the used parameter mapping across ranks.
  1520. def _match_unused_params_allreduce(self):
  1521. locally_used_param_map = self.reducer._get_local_used_map()
  1522. self.process_group.allreduce(locally_used_param_map)
  1523. def join(
  1524. self,
  1525. divide_by_initial_world_size: bool = True,
  1526. enable: bool = True,
  1527. throw_on_early_termination: bool = False,
  1528. ):
  1529. r"""
  1530. Context manager for training with uneven inputs across processes in DDP.
  1531. This context manager will keep track of already-joined DDP processes,
  1532. and "shadow" the forward and backward passes by inserting collective
  1533. communication operations to match with the ones created by non-joined
  1534. DDP processes. This will ensure each collective call has a corresponding
  1535. call by already-joined DDP processes, preventing hangs or errors that
  1536. would otherwise happen when training with uneven inputs across
  1537. processes. Alternatively, if the flag ``throw_on_early_termination`` is
  1538. specified to be ``True``, all trainers will throw an error once one rank
  1539. runs out of inputs, allowing these errors to be caught and handled
  1540. according to application logic.
  1541. Once all DDP processes have joined, the context manager will broadcast
  1542. the model corresponding to the last joined process to all processes to
  1543. ensure the model is the same across all processes
  1544. (which is guaranteed by DDP).
  1545. To use this to enable training with uneven inputs across processes,
  1546. simply wrap this context manager around your training loop. No further
  1547. modifications to the model or data loading is required.
  1548. .. warning::
  1549. If the model or training loop this context manager is wrapped around
  1550. has additional distributed collective operations, such as
  1551. ``SyncBatchNorm`` in the model's forward pass, then the flag
  1552. ``throw_on_early_termination`` must be enabled. This is because this
  1553. context manager is not aware of non-DDP collective communication.
  1554. This flag will cause all ranks to throw when any one rank
  1555. exhausts inputs, allowing these errors to be caught and recovered
  1556. from across all ranks.
  1557. Args:
  1558. divide_by_initial_world_size (bool): If ``True``, will divide
  1559. gradients by the initial ``world_size`` DDP training was launched
  1560. with. If ``False``, will compute the effective world size
  1561. (number of ranks that have not depleted their inputs yet) and
  1562. divide gradients by that during allreduce. Set
  1563. ``divide_by_initial_world_size=True`` to ensure every input
  1564. sample including the uneven inputs have equal weight in terms of
  1565. how much they contribute to the global gradient. This is
  1566. achieved by always dividing the gradient by the initial
  1567. ``world_size`` even when we encounter uneven inputs. If you set
  1568. this to ``False``, we divide the gradient by the remaining
  1569. number of nodes. This ensures parity with training on a smaller
  1570. ``world_size`` although it also means the uneven inputs would
  1571. contribute more towards the global gradient. Typically, you
  1572. would want to set this to ``True`` for cases where the last few
  1573. inputs of your training job are uneven. In extreme cases, where
  1574. there is a large discrepancy in the number of inputs, setting
  1575. this to ``False`` might provide better results.
  1576. enable (bool): Whether to enable uneven input detection or not. Pass
  1577. in ``enable=False`` to disable in cases where you know that
  1578. inputs are even across participating processes. Default is
  1579. ``True``.
  1580. throw_on_early_termination (bool): Whether to throw an error
  1581. or continue training when at least one rank has exhausted
  1582. inputs. If ``True``, will throw upon the first rank reaching end
  1583. of data. If ``False``, will continue training with a smaller
  1584. effective world size until all ranks are joined. Note that if
  1585. this flag is specified, then the flag
  1586. ``divide_by_initial_world_size`` would be ignored. Default
  1587. is ``False``.
  1588. Example::
  1589. >>> # xdoctest: +SKIP("Distributed")
  1590. >>> import torch
  1591. >>> import torch.distributed as dist
  1592. >>> import os
  1593. >>> import torch.multiprocessing as mp
  1594. >>> import torch.nn as nn
  1595. >>> # On each spawned worker
  1596. >>> def worker(rank):
  1597. >>> dist.init_process_group("nccl", rank=rank, world_size=2)
  1598. >>> torch.cuda.set_device(rank)
  1599. >>> model = nn.Linear(1, 1, bias=False).to(rank)
  1600. >>> model = torch.nn.parallel.DistributedDataParallel(
  1601. >>> model, device_ids=[rank], output_device=rank
  1602. >>> )
  1603. >>> # Rank 1 gets one more input than rank 0.
  1604. >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
  1605. >>> with model.join():
  1606. >>> for _ in range(5):
  1607. >>> for inp in inputs:
  1608. >>> loss = model(inp).sum()
  1609. >>> loss.backward()
  1610. >>> # Without the join() API, the below synchronization will hang
  1611. >>> # blocking for rank 1's allreduce to complete.
  1612. >>> torch.cuda.synchronize(device=rank)
  1613. """
  1614. return Join(
  1615. [self],
  1616. enable,
  1617. throw_on_early_termination,
  1618. divide_by_initial_world_size=divide_by_initial_world_size,
  1619. )
  1620. def join_hook(
  1621. self,
  1622. **kwargs,
  1623. ):
  1624. r"""
  1625. DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes.
  1626. Arguments:
  1627. kwargs (dict): a :class:`dict` containing any keyword arguments
  1628. to modify the behavior of the join hook at run time; all
  1629. :class:`Joinable` instances sharing the same join context
  1630. manager are forwarded the same value for ``kwargs``.
  1631. The hook supports the following keyword arguments:
  1632. divide_by_initial_world_size (bool, optional):
  1633. If ``True``, then gradients are divided by the initial world
  1634. size that DDP was launched with.
  1635. If ``False``, then gradients are divided by the effective world
  1636. size (i.e. the number of non-joined processes), meaning that
  1637. the uneven inputs contribute more toward the global gradient.
  1638. Typically, this should be set to ``True`` if the degree of
  1639. unevenness is small but can be set to ``False`` in extreme
  1640. cases for possibly better results.
  1641. Default is ``True``.
  1642. """
  1643. divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True)
  1644. return _DDPJoinHook(
  1645. self, divide_by_initial_world_size=divide_by_initial_world_size
  1646. )
  1647. @property
  1648. def join_device(self):
  1649. return self.device
  1650. @property
  1651. def join_process_group(self):
  1652. return self.process_group
  1653. def _register_buffer_comm_hook(
  1654. self,
  1655. state,
  1656. hook: Callable,
  1657. comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
  1658. ):
  1659. r"""
  1660. Allow custom registration of hooks that define how buffer are synchronized across ranks.
  1661. The hook takes in an optional state and is passed in a Dict[str, Tensor]
  1662. corresponding to buffer names and the buffers, and can run arbitrary reductions
  1663. on buffers as opposed to DDP's default broadcast from rank 0. This is useful for
  1664. example if a counter needs to be summed or averaged across ranks every iteration.
  1665. Args:
  1666. state (Any): Optional state that is passed to the hook.
  1667. hook (Callable): Callable with the following signature:
  1668. ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``
  1669. comm_hook_location (_BufferCommHookLocation): Enum value indicating
  1670. where to run the hook.
  1671. _BufferCommHookLocation.PRE_FORWARD means that the
  1672. hook will run _before_ the forward pass, and
  1673. _BufferCommHookLocation.POST_FORWARD means that the
  1674. hook will run _after_ the forward pass.
  1675. NOTE: To maximize performance, users can return a
  1676. List[torch.futures.Future] from their hook, and DDP will
  1677. install and await these hooks appropriately at the end of
  1678. the backward pass. This will ensure all buffers are
  1679. synchronized by the end of the backward pass. If this
  1680. setting is used, it is recommended to pass
  1681. comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
  1682. which will trigger the hook after the forward pass.
  1683. If _BufferCommHookLocation.PRE_FORWARD is used, users must
  1684. ensure appropriate synchronization when manipulating GPU
  1685. buffers in the forward pass.
  1686. """
  1687. assert callable(hook)
  1688. self.buffer_hook = _BufferCommHook(
  1689. buffer_comm_hook=hook,
  1690. buffer_comm_hook_state=state,
  1691. buffer_comm_hook_location=comm_hook_location,
  1692. )
  1693. def register_comm_hook(self, state: object, hook: Callable):
  1694. r"""
  1695. Register communication hook for user-defined DDP aggregation of gradients across multiple workers.
  1696. This hook would be very useful for researchers to try out new ideas. For
  1697. example, this hook can be used to implement several algorithms like GossipGrad
  1698. and gradient compression which involve different communication strategies for
  1699. parameter syncs while running Distributed DataParallel training.
  1700. Args:
  1701. state (object): Passed to the hook to maintain any state information during the training process.
  1702. Examples include error feedback in gradient compression,
  1703. peers to communicate with next in GossipGrad, etc.
  1704. It is locally stored by each worker
  1705. and shared by all the gradient tensors on the worker.
  1706. hook (Callable): Callable with the following signature:
  1707. ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
  1708. This function is called once the bucket is ready. The
  1709. hook can perform whatever processing is needed and return
  1710. a Future indicating completion of any async work (ex: allreduce).
  1711. If the hook doesn't perform any communication, it still
  1712. must return a completed Future. The Future should hold the
  1713. new value of grad bucket's tensors. Once a bucket is ready,
  1714. c10d reducer would call this hook and use the tensors returned
  1715. by the Future and copy grads to individual parameters.
  1716. Note that the future's return type must be a single tensor.
  1717. We also provide an API called ``get_future`` to retrieve a
  1718. Future associated with the completion of ``c10d.ProcessGroup.Work``.
  1719. ``get_future`` is currently supported for NCCL and also supported for most
  1720. operations on GLOO and MPI, except for peer to peer operations (send/recv).
  1721. .. warning ::
  1722. Grad bucket's tensors will not be predivided by world_size. User is responsible
  1723. to divide by the world_size in case of operations like allreduce.
  1724. .. warning ::
  1725. DDP communication hook can only be registered once and should be registered
  1726. before calling backward.
  1727. .. warning ::
  1728. The Future object that hook returns should contain a single tensor
  1729. that has the same shape with the tensors inside grad bucket.
  1730. .. warning ::
  1731. ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support
  1732. for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
  1733. Example::
  1734. Below is an example of a noop hook that returns the same tensor.
  1735. >>> # xdoctest: +SKIP('undefined name')
  1736. >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
  1737. >>> fut = torch.futures.Future()
  1738. >>> fut.set_result(bucket.buffer())
  1739. >>> return fut
  1740. >>> ddp.register_comm_hook(state=None, hook=noop)
  1741. Example::
  1742. Below is an example of a Parallel SGD algorithm where gradients are encoded before
  1743. allreduce, and then decoded after allreduce.
  1744. >>> # xdoctest: +SKIP('undefined name')
  1745. >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
  1746. >>> encoded_tensor = encode(bucket.buffer()) # encode gradients
  1747. >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future()
  1748. >>> # Define the then callback to decode.
  1749. >>> def decode(fut):
  1750. >>> decoded_tensor = decode(fut.value()[0]) # decode gradients
  1751. >>> return decoded_tensor
  1752. >>> return fut.then(decode)
  1753. >>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
  1754. """
  1755. self._check_comm_hook(hook)
  1756. assert self.logger is not None
  1757. self.logger._set_comm_hook_name(hook.__qualname__)
  1758. self._comm_hooks.append((hook, state))
  1759. dist._register_comm_hook(self.reducer, state, hook)
  1760. def _register_builtin_comm_hook(self, comm_hook_type):
  1761. r"""
  1762. Register a built-in communication hook that specifies how DDP aggregates gradients across multiple workers.
  1763. The built-in hooks aim to provide efficient C++ implementations for certain hooks,
  1764. which might not be as efficient if implemented in Python using a Python communication hook.
  1765. Args:
  1766. comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc.
  1767. .. warning ::
  1768. DDP communication hook can only be registered once and should be registered
  1769. before calling backward.
  1770. Example::
  1771. Below is an example of a FP16 compression where gradients are
  1772. compressed into 16-bit floating-point numbers before allreduce, and
  1773. then decompressed after allreduce.
  1774. >>> # xdoctest: +SKIP('undefined name')
  1775. >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)
  1776. """
  1777. assert self.logger is not None
  1778. self.logger._set_comm_hook_name(str(comm_hook_type))
  1779. dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
  1780. def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs):
  1781. r"""
  1782. Register an optimizer in DDP to optimize parameter immediately after its gradient reduction.
  1783. Registers an optimizer with DDP such that the optimization for a
  1784. parameter will run immediately when that parameter's gradient is
  1785. finished with reduction, instead of waiting for all parameters'
  1786. gradients to finish reduction. This can result in a training speedup
  1787. depending on your workload since the optimizer can run while gradient
  1788. reduction for other parameters are still ongoing. In addition, this has
  1789. the potential to reduce peak memory consumption during training, as it
  1790. only needs to load the per-parameter optimizer states of a single
  1791. parameter at a time, instead of loading all per-parameter optimizer
  1792. states at once.
  1793. Args:
  1794. optim (Type): a ``torch.optim.Optimizer`` class to be registered
  1795. as a fused optimizer.
  1796. *args (Sequence[Any]): Arguments to forward to `optim`.
  1797. optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
  1798. to optimize, similar to `params` argument of traditional `torch.optim`
  1799. Optimizers. If this is omitted, all DDP model parameters will be
  1800. optimized.
  1801. **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`.
  1802. .. warning ::
  1803. _register_fused_optim should only be called once on a DDP instance,
  1804. and registering multiple fused optimizers for the same DDP model
  1805. is not currently supported. Please ping
  1806. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1807. for your use case.
  1808. .. warning ::
  1809. _register_fused_optim and register_comm_hook currently do not
  1810. compose together, meaning that custom DDP communication hooks are
  1811. not supported with overlapped optimizers. Please ping
  1812. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1813. for your use case.
  1814. .. warning ::
  1815. Gradient accumulation and DDP `no_sync` are currently not supported
  1816. with overlapped optimizer. Please ping
  1817. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1818. for your use case.
  1819. Example::
  1820. >>> # xdoctest: +SKIP("No rendezvous handler")
  1821. >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
  1822. >>> net = torch.nn.parallel.DistributedDataParallel(model, pg)
  1823. >>> lr = 1e-2
  1824. >>> betas = (0.9, 0.99)
  1825. >>> eps = 1e-6
  1826. >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps)
  1827. >>> # Example with subset of parameters
  1828. >>> params_to_opt = [list(net.parameters())[0]]
  1829. >>> net._register_fused_optim(
  1830. ... torch.optim.Adam, lr, optim_params=params_to_opt, betas=betas, eps=eps
  1831. ... )
  1832. """
  1833. # Note: importing in function, otherwise this will cause a circular
  1834. # import as optimizer_overlap module needs to import DistributedDataParallel.
  1835. from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim
  1836. overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs)
  1837. try:
  1838. overlapped_optim.register_ddp(self)
  1839. except NotImplementedError as e:
  1840. raise RuntimeError(
  1841. f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
  1842. ) from e
  1843. def _distributed_broadcast_coalesced(
  1844. self, tensors, buffer_size, authoritative_rank=0
  1845. ):
  1846. dist._broadcast_coalesced(
  1847. self.process_group, tensors, buffer_size, authoritative_rank
  1848. )
  1849. def _check_sync_bufs_post_fwd(self):
  1850. return (
  1851. self.will_sync_module_buffers()
  1852. and hasattr(self, "buffer_hook")
  1853. and self.buffer_hook.buffer_comm_hook_location
  1854. == _BufferCommHookLocation.POST_FORWARD
  1855. )
  1856. def _check_sync_bufs_pre_fwd(self):
  1857. return self.will_sync_module_buffers() and (
  1858. not hasattr(self, "buffer_hook")
  1859. or self.buffer_hook.buffer_comm_hook_location
  1860. == _BufferCommHookLocation.PRE_FORWARD
  1861. )
  1862. def will_sync_module_buffers(self):
  1863. return (
  1864. self.require_forward_param_sync
  1865. and self.broadcast_buffers
  1866. and len(self.modules_buffers) > 0
  1867. )
  1868. def _find_common_rank(self, input_rank, rank_cond):
  1869. # -1 indicates that this rank is not under consideration to be the
  1870. # common_rank
  1871. rank_to_use = torch.tensor(
  1872. [input_rank if rank_cond else -1],
  1873. device=self.device,
  1874. )
  1875. dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group)
  1876. if rank_to_use.item() == -1:
  1877. self._log_and_throw(
  1878. ValueError,
  1879. "BUG! Expected rank_cond to be true for at least one process."
  1880. " This indicates a bug in PyTorch, please report an issue.",
  1881. )
  1882. return rank_to_use.item()
  1883. def _sync_buffers(self):
  1884. with torch.no_grad():
  1885. # module buffer sync
  1886. # Synchronize buffers across processes.
  1887. # If we are running DDP with the join manager, we have to agree
  1888. # upon a rank to sync module buffers from, since rank 0 may
  1889. # already have been joined and have stale module buffers.
  1890. if self._join_config.enable:
  1891. authoritative_rank = self._find_common_rank(
  1892. self._distributed_rank, True
  1893. )
  1894. else:
  1895. # The process with rank 0 is considered the authoritative copy.
  1896. authoritative_rank = 0
  1897. # Update self.modules_buffers incase any buffers were
  1898. # reassigned.
  1899. self._assign_modules_buffers()
  1900. self._sync_module_buffers(authoritative_rank)
  1901. def _sync_module_buffers(self, authoritative_rank):
  1902. if not hasattr(self, "buffer_hook"):
  1903. self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
  1904. else:
  1905. hook = self.buffer_hook.buffer_comm_hook
  1906. state = self.buffer_hook.buffer_comm_hook_state
  1907. futs = hook(state, self.named_module_buffers)
  1908. if futs is not None:
  1909. self.reducer._install_post_backward_futures(futs)
  1910. def _default_broadcast_coalesced(
  1911. self, bufs=None, bucket_size=None, authoritative_rank=0
  1912. ):
  1913. """
  1914. Broadcasts buffers from rank 0 to rest of workers.
  1915. If bufs, bucket_size are None, default values self.modules_buffers
  1916. and self.broadcast_bucket_size are used instead.
  1917. """
  1918. if bufs is None:
  1919. bufs = self.modules_buffers
  1920. if bucket_size is None:
  1921. bucket_size = self.broadcast_bucket_size
  1922. self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank)
  1923. def _passing_sync_batchnorm_handle(self, module):
  1924. for layer in module.modules():
  1925. if isinstance(layer, torch.nn.modules.SyncBatchNorm):
  1926. if self.device_type == "cpu":
  1927. self._log_and_throw(
  1928. ValueError,
  1929. "SyncBatchNorm layers only work with GPU modules",
  1930. )
  1931. def _check_comm_hook(self, hook):
  1932. if not callable(hook):
  1933. self._log_and_throw(TypeError, "Communication hook must be callable.")
  1934. sig = inspect.signature(hook)
  1935. if (
  1936. sig.parameters["bucket"].annotation != inspect._empty
  1937. and sig.parameters["bucket"].annotation != dist.GradBucket
  1938. ):
  1939. self._log_and_throw(
  1940. ValueError,
  1941. "Communication hook: bucket annotation should be dist.GradBucket.",
  1942. )
  1943. if (
  1944. sig.return_annotation != inspect._empty
  1945. and sig.return_annotation != torch.futures.Future[torch.Tensor]
  1946. ):
  1947. self._log_and_throw(
  1948. ValueError,
  1949. "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].",
  1950. )
  1951. if hook.__name__ in [
  1952. "bf16_compress_hook",
  1953. "bf16_compress_wrapper_hook",
  1954. ] and (
  1955. (torch.version.cuda is None and torch.version.hip is None)
  1956. or (
  1957. torch.version.cuda is not None
  1958. and int(torch.version.cuda.split(".")[0]) < 11
  1959. )
  1960. or not dist.is_available()
  1961. or not dist.is_nccl_available()
  1962. or torch.cuda.nccl.version() < (2, 10)
  1963. ):
  1964. self._log_and_throw(
  1965. TypeError,
  1966. "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.",
  1967. )
  1968. @property
  1969. def _distributed_rank(self):
  1970. return dist.get_rank(self.process_group)
  1971. @staticmethod
  1972. def _get_data_parallel_params(module, named_params=False):
  1973. """Return a generator of parameters managed by a given DDP unit."""
  1974. for param in (
  1975. module.parameters() if not named_params else module.named_parameters()
  1976. ):
  1977. if not hasattr(param, "_ddp_ignored"):
  1978. yield param
  1979. @staticmethod
  1980. def _set_params_and_buffers_to_ignore_for_model(
  1981. module, params_and_buffers_to_ignore
  1982. ):
  1983. """
  1984. Set parameters and buffers to be ignored by DDP.
  1985. Expected format for parameters is the fully qualified name: {module_name}.{param_name}, and
  1986. similarly, {module_name}.{buffer_name} for buffers. For example:
  1987. params_to_ignore = []
  1988. # NB: model here is vanilla PyTorch module, not yet wrapped with DDP.
  1989. for module_name, module in model.named_modules():
  1990. for param_name, param in module.named_parameters(recurse=False):
  1991. if should_ignore(param):
  1992. # Create expected format
  1993. fqn = f"{module_name}.{param_name}"
  1994. params_to_ignore.append(fqn)
  1995. torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
  1996. model,
  1997. params_to_ignore
  1998. )
  1999. """
  2000. # This is a workaround to set parameters and buffers DDP should ignore
  2001. # during synchronization. It will be removed when the API is finalized
  2002. # as part of addressing https://github.com/pytorch/pytorch/issues/43690.
  2003. module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore
  2004. for name, param in module.named_parameters():
  2005. if name in params_and_buffers_to_ignore:
  2006. param._ddp_ignored = True
  2007. for name, buffer in module.named_buffers():
  2008. if name in params_and_buffers_to_ignore:
  2009. buffer._ddp_ignored = True
  2010. def _get_ddp_logging_data(self):
  2011. r"""
  2012. Return a dictionary of logging data for debugging and analysis.
  2013. This interface can be called after DistributedDataParallel() is
  2014. constructed. It returns a dictionary of logging data. It could help
  2015. for debugging and analysis. The logging data includes DistributedDataParallel
  2016. constructor input parameters, some internal states of DistributedDataParallel
  2017. and performance metrics. Simply print the dictionary and see what
  2018. these metrics are.
  2019. This is a prototype interface and subject to change in the future.
  2020. """
  2021. assert self.logger is not None
  2022. ddp_logging_data = self.logger._get_ddp_logging_data()
  2023. return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map}
  2024. def _set_ddp_runtime_logging_sample_rate(self, sample_rate):
  2025. r"""
  2026. Set sample_rate of collecting runtime stats.
  2027. This interface allows users to set sample_rate of collecting
  2028. runtime stats. The runtime stats will be recorded for the
  2029. first 10 iterations, after 10 iterations runtime stats will be
  2030. recorded once every "sample_rate" training iterations. In
  2031. default, runtime stats are recorded for the first 10 iterations,
  2032. after 10 iterations runtime stats are recorded once every
  2033. "kDDPRuntimeLoggingSampleRate=100" training iterations.
  2034. This is a prototype interface and subject to change in the future.
  2035. """
  2036. if sample_rate < 1:
  2037. self._log_and_throw(
  2038. ValueError,
  2039. "DDP runtime logging sample rate should be equal or greater than 1",
  2040. )
  2041. self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate)
  2042. def _set_static_graph(self):
  2043. """
  2044. Set static graph for DDP.
  2045. It is recommended to set static graph in the DDP constructor, which will
  2046. call this private API internally.
  2047. """
  2048. # If self.static_graph has been set, no need to set it again
  2049. if self.static_graph:
  2050. warnings.warn(
  2051. "You've set static_graph to be True, no need to set it again."
  2052. )
  2053. return
  2054. self.static_graph = True
  2055. self._static_graph_delay_allreduce_enqueued = False
  2056. self.reducer._set_static_graph()
  2057. assert self.logger is not None
  2058. self.logger._set_static_graph()
  2059. if self.find_unused_parameters:
  2060. warnings.warn(
  2061. "You passed find_unused_parameters=true to DistributedDataParallel, "
  2062. "`_set_static_graph` will detect unused parameters automatically, so "
  2063. "you do not need to set find_unused_parameters=true, just be sure these "
  2064. "unused parameters will not change during training loop while calling "
  2065. "`_set_static_graph`."
  2066. )
  2067. def _remove_autograd_hooks(self):
  2068. """Remove autograd hooks registered by the reducer on the model parameters."""
  2069. self.reducer._remove_autograd_hooks()
  2070. def _check_reducer_finalized(self):
  2071. """
  2072. Check if the reducer has processed all buckets and finalized the backward appropriately.
  2073. It is useful to call this method after calling .backward() in your training loop
  2074. in order to avoid subsequent hard to debug errors down the road due to the
  2075. reducer not finalizing backward.
  2076. """
  2077. self.reducer._check_reducer_finalized()
  2078. def _set_sparse_metadata(self, global_unique_ids):
  2079. self.reducer._set_sparse_metadata(global_unique_ids)
  2080. def _update_process_group(self, new_process_group):
  2081. """
  2082. Dynamically updates the process group for DDP so that we can shrink/expand DDP
  2083. world size without having to reinitialize DDP.
  2084. NOTE: If you are using custom communications hooks via, register_comm_hook,
  2085. you need to update the process groups for those hooks separately.
  2086. """
  2087. # Force a rebuild of buckets for a new process group. This ensures all ranks
  2088. # are synchronized in terms of when they will rebuild buckets and also
  2089. # re-evaluates previous assumptions of buckets given the world size might have
  2090. # changed.
  2091. self._has_rebuilt_buckets = False
  2092. self.reducer._reset_state()
  2093. if not _rank_not_in_group(new_process_group):
  2094. self.process_group = new_process_group
  2095. self.reducer._update_process_group(new_process_group)
  2096. def _set_ddp_sink_clone(self, val: bool):
  2097. """
  2098. Sets whether or not DDPSink should clone the output tensors or not.
  2099. The default is True since if the loss is modified in place we run
  2100. into the view is modified in-place error.
  2101. Although, cloning the tensors can add significant memory and
  2102. performance hit if the number and size of tensors are large. As
  2103. a result, this can be set to False if you are not modifying the
  2104. loss in place.
  2105. """
  2106. self._ddp_sink_clone = val