fx_utils.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # mypy: allow-untyped-defs
  2. import operator
  3. from collections import defaultdict
  4. from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
  5. import sympy
  6. import torch
  7. import torch.fx
  8. from torch.fx.experimental.symbolic_shapes import (
  9. compute_unbacked_bindings,
  10. rebind_unbacked,
  11. statically_known_true,
  12. sym_eq,
  13. )
  14. from torch.utils import _pytree as pytree
  15. from torch.utils._pytree import tree_map
  16. from .virtualized import V
  17. # Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
  18. # Works for length 2 patterns with 1 module and 1 function/method.
  19. def matches_module_function_pattern(
  20. pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
  21. node: torch.fx.node.Node,
  22. modules: Dict[str, torch.nn.modules.Module],
  23. ) -> bool:
  24. if len(node.args) == 0:
  25. return False
  26. if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
  27. node, torch.fx.Node
  28. ):
  29. return False
  30. # the first node is call_module
  31. if node.args[0].op != "call_module":
  32. return False
  33. if not isinstance(node.args[0].target, str):
  34. return False
  35. if node.args[0].target not in modules:
  36. return False
  37. if type(modules[node.args[0].target]) is not pattern[0]:
  38. return False
  39. # the second node is call_function or call_method
  40. if node.op != "call_function" and node.op != "call_method":
  41. return False
  42. if node.target != pattern[1]:
  43. return False
  44. # make sure node.args[0] output is only used by current node.
  45. if len(node.args[0].users) > 1:
  46. return False
  47. return True
  48. class FakeTensorUpdater:
  49. """
  50. The main idea here is that it's difficult to maintain accurate fake
  51. tensors (our primary form of metadata) for each node in our graph as we
  52. transform it.
  53. The most reliable way to obtain this information is by rerunning
  54. faketensor propagation. However, in general, faketensor propagation is
  55. fairly expensive. So, instead we'd like to only rerun faketensor
  56. propagation on nodes that have changed.
  57. In order to detect which nodes have changed, we first hash its node,
  58. target, and argument lists (which are immutable in FX).
  59. Then, whenever we call incremental_update, we check which FX nodes have a
  60. new hash, and recompute the faketensor metadata for that node. Then, we
  61. continue to recursively compute the faketensors for all users until the
  62. fake tensors stop changing.
  63. """
  64. def __init__(self, graph: torch.fx.Graph):
  65. self.processed_hashes = set()
  66. self.graph = graph
  67. for node in self.graph.nodes:
  68. self.processed_hashes.add(self.hash_node(node))
  69. def hash_node(self, node: torch.fx.Node):
  70. # todo(chilli): Not a great hash function
  71. return (node, node.target, id(node.args), id(node.kwargs))
  72. def incremental_update(self):
  73. processed = set()
  74. existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
  75. for node in self.graph.nodes:
  76. existing_storages[get_node_storage(node)] += 1
  77. def is_intlist_same(new, old):
  78. return statically_known_true(sym_eq(new, old))
  79. def is_fake_tensor_same(new, old):
  80. if type(new) != type(old):
  81. return False
  82. if isinstance(new, (list, tuple)):
  83. if len(new) != len(old):
  84. return False
  85. return all(
  86. is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
  87. )
  88. if new is None:
  89. return old is None
  90. if not isinstance(new, torch.Tensor):
  91. assert isinstance(
  92. new, (torch.SymInt, torch.SymBool, torch.SymFloat)
  93. ), f"Unknown type {type(new)} in {self.graph}"
  94. return (
  95. new.node.shape_env._maybe_evaluate_static(
  96. sympy.Eq(new.node.expr, old.node.expr)
  97. )
  98. == sympy.true
  99. )
  100. if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
  101. return False
  102. if new.layout == torch.strided and (
  103. not is_intlist_same(new.stride(), old.stride())
  104. or not statically_known_true(
  105. new.storage_offset() == old.storage_offset()
  106. )
  107. ):
  108. return False
  109. if new.device != old.device:
  110. return False
  111. if get_storage(new) == get_storage(old):
  112. return True
  113. # This is the case where it returns a completely fresh storage that's used nowhere else.
  114. if (
  115. existing_storages[get_storage(old)] == 1
  116. and get_storage(new) not in existing_storages
  117. ):
  118. return True
  119. return False
  120. def should_process_node(node):
  121. # node.target for nodes returning true from this function
  122. # are called under fake mode and does not work for inductor
  123. # lowerings. We check if the node.target is an aten operator
  124. # or operator.getitem which is used when returning multiple
  125. # tensors from an op.
  126. return node.op == "call_function" and (
  127. isinstance(node.target, torch._ops.OpOverload)
  128. or node.target == operator.getitem
  129. )
  130. to_process = set()
  131. for node in self.graph.nodes:
  132. if (
  133. self.hash_node(node) in self.processed_hashes
  134. and id(node) not in to_process
  135. ):
  136. continue
  137. if not should_process_node(node):
  138. continue
  139. is_valid, args, kwargs = get_fake_args_kwargs(node)
  140. if not is_valid:
  141. continue
  142. with V.fake_mode:
  143. new_fake_tensor = node.target(*args, **kwargs)
  144. if "val" in node.meta and is_fake_tensor_same(
  145. new_fake_tensor, node.meta["val"]
  146. ):
  147. continue
  148. rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor)
  149. node.meta["val"] = new_fake_tensor
  150. if (shape_env := V.fake_mode.shape_env) and (
  151. symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
  152. ):
  153. # Refresh the bindings to the new symbols
  154. node.meta["unbacked_bindings"] = symbol_to_path
  155. existing_storages[get_node_storage(node)] += 1
  156. to_process.update([id(user) for user in node.users])
  157. self.processed_hashes.add(self.hash_node(node))
  158. def get_storage(t: torch.Tensor) -> int:
  159. return t.untyped_storage()._cdata
  160. def get_node_storage(node: torch.fx.Node) -> Optional[int]:
  161. if "val" not in node.meta:
  162. return None
  163. if not isinstance(node.meta["val"], torch.Tensor):
  164. return None
  165. if not torch._C._has_storage(node.meta["val"]):
  166. return None
  167. return get_storage(node.meta["val"])
  168. def get_fake(x):
  169. if isinstance(x, torch.fx.Node):
  170. if "val" not in x.meta:
  171. return x
  172. return x.meta["val"]
  173. return x
  174. def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
  175. """
  176. First value returns a boolean if any of the input nodes don't have a faketensor.
  177. """
  178. args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
  179. if any(
  180. isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
  181. ):
  182. return False, args, kwargs
  183. return True, args, kwargs
  184. def is_node_realized(node: torch.fx.Node) -> bool:
  185. """Returns true if a node is always realized when lowered to inductor IR.
  186. NOTE: This may return some false negatives. e.g. it doesn't
  187. handle buffers realized heuristically during lowering, or
  188. buffers realized indirectly through view ops.
  189. """
  190. from torch._inductor.lowering import fallbacks, needs_realized_inputs
  191. def is_buffer(node: torch.fx.Node) -> bool:
  192. if node.op == "call_function" and node.target is operator.getitem:
  193. # For nodes with multiple outputs, we get the fx graph:
  194. # foo = torch.ops.aten.foo(...)
  195. # getitem = foo[0]
  196. # getitem_1 = foo[1]
  197. # where we need to check if foo is a fallback kernel
  198. return is_buffer(node.args[0]) # type: ignore[arg-type]
  199. return node.op in ("placeholder", "output") or node.target in fallbacks
  200. if is_buffer(node):
  201. return True
  202. def realizes_inputs(node: torch.fx.Node) -> bool:
  203. return node.op == "output" or node.target in needs_realized_inputs
  204. if any(realizes_inputs(user) for user in node.users):
  205. return True
  206. # Otherwise, assume node isn't realized
  207. return False