planner_helpers.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. # mypy: allow-untyped-defs
  2. from typing import Any, cast, List
  3. import torch
  4. import torch.distributed as dist
  5. from torch._utils import _get_device_module
  6. from torch.distributed._shard.metadata import ShardMetadata
  7. from torch.distributed._shard.sharded_tensor import ShardedTensor
  8. from torch.distributed._tensor import DTensor
  9. from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
  10. from torch.distributed.checkpoint.planner import _Checkpointable
  11. from torch.utils._pytree import tree_map_only_
  12. from .metadata import (
  13. BytesStorageMetadata,
  14. ChunkStorageMetadata,
  15. MetadataIndex,
  16. STATE_DICT_TYPE,
  17. STORAGE_TYPES,
  18. TensorProperties,
  19. TensorStorageMetadata,
  20. )
  21. from .planner import (
  22. LoadItemType,
  23. ReadItem,
  24. SavePlan,
  25. TensorWriteData,
  26. WriteItem,
  27. WriteItemType,
  28. )
  29. from .resharding import (
  30. _check_shard_metadata_pair_overlap,
  31. _shards_get_overlap_region_wrt_saved_tensor,
  32. )
  33. __all__: List[str] = ["create_read_items_for_chunk_list"]
  34. def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
  35. return ChunkStorageMetadata(
  36. offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()
  37. )
  38. def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
  39. return ChunkStorageMetadata(
  40. offsets=torch.Size(shard_md.shard_offsets),
  41. sizes=torch.Size(shard_md.shard_sizes),
  42. )
  43. def _sharded_tensor_metadata(
  44. sharded_tensor: ShardedTensor, shard_md: ShardMetadata
  45. ) -> TensorWriteData:
  46. shard_properties = sharded_tensor.metadata().tensor_properties
  47. properties = TensorProperties(
  48. dtype=shard_properties.dtype,
  49. layout=shard_properties.layout,
  50. requires_grad=shard_properties.requires_grad,
  51. memory_format=shard_properties.memory_format,
  52. pin_memory=shard_properties.pin_memory,
  53. )
  54. return TensorWriteData(
  55. chunk=_chunk_for_shard(shard_md),
  56. properties=properties,
  57. size=sharded_tensor.metadata().size,
  58. )
  59. def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
  60. sizes, offsets = compute_local_shape_and_global_offset(
  61. tensor.shape, tensor.device_mesh, tensor.placements
  62. )
  63. sizes, offsets = torch.Size(sizes), torch.Size(offsets)
  64. return WriteItem(
  65. index=MetadataIndex(fqn, offsets),
  66. type=WriteItemType.SHARD,
  67. tensor_data=TensorWriteData(
  68. chunk=ChunkStorageMetadata(
  69. offsets=offsets,
  70. sizes=sizes,
  71. ),
  72. properties=TensorProperties.create_from_tensor(tensor.to_local()),
  73. size=tensor.size(),
  74. ),
  75. )
  76. def _create_write_item_for_shard(
  77. fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
  78. ) -> WriteItem:
  79. offsets = torch.Size(shard_md.shard_offsets)
  80. return WriteItem(
  81. index=MetadataIndex(fqn, offsets),
  82. type=WriteItemType.SHARD,
  83. tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
  84. )
  85. def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
  86. offsets = torch.Size([0] * len(tensor.size()))
  87. return WriteItem(
  88. index=MetadataIndex(fqn, offsets),
  89. type=WriteItemType.TENSOR,
  90. tensor_data=TensorWriteData(
  91. chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
  92. properties=TensorProperties.create_from_tensor(tensor),
  93. size=tensor.size(),
  94. ),
  95. )
  96. def _create_write_item_for_bytesio(fqn: str, bytes: Any):
  97. return WriteItem(
  98. index=MetadataIndex(fqn),
  99. type=WriteItemType.BYTE_IO,
  100. )
  101. def _create_read_item_for_byteio(
  102. dest_index, dest_offset, storage_index, storage_offset, length
  103. ):
  104. return ReadItem(
  105. type=LoadItemType.BYTE_IO,
  106. dest_index=dest_index,
  107. dest_offsets=torch.Size((dest_offset,)),
  108. storage_index=storage_index,
  109. storage_offsets=torch.Size((storage_offset,)),
  110. lengths=torch.Size((length,)),
  111. )
  112. def _create_read_item_for_tensor(
  113. dest_index, dest_offsets, storage_index, storage_offsets, lengths
  114. ):
  115. return ReadItem(
  116. type=LoadItemType.TENSOR,
  117. dest_index=dest_index,
  118. dest_offsets=torch.Size(dest_offsets),
  119. storage_index=storage_index,
  120. storage_offsets=torch.Size(storage_offsets),
  121. lengths=torch.Size(lengths),
  122. )
  123. def create_read_items_for_chunk_list(
  124. fqn: str,
  125. checkpoint_md: TensorStorageMetadata,
  126. local_chunks: List[ChunkStorageMetadata],
  127. ) -> List[ReadItem]:
  128. """
  129. Create a list of ``ReadItem`` based on the checkpoint and local chunks.
  130. This applies the resharding algorithm and computes the reads needed
  131. to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``.
  132. Args:
  133. fqn (str) : The state_dict FQN to pass to ``ReadItem``.
  134. checkpoint_md (TensorStorageMetadata): metadata for a given tensor
  135. from a checkpoint.
  136. local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be
  137. loaded.
  138. Returns:
  139. A list of ``ReadItem`` that will satisfy all input chunks.
  140. """
  141. read_items = []
  142. # this is a naive quadratic algo that can be optimized later
  143. for idx, shard in enumerate(local_chunks):
  144. for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
  145. if not _check_shard_metadata_pair_overlap(shard, storage_md):
  146. continue
  147. storage_offsets = []
  148. dest_offsets = []
  149. lengths = []
  150. for (
  151. dim,
  152. offset_for_saved_tensor,
  153. offset_for_current_tensor,
  154. length,
  155. ) in _shards_get_overlap_region_wrt_saved_tensor(
  156. saved_shard=storage_md, current_shard=shard
  157. ):
  158. storage_offsets.append(offset_for_saved_tensor)
  159. dest_offsets.append(offset_for_current_tensor)
  160. lengths.append(length)
  161. read_items.append(
  162. _create_read_item_for_tensor(
  163. dest_index=MetadataIndex(fqn, shard.offsets, idx),
  164. dest_offsets=dest_offsets,
  165. storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx),
  166. storage_offsets=storage_offsets,
  167. lengths=lengths,
  168. )
  169. )
  170. return read_items
  171. def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
  172. requests = []
  173. for fqn, obj in state_dict.items():
  174. if isinstance(obj, DTensor):
  175. requests.append(_create_write_items_for_dtensor(fqn, obj))
  176. elif isinstance(obj, ShardedTensor):
  177. for shard_md in obj.metadata().shards_metadata:
  178. requests.append(_create_write_item_for_shard(fqn, obj, shard_md))
  179. elif isinstance(obj, torch.Tensor):
  180. requests.append(_create_write_item_for_tensor(fqn, obj))
  181. else:
  182. requests.append(_create_write_item_for_bytesio(fqn, obj))
  183. return SavePlan(requests)
  184. def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
  185. if isinstance(object, _Checkpointable):
  186. return object._create_write_items(fqn, object)
  187. elif isinstance(object, DTensor):
  188. # DTensor can contain a local tensor that is a tensor subclass
  189. if isinstance(object.to_local(), _Checkpointable):
  190. return object.to_local()._create_write_items(fqn, object) # type: ignore[arg-type]
  191. return [_create_write_items_for_dtensor(fqn, object)]
  192. elif isinstance(object, ShardedTensor):
  193. return [
  194. _create_write_item_for_shard(fqn, object, shard.metadata)
  195. for shard in object.local_shards()
  196. ]
  197. elif isinstance(object, torch.Tensor):
  198. return [_create_write_item_for_tensor(fqn, object)]
  199. else:
  200. return [_create_write_item_for_bytesio(fqn, object)]
  201. def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
  202. sizes, offsets = compute_local_shape_and_global_offset(
  203. tensor.shape, tensor.device_mesh, tensor.placements
  204. )
  205. sizes, offsets = torch.Size(sizes), torch.Size(offsets)
  206. return ChunkStorageMetadata(
  207. offsets=offsets,
  208. sizes=sizes,
  209. )
  210. def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
  211. if isinstance(tensor, _Checkpointable):
  212. local_chunks = tensor._create_chunk_list(tensor)
  213. elif isinstance(tensor, DTensor):
  214. # DTensor can contain a local tensor that is a tensor subclass
  215. if isinstance(tensor.to_local(), _Checkpointable):
  216. return tensor.to_local()._create_chunk_list(tensor) # type: ignore[arg-type]
  217. local_chunks = [_create_chunk_from_dtensor(tensor)]
  218. elif isinstance(tensor, ShardedTensor):
  219. local_chunks = [
  220. _chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
  221. ]
  222. elif isinstance(tensor, torch.Tensor):
  223. local_chunks = [_create_chunk_from_tensor(tensor)]
  224. else:
  225. raise ValueError(
  226. "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] "
  227. f",but got {type(tensor)}"
  228. )
  229. return local_chunks
  230. def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
  231. if not isinstance(md, BytesStorageMetadata):
  232. try:
  233. local_chunks = _create_chunk_list(obj)
  234. except ValueError as ex:
  235. raise ValueError(
  236. f"Invalid checkpoint metadata for {fqn}, "
  237. + f"expected BytesStorageMetadata but found {type(md)}",
  238. ) from ex
  239. return create_read_items_for_chunk_list(fqn, md, local_chunks)
  240. else:
  241. return [
  242. _create_read_item_for_byteio(
  243. dest_index=MetadataIndex(fqn),
  244. dest_offset=0,
  245. storage_index=MetadataIndex(fqn),
  246. storage_offset=0,
  247. length=0,
  248. )
  249. ]
  250. def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None:
  251. tree_map_only_(torch.Tensor, _init_meta_tensor, state_dict)
  252. def _init_meta_tensor(value: Any) -> Any:
  253. """
  254. Initializes tensor, moves it to device for torch.Tensor/DTensor on meta device.
  255. """
  256. device = getattr(value, "device", None)
  257. # DCP does the initialization if it's meta tensor/DTensor.
  258. if device == torch.device("meta"):
  259. device_type = dist.distributed_c10d._get_pg_default_device().type
  260. device = cast(torch.device, _get_device_module(device_type).current_device())
  261. if isinstance(value, DTensor):
  262. new_local_tensor = torch.empty_like(value.to_local(), device=device)
  263. # We need to pass shape and stride explicitly, since DTensor might be
  264. # sharded unevenly.
  265. dtensor = DTensor.from_local(
  266. new_local_tensor,
  267. device_mesh=value.device_mesh,
  268. placements=value.placements,
  269. shape=value.size(),
  270. stride=value.stride(),
  271. )
  272. return dtensor
  273. elif isinstance(value, torch.Tensor):
  274. tensor = torch.empty_like(value, device=device)
  275. return tensor
  276. else:
  277. raise RuntimeError(
  278. f"Found unsupported type {type(value)} for meta device loading."
  279. )
  280. else:
  281. return value