| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399 |
- # mypy: allow-untyped-defs
- import copy
- import functools
- import inspect
- import itertools
- import logging
- import os
- import sys
- import warnings
- import weakref
- from collections import defaultdict, deque
- from contextlib import contextmanager
- from dataclasses import dataclass, fields, is_dataclass
- from enum import auto, Enum
- from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING
- import torch
- import torch.distributed as dist
- from torch.autograd import Function, Variable
- from torch.distributed.algorithms.join import Join, Joinable, JoinHook
- from torch.utils._pytree import tree_flatten, tree_unflatten
- RPC_AVAILABLE = False
- if dist.is_available():
- from torch.distributed.distributed_c10d import (
- _get_default_group,
- _rank_not_in_group,
- ReduceOp,
- )
- from torch.distributed.utils import (
- _alloc_storage,
- _cast_forward_inputs,
- _free_storage,
- _sync_module_states,
- _to_kwargs,
- _verify_param_shape_across_processes,
- )
- if torch.distributed.rpc.is_available():
- RPC_AVAILABLE = True
- from torch.distributed.rpc import RRef
- from torch._utils import _get_device_index
- from ..modules import Module
- from .scatter_gather import gather, scatter_kwargs # noqa: F401
- if TYPE_CHECKING:
- from torch.utils.hooks import RemovableHandle
- __all__ = ["DistributedDataParallel"]
- logger = logging.getLogger(__name__)
- @dataclass
- class _MixedPrecision:
- """
- This configures DDP-native mixed precision training.
- Attributes:
- param_dtype (torch.dtype): This specifies the dtype for model
- parameters, inputs (when ``cast_forward_inputs`` is set to
- ``True``), and therefore the dtype for computation.
- However, outside the forward and backward passes, parameters are in
- full precision. Model checkpointing always happens in full
- precision.
- reduce_dtype (torch.dtype): This specifies the dtype for gradient
- reduction, which is permitted to differ from ``param_dtype``.
- buffer_dtype (torch.dtype): This specifies the dtype for buffers.
- .. note:: This API is experimental and subject to change.
- .. note:: Only floating point tensors are cast to their specified dtypes.
- .. note:: ``state_dict`` checkpoints parameters and buffers in full
- precision.
- .. note:: Each low precision dtype must be specified explicitly. For
- example, ``_MixedPrecision(reduce_dtype=torch.float16)`` only specifies
- the reduction dtype to be low precision, and DDP will not cast
- parameters or buffers.
- .. note:: If a ``reduce_dtype`` is not specified, then gradient reduction
- happens in ``param_dtype`` if specified or the original parameter dtype
- otherwise. For example, ``_MixedPrecision(param_dtype=torch.float16)``
- would result in communication occurring in fp16.
- """
- param_dtype: Optional[torch.dtype] = None
- reduce_dtype: Optional[torch.dtype] = None
- buffer_dtype: Optional[torch.dtype] = None
- # TODO (rohan-varma): keep_low_precision_grads: bool = False
- # TODO (rohan-varma): APIs to allow users to run batchnorm and layernorm
- # in full precision. For DDP, this can be implemented by not performing the
- # parameter cast for BN and LN units.
- def _cast_buffers(mixed_precision_config, root_module):
- """Casts buffers to the given ``buffer_dtype``."""
- for buf in root_module.buffers():
- if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored:
- continue
- buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype)
- def _setup_mixed_precision_params(mixed_precision_config, root_module):
- """Create and free storage for the mixed precision parameters."""
- for param in root_module.parameters():
- # Do not setup mixed precision for DDP ignored parameters.
- if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
- continue
- if not hasattr(param, "_mp_param"):
- param._mp_param = torch.zeros_like(
- param,
- device=param.device,
- dtype=mixed_precision_config.param_dtype,
- requires_grad=param.requires_grad,
- )
- _free_storage(param._mp_param)
- # _fp_param will point to the full precision param so it can be switched
- # back to at the end of forward / backward.
- param._fp_param = param.data
- def _tree_flatten_with_rref(output):
- output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
- if output_is_rref:
- output_tensor_list, treespec = tree_flatten(output.local_value())
- else:
- output_tensor_list, treespec = tree_flatten(output)
- # Need to return flattened tensors, spec to re-pack them, as well
- # as if the return type was actually an RRef to reconstruct.
- return output_tensor_list, treespec, output_is_rref
- def _tree_unflatten_with_rref(output, treespec, output_is_rref):
- output = tree_unflatten(output, treespec)
- if output_is_rref:
- output = RRef(output)
- return output
- def _find_tensors(obj):
- r"""Recursively find all tensors contained in the specified object."""
- if RPC_AVAILABLE and isinstance(obj, RRef):
- # If the current node is the owner of the RRef, unwrap it and try to
- # find Tensors.
- # TODO: Expand to remote RRefs.
- if obj.is_owner():
- return _find_tensors(obj.local_value())
- if isinstance(obj, torch.Tensor):
- return [obj]
- if isinstance(obj, (list, tuple)):
- return itertools.chain.from_iterable(map(_find_tensors, obj))
- if isinstance(obj, dict):
- return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
- if is_dataclass(obj):
- return itertools.chain.from_iterable(
- map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
- )
- return []
- def _dump_DDP_relevant_env_vars():
- relevant_env_vars = [
- "RANK",
- "LOCAL_RANK",
- "WORLD_SIZE",
- "MASTER_PORT",
- "MASTER_ADDR",
- "CUDA_VISIBLE_DEVICES",
- "GLOO_SOCKET_IFNAME",
- "GLOO_DEVICE_TRANSPORT",
- "NCCL_SOCKET_IFNAME",
- "TORCH_NCCL_BLOCKING_WAIT",
- "NCCL_DEBUG",
- "NCCL_DEBUG_SUBSYS",
- "NCCL_IB_DISABLE",
- # More NCCL env vars:
- "NCCL_P2P_DISABLE",
- "NCCL_P2P_LEVEL",
- "NCCL_SHM_DISABLE",
- "NCCL_SOCKET_NTHREADS",
- "NCCL_NSOCKS_PERTHREAD",
- "NCCL_BUFFSIZE",
- "NCCL_NTHREADS",
- "NCCL_RINGS",
- "NCCL_MAX_NCHANNELS",
- "NCCL_MIN_NCHANNELS",
- "NCCL_CHECKS_DISABLE",
- "NCCL_CHECK_POINTERS",
- "NCCL_LAUNCH_MODE",
- "NCCL_IB_HCA",
- "NCCL_IB_TIMEOUT",
- "NCCL_IB_RETRY_CNT",
- "NCCL_IB_GID_INDEX",
- "NCCL_IB_SL",
- "NCCL_IB_TC",
- "NCCL_IB_AR_THRESHOLD",
- "NCCL_IB_CUDA_SUPPORT",
- "NCCL_NET_GDR_LEVEL",
- "NCCL_NET_GDR_READ",
- "NCCL_SINGLE_RING_THRESHOLD",
- "NCCL_LL_THRESHOLD",
- "NCCL_TREE_THRESHOLD",
- "NCCL_ALGO",
- "NCCL_PROTO",
- "NCCL_IGNORE_CPU_AFFINITY",
- "NCCL_DEBUG_FILE",
- "NCCL_COLLNET_ENABLE",
- "NCCL_TOPO_FILE",
- "NCCL_TOPO_DUMP_FILE",
- "TORCH_NCCL_ASYNC_ERROR_HANDLING",
- ]
- formatted_output = ""
- for var in relevant_env_vars:
- value = os.environ[var] if var in os.environ else "N/A"
- formatted_output += f"env:{var}={value}\n"
- print(formatted_output)
- class _BufferCommHookLocation(Enum):
- PRE_FORWARD = auto()
- POST_FORWARD = auto()
- @dataclass
- class _BufferCommHook:
- buffer_comm_hook: Callable
- buffer_comm_hook_state: Any
- buffer_comm_hook_location: _BufferCommHookLocation
- # Add a DDPSink to run various functions when backwards starts, such as
- # queueing call back of out-most backward/graph task,
- # this helps call back is fired after all gradients' calculation
- # is completed.
- class _DDPSink(Function):
- @staticmethod
- def forward(ctx, ddp_weakref, *inputs):
- # set_materialize_grads(False) will ensure that None gradients stay as
- # None and are not filled with zeros.
- ctx.set_materialize_grads(False)
- ctx.ddp_weakref = ddp_weakref
- ret = inputs
- if ddp_weakref()._ddp_sink_clone:
- ret = tuple(
- inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
- )
- return ret
- @staticmethod
- def backward(ctx, *grad_outputs):
- # Enqueue delay allreduce for static graph training on the first
- # iteration.
- ddp_weakref = ctx.ddp_weakref()
- reducer = ddp_weakref.reducer
- static_graph = ddp_weakref.static_graph
- delay_ar_enqueued = (
- static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued
- )
- if static_graph and not delay_ar_enqueued:
- Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
- reducer._delay_all_reduce
- )
- ddp_weakref._static_graph_delay_allreduce_enqueued = True
- return (None, *grad_outputs)
- class _DDPJoinHook(JoinHook):
- def __init__(self, ddp, divide_by_initial_world_size):
- """Set config variables for internal usage."""
- assert isinstance(ddp, DistributedDataParallel), (
- "DDP join hook requires passing in a DistributedDataParallel "
- "instance as the state"
- )
- assert ddp.logger is not None
- ddp.logger._set_uneven_input_join()
- self.ddp = ddp
- self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
- super().__init__()
- def main_hook(self):
- """Shadow the DDP collective communication operations in the forward and backward passes."""
- ddp = self.ddp
- # Buckets are rebuilt only once during a training period
- ddp.reducer._rebuild_buckets()
- # Schedule a broadcast if we are syncing module buffers in the
- # forward pass
- # TODO: make DDP uneven inputs context manager support buffer
- # comm hook (https://github.com/pytorch/pytorch/issues/65436)
- ddp._check_and_sync_module_buffers()
- # Check if need to sync in the backward pass
- should_sync_backwards = ddp._check_global_requires_backward_grad_sync(
- is_joined_rank=True
- )
- # Forward parameter sync is disabled in the next iteration if we
- # are skipping gradient sync this iteration, so set
- # `require_forward_param_sync` accordingly
- ddp.require_forward_param_sync = should_sync_backwards
- if not should_sync_backwards:
- return
- # Schedule one allreduce per gradient bucket to match the backward
- # pass allreduce
- ddp._match_all_reduce_for_bwd_pass()
- # Check if we need to allreduce locally unused parameters
- if ddp.find_unused_parameters:
- ddp._match_unused_params_allreduce()
- # Rebuilt parameters are pushed only once during a training period
- ddp.reducer._push_all_rebuilt_params()
- def post_hook(self, is_last_joiner: bool):
- """Sync the final model to ensure that the model is the same across all processes."""
- self.ddp._sync_final_model(is_last_joiner)
- class DistributedDataParallel(Module, Joinable):
- r"""Implement distributed data parallelism based on ``torch.distributed`` at module level.
- This container provides data parallelism by synchronizing gradients
- across each model replica. The devices to synchronize across are
- specified by the input ``process_group``, which is the entire world
- by default. Note that ``DistributedDataParallel`` does not chunk or
- otherwise shard the input across participating GPUs; the user is
- responsible for defining how to do so, for example through the use
- of a :class:`DistributedSampler`.
- See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`.
- The same constraints on input as in :class:`torch.nn.DataParallel` apply.
- Creation of this class requires that ``torch.distributed`` to be already
- initialized, by calling :func:`torch.distributed.init_process_group`.
- ``DistributedDataParallel`` is proven to be significantly faster than
- :class:`torch.nn.DataParallel` for single-node multi-GPU data
- parallel training.
- To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn
- up ``N`` processes, ensuring that each process exclusively works on a single
- GPU from 0 to N-1. This can be done by either setting
- ``CUDA_VISIBLE_DEVICES`` for every process or by calling:
- >>> # xdoctest: +SKIP("undefined variables")
- >>> torch.cuda.set_device(i)
- where i is from 0 to N-1. In each process, you should refer the following
- to construct this module:
- >>> # xdoctest: +SKIP("undefined variables")
- >>> torch.distributed.init_process_group(
- >>> backend='nccl', world_size=N, init_method='...'
- >>> )
- >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
- In order to spawn up multiple processes per node, you can use either
- ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``.
- .. note::
- Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__
- for a brief introduction to all features related to distributed training.
- .. note::
- ``DistributedDataParallel`` can be used in conjunction with
- :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce
- per-rank optimizer states memory footprint. Please refer to
- `ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__
- for more details.
- .. note:: ``nccl`` backend is currently the fastest and highly recommended
- backend when using GPUs. This applies to both single-node and
- multi-node distributed training.
- .. note:: This module also supports mixed-precision distributed training.
- This means that your model can have different types of parameters such
- as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these
- mixed types of parameters will just work fine.
- .. note:: If you use ``torch.save`` on one process to checkpoint the module,
- and ``torch.load`` on some other processes to recover it, make sure that
- ``map_location`` is configured properly for every process. Without
- ``map_location``, ``torch.load`` would recover the module to devices
- where the module was saved from.
- .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
- gradient will be ``M`` times smaller when compared to the same model
- trained on a single node with ``batch=M*N`` if the loss is summed (NOT
- averaged as usual) across instances in a batch (because the gradients
- between different nodes are averaged). You should take this into
- consideration when you want to obtain a mathematically equivalent
- training process compared to the local training counterpart. But in most
- cases, you can just treat a DistributedDataParallel wrapped model, a
- DataParallel wrapped model and an ordinary model on a single GPU as the
- same (E.g. using the same learning rate for equivalent batch size).
- .. note::
- Parameters are never broadcast between processes. The module performs
- an all-reduce step on gradients and assumes that they will be modified
- by the optimizer in all processes in the same way. Buffers
- (e.g. BatchNorm stats) are broadcast from the module in process of rank
- 0, to all other replicas in the system in every iteration.
- .. note::
- If you are using DistributedDataParallel in conjunction with the
- :ref:`distributed-rpc-framework`, you should always use
- :meth:`torch.distributed.autograd.backward` to compute gradients and
- :class:`torch.distributed.optim.DistributedOptimizer` for optimizing
- parameters.
- Example::
- >>> # xdoctest: +SKIP("undefined variables")
- >>> import torch.distributed.autograd as dist_autograd
- >>> from torch.nn.parallel import DistributedDataParallel as DDP
- >>> import torch
- >>> from torch import optim
- >>> from torch.distributed.optim import DistributedOptimizer
- >>> import torch.distributed.rpc as rpc
- >>> from torch.distributed.rpc import RRef
- >>>
- >>> t1 = torch.rand((3, 3), requires_grad=True)
- >>> t2 = torch.rand((3, 3), requires_grad=True)
- >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2))
- >>> ddp_model = DDP(my_model)
- >>>
- >>> # Setup optimizer
- >>> optimizer_params = [rref]
- >>> for param in ddp_model.parameters():
- >>> optimizer_params.append(RRef(param))
- >>>
- >>> dist_optim = DistributedOptimizer(
- >>> optim.SGD,
- >>> optimizer_params,
- >>> lr=0.05,
- >>> )
- >>>
- >>> with dist_autograd.context() as context_id:
- >>> pred = ddp_model(rref.to_here())
- >>> loss = loss_func(pred, target)
- >>> dist_autograd.backward(context_id, [loss])
- >>> dist_optim.step(context_id)
- .. note::
- DistributedDataParallel currently offers limited support for gradient
- checkpointing with :meth:`torch.utils.checkpoint`.
- If the checkpoint is done with use_reentrant=False (recommended), DDP
- will work as expected without any limitations.
- If, however, the checkpoint is done with use_reentrant=True (the default),
- DDP will work as expected when there are no unused parameters in the model
- and each layer is checkpointed at most once (make sure you are not passing
- `find_unused_parameters=True` to DDP). We currently do not support the
- case where a layer is checkpointed multiple times, or when there unused
- parameters in the checkpointed model.
- .. note::
- To let a non-DDP model load a state dict from a DDP model,
- :meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present`
- needs to be applied to strip the prefix "module." in the DDP state dict before loading.
- .. warning::
- Constructor, forward method, and differentiation of the output (or a
- function of the output of this module) are distributed synchronization
- points. Take that into account in case different processes might be
- executing different code.
- .. warning::
- This module assumes all parameters are registered in the model by the
- time it is created. No parameters should be added nor removed later.
- Same applies to buffers.
- .. warning::
- This module assumes all parameters are registered in the model of each
- distributed processes are in the same order. The module itself will
- conduct gradient ``allreduce`` following the reverse order of the
- registered parameters of the model. In other words, it is users'
- responsibility to ensure that each distributed process has the exact
- same model and thus the exact same parameter registration order.
- .. warning::
- This module allows parameters with non-rowmajor-contiguous strides.
- For example, your model may contain some parameters whose
- :class:`torch.memory_format` is ``torch.contiguous_format``
- and others whose format is ``torch.channels_last``. However,
- corresponding parameters in different processes must have the
- same strides.
- .. warning::
- This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
- only work if gradients are to be accumulated in ``.grad`` attributes of
- parameters).
- .. warning::
- If you plan on using this module with a ``nccl`` backend or a ``gloo``
- backend (that uses Infiniband), together with a DataLoader that uses
- multiple workers, please change the multiprocessing start method to
- ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
- Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
- likely experience deadlocks if you don't change this setting.
- .. warning::
- You should never try to change your model's parameters after wrapping
- up your model with ``DistributedDataParallel``. Because, when
- wrapping up your model with ``DistributedDataParallel``, the constructor
- of ``DistributedDataParallel`` will register the additional gradient
- reduction functions on all the parameters of the model itself at the
- time of construction. If you change the model's parameters afterwards,
- gradient reduction functions no longer match the correct set of
- parameters.
- .. warning::
- Using ``DistributedDataParallel`` in conjunction with the
- :ref:`distributed-rpc-framework` is experimental and subject to change.
- Args:
- module (Module): module to be parallelized
- device_ids (list of int or torch.device): CUDA devices.
- 1) For single-device modules, ``device_ids`` can
- contain exactly one device id, which represents the only
- CUDA device where the input module corresponding to this process resides.
- Alternatively, ``device_ids`` can also be ``None``.
- 2) For multi-device modules and CPU modules,
- ``device_ids`` must be ``None``.
- When ``device_ids`` is ``None`` for both cases,
- both the input data for the forward pass and the actual module
- must be placed on the correct device.
- (default: ``None``)
- output_device (int or torch.device): Device location of output for
- single-device CUDA modules. For multi-device modules and
- CPU modules, it must be ``None``, and the module itself
- dictates the output location. (default: ``device_ids[0]``
- for single-device modules)
- broadcast_buffers (bool): Flag that enables syncing (broadcasting)
- buffers of the module at beginning of the ``forward``
- function. (default: ``True``)
- process_group: The process group to be used for distributed data
- all-reduction. If ``None``, the default process group, which
- is created by :func:`torch.distributed.init_process_group`,
- will be used. (default: ``None``)
- bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into
- multiple buckets so that gradient reduction of each
- bucket can potentially overlap with backward computation.
- :attr:`bucket_cap_mb` controls the bucket size in
- MebiBytes (MiB). If ``None``, a default size of 25 MiB
- will be used. (default: ``None``)
- find_unused_parameters (bool): Traverse the autograd graph from all
- tensors contained in the return value of the
- wrapped module's ``forward`` function. Parameters
- that don't receive gradients as part of this
- graph are preemptively marked as being ready to
- be reduced. In addition, parameters that may have
- been used in the wrapped module's ``forward``
- function but were not part of loss computation and
- thus would also not receive gradients are
- preemptively marked as ready to be reduced.
- (default: ``False``)
- check_reduction: This argument is deprecated.
- gradient_as_bucket_view (bool): When set to ``True``, gradients will be views
- pointing to different offsets of ``allreduce`` communication
- buckets. This can reduce peak memory usage, where the
- saved memory size will be equal to the total gradients
- size. Moreover, it avoids the overhead of copying between
- gradients and ``allreduce`` communication buckets. When
- gradients are views, ``detach_()`` cannot be called on the
- gradients. If hitting such errors, please fix it by
- referring to the :meth:`~torch.optim.Optimizer.zero_grad`
- function in ``torch/optim/optimizer.py`` as a solution.
- Note that gradients will be views after first iteration, so
- the peak memory saving should be checked after first iteration.
- static_graph (bool): When set to ``True``, DDP knows the trained graph is
- static. Static graph means 1) The set of used and unused
- parameters will not change during the whole training loop; in
- this case, it does not matter whether users set
- ``find_unused_parameters = True`` or not. 2) How the graph is trained
- will not change during the whole training loop (meaning there is
- no control flow depending on iterations).
- When static_graph is set to be ``True``, DDP will support cases that
- can not be supported in the past:
- 1) Reentrant backwards.
- 2) Activation checkpointing multiple times.
- 3) Activation checkpointing when model has unused parameters.
- 4) There are model parameters that are outside of forward function.
- 5) Potentially improve performance when there are unused parameters,
- as DDP will not search graph in each iteration to detect unused
- parameters when static_graph is set to be ``True``.
- To check whether you can set static_graph to be ``True``, one way is to
- check ddp logging data at the end of your previous model training,
- if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you
- can set ``static_graph = True`` as well.
- Example::
- >>> # xdoctest: +SKIP("undefined variables")
- >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
- >>> # Training loop
- >>> ...
- >>> ddp_logging_data = model_DDP._get_ddp_logging_data()
- >>> static_graph = ddp_logging_data.get("can_set_static_graph")
- delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter): a list
- of named parameters whose all reduce will be delayed when the gradient of
- the parameter specified in ``param_to_hook_all_reduce`` is ready. Other
- arguments of DDP do not apply to named params specified in this argument
- as these named params will be ignored by DDP reducer.
- param_to_hook_all_reduce (torch.nn.Parameter): a parameter to hook delayed all reduce
- of parameters specified in ``delay_all_reduce_named_params``.
- Attributes:
- module (Module): the module to be parallelized.
- Example::
- >>> # xdoctest: +SKIP("undefined variables")
- >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
- >>> net = torch.nn.parallel.DistributedDataParallel(model)
- """
- # used to track whether the given thread is inside ddp forward for torchdynamo purposes
- _active_ddp_module: Optional["DistributedDataParallel"] = None
- def __init__(
- self,
- module,
- device_ids=None,
- output_device=None,
- dim=0,
- broadcast_buffers=True,
- process_group=None,
- bucket_cap_mb=None,
- find_unused_parameters=False,
- check_reduction=False,
- gradient_as_bucket_view=False,
- static_graph=False,
- delay_all_reduce_named_params=None,
- param_to_hook_all_reduce=None,
- mixed_precision: Optional[_MixedPrecision] = None,
- device_mesh=None,
- ):
- super().__init__()
- Joinable.__init__(self)
- self.logger = None
- if bool(delay_all_reduce_named_params is not None) != bool(
- param_to_hook_all_reduce is not None
- ):
- self._log_and_throw(
- ValueError,
- "delay_all_reduce_named_params and param_to_hook_all_reduce "
- "need to be set at the same time.",
- )
- if process_group and device_mesh is not None:
- raise RuntimeError(
- "Cannot specify both process_group and device_mesh arguments."
- )
- elif process_group is None and device_mesh is None:
- self.process_group = _get_default_group()
- elif device_mesh is None:
- self.process_group = process_group
- else:
- if device_mesh.ndim != 1:
- raise RuntimeError(
- f"Only 1D device mesh is supported, but got {device_mesh}."
- )
- self.device_mesh = device_mesh
- self.process_group = device_mesh.get_group(mesh_dim=0)
- from torch.distributed.device_mesh import _mesh_resources
- if _mesh_resources.get_parent_mesh(device_mesh) is not None:
- # TODO: This is a temporary work around to enable DDP + TP.
- # We should do the logic in DDP so that the 2D implementation is
- # sound and the state_dict works out of the box.
- # This has to be done before check UninitializedParameter.
- from torch.distributed.tensor.parallel.ddp import (
- _pre_dp_module_transform,
- )
- _pre_dp_module_transform(module)
- self._delay_all_reduce_params = []
- if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
- self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore)
- else:
- self.parameters_to_ignore = set()
- if delay_all_reduce_named_params is not None:
- for name, param in delay_all_reduce_named_params:
- self.parameters_to_ignore.add(name)
- self._delay_all_reduce_params.append(param)
- self._module_parameters = [
- p
- for n, p in module.named_parameters()
- if n not in self.parameters_to_ignore
- ]
- if not any(p.requires_grad for p in self._module_parameters):
- if len(self._delay_all_reduce_params):
- logger.info("Delay the AllReduce of all parameters.")
- else:
- self._log_and_throw(
- RuntimeError,
- "DistributedDataParallel is not needed when a module "
- "doesn't have any parameter that requires a gradient.",
- )
- if device_ids is not None and len(device_ids) > 1:
- self._log_and_throw(
- ValueError,
- "device_ids can only be None or contain a single element.",
- )
- self.is_multi_device_module = (
- len({p.device for p in self._module_parameters}) > 1
- )
- distinct_device_types = {
- p.device.type for p in self._module_parameters if p.device is not None
- }
- if len(distinct_device_types) != 1:
- self._log_and_throw(
- ValueError,
- "DistributedDataParallel's input module must be on "
- f"the same type of devices, but input module parameters locate in {distinct_device_types}.",
- )
- self.device_type = next(iter(distinct_device_types))
- if (
- device_ids is None
- or len(device_ids) == 0 # For backward compatibility.
- or self.device_type == "cpu"
- or self.is_multi_device_module
- ):
- if device_ids or output_device:
- self._log_and_throw(
- ValueError,
- "DistributedDataParallel device_ids and output_device arguments "
- "only work with single-device/multiple-device GPU modules or CPU modules, "
- f"but got device_ids {device_ids}, output_device {output_device}, "
- f"and module parameters {({p.device for p in self._module_parameters})}.",
- )
- self.device_ids = None
- self.output_device = None
- else:
- self.device_ids = [_get_device_index(x, True) for x in device_ids]
- if output_device is None:
- output_device = device_ids[0]
- self.output_device = _get_device_index(output_device, True)
- self.static_graph = False
- self.dim = dim
- self.module = module
- self.device = next(iter(self._module_parameters)).device
- self.broadcast_buffers = broadcast_buffers
- self.find_unused_parameters = find_unused_parameters
- self.require_backward_grad_sync = True
- self.require_forward_param_sync = True
- self.gradient_as_bucket_view = gradient_as_bucket_view
- self.mixed_precision = mixed_precision
- if self.mixed_precision is not None:
- logger.warning("Received mixed precision config %s", self.mixed_precision)
- if check_reduction:
- # This argument is no longer used since the reducer
- # will ensure reduction completes even if some parameters
- # do not receive gradients.
- warnings.warn(
- "The `check_reduction` argument in `DistributedDataParallel` "
- "module is deprecated. Please avoid using it.",
- FutureWarning,
- stacklevel=2,
- )
- # Check that a module does not have Uninitialized parameters
- for param in self._module_parameters:
- if isinstance(param, torch.nn.parameter.UninitializedParameter):
- self._log_and_throw(
- RuntimeError,
- "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. "
- "Run a dummy forward pass to correctly initialize the modules",
- )
- # used for intra-node param sync and inter-node sync as well
- self.broadcast_bucket_size = int(250 * 1024 * 1024)
- # reduction bucket size
- if bucket_cap_mb is None:
- # default case (bucket cap is 25 MiB)
- bucket_cap_mb = 25
- self.bucket_bytes_cap_default = True
- else:
- self.bucket_bytes_cap_default = False
- self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
- # Whether to perform input tensor CPU to GPU copies on a side-stream
- self.use_side_stream_for_tensor_copies = (
- os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1"
- )
- # Initialize gradient buffers and register all reduce hook
- self._delay_grad_buffer = None
- self._delay_grad_views: List[torch.Tensor] = []
- self._delay_all_reduce_all_params = False
- if len(self._delay_all_reduce_params) != 0:
- self._register_delay_all_reduce_hook(
- bucket_cap_mb=bucket_cap_mb,
- param_to_hook_all_reduce=param_to_hook_all_reduce,
- device_ids=device_ids,
- )
- if self._delay_all_reduce_all_params:
- return
- # Build parameters for reducer.
- parameters, expect_sparse_gradient = self._build_params_for_reducer()
- # Verify model equivalence.
- _verify_param_shape_across_processes(self.process_group, parameters)
- # Sync params and buffers. Ensures all DDP models start off at the same value.
- _sync_module_states(
- module=self.module,
- process_group=self.process_group,
- broadcast_bucket_size=self.broadcast_bucket_size,
- src=0,
- params_and_buffers_to_ignore=self.parameters_to_ignore,
- broadcast_buffers=self.broadcast_buffers,
- )
- # In debug mode, build a mapping of parameter index -> parameter.
- param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
- # Builds reducer.
- self._ddp_init_helper(
- parameters,
- expect_sparse_gradient,
- param_to_name_mapping,
- static_graph,
- )
- self._comm_hooks: List[Tuple[Callable, object]] = []
- if self.mixed_precision is not None:
- _setup_mixed_precision_params(self.mixed_precision, self.module)
- _cast_buffers(self.mixed_precision, self.module)
- # Stream used for async low precision copies.
- self._mp_stream = torch.cuda.Stream()
- self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated]
- # Add forward pre-hook to root module to kick off copies to lower
- # precision.
- self.module.register_forward_pre_hook(
- self._root_copy_hook, prepend=False, with_kwargs=True
- )
- # Add forward pre hook to all submodules to wait for copy events
- # before running computation.
- for module in self.module.modules():
- module.register_forward_pre_hook(
- self._module_wait_for_copy_hook,
- prepend=False,
- with_kwargs=True,
- )
- # Set up callbacks in backward to upcast and use full precision
- # params. TODO (rohan-varma): Make this compose with general
- # comm hooks and apply_optimizer_in_backward. Importing inline to
- # avoid circular import issue.
- from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import (
- _AllreduceUpcastHookState,
- _reducer_allreduce_and_upcast_hook,
- )
- upcast_hook_state = _AllreduceUpcastHookState(
- ddp_weakref=weakref.ref(self),
- upcast_stream=torch.cuda.Stream(),
- )
- self.register_comm_hook(
- upcast_hook_state,
- _reducer_allreduce_and_upcast_hook,
- )
- # Inform reducer of reduced precision param dtype for correctness
- # of type checks between gradient and bucket.
- self.reducer._set_mixed_precision_param_dtype( # type: ignore[attr-defined]
- self.mixed_precision.param_dtype
- )
- self._has_rebuilt_buckets = False
- if static_graph:
- self._set_static_graph()
- self._lazy_init_ran = False
- # Register the AccumulateGrad post hooks if optimize_ddp is
- # True. The hooks will be deregistered if compiled_autograd is not
- # enabled.
- self._accum_grad_hooks: List[RemovableHandle] = []
- optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode()
- self._use_python_reducer = optimize_ddp in (
- "python_reducer",
- "python_reducer_without_compiled_forward",
- )
- if self._use_python_reducer:
- torch._inductor.config._fuse_ddp_communication = True
- torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
- # Directly adding this to the trace rule will disturb the users
- # who are using DDPOptimizer.
- torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add(
- "torch.nn.parallel.distributed"
- )
- torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear()
- self._force_to_disable_cpp_reducer = (
- optimize_ddp == "python_reducer_without_compiled_forward"
- )
- if self._use_python_reducer:
- self._register_accum_grad_hook()
- # Whether or not DDPSink performs a clone.
- self._ddp_sink_clone = True
- def _register_accum_grad_hook(self):
- import torch.distributed._functional_collectives as fcol
- def compiled_accum_grad_hook(
- param,
- *,
- param_index: int,
- ):
- if not self.require_backward_grad_sync:
- return
- if param.grad is None:
- return
- if self._comm_hooks:
- for hook, state in self._comm_hooks:
- hook(state, (param.grad, param))
- else:
- gradient = param.grad / self.process_group.size()
- gradient = fcol.all_reduce(gradient, "sum", self.process_group)
- param.grad.copy_(gradient)
- for index, param in enumerate(self._module_parameters):
- if not param.requires_grad:
- continue
- self._accum_grad_hooks.append(
- param.register_post_accumulate_grad_hook(
- functools.partial(
- compiled_accum_grad_hook,
- param_index=index,
- )
- )
- )
- def _delayed_all_reduce_hook(self, grad):
- world_size = dist.get_world_size(self.process_group)
- self._delay_grad_buffer.div_(world_size) # type: ignore[union-attr]
- _ = dist.all_reduce(
- self._delay_grad_buffer, group=self.process_group, async_op=True
- )
- return grad
- def _register_delay_all_reduce_hook(
- self,
- bucket_cap_mb,
- param_to_hook_all_reduce,
- device_ids,
- ):
- # 1. Create gradient buffer
- device = torch.device("cpu") if device_ids is None else device_ids[0]
- self._delay_grad_buffer = torch.zeros(
- sum(p.numel() for p in self._delay_all_reduce_params),
- device=device,
- )
- # 2. Broadcast the parameters
- detached_params = [p.detach() for p in self._delay_all_reduce_params]
- dist._broadcast_coalesced(self.process_group, detached_params, bucket_cap_mb, 0)
- # 3. Hook all reduce to the specified parameter
- param_to_hook_all_reduce.register_hook(self._delayed_all_reduce_hook)
- # 4. Build tensor views for gradients
- offset = 0
- for param in self._delay_all_reduce_params:
- grad_view = self._delay_grad_buffer[offset : (offset + param.numel())].view(
- param.shape
- )
- self._delay_grad_views.append(grad_view)
- offset = offset + param.numel()
- # 5. Check whether the all reduce of all params requiring grad is delayed.
- for module_name, module in self.module.named_modules():
- for param_name, param in module.named_parameters(recurse=False):
- if param.requires_grad:
- full_name = f"{module_name}.{param_name}"
- if full_name not in self.parameters_to_ignore:
- # There is at least a param whose all reduce will not be delayed.
- # In this case, we should not set self._delay_all_reduce_all_params
- # to True.
- return
- self._delay_all_reduce_all_params = True
- def _setup_in_backward_optimizers(self):
- # Check if user has used apply_optim_in_backward to overlap optimizer
- # step + DDP backward. Current constraints:
- # 1. Only allreduce is supported at the moment, no custom communication.
- # 2. For DDP-managed parameters that have their optimizer run in
- # backward, their gradients are set to ``None``. If your use case
- # requires DDP parameters grad not to be set to ``None`` after their
- # in-backward optimizer runs, please ping
- # https://github.com/pytorch/pytorch/issues/90052.
- # NOTE: we use self._module_parameters instead of .parameters() since
- # the former excludes ignored (non-DDP managed) parameters.
- if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters):
- torch._C._log_api_usage_once("ddp.optimizer_in_backward")
- # Remove hooks that apply_optim_in_backward had registered because
- # DDP customizes how optimizer is overlapped with backward due to
- # the allreduce.
- param_to_handle_map = (
- dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
- )
- for p in self._module_parameters:
- for handle in param_to_handle_map.get(p, []):
- handle.remove()
- # Need a weakref to DDP instance to run all_reduce (from reducer)
- # and get managed DDP parameters.
- ddp_weakref = weakref.ref(self)
- # Note: importing in function, otherwise this will cause a circular
- # import.
- from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
- _apply_optim_in_backward_hook,
- )
- self.register_comm_hook(
- ddp_weakref,
- _apply_optim_in_backward_hook(
- gradient_is_bucket_view=self.gradient_as_bucket_view
- ),
- )
- self.reducer._set_optimizer_in_backward() # type: ignore[attr-defined]
- def _fire_reducer_autograd_hook(self, idx, *unused):
- """
- Fire the reducer's autograd hook to allreduce params in a Reducer bucket.
- Note that this is only used during mixed precision training as the
- Reducer's hooks installed during construction time would not be called
- as we're working in the low precision parameter setting.
- """
- self.reducer._autograd_hook(idx) # type: ignore[attr-defined]
- def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None:
- """
- For DDP mixed precision, put low precision copies on separate stream and create events to wait for them.
- When training with DDP mixed precision, this root pre-forward hook kicks
- off low precision copies on a separate stream and creates respective
- events to wait for them.
- """
- # Clear out previous iteration submodule to event. This is because we
- # may have populated some events for modules that didn't end up being
- # used.
- self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated]
- with torch.cuda.stream(self._mp_stream):
- for submodule in self.module.modules():
- for param in submodule.parameters(recurse=False):
- # Do not cast DDP ignored parameters.
- if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
- continue
- _alloc_storage(param._mp_param, param.size())
- # copy() implicitly casts to low precision
- with torch.no_grad():
- param._mp_param.copy_(param.data)
- # TODO: when zero_grad(set_to_none=False) or in grad
- # accumulation case, accumulated grads can be in fp32
- # which can cause errors when running DDP backwards due
- # to mismatched incoming and accumulated gradient types.
- # So we manually cast the accumulated grad down for now,
- # in the future we may shift to FSDP style gradient
- # accumulation management where the accumulated gradient
- # is saved and .grad field is set to None, bypassing
- # this issue.
- if param.grad is not None:
- param.grad.data = param.grad.to(
- self.mixed_precision.param_dtype # type: ignore[union-attr]
- )
- param.data = param._mp_param
- copy_event = torch.cuda.Event()
- copy_event.record()
- self._submodule_to_event[submodule].append(copy_event)
- def _module_wait_for_copy_hook(
- self,
- module,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- """Before carrying out computation, wait on the appropriate event to ensure low precision copies have finished."""
- try:
- event = self._submodule_to_event[module].popleft()
- except IndexError:
- # copy event has already been waited on
- return
- event.wait(stream=torch.cuda.current_stream())
- for p in module.parameters(recurse=False):
- # Don't register hooks if param does not require grad
- if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored):
- continue
- # We need to register autograd hook here instead of DDP's ctor
- # since we're working with the low precision param. Register them
- # via obtaining the gradient accumulator.
- tmp = p.expand_as(p)
- grad_acc = tmp.grad_fn.next_functions[0][0]
- hook = grad_acc.register_hook(
- functools.partial(self._fire_reducer_autograd_hook, p._idx)
- )
- p._ddp_mp_hook_state = (grad_acc, hook)
- def _log_and_throw(self, err_type, err_msg):
- if self.logger is not None:
- self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}")
- raise err_type(err_msg)
- def _ddp_init_helper(
- self,
- parameters,
- expect_sparse_gradient,
- param_to_name_mapping,
- static_graph,
- ):
- """
- DDP init helper function to manage parameters, grad hooks, logging, and SyncBatchNorm.
- Initialization helper function that does the following:
- (1) bucketing the parameters for reductions
- (2) resetting the bucketing states
- (3) registering the grad hooks
- (4) Logging construction-time DDP logging data
- (5) passing a handle of DDP to SyncBatchNorm Layer
- """
- # Notice, the parameters order is not in the order in which they are used,
- # especially in models with control flow.
- #
- # Alongside parameters are not presented in the real execution order,
- # if a certain model happens to also
- # 1) have other collectives comm ops in its backward graph.
- # 2) have unused parameter in subset ranks of the whole world.
- # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter,
- # matching up with other collectives comm ops on other ranks unexpectedly.
- #
- # In order to handle this corner case, when the parameters are not in the real execution order,
- # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients
- # of the whole graph are computed.
- #
- # Notice, here we only disable bucketing for the first iteration.
- # After the first iteration, it's OK to rebuild buckets,
- # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph.
- # Can remove this branching once #73732 is landed.
- if static_graph is True or self.find_unused_parameters is False:
- bucket_size_limits = [sys.maxsize]
- else:
- if self.bucket_bytes_cap_default:
- bucket_size_limits = [
- dist._DEFAULT_FIRST_BUCKET_BYTES,
- self.bucket_bytes_cap,
- ]
- else:
- bucket_size_limits = [self.bucket_bytes_cap]
- (
- bucket_indices,
- per_bucket_size_limits,
- ) = dist._compute_bucket_assignment_by_size(
- parameters,
- bucket_size_limits,
- expect_sparse_gradient,
- )
- # Remember index for parameters if we are in mixed precision, as we
- # need to pass in index to Reducer's autograd hook via python.
- if self.mixed_precision is not None:
- for i, p in enumerate(parameters):
- p._idx = i
- # Note: reverse list of buckets because we want to approximate the
- # order in which their gradients are produced, and assume they
- # are used in the forward pass in the order they are defined.
- self.reducer = dist.Reducer(
- parameters,
- list(reversed(bucket_indices)),
- list(reversed(per_bucket_size_limits)),
- self.process_group,
- expect_sparse_gradient,
- # The bucket size limit is specified in the constructor.
- # Additionally, we allow for a single small bucket for parameters
- # that are defined first, such that their gradients don't spill into
- # a much larger bucket, adding unnecessary latency after gradient
- # computation finishes. Experiments showed 1MB is a reasonable value.
- self.bucket_bytes_cap,
- self.find_unused_parameters,
- self.gradient_as_bucket_view,
- param_to_name_mapping,
- # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first
- # bucket.
- dist._DEFAULT_FIRST_BUCKET_BYTES
- if self.bucket_bytes_cap_default
- else self.bucket_bytes_cap,
- )
- self.logger = dist.Logger(self.reducer)
- # Set as a weak reference to avoid reference cycle between
- # logger and reducer.
- self.reducer.set_logger(self.logger)
- has_sync_bn = False
- for submodule in self.module.modules():
- if isinstance(submodule, torch.nn.SyncBatchNorm):
- has_sync_bn = True
- break
- # Set logging data that can be got during construction time.
- self.logger.set_construction_data_and_log(
- self.module.__class__.__name__,
- [] if self.device_ids is None else self.device_ids,
- -1 if self.output_device is None else self.output_device,
- self.broadcast_buffers,
- has_sync_bn,
- static_graph,
- )
- # passing a handle to torch.nn.SyncBatchNorm layer
- self._passing_sync_batchnorm_handle(self.module)
- def __getstate__(self):
- self._check_default_group()
- attrs = copy.copy(self.__dict__)
- del attrs["process_group"]
- del attrs["reducer"]
- del attrs["logger"]
- return attrs
- def __setstate__(self, state):
- # If serializable, then the process group should be the default one
- self.process_group = _get_default_group()
- super().__setstate__(state)
- self.__dict__.setdefault("require_forward_param_sync", True)
- self.__dict__.setdefault("require_backward_grad_sync", True)
- parameters, expect_sparse_gradient = self._build_params_for_reducer()
- # In debug mode, build a mapping of parameter index -> parameter.
- param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
- # Builds reducer.
- self._ddp_init_helper(
- parameters,
- expect_sparse_gradient,
- param_to_name_mapping,
- self.static_graph,
- )
- if self.static_graph:
- self.reducer._set_static_graph()
- assert self.logger is not None
- self.logger._set_static_graph()
- def _build_params_for_reducer(self):
- # Build tuple of (module, parameter) for all parameters that require grads.
- modules_and_parameters = [
- (module, parameter)
- for module_name, module in self.module.named_modules()
- for parameter in [
- param
- # Note that we access module.named_parameters instead of
- # parameters(module). parameters(module) is only needed in the
- # single-process multi device case, where it accesses replicated
- # parameters through _former_parameters.
- for param_name, param in module.named_parameters(recurse=False)
- if param.requires_grad
- and f"{module_name}.{param_name}" not in self.parameters_to_ignore
- ]
- ]
- # Deduplicate any parameters that might be shared across child modules.
- memo = set()
- modules_and_parameters = [
- # "p not in memo" is the deduplication check.
- # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
- (m, p)
- for m, p in modules_and_parameters
- if p not in memo and not memo.add(p) # type: ignore[func-returns-value]
- ]
- # Build list of parameters.
- parameters = [parameter for _, parameter in modules_and_parameters]
- # Checks if a module will produce a sparse gradient.
- def produces_sparse_gradient(module):
- if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
- return module.sparse
- return False
- # Build list of booleans indicating whether or not to expect sparse
- # gradients for the corresponding parameters.
- expect_sparse_gradient = [
- produces_sparse_gradient(module) for module, _ in modules_and_parameters
- ]
- self._assign_modules_buffers()
- return parameters, expect_sparse_gradient
- def _assign_modules_buffers(self):
- """
- Assign self.module.named_buffers to self.modules_buffers.
- Assigns module buffers to self.modules_buffers which are then used to
- broadcast across ranks when broadcast_buffers=True. Note that this
- must be called every time buffers need to be synced because buffers can
- be reassigned by user module,
- see https://github.com/pytorch/pytorch/issues/63916.
- """
- # Collect buffers for modules, filtering out buffers that should be ignored.
- named_module_buffers = [
- (buffer, buffer_name)
- for buffer_name, buffer in self.module.named_buffers()
- if buffer_name not in self.parameters_to_ignore
- ]
- self.modules_buffers = [
- buffer for (buffer, buffer_name) in named_module_buffers
- ]
- # Dict[str, tensor] representing module buffers not ignored by DDP.
- self.named_module_buffers = {
- buffer_name: buffer for (buffer, buffer_name) in named_module_buffers
- }
- def _build_debug_param_to_name_mapping(self, parameters):
- param_to_param_index = {parameters[i]: i for i in range(len(parameters))}
- param_set = set(parameters)
- param_index_to_param_fqn = {}
- for module_name, module in self.module.named_modules():
- for param_name, param in module.named_parameters(recurse=False):
- fqn = f"{module_name}.{param_name}"
- # Bypass ignored parameters since those are not reduced by DDP
- # to begin with.
- if fqn not in self.parameters_to_ignore and param.requires_grad:
- if param not in param_set:
- self._log_and_throw(
- ValueError,
- f"Param with name {fqn} found in module parameters, but not DDP parameters."
- " This indicates a bug in DDP, please report an issue to PyTorch.",
- )
- param_index = param_to_param_index[param]
- param_index_to_param_fqn[param_index] = fqn
- # Ensure we covered all parameters
- if len(param_set) != len(param_index_to_param_fqn):
- self._log_and_throw(
- ValueError,
- (
- "Expected param to name mapping to cover all parameters, but"
- f" got conflicting lengths: {len(param_set)} vs "
- f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP"
- ", please report an issue to PyTorch."
- ),
- )
- return param_index_to_param_fqn
- def _get_parameters(self, m, recurse=True):
- """Return a generator of module parameters."""
- def model_parameters(m):
- ps = (
- m._former_parameters.values()
- if hasattr(m, "_former_parameters")
- else m.parameters(recurse=False)
- )
- yield from ps
- for mod in m.modules() if recurse else [m]:
- yield from model_parameters(mod)
- def _check_default_group(self):
- pickle_not_supported = False
- try:
- if self.process_group != _get_default_group():
- pickle_not_supported = True
- except RuntimeError:
- pickle_not_supported = True
- if pickle_not_supported:
- self._log_and_throw(
- RuntimeError,
- "DDP Pickling/Unpickling are only supported "
- "when using DDP with the default process "
- "group. That is, when you have called "
- "init_process_group and have not passed "
- "process_group argument to DDP constructor",
- )
- @contextmanager
- def no_sync(self):
- r"""
- Context manager to disable gradient synchronizations across DDP processes.
- Within this context, gradients will be accumulated on module
- variables, which will later be synchronized in the first
- forward-backward pass exiting the context.
- Example::
- >>> # xdoctest: +SKIP("undefined variables")
- >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
- >>> with ddp.no_sync():
- >>> for input in inputs:
- >>> ddp(input).backward() # no synchronization, accumulate grads
- >>> ddp(another_input).backward() # synchronize grads
- .. warning::
- The forward pass should be included inside the context manager, or
- else gradients will still be synchronized.
- """
- old_require_backward_grad_sync = self.require_backward_grad_sync
- self.require_backward_grad_sync = False
- try:
- yield
- finally:
- self.require_backward_grad_sync = old_require_backward_grad_sync
- @classmethod
- def _get_active_ddp_module(cls):
- """`TorchDynamo` requires DDP's status and module for cooperative optimization."""
- return cls._active_ddp_module
- # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in
- # for the 'module_to_run' underneath
- # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details
- @contextmanager
- @torch._disable_dynamo(recursive=False)
- def _inside_ddp_forward(self):
- DistributedDataParallel._active_ddp_module = self
- try:
- yield
- finally:
- DistributedDataParallel._active_ddp_module = None
- def _run_ddp_forward(self, *inputs, **kwargs):
- if self._use_python_reducer:
- return self.module(*inputs, **kwargs) # type: ignore[index]
- else:
- with self._inside_ddp_forward():
- return self.module(*inputs, **kwargs) # type: ignore[index]
- def _clear_grad_buffer(self):
- # Making param.grad points to the grad buffers before backward is based on the
- # assumption that the grad accumulation is done in place in autograd engine,
- # for some edge cases, if the grad accumulation in autograd engine is not in
- # place, then the param.grad and grad buffers are detached.
- if self._delay_grad_buffer is not None:
- # We batch zero_grad for all params by resetting the whole grad
- # buffer when the grad of all params is set to None.
- all_param_grad_none = all(
- param.grad is None for param in self._delay_all_reduce_params
- )
- for index, param in enumerate(self._delay_all_reduce_params):
- if param.grad is None:
- param.grad = self._delay_grad_views[index]
- if not all_param_grad_none:
- param.grad.zero_()
- if all_param_grad_none:
- self._delay_grad_buffer.zero_()
- def _lazy_init(self):
- # Initialization for DDP that occurs after construction, but lazily
- # before the first forward pass.
- self._setup_in_backward_optimizers()
- self._lazy_init_ran = True
- def _should_disable_cpp_reducer(self) -> bool:
- return self._use_python_reducer and (
- torch._utils.is_compiling() or self._force_to_disable_cpp_reducer
- )
- def _pre_forward(self, *inputs, **kwargs):
- if self._should_disable_cpp_reducer():
- return inputs, kwargs
- # Disable the python reducer if compiled_autograd is not enabled.
- if self._accum_grad_hooks:
- for index, h in enumerate(self._accum_grad_hooks):
- h.remove()
- self._accum_grad_hooks.clear()
- if not self._lazy_init_ran and not torch._utils.is_compiling():
- self._lazy_init()
- if self._delay_all_reduce_all_params:
- return inputs, kwargs
- if torch.is_grad_enabled() and self.require_backward_grad_sync:
- assert self.logger is not None
- self.logger.set_runtime_stats_and_log()
- self.reducer.prepare_for_forward()
- # Notify the join context that this process has not joined, if
- # needed
- work = Join.notify_join_context(self)
- if work:
- self.reducer._set_forward_pass_work_handle(
- work, self._divide_by_initial_world_size # type: ignore[arg-type]
- )
- # Calling _rebuild_buckets before forward computation,
- # It may allocate new buckets before deallocating old buckets
- # inside _rebuild_buckets. To save peak memory usage,
- # call _rebuild_buckets before the peak memory usage increases
- # during forward computation.
- # This should be called only once during whole training period.
- if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
- logger.info("Reducer buckets have been rebuilt in this iteration.")
- self._has_rebuilt_buckets = True
- # sync params according to location (before/after forward) user
- # specified as part of hook, if hook was specified.
- if self._check_sync_bufs_pre_fwd():
- self._sync_buffers()
- if self._join_config.enable:
- # Notify joined ranks whether they should sync in backwards pass or not.
- self._check_global_requires_backward_grad_sync(is_joined_rank=False)
- if self.device_ids:
- moved_inputs, moved_kwargs = _to_kwargs(
- inputs,
- kwargs,
- torch.device(self.device_type, self.device_ids[0]),
- self.use_side_stream_for_tensor_copies,
- )
- args, kwargs = moved_inputs[0], moved_kwargs[0]
- # Cast inputs to reduced precision if needed.
- if self.mixed_precision is not None:
- args, kwargs = _cast_forward_inputs(
- self.mixed_precision.param_dtype,
- *args,
- **kwargs,
- )
- return args, kwargs
- else:
- # Cast inputs to reduced precision if needed.
- # TODO (rohan-varma) test this codepath.
- if self.mixed_precision is not None:
- inputs, kwargs = _cast_forward_inputs(
- self.mixed_precision.param_dtype,
- *inputs,
- **kwargs,
- )
- return inputs, kwargs
- def _post_forward(self, output):
- if self._should_disable_cpp_reducer():
- return output
- if self._delay_all_reduce_all_params:
- self._clear_grad_buffer()
- return output
- # sync params according to location (before/after forward) user
- # specified as part of hook, if hook was specified.
- if self._check_sync_bufs_post_fwd():
- self._sync_buffers()
- if torch.is_grad_enabled() and self.require_backward_grad_sync:
- self.require_forward_param_sync = True
- # We'll return the output object verbatim since it is a freeform
- # object. We need to find any tensors in this object, though,
- # because we need to figure out which parameters were used during
- # this forward pass, to ensure we short circuit reduction for any
- # unused parameters. Only if `find_unused_parameters` is set.
- if self.find_unused_parameters and not self.static_graph:
- # Do not need to populate this for static graph.
- self.reducer.prepare_for_backward(list(_find_tensors(output)))
- else:
- self.reducer.prepare_for_backward([])
- else:
- self.require_forward_param_sync = False
- # TODO: DDPSink is currently enabled for unused parameter detection and
- # static graph training for first iteration.
- if (self.find_unused_parameters and not self.static_graph) or (
- self.static_graph and not self._static_graph_delay_allreduce_enqueued
- ):
- (
- output_tensor_list,
- treespec,
- output_is_rref,
- ) = _tree_flatten_with_rref(output)
- output_placeholders = [None for _ in range(len(output_tensor_list))]
- # Do not touch tensors that have no grad_fn, which can cause issues
- # such as https://github.com/pytorch/pytorch/issues/60733
- for i, output in enumerate(output_tensor_list):
- if torch.is_tensor(output) and output.grad_fn is None:
- output_placeholders[i] = output
- # When find_unused_parameters=True, makes tensors which require grad
- # run through the DDPSink backward pass. When not all outputs are
- # used in loss, this makes those corresponding tensors receive
- # undefined gradient which the reducer then handles to ensure
- # param.grad field is not touched and we don't error out.
- passthrough_tensor_list = _DDPSink.apply(
- weakref.ref(self),
- *output_tensor_list,
- )
- for i in range(len(output_placeholders)):
- if output_placeholders[i] is None:
- output_placeholders[i] = passthrough_tensor_list[i]
- # Reconstruct output data structure.
- output = _tree_unflatten_with_rref(
- output_placeholders, treespec, output_is_rref
- )
- # At the end of the forward pass, reset the grad buffer and grad views
- self._clear_grad_buffer()
- return output
- def forward(self, *inputs, **kwargs):
- with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
- inputs, kwargs = self._pre_forward(*inputs, **kwargs)
- output = (
- self.module.forward(*inputs, **kwargs)
- if self._delay_all_reduce_all_params
- else self._run_ddp_forward(*inputs, **kwargs)
- )
- return self._post_forward(output)
- def scatter(self, inputs, kwargs, device_ids):
- return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
- def to_kwargs(self, inputs, kwargs, device_id):
- # Kept for BC
- return _to_kwargs(
- inputs,
- kwargs,
- torch.device(self.device_type, device_id),
- self.use_side_stream_for_tensor_copies,
- )
- def gather(self, outputs, output_device):
- return gather(outputs, output_device, dim=self.dim)
- def train(self, mode=True):
- super().train(mode)
- return self
- # When running in join mode, schedules an allreduce to notify joined ranks
- # of whether backwards pass synchronization will run this iteration or not.
- def _check_global_requires_backward_grad_sync(self, is_joined_rank):
- if not is_joined_rank and self.require_backward_grad_sync:
- requires_sync_tensor = torch.ones(1, device=self.device)
- else:
- requires_sync_tensor = torch.zeros(1, device=self.device)
- work = dist.all_reduce(
- requires_sync_tensor, group=self.process_group, async_op=True
- )
- # (kwen2501) This if condition is a plain translation of previous
- # behavior, i.e. in the `is_joined_rank=False` case, `work.wait()`
- # is not called and it doesn't care about the result. I am guessing
- # that it just wants to fire a matching all-reduce and does not want
- # the main stream to wait.
- if is_joined_rank:
- work.wait()
- should_sync_backwards = requires_sync_tensor.item() != 0
- return should_sync_backwards
- else:
- return None # Return value is not/should not be used.
- # When running in join mode, checks and performs sync of module buffers if
- # the models have buffers that should be synchronized in the forward pass.
- def _check_and_sync_module_buffers(self):
- if self._check_sync_bufs_pre_fwd():
- authoritative_rank = self._find_common_rank(self._distributed_rank, False)
- self._sync_module_buffers(authoritative_rank)
- # When running in join model, agrees upon a common rank and broadcast model
- # parameters to all other ranks.
- def _sync_final_model(self, is_last_joiner):
- # Agree upon the process that will be the authoritative model copy.
- # The current rank is a candidate for being the authoritative copy if
- # is_last_joiner=True. We break ties via picking the larger rank.
- self._authoritative_rank = self._find_common_rank(
- self._distributed_rank, is_last_joiner
- )
- _sync_module_states(
- module=self.module,
- process_group=self.process_group,
- broadcast_bucket_size=self.broadcast_bucket_size,
- src=self._authoritative_rank,
- params_and_buffers_to_ignore=self.parameters_to_ignore,
- broadcast_buffers=self.broadcast_buffers,
- )
- # Schedule comm ops to match those scheduled in the reducer's backward
- # pass.
- def _match_all_reduce_for_bwd_pass(self):
- comm_work = []
- # Schedule comm in the same order as Reducer schedules them, i.e.
- # the order of the buckets. Retrieving the bucket order from the reducer
- # ensures that we keep the same order in join mode, such as when bucket
- # order is rebuilt dynamically.
- # Returns grad_buckets in order, but real tensors are substituted with
- # zero tensors of the same shape.
- grad_buckets = self.reducer._get_zeros_like_grad_buckets()
- for grad_bucket in grad_buckets:
- # Joined processes contribute zero gradient. In the case that
- # divide_by_initial_world_size=True, we divide grads by the static
- # world size, if not, the dividing factor is reduced by the number
- # of joined processes.
- work = self.reducer._run_comm_hook(grad_bucket)
- comm_work.append(work)
- for work in comm_work:
- work.wait()
- # Allreduces the used parameter mapping across ranks.
- def _match_unused_params_allreduce(self):
- locally_used_param_map = self.reducer._get_local_used_map()
- self.process_group.allreduce(locally_used_param_map)
- def join(
- self,
- divide_by_initial_world_size: bool = True,
- enable: bool = True,
- throw_on_early_termination: bool = False,
- ):
- r"""
- Context manager for training with uneven inputs across processes in DDP.
- This context manager will keep track of already-joined DDP processes,
- and "shadow" the forward and backward passes by inserting collective
- communication operations to match with the ones created by non-joined
- DDP processes. This will ensure each collective call has a corresponding
- call by already-joined DDP processes, preventing hangs or errors that
- would otherwise happen when training with uneven inputs across
- processes. Alternatively, if the flag ``throw_on_early_termination`` is
- specified to be ``True``, all trainers will throw an error once one rank
- runs out of inputs, allowing these errors to be caught and handled
- according to application logic.
- Once all DDP processes have joined, the context manager will broadcast
- the model corresponding to the last joined process to all processes to
- ensure the model is the same across all processes
- (which is guaranteed by DDP).
- To use this to enable training with uneven inputs across processes,
- simply wrap this context manager around your training loop. No further
- modifications to the model or data loading is required.
- .. warning::
- If the model or training loop this context manager is wrapped around
- has additional distributed collective operations, such as
- ``SyncBatchNorm`` in the model's forward pass, then the flag
- ``throw_on_early_termination`` must be enabled. This is because this
- context manager is not aware of non-DDP collective communication.
- This flag will cause all ranks to throw when any one rank
- exhausts inputs, allowing these errors to be caught and recovered
- from across all ranks.
- Args:
- divide_by_initial_world_size (bool): If ``True``, will divide
- gradients by the initial ``world_size`` DDP training was launched
- with. If ``False``, will compute the effective world size
- (number of ranks that have not depleted their inputs yet) and
- divide gradients by that during allreduce. Set
- ``divide_by_initial_world_size=True`` to ensure every input
- sample including the uneven inputs have equal weight in terms of
- how much they contribute to the global gradient. This is
- achieved by always dividing the gradient by the initial
- ``world_size`` even when we encounter uneven inputs. If you set
- this to ``False``, we divide the gradient by the remaining
- number of nodes. This ensures parity with training on a smaller
- ``world_size`` although it also means the uneven inputs would
- contribute more towards the global gradient. Typically, you
- would want to set this to ``True`` for cases where the last few
- inputs of your training job are uneven. In extreme cases, where
- there is a large discrepancy in the number of inputs, setting
- this to ``False`` might provide better results.
- enable (bool): Whether to enable uneven input detection or not. Pass
- in ``enable=False`` to disable in cases where you know that
- inputs are even across participating processes. Default is
- ``True``.
- throw_on_early_termination (bool): Whether to throw an error
- or continue training when at least one rank has exhausted
- inputs. If ``True``, will throw upon the first rank reaching end
- of data. If ``False``, will continue training with a smaller
- effective world size until all ranks are joined. Note that if
- this flag is specified, then the flag
- ``divide_by_initial_world_size`` would be ignored. Default
- is ``False``.
- Example::
- >>> # xdoctest: +SKIP("Distributed")
- >>> import torch
- >>> import torch.distributed as dist
- >>> import os
- >>> import torch.multiprocessing as mp
- >>> import torch.nn as nn
- >>> # On each spawned worker
- >>> def worker(rank):
- >>> dist.init_process_group("nccl", rank=rank, world_size=2)
- >>> torch.cuda.set_device(rank)
- >>> model = nn.Linear(1, 1, bias=False).to(rank)
- >>> model = torch.nn.parallel.DistributedDataParallel(
- >>> model, device_ids=[rank], output_device=rank
- >>> )
- >>> # Rank 1 gets one more input than rank 0.
- >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
- >>> with model.join():
- >>> for _ in range(5):
- >>> for inp in inputs:
- >>> loss = model(inp).sum()
- >>> loss.backward()
- >>> # Without the join() API, the below synchronization will hang
- >>> # blocking for rank 1's allreduce to complete.
- >>> torch.cuda.synchronize(device=rank)
- """
- return Join(
- [self],
- enable,
- throw_on_early_termination,
- divide_by_initial_world_size=divide_by_initial_world_size,
- )
- def join_hook(
- self,
- **kwargs,
- ):
- r"""
- DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes.
- Arguments:
- kwargs (dict): a :class:`dict` containing any keyword arguments
- to modify the behavior of the join hook at run time; all
- :class:`Joinable` instances sharing the same join context
- manager are forwarded the same value for ``kwargs``.
- The hook supports the following keyword arguments:
- divide_by_initial_world_size (bool, optional):
- If ``True``, then gradients are divided by the initial world
- size that DDP was launched with.
- If ``False``, then gradients are divided by the effective world
- size (i.e. the number of non-joined processes), meaning that
- the uneven inputs contribute more toward the global gradient.
- Typically, this should be set to ``True`` if the degree of
- unevenness is small but can be set to ``False`` in extreme
- cases for possibly better results.
- Default is ``True``.
- """
- divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True)
- return _DDPJoinHook(
- self, divide_by_initial_world_size=divide_by_initial_world_size
- )
- @property
- def join_device(self):
- return self.device
- @property
- def join_process_group(self):
- return self.process_group
- def _register_buffer_comm_hook(
- self,
- state,
- hook: Callable,
- comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
- ):
- r"""
- Allow custom registration of hooks that define how buffer are synchronized across ranks.
- The hook takes in an optional state and is passed in a Dict[str, Tensor]
- corresponding to buffer names and the buffers, and can run arbitrary reductions
- on buffers as opposed to DDP's default broadcast from rank 0. This is useful for
- example if a counter needs to be summed or averaged across ranks every iteration.
- Args:
- state (Any): Optional state that is passed to the hook.
- hook (Callable): Callable with the following signature:
- ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``
- comm_hook_location (_BufferCommHookLocation): Enum value indicating
- where to run the hook.
- _BufferCommHookLocation.PRE_FORWARD means that the
- hook will run _before_ the forward pass, and
- _BufferCommHookLocation.POST_FORWARD means that the
- hook will run _after_ the forward pass.
- NOTE: To maximize performance, users can return a
- List[torch.futures.Future] from their hook, and DDP will
- install and await these hooks appropriately at the end of
- the backward pass. This will ensure all buffers are
- synchronized by the end of the backward pass. If this
- setting is used, it is recommended to pass
- comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
- which will trigger the hook after the forward pass.
- If _BufferCommHookLocation.PRE_FORWARD is used, users must
- ensure appropriate synchronization when manipulating GPU
- buffers in the forward pass.
- """
- assert callable(hook)
- self.buffer_hook = _BufferCommHook(
- buffer_comm_hook=hook,
- buffer_comm_hook_state=state,
- buffer_comm_hook_location=comm_hook_location,
- )
- def register_comm_hook(self, state: object, hook: Callable):
- r"""
- Register communication hook for user-defined DDP aggregation of gradients across multiple workers.
- This hook would be very useful for researchers to try out new ideas. For
- example, this hook can be used to implement several algorithms like GossipGrad
- and gradient compression which involve different communication strategies for
- parameter syncs while running Distributed DataParallel training.
- Args:
- state (object): Passed to the hook to maintain any state information during the training process.
- Examples include error feedback in gradient compression,
- peers to communicate with next in GossipGrad, etc.
- It is locally stored by each worker
- and shared by all the gradient tensors on the worker.
- hook (Callable): Callable with the following signature:
- ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
- This function is called once the bucket is ready. The
- hook can perform whatever processing is needed and return
- a Future indicating completion of any async work (ex: allreduce).
- If the hook doesn't perform any communication, it still
- must return a completed Future. The Future should hold the
- new value of grad bucket's tensors. Once a bucket is ready,
- c10d reducer would call this hook and use the tensors returned
- by the Future and copy grads to individual parameters.
- Note that the future's return type must be a single tensor.
- We also provide an API called ``get_future`` to retrieve a
- Future associated with the completion of ``c10d.ProcessGroup.Work``.
- ``get_future`` is currently supported for NCCL and also supported for most
- operations on GLOO and MPI, except for peer to peer operations (send/recv).
- .. warning ::
- Grad bucket's tensors will not be predivided by world_size. User is responsible
- to divide by the world_size in case of operations like allreduce.
- .. warning ::
- DDP communication hook can only be registered once and should be registered
- before calling backward.
- .. warning ::
- The Future object that hook returns should contain a single tensor
- that has the same shape with the tensors inside grad bucket.
- .. warning ::
- ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support
- for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
- Example::
- Below is an example of a noop hook that returns the same tensor.
- >>> # xdoctest: +SKIP('undefined name')
- >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
- >>> fut = torch.futures.Future()
- >>> fut.set_result(bucket.buffer())
- >>> return fut
- >>> ddp.register_comm_hook(state=None, hook=noop)
- Example::
- Below is an example of a Parallel SGD algorithm where gradients are encoded before
- allreduce, and then decoded after allreduce.
- >>> # xdoctest: +SKIP('undefined name')
- >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
- >>> encoded_tensor = encode(bucket.buffer()) # encode gradients
- >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future()
- >>> # Define the then callback to decode.
- >>> def decode(fut):
- >>> decoded_tensor = decode(fut.value()[0]) # decode gradients
- >>> return decoded_tensor
- >>> return fut.then(decode)
- >>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
- """
- self._check_comm_hook(hook)
- assert self.logger is not None
- self.logger._set_comm_hook_name(hook.__qualname__)
- self._comm_hooks.append((hook, state))
- dist._register_comm_hook(self.reducer, state, hook)
- def _register_builtin_comm_hook(self, comm_hook_type):
- r"""
- Register a built-in communication hook that specifies how DDP aggregates gradients across multiple workers.
- The built-in hooks aim to provide efficient C++ implementations for certain hooks,
- which might not be as efficient if implemented in Python using a Python communication hook.
- Args:
- comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc.
- .. warning ::
- DDP communication hook can only be registered once and should be registered
- before calling backward.
- Example::
- Below is an example of a FP16 compression where gradients are
- compressed into 16-bit floating-point numbers before allreduce, and
- then decompressed after allreduce.
- >>> # xdoctest: +SKIP('undefined name')
- >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)
- """
- assert self.logger is not None
- self.logger._set_comm_hook_name(str(comm_hook_type))
- dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
- def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs):
- r"""
- Register an optimizer in DDP to optimize parameter immediately after its gradient reduction.
- Registers an optimizer with DDP such that the optimization for a
- parameter will run immediately when that parameter's gradient is
- finished with reduction, instead of waiting for all parameters'
- gradients to finish reduction. This can result in a training speedup
- depending on your workload since the optimizer can run while gradient
- reduction for other parameters are still ongoing. In addition, this has
- the potential to reduce peak memory consumption during training, as it
- only needs to load the per-parameter optimizer states of a single
- parameter at a time, instead of loading all per-parameter optimizer
- states at once.
- Args:
- optim (Type): a ``torch.optim.Optimizer`` class to be registered
- as a fused optimizer.
- *args (Sequence[Any]): Arguments to forward to `optim`.
- optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
- to optimize, similar to `params` argument of traditional `torch.optim`
- Optimizers. If this is omitted, all DDP model parameters will be
- optimized.
- **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`.
- .. warning ::
- _register_fused_optim should only be called once on a DDP instance,
- and registering multiple fused optimizers for the same DDP model
- is not currently supported. Please ping
- https://github.com/pytorch/pytorch/issues/71595 if this is necessary
- for your use case.
- .. warning ::
- _register_fused_optim and register_comm_hook currently do not
- compose together, meaning that custom DDP communication hooks are
- not supported with overlapped optimizers. Please ping
- https://github.com/pytorch/pytorch/issues/71595 if this is necessary
- for your use case.
- .. warning ::
- Gradient accumulation and DDP `no_sync` are currently not supported
- with overlapped optimizer. Please ping
- https://github.com/pytorch/pytorch/issues/71595 if this is necessary
- for your use case.
- Example::
- >>> # xdoctest: +SKIP("No rendezvous handler")
- >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
- >>> net = torch.nn.parallel.DistributedDataParallel(model, pg)
- >>> lr = 1e-2
- >>> betas = (0.9, 0.99)
- >>> eps = 1e-6
- >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps)
- >>> # Example with subset of parameters
- >>> params_to_opt = [list(net.parameters())[0]]
- >>> net._register_fused_optim(
- ... torch.optim.Adam, lr, optim_params=params_to_opt, betas=betas, eps=eps
- ... )
- """
- # Note: importing in function, otherwise this will cause a circular
- # import as optimizer_overlap module needs to import DistributedDataParallel.
- from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim
- overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs)
- try:
- overlapped_optim.register_ddp(self)
- except NotImplementedError as e:
- raise RuntimeError(
- f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
- ) from e
- def _distributed_broadcast_coalesced(
- self, tensors, buffer_size, authoritative_rank=0
- ):
- dist._broadcast_coalesced(
- self.process_group, tensors, buffer_size, authoritative_rank
- )
- def _check_sync_bufs_post_fwd(self):
- return (
- self.will_sync_module_buffers()
- and hasattr(self, "buffer_hook")
- and self.buffer_hook.buffer_comm_hook_location
- == _BufferCommHookLocation.POST_FORWARD
- )
- def _check_sync_bufs_pre_fwd(self):
- return self.will_sync_module_buffers() and (
- not hasattr(self, "buffer_hook")
- or self.buffer_hook.buffer_comm_hook_location
- == _BufferCommHookLocation.PRE_FORWARD
- )
- def will_sync_module_buffers(self):
- return (
- self.require_forward_param_sync
- and self.broadcast_buffers
- and len(self.modules_buffers) > 0
- )
- def _find_common_rank(self, input_rank, rank_cond):
- # -1 indicates that this rank is not under consideration to be the
- # common_rank
- rank_to_use = torch.tensor(
- [input_rank if rank_cond else -1],
- device=self.device,
- )
- dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group)
- if rank_to_use.item() == -1:
- self._log_and_throw(
- ValueError,
- "BUG! Expected rank_cond to be true for at least one process."
- " This indicates a bug in PyTorch, please report an issue.",
- )
- return rank_to_use.item()
- def _sync_buffers(self):
- with torch.no_grad():
- # module buffer sync
- # Synchronize buffers across processes.
- # If we are running DDP with the join manager, we have to agree
- # upon a rank to sync module buffers from, since rank 0 may
- # already have been joined and have stale module buffers.
- if self._join_config.enable:
- authoritative_rank = self._find_common_rank(
- self._distributed_rank, True
- )
- else:
- # The process with rank 0 is considered the authoritative copy.
- authoritative_rank = 0
- # Update self.modules_buffers incase any buffers were
- # reassigned.
- self._assign_modules_buffers()
- self._sync_module_buffers(authoritative_rank)
- def _sync_module_buffers(self, authoritative_rank):
- if not hasattr(self, "buffer_hook"):
- self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
- else:
- hook = self.buffer_hook.buffer_comm_hook
- state = self.buffer_hook.buffer_comm_hook_state
- futs = hook(state, self.named_module_buffers)
- if futs is not None:
- self.reducer._install_post_backward_futures(futs)
- def _default_broadcast_coalesced(
- self, bufs=None, bucket_size=None, authoritative_rank=0
- ):
- """
- Broadcasts buffers from rank 0 to rest of workers.
- If bufs, bucket_size are None, default values self.modules_buffers
- and self.broadcast_bucket_size are used instead.
- """
- if bufs is None:
- bufs = self.modules_buffers
- if bucket_size is None:
- bucket_size = self.broadcast_bucket_size
- self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank)
- def _passing_sync_batchnorm_handle(self, module):
- for layer in module.modules():
- if isinstance(layer, torch.nn.modules.SyncBatchNorm):
- if self.device_type == "cpu":
- self._log_and_throw(
- ValueError,
- "SyncBatchNorm layers only work with GPU modules",
- )
- def _check_comm_hook(self, hook):
- if not callable(hook):
- self._log_and_throw(TypeError, "Communication hook must be callable.")
- sig = inspect.signature(hook)
- if (
- sig.parameters["bucket"].annotation != inspect._empty
- and sig.parameters["bucket"].annotation != dist.GradBucket
- ):
- self._log_and_throw(
- ValueError,
- "Communication hook: bucket annotation should be dist.GradBucket.",
- )
- if (
- sig.return_annotation != inspect._empty
- and sig.return_annotation != torch.futures.Future[torch.Tensor]
- ):
- self._log_and_throw(
- ValueError,
- "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].",
- )
- if hook.__name__ in [
- "bf16_compress_hook",
- "bf16_compress_wrapper_hook",
- ] and (
- (torch.version.cuda is None and torch.version.hip is None)
- or (
- torch.version.cuda is not None
- and int(torch.version.cuda.split(".")[0]) < 11
- )
- or not dist.is_available()
- or not dist.is_nccl_available()
- or torch.cuda.nccl.version() < (2, 10)
- ):
- self._log_and_throw(
- TypeError,
- "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.",
- )
- @property
- def _distributed_rank(self):
- return dist.get_rank(self.process_group)
- @staticmethod
- def _get_data_parallel_params(module, named_params=False):
- """Return a generator of parameters managed by a given DDP unit."""
- for param in (
- module.parameters() if not named_params else module.named_parameters()
- ):
- if not hasattr(param, "_ddp_ignored"):
- yield param
- @staticmethod
- def _set_params_and_buffers_to_ignore_for_model(
- module, params_and_buffers_to_ignore
- ):
- """
- Set parameters and buffers to be ignored by DDP.
- Expected format for parameters is the fully qualified name: {module_name}.{param_name}, and
- similarly, {module_name}.{buffer_name} for buffers. For example:
- params_to_ignore = []
- # NB: model here is vanilla PyTorch module, not yet wrapped with DDP.
- for module_name, module in model.named_modules():
- for param_name, param in module.named_parameters(recurse=False):
- if should_ignore(param):
- # Create expected format
- fqn = f"{module_name}.{param_name}"
- params_to_ignore.append(fqn)
- torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
- model,
- params_to_ignore
- )
- """
- # This is a workaround to set parameters and buffers DDP should ignore
- # during synchronization. It will be removed when the API is finalized
- # as part of addressing https://github.com/pytorch/pytorch/issues/43690.
- module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore
- for name, param in module.named_parameters():
- if name in params_and_buffers_to_ignore:
- param._ddp_ignored = True
- for name, buffer in module.named_buffers():
- if name in params_and_buffers_to_ignore:
- buffer._ddp_ignored = True
- def _get_ddp_logging_data(self):
- r"""
- Return a dictionary of logging data for debugging and analysis.
- This interface can be called after DistributedDataParallel() is
- constructed. It returns a dictionary of logging data. It could help
- for debugging and analysis. The logging data includes DistributedDataParallel
- constructor input parameters, some internal states of DistributedDataParallel
- and performance metrics. Simply print the dictionary and see what
- these metrics are.
- This is a prototype interface and subject to change in the future.
- """
- assert self.logger is not None
- ddp_logging_data = self.logger._get_ddp_logging_data()
- return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map}
- def _set_ddp_runtime_logging_sample_rate(self, sample_rate):
- r"""
- Set sample_rate of collecting runtime stats.
- This interface allows users to set sample_rate of collecting
- runtime stats. The runtime stats will be recorded for the
- first 10 iterations, after 10 iterations runtime stats will be
- recorded once every "sample_rate" training iterations. In
- default, runtime stats are recorded for the first 10 iterations,
- after 10 iterations runtime stats are recorded once every
- "kDDPRuntimeLoggingSampleRate=100" training iterations.
- This is a prototype interface and subject to change in the future.
- """
- if sample_rate < 1:
- self._log_and_throw(
- ValueError,
- "DDP runtime logging sample rate should be equal or greater than 1",
- )
- self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate)
- def _set_static_graph(self):
- """
- Set static graph for DDP.
- It is recommended to set static graph in the DDP constructor, which will
- call this private API internally.
- """
- # If self.static_graph has been set, no need to set it again
- if self.static_graph:
- warnings.warn(
- "You've set static_graph to be True, no need to set it again."
- )
- return
- self.static_graph = True
- self._static_graph_delay_allreduce_enqueued = False
- self.reducer._set_static_graph()
- assert self.logger is not None
- self.logger._set_static_graph()
- if self.find_unused_parameters:
- warnings.warn(
- "You passed find_unused_parameters=true to DistributedDataParallel, "
- "`_set_static_graph` will detect unused parameters automatically, so "
- "you do not need to set find_unused_parameters=true, just be sure these "
- "unused parameters will not change during training loop while calling "
- "`_set_static_graph`."
- )
- def _remove_autograd_hooks(self):
- """Remove autograd hooks registered by the reducer on the model parameters."""
- self.reducer._remove_autograd_hooks()
- def _check_reducer_finalized(self):
- """
- Check if the reducer has processed all buckets and finalized the backward appropriately.
- It is useful to call this method after calling .backward() in your training loop
- in order to avoid subsequent hard to debug errors down the road due to the
- reducer not finalizing backward.
- """
- self.reducer._check_reducer_finalized()
- def _set_sparse_metadata(self, global_unique_ids):
- self.reducer._set_sparse_metadata(global_unique_ids)
- def _update_process_group(self, new_process_group):
- """
- Dynamically updates the process group for DDP so that we can shrink/expand DDP
- world size without having to reinitialize DDP.
- NOTE: If you are using custom communications hooks via, register_comm_hook,
- you need to update the process groups for those hooks separately.
- """
- # Force a rebuild of buckets for a new process group. This ensures all ranks
- # are synchronized in terms of when they will rebuild buckets and also
- # re-evaluates previous assumptions of buckets given the world size might have
- # changed.
- self._has_rebuilt_buckets = False
- self.reducer._reset_state()
- if not _rank_not_in_group(new_process_group):
- self.process_group = new_process_group
- self.reducer._update_process_group(new_process_group)
- def _set_ddp_sink_clone(self, val: bool):
- """
- Sets whether or not DDPSink should clone the output tensors or not.
- The default is True since if the loss is modified in place we run
- into the view is modified in-place error.
- Although, cloning the tensors can add significant memory and
- performance hit if the number and size of tensors are large. As
- a result, this can be set to False if you are not modifying the
- loss in place.
- """
- self._ddp_sink_clone = val
|