| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- # mypy: allow-untyped-defs
- from typing import List, Optional
- import torch
- import torch.distributed.distributed_c10d as c10d
- """
- This file contains the op impls for the legacy (c10d_functional) functional collectives.
- These impls simply call into the native (_c10d_functional) functional collectives.
- """
- def _broadcast(input, src, tag, ranks, group_size):
- group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
- return torch.ops._c10d_functional.broadcast(
- input,
- src,
- group_name,
- )
- def _all_reduce(input, reduce_op, tag, ranks, group_size):
- group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
- return torch.ops._c10d_functional.all_reduce(
- input,
- reduce_op,
- group_name,
- )
- def _all_reduce_coalesced(inputs, reduce_op, tag, ranks, group_size):
- group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
- return torch.ops._c10d_functional.all_reduce_coalesced(
- inputs,
- reduce_op,
- group_name,
- )
- def _all_gather_into_tensor(input, tag, ranks, group_size):
- group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
- return torch.ops._c10d_functional.all_gather_into_tensor(
- input,
- group_size,
- group_name,
- )
- def _all_gather_into_tensor_coalesced(input, tag, ranks, group_size):
- group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
- return torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
- input,
- group_size,
- group_name,
- )
- def _reduce_scatter_tensor(
- input: torch.Tensor,
- reduce_op: str,
- tag: str,
- ranks: List[int],
- group_size: int,
- ):
- group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
- return torch.ops._c10d_functional.reduce_scatter_tensor(
- input,
- reduce_op,
- group_size,
- group_name,
- )
- def _reduce_scatter_tensor_coalesced(
- inputs: List[torch.Tensor],
- reduce_op: str,
- tag: str,
- ranks: List[int],
- group_size: int,
- ):
- group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
- return torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
- inputs,
- reduce_op,
- group_size,
- group_name,
- )
- def _all_to_all_single(
- input: torch.Tensor,
- output_split_sizes: Optional[List[int]],
- input_split_sizes: Optional[List[int]],
- tag: str,
- ranks: List[int],
- group_size: int,
- ):
- if output_split_sizes is None or input_split_sizes is None:
- assert output_split_sizes is None and input_split_sizes is None, (
- "output_split_sizes and input_split_sizes must either be "
- "specified together or both set to None"
- )
- output_split_sizes = [input.shape[0] // group_size] * group_size
- input_split_sizes = output_split_sizes
- group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
- return torch.ops._c10d_functional.all_to_all_single(
- input,
- output_split_sizes,
- input_split_sizes,
- group_name,
- )
- def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor:
- return torch.ops._c10d_functional.wait_tensor(tensor)
|