_init_utils.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import itertools
  4. import os
  5. import warnings
  6. from typing import (
  7. Any,
  8. Callable,
  9. Deque,
  10. Dict,
  11. Generator,
  12. Iterable,
  13. Iterator,
  14. List,
  15. no_type_check,
  16. Optional,
  17. Set,
  18. Tuple,
  19. TYPE_CHECKING,
  20. Union,
  21. )
  22. import torch
  23. import torch.distributed as dist
  24. import torch.distributed.fsdp._exec_order_utils as exec_order_utils
  25. import torch.distributed.fsdp._traversal_utils as traversal_utils
  26. import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
  27. import torch.nn as nn
  28. from torch.distributed.algorithms._comm_hooks import default_hooks
  29. from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
  30. from torch.distributed.distributed_c10d import _get_default_group
  31. from torch.distributed.fsdp._common_utils import (
  32. _FSDPDeviceHandle,
  33. _FSDPState,
  34. _get_module_fsdp_state,
  35. _is_fsdp_flattened,
  36. _named_parameters_with_duplicates,
  37. clean_tensor_name,
  38. TrainingState,
  39. )
  40. from torch.distributed.fsdp._flat_param import (
  41. _FSDP_USE_FULL_PREC_IN_EVAL,
  42. FlatParameter,
  43. FlatParamHandle,
  44. HandleShardingStrategy,
  45. )
  46. from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
  47. from torch.distributed.fsdp.api import (
  48. BackwardPrefetch,
  49. CPUOffload,
  50. FullOptimStateDictConfig,
  51. FullStateDictConfig,
  52. MixedPrecision,
  53. ShardingStrategy,
  54. StateDictConfig,
  55. StateDictType,
  56. )
  57. from torch.distributed.fsdp.wrap import _Policy
  58. from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
  59. from torch.distributed.utils import _sync_params_and_buffers
  60. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  61. if TYPE_CHECKING:
  62. from torch.utils.hooks import RemovableHandle
  63. _TORCHDISTX_AVAIL = True
  64. try:
  65. from torchdistx import deferred_init, fake # type: ignore[import]
  66. except ImportError:
  67. _TORCHDISTX_AVAIL = False
  68. PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024)
  69. FSDP_SYNCED = "_fsdp_synced"
  70. # Specification of process groups for hybrid sharding strategies.
  71. HybridShardProcessGroupType = Tuple[dist.ProcessGroup, dist.ProcessGroup]
  72. # Overall specification of process group.
  73. ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]]
  74. # TODO (awgu): Refactor this later
  75. SHARDING_STRATEGY_MAP = {
  76. ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
  77. ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
  78. ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
  79. ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
  80. ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
  81. }
  82. HYBRID_SHARDING_STRATEGIES = [
  83. ShardingStrategy.HYBRID_SHARD,
  84. ShardingStrategy._HYBRID_SHARD_ZERO2,
  85. ]
  86. NO_RESHARD_AFTER_FORWARD_STRATEGIES = (
  87. ShardingStrategy.SHARD_GRAD_OP,
  88. ShardingStrategy._HYBRID_SHARD_ZERO2,
  89. )
  90. # NOTE: Since non-self attributes cannot be type annotated, several attributes
  91. # on `state` are defined first as local variables before being assigned.
  92. @no_type_check
  93. def _init_process_group_state(
  94. state: _FSDPState,
  95. process_group: ProcessGroupType,
  96. sharding_strategy: ShardingStrategy,
  97. policy: Optional[_Policy],
  98. device_mesh: Optional[DeviceMesh] = None,
  99. ) -> _FSDPState:
  100. if process_group is not None and device_mesh is not None:
  101. raise ValueError(
  102. "Cannot pass both process_group and device_mesh at the "
  103. "same time. Please just pass only one of them."
  104. )
  105. is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES
  106. if is_hybrid_strategy:
  107. if process_group is None and policy is None and device_mesh is None:
  108. # Raise an error here, since this is manual wrapping with no process group
  109. # passed in, there is no way to ensure all wrapped FSDP instances use the same
  110. # process groups.
  111. raise ValueError(
  112. f"Manual wrapping with {sharding_strategy} "
  113. "requires explicit specification of process group or device_mesh."
  114. )
  115. else:
  116. state = _init_process_group_state_for_hybrid_shard(
  117. state, process_group, device_mesh
  118. )
  119. else:
  120. if device_mesh:
  121. state._device_mesh = device_mesh
  122. state.process_group = device_mesh.get_group(mesh_dim=0)
  123. else:
  124. state.process_group = (
  125. process_group if process_group is not None else _get_default_group()
  126. )
  127. state.rank = state.process_group.rank()
  128. state.world_size = state.process_group.size()
  129. data_parallel_world_size = state.world_size
  130. if is_hybrid_strategy:
  131. data_parallel_world_size *= state._inter_node_pg.size()
  132. state._gradient_predivide_factor = (
  133. default_hooks.DefaultState._get_gradient_predivide_factor(
  134. data_parallel_world_size
  135. )
  136. )
  137. state._gradient_postdivide_factor = (
  138. data_parallel_world_size / state._gradient_predivide_factor
  139. )
  140. return state
  141. @no_type_check
  142. def _init_process_group_state_for_hybrid_shard(
  143. state: _FSDPState,
  144. process_group: ProcessGroupType,
  145. device_mesh: DeviceMesh,
  146. ) -> _FSDPState:
  147. if device_mesh:
  148. if _is_valid_hybrid_shard_device_mesh(device_mesh):
  149. state._device_mesh = device_mesh
  150. # We currently only allow _inter_node_pg to be the outermost dimension, and the
  151. # process_group(intra_node) to be the innermost dimension.
  152. state._inter_node_pg = device_mesh.get_group(mesh_dim=0)
  153. state.process_group = device_mesh.get_group(mesh_dim=1)
  154. else:
  155. raise ValueError(
  156. f"Expected device_mesh to have ndim=2 but got {device_mesh.ndim}"
  157. )
  158. elif process_group is None:
  159. default_group = _get_default_group()
  160. intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
  161. default_group, state._device_handle.device_count()
  162. )
  163. # we shard across intra-node
  164. state.process_group = intra_node_group
  165. # save _inter_node_pg to allreduce across.
  166. state._inter_node_pg = inter_node_group
  167. else:
  168. # Check type and assign state.process_group and state._inter_node_pg.
  169. if _is_valid_hybrid_shard_pg_type(process_group):
  170. # Assuming that user passed in as intra node group and inter node group
  171. # as documented.
  172. state.process_group, state._inter_node_pg = process_group
  173. else:
  174. raise ValueError(
  175. "Expected process_group to be passed in as either None or "
  176. f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}"
  177. )
  178. # Create state for allreduce
  179. state._inter_node_state = _get_default_comm_hook_state(
  180. process_group=state._inter_node_pg,
  181. )
  182. return state
  183. @no_type_check
  184. def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool:
  185. return (
  186. isinstance(process_group, tuple)
  187. and len(process_group) == 2
  188. and all(isinstance(pg, dist.ProcessGroup) for pg in process_group)
  189. )
  190. @no_type_check
  191. def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool:
  192. return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2
  193. @no_type_check
  194. def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup:
  195. """
  196. Return a process group across the current node.
  197. For example, given each row is a distinct node:
  198. 0 1 2 3 4 5 6 7
  199. 8 9 10 11 12 13 14 15
  200. This API would return an intra-node subgroup across
  201. [0, 1, ..., 7] or [8, 9, ..., 15] depending on the process's rank.
  202. For example, rank 3 would get [0, 1, ..., 7].
  203. """
  204. intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node)
  205. return intra_node_subgroup
  206. @no_type_check
  207. def _init_inter_node_process_group(
  208. global_process_group: dist.ProcessGroup,
  209. num_devices_per_node: int,
  210. ) -> dist.ProcessGroup:
  211. """
  212. Return an inter-node process group where each contained rank has the same local rank.
  213. For example, given each row is a distinct node:
  214. 0 1 2 3 4 5 6 7
  215. 8 9 10 11 12 13 14 15
  216. This API would return inter-node process group [0, 8], [1, 9], [2, 10], and so forth
  217. depending on the process's rank. For example, rank 1 would get [1, 9], rank 5
  218. would get [5, 13].
  219. """
  220. # the inter-node pg that is returned
  221. inter_node_pg = None
  222. sharding_backend = dist.get_backend(global_process_group)
  223. world_size = dist.get_world_size(global_process_group)
  224. # Assuming fully homogeneous setup
  225. num_nodes = world_size // num_devices_per_node
  226. my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node
  227. for local_rank in range(num_devices_per_node):
  228. ranks_for_inter_group = [
  229. local_rank + (i * num_devices_per_node) for i in range(num_nodes)
  230. ]
  231. # every rank always needs to call dist.new_group
  232. grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend)
  233. if local_rank == my_local_rank:
  234. inter_node_pg = grp
  235. assert (
  236. inter_node_pg is not None
  237. ), f"{my_local_rank} expected to assign inter-node pg, but did not"
  238. return inter_node_pg
  239. def _init_intra_and_inter_node_groups(
  240. global_process_group: dist.ProcessGroup,
  241. num_devices_per_node: int,
  242. ) -> Tuple[dist.ProcessGroup, dist.ProcessGroup]:
  243. """
  244. Initialize intra and inter-node process groups and return the ones corresponding to this process's rank.
  245. This function can be used to initialize process groups for ``HYBRID_SHARD`` or
  246. ``_HYBRID_SHARD_ZERO2`` in FSDP.
  247. This function assumes each node has an equal number of CUDA-enabled devices.
  248. Returns:
  249. Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group.
  250. """
  251. return (
  252. _init_intra_node_process_group(num_devices_per_node),
  253. _init_inter_node_process_group(global_process_group, num_devices_per_node),
  254. )
  255. @no_type_check
  256. def _init_ignored_module_states(
  257. state: _FSDPState,
  258. module: nn.Module,
  259. ignored_modules: Optional[Iterable[torch.nn.Module]],
  260. ignored_states: Union[
  261. Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
  262. ] = None,
  263. ) -> _FSDPState:
  264. if ignored_modules is not None and ignored_states is not None:
  265. raise ValueError(
  266. "Cannot pass both ignored_modules and ignored_states at the "
  267. "same time. Please just pass ignored_states."
  268. )
  269. ignored_parameters = None
  270. passed_as_ignored_states = ignored_states is not None
  271. if passed_as_ignored_states:
  272. ignored_states_list = list(ignored_states)
  273. _check_ignored_states(ignored_states_list, True)
  274. else:
  275. ignored_states_list = []
  276. _check_ignored_states(
  277. list(ignored_modules) if ignored_modules is not None else [], False
  278. )
  279. if len(ignored_states_list) > 0:
  280. if isinstance(ignored_states_list[0], nn.Parameter):
  281. ignored_parameters = ignored_states_list
  282. else:
  283. ignored_modules = ignored_states_list
  284. state._ignored_modules = _get_ignored_modules(module, ignored_modules)
  285. state._ignored_params = _get_ignored_params(
  286. module,
  287. state._ignored_modules,
  288. ignored_parameters,
  289. )
  290. state._ignored_buffer_names = _get_ignored_buffer_names(
  291. module,
  292. state._ignored_modules,
  293. )
  294. # TODO: FSDP's contract for buffers is not well-defined. They are
  295. # implicitly ignored for most functionality since they are not sharded;
  296. # however, FSDP still imposes some semantics on buffers (e.g. buffer mixed
  297. # precision). We should formalize this contract and decide if we need to
  298. # compute and store `_ignored_buffers`.
  299. return state
  300. def _check_ignored_states(
  301. ignored_states: List[Any], passed_as_ignored_states: bool
  302. ) -> None:
  303. """
  304. Check that the ignored states are uniformly parameters or uniformly modules.
  305. We may remove this check in the future if we permit mixing.
  306. """
  307. if len(ignored_states) == 0:
  308. return
  309. if passed_as_ignored_states:
  310. all_params = all(isinstance(state, nn.Parameter) for state in ignored_states)
  311. all_modules = all(isinstance(state, nn.Module) for state in ignored_states)
  312. if not all_params and not all_modules:
  313. # Sort for consistent ordering for unit test regex matching
  314. sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
  315. raise ValueError(
  316. "ignored_states expects all nn.Parameter or all nn.Module list "
  317. f"elements but got types {sorted_types}"
  318. )
  319. else:
  320. if not all(isinstance(state, nn.Module) for state in ignored_states):
  321. sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
  322. raise ValueError(
  323. "ignored_modules expects nn.Module list elements but got "
  324. f"types {sorted_types}"
  325. )
  326. @no_type_check
  327. def _init_device_handle(
  328. state: _FSDPState,
  329. module: nn.Module,
  330. ignored_params: Set[nn.Parameter],
  331. device_id: Optional[Union[int, torch.device]],
  332. ) -> _FSDPState:
  333. """
  334. Determine device handle used for initializing FSDP.
  335. If a device is specified by ``device_id``,
  336. then returns device handle corresponds to that device type. Otherwise, If the
  337. module is already on a non-CPU device, then the device type is that non-CPU device type.
  338. If the module is on CPU or meta, then the device type is the current cuda device.
  339. This method will be called once ignored paramters was determined, as the device handle maybe needed
  340. for other initialization.
  341. """
  342. determined_device = None
  343. if device_id is not None:
  344. determined_device = (
  345. device_id
  346. if isinstance(device_id, torch.device)
  347. else torch.device(device_id)
  348. )
  349. if determined_device is None:
  350. for param in _get_orig_params(module, ignored_params):
  351. if param.device.type in {"cpu", "meta"}:
  352. continue
  353. if determined_device is None:
  354. determined_device = param.device
  355. else:
  356. if param.device.type != determined_device.type:
  357. raise RuntimeError(
  358. f"FSDP does not support modules with different device types "
  359. f"but got params on {determined_device.type} and {param.device.type}"
  360. )
  361. determined_device = determined_device or torch.device(
  362. "cuda", torch.cuda.current_device()
  363. )
  364. state._device_handle = _FSDPDeviceHandle.from_device(determined_device)
  365. return state
  366. @no_type_check
  367. def _init_buffer_state(
  368. state: _FSDPState,
  369. module: nn.Module,
  370. ) -> _FSDPState:
  371. state._buffer_names = _get_buffer_names(module)
  372. # Save a mapping from clean fully-qualified buffer name (starting from
  373. # `module`) to its original dtype for restoring that dtype during model
  374. # checkpointing when buffer mixed precision is enabled. The names should
  375. # be clean since the casting happens in a `summon_full_params()` context.
  376. _buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
  377. for buffer_name, buffer in module.named_buffers():
  378. buffer_name = clean_tensor_name(buffer_name)
  379. _buffer_name_to_orig_dtype[buffer_name] = buffer.dtype
  380. state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype
  381. return state
  382. @no_type_check
  383. def _init_core_state(
  384. state: _FSDPState,
  385. sharding_strategy: Optional[ShardingStrategy],
  386. mixed_precision: Optional[MixedPrecision],
  387. cpu_offload: Optional[CPUOffload],
  388. limit_all_gathers: bool,
  389. use_orig_params: bool,
  390. backward_prefetch_limit: int,
  391. forward_prefetch_limit: int,
  392. ) -> _FSDPState:
  393. # We clamp the strategy to `NO_SHARD` for world size of 1 since they are
  394. # currently functionally equivalent. This may change if/when we integrate
  395. # FSDP with MoE.
  396. if state.world_size == 1:
  397. if sharding_strategy != ShardingStrategy.NO_SHARD:
  398. warnings.warn(
  399. "FSDP is switching to use `NO_SHARD` instead of "
  400. f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since "
  401. "the world size is 1."
  402. )
  403. sharding_strategy = ShardingStrategy.NO_SHARD
  404. elif sharding_strategy == ShardingStrategy.NO_SHARD:
  405. warnings.warn(
  406. "The `NO_SHARD` sharding strategy is deprecated. If having issues, "
  407. "please use `DistributedDataParallel` instead.",
  408. FutureWarning,
  409. # Level 1 is here, level 2 is from `FullyShardedDataParallel`, and
  410. # level 3 is from the true caller
  411. stacklevel=3,
  412. )
  413. state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD
  414. state.mixed_precision = mixed_precision or MixedPrecision()
  415. if mixed_precision is not None:
  416. torch._C._log_api_usage_once(
  417. f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}"
  418. )
  419. state._use_full_prec_in_eval = (
  420. os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
  421. )
  422. state.cpu_offload = cpu_offload or CPUOffload()
  423. state.limit_all_gathers = limit_all_gathers
  424. state._use_orig_params = use_orig_params
  425. state.training_state = TrainingState.IDLE
  426. state._is_root = None
  427. state._free_event_queue = _FreeEventQueue()
  428. state._debug_level = dist.get_debug_level()
  429. state._exec_order_data = exec_order_utils._ExecOrderData(
  430. state._debug_level,
  431. backward_prefetch_limit,
  432. forward_prefetch_limit,
  433. )
  434. state._unshard_event = None
  435. # Mapping from fully sharded module to the handles it is responsible to
  436. # unshard and reshard (see [Note: Fully Sharded Module])
  437. _fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = dict()
  438. state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle
  439. # Invariant: `state.params` contains exactly the `FlatParameter`s of the
  440. # handles in `state._handle`
  441. _handle: FlatParamHandle = None
  442. state._handle = _handle
  443. params: List[FlatParameter] = []
  444. state.params = params
  445. return state
  446. @no_type_check
  447. def _init_runtime_state(
  448. state: _FSDPState,
  449. ) -> _FSDPState:
  450. _root_pre_forward_handles: List[RemovableHandle] = []
  451. state._root_pre_forward_handles = _root_pre_forward_handles
  452. _pre_forward_handles: List[RemovableHandle] = []
  453. state._pre_forward_handles = _pre_forward_handles
  454. _post_forward_handles: List[RemovableHandle] = []
  455. state._post_forward_handles = _post_forward_handles
  456. state._sync_gradients = True
  457. state._comm_hook = None
  458. state._comm_hook_state = None
  459. # Used to prevent running the pre-backward hook multiple times
  460. return state
  461. @no_type_check
  462. def _init_prefetching_state(
  463. state: _FSDPState,
  464. backward_prefetch: BackwardPrefetch,
  465. forward_prefetch: bool,
  466. ) -> _FSDPState:
  467. state.backward_prefetch = backward_prefetch
  468. state.forward_prefetch = forward_prefetch
  469. # The data structures use tuples of handles to generalize over the case
  470. # where a module's forward involves multiple handles.
  471. return state
  472. @no_type_check
  473. def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
  474. # TODO: we need to add additional check once we support FSDP + PiPPy.
  475. # This check is currently sufficient, since we only support FSDP + TP.
  476. if device_mesh and _mesh_resources.get_parent_mesh(state._device_mesh) is not None:
  477. state._fsdp_extension = DTensorExtensions(state._device_handle)
  478. else:
  479. # We need to explicilty set _fsdp_extension to None.
  480. # Otherwise, we will run into an infinite recursion when getting the attribute.
  481. state._fsdp_extension = None
  482. return state
  483. @no_type_check
  484. def _init_state_dict_state(state: _FSDPState) -> _FSDPState:
  485. state._state_dict_type = StateDictType.FULL_STATE_DICT
  486. state_dict_config: StateDictConfig = FullStateDictConfig()
  487. state._optim_state_dict_config = FullOptimStateDictConfig()
  488. state._state_dict_config = state_dict_config
  489. unshard_params_ctx: Dict[nn.Module, Generator] = {}
  490. state._unshard_params_ctx = unshard_params_ctx
  491. return state
  492. @no_type_check
  493. def _init_param_handle_from_module(
  494. state: _FSDPState,
  495. fully_sharded_module: nn.Module,
  496. device_id: Optional[Union[int, torch.device]],
  497. param_init_fn: Optional[Callable[[nn.Module], None]],
  498. sync_module_states: bool,
  499. ) -> _FSDPState:
  500. """Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``."""
  501. _check_single_device_module(fully_sharded_module, state._ignored_params, device_id)
  502. device_from_device_id = _get_device_from_device_id(device_id, state.rank)
  503. is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
  504. fully_sharded_module, state._ignored_params, state._ignored_modules
  505. )
  506. # Materialize the module if needed
  507. if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None:
  508. _materialize_with_param_init_fn(
  509. fully_sharded_module, param_init_fn, state._ignored_modules
  510. )
  511. elif is_meta_module:
  512. _materialize_meta_module(
  513. fully_sharded_module, device_id, state._ignored_modules
  514. )
  515. elif is_torchdistX_deferred_init:
  516. deferred_init.materialize_module(
  517. fully_sharded_module,
  518. check_fn=lambda submodule: _get_module_fsdp_state(submodule) is None
  519. and submodule not in state._ignored_modules,
  520. )
  521. ignored_buffers = {
  522. buffer
  523. for ignored_module in state._ignored_modules
  524. for buffer in ignored_module.buffers()
  525. }
  526. _move_module_to_device(
  527. fully_sharded_module,
  528. state._ignored_params,
  529. ignored_buffers,
  530. device_from_device_id,
  531. )
  532. state.compute_device = _get_compute_device(
  533. fully_sharded_module,
  534. state._ignored_params,
  535. device_from_device_id,
  536. state.rank,
  537. )
  538. managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
  539. if sync_module_states:
  540. _sync_module_params_and_buffers(
  541. fully_sharded_module, managed_params, state.process_group
  542. )
  543. if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
  544. _sync_module_params_and_buffers(
  545. fully_sharded_module, managed_params, state._inter_node_pg
  546. )
  547. _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  548. return state
  549. @no_type_check
  550. def _init_param_handle_from_params(
  551. state: _FSDPState,
  552. params: List[nn.Parameter],
  553. fully_sharded_module: nn.Module,
  554. ):
  555. if len(params) == 0:
  556. return
  557. handle = FlatParamHandle(
  558. params,
  559. fully_sharded_module,
  560. state.compute_device,
  561. SHARDING_STRATEGY_MAP[state.sharding_strategy],
  562. state.cpu_offload.offload_params,
  563. state.mixed_precision.param_dtype,
  564. state.mixed_precision.reduce_dtype,
  565. state.mixed_precision.keep_low_precision_grads,
  566. state.process_group,
  567. state._use_orig_params,
  568. fsdp_extension=state._fsdp_extension,
  569. )
  570. handle.shard()
  571. assert not state._handle
  572. state.params.append(handle.flat_param)
  573. state._handle = handle
  574. state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle
  575. cpu_device = torch.device("cpu")
  576. if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device:
  577. handle.flat_param_to(cpu_device)
  578. def _get_ignored_modules(
  579. root_module: nn.Module,
  580. _ignored_modules: Optional[Iterable[torch.nn.Module]],
  581. ) -> Set[nn.Module]:
  582. """
  583. Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances.
  584. Return the modules contained in their module
  585. subtrees as a :class:`set`. Nested FSDP instances are excluded, but their
  586. already-computed ignored modules are included.
  587. ``_ignored_modules`` represents the argument passed by the user to FSDP.
  588. """
  589. msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
  590. try:
  591. ignored_root_modules = (
  592. set(_ignored_modules) if _ignored_modules is not None else set()
  593. )
  594. except TypeError as e:
  595. raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e
  596. for module in ignored_root_modules:
  597. if not isinstance(module, torch.nn.Module):
  598. raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
  599. if _get_module_fsdp_state(module):
  600. # TODO: We may relax this by taking the FSDP instance's wrapped
  601. # module to provide more flexibility to the user.
  602. raise ValueError("`ignored_modules` should not include FSDP modules")
  603. # Treat modules that cannot compose with `fully_shard` as ignored modules,
  604. # meaning that their subtrees are ignored
  605. for module in root_module.modules():
  606. if not traversal_utils._composable(module):
  607. ignored_root_modules.add(module)
  608. # NOTE: Even if `ignored_root_modules` is empty, do not return early so
  609. # that this FSDP instance can get any ignored modules from its children.
  610. # Include child modules and exclude nested FSDP modules themselves
  611. ignored_modules = {
  612. child
  613. for module in ignored_root_modules
  614. for child in module.modules()
  615. if not isinstance(child, fsdp_file.FullyShardedDataParallel)
  616. }
  617. if root_module in ignored_modules:
  618. warnings.warn(
  619. "Trying to ignore the top-level module passed into the FSDP "
  620. "constructor itself will result in all parameters being "
  621. f"ignored and is not well-supported: {module}"
  622. )
  623. # Include nested FSDP modules' ignored modules
  624. for submodule in root_module.modules():
  625. optional_fsdp_state = _get_module_fsdp_state(submodule)
  626. if optional_fsdp_state is not None:
  627. assert hasattr(optional_fsdp_state, "_ignored_modules")
  628. ignored_modules.update(optional_fsdp_state._ignored_modules)
  629. return ignored_modules
  630. def _get_ignored_params(
  631. root_module: torch.nn.Module,
  632. ignored_modules: Set[torch.nn.Module],
  633. ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
  634. ) -> Set[torch.nn.Parameter]:
  635. """
  636. Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``.
  637. :class:`FlatParameter` s are excluded from the result.
  638. """
  639. all_ignored_params: Set[torch.nn.Parameter] = set()
  640. params_in_ignored_modules = {
  641. p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
  642. }
  643. all_ignored_params.update(params_in_ignored_modules)
  644. if ignored_parameters is not None:
  645. params_in_ignored_parameters = {
  646. p for p in ignored_parameters if not _is_fsdp_flattened(p)
  647. }
  648. all_ignored_params.update(params_in_ignored_parameters)
  649. # Always include nested FSDP modules' ignored parameters
  650. for submodule in root_module.modules():
  651. optional_fsdp_state = _get_module_fsdp_state(submodule)
  652. if optional_fsdp_state is not None:
  653. assert hasattr(optional_fsdp_state, "_ignored_params")
  654. all_ignored_params.update(optional_fsdp_state._ignored_params)
  655. return all_ignored_params
  656. def _get_ignored_buffer_names(
  657. root_module: torch.nn.Module,
  658. ignored_modules: Set[torch.nn.Module],
  659. ) -> Set[str]:
  660. """Return the cleaned buffer FQNs in ``ignored_modules``."""
  661. all_ignored_buffer_names: Set[str] = set()
  662. buffers_in_ignored_modules = {
  663. buffer for m in ignored_modules for buffer in m.buffers()
  664. }
  665. all_ignored_buffer_names.update(
  666. {
  667. clean_tensor_name(buffer_name)
  668. for buffer_name, buffer in root_module.named_buffers()
  669. if buffer in buffers_in_ignored_modules
  670. }
  671. )
  672. # Always include nested FSDP modules' ignored buffer names
  673. for submodule in root_module.modules():
  674. optional_fsdp_state = _get_module_fsdp_state(submodule)
  675. if optional_fsdp_state is not None:
  676. assert hasattr(optional_fsdp_state, "_ignored_buffer_names")
  677. all_ignored_buffer_names.update(optional_fsdp_state._ignored_buffer_names)
  678. return all_ignored_buffer_names
  679. def _get_buffer_names(root_module: nn.Module) -> Set[str]:
  680. """Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`."""
  681. return {
  682. clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
  683. }
  684. def _check_single_device_module(
  685. module: nn.Module,
  686. ignored_params: Set[nn.Parameter],
  687. device_id: Optional[Union[int, torch.device]],
  688. ) -> None:
  689. """
  690. Raise an error if ``module`` has original parameters on multiple devices, ignoring the parameters in ``ignored_params``.
  691. Thus, after this method, the
  692. module must be either fully on the CPU or fully on a non-CPU device.
  693. """
  694. devices = {param.device for param in _get_orig_params(module, ignored_params)}
  695. # We allow module to be partially on CPU and partially on GPU if device_id is not
  696. # None, since the device_id arg will result in the CPU portion being moved to
  697. # GPU. This is useful in cases where part of the module may be parallelized
  698. # by another algorithm and may already be on GPU. We'd like to enforce device_id
  699. # to not be None, otherwise we'd flatten parameters in a mixed module which is
  700. # not supported.
  701. if len(devices) == 2 and torch.device("cpu") in devices:
  702. if device_id is None:
  703. raise RuntimeError(
  704. "To support a module with both CPU and GPU params, "
  705. "please pass in device_id argument."
  706. )
  707. elif len(devices) > 1:
  708. raise RuntimeError(
  709. f"FSDP only supports single device modules but got params on {devices}"
  710. )
  711. def _get_device_from_device_id(
  712. device_id: Optional[Union[int, torch.device]],
  713. rank: int,
  714. ) -> Optional[torch.device]:
  715. """
  716. Return a ``torch.device`` for the specified ``device_id``.
  717. Processes ``device_id`` and returns either the corresponding device or
  718. ``None`` if ``device_id`` is ``None``.
  719. """
  720. if device_id is None:
  721. return None
  722. device = (
  723. device_id if isinstance(device_id, torch.device) else torch.device(device_id)
  724. )
  725. if device == torch.device("cuda"):
  726. warnings.warn(
  727. f"FSDP got the argument `device_id` {device_id} on rank "
  728. f"{rank}, which does not have an explicit index. "
  729. f"FSDP will use the current device {torch.cuda.current_device()}. "
  730. "If this is incorrect, please explicitly call `torch.cuda.set_device()` "
  731. "before FSDP initialization or pass in the explicit device "
  732. "index as the `device_id` argument."
  733. )
  734. device = torch.device("cuda", torch.cuda.current_device())
  735. return device
  736. def _need_to_materialize_module(
  737. module: nn.Module,
  738. ignored_params: Set[nn.Parameter],
  739. ignored_modules: Set[nn.Module],
  740. ) -> Tuple[bool, bool]:
  741. """
  742. Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization.
  743. At most of the returned bools can
  744. be ``True``. If either is ``True``, then ``module`` needs to be
  745. materialized.
  746. """
  747. managed_params = list(_get_orig_params(module, ignored_params))
  748. is_meta_module = any(param.is_meta for param in managed_params)
  749. # TODO: We need to establish a contract for FSDP and buffers. For now, we
  750. # skip checking for meta buffers from ignored modules. We should consider
  751. # refactoring the initialization holistically to avoid so many traversals.
  752. for submodule in module.modules():
  753. if submodule in ignored_modules:
  754. continue
  755. for buf in submodule.buffers(recurse=False):
  756. is_meta_module |= buf.is_meta
  757. is_torchdistX_deferred_init = (
  758. not is_meta_module
  759. and _TORCHDISTX_AVAIL
  760. and any(fake.is_fake(param) for param in managed_params)
  761. )
  762. return is_meta_module, is_torchdistX_deferred_init
  763. def _materialize_with_param_init_fn(
  764. root_module: nn.Module,
  765. param_init_fn: Callable[[nn.Module], None],
  766. ignored_modules: Set[nn.Module],
  767. ) -> None:
  768. if not callable(param_init_fn):
  769. raise ValueError(
  770. f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}"
  771. )
  772. modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
  773. for module in modules_to_materialize:
  774. param_init_fn(module)
  775. def _materialize_meta_module(
  776. root_module: nn.Module,
  777. device_from_device_id: Optional[torch.device],
  778. ignored_modules: Set[nn.Module],
  779. ):
  780. # Run default meta device initialization
  781. materialization_device = device_from_device_id or torch.device(
  782. torch.cuda.current_device()
  783. )
  784. modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
  785. try:
  786. # Assume that each module's `reset_parameters()` only initializes its
  787. # own parameters and not those of its children
  788. with torch.no_grad():
  789. for module in modules_to_materialize:
  790. # As a contract to the user, only call `reset_parameters()` if
  791. # the module has directly managed parameters/buffers
  792. module_state_iter = itertools.chain(
  793. module.parameters(recurse=False), module.buffers(recurse=False)
  794. )
  795. has_module_states = len(list(module_state_iter)) > 0
  796. if has_module_states:
  797. module.to_empty(device=materialization_device, recurse=False)
  798. module.reset_parameters() # type: ignore[operator]
  799. except BaseException as e:
  800. warnings.warn(
  801. "Unable to call `reset_parameters()` for module on meta "
  802. f"device with error {str(e)}. Please ensure that your module of"
  803. f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined]
  804. )
  805. raise e
  806. def _get_modules_to_materialize(
  807. root_module: nn.Module, ignored_modules: Set[nn.Module]
  808. ) -> List[nn.Module]:
  809. # Run BFS to collect the modules to materialize via `reset_parameters()`,
  810. # stopping at any module with FSDP already applied or at ignored modules.
  811. modules_to_materialize: List[nn.Module] = []
  812. queue = collections.deque([root_module])
  813. visited_modules: Set[nn.Module] = {root_module}
  814. while queue:
  815. module = queue.popleft()
  816. modules_to_materialize.append(module)
  817. for child_module in module.children():
  818. if (
  819. child_module not in visited_modules
  820. and _get_module_fsdp_state(child_module) is None
  821. and child_module not in ignored_modules
  822. ):
  823. visited_modules.add(child_module)
  824. queue.append(child_module)
  825. return modules_to_materialize
  826. def _move_module_to_device(
  827. module: nn.Module,
  828. ignored_params: Set[nn.Parameter],
  829. ignored_buffers: Set[torch.Tensor],
  830. device_from_device_id: Optional[torch.device],
  831. ) -> None:
  832. """
  833. Move ``module`` depending on ``device_from_device_id`` and its current device.
  834. This includes moving ignored modules' parameters.
  835. - If ``device_from_device_id`` is not ``None``, then this moves
  836. ``module`` to the device.
  837. - If ``device_from_device_id`` is ``None``, then this does not move
  838. ``module`` but warns the user if it is on CPU.
  839. Precondition: ``_check_single_device_module()``.
  840. """
  841. cpu_device = torch.device("cpu")
  842. if device_from_device_id is not None:
  843. # BFS from `module` without traversing any nested FSDP instances to
  844. # collect the parameters/buffers that have not yet been managed
  845. queue: Deque[nn.Module] = collections.deque()
  846. queue.append(module)
  847. params: List[nn.Parameter] = []
  848. buffers: List[torch.Tensor] = []
  849. while queue:
  850. curr_module = queue.popleft()
  851. # NOTE: We include a check to only move parameters/buffers that are
  852. # on CPU device. If they are on a CUDA device different from the
  853. # one specified by `device_id`, then this does NOT move them. This
  854. # is so that we can raise an error in `_get_compute_device()`.
  855. params.extend(
  856. param
  857. for param in curr_module.parameters(recurse=False)
  858. if param.device == cpu_device
  859. )
  860. buffers.extend(
  861. buffer
  862. for buffer in curr_module.buffers(recurse=False)
  863. if buffer.device == cpu_device
  864. )
  865. for submodule in curr_module.children():
  866. if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
  867. queue.append(submodule)
  868. params_to_move = [p for p in params if p not in ignored_params]
  869. bufs_to_move = [p for p in buffers if p not in ignored_buffers]
  870. _move_states_to_device(params_to_move, bufs_to_move, device_from_device_id)
  871. return
  872. param = next(_get_orig_params(module, ignored_params), None)
  873. if param is not None and param.device == cpu_device:
  874. _warn_cpu_init()
  875. def _move_states_to_device(
  876. params: List[nn.Parameter],
  877. buffers: List[torch.Tensor],
  878. device_from_device_id: Optional[torch.device],
  879. ) -> None:
  880. """
  881. Move states to the specified device.
  882. Precondition: ``_check_single_device_module()`` and module's parameters and
  883. buffers have been materialized if needed.
  884. """
  885. if len(params) == 0 and len(buffers) == 0:
  886. return
  887. if len(params) > 0:
  888. current_device = params[0].device
  889. elif len(buffers) > 0:
  890. current_device = buffers[0].device
  891. cpu_device = torch.device("cpu")
  892. if device_from_device_id is not None:
  893. # Move the parameters and buffers like the `.data` code path in
  894. # `nn.Module._apply()`, which underlies `nn.Module.to()`
  895. for param in params:
  896. with torch.no_grad():
  897. param.data = param.to(device_from_device_id)
  898. if param.grad is not None:
  899. param.grad.data = param.grad.to(device_from_device_id)
  900. for buffer in buffers:
  901. buffer.data = buffer.to(device_from_device_id)
  902. elif current_device == cpu_device: # type: ignore[possibly-undefined]
  903. _warn_cpu_init()
  904. def _warn_cpu_init():
  905. warnings.warn(
  906. "The passed-in `module` is on CPU and will thus have FSDP's sharding "
  907. "initialization run on CPU, which may be slower than on GPU. We "
  908. "recommend passing in the `device_id` argument for FSDP to move "
  909. "`module` to GPU for the sharding initialization. `module` must also "
  910. "be on GPU device to work with the `sync_module_states=True` flag "
  911. "since that requires GPU communication."
  912. )
  913. def _get_compute_device(
  914. module: nn.Module,
  915. ignored_params: Set[nn.Parameter],
  916. device_from_device_id: Optional[torch.device],
  917. rank: int,
  918. ) -> torch.device:
  919. """
  920. Determine and return this FSDP instance's compute device.
  921. If a device is
  922. specified by ``device_id``, then returns that device. Otherwise, If the
  923. module is already on a non-CPU device, then the compute device is that non-CPU
  924. device. If the module is on CPU, then the compute device is the current
  925. device.
  926. Since this method should be called after materializing the module, any
  927. non-CPU device should not be meta device. For now, the compute device is
  928. always a CUDA GPU device with its explicit index.
  929. Precondition: ``_check_single_device_module()`` and
  930. ``_move_module_to_device()``.
  931. """
  932. param = next(_get_orig_params(module, ignored_params), None)
  933. if param is not None and param.device.type != "cpu":
  934. compute_device = param.device # Determined by model param placement
  935. else:
  936. if device_from_device_id is not None and device_from_device_id.type != "cuda":
  937. compute_device = device_from_device_id # Determined by custom backend
  938. else:
  939. compute_device = torch.device("cuda", torch.cuda.current_device())
  940. if device_from_device_id is not None and compute_device != device_from_device_id:
  941. raise ValueError(
  942. f"Inconsistent compute device and `device_id` on rank {rank}: "
  943. f"{compute_device} vs {device_from_device_id}"
  944. )
  945. return compute_device
  946. # TODO: See how to deprecate!
  947. def _sync_module_params_and_buffers(
  948. module: nn.Module,
  949. params: List[nn.Parameter],
  950. process_group: dist.ProcessGroup,
  951. ) -> None:
  952. """
  953. Synchronize module states (i.e. parameters ``params`` and all not-yet-synced buffers) by broadcasting from rank 0 to all ranks.
  954. Precondition: ``sync_module_states == True`` and ``self.process_group`` has
  955. been set.
  956. """
  957. module_states: List[torch.Tensor] = []
  958. for buffer in module.buffers():
  959. # Avoid re-synchronizing buffers in case of nested wrapping
  960. if not getattr(buffer, FSDP_SYNCED, False):
  961. setattr(buffer, FSDP_SYNCED, True)
  962. detached_buffer = buffer.detach()
  963. if is_traceable_wrapper_subclass(detached_buffer):
  964. # NOTE: Here we assume no nested subclasses, at most one level of subclass
  965. # in both model's buffers and params
  966. attrs, _ = detached_buffer.__tensor_flatten__() # type: ignore[attr-defined]
  967. inner_buffers = [getattr(detached_buffer, attr) for attr in attrs]
  968. module_states.extend(inner_buffers)
  969. else:
  970. module_states.append(detached_buffer)
  971. for param in params:
  972. detached_param = param.detach()
  973. if is_traceable_wrapper_subclass(detached_param):
  974. attrs, _ = detached_param.__tensor_flatten__() # type: ignore[attr-defined]
  975. inner_params = [getattr(detached_param, attr) for attr in attrs]
  976. module_states.extend(inner_params)
  977. else:
  978. module_states.append(detached_param)
  979. _check_module_states_for_sync_module_states(module_states)
  980. _sync_params_and_buffers(
  981. process_group,
  982. module_states,
  983. PARAM_BROADCAST_BUCKET_SIZE,
  984. src=0,
  985. )
  986. def _check_module_states_for_sync_module_states(
  987. module_states: List[torch.Tensor],
  988. ) -> None:
  989. if module_states and any(
  990. tensor.device == torch.device("cpu") for tensor in module_states
  991. ):
  992. raise ValueError(
  993. "The module has CPU parameters or buffers when `sync_module_states=True`, "
  994. "which requires them to be on GPU. Please specify the `device_id` argument "
  995. "or move the module to GPU before passing it to FSDP."
  996. )
  997. def _get_orig_params(
  998. module: nn.Module,
  999. ignored_params: Set[nn.Parameter],
  1000. ) -> Iterator[nn.Parameter]:
  1001. """
  1002. Return an iterator over the original parameters in ``module``.
  1003. The iterator does not return
  1004. the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be
  1005. present due to nested FSDP wrapping), or any original parameters already
  1006. flattened (only relevant when ``use_orig_params=True``).
  1007. """
  1008. param_gen = module.parameters()
  1009. try:
  1010. while True:
  1011. param = next(param_gen)
  1012. if param not in ignored_params and not _is_fsdp_flattened(param):
  1013. yield param
  1014. except StopIteration:
  1015. pass
  1016. def _check_orig_params_flattened(
  1017. fsdp_module,
  1018. ignored_params: Set[nn.Parameter],
  1019. ) -> None:
  1020. """
  1021. Check that original parameters in ``fsdp_module`` have been flattened.
  1022. The flattened parameters are made
  1023. invisible to ``named_parameters()`` for the module hierarchy rooted at
  1024. ``fsdp_module``. This should be called as a sanity check after flattening
  1025. the wrapped module's parameters.
  1026. """
  1027. for param_name, param in _named_parameters_with_duplicates(fsdp_module):
  1028. if param not in ignored_params and not _is_fsdp_flattened(param):
  1029. raise RuntimeError(
  1030. f"Found an unflattened parameter: {param_name}; "
  1031. f"{param.size()} {param.__class__}"
  1032. )
  1033. def _get_default_comm_hook(sharding_strategy: ShardingStrategy):
  1034. return (
  1035. default_hooks.allreduce_hook
  1036. if sharding_strategy == ShardingStrategy.NO_SHARD
  1037. else default_hooks.reduce_scatter_hook
  1038. )
  1039. def _get_default_comm_hook_state(
  1040. process_group: dist.ProcessGroup,
  1041. ) -> default_hooks.DefaultState:
  1042. return default_hooks.DefaultState(process_group=process_group)