state_dict_loader.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # mypy: allow-untyped-defs
  2. import os
  3. import warnings
  4. from typing import Any, cast, Dict, Optional, Set, Union
  5. from typing_extensions import deprecated
  6. import torch
  7. import torch.distributed as dist
  8. from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
  9. from torch.distributed.checkpoint.logger import _dcp_method_logger
  10. from torch.distributed.checkpoint.stateful import Stateful
  11. from ._storage_utils import _storage_setup
  12. from .default_planner import DefaultLoadPlanner
  13. from .planner import LoadPlan, LoadPlanner
  14. from .storage import StorageReader
  15. from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile
  16. __all__ = ["load_state_dict", "load"]
  17. @deprecated(
  18. "`load_state_dict` is deprecated and will be removed in future versions. "
  19. "Please use `load` instead.",
  20. category=FutureWarning,
  21. )
  22. def load_state_dict(
  23. state_dict: Dict[str, Any],
  24. storage_reader: StorageReader,
  25. process_group: Optional[dist.ProcessGroup] = None,
  26. coordinator_rank: int = 0,
  27. no_dist: bool = False,
  28. planner: Optional[LoadPlanner] = None,
  29. ) -> None:
  30. """This method is deprecated. Please switch to 'load'."""
  31. storage_reader.reset()
  32. with _profile():
  33. # TODO: test returning `load` here instead.
  34. return _load_state_dict(
  35. state_dict,
  36. storage_reader,
  37. process_group,
  38. coordinator_rank,
  39. no_dist,
  40. planner,
  41. )
  42. @_dcp_method_logger(log_exceptions=True)
  43. @_api_bc_check
  44. def load(
  45. state_dict: Dict[str, Any],
  46. *,
  47. checkpoint_id: Union[str, os.PathLike, None] = None,
  48. storage_reader: Optional[StorageReader] = None,
  49. planner: Optional[LoadPlanner] = None,
  50. process_group: Optional[dist.ProcessGroup] = None,
  51. ) -> None:
  52. """
  53. Load a distributed ``state_dict`` in SPMD style.
  54. Each rank will try to read the least amount of data necessary
  55. to fullfill the requested `state_dict`. When loading :class:`ShardedTensor`
  56. or :class:`DTensor` instances, each rank only reads data for their local shards.
  57. For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``),
  58. load will first call ``state_dict`` before attempting deserialization, followed by
  59. ``load_state_dict`` once the deserialization is complete.
  60. .. warning::
  61. All tensors in ``state_dict`` must be allocated on their
  62. destination device *prior to* calling this function.
  63. All non-tensor data is loaded using `torch.load()` and modified in place
  64. on state_dict.
  65. .. warning::
  66. Users must call `load_state_dict` on the root module to ensure load
  67. pos-processing and non-tensor data properly propagates.
  68. .. note:
  69. If no process group is initialized, this function will assume the intent
  70. is to load a checkpoint into the local process. This can be useful in the
  71. case of local inference, and when using regular Tensors (as opposed to DTensor
  72. or ShardedTensor)
  73. .. note:
  74. Rank 0 is assumed to be the coordinator rank.
  75. Args:
  76. state_dict (Dict[str, Any]): The state_dict to save.
  77. checkpoint_id (Union[str, os.PathLike, None]):
  78. The ID of this checkpoint instance. The meaning of the checkpoint_id
  79. depends on the storage. It can be a path to a folder or to a file.
  80. It can also be a key if the storage is a key-value store.
  81. (Default: ``None``)
  82. storage_reader (Optional[StorageReader]):
  83. Instance of StorageWriter used to perform reads. If this is not
  84. specified, DCP will automatically infer the reader based on the
  85. checkpoint_id. If checkpoint_id is also None, an exception will
  86. be raised. (Default: ``None``)
  87. planner (Optional[LoadPlanner]):
  88. Instance of LoadPlanner. If this is not specificed, the default
  89. planner will be used. (Default: ``None``)
  90. process_group (Optional[ProcessGroup]):
  91. ProcessGroup to be used for cross-rank synchronization.
  92. (Default: ``None``)
  93. Returns:
  94. None.
  95. Examples
  96. >>> # xdoctest: +SKIP
  97. >>> my_model = MyModule()
  98. >>> optimizer = Adagrad(my_model.parameters())
  99. >>> model_state_dict = my_model.state_dict()
  100. >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
  101. >>> torch.distributed.checkpoint.load_state_dict(
  102. >>> state_dict=model_state_dict,
  103. >>> storage_reader=fs_storage_reader,
  104. >>> )
  105. >>> # module.load_state_dict() function might have customized steps
  106. >>> # to flush the state_dict, must call it to
  107. >>> # ensure correct behavior.
  108. >>> my_model.load_state_dict(model_state_dict)
  109. .. note::
  110. load_state_dict uses collectives to coordinate reads across ranks.
  111. For NCCL-based process groups, internal tensor representations of
  112. objects must be moved to the GPU device before communication takes place.
  113. In this case, the device used is given by ``torch.cuda.current_device()``
  114. and it is the user's responsibility to ensure that this is set so that each
  115. rank has an individual GPU, via ``torch.cuda.set_device()``.
  116. """
  117. no_dist = not (dist.is_available() and dist.is_initialized())
  118. if no_dist:
  119. warnings.warn(
  120. "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process."
  121. )
  122. with _profile():
  123. storage_reader = cast(
  124. StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
  125. )
  126. if no_dist:
  127. keys = list(state_dict.keys())
  128. else:
  129. keys = _all_gather_keys(state_dict, process_group)
  130. if keys != sorted(state_dict.keys()):
  131. warnings.warn(
  132. "Detected mismatched keys in state dict after all gather!"
  133. " This behavior is unsupported and may cause errors may cause errors."
  134. )
  135. statetful_sd = {}
  136. for key in keys:
  137. if key not in state_dict:
  138. continue
  139. elem = state_dict[key]
  140. statetful_sd[key] = (
  141. elem.state_dict() if isinstance(elem, Stateful) else elem
  142. )
  143. _load_state_dict(
  144. state_dict=statetful_sd,
  145. storage_reader=storage_reader,
  146. process_group=process_group,
  147. no_dist=no_dist,
  148. planner=planner,
  149. )
  150. for key in keys:
  151. if key not in state_dict:
  152. continue
  153. elem = state_dict[key]
  154. if isinstance(elem, Stateful):
  155. elem.load_state_dict(statetful_sd[key])
  156. state_dict[key] = statetful_sd[key]
  157. def _load_state_dict(
  158. state_dict: Dict[str, Any],
  159. storage_reader: StorageReader,
  160. process_group: Optional[dist.ProcessGroup] = None,
  161. coordinator_rank: int = 0,
  162. no_dist: bool = False,
  163. planner: Optional[LoadPlanner] = None,
  164. ) -> None:
  165. torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")
  166. distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
  167. if planner is None:
  168. planner = DefaultLoadPlanner()
  169. ckpt_kwargs = {}
  170. if (ckpt_id := getattr(storage_reader, "checkpoint_id", None)) is not None:
  171. ckpt_kwargs["checkpoint_id"] = ckpt_id
  172. @_dcp_method_logger(**ckpt_kwargs)
  173. def local_step():
  174. assert planner is not None
  175. metadata = storage_reader.read_metadata()
  176. planner.set_up_planner(state_dict, metadata, distW.is_coordinator)
  177. storage_reader.set_up_storage_reader(metadata, distW.is_coordinator)
  178. local_plan = planner.create_local_plan()
  179. local_plan = storage_reader.prepare_local_plan(local_plan)
  180. return local_plan
  181. @_dcp_method_logger(**ckpt_kwargs)
  182. def global_step(all_local_plans):
  183. assert planner is not None
  184. all_local_plans = planner.create_global_plan(all_local_plans)
  185. all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
  186. return all_local_plans
  187. central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
  188. @_dcp_method_logger(**ckpt_kwargs)
  189. def read_data():
  190. assert planner is not None
  191. final_local_plan = planner.finish_plan(central_plan)
  192. all_reads = storage_reader.read_data(final_local_plan, planner)
  193. all_reads.wait()
  194. return None
  195. _ = distW.all_gather("read", read_data)
  196. def _load_state_dict_from_keys(
  197. keys: Optional[Union[Set[str], str]] = None,
  198. *,
  199. checkpoint_id: Union[str, os.PathLike, None] = None,
  200. storage_reader: Optional[StorageReader] = None,
  201. process_group: Optional[dist.ProcessGroup] = None,
  202. ) -> Dict[str, Any]:
  203. """
  204. Load only the specified keys from the checkpoint, if no keys are specified, the entire
  205. checkpoint will be loaded. Note, this method completely loads the checkpoint into the
  206. current process and is not distributed.
  207. .. warning::
  208. .. warning::
  209. All non-tensor data is loaded using `torch.load()`
  210. .. note:
  211. As opposed to the usual pattern, this function does not take a state dict as input
  212. and does not load inplace. Instead, a new state dict is directly initialized and read
  213. from file.
  214. .. note:
  215. If no process group is initialized, this function will assume the intent
  216. is to load a checkpoint into the local process. This can be useful in the
  217. case of local inference, and when using regular Tensors (as opposed to DTensor
  218. or ShardedTensor)
  219. .. note:
  220. Rank 0 is assumed to be the coordinator rank.
  221. Args:
  222. keys (Optional[Union[Set[str], str]]):
  223. Loads any key specified in this set. If no keys are specified, the entire checkpoint
  224. is loaded.
  225. checkpoint_id (Union[str, os.PathLike, None]):
  226. The ID of this checkpoint instance. The meaning of the checkpoint_id
  227. depends on the storage. It can be a path to a folder or to a file.
  228. It can also be a key if the storage is a key-value store.
  229. (Default: ``None``)
  230. storage_reader (Optional[StorageReader]):
  231. Instance of StorageWriter used to perform reads. If this is not
  232. specified, DCP will automatically infer the reader based on the
  233. checkpoint_id. If checkpoint_id is also None, an exception will
  234. be raised. (Default: ``None``)
  235. process_group (Optional[ProcessGroup]):
  236. ProcessGroup to be used for cross-rank synchronization.
  237. (Default: ``None``)
  238. Returns:
  239. State dict from specified keys
  240. """
  241. torch._C._log_api_usage_once(
  242. "torch.distributed.checkpoint._load_state_dict_from_keys"
  243. )
  244. no_dist = not (dist.is_available() and dist.is_initialized())
  245. if no_dist:
  246. warnings.warn(
  247. "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process."
  248. )
  249. storage_reader = cast(
  250. StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
  251. )
  252. if isinstance(keys, str):
  253. keys = {keys}
  254. sd: Dict[str, Any] = {}
  255. _load_state_dict(
  256. state_dict=sd,
  257. storage_reader=storage_reader,
  258. process_group=process_group,
  259. no_dist=no_dist,
  260. planner=_EmptyStateDictLoadPlanner(keys=keys or set()),
  261. )
  262. return sd