_nested_dict.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from typing import Dict, Tuple
  3. from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
  4. from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
  5. """
  6. TODO:
  7. Need to add ability to handle tuple, OrderedDict, NamedTuple.
  8. Update mappings from dict to a class.
  9. Change set_element to recreate the right type for tuple, OrderedDict, and NamedTuple.
  10. """
  11. FLATTEN_MAPPING = Dict[str, OBJ_PATH]
  12. # TODO: Update Docstring for nested_dict.py
  13. def flatten_state_dict(
  14. state_dict: STATE_DICT_TYPE,
  15. ) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]:
  16. """
  17. Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary.
  18. Use ``unflatten_state_dict`` to revert this process.
  19. Returns:
  20. A tuple with the flatten state_dict and a mapping from original to new state_dict.
  21. N.B. The new keys are derived from the object paths, joined by dot.
  22. For example: ``{ 'a': {'b':...}}`` results in the key `a.b`.
  23. """
  24. flattened: STATE_DICT_TYPE = {}
  25. mappings: FLATTEN_MAPPING = {}
  26. def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
  27. new_fqn = ".".join(map(str, path))
  28. if new_fqn in flattened:
  29. raise ValueError(f"duplicated flatten key {new_fqn}")
  30. flattened[new_fqn] = value
  31. mappings[new_fqn] = path
  32. traverse_state_dict(state_dict, flat_copy)
  33. return flattened, mappings
  34. def unflatten_state_dict(
  35. state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING
  36. ) -> STATE_DICT_TYPE:
  37. """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``."""
  38. nested: STATE_DICT_TYPE = {}
  39. for key, value in state_dict.items():
  40. set_element(nested, mapping[key], value)
  41. return nested