api.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import inspect
  4. import warnings
  5. from typing import Any, Callable, cast, Optional, Sequence, Tuple
  6. import torch
  7. import torch.distributed._tensor._dispatch as op_dispatch
  8. import torch.distributed._tensor.random as random
  9. import torch.nn as nn
  10. from torch.distributed._tensor._collective_utils import mesh_broadcast
  11. from torch.distributed._tensor._redistribute import (
  12. Redistribute,
  13. redistribute_local_tensor,
  14. )
  15. from torch.distributed._tensor._utils import compute_global_tensor_info
  16. from torch.distributed._tensor.placement_types import (
  17. DTensorSpec,
  18. Partial,
  19. Placement,
  20. Replicate,
  21. Shard,
  22. TensorMeta,
  23. )
  24. from torch.distributed._tensor.random import (
  25. is_rng_supported_mesh,
  26. OffsetBasedRNGTracker,
  27. )
  28. from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
  29. __all__ = ["DTensor", "distribute_tensor", "distribute_module"]
  30. aten = torch.ops.aten
  31. # NOTE [Autograd interaction between torch.Tensor]
  32. #
  33. # The autograd functions defined below are being used by the public
  34. # facing APIs (i.e. from_local, to_local) to ensure our DTensor
  35. # works together with torch.Tensor within autograd engine. This
  36. # allows DistributedTensor to exist on part of the module hierarchy
  37. # and still able to calculate gradients across the torch.Tensor and
  38. # DistributedTensor boundary.
  39. # As an example, we have the a module that consists of submodules
  40. # A, B, and C, the execution flow would be like:
  41. # input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor)
  42. #
  43. # Suppose I only want to make Module B be a sharded module with
  44. # DistributedTensor params, we would need to make the following
  45. # flow to work:
  46. #
  47. # input(torch.Tensor) -> Module A
  48. # -> DTensor input -> Sharded Module B -> DTensor output
  49. # -> output (torch.Tensor) -> Module C -> output (torch.Tensor)
  50. #
  51. # We need the conversion from Module A to DTensor input, which is
  52. # `from_local`, and conversion from DTensor output to output, which
  53. # is `to_local`, thus these two functions must be Autograd functions.
  54. #
  55. class _ToTorchTensor(torch.autograd.Function):
  56. @staticmethod
  57. def forward( # type: ignore[override]
  58. ctx,
  59. input: "DTensor",
  60. grad_placements: Optional[Sequence[Placement]],
  61. ):
  62. ctx.dtensor_spec = input._spec
  63. ctx.grad_placements = grad_placements
  64. local_tensor = input._local_tensor
  65. # We need to return a fresh Tensor object there as autograd metadata
  66. # will be inplaced into it. So we don't want to pollute the Tensor
  67. # object stored in the _local_tensor of this DTensor.
  68. return local_tensor.view_as(local_tensor)
  69. @staticmethod
  70. def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
  71. dtensor_spec = ctx.dtensor_spec
  72. mesh = dtensor_spec.mesh
  73. grad_placements = ctx.grad_placements
  74. dtensor_meta = dtensor_spec.tensor_meta
  75. _, tensor_stride = compute_global_tensor_info(
  76. grad_output, mesh, dtensor_spec.placements
  77. )
  78. tensor_stride = tuple(tensor_stride)
  79. grad_placements = grad_placements or dtensor_spec.placements
  80. grad_spec = DTensorSpec(
  81. mesh,
  82. grad_placements,
  83. tensor_meta=TensorMeta(
  84. shape=dtensor_meta.shape,
  85. stride=tensor_stride,
  86. dtype=dtensor_meta.dtype,
  87. ),
  88. )
  89. return (
  90. DTensor(
  91. grad_output,
  92. grad_spec,
  93. requires_grad=grad_output.requires_grad,
  94. ),
  95. None,
  96. )
  97. class _FromTorchTensor(torch.autograd.Function):
  98. @staticmethod
  99. def forward( # type: ignore[override]
  100. ctx, # pyre-ignore[2]: Parameter must be annotated.
  101. input: torch.Tensor,
  102. device_mesh: DeviceMesh,
  103. placements: Tuple[Placement, ...],
  104. run_check: bool,
  105. shape: Optional[torch.Size] = None,
  106. stride: Optional[Tuple[int, ...]] = None,
  107. ) -> "DTensor":
  108. ctx.previous_placement = placements
  109. ctx.previous_device_mesh = device_mesh
  110. if shape and stride:
  111. tensor_shape, tensor_stride = shape, stride
  112. elif not shape and not stride:
  113. # if it's not by default run_check, we assume user is certain that each
  114. # rank has the same tensor shape, and we just use that to calculate the
  115. # global shape
  116. global_shape, global_stride = compute_global_tensor_info(
  117. input, device_mesh, placements
  118. )
  119. tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride)
  120. else:
  121. raise RuntimeError(
  122. f"Found shape:{shape}, stride:{stride}.",
  123. "Please pass both shape and stride at the same time.",
  124. )
  125. if device_mesh.get_coordinate() is None:
  126. # if the global rank is not participating in the device mesh, we
  127. # simply set the local tensor to an empty tensor
  128. input = input.new_empty(0, requires_grad=input.requires_grad)
  129. elif run_check:
  130. # TODO: by default check tensor metas across rank
  131. # TODO: See if we need to make this run_check logic
  132. # have a corresponding backward.
  133. for idx, placement in enumerate(placements):
  134. if placement.is_replicate():
  135. # broadcast rank 0 tensor to all ranks
  136. # only broadcast if run_check is True
  137. input = input.contiguous()
  138. mesh_broadcast(input, device_mesh, mesh_dim=idx)
  139. dist_spec = DTensorSpec(
  140. device_mesh,
  141. placements,
  142. tensor_meta=TensorMeta(
  143. tensor_shape,
  144. tensor_stride,
  145. input.dtype,
  146. ),
  147. )
  148. # We want a fresh Tensor object that shares memory with the input tensor
  149. dist_tensor = DTensor(
  150. input.view_as(input),
  151. dist_spec,
  152. # requires_grad of the dist tensor depends on if input
  153. # requires_grad or not
  154. requires_grad=input.requires_grad,
  155. )
  156. return dist_tensor
  157. @staticmethod
  158. def backward(ctx, grad_output: "DTensor"): # type: ignore[override]
  159. previous_placement = ctx.previous_placement
  160. previous_device_mesh = ctx.previous_device_mesh
  161. # reshard to the placement when creating DistributedTensor
  162. # so that the gradient layout matches, and we could return
  163. # local gradients directly
  164. if grad_output.placements != previous_placement:
  165. current_spec = grad_output._spec
  166. target_spec = DTensorSpec(
  167. previous_device_mesh,
  168. previous_placement,
  169. tensor_meta=grad_output._spec.tensor_meta,
  170. )
  171. local_tensor = grad_output._local_tensor
  172. output = redistribute_local_tensor(
  173. local_tensor, current_spec, target_spec, is_backward=True
  174. )
  175. # TODO: return the redistributed local tensor directly without
  176. # differentiable backward. see if this make sense for all cases.
  177. return output, None, None, None, None, None
  178. # TODO: backward is also differentiable now, add a test
  179. # to test higher level gradients.
  180. return grad_output.to_local(), None, None, None, None, None
  181. class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
  182. _local_tensor: torch.Tensor
  183. _spec: DTensorSpec
  184. __slots__ = ["_local_tensor", "_spec"]
  185. # class attribute that handles operator placements propagation
  186. # rules, keyed by aten op name, value is propagation func
  187. _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher()
  188. @staticmethod
  189. @torch._disable_dynamo
  190. def __new__(
  191. cls,
  192. local_tensor: torch.Tensor,
  193. spec: DTensorSpec,
  194. *,
  195. requires_grad: bool,
  196. ) -> "DTensor":
  197. """
  198. Construct a DTensor from a local tensor, device mesh, and placement and
  199. other tensor properties (i.e. shape, requires_grad, strides, etc).
  200. Note: This is not a public API and it's only supposed to be used by the
  201. operator implementations and internals. If you want to construct a
  202. DTensor from a local tensor, consider using `DTensor.from_local`, if
  203. you want to construct a DTensor from a "global" tensor (where you
  204. already have tensor initialized and want to shard this tensor),
  205. consider using `distribute_tensor`.
  206. """
  207. if local_tensor.requires_grad and not requires_grad:
  208. warnings.warn(
  209. "To construct DTensor from torch.Tensor, it's recommended to "
  210. "use local_tensor.detach() and make requires_grad consistent."
  211. )
  212. # new method instruct wrapper tensor from local_tensor and add
  213. # placement spec, it does not do actual distribution
  214. assert spec.tensor_meta is not None, "TensorMeta should not be None!"
  215. r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
  216. cls,
  217. spec.tensor_meta.shape,
  218. strides=spec.tensor_meta.stride,
  219. dtype=local_tensor.dtype,
  220. device=local_tensor.device,
  221. layout=local_tensor.layout,
  222. requires_grad=requires_grad,
  223. )
  224. r._spec = spec
  225. r._local_tensor = local_tensor
  226. return r
  227. # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.
  228. # pyre-fixme[3]: Return type must be annotated.
  229. def __repr__(self):
  230. # TODO: consider all_gather the local tensors for better debugging
  231. return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"
  232. def __tensor_flatten__(self):
  233. """
  234. protocol to inform how to flatten a DTensor to local tensor
  235. for PT2 tracing
  236. """
  237. return ["_local_tensor"], (self._spec, self.requires_grad)
  238. @staticmethod
  239. def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
  240. assert (
  241. flatten_spec is not None
  242. ), "Expecting spec to be not None from `__tensor_flatten__` return value!"
  243. local_tensor = inner_tensors["_local_tensor"]
  244. spec, requires_grad = flatten_spec
  245. unflatten_tensor_meta = TensorMeta(
  246. shape=outer_size,
  247. stride=outer_stride,
  248. dtype=spec.tensor_meta.dtype,
  249. )
  250. unflatten_spec = DTensorSpec(
  251. spec.mesh,
  252. spec.placements,
  253. tensor_meta=unflatten_tensor_meta,
  254. )
  255. return DTensor(
  256. local_tensor,
  257. unflatten_spec,
  258. requires_grad=requires_grad,
  259. )
  260. def __coerce_tangent_metadata__(self):
  261. if not any(isinstance(p, Partial) for p in self.placements):
  262. return self
  263. placements = [
  264. Replicate() if isinstance(p, Partial) else p for p in self.placements
  265. ]
  266. return self.redistribute(device_mesh=self.device_mesh, placements=placements)
  267. def __coerce_same_metadata_as_tangent__(self, metadata_tensor):
  268. return self.redistribute(
  269. device_mesh=self.device_mesh,
  270. placements=metadata_tensor.placements,
  271. )
  272. @classmethod
  273. @torch._disable_dynamo
  274. # pyre-fixme[3]: Return type must be annotated.
  275. # pyre-fixme[2]: Parameter must be annotated.
  276. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  277. return DTensor._op_dispatcher.dispatch(
  278. func,
  279. args,
  280. kwargs or {},
  281. )
  282. @staticmethod
  283. def from_local(
  284. local_tensor: torch.Tensor,
  285. device_mesh: Optional[DeviceMesh] = None,
  286. placements: Optional[Sequence[Placement]] = None,
  287. *,
  288. run_check: bool = True,
  289. shape: Optional[torch.Size] = None,
  290. stride: Optional[Tuple[int, ...]] = None,
  291. ) -> "DTensor":
  292. """
  293. Create a :class:`DTensor` from a local torch.Tensor on each rank
  294. according to the `device_mesh` and `placements` specified.
  295. Args:
  296. local_tensor (torch.Tensor): local torch.Tensor on each rank.
  297. device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
  298. tensor, if not specified, must be called under a DeviceMesh
  299. context manager, default: None
  300. placements (List[:class:`Placement`], optional): the placements that
  301. describes how to place the local torch.Tensor on DeviceMesh, must
  302. have the same number of elements as `device_mesh.ndim`. If not
  303. specified, we will by default replicate the tensor across the
  304. `device_mesh` from the first rank of each dimension of the `device_mesh`.
  305. Keyword args:
  306. run_check (bool, optional): indicate whether to run check across ranks
  307. to check meta information and data. if have :class:`Replicate` in
  308. `placements`, the data on first rank of the device mesh dimension
  309. will be broadcasted to other ranks.
  310. shape (torch.Size, optional): A List of int which specifies the size of
  311. DTensor which build on top of `local_tensor`. Note this needs to be
  312. provided if the shape of `local_tensor` are different across the ranks.
  313. If not provided, `shape` will be computed assuming the given distributed
  314. tensor is evenly sharded across ranks.
  315. stride (tuple, optional): A List of int which specifies the stride of DTensor.
  316. If not provided, `stride` will be computed assuming the given distributed
  317. tensor is evenly sharded across ranks.
  318. Returns:
  319. A :class:`DTensor` object
  320. .. note:: `from_local` is differentiable, the `requires_grad` of the created
  321. `DTensor` object will depend on if `local_tensor` requires_grad or not.
  322. """
  323. # if same shape/dtype, no need to run_check, if not, must allgather
  324. # the metadatas to check the size/dtype across ranks
  325. # There should be no data communication unless there's replication
  326. # strategy, where we broadcast the replication from the first rank
  327. # in the mesh dimension
  328. device_mesh = device_mesh or _mesh_resources.get_current_mesh()
  329. device_type = device_mesh.device_type
  330. # convert the local tensor to desired device base on device mesh's device_type
  331. if device_type != local_tensor.device.type and not local_tensor.is_meta:
  332. local_tensor = local_tensor.to(device_type)
  333. # set default placements to replicated if not specified
  334. if placements is None:
  335. placements = [Replicate() for _ in range(device_mesh.ndim)]
  336. else:
  337. placements = list(placements)
  338. for idx, placement in enumerate(placements):
  339. # normalize shard dim to be positive
  340. if placement.is_shard():
  341. placement = cast(Shard, placement)
  342. if placement.dim < 0:
  343. placements[idx] = Shard(placement.dim + local_tensor.ndim)
  344. # `from_local` is differentiable, and the gradient of the dist tensor this function
  345. # created should flow back the gradients to the local_tensor, so we call an autograd
  346. # function to construct the dist tensor instead.
  347. return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func
  348. local_tensor,
  349. device_mesh,
  350. tuple(placements),
  351. run_check,
  352. shape,
  353. stride,
  354. )
  355. def to_local(
  356. self, *, grad_placements: Optional[Sequence[Placement]] = None
  357. ) -> torch.Tensor:
  358. """
  359. Get the local tensor of this DTensor on its current rank. For sharding it returns
  360. a local shard of the logical tensor view, for replication it returns the replica on
  361. its current rank.
  362. Keyword args:
  363. grad_placements (List[:class:`Placement`], optional): the placements describes
  364. the future layout of any gradient layout of the Tensor returned from this
  365. function.
  366. `to_local` converts DTensor to local tensor and the returned local tensor
  367. might not be used as the original DTensor layout later in the code. This
  368. argument is the hint that user can give to autograd in case the gradient
  369. layout of the returned tensor does not match the original DTensor layout.
  370. If not specified, we will assume the gradient layout remains the same
  371. as the original DTensor and use that for gradient computation.
  372. Returns:
  373. A :class:`torch.Tensor` or `AsyncCollectiveTensor` object. it represents the
  374. local tensor on its current rank.
  375. .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned
  376. will depend on if the `DTensor` requires_grad or not.
  377. """
  378. if not torch.is_grad_enabled():
  379. return self._local_tensor
  380. if grad_placements is not None and not isinstance(grad_placements, tuple):
  381. grad_placements = tuple(grad_placements)
  382. return _ToTorchTensor.apply(
  383. self, grad_placements
  384. ) # pyre-ignore[16]: autograd func
  385. def redistribute(
  386. self,
  387. device_mesh: Optional[DeviceMesh] = None,
  388. placements: Optional[Sequence[Placement]] = None,
  389. *,
  390. async_op: bool = False,
  391. ) -> "DTensor":
  392. """
  393. `redistribute` performs necessary collective operations that redistribute the current
  394. DTensor from its current placements to a new placements, or from is current DeviceMesh
  395. to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by
  396. specifying a Replicate placement for each dimension of the DeviceMesh.
  397. Args:
  398. device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
  399. DTensor, if not specified, must be called under a DeviceMesh
  400. context manager, default: None
  401. placements (List[:class:`Placement`], optional): the new placements that
  402. describes how to place the DTensor into the DeviceMesh, must
  403. have the same number of elements as `device_mesh.ndim`.
  404. Keyword args:
  405. async_op (bool, optional): whether to perform the DTensor redistribute operation
  406. asynchronously or not. Default: False
  407. Returns:
  408. A :class:`DTensor` object
  409. .. note:: `redistribute` is differentiable.
  410. """
  411. # NOTE: This redistribute API currently only supports out
  412. # of place redistribution, i.e. it always create a new
  413. # DTensor object and leave the original one unchanged.
  414. # if device_mesh is not specified, use the current device_mesh
  415. device_mesh = device_mesh or self.device_mesh
  416. # raise error if new placements not specified
  417. if placements is None:
  418. raise RuntimeError("placements is needed for redistribute!")
  419. placements = list(placements)
  420. for i, placement in enumerate(placements):
  421. if placement.is_partial():
  422. raise RuntimeError(
  423. "Can not redistribute to Partial, redistributing to Partial is for internal use only!"
  424. )
  425. elif isinstance(placement, Shard) and placement.dim < 0:
  426. # normalize shard dim to be positive
  427. placements[i] = Shard(placement.dim + self.ndim)
  428. placements = tuple(placements)
  429. # pyre-fixme[16]: `Redistribute` has no attribute `apply`.
  430. return Redistribute.apply(self, device_mesh, placements, async_op)
  431. def full_tensor(
  432. self, *, grad_placements: Optional[Sequence[Placement]] = None
  433. ) -> torch.Tensor:
  434. """
  435. Return the full tensor of this DTensor. It will perform necessary collectives
  436. to gather the local tensors from other ranks in its DeviceMesh and concatenate
  437. them together. It's a syntatic sugar of the following code:
  438. `dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`
  439. Keyword args:
  440. grad_placements (List[:class:`Placement`], optional): the placements describes
  441. the future layout of any gradient layout of the full Tensor returned from this
  442. function.
  443. `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor
  444. might not be used as the original replicated DTensor layout later in the code. This
  445. argument is the hint that user can give to autograd in case the gradient
  446. layout of the returned tensor does not match the original replicated DTensor layout.
  447. If not specified, we will assume the gradient layout of the full tensor be replicated.
  448. Returns:
  449. A :class:`torch.Tensor` object that represents the full tensor of this DTensor.
  450. .. note:: `full_tensor` is differentiable.
  451. """
  452. redist_res = self.redistribute(
  453. placements=[Replicate()] * self.device_mesh.ndim, async_op=False
  454. )
  455. return _ToTorchTensor.apply(redist_res, grad_placements)
  456. @property
  457. def device_mesh(self) -> DeviceMesh:
  458. """
  459. The :class:`DeviceMesh` attribute that associates with this DTensor object.
  460. .. note:: device_mesh is a read-only property, it can not be set.
  461. """
  462. return self._spec.mesh
  463. @property
  464. def placements(self) -> Sequence[Placement]:
  465. """
  466. The placements attribute of this DTensor that describes the layout of this
  467. DTensor on the its DeviceMesh.
  468. .. note:: placements is a read-only property, it can not be set.
  469. """
  470. return self._spec.placements
  471. def distribute_tensor(
  472. tensor: torch.Tensor,
  473. device_mesh: Optional[DeviceMesh] = None,
  474. placements: Optional[Sequence[Placement]] = None,
  475. ) -> DTensor:
  476. """
  477. Distribute a leaf torch.Tensor (i.e. nn.Parameter) to the ``device_mesh`` according
  478. to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be
  479. the same. If you want to construct a DTensor in the middle of the Autograd computation,
  480. please use ``DTensor.from_local`` instead.
  481. Args:
  482. tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you
  483. want to shard a tensor on a dimension that is not evenly divisible by
  484. the number of devices in that mesh dimension, we use ``torch.chunk``
  485. semantic to shard the tensor and scatter the shards.
  486. device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the
  487. tensor, if not specified, must be called under a DeviceMesh context
  488. manager, default: None
  489. placements (List[:class:`Placement`], optional): the placements that
  490. describes how to place the tensor on DeviceMesh, must have the same
  491. number of elements as `device_mesh.ndim`. If not specified, we will
  492. by default replicate the tensor across the `device_mesh` from the
  493. first rank of each dimension of the `device_mesh`.
  494. Returns:
  495. A :class:`DTensor` or `XLAShardedTensor` object.
  496. Note:
  497. When initialize the DeviceMesh with the `xla` device_type, `distribute_tensor`
  498. return `XLAShardedTensor` instead. see [link](https://github.com/pytorch/pytorch/issues/92909)
  499. for more details. The XLA integration is experimental and subject to change.
  500. """
  501. torch._C._log_api_usage_once("torch.dtensor.distribute_tensor")
  502. # get default device mesh if there's nothing specified
  503. device_mesh = device_mesh or _mesh_resources.get_current_mesh()
  504. device_type = device_mesh.device_type
  505. if device_type == "xla":
  506. try:
  507. # call PyTorch/XLA SPMD for `xla` backend type device mesh.
  508. # This returns XLAShardedTensor
  509. from torch_xla.distributed.spmd import ( # type:ignore[import]
  510. xla_distribute_tensor,
  511. )
  512. return xla_distribute_tensor(
  513. tensor, device_mesh, placements
  514. ) # type:ignore[return-value]
  515. except ImportError as e:
  516. msg = "To use DTensor API with xla, you must install the torch_xla package!"
  517. raise ImportError(msg) from e
  518. # instantiate a RNG tracker if haven't. By default DTensor uses an
  519. # OffsetBasedRNGTracker to perform random operators.
  520. # TODO: the value assignment to global variable is not the ideal solution
  521. # we can replace it in future.
  522. if not random._rng_tracker and is_rng_supported_mesh(device_mesh):
  523. random._rng_tracker = OffsetBasedRNGTracker(device_type)
  524. if not tensor.is_leaf:
  525. raise RuntimeError(
  526. "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!"
  527. )
  528. # convert tensor to the corresponding device type if it's not in that device type
  529. if device_type != tensor.device.type and not tensor.is_meta:
  530. tensor = tensor.to(device_type)
  531. # set default placements to replicated if not specified
  532. if placements is None:
  533. placements = [Replicate() for _ in range(device_mesh.ndim)]
  534. if len(placements) != device_mesh.ndim:
  535. raise ValueError(
  536. f"`placements` must have the same length as `device_mesh.ndim`! "
  537. f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}."
  538. )
  539. if isinstance(tensor, DTensor):
  540. # if the tensor is already a DTensor, we need to check:
  541. # 1. if the we can further shard this DTensor if the two device mesh belong to
  542. # the same parenet mesh and further sharding is possible.
  543. # 2. check if device mesh and placements are the same
  544. if tensor.device_mesh != device_mesh:
  545. raise ValueError(
  546. f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} "
  547. f"to a different device mesh {device_mesh}."
  548. )
  549. if tensor.placements != tuple(placements):
  550. raise ValueError(
  551. f"Cannot distribute a DTensor with placements {tensor.placements} "
  552. f"to a different placements {placements}. do you want to call "
  553. f"`redistribute` instead?"
  554. )
  555. return tensor
  556. local_tensor = tensor.detach()
  557. # distribute the tensor according to the placements.
  558. placements = list(placements)
  559. for idx, placement in enumerate(placements):
  560. if placement.is_shard():
  561. placement = cast(Shard, placement)
  562. if placement.dim < 0:
  563. # normalize shard placement dim
  564. placement = Shard(placement.dim + tensor.ndim)
  565. placements[idx] = placement
  566. local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx)
  567. elif placement.is_replicate():
  568. placement = cast(Replicate, placement)
  569. local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx)
  570. else:
  571. raise RuntimeError(
  572. f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!"
  573. )
  574. placements = tuple(placements)
  575. assert local_tensor is not None, "distributing a tensor should not be None"
  576. # detach the local tensor passed to DTensor since after the construction
  577. # of DTensor, autograd would work on top of DTensor instead of local tensor
  578. spec = DTensorSpec(
  579. mesh=device_mesh,
  580. placements=placements,
  581. tensor_meta=TensorMeta(
  582. shape=tensor.size(),
  583. stride=tensor.stride(),
  584. dtype=tensor.dtype,
  585. ),
  586. )
  587. return DTensor(
  588. local_tensor.requires_grad_(tensor.requires_grad),
  589. spec,
  590. requires_grad=tensor.requires_grad,
  591. )
  592. def distribute_module(
  593. module: nn.Module,
  594. device_mesh: Optional[DeviceMesh] = None,
  595. partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
  596. input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
  597. output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
  598. ) -> nn.Module:
  599. """
  600. This function expose three functions to control the Tensors inside the module:
  601. 1. To perform sharding on the module before runtime execution by specifying the
  602. ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor`
  603. parameters according to the `partition_fn` specified).
  604. 2. To control the inputs or outputs of the module during runtime execution by
  605. specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to
  606. :class:`DTensor`, convert the output back to torch.Tensor)
  607. Args:
  608. module (:class:`nn.Module`): user module to be partitioned.
  609. device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
  610. partition_fn (Callable): the function to partition parameters (i.e. shard certain
  611. parameters across the `device_mesh`). If `partition_fn` is not specified,
  612. by default we replicate all module parameters of `module` across the mesh.
  613. input_fn (Callable): specify the input distribution, i.e. could control how the
  614. input of the module is sharded. `input_fn` will be installed as a module
  615. `forward_pre_hook` (pre forward hook).
  616. output_fn (Callable): specify the output distribution, i.e. could control how the
  617. output is sharded, or convert it back to torch.Tensor. output_fn will be
  618. installed as a module `forward_hook` (post forward hook).
  619. Returns:
  620. A module that contains parameters/buffers that are all `DTensor`s.
  621. Note:
  622. When initialize the DeviceMesh with the `xla` device_type, `distribute_module`
  623. return nn.Module with PyTorch/XLA SPMD annotated parameters. See [link](https://github.com/pytorch/pytorch/issues/92909)
  624. for more details. The XLA integration is experimental and subject to change.
  625. """
  626. torch._C._log_api_usage_once("torch.dtensor.distribute_module")
  627. device_mesh = device_mesh or _mesh_resources.get_current_mesh()
  628. device_type = device_mesh.device_type
  629. if device_type == "xla":
  630. try:
  631. # This function annotates all module parameters for auto-partitioning with
  632. # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters
  633. # according to the `partition_fn` specified.
  634. from torch_xla.distributed.spmd import ( # type:ignore[import]
  635. xla_distribute_module,
  636. )
  637. return xla_distribute_module(
  638. module, device_mesh, partition_fn, input_fn, output_fn
  639. ) # type:ignore[return-value]
  640. except ImportError as e:
  641. msg = "To use DTensor API with xla, you must install the torch_xla package!"
  642. raise ImportError(msg) from e
  643. def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None:
  644. # This function loop over the immediate module parameters and
  645. # buffers, replicate all non DTensor params/buffers to DTensor
  646. # parameters/buffers, if they have not been partitioned in the
  647. # partition_fn, we can't easily use `module._apply` here
  648. # because we don't know what happened inside partition_fn as
  649. # user could do anything, i.e. install hooks, and we want to
  650. # preserve those.
  651. full_replicate = [Replicate()] * mesh.ndim
  652. for key, param in m._parameters.items():
  653. if param is not None and not isinstance(param, DTensor):
  654. m.register_parameter(
  655. key,
  656. nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)),
  657. )
  658. for key, buffer in m._buffers.items():
  659. if buffer is not None and not isinstance(buffer, DTensor):
  660. m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate)
  661. if partition_fn is None:
  662. # if partition_fn not specified, we by default replicate
  663. # all module params/buffers
  664. for name, submod in module.named_modules():
  665. replicate_module_params_buffers(submod, device_mesh)
  666. else:
  667. # apply partition_fun to submodules
  668. for name, submod in module.named_modules():
  669. partition_fn(name, submod, device_mesh)
  670. replicate_module_params_buffers(submod, device_mesh)
  671. # register input_fn as module forward pre hook
  672. if input_fn is not None:
  673. # check the input_fn signature
  674. num_args = len(inspect.signature(input_fn).parameters)
  675. if num_args == 2:
  676. # input_fn only takes in inputs and device mesh
  677. warnings.warn(
  678. "Deprecating input_fn that takes two arguments (inputs, device_mesh), "
  679. "please use input_fn that takes in (module, inputs, device_mesh) instead!",
  680. FutureWarning,
  681. stacklevel=2,
  682. )
  683. module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg]
  684. elif num_args == 3:
  685. # input_fn takes in module, inputs, device mesh
  686. module.register_forward_pre_hook(
  687. lambda mod, inputs: input_fn(mod, inputs, device_mesh)
  688. )
  689. else:
  690. raise ValueError(
  691. f"input_fn should take in 3 arguments, but got {num_args} arguments!"
  692. )
  693. # register output_fn as module forward hook
  694. if output_fn is not None:
  695. num_args = len(inspect.signature(output_fn).parameters)
  696. if num_args == 2:
  697. # output_fn only takes in outputs and device mesh
  698. warnings.warn(
  699. "Deprecating output_fn that takes two arguments (inputs, device_mesh), "
  700. "please use output_fn that takes in (module, inputs, device_mesh) instead!",
  701. FutureWarning,
  702. stacklevel=2,
  703. )
  704. module.register_forward_hook(
  705. lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg]
  706. )
  707. elif num_args == 3:
  708. module.register_forward_hook(
  709. lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)
  710. )
  711. else:
  712. raise ValueError(
  713. f"output_fn should take in 3 arguments, but got {num_args} arguments!"
  714. )
  715. return module