_functional_collectives_impl.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # mypy: allow-untyped-defs
  2. from typing import List, Optional
  3. import torch
  4. import torch.distributed.distributed_c10d as c10d
  5. """
  6. This file contains the op impls for the legacy (c10d_functional) functional collectives.
  7. These impls simply call into the native (_c10d_functional) functional collectives.
  8. """
  9. def _broadcast(input, src, tag, ranks, group_size):
  10. group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
  11. return torch.ops._c10d_functional.broadcast(
  12. input,
  13. src,
  14. group_name,
  15. )
  16. def _all_reduce(input, reduce_op, tag, ranks, group_size):
  17. group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
  18. return torch.ops._c10d_functional.all_reduce(
  19. input,
  20. reduce_op,
  21. group_name,
  22. )
  23. def _all_reduce_coalesced(inputs, reduce_op, tag, ranks, group_size):
  24. group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
  25. return torch.ops._c10d_functional.all_reduce_coalesced(
  26. inputs,
  27. reduce_op,
  28. group_name,
  29. )
  30. def _all_gather_into_tensor(input, tag, ranks, group_size):
  31. group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
  32. return torch.ops._c10d_functional.all_gather_into_tensor(
  33. input,
  34. group_size,
  35. group_name,
  36. )
  37. def _all_gather_into_tensor_coalesced(input, tag, ranks, group_size):
  38. group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
  39. return torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
  40. input,
  41. group_size,
  42. group_name,
  43. )
  44. def _reduce_scatter_tensor(
  45. input: torch.Tensor,
  46. reduce_op: str,
  47. tag: str,
  48. ranks: List[int],
  49. group_size: int,
  50. ):
  51. group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
  52. return torch.ops._c10d_functional.reduce_scatter_tensor(
  53. input,
  54. reduce_op,
  55. group_size,
  56. group_name,
  57. )
  58. def _reduce_scatter_tensor_coalesced(
  59. inputs: List[torch.Tensor],
  60. reduce_op: str,
  61. tag: str,
  62. ranks: List[int],
  63. group_size: int,
  64. ):
  65. group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
  66. return torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
  67. inputs,
  68. reduce_op,
  69. group_size,
  70. group_name,
  71. )
  72. def _all_to_all_single(
  73. input: torch.Tensor,
  74. output_split_sizes: Optional[List[int]],
  75. input_split_sizes: Optional[List[int]],
  76. tag: str,
  77. ranks: List[int],
  78. group_size: int,
  79. ):
  80. if output_split_sizes is None or input_split_sizes is None:
  81. assert output_split_sizes is None and input_split_sizes is None, (
  82. "output_split_sizes and input_split_sizes must either be "
  83. "specified together or both set to None"
  84. )
  85. output_split_sizes = [input.shape[0] // group_size] * group_size
  86. input_split_sizes = output_split_sizes
  87. group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
  88. return torch.ops._c10d_functional.all_to_all_single(
  89. input,
  90. output_split_sizes,
  91. input_split_sizes,
  92. group_name,
  93. )
  94. def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor:
  95. return torch.ops._c10d_functional.wait_tensor(tensor)