| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- from concurrent.futures import Future
- from typing import Any, Dict, List, Optional
- import torch.distributed as dist
- import torch.distributed.checkpoint.state_dict_loader as loader
- import torch.distributed.checkpoint.state_dict_saver as saver
- from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
- from torch.distributed.checkpoint.storage import (
- LoadPlanner,
- SavePlanner,
- StorageReader,
- StorageWriter,
- )
- __all__: List[str] = []
- class _Checkpointer:
- """This base class specefies a high level API for saving and loading
- distributed `state_dict` 's. It provides an abstraction over the low-level APIs
- provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling
- :py:meth: `torch.distributed.state_dict_saver.save` and
- :py:meth: `torch.distributed.state_dict_loader.load` with the provided storage
- readers and writers.
- .. warning::
- This feature is experimental and subject to removal/change.
- """
- def __init__(
- self,
- storage_writer: StorageWriter,
- storage_reader: StorageReader,
- *,
- process_group: Optional[dist.ProcessGroup] = None,
- coordinator_rank: int = 0,
- no_dist: bool = False,
- load_planner: Optional[LoadPlanner] = None,
- save_planner: Optional[SavePlanner] = None,
- ):
- """Initializes the Checkpointer instance.
- Args:
- storage_writer: Instance of StorageWrite use to perform writes.
- storage_reader: StorageReader used to load data from.
- process_group: ProcessGroup to be used for cross-rank synchronization.
- coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default.
- no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``)
- loader_planner: Instance of LoadPlanner to use when loading.
- save_planner: Instance of SavePlanner to use when saving.
- """
- self.storage_writer = storage_writer
- self.storage_reader = storage_reader
- self.process_group = process_group
- self.coordinator_rank = coordinator_rank
- self.no_dist = no_dist
- self.load_planner = load_planner
- self.save_planner = save_planner
- def save(
- self,
- state_dict: STATE_DICT_TYPE,
- ) -> Metadata:
- """Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization."""
- return saver.save(
- state_dict,
- self.storage_writer,
- process_group=self.process_group,
- coordinator_rank=self.coordinator_rank,
- no_dist=self.no_dist,
- planner=self.save_planner,
- )
- def async_save(
- self,
- state_dict: STATE_DICT_TYPE,
- ) -> Future:
- """
- Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization.
- Returns:
- Future: A future holding the resultant Metadata object from `save`.
- """
- return saver.async_save(
- state_dict,
- storage_writer=self.storage_writer,
- process_group=self.process_group,
- planner=self.save_planner,
- )
- def load(self, state_dict: Dict[str, Any]) -> None:
- """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
- loader.load(
- state_dict,
- storage_reader=self.storage_reader,
- process_group=self.process_group,
- planner=self.load_planner,
- )
|