partial_lower.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # This file is copied from Meta internal repo and is not synced with the
  2. # internal version. Once the internal version is fully mature, we should
  3. # upstream again and retire the internal version. @yifuwang
  4. import logging
  5. import operator
  6. from typing import Callable, List, Optional, Set, Tuple
  7. import torch
  8. from functorch import make_fx
  9. from torch._inductor.compile_fx import compile_fx_inner
  10. from torch._inductor.decomposition import select_decomp_table
  11. MIN_ATEN_OPS_TO_LOWER = 10
  12. logger: logging.Logger = logging.getLogger(__name__)
  13. def _create_subgraph_module(
  14. inputs: List[torch.fx.Node], body: List[torch.fx.Node], outputs: List[torch.fx.Node]
  15. ) -> torch.fx.GraphModule:
  16. subgraph: torch.fx.Graph = torch.fx.Graph()
  17. node_to_subgraph_node = {}
  18. for idx, inp in enumerate(inputs):
  19. subgraph_inp = subgraph.placeholder(name=f"arg_{idx}")
  20. subgraph_inp.meta = inp.meta
  21. node_to_subgraph_node[inp] = subgraph_inp
  22. for node in body:
  23. subgraph_node = subgraph.node_copy(
  24. node, arg_transform=lambda x: node_to_subgraph_node[x]
  25. )
  26. node_to_subgraph_node[node] = subgraph_node
  27. subgraph.output(result=tuple(node_to_subgraph_node[x] for x in outputs))
  28. subgraph.eliminate_dead_code()
  29. subgraph.lint()
  30. return torch.fx.GraphModule(root={}, graph=subgraph)
  31. def _is_container_node(node: torch.fx.Node) -> bool:
  32. if any(user.target == operator.getitem for user in node.users):
  33. assert all(user.target == operator.getitem for user in node.users), (
  34. "Malformed graph: a container node is used as input for non-getitem nodes."
  35. "\nNode: {fmt_node}\nUsers: {fmt_users}".format(
  36. fmt_node=node.format_node(),
  37. fmt_users="\n".join(u.format_node() for u in node.users),
  38. )
  39. )
  40. return True
  41. return False
  42. def _lower_subgraph_nodes(
  43. gm: torch.fx.GraphModule,
  44. subgraph_name: str,
  45. subgraph_nodes: List[torch.fx.Node],
  46. dumper: Callable[[str], str],
  47. ) -> None:
  48. prologue: List[torch.fx.Node] = []
  49. inputs: List[torch.fx.Node] = []
  50. body: List[torch.fx.Node] = []
  51. visible: Set[torch.fx.Node] = set()
  52. # Inductor requires all graph input to be tensors. When adding a container
  53. # node as subgraph input, add its descendant getitem nodes to the subgraph
  54. # prologue and add its leaf getitem nodes to the subgraph input.
  55. def add_input(arg: torch.fx.Node) -> None:
  56. stack = [arg]
  57. while len(stack) != 0:
  58. node = stack.pop()
  59. if _is_container_node(node):
  60. # We should only prepone nodes within subgraph_nodes
  61. prologue.extend(user for user in node.users if user in subgraph_nodes)
  62. stack.extend(node.users)
  63. else:
  64. if node not in visible:
  65. inputs.append(node)
  66. visible.add(node)
  67. for node in subgraph_nodes:
  68. if node.op == "get_attr":
  69. # Prepone get_attr to avoid having to copy
  70. # the attribute to the subgraph module.
  71. inputs.append(node)
  72. visible.add(node)
  73. continue
  74. for arg in node.all_input_nodes:
  75. if arg not in visible:
  76. add_input(arg)
  77. if node not in prologue:
  78. body.append(node)
  79. visible.add(node)
  80. outputs: List[torch.fx.Node] = []
  81. # Inductor requires all graph output to be tensors. When adding a container
  82. # node as subgraph output, add its descendant getitem nodes to the subgraph
  83. # body and add its leaf getitem nodes to the subgraph output.
  84. def add_output(output: torch.fx.Node) -> None:
  85. stack = [output]
  86. while len(stack) != 0:
  87. node = stack.pop()
  88. if _is_container_node(node):
  89. body.extend(node.users)
  90. stack.extend(node.users)
  91. elif not all(user in visible for user in node.users):
  92. if node not in outputs:
  93. outputs.append(node)
  94. for node in body:
  95. if not all(user in visible for user in node.users):
  96. add_output(node)
  97. assert len(inputs) == len(set(inputs))
  98. assert len(outputs) == len(set(outputs))
  99. subgraph_module = _create_subgraph_module(inputs, body, outputs)
  100. readable_tag = dumper(str(subgraph_module.graph))
  101. setattr(gm, subgraph_name, _InductorModule(subgraph_module))
  102. insertion_point = subgraph_nodes[-1].next
  103. for node in prologue:
  104. insertion_point.prepend(node)
  105. with gm.graph.inserting_before(insertion_point):
  106. # Insert subgraph call
  107. subgraph_call = gm.graph.create_node(
  108. op="call_module",
  109. target=subgraph_name,
  110. args=tuple(inputs),
  111. kwargs={"tag": readable_tag},
  112. )
  113. # Replace parent graph nodes with their corresponding subgraph outputs
  114. for idx, output in enumerate(outputs):
  115. new_output = gm.graph.create_node(
  116. op="call_function",
  117. target=operator.getitem,
  118. args=(subgraph_call, idx),
  119. )
  120. new_output.meta = output.meta
  121. output.replace_all_uses_with(new_output)
  122. # Erase lowered nodes from the parent graph
  123. for node in reversed(body + outputs):
  124. if len(node.users) == 0:
  125. gm.graph.erase_node(node)
  126. class _InductorModule(torch.nn.Module):
  127. def __init__(self, gm: torch.fx.GraphModule) -> None:
  128. super().__init__()
  129. self.gm = gm
  130. self.compiled: Optional[
  131. Callable[[List[torch.Tensor]], List[torch.Tensor]]
  132. ] = None
  133. def forward(self, *args: torch.Tensor, tag: str) -> List[torch.Tensor]:
  134. if self.compiled is None:
  135. inductor_decompositions = select_decomp_table()
  136. # TODO: figure out why turning on cudagraphs cause exceptions.
  137. decomp_gm = make_fx(self.gm, decomposition_table=inductor_decompositions)(
  138. *args
  139. )
  140. logger.info("Lowering subgraph (%s) to Inductor...", tag)
  141. self.compiled = compile_fx_inner(
  142. decomp_gm,
  143. list(args),
  144. cudagraphs=False,
  145. )
  146. logger.info("Completed lowering subgraph (%s) to Inductor", tag)
  147. with torch.profiler.record_function(tag):
  148. assert self.compiled is not None
  149. return self.compiled(list(args))
  150. def _is_inductor_compatible(node: torch.fx.Node) -> Tuple[bool, str]:
  151. # `has_tag` is not supported yet
  152. # if has_tag(node, "non_lowerable"):
  153. if node.target in (
  154. torch.ops.aten._fused_adam_.default,
  155. torch.ops.aten._fused_adam.default,
  156. torch.ops.aten._foreach_add_.Scalar,
  157. torch.ops.aten._foreach_add.Scalar,
  158. ):
  159. return False, "fused adam is not supported yet"
  160. # TODO(yifu): apparently having a meta kernel is not a necessary
  161. # condition for Inductor compatiblity. We should refine the check.
  162. # Sneaking this one in for now to support comm_fusion_with_cat.
  163. if node.target == torch.ops.aten.flatten.using_ints:
  164. return True, ""
  165. if isinstance(node.target, torch._ops.OpOverload):
  166. if not node.target.has_kernel_for_dispatch_key(torch._C.DispatchKey.Meta):
  167. return False, f"{node.target} doesn't have a meta kernel registered"
  168. return True, ""
  169. def _subgraph_predicate(nodes: List[torch.fx.Node]) -> bool:
  170. num_aten_ops = len([n for n in nodes if str(n.target).startswith("aten.")])
  171. return num_aten_ops >= MIN_ATEN_OPS_TO_LOWER
  172. def partial_lower(
  173. gm: torch.fx.GraphModule,
  174. node_predicate: Callable[[torch.fx.Node], bool] = lambda x: True,
  175. subgraph_predicate: Callable[[List[torch.fx.Node]], bool] = lambda x: True,
  176. dumper: Callable[[str], str] = lambda x: "subgraph",
  177. ) -> torch.fx.GraphModule:
  178. """
  179. Lower Inductor compatible portions of the graph module to Inductor.
  180. Args:
  181. node_predicate: user predicate for determining whether to consider a node for
  182. lowering.
  183. subgraph_predicate: user predicate for determining whether to consider a list of
  184. candidate nodes for lowering.
  185. dumper: a callback for dumping subgraphs for human digestion. For exmaple, it
  186. can be a function that writes to disk/blob storage and returns the
  187. path/handle. The returned path/handle for each subgraph will be made
  188. available in the subgraph call node in the parent graph, as well as the
  189. label of the profiler block for the subgraph.
  190. """
  191. nodes_per_subgraph: List[List[torch.fx.Node]] = [[]]
  192. ptr = next(iter(gm.graph.nodes))
  193. def _node_predicate(node: torch.fx.Node) -> Tuple[bool, str]:
  194. should_lower, reason = _is_inductor_compatible(node)
  195. if not should_lower:
  196. return should_lower, reason
  197. if not node_predicate(node):
  198. return False, "user predicate"
  199. return True, ""
  200. while ptr.op != "output":
  201. if ptr.op == "placeholder":
  202. ptr = ptr.next
  203. continue
  204. should_lower, reason = _node_predicate(ptr)
  205. if should_lower:
  206. nodes_per_subgraph[-1].append(ptr)
  207. else:
  208. if len(nodes_per_subgraph[-1]) > 0:
  209. logger.warning(
  210. "partial_lower: graph break at %s. Reason: %s", str(ptr), reason
  211. )
  212. nodes_per_subgraph.append([])
  213. ptr = ptr.next
  214. nodes_per_subgraph = [
  215. nodes
  216. for nodes in nodes_per_subgraph
  217. if subgraph_predicate(nodes) and _subgraph_predicate(nodes)
  218. ]
  219. for idx, subgraph_nodes in enumerate(nodes_per_subgraph):
  220. subgraph_name = f"subgraph_{idx}"
  221. _lower_subgraph_nodes(gm, subgraph_name, subgraph_nodes, dumper)
  222. gm.graph.lint()
  223. gm.recompile()
  224. return gm