| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147 |
- # mypy: allow-untyped-defs
- import sys
- import warnings
- from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
- import torch
- import torch.distributed as dist
- import torch.distributed.distributed_c10d as c10d
- from torch.distributed.device_mesh import DeviceMesh
- from torch.fx.experimental.proxy_tensor import get_innermost_proxy_mode
- from . import _functional_collectives_impl as fun_col_impl
- try:
- from torch.utils._cxx_pytree import tree_map_only
- except ImportError:
- from torch.utils._pytree import tree_map_only # type: ignore[no-redef]
- if torch._running_with_deploy():
- def is_torchdynamo_compiling():
- """Can't import torchdynamo in torchdeploy builds currently."""
- return False
- else:
- try:
- from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
- except Exception:
- warnings.warn(
- "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
- )
- def is_torchdynamo_compiling():
- return False
- """
- New traceable, functional collectives.
- RFC: https://github.com/pytorch/pytorch/issues/93173
- compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
- eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
- automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
- a downstream op.
- Issues:
- * Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
- * Proper support for eager requires inplace ops. We should explore having it as an option for the API.
- """
- """
- Functional collectives are asynchronous only and we perform implicit stream synchronization
- on behalf of the user.
- We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
- first usage of the tensor and insert cross stream sync at the right place.
- The above are the easy bits, the hard one is how we match the Work object returned by
- c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
- op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
- dispatcher which might call other implementations that are allowed to change the returned
- tensor - even return a tensor with a different shape (see ``torch.vmap``).
- This means the caller of our ops receives a Tensor that is not guaranteed to be the same
- allocated by our implementations and that makes pairing The AsyncTensor to the original
- tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
- Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
- identity is not stable across dispatch, the op caller would end up with a different Tensor
- instance that would not match any in the dictionary.
- With Tensor identity out of the question, we decided use the tensor data pointer, which
- should be stable across all the Tensor changes done during dispatch.
- We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
- We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
- Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
- can clean up stale entries in the dictionary.
- To eliminate the possibility of races we have a global version counter that is used by the finalizer.
- As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
- """
- """
- Functional collectives can accept any of these types to describe the ranks participating in collectives.
- The different types will be desugared to a canonical format
- """
- RANK_TYPES = Union[
- List[int],
- List[List[int]],
- dist.ProcessGroup,
- DeviceMesh,
- Tuple["dist._tensor.DeviceMesh", int],
- str,
- ]
- """
- User facing APIs for functional collectives
- -------------------------------------------
- These apis are called by user code and expected to work both in eager execution and compilation,
- but there are significant differences to how the two modes are implemented underneath.
- Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op)
- just before the tensor is first used. Compiled tracing currently relies on the compiler to perform this optimization,
- and cannot yet correctly trace the AsyncTensor wrapper class. In the future, these paths may be unified
- if sufficient subclass support is added in dynamo.
- Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern.
- Here's how it works under torch.compile/dynamo:
- all_reduce(...)
- |--> _expand_group(...) - desugars processgroup into canonical/traceable format
- |--> c10d_functional.all_reduce(...) - dynamo captures this op call, doesn't trace deeper
- |--> _maybe_wrap_tensor(...) - wait_tensor() op is immediately called, no AsyncTensor subclass needed
- And under eager execution:
- all_reduce(...)
- |--> _expand_group(...) - same as above, but less critical for eager
- |--> c10d_functional.all_reduce(...) - dispatches to real kernel OR records op in trace
- |--> _maybe_wrap_tensor(...) - AsyncTensor wrapper applied to returned tensor,
- which issues wait_tensor() at the time of first use
- """
- def wait_tensor(tensor):
- """
- Wait on a tensor returned by the collectives ops.
- Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
- """
- return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
- def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""):
- """
- Broadcasts the tensor to all processes in the given process group.
- Args:
- src (int): Source rank
- group (ProcessGroup or List[int]): The process group to work on.
- tag (str, optional): A unique identifier for the collective. Default: empty string
- """
- group_name = _resolve_group_name(group, tag)
- tensor = torch.ops._c10d_functional.broadcast(self, src, group_name)
- return _maybe_wrap_tensor(tensor)
- def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
- """
- Reduces the tensor data across all machines in such a way that all get
- the final result.
- The input tensor is left unmodified.
- Group can be one of:
- List[int]: ranks participating in the collective.
- List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
- ProcessGroup: Will perform a collective using the ranks and tag of the PG.
- DeviceMesh: Do a SPMD collective over all ranks of the mesh
- (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
- :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
- that information and perform collective algebraic optimization. Use other forms of input for that.
- """
- group_name = _resolve_group_name(group, tag)
- tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
- return _maybe_wrap_tensor(tensor)
- def all_gather_tensor(
- self: torch.Tensor,
- gather_dim: int,
- group: RANK_TYPES,
- tag: str = "",
- ):
- """
- Gather tensor data across from all machines and concatenate over ``gather_dim``.
- Note that it currently only supports gather_dim = 0.
- The input tensor is left unmodified.
- Group can be one of:
- List[int]: ranks participating in the collective.
- List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
- ProcessGroup: Will perform a collective using the ranks and tag of the PG.
- DeviceMesh: Do a SPMD collective over all ranks of the mesh
- (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
- :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
- that information and perform collective algebraic optimization. Use other forms of input for that.
- """
- assert self.is_contiguous()
- group_name = _resolve_group_name(group, tag)
- group_size = c10d._get_group_size_by_name(group_name)
- tensor = torch.ops._c10d_functional.all_gather_into_tensor(
- self, group_size, group_name
- )
- res = _maybe_wrap_tensor(tensor)
- # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
- if gather_dim != 0:
- # torch.cat access the data so we already need to wait here, first do wait
- # and then chunk + cat avoid us going through ACT dispatching logic again
- if isinstance(res, AsyncCollectiveTensor):
- res = res.wait() # type: ignore[attr-defined]
- res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
- return res
- def all_gather_tensor_autograd(
- self: torch.Tensor,
- gather_dim: int,
- group: RANK_TYPES,
- tag: str = "",
- ):
- """
- Gather tensor data across from all machines and concatenate over ``gather_dim``.
- Note that it currently only supports gather_dim = 0.
- This function is the same as all_gather_tensor but will propagate the
- backwards gradient across workers.
- See all_gather_tensor for more details on usage.
- """
- group_name = _resolve_group_name(group, tag)
- group_size = c10d._get_group_size_by_name(group_name)
- tensor = torch.ops._c10d_functional_autograd.all_gather_into_tensor(
- self, group_size, group_name
- )
- res = _FromTorchTensor.apply(tensor)
- # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
- if gather_dim != 0:
- # torch.cat access the data so we already need to wait here, first do wait
- # and then chunk + cat avoid us going through ACT dispatching logic again
- if isinstance(res, AsyncCollectiveTensor):
- res = res.wait() # type: ignore[attr-defined]
- res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
- return res
- def reduce_scatter_tensor(
- self: torch.Tensor,
- reduceOp: str,
- scatter_dim: int,
- group: RANK_TYPES,
- tag: str = "",
- ):
- """
- Reduces the tensor data across all machines in such a way that all get
- the final result, then scatter the results to corresponding ranks.
- The input tensor is left unmodified.
- Group can be one of:
- List[int]: ranks participating in the collective.
- List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
- ProcessGroup: Will perform a collective using the ranks and tag of the PG.
- DeviceMesh: Do a SPMD collective over all ranks of the mesh
- (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
- :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
- that information and perform collective algebraic optimization. Use other forms of input for that.
- """
- group_name = _resolve_group_name(group, tag)
- group_size = c10d._get_group_size_by_name(group_name)
- assert (
- self.size(scatter_dim) % group_size == 0
- ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
- if scatter_dim != 0:
- tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
- self = torch.cat(tensor_list)
- tensor = torch.ops._c10d_functional.reduce_scatter_tensor(
- self,
- reduceOp.lower(),
- group_size,
- group_name, # type: ignore[possibly-undefined]
- )
- res = _maybe_wrap_tensor(tensor)
- return res
- def reduce_scatter_tensor_autograd(
- self: torch.Tensor,
- reduceOp: str,
- scatter_dim: int,
- group: RANK_TYPES,
- tag: str = "",
- ):
- """
- Reduces the tensor data across all machines in such a way that all get
- the final result, then scatter the results to corresponding ranks.
- This function is the same as reduce_scatter_tensor but will propagate the
- backwards gradient across workers.
- Currently only the "sum" reduceOp is supported.
- See reduce_scatter_tensor for more details on usage.
- """
- group_name = _resolve_group_name(group, tag)
- group_size = c10d._get_group_size_by_name(group_name)
- assert (
- self.size(scatter_dim) % group_size == 0
- ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
- if scatter_dim != 0:
- tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
- self = torch.cat(tensor_list)
- tensor = torch.ops._c10d_functional_autograd.reduce_scatter_tensor(
- self,
- reduceOp.lower(),
- group_size,
- group_name, # type: ignore[possibly-undefined]
- )
- res = _FromTorchTensor.apply(tensor)
- return res
- def all_reduce_coalesced(
- self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
- ) -> List[torch.Tensor]:
- """
- Reduces a list of tensors across all machines in such a way that all get
- the final result.
- The all tensors in the input list are left unmodified.
- Group can be one of:
- List[int]: ranks participating in the collective.
- List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
- ProcessGroup: Will perform a collective using the ranks and tag of the PG.
- DeviceMesh: Do a SPMD collective over all ranks of the mesh
- (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
- :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
- that information and perform collective algebraic optimization. Use other forms of input for that.
- """
- group_name = _resolve_group_name(group, tag)
- tensor_list = torch.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined]
- self,
- reduceOp.lower(),
- group_name,
- )
- return list(map(_maybe_wrap_tensor, tensor_list))
- def all_gather_into_tensor_coalesced(
- self: List[torch.Tensor], group: RANK_TYPES, tag: str = ""
- ) -> List[torch.Tensor]:
- """
- Gather a list of tensors across from all machines.
- Note that it currently only supports gather_dim = 0.
- The input tensor is left unmodified.
- Group can be one of:
- List[int]: ranks participating in the collective.
- List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
- ProcessGroup: Will perform a collective using the ranks and tag of the PG.
- DeviceMesh: Do a SPMD collective over all ranks of the mesh
- (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
- :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
- that information and perform collective algebraic optimization. Use other forms of input for that.
- """
- group_name = _resolve_group_name(group, tag)
- group_size = c10d._get_group_size_by_name(group_name)
- tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined]
- self,
- group_size,
- group_name,
- )
- return list(map(_maybe_wrap_tensor, tensor_list))
- def reduce_scatter_tensor_coalesced(
- inputs: List[torch.Tensor],
- reduceOp: str,
- scatter_dim: List[int],
- group: RANK_TYPES,
- tag: str = "",
- ) -> List[torch.Tensor]:
- """
- Reduces a list of tensors across all machines in such a way that all get
- the final result, then scatter the results to corresponding ranks.
- The input tensors are left unmodified.
- Group can be one of:
- List[int]: ranks participating in the collective.
- List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
- ProcessGroup: Will perform a collective using the ranks and tag of the PG.
- DeviceMesh: Do a SPMD collective over all ranks of the mesh
- (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
- :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
- that information and perform collective algebraic optimization. Use other forms of input for that.
- """
- group_name = _resolve_group_name(group, tag)
- group_size = c10d._get_group_size_by_name(group_name)
- assert len(scatter_dim) == len(inputs)
- for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
- assert (
- tensor.size(dim) % group_size == 0
- ), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
- if dim != 0:
- tensor_list = torch.chunk(tensor, group_size, dim=dim)
- inputs[idx] = torch.cat(tensor_list)
- tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined]
- inputs,
- reduceOp.lower(),
- group_size,
- group_name, # type: ignore[possibly-undefined]
- )
- return list(map(_maybe_wrap_tensor, tensor_list))
- # This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias.
- # Today, this maps 1:1 with "aten ops that are views".
- def _is_view_op(tgt):
- assert isinstance(tgt, torch._ops.OpOverload)
- schema = tgt._schema
- if len(schema.arguments) > 0:
- first_arg = schema.arguments[0]
- # check if op is a view
- return first_arg.alias_info is not None and not first_arg.alias_info.is_write
- def all_to_all_single(
- self: torch.Tensor,
- output_split_sizes: Optional[List[int]],
- input_split_sizes: Optional[List[int]],
- group: RANK_TYPES,
- tag: str = "",
- ) -> torch.Tensor:
- """
- Each process splits input tensor and then scatters the split list
- to all processes in a group. Then concatenate the received tensors from all
- the processes in the group and return single output tensor.
- Group can be one of:
- List[int]: ranks participating in the collective.
- List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
- ProcessGroup: Will perform a collective using the ranks and tag of the PG.
- DeviceMesh: Do a SPMD collective over all ranks of the mesh
- (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
- :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
- that information and perform collective algebraic optimization. Use other forms of input for that.
- """
- if output_split_sizes is not None:
- assert all(
- isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
- ), output_split_sizes
- if input_split_sizes is not None:
- assert all(
- isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
- ), input_split_sizes
- group_name = _resolve_group_name(group, tag)
- group_size = c10d._get_group_size_by_name(group_name)
- if output_split_sizes is None or input_split_sizes is None:
- assert output_split_sizes is None and input_split_sizes is None, (
- "output_split_sizes and input_split_sizes must either be "
- "specified together or both set to None"
- )
- output_split_sizes = [self.shape[0] // group_size] * group_size
- input_split_sizes = output_split_sizes
- tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined]
- self,
- output_split_sizes,
- input_split_sizes,
- group_name,
- )
- return _maybe_wrap_tensor(tensor)
- def all_to_all_single_autograd(
- self: torch.Tensor,
- output_split_sizes: Optional[List[int]],
- input_split_sizes: Optional[List[int]],
- group: RANK_TYPES,
- tag: str = "",
- ) -> torch.Tensor:
- """
- Same as all_to_all_single but supports autograd.
- """
- if output_split_sizes is not None:
- assert all(
- isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
- ), output_split_sizes
- if input_split_sizes is not None:
- assert all(
- isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
- ), input_split_sizes
- group_name = _resolve_group_name(group, tag)
- group_size = c10d._get_group_size_by_name(group_name)
- if output_split_sizes is None or input_split_sizes is None:
- assert output_split_sizes is None and input_split_sizes is None, (
- "output_split_sizes and input_split_sizes must either be "
- "specified together or both set to None"
- )
- output_split_sizes = [self.shape[0] // group_size] * group_size
- input_split_sizes = output_split_sizes
- tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined]
- self,
- output_split_sizes,
- input_split_sizes,
- group_name,
- )
- return _FromTorchTensor.apply(tensor)
- def permute_tensor(
- self: torch.Tensor,
- src_dst: List[int],
- group: RANK_TYPES,
- tag: str = "",
- ) -> torch.Tensor:
- """
- Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should
- be defined such that src_dst[m] == n means m sends to n.
- Group can be one of:
- List[int]: ranks participating in the collective.
- List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
- ProcessGroup: Will perform a collective using the ranks and tag of the PG.
- DeviceMesh: Do a SPMD collective over all ranks of the mesh
- (DeviceMesh, int): Do a MPMD collective over one
- """
- t, rankset, group_size = _expand_group(group, tag)
- local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size)
- output_split_sizes = [0] * group_size
- input_split_sizes = [0] * group_size
- for src, dst in enumerate(src_dst):
- if src == dist.get_rank(local_pg):
- input_split_sizes[dst] = self.numel()
- if dst == dist.get_rank(local_pg):
- output_split_sizes[src] = self.numel()
- return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag)
- class AsyncCollectiveTensor(torch.Tensor):
- r"""
- A Tensor wrapper subclass that is used to trigger a call to wait
- prior to first use of the underlying tensor.
- Use it inside functional collective pytorch wrappers like the following:
- def functional_collective(self, group, tag):
- tag, rankset, group_size = _expand_group(group, tag)
- tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
- return _maybe_wrap_tensor(tensor)
- """
- elem: torch.Tensor
- completed: bool
- __slots__ = ["elem", "completed"]
- @staticmethod
- def __new__(cls, elem: torch.Tensor):
- r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
- cls,
- elem.size(),
- strides=elem.stride(),
- storage_offset=elem.storage_offset(),
- dtype=elem.dtype,
- layout=elem.layout,
- device=elem.device,
- requires_grad=elem.requires_grad,
- )
- r.elem = elem
- r.completed = False
- return r
- def __tensor_flatten__(self):
- return ["elem"], None
- def tolist(self):
- return self.trigger_wait().tolist()
- @staticmethod
- def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
- assert meta is None
- elem = inner_tensors["elem"]
- return AsyncCollectiveTensor(elem)
- def __repr__(self):
- return f"AsyncCollectiveTensor({self.trigger_wait()})"
- def trigger_wait(self):
- if not self.completed:
- out = wait_tensor(self.elem)
- self.completed = True
- return out
- else:
- return self.elem
- def wait(self) -> torch.Tensor:
- return wait_tensor(self.elem)
- def _get_acs_underlying_tensor(self):
- """This method enables _functional_collectives_impl to test if a tensor is an ACS"""
- return self.elem
- @classmethod
- def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- if func == torch.ops.aten.view.default:
- # Fast handle aten.view as a lot of view related op goes to aten.view
- # eventually, this avoids pytree slowdown
- res = func(args[0].elem, args[1])
- wrapper_res = AsyncCollectiveTensor(res)
- return wrapper_res
- is_view_op = _is_view_op(func)
- def unwrap(e: AsyncCollectiveTensor):
- # wait_tensor is idepotent and will do stream sync only once
- if not is_view_op:
- return e.trigger_wait()
- return e.elem
- def wrap(e: torch.Tensor):
- # wait_tensor is idepotent and will do stream sync only once
- assert not isinstance(e, AsyncCollectiveTensor)
- res = AsyncCollectiveTensor(e)
- return res
- unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
- unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs)
- # we don't wrap the result as it doesn't need to be waited on.
- out = func(*unwrapped_args, **unwrapped_kwargs)
- # View ops dont require a sync, so we should re-wrap the outputs.
- if is_view_op:
- out = tree_map_only(torch.Tensor, wrap, out)
- return out
- def numpy(self):
- return self.wait().numpy()
- """
- Utils and infrastructure for tracing support
- """
- def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]:
- """
- _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable.
- By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside
- torchdynamo and can still interoperate with processgroup objects or other untraceable forms.
- """
- # had to define this hack _inside_ expand_group to avoid
- # graph_break [('torch.* op returned non-Tensor int
- # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc)
- if TYPE_CHECKING:
- def cast_listlistint(x):
- return cast(List[List[int]], x)
- def cast_listint(x):
- return cast(List[int], x)
- else:
- # fake cast op for use at runtime since dynamo doesn't support real cast
- # also, dynamo didn't like encountering 'typing' objects ()
- # NotImplementedError: argument of type: <class 'typing._GenericAlias'>
- def cast_listlistint(x):
- return x
- def cast_listint(x):
- return x
- rankset: List[int]
- if isinstance(group, list):
- if isinstance(group[0], list):
- nested_list = cast_listlistint(group)
- rankset = []
- group_size = -1
- for rs in nested_list:
- rankset.extend(rs)
- if group_size != -1 and group_size != len(rs):
- raise ValueError(
- f"group sizes must be identical found {group_size} and {len(rs)}"
- )
- group_size = len(rs)
- else:
- rankset = cast_listint(group)
- group_size = len(rankset)
- elif isinstance(group, dist.ProcessGroup):
- rankset = dist.get_process_group_ranks(group)
- group_size = len(rankset)
- tag = tag or c10d._get_group_tag(group)
- elif isinstance(group, DeviceMesh):
- assert (
- group.ndim == 1
- ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
- # TODO: it should run collective in the whole mesh instead of dim 0
- tag, rankset, _ = group._dim_group_infos[0]
- group_size = len(rankset)
- elif isinstance(group, tuple):
- if (
- len(group) == 2
- and isinstance(group[0], DeviceMesh)
- and isinstance(group[1], int)
- ):
- dmesh = group[0]
- dim = group[1]
- tag, rankset, _ = dmesh._dim_group_infos[dim]
- group_size = len(rankset)
- else:
- raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
- else:
- raise ValueError(
- "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)."
- )
- return (tag, rankset, group_size)
- def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
- """
- Given group in RANK_TYPES, return the group name.
- """
- # `tag` will be deprecated. See details in:
- # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
- if isinstance(group, dist.ProcessGroup):
- return group.group_name
- elif isinstance(group, str):
- return group
- elif isinstance(group, DeviceMesh):
- assert (
- group.ndim == 1
- ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
- return group._dim_group_infos[0][2]
- elif isinstance(group, tuple):
- if (
- len(group) == 2
- and isinstance(group[0], DeviceMesh)
- and isinstance(group[1], int)
- ):
- dmesh = group[0]
- dim = group[1]
- return dmesh._dim_group_infos[dim][2]
- else:
- raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
- elif isinstance(group, list):
- if not is_torchdynamo_compiling():
- warnings.warn(
- "The combination of ranks + tag as process group "
- "identifier has been deprecated. Please switch to "
- "using ProcessGroup, DeviceMesh, or group name instead.",
- FutureWarning,
- stacklevel=3,
- )
- return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag)
- else:
- raise ValueError(f"Unsupported group type: {type(group)}, {group}")
- class _FromTorchTensor(torch.autograd.Function):
- """
- _FromTorchTensor allows autograd to propagate from a normal Tensor to an
- AsyncCollectiveTensor.
- """
- @staticmethod
- def forward( # type: ignore[override]
- ctx, # pyre-ignore[2]: Parameter must be annotated.
- input: torch.Tensor,
- ) -> torch.Tensor:
- return _maybe_wrap_tensor(input)
- @staticmethod
- def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # type: ignore[override]
- return grad_output
- def _are_we_tracing() -> bool:
- if is_torchdynamo_compiling():
- return True
- # If functionalization is turned on, we are almost definitely compiling/tracing.
- # (In particular, AOTAutograd traces a model once with functionalization on
- # but proxy tracing turned of, so this is how we detect it).
- if (
- torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
- is not None
- ):
- return True
- mode = get_innermost_proxy_mode()
- if mode is None:
- return False
- return mode.tracer is not None
- def _maybe_wrap_tensor(self) -> torch.Tensor:
- if _are_we_tracing():
- return wait_tensor(self)
- res = AsyncCollectiveTensor(self)
- return cast(torch.Tensor, res)
- def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size):
- def mk_out_tensor(shard):
- out_size = list(shard.size())
- out_size[0] *= group_size
- out_tensor = shard.new_empty(out_size)
- return out_tensor
- return [mk_out_tensor(t) for t in self]
- # We now register meta kernels to deal with tracing
- def _broadcast_meta(self, *args):
- return torch.empty_like(self)
- def _all_reduce_meta(self, *args):
- return torch.empty_like(self)
- def _wait_tensor_meta(self, *args):
- return torch.empty_like(self)
- def _all_gather_into_tensor_meta(shard, tag, rankset, group_size):
- out_size = list(shard.size())
- out_size[0] *= group_size
- return shard.new_empty(out_size)
- def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
- out_size = list(input.size())
- out_size[0] //= group_size
- return input.new_empty(out_size)
- def _all_reduce_coalesced_meta(self, *args):
- return [torch.empty_like(t) for t in self]
- def _all_reduce__meta(inp, *args):
- return inp
- def _broadcast__meta(inp, *args):
- return inp
- def _all_reduce_coalesced__meta(inputs, *args):
- return inputs
- def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
- def mk_out_tensor(input):
- out_size = list(input.size())
- out_size[0] //= group_size
- out_tensor = input.new_empty(out_size)
- return out_tensor
- return [mk_out_tensor(t) for t in inputs]
- # NB: We often say all_to_all has dynamic output size, but this is not
- # technically true: instead, what typically happens is you manually
- # communicate the output_split_sizes ahead of time (which is dynamic),
- # but then you pass those sizes explicitly, and the all to all itself
- # isn't dynamic, it just follows the specified output splits
- def _all_to_all_single_meta(
- input, output_split_sizes, input_split_sizes, *args, **kwargs
- ):
- if output_split_sizes is None:
- return input.new_empty(input.size())
- else:
- for s in output_split_sizes:
- torch._check_is_size(s)
- out_size = list(input.size())
- out_size[0] = sum(output_split_sizes)
- return input.new_empty(out_size)
- def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out):
- shape = list(input.size())
- shape[0] *= group_size
- return input.new_empty(shape)
- def _all_gather_into_tensor_native_meta(input, group_size, group_name):
- shape = list(input.size())
- shape[0] *= group_size
- return input.new_empty(shape)
- def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name):
- return [
- _all_gather_into_tensor_native_meta(input, group_size, group_name)
- for input in inputs
- ]
- def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name):
- shape = list(inp.size())
- shape[0] //= group_size
- return inp.new_empty(shape)
- def _reduce_scatter_tensor_coalesced_native_meta(
- inputs, reduce_op, group_size, group_name
- ):
- return [
- _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name)
- for inp in inputs
- ]
- if not torch._running_with_deploy():
- # Library MUST be defined at module scope or it doesn't work
- # Creating a "DEF" Library always crashes torch::deploy so we create our
- # Library instances here guarded against running inside it
- lib_impl = torch.library.Library("_c10d_functional", "IMPL")
- lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
- lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
- lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
- lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
- lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
- lib_impl.impl(
- "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta"
- )
- lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
- lib_impl.impl(
- "all_gather_into_tensor_coalesced",
- _all_gather_into_tensor_coalesced_native_meta,
- "Meta",
- )
- lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta")
- lib_impl.impl(
- "reduce_scatter_tensor_coalesced",
- _reduce_scatter_tensor_coalesced_native_meta,
- "Meta",
- )
- lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
- lib_impl.impl("broadcast", _broadcast_meta, "Meta")
- lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
- # Register legacy ops for backward compatibility
- # TODO(yifu): remove these in functional collective beta release
- legacy_lib = torch.library.Library("c10d_functional", "DEF")
- legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL")
- ops_defs = [
- "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",
- "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
- "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
- "wait_tensor(Tensor self) -> Tensor",
- "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",
- "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]",
- "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
- "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
- "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950
- ]
- my_module = sys.modules[__name__]
- for op_def in ops_defs:
- op_name = op_def[0 : op_def.index("(")]
- backend_impl = getattr(fun_col_impl, f"_{op_name}")
- legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)
- legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd")
- else:
- warnings.warn(
- "PyTorch Distributed functional collectives do not work with torch::deploy."
- )
- """
- Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into
- functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph.
- We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via
- the mapping dict below.
- These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from
- """
- def all_gather_tensor_inplace(
- output_tensor: torch.Tensor,
- input_tensor: torch.Tensor,
- group, # TODO add a type,
- async_op: bool = False,
- tag: str = "",
- gather_dim: int = 0,
- ):
- assert (
- not async_op
- ), "Can't remap async version of inplace op to functional collective"
- group = group or dist.group.WORLD
- assert group is not None
- return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
- def reduce_scatter_tensor_inplace(
- output: torch.Tensor,
- input: torch.Tensor,
- op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok?
- group=None, # TODO add a type
- async_op: bool = False,
- scatter_dim: int = 0,
- tag: str = "",
- ):
- assert (
- not async_op
- ), "Can't remap async version of inplace op to functional collective"
- group = group or dist.group.WORLD
- assert group is not None
- return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))
- REDUCE_OP_TO_STR = {
- dist.ReduceOp.SUM: "sum",
- dist.ReduceOp.AVG: "avg",
- dist.ReduceOp.PRODUCT: "product",
- dist.ReduceOp.MIN: "min",
- dist.ReduceOp.MAX: "max",
- dist.ReduceOp.BAND: "band",
- dist.ReduceOp.BOR: "bor",
- dist.ReduceOp.BXOR: "bxor",
- }
- def all_reduce_inplace(
- tensor: torch.Tensor,
- op: str = "sum",
- group=None,
- async_op: bool = False,
- tag: str = "",
- ):
- assert (
- not async_op
- ), "Can't remap async version of inplace op to functional collective"
- group = group or dist.group.WORLD
- assert group is not None
- return tensor.copy_(all_reduce(tensor, op, group, tag))
- def all_to_all_inplace(
- output: torch.Tensor,
- input: torch.Tensor,
- output_split_sizes=None,
- input_split_sizes=None,
- group=None,
- async_op=False,
- tag: str = "",
- ):
- assert (
- not async_op
- ), "Can't remap async version of inplace op to functional collective"
- group = group or dist.group.WORLD
- assert group is not None
- return output.copy_(
- all_to_all_single(
- input,
- output_split_sizes,
- input_split_sizes,
- group,
- tag,
- )
- )
- def all_gather_inplace(
- tensor_list: List[torch.Tensor],
- tensor: torch.Tensor,
- group=None,
- async_op=False,
- tag: str = "",
- ):
- assert (
- not async_op
- ), "Can't remap async version of inplace op to functional collective"
- assert all(
- t.size(0) == tensor.size(0) for t in tensor_list
- ), "Remapping variable size all_gather is not yet supported"
- group = group or dist.group.WORLD
- assert group is not None
- output = all_gather_tensor(tensor, 0, group, tag)
- # Use aten.slice instead of aten.split because the latter causes
- # tensor.shape(0) to be unnecessarily baked in when it's a SymInt.
- output_splits = []
- offset = 0
- for t in tensor_list:
- output_splits.append(output[offset : offset + t.size(0)])
- offset += t.size(0)
- for dst, src in zip(tensor_list, output_splits):
- dst.copy_(src)
- return tensor_list
- from torch.distributed.distributed_c10d import (
- _all_gather_base as legacy_all_gather_base,
- _reduce_scatter_base as legacy_reduce_scatter_base,
- all_gather as legacy_all_gather,
- all_gather_into_tensor as legacy_allgather,
- all_reduce as legacy_allreduce,
- all_to_all_single as legacy_all_to_all_single,
- reduce_scatter_tensor as legacy_reducescatter,
- )
- # This dict should contain sets of functions that dynamo is allowed to remap.
- # Functions in this set should accept the same args/kwargs 1:1 as their mapping.
- traceable_collective_remaps = {
- legacy_allgather: all_gather_tensor_inplace,
- legacy_reducescatter: reduce_scatter_tensor_inplace,
- legacy_allreduce: all_reduce_inplace,
- legacy_all_to_all_single: all_to_all_inplace,
- legacy_all_gather: all_gather_inplace,
- legacy_reduce_scatter_base: reduce_scatter_tensor_inplace,
- legacy_all_gather_base: all_gather_tensor_inplace,
- }
|