state_dict.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. import gc
  5. import warnings
  6. from dataclasses import asdict, dataclass, field
  7. from itertools import chain
  8. from typing import (
  9. Any,
  10. Callable,
  11. cast,
  12. Dict,
  13. Generator,
  14. Iterable,
  15. List,
  16. no_type_check,
  17. Optional,
  18. Set,
  19. Tuple,
  20. Union,
  21. )
  22. import torch
  23. import torch.distributed as dist
  24. import torch.nn as nn
  25. from torch.distributed._shard.sharded_tensor import ShardedTensor
  26. from torch.distributed._state_dict_utils import (
  27. _broadcast_state_dict,
  28. _flatten_state_dict,
  29. _gather_state_dict,
  30. _offload_state_dict_to_cpu,
  31. _unflatten_state_dict,
  32. )
  33. from torch.distributed._tensor import DTensor
  34. from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
  35. _CHECKPOINT_PREFIX,
  36. )
  37. from torch.distributed.fsdp import (
  38. FullOptimStateDictConfig,
  39. FullStateDictConfig,
  40. FullyShardedDataParallel as FSDP,
  41. OptimStateDictConfig,
  42. ShardedOptimStateDictConfig,
  43. ShardedStateDictConfig,
  44. StateDictConfig,
  45. StateDictType,
  46. )
  47. from torch.distributed.fsdp._common_utils import (
  48. _get_module_fsdp_state_if_fully_sharded_module,
  49. FSDP_WRAPPED_MODULE,
  50. )
  51. from torch.nn.modules.module import _IncompatibleKeys
  52. from torch.nn.parallel import DistributedDataParallel as DDP
  53. from torch.utils._pytree import tree_map_only
  54. __all__ = [
  55. "FQNS_T",
  56. "PrimitiveType",
  57. "ValueType",
  58. "DictValueType",
  59. "ListDictValueType",
  60. "OptimizerStateType",
  61. "StateDictOptions",
  62. "get_model_state_dict",
  63. "get_optimizer_state_dict",
  64. "get_state_dict",
  65. "set_model_state_dict",
  66. "set_optimizer_state_dict",
  67. "set_state_dict",
  68. ]
  69. _FLAT_PARAM = "_flat_param"
  70. _PG = "param_groups"
  71. _PARAMS = "params"
  72. _STATE = "state"
  73. FQNS_T = Set[str]
  74. PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str]
  75. ValueType = Union[
  76. PrimitiveType, List[PrimitiveType], Tuple[PrimitiveType], Dict[str, "ValueType"]
  77. ]
  78. DictValueType = Dict[str, ValueType]
  79. ListDictValueType = List[DictValueType]
  80. OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]]
  81. _patched_state_dict: Set[Callable] = set()
  82. @contextlib.contextmanager
  83. def _gc_context():
  84. is_enabled = gc.isenabled()
  85. gc.disable()
  86. try:
  87. yield
  88. finally:
  89. if is_enabled:
  90. gc.enable()
  91. @dataclass
  92. class StateDictOptions:
  93. """
  94. This dataclass specifies how get_state_dict/set_state_dict will work.
  95. - ``full_state_dict``: if this is set to True, all the tensors in the
  96. returned state_dict will be gathered. No ShardedTensor and DTensor
  97. will be in the returned state_dict.
  98. - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if
  99. ``full_state_dict`` is also true, then only the rank0 will get the
  100. state_dict and all other ranks will get empty state_dict.
  101. - ``ignore_frozen_params``: if the value is True, the returned state_dict
  102. won't contain any frozen parameters -- the ``requires_grad`` is False.
  103. The default value is False.
  104. - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option
  105. indicates whether to keep the submodule prefixes from the state_dict keys.
  106. or example, if the submodule is ``module.pretrain`` and the full FQN of
  107. the parameter is ``pretrain.layer1.weight`` of the param. When this option
  108. is True, the parameter's key in the returned state_dict will be
  109. ``pretrain.layer1.weight``. If the options is False, the key will be
  110. ``layer1.weight``.
  111. Note that if ``keep_submodule_prefixes`` is False, there may be conflicted
  112. FQNs, hence there should be only one submodule in ``submodules``.
  113. - ``strict``: the ``strict`` option when ``set_state_dict`` calls
  114. model.load_state_dict().
  115. - ``broadcast_from_rank0``: when the option is True, rank0 should receive a
  116. full state_dict and will broadcast the tensors in the state_dict/
  117. optim_state_dict one by one to other ranks. Other ranks will receive
  118. the tensors and shard according to the local shards in the model and
  119. optimizer. ``full_state_dict`` must be set to True when using this option.
  120. This option currently only supports DTensor, not the legacy ShardedTensor.
  121. """
  122. full_state_dict: bool = False
  123. cpu_offload: bool = False
  124. ignore_frozen_params: bool = False
  125. keep_submodule_prefixes: bool = True
  126. strict: bool = True
  127. broadcast_from_rank0: bool = False
  128. flatten_optimizer_state_dict: bool = False
  129. @dataclass
  130. class _StateDictInfo(StateDictOptions):
  131. fqn_param_mapping: Dict[
  132. Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
  133. ] = field(default_factory=dict)
  134. shared_params_mapping: Dict[
  135. Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
  136. ] = field(default_factory=dict)
  137. submodule_prefixes: Set[str] = field(default_factory=set)
  138. handle_model: bool = True
  139. handle_optim: bool = True
  140. fsdp_context: Callable = contextlib.nullcontext
  141. fsdp_modules: List[nn.Module] = field(default_factory=list)
  142. @functools.lru_cache(maxsize=None)
  143. def _get_fqns(
  144. model: nn.Module,
  145. name: str,
  146. skip_ddp_prefix: bool = True,
  147. skip_compiler_prefix: bool = True,
  148. ) -> FQNS_T:
  149. """
  150. This API is used to convert the name of a parameter to the FQNs. For FSDP
  151. without `use_orig_params`, the name of FlatParameter can be mapped to
  152. multiple original parameters. As a result, the return type of this function
  153. is `Set[str]`.
  154. Args:
  155. module (nn.Module): the root model.
  156. name (str): the name
  157. skip_ddp_prefix (bool): whether to skip DDP's `module` prefix
  158. Returns:
  159. The canonical FQNs based on the model traversal.
  160. """
  161. # Remove the checkpoint prefix, if it exists.
  162. name = name.replace(_CHECKPOINT_PREFIX, "")
  163. if "." not in name:
  164. return {name}
  165. obj_names = name.split(".")
  166. fqn_obj_names = []
  167. curr_obj = model
  168. for i, curr_obj_name in enumerate(obj_names):
  169. if isinstance(curr_obj, DDP):
  170. assert curr_obj_name == "module"
  171. curr_obj = curr_obj.module
  172. if not skip_ddp_prefix:
  173. fqn_obj_names.append(curr_obj_name)
  174. elif isinstance(curr_obj, FSDP):
  175. if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM:
  176. prefix = ".".join(fqn_obj_names)
  177. flat_param = getattr(curr_obj, _FLAT_PARAM)
  178. if prefix:
  179. prefix = f"{prefix}."
  180. return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
  181. curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
  182. if curr_obj_name != FSDP_WRAPPED_MODULE:
  183. fqn_obj_names.append(curr_obj_name)
  184. curr_obj = getattr(curr_obj, curr_obj_name)
  185. elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):
  186. assert curr_obj_name == "_orig_mod"
  187. curr_obj = curr_obj._orig_mod
  188. if not skip_compiler_prefix:
  189. fqn_obj_names.append(curr_obj_name)
  190. else:
  191. fqn_obj_names.append(curr_obj_name)
  192. if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:
  193. if i != len(obj_names) - 1:
  194. raise RuntimeError("Expect `_extra_state` to be the last obj name")
  195. else:
  196. curr_obj = getattr(curr_obj, curr_obj_name)
  197. return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")}
  198. class _EXTRA_STATE:
  199. pass
  200. def _iterate_valid_model_state(model):
  201. visited_modules: Set[nn.Module] = set()
  202. def recurse(module: nn.Module, curr_fqn: str) -> Generator:
  203. visited_modules.add(module)
  204. curr_fqn = f"{curr_fqn}." if curr_fqn else ""
  205. for name, submodule in module.named_children():
  206. if submodule in visited_modules:
  207. continue
  208. new_fqn = f"{curr_fqn}{name}"
  209. yield from recurse(submodule, new_fqn)
  210. for name, obj in chain(
  211. module.named_buffers(recurse=False), module.named_parameters(recurse=False)
  212. ):
  213. if name in module._non_persistent_buffers_set:
  214. continue
  215. new_fqn = f"{curr_fqn}{name}"
  216. yield new_fqn, obj
  217. if (
  218. getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state)
  219. != nn.Module.get_extra_state
  220. ):
  221. new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}"
  222. yield new_fqn, _EXTRA_STATE()
  223. yield from recurse(model, "")
  224. def _verify_options(
  225. model: nn.Module,
  226. optims: Tuple[torch.optim.Optimizer, ...],
  227. optim_only: bool,
  228. *,
  229. submodules: Optional[Set[nn.Module]] = None,
  230. options: Optional[StateDictOptions] = None,
  231. ) -> _StateDictInfo:
  232. """
  233. Verify the model and options passed by the user and generates _StateDictInfo.
  234. """
  235. if submodules:
  236. warnings.warn(
  237. "Getting submodules only model/optim state_dict is deprecated and "
  238. "will be removed in 2.5. This feature can be achieved by manually "
  239. "filtering out the state_dict returned from get_state_dict.",
  240. FutureWarning,
  241. )
  242. if optim_only and not optims:
  243. raise RuntimeError(
  244. "Optimizers are not passed in but optim_only is set to True."
  245. )
  246. options = options or StateDictOptions()
  247. fqn_param_mapping: Dict[
  248. Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
  249. ] = {}
  250. shared_params_mapping: Dict[
  251. Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
  252. ] = {}
  253. for name, param in _iterate_valid_model_state(model):
  254. if isinstance(param, _EXTRA_STATE):
  255. continue
  256. fqns = _get_fqns(model, name)
  257. fqn = fqn_param_mapping.get(param, None)
  258. if fqn is not None:
  259. cast(Set[str], fqn_param_mapping[param]).update(fqns)
  260. shared_params_mapping[param] = fqn_param_mapping[param]
  261. else:
  262. # We need to do copy as _get_fqns is lru_cached
  263. fqn_param_mapping[param] = fqns.copy()
  264. for fqn in fqns:
  265. if not isinstance(param, _EXTRA_STATE):
  266. fqn_param_mapping[fqn] = param
  267. for param_, fqns_ in list(shared_params_mapping.items()):
  268. for fqn in fqns_:
  269. shared_params_mapping[fqn] = cast(torch.Tensor, param_)
  270. submodule_prefixes: Set[str] = set()
  271. if submodules:
  272. submodules = set(submodules)
  273. for name, module in model.named_modules():
  274. if module not in submodules:
  275. continue
  276. fqns = _get_fqns(model, name)
  277. assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
  278. submodule_prefixes.update(f"{fqn}." for fqn in fqns)
  279. if options.broadcast_from_rank0 and not options.full_state_dict:
  280. raise ValueError(
  281. "full_state_dict must be True when broadcast_from_rank0 is True."
  282. )
  283. fsdp_modules = FSDP.fsdp_modules(model)
  284. state_dict_config: StateDictConfig
  285. optim_state_dict_config: OptimStateDictConfig
  286. fsdp_context: Callable
  287. if fsdp_modules:
  288. # FSDP API only work if at least one FSDP instance exists.
  289. if options.full_state_dict:
  290. state_dict_config = FullStateDictConfig(
  291. offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload
  292. )
  293. optim_state_dict_config = FullOptimStateDictConfig(
  294. offload_to_cpu=options.cpu_offload,
  295. rank0_only=(options.cpu_offload or options.broadcast_from_rank0),
  296. )
  297. state_dict_type = StateDictType.FULL_STATE_DICT
  298. else:
  299. state_dict_config = ShardedStateDictConfig(
  300. offload_to_cpu=options.cpu_offload,
  301. )
  302. optim_state_dict_config = ShardedOptimStateDictConfig(
  303. offload_to_cpu=options.cpu_offload,
  304. )
  305. state_dict_type = StateDictType.SHARDED_STATE_DICT
  306. @contextlib.contextmanager
  307. def fsdp_state_dict_type_without_warning(
  308. module,
  309. state_dict_type,
  310. state_dict_config,
  311. optim_state_dict_config,
  312. ):
  313. with warnings.catch_warnings():
  314. with FSDP.state_dict_type(
  315. module=module,
  316. state_dict_type=state_dict_type,
  317. state_dict_config=state_dict_config,
  318. optim_state_dict_config=optim_state_dict_config,
  319. ):
  320. yield
  321. fsdp_context = functools.partial(
  322. fsdp_state_dict_type_without_warning,
  323. module=model,
  324. state_dict_type=state_dict_type,
  325. state_dict_config=state_dict_config,
  326. optim_state_dict_config=optim_state_dict_config,
  327. )
  328. else:
  329. fsdp_context = contextlib.nullcontext
  330. return _StateDictInfo(
  331. **asdict(options),
  332. fqn_param_mapping=fqn_param_mapping,
  333. shared_params_mapping=shared_params_mapping,
  334. submodule_prefixes=submodule_prefixes,
  335. fsdp_context=fsdp_context,
  336. fsdp_modules=cast(List[nn.Module], fsdp_modules),
  337. handle_model=not optim_only,
  338. handle_optim=(len(optims) > 0),
  339. )
  340. def _verify_state_dict(
  341. model_state_dict: Dict[str, ValueType],
  342. optim_state_dict: OptimizerStateType,
  343. info: _StateDictInfo,
  344. ) -> None:
  345. for module in info.fsdp_modules:
  346. fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
  347. assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module."
  348. # Verify if the model_state_dict and optim_state_dict are valid. This API
  349. # should give the users an explicit error message to debug or report.
  350. if (
  351. info.handle_model
  352. and not model_state_dict
  353. and not info.submodule_prefixes
  354. and not info.ignore_frozen_params
  355. and not (info.cpu_offload and info.full_state_dict)
  356. and info.strict
  357. and not info.broadcast_from_rank0
  358. ):
  359. raise RuntimeError(
  360. "The option indicates that model state_dict is required to save "
  361. "or load, but model state_dict is empty."
  362. f"rank = {dist.get_rank()=}."
  363. )
  364. if info.handle_optim:
  365. if (
  366. not optim_state_dict
  367. and not (info.cpu_offload and info.full_state_dict)
  368. and (not info.broadcast_from_rank0)
  369. ):
  370. raise RuntimeError(
  371. "The option indicates that model state_dict is required to save, "
  372. f"or load but optim state_dict is empty. {optim_state_dict}"
  373. )
  374. for key in model_state_dict.keys():
  375. if _FLAT_PARAM in key:
  376. raise RuntimeError(
  377. f"{key} contains {_FLAT_PARAM}. This can happen if the model "
  378. "is not the root module."
  379. )
  380. def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Callable:
  381. call = getattr(obj, api)
  382. if call in _patched_state_dict:
  383. call = functools.partial(getattr(obj.__class__, api), self=obj)
  384. return call
  385. def _maybe_full_or_cpu_state_dict(
  386. state_dict: Dict[str, Any], info: _StateDictInfo
  387. ) -> Dict[str, Any]:
  388. if info.full_state_dict:
  389. ranks_only = (
  390. tuple()
  391. if (not info.cpu_offload or not torch.distributed.is_initialized())
  392. else (0,)
  393. )
  394. return _gather_state_dict(
  395. state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only
  396. )
  397. elif info.cpu_offload:
  398. return _offload_state_dict_to_cpu(state_dict)
  399. else:
  400. return state_dict
  401. def _get_model_state_dict(
  402. model: nn.Module, info: _StateDictInfo
  403. ) -> Dict[str, ValueType]:
  404. if not info.handle_model:
  405. return {}
  406. with info.fsdp_context():
  407. state_dict = _state_dict_fn(model, "state_dict")()
  408. for key in list(state_dict.keys()):
  409. fqns = _get_fqns(model, key)
  410. assert len(fqns) == 1, (key, fqns)
  411. fqn = next(iter(fqns))
  412. if fqn != key:
  413. # As we only support FSDP, DDP, and TP, the only cases are
  414. # wrapper-based DDP and compiler. Verify if the assumption
  415. # is correct.
  416. def verify(key, fqn) -> bool:
  417. if len(fqn) >= len(key):
  418. return False
  419. fqn_split = fqn.split(".")
  420. key_split = key.split(".")
  421. fqn_idx = 0
  422. for key_idx, key_name in enumerate(key_split):
  423. if key_name == fqn_split[fqn_idx]:
  424. fqn_idx += 1
  425. if fqn_idx == len(fqn_split):
  426. return key_idx == len(key_split) - 1
  427. elif key_name in ("module", "_orig_mod"):
  428. continue
  429. else:
  430. return False
  431. return True
  432. if not verify(key, fqn):
  433. raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}")
  434. state_dict[fqn] = state_dict.pop(key)
  435. if info.submodule_prefixes:
  436. new_state_dict: Dict[str, ValueType] = {}
  437. # TODO: make this faster.
  438. for fqn in state_dict.keys():
  439. for prefix in info.submodule_prefixes:
  440. if not fqn.startswith(prefix):
  441. continue
  442. if info.keep_submodule_prefixes:
  443. new_state_dict[fqn] = state_dict[fqn]
  444. else:
  445. new_fqn = fqn[len(prefix) :]
  446. new_state_dict[new_fqn] = state_dict[fqn]
  447. state_dict = new_state_dict
  448. if info.ignore_frozen_params:
  449. for key, param in model.named_parameters():
  450. if param.requires_grad:
  451. continue
  452. fqns = _get_fqns(model, key)
  453. for fqn in fqns:
  454. state_dict.pop(fqn)
  455. for key, p in list(state_dict.items()):
  456. if torch.is_tensor(p) and p.is_meta:
  457. state_dict.pop(key)
  458. return _maybe_full_or_cpu_state_dict(state_dict, info)
  459. def _load_model_state_dict(
  460. model: nn.Module,
  461. state_dict: Dict[str, ValueType],
  462. info: _StateDictInfo,
  463. ) -> _IncompatibleKeys:
  464. if not info.handle_model or (not state_dict and not info.broadcast_from_rank0):
  465. return _IncompatibleKeys({}, {})
  466. local_state_dict = {}
  467. for key, value in _iterate_valid_model_state(model):
  468. fqns = _get_fqns(model, key)
  469. fqns_with_prefix = _get_fqns(
  470. model, key, skip_ddp_prefix=False, skip_compiler_prefix=False
  471. )
  472. for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix):
  473. if (
  474. not info.broadcast_from_rank0 or dist.get_rank() == 0
  475. ) and fqn != fqn_with_prefix:
  476. state_dict[fqn_with_prefix] = state_dict.pop(fqn)
  477. local_state_dict[fqn_with_prefix] = value
  478. if info.broadcast_from_rank0:
  479. device = None
  480. for key, value in local_state_dict.items():
  481. if torch.is_tensor(value) and value.dim() > 0:
  482. if device is None:
  483. device = value.device
  484. else:
  485. assert device == value.device
  486. assert device is not None
  487. _broadcast_state_dict(
  488. state_dict, local_state_dict, device=device, strict=info.strict
  489. )
  490. for fqn, local_state in local_state_dict.items():
  491. state_dict[fqn] = local_state
  492. with info.fsdp_context():
  493. return cast(
  494. _IncompatibleKeys,
  495. _state_dict_fn(model, "load_state_dict")(
  496. state_dict=state_dict, strict=info.strict
  497. ),
  498. )
  499. def _init_optim_state(optim: torch.optim.Optimizer) -> None:
  500. """
  501. Initialize optim states by calling the step() with zero grads.
  502. """
  503. if optim.state:
  504. # The optimizer state is initialized.
  505. return
  506. for param_group in optim.param_groups:
  507. for param in param_group[_PARAMS]:
  508. if param.grad is not None:
  509. raise RuntimeError(
  510. "state_dict can only be used if the optimizer "
  511. "states are initialized (usually after one step() with "
  512. "gradients) or gradients are None. For the later case, "
  513. "state_dict will fake the gradients as zero "
  514. "to initialize the optimizer states. However, the "
  515. "gradients are not None."
  516. )
  517. if param.requires_grad:
  518. param.grad = torch.zeros_like(param)
  519. # Some optimizers will update parameters regardless of grads due to lr, so
  520. # make lr to zero when calling `step()`.
  521. lrs = []
  522. for param_group in optim.param_groups:
  523. if "lr" in param_group:
  524. lrs.append(param_group["lr"])
  525. param_group["lr"] = 0.0
  526. optim.step(closure=None)
  527. # Whether to recover the "lr" should not matter too much as we will
  528. # restore checkpointing later.
  529. for param_group in optim.param_groups:
  530. if "lr" in param_group:
  531. param_group["lr"] = lrs.pop(0)
  532. optim.zero_grad(set_to_none=True)
  533. def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, ValueType]:
  534. """
  535. This API flattens the optimizer state_dict to support optimizer resharding for
  536. MPMD, e.g., pipeline parallelism.
  537. Without the API, the original optimizer state_dict looks like:
  538. {
  539. "state": {
  540. "layer1.weight": {
  541. "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor
  542. },
  543. "layer2.weight": {
  544. "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor
  545. },
  546. },
  547. "param_group": [
  548. {
  549. "lr": 0.0,
  550. "betas": (0.9, 0.95), ...,
  551. "params": ["layer1.weight", "layer2.weight"]
  552. }
  553. ]
  554. }
  555. With this API, the optimizer state_dict looks like:
  556. {
  557. "state.layer1.weight.step": 10,
  558. "state.layer2.weight.step": 10,
  559. "state.layer1.weight.exp_avg": SomeTensor,
  560. "state.layer2.weight.exp_avg": SomeTensor,
  561. "state.layer1.weight.exp_avg_sq": SomeTensor,
  562. "state.layer2.weight.exp_avg_sq": SomeTensor,
  563. "param_group.layer1.weight.lr" : 0.1,
  564. "param_group.layer2.weight.lr" : 0.1,
  565. "param_group.layer1.weight.betas" : (0.9, 0.95),
  566. "param_group.layer2.weight.betas" : (0.9, 0.95),
  567. }
  568. Note that if any of the value is a container, like the betas in the example,
  569. this API won't flattent it.
  570. """
  571. def _raise_if_type_not_supported(v):
  572. if not isinstance(v, (torch.Tensor, int, float)):
  573. raise NotImplementedError(
  574. "Flattening optimizer state_dict only supports "
  575. "tensor, int, float states now. "
  576. f"Type is {type(v)}."
  577. )
  578. ret: Dict[str, ValueType] = {}
  579. for fqn, state in cast(DictValueType, state_dict[_STATE]).items():
  580. for k, v in cast(DictValueType, state).items():
  581. _raise_if_type_not_supported(v)
  582. ret[f"{_STATE}.{fqn}.{k}"] = v
  583. for param_group in cast(ListDictValueType, state_dict[_PG]):
  584. fqns = param_group.pop(_PARAMS)
  585. for fqn in cast(List[str], fqns):
  586. for k, v in param_group.items():
  587. ret[f"{_PG}.{fqn}.{k}"] = v
  588. return ret
  589. def _unflatten_optim_state_dict(
  590. optim: torch.optim.Optimizer,
  591. state_dict: Dict[str, ValueType],
  592. info: _StateDictInfo,
  593. ) -> OptimizerStateType:
  594. """
  595. This API unflattens the state_dict generated by _flatten_optim_state_dict().
  596. See the docstring of _flatten_optim_state_dict() for more detail.
  597. """
  598. state: DictValueType = {}
  599. pg_state: ListDictValueType = []
  600. return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}
  601. for param_group in optim.param_groups:
  602. pg_state.append({_PARAMS: []})
  603. for param in param_group[_PARAMS]:
  604. for fqn in info.fqn_param_mapping[param]:
  605. params = pg_state[-1][_PARAMS]
  606. assert isinstance(params, list) # typing
  607. params.append(fqn)
  608. if not param.requires_grad:
  609. continue
  610. state[fqn] = {}
  611. for state_name in optim.state[param].keys():
  612. cast(DictValueType, state[fqn])[state_name] = state_dict[
  613. f"{_STATE}.{fqn}.{state_name}"
  614. ]
  615. first_param_fqn = cast(List[str], pg_state[-1][_PARAMS])[0]
  616. for k in param_group.keys():
  617. if k == _PARAMS:
  618. continue
  619. value = state_dict[f"{_PG}.{first_param_fqn}.{k}"]
  620. if k not in pg_state[-1]:
  621. pg_state[-1][k] = value
  622. elif pg_state[-1][k] != value:
  623. raise RuntimeError(
  624. "All the parameters in the same parameter group should have "
  625. f"the same saved param_group value. But {first_param_fqn}.{k} "
  626. f"is {value} while other(s) is {pg_state[-1][k]}."
  627. )
  628. return return_osd
  629. def _get_optim_state_dict(
  630. model: nn.Module,
  631. optimizers: Tuple[torch.optim.Optimizer, ...],
  632. info: _StateDictInfo,
  633. ) -> OptimizerStateType:
  634. if not info.handle_optim:
  635. return {}
  636. optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []}
  637. for optim in optimizers:
  638. _init_optim_state(optim)
  639. osd = _state_dict_fn(optim, "state_dict")()
  640. if info.fsdp_modules:
  641. with info.fsdp_context():
  642. osd = FSDP.optim_state_dict(model, optim, osd)
  643. # We need to specially handle FlatParameter FSDP as
  644. # FlatParameter FSDP converts the FQNs.
  645. # There are no easy ways to do this conversion systematically.
  646. # We can only use a string replacment without correctness check.
  647. if not osd:
  648. continue
  649. for k in list(osd[_STATE].keys()):
  650. if "_orig_mod" in k:
  651. osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k)
  652. for g in osd[_PG]:
  653. params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]]
  654. g[_PARAMS] = params
  655. else:
  656. params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups))
  657. param_pid_mapping = dict(zip(params, range(len(params))))
  658. fqn_pid_mapping = {}
  659. for key, param in model.named_parameters():
  660. fqns = _get_fqns(model, key)
  661. assert len(fqns) == 1
  662. fqn = next(iter(fqns))
  663. if param not in param_pid_mapping:
  664. continue
  665. pid = param_pid_mapping[param]
  666. fqn_pid_mapping[fqn] = pid
  667. fqn_pid_mapping[pid] = fqn
  668. for key in list(osd[_STATE].keys()):
  669. fqn = fqn_pid_mapping[key]
  670. osd[_STATE][fqn] = osd[_STATE].pop(key)
  671. for group in osd[_PG]:
  672. group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]]
  673. if not osd:
  674. continue
  675. cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE])
  676. cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG])
  677. if info.flatten_optimizer_state_dict:
  678. optim_state_dict = cast(
  679. OptimizerStateType, _flatten_optim_state_dict(optim_state_dict)
  680. )
  681. return _maybe_full_or_cpu_state_dict(optim_state_dict, info)
  682. def _split_optim_state_dict(
  683. model: nn.Module,
  684. optim: torch.optim.Optimizer,
  685. optim_state_dict: OptimizerStateType,
  686. info: _StateDictInfo,
  687. ) -> OptimizerStateType:
  688. """
  689. Extract the corresponding optim state_dict from ``optim_state_dict`` for
  690. ``optim`` and return the result optim state_dict.
  691. Args:
  692. model (nn.Module): the root model.
  693. optim (torch.optim.Optimizer): the optimizer.
  694. optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that
  695. contains the optim state_dict of ``optim``.
  696. info (_StateDictInfo): state dict information.
  697. Returns:
  698. The optim state_dict of ``optim``.
  699. """
  700. state: DictValueType = {}
  701. pg_state: ListDictValueType = []
  702. return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}
  703. pg_mapping: Dict[int, int] = {}
  704. if all(
  705. isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys()
  706. ):
  707. return optim_state_dict
  708. for param_group in optim.param_groups:
  709. pg_state.append({_PARAMS: []})
  710. for param in param_group[_PARAMS]:
  711. for fqn in info.fqn_param_mapping[param]:
  712. if fqn in info.shared_params_mapping:
  713. in_params = False
  714. for loaded_param_group in cast(
  715. ListDictValueType, optim_state_dict[_PG]
  716. ):
  717. if fqn in cast(List[str], loaded_param_group[_PARAMS]):
  718. in_params = True
  719. break
  720. else:
  721. in_params = True
  722. if not in_params:
  723. continue
  724. params = pg_state[-1][_PARAMS]
  725. assert isinstance(params, list)
  726. params.append(fqn)
  727. if param.requires_grad:
  728. state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn]
  729. for loaded_param_group in cast(
  730. ListDictValueType, optim_state_dict[_PG]
  731. ):
  732. if fqn in cast(List[str], loaded_param_group[_PARAMS]):
  733. pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1
  734. for param_group in cast(ListDictValueType, optim_state_dict[_PG]):
  735. idx = pg_mapping.get(id(param_group), -1)
  736. if idx == -1:
  737. continue
  738. for key, value in param_group.items():
  739. if key == _PARAMS:
  740. continue
  741. # TODO: check if value is the same if exists.
  742. pg_state[idx][key] = value
  743. return return_osd
  744. def _load_optim_state_dict(
  745. model: nn.Module,
  746. optimizers: Tuple[torch.optim.Optimizer, ...],
  747. state_dict: OptimizerStateType,
  748. info: _StateDictInfo,
  749. ) -> None:
  750. if not info.handle_optim:
  751. return
  752. for optim in optimizers:
  753. _init_optim_state(optim)
  754. if state_dict:
  755. if _STATE in state_dict:
  756. optim_state_dict = _split_optim_state_dict(
  757. model, optim, state_dict, info
  758. )
  759. else:
  760. optim_state_dict = _unflatten_optim_state_dict(
  761. optim, cast(Dict[str, ValueType], state_dict), info
  762. )
  763. else:
  764. optim_state_dict = {}
  765. if info.fsdp_modules:
  766. # We need to specially handle FlatParameter FSDP as
  767. # FlatParameter FSDP converts the FQNs.
  768. for original_fqn, _ in model.named_parameters():
  769. fqns = _get_fqns(model, original_fqn)
  770. fqns_with_compiler = _get_fqns(
  771. model, original_fqn, skip_compiler_prefix=False
  772. )
  773. if fqns == fqns_with_compiler:
  774. continue
  775. assert len(fqns) == 1
  776. fqn = fqns.pop()
  777. fqn_with_compiler = fqns_with_compiler.pop()
  778. for g in optim_state_dict[_PG]:
  779. val = cast(Dict[str, Any], g)
  780. params = [
  781. key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS]
  782. ]
  783. val[_PARAMS] = params
  784. osd_state = cast(DictValueType, optim_state_dict[_STATE])
  785. for k in list(osd_state.keys()):
  786. if fqn in k:
  787. osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k)
  788. with info.fsdp_context():
  789. optim_state_dict = FSDP.optim_state_dict_to_load(
  790. model, optim, optim_state_dict
  791. )
  792. elif info.broadcast_from_rank0:
  793. info.full_state_dict = False
  794. local_state_dict = _get_optim_state_dict(model, (optim,), info)
  795. info.full_state_dict = True
  796. device = None
  797. def _device(t):
  798. if t.dim() > 0:
  799. nonlocal device
  800. if device is None:
  801. device = t.device
  802. elif device != t.device:
  803. raise ValueError("Device mismatch")
  804. return t
  805. _ = tree_map_only(torch.Tensor, _device, local_state_dict)
  806. assert device is not None
  807. flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict)
  808. flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict)
  809. _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device)
  810. # The modifications listed seek to address the problem where optim might possess
  811. # dissimilar parameters in comparison to optim_state_dict. This is achieved by
  812. # incorporating differential parameters within local, which may result in optim
  813. # having additional parameters ultimately.
  814. for optim_key in flatten_osd.keys():
  815. if optim_key not in flatten_local_osd:
  816. assert optim_key in osd_mapping
  817. flatten_local_osd[optim_key] = flatten_osd[optim_key]
  818. local_osd_mapping[optim_key] = osd_mapping[optim_key]
  819. optim_state_dict = _unflatten_state_dict(
  820. flatten_local_osd, local_osd_mapping
  821. )
  822. # Note that we do not have to convert the FQN back to param id here if
  823. # order in optim.param_groups[idx][_PARAMS] is the same as the one in
  824. # optim_state_dict[_PG][idx][_PARAMS].
  825. _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict)
  826. def get_model_state_dict(
  827. model: nn.Module,
  828. *,
  829. submodules: Optional[Set[nn.Module]] = None,
  830. options: Optional[StateDictOptions] = None,
  831. ) -> Dict[str, ValueType]:
  832. """
  833. Return the model state_dict of ``model``.
  834. See ``get_state_dict`` for the detail usage.
  835. Args:
  836. model (nn.Module): the nn.Module to the model.
  837. submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters
  838. that belong to the submodules.
  839. options (StateDictOptions): the options to control how
  840. model state_dict and optimizer state_dict should be returned. See
  841. `StateDictOptions` for the details.
  842. Returns:
  843. The state_dict for ``model``.
  844. :rtype: typing.Dict[str, ValueType]
  845. """
  846. with _gc_context():
  847. info = _verify_options(
  848. model,
  849. tuple(),
  850. optim_only=False,
  851. submodules=submodules,
  852. options=options,
  853. )
  854. model_state_dict = _get_model_state_dict(model, info)
  855. _verify_state_dict(model_state_dict, {}, info)
  856. return model_state_dict
  857. def get_optimizer_state_dict(
  858. model: nn.Module,
  859. optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
  860. *,
  861. submodules: Optional[Set[nn.Module]] = None,
  862. options: Optional[StateDictOptions] = None,
  863. ) -> OptimizerStateType:
  864. """
  865. Return the combined state_dict for optimizers.
  866. See ``get_state_dict`` for the detail usage.
  867. Args:
  868. model (nn.Module): the nn.Module to the model.
  869. optimizers (Union[None, Optimizer, Iterable[Optimizer]]):
  870. The optimizers that are used to optimize ``model``.
  871. submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters
  872. that belong to the submodules.
  873. options (StateDictOptions): the options to control how
  874. model state_dict and optimizer state_dict should be returned. See
  875. `StateDictOptions` for the details.
  876. Returns:
  877. The state_dict for ``optimizers``.
  878. :rtype: OptimizerStateType
  879. """
  880. with _gc_context():
  881. optimizers = (
  882. (optimizers,)
  883. if isinstance(optimizers, torch.optim.Optimizer)
  884. else tuple(optimizers)
  885. )
  886. info = _verify_options(
  887. model,
  888. optimizers,
  889. optim_only=True,
  890. submodules=submodules,
  891. options=options,
  892. )
  893. optim_state_dict = _get_optim_state_dict(model, optimizers, info)
  894. _verify_state_dict({}, optim_state_dict, info)
  895. return optim_state_dict
  896. def get_state_dict(
  897. model: nn.Module,
  898. optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
  899. *,
  900. submodules: Optional[Set[nn.Module]] = None,
  901. options: Optional[StateDictOptions] = None,
  902. ) -> Tuple[Dict[str, ValueType], OptimizerStateType]:
  903. """
  904. Return the model state_dict and optimizers state_dict.
  905. ``get_state_dict`` can process any module that is parallelized by PyTorch
  906. FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any
  907. combination of these parallelisms. The main functions of ``get_state_dict``
  908. are: 1.) returning a model and optimizer state_dict that can be resharded
  909. with a different number of trainers and/or different parallelisms.
  910. 2.) hiding the parallelism-specific state_dict APIs. Users don't have to call
  911. these APIs.
  912. 3.) sanity checking the result state_dict.
  913. The keys of the result state dictionary are the canonical FQNs (Fully
  914. Qualified Names). A canonical FQN refers to the FQN based on a parameter's
  915. position in an nn.Module hierarchy. More specifically, a canonical FQN to a
  916. parameter is the FQN returned by ``module.named_parameters()`` or
  917. ``module.named_buffers()`` when the module is not distributed by any
  918. parallelisms. Since the optimizer internally uses parameter IDs to represent
  919. a parameter, there will be a conversion from the parameter IDs to the
  920. canonical FQNs when calling this API.
  921. ``get_state_dict`` can also process a module that is not parallelized. In
  922. such a case, ``get_state_dict`` only performs one function -- converting the
  923. optimizer parameter IDs to the canonical FQNs.
  924. Example:
  925. >>> # xdoctest: +SKIP
  926. >>> import torch
  927. >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  928. >>> from torch.nn.parallel import DistributedDataParallel as DDP
  929. >>> from torch.distributed.checkpoint.state_dict import get_state_dict
  930. >>> fsdp_model = FSDP(copy.deepcopy(model))
  931. >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
  932. >>> ddp_model = DDP(copy.deepcopy(model))
  933. >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
  934. >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
  935. >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
  936. >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
  937. >>> # the asserts will fail.
  938. >>> assert ddp_state_dict == fsdp_state_dict
  939. >>> assert ddp_optim_state == fsdp_optim_state_dict
  940. Args:
  941. model (nn.Module): the nn.Module to the model.
  942. optimizers (Union[None, Optimizer, Iterable[Optimizer]]):
  943. The optimizers that are used to optimize ``model``.
  944. submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters
  945. that belong to the submodules.
  946. options (StateDictOptions): the options to control how
  947. model state_dict and optimizer state_dict should be returned. See
  948. `StateDictOptions` for the details.
  949. Returns:
  950. ``Tuple`` that contain model state_dict and optimizer state_dict.
  951. :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType]
  952. """
  953. with _gc_context():
  954. optimizers = (
  955. (optimizers,)
  956. if isinstance(optimizers, torch.optim.Optimizer)
  957. else tuple(optimizers)
  958. )
  959. info = _verify_options(
  960. model,
  961. optimizers,
  962. optim_only=False,
  963. submodules=submodules,
  964. options=options,
  965. )
  966. model_state_dict = _get_model_state_dict(model, info)
  967. optim_state_dict = _get_optim_state_dict(model, optimizers, info)
  968. _verify_state_dict(model_state_dict, optim_state_dict, info)
  969. return model_state_dict, optim_state_dict
  970. def _unflatten_model_state_dict(
  971. model: nn.Module,
  972. state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]],
  973. ) -> Dict[str, ValueType]:
  974. if not state_dict:
  975. return {}
  976. if isinstance(next(iter(state_dict.keys())), nn.Module):
  977. warnings.warn(
  978. "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``"
  979. "is deprecated and will be removed in 2.5. If you need this "
  980. "feature, please preprocessing the model_state_dict to achieve the "
  981. "same functionality.",
  982. FutureWarning,
  983. )
  984. cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict)
  985. new_state_dict: Dict[str, ValueType] = {}
  986. for submodule, sub_state_dict in cast_state_dict.items():
  987. for name, m in model.named_modules():
  988. if m != submodule:
  989. continue
  990. fqns = _get_fqns(model, name)
  991. assert len(fqns) == 1, "FQNs for a submodule should only have 1 element"
  992. prefix = f"{next(iter(fqns))}."
  993. new_state_dict.update(
  994. {prefix + subfqn: value for subfqn, value in sub_state_dict.items()}
  995. )
  996. return new_state_dict
  997. else:
  998. return cast(Dict[str, ValueType], state_dict)
  999. def set_model_state_dict(
  1000. model: nn.Module,
  1001. model_state_dict: Dict[str, ValueType],
  1002. *,
  1003. options: Optional[StateDictOptions] = None,
  1004. ) -> _IncompatibleKeys:
  1005. """Load the model state_dict.
  1006. The counterpart of ``get_model_state_dict`` to set the state_dict to the
  1007. model. See ``set_state_dict`` for the detail usage.
  1008. Args:
  1009. model (nn.Module): the nn.Module to the model.
  1010. model_state_dict: (Dict[str, ValueType]):
  1011. the model state_dict to load. If the key of the ``model_state_dict``
  1012. is nn.Module, the key is a submodule of ``model`` and the value should
  1013. be the state_dict of the submodule. When loading the state_dict,
  1014. the prefix of the submodule will be append to the state_dict.
  1015. options (StateDictOptions): the options to control how
  1016. model state_dict and optimizer state_dict should be loaded. See
  1017. `StateDictOptions` for the details.
  1018. Returns:
  1019. ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
  1020. * **missing_keys** is a list of str containing the missing keys
  1021. * **unexpected_keys** is a list of str containing the unexpected keys
  1022. :type model_state_dict: typing.Dict[str, ValueType]
  1023. """
  1024. model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict(
  1025. model, model_state_dict
  1026. )
  1027. with _gc_context():
  1028. info = _verify_options(model, tuple(), optim_only=False, options=options)
  1029. _verify_state_dict(model_state_dict, {}, info)
  1030. return _load_model_state_dict(model, model_state_dict, info)
  1031. def set_optimizer_state_dict(
  1032. model: nn.Module,
  1033. optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
  1034. optim_state_dict: OptimizerStateType,
  1035. *,
  1036. options: Optional[StateDictOptions] = None,
  1037. ) -> None:
  1038. """Load the optimizers state_dict.
  1039. The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the
  1040. optimizers. See ``set_state_dict`` for the detail usage.
  1041. Args:
  1042. model (nn.Module): the nn.Module to the model.
  1043. optimizers (Union[Optimizer, Iterable[Optimizer]]):
  1044. The optimizers that are used to optimize ``model``.
  1045. optim_state_dict: OptimizerStateType:
  1046. the optimizer state_dict to load.
  1047. options (StateDictOptions): the options to control how
  1048. model state_dict and optimizer state_dict should be loaded. See
  1049. `StateDictOptions` for the details.
  1050. Returns:
  1051. None
  1052. :type optim_state_dict: typing.OptimizerStateType
  1053. """
  1054. with _gc_context():
  1055. optimizers = (
  1056. (optimizers,)
  1057. if isinstance(optimizers, torch.optim.Optimizer)
  1058. else tuple(optimizers)
  1059. )
  1060. info = _verify_options(model, optimizers, optim_only=True, options=options)
  1061. _verify_state_dict({}, optim_state_dict, info)
  1062. _load_optim_state_dict(model, optimizers, optim_state_dict, info)
  1063. def set_state_dict(
  1064. model: nn.Module,
  1065. optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
  1066. *,
  1067. model_state_dict: Dict[str, ValueType],
  1068. optim_state_dict: OptimizerStateType,
  1069. options: Optional[StateDictOptions] = None,
  1070. ) -> _IncompatibleKeys:
  1071. """Load the model state_dict and optimizers state_dict.
  1072. The counterpart of ``get_state_dict`` to set the state_dict to the model and
  1073. optimizers. The given ``model_state_dict`` and ``optim_state_dict`` do not
  1074. have to be returned by ``get_state_dict`` but must meet the following
  1075. requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``,
  1076. 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor,
  1077. 3) optimizer state_dict cannot contain the parameter IDs; the keys should be
  1078. the canonical FQNs.
  1079. Args:
  1080. model (nn.Module): the nn.Module to the model.
  1081. optimizers (Union[Optimizer, Iterable[Optimizer]]):
  1082. The optimizers that are used to optimize ``model``.
  1083. model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):
  1084. the model state_dict to load. If the key of the ``model_state_dict``
  1085. is nn.Module, the key is a submodule of ``model`` and the value should
  1086. be the state_dict of the submodule. When loading the state_dict,
  1087. the prefix of the submodule will be append to the state_dict.
  1088. optim_state_dict: OptimizerStateType:
  1089. the optimizer state_dict to load.
  1090. options (StateDictOptions): the options to control how
  1091. model state_dict and optimizer state_dict should be loaded. See
  1092. `StateDictOptions` for the details.
  1093. Returns:
  1094. ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
  1095. * **missing_keys** is a list of str containing the missing keys of the model state_dict.
  1096. * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict.
  1097. :type model_state_dict: typing.Dict[str, ValueType]
  1098. :type optim_state_dict: typing.OptimizerStateType
  1099. """
  1100. model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict(
  1101. model, model_state_dict
  1102. )
  1103. with _gc_context():
  1104. optimizers = (
  1105. (optimizers,)
  1106. if isinstance(optimizers, torch.optim.Optimizer)
  1107. else tuple(optimizers)
  1108. )
  1109. info = _verify_options(
  1110. model, optimizers, optim_only=not model_state_dict, options=options
  1111. )
  1112. _verify_state_dict(model_state_dict, optim_state_dict, info)
  1113. _load_optim_state_dict(model, optimizers, optim_state_dict, info)
  1114. return _load_model_state_dict(model, model_state_dict, info)
  1115. # TODO: correct the state_dict function signature.
  1116. # TODO: this API is not yet fully tested. Make it private
  1117. @no_type_check
  1118. def _patch_model_state_dict(
  1119. model: nn.Module,
  1120. *,
  1121. options: Optional[StateDictOptions] = None,
  1122. ) -> None:
  1123. """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``.
  1124. Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to
  1125. be a partial function to call ``get_state_dict`` and ``set_state_dict``.
  1126. Example:
  1127. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  1128. from torch.distributed.checkpoint.state_dict import patch_model_state_dict
  1129. model = fsdp(model)
  1130. patch_model_state_dict(model)
  1131. Args:
  1132. model (nn.Module): the nn.Module to the model.
  1133. options (StateDictOptions): the options to control how
  1134. model state_dict and optimizer state_dict should be loaded. See
  1135. `StateDictOptions` for the details.
  1136. Returns:
  1137. None
  1138. """
  1139. _state_dict_call = functools.partial(
  1140. get_model_state_dict,
  1141. model=model,
  1142. options=options,
  1143. )
  1144. def state_dict_call():
  1145. return _state_dict_call()
  1146. model.state_dict = state_dict_call
  1147. _load_state_dict_call = functools.partial(
  1148. set_model_state_dict,
  1149. model=model,
  1150. options=options,
  1151. )
  1152. def load_state_dict_call(state_dict: Dict[str, Any]):
  1153. _load_state_dict_call(model_state_dict=state_dict)
  1154. model.load_state_dict = load_state_dict_call
  1155. _patched_state_dict.add(state_dict_call)
  1156. _patched_state_dict.add(load_state_dict_call)
  1157. # TODO: correct the load_state_dict function signature.
  1158. # TODO: this API is not yet fully tested. Make it private
  1159. @no_type_check
  1160. def _patch_optimizer_state_dict(
  1161. model: nn.Module,
  1162. *,
  1163. optimizers: Tuple[torch.optim.Optimizer, ...],
  1164. options: Optional[StateDictOptions] = None,
  1165. ) -> None:
  1166. """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``.
  1167. Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to
  1168. be a partial function to call ``get_state_dict`` and ``set_state_dict``.
  1169. Note that if there are multiple optimizers, all of the optimizers will be patched.
  1170. So users only need to call one of the state_dict() to get the full result.
  1171. Example:
  1172. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  1173. from torch.distributed.checkpoint.state_dict import patch_model_state_dict
  1174. model = fsdp(model)
  1175. patch_model_state_dict(model)
  1176. Args:
  1177. model (nn.Module): the nn.Module to the model.
  1178. options (StateDictOptions): the options to control how
  1179. model state_dict and optimizer state_dict should be loaded. See
  1180. `StateDictOptions` for the details.
  1181. Returns:
  1182. None
  1183. """
  1184. _state_dict_call = functools.partial(
  1185. get_optimizer_state_dict,
  1186. model=model,
  1187. optimizers=optimizers,
  1188. options=options,
  1189. )
  1190. def state_dict_call():
  1191. return _state_dict_call()
  1192. _load_state_dict_call = functools.partial(
  1193. set_optimizer_state_dict,
  1194. model=model,
  1195. optimizers=optimizers,
  1196. options=options,
  1197. )
  1198. def load_state_dict_call(state_dict: Dict[str, Any]):
  1199. _load_state_dict_call(optim_state_dict=state_dict)
  1200. _patched_state_dict.add(state_dict_call)
  1201. _patched_state_dict.add(load_state_dict_call)
  1202. optimizers = (
  1203. (optimizers,)
  1204. if isinstance(optimizers, torch.optim.Optimizer)
  1205. else tuple(optimizers)
  1206. )
  1207. for optim in optimizers:
  1208. optim.state_dict = state_dict_call
  1209. optim.load_state_dict = load_state_dict_call