| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279 |
- # mypy: allow-untyped-defs
- import argparse
- import os
- from enum import Enum
- from typing import cast, Dict, List, Optional, Union
- import torch
- import torch.distributed as dist
- from torch.distributed._shard._utils import narrow_tensor_by_index
- from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
- from torch.distributed.checkpoint._nested_dict import flatten_state_dict
- from torch.distributed.checkpoint.default_planner import (
- _EmptyStateDictLoadPlanner,
- DefaultLoadPlanner,
- )
- from torch.distributed.checkpoint.metadata import (
- Metadata,
- STATE_DICT_TYPE,
- STORAGE_TYPES,
- TensorProperties,
- TensorStorageMetadata,
- )
- from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner
- from torch.distributed.checkpoint.planner_helpers import _create_chunk_list
- from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
- from torch.distributed.checkpoint.state_dict_saver import _save_state_dict
- from torch.distributed.checkpoint.storage import StorageReader
- from torch.futures import Future
- __all__ = [
- "dcp_to_torch_save",
- "torch_save_to_dcp",
- "BroadcastingTorchSaveReader",
- "DynamicMetaLoadPlanner",
- ]
- class BroadcastingTorchSaveReader(StorageReader):
- """
- StorageReader for reading a Torch Save file. This reader will read the entire checkpoint
- on the coordinator rank, and then broadcast and shard each tensor to all ranks.
- . N.B. Intended to be used with DynamicMetaLoadPlanner
- .. warning::
- Current implementation only supports loading Tensors.
- >>> # xdoctest: +SKIP("undefined vars")
- >>> sd = {"mode": model}
- >>> dcp.load(
- >>> sd,
- >>> storage_reader=BroadcastingTorchSaveReader(),
- >>> planner=DynamicMetaLoadPlanner(),
- >>> checkpoint_id="path_to_model.pt"
- >>> )
- """
- def __init__(
- self,
- checkpoint_id: Optional[Union[str, os.PathLike]] = None,
- coordinator_rank: int = 0,
- ) -> None:
- self.checkpoint_id = checkpoint_id
- self.coordinator_rank = coordinator_rank
- def read_metadata(self) -> Metadata:
- """Extends the default StorageReader to support building the metadata file"""
- # Metadata is built in planner.set_up_planner, since we are not actually reading metadata from
- # the disk
- return Metadata(state_dict_metadata={})
- def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
- """
- Reads torch save data on the coordinator rank, and broadcast afterwards
- this incurrs a communication cost, but avoids having to load
- the entire checkpoint on each rank, hopefully preventing OOM issues
- """
- planner = cast(DefaultLoadPlanner, planner)
- # data is read in on the coordinator rank, and broadcast afterwards
- # this incurrs a communication cost, but it avoids having to load
- # the entire checkpoint on each rank, hopefully preventing OOM issues
- # TODO: read on each host, instead of only the coordinator
- if self.is_coordinator:
- assert self.checkpoint_id is not None
- torch_state_dict = torch.load(
- self.checkpoint_id, map_location="cpu", weights_only=False
- )
- if planner.flatten_state_dict:
- torch_state_dict, _ = flatten_state_dict(torch_state_dict)
- else:
- torch_state_dict = None
- for req in plan.items:
- if req.type == LoadItemType.BYTE_IO:
- raise RuntimeError(
- f"Non-tensor value identified at {req.storage_index.fqn}. "
- f"At this time {type(self).__name__} only supports loading Tensors."
- )
- # Broadcast the tensor from the coordinator rank
- if self.is_coordinator:
- tensor = torch_state_dict[req.storage_index.fqn].cuda()
- else:
- tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn])
- dist.broadcast(tensor, src=self.coordinator_rank, async_op=False)
- tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
- target_tensor = planner.resolve_tensor(req).detach()
- assert target_tensor.size() == tensor.size(), (
- f"req {req.storage_index} mismatch sizes, "
- f"{target_tensor.size()} vs {tensor.size()}"
- )
- target_tensor.copy_(tensor)
- planner.commit_tensor(req, target_tensor)
- fut: Future = Future()
- fut.set_result(None)
- return fut
- def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
- """Implementation of the StorageReader method"""
- self.is_coordinator = is_coordinator
- if self.is_coordinator:
- assert dist.get_rank() == self.coordinator_rank
- assert self.checkpoint_id is not None
- def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
- """Implementation of the StorageReader method"""
- return plan
- def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
- """Implementation of the StorageReader method"""
- return global_plan
- def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
- """Implementation of the StorageReader method"""
- self.checkpoint_id = checkpoint_id
- @classmethod
- def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
- """Implementation of the StorageReader method"""
- return os.path.isfile(checkpoint_id)
- class DynamicMetaLoadPlanner(DefaultLoadPlanner):
- """
- Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict,
- avoiding the need to read metadata from disk. This is useful when reading formats which don't have a
- metadata file, like Torch Save files.
- . N.B. Intended to be used with BroadcastingTorchSaveReader
- .. warning::
- Current implementation only supports loading Tensors.
- >>> # xdoctest: +SKIP("undefined vars")
- >>> sd = {"mode": model}
- >>> dcp.load(
- >>> sd,
- >>> storage_reader=BroadcastingTorchSaveReader(),
- >>> planner=DynamicMetaLoadPlanner(),
- >>> checkpoint_id="path_to_model.pt"
- >>> )
- """
- def set_up_planner(
- self,
- state_dict: STATE_DICT_TYPE,
- metadata: Optional[Metadata] = None,
- is_coordinator: bool = False,
- ) -> None:
- """Setups of the planner, extnding default behavior by creating the Metadata object from the state dict"""
- super().set_up_planner(state_dict, metadata, is_coordinator)
- state_dict_metadata: Dict[str, STORAGE_TYPES] = {}
- for key, tensor in self.state_dict.items():
- if not torch.is_tensor(tensor):
- raise RuntimeError(
- f"Non-tensor value identified at {key}. "
- f"At this time {type(self).__name__} only supports loading Tensors."
- )
- state_dict_metadata[key] = TensorStorageMetadata(
- TensorProperties(dtype=tensor.dtype),
- tensor.size(),
- _create_chunk_list(tensor),
- )
- self.metadata = Metadata(state_dict_metadata=state_dict_metadata)
- def dcp_to_torch_save(
- dcp_checkpoint_dir: Union[str, os.PathLike],
- torch_save_path: Union[str, os.PathLike],
- ):
- """
- Given a directory containing a DCP checkpoint, this function will convert it into a
- Torch save file.
- Args:
- dcp_checkpoint_dir: Directory containing the DCP checkpoint.
- torch_save_path: Filename to store the converted Torch save file.
- .. warning::
- To avoid OOM, it's recommended to only run this function on a single rank.
- """
- sd: STATE_DICT_TYPE = {}
- _load_state_dict(
- sd,
- storage_reader=FileSystemReader(dcp_checkpoint_dir),
- planner=_EmptyStateDictLoadPlanner(),
- no_dist=True,
- )
- torch.save(sd, torch_save_path)
- def torch_save_to_dcp(
- torch_save_path: Union[str, os.PathLike],
- dcp_checkpoint_dir: Union[str, os.PathLike],
- ):
- """
- Given the location of a torch save file, converts it into a DCP checkpoint.
- Args:
- torch_save_path: Filename of the Torch save file.
- dcp_checkpoint_dir: Directory to store the DCP checkpoint.
- .. warning::
- To avoid OOM, it's recommended to only run this function on a single rank.
- """
- state_dict = torch.load(torch_save_path, weights_only=False)
- # we don't need stateful behavior here because the expectation is anything loaded by
- # torch.load would not contain stateful objects.
- _save_state_dict(
- state_dict, storage_writer=FileSystemWriter(dcp_checkpoint_dir), no_dist=True
- )
- if __name__ == "__main__":
- class FormatMode(Enum):
- TORCH_TO_DCP = "torch_to_dcp"
- DCP_TO_TORCH = "dcp_to_torch"
- # Parse command-line arguments
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "mode",
- type=str,
- help="Conversion mode",
- choices=[m.value for m in FormatMode],
- default=FormatMode.TORCH_TO_DCP,
- )
- parser.add_argument("src", type=str, help="Path to the source model")
- parser.add_argument("dst", type=str, help="Path to the destination model")
- args = parser.parse_args()
- print(
- f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'"
- )
- checkpoint_missing_warning = (
- f"No checkpoint found at {args.src}. Skipping conversion."
- )
- if args.mode == FormatMode.TORCH_TO_DCP.value:
- if os.path.isfile(args.src):
- torch_save_to_dcp(args.src, args.dst)
- else:
- print(checkpoint_missing_warning)
- elif args.mode == FormatMode.DCP_TO_TORCH.value:
- if os.path.isdir(args.src):
- dcp_to_torch_save(args.src, args.dst)
- else:
- print(checkpoint_missing_warning)
- else:
- raise ValueError(f"Unknown conversion mode: {args.mode}")
|