_debug_utils.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # mypy: allow-untyped-defs
  2. import logging
  3. import time
  4. from collections import defaultdict
  5. from contextlib import contextmanager
  6. from enum import Enum
  7. from typing import Dict, Iterator, List, Set, Tuple
  8. import torch
  9. import torch.distributed as dist
  10. import torch.distributed.fsdp._flat_param as flat_param_file
  11. from torch.distributed.fsdp._common_utils import (
  12. _apply_to_modules,
  13. _get_module_fsdp_state,
  14. clean_tensor_name,
  15. )
  16. logger = logging.getLogger(__name__)
  17. class SimpleProfiler:
  18. class Type(str, Enum):
  19. ALL = "all"
  20. ALLGATHER = "all_gather"
  21. ALLGATHER_OBJ = "all_gather_object"
  22. RESHARDING = "resharding"
  23. H2D = "H2D"
  24. D2H = "D2H"
  25. results: Dict[str, float] = defaultdict(float)
  26. profiling: Set[str] = set()
  27. @classmethod
  28. def reset(cls) -> None:
  29. cls.results.clear()
  30. cls.profiling.clear()
  31. @classmethod
  32. @contextmanager
  33. def profile(cls, profile_type: str) -> Iterator[None]:
  34. assert profile_type not in cls.profiling, (
  35. f"{profile_type} is already being profiled. "
  36. "SimpleProfiler does not support profiling multiple instances at "
  37. "the same time. "
  38. )
  39. cls.profiling.add(profile_type)
  40. begin = time.monotonic()
  41. try:
  42. yield
  43. finally:
  44. end = time.monotonic()
  45. cls.results[profile_type] += end - begin
  46. cls.profiling.remove(profile_type)
  47. @classmethod
  48. def dump_and_reset(cls, msg: str) -> None:
  49. # This cannot be combined with DETAIL distributed log
  50. # as the profiling will be very incorrect.
  51. if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO:
  52. logger.info("%s %s", msg, cls.results)
  53. cls.reset()
  54. def _get_sharded_module_tree_with_module_name_to_fqns(
  55. model: torch.nn.Module,
  56. ) -> Tuple[str, Dict[str, List[str]]]:
  57. """
  58. It is used for composable fully_shard() code path, it returns
  59. 1. sharded module tree info: each line reprents a submodule name that contats the
  60. submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
  61. the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
  62. level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
  63. is like this:
  64. [CompositeModel] FULLY SHARDED
  65. l1[Linear]
  66. u1[UnitModule] FULLY SHARDED
  67. u1.l1[Linear]
  68. u1.seq[Sequential]
  69. u1.seq.0[ReLU]
  70. u1.seq.1[Linear]
  71. u1.seq.2[ReLU]
  72. u1.l2[Linear]
  73. u2[UnitModule] FULLY SHARDED
  74. u2.l1[Linear]
  75. u2.seq[Sequential]
  76. u2.seq.0[ReLU]
  77. u2.seq.1[Linear]
  78. u2.seq.2[ReLU]
  79. u2.l2[Linear]
  80. l2[Linear]
  81. 2. a dict mapping from the concated module FQN and class name to a list of its managed
  82. original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
  83. {'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
  84. 'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
  85. 'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
  86. }
  87. All FQNs are prefixed starting from ``model``.
  88. Args:
  89. model (torch.nn.Module): Root module (which may or may not be passed to
  90. composable `fully_shard()`).
  91. """
  92. def module_fn(
  93. module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
  94. ):
  95. num_spaces = tree_level * 4
  96. trimed_prefix = (
  97. prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
  98. )
  99. prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
  100. printed_prefixed_module_name = " " * num_spaces + prefixed_module_name
  101. state = _get_module_fsdp_state(module)
  102. if state is None:
  103. sharded_tree_info[0] += printed_prefixed_module_name + "\n"
  104. return
  105. handle = state._fully_sharded_module_to_handle.get(module, None)
  106. if handle:
  107. sharded_tree_info[0] += (
  108. printed_prefixed_module_name + " FULLY SHARDED" + "\n"
  109. )
  110. else:
  111. sharded_tree_info[0] += printed_prefixed_module_name + "\n"
  112. if handle:
  113. param = handle.flat_param
  114. assert isinstance(param, flat_param_file.FlatParameter)
  115. global_fqns = [
  116. clean_tensor_name(prefix + name) for name in param._fqns
  117. ] # prefixed from the top level `model` (i.e. including `prefix`)
  118. if prefixed_module_name in sharded_module_name_to_fqns:
  119. sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
  120. else:
  121. sharded_module_name_to_fqns[prefixed_module_name] = global_fqns
  122. def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
  123. return sharded_tree_info[0], sharded_module_name_to_fqns
  124. # Use List to mutate its value in place while running the recursive functions
  125. sharded_tree_info: List[str] = [
  126. "",
  127. ]
  128. sharded_module_name_to_fqns: Dict[str, List[str]] = {}
  129. return _apply_to_modules(
  130. model,
  131. module_fn,
  132. return_fn,
  133. [key for key, _ in model.named_parameters()],
  134. sharded_tree_info,
  135. sharded_module_name_to_fqns,
  136. )