storage.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import abc
  2. import os
  3. from dataclasses import dataclass
  4. from typing import Any, List, Optional, Union
  5. from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
  6. from torch.distributed.checkpoint.planner import (
  7. LoadPlan,
  8. LoadPlanner,
  9. SavePlan,
  10. SavePlanner,
  11. )
  12. from torch.futures import Future
  13. __all__ = ["WriteResult", "StorageWriter", "StorageReader"]
  14. @dataclass(frozen=True)
  15. class WriteResult:
  16. index: MetadataIndex
  17. size_in_bytes: int
  18. storage_data: Any
  19. class StorageWriter(abc.ABC):
  20. """
  21. Interface used by ``save_state_dict`` to write to storage.
  22. One StorageWriter instance acts as both the coordinator and the follower
  23. in a distributed checkpoint. As part of initialization, each instance
  24. is told its role.
  25. A subclass should expect the following sequence of calls.
  26. 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
  27. 1) (all ranks) set_up_storage_writer()
  28. 2) (all ranks) prepare_local_plan()
  29. 3) (coordinator) prepare_global_plan()
  30. 4) (all ranks) write_data()
  31. 5) (coordinator) finish()
  32. """
  33. @abc.abstractmethod
  34. def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
  35. """
  36. Calls to indicates a brand new checkpoint write is going to happen.
  37. A checkpoint_id may be present if users set the checkpoint_id for
  38. this checkpoint write. The meaning of the checkpiont_id is
  39. storage-dependent. It can be a path to a folder/file or a key for
  40. a key-value storage.
  41. Args:
  42. checkpoint_id (Union[str, os.PathLike, None]):
  43. The ID of this checkpoint instance. The meaning of the checkpoint_id
  44. depends on the storage. It can be a path to a folder or to a file.
  45. It can also be a key if the storage is a key-value store.
  46. (Default: ``None``)
  47. """
  48. ...
  49. @abc.abstractmethod
  50. def set_up_storage_writer(self, is_coordinator: bool) -> None:
  51. """
  52. Initialize this instance.
  53. Args:
  54. is_coordinator (bool): Whether this instance is responsible for coordinating
  55. the checkpoint.
  56. """
  57. pass
  58. @abc.abstractmethod
  59. def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
  60. """
  61. Perform storage-specific local planning.
  62. While this method can produce a completely different plan, the recommended
  63. way is to store storage specific data in SavePlan::storage_data.
  64. Args:
  65. plan (SavePlan): The local plan from the ``SavePlanner`` in use.
  66. Returns:
  67. A transformed ``SavePlan`` after storage local planning
  68. """
  69. pass
  70. @abc.abstractmethod
  71. def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
  72. """
  73. Perform centralized planning of storage.
  74. This method is only called on the coordinator instance.
  75. While this method can produce a completely different plan, the preferred
  76. way is to store storage specific data in SavePlan::storage_data.
  77. Args:
  78. plans: A list of ``SavePlan`` instances, one for each rank.
  79. Returns:
  80. A list of transformed ``SavePlan`` after storage global planning
  81. """
  82. pass
  83. @abc.abstractmethod
  84. def write_data(
  85. self, plan: SavePlan, planner: SavePlanner
  86. ) -> Future[List[WriteResult]]:
  87. """
  88. Write all items from ``plan`` using ``planner`` to resolve the data.
  89. A subclass should call ``SavePlanner::resolve_data`` on each item
  90. from the plan to get access to the underlying object to write.
  91. Subclasses should lazily call `resolve_data` as it can allocate memory.
  92. In case of tensors, make following assumptions:
  93. - They might be on any device, including not matching the one on ``WriteItem::tensor_data``
  94. - They might be views or not contiguous. Only the projection needs to be saved.
  95. Args:
  96. plan (SavePlan): The save plan to execute.
  97. planner (SavePlanner): Planner object to be used to resolve items to data.
  98. Returns:
  99. A future that completes to a list of WriteResult
  100. """
  101. pass
  102. @abc.abstractmethod
  103. def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
  104. """
  105. Write the metadata and marks the current checkpoint as successful.
  106. The actual format/schema used for serializing `metadata` is an
  107. implementation detail. The only requirement is that it's recoverable
  108. in to the same object graph.
  109. Args:
  110. metadata (Metadata): metadata for the new checkpoint
  111. results: A list of WriteResults from all ranks.
  112. Returns:
  113. None
  114. """
  115. pass
  116. @classmethod
  117. @abc.abstractmethod
  118. def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
  119. """
  120. Check if the given checkpoint_id is supported by the stroage. This allow
  121. us to enable automatic storage selection.
  122. """
  123. ...
  124. def storage_meta(self) -> Optional[StorageMeta]:
  125. """
  126. Return the storage-specific metadata. This is used to store additional information
  127. in a checkpoint that can be useful for providing request-level observability. StorageMeta
  128. is passed to the ``SavePlanner`` during save calls. Returns None by default.
  129. TODO: provide an example
  130. """
  131. return None
  132. class StorageReader(abc.ABC):
  133. """
  134. Interface used by ``load_state_dict`` to read from storage.
  135. One StorageReader instance acts as both the coordinator and the follower
  136. in a distributed checkpoint. As part of initialization, each instance
  137. is told its role.
  138. A subclass should expected the following sequence of calls by ``load_state_dict``:
  139. 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
  140. 1) (all ranks) read_metadata()
  141. 2) (all ranks) set_up_storage_reader()
  142. 3) (all ranks) prepare_local_plan()
  143. 4) (coordinator) prepare_global_plan()
  144. 5) (all ranks) read_data()
  145. """
  146. @abc.abstractmethod
  147. def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
  148. """
  149. Calls to indicates a brand new checkpoint read is going to happen.
  150. A checkpoint_id may be present if users set the checkpoint_id for
  151. this checkpoint read. The meaning of the checkpiont_id is
  152. storage-dependent. It can be a path to a folder/file or a key for
  153. a key-value storage.
  154. Args:
  155. checkpoint_id (Union[str, os.PathLike, None]):
  156. The ID of this checkpoint instance. The meaning of the checkpoint_id
  157. depends on the storage. It can be a path to a folder or to a file.
  158. It can also be a key if the storage is more like a key-value store.
  159. (Default: ``None``)
  160. """
  161. ...
  162. @abc.abstractmethod
  163. def read_metadata(self) -> Metadata:
  164. """
  165. Read the checkpoint metadata.
  166. Returns:
  167. The metadata object associated with the checkpoint being loaded.
  168. """
  169. pass
  170. @abc.abstractmethod
  171. def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
  172. """
  173. Initialize this instance.
  174. Args:
  175. metadata (Metadata): The metadata schema to use.
  176. is_coordinator (bool): Whether this instance is responsible for coordinating
  177. the checkpoint.
  178. """
  179. pass
  180. @abc.abstractmethod
  181. def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
  182. """
  183. Perform storage-specific local planning.
  184. While this method can produce a completely different plan, the recommended
  185. way is to store storage specific data in LoadPlan::storage_data.
  186. Args:
  187. plan (LoadPlan): The local plan from the ``LoadPlan`` in use.
  188. Returns:
  189. A transformed ``LoadPlan`` after storage local planning
  190. """
  191. pass
  192. @abc.abstractmethod
  193. def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
  194. """
  195. Perform centralized planning of storage loading.
  196. This method is only called on the coordinator instance.
  197. While this method can produce a completely different plan, the preferred
  198. way is to store storage specific data in LoadPlan::storage_data.
  199. Args:
  200. plans: A list of ``LoadPlan`` instances, one for each rank.
  201. Returns:
  202. A list of transformed ``LoadPlan`` after storage global planning
  203. """
  204. pass
  205. @abc.abstractmethod
  206. def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
  207. """
  208. Read all items from ``plan`` using ``planner`` to resolve the data.
  209. A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO
  210. object into the right place.
  211. A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the
  212. tensors that in should load data into.
  213. It's the StorageLayer responsibility to properly schedule any cross device copies
  214. required.
  215. Args:
  216. plan (LoadPlan): The local plan to execute on
  217. planner (LoadPlanner): The planner object to use to resolve items.
  218. Returns:
  219. A future that completes once all reads are finished.
  220. """
  221. pass
  222. @classmethod
  223. @abc.abstractmethod
  224. def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
  225. """
  226. Check if the given checkpoint_id is supported by the stroage. This allow
  227. us to enable automatic storage selection.
  228. """
  229. ...