device_mesh.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import logging
  4. import math
  5. import threading
  6. from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
  7. import torch
  8. from torch.distributed import is_available
  9. from ..utils._typing_utils import not_none
  10. __all__ = ["init_device_mesh", "DeviceMesh"]
  11. if not is_available():
  12. import sys
  13. # We need to create the stubs when distributed is not available.
  14. # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),
  15. # since it would try to import ``torch.distributed.device_mesh`` or
  16. # ``torch.distributed.init_device_mesh`` but cannot find them.
  17. class _DeviceMeshStub:
  18. pass
  19. def _init_device_mesh_stub():
  20. pass
  21. sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]
  22. sys.modules[
  23. "torch.distributed.device_mesh"
  24. ].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined]
  25. else:
  26. from torch.distributed.distributed_c10d import (
  27. _find_pg_by_ranks_and_tag,
  28. _get_default_group,
  29. _get_group_tag,
  30. get_process_group_ranks,
  31. get_rank,
  32. get_world_size,
  33. init_process_group,
  34. is_initialized,
  35. new_group,
  36. ProcessGroup,
  37. )
  38. logger = logging.getLogger(__name__)
  39. # only import numpy typing when type checking
  40. if TYPE_CHECKING:
  41. try:
  42. from numpy.typing import ArrayLike
  43. except ImportError:
  44. logger.warning(
  45. "DeviceMesh requires numpy >= 1.21 to be installed for type checking"
  46. )
  47. class _MeshEnv(threading.local):
  48. def __init__(self) -> None:
  49. self.mesh_stack: List[DeviceMesh] = []
  50. self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
  51. self.mesh_dim_group_options: Dict[
  52. int, Tuple[str, Optional[ProcessGroup.Options]]
  53. ] = {}
  54. def get_current_mesh(self) -> "DeviceMesh":
  55. if len(self.mesh_stack) == 0:
  56. raise RuntimeError("No device mesh is currently active!")
  57. return self.mesh_stack[-1]
  58. def create_child_mesh(
  59. self, parent_mesh: "DeviceMesh", submesh_dim_names: Tuple[str, ...]
  60. ) -> "DeviceMesh":
  61. # submesh_dims are the mesh dimension of the submesh in the parent mesh.
  62. submesh_dims = [
  63. not_none(parent_mesh.mesh_dim_names).index(mesh_dim_name)
  64. for mesh_dim_name in submesh_dim_names
  65. ]
  66. submesh_dim_sizes = [
  67. parent_mesh.mesh.size(mesh_dim) for mesh_dim in submesh_dims
  68. ]
  69. mesh_dims_remained = list(range(parent_mesh.mesh.ndim))
  70. for submesh_dim in submesh_dims:
  71. mesh_dims_remained.remove(submesh_dim)
  72. # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *sub_mesh_dims]
  73. # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with
  74. # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank.
  75. pg_ranks_by_dim = parent_mesh.mesh.permute(
  76. *mesh_dims_remained, *submesh_dims
  77. ).reshape(-1, *submesh_dim_sizes)
  78. cur_rank = parent_mesh.get_rank()
  79. for mesh_nd in pg_ranks_by_dim:
  80. submesh = DeviceMesh(
  81. parent_mesh.device_type,
  82. mesh_nd,
  83. mesh_dim_names=submesh_dim_names,
  84. _init_backend=False,
  85. )
  86. if cur_rank in mesh_nd:
  87. res_submesh = submesh
  88. res_submesh._parent_mesh = parent_mesh # type: ignore[possibly-undefined]
  89. res_submesh._dim_group_infos = [
  90. parent_mesh._dim_group_infos[mesh_dim] for mesh_dim in submesh_dims # type: ignore[possibly-undefined]
  91. ]
  92. self.child_to_parent_mapping[res_submesh] = parent_mesh
  93. return res_submesh
  94. def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
  95. return self.child_to_parent_mapping.get(device_mesh, None)
  96. def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
  97. """
  98. Return the index of the mesh dim in the parent mesh.
  99. The device_mesh passed in needs to be sliced out from a parent mesh.
  100. """
  101. parent_mesh = self.get_parent_mesh(device_mesh)
  102. child_mesh_dim_names = device_mesh.mesh_dim_names
  103. if parent_mesh and child_mesh_dim_names:
  104. assert (
  105. len(child_mesh_dim_names) == 1
  106. ), "The child mesh can only be a 1D mesh."
  107. child_mesh_dim_name = child_mesh_dim_names[0]
  108. return self.get_mesh_dim_by_name(parent_mesh, child_mesh_dim_name)
  109. return None
  110. @staticmethod
  111. def num_devices_per_host(device_type: str) -> int:
  112. return _get_device_handle(device_type).device_count()
  113. @staticmethod
  114. def num_hosts(device_type: str) -> int:
  115. # ProcessGroup can't tell us this info so we have to infer it, assume
  116. # homogeneous hardware for now
  117. return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
  118. def get_mesh_dim_by_name(
  119. self, device_mesh: "DeviceMesh", mesh_dim_name: str
  120. ) -> int:
  121. if (
  122. device_mesh.mesh_dim_names is None
  123. or len(device_mesh.mesh_dim_names) == 0
  124. ):
  125. raise KeyError(
  126. "No `mesh_dim_names` found.",
  127. )
  128. if mesh_dim_name not in device_mesh.mesh_dim_names:
  129. raise KeyError(
  130. f"Mesh dimension '{mesh_dim_name}' does not exist.",
  131. f"Available mesh dimensions are: mesh_dim_names={device_mesh.mesh_dim_names}",
  132. )
  133. return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name))
  134. def _set_mesh_dim_group_options(
  135. self,
  136. dim: int,
  137. backend: str,
  138. pg_options: Optional[ProcessGroup.Options] = None,
  139. ) -> None:
  140. self.mesh_dim_group_options[dim] = (backend, pg_options)
  141. _mesh_resources: _MeshEnv = _MeshEnv()
  142. def _get_device_handle(device_type: str = "cuda"):
  143. """
  144. Get the module corresponding to the device_type which is cuda or cuda-like device.
  145. For example, when the device_type is cuda, the module `torch.cuda` is returned.
  146. Return None when there is no corresponding module for device_type, otherwise
  147. return the corresponding module.
  148. """
  149. return getattr(torch, device_type, None)
  150. class DeviceMesh:
  151. """
  152. DeviceMesh represents a mesh of devices, where layout of devices could be
  153. represented as a n-d dimension array, and each value of the n-d dimensional
  154. array is the global id of the default process group ranks.
  155. DeviceMesh could be used to describe the layout of devices across the cluster,
  156. and serves as a proxy for communication among the device lists within the cluster.
  157. DeviceMesh can be used as a context manager.
  158. .. note::
  159. DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
  160. is running on all processes/ranks in the cluster. Therefore, users need to make sure the
  161. `mesh` array (which describes the layout of devices) should be identical across all ranks.
  162. Inconsistent `mesh` will lead to silent hang.
  163. Args:
  164. device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
  165. mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
  166. of devices, where the IDs are global IDs of the default process group.
  167. Returns:
  168. DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
  169. The following program runs on each process/rank in an SPMD manner. In this example, we have 2
  170. hosts with 4 GPUs each.
  171. A reduction over the first dimension of mesh will reduce across
  172. columns (0, 4), .. and (3, 7), a reduction over the second dimension
  173. of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
  174. Example::
  175. >>> # xdoctest: +SKIP("no rank")
  176. >>> from torch.distributed.device_mesh import DeviceMesh
  177. >>>
  178. >>> # Initialize device mesh as (2, 4) to represent the topology
  179. >>> # of cross-host(dim 0), and within-host (dim 1).
  180. >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
  181. """
  182. device_type: str
  183. mesh: torch.Tensor
  184. mesh_dim_names: Optional[Tuple[str, ...]]
  185. def __init__(
  186. self,
  187. device_type: str,
  188. mesh: Union[torch.Tensor, "ArrayLike"],
  189. *,
  190. mesh_dim_names: Optional[Tuple[str, ...]] = None,
  191. _init_backend: bool = True,
  192. ) -> None:
  193. self.device_type = device_type
  194. if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
  195. raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
  196. self.mesh = (
  197. mesh.detach().to(dtype=torch.int)
  198. if isinstance(mesh, torch.Tensor)
  199. else torch.tensor(mesh, device="cpu", dtype=torch.int)
  200. )
  201. self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
  202. # private field to pre-generate DeviceMesh's hash
  203. self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
  204. self._parent_mesh: Optional[DeviceMesh] = None
  205. self._thread_id = threading.get_ident()
  206. # Skip process group initialization if xla device or init backend is False
  207. # TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
  208. if device_type != "xla":
  209. # always try to create default (world) pg, even if it is not initialized
  210. # already. The world pg is used for device mesh identity (rank) on each
  211. # process (we need to know if the current global rank is in the mesh or not).
  212. if _init_backend:
  213. self._get_or_create_default_group()
  214. self._init_process_groups()
  215. # calculate the coordinates of the current global rank on the mesh
  216. rank_coords = (self.mesh == get_rank()).nonzero()
  217. assert rank_coords.size(0) in (0, 1)
  218. self._coordinate_on_dim: Optional[List[int]] = (
  219. rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
  220. )
  221. def _get_or_create_default_group(self):
  222. default_initialized = is_initialized()
  223. if not default_initialized:
  224. init_process_group()
  225. world_size = get_world_size()
  226. if self.mesh.numel() > world_size:
  227. raise RuntimeError(
  228. f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
  229. )
  230. device_handle = _get_device_handle(self.device_type)
  231. # TODO: if user want to pass pg_options, offer a way to do it
  232. if not default_initialized and device_handle:
  233. # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
  234. # NOTE: This device selection would only work for homogeneous hardware.
  235. num_devices_per_host = device_handle.device_count()
  236. if (
  237. world_size > num_devices_per_host
  238. and world_size % num_devices_per_host != 0
  239. ):
  240. raise RuntimeError(
  241. f"DeviceMesh only support homogeneous hardware, but found "
  242. f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
  243. )
  244. device_handle.set_device(get_rank() % num_devices_per_host)
  245. return _get_default_group()
  246. def _init_process_groups(self):
  247. # tag/ranks/group_name associated with each mesh dimension, each
  248. # mesh dimension should have one sub-group per rank
  249. #
  250. # TODO(yifu): remove tag and ranks once we fully migrate to native
  251. # functional collectives. See details in:
  252. # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
  253. dim_group_infos: List[Tuple[str, List[int], str]] = []
  254. if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
  255. # if the mesh is the same as world_pg, we just append the default
  256. # pg to the first dim groups, as new_group cannot have the exact
  257. # same ranks as world
  258. dim_group_infos.append(
  259. (
  260. _get_group_tag(_get_default_group()),
  261. list(range(get_world_size())),
  262. _get_default_group().group_name,
  263. )
  264. )
  265. else:
  266. # create sub pgs base on the mesh argument specified
  267. for dim in range(self.mesh.ndim):
  268. # swap the current dim to the last dim
  269. # then reshape to flatten out other dims
  270. pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
  271. -1, self.mesh.size(dim)
  272. )
  273. # multi-dim mesh, create subgroups by looping over the pg_ranks
  274. # for each dim and append the groups
  275. for dim_mesh in pg_ranks_by_dim:
  276. subgroup_ranks = dim_mesh.tolist()
  277. # Respect dim group options specified via _MeshEnv.set_dim_group_options().
  278. # Inherit from the parent group if no options are specified for the group.
  279. if dim in _mesh_resources.mesh_dim_group_options:
  280. (
  281. backend,
  282. pg_options,
  283. ) = _mesh_resources.mesh_dim_group_options[dim]
  284. else:
  285. backend, pg_options = None, None
  286. # We temporarily revert the re-use subgroup, since it breaks two internal tests.
  287. # Temporarily reverting to resolve test timeout while root-causing.
  288. # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
  289. dim_group = new_group(
  290. ranks=subgroup_ranks,
  291. backend=backend,
  292. pg_options=pg_options,
  293. )
  294. # only add to dim_groups if the current rank in the subgroup
  295. if self.get_rank() in subgroup_ranks:
  296. if len(dim_group_infos) > dim:
  297. raise RuntimeError(
  298. f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
  299. f"in {subgroup_ranks}!"
  300. )
  301. dim_group_infos.append(
  302. (
  303. _get_group_tag(not_none(dim_group)),
  304. subgroup_ranks,
  305. dim_group.group_name,
  306. )
  307. )
  308. self._dim_group_infos = dim_group_infos
  309. def __enter__(self) -> "DeviceMesh":
  310. # set this mesh as the current mesh in mesh env
  311. _mesh_resources.mesh_stack.append(self)
  312. return self
  313. # pyre-fixme[2]: Parameter must be annotated.
  314. def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
  315. # pop this mesh from mesh env
  316. _mesh_resources.mesh_stack.pop()
  317. def __repr__(self) -> str:
  318. device_mesh_repr = (
  319. f"DeviceMesh({self.mesh.tolist()})"
  320. if not self.mesh_dim_names
  321. else f"DeviceMesh({self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})"
  322. )
  323. return device_mesh_repr
  324. def __hash__(self):
  325. # lazily compute hash
  326. self._hash = getattr(self, "_hash", None)
  327. if not self._hash:
  328. self._hash = hash(
  329. (
  330. self._flatten_mesh_list,
  331. self.mesh.shape,
  332. self.device_type,
  333. self.mesh_dim_names,
  334. self._parent_mesh,
  335. self._thread_id,
  336. )
  337. )
  338. return self._hash
  339. def __eq__(self, other: object) -> bool:
  340. if not isinstance(other, DeviceMesh):
  341. return False
  342. if id(self) == id(other):
  343. return True
  344. else:
  345. return (
  346. self._flatten_mesh_list == other._flatten_mesh_list
  347. and self.mesh.shape == other.mesh.shape
  348. and self.device_type == other.device_type
  349. and self.mesh_dim_names == other.mesh_dim_names
  350. and self._parent_mesh == other._parent_mesh
  351. and self._thread_id == other._thread_id
  352. )
  353. def __getitem__(
  354. self, mesh_dim_names: Union[str, Tuple[str, ...]]
  355. ) -> "DeviceMesh":
  356. """
  357. Slice the current DeviceMesh based on the mesh_dim_name given to create a child
  358. DeviceMesh.
  359. Args:
  360. mesh_dim_name (Union[str, Tuple[str]]): the name or the tuple of names of the
  361. mesh dimension of the parent DeviceMesh to create the child DeviceMesh for.
  362. Returns:
  363. A :class:`DeviceMesh` object
  364. The following program runs on each process/rank in an SPMD manner. In this example, we have 2
  365. hosts with 4 GPUs each.
  366. Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
  367. Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
  368. Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
  369. Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
  370. Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
  371. Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
  372. Example::
  373. >>> # xdoctest: +SKIP("no rank")
  374. >>> from torch.distributed.device_mesh import DeviceMesh
  375. >>>
  376. >>> # Initialize device mesh as (2, 4) to represent the topology
  377. >>> # of cross-host(dim 0), and within-host (dim 1).
  378. >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
  379. """
  380. if not self.mesh_dim_names:
  381. raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!")
  382. mesh_dim_names = (
  383. (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names
  384. )
  385. error_msg = (
  386. f"Invalid mesh_dim_name {mesh_dim_names} specified. "
  387. f"Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}."
  388. )
  389. if mesh_dim_names == self.mesh_dim_names:
  390. return self
  391. elif len(mesh_dim_names) > len(self.mesh_dim_names) or not all(
  392. mesh_dim_name in self.mesh_dim_names for mesh_dim_name in mesh_dim_names
  393. ):
  394. raise KeyError(error_msg)
  395. # Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names
  396. # of the current DeviceMesh.
  397. else:
  398. outermost_dim_name = mesh_dim_names[0]
  399. outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name)
  400. for i, j in zip(
  401. mesh_dim_names,
  402. self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)],
  403. ):
  404. if i != j:
  405. raise KeyError(error_msg)
  406. submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names)
  407. return submesh
  408. def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
  409. """
  410. Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the
  411. DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh.
  412. Args:
  413. mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
  414. of the mesh dimension. Default is None.
  415. Returns:
  416. A :class:`ProcessGroup` object.
  417. """
  418. if not hasattr(self, "_dim_group_infos"):
  419. raise RuntimeError("DeviceMesh process groups not initialized!")
  420. if self.mesh.ndim > 1 and mesh_dim is None:
  421. raise RuntimeError(
  422. f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
  423. "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
  424. "If you want to get the list of all the ProcessGroups in the DeviceMesh,"
  425. "please use `get_all_groups()` instead.",
  426. )
  427. if self.mesh.ndim == 1 and mesh_dim is None:
  428. mesh_dim = 0
  429. else:
  430. mesh_dim = (
  431. _mesh_resources.get_mesh_dim_by_name(self, mesh_dim)
  432. if isinstance(mesh_dim, str)
  433. else mesh_dim
  434. )
  435. return not_none(
  436. _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index]
  437. )
  438. def get_all_groups(self) -> List[ProcessGroup]:
  439. """
  440. Returns a list of ProcessGroups for all mesh dimensions.
  441. Returns:
  442. A list of :class:`ProcessGroup` object.
  443. """
  444. return [self.get_group(i) for i in range(self.mesh.ndim)]
  445. @staticmethod
  446. def from_group(
  447. group: Union[ProcessGroup, List[ProcessGroup]],
  448. device_type: str,
  449. mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
  450. *,
  451. mesh_dim_names: Optional[Tuple[str, ...]] = None,
  452. ) -> "DeviceMesh":
  453. """
  454. Contstructs a :class:`DeviceMesh` with ``device_type`` from an
  455. existing :class:`ProcessGroup`.
  456. The constructed device mesh has number of dimensions equal to the
  457. number of groups passed. If more than one group is passed, then the
  458. ``mesh`` argument is required.
  459. """
  460. if isinstance(group, ProcessGroup):
  461. group_ranks = get_process_group_ranks(group)
  462. if (
  463. isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks
  464. ) or (mesh is not None and mesh != group_ranks):
  465. raise ValueError(
  466. f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}"
  467. )
  468. mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int)
  469. device_mesh = DeviceMesh(
  470. device_type,
  471. mesh,
  472. mesh_dim_names=mesh_dim_names,
  473. _init_backend=False,
  474. )
  475. device_mesh._dim_group_infos = [
  476. (_get_group_tag(group), group_ranks, group.group_name)
  477. ]
  478. return device_mesh
  479. groups = list(group)
  480. if len(groups) == 0:
  481. raise ValueError("Expects at least one ProcessGroup to be passed")
  482. if mesh is None:
  483. raise ValueError("Must pass mesh if passing multiple ProcessGroups")
  484. mesh = (
  485. mesh.detach().to(dtype=torch.int, device="cpu")
  486. if isinstance(mesh, torch.Tensor)
  487. else torch.tensor(mesh, device="cpu", dtype=torch.int)
  488. )
  489. if mesh.ndim != len(groups):
  490. raise ValueError(
  491. "Expects mesh with ndim equal to number of ProcessGroups but got "
  492. f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups"
  493. )
  494. device_mesh = DeviceMesh(
  495. device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False
  496. )
  497. device_mesh._dim_group_infos = [
  498. (
  499. _get_group_tag(group),
  500. get_process_group_ranks(group),
  501. group.group_name,
  502. )
  503. for group in groups
  504. ]
  505. return device_mesh
  506. def size(self, mesh_dim: Optional[int] = None) -> int:
  507. return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
  508. @property
  509. def ndim(self) -> int:
  510. return self.mesh.ndim
  511. @property
  512. def shape(self) -> Tuple[int, ...]:
  513. return tuple(self.mesh.shape)
  514. def get_rank(self) -> int:
  515. """
  516. Returns the current global rank.
  517. """
  518. return get_rank()
  519. def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
  520. """
  521. Returns the local rank of the given mesh_dim of the DeviceMesh.
  522. Args:
  523. mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
  524. of the mesh dimension. Default is None.
  525. Returns:
  526. An integer denotes the local rank.
  527. The following program runs on each process/rank in an SPMD manner. In this example, we have 2
  528. hosts with 4 GPUs each.
  529. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0.
  530. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1.
  531. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0.
  532. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1.
  533. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2.
  534. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.
  535. Example::
  536. >>> # xdoctest: +SKIP("no rank")
  537. >>> from torch.distributed.device_mesh import DeviceMesh
  538. >>>
  539. >>> # Initialize device mesh as (2, 4) to represent the topology
  540. >>> # of cross-host(dim 0), and within-host (dim 1).
  541. >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
  542. """
  543. if self.ndim > 1 and mesh_dim is None:
  544. raise RuntimeError(
  545. f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
  546. "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
  547. )
  548. elif mesh_dim is None:
  549. mesh_dim = 0
  550. mesh_dim_group = not_none(self.get_group(mesh_dim))
  551. assert isinstance(
  552. mesh_dim_group, ProcessGroup
  553. ), "We expect ProcessGroup before calling `get_rank`!"
  554. return not_none(get_rank(mesh_dim_group))
  555. def get_coordinate(self) -> Optional[List[int]]:
  556. """
  557. Return the relative indices of this rank relative to all
  558. dimensions of the mesh. If this rank is not part of the mesh, return None.
  559. """
  560. return self._coordinate_on_dim if self._coordinate_on_dim else None
  561. def init_device_mesh(
  562. device_type: str,
  563. mesh_shape: Tuple[int, ...],
  564. *,
  565. mesh_dim_names: Optional[Tuple[str, ...]] = None,
  566. ) -> DeviceMesh:
  567. """
  568. Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
  569. This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`.
  570. If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`.
  571. .. note::
  572. `init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program
  573. runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array
  574. describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging.
  575. .. note::
  576. If no process group is found, init_device_mesh will initialize distributed process group/groups
  577. required for distributed communications behind the scene.
  578. Args:
  579. device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
  580. Passing in a device type with a GPU index, such as "cuda:0", is not allowed.
  581. mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array
  582. describing the layout of devices.
  583. mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension
  584. of the multi-dimensional array describing the layout of devices. Its length must match the length
  585. of `mesh_shape`. Each string in `mesh_dim_names` must be unique.
  586. Returns:
  587. DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
  588. Example::
  589. >>> # xdoctest: +SKIP("no rank")
  590. >>> from torch.distributed.device_mesh import init_device_mesh
  591. >>>
  592. >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
  593. >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
  594. """
  595. if mesh_dim_names is not None:
  596. if len(set(mesh_dim_names)) != len(mesh_dim_names):
  597. raise RuntimeError(
  598. "Each mesh_dim_name must be unique.",
  599. f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
  600. )
  601. if len(mesh_shape) != len(mesh_dim_names):
  602. raise RuntimeError(
  603. "mesh_shape and mesh_dim_names should have same length!",
  604. f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
  605. )
  606. # assume valid device types are all letters
  607. if device_type and not device_type.isalpha():
  608. raise RuntimeError(
  609. f"Device type with GPU index is not supported but got {device_type}. ",
  610. "If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.",
  611. )
  612. # Always initialize the mesh's tensor on CPU, regardless of what the
  613. # external device type has been set to be (e.g. meta)
  614. with torch.device("cpu"):
  615. mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape)
  616. device_mesh = DeviceMesh(
  617. device_type=device_type,
  618. mesh=mesh,
  619. mesh_dim_names=mesh_dim_names,
  620. )
  621. return device_mesh