format_utils.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # mypy: allow-untyped-defs
  2. import argparse
  3. import os
  4. from enum import Enum
  5. from typing import cast, Dict, List, Optional, Union
  6. import torch
  7. import torch.distributed as dist
  8. from torch.distributed._shard._utils import narrow_tensor_by_index
  9. from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
  10. from torch.distributed.checkpoint._nested_dict import flatten_state_dict
  11. from torch.distributed.checkpoint.default_planner import (
  12. _EmptyStateDictLoadPlanner,
  13. DefaultLoadPlanner,
  14. )
  15. from torch.distributed.checkpoint.metadata import (
  16. Metadata,
  17. STATE_DICT_TYPE,
  18. STORAGE_TYPES,
  19. TensorProperties,
  20. TensorStorageMetadata,
  21. )
  22. from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner
  23. from torch.distributed.checkpoint.planner_helpers import _create_chunk_list
  24. from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
  25. from torch.distributed.checkpoint.state_dict_saver import _save_state_dict
  26. from torch.distributed.checkpoint.storage import StorageReader
  27. from torch.futures import Future
  28. __all__ = [
  29. "dcp_to_torch_save",
  30. "torch_save_to_dcp",
  31. "BroadcastingTorchSaveReader",
  32. "DynamicMetaLoadPlanner",
  33. ]
  34. class BroadcastingTorchSaveReader(StorageReader):
  35. """
  36. StorageReader for reading a Torch Save file. This reader will read the entire checkpoint
  37. on the coordinator rank, and then broadcast and shard each tensor to all ranks.
  38. . N.B. Intended to be used with DynamicMetaLoadPlanner
  39. .. warning::
  40. Current implementation only supports loading Tensors.
  41. >>> # xdoctest: +SKIP("undefined vars")
  42. >>> sd = {"mode": model}
  43. >>> dcp.load(
  44. >>> sd,
  45. >>> storage_reader=BroadcastingTorchSaveReader(),
  46. >>> planner=DynamicMetaLoadPlanner(),
  47. >>> checkpoint_id="path_to_model.pt"
  48. >>> )
  49. """
  50. def __init__(
  51. self,
  52. checkpoint_id: Optional[Union[str, os.PathLike]] = None,
  53. coordinator_rank: int = 0,
  54. ) -> None:
  55. self.checkpoint_id = checkpoint_id
  56. self.coordinator_rank = coordinator_rank
  57. def read_metadata(self) -> Metadata:
  58. """Extends the default StorageReader to support building the metadata file"""
  59. # Metadata is built in planner.set_up_planner, since we are not actually reading metadata from
  60. # the disk
  61. return Metadata(state_dict_metadata={})
  62. def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
  63. """
  64. Reads torch save data on the coordinator rank, and broadcast afterwards
  65. this incurrs a communication cost, but avoids having to load
  66. the entire checkpoint on each rank, hopefully preventing OOM issues
  67. """
  68. planner = cast(DefaultLoadPlanner, planner)
  69. # data is read in on the coordinator rank, and broadcast afterwards
  70. # this incurrs a communication cost, but it avoids having to load
  71. # the entire checkpoint on each rank, hopefully preventing OOM issues
  72. # TODO: read on each host, instead of only the coordinator
  73. if self.is_coordinator:
  74. assert self.checkpoint_id is not None
  75. torch_state_dict = torch.load(
  76. self.checkpoint_id, map_location="cpu", weights_only=False
  77. )
  78. if planner.flatten_state_dict:
  79. torch_state_dict, _ = flatten_state_dict(torch_state_dict)
  80. else:
  81. torch_state_dict = None
  82. for req in plan.items:
  83. if req.type == LoadItemType.BYTE_IO:
  84. raise RuntimeError(
  85. f"Non-tensor value identified at {req.storage_index.fqn}. "
  86. f"At this time {type(self).__name__} only supports loading Tensors."
  87. )
  88. # Broadcast the tensor from the coordinator rank
  89. if self.is_coordinator:
  90. tensor = torch_state_dict[req.storage_index.fqn].cuda()
  91. else:
  92. tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn])
  93. dist.broadcast(tensor, src=self.coordinator_rank, async_op=False)
  94. tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
  95. target_tensor = planner.resolve_tensor(req).detach()
  96. assert target_tensor.size() == tensor.size(), (
  97. f"req {req.storage_index} mismatch sizes, "
  98. f"{target_tensor.size()} vs {tensor.size()}"
  99. )
  100. target_tensor.copy_(tensor)
  101. planner.commit_tensor(req, target_tensor)
  102. fut: Future = Future()
  103. fut.set_result(None)
  104. return fut
  105. def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
  106. """Implementation of the StorageReader method"""
  107. self.is_coordinator = is_coordinator
  108. if self.is_coordinator:
  109. assert dist.get_rank() == self.coordinator_rank
  110. assert self.checkpoint_id is not None
  111. def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
  112. """Implementation of the StorageReader method"""
  113. return plan
  114. def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
  115. """Implementation of the StorageReader method"""
  116. return global_plan
  117. def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
  118. """Implementation of the StorageReader method"""
  119. self.checkpoint_id = checkpoint_id
  120. @classmethod
  121. def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
  122. """Implementation of the StorageReader method"""
  123. return os.path.isfile(checkpoint_id)
  124. class DynamicMetaLoadPlanner(DefaultLoadPlanner):
  125. """
  126. Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict,
  127. avoiding the need to read metadata from disk. This is useful when reading formats which don't have a
  128. metadata file, like Torch Save files.
  129. . N.B. Intended to be used with BroadcastingTorchSaveReader
  130. .. warning::
  131. Current implementation only supports loading Tensors.
  132. >>> # xdoctest: +SKIP("undefined vars")
  133. >>> sd = {"mode": model}
  134. >>> dcp.load(
  135. >>> sd,
  136. >>> storage_reader=BroadcastingTorchSaveReader(),
  137. >>> planner=DynamicMetaLoadPlanner(),
  138. >>> checkpoint_id="path_to_model.pt"
  139. >>> )
  140. """
  141. def set_up_planner(
  142. self,
  143. state_dict: STATE_DICT_TYPE,
  144. metadata: Optional[Metadata] = None,
  145. is_coordinator: bool = False,
  146. ) -> None:
  147. """Setups of the planner, extnding default behavior by creating the Metadata object from the state dict"""
  148. super().set_up_planner(state_dict, metadata, is_coordinator)
  149. state_dict_metadata: Dict[str, STORAGE_TYPES] = {}
  150. for key, tensor in self.state_dict.items():
  151. if not torch.is_tensor(tensor):
  152. raise RuntimeError(
  153. f"Non-tensor value identified at {key}. "
  154. f"At this time {type(self).__name__} only supports loading Tensors."
  155. )
  156. state_dict_metadata[key] = TensorStorageMetadata(
  157. TensorProperties(dtype=tensor.dtype),
  158. tensor.size(),
  159. _create_chunk_list(tensor),
  160. )
  161. self.metadata = Metadata(state_dict_metadata=state_dict_metadata)
  162. def dcp_to_torch_save(
  163. dcp_checkpoint_dir: Union[str, os.PathLike],
  164. torch_save_path: Union[str, os.PathLike],
  165. ):
  166. """
  167. Given a directory containing a DCP checkpoint, this function will convert it into a
  168. Torch save file.
  169. Args:
  170. dcp_checkpoint_dir: Directory containing the DCP checkpoint.
  171. torch_save_path: Filename to store the converted Torch save file.
  172. .. warning::
  173. To avoid OOM, it's recommended to only run this function on a single rank.
  174. """
  175. sd: STATE_DICT_TYPE = {}
  176. _load_state_dict(
  177. sd,
  178. storage_reader=FileSystemReader(dcp_checkpoint_dir),
  179. planner=_EmptyStateDictLoadPlanner(),
  180. no_dist=True,
  181. )
  182. torch.save(sd, torch_save_path)
  183. def torch_save_to_dcp(
  184. torch_save_path: Union[str, os.PathLike],
  185. dcp_checkpoint_dir: Union[str, os.PathLike],
  186. ):
  187. """
  188. Given the location of a torch save file, converts it into a DCP checkpoint.
  189. Args:
  190. torch_save_path: Filename of the Torch save file.
  191. dcp_checkpoint_dir: Directory to store the DCP checkpoint.
  192. .. warning::
  193. To avoid OOM, it's recommended to only run this function on a single rank.
  194. """
  195. state_dict = torch.load(torch_save_path, weights_only=False)
  196. # we don't need stateful behavior here because the expectation is anything loaded by
  197. # torch.load would not contain stateful objects.
  198. _save_state_dict(
  199. state_dict, storage_writer=FileSystemWriter(dcp_checkpoint_dir), no_dist=True
  200. )
  201. if __name__ == "__main__":
  202. class FormatMode(Enum):
  203. TORCH_TO_DCP = "torch_to_dcp"
  204. DCP_TO_TORCH = "dcp_to_torch"
  205. # Parse command-line arguments
  206. parser = argparse.ArgumentParser()
  207. parser.add_argument(
  208. "mode",
  209. type=str,
  210. help="Conversion mode",
  211. choices=[m.value for m in FormatMode],
  212. default=FormatMode.TORCH_TO_DCP,
  213. )
  214. parser.add_argument("src", type=str, help="Path to the source model")
  215. parser.add_argument("dst", type=str, help="Path to the destination model")
  216. args = parser.parse_args()
  217. print(
  218. f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'"
  219. )
  220. checkpoint_missing_warning = (
  221. f"No checkpoint found at {args.src}. Skipping conversion."
  222. )
  223. if args.mode == FormatMode.TORCH_TO_DCP.value:
  224. if os.path.isfile(args.src):
  225. torch_save_to_dcp(args.src, args.dst)
  226. else:
  227. print(checkpoint_missing_warning)
  228. elif args.mode == FormatMode.DCP_TO_TORCH.value:
  229. if os.path.isdir(args.src):
  230. dcp_to_torch_save(args.src, args.dst)
  231. else:
  232. print(checkpoint_missing_warning)
  233. else:
  234. raise ValueError(f"Unknown conversion mode: {args.mode}")