_functional_collectives.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147
  1. # mypy: allow-untyped-defs
  2. import sys
  3. import warnings
  4. from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
  5. import torch
  6. import torch.distributed as dist
  7. import torch.distributed.distributed_c10d as c10d
  8. from torch.distributed.device_mesh import DeviceMesh
  9. from torch.fx.experimental.proxy_tensor import get_innermost_proxy_mode
  10. from . import _functional_collectives_impl as fun_col_impl
  11. try:
  12. from torch.utils._cxx_pytree import tree_map_only
  13. except ImportError:
  14. from torch.utils._pytree import tree_map_only # type: ignore[no-redef]
  15. if torch._running_with_deploy():
  16. def is_torchdynamo_compiling():
  17. """Can't import torchdynamo in torchdeploy builds currently."""
  18. return False
  19. else:
  20. try:
  21. from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
  22. except Exception:
  23. warnings.warn(
  24. "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
  25. )
  26. def is_torchdynamo_compiling():
  27. return False
  28. """
  29. New traceable, functional collectives.
  30. RFC: https://github.com/pytorch/pytorch/issues/93173
  31. compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
  32. eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
  33. automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
  34. a downstream op.
  35. Issues:
  36. * Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
  37. * Proper support for eager requires inplace ops. We should explore having it as an option for the API.
  38. """
  39. """
  40. Functional collectives are asynchronous only and we perform implicit stream synchronization
  41. on behalf of the user.
  42. We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
  43. first usage of the tensor and insert cross stream sync at the right place.
  44. The above are the easy bits, the hard one is how we match the Work object returned by
  45. c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
  46. op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
  47. dispatcher which might call other implementations that are allowed to change the returned
  48. tensor - even return a tensor with a different shape (see ``torch.vmap``).
  49. This means the caller of our ops receives a Tensor that is not guaranteed to be the same
  50. allocated by our implementations and that makes pairing The AsyncTensor to the original
  51. tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
  52. Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
  53. identity is not stable across dispatch, the op caller would end up with a different Tensor
  54. instance that would not match any in the dictionary.
  55. With Tensor identity out of the question, we decided use the tensor data pointer, which
  56. should be stable across all the Tensor changes done during dispatch.
  57. We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
  58. We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
  59. Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
  60. can clean up stale entries in the dictionary.
  61. To eliminate the possibility of races we have a global version counter that is used by the finalizer.
  62. As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
  63. """
  64. """
  65. Functional collectives can accept any of these types to describe the ranks participating in collectives.
  66. The different types will be desugared to a canonical format
  67. """
  68. RANK_TYPES = Union[
  69. List[int],
  70. List[List[int]],
  71. dist.ProcessGroup,
  72. DeviceMesh,
  73. Tuple["dist._tensor.DeviceMesh", int],
  74. str,
  75. ]
  76. """
  77. User facing APIs for functional collectives
  78. -------------------------------------------
  79. These apis are called by user code and expected to work both in eager execution and compilation,
  80. but there are significant differences to how the two modes are implemented underneath.
  81. Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op)
  82. just before the tensor is first used. Compiled tracing currently relies on the compiler to perform this optimization,
  83. and cannot yet correctly trace the AsyncTensor wrapper class. In the future, these paths may be unified
  84. if sufficient subclass support is added in dynamo.
  85. Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern.
  86. Here's how it works under torch.compile/dynamo:
  87. all_reduce(...)
  88. |--> _expand_group(...) - desugars processgroup into canonical/traceable format
  89. |--> c10d_functional.all_reduce(...) - dynamo captures this op call, doesn't trace deeper
  90. |--> _maybe_wrap_tensor(...) - wait_tensor() op is immediately called, no AsyncTensor subclass needed
  91. And under eager execution:
  92. all_reduce(...)
  93. |--> _expand_group(...) - same as above, but less critical for eager
  94. |--> c10d_functional.all_reduce(...) - dispatches to real kernel OR records op in trace
  95. |--> _maybe_wrap_tensor(...) - AsyncTensor wrapper applied to returned tensor,
  96. which issues wait_tensor() at the time of first use
  97. """
  98. def wait_tensor(tensor):
  99. """
  100. Wait on a tensor returned by the collectives ops.
  101. Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
  102. """
  103. return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
  104. def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""):
  105. """
  106. Broadcasts the tensor to all processes in the given process group.
  107. Args:
  108. src (int): Source rank
  109. group (ProcessGroup or List[int]): The process group to work on.
  110. tag (str, optional): A unique identifier for the collective. Default: empty string
  111. """
  112. group_name = _resolve_group_name(group, tag)
  113. tensor = torch.ops._c10d_functional.broadcast(self, src, group_name)
  114. return _maybe_wrap_tensor(tensor)
  115. def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
  116. """
  117. Reduces the tensor data across all machines in such a way that all get
  118. the final result.
  119. The input tensor is left unmodified.
  120. Group can be one of:
  121. List[int]: ranks participating in the collective.
  122. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  123. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  124. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  125. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  126. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  127. that information and perform collective algebraic optimization. Use other forms of input for that.
  128. """
  129. group_name = _resolve_group_name(group, tag)
  130. tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
  131. return _maybe_wrap_tensor(tensor)
  132. def all_gather_tensor(
  133. self: torch.Tensor,
  134. gather_dim: int,
  135. group: RANK_TYPES,
  136. tag: str = "",
  137. ):
  138. """
  139. Gather tensor data across from all machines and concatenate over ``gather_dim``.
  140. Note that it currently only supports gather_dim = 0.
  141. The input tensor is left unmodified.
  142. Group can be one of:
  143. List[int]: ranks participating in the collective.
  144. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  145. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  146. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  147. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  148. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  149. that information and perform collective algebraic optimization. Use other forms of input for that.
  150. """
  151. assert self.is_contiguous()
  152. group_name = _resolve_group_name(group, tag)
  153. group_size = c10d._get_group_size_by_name(group_name)
  154. tensor = torch.ops._c10d_functional.all_gather_into_tensor(
  155. self, group_size, group_name
  156. )
  157. res = _maybe_wrap_tensor(tensor)
  158. # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
  159. if gather_dim != 0:
  160. # torch.cat access the data so we already need to wait here, first do wait
  161. # and then chunk + cat avoid us going through ACT dispatching logic again
  162. if isinstance(res, AsyncCollectiveTensor):
  163. res = res.wait() # type: ignore[attr-defined]
  164. res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
  165. return res
  166. def all_gather_tensor_autograd(
  167. self: torch.Tensor,
  168. gather_dim: int,
  169. group: RANK_TYPES,
  170. tag: str = "",
  171. ):
  172. """
  173. Gather tensor data across from all machines and concatenate over ``gather_dim``.
  174. Note that it currently only supports gather_dim = 0.
  175. This function is the same as all_gather_tensor but will propagate the
  176. backwards gradient across workers.
  177. See all_gather_tensor for more details on usage.
  178. """
  179. group_name = _resolve_group_name(group, tag)
  180. group_size = c10d._get_group_size_by_name(group_name)
  181. tensor = torch.ops._c10d_functional_autograd.all_gather_into_tensor(
  182. self, group_size, group_name
  183. )
  184. res = _FromTorchTensor.apply(tensor)
  185. # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
  186. if gather_dim != 0:
  187. # torch.cat access the data so we already need to wait here, first do wait
  188. # and then chunk + cat avoid us going through ACT dispatching logic again
  189. if isinstance(res, AsyncCollectiveTensor):
  190. res = res.wait() # type: ignore[attr-defined]
  191. res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
  192. return res
  193. def reduce_scatter_tensor(
  194. self: torch.Tensor,
  195. reduceOp: str,
  196. scatter_dim: int,
  197. group: RANK_TYPES,
  198. tag: str = "",
  199. ):
  200. """
  201. Reduces the tensor data across all machines in such a way that all get
  202. the final result, then scatter the results to corresponding ranks.
  203. The input tensor is left unmodified.
  204. Group can be one of:
  205. List[int]: ranks participating in the collective.
  206. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  207. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  208. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  209. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  210. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  211. that information and perform collective algebraic optimization. Use other forms of input for that.
  212. """
  213. group_name = _resolve_group_name(group, tag)
  214. group_size = c10d._get_group_size_by_name(group_name)
  215. assert (
  216. self.size(scatter_dim) % group_size == 0
  217. ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
  218. if scatter_dim != 0:
  219. tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
  220. self = torch.cat(tensor_list)
  221. tensor = torch.ops._c10d_functional.reduce_scatter_tensor(
  222. self,
  223. reduceOp.lower(),
  224. group_size,
  225. group_name, # type: ignore[possibly-undefined]
  226. )
  227. res = _maybe_wrap_tensor(tensor)
  228. return res
  229. def reduce_scatter_tensor_autograd(
  230. self: torch.Tensor,
  231. reduceOp: str,
  232. scatter_dim: int,
  233. group: RANK_TYPES,
  234. tag: str = "",
  235. ):
  236. """
  237. Reduces the tensor data across all machines in such a way that all get
  238. the final result, then scatter the results to corresponding ranks.
  239. This function is the same as reduce_scatter_tensor but will propagate the
  240. backwards gradient across workers.
  241. Currently only the "sum" reduceOp is supported.
  242. See reduce_scatter_tensor for more details on usage.
  243. """
  244. group_name = _resolve_group_name(group, tag)
  245. group_size = c10d._get_group_size_by_name(group_name)
  246. assert (
  247. self.size(scatter_dim) % group_size == 0
  248. ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
  249. if scatter_dim != 0:
  250. tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
  251. self = torch.cat(tensor_list)
  252. tensor = torch.ops._c10d_functional_autograd.reduce_scatter_tensor(
  253. self,
  254. reduceOp.lower(),
  255. group_size,
  256. group_name, # type: ignore[possibly-undefined]
  257. )
  258. res = _FromTorchTensor.apply(tensor)
  259. return res
  260. def all_reduce_coalesced(
  261. self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
  262. ) -> List[torch.Tensor]:
  263. """
  264. Reduces a list of tensors across all machines in such a way that all get
  265. the final result.
  266. The all tensors in the input list are left unmodified.
  267. Group can be one of:
  268. List[int]: ranks participating in the collective.
  269. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  270. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  271. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  272. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  273. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  274. that information and perform collective algebraic optimization. Use other forms of input for that.
  275. """
  276. group_name = _resolve_group_name(group, tag)
  277. tensor_list = torch.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined]
  278. self,
  279. reduceOp.lower(),
  280. group_name,
  281. )
  282. return list(map(_maybe_wrap_tensor, tensor_list))
  283. def all_gather_into_tensor_coalesced(
  284. self: List[torch.Tensor], group: RANK_TYPES, tag: str = ""
  285. ) -> List[torch.Tensor]:
  286. """
  287. Gather a list of tensors across from all machines.
  288. Note that it currently only supports gather_dim = 0.
  289. The input tensor is left unmodified.
  290. Group can be one of:
  291. List[int]: ranks participating in the collective.
  292. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  293. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  294. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  295. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  296. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  297. that information and perform collective algebraic optimization. Use other forms of input for that.
  298. """
  299. group_name = _resolve_group_name(group, tag)
  300. group_size = c10d._get_group_size_by_name(group_name)
  301. tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined]
  302. self,
  303. group_size,
  304. group_name,
  305. )
  306. return list(map(_maybe_wrap_tensor, tensor_list))
  307. def reduce_scatter_tensor_coalesced(
  308. inputs: List[torch.Tensor],
  309. reduceOp: str,
  310. scatter_dim: List[int],
  311. group: RANK_TYPES,
  312. tag: str = "",
  313. ) -> List[torch.Tensor]:
  314. """
  315. Reduces a list of tensors across all machines in such a way that all get
  316. the final result, then scatter the results to corresponding ranks.
  317. The input tensors are left unmodified.
  318. Group can be one of:
  319. List[int]: ranks participating in the collective.
  320. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  321. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  322. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  323. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  324. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  325. that information and perform collective algebraic optimization. Use other forms of input for that.
  326. """
  327. group_name = _resolve_group_name(group, tag)
  328. group_size = c10d._get_group_size_by_name(group_name)
  329. assert len(scatter_dim) == len(inputs)
  330. for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
  331. assert (
  332. tensor.size(dim) % group_size == 0
  333. ), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
  334. if dim != 0:
  335. tensor_list = torch.chunk(tensor, group_size, dim=dim)
  336. inputs[idx] = torch.cat(tensor_list)
  337. tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined]
  338. inputs,
  339. reduceOp.lower(),
  340. group_size,
  341. group_name, # type: ignore[possibly-undefined]
  342. )
  343. return list(map(_maybe_wrap_tensor, tensor_list))
  344. # This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias.
  345. # Today, this maps 1:1 with "aten ops that are views".
  346. def _is_view_op(tgt):
  347. assert isinstance(tgt, torch._ops.OpOverload)
  348. schema = tgt._schema
  349. if len(schema.arguments) > 0:
  350. first_arg = schema.arguments[0]
  351. # check if op is a view
  352. return first_arg.alias_info is not None and not first_arg.alias_info.is_write
  353. def all_to_all_single(
  354. self: torch.Tensor,
  355. output_split_sizes: Optional[List[int]],
  356. input_split_sizes: Optional[List[int]],
  357. group: RANK_TYPES,
  358. tag: str = "",
  359. ) -> torch.Tensor:
  360. """
  361. Each process splits input tensor and then scatters the split list
  362. to all processes in a group. Then concatenate the received tensors from all
  363. the processes in the group and return single output tensor.
  364. Group can be one of:
  365. List[int]: ranks participating in the collective.
  366. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  367. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  368. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  369. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  370. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  371. that information and perform collective algebraic optimization. Use other forms of input for that.
  372. """
  373. if output_split_sizes is not None:
  374. assert all(
  375. isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
  376. ), output_split_sizes
  377. if input_split_sizes is not None:
  378. assert all(
  379. isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
  380. ), input_split_sizes
  381. group_name = _resolve_group_name(group, tag)
  382. group_size = c10d._get_group_size_by_name(group_name)
  383. if output_split_sizes is None or input_split_sizes is None:
  384. assert output_split_sizes is None and input_split_sizes is None, (
  385. "output_split_sizes and input_split_sizes must either be "
  386. "specified together or both set to None"
  387. )
  388. output_split_sizes = [self.shape[0] // group_size] * group_size
  389. input_split_sizes = output_split_sizes
  390. tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined]
  391. self,
  392. output_split_sizes,
  393. input_split_sizes,
  394. group_name,
  395. )
  396. return _maybe_wrap_tensor(tensor)
  397. def all_to_all_single_autograd(
  398. self: torch.Tensor,
  399. output_split_sizes: Optional[List[int]],
  400. input_split_sizes: Optional[List[int]],
  401. group: RANK_TYPES,
  402. tag: str = "",
  403. ) -> torch.Tensor:
  404. """
  405. Same as all_to_all_single but supports autograd.
  406. """
  407. if output_split_sizes is not None:
  408. assert all(
  409. isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
  410. ), output_split_sizes
  411. if input_split_sizes is not None:
  412. assert all(
  413. isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
  414. ), input_split_sizes
  415. group_name = _resolve_group_name(group, tag)
  416. group_size = c10d._get_group_size_by_name(group_name)
  417. if output_split_sizes is None or input_split_sizes is None:
  418. assert output_split_sizes is None and input_split_sizes is None, (
  419. "output_split_sizes and input_split_sizes must either be "
  420. "specified together or both set to None"
  421. )
  422. output_split_sizes = [self.shape[0] // group_size] * group_size
  423. input_split_sizes = output_split_sizes
  424. tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined]
  425. self,
  426. output_split_sizes,
  427. input_split_sizes,
  428. group_name,
  429. )
  430. return _FromTorchTensor.apply(tensor)
  431. def permute_tensor(
  432. self: torch.Tensor,
  433. src_dst: List[int],
  434. group: RANK_TYPES,
  435. tag: str = "",
  436. ) -> torch.Tensor:
  437. """
  438. Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should
  439. be defined such that src_dst[m] == n means m sends to n.
  440. Group can be one of:
  441. List[int]: ranks participating in the collective.
  442. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  443. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  444. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  445. (DeviceMesh, int): Do a MPMD collective over one
  446. """
  447. t, rankset, group_size = _expand_group(group, tag)
  448. local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size)
  449. output_split_sizes = [0] * group_size
  450. input_split_sizes = [0] * group_size
  451. for src, dst in enumerate(src_dst):
  452. if src == dist.get_rank(local_pg):
  453. input_split_sizes[dst] = self.numel()
  454. if dst == dist.get_rank(local_pg):
  455. output_split_sizes[src] = self.numel()
  456. return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag)
  457. class AsyncCollectiveTensor(torch.Tensor):
  458. r"""
  459. A Tensor wrapper subclass that is used to trigger a call to wait
  460. prior to first use of the underlying tensor.
  461. Use it inside functional collective pytorch wrappers like the following:
  462. def functional_collective(self, group, tag):
  463. tag, rankset, group_size = _expand_group(group, tag)
  464. tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
  465. return _maybe_wrap_tensor(tensor)
  466. """
  467. elem: torch.Tensor
  468. completed: bool
  469. __slots__ = ["elem", "completed"]
  470. @staticmethod
  471. def __new__(cls, elem: torch.Tensor):
  472. r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
  473. cls,
  474. elem.size(),
  475. strides=elem.stride(),
  476. storage_offset=elem.storage_offset(),
  477. dtype=elem.dtype,
  478. layout=elem.layout,
  479. device=elem.device,
  480. requires_grad=elem.requires_grad,
  481. )
  482. r.elem = elem
  483. r.completed = False
  484. return r
  485. def __tensor_flatten__(self):
  486. return ["elem"], None
  487. def tolist(self):
  488. return self.trigger_wait().tolist()
  489. @staticmethod
  490. def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
  491. assert meta is None
  492. elem = inner_tensors["elem"]
  493. return AsyncCollectiveTensor(elem)
  494. def __repr__(self):
  495. return f"AsyncCollectiveTensor({self.trigger_wait()})"
  496. def trigger_wait(self):
  497. if not self.completed:
  498. out = wait_tensor(self.elem)
  499. self.completed = True
  500. return out
  501. else:
  502. return self.elem
  503. def wait(self) -> torch.Tensor:
  504. return wait_tensor(self.elem)
  505. def _get_acs_underlying_tensor(self):
  506. """This method enables _functional_collectives_impl to test if a tensor is an ACS"""
  507. return self.elem
  508. @classmethod
  509. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  510. if func == torch.ops.aten.view.default:
  511. # Fast handle aten.view as a lot of view related op goes to aten.view
  512. # eventually, this avoids pytree slowdown
  513. res = func(args[0].elem, args[1])
  514. wrapper_res = AsyncCollectiveTensor(res)
  515. return wrapper_res
  516. is_view_op = _is_view_op(func)
  517. def unwrap(e: AsyncCollectiveTensor):
  518. # wait_tensor is idepotent and will do stream sync only once
  519. if not is_view_op:
  520. return e.trigger_wait()
  521. return e.elem
  522. def wrap(e: torch.Tensor):
  523. # wait_tensor is idepotent and will do stream sync only once
  524. assert not isinstance(e, AsyncCollectiveTensor)
  525. res = AsyncCollectiveTensor(e)
  526. return res
  527. unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
  528. unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs)
  529. # we don't wrap the result as it doesn't need to be waited on.
  530. out = func(*unwrapped_args, **unwrapped_kwargs)
  531. # View ops dont require a sync, so we should re-wrap the outputs.
  532. if is_view_op:
  533. out = tree_map_only(torch.Tensor, wrap, out)
  534. return out
  535. def numpy(self):
  536. return self.wait().numpy()
  537. """
  538. Utils and infrastructure for tracing support
  539. """
  540. def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]:
  541. """
  542. _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable.
  543. By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside
  544. torchdynamo and can still interoperate with processgroup objects or other untraceable forms.
  545. """
  546. # had to define this hack _inside_ expand_group to avoid
  547. # graph_break [('torch.* op returned non-Tensor int
  548. # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc)
  549. if TYPE_CHECKING:
  550. def cast_listlistint(x):
  551. return cast(List[List[int]], x)
  552. def cast_listint(x):
  553. return cast(List[int], x)
  554. else:
  555. # fake cast op for use at runtime since dynamo doesn't support real cast
  556. # also, dynamo didn't like encountering 'typing' objects ()
  557. # NotImplementedError: argument of type: <class 'typing._GenericAlias'>
  558. def cast_listlistint(x):
  559. return x
  560. def cast_listint(x):
  561. return x
  562. rankset: List[int]
  563. if isinstance(group, list):
  564. if isinstance(group[0], list):
  565. nested_list = cast_listlistint(group)
  566. rankset = []
  567. group_size = -1
  568. for rs in nested_list:
  569. rankset.extend(rs)
  570. if group_size != -1 and group_size != len(rs):
  571. raise ValueError(
  572. f"group sizes must be identical found {group_size} and {len(rs)}"
  573. )
  574. group_size = len(rs)
  575. else:
  576. rankset = cast_listint(group)
  577. group_size = len(rankset)
  578. elif isinstance(group, dist.ProcessGroup):
  579. rankset = dist.get_process_group_ranks(group)
  580. group_size = len(rankset)
  581. tag = tag or c10d._get_group_tag(group)
  582. elif isinstance(group, DeviceMesh):
  583. assert (
  584. group.ndim == 1
  585. ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
  586. # TODO: it should run collective in the whole mesh instead of dim 0
  587. tag, rankset, _ = group._dim_group_infos[0]
  588. group_size = len(rankset)
  589. elif isinstance(group, tuple):
  590. if (
  591. len(group) == 2
  592. and isinstance(group[0], DeviceMesh)
  593. and isinstance(group[1], int)
  594. ):
  595. dmesh = group[0]
  596. dim = group[1]
  597. tag, rankset, _ = dmesh._dim_group_infos[dim]
  598. group_size = len(rankset)
  599. else:
  600. raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
  601. else:
  602. raise ValueError(
  603. "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)."
  604. )
  605. return (tag, rankset, group_size)
  606. def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
  607. """
  608. Given group in RANK_TYPES, return the group name.
  609. """
  610. # `tag` will be deprecated. See details in:
  611. # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
  612. if isinstance(group, dist.ProcessGroup):
  613. return group.group_name
  614. elif isinstance(group, str):
  615. return group
  616. elif isinstance(group, DeviceMesh):
  617. assert (
  618. group.ndim == 1
  619. ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
  620. return group._dim_group_infos[0][2]
  621. elif isinstance(group, tuple):
  622. if (
  623. len(group) == 2
  624. and isinstance(group[0], DeviceMesh)
  625. and isinstance(group[1], int)
  626. ):
  627. dmesh = group[0]
  628. dim = group[1]
  629. return dmesh._dim_group_infos[dim][2]
  630. else:
  631. raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
  632. elif isinstance(group, list):
  633. if not is_torchdynamo_compiling():
  634. warnings.warn(
  635. "The combination of ranks + tag as process group "
  636. "identifier has been deprecated. Please switch to "
  637. "using ProcessGroup, DeviceMesh, or group name instead.",
  638. FutureWarning,
  639. stacklevel=3,
  640. )
  641. return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag)
  642. else:
  643. raise ValueError(f"Unsupported group type: {type(group)}, {group}")
  644. class _FromTorchTensor(torch.autograd.Function):
  645. """
  646. _FromTorchTensor allows autograd to propagate from a normal Tensor to an
  647. AsyncCollectiveTensor.
  648. """
  649. @staticmethod
  650. def forward( # type: ignore[override]
  651. ctx, # pyre-ignore[2]: Parameter must be annotated.
  652. input: torch.Tensor,
  653. ) -> torch.Tensor:
  654. return _maybe_wrap_tensor(input)
  655. @staticmethod
  656. def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # type: ignore[override]
  657. return grad_output
  658. def _are_we_tracing() -> bool:
  659. if is_torchdynamo_compiling():
  660. return True
  661. # If functionalization is turned on, we are almost definitely compiling/tracing.
  662. # (In particular, AOTAutograd traces a model once with functionalization on
  663. # but proxy tracing turned of, so this is how we detect it).
  664. if (
  665. torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
  666. is not None
  667. ):
  668. return True
  669. mode = get_innermost_proxy_mode()
  670. if mode is None:
  671. return False
  672. return mode.tracer is not None
  673. def _maybe_wrap_tensor(self) -> torch.Tensor:
  674. if _are_we_tracing():
  675. return wait_tensor(self)
  676. res = AsyncCollectiveTensor(self)
  677. return cast(torch.Tensor, res)
  678. def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size):
  679. def mk_out_tensor(shard):
  680. out_size = list(shard.size())
  681. out_size[0] *= group_size
  682. out_tensor = shard.new_empty(out_size)
  683. return out_tensor
  684. return [mk_out_tensor(t) for t in self]
  685. # We now register meta kernels to deal with tracing
  686. def _broadcast_meta(self, *args):
  687. return torch.empty_like(self)
  688. def _all_reduce_meta(self, *args):
  689. return torch.empty_like(self)
  690. def _wait_tensor_meta(self, *args):
  691. return torch.empty_like(self)
  692. def _all_gather_into_tensor_meta(shard, tag, rankset, group_size):
  693. out_size = list(shard.size())
  694. out_size[0] *= group_size
  695. return shard.new_empty(out_size)
  696. def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
  697. out_size = list(input.size())
  698. out_size[0] //= group_size
  699. return input.new_empty(out_size)
  700. def _all_reduce_coalesced_meta(self, *args):
  701. return [torch.empty_like(t) for t in self]
  702. def _all_reduce__meta(inp, *args):
  703. return inp
  704. def _broadcast__meta(inp, *args):
  705. return inp
  706. def _all_reduce_coalesced__meta(inputs, *args):
  707. return inputs
  708. def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
  709. def mk_out_tensor(input):
  710. out_size = list(input.size())
  711. out_size[0] //= group_size
  712. out_tensor = input.new_empty(out_size)
  713. return out_tensor
  714. return [mk_out_tensor(t) for t in inputs]
  715. # NB: We often say all_to_all has dynamic output size, but this is not
  716. # technically true: instead, what typically happens is you manually
  717. # communicate the output_split_sizes ahead of time (which is dynamic),
  718. # but then you pass those sizes explicitly, and the all to all itself
  719. # isn't dynamic, it just follows the specified output splits
  720. def _all_to_all_single_meta(
  721. input, output_split_sizes, input_split_sizes, *args, **kwargs
  722. ):
  723. if output_split_sizes is None:
  724. return input.new_empty(input.size())
  725. else:
  726. for s in output_split_sizes:
  727. torch._check_is_size(s)
  728. out_size = list(input.size())
  729. out_size[0] = sum(output_split_sizes)
  730. return input.new_empty(out_size)
  731. def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out):
  732. shape = list(input.size())
  733. shape[0] *= group_size
  734. return input.new_empty(shape)
  735. def _all_gather_into_tensor_native_meta(input, group_size, group_name):
  736. shape = list(input.size())
  737. shape[0] *= group_size
  738. return input.new_empty(shape)
  739. def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name):
  740. return [
  741. _all_gather_into_tensor_native_meta(input, group_size, group_name)
  742. for input in inputs
  743. ]
  744. def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name):
  745. shape = list(inp.size())
  746. shape[0] //= group_size
  747. return inp.new_empty(shape)
  748. def _reduce_scatter_tensor_coalesced_native_meta(
  749. inputs, reduce_op, group_size, group_name
  750. ):
  751. return [
  752. _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name)
  753. for inp in inputs
  754. ]
  755. if not torch._running_with_deploy():
  756. # Library MUST be defined at module scope or it doesn't work
  757. # Creating a "DEF" Library always crashes torch::deploy so we create our
  758. # Library instances here guarded against running inside it
  759. lib_impl = torch.library.Library("_c10d_functional", "IMPL")
  760. lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
  761. lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
  762. lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
  763. lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
  764. lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
  765. lib_impl.impl(
  766. "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta"
  767. )
  768. lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
  769. lib_impl.impl(
  770. "all_gather_into_tensor_coalesced",
  771. _all_gather_into_tensor_coalesced_native_meta,
  772. "Meta",
  773. )
  774. lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta")
  775. lib_impl.impl(
  776. "reduce_scatter_tensor_coalesced",
  777. _reduce_scatter_tensor_coalesced_native_meta,
  778. "Meta",
  779. )
  780. lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
  781. lib_impl.impl("broadcast", _broadcast_meta, "Meta")
  782. lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
  783. # Register legacy ops for backward compatibility
  784. # TODO(yifu): remove these in functional collective beta release
  785. legacy_lib = torch.library.Library("c10d_functional", "DEF")
  786. legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL")
  787. ops_defs = [
  788. "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",
  789. "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
  790. "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
  791. "wait_tensor(Tensor self) -> Tensor",
  792. "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",
  793. "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]",
  794. "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
  795. "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
  796. "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950
  797. ]
  798. my_module = sys.modules[__name__]
  799. for op_def in ops_defs:
  800. op_name = op_def[0 : op_def.index("(")]
  801. backend_impl = getattr(fun_col_impl, f"_{op_name}")
  802. legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)
  803. legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd")
  804. else:
  805. warnings.warn(
  806. "PyTorch Distributed functional collectives do not work with torch::deploy."
  807. )
  808. """
  809. Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into
  810. functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph.
  811. We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via
  812. the mapping dict below.
  813. These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from
  814. """
  815. def all_gather_tensor_inplace(
  816. output_tensor: torch.Tensor,
  817. input_tensor: torch.Tensor,
  818. group, # TODO add a type,
  819. async_op: bool = False,
  820. tag: str = "",
  821. gather_dim: int = 0,
  822. ):
  823. assert (
  824. not async_op
  825. ), "Can't remap async version of inplace op to functional collective"
  826. group = group or dist.group.WORLD
  827. assert group is not None
  828. return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
  829. def reduce_scatter_tensor_inplace(
  830. output: torch.Tensor,
  831. input: torch.Tensor,
  832. op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok?
  833. group=None, # TODO add a type
  834. async_op: bool = False,
  835. scatter_dim: int = 0,
  836. tag: str = "",
  837. ):
  838. assert (
  839. not async_op
  840. ), "Can't remap async version of inplace op to functional collective"
  841. group = group or dist.group.WORLD
  842. assert group is not None
  843. return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))
  844. REDUCE_OP_TO_STR = {
  845. dist.ReduceOp.SUM: "sum",
  846. dist.ReduceOp.AVG: "avg",
  847. dist.ReduceOp.PRODUCT: "product",
  848. dist.ReduceOp.MIN: "min",
  849. dist.ReduceOp.MAX: "max",
  850. dist.ReduceOp.BAND: "band",
  851. dist.ReduceOp.BOR: "bor",
  852. dist.ReduceOp.BXOR: "bxor",
  853. }
  854. def all_reduce_inplace(
  855. tensor: torch.Tensor,
  856. op: str = "sum",
  857. group=None,
  858. async_op: bool = False,
  859. tag: str = "",
  860. ):
  861. assert (
  862. not async_op
  863. ), "Can't remap async version of inplace op to functional collective"
  864. group = group or dist.group.WORLD
  865. assert group is not None
  866. return tensor.copy_(all_reduce(tensor, op, group, tag))
  867. def all_to_all_inplace(
  868. output: torch.Tensor,
  869. input: torch.Tensor,
  870. output_split_sizes=None,
  871. input_split_sizes=None,
  872. group=None,
  873. async_op=False,
  874. tag: str = "",
  875. ):
  876. assert (
  877. not async_op
  878. ), "Can't remap async version of inplace op to functional collective"
  879. group = group or dist.group.WORLD
  880. assert group is not None
  881. return output.copy_(
  882. all_to_all_single(
  883. input,
  884. output_split_sizes,
  885. input_split_sizes,
  886. group,
  887. tag,
  888. )
  889. )
  890. def all_gather_inplace(
  891. tensor_list: List[torch.Tensor],
  892. tensor: torch.Tensor,
  893. group=None,
  894. async_op=False,
  895. tag: str = "",
  896. ):
  897. assert (
  898. not async_op
  899. ), "Can't remap async version of inplace op to functional collective"
  900. assert all(
  901. t.size(0) == tensor.size(0) for t in tensor_list
  902. ), "Remapping variable size all_gather is not yet supported"
  903. group = group or dist.group.WORLD
  904. assert group is not None
  905. output = all_gather_tensor(tensor, 0, group, tag)
  906. # Use aten.slice instead of aten.split because the latter causes
  907. # tensor.shape(0) to be unnecessarily baked in when it's a SymInt.
  908. output_splits = []
  909. offset = 0
  910. for t in tensor_list:
  911. output_splits.append(output[offset : offset + t.size(0)])
  912. offset += t.size(0)
  913. for dst, src in zip(tensor_list, output_splits):
  914. dst.copy_(src)
  915. return tensor_list
  916. from torch.distributed.distributed_c10d import (
  917. _all_gather_base as legacy_all_gather_base,
  918. _reduce_scatter_base as legacy_reduce_scatter_base,
  919. all_gather as legacy_all_gather,
  920. all_gather_into_tensor as legacy_allgather,
  921. all_reduce as legacy_allreduce,
  922. all_to_all_single as legacy_all_to_all_single,
  923. reduce_scatter_tensor as legacy_reducescatter,
  924. )
  925. # This dict should contain sets of functions that dynamo is allowed to remap.
  926. # Functions in this set should accept the same args/kwargs 1:1 as their mapping.
  927. traceable_collective_remaps = {
  928. legacy_allgather: all_gather_tensor_inplace,
  929. legacy_reducescatter: reduce_scatter_tensor_inplace,
  930. legacy_allreduce: all_reduce_inplace,
  931. legacy_all_to_all_single: all_to_all_inplace,
  932. legacy_all_gather: all_gather_inplace,
  933. legacy_reduce_scatter_base: reduce_scatter_tensor_inplace,
  934. legacy_all_gather_base: all_gather_tensor_inplace,
  935. }