| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- from abc import ABC, abstractmethod
- from typing import Any, Callable, Dict, List, Optional, Tuple
- import torch
- import torch.distributed as dist
- import torch.utils._pytree as pytree
- from torch._subclasses import FakeTensorMode
- from torch.distributed._spmd.data_parallel import (
- DataParallelStyle,
- partition_data_parallel,
- )
- from torch.distributed._spmd.distribute import _convert_to_distributed, Schema
- from torch.distributed._tensor import DeviceMesh, Placement, Replicate, Shard
- from torch.fx import GraphModule
- class ParallelMode(ABC):
- """
- Basic Parallel Mode interface. Each parallelism pattern should implement
- this interface to describe how to partition and compile the graph in the
- spmd compiler.
- """
- @abstractmethod
- def partition(
- self,
- gm: GraphModule,
- model: torch.nn.Module,
- optimizer: Optional[torch.optim.Optimizer],
- params_and_buffers: Dict[str, Any],
- named_states: Dict[str, Any],
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- ) -> GraphModule:
- """
- Partition a single device graph to a distributed graph.
- TODO(@wanchaol): some of these arguments are not necessary for
- partitioning, remove the unnecessary ones later.
- """
- raise NotImplementedError
- @abstractmethod
- def transform_and_compile(self, gm: GraphModule) -> GraphModule:
- """
- Transform and compile a distributed graph with a set of graph
- transformation and optimization passes for each parallel mode.
- The returned result should be a compiled executable graph in
- the distributed environment.
- """
- # TODO: add more necessary arguments to this interface.
- raise NotImplementedError
- class DataParallel(ParallelMode):
- """Data Parallelism mode."""
- def __init__(
- self,
- parallel_style: str = "replicate",
- *,
- input_batch_dim: int = 0,
- custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None,
- ):
- """
- DataParallel Mode that partition the model and graph to data parallel style
- parallelism (i.e. DDP/FSDP/ZERO-3). It currently supports three different
- parallel styles: "replicate", "fully_shard", and "default". See
- :class:`DataParallelStyle` for more details.
- Args:
- parallel_style (str): parallel style to use. Currently supports
- "replicate", "fully_shard", and "default".
- Keyword args:
- input_batch_dim (int): the batch dimension of the input tensor.
- default: 0
- custom_passes (Callable[[GraphModule], GraphModule], optional):
- A custom callable that overrides the default graph transformation
- and optimization passes.
- """
- if parallel_style == "replicate":
- self.parallel_style = DataParallelStyle.REPLICATE
- elif parallel_style == "fully_shard":
- self.parallel_style = DataParallelStyle.FULLY_SHARD
- elif parallel_style == "default":
- self.parallel_style = DataParallelStyle.DEFAULT
- else:
- raise RuntimeError(f"Unknown parallel style: {parallel_style}")
- # TODO: what if user passes in a incorrect `input_batch_dim`, how should we
- # detect that and do proper error handling?
- self.input_batch_dim = input_batch_dim
- if custom_passes is not None:
- self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes
- else:
- # TODO: add a few default passes here.
- self._gm_passes = lambda gm: gm
- def partition(
- self,
- gm: GraphModule,
- model: torch.nn.Module,
- optimizer: Optional[torch.optim.Optimizer],
- params_and_buffers: Dict[str, Any],
- named_states: Dict[str, Any],
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- ) -> GraphModule:
- # TODO: figure out a way to avoid explicit "cuda" mesh.
- mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()))
- gm = partition_data_parallel(
- gm,
- model,
- optimizer,
- params_and_buffers,
- named_states,
- args,
- kwargs,
- mesh,
- self.parallel_style,
- self.input_batch_dim,
- )
- return gm
- def transform_and_compile(self, gm: GraphModule) -> GraphModule:
- """optimize a distributed graph with a set of optimization passes"""
- # TODO: add more necessary arguments to this interface.
- return self._gm_passes(gm)
- class DTensorExpandMode(ParallelMode):
- """
- The DTensor Expand mode. It's replicating the parameters and
- shard the inputs to represent DDP like behavior, it's currently
- a transitent mode before we move to the new data parallel expansion.
- """
- def __init__(
- self, custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None
- ):
- self._placements_override: Dict[int, List[Placement]] = {}
- if custom_passes is not None:
- self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes
- else:
- # TODO: add a few default passes here.
- self._gm_passes = lambda gm: gm
- def partition(
- self,
- gm: GraphModule,
- model: torch.nn.Module,
- optimizer: Optional[torch.optim.Optimizer],
- params_and_buffers: Dict[str, Any],
- named_states: Dict[str, Any],
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- ) -> GraphModule:
- flat_args = pytree.arg_tree_leaves(*args, **kwargs)
- mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()).cuda())
- shard_schema: Schema = Schema(mesh=mesh, placements=[Shard(0)])
- # FIXME: allow other sharding schemas
- replicate_schema: Schema = Schema(mesh=mesh, placements=[Replicate()])
- inps, schemas = [], []
- for p in pytree.tree_leaves(params_and_buffers):
- assert isinstance(p, torch.Tensor), f"expecting Tensor but got {type(p)}"
- inps.append(p)
- schemas.append(replicate_schema)
- for o in pytree.tree_leaves(named_states):
- if isinstance(o, torch.Tensor):
- inps.append(o)
- schemas.append(replicate_schema)
- else:
- inps.append(torch.empty(0))
- schemas.append(replicate_schema)
- for a in flat_args:
- if isinstance(a, torch.Tensor):
- inps.append(a)
- if id(a) in self._placements_override:
- schemas.append(
- Schema(mesh=mesh, placements=self._placements_override[id(a)])
- )
- else:
- schemas.append(shard_schema)
- else:
- # Create dummy tensor and schema for non-tensor inputs for
- # the purpose of dtensor expansion. Non-tensor inputs are
- # guaranteed unused in dispatcher graphs produced by make_fx.
- # However, we still need to respect them so that tensor inputs
- # match wtih their placeholders.
- inps.append(torch.empty(0))
- schemas.append(shard_schema)
- with FakeTensorMode(allow_non_fake_inputs=True):
- fake_inps = [torch.empty_like(inp) for inp in inps]
- return _convert_to_distributed(
- gm, fake_inps, schemas, default_mesh=mesh, _allow_partial=False
- )[0]
- def transform_and_compile(self, gm: GraphModule) -> GraphModule:
- """
- Transform and compile a distributed graph with a set of graph transformation
- and optimization passes for the dtensor fallback parallel mode.
- """
- # TODO: move the trasnformation passed to this function
- return self._gm_passes(gm)
|