planner.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. import abc
  2. import io
  3. import operator
  4. from dataclasses import dataclass
  5. from enum import auto, Enum
  6. from functools import reduce
  7. from typing import Any, List, Optional, Tuple, Union
  8. import torch
  9. from torch.distributed.checkpoint.metadata import (
  10. ChunkStorageMetadata,
  11. Metadata,
  12. MetadataIndex,
  13. STATE_DICT_TYPE,
  14. StorageMeta,
  15. TensorProperties,
  16. )
  17. __all__ = [
  18. "WriteItemType",
  19. "LoadItemType",
  20. "TensorWriteData",
  21. "WriteItem",
  22. "ReadItem",
  23. "SavePlan",
  24. "LoadPlan",
  25. "SavePlanner",
  26. "LoadPlanner",
  27. ]
  28. class WriteItemType(Enum):
  29. TENSOR = auto()
  30. SHARD = auto()
  31. BYTE_IO = auto()
  32. class LoadItemType(Enum):
  33. TENSOR = auto()
  34. BYTE_IO = auto()
  35. @dataclass(frozen=True)
  36. class TensorWriteData:
  37. chunk: ChunkStorageMetadata
  38. properties: TensorProperties
  39. size: torch.Size
  40. @dataclass(frozen=True)
  41. class WriteItem:
  42. """Dataclass which holds information about what needs to be written to storage."""
  43. index: MetadataIndex
  44. type: WriteItemType
  45. # Value present if it's a tensor write
  46. tensor_data: Optional[TensorWriteData] = None
  47. def tensor_storage_size(self) -> Optional[int]:
  48. """
  49. Calculates the storage size of the underlying tensor, or None if this is not a tensor write.
  50. Returns:
  51. Optional[int] storage size, in bytes of underlying tensor if any.
  52. """
  53. if self.tensor_data is None:
  54. return None
  55. numels = reduce(operator.mul, self.tensor_data.size, 1)
  56. dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype)
  57. return numels * dtype_size
  58. @dataclass(frozen=True)
  59. class ReadItem:
  60. # Read Item
  61. type: LoadItemType
  62. # Index into the state_dict
  63. dest_index: MetadataIndex
  64. # Offsets into destination tensor
  65. dest_offsets: torch.Size
  66. # Index into the checkpoint
  67. storage_index: MetadataIndex
  68. # Offset into the checkpoint data
  69. storage_offsets: torch.Size
  70. # Size of the hypercube to copy
  71. lengths: torch.Size
  72. @dataclass(frozen=True)
  73. class SavePlan:
  74. items: List[WriteItem]
  75. storage_data: Any = None
  76. planner_data: Any = None
  77. @dataclass
  78. class LoadPlan:
  79. items: List[ReadItem]
  80. storage_data: Any = None
  81. planner_data: Any = None
  82. class SavePlanner(abc.ABC):
  83. """
  84. Abstract class defining the protocol used by save_state_dict to plan the save process.
  85. SavePlanners are stateful objects that can be used to customize the whole save process.
  86. SavePlanner acts as an access proxy to the state_dict, so any transformation done to it
  87. will be visible to the whole process.
  88. A planner subclass can expect the following sequence of calls during save_state_dict:
  89. 1) set_up_planner - called on all ranks.
  90. Signals the start of a checkpoint save.
  91. 2) create_local_plan - called on all ranks.
  92. Process the state_dict and produces a `SavePlan` that will be sent for global planning.
  93. 3) create_global_plan - called on the coordinator rank only.
  94. Takes the SavePlan from all ranks and make any global decision.
  95. 4) finish_plan - called on all ranks.
  96. This gives each rank a chance to adjust to global planning decisions.
  97. 5) resolve_data - called multiple times on each rank
  98. Lookups a value on the `state_dict` for the storage layer to write.
  99. Users are recommended to extend DefaultSavePlanner instead of this interface directly as
  100. most changes can be expressed by changes in a single method.
  101. There are 3 usual patterns of extension:
  102. Rewriting state_dict. This is the simplest way to extend the save process as it
  103. doesn't requite understanding the intrincacies of how SavePlan works:
  104. >>> # xdoctest: +SKIP("undefined vars")
  105. >>> class RenamePlanner(DefaultSavePlanner):
  106. >>> def set_up_planner(
  107. >>> self,
  108. >>> state_dict: STATE_DICT_TYPE,
  109. >>> storage_meta: Optional[StorageMeta],
  110. >>> is_coordinator: bool,
  111. >>> ) -> None:
  112. >>> # prefix all keys with `foo_``
  113. >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
  114. Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
  115. >>> # xdoctest: +SKIP("undefined vars")
  116. >>> class FP16Planner(DefaultSavePlanner):
  117. >>> def create_local_plan(self):
  118. >>> plan = super().create_local_plan()
  119. >>> for p in plan:
  120. >>> if p.tensor_data is not None:
  121. >>> p.tensor_data.properties.dtype = torch.float16
  122. >>> return plan
  123. >>>
  124. >>> def resolve_data(self, write_item):
  125. >>> item = super().resolve_data(write_item)
  126. >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
  127. Using the global planning step to make central decisions that can't be made individually by each rank
  128. >>> # xdoctest: +SKIP("undefined vars")
  129. >>> from itertools import islice
  130. >>> from dataclasses import replace
  131. >>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
  132. >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
  133. >>> # This sample doesn't handle ShardedTensors
  134. >>> def create_global_plan(self, all_plans):
  135. >>> def chunk(it, size):
  136. >>> it = iter(it)
  137. >>> return list(iter(lambda: tuple(islice(it, size)), ()))
  138. >>> all_plans = [
  139. >>> replace(plan, items=items) for plan, items in
  140. >>> zip(all_plans, chunk(all_plans[0].items, len(all_plans)))
  141. >>> ]
  142. >>> return super().create_global_plan(all_plans)
  143. Finally, some planners need to save additional metadata in the checkpoint, this is
  144. accomplished by having each rank contribute their data items in the local plan and
  145. the global planner aggregate them:
  146. >>> # xdoctest: +SKIP("undefined vars")
  147. >>> class SaveExtraDataPlanner(DefaultSavePlanner):
  148. >>> def create_local_plan(self) -> SavePlan:
  149. >>> plan = super().create_local_plan()
  150. >>> return replace(plan, planner_data="per-rank-data")
  151. >>>
  152. >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
  153. >>> global_plan, metadata = super().create_global_plan(all_plans)
  154. >>> merged_data = [p.planner_data for p in global_plan]
  155. >>> metadata = replace(metadata, planner_data=merged_data)
  156. >>> return global_plan, metadata
  157. """
  158. @abc.abstractmethod
  159. def set_up_planner(
  160. self,
  161. state_dict: STATE_DICT_TYPE,
  162. storage_meta: Optional[StorageMeta] = None,
  163. is_coordinator: bool = False,
  164. ) -> None:
  165. """
  166. Initialize this planner to save ``state_dict``.
  167. Implementations should save those values as they won't be provided lated in the save process.
  168. This is called on all ranks.
  169. """
  170. pass
  171. @abc.abstractmethod
  172. def create_local_plan(self) -> SavePlan:
  173. """
  174. Compute the save plan for the current rank.
  175. This will be aggregated and passed to create_global_plan.
  176. Planner specific data can be passed through SavePlan::planner_data.
  177. This is called on all ranks.
  178. """
  179. pass
  180. @abc.abstractmethod
  181. def create_global_plan(
  182. self, all_plans: List[SavePlan]
  183. ) -> Tuple[List[SavePlan], Metadata]:
  184. """
  185. Compute the global checkpoint plan and return the local plan of each rank.
  186. This is called on the coordinator rank only.
  187. """
  188. pass
  189. @abc.abstractmethod
  190. def finish_plan(self, new_plan: SavePlan) -> SavePlan:
  191. """
  192. Merge the plan created by `create_local_plan` and the result of `create_global_plan`.
  193. This is called on all ranks.
  194. """
  195. pass
  196. @abc.abstractmethod
  197. def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
  198. """
  199. Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety.
  200. Lookup the object associated with ``write_item`` in ``state_dict`` and apply any
  201. transformation (such as serialization) prior to the storage layer consuming it.
  202. Called on each rank multiple times, at least once per WriteItem in the final SavePlan.
  203. This method should be idempotent and thread-save. StorageWriter implementations
  204. are free to call it as frequently as they need.
  205. Any transformation that allocates memory should be lazily done when his method
  206. is called in order to reduce peak memory required by checkpointing.
  207. When returning tensors, they can be on any device or format, they can be views too.
  208. It's the storage layer responsibility to figure out how to save them.
  209. """
  210. pass
  211. class LoadPlanner:
  212. """
  213. Abstract class defining the protocol used by load_state_dict to plan the load process.
  214. LoadPlanner are stateful objects that can be used to customize the whole load process.
  215. LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it
  216. will be visible to the whole process.
  217. A planner subclass can expect the following sequence of calls during load_state_dict:
  218. 1) set_up_planner - called on all ranks.
  219. Signals the start of loading a checkpoint.
  220. 2) create_local_plan - called on all ranks.
  221. Process the state_dict and produces a `LoadPlan` that will be sent for global planning.
  222. 3) create_global_plan - called on the coordinator rank only.
  223. Takes the LoadPlan from all ranks and make any global decision.
  224. 4) load_bytes - called multiple times on each rank
  225. This is called once per non-tensor value in state_dict.
  226. 5) resolve_tensor and commit_tensor - called multiple times on each rank
  227. They are called in pair for each Tensor value in state_dict.
  228. Users are recommended to extend DefaultLoadPlanner instead of this interface directly as
  229. most changes can be expressed by changes in a single method.
  230. There are two usual patterns of extension:
  231. Rewriting state_dict. This is the simplest way to extend the load process as it
  232. doesn't requite understanding the intrincacies of how LoadPlan works. We need
  233. to keep a reference to the original state_dict as load happens in place so
  234. we need to be able to perform it in place
  235. >>> # xdoctest: +SKIP("undefined vars")
  236. >>> class RenamePlanner(DefaultLoadPlanner):
  237. >>> def set_up_planner(
  238. >>> self,
  239. >>> state_dict: STATE_DICT_TYPE,
  240. >>> metadata: Metadata,
  241. >>> is_coordinator: bool,
  242. >>> ) -> None:
  243. >>> self.original_state_dict = state_dict
  244. >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()}
  245. >>>
  246. >>> if self.flatten_sharded_tensors:
  247. >>> state_dict = _flatten_sharded_tensors(state_dict)
  248. >>>
  249. >>> if self.flatten_state_dict:
  250. >>> state_dict, self.mappings = flatten_state_dict(state_dict)
  251. >>>
  252. >>> self.state_dict = state_dict
  253. >>> self.metadata = metadata
  254. >>> self.is_coordinator = is_coordinator
  255. >>>
  256. >>> def load_bytes(self, read_item, value):
  257. >>> # Remove the "foo_" prefix
  258. >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
  259. Modifying resolve_tensor and commit_tensor to handle load time transformation.
  260. >>> # xdoctest: +SKIP("undefined vars")
  261. >>> class MetaModelMaterialize(DefaultSavePlanner):
  262. >>> def resolve_tensor(self, read_item):
  263. >>> tensor = super().resolve_tensor(read_item)
  264. >>> return torch.empty_like(tensor, device="cpu")
  265. >>>
  266. >>> def commit_tensor(self, read_item, tensor):
  267. >>> self.state_dict[read_item.dest_index.fqn] = tensor
  268. """
  269. @abc.abstractmethod
  270. def set_up_planner(
  271. self,
  272. state_dict: STATE_DICT_TYPE,
  273. metadata: Optional[Metadata] = None,
  274. is_coordinator: bool = False,
  275. ) -> None:
  276. """
  277. Initialize this instance to load data into ``state_dict``.
  278. . N.B. This is called on every rank.
  279. """
  280. pass
  281. @abc.abstractmethod
  282. def create_local_plan(self) -> LoadPlan:
  283. """
  284. Create a LoadPlan based on state_dict and metadata provided by set_up_planner.
  285. . N.B. This is called on every rank.
  286. """
  287. pass
  288. @abc.abstractmethod
  289. def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
  290. """
  291. Compute the global load plan and return plans for each rank.
  292. . N.B. This is called on the coordinator rank only
  293. """
  294. pass
  295. @abc.abstractmethod
  296. def finish_plan(self, central_plan: LoadPlan) -> LoadPlan:
  297. """Accept the plan from coordinator and return final LoadPlan."""
  298. pass
  299. @abc.abstractmethod
  300. def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
  301. """
  302. Load the item described by ``read_item``and ``value``.
  303. This method is expected to modify in-place the underlying state_dict.
  304. The contents of ``value`` are defined by the SavePlanner used to produce
  305. the checkpoint being loaded.
  306. """
  307. pass
  308. def resolve_bytes(self, read_item: ReadItem) -> io.BytesIO:
  309. """
  310. Return the BytesIO to be used by the StorageReader to load `read_item`.
  311. The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents.
  312. """
  313. raise NotImplementedError("LoadPlanner.resolve_bytes is not implemented")
  314. @abc.abstractmethod
  315. def resolve_tensor(self, read_item: ReadItem) -> torch.Tensor:
  316. """
  317. Return the tensor described by ``read_item`` to be used by the StorageReader to load `read_item`.
  318. The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents.
  319. If, for any reason, that's not possible, the planner can use the ``commit_tensor`` method to copy the data
  320. back to the one in state_dict.
  321. """
  322. pass
  323. @abc.abstractmethod
  324. def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
  325. """
  326. Call once the StorageReader finished loading data into ``tensor``.
  327. The provided tensor is the same one returned by the call to ``resolve_tensor``.
  328. This method is only needed if this LoadPlanner needs to post process ``tensor`` prior to
  329. copying it back to the one in the state_dict.
  330. The contents of tensor will follow its device synchronization model.
  331. """
  332. pass
  333. class _Checkpointable:
  334. """
  335. Interface for checkpointable objects.
  336. This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface.
  337. """
  338. @abc.abstractmethod
  339. def _create_write_items(self, fqn: str, object: Any) -> List[WriteItem]:
  340. """
  341. Return a list of WriteItems based on object's contents.
  342. """
  343. raise NotImplementedError(
  344. "_Checkpointable._create_write_items is not implemented"
  345. )
  346. @abc.abstractmethod
  347. def _create_chunk_list(self, tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
  348. """
  349. Return a list of `ChunkStorageMetadata` based on object's contents.
  350. """
  351. raise NotImplementedError(
  352. "_Checkpointable._create_chunk_list is not implemented"
  353. )
  354. @abc.abstractmethod
  355. def _get_tensor_shard(
  356. self, tensor: torch.Tensor, index: MetadataIndex
  357. ) -> torch.Tensor:
  358. """
  359. Return a 'torch.Tensor' shard based on 'MetadataIndex'.
  360. """
  361. raise NotImplementedError(
  362. "_Checkpointable._get_tensor_shard is not implemented"
  363. )