parallel_mode.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Callable, Dict, List, Optional, Tuple
  3. import torch
  4. import torch.distributed as dist
  5. import torch.utils._pytree as pytree
  6. from torch._subclasses import FakeTensorMode
  7. from torch.distributed._spmd.data_parallel import (
  8. DataParallelStyle,
  9. partition_data_parallel,
  10. )
  11. from torch.distributed._spmd.distribute import _convert_to_distributed, Schema
  12. from torch.distributed._tensor import DeviceMesh, Placement, Replicate, Shard
  13. from torch.fx import GraphModule
  14. class ParallelMode(ABC):
  15. """
  16. Basic Parallel Mode interface. Each parallelism pattern should implement
  17. this interface to describe how to partition and compile the graph in the
  18. spmd compiler.
  19. """
  20. @abstractmethod
  21. def partition(
  22. self,
  23. gm: GraphModule,
  24. model: torch.nn.Module,
  25. optimizer: Optional[torch.optim.Optimizer],
  26. params_and_buffers: Dict[str, Any],
  27. named_states: Dict[str, Any],
  28. args: Tuple[Any, ...],
  29. kwargs: Dict[str, Any],
  30. ) -> GraphModule:
  31. """
  32. Partition a single device graph to a distributed graph.
  33. TODO(@wanchaol): some of these arguments are not necessary for
  34. partitioning, remove the unnecessary ones later.
  35. """
  36. raise NotImplementedError
  37. @abstractmethod
  38. def transform_and_compile(self, gm: GraphModule) -> GraphModule:
  39. """
  40. Transform and compile a distributed graph with a set of graph
  41. transformation and optimization passes for each parallel mode.
  42. The returned result should be a compiled executable graph in
  43. the distributed environment.
  44. """
  45. # TODO: add more necessary arguments to this interface.
  46. raise NotImplementedError
  47. class DataParallel(ParallelMode):
  48. """Data Parallelism mode."""
  49. def __init__(
  50. self,
  51. parallel_style: str = "replicate",
  52. *,
  53. input_batch_dim: int = 0,
  54. custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None,
  55. ):
  56. """
  57. DataParallel Mode that partition the model and graph to data parallel style
  58. parallelism (i.e. DDP/FSDP/ZERO-3). It currently supports three different
  59. parallel styles: "replicate", "fully_shard", and "default". See
  60. :class:`DataParallelStyle` for more details.
  61. Args:
  62. parallel_style (str): parallel style to use. Currently supports
  63. "replicate", "fully_shard", and "default".
  64. Keyword args:
  65. input_batch_dim (int): the batch dimension of the input tensor.
  66. default: 0
  67. custom_passes (Callable[[GraphModule], GraphModule], optional):
  68. A custom callable that overrides the default graph transformation
  69. and optimization passes.
  70. """
  71. if parallel_style == "replicate":
  72. self.parallel_style = DataParallelStyle.REPLICATE
  73. elif parallel_style == "fully_shard":
  74. self.parallel_style = DataParallelStyle.FULLY_SHARD
  75. elif parallel_style == "default":
  76. self.parallel_style = DataParallelStyle.DEFAULT
  77. else:
  78. raise RuntimeError(f"Unknown parallel style: {parallel_style}")
  79. # TODO: what if user passes in a incorrect `input_batch_dim`, how should we
  80. # detect that and do proper error handling?
  81. self.input_batch_dim = input_batch_dim
  82. if custom_passes is not None:
  83. self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes
  84. else:
  85. # TODO: add a few default passes here.
  86. self._gm_passes = lambda gm: gm
  87. def partition(
  88. self,
  89. gm: GraphModule,
  90. model: torch.nn.Module,
  91. optimizer: Optional[torch.optim.Optimizer],
  92. params_and_buffers: Dict[str, Any],
  93. named_states: Dict[str, Any],
  94. args: Tuple[Any, ...],
  95. kwargs: Dict[str, Any],
  96. ) -> GraphModule:
  97. # TODO: figure out a way to avoid explicit "cuda" mesh.
  98. mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()))
  99. gm = partition_data_parallel(
  100. gm,
  101. model,
  102. optimizer,
  103. params_and_buffers,
  104. named_states,
  105. args,
  106. kwargs,
  107. mesh,
  108. self.parallel_style,
  109. self.input_batch_dim,
  110. )
  111. return gm
  112. def transform_and_compile(self, gm: GraphModule) -> GraphModule:
  113. """optimize a distributed graph with a set of optimization passes"""
  114. # TODO: add more necessary arguments to this interface.
  115. return self._gm_passes(gm)
  116. class DTensorExpandMode(ParallelMode):
  117. """
  118. The DTensor Expand mode. It's replicating the parameters and
  119. shard the inputs to represent DDP like behavior, it's currently
  120. a transitent mode before we move to the new data parallel expansion.
  121. """
  122. def __init__(
  123. self, custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None
  124. ):
  125. self._placements_override: Dict[int, List[Placement]] = {}
  126. if custom_passes is not None:
  127. self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes
  128. else:
  129. # TODO: add a few default passes here.
  130. self._gm_passes = lambda gm: gm
  131. def partition(
  132. self,
  133. gm: GraphModule,
  134. model: torch.nn.Module,
  135. optimizer: Optional[torch.optim.Optimizer],
  136. params_and_buffers: Dict[str, Any],
  137. named_states: Dict[str, Any],
  138. args: Tuple[Any, ...],
  139. kwargs: Dict[str, Any],
  140. ) -> GraphModule:
  141. flat_args = pytree.arg_tree_leaves(*args, **kwargs)
  142. mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()).cuda())
  143. shard_schema: Schema = Schema(mesh=mesh, placements=[Shard(0)])
  144. # FIXME: allow other sharding schemas
  145. replicate_schema: Schema = Schema(mesh=mesh, placements=[Replicate()])
  146. inps, schemas = [], []
  147. for p in pytree.tree_leaves(params_and_buffers):
  148. assert isinstance(p, torch.Tensor), f"expecting Tensor but got {type(p)}"
  149. inps.append(p)
  150. schemas.append(replicate_schema)
  151. for o in pytree.tree_leaves(named_states):
  152. if isinstance(o, torch.Tensor):
  153. inps.append(o)
  154. schemas.append(replicate_schema)
  155. else:
  156. inps.append(torch.empty(0))
  157. schemas.append(replicate_schema)
  158. for a in flat_args:
  159. if isinstance(a, torch.Tensor):
  160. inps.append(a)
  161. if id(a) in self._placements_override:
  162. schemas.append(
  163. Schema(mesh=mesh, placements=self._placements_override[id(a)])
  164. )
  165. else:
  166. schemas.append(shard_schema)
  167. else:
  168. # Create dummy tensor and schema for non-tensor inputs for
  169. # the purpose of dtensor expansion. Non-tensor inputs are
  170. # guaranteed unused in dispatcher graphs produced by make_fx.
  171. # However, we still need to respect them so that tensor inputs
  172. # match wtih their placeholders.
  173. inps.append(torch.empty(0))
  174. schemas.append(shard_schema)
  175. with FakeTensorMode(allow_non_fake_inputs=True):
  176. fake_inps = [torch.empty_like(inp) for inp in inps]
  177. return _convert_to_distributed(
  178. gm, fake_inps, schemas, default_mesh=mesh, _allow_partial=False
  179. )[0]
  180. def transform_and_compile(self, gm: GraphModule) -> GraphModule:
  181. """
  182. Transform and compile a distributed graph with a set of graph transformation
  183. and optimization passes for the dtensor fallback parallel mode.
  184. """
  185. # TODO: move the trasnformation passed to this function
  186. return self._gm_passes(gm)