common_fsdp.py 54 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509
  1. # mypy: allow-untyped-defs
  2. # Owner(s): ["oncall: distributed"]
  3. import contextlib
  4. import os
  5. import re
  6. import sys
  7. import warnings
  8. from abc import ABC, abstractmethod
  9. from contextlib import nullcontext
  10. from copy import deepcopy
  11. from enum import auto, Enum
  12. from functools import wraps
  13. from typing import (
  14. Any,
  15. Callable,
  16. Dict,
  17. List,
  18. no_type_check,
  19. Optional,
  20. Tuple,
  21. Type,
  22. Union,
  23. )
  24. from unittest import mock
  25. import torch
  26. import torch.distributed as dist
  27. import torch.nn as nn
  28. import torch.nn.functional as F
  29. from torch.distributed._composable import checkpoint
  30. from torch.distributed._composable.fsdp import fully_shard
  31. from torch.distributed._composable.fsdp._fsdp_param_group import (
  32. FSDPParamGroup,
  33. RegisterPostBackwardFunction,
  34. )
  35. from torch.distributed._tensor import distribute_tensor, DTensor, Shard
  36. from torch.distributed.device_mesh import DeviceMesh
  37. from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
  38. from torch.distributed.fsdp._common_utils import TrainingState
  39. from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
  40. from torch.distributed.fsdp.fully_sharded_data_parallel import (
  41. BackwardPrefetch,
  42. MixedPrecision,
  43. ShardingStrategy,
  44. )
  45. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  46. from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap
  47. from torch.distributed.tensor.parallel import (
  48. ColwiseParallel,
  49. parallelize_module,
  50. RowwiseParallel,
  51. SequenceParallel,
  52. )
  53. from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
  54. from torch.nn.parallel.distributed import DistributedDataParallel as DDP
  55. from torch.testing._internal.common_distributed import (
  56. MultiProcessTestCase,
  57. MultiThreadedTestCase,
  58. run_subtests,
  59. TEST_SKIPS,
  60. )
  61. from torch.testing._internal.common_utils import FILE_SCHEMA, get_cycles_per_ms
  62. from torch.utils._triton import has_triton
  63. class FSDPInitMode(Enum):
  64. # No FSDP wrapping
  65. NO_FSDP = auto()
  66. # FSDP recursive wrapping
  67. RECURSIVE = auto()
  68. # TODO: FSDP non-recursive wrapping
  69. # NONRECURSIVE = auto()
  70. class CUDAInitMode(Enum):
  71. # Move model to CUDA before passing to the FSDP constructor
  72. CUDA_BEFORE = auto()
  73. # Move model to CUDA after passing to the FSDP constructor
  74. CUDA_AFTER = auto()
  75. # Keep on CPU
  76. CUDA_NEVER = auto()
  77. class FSDPTestModel(nn.Module, ABC):
  78. """This defines the interface expected from all models used commonly for
  79. FSDP unit tests."""
  80. @abstractmethod
  81. def get_input(self, device) -> Tuple[torch.Tensor, ...]:
  82. """Returns an input for the model as as tuple."""
  83. ...
  84. @abstractmethod
  85. def get_loss(self, input, output) -> torch.Tensor:
  86. """Returns the loss given the input and output."""
  87. ...
  88. @abstractmethod
  89. def run_backward(self, loss) -> None:
  90. """Runs the backward pass (e.g. including ``loss.backward()``)."""
  91. ...
  92. @staticmethod
  93. @abstractmethod
  94. def init(*args: Any, **kwargs: Any) -> nn.Module:
  95. """Initializes an instance of this model."""
  96. ...
  97. def _assert_module_states(
  98. model: nn.Module,
  99. process_group: dist.ProcessGroup,
  100. assert_fn: Callable,
  101. ):
  102. """
  103. All-gathers module states across ranks and calls ``assert_fn`` on each pair
  104. of corresponding states from rank 0 and a nonzero rank. For example, if
  105. ``assert_fn`` is ``self.assertEqual()``, then this checks that all module
  106. states are equal across ranks.
  107. """
  108. # Include names for debugging convenience
  109. named_module_states = [
  110. (param_name, param.detach().cpu())
  111. for param_name, param in model.named_parameters()
  112. ]
  113. named_module_states += [
  114. (buffer_name, buffer.detach().cpu())
  115. for buffer_name, buffer in model.named_buffers()
  116. ]
  117. world_size = dist.get_world_size(process_group)
  118. olist = [None for _ in range(world_size)]
  119. dist.all_gather_object(olist, named_module_states, group=process_group)
  120. rank0_states = olist[0]
  121. assert rank0_states is not None # mypy
  122. for state in olist[1:]:
  123. assert state is not None # mypy
  124. for (_, p1), (_, p2) in zip(rank0_states, state):
  125. assert_fn(p1, p2)
  126. def _zero_model(
  127. model: nn.Module,
  128. zero_buffers: bool = False,
  129. summon_full=True,
  130. ):
  131. """Zeros the parameters and optionally buffers of ``model`` in place."""
  132. ctx = FSDP.summon_full_params(model) if summon_full else nullcontext()
  133. with ctx:
  134. for param in model.parameters():
  135. with torch.no_grad():
  136. param.zero_()
  137. if zero_buffers:
  138. for buffer in model.buffers():
  139. with torch.no_grad():
  140. buffer.zero_()
  141. def _get_state_dict(model, cpu_offload=False, half=False):
  142. if not cpu_offload:
  143. model = model.cuda()
  144. if half:
  145. model.half()
  146. return model.state_dict()
  147. def subtest_name(test_name_mapping, *args):
  148. return "_".join(
  149. [test_name_mapping[str(s)] if s is not None else "none" for s in args]
  150. )
  151. def _broadcast_state_dict(rank, state_dict):
  152. # For non-FSDP roots, some parts of the model state on rank 0 may
  153. # not be on CPU, so we move everything to CPU to avoid issues like:
  154. # https://github.com/pytorch/pytorch/issues/77113.
  155. for param_name, param in state_dict.items():
  156. if param.device != torch.device("cpu"):
  157. state_dict[param_name] = param.cpu()
  158. olist = [state_dict if rank == 0 else None]
  159. dist.broadcast_object_list(olist)
  160. state_dict = olist[0]
  161. # Ensure that the state is on CUDA
  162. for param_name in state_dict.keys():
  163. state_dict[param_name] = state_dict[param_name].cuda()
  164. return state_dict
  165. def get_full_params(model: nn.Module, recurse: bool = True):
  166. """
  167. Returns the full unsharded parameters of ``model``. Any FSDP-managed
  168. parameters offloaded to CPU are moved to GPU in the returned list.
  169. Args:
  170. recurse (bool): If ``False``, only unshards the parameters immediate to
  171. ``model``; if ``True``, recurses through the module hierarchy
  172. rooted at ``model``.
  173. """
  174. with FSDP.summon_full_params(model, recurse=recurse):
  175. return deepcopy(list(model.parameters()))
  176. def _maybe_cuda(model: nn.Module, move_to_cuda: bool):
  177. return model.cuda() if move_to_cuda else model
  178. def _maybe_wrap_fsdp(model: nn.Module, wrap_fsdp: bool, *args, **kwargs):
  179. return model if not wrap_fsdp else FSDP(model, *args, **kwargs)
  180. class DummyProcessGroup:
  181. def __init__(self, rank: int, size: int):
  182. self._rank = rank
  183. self._size = size
  184. def rank(self) -> int:
  185. return self._rank
  186. def size(self) -> int:
  187. return self._size
  188. def allreduce(self, *args, **kwargs):
  189. dist_wait = mock.Mock()
  190. def get_future():
  191. future: torch.futures.Future = torch.futures.Future()
  192. future.set_result(1)
  193. return future
  194. dist_wait.get_future = get_future
  195. return dist_wait
  196. class TransformerWithSharedParams(FSDPTestModel):
  197. def __init__(
  198. self,
  199. group: dist.ProcessGroup,
  200. cuda_init_mode: CUDAInitMode,
  201. add_bn: bool,
  202. deterministic: bool,
  203. ):
  204. super().__init__()
  205. self.rank = group.rank()
  206. self.world_size = group.size()
  207. if deterministic:
  208. torch.manual_seed(0)
  209. d_vocab = 23
  210. d_model = 16
  211. self.embed_tokens = nn.Embedding(d_vocab, d_model)
  212. self.transformer = nn.Transformer(
  213. d_model=d_model,
  214. num_encoder_layers=2,
  215. num_decoder_layers=2,
  216. dim_feedforward=8,
  217. dropout=0.1,
  218. )
  219. self.output_proj = nn.Linear(d_model, d_vocab)
  220. # share the embedding and output projection weights
  221. self.output_proj.weight = self.embed_tokens.weight
  222. self.register_buffer(
  223. "vocab_bias", self.embed_tokens.weight.new_ones((d_model,))
  224. )
  225. self.register_buffer(
  226. "long_buffer",
  227. torch.zeros_like(self.vocab_bias, dtype=torch.long),
  228. ) # type: ignore[arg-type]
  229. self.bs = 2
  230. self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
  231. if cuda_init_mode == CUDAInitMode.CUDA_BEFORE:
  232. self = self.cuda()
  233. if deterministic:
  234. self.eval()
  235. def get_input(self, device):
  236. torch.manual_seed(1 + self.rank) # keep everything deterministic
  237. src = torch.arange(12, device=device).view(6, self.bs) # T x B
  238. tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B
  239. return (src, tgt)
  240. def forward(self, src_ids, tgt_ids):
  241. src = self.embed_tokens(src_ids)
  242. src = src + self.vocab_bias + self.long_buffer.type_as(src) # type: ignore[operator]
  243. tgt = self.embed_tokens(tgt_ids)
  244. tgt = self.bn(tgt)
  245. x = self.transformer(src, tgt)
  246. return self.output_proj(x)
  247. def get_loss(self, input, output):
  248. _, tgt = input
  249. return nn.functional.cross_entropy(
  250. output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum"
  251. )
  252. def run_backward(self, loss):
  253. loss.backward()
  254. @staticmethod
  255. def init(
  256. group: dist.ProcessGroup,
  257. fsdp_init_mode: FSDPInitMode,
  258. cuda_init_mode: CUDAInitMode,
  259. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  260. deterministic: bool = False,
  261. add_bn: bool = True,
  262. ) -> Union[nn.Module, FSDP]:
  263. """
  264. Initializes a :class:`TransformerWithSharedParams` instance.
  265. Args:
  266. fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
  267. any modules with FSDP. If ``RECURSIVE``, then wraps with
  268. top-level FSDP. By default, the top-level FSDP uses the
  269. ``ModuleWrapPolicy`` for encoder and decoder layers, but a
  270. different auto wrap policy may be specified via
  271. ``fsdp_kwargs``.
  272. cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
  273. fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
  274. forwarded to the FSDP constructor.
  275. deterministic (bool): Whether to make the model deterministic
  276. across constructions.
  277. add_bn (bool): Whether to include batch norm in the model.
  278. """
  279. if fsdp_kwargs is None:
  280. fsdp_kwargs = {}
  281. if fsdp_init_mode == FSDPInitMode.NO_FSDP:
  282. if isinstance(group, tuple):
  283. pg = group[0]
  284. else:
  285. pg = group
  286. return TransformerWithSharedParams(
  287. pg, cuda_init_mode, add_bn, deterministic
  288. )
  289. elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
  290. # Default to the `ModuleWrapPolicy`
  291. if "auto_wrap_policy" not in fsdp_kwargs:
  292. auto_wrap_policy = ModuleWrapPolicy(
  293. {
  294. TransformerEncoderLayer,
  295. TransformerDecoderLayer,
  296. }
  297. )
  298. else:
  299. auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy")
  300. if (
  301. "sharding_strategy" in fsdp_kwargs
  302. and fsdp_kwargs["sharding_strategy"]
  303. in {ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2}
  304. and not isinstance(group, tuple)
  305. ):
  306. fsdp_pg = None
  307. else:
  308. fsdp_pg = group
  309. if isinstance(group, tuple):
  310. tformer_pg = group[0]
  311. else:
  312. tformer_pg = group
  313. m = TransformerWithSharedParams(
  314. tformer_pg, cuda_init_mode, add_bn, deterministic
  315. )
  316. fsdp_model = FSDP(
  317. m,
  318. fsdp_pg,
  319. auto_wrap_policy=auto_wrap_policy,
  320. **fsdp_kwargs,
  321. )
  322. if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
  323. fsdp_model = fsdp_model.cuda()
  324. return fsdp_model
  325. raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
  326. def get_ignored_modules(self):
  327. return [self.transformer]
  328. class NestedWrappedModule(FSDPTestModel):
  329. def __init__(
  330. self,
  331. group: dist.ProcessGroup,
  332. wrap_fsdp: bool,
  333. cuda_init_mode: CUDAInitMode,
  334. deterministic: bool,
  335. **fsdp_kwargs,
  336. ):
  337. super().__init__()
  338. self.rank = group.rank()
  339. self.world_size = group.size()
  340. move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
  341. def _maybe_wrap(layer):
  342. if wrap_fsdp:
  343. return FSDP(layer, group, **fsdp_kwargs)
  344. return layer
  345. if deterministic:
  346. torch.manual_seed(0)
  347. self.module = nn.Sequential(
  348. _maybe_cuda(nn.Linear(8, 4), move_to_cuda),
  349. _maybe_wrap(
  350. nn.Sequential(
  351. _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
  352. _maybe_cuda(nn.Linear(16, 16), move_to_cuda),
  353. ),
  354. ),
  355. _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
  356. _maybe_cuda(nn.Linear(4, 8), move_to_cuda),
  357. )
  358. def get_input(self, device):
  359. torch.manual_seed(1 + self.rank) # keep everything deterministic
  360. return (torch.rand(4, 8, device=device),)
  361. def forward(self, x):
  362. return self.module(x)
  363. def get_loss(self, input, output):
  364. loss = output.sum()
  365. return loss
  366. def run_backward(self, loss):
  367. loss.backward()
  368. @staticmethod
  369. def init(
  370. group: dist.ProcessGroup,
  371. fsdp_init_mode: FSDPInitMode,
  372. cuda_init_mode: CUDAInitMode,
  373. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  374. deterministic: bool = False,
  375. ) -> nn.Module:
  376. """
  377. Initializes a :class:`NestedWrappedModule` instance.
  378. Args:
  379. fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
  380. any modules with FSDP. If ``RECURSIVE``, then wraps some nested
  381. modules with FSDP but not the top-level module. The model may
  382. later be wrapped with a top-level FSDP external to this method
  383. if desired.
  384. cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
  385. fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
  386. forwarded to the FSDP constructor.
  387. deterministic (bool): Whether to make the model deterministic
  388. across constructions.
  389. """
  390. if fsdp_kwargs is None:
  391. fsdp_kwargs = {}
  392. if fsdp_init_mode == FSDPInitMode.NO_FSDP:
  393. return NestedWrappedModule(
  394. group,
  395. wrap_fsdp=False,
  396. cuda_init_mode=cuda_init_mode,
  397. deterministic=deterministic,
  398. )
  399. elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
  400. # Does not wrap with top-level FSDP
  401. fsdp_model = NestedWrappedModule(
  402. group,
  403. wrap_fsdp=True,
  404. cuda_init_mode=cuda_init_mode,
  405. deterministic=deterministic,
  406. **fsdp_kwargs,
  407. )
  408. if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
  409. fsdp_model = fsdp_model.cuda()
  410. return fsdp_model
  411. raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
  412. class AlwaysWrapNestedWrappedModule(NestedWrappedModule):
  413. @staticmethod
  414. def init(
  415. group: dist.ProcessGroup,
  416. fsdp_init_mode: FSDPInitMode,
  417. cuda_init_mode: CUDAInitMode,
  418. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  419. deterministic: bool = False,
  420. ):
  421. """
  422. Initializes a :class:`NestedWrappedModule` instance, but unlike
  423. :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this
  424. wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap
  425. policy.
  426. """
  427. model = super(
  428. AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule
  429. ).init(
  430. group=group,
  431. fsdp_init_mode=FSDPInitMode.NO_FSDP,
  432. cuda_init_mode=cuda_init_mode,
  433. fsdp_kwargs=fsdp_kwargs,
  434. deterministic=deterministic,
  435. )
  436. if fsdp_init_mode == FSDPInitMode.NO_FSDP:
  437. return model
  438. elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
  439. fsdp_kwargs = fsdp_kwargs or {}
  440. fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs)
  441. if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
  442. fsdp_model = fsdp_model.cuda()
  443. return fsdp_model
  444. class NonUniformReqGradNWM(NestedWrappedModule):
  445. def __init__(
  446. self,
  447. group: dist.ProcessGroup,
  448. wrap_fsdp: bool,
  449. cuda_init_mode: CUDAInitMode,
  450. deterministic: bool,
  451. **fsdp_kwargs,
  452. ):
  453. super(NestedWrappedModule, self).__init__()
  454. # This `__init__` only differs from `NestedWrappedModule.__init__` in that
  455. # the last two `nn.Linear` layers are FSDP wrapped in a `nn.Sequential`
  456. # container. This arrangement results in all elements of the last two parameters
  457. # residing on a single rank. Freezing all parameters except those two allows us
  458. # to verify that `ShardedGradScaler` accommodates situations where some ranks
  459. # have no (non-zero sized) parameter shards.
  460. self.rank = group.rank()
  461. self.world_size = group.size()
  462. move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
  463. def _maybe_wrap(layer):
  464. if wrap_fsdp:
  465. return FSDP(layer, group, **fsdp_kwargs)
  466. return layer
  467. if deterministic:
  468. torch.manual_seed(0)
  469. self.module = nn.Sequential(
  470. _maybe_cuda(nn.Linear(8, 4), move_to_cuda),
  471. _maybe_wrap(
  472. nn.Sequential(
  473. _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
  474. _maybe_cuda(nn.Linear(16, 16), move_to_cuda),
  475. ),
  476. ),
  477. _maybe_wrap(
  478. nn.Sequential(
  479. _maybe_cuda(nn.Linear(16, 4), move_to_cuda),
  480. _maybe_cuda(nn.Linear(4, 8), move_to_cuda),
  481. ),
  482. ),
  483. )
  484. @staticmethod
  485. def _set_nonuniform_req_grad(model, req_grad_mask) -> None:
  486. for n, p in model.named_parameters():
  487. if not re.match(req_grad_mask, n):
  488. p.requires_grad_(False)
  489. @staticmethod
  490. def init(
  491. group: dist.ProcessGroup,
  492. fsdp_init_mode: FSDPInitMode,
  493. cuda_init_mode: CUDAInitMode,
  494. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  495. deterministic: bool = False,
  496. ):
  497. """
  498. Initializes a :class:`NestedWrappedModule` instance, but unlike
  499. :meth:`NestedWrappedModule.init`, it wraps a second :class:`torch.nn.Sequential`
  500. container to enable the desired non-uniform ``requires_grad``
  501. ``use_orig_params=True`` tests. For both ``RECURSIVE`` and ``NO_FSDP``
  502. init modes, freezes all parameters except the last two to validate
  503. ``ShardedGradScaler`` support for ranks with no (non-zero sized) local shards in
  504. FSDP ``use_orig_params=True`` mode.
  505. """
  506. # The parameters that should remain unfrozen are in `module.2.1`. The regex
  507. # pattern below matches the relevant parameter names both with and without
  508. # an interstitial FSDP module indicator (`_fsdp_wrapped_module`) present.
  509. req_grad_pattern = re.compile(r"module\.2.*\.1.*")
  510. if fsdp_init_mode == FSDPInitMode.NO_FSDP:
  511. ddp_model = NonUniformReqGradNWM(
  512. group,
  513. wrap_fsdp=False,
  514. cuda_init_mode=cuda_init_mode,
  515. deterministic=deterministic,
  516. )
  517. NonUniformReqGradNWM._set_nonuniform_req_grad(ddp_model, req_grad_pattern)
  518. return ddp_model
  519. elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
  520. if fsdp_kwargs is None:
  521. fsdp_kwargs = {}
  522. fsdp_model = NonUniformReqGradNWM(
  523. group,
  524. wrap_fsdp=True,
  525. cuda_init_mode=cuda_init_mode,
  526. deterministic=deterministic,
  527. **fsdp_kwargs,
  528. )
  529. if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
  530. fsdp_model = fsdp_model.cuda()
  531. NonUniformReqGradNWM._set_nonuniform_req_grad(fsdp_model, req_grad_pattern)
  532. return fsdp_model
  533. raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
  534. class ModuleWithDelay(FSDPTestModel):
  535. """This class wraps a :class:`FSDPTestModel` to optionally add a delay
  536. after computing the loss and/or before the gradient reduction."""
  537. def __init__(
  538. self,
  539. module: nn.Module,
  540. delay_after_loss_ms: int,
  541. delay_before_reduction_ms: int,
  542. ):
  543. super().__init__()
  544. self.delay_after_loss_ms = delay_after_loss_ms
  545. self.delay_before_reduction_ms = delay_before_reduction_ms
  546. self.module = module
  547. def get_input(self, device):
  548. return self.module.get_input(device)
  549. def forward(self, x):
  550. return self.module(x)
  551. def get_loss(self, input, output):
  552. loss = self.module.get_loss(input, output)
  553. if self.delay_after_loss_ms > 0:
  554. torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
  555. return loss
  556. def run_backward(self, loss):
  557. orig_reduce_scatter = torch.distributed.reduce_scatter_tensor
  558. def _delayed_reduce_scatter(*args, **kwargs):
  559. if self.delay_before_reduction_ms > 0:
  560. torch.cuda._sleep(
  561. int(self.delay_before_reduction_ms * get_cycles_per_ms())
  562. )
  563. return orig_reduce_scatter(*args, **kwargs)
  564. with mock.patch(
  565. "torch.distributed.reduce_scatter_tensor", _delayed_reduce_scatter
  566. ):
  567. self.module.run_backward(loss)
  568. @staticmethod
  569. def init(
  570. module_class: Type[FSDPTestModel],
  571. *model_args: Any,
  572. delay_after_loss_ms: int,
  573. delay_before_reduction_ms: int,
  574. **model_kwargs: Any,
  575. ):
  576. """
  577. Args:
  578. module_class (Type[FSDPTestModel]): Wrapped module class to which
  579. to add delays.
  580. model_args: Positional arguments forwarded to the ``module_class``
  581. ``init()``.
  582. delay_after_loss_ms (int): Delay after computing the loss/before
  583. the optimizer step (in ms).
  584. delay_before_reduction_ms (int): Delay before reduce-scattering
  585. gradients (in ms).
  586. model_kwargs: Keyword arguments forwarded to the ``module_class``
  587. ``init()``.
  588. """
  589. return ModuleWithDelay(
  590. module_class.init(*model_args, **model_kwargs),
  591. delay_after_loss_ms,
  592. delay_before_reduction_ms,
  593. )
  594. class NestedWrappedModuleWithDelay(ModuleWithDelay):
  595. @staticmethod
  596. def init( # type: ignore[override]
  597. group: dist.ProcessGroup,
  598. fsdp_init_mode: FSDPInitMode,
  599. cuda_init_mode: CUDAInitMode = CUDAInitMode.CUDA_AFTER,
  600. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  601. deterministic: bool = False,
  602. delay_after_loss_ms: int = 0,
  603. delay_before_reduction_ms: int = 0,
  604. ):
  605. return ModuleWithDelay.init(
  606. NestedWrappedModule,
  607. group=group,
  608. fsdp_init_mode=fsdp_init_mode,
  609. cuda_init_mode=cuda_init_mode,
  610. fsdp_kwargs=fsdp_kwargs,
  611. deterministic=deterministic,
  612. delay_after_loss_ms=delay_after_loss_ms,
  613. delay_before_reduction_ms=delay_before_reduction_ms,
  614. )
  615. class DummyDDP(nn.Module):
  616. def __init__(self, module):
  617. super().__init__()
  618. self.module = module
  619. def forward(self, *args, **kwargs):
  620. return self.module(*args, **kwargs)
  621. class MixtureOfExperts(NestedWrappedModule):
  622. def __init__(
  623. self,
  624. group: dist.ProcessGroup,
  625. wrap_fsdp: bool,
  626. cuda_init_mode: CUDAInitMode,
  627. delay_before_free_ms: int,
  628. deterministic: bool,
  629. **fsdp_kwargs,
  630. ):
  631. super().__init__(
  632. group=group,
  633. wrap_fsdp=wrap_fsdp,
  634. cuda_init_mode=cuda_init_mode,
  635. deterministic=deterministic,
  636. )
  637. self.group = group
  638. self.delay_before_free_ms = delay_before_free_ms
  639. self.wrap_fsdp = wrap_fsdp
  640. self.move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
  641. if deterministic:
  642. # Give each rank different expert parameters
  643. torch.manual_seed(42 + self.rank)
  644. d_expert = 23
  645. d_shared = 12
  646. d_input = 8
  647. expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)
  648. self.num_expert_params = sum(p.numel() for p in expert.parameters())
  649. for p in expert.parameters():
  650. p.expert = True # type: ignore[attr-defined]
  651. if deterministic:
  652. # Keep all other parameters the same across ranks
  653. torch.manual_seed(0)
  654. shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)
  655. if wrap_fsdp:
  656. # we create a process group of size 1 for the expert params
  657. expert_group = torch.distributed.new_group(
  658. [group.rank()]
  659. ) # world size 1 means no shard
  660. expert = FSDP(expert, expert_group, **fsdp_kwargs) # type: ignore[assignment]
  661. shared = FSDP(shared, group, **fsdp_kwargs) # type: ignore[assignment]
  662. self.module = nn.Sequential(
  663. _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
  664. shared,
  665. expert,
  666. _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda),
  667. )
  668. def forward(self, x):
  669. if self.delay_before_free_ms > 0:
  670. expert = self.module[2]
  671. if isinstance(expert, FSDP):
  672. orig_reshard = torch.distributed.fsdp._runtime_utils._reshard
  673. def _delayed_reshard(*args, **kwargs):
  674. torch.cuda._sleep(
  675. int(self.delay_before_free_ms * get_cycles_per_ms())
  676. )
  677. return orig_reshard(*args, **kwargs)
  678. # This patch covers any `import torch..._reshard` uses.
  679. with mock.patch(
  680. "torch.distributed.fsdp._runtime_utils._reshard", _delayed_reshard
  681. ):
  682. return self.module(x)
  683. return self.module(x)
  684. def run_backward(self, loss):
  685. loss.backward()
  686. # Manually reduce gradients if not wrapped in FullyShardedDataParallel
  687. if not self.wrap_fsdp:
  688. with torch.no_grad():
  689. for p in self.parameters():
  690. if hasattr(p, "expert"):
  691. continue # these params don't need grad reduction
  692. if p.grad is not None:
  693. p.grad.div_(self.world_size)
  694. torch.distributed.all_reduce(p.grad, group=self.group)
  695. @staticmethod
  696. def init(
  697. group: dist.ProcessGroup,
  698. fsdp_init_mode: FSDPInitMode,
  699. cuda_init_mode: CUDAInitMode,
  700. fsdp_kwargs: Optional[Dict[str, Any]] = None,
  701. deterministic: bool = False,
  702. delay_before_free_ms: int = 0,
  703. ):
  704. """
  705. Initializes a :class:`MixtureOfExperts` instance.
  706. Args:
  707. fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
  708. any modules with FSDP. If ``RECURSIVE``, then wraps some nested
  709. modules with FSDP, including the expert and shared layers, but
  710. not the top-level module. The model may later be wrapped with a
  711. top-level FSDP external to this method if desired.
  712. cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
  713. fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
  714. forwarded to the FSDP constructor.
  715. deterministic (bool): Whether to make the model deterministic
  716. across constructions.
  717. delay_before_free_ms (int): Delay before resharding expert
  718. parameters in the forward pass (in ms).
  719. """
  720. if fsdp_kwargs is None:
  721. fsdp_kwargs = {}
  722. if fsdp_init_mode == FSDPInitMode.NO_FSDP:
  723. return MixtureOfExperts(
  724. group,
  725. wrap_fsdp=False,
  726. cuda_init_mode=cuda_init_mode,
  727. delay_before_free_ms=delay_before_free_ms,
  728. deterministic=deterministic,
  729. )
  730. elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
  731. # Does not wrap with top-level FSDP
  732. fsdp_model = MixtureOfExperts(
  733. group,
  734. wrap_fsdp=True,
  735. cuda_init_mode=cuda_init_mode,
  736. delay_before_free_ms=delay_before_free_ms,
  737. deterministic=deterministic,
  738. **fsdp_kwargs,
  739. )
  740. if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
  741. fsdp_model = fsdp_model.cuda()
  742. return fsdp_model
  743. raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
  744. class MLP(nn.Module):
  745. def __init__(
  746. self,
  747. dim: int,
  748. device: Optional[torch.device] = None,
  749. *,
  750. bias: bool = True,
  751. with_buffer: bool = False,
  752. dim_multiplier: int = 4,
  753. ):
  754. super().__init__()
  755. self.in_proj = nn.Linear(dim, dim_multiplier * dim, device=device, bias=bias)
  756. self.out_proj = nn.Linear(dim_multiplier * dim, dim, device=device, bias=bias)
  757. if with_buffer:
  758. self.register_buffer("buffer", torch.randn((dim,), device=device))
  759. else:
  760. self.buffer = None
  761. def forward(self, x: torch.Tensor) -> torch.Tensor:
  762. z = self.in_proj(x)
  763. z = F.relu(z)
  764. z = self.out_proj(z)
  765. z = F.relu(z)
  766. if self.buffer is not None:
  767. z = z + self.buffer
  768. return z
  769. def reset_parameters(self):
  770. if self.buffer is not None:
  771. torch.nn.init.normal_(self.buffer)
  772. class MLPStack(nn.Sequential):
  773. def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False):
  774. modules: List[nn.Module] = [
  775. # Use multiplier of 3 to exercise uneven case
  776. MLP(mlp_dim, dim_multiplier=3),
  777. MLP(mlp_dim),
  778. MLP(mlp_dim, dim_multiplier=3),
  779. ]
  780. if with_seq_parallel:
  781. modules.append(nn.LayerNorm(mlp_dim, bias=False))
  782. super().__init__(*modules)
  783. self.with_seq_parallel = with_seq_parallel
  784. def parallelize(
  785. self,
  786. tp_mesh: DeviceMesh,
  787. dp_mesh: DeviceMesh,
  788. use_activation_checkpointing: bool,
  789. **fsdp_kwargs,
  790. ) -> "MLPStack":
  791. parallelize_plan = {
  792. # Pass `use_local_output=False` to keep as DTensor to preserve
  793. # uneven activation dims
  794. "0.in_proj": ColwiseParallel(use_local_output=False),
  795. "0.out_proj": RowwiseParallel(use_local_output=False),
  796. "1.in_proj": ColwiseParallel(use_local_output=False),
  797. "1.out_proj": RowwiseParallel(use_local_output=False),
  798. "2.in_proj": ColwiseParallel(use_local_output=False),
  799. "2.out_proj": RowwiseParallel(output_layouts=Shard(1))
  800. if self.with_seq_parallel
  801. else RowwiseParallel(),
  802. }
  803. if self.with_seq_parallel:
  804. parallelize_plan["3"] = SequenceParallel(sequence_dim=1)
  805. parallelize_module(self, device_mesh=tp_mesh, parallelize_plan=parallelize_plan)
  806. for module in self:
  807. if isinstance(module, nn.LayerNorm):
  808. continue
  809. if use_activation_checkpointing:
  810. checkpoint(module)
  811. fully_shard(module, mesh=dp_mesh, **fsdp_kwargs)
  812. fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)
  813. return self
  814. class DoubleLinear(nn.Module):
  815. """
  816. This can be used for returning multiple outputs from a module
  817. (``use_second_linear=True``) or for having an unused module (``False``).
  818. """
  819. def __init__(self, dim: int, use_second_linear: bool = True):
  820. super().__init__()
  821. self.lin1 = nn.Linear(dim, dim)
  822. self.lin2 = nn.Linear(dim, dim)
  823. self.relu = nn.ReLU()
  824. self.use_second_linear = use_second_linear
  825. def forward(
  826. self, x: torch.Tensor
  827. ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
  828. if self.use_second_linear:
  829. return self.relu(self.lin1(x)), self.relu(self.lin2(x))
  830. return self.relu(self.lin1(x))
  831. # NOTE: For these patch methods, if we want safety under multi-threading (e.g.
  832. # when using multi-threaded process group), then we want:
  833. # (1) a barrier immediately after reading the original value to ensure that all
  834. # threads see the same original value
  835. # (2) a barrier immediately before restoring the original value to ensure that
  836. # all threads use the patched value inside the context
  837. @contextlib.contextmanager
  838. def patch_all_gather(new_all_gather_into_tensor: Callable):
  839. orig_all_gather = dist.all_gather_into_tensor
  840. dist.barrier()
  841. dist.all_gather_into_tensor = new_all_gather_into_tensor
  842. try:
  843. yield
  844. finally:
  845. dist.barrier()
  846. dist.all_gather_into_tensor = orig_all_gather
  847. @contextlib.contextmanager
  848. def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
  849. orig_reduce_scatter = dist.reduce_scatter_tensor
  850. dist.barrier()
  851. dist.reduce_scatter_tensor = new_reduce_scatter_tensor
  852. try:
  853. yield
  854. finally:
  855. dist.barrier()
  856. dist.reduce_scatter_tensor = orig_reduce_scatter
  857. @contextlib.contextmanager
  858. def patch_all_reduce(new_all_reduce: Callable):
  859. orig_all_reduce = dist.all_reduce
  860. dist.barrier()
  861. dist.all_reduce = new_all_reduce
  862. try:
  863. yield
  864. finally:
  865. dist.barrier()
  866. dist.all_reduce = orig_all_reduce
  867. @no_type_check
  868. @contextlib.contextmanager
  869. def patch_unshard(new_unshard: Callable):
  870. orig_unshard = FSDPParamGroup.unshard
  871. dist.barrier()
  872. FSDPParamGroup.unshard = new_unshard
  873. try:
  874. yield
  875. finally:
  876. dist.barrier()
  877. FSDPParamGroup.unshard = orig_unshard
  878. @no_type_check
  879. @contextlib.contextmanager
  880. def patch_post_backward(new_post_backward: Callable):
  881. orig_post_backward = FSDPParamGroup.post_backward
  882. dist.barrier()
  883. FSDPParamGroup.post_backward = new_post_backward
  884. try:
  885. yield
  886. finally:
  887. dist.barrier()
  888. FSDPParamGroup.post_backward = orig_post_backward
  889. @no_type_check
  890. @contextlib.contextmanager
  891. def patch_register_post_backward_hook_backward(new_backward: Callable):
  892. orig_backward = RegisterPostBackwardFunction.backward
  893. dist.barrier()
  894. RegisterPostBackwardFunction.backward = new_backward
  895. try:
  896. yield
  897. finally:
  898. dist.barrier()
  899. RegisterPostBackwardFunction.backward = orig_backward
  900. def reduce_scatter_with_assert(
  901. cls,
  902. orig_reduce_scatter: Callable,
  903. assert_fn: Callable, # `assert_fn(output: Tensor)`
  904. *args: Any,
  905. **kwargs: Any,
  906. ):
  907. if len(args) > 0:
  908. output = args[0]
  909. elif "output" in kwargs:
  910. output = kwargs["output"]
  911. else:
  912. raise AssertionError(
  913. f"Cannot get reduce-scatter output from\nargs: {args}\nkwargs: {kwargs}"
  914. )
  915. assert_fn(output)
  916. return orig_reduce_scatter(*args, **kwargs)
  917. def check_sharded_parity(
  918. cls, # unit test class
  919. replicated_module: nn.Module,
  920. sharded_module: nn.Module,
  921. prefixes_to_ignore: Tuple[str, ...] = (),
  922. ):
  923. for (replicated_name, replicated_param), (sharded_name, sharded_param) in zip(
  924. replicated_module.named_parameters(), sharded_module.named_parameters()
  925. ):
  926. clean_sharded_name = sharded_name
  927. for prefix in prefixes_to_ignore:
  928. clean_sharded_name = clean_sharded_name.replace(prefix, "")
  929. cls.assertEqual(replicated_name, clean_sharded_name)
  930. cls.assertIsInstance(sharded_param, DTensor)
  931. assert isinstance(sharded_param, DTensor) # mypy
  932. mesh, placements = sharded_param.device_mesh, sharded_param.placements
  933. if tuple(placements) == (Shard(0), Shard(0)):
  934. raise AssertionError(
  935. "FSDP's (Shard(0), Shard(0)) layout differs from distribute_tensor(), "
  936. "so we cannot check for equality using it"
  937. )
  938. sharded_ref_param = distribute_tensor(replicated_param, mesh, placements)
  939. cls.assertEqual(sharded_param.to_local(), sharded_ref_param.to_local())
  940. if replicated_param.grad is None:
  941. cls.assertIsNone(sharded_param.grad)
  942. continue
  943. cls.assertIsNotNone(sharded_param.grad)
  944. sharded_ref_grad = distribute_tensor(replicated_param.grad, mesh, placements)
  945. cls.assertIsInstance(sharded_param.grad, DTensor)
  946. assert isinstance(sharded_param.grad, DTensor) # mypy
  947. cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local())
  948. class FSDPTestMultiThread(MultiThreadedTestCase):
  949. @property
  950. def world_size(self):
  951. return torch.cuda.device_count() if torch.cuda.is_available() else 4
  952. def setUp(self):
  953. super().setUp()
  954. self._spawn_threads()
  955. def run_subtests(self, *args, **kwargs):
  956. return run_subtests(self, *args, **kwargs)
  957. def perThreadSetUp(self):
  958. torch._dynamo.reset()
  959. def perThreadTearDown(self):
  960. torch._dynamo.reset()
  961. class FSDPTest(MultiProcessTestCase):
  962. def setUp(self):
  963. super().setUp()
  964. # Set TORCH_NCCL_DESYNC_DEBUG=0 to disable the NCCL `workCleanupLoop()`,
  965. # which can cause unit test flakiness:
  966. # https://github.com/pytorch/pytorch/issues/90848
  967. os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
  968. self._spawn_processes()
  969. @property
  970. def world_size(self):
  971. return min(torch.cuda.device_count(), 8) if torch.cuda.is_available() else 4
  972. @property
  973. def process_group(self):
  974. return dist.distributed_c10d._get_default_group()
  975. @property
  976. def init_method(self):
  977. return f"{FILE_SCHEMA}{self.file_name}"
  978. def _check_cpu_offload(self, fsdp_model, cpu_offload):
  979. self.assertEqual(cpu_offload, fsdp_model.cpu_offload)
  980. def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
  981. self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)
  982. def _check_forward_prefetch(self, fsdp_model, forward_prefetch):
  983. self.assertEqual(forward_prefetch, fsdp_model.forward_prefetch)
  984. def run_subtests(self, *args, **kwargs):
  985. return run_subtests(self, *args, **kwargs)
  986. @classmethod
  987. def _run(cls, rank, test_name, file_name, pipe):
  988. self = cls(test_name)
  989. self.rank = rank
  990. self.file_name = file_name
  991. print(f"dist init r={self.rank}, world={self.world_size}")
  992. # Specify gloo backend to make 'init_process_group()' succeed,
  993. # Actual tests will be skipped if there is no enough GPUs.
  994. backend = "nccl" if torch.cuda.is_available() else "gloo"
  995. try:
  996. dist.init_process_group(
  997. init_method=self.init_method,
  998. backend=backend,
  999. world_size=int(self.world_size),
  1000. rank=self.rank,
  1001. )
  1002. except RuntimeError as e:
  1003. if "recompile" in e.args[0]:
  1004. sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
  1005. raise
  1006. device_ids = None
  1007. if torch.cuda.is_available() and torch.cuda.device_count():
  1008. device_id = self.rank % torch.cuda.device_count()
  1009. torch.cuda.set_device(device_id)
  1010. device_ids = [device_id]
  1011. # Execute barrier prior to running test to ensure that every process
  1012. # has finished initialization and that the following test
  1013. # immediately exiting due to a skip doesn't cause flakiness.
  1014. dist.barrier(device_ids=device_ids)
  1015. torch._dynamo.reset()
  1016. self.run_test(test_name, pipe)
  1017. torch._dynamo.reset()
  1018. dist.barrier(device_ids=device_ids)
  1019. dist.destroy_process_group()
  1020. def _train_for_several_steps(
  1021. self,
  1022. model: nn.Module,
  1023. num_steps: int,
  1024. autocast: bool,
  1025. lr: float = 0.01,
  1026. fsdp_cpu_offload: Optional[CPUOffload] = None,
  1027. save_model: bool = False,
  1028. mixed_precision: Optional[MixedPrecision] = None,
  1029. enable_sharded_grad_scaler: bool = False,
  1030. use_pure_fp16: bool = False,
  1031. sharded_grad_scaler_kwargs: Optional[Dict[str, Any]] = None,
  1032. ):
  1033. cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
  1034. model_device = next(model.parameters()).device
  1035. if sharded_grad_scaler_kwargs is None:
  1036. sharded_grad_scaler_kwargs = {}
  1037. sharded_grad_scaler = ShardedGradScaler(
  1038. enabled=enable_sharded_grad_scaler, **sharded_grad_scaler_kwargs
  1039. )
  1040. # use SGD with momentum instead of Adam, since Adam is scale invariant
  1041. # and this makes it bad for tests
  1042. optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  1043. for _ in range(num_steps):
  1044. optim.zero_grad()
  1045. with torch.cuda.amp.autocast(enabled=autocast):
  1046. # Inputs always cuda regardless of cpu offloading, or model.device
  1047. input = model.module.get_input(torch.device("cuda"))
  1048. if use_pure_fp16 or (mixed_precision and not isinstance(model, FSDP)):
  1049. if isinstance(input, torch.Tensor):
  1050. input = input.half()
  1051. else:
  1052. input = tuple(x.half() for x in input)
  1053. output = model(*input)
  1054. # Post-forward, if CPU offloading model param should be on CPU.
  1055. if (
  1056. cpu_offload_params
  1057. and isinstance(model, FSDP)
  1058. # If not resharding after forward, the parameters are still
  1059. # exposed as unsharded views into the GPU flat parameter
  1060. and model.sharding_strategy
  1061. not in NO_RESHARD_AFTER_FORWARD_STRATEGIES
  1062. ):
  1063. for p in model.parameters():
  1064. # Params should always be on CPU
  1065. self.assertEqual(p.device, torch.device("cpu"))
  1066. loss = model.module.get_loss(input, output).to(model_device)
  1067. loss = sharded_grad_scaler.scale(loss)
  1068. if not mixed_precision and not use_pure_fp16:
  1069. assert (
  1070. loss.dtype == torch.float32
  1071. ), "loss data type should be float32, as the original \
  1072. parameter data type is float32."
  1073. else:
  1074. if use_pure_fp16:
  1075. self.assertEqual(loss.dtype, torch.float16)
  1076. # FSDP loss is fp16, DDP AMP loss is fp32
  1077. elif isinstance(model, FSDP):
  1078. assert mixed_precision is not None # mypy
  1079. self.assertEqual(loss.dtype, mixed_precision.param_dtype)
  1080. else:
  1081. self.assertEqual(loss.dtype, torch.float32)
  1082. model.module.run_backward(loss)
  1083. # Post-backward, if CPU offloading model params should be on CPU.
  1084. if cpu_offload_params and isinstance(model, FSDP):
  1085. for p in model.parameters():
  1086. # Params should always be on CPU
  1087. self.assertEqual(p.device, torch.device("cpu"))
  1088. # Unscale the gradients and step
  1089. sharded_grad_scaler.step(optim)
  1090. # Update the scale factor
  1091. sharded_grad_scaler.update()
  1092. # if save_model, simulate save + load.
  1093. if save_model:
  1094. state_dict = {k: v.clone() for k, v in model.state_dict().items()}
  1095. # Zero params, if save/load state_dict did not work properly, this
  1096. # would break the parity test with DDP.
  1097. _zero_model(model)
  1098. model.load_state_dict(state_dict)
  1099. if isinstance(model, FSDP):
  1100. model._assert_state(TrainingState.IDLE)
  1101. return loss.detach() # type: ignore[possibly-undefined]
  1102. def _test_fsdp_parity(
  1103. self,
  1104. model_class: Type[FSDPTestModel],
  1105. fsdp_init_mode: FSDPInitMode,
  1106. cuda_init_mode: CUDAInitMode,
  1107. ref_init_fn: Optional[Callable] = None,
  1108. num_iters: int = 2,
  1109. save_model: bool = True,
  1110. cpu_offload: CPUOffload = CPUOffload(),
  1111. backward_prefetch: Optional[BackwardPrefetch] = None,
  1112. sharding_strategy: Optional[ShardingStrategy] = None,
  1113. mixed_precision: Optional[MixedPrecision] = None,
  1114. forward_prefetch: bool = False,
  1115. use_orig_params: bool = False,
  1116. enable_sharded_grad_scaler: bool = False,
  1117. use_pure_fp16: bool = False,
  1118. init_kwargs: Optional[Dict[str, Any]] = None,
  1119. sharded_grad_scaler_kwargs: Optional[Dict[str, Any]] = None,
  1120. **fsdp_kwargs,
  1121. ):
  1122. """
  1123. Tests FSDP training against a reference, which defaults to DDP but
  1124. may be customized with ``ref_init_fn``.
  1125. Args:
  1126. model_class (Type[FSDPTestModel]): A model class that inherits from
  1127. ``FSDPTestModel``, which defines the expected interface.
  1128. fsdp_init_mode (FSDPInitMode): The mode to initialize the
  1129. FSDP-wrapped model. This should not be ``NO_FSDP``.
  1130. ref_init_fn (Optional[Callable]): A callable to invoke that wraps a
  1131. non-wrapped model to construct the reference model, where this
  1132. wrapper should provide data parallel semantics. If ``None``,
  1133. then the callable defaults to the DDP constructor.
  1134. """
  1135. assert (
  1136. fsdp_init_mode != FSDPInitMode.NO_FSDP
  1137. ), "Expects an FSDP init mode that wraps with FSDP"
  1138. if init_kwargs is None:
  1139. init_kwargs = {}
  1140. lr = 1e-2
  1141. rank = self.process_group.rank()
  1142. # Establish reference behavior with DDP
  1143. model = model_class.init(
  1144. self.process_group,
  1145. FSDPInitMode.NO_FSDP,
  1146. CUDAInitMode.CUDA_BEFORE,
  1147. deterministic=True,
  1148. **init_kwargs,
  1149. )
  1150. if ref_init_fn is None:
  1151. ref_model = DDP(model, device_ids=[rank], output_device=rank)
  1152. else:
  1153. ref_model = ref_init_fn(model)
  1154. if use_pure_fp16:
  1155. ref_model = ref_model.half()
  1156. ref_loss = self._train_for_several_steps(
  1157. ref_model,
  1158. num_iters,
  1159. autocast=mixed_precision is not None,
  1160. lr=lr,
  1161. fsdp_cpu_offload=cpu_offload,
  1162. mixed_precision=mixed_precision,
  1163. enable_sharded_grad_scaler=enable_sharded_grad_scaler,
  1164. use_pure_fp16=use_pure_fp16,
  1165. sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs,
  1166. )
  1167. ddp_params = list(ref_model.parameters())
  1168. # Check against FSDP behavior
  1169. fsdp_kwargs.update(
  1170. {
  1171. "cpu_offload": cpu_offload,
  1172. "backward_prefetch": backward_prefetch,
  1173. "sharding_strategy": sharding_strategy,
  1174. "mixed_precision": mixed_precision,
  1175. "forward_prefetch": forward_prefetch,
  1176. "use_orig_params": use_orig_params,
  1177. }
  1178. )
  1179. try:
  1180. fsdp_model = model_class.init(
  1181. self.process_group,
  1182. fsdp_init_mode,
  1183. cuda_init_mode,
  1184. fsdp_kwargs,
  1185. deterministic=True,
  1186. **init_kwargs,
  1187. )
  1188. except Exception as e:
  1189. raise ValueError(f"Initializing {model_class} raised error {str(e)}") from e
  1190. if not isinstance(fsdp_model, FSDP):
  1191. # Enforce that we wrap with top-level FSDP since we are comparing
  1192. # assuming a data parallel reference and some test models may not
  1193. # do so in their `init()` method
  1194. fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs)
  1195. if use_pure_fp16:
  1196. # Change the model parameter dtype after FSDP initialization
  1197. fsdp_model = fsdp_model.half()
  1198. if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
  1199. fsdp_model = fsdp_model.cuda()
  1200. offload_params = cpu_offload is not None and cpu_offload.offload_params
  1201. # Offloading parameters with `CUDA_AFTER` should raise an error during
  1202. # lazy initialization due to the parameter devices not being CPU;
  1203. # otherwise, all parameter devices should be CPU
  1204. expects_device_error = (
  1205. offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER
  1206. )
  1207. expects_cpu_device = (
  1208. offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER
  1209. )
  1210. if expects_cpu_device:
  1211. cpu_device = torch.device("cpu")
  1212. for param in fsdp_model.parameters():
  1213. self.assertEqual(param.device, cpu_device)
  1214. context = (
  1215. self.assertRaisesRegex(
  1216. RuntimeError,
  1217. "An FSDP-managed module with parameter CPU offloading enabled "
  1218. "has parameters on cuda",
  1219. )
  1220. if expects_device_error
  1221. else nullcontext()
  1222. )
  1223. with context:
  1224. fsdp_loss = self._train_for_several_steps(
  1225. fsdp_model,
  1226. num_iters,
  1227. autocast=False,
  1228. lr=lr,
  1229. fsdp_cpu_offload=cpu_offload,
  1230. save_model=save_model,
  1231. mixed_precision=mixed_precision,
  1232. enable_sharded_grad_scaler=enable_sharded_grad_scaler,
  1233. use_pure_fp16=use_pure_fp16,
  1234. sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs,
  1235. )
  1236. # No need to check for parameter and loss parity if expecting an error
  1237. if expects_device_error:
  1238. return
  1239. # Check parameter devices are CPU if offloading to CPU before calling
  1240. # `get_full_params()`, which will cast the parameters to FP32
  1241. if offload_params:
  1242. cpu_device = torch.device("cpu")
  1243. for param in fsdp_model.parameters():
  1244. self.assertEqual(param.device, cpu_device)
  1245. fsdp_loss = fsdp_loss.cuda()
  1246. fsdp_unsharded_params = get_full_params(fsdp_model)
  1247. # Do not check dtype since the reference DDP loss may not be the same
  1248. # dtype as the FSDP loss in the case of mixed precision
  1249. torch.testing.assert_close(ref_loss, fsdp_loss, check_dtype=False)
  1250. # Do not check for parameter parity if using mixed precision since (1)
  1251. # the DDP parameters are in FP16 (from `half()`) while the FSDP
  1252. # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs
  1253. # the optimizer in FP16 while FSDP runs it in FP32
  1254. # TODO: Disable checking the parameters for pure FP16 due to floating
  1255. # point inaccuracy. Note that this means that the backward pass is not
  1256. # checked: https://github.com/pytorch/pytorch/issues/90784
  1257. if mixed_precision is None and not use_pure_fp16:
  1258. self.assertEqual(
  1259. ddp_params,
  1260. fsdp_unsharded_params,
  1261. exact_device=True,
  1262. msg="FSDP did not match DDP",
  1263. )
  1264. def test_compiled_fsdp(compile_compute_on_module: Optional[type] = None):
  1265. def fully_shard_with_compiled_compute(*args, **kwargs):
  1266. torch.distributed._composable.fsdp.fully_shard(*args, **kwargs) # type: ignore[operator]
  1267. if compile_compute_on_module is None or isinstance(
  1268. args[0], compile_compute_on_module
  1269. ):
  1270. args[0].compile()
  1271. class FullyShardMode(Enum):
  1272. EAGER = auto()
  1273. COMPILED_COMPUTE = auto()
  1274. def decorator(func):
  1275. @wraps(func)
  1276. def wrapper(*args, **kwargs):
  1277. original_fully_shard = torch.distributed._composable.fsdp.fully_shard
  1278. for mode in FullyShardMode:
  1279. if mode != FullyShardMode.EAGER and not has_triton():
  1280. warnings.warn("Inductor on GPU needs Triton and recent GPU arch")
  1281. continue
  1282. # barrier to ensure thread reading the same value
  1283. original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks
  1284. original_compile_threads = torch._inductor.config.compile_threads
  1285. torch.distributed.barrier()
  1286. if mode == FullyShardMode.EAGER:
  1287. fully_shard_patch = original_fully_shard
  1288. elif mode == FullyShardMode.COMPILED_COMPUTE:
  1289. torch._dynamo.config.skip_fsdp_hooks = True
  1290. torch._inductor.config.compile_threads = 1
  1291. fully_shard_patch = fully_shard_with_compiled_compute # type: ignore[assignment]
  1292. else:
  1293. raise NotImplementedError(
  1294. f"Need to implement FullyShardMode={mode}"
  1295. )
  1296. # fully_shard is imported as a global
  1297. # through `from ... import fully_shard`
  1298. func.__globals__[original_fully_shard.__name__] = fully_shard_patch
  1299. func(*args, **kwargs)
  1300. # other threads use patched func before this thread restores
  1301. torch.distributed.barrier()
  1302. func.__globals__[original_fully_shard.__name__] = original_fully_shard
  1303. torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks
  1304. torch._inductor.config.compile_threads = original_compile_threads
  1305. return wrapper
  1306. return decorator
  1307. class SkipModule(nn.Module):
  1308. def __init__(self):
  1309. super().__init__()
  1310. self.lin = nn.Linear(10, 10, bias=False)
  1311. def forward(self, x):
  1312. return self.lin(x)
  1313. class NestedLinear(nn.Module):
  1314. def __init__(self, fsdp_wrap):
  1315. super().__init__()
  1316. if fsdp_wrap:
  1317. self.nested_linear = wrap(nn.Linear(10, 10, bias=False).cuda())
  1318. else:
  1319. self.nested_linear = nn.Linear(10, 10, bias=False).cuda()
  1320. def forward(self, x):
  1321. return self.nested_linear(x)
  1322. class SkipModel(nn.Module):
  1323. def __init__(self, double_nest):
  1324. super().__init__()
  1325. self.linear = nn.Linear(10, 10, bias=False).cuda()
  1326. self.linear_skip = SkipModule().cuda()
  1327. self.nested_linear = wrap(NestedLinear(fsdp_wrap=double_nest))
  1328. def forward(self, x):
  1329. x = self.linear(x)
  1330. x = self.linear_skip(x)
  1331. x = self.nested_linear(x)
  1332. return x