api.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. # mypy: allow-untyped-defs
  2. from abc import ABC, abstractmethod
  3. from contextlib import contextmanager, nullcontext
  4. from copy import copy
  5. from dataclasses import dataclass
  6. from functools import partial, wraps
  7. from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union
  8. import torch
  9. import torch.distributed as dist
  10. # We need to import _functional_collectives to trigger op registration
  11. import torch.distributed._functional_collectives
  12. import torch.nn as nn
  13. import torch.utils._pytree as pytree
  14. from functorch import make_fx
  15. from torch import fx
  16. from torch._decomp.decompositions import native_layer_norm_backward
  17. from torch._subclasses.fake_tensor import FakeTensorMode
  18. from torch.distributed._spmd.data_parallel import gradients_tagging
  19. from torch.distributed._spmd.parallel_mode import (
  20. DataParallel,
  21. DTensorExpandMode,
  22. ParallelMode,
  23. )
  24. from torch.distributed._tensor import Placement
  25. from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo, CodeGen
  26. from torch.nn.utils import stateless
  27. from torch.nn.utils._named_member_accessor import NamedMemberAccessor
  28. class Override(ABC):
  29. r"""Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
  30. This is useful when any part of the model is not traceable or if you prefer
  31. to not trace it due to any reason. More specifically, users can implement
  32. :meth:`torch.distributed._spmd.Override.replacement` to replace an original
  33. submodule with the return new submodule. The new submodule contains
  34. operations that users preferred to be traced, which simply be a dummy
  35. placeholder operator. After tracing, users can implement
  36. :meth:`torch.distributed._spmd.Override.transform` to transform the traced
  37. graph, where the dummy placeholder operator serves as an anchor to insert
  38. new sub-graphs.
  39. """
  40. @abstractmethod
  41. def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module:
  42. r"""Implement this method to return a new :class:`nn.Module` instance to replace the ``orig_submodule``
  43. argument in the model.
  44. This helps if ``orig_submodule`` is not traceable or should not be traced.
  45. Args:
  46. fqn (str): fully quantified name of the submodule.
  47. orig_submodule (class:`nn.Module`): original submodule instance to replace.
  48. Returns:
  49. A new :class:`nn.Module` instance to replace the original one.
  50. """
  51. pass
  52. @abstractmethod
  53. def transform(
  54. self,
  55. gm: fx.GraphModule,
  56. flat_state: List[torch.Tensor],
  57. ) -> fx.GraphModule:
  58. r"""
  59. Given a DTensor-expanded graph and sharding schema for every node,
  60. conduct additional transformation for the sub-graph from the :class:`nn.Module`
  61. returned by :meth:`torch.distributed._spmd.Override.replacement` if
  62. necessary.
  63. Args:
  64. gm (:class:`fx.Graph`): a DTensor-expanded graph.
  65. flat_state (List[str, :class:`Tensor`]): a reference to the list of
  66. flattened state. The elements in ``flat_state`` map to the first
  67. ``len(flat_state)`` placeholders in the graph. The transformation
  68. can add state to or remove state from ``flat_state`` as long as
  69. it keeps ``flat_state`` and the placeholders consistent.
  70. Returns:
  71. The :class:`fx.Graph` after transformation.
  72. """
  73. pass
  74. class _PyTreeCodeGenOutputsOnly(_PyTreeCodeGen):
  75. # pyre-ignore[3]
  76. def process_inputs(self, *args: Any) -> Any:
  77. return args
  78. # pyre-ignore[2, 3]
  79. def gen_fn_def(self, free_vars, maybe_return_annotation):
  80. return CodeGen.gen_fn_def(self, free_vars, maybe_return_annotation)
  81. def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
  82. """Move the responsibility of flattening the input arguments from the graph module to the caller.
  83. Example:
  84. output = gm(my_struct)
  85. gm = gm(to_caller_flattened_graph_module)
  86. output = gm(*pytree.flatten(my_struct)[0])
  87. """
  88. # pyre-ignore[16]
  89. gm._graph._codegen = _PyTreeCodeGenOutputsOnly(
  90. pytree_info=_PyTreeInfo(
  91. # pyre-ignore[6]
  92. orig_args=None, # type: ignore[arg-type]
  93. # pyre-ignore[6]
  94. in_spec=None, # type: ignore[arg-type]
  95. # pyre-ignore[16]
  96. out_spec=gm._graph._codegen.pytree_info.out_spec,
  97. )
  98. )
  99. gm.recompile()
  100. return gm
  101. # Use a dtensor expand mode for now to preserve the old behavior
  102. # and avoid breaking existing code
  103. dtensor_expand_mode = DTensorExpandMode()
  104. def _override_placements(t: torch.Tensor, placements: List[Placement]):
  105. global dtensor_expand_mode
  106. dtensor_expand_mode._placements_override[id(t)] = placements
  107. @contextmanager
  108. def _rematerialize_optimizer(
  109. opt: torch.optim.Optimizer,
  110. named_states: Dict[str, Any],
  111. params: Dict[str, nn.Parameter],
  112. ):
  113. assert opt is not None
  114. # update opt.state with proxy tensors
  115. orig_states = copy(opt.state)
  116. for n in named_states:
  117. # opt.state's key type is string, but optimizer uses Parameter as keys
  118. opt.state[params[n]] = named_states[n] # type: ignore[index]
  119. # FIXME: support multiple parameter groups
  120. param_group = opt.param_groups[0]
  121. orig_params = param_group["params"]
  122. param_group["params"] = params.values()
  123. try:
  124. yield
  125. finally:
  126. param_group["params"] = orig_params
  127. opt.state = orig_states
  128. aten = torch.ops.aten # pyre-ignore
  129. @contextmanager
  130. def _enable_compile():
  131. # The return value of torch._utils.is_compiling changes optimizer behavior.
  132. # We need that function to return True to include optimizer in the graph.
  133. # See: https://github.com/pytorch/pytorch/blob/a524123c91ab399c9dd6882c1189596dd77e7734/torch/optim/optimizer.py#L41
  134. def f_true():
  135. return True
  136. orig_is_compiling_code = torch._utils.is_compiling.__code__
  137. torch._utils.is_compiling.__code__ = f_true.__code__
  138. try:
  139. yield
  140. finally:
  141. torch._utils.is_compiling.__code__ = orig_is_compiling_code
  142. def _foreach_add_decomp(self, other, alpha=1):
  143. self_updated = aten._foreach_add.List(self, other, alpha=alpha)
  144. for s, s_u in zip(self, self_updated):
  145. s.copy_(s_u)
  146. def _foreach_unaop_decomp(op, self):
  147. self_updated = op(self)
  148. for s, s_u in zip(self, self_updated):
  149. s.copy_(s_u)
  150. def _foreach_binop_list_decomp(op, self, other):
  151. self_updated = op(self, other)
  152. for s, s_u in zip(self, self_updated):
  153. s.copy_(s_u)
  154. def _foreach_binop_scalar_decomp(op, self, scalar=1):
  155. self_updated = op(self, scalar)
  156. for s, s_u in zip(self, self_updated):
  157. s.copy_(s_u)
  158. def _foreach_addcop_scalar_decomp(op, self, tensor1, tensor2, scalar=1):
  159. self_updated = op(self, tensor1, tensor2, scalar)
  160. for s, s_u in zip(self, self_updated):
  161. s.copy_(s_u)
  162. def _fused_adam_decomp(
  163. self,
  164. grads,
  165. exp_avgs,
  166. exp_avg_sqs,
  167. max_exp_avg_sqs,
  168. state_steps,
  169. *,
  170. lr=1,
  171. beta1=1,
  172. beta2=1,
  173. weight_decay=1,
  174. eps=1,
  175. amsgrad=True,
  176. maximize=True,
  177. grad_scale=None,
  178. found_inf=None,
  179. ):
  180. orig_tuple = (self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs)
  181. updated_tuple = aten._fused_adam.default(
  182. self,
  183. grads,
  184. exp_avgs,
  185. exp_avg_sqs,
  186. max_exp_avg_sqs,
  187. state_steps,
  188. lr=lr,
  189. beta1=beta1,
  190. beta2=beta2,
  191. weight_decay=weight_decay,
  192. eps=eps,
  193. amsgrad=amsgrad,
  194. maximize=maximize,
  195. grad_scale=grad_scale,
  196. found_inf=found_inf,
  197. )
  198. for idx, (orig, updated) in enumerate(zip(orig_tuple, updated_tuple)):
  199. if idx == 1:
  200. # skip gradient copying as we don't need to copy gradients back
  201. continue
  202. for o, u in zip(orig, updated):
  203. o.copy_(u)
  204. SPMD_DECOMP_TABLE = {
  205. aten._foreach_add_.List: _foreach_add_decomp,
  206. aten._foreach_add_.Scalar: partial(
  207. _foreach_binop_scalar_decomp, aten._foreach_add.Scalar
  208. ),
  209. aten._foreach_addcdiv_.Scalar: partial(
  210. _foreach_addcop_scalar_decomp, aten._foreach_addcdiv.Scalar
  211. ),
  212. aten._foreach_addcmul_.Scalar: partial(
  213. _foreach_addcop_scalar_decomp, aten._foreach_addcmul.Scalar
  214. ),
  215. aten._foreach_div_.List: partial(
  216. _foreach_binop_list_decomp, aten._foreach_div.List
  217. ),
  218. aten._foreach_mul_.Scalar: partial(
  219. _foreach_binop_scalar_decomp, aten._foreach_mul.Scalar
  220. ),
  221. aten._foreach_div_.Scalar: partial(
  222. _foreach_binop_scalar_decomp, aten._foreach_div.Scalar
  223. ),
  224. aten._foreach_neg_.default: partial(
  225. _foreach_unaop_decomp, aten._foreach_neg.default
  226. ),
  227. aten._foreach_reciprocal_.default: partial(
  228. _foreach_unaop_decomp, aten._foreach_reciprocal.default
  229. ),
  230. aten._foreach_sqrt_.default: partial(
  231. _foreach_unaop_decomp, aten._foreach_sqrt.default
  232. ),
  233. aten._foreach_sub_.Scalar: partial(
  234. _foreach_binop_scalar_decomp, aten._foreach_sub.Scalar
  235. ),
  236. aten._fused_adam_.default: _fused_adam_decomp,
  237. aten.native_layer_norm_backward.default: native_layer_norm_backward,
  238. }
  239. DEDUP_TARGETS: Set[torch._ops.OpOverload] = {
  240. torch.ops._c10d_functional.all_reduce.default,
  241. torch.ops._c10d_functional.wait_tensor.default,
  242. }
  243. def _dedup_collectives(gm: fx.GraphModule) -> fx.GraphModule:
  244. args_to_node: Dict[Tuple[Any, ...], fx.Node] = {}
  245. for node in gm.graph.nodes:
  246. # replace all args with the results from the first unique comm op
  247. args = pytree.arg_tree_leaves(*node.args)
  248. if node.target in DEDUP_TARGETS:
  249. args_key = (node.target, *args)
  250. unique_node = args_to_node.get(args_key, None)
  251. if unique_node is None:
  252. # first time seeing this combination, remember it
  253. args_to_node[args_key] = node
  254. else:
  255. # the current node is a duplicate, replace it
  256. node.replace_all_uses_with(unique_node)
  257. gm.graph.erase_node(node)
  258. gm.recompile()
  259. return gm
  260. @dataclass
  261. class _CompiledResult:
  262. gm: fx.GraphModule
  263. mod: nn.Module
  264. opt: Optional[torch.optim.Optimizer]
  265. flat_state: List[torch.Tensor]
  266. def _compile(
  267. func: Callable,
  268. module_override: Optional[List[Override]],
  269. parallel_mode: ParallelMode,
  270. *args: Any,
  271. **kwargs: Any,
  272. ) -> _CompiledResult:
  273. # 1. Extract nn.Module and Optimizer from args and kwargs
  274. # FIXME(@mrshenli): support multiple nn.Module instances
  275. # FIXME(@mrshenli): support multiple Optiimzer instances
  276. # FIXME(@mrshenli): need to broadcast model to sync parameters
  277. mod, opt = None, None
  278. for arg in pytree.arg_tree_leaves(*args, **kwargs):
  279. if isinstance(arg, nn.Module):
  280. assert mod is None, "Only support single nn.Module for now"
  281. mod = arg
  282. if isinstance(arg, torch.optim.Optimizer):
  283. assert opt is None, "Only support single Optimizer for now"
  284. opt = arg
  285. assert mod is not None, "Couldn't find nn.Module instances from the arguments."
  286. # 2. Override target submodules (e.g., MoE) with dummy replacements
  287. if module_override:
  288. accessor = NamedMemberAccessor(mod)
  289. def swap(fqn_prefix: str, module: torch.nn.Module) -> None:
  290. for override in module_override: # type: ignore[union-attr]
  291. for name, child in module.named_children():
  292. if len(name) == 0:
  293. continue
  294. fqn = fqn_prefix + "." + name if fqn_prefix != "" else name
  295. new_child = override.replacement(fqn, child)
  296. if id(new_child) == id(child):
  297. swap(fqn, new_child)
  298. else:
  299. accessor.swap_submodule(fqn, new_child)
  300. swap("", mod)
  301. # 3. Trace statelss version of the train_step
  302. params = dict(mod.named_parameters(remove_duplicate=False))
  303. buffers = dict(mod.named_buffers(remove_duplicate=False))
  304. named_states = {}
  305. if opt is not None:
  306. # Pass named_states instead of opt.state to stateless_func, because
  307. # the later uses nn.Parameter as key. During tracing, we need to
  308. # make sure optimizers can find the states using proxy tensors.
  309. for n, p in params.items():
  310. if p in opt.state:
  311. # opt.state's key type is string, but optimizer uses
  312. # Parameter as keys
  313. named_states[n] = opt.state[p] # type: ignore[index]
  314. is_data_parallel_mode = isinstance(parallel_mode, DataParallel)
  315. # Lift states and parameters as function arguments so that make_fx
  316. # can trace operations applied to them.
  317. def stateless_func(func, params, buffers, named_states, args, kwargs):
  318. with stateless._reparametrize_module(
  319. mod, {**params, **buffers}
  320. ), _rematerialize_optimizer(
  321. opt, named_states, params
  322. ) if opt else nullcontext():
  323. # For DataParallel mode, install hooks first to tag the gradients
  324. with gradients_tagging(params) if is_data_parallel_mode else nullcontext():
  325. ret = func(*args, **kwargs)
  326. # make sure updated parameters are returned
  327. return ret, list(mod.parameters()), list(named_states.values()) # type: ignore[union-attr]
  328. # FIXME: Using symbolic tracing to work around in DTensor expand mode.
  329. # Otherwise it hits shape mismatch error, as we use local inputs to
  330. # trace local graph and use DTensor to expand operators, where
  331. # DTensor's shape is the global shape.
  332. tracing_mode = "fake" if is_data_parallel_mode else "symbolic"
  333. if is_data_parallel_mode:
  334. fake_mode = FakeTensorMode()
  335. data_parallel_mode = cast(DataParallel, parallel_mode)
  336. def _get_full_batch_arg(arg: torch.Tensor) -> torch.Tensor:
  337. # since compilation happens in the first iteration and we
  338. # receives mini-batch input, convert them to full batch
  339. # fake tensor input first for data parallel sharding
  340. # propagations
  341. fake_arg = fake_mode.from_tensor(arg)
  342. arg_dims = [1] * arg.ndim
  343. # expand the tensor to full batch size on its batch dim
  344. arg_dims[data_parallel_mode.input_batch_dim] *= dist.get_world_size()
  345. return fake_arg.repeat(arg_dims)
  346. args = pytree.tree_map_only(
  347. torch.Tensor,
  348. _get_full_batch_arg,
  349. args,
  350. )
  351. kwargs = pytree.tree_map_only(
  352. torch.Tensor,
  353. _get_full_batch_arg,
  354. kwargs,
  355. )
  356. with _enable_compile(), torch.autograd.detect_anomaly(check_nan=False):
  357. # FIXME(@mrshenli): functionalization does not work for our use
  358. # case yet. Use explicit decompositions for foreach ops.
  359. # Remove this when the following issue is addressed.
  360. # Issue: https://github.com/pytorch/pytorch/issues/97852
  361. gm = make_fx(
  362. partial(stateless_func, func),
  363. tracing_mode=tracing_mode,
  364. decomposition_table=SPMD_DECOMP_TABLE,
  365. _allow_non_fake_inputs=False,
  366. )(params, buffers, named_states, args, kwargs)
  367. params_and_buffers: Dict[str, Union[torch.Tensor, nn.Parameter]] = {
  368. **params,
  369. **buffers,
  370. }
  371. # 4. parallel mode to expand a single device graph to a distributed graph
  372. gm = parallel_mode.partition(
  373. gm,
  374. mod,
  375. opt,
  376. params_and_buffers,
  377. named_states,
  378. args,
  379. kwargs,
  380. )
  381. # 5. Move the responsibility of flattening the input arguments from the
  382. # graph module to the caller. This serves two purposes:
  383. # - Transformations that add/remove state need to manipulate a state
  384. # container that maintains the state tensors in the same order as they
  385. # appear in graph placeholders.
  386. # - Reduced runtime cost. The state container is only flattened once upfront.
  387. flat_state = pytree.tree_leaves([params_and_buffers, named_states])
  388. gm = _to_caller_flattened_graph_module(gm)
  389. # 6. dedup comm operators.
  390. # The duplication could come from DTensor args and kwargs redistribution.
  391. # Suppose one operator produces a Partial gradient tensor and model
  392. # parameters are replicated. In this case, every optimizer operation using
  393. # that Partial gradient tensor would trigger an allreduce. This is becuase
  394. # DTensor only has local information on individual tensor/operator, which is
  395. # not sufficient to detect duplications in the graph. This situation can
  396. # also happen when inserting FSDP allgather if a parameter is used multiple
  397. # times in the forward method.
  398. # TODO(@mrshenli): @yifuwang has a suggestion of conducting expansion and
  399. # dedup at tracer-level to avoid multiple graph passes.
  400. gm = _dedup_collectives(gm)
  401. # 7. Replace previously inserted dummy ones with real graphs.
  402. if module_override:
  403. for override in module_override:
  404. gm = override.transform(gm, flat_state)
  405. return _CompiledResult(gm, mod, opt, flat_state)
  406. # Note that the Python convention of __dict__ requires the key to be str.
  407. # TODO: ensure the key is unique.
  408. COMPILED_OBJECT_KEY = "_compiled_obj"
  409. def compile(
  410. module_override: Optional[List[Override]] = None,
  411. gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
  412. parallel_mode: Optional[ParallelMode] = None,
  413. ):
  414. r"""Compile and optimize a callable, which can be a train step within a training loop.
  415. This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
  416. instances from the input arguments and trace operations applied to their
  417. parameters and states.
  418. Args:
  419. module_override (Optional[List[Override]]): a list of Override instances
  420. that will be applied to the module in order. The :class:`Override`
  421. objects provide :class:`nn.Module` replacements during tracing and a
  422. graph transformation function after tracing. (Default: ``None``)
  423. gm_transformation (Optional[Callable[fx.GraphModule, fx.GraphModule]]):
  424. a callback that will be called after the original callable is
  425. compiled and distributed (usually after the first iteration) to
  426. transform the compiled GraphModule into a new optimized one.
  427. parallel_mode (Optional[ParallelMode]): a :class:`ParallelMode` object
  428. that specifies how to parallelize the callable. Each ParallelMode
  429. would have its own strategy to partition the model and the captured
  430. graph (Default: ``None``)
  431. """
  432. def inner(func: Callable):
  433. @wraps(func)
  434. def wrapper(*args, **kwargs):
  435. last_train_step = kwargs.pop("last_train_step", False) if kwargs else False
  436. first_iter = False
  437. # Put the COMPILED_OBJECT_KEY in ``wrapper`` instead of ``func`` as
  438. # ``wrapper`` is the one that users will get.
  439. compiled_obj = wrapper.__dict__.get(COMPILED_OBJECT_KEY, None)
  440. if compiled_obj is None:
  441. first_iter = True
  442. global dtensor_expand_mode
  443. mode: ParallelMode = (
  444. dtensor_expand_mode if parallel_mode is None else parallel_mode
  445. )
  446. compiled_obj = _compile(func, module_override, mode, *args, **kwargs)
  447. wrapper.__dict__[COMPILED_OBJECT_KEY] = compiled_obj
  448. flat_inps = compiled_obj.flat_state + pytree.arg_tree_leaves(
  449. *args, **kwargs
  450. )
  451. with torch.no_grad():
  452. # N.B.: we don't need autograd as backward has already been
  453. # captured in the graph.
  454. if first_iter and gm_transformation:
  455. # TODO: SPMD should provid a default and configurable
  456. # transformation.
  457. compiled_obj.gm = gm_transformation(compiled_obj.gm)
  458. if not last_train_step:
  459. output = compiled_obj.gm(*flat_inps)[0]
  460. else:
  461. # This is the last train step. Call IterGraphModule.forward()
  462. # with the `last_iter` argument and catch the exception in
  463. # case the compiled_obj is not wrapped with IterGraphModule.
  464. try:
  465. output = compiled_obj.gm(*flat_inps, last_iter=last_train_step)[
  466. 0
  467. ]
  468. except TypeError as e:
  469. if "last_iter" not in str(e):
  470. raise e
  471. output = compiled_obj.gm(*flat_inps)[0]
  472. return output
  473. return wrapper
  474. return inner