_traverse.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from typing import (
  3. Callable,
  4. cast,
  5. Collection,
  6. List,
  7. Mapping,
  8. MutableMapping,
  9. Optional,
  10. Tuple,
  11. TypeVar,
  12. Union,
  13. )
  14. import torch
  15. from torch.distributed._shard.sharded_tensor.api import ShardedTensor
  16. from torch.distributed._tensor import DTensor
  17. from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
  18. PATH_ITEM = Union[str, int]
  19. OBJ_PATH = Tuple[PATH_ITEM, ...]
  20. T = TypeVar("T")
  21. STATE_DICT_ITEM = object
  22. CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
  23. __all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
  24. def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
  25. return isinstance(value, torch.Tensor)
  26. # TODO: update docstring for traverse.py
  27. def traverse_state_dict(
  28. state_dict: STATE_DICT_TYPE,
  29. visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
  30. keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
  31. ) -> None:
  32. """
  33. Invoke ``visitor`` for each value recursively in ``state_dict``.
  34. Mapping, list, and tuple will be flattened and other value types are treated
  35. as the terminal values and will invoke ``visitor``.
  36. Mapping is treated as non terminal node and will be flattened.
  37. List and tuple, on the other hand, will not be flattened unless containing other
  38. mapping containers or tensors.
  39. """
  40. # a value is terminal if it has no other containers values inside it
  41. def _is_terminal(value: STATE_DICT_ITEM) -> bool:
  42. values: Collection[STATE_DICT_ITEM]
  43. if isinstance(value, Mapping):
  44. return False
  45. elif isinstance(value, list):
  46. values = value
  47. else:
  48. return True
  49. for entry in values:
  50. if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
  51. return False
  52. if keep_traversing is not None and keep_traversing(entry):
  53. return False
  54. return True
  55. def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
  56. if isinstance(value, Mapping):
  57. for k, v in value.items():
  58. _traverse_obj(path + (str(k),), v)
  59. elif _is_terminal(value):
  60. visitor(path, value)
  61. elif isinstance(value, (list, tuple)):
  62. for i, v in enumerate(value):
  63. _traverse_obj(path + (i,), v)
  64. for key, value in state_dict.items():
  65. _traverse_obj((str(key),), value)
  66. def set_element(
  67. root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
  68. ) -> None:
  69. """Set ``value`` in ``root_dict`` along the ``path`` object path."""
  70. cur_container = cast(CONTAINER_TYPE, root_dict)
  71. def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
  72. while len(lst) <= idx:
  73. lst.append(None)
  74. for i in range(1, len(path)):
  75. prev_key = path[i - 1]
  76. key = path[i]
  77. def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
  78. if isinstance(cur_container, Mapping):
  79. cur_container = cast(
  80. CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
  81. )
  82. else:
  83. extend_list(cur_container, prev_key)
  84. if cur_container[prev_key] is None:
  85. cur_container[prev_key] = def_val
  86. cur_container = cur_container[prev_key]
  87. key = path[-1]
  88. if type(key) == int:
  89. extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
  90. cur_container[key] = value
  91. def get_element(
  92. root_dict: STATE_DICT_TYPE,
  93. path: OBJ_PATH,
  94. default_value: Optional[T] = None,
  95. ) -> Optional[T]:
  96. """Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found."""
  97. cur_value = cast(CONTAINER_TYPE, root_dict)
  98. for part in path:
  99. if type(part) is int:
  100. if not isinstance(cur_value, list) or len(cur_value) < part:
  101. return default_value
  102. elif not isinstance(cur_value, Mapping) or part not in cur_value:
  103. return default_value
  104. cur_value = cast(CONTAINER_TYPE, cur_value[part])
  105. return cast(Optional[T], cur_value)
  106. def _print_nested(
  107. value: STATE_DICT_ITEM,
  108. prefix: str = "",
  109. print_fun: Callable[[str], None] = print,
  110. ) -> None:
  111. if type(value) is ShardedTensor:
  112. print_fun(f"{prefix} ShardedTensor size: {value.size()}")
  113. for shard in value.local_shards():
  114. _print_nested(
  115. shard.tensor,
  116. f"{shard.metadata.shard_offsets} ",
  117. print_fun=print_fun,
  118. )
  119. elif type(value) is (DTensor):
  120. print_fun(f"{prefix} DistributedTensor size: {value.size()}")
  121. # TODO: add local offset for _local_tensor in print_nested.
  122. _print_nested(
  123. value._local_tensor,
  124. print_fun=print_fun,
  125. )
  126. elif isinstance(value, torch.Tensor):
  127. print_fun(f"{prefix} Tensor size: {value.size()}")
  128. else:
  129. print_fun(f"{prefix} Type: {type(value)}")
  130. def print_tensor(
  131. path: OBJ_PATH,
  132. value: STATE_DICT_ITEM,
  133. print_fun: Callable[[str], None] = print,
  134. ) -> None:
  135. """
  136. Use this callback with traverse_state_dict to print its content.
  137. By default the content is printed using the builtin ``print`` but this can
  138. be change by passing a different ``print_fun` callable.
  139. """
  140. _print_nested(value, prefix=str(path), print_fun=print_fun)