| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- # mypy: allow-untyped-defs
- import logging
- import time
- from collections import defaultdict
- from contextlib import contextmanager
- from enum import Enum
- from typing import Dict, Iterator, List, Set, Tuple
- import torch
- import torch.distributed as dist
- import torch.distributed.fsdp._flat_param as flat_param_file
- from torch.distributed.fsdp._common_utils import (
- _apply_to_modules,
- _get_module_fsdp_state,
- clean_tensor_name,
- )
- logger = logging.getLogger(__name__)
- class SimpleProfiler:
- class Type(str, Enum):
- ALL = "all"
- ALLGATHER = "all_gather"
- ALLGATHER_OBJ = "all_gather_object"
- RESHARDING = "resharding"
- H2D = "H2D"
- D2H = "D2H"
- results: Dict[str, float] = defaultdict(float)
- profiling: Set[str] = set()
- @classmethod
- def reset(cls) -> None:
- cls.results.clear()
- cls.profiling.clear()
- @classmethod
- @contextmanager
- def profile(cls, profile_type: str) -> Iterator[None]:
- assert profile_type not in cls.profiling, (
- f"{profile_type} is already being profiled. "
- "SimpleProfiler does not support profiling multiple instances at "
- "the same time. "
- )
- cls.profiling.add(profile_type)
- begin = time.monotonic()
- try:
- yield
- finally:
- end = time.monotonic()
- cls.results[profile_type] += end - begin
- cls.profiling.remove(profile_type)
- @classmethod
- def dump_and_reset(cls, msg: str) -> None:
- # This cannot be combined with DETAIL distributed log
- # as the profiling will be very incorrect.
- if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO:
- logger.info("%s %s", msg, cls.results)
- cls.reset()
- def _get_sharded_module_tree_with_module_name_to_fqns(
- model: torch.nn.Module,
- ) -> Tuple[str, Dict[str, List[str]]]:
- """
- It is used for composable fully_shard() code path, it returns
- 1. sharded module tree info: each line reprents a submodule name that contats the
- submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
- the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
- level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
- is like this:
- [CompositeModel] FULLY SHARDED
- l1[Linear]
- u1[UnitModule] FULLY SHARDED
- u1.l1[Linear]
- u1.seq[Sequential]
- u1.seq.0[ReLU]
- u1.seq.1[Linear]
- u1.seq.2[ReLU]
- u1.l2[Linear]
- u2[UnitModule] FULLY SHARDED
- u2.l1[Linear]
- u2.seq[Sequential]
- u2.seq.0[ReLU]
- u2.seq.1[Linear]
- u2.seq.2[ReLU]
- u2.l2[Linear]
- l2[Linear]
- 2. a dict mapping from the concated module FQN and class name to a list of its managed
- original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
- {'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
- 'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
- 'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
- }
- All FQNs are prefixed starting from ``model``.
- Args:
- model (torch.nn.Module): Root module (which may or may not be passed to
- composable `fully_shard()`).
- """
- def module_fn(
- module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
- ):
- num_spaces = tree_level * 4
- trimed_prefix = (
- prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
- )
- prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
- printed_prefixed_module_name = " " * num_spaces + prefixed_module_name
- state = _get_module_fsdp_state(module)
- if state is None:
- sharded_tree_info[0] += printed_prefixed_module_name + "\n"
- return
- handle = state._fully_sharded_module_to_handle.get(module, None)
- if handle:
- sharded_tree_info[0] += (
- printed_prefixed_module_name + " FULLY SHARDED" + "\n"
- )
- else:
- sharded_tree_info[0] += printed_prefixed_module_name + "\n"
- if handle:
- param = handle.flat_param
- assert isinstance(param, flat_param_file.FlatParameter)
- global_fqns = [
- clean_tensor_name(prefix + name) for name in param._fqns
- ] # prefixed from the top level `model` (i.e. including `prefix`)
- if prefixed_module_name in sharded_module_name_to_fqns:
- sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
- else:
- sharded_module_name_to_fqns[prefixed_module_name] = global_fqns
- def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
- return sharded_tree_info[0], sharded_module_name_to_fqns
- # Use List to mutate its value in place while running the recursive functions
- sharded_tree_info: List[str] = [
- "",
- ]
- sharded_module_name_to_fqns: Dict[str, List[str]] = {}
- return _apply_to_modules(
- model,
- module_fn,
- return_fn,
- [key for key, _ in model.named_parameters()],
- sharded_tree_info,
- sharded_module_name_to_fqns,
- )
|