_checkpointer.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from concurrent.futures import Future
  2. from typing import Any, Dict, List, Optional
  3. import torch.distributed as dist
  4. import torch.distributed.checkpoint.state_dict_loader as loader
  5. import torch.distributed.checkpoint.state_dict_saver as saver
  6. from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
  7. from torch.distributed.checkpoint.storage import (
  8. LoadPlanner,
  9. SavePlanner,
  10. StorageReader,
  11. StorageWriter,
  12. )
  13. __all__: List[str] = []
  14. class _Checkpointer:
  15. """This base class specefies a high level API for saving and loading
  16. distributed `state_dict` 's. It provides an abstraction over the low-level APIs
  17. provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling
  18. :py:meth: `torch.distributed.state_dict_saver.save` and
  19. :py:meth: `torch.distributed.state_dict_loader.load` with the provided storage
  20. readers and writers.
  21. .. warning::
  22. This feature is experimental and subject to removal/change.
  23. """
  24. def __init__(
  25. self,
  26. storage_writer: StorageWriter,
  27. storage_reader: StorageReader,
  28. *,
  29. process_group: Optional[dist.ProcessGroup] = None,
  30. coordinator_rank: int = 0,
  31. no_dist: bool = False,
  32. load_planner: Optional[LoadPlanner] = None,
  33. save_planner: Optional[SavePlanner] = None,
  34. ):
  35. """Initializes the Checkpointer instance.
  36. Args:
  37. storage_writer: Instance of StorageWrite use to perform writes.
  38. storage_reader: StorageReader used to load data from.
  39. process_group: ProcessGroup to be used for cross-rank synchronization.
  40. coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default.
  41. no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``)
  42. loader_planner: Instance of LoadPlanner to use when loading.
  43. save_planner: Instance of SavePlanner to use when saving.
  44. """
  45. self.storage_writer = storage_writer
  46. self.storage_reader = storage_reader
  47. self.process_group = process_group
  48. self.coordinator_rank = coordinator_rank
  49. self.no_dist = no_dist
  50. self.load_planner = load_planner
  51. self.save_planner = save_planner
  52. def save(
  53. self,
  54. state_dict: STATE_DICT_TYPE,
  55. ) -> Metadata:
  56. """Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization."""
  57. return saver.save(
  58. state_dict,
  59. self.storage_writer,
  60. process_group=self.process_group,
  61. coordinator_rank=self.coordinator_rank,
  62. no_dist=self.no_dist,
  63. planner=self.save_planner,
  64. )
  65. def async_save(
  66. self,
  67. state_dict: STATE_DICT_TYPE,
  68. ) -> Future:
  69. """
  70. Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization.
  71. Returns:
  72. Future: A future holding the resultant Metadata object from `save`.
  73. """
  74. return saver.async_save(
  75. state_dict,
  76. storage_writer=self.storage_writer,
  77. process_group=self.process_group,
  78. planner=self.save_planner,
  79. )
  80. def load(self, state_dict: Dict[str, Any]) -> None:
  81. """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
  82. loader.load(
  83. state_dict,
  84. storage_reader=self.storage_reader,
  85. process_group=self.process_group,
  86. planner=self.load_planner,
  87. )