| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- import abc
- import os
- from dataclasses import dataclass
- from typing import Any, List, Optional, Union
- from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
- from torch.distributed.checkpoint.planner import (
- LoadPlan,
- LoadPlanner,
- SavePlan,
- SavePlanner,
- )
- from torch.futures import Future
- __all__ = ["WriteResult", "StorageWriter", "StorageReader"]
- @dataclass(frozen=True)
- class WriteResult:
- index: MetadataIndex
- size_in_bytes: int
- storage_data: Any
- class StorageWriter(abc.ABC):
- """
- Interface used by ``save_state_dict`` to write to storage.
- One StorageWriter instance acts as both the coordinator and the follower
- in a distributed checkpoint. As part of initialization, each instance
- is told its role.
- A subclass should expect the following sequence of calls.
- 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
- 1) (all ranks) set_up_storage_writer()
- 2) (all ranks) prepare_local_plan()
- 3) (coordinator) prepare_global_plan()
- 4) (all ranks) write_data()
- 5) (coordinator) finish()
- """
- @abc.abstractmethod
- def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
- """
- Calls to indicates a brand new checkpoint write is going to happen.
- A checkpoint_id may be present if users set the checkpoint_id for
- this checkpoint write. The meaning of the checkpiont_id is
- storage-dependent. It can be a path to a folder/file or a key for
- a key-value storage.
- Args:
- checkpoint_id (Union[str, os.PathLike, None]):
- The ID of this checkpoint instance. The meaning of the checkpoint_id
- depends on the storage. It can be a path to a folder or to a file.
- It can also be a key if the storage is a key-value store.
- (Default: ``None``)
- """
- ...
- @abc.abstractmethod
- def set_up_storage_writer(self, is_coordinator: bool) -> None:
- """
- Initialize this instance.
- Args:
- is_coordinator (bool): Whether this instance is responsible for coordinating
- the checkpoint.
- """
- pass
- @abc.abstractmethod
- def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
- """
- Perform storage-specific local planning.
- While this method can produce a completely different plan, the recommended
- way is to store storage specific data in SavePlan::storage_data.
- Args:
- plan (SavePlan): The local plan from the ``SavePlanner`` in use.
- Returns:
- A transformed ``SavePlan`` after storage local planning
- """
- pass
- @abc.abstractmethod
- def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
- """
- Perform centralized planning of storage.
- This method is only called on the coordinator instance.
- While this method can produce a completely different plan, the preferred
- way is to store storage specific data in SavePlan::storage_data.
- Args:
- plans: A list of ``SavePlan`` instances, one for each rank.
- Returns:
- A list of transformed ``SavePlan`` after storage global planning
- """
- pass
- @abc.abstractmethod
- def write_data(
- self, plan: SavePlan, planner: SavePlanner
- ) -> Future[List[WriteResult]]:
- """
- Write all items from ``plan`` using ``planner`` to resolve the data.
- A subclass should call ``SavePlanner::resolve_data`` on each item
- from the plan to get access to the underlying object to write.
- Subclasses should lazily call `resolve_data` as it can allocate memory.
- In case of tensors, make following assumptions:
- - They might be on any device, including not matching the one on ``WriteItem::tensor_data``
- - They might be views or not contiguous. Only the projection needs to be saved.
- Args:
- plan (SavePlan): The save plan to execute.
- planner (SavePlanner): Planner object to be used to resolve items to data.
- Returns:
- A future that completes to a list of WriteResult
- """
- pass
- @abc.abstractmethod
- def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
- """
- Write the metadata and marks the current checkpoint as successful.
- The actual format/schema used for serializing `metadata` is an
- implementation detail. The only requirement is that it's recoverable
- in to the same object graph.
- Args:
- metadata (Metadata): metadata for the new checkpoint
- results: A list of WriteResults from all ranks.
- Returns:
- None
- """
- pass
- @classmethod
- @abc.abstractmethod
- def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
- """
- Check if the given checkpoint_id is supported by the stroage. This allow
- us to enable automatic storage selection.
- """
- ...
- def storage_meta(self) -> Optional[StorageMeta]:
- """
- Return the storage-specific metadata. This is used to store additional information
- in a checkpoint that can be useful for providing request-level observability. StorageMeta
- is passed to the ``SavePlanner`` during save calls. Returns None by default.
- TODO: provide an example
- """
- return None
- class StorageReader(abc.ABC):
- """
- Interface used by ``load_state_dict`` to read from storage.
- One StorageReader instance acts as both the coordinator and the follower
- in a distributed checkpoint. As part of initialization, each instance
- is told its role.
- A subclass should expected the following sequence of calls by ``load_state_dict``:
- 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
- 1) (all ranks) read_metadata()
- 2) (all ranks) set_up_storage_reader()
- 3) (all ranks) prepare_local_plan()
- 4) (coordinator) prepare_global_plan()
- 5) (all ranks) read_data()
- """
- @abc.abstractmethod
- def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
- """
- Calls to indicates a brand new checkpoint read is going to happen.
- A checkpoint_id may be present if users set the checkpoint_id for
- this checkpoint read. The meaning of the checkpiont_id is
- storage-dependent. It can be a path to a folder/file or a key for
- a key-value storage.
- Args:
- checkpoint_id (Union[str, os.PathLike, None]):
- The ID of this checkpoint instance. The meaning of the checkpoint_id
- depends on the storage. It can be a path to a folder or to a file.
- It can also be a key if the storage is more like a key-value store.
- (Default: ``None``)
- """
- ...
- @abc.abstractmethod
- def read_metadata(self) -> Metadata:
- """
- Read the checkpoint metadata.
- Returns:
- The metadata object associated with the checkpoint being loaded.
- """
- pass
- @abc.abstractmethod
- def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
- """
- Initialize this instance.
- Args:
- metadata (Metadata): The metadata schema to use.
- is_coordinator (bool): Whether this instance is responsible for coordinating
- the checkpoint.
- """
- pass
- @abc.abstractmethod
- def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
- """
- Perform storage-specific local planning.
- While this method can produce a completely different plan, the recommended
- way is to store storage specific data in LoadPlan::storage_data.
- Args:
- plan (LoadPlan): The local plan from the ``LoadPlan`` in use.
- Returns:
- A transformed ``LoadPlan`` after storage local planning
- """
- pass
- @abc.abstractmethod
- def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
- """
- Perform centralized planning of storage loading.
- This method is only called on the coordinator instance.
- While this method can produce a completely different plan, the preferred
- way is to store storage specific data in LoadPlan::storage_data.
- Args:
- plans: A list of ``LoadPlan`` instances, one for each rank.
- Returns:
- A list of transformed ``LoadPlan`` after storage global planning
- """
- pass
- @abc.abstractmethod
- def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
- """
- Read all items from ``plan`` using ``planner`` to resolve the data.
- A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO
- object into the right place.
- A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the
- tensors that in should load data into.
- It's the StorageLayer responsibility to properly schedule any cross device copies
- required.
- Args:
- plan (LoadPlan): The local plan to execute on
- planner (LoadPlanner): The planner object to use to resolve items.
- Returns:
- A future that completes once all reads are finished.
- """
- pass
- @classmethod
- @abc.abstractmethod
- def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
- """
- Check if the given checkpoint_id is supported by the stroage. This allow
- us to enable automatic storage selection.
- """
- ...
|