_state_dict_utils.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import io
  4. import math
  5. from typing import (
  6. Any,
  7. Callable,
  8. cast,
  9. Dict,
  10. List,
  11. Mapping,
  12. MutableMapping,
  13. NamedTuple,
  14. Optional,
  15. Tuple,
  16. TYPE_CHECKING,
  17. Union,
  18. )
  19. import torch
  20. import torch.distributed as dist
  21. import torch.nn.functional as F
  22. from torch.distributed._functional_collectives import AsyncCollectiveTensor
  23. if dist.is_available() or TYPE_CHECKING:
  24. from torch.distributed import distributed_c10d
  25. from torch.distributed._shard.sharded_tensor import ShardedTensor
  26. from torch.distributed._tensor import distribute_tensor, DTensor, Replicate
  27. def _identity_func(
  28. obj: torch.Tensor,
  29. pg: Optional[dist.ProcessGroup],
  30. device: Optional[torch.device],
  31. companion_obj: Any,
  32. ) -> torch.Tensor:
  33. return obj
  34. def _all_gather_sharded_tensor(
  35. sharded_tensor: "ShardedTensor",
  36. pg: Optional[dist.ProcessGroup] = None,
  37. device: Optional[torch.device] = None,
  38. ) -> torch.Tensor:
  39. if pg is None:
  40. pg = distributed_c10d._get_default_group()
  41. world_size = dist.get_world_size(pg)
  42. shards = sharded_tensor.local_shards()
  43. dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
  44. tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
  45. chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
  46. pg_device = (
  47. distributed_c10d._get_pg_default_device(pg) if device is None else device
  48. )
  49. if shards:
  50. local_tensor = shards[0].tensor.flatten()
  51. if local_tensor.device.type != pg_device.type:
  52. local_tensor = local_tensor.to(pg_device)
  53. num_padding = chunk_size - local_tensor.numel()
  54. if num_padding > 0:
  55. local_tensor = F.pad(local_tensor, [0, num_padding])
  56. else:
  57. local_tensor = torch.zeros(
  58. chunk_size, dtype=sharded_tensor.dtype, device=pg_device
  59. )
  60. tensor = torch.empty(
  61. chunk_size * world_size,
  62. dtype=local_tensor.dtype,
  63. device=pg_device,
  64. )
  65. dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
  66. tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
  67. return tensor
  68. class CompanionMismatch(Exception):
  69. ...
  70. def _iterate_state_dict(
  71. iter_object: Any,
  72. sharded_tensor_func: Callable,
  73. dtensor_func: Callable,
  74. tensor_func: Callable,
  75. *,
  76. pg: Optional[dist.ProcessGroup] = None,
  77. device: Optional[torch.device] = None,
  78. cpu_offload: bool = False,
  79. companion_obj: Any = None,
  80. ranks_only: Tuple[int, ...] = tuple(),
  81. type_check: bool = True,
  82. non_blocking: bool = True,
  83. ) -> Dict[str, Any]:
  84. """Iterate through the state dict, applying the given functions to each tensor type.
  85. Args:
  86. iter_object (Any): the target state_dict.
  87. sharded_tensor_func (Callable): the function to apply to ShardedTensor
  88. dtensor_func (Callable): the function to apply to DTensor
  89. tensor_func (Callable): the function to apply to Tensor
  90. pg (Optional[dist.ProcessGroup]): process group passed to tensor functions
  91. device (Optional[torch.device]): device passed to tensor functions
  92. cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored
  93. if a companion_obj is supplied.
  94. companion_obj (Any): A companion object to the state dict. If this object
  95. is supplied, we attempt to copy the tensor to the companion object.
  96. ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will
  97. have the same state_dicts. Otherwise only ranks that in ``ranks_only``
  98. have the same state_dicts. Other ranks will get empty state_dicts.
  99. type_check (bool): check if the instance data type is a supported type
  100. that can be saved by DCP. The current supported data types are
  101. torch.Tensor, DTensor, int, float, str, list, dict, None.
  102. non_blocking (bool): whether to use non-blocking copy when copying to the companion object.
  103. """
  104. # TODO: should we use pytree?
  105. cpu_device = torch.device("cpu")
  106. if isinstance(iter_object, ShardedTensor):
  107. ret = sharded_tensor_func(iter_object, pg, device, companion_obj)
  108. elif isinstance(iter_object, DTensor):
  109. ret = dtensor_func(iter_object, pg, device, companion_obj)
  110. elif isinstance(iter_object, torch.Tensor):
  111. ret = tensor_func(iter_object, pg, device, companion_obj)
  112. elif (
  113. isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
  114. or iter_object is None
  115. ):
  116. ret = iter_object
  117. elif isinstance(iter_object, dict):
  118. if companion_obj is not None and (
  119. not isinstance(companion_obj, dict)
  120. or set(companion_obj.keys()) != set(iter_object.keys())
  121. ):
  122. msg = (
  123. ""
  124. if isinstance(companion_obj, dict)
  125. else f"{set(companion_obj.keys())=} {set(iter_object.keys())=}"
  126. )
  127. raise CompanionMismatch(msg)
  128. ret = {
  129. key: _iterate_state_dict(
  130. value,
  131. sharded_tensor_func,
  132. dtensor_func,
  133. tensor_func,
  134. pg=pg,
  135. device=device,
  136. cpu_offload=cpu_offload,
  137. companion_obj=companion_obj[key] if companion_obj is not None else None,
  138. ranks_only=ranks_only,
  139. type_check=type_check,
  140. non_blocking=non_blocking,
  141. )
  142. for key, value in iter_object.items()
  143. }
  144. elif isinstance(iter_object, (list, tuple)):
  145. if companion_obj is not None and (
  146. not isinstance(companion_obj, (list, tuple))
  147. or len(companion_obj) != len(iter_object)
  148. ):
  149. raise CompanionMismatch
  150. ret = [
  151. _iterate_state_dict(
  152. v,
  153. sharded_tensor_func,
  154. dtensor_func,
  155. tensor_func,
  156. pg=pg,
  157. device=device,
  158. cpu_offload=cpu_offload,
  159. companion_obj=companion_obj[idx] if companion_obj is not None else None,
  160. ranks_only=ranks_only,
  161. type_check=type_check,
  162. non_blocking=non_blocking,
  163. )
  164. for idx, v in enumerate(iter_object)
  165. ]
  166. if isinstance(iter_object, tuple):
  167. ret = tuple(ret)
  168. elif not type_check:
  169. ret = copy.deepcopy(iter_object)
  170. else:
  171. raise ValueError(f"Unexpected value type {type(iter_object)}")
  172. if not ranks_only or dist.get_rank(pg) in ranks_only:
  173. if isinstance(ret, torch.Tensor):
  174. if cpu_offload and companion_obj is None:
  175. ret = ret.to(cpu_device)
  176. if companion_obj is not None:
  177. # TODO: support DTensor
  178. companion_obj.copy_(ret, non_blocking=non_blocking)
  179. ret = companion_obj
  180. else:
  181. ret = {} if isinstance(ret, dict) else None
  182. return ret
  183. def _gather_state_dict(
  184. state_dict: Dict[str, Any],
  185. *,
  186. pg: Optional[dist.ProcessGroup] = None,
  187. device: Optional[torch.device] = None,
  188. cpu_offload: bool = False,
  189. ranks_only: Tuple[int, ...] = tuple(),
  190. type_check: bool = True,
  191. ) -> Dict[str, Any]:
  192. """
  193. Given a state_dict, this API gathers all the ShardedTensors or DTensors in
  194. the state_dict.
  195. Args:
  196. state_dict (Dict[str, Any]): the target sharded state_dict.
  197. pg (Optional[dist.ProcessGroup]): the process group that is used to
  198. gather ShardedTensor. Note that gathering a DTensor will use
  199. the DeviceMesh. So this argument will be ignored when gathering a
  200. DTensor.
  201. device: (Optional[torch.device]): the device that is used to
  202. perform allgather for ShardedTensor. Note that gathering a DTensor
  203. will use the DeviceMesh. So this argument will be ignored when
  204. gathering a DTensor.
  205. cpu_offload (bool): whether to offload the tensors to CPU memory. The
  206. default value is False.
  207. ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
  208. have the same state_dicts. Otherwise only ranks that in ``ranks_only``
  209. have the same state_dicts. Other ranks will get empty state_dicts.
  210. type_check: (bool): check if the instance data type is a supported type
  211. that can be saved by DCP. The current supported data types are
  212. torch.Tensor, DTensor, int, float, str, list, dict, None.
  213. Returns:
  214. The gathered state dictionary.
  215. """
  216. def sharded_tensor_func(value, pg, device, companion_obj):
  217. # ShardedTensor does not seem to record the original device type.
  218. # So if the tensor is moved to CPU, we won't know the original type.
  219. # As a result, we have to rely on the user to tell us the correct one.
  220. cpu_device = torch.device("cpu")
  221. output_tensor = _all_gather_sharded_tensor(value, pg, device)
  222. local_shard_device = (
  223. value.local_shards()[0].tensor.device
  224. if value.local_shards()
  225. else cpu_device
  226. )
  227. if output_tensor.device != local_shard_device:
  228. value = output_tensor.to(local_shard_device)
  229. else:
  230. value = output_tensor
  231. return value
  232. def dtensor_func(value, pg, device, companion_obj):
  233. if value.device != value.device_mesh.device_type:
  234. value = value.to(value.device_mesh.device_type)
  235. # FSDP all_gather: [Shard(0)] -> [Replicate()]
  236. # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
  237. # 2D FSDP + TP all_gather:
  238. # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
  239. # - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
  240. placements = [Replicate() for _ in value.placements]
  241. value = value.redistribute(
  242. device_mesh=value.device_mesh,
  243. placements=placements,
  244. )
  245. # Call `wait()` to force the tensor to be synchronous with respect
  246. # to the main stream.
  247. # See the discussion in https://github.com/pytorch/pytorch/pull/117799.
  248. value = value.to_local()
  249. if isinstance(value, AsyncCollectiveTensor):
  250. value = value.wait()
  251. return value
  252. return _iterate_state_dict(
  253. state_dict,
  254. sharded_tensor_func,
  255. dtensor_func,
  256. _identity_func,
  257. pg=pg,
  258. device=device,
  259. cpu_offload=cpu_offload,
  260. ranks_only=ranks_only,
  261. type_check=type_check,
  262. )
  263. def _offload_state_dict_to_cpu(
  264. state_dict: Dict[str, Any],
  265. *,
  266. ranks_only: Tuple[int, ...] = tuple(),
  267. type_check: bool = True,
  268. ) -> Dict[str, Any]:
  269. """
  270. Given a state_dict, this API offload all the tensors to CPU memory.
  271. Args:
  272. state_dict (Dict[str, Any]): the target state_dict.
  273. pg (Optional[dist.ProcessGroup]): the process group that is used to
  274. gather ShardedTensor. Note that gathering a DTensor will use
  275. the DeviceMesh. So this argument will be ignored when gathering a
  276. DTensor.
  277. ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
  278. have the same state_dicts. Otherwise only ranks that in ``ranks_only``
  279. have the same state_dicts. Other ranks will get empty state_dicts.
  280. type_check: (bool): check if the instance data type is a supported type
  281. that can be saved by DCP. The current supported data types are
  282. torch.Tensor, DTensor, int, float, str, list, dict, None.
  283. Returns:
  284. The gathered state dictionary.
  285. """
  286. ret = _iterate_state_dict(
  287. state_dict,
  288. _identity_func,
  289. _identity_func,
  290. _identity_func,
  291. pg=None,
  292. device=None,
  293. cpu_offload=True,
  294. ranks_only=ranks_only,
  295. type_check=type_check,
  296. )
  297. return ret
  298. def _copy_state_dict(
  299. state_dict: Dict[str, Any],
  300. copy_state_dict: Dict[str, Any],
  301. non_blocking: bool = False,
  302. type_check: bool = True,
  303. ) -> Dict[str, Any]:
  304. """
  305. Copies all tensors in a given state dict into a different state_dict with the
  306. same structure. Additionally, a copied state dict with the same value references
  307. is returned. Editing the keys on this state dict will not affect the
  308. passed in copy_state_dict (but the value references are the same).
  309. .. warning::
  310. It is expected by this function that state_dict and copy_state_dict share
  311. the same structure and data types.
  312. .. warning::
  313. The current supported data types are
  314. torch.Tensor, DTensor, int, float, str, list, dict, None.
  315. Args:
  316. state_dict (Dict[str, Any]): the target state_dict.
  317. copy_state_dict (Dict[str, Any]):
  318. The state dict we are copying into. This state_dict must have exactly
  319. the same structure as the source `state_dict`.
  320. non_blocking: (bool): Whether copy ops should be performed asynchronously
  321. type_check (bool): check if the instance data type is a supported type
  322. that can be saved by DCP. The current supported data types are
  323. torch.Tensor, DTensor, int, float, str, list, dict, None.
  324. Returns:
  325. State Dict copy
  326. """
  327. return _iterate_state_dict(
  328. state_dict,
  329. _identity_func,
  330. _identity_func,
  331. _identity_func,
  332. pg=None,
  333. device=None,
  334. cpu_offload=False,
  335. ranks_only=tuple(),
  336. companion_obj=copy_state_dict,
  337. type_check=type_check,
  338. non_blocking=non_blocking,
  339. )
  340. def _create_cpu_state_dict(
  341. state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False
  342. ) -> Dict[str, Any]:
  343. """
  344. Given a state_dict, create another state_dict with the same structure and elements.
  345. However, all tensors in the returned state_dict are new tensors on CPU. These
  346. tensors can be placed on pin_memory or share_memory based on the provided arguments.
  347. .. warning::
  348. Setting both `pin_memory` and `share_memory` to True significantly increases the
  349. latency of this method because of the nuances which require us to register memory
  350. as pinned directly as opposed to relying on the pin_memory cache allocator. This
  351. option should only be used for long lived tensors which are required to be shared.
  352. This is not the case as long as at least one of `pin_memory` or `share_memory` is
  353. set to False.
  354. """
  355. def tensor_func(
  356. obj: torch.Tensor,
  357. pg: Optional[dist.ProcessGroup],
  358. device: Optional[torch.device],
  359. _: Any,
  360. ) -> torch.Tensor:
  361. if len(obj.size()) == 0:
  362. return torch.tensor(0, dtype=obj.dtype)
  363. if share_memory:
  364. t = torch.empty(*tuple(obj.size()), dtype=obj.dtype).share_memory_()
  365. if pin_memory:
  366. succ = torch.cuda.cudart().cudaHostRegister(
  367. t.data_ptr(),
  368. t.numel() * t.element_size(),
  369. 1, # lines up with 'cudaHostRegisterPortable'
  370. )
  371. assert (
  372. succ == 0
  373. ), f"Pinning shared memory failed with error-code: {succ}"
  374. return t
  375. elif pin_memory:
  376. return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()
  377. else:
  378. return torch.empty(*tuple(obj.size()), dtype=obj.dtype)
  379. ret = _iterate_state_dict(
  380. state_dict,
  381. _identity_func,
  382. _identity_func,
  383. tensor_func,
  384. pg=None,
  385. device=None,
  386. cpu_offload=False,
  387. ranks_only=tuple(),
  388. type_check=False,
  389. )
  390. return ret
  391. def _check_state_dict_similarity(
  392. state_dict: Dict[str, Any],
  393. compared_state_dict: Dict[str, Any],
  394. ) -> bool:
  395. """
  396. Given two state_dicts, check if the structures are the same. And
  397. if a [key, tensor] pair exist in one state_dict there must be
  398. the a corresponding pait, [key, other_tensor], in the other state_dict,
  399. where tensor and other_tensor have the same size and dtype.
  400. Return the check result.
  401. """
  402. def tensor_func(
  403. obj: torch.Tensor,
  404. pg: Optional[dist.ProcessGroup],
  405. device: Optional[torch.device],
  406. companion_obj: Any,
  407. ) -> torch.Tensor:
  408. if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():
  409. raise CompanionMismatch
  410. return obj
  411. try:
  412. _iterate_state_dict(
  413. state_dict,
  414. _identity_func,
  415. _identity_func,
  416. tensor_func,
  417. pg=None,
  418. device=None,
  419. cpu_offload=False,
  420. ranks_only=tuple(),
  421. companion_obj=compared_state_dict,
  422. type_check=False,
  423. )
  424. except CompanionMismatch:
  425. return False
  426. return True
  427. class _TensorInfo(NamedTuple):
  428. size: torch.Size
  429. dtype: torch.dtype
  430. def _broadcast_tensors(
  431. full_state_dict: Dict[str, Any],
  432. local_state_dict: Dict[str, Any],
  433. keys: List[str],
  434. device: torch.device,
  435. pg: Optional[dist.ProcessGroup] = None,
  436. ) -> None:
  437. tensors = []
  438. for key in keys:
  439. if dist.get_rank() == 0:
  440. full_state = full_state_dict[key]
  441. assert isinstance(full_state, torch.Tensor)
  442. full_tensor = full_state.detach().to(device)
  443. else:
  444. tensor_info = full_state_dict[key]
  445. full_tensor = torch.empty(
  446. size=tensor_info.size,
  447. device=device,
  448. dtype=tensor_info.dtype,
  449. )
  450. tensors.append(full_tensor)
  451. local_state = local_state_dict.get(key, None)
  452. if local_state is None:
  453. continue
  454. elif isinstance(local_state, DTensor):
  455. local_state_dict[key] = (local_state, full_tensor)
  456. else:
  457. local_state_dict[key] = full_tensor
  458. if pg is None:
  459. pg = dist.distributed_c10d._get_default_group()
  460. if len(tensors) > 1:
  461. dist._broadcast_coalesced(pg, tensors, 500, 0)
  462. else:
  463. dist.broadcast(tensors[0], src=0, group=pg)
  464. for key in keys:
  465. _local_state = local_state_dict.get(key, None)
  466. if _local_state is None or torch.is_tensor(_local_state):
  467. continue
  468. local_state = _local_state[0]
  469. full_tensor = _local_state[1]
  470. local_state_dict[key] = distribute_tensor(
  471. full_tensor, local_state.device_mesh, local_state.placements
  472. )
  473. def _broadcast_state_dict(
  474. full_state_dict: Dict[str, Any],
  475. local_state_dict: Dict[str, Any],
  476. device: torch.device,
  477. pg: Optional[dist.ProcessGroup] = None,
  478. strict: bool = False,
  479. ) -> None:
  480. # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`.
  481. # If strict is True, any keys in `local_state_dict` but not in `full_state_dict`
  482. # will be removed from `local_state_dict`.
  483. ret = {}
  484. if dist.get_rank() == 0:
  485. for key, value in full_state_dict.items():
  486. if not torch.is_tensor(value):
  487. ret[key] = value
  488. elif value.dim() == 0:
  489. ret[key] = value.cpu()
  490. else:
  491. ret[key] = _TensorInfo(value.size(), value.dtype)
  492. broadcast_list = [ret]
  493. dist.broadcast_object_list(broadcast_list, src=0, group=pg)
  494. ret = broadcast_list[0]
  495. # Gather values
  496. keys = []
  497. local_state_dict_keys = set(local_state_dict.keys())
  498. global_keys = set()
  499. for key, value in ret.items():
  500. global_keys.add(key)
  501. if not isinstance(value, _TensorInfo):
  502. if key in local_state_dict:
  503. local_state_dict[key] = value
  504. continue
  505. if dist.get_rank() == 0:
  506. ret[key] = full_state_dict[key]
  507. keys.append(key)
  508. # Broadcast every tensor to avoid OOM for now.
  509. if len(keys) >= 1:
  510. _broadcast_tensors(ret, local_state_dict, keys, device, pg)
  511. keys.clear()
  512. if strict:
  513. if missing_keys := (local_state_dict_keys - global_keys):
  514. for key in missing_keys:
  515. local_state_dict.pop(key)
  516. if keys:
  517. _broadcast_tensors(ret, local_state_dict, keys, device, pg)
  518. # These APIs are from torch.distributed.checkpoint.
  519. # TODO: We should consolidate the code here as some not all modules can depend on
  520. # DCP.
  521. PATH_ITEM = Union[str, int]
  522. OBJ_PATH = Tuple[PATH_ITEM, ...]
  523. FLATTEN_MAPPING = Dict[str, OBJ_PATH]
  524. STATE_DICT_TYPE = Dict[str, Any]
  525. CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any]
  526. def _traverse_state_dict(
  527. state_dict: STATE_DICT_TYPE,
  528. visitor: Callable[[OBJ_PATH, Any], None],
  529. ) -> None:
  530. """
  531. Invoke ``visitor`` for each value recursively in ``state_dict``.
  532. Mapping, list, and tuple will be flattened and other value types are treated
  533. as the terminal values and will invoke ``visitor``.
  534. """
  535. def _traverse_obj(path: OBJ_PATH, value: Any) -> None:
  536. if isinstance(value, Mapping):
  537. for k, v in value.items():
  538. _traverse_obj(path + (str(k),), v)
  539. elif isinstance(value, (list, tuple)):
  540. for i, v in enumerate(value):
  541. _traverse_obj(path + (i,), v)
  542. else:
  543. visitor(path, value)
  544. for key, value in state_dict.items():
  545. _traverse_obj((str(key),), value)
  546. def _flatten_state_dict(
  547. state_dict: STATE_DICT_TYPE,
  548. ) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]:
  549. """
  550. Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary.
  551. Use ``unflatten_state_dict`` to revert this process.
  552. Returns:
  553. A tuple with the flatten state_dict and a mapping from original to new state_dict.
  554. N.B. The new keys are derived from the object paths, joined by dot.
  555. For example: ``{ 'a': {'b':...}}`` results in the key `a.b`.
  556. """
  557. flattened: STATE_DICT_TYPE = {}
  558. mappings: FLATTEN_MAPPING = {}
  559. def flat_copy(path: OBJ_PATH, value: Any) -> None:
  560. new_fqn = ".".join(map(str, path))
  561. if new_fqn in flattened:
  562. raise ValueError(f"duplicated flatten key {new_fqn}")
  563. flattened[new_fqn] = value
  564. mappings[new_fqn] = path
  565. _traverse_state_dict(state_dict, flat_copy)
  566. return flattened, mappings
  567. def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None:
  568. """Set ``value`` in ``root_dict`` along the ``path`` object path."""
  569. cur_container = cast(CONTAINER_TYPE, root_dict)
  570. def extend_list(lst: List[Any], idx: int) -> None:
  571. while len(lst) <= idx:
  572. lst.append(None)
  573. for i in range(1, len(path)):
  574. prev_key = path[i - 1]
  575. key = path[i]
  576. def_val: Union[CONTAINER_TYPE, List[Any]] = {} if type(key) == str else []
  577. if isinstance(cur_container, Mapping):
  578. cur_container = cast(
  579. CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
  580. )
  581. else:
  582. extend_list(cur_container, prev_key)
  583. if cur_container[prev_key] is None:
  584. cur_container[prev_key] = def_val
  585. cur_container = cur_container[prev_key]
  586. key = path[-1]
  587. if type(key) == int:
  588. extend_list(cast(List[Any], cur_container), key)
  589. cur_container[key] = value
  590. def _unflatten_state_dict(
  591. state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING
  592. ) -> STATE_DICT_TYPE:
  593. """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``."""
  594. nested: STATE_DICT_TYPE = {}
  595. for key, value in state_dict.items():
  596. _set_element(nested, mapping[key], value)
  597. return nested