api.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from typing import Dict, Union
  3. from fnmatch import fnmatch
  4. import torch
  5. import torch.distributed._tensor.random as random
  6. import torch.nn as nn
  7. from torch.distributed._tensor import (
  8. DeviceMesh,
  9. )
  10. from torch.distributed._tensor.random import (
  11. is_rng_supported_mesh,
  12. TensorParallelRNGTracker,
  13. )
  14. from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
  15. from torch.distributed.tensor.parallel.style import (
  16. ParallelStyle,
  17. )
  18. __all__ = [
  19. "parallelize_module",
  20. ]
  21. def parallelize_module( # type: ignore[return]
  22. module: nn.Module,
  23. device_mesh: DeviceMesh,
  24. parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
  25. ) -> nn.Module:
  26. """
  27. Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
  28. We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
  29. :class:`ParallelStyle`, which indicates how user wants the module or sub_module
  30. to be parallelized.
  31. User can also specify different parallel style per module fully qualified name (FQN).
  32. Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
  33. slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
  34. Args:
  35. module (:class:`nn.Module`):
  36. Module to be parallelized.
  37. device_mesh (:class:`DeviceMesh`):
  38. Object which describes the mesh topology
  39. of devices for the DTensor.
  40. parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
  41. The plan used to parallelize the module. It can be either a
  42. :class:`ParallelStyle` object which contains how
  43. we prepare input/output for Tensor Parallelism or it can be a
  44. dict of module FQN and its corresponding :class:`ParallelStyle` object.
  45. Return:
  46. A :class:`nn.Module` object parallelized.
  47. Example::
  48. >>> # xdoctest: +SKIP("distributed")
  49. >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
  50. >>> from torch.distributed.device_mesh import init_device_mesh
  51. >>>
  52. >>> # Define the module.
  53. >>> m = Model(...)
  54. >>> tp_mesh = init_device_mesh("cuda", (8,))
  55. >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
  56. >>>
  57. .. note:: For complex module architecture like Attention, MLP layers, we recommend composing
  58. different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
  59. as a parallelize_plan, to achieves the desired sharding computation.
  60. """
  61. torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
  62. _validate_tp_mesh_dim(device_mesh)
  63. # instantiate a TP RNG state tracker if it's not there
  64. if is_rng_supported_mesh(device_mesh) and not isinstance(
  65. random._rng_tracker, TensorParallelRNGTracker
  66. ):
  67. random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
  68. # TODO: we should allow user to pass in the default seed from a config
  69. random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
  70. # By default we execute random ops in non-tensor-parallel region. If users want
  71. # to execute in tensor-parallel region, they can manually set this field to True
  72. # after parallelizing the model.
  73. random._rng_tracker.distribute_region_enabled = False
  74. if isinstance(parallelize_plan, ParallelStyle):
  75. return parallelize_plan._apply(module, device_mesh)
  76. elif isinstance(parallelize_plan, dict):
  77. for module_path, parallelize_style in parallelize_plan.items():
  78. path_splits = module_path.split(".")
  79. if len(path_splits) == 0:
  80. raise ValueError(
  81. "Expect module path to be non-empty, but got empty string!"
  82. )
  83. while path_splits:
  84. atom = path_splits.pop(0)
  85. matched_children = filter(
  86. # `t[0]` is child name
  87. lambda t: fnmatch(t[0], atom), module.named_children()
  88. )
  89. # apply the plan to all matched submodules
  90. for _, submodule in matched_children:
  91. if path_splits:
  92. # we haven't reached the leaf, apply in dict style
  93. leaf_path = ".".join(path_splits) # rest of the path after `atom`
  94. parallelize_module(submodule, device_mesh, {leaf_path: parallelize_style})
  95. else:
  96. # otherwise, directly apply style to this submodule
  97. parallelize_module(submodule, device_mesh, parallelize_style)
  98. return module
  99. else:
  100. raise TypeError( # pyre-ignore[7]
  101. "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
  102. f" parallelize_plan, {type(parallelize_plan)} found!"
  103. )