random.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import contextlib
  4. import warnings
  5. from typing import Dict, List, Optional
  6. import torch
  7. import torch.distributed as dist
  8. from torch import Tensor
  9. from torch.distributed._tensor.placement_types import DTensorSpec, Shard
  10. from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
  11. _rng_tracker: Optional["_RNGStateTracker"] = None
  12. def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool:
  13. """Checks if the current device of `device_mesh` supports DTensor's random APIs.
  14. Currently DTensor Random APIs only supports cuda/cuda-like devices. We suggest
  15. users call this API to test the availability before using our random APIs.
  16. Args:
  17. device_mesh (:class:`DeviceMesh`): The device mesh on which we check if the
  18. random ops APIs are supported.
  19. Returns:
  20. A bool value. True if `device_mesh` supports DTensor Random APIs; False otherwise.
  21. .. warning::
  22. Currently we only support correct RNG on cuda/cuda-like devices.
  23. """
  24. device_handle = _get_device_handle(device_mesh.device_type)
  25. if device_handle and hasattr(device_handle, "set_rng_state"):
  26. return True
  27. else:
  28. # TODO: Logs way too much
  29. warnings.warn(
  30. f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh"
  31. )
  32. return False
  33. def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
  34. """Sets the seed for generating random numbers for the calling rank.
  35. Args:
  36. seed (int): The desired seed.
  37. device_mesh (:class:`DeviceMesh`): The device mesh to set the seed.
  38. Returns:
  39. None
  40. .. warning::
  41. When calling this function, :func:`manual_seed` must be called from all ranks of the
  42. default `ProcessGroup` even if some ranks may not be a part of the `device_mesh`,
  43. with the same `seed` value.
  44. If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it,
  45. `manual_seed` will not set its GPU device's generator seed.
  46. Current implementation only supports a GPU device mesh.
  47. """
  48. device_handle = _get_device_handle(device_mesh.device_type)
  49. if not device_handle:
  50. raise NotImplementedError(
  51. f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
  52. )
  53. # allgather the seed over the default PG
  54. object_list = [seed] * dist.get_world_size()
  55. dist.all_gather_object(object_list, seed)
  56. for rank, object in enumerate(object_list):
  57. if seed != int(object):
  58. raise RuntimeError(
  59. f"calling manual_seed function over {device_mesh} but received different seed values on ranks:",
  60. f"seed on rank {dist.get_rank()} is {seed}, and seed on rank {rank} is {object}!",
  61. )
  62. # instantiate a RNG tracker if haven't. By default DTensor uses an
  63. # OffsetBasedRNGTracker to perform random operators.
  64. global _rng_tracker
  65. if not _rng_tracker:
  66. _rng_tracker = OffsetBasedRNGTracker(device_mesh.device_type)
  67. # the current rank is in mesh
  68. if device_mesh.get_coordinate() is not None:
  69. if isinstance(_rng_tracker, TensorParallelRNGTracker):
  70. _rng_tracker._manual_seed(device_mesh, seed)
  71. elif isinstance(_rng_tracker, OffsetBasedRNGTracker):
  72. _rng_tracker._manual_seed(seed)
  73. else:
  74. raise RuntimeError(
  75. f"Unknown type of cuda RNG state tracker: _rng_tracker = {_rng_tracker}"
  76. )
  77. class _RNGStateTracker:
  78. """
  79. _RNGStateTracker stores Random Number Generator (RNG) state (a ByteTensor object)
  80. in a dict, mapping from a corresponding tag to each state tensor. It also provides
  81. a set of convenient utility methods to help access/modify the state tensors. The most
  82. important interface is _distribute_region which will be used when DTensor executes
  83. a random op (an operator that calls RNG).
  84. """
  85. def __init__(self, device_type: str = "cuda"):
  86. self._device_type = device_type
  87. self._device_handle = _get_device_handle(device_type)
  88. if not (self._device_handle and self._device_handle.is_available()):
  89. raise RuntimeError(
  90. f"{self.__class__.__name__} instantiation requires the presence of CUDA/CUDA-like device"
  91. )
  92. self._states: Dict[str, Tensor] = {}
  93. self._devices = [self._device_handle.current_device()]
  94. self._use_distribute_region = True
  95. @property
  96. def rng_states(self) -> Dict[str, Tensor]:
  97. return self._states
  98. @property
  99. def distribute_region_enabled(self) -> bool:
  100. return self._use_distribute_region
  101. @distribute_region_enabled.setter
  102. def distribute_region_enabled(self, value) -> None:
  103. self._use_distribute_region = value
  104. def rng_state_is_sync(self, name) -> bool:
  105. return name in self.rng_states
  106. def get_seed(self, name: str) -> int:
  107. if name not in self.rng_states:
  108. raise RuntimeError(
  109. f"{self.__class__.__name__} does not have random state for {name}"
  110. )
  111. seed_tensor = (self.rng_states[name])[0:8].view(dtype=torch.int64)
  112. return int(seed_tensor.item())
  113. def set_seed(self, name: str, seed: int) -> None:
  114. seed_tensor = torch.tensor([seed]).view(torch.uint8)
  115. offset_tensor = torch.tensor([0]).view(torch.uint8)
  116. self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
  117. def _distribute_region(self, spec: DTensorSpec):
  118. pass
  119. class OffsetBasedRNGTracker(_RNGStateTracker):
  120. """
  121. This subclass of `_RNGStateTracker` defines the default policy of how RNG states
  122. should be shared and synchronized among all ranks to respect the semantics of DTensor
  123. random operators.
  124. """
  125. def __init__(self, device_type: str = "cuda"):
  126. super().__init__(device_type)
  127. # synchronize RNG state using rank 0's current one
  128. rng_state = self._device_handle.get_rng_state().to(device_type)
  129. dist.broadcast(rng_state, 0)
  130. self.rng_states["parallel-rng"] = rng_state.to("cpu")
  131. def _manual_seed(self, parallel_seed: int) -> None:
  132. self.set_seed("parallel-rng", parallel_seed)
  133. @contextlib.contextmanager
  134. def _distribute_region(self, spec: DTensorSpec):
  135. # check if the parallel rng state has been synchronized or not
  136. if not self.rng_state_is_sync("parallel-rng"):
  137. raise RuntimeError(
  138. "OffsetBasedRNGTracker requires the random state to be synchronized "
  139. "before entering into a distribute region!"
  140. )
  141. if self.distribute_region_enabled:
  142. old_offset = self.get_offset("parallel-rng")
  143. self._set_pre_op_offset(spec)
  144. with torch.random.fork_rng(self._devices, device_type=self._device_type):
  145. self._device_handle.set_rng_state(self.rng_states["parallel-rng"])
  146. try:
  147. yield # execute the region code
  148. finally:
  149. # update offset to synchronize among ranks
  150. self._set_post_op_offset(spec, old_offset)
  151. else:
  152. yield
  153. def get_offset(self, name: str) -> int:
  154. if name not in self.rng_states:
  155. raise RuntimeError(
  156. f"{self.__class__.__name__} does not have random state for {name}"
  157. )
  158. offset_tensor = (self.rng_states[name])[8:].view(dtype=torch.int64)
  159. return int(offset_tensor.item())
  160. def set_offset(self, name: str, offset: int) -> None:
  161. if name not in self.rng_states:
  162. raise RuntimeError(
  163. f"{self.__class__.__name__} does not have random state for {name}"
  164. )
  165. seed_tensor = (self.rng_states[name])[0:8]
  166. offset_tensor = torch.tensor([offset]).view(torch.uint8)
  167. self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
  168. def _set_pre_op_offset(self, spec: DTensorSpec) -> None:
  169. """Set the starting RNG offset for current device's local shard before actual
  170. op execution. The pre_op_offset value should start from the current RNG offset
  171. and increment by the size of local shard until it reaches the size of the whole
  172. DTensor. For different ranks that hold the same DTensor shard, their pre_op_offset
  173. will be the same.
  174. Args:
  175. spec (:class:`DTensorSpec`): the spec of the DTensor object on which
  176. we prepare the offset for running random ops.
  177. Returns:
  178. None
  179. .. warning::
  180. Note that, current implementation does not consider DTensor's continguity.
  181. Example:
  182. take a DTensor of shape [8, 16] as an example. Assume that the DTensor
  183. is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]),
  184. and the mesh is:
  185. [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
  186. ``spec.mesh.get_coordinate()`` provides the coordinate of the current rank
  187. in the mesh. For example, the coordinate of rank 5 is (1, 0, 1).
  188. Another concept to introduce besides rank coordinate is shard coordinate.
  189. Each rank holds a local shard of the DTensor. In the example, the DTensor
  190. is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and
  191. rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each.
  192. That being said, the local shard on rank 0 and rank 2 correspond to the same
  193. shard of the DTensor. To denote each DTensor shard, we use a shard coordinate
  194. (in the example, it will be a tuple (i, j) where shard (i, j) has the slice
  195. DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2).
  196. Once we have rank coordinate and shard coordinate, we can calculate on each rank
  197. what shard of the DTensor the rank holds, with the help of dim_map. The dim_map
  198. of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord
  199. (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]).
  200. Following this calculation,
  201. rank 0 and rank 2 holds the shard of coord (0, 0);
  202. rank 1 and rank 3 holds the shard of coord (0, 1);
  203. rank 4 and rank 6 holds the shard of coord (1, 0);
  204. rank 5 and rank 7 holds the shard of coord (1, 1);
  205. The last value to calculate before obtaining the starting offset is the shard linear index.
  206. The starting offset for each rank will be its shard_linear_index * local_tensor_numel.
  207. """
  208. dtensor_shape = spec.shape
  209. mesh = spec.mesh
  210. dim_map = spec.dim_map
  211. # Compute shard coordinate:
  212. # The coordinate on each tensor dim is a tuple (idx, range)
  213. # If a DTensor is partitioned on its dim i into n shards, and the current rank
  214. # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i
  215. coordinate = mesh.get_coordinate()
  216. assert coordinate is not None
  217. shard_coord = [
  218. coordinate[mesh_dim] if mesh_dim >= 0 else 0 for mesh_dim in dim_map
  219. ]
  220. shard_size = [
  221. mesh.size(mesh_dim) if mesh_dim >= 0 else 1 for mesh_dim in dim_map
  222. ]
  223. # compute shard linear index
  224. shard_linear_idx = self._calc_shard_linear_idx(shard_coord, shard_size)
  225. # compute starting offset using the first shard's size
  226. local_size_on_rank_0 = list(dtensor_shape)
  227. for idx, placement in enumerate(spec.placements):
  228. if isinstance(placement, Shard):
  229. mesh_dim_size = mesh.size(idx)
  230. shard_dim = placement.dim
  231. local_size_on_rank_0[shard_dim] = placement._local_shard_size_on_dim(
  232. dtensor_shape[shard_dim],
  233. mesh_dim_size,
  234. 0,
  235. return_offset=False,
  236. )[0]
  237. from torch.distributed._tensor.ops.utils import prod
  238. local_size = prod(local_size_on_rank_0)
  239. # get current RNG offset
  240. current_offset = self.get_offset("parallel-rng")
  241. # pytorch: offset must be multiple of 4
  242. # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
  243. offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4
  244. self.set_offset("parallel-rng", current_offset + offset_incr)
  245. def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None:
  246. """Sets the RNG to a synchronized state after running the local random op. Every
  247. rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is
  248. the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor
  249. random ops.
  250. Args:
  251. spec (:class:`DTensorSpec`): the spec of the DTensor object on which
  252. we post-process the offset for running random ops.
  253. Returns:
  254. None
  255. """
  256. dtensor_shape = spec.shape
  257. from torch.distributed._tensor.ops.utils import prod
  258. numel = prod(dtensor_shape)
  259. # pytorch: offset must be multiple of 4
  260. # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
  261. numel = (numel + 3) // 4 * 4
  262. self.set_offset("parallel-rng", old_offset + numel)
  263. def _calc_shard_linear_idx(
  264. self, shard_coord: List[int], shard_size: List[int]
  265. ) -> int:
  266. # compute shard linear index
  267. shard_linear_idx = 0
  268. shard_coord_stride = 1
  269. for idx, size in zip(reversed(shard_coord), reversed(shard_size)):
  270. shard_linear_idx += idx * shard_coord_stride
  271. shard_coord_stride *= size
  272. return shard_linear_idx
  273. class TensorParallelRNGTracker(_RNGStateTracker):
  274. def __init__(self, device_type: str = "cuda"):
  275. super().__init__(device_type)
  276. # copy the default RNG state
  277. self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state()
  278. def _manual_seed(
  279. self,
  280. tp_mesh: DeviceMesh,
  281. base_seed: int = 1234,
  282. ):
  283. tensor_parallel_rank = tp_mesh.get_local_rank()
  284. # this magic number 2718 comes from Megatron's code
  285. # (https://github.com/NVIDIA/Megatron-LM/blob/060415572f4365a2e895f8036c4e37dad0efbdf5/megatron/core/tensor_parallel/random.py#L162-L163)
  286. MegatronMagicNum = 2718
  287. tensor_parallel_seed = base_seed + MegatronMagicNum + tensor_parallel_rank
  288. self.set_seed("tensor-parallel-rng", tensor_parallel_seed)
  289. @contextlib.contextmanager
  290. def _distribute_region(self, spec: DTensorSpec):
  291. # check if the tensor parallel rng state has been synchronized or not
  292. if not self.rng_state_is_sync("tensor-parallel-rng"):
  293. raise RuntimeError(
  294. "TensorParallelRNGTracker requires the random state to be synchronized "
  295. "before entering into a distribute region!"
  296. )
  297. if self.distribute_region_enabled:
  298. with torch.random.fork_rng(self._devices, device_type=self._device_type):
  299. self._device_handle.set_rng_state(
  300. self.rng_states["tensor-parallel-rng"]
  301. )
  302. try:
  303. yield
  304. finally:
  305. self.rng_states[
  306. "tensor-parallel-rng"
  307. ] = self._device_handle.get_rng_state()
  308. else:
  309. yield