_shard_utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import itertools
  4. import math
  5. from typing import Optional
  6. import torch
  7. import torch.distributed as dist
  8. from torch.distributed import distributed_c10d
  9. from torch.distributed._shard.sharded_tensor import (
  10. Shard,
  11. ShardedTensor,
  12. ShardedTensorMetadata,
  13. TensorProperties,
  14. )
  15. from torch.distributed._shard.sharding_spec import ShardMetadata
  16. from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
  17. def _get_remote_device_str(rank, device_type, num_devices_per_node):
  18. if device_type.lower() == "cpu":
  19. return f"rank:{rank}/{device_type}"
  20. else:
  21. return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
  22. def _create_chunk_sharded_tensor(
  23. tensor: torch.Tensor,
  24. rank: int,
  25. world_size: int,
  26. num_devices_per_node: int,
  27. pg: dist.ProcessGroup,
  28. device: Optional[torch.device] = None,
  29. ) -> ShardedTensor:
  30. """
  31. Shard a tensor to chunks along the first dimension. The local rank will gets its
  32. corresponding chunk as the local shard to create a ShardedTensor.
  33. """
  34. chunks = tensor.chunk(world_size, dim=0)
  35. if len(chunks) > rank:
  36. local_shard = chunks[rank].clone()
  37. offsets = [0 for _ in tensor.size()]
  38. offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
  39. local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
  40. else:
  41. local_shards = []
  42. # Create a ShardedTensor without invoking communication.
  43. chunk_sizes = [list(chunk.size()) for chunk in chunks]
  44. dim0_offsets = [0] + list(
  45. itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
  46. )[:-1]
  47. offsets = [0] * (len(chunk_sizes[0]) - 1)
  48. chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
  49. device_type = (
  50. distributed_c10d._get_pg_default_device(pg).type
  51. if device is None
  52. else device.type
  53. )
  54. placements = [
  55. _get_remote_device_str(
  56. dist.get_global_rank(pg, r),
  57. device_type,
  58. num_devices_per_node,
  59. )
  60. for r in range(len(chunk_sizes))
  61. ]
  62. assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
  63. shard_metadata = [
  64. ShardMetadata(offset, size, placement)
  65. for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
  66. ]
  67. sharded_tensor_metadata = ShardedTensorMetadata(
  68. shards_metadata=shard_metadata,
  69. size=tensor.size(),
  70. tensor_properties=TensorProperties(
  71. dtype=tensor.dtype,
  72. layout=tensor.layout,
  73. requires_grad=False,
  74. memory_format=torch.contiguous_format,
  75. pin_memory=tensor.is_pinned(),
  76. ),
  77. )
  78. return ShardedTensor._init_from_local_shards_and_global_metadata(
  79. local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
  80. )
  81. def _create_chunk_dtensor(
  82. tensor: torch.Tensor,
  83. rank: int,
  84. device_mesh: DeviceMesh,
  85. ) -> DTensor:
  86. """
  87. Shard a tensor to chunks along the first dimension. The local rank will gets its
  88. corresponding chunk as the local tensor to create a DTensor.
  89. """
  90. # We need to explicitly call .detach() to return a new tensor detached from the current graph.
  91. tensor = tensor.clone().detach()
  92. # FSDP placements: [Shard(0)]
  93. # HSDP placements: [Replicate(), Shard(0)]
  94. replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
  95. shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
  96. shard_placements[-1] = DShard(0) # type: ignore[call-overload]
  97. return DTensor.from_local(
  98. tensor, device_mesh, replicate_placements, run_check=False
  99. ).redistribute(
  100. placements=shard_placements,
  101. )
  102. def _all_gather_dtensor(
  103. tensor: DTensor,
  104. parent_mesh: Optional[DeviceMesh],
  105. ) -> torch.Tensor:
  106. """
  107. All gather a DTensor in its sharded dimension and return the local tensor.
  108. """
  109. assert parent_mesh is None
  110. placements = list(copy.deepcopy(tensor.placements))
  111. # FSDP placements: [Shard(0)] -> [Replicate()]
  112. # HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
  113. placements[-1] = Replicate()
  114. tensor = tensor.redistribute(
  115. device_mesh=tensor.device_mesh,
  116. placements=placements,
  117. )
  118. return tensor.to_local()