| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- # mypy: allow-untyped-defs
- import copy
- import itertools
- import math
- from typing import Optional
- import torch
- import torch.distributed as dist
- from torch.distributed import distributed_c10d
- from torch.distributed._shard.sharded_tensor import (
- Shard,
- ShardedTensor,
- ShardedTensorMetadata,
- TensorProperties,
- )
- from torch.distributed._shard.sharding_spec import ShardMetadata
- from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
- def _get_remote_device_str(rank, device_type, num_devices_per_node):
- if device_type.lower() == "cpu":
- return f"rank:{rank}/{device_type}"
- else:
- return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
- def _create_chunk_sharded_tensor(
- tensor: torch.Tensor,
- rank: int,
- world_size: int,
- num_devices_per_node: int,
- pg: dist.ProcessGroup,
- device: Optional[torch.device] = None,
- ) -> ShardedTensor:
- """
- Shard a tensor to chunks along the first dimension. The local rank will gets its
- corresponding chunk as the local shard to create a ShardedTensor.
- """
- chunks = tensor.chunk(world_size, dim=0)
- if len(chunks) > rank:
- local_shard = chunks[rank].clone()
- offsets = [0 for _ in tensor.size()]
- offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
- local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
- else:
- local_shards = []
- # Create a ShardedTensor without invoking communication.
- chunk_sizes = [list(chunk.size()) for chunk in chunks]
- dim0_offsets = [0] + list(
- itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
- )[:-1]
- offsets = [0] * (len(chunk_sizes[0]) - 1)
- chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
- device_type = (
- distributed_c10d._get_pg_default_device(pg).type
- if device is None
- else device.type
- )
- placements = [
- _get_remote_device_str(
- dist.get_global_rank(pg, r),
- device_type,
- num_devices_per_node,
- )
- for r in range(len(chunk_sizes))
- ]
- assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
- shard_metadata = [
- ShardMetadata(offset, size, placement)
- for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
- ]
- sharded_tensor_metadata = ShardedTensorMetadata(
- shards_metadata=shard_metadata,
- size=tensor.size(),
- tensor_properties=TensorProperties(
- dtype=tensor.dtype,
- layout=tensor.layout,
- requires_grad=False,
- memory_format=torch.contiguous_format,
- pin_memory=tensor.is_pinned(),
- ),
- )
- return ShardedTensor._init_from_local_shards_and_global_metadata(
- local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
- )
- def _create_chunk_dtensor(
- tensor: torch.Tensor,
- rank: int,
- device_mesh: DeviceMesh,
- ) -> DTensor:
- """
- Shard a tensor to chunks along the first dimension. The local rank will gets its
- corresponding chunk as the local tensor to create a DTensor.
- """
- # We need to explicitly call .detach() to return a new tensor detached from the current graph.
- tensor = tensor.clone().detach()
- # FSDP placements: [Shard(0)]
- # HSDP placements: [Replicate(), Shard(0)]
- replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
- shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
- shard_placements[-1] = DShard(0) # type: ignore[call-overload]
- return DTensor.from_local(
- tensor, device_mesh, replicate_placements, run_check=False
- ).redistribute(
- placements=shard_placements,
- )
- def _all_gather_dtensor(
- tensor: DTensor,
- parent_mesh: Optional[DeviceMesh],
- ) -> torch.Tensor:
- """
- All gather a DTensor in its sharded dimension and return the local tensor.
- """
- assert parent_mesh is None
- placements = list(copy.deepcopy(tensor.placements))
- # FSDP placements: [Shard(0)] -> [Replicate()]
- # HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
- placements[-1] = Replicate()
- tensor = tensor.redistribute(
- device_mesh=tensor.device_mesh,
- placements=placements,
- )
- return tensor.to_local()
|