graph_optimization.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987
  1. # mypy: allow-untyped-defs
  2. # Owner(s): ["oncall: distributed"]
  3. import collections
  4. import itertools
  5. import logging
  6. import operator
  7. import tempfile
  8. import time
  9. from dataclasses import dataclass, field
  10. from functools import wraps
  11. from typing import (
  12. Any,
  13. Callable,
  14. cast,
  15. DefaultDict,
  16. Dict,
  17. Iterable,
  18. List,
  19. Optional,
  20. Set,
  21. Tuple,
  22. Union,
  23. )
  24. import torch
  25. import torch.fx as fx
  26. from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
  27. from torch.distributed._spmd.graph_utils import (
  28. CommType,
  29. dump_graphs_to_files,
  30. find_node,
  31. get_output,
  32. OP,
  33. )
  34. from torch.distributed._spmd.iter_graph_module import IterGraphModule
  35. from torch.fx.passes.shape_prop import TensorMetadata
  36. from torch.utils import _pytree as pytree
  37. from torch.utils._pytree import tree_flatten, tree_unflatten
  38. logger: logging.Logger = logging.getLogger("graph_optimization")
  39. aten = torch.ops.aten
  40. fake_tensor_mode = FakeTensorMode()
  41. _optimized_func: Set[str] = set()
  42. # The key is the target pass and the value is the prerequisites of the pass.
  43. _prerequisite_sets: DefaultDict[str, Set[str]] = collections.defaultdict(set)
  44. # The key is the target pass and the value is the passes that must applied before
  45. # the key.
  46. _apply_before_sets: DefaultDict[str, Set[str]] = collections.defaultdict(set)
  47. _dump_graph_folder: str = ""
  48. def enable_graph_optimization_dump(folder: str = ""):
  49. global _dump_graph_folder
  50. if not folder:
  51. folder = tempfile.mkdtemp()
  52. _dump_graph_folder = folder
  53. # TODO(@fegin): Support multiple runs of graph optimization
  54. # TODO(@fegin): With this design, circular imports will happen when a pass
  55. # developer accidentally create a pass dependency cycle. As a result, we need to
  56. # break this file into a finer granularity to avoid incorrect circular import.
  57. def graph_optimization_pass(
  58. prerequisites: Iterable[Callable],
  59. apply_after: Iterable[Callable],
  60. ) -> Callable:
  61. """Define the contract of a graph optimization pass.
  62. All the passes should be wrapped with this decorator.
  63. `prerequisites` is used to annotate the prerequisite passes of the this pass.
  64. `apply_after` means that this wrapped pass must be applied after the passes
  65. in `apply_after`. The difference between `prerequisites` and `apply_after`
  66. is that all the passes in `prerequisites` must be applied to the graph and
  67. must be applifed before the wrapped pass while the passes `apply_after` are
  68. optional. But if a pass in `apply_after` is applied to the graph, it has to
  69. be done before the wrapped pass.
  70. Optimizer pass developers are required to add these fields accordingly and
  71. users need to follow the restrictions to avoid the assert.
  72. Current design has one limitation: users can only apply the optimizations
  73. once. In some cases, we may need to run multiple the same optimization
  74. multiple time, e.g., optimization passes -> profiling the result -> apply
  75. optimization passes with the profiling result again. This limitation will be
  76. addressed limitation in the future.
  77. Args:
  78. prerequisites (Iterable[Callable]): the list of string to the names of
  79. passes which are the prerequisites of this pass.
  80. apply_after (Iterable[Callable]): the list of string to the names of
  81. passes that can not be applied after the wrapped pass.
  82. """
  83. def inner(func: Callable) -> Callable:
  84. def make_key(func: Callable) -> str:
  85. return f"{func.__module__}.{func.__name__}"
  86. func_key = make_key(func)
  87. _prerequisite_sets[func_key] = {make_key(f) for f in prerequisites}
  88. for apply_after_pass in apply_after:
  89. _apply_before_sets[make_key(apply_after_pass)].add(func_key)
  90. @wraps(func)
  91. def pass_wrapper(
  92. gm: Union[fx.GraphModule, IterGraphModule], *args: Any, **kwargs: Any
  93. ) -> None:
  94. begin = time.time()
  95. assert isinstance(gm, (fx.GraphModule, IterGraphModule)), (
  96. "The first argument of the pass must be either "
  97. "fx.GraphModule or IterGraphModule."
  98. )
  99. assert func_key not in _optimized_func, f"Cannot apply {func_key} twice."
  100. invalid_passes = _apply_before_sets[func_key].intersection(_optimized_func)
  101. assert (
  102. not invalid_passes
  103. ), f"{invalid_passes} must be applied after {func_key}."
  104. assert _prerequisite_sets[func_key].issubset(_optimized_func), (
  105. f"{_prerequisite_sets[func_key] - _optimized_func} are the "
  106. f"prerequisites of {func_key} but are not applified. "
  107. f"Applied passes are {_optimized_func}."
  108. )
  109. func(gm, *args, **kwargs)
  110. gm.graph.lint()
  111. gm.graph.eliminate_dead_code()
  112. gm.recompile()
  113. _optimized_func.add(func_key)
  114. prefix = f"after_{func.__name__}"
  115. if _dump_graph_folder:
  116. if isinstance(gm, IterGraphModule):
  117. dump_graphs_to_files(
  118. {
  119. f"{prefix}_setup_gm": gm.setup_gm,
  120. f"{prefix}_main_gm": gm.main_gm,
  121. f"{prefix}_cleanup_gm": gm.cleanup_gm,
  122. },
  123. _dump_graph_folder,
  124. )
  125. else:
  126. dump_graphs_to_files({prefix: gm}, _dump_graph_folder)
  127. logger.info("Spent %f seconds applying %s", time.time() - begin, func_key)
  128. return pass_wrapper
  129. return inner
  130. @dataclass(unsafe_hash=True)
  131. class CommBlock:
  132. shape: Optional[torch.Size]
  133. node_list: List[fx.Node]
  134. inputs: List[fx.Node]
  135. wait_nodes: List[fx.Node]
  136. comm_node: fx.Node
  137. outputs: Set[fx.Node]
  138. def get_comm_block(comm_node: fx.Node) -> CommBlock:
  139. """Find out all the nodes belong to this communcation given a collective node (e.g., allreduce).
  140. Args:
  141. comm_node(fx.Node): The target communication/collective node.
  142. Returns:
  143. The CommBlock that encapsulates the related nodes (e.g., wait_node) of
  144. the given comm_node.
  145. """
  146. # We choose 5 to prevent some accidents that cause infinite loop. But
  147. # with functional collective, the distance is 1.
  148. MAX_WAIT_DISTANCE = 5
  149. node_list = []
  150. wait_nodes = []
  151. inputs = pytree.arg_tree_leaves(*comm_node.args, **comm_node.kwargs)
  152. input_nodes = [inp for inp in inputs if isinstance(inp, fx.Node)]
  153. distance = 0
  154. wait_prefixes = ("wait_comm", "wait_tensor")
  155. non_end_users_nodes = ("split", "reshape", "getitem", "detach", "alias")
  156. nodes = collections.deque([comm_node, None])
  157. while nodes and distance < 5:
  158. node = nodes.popleft()
  159. if node is None:
  160. distance += 1
  161. if nodes:
  162. nodes.append(None)
  163. continue
  164. node_list.append(node)
  165. if node.name.startswith(wait_prefixes):
  166. wait_nodes.append(node)
  167. else:
  168. for child in node.users:
  169. if isinstance(child, fx.Node):
  170. nodes.append(child)
  171. if not wait_nodes:
  172. raise RuntimeError(
  173. "The wait nodes are too far away from the comm node {comm_node}."
  174. )
  175. # Identify all the outputs of this collective block.
  176. outputs: Set[fx.Node] = set()
  177. nodes = collections.deque(wait_nodes)
  178. while nodes:
  179. node = nodes.popleft()
  180. assert node is not None
  181. for user in node.users:
  182. if isinstance(user, fx.Node) and user.name.startswith(non_end_users_nodes):
  183. nodes.append(user)
  184. node_list.append(user)
  185. else:
  186. outputs.add(node)
  187. break
  188. # TODO: populate all the tensor metadata and remove the default.
  189. tensor_meta = input_nodes[0].meta.get("tensor_meta", None)
  190. return CommBlock(
  191. # TODO: support symbolic shapes
  192. shape=torch.Size(int(s) for s in tensor_meta.shape) if tensor_meta else None,
  193. node_list=node_list,
  194. wait_nodes=wait_nodes,
  195. comm_node=comm_node,
  196. inputs=input_nodes,
  197. outputs=outputs,
  198. )
  199. def get_all_comm_blocks(
  200. gm: IterGraphModule, comm_ops: Union[Tuple[str, ...], str]
  201. ) -> List[CommBlock]:
  202. return [
  203. get_comm_block(node)
  204. for node in gm.graph.nodes
  205. if node.name.startswith(comm_ops)
  206. ]
  207. def _create_meta_val(
  208. fake_tensor_mode: FakeTensorMode,
  209. val: FakeTensor,
  210. ) -> FakeTensor:
  211. # TODO: fix the memory_format
  212. return FakeTensor(
  213. fake_tensor_mode,
  214. torch.empty(
  215. val.shape,
  216. dtype=val.dtype,
  217. device="meta",
  218. requires_grad=val.requires_grad,
  219. ),
  220. val.device,
  221. )
  222. def _create_meta_tensor_meta(
  223. fake_tensor_mode: FakeTensorMode,
  224. val: FakeTensor,
  225. ) -> TensorMetadata:
  226. return TensorMetadata(
  227. shape=val.shape,
  228. dtype=val.dtype,
  229. requires_grad=val.requires_grad,
  230. stride=val.stride, # type: ignore[arg-type]
  231. # TODO: fix these value
  232. memory_format=None,
  233. is_quantized=False,
  234. qparams={},
  235. )
  236. def _call_function(
  237. gm: IterGraphModule,
  238. fake_tensor_mode: FakeTensorMode,
  239. meta_val: Optional[FakeTensor],
  240. function: Any,
  241. *args: Any,
  242. **kwargs: Any,
  243. ) -> fx.Node:
  244. node = gm.graph.call_function(function, args, kwargs)
  245. if meta_val is None:
  246. flat_args, spec = tree_flatten((args, kwargs))
  247. new_flat_args = []
  248. memory_format = None
  249. for arg in flat_args:
  250. if not isinstance(arg, fx.Node):
  251. new_flat_args.append(arg)
  252. continue
  253. val = arg.meta["val"]
  254. new_flat_args.append(_create_meta_val(fake_tensor_mode, val))
  255. fake_args, fake_kwargs = tree_unflatten(new_flat_args, spec)
  256. new_meta_val = function(*fake_args, **fake_kwargs)
  257. else:
  258. new_meta_val = meta_val
  259. node.meta["val"] = new_meta_val
  260. node.meta["tensor_meta"] = _create_meta_tensor_meta(fake_tensor_mode, new_meta_val)
  261. return node
  262. def _scatter_wait_result(
  263. gm: IterGraphModule,
  264. fused_comm_block: CommBlock,
  265. comm_blocks: List[CommBlock],
  266. node_indices: Dict[fx.Node, int],
  267. ) -> None:
  268. """Scatter the result of the fused communication node to the original users -- splitting the output and reshape each subitem."""
  269. last_wait_node_idx = 0
  270. for node in gm.graph.nodes:
  271. if node == fused_comm_block.comm_node:
  272. break
  273. last_wait_node_idx = max(
  274. node_indices.get(node, last_wait_node_idx), last_wait_node_idx
  275. )
  276. fused_comm_node = fused_comm_block.comm_node
  277. fused_wait_node = fused_comm_block.wait_nodes[0]
  278. with gm.graph.inserting_after(fused_wait_node):
  279. split_node = gm.graph.call_function(
  280. aten.split,
  281. (
  282. fused_wait_node,
  283. # TODO(@fegin): support symbolic shapes
  284. [int(cast(torch.Size, cb.shape).numel()) for cb in comm_blocks],
  285. ),
  286. )
  287. # Scatter the split result.
  288. need_sort_nodes = []
  289. last_split_reshape_node = split_node
  290. with gm.graph.inserting_after(split_node):
  291. for idx, comm_block in enumerate(comm_blocks):
  292. # Some users of the original allreduce and wait are scheduled
  293. # before the fused allreduce. We must move these users to a
  294. # correct topological sort order -- right after the last fused
  295. # allreduce result, the `last_split_reshape_node` variable.
  296. orig_wait = comm_block.wait_nodes[0]
  297. nodes = collections.deque(list(orig_wait.users))
  298. while nodes:
  299. user_node = nodes.popleft()
  300. if not isinstance(user_node, fx.Node):
  301. continue
  302. if node_indices[user_node] < last_wait_node_idx:
  303. need_sort_nodes.append(user_node)
  304. nodes.extend(list(user_node.users))
  305. split_idx_node = gm.graph.call_function(operator.getitem, (split_node, idx))
  306. with gm.graph.inserting_after(split_idx_node):
  307. wait_output_node = gm.graph.call_function(
  308. aten.reshape, (split_idx_node, comm_block.shape)
  309. )
  310. gm.graph.node_replace_all_uses_with(orig_wait, wait_output_node)
  311. if last_split_reshape_node == split_node:
  312. last_split_reshape_node = wait_output_node # type: ignore[possibly-undefined]
  313. need_sort_nodes = sorted(need_sort_nodes, key=lambda node: node_indices[node])
  314. gm.graph.move_after(need_sort_nodes, last_split_reshape_node)
  315. gm.graph.eliminate_dead_code()
  316. def _fuse_with_cat(
  317. gm: IterGraphModule,
  318. comm_blocks: List[CommBlock],
  319. node_indices: Dict[fx.Node, int],
  320. ) -> CommBlock:
  321. """Fuse the CommBlocks using concat given a list of CommBlock (only allreduce)."""
  322. # Find the last input node.
  323. last_input_node = comm_blocks[0].inputs[0]
  324. last_input_index = -1
  325. all_input_nodes = []
  326. for comm_block in comm_blocks:
  327. input_node = comm_block.inputs[0]
  328. # If the input node is a clone, this is CommTensor based implementation.
  329. if input_node.name.startswith("clone"):
  330. input_node = cast(fx.Node, input_node.args[0])
  331. all_input_nodes.append(input_node)
  332. index = node_indices[input_node]
  333. if index >= last_input_index:
  334. assert index != last_input_index
  335. last_input_node = input_node
  336. last_input_index = index
  337. # Flatten all the inputs right after the last input is ready.
  338. with gm.graph.inserting_after(last_input_node):
  339. cat_inputs = []
  340. for input_node in all_input_nodes:
  341. cat_inputs.append(
  342. _call_function(
  343. gm, fake_tensor_mode, None, aten.flatten.using_ints, input_node
  344. )
  345. )
  346. with gm.graph.inserting_after(cat_inputs[0]):
  347. cat_node = _call_function(gm, fake_tensor_mode, None, aten.cat, cat_inputs)
  348. # Create a new Comm node.
  349. last_comm = comm_blocks[-1]
  350. last_comm_node = last_comm.comm_node
  351. last_wait_node = last_comm.wait_nodes[0]
  352. with gm.graph.inserting_after(cat_node):
  353. flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs))
  354. flatten_args[0] = cat_node
  355. args, kwargs = tree_unflatten(flatten_args, spec)
  356. fused_comm_node = _call_function(
  357. gm,
  358. fake_tensor_mode,
  359. cat_node.meta["val"],
  360. last_comm_node.target,
  361. *args,
  362. **kwargs,
  363. )
  364. # Create a new Wait node.
  365. with gm.graph.inserting_after(fused_comm_node):
  366. flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs))
  367. flatten_args[0] = fused_comm_node
  368. args, kwargs = tree_unflatten(flatten_args, spec)
  369. fused_wait_node = _call_function(
  370. gm,
  371. fake_tensor_mode,
  372. cat_node.meta["val"],
  373. last_wait_node.target,
  374. *args,
  375. **kwargs,
  376. )
  377. # Move the fused_comm_node and its args to right after the source node
  378. nodes_to_move = cat_inputs + [cat_node, fused_comm_node, fused_wait_node]
  379. gm.graph.move_after(nodes_to_move, last_input_node)
  380. tensor_meta = cat_node.meta.get("tensor_meta")
  381. fused_comm_block = CommBlock(
  382. shape=tensor_meta.shape, # type: ignore[union-attr]
  383. node_list=[fused_comm_node, fused_wait_node],
  384. wait_nodes=[fused_wait_node],
  385. comm_node=fused_comm_node,
  386. inputs=[cat_node],
  387. outputs={fused_wait_node},
  388. )
  389. _scatter_wait_result(gm, fused_comm_block, comm_blocks, node_indices)
  390. return fused_comm_block
  391. def _expedite_comm_ops(gm: IterGraphModule, comm_blocks: List[CommBlock]) -> None:
  392. node_indices = {node: i for i, node in enumerate(gm.graph.nodes)}
  393. for comm_block in comm_blocks:
  394. last_input = comm_block.comm_node
  395. last_input_idx = -1
  396. for input in comm_block.inputs:
  397. input_idx = node_indices[input]
  398. if input_idx > last_input_idx:
  399. last_input = input
  400. last_input_idx = input_idx
  401. gm.graph.node_append(last_input, comm_block.comm_node)
  402. @graph_optimization_pass(
  403. prerequisites=[],
  404. apply_after=[],
  405. )
  406. def comm_fusion_with_concat(
  407. gm: IterGraphModule,
  408. bucket_size_mb: int,
  409. ) -> None:
  410. """Run fuse communication with concat.
  411. This implementation uses concat to concat the bucketed gradients.
  412. """
  413. comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
  414. # First ensure the allreduce are scheduled immediately right after the gradients.
  415. _expedite_comm_ops(gm, comm_blocks)
  416. # Get the comm_blocks based on the new order.
  417. comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
  418. node_indices = {node: i for i, node in enumerate(gm.graph.nodes)}
  419. bucket_size = 1 * 1024**2
  420. bucket_cap_size = bucket_size_mb * 1024**2
  421. begin = end = curr_size = 0
  422. while end < len(comm_blocks):
  423. # TODO: determine the dtype
  424. curr_size += cast(torch.Size, comm_blocks[end].shape).numel() * 4
  425. end += 1
  426. if curr_size < bucket_size:
  427. continue
  428. _fuse_with_cat(gm, comm_blocks[begin:end], node_indices)
  429. bucket_size = bucket_cap_size
  430. begin = end
  431. curr_size = 0
  432. else:
  433. if begin < len(comm_blocks):
  434. _fuse_with_cat(gm, comm_blocks[begin:end], node_indices)
  435. @graph_optimization_pass(
  436. prerequisites=[comm_fusion_with_concat],
  437. apply_after=[],
  438. )
  439. def schedule_comm_wait(gm: IterGraphModule) -> None:
  440. """Delay the execution of wait tensors of allreduce until its first user."""
  441. comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
  442. # Find all the end users.
  443. allreduce_users: Set[fx.Node] = set()
  444. for allreduce in comm_blocks:
  445. for output in allreduce.outputs:
  446. allreduce_users.update(output.users)
  447. node_indices = {node: i for i, node in enumerate(gm.graph.nodes)}
  448. for allreduce in comm_blocks:
  449. # Find the earliest users.
  450. assert (
  451. len(allreduce.outputs) >= 1
  452. ), f"Found a allreduce that has zero outputs/users -- {allreduce}."
  453. # Initialize the target_node to be the first user of the first output.
  454. target_node = next(iter(next(iter(allreduce.outputs)).users))
  455. target_node_index = 2**31
  456. for user in (user for output in allreduce.outputs for user in output.users):
  457. index = node_indices[user]
  458. if index < target_node_index:
  459. target_node = user
  460. target_node_index = index
  461. # Move wait nodes and all the subsequent output nodes before the
  462. # earliest user.
  463. wait_idx = -1
  464. for wait_idx, node in enumerate(allreduce.node_list):
  465. if node == allreduce.wait_nodes[0]:
  466. break
  467. assert wait_idx >= 0
  468. gm.graph.move_before(allreduce.node_list[wait_idx:], target_node)
  469. @graph_optimization_pass(
  470. prerequisites=[],
  471. apply_after=[],
  472. )
  473. def remove_copy_from_optimizer(gm: IterGraphModule) -> None:
  474. """Erase the orphant copy_ that generated when tracing optimizer.
  475. Two reasons why we could not simply use the DCE of fx.Graph.
  476. 1. fx.Graph treats copy_ as a side-effect node and does not erase it.
  477. 2. Users may want to preserve some orphan `copy_` that is not from the
  478. optimizer.
  479. If the second reason does not hold, this pass can be rewritten as using
  480. DCE from fx.Graph (with the overwrite to the side-effect node list).
  481. """
  482. MAX_COPY_DISTANCE = 5
  483. remove_candidates: Set[fx.Node] = set()
  484. for node in reversed(gm.graph.nodes):
  485. if node.users:
  486. continue
  487. if node.op != OP.CALL_FUNCTION or node.target != aten.copy_.default:
  488. continue
  489. copy_ancestors: Set[fx.Node] = set()
  490. nodes = collections.deque([node, None])
  491. distance = 0
  492. should_remove = False
  493. while nodes and distance < MAX_COPY_DISTANCE:
  494. visiting = nodes.popleft()
  495. if visiting is None:
  496. distance += 1
  497. if nodes:
  498. nodes.append(None)
  499. continue
  500. copy_ancestors.add(visiting)
  501. if visiting.op == OP.CALL_FUNCTION and str(visiting.target).startswith(
  502. ("aten._foreach_", "aten._fused_")
  503. ):
  504. should_remove = True
  505. parents = pytree.arg_tree_leaves(*visiting.args, **visiting.kwargs)
  506. for parent in parents:
  507. if isinstance(parent, fx.Node):
  508. nodes.append(parent)
  509. if should_remove:
  510. # We add all ancestors to the list and it is okay as not all of
  511. # them will be erased -- only those nodes with zero users will be
  512. # erased.
  513. remove_candidates.update(copy_ancestors)
  514. for node in reversed(gm.graph.nodes):
  515. if node.users:
  516. continue
  517. if node not in remove_candidates:
  518. continue
  519. gm.graph.erase_node(node)
  520. # The args list of fused_adam function. We don't care about kwargs.
  521. AdamArgs = collections.namedtuple(
  522. "AdamArgs",
  523. ["params", "grads", "exp_avgs", "exp_avg_sqs", "max_exp_avg_sqs", "state_steps"],
  524. )
  525. # TODO(fegin): Have a template class for all Block class.
  526. @dataclass(unsafe_hash=True)
  527. class FusedAdamBlock:
  528. optim_node: fx.Node
  529. generate_output: bool
  530. # The output list of the copy nodes. The order follows the argument order.
  531. param_outputs: List[fx.Node] = field(default_factory=list)
  532. grad_outputs: List[fx.Node] = field(default_factory=list)
  533. exp_avgs_outputs: List[fx.Node] = field(default_factory=list)
  534. exp_avg_sqs_outputs: List[fx.Node] = field(default_factory=list)
  535. # TODO(fegin): populate/generate the max_exp_avg_sqs if exists
  536. max_exp_avg_sqs: List[fx.Node] = field(default_factory=list)
  537. def generate_outputs(self):
  538. # Iterate all the args and generate the corresponding output lists.
  539. # Assuming the corrsesponding output nodes are not created yet.
  540. def _generate_outputs(arg_idx, output_list):
  541. graph = self.optim_node.graph
  542. with graph.inserting_after(self.optim_node):
  543. optim_getitem = graph.call_function(
  544. operator.getitem, (self.optim_node, arg_idx)
  545. )
  546. for i, arg in enumerate(self.optim_node.args[arg_idx]):
  547. with graph.inserting_after(optim_getitem):
  548. updated_arg = graph.call_function(
  549. operator.getitem, (optim_getitem, i)
  550. )
  551. with graph.inserting_after(updated_arg):
  552. output_copy = graph.call_function(aten.copy_, (arg, updated_arg))
  553. output_list.append(output_copy)
  554. _generate_outputs(0, self.param_outputs)
  555. # Do not generate gradient out list as it is not used.
  556. _generate_outputs(2, self.exp_avgs_outputs)
  557. _generate_outputs(3, self.exp_avg_sqs_outputs)
  558. def populate_outputs(self):
  559. # Populate the existing output lists from the graph.
  560. def _populate_outputs(args_idx, output_list):
  561. optim_getitem = self.optim_node
  562. for user in self.optim_node.users:
  563. assert (
  564. user.target == operator.getitem
  565. ), f"The user of {self.optim_node} is not getitem."
  566. if user.args[1] == args_idx:
  567. optim_getitem = user
  568. break
  569. assert (
  570. optim_getitem != self.optim_node
  571. ), f"Cannot find the getitem node for {self.optim_node}"
  572. output_list.extend(
  573. [self.optim_node] * len(cast(List[fx.Node], self.optim_node.args[0]))
  574. )
  575. for updated_arg in optim_getitem.users:
  576. assert (
  577. updated_arg.target == operator.getitem
  578. ), f"Unexpected node target {updated_arg.target}."
  579. idx = updated_arg.args[1]
  580. output_copy = next(iter(updated_arg.users))
  581. assert str(output_copy.target).startswith(
  582. "aten.copy_"
  583. ), f"Unexpected node target {output_copy.target}."
  584. output_list[idx] = output_copy
  585. for i, output in enumerate(output_list):
  586. assert output != self.optim_node, f"{i}th output is not replaced."
  587. assert output_list, f"The output for {self.optim_node} is empty."
  588. _populate_outputs(0, self.param_outputs)
  589. _populate_outputs(2, self.exp_avgs_outputs)
  590. _populate_outputs(3, self.exp_avg_sqs_outputs)
  591. def __post_init__(self):
  592. if self.param_outputs:
  593. return
  594. if self.generate_output:
  595. self.generate_outputs()
  596. else:
  597. self.populate_outputs()
  598. @dataclass(unsafe_hash=True)
  599. class ForeachAddBlock:
  600. add_node: fx.Node
  601. generate_output: bool
  602. # The output list of the copy nodes. The order follows the argument order.
  603. outputs: List[fx.Node] = field(default_factory=list)
  604. def generate_outputs(self):
  605. # Iterate all the args and generate the corresponding output lists
  606. # Assuming the corrsesponding output nodes are not created yet.
  607. graph = self.add_node.graph
  608. for i, arg in enumerate(cast(Tuple[Any, ...], self.add_node.args[0])):
  609. with graph.inserting_after(self.add_node):
  610. updated_arg = graph.call_function(operator.getitem, (self.add_node, i))
  611. with graph.inserting_after(updated_arg):
  612. output_copy = graph.call_function(aten.copy_, (arg, updated_arg))
  613. self.outputs.append(output_copy)
  614. assert self.outputs, f"The output for {self.add_node} is empty."
  615. def populate_outputs(self):
  616. # Populate the existing output lists from the graph.
  617. self.outputs = [
  618. self.add_node for _ in cast(Tuple[Any, ...], self.add_node.args[0])
  619. ]
  620. for updated_arg in self.add_node.users:
  621. assert (
  622. updated_arg.target == operator.getitem
  623. ), f"Unexpected node target {updated_arg.target}"
  624. idx = cast(int, updated_arg.args[1])
  625. output_copy = next(iter(updated_arg.users))
  626. assert str(output_copy.target).startswith(
  627. "aten.copy_"
  628. ), f"The execpted output node is different, {str(output_copy.target)}"
  629. self.outputs[idx] = output_copy
  630. for i, output in enumerate(self.outputs):
  631. assert output != self.add_node, f"{i}th output is not replaced."
  632. def __post_init__(self):
  633. if self.outputs:
  634. return
  635. if self.generate_output:
  636. self.generate_outputs()
  637. else:
  638. self.populate_outputs()
  639. @dataclass(unsafe_hash=True)
  640. class FusedOptimizerBlock:
  641. step: ForeachAddBlock
  642. optim: FusedAdamBlock
  643. def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock:
  644. """Given a fused optimizer node and return the FusedOptimizerBlock."""
  645. MAX_STEP_DISTANCE = 5
  646. # Find the step (foreach_add)
  647. nodes = collections.deque([optim_node, None])
  648. step_node = optim_node
  649. distance = 0
  650. while nodes and distance < MAX_STEP_DISTANCE:
  651. node = nodes.popleft()
  652. if node is None:
  653. distance += 1
  654. if nodes:
  655. nodes.append(None)
  656. continue
  657. elif node.op == OP.CALL_FUNCTION and str(node.target).startswith(
  658. "aten._foreach_add"
  659. ):
  660. step_node = node
  661. break
  662. else:
  663. nodes.extend(
  664. a
  665. for a in pytree.arg_tree_leaves(*node.args, **node.kwargs)
  666. if isinstance(a, fx.Node)
  667. )
  668. if step_node == optim_node:
  669. raise RuntimeError(
  670. "Cannot find step node (foreach_add) for the optimizer node "
  671. f"{optim_node} with {MAX_STEP_DISTANCE} BFS distance. "
  672. "The API design does not match the tracing graph."
  673. )
  674. step = ForeachAddBlock(step_node, generate_output=False)
  675. optim = FusedAdamBlock(optim_node, generate_output=False)
  676. return FusedOptimizerBlock(step, optim)
  677. def get_all_fused_optimizer_blocks(
  678. gm: IterGraphModule, optim_ops: Union[Tuple[str, ...], str]
  679. ) -> List[FusedOptimizerBlock]:
  680. """Find all the FusedOptimizerBlock that the optimizer operators are in `optim_ops`."""
  681. return [
  682. get_fused_optimizer_block(node)
  683. for node in gm.graph.nodes
  684. if node.name.startswith(optim_ops)
  685. ]
  686. def _split_fused_adam(
  687. gm: IterGraphModule,
  688. orig_optim_block: FusedOptimizerBlock,
  689. split_gradients: Set[fx.Node],
  690. ) -> Tuple[FusedOptimizerBlock, FusedOptimizerBlock]:
  691. """Split the `orig_optim_block` into two FusedOptimizerBlock.
  692. The first one will be the optimizer that optimize `split_gradients`. The second one is
  693. used to optimize the remaining gradients.
  694. An assert will be raised if one of the optimizer optimize zero gradients.
  695. """
  696. orig_optim_args = AdamArgs(*orig_optim_block.optim.optim_node.args)
  697. optim_args = (AdamArgs([], [], [], [], [], []), AdamArgs([], [], [], [], [], []))
  698. # The only hint we can use to split the optimizer is the order/indices.
  699. orig_optim_indices: Tuple[List[int], List[int]] = ([], [])
  700. orig_step_indices: Tuple[List[int], List[int]] = ([], [])
  701. for idx, gradient in enumerate(orig_optim_args.grads):
  702. group_idx = 0 if gradient in split_gradients else 1
  703. orig_optim_indices[group_idx].append(idx)
  704. # Get the argument for idx-th gradient from orig_optim_args
  705. for orig_arg, optim_arg in zip(orig_optim_args, optim_args[group_idx]):
  706. # Only add the argument to the list if the original argument list
  707. # is not empty. If the original argument list is empty, the new
  708. # one must be an empty list as well.
  709. if orig_arg:
  710. optim_arg.append(orig_arg[idx])
  711. # If argument order of step is the same as optimizer, nothing has to be
  712. # done. However, it is risky to rely on this assumption so we populate
  713. # the orig_step_indices.
  714. orig_step_output = optim_args[group_idx].state_steps[-1]
  715. assert str(orig_step_output.target).startswith(
  716. "aten.copy_"
  717. ), f"The copy output is {orig_step_output.target}, expect aten.copy_"
  718. orig_step_getitem = orig_step_output.args[1]
  719. assert "getitem" in str(
  720. orig_step_getitem.target
  721. ), f"The copy getitem is {orig_step_getitem.target}, expect operator.getitem"
  722. orig_step_idx = orig_step_getitem.args[1]
  723. orig_step_indices[group_idx].append(orig_step_idx)
  724. if not all(l for l in (orig_step_indices + orig_optim_indices)):
  725. raise ValueError("At least one split optimizer does not have input.")
  726. output = get_output(gm.graph)
  727. results: List[FusedOptimizerBlock] = []
  728. flatten_output_args, spec = tree_flatten((output.args, output.kwargs))
  729. flatten_output_args_indices: DefaultDict[
  730. fx.Node, Set[int]
  731. ] = collections.defaultdict(set)
  732. for idx, output_arg in enumerate(flatten_output_args):
  733. if isinstance(output_arg, fx.Node):
  734. flatten_output_args_indices[output_arg].add(idx)
  735. def replace_flatten_output_args(orig_node: fx.Node, new_node: fx.Node):
  736. for idx in flatten_output_args_indices[orig_node]:
  737. flatten_output_args[idx] = new_node
  738. # Create the new step and optim nodes and blocks.
  739. for group_idx in range(2):
  740. step_args: List[fx.Node] = []
  741. orig_step_outputs: List[fx.Node] = []
  742. # We have to create the new step node and block first because it is used
  743. # for the new optim node as the input.
  744. with gm.graph.inserting_after(orig_optim_block.optim.optim_node):
  745. for idx in orig_step_indices[group_idx]:
  746. step_args.append(
  747. cast(Tuple[fx.Node, ...], orig_optim_block.step.add_node.args[0])[
  748. idx
  749. ]
  750. )
  751. orig_step_outputs.append(orig_optim_block.step.outputs[idx])
  752. step = gm.graph.call_function(
  753. aten._foreach_add.Scalar,
  754. (step_args, 1),
  755. )
  756. step_block = ForeachAddBlock(step, generate_output=True)
  757. for i, step_output in enumerate(step_block.outputs):
  758. # Replace the original step output in the graph output node with
  759. # the new one.
  760. orig_step_output = orig_step_outputs[i]
  761. replace_flatten_output_args(orig_step_output, step_output)
  762. # Also need to replace the step output used for the new optimizer.
  763. assert optim_args[group_idx].state_steps[i] == orig_step_output, (
  764. f"The expected step output node mismatched, {orig_step_output} "
  765. f"{optim_args[group_idx].state_steps[i]}"
  766. )
  767. optim_args[group_idx].state_steps[i] = step_output
  768. # Insert the optimizer node after the first step output because its
  769. # topo sort order is the last.
  770. with gm.graph.inserting_after(step_block.outputs[0]):
  771. optim = gm.graph.call_function(
  772. aten._fused_adam.default,
  773. optim_args[group_idx],
  774. orig_optim_block.optim.optim_node.kwargs,
  775. )
  776. optim_block = FusedAdamBlock(optim, generate_output=True)
  777. for curr_idx, orig_idx in enumerate(orig_optim_indices[group_idx]):
  778. list_names = ("param_outputs", "exp_avgs_outputs", "exp_avg_sqs_outputs")
  779. for name in list_names:
  780. orig_list = getattr(orig_optim_block.optim, name)
  781. curr_list = getattr(optim_block, name)
  782. replace_flatten_output_args(orig_list[orig_idx], curr_list[curr_idx])
  783. results.append(FusedOptimizerBlock(step_block, optim_block))
  784. # Optimizer is used as the output of the train_step. Therefore, we have to
  785. # update the output node of the graph.
  786. output_args, output_kwargs = tree_unflatten(flatten_output_args, spec)
  787. gm.graph.node_set_args(output, output_args)
  788. gm.graph.node_set_kwargs(output, output_kwargs)
  789. # Remove the original copy_ nodes as they won't be DCE.
  790. for copy_output in itertools.chain(
  791. orig_optim_block.optim.param_outputs,
  792. orig_optim_block.optim.exp_avgs_outputs,
  793. orig_optim_block.optim.exp_avg_sqs_outputs,
  794. ):
  795. gm.graph.erase_node(copy_output)
  796. # Call DCE once to get rid of the old optimizer. By doing so, we will be
  797. # able to erase the copy_ nodes of step later.
  798. gm.graph.eliminate_dead_code()
  799. for copy_output in orig_optim_block.step.outputs:
  800. gm.graph.erase_node(copy_output)
  801. # This is not required but calling this for consistency.
  802. gm.graph.eliminate_dead_code()
  803. return results[0], results[1]
  804. def split_fused_optimizer(
  805. gm: IterGraphModule,
  806. optim_block: FusedOptimizerBlock,
  807. split_gradients: Set[fx.Node],
  808. ) -> Tuple[FusedOptimizerBlock, FusedOptimizerBlock]:
  809. if not split_gradients:
  810. raise ValueError("The given split_gradients is empty.")
  811. if str(optim_block.optim.optim_node.target).startswith("aten._fused_adam"):
  812. return _split_fused_adam(gm, optim_block, split_gradients)
  813. else:
  814. raise NotImplementedError("Only fused_adam is supported now")
  815. # TODO(fegin): The API only support fused adam now. Should extend it to support
  816. # foreach as well.
  817. @graph_optimization_pass(
  818. prerequisites=[remove_copy_from_optimizer],
  819. apply_after=[schedule_comm_wait],
  820. )
  821. def iter_move_grads_and_optimizers(
  822. gm: IterGraphModule,
  823. target_comm_node: str,
  824. target_dest_node: str,
  825. ) -> None:
  826. """Extract a comm block and split out a new optimizer and step for it.
  827. This subgraph is then moved to the forward graph.
  828. """
  829. for comm_block in get_all_comm_blocks(gm, "all_reduce"):
  830. if comm_block.comm_node.name == target_comm_node:
  831. break
  832. else:
  833. raise ValueError(f"Cannot find {target_comm_node}")
  834. optim_blocks = get_all_fused_optimizer_blocks(gm, "_fused_adam")
  835. for optim_block in optim_blocks:
  836. optim_args = AdamArgs(*optim_block.optim.optim_node.args)
  837. one_output = next(iter(comm_block.outputs))
  838. if one_output in optim_args.grads:
  839. break
  840. else:
  841. raise ValueError(f"{target_comm_node} is not used by any fused optimizer.")
  842. move_optim, _ = split_fused_optimizer(gm, optim_block, comm_block.outputs)
  843. move_nodes = find_all_descendants(
  844. gm, [comm_block.comm_node, move_optim.step.add_node]
  845. )
  846. stop_node = find_node(gm.graph, lambda n: n.name == target_dest_node)[0]
  847. gm.graph.move_to_next_iter_before(move_nodes, stop_node)
  848. def find_all_descendants(
  849. gm: IterGraphModule,
  850. parent_nodes: List[fx.Node],
  851. ) -> List[fx.Node]:
  852. """Identify the list of nodes to move during FX graph transformation."""
  853. assert len(parent_nodes) > 0, "No parent nodes are given."
  854. output = get_output(gm.graph)
  855. dq_parent_nodes = collections.deque(parent_nodes)
  856. move_node_set = set()
  857. while dq_parent_nodes:
  858. node = dq_parent_nodes.popleft()
  859. move_node_set.add(node)
  860. dq_parent_nodes += [
  861. u for u in node.users if isinstance(u, fx.Node) and u != output
  862. ]
  863. move_nodes = [node for node in gm.graph.nodes if node in move_node_set]
  864. return move_nodes