loss.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import contextlib
  4. from typing import cast, Dict, Optional, Tuple
  5. import torch
  6. import torch._prims_common as utils
  7. import torch.distributed._functional_collectives as funcol
  8. import torch.distributed.distributed_c10d as c10d
  9. from torch import Tensor
  10. from torch.distributed._tensor import DTensor, Replicate, Shard
  11. from torch.distributed._tensor.ops.embedding_ops import _MaskPartial
  12. from torch.distributed._tensor.ops.math_ops import (
  13. _skip_dim,
  14. Reduction,
  15. replicate_reduction_dims,
  16. )
  17. from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta
  18. from torch.distributed.device_mesh import DeviceMesh
  19. aten = torch.ops.aten
  20. __all__ = ["loss_parallel"]
  21. @contextlib.contextmanager
  22. def loss_parallel():
  23. """
  24. A context manager that enables loss parallelism, where efficient parallelized loss computation
  25. can be performed when the input is sharded on the class dimension. Currently only the cross-entropy
  26. loss is supported.
  27. Within this context manager, one can use :func:`~torch.nn.functional.cross_entropy` or
  28. :class:`~torch.nn.CrossEntropyLoss` as usual, with the following assumptions on the input parameters.
  29. The corresponding ``backward()`` call, if any, also needs to happen under this context manager.
  30. Args:
  31. input (:class:`DTensor`):
  32. Input logits. Assumed to be sharded on the class dimension.
  33. target (Union[:class:`torch.Tensor`, :class:`DTensor`]):
  34. Must be ground truth class indices (class probabilities currently not supported).
  35. Assumed to be replicated across the ``DeviceMesh``.
  36. weight (Union[:class:`torch.Tensor`, :class:`DTensor`], optional):
  37. If given, assumed to be replicated across the ``DeviceMesh``.
  38. label_smoothing:
  39. Currently not supported.
  40. Returns:
  41. A replicated :class:`DTensor`.
  42. Example:
  43. A sharded DTensor is manually created here to showcase the usage.
  44. In practice, it is usually the output of a TP module.
  45. >>> # xdoctest: +SKIP("distributed")
  46. >>> from torch.distributed.tensor.parallel import loss_parallel
  47. >>> from torch.distributed.device_mesh import init_device_mesh
  48. >>> ...
  49. >>> device_mesh = init_device_mesh("cuda", (8,))
  50. >>> input = torch.randn(4, 16, device="cuda", requires_grad=True)
  51. >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)])
  52. >>> target = torch.randint(16, (4,), device="cuda")
  53. >>> with loss_parallel():
  54. >>> loss = F.cross_entropy(dist_input, target, reduction="mean")
  55. >>> loss.backward()
  56. >>> ...
  57. """
  58. _enable_custom_loss_ops()
  59. yield
  60. _disable_custom_loss_ops()
  61. # Currently only needs to support one dimensional DeviceMesh; in general return
  62. # the mesh_dim with placements[mesh_dim].is_shard(dim)
  63. def _find_all_reduce_mesh_dim(placements: Tuple[Placement, ...], dim: int) -> int:
  64. if not len(placements) == 1:
  65. raise ValueError(
  66. "Currently loss_parallel() only supports input on one-dimensional DeviceMesh."
  67. )
  68. if not placements[0].is_shard(dim):
  69. raise ValueError(
  70. f"loss_parallel() should be enabled only when the input tensor is sharded on dimension {dim}."
  71. )
  72. return 0
  73. def _cast_to_dtensor(
  74. tensor, placements: Tuple[Placement, ...], mesh: DeviceMesh
  75. ) -> DTensor:
  76. if isinstance(tensor, DTensor):
  77. if tensor.placements == placements:
  78. return tensor
  79. else:
  80. raise RuntimeError(f"Expected {placements} but got {tensor.placements}.")
  81. elif isinstance(tensor, torch.Tensor):
  82. return DTensor.from_local(
  83. tensor, device_mesh=mesh, placements=placements, run_check=False
  84. )
  85. else:
  86. raise TypeError(f"Unsupported type {type(tensor)}")
  87. def _propagate_tensor_meta(
  88. op_call: torch._ops.OpOverload,
  89. args: Tuple[object, ...],
  90. kwargs: Dict[str, object],
  91. ) -> TensorMeta:
  92. op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
  93. tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta(
  94. op_info.schema
  95. )
  96. if isinstance(tensor_meta, TensorMeta):
  97. return tensor_meta
  98. elif isinstance(tensor_meta, tuple):
  99. return tensor_meta[0]
  100. else:
  101. raise RuntimeError(f"Unexpected tensor meta type: {type(tensor_meta)}.")
  102. # NOTE: The implementation follows torch._decomp.decomposition._log_softmax,
  103. # with all_reduce manually inserted to perform distributed computation.
  104. def _log_softmax(x, dim, half_to_float, mesh, mesh_dim):
  105. x = x.contiguous()
  106. if half_to_float:
  107. assert x.dtype == torch.half
  108. computation_dtype, result_dtype = utils.elementwise_dtypes(
  109. x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  110. )
  111. x = x.to(computation_dtype)
  112. if x.numel() == 0:
  113. shifted = x
  114. else:
  115. x_max = torch.amax(x, dim, keepdim=True)
  116. x_max = funcol.all_reduce(
  117. x_max, reduceOp=c10d.ReduceOp.MAX.name, group=(mesh, mesh_dim)
  118. )
  119. shifted = x - x_max
  120. shifted_sumexp = torch.sum(torch.exp(shifted), dim, keepdim=True)
  121. shifted_sumexp = funcol.all_reduce(
  122. shifted_sumexp, reduceOp=c10d.ReduceOp.SUM.name, group=(mesh, mesh_dim)
  123. )
  124. shifted_logsumexp = torch.log(shifted_sumexp)
  125. result = shifted - shifted_logsumexp
  126. if not half_to_float:
  127. result = result.to(result_dtype)
  128. return result
  129. def _log_softmax_handler(
  130. op_call: torch._ops.OpOverload,
  131. args: Tuple[object, ...],
  132. kwargs: Dict[str, object],
  133. ) -> object:
  134. x = cast(DTensor, args[0])
  135. dim = cast(int, args[1])
  136. half_to_float = cast(bool, args[2])
  137. spec = x._spec
  138. mesh_dim = _find_all_reduce_mesh_dim(spec.placements, dim)
  139. output_tensor_meta = _propagate_tensor_meta(op_call, args, kwargs)
  140. res = _log_softmax(x._local_tensor, dim, half_to_float, spec.mesh, mesh_dim)
  141. res_spec = DTensorSpec(
  142. spec.mesh,
  143. spec.placements,
  144. tensor_meta=output_tensor_meta,
  145. )
  146. return DTensor(
  147. res,
  148. res_spec,
  149. requires_grad=res.requires_grad,
  150. )
  151. # NOTE: As explained below at _nll_loss_and_log_softmax_backward, the
  152. # _log_softmax_backward_handler does not actually do any computation.
  153. def _log_softmax_backward_handler(
  154. op_call: torch._ops.OpOverload,
  155. args: Tuple[object, ...],
  156. kwargs: Dict[str, object],
  157. ) -> object:
  158. grad_output = cast(DTensor, args[0])
  159. input_dtype = cast(torch.dtype, args[3])
  160. return grad_output.to(input_dtype)
  161. # NOTE: The implementation follows torch._decomp.decomposition._nll_loss_forward,
  162. # with customized communication inserted to perform distributed computation.
  163. def _nll_loss_forward(
  164. x: Tensor,
  165. target: Tensor,
  166. weight: Optional[Tensor],
  167. local_weight: Optional[Tensor],
  168. reduction: int,
  169. ignore_index: int,
  170. channel_dim_size: int,
  171. mesh: DeviceMesh,
  172. mesh_dim: int,
  173. ) -> Tuple[Tensor, Tensor]:
  174. n_dims = x.dim()
  175. channel_dim = 1
  176. if n_dims < 2:
  177. channel_dim = 0
  178. def _weight_view(weight: Tensor) -> Tensor:
  179. if n_dims > 1:
  180. shape = [
  181. 1,
  182. ] * n_dims
  183. shape[channel_dim] = weight.shape[0]
  184. w = weight.view(shape)
  185. else:
  186. w = weight
  187. return w
  188. if weight is not None:
  189. w = _weight_view(weight)
  190. assert local_weight is not None
  191. local_w = _weight_view(local_weight)
  192. x = x * local_w
  193. safe_target = torch.where(target != ignore_index, target, 0)
  194. safe_target_ = safe_target.unsqueeze(channel_dim)
  195. # The following code block is a distributed version of
  196. # result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
  197. partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
  198. safe_target_partial_ = partial_placement._partition_value(
  199. safe_target_, mesh, mesh_dim
  200. )
  201. result_partial = torch.gather(x, channel_dim, safe_target_partial_)
  202. # an all_reduce happens here
  203. result_reduced = partial_placement._reduce_value(result_partial, mesh, mesh_dim)
  204. result = -result_reduced.squeeze(channel_dim)
  205. result = torch.where(target != ignore_index, result, 0)
  206. if reduction == Reduction.NONE.value and n_dims > 1:
  207. total_weight = x.new_full((), 0.0)
  208. return result, total_weight
  209. if weight is not None:
  210. new_shape = list(x.shape)
  211. new_shape[channel_dim] = -1
  212. w = w.expand(new_shape)
  213. wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
  214. wsum = torch.where(target != ignore_index, wsum, 0)
  215. total_weight = wsum.sum()
  216. else:
  217. total_weight = (target != ignore_index).sum().to(x)
  218. # NOTE: this is correct only on 1D DeviceMesh; o/w additional
  219. # all-reduce on result and total_weight is needed
  220. if reduction == Reduction.SUM.value:
  221. result = result.sum()
  222. elif reduction == Reduction.MEAN.value:
  223. result = result.sum() / total_weight
  224. return result, total_weight
  225. def _nll_loss_forward_handler(
  226. op_call: torch._ops.OpOverload,
  227. args: Tuple[object, ...],
  228. kwargs: Dict[str, object],
  229. ) -> object:
  230. x = cast(DTensor, args[0])
  231. target = args[1]
  232. weight = args[2]
  233. reduction = cast(int, args[3])
  234. ignore_index = cast(int, args[4])
  235. channel_dim = 1 if x.dim() >= 2 else 0
  236. channel_dim_size = x.shape[channel_dim]
  237. spec = x._spec
  238. mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
  239. # Check user input: if target and weight are not DTensors, convert them to DTensors;
  240. # if they are DTensors, check that they have the desired placements.
  241. target_placements = _skip_dim(
  242. replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim
  243. )
  244. all_replicate_placements = (Replicate(),) * spec.mesh.ndim
  245. target = _cast_to_dtensor(target, target_placements, spec.mesh)
  246. local_weight = None
  247. if weight is not None:
  248. weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
  249. # For local computation, both (replicated) weight and (sharded) local_weight
  250. # are needed in _nll_loss_forward(). local_weight is generated here using
  251. # DTensor API, without incurring any communication.
  252. sharded_placements = [
  253. Shard(0) if i == mesh_dim else Replicate() for i in range(spec.mesh.ndim)
  254. ]
  255. local_weight = weight.redistribute(spec.mesh, sharded_placements)._local_tensor
  256. assert local_weight.shape[0] == x._local_tensor.shape[channel_dim]
  257. if reduction == Reduction.NONE.value:
  258. output_placements = target_placements
  259. else:
  260. output_placements = all_replicate_placements
  261. # tensor inputs to _propagate_tensor_meta need to be DTensors
  262. args = list(args)
  263. args[1], args[2] = target, weight
  264. output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
  265. result, total_weight = _nll_loss_forward(
  266. x._local_tensor,
  267. target._local_tensor,
  268. weight._local_tensor if weight is not None else None,
  269. local_weight,
  270. reduction,
  271. ignore_index,
  272. channel_dim_size,
  273. spec.mesh,
  274. mesh_dim,
  275. )
  276. out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta)
  277. return (
  278. DTensor(
  279. result,
  280. out_spec,
  281. requires_grad=result.requires_grad,
  282. ),
  283. total_weight,
  284. )
  285. # NOTE: The backward computation of cross_entropy goes through two steps:
  286. # backward for nll_loss and then backward for log_softmax. In loss parallel,
  287. # the two steps are fused into the following function (called by _nll_loss_backward_handler)
  288. # to avoid communication when target contains class indices not class probabilities.
  289. # Also note that the _log_softmax_backward_handler does not perform computation.
  290. # The implementation resembles _nll_loss_backward and _log_softmax_backward_data
  291. # from torch._decomp.decomposition.
  292. def _nll_loss_and_log_softmax_backward(
  293. grad_output: Tensor,
  294. x: Tensor,
  295. target: Tensor,
  296. weight: Optional[Tensor],
  297. reduction: int,
  298. ignore_index: int,
  299. total_weight: Tensor,
  300. channel_dim_size: int,
  301. mesh: DeviceMesh,
  302. mesh_dim: int,
  303. ) -> Tensor:
  304. channel_dim = 0 if x.dim() < 2 else 1
  305. if reduction == Reduction.MEAN.value:
  306. grad_output = grad_output / total_weight
  307. target = target.unsqueeze(channel_dim)
  308. safe_target = torch.where(target != ignore_index, target, 0)
  309. grad_input = torch.zeros_like(x)
  310. # The following code block is a distributed version of
  311. # grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
  312. partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
  313. safe_target = safe_target.squeeze(channel_dim).flatten()
  314. masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
  315. # only update grad_input to -1 if not masked
  316. assert partial_placement.mask_buffer.data is not None
  317. grad_update = partial_placement.mask_buffer.data.float() - 1.0
  318. arange_1d = torch.arange(
  319. masked_safe_target.shape[0], device=masked_safe_target.device
  320. )
  321. # The first two cases with x.dim() <= 2 are for aten.nll_loss_backward.default;
  322. # the last case is for aten.nll_loss2d_backward.default.
  323. if x.dim() == 1:
  324. grad_input[masked_safe_target] = grad_update
  325. elif x.dim() == 2:
  326. grad_input[arange_1d, masked_safe_target] = grad_update
  327. else:
  328. grad_input_t = grad_input.transpose(channel_dim, -1)
  329. intermidate_shape = grad_input_t.shape
  330. grad_input_2d = grad_input_t.reshape(-1, x.shape[channel_dim])
  331. grad_input_2d[arange_1d, masked_safe_target] = grad_update
  332. grad_input = grad_input_2d.view(intermidate_shape).transpose(channel_dim, -1)
  333. if grad_input.dim() > grad_output.dim() > 0:
  334. grad_output = grad_output.unsqueeze(channel_dim)
  335. if weight is not None:
  336. new_shape = [1 for _ in range(x.dim())]
  337. new_shape[channel_dim] = weight.shape[0]
  338. weight = weight.reshape(new_shape)
  339. # In order for fused computation to work, the following line is rewritten.
  340. # grad_output = grad_output * weight
  341. new_shape = list(x.shape)
  342. new_shape[channel_dim] = -1
  343. w = weight.expand(new_shape)
  344. w_target = torch.gather(w, channel_dim, target)
  345. grad_output = grad_output * w_target
  346. grad_output = torch.where(target != ignore_index, grad_output, 0)
  347. # NOTE: Instead of directly returning the grad_input as grad_output for log_softmax,
  348. # here we perform backward computation for log_softmax altogether to avoid the
  349. # otherwise extra all_gather communication.
  350. # return grad_input * grad_output
  351. return (grad_input + torch.exp(x)) * grad_output
  352. def _nll_loss_backward_handler(
  353. op_call: torch._ops.OpOverload,
  354. args: Tuple[object, ...],
  355. kwargs: Dict[str, object],
  356. ) -> object:
  357. grad_output = cast(DTensor, args[0])
  358. x = cast(DTensor, args[1])
  359. target = args[2]
  360. weight = args[3]
  361. reduction = cast(int, args[4])
  362. ignore_index = cast(int, args[5])
  363. total_weight = cast(Tensor, args[6])
  364. channel_dim = 1 if x.dim() >= 2 else 0
  365. channel_dim_size = x.shape[channel_dim]
  366. spec = x._spec
  367. mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
  368. # if target and weight are not DTensors, convert them to DTensors
  369. target_placements = _skip_dim(
  370. replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim
  371. )
  372. all_replicate_placements = (Replicate(),) * spec.mesh.ndim
  373. target = _cast_to_dtensor(target, target_placements, spec.mesh)
  374. if weight is not None:
  375. weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
  376. # tensor inputs to _propagate_tensor_meta need to be DTensors
  377. args = list(args)
  378. args[2], args[3] = target, weight
  379. args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh)
  380. output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
  381. result = _nll_loss_and_log_softmax_backward(
  382. grad_output._local_tensor,
  383. x._local_tensor,
  384. target._local_tensor,
  385. weight._local_tensor if weight is not None else None,
  386. reduction,
  387. ignore_index,
  388. total_weight,
  389. channel_dim_size,
  390. spec.mesh,
  391. mesh_dim,
  392. )
  393. # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim
  394. out_spec = DTensorSpec(
  395. spec.mesh,
  396. spec.placements,
  397. tensor_meta=output_tensor_meta,
  398. )
  399. return DTensor(
  400. result,
  401. out_spec,
  402. requires_grad=result.requires_grad,
  403. )
  404. customized_loss_ops = {
  405. aten._log_softmax.default: _log_softmax_handler,
  406. aten._log_softmax_backward_data.default: _log_softmax_backward_handler,
  407. aten.nll_loss_forward.default: _nll_loss_forward_handler,
  408. aten.nll_loss2d_forward.default: _nll_loss_forward_handler,
  409. aten.nll_loss_backward.default: _nll_loss_backward_handler,
  410. aten.nll_loss2d_backward.default: _nll_loss_backward_handler,
  411. }
  412. def _enable_custom_loss_ops():
  413. DTensor._op_dispatcher._custom_op_handlers.update(customized_loss_ops)
  414. def _disable_custom_loss_ops():
  415. for custom_op in customized_loss_ops:
  416. DTensor._op_dispatcher._custom_op_handlers.pop(custom_op)