pass_base.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. # mypy: allow-untyped-defs
  2. import operator
  3. import traceback
  4. import typing
  5. from contextlib import nullcontext
  6. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
  7. import torch
  8. from functorch.experimental.control_flow import _unstack_pytree
  9. from torch import fx
  10. from torch._dispatch.python import enable_python_dispatcher
  11. from torch._export.pass_infra.node_metadata import NodeMetadata
  12. from torch._export.pass_infra.proxy_value import ProxyValue
  13. from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
  14. from torch._subclasses.fake_tensor import FakeTensorMode
  15. from torch.fx import traceback as fx_traceback
  16. from torch.fx.experimental.proxy_tensor import PythonKeyTracer
  17. from torch.fx.graph import CodeGen
  18. from torch.fx.passes.infra.pass_base import PassBase, PassResult
  19. from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
  20. from torch.utils import _pytree as pytree
  21. from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings
  22. __all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
  23. Argument = Any
  24. Value = Any
  25. Fn = Callable[..., Any]
  26. PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
  27. _TORCH_SYM_OPS: Set[Callable] = {
  28. torch.sym_int,
  29. torch.sym_float,
  30. torch.sym_ite,
  31. torch.sym_max,
  32. torch.sym_min,
  33. torch.sym_not,
  34. torch.sym_sqrt,
  35. }
  36. class ExportPassBaseError(RuntimeError):
  37. pass
  38. class _ExportPassBaseDeprecatedDoNotUse(PassBase):
  39. """
  40. Interpreter-based pass class to help users maintain the IR spec while writing
  41. transformations.
  42. """
  43. @staticmethod
  44. def _create_dummy_node_metadata():
  45. return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
  46. class ExportTracer(PythonKeyTracer):
  47. def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None:
  48. super().__init__()
  49. self.callback = callback
  50. self.root = torch.nn.Module()
  51. self.graph = torch.fx.Graph()
  52. self.graph.set_codegen(codegen)
  53. self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment]
  54. self.fake_tensor_mode: Optional[FakeTensorMode] = None
  55. self.submodules: Dict[torch.nn.Module, str] = {}
  56. def trace(self) -> None:
  57. raise ExportPassBaseError("ExportTracer doesn't support trace().")
  58. def create_arg(self, a: Argument) -> torch.fx.Node:
  59. if isinstance(a, torch.nn.Module):
  60. if a not in self.submodules:
  61. name_submodule = f"submodule_{len(self.submodules)}"
  62. self.root.add_module(name_submodule, a)
  63. self.submodules[a] = name_submodule
  64. elif isinstance(a, FakeTensor):
  65. if not hasattr(a, "constant") or a.constant is None:
  66. raise ExportPassBaseError(f"Cannot add {a} to graph.")
  67. a = a.constant
  68. node = super().create_arg(a)
  69. if (
  70. isinstance(a, torch.Tensor)
  71. and isinstance(node, torch.fx.Node)
  72. and node.op == "get_attr"
  73. ):
  74. self.set_metadata(node, a)
  75. self.callback.on_attr(ProxyValue(a, node))
  76. return node
  77. def set_metadata(
  78. self, node: torch.fx.Node, value: Argument,
  79. ) -> None:
  80. # propagate the fake tensor or sym nodes
  81. def make_val(
  82. x: Argument,
  83. ) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]:
  84. if isinstance(x, FakeTensor):
  85. return x
  86. elif isinstance(x, torch.Tensor):
  87. if x.is_quantized:
  88. # TODO (tmanlaibaatar) properly support Quantized FakeTensor
  89. x = torch.dequantize(x)
  90. try:
  91. assert self.fake_tensor_mode is not None
  92. # TODO we should allocate static shapes
  93. # for param/buffer values
  94. if isinstance(x, torch.nn.Parameter):
  95. fake_tensor = self.fake_tensor_mode.from_tensor(
  96. x, static_shapes=True
  97. )
  98. else:
  99. fake_tensor = self.fake_tensor_mode.from_tensor(x)
  100. except UnsupportedFakeTensorException:
  101. # TODO: This is just a workaround to get over the
  102. # x.as_subclass error
  103. print(
  104. "Fakeifying a Tensor subclass is not supported \
  105. right now. Instead a TensorMetadata is used."
  106. )
  107. fake_tensor = None
  108. return fake_tensor
  109. elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)):
  110. return x
  111. else:
  112. return None
  113. node.meta["val"] = pytree.tree_map(make_val, value)
  114. # Set the tensor_metadata for values that do not have a corresponding FakeTensor
  115. def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]:
  116. if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor):
  117. if x.is_quantized:
  118. # TODO (tmanlaibaatar) properly support Quantized FakeTensor
  119. x = torch.dequantize(x)
  120. try:
  121. assert self.fake_tensor_mode is not None
  122. _ = self.fake_tensor_mode.from_tensor(x)
  123. tensor_meta = None
  124. except UnsupportedFakeTensorException:
  125. # TODO: This is just a workaround to get over the
  126. # x.as_subclass error
  127. tensor_meta = _extract_tensor_metadata(x)
  128. return tensor_meta
  129. else:
  130. return None
  131. node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
  132. class ExportInterpreter(fx.Interpreter):
  133. def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None:
  134. super().__init__(gm)
  135. self.callback = callback
  136. self.node: torch.fx.Node = next(iter(gm.graph.nodes))
  137. def placeholder(
  138. self,
  139. target: str,
  140. args: Tuple[Argument, ...],
  141. kwargs: Dict[str, Argument],
  142. ) -> ProxyValue:
  143. arg = super().placeholder(target, args, kwargs)
  144. return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
  145. def output(
  146. self,
  147. target: torch.fx.node.Target,
  148. args: Tuple[Argument, ...],
  149. kwargs: Dict[str, Argument],
  150. ) -> ProxyValue:
  151. return self.callback.output(args[0], NodeMetadata(self.node.meta)).data
  152. def call_function(
  153. self,
  154. target: torch.fx.node.Target,
  155. args: Tuple[Argument, ...],
  156. kwargs: Dict[str, Argument],
  157. ) -> ProxyValue:
  158. meta = NodeMetadata(self.node.meta)
  159. if target == operator.getitem:
  160. value, key = args
  161. return self.callback.call_getitem(value, key, meta)
  162. elif getattr(target, "__module__", None) in {"_operator", "math"}:
  163. assert callable(target)
  164. return self.callback.call_sym(target, args, meta)
  165. elif target in _TORCH_SYM_OPS:
  166. assert callable(target)
  167. return self.callback.call_sym(target, args, meta)
  168. elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
  169. return self.callback.call_operator(
  170. target,
  171. args,
  172. kwargs,
  173. meta,
  174. )
  175. elif target == torch.ops.higher_order.cond:
  176. pred, true_fn, false_fn, inputs = args
  177. return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
  178. elif target == torch.ops.higher_order.map_impl:
  179. f, mapped_args, operands = args # type: ignore[assignment]
  180. return self.callback.call_map(f, mapped_args, operands, meta)
  181. # For other unregistered HigherOrderOps, just interpret them blindly
  182. elif isinstance(target, torch._ops.HigherOrderOperator):
  183. return self.callback._fx(
  184. "call_function",
  185. target,
  186. args,
  187. kwargs,
  188. meta,
  189. )
  190. else:
  191. raise ExportPassBaseError(f"Unsupported target type: {target}")
  192. def get_attr(
  193. self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
  194. ) -> Argument:
  195. return super().get_attr(target, args, kwargs)
  196. def call_module(
  197. self,
  198. target: torch.fx.node.Target,
  199. args: Tuple[Argument, ...],
  200. kwargs: Dict[str, Argument],
  201. ) -> None:
  202. raise ExportPassBaseError("call_module is not supported.")
  203. def call_method(
  204. self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
  205. ) -> None:
  206. raise ExportPassBaseError("call_method is not supported.")
  207. def run_node(self, n: torch.fx.Node) -> Argument:
  208. self.node = n
  209. self.callback.node_debug_str = n.format_node()
  210. return super().run_node(n)
  211. def __init__(self) -> None:
  212. self.interpreter = PropagateUnbackedSymInts(
  213. torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
  214. )
  215. self.tracer = self.ExportTracer(self, CodeGen())
  216. self.fake_tensor_mode: Optional[FakeTensorMode] = None
  217. self._initialized = True
  218. self.node_debug_str: typing.Optional[str] = None
  219. def _fx(
  220. self,
  221. kind: str,
  222. target: torch.fx.node.Target,
  223. args: Tuple[Argument, ...],
  224. kwargs: Dict[str, Argument],
  225. meta: NodeMetadata,
  226. ) -> ProxyValue:
  227. args_data, kwargs_data = pytree.tree_map_only(
  228. ProxyValue, lambda x: x.data, (args, kwargs)
  229. )
  230. res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
  231. args_proxy, kwargs_proxy = pytree.tree_map_only(
  232. ProxyValue, lambda x: x.proxy, (args, kwargs)
  233. )
  234. name = None
  235. if isinstance(target, torch._ops.OpOverload):
  236. name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
  237. res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
  238. res_proxy.node.meta.update(meta.data)
  239. if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
  240. if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
  241. res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
  242. self.tracer.set_metadata(res_proxy.node, res_data)
  243. return ProxyValue(res_data, res_proxy)
  244. def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
  245. # TODO(angelayi): Update this with what we decide to do for metadata in
  246. # the exported graph module
  247. if (args := graph_module.meta.get("args", None)) is not None:
  248. return list(args)
  249. def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
  250. if "val" in node.meta:
  251. fake = node.meta["val"]
  252. if hasattr(fake, "constant") and fake.constant is not None:
  253. return fake.constant
  254. return fake
  255. elif tensor_meta := node.meta.get("tensor_meta"):
  256. assert self.fake_tensor_mode is not None
  257. return FakeTensor(
  258. self.fake_tensor_mode,
  259. torch.empty(
  260. tensor_meta.shape,
  261. dtype=tensor_meta.dtype,
  262. device="meta",
  263. requires_grad=tensor_meta.requires_grad,
  264. memory_format=tensor_meta.memory_format,
  265. ),
  266. torch.device("cpu"),
  267. )
  268. elif len(node.users) == 0:
  269. return None
  270. raise ExportPassBaseError(
  271. f"Cannot construct an input for graph module: {graph_module}.",
  272. )
  273. return [
  274. extract_input(node)
  275. for node in graph_module.graph.nodes
  276. if node.op == "placeholder"
  277. ]
  278. def on_attr(self, attr: ProxyValue) -> None:
  279. pass
  280. def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
  281. arg_proxy = self.tracer.create_proxy("placeholder", name, (), {})
  282. arg_proxy.node.meta = meta.data
  283. self.tracer.set_metadata(arg_proxy.node, arg)
  284. return ProxyValue(arg, arg_proxy)
  285. def call_operator(
  286. self,
  287. op,
  288. args: Tuple[Argument, ...],
  289. kwargs: Dict[str, Argument],
  290. meta: NodeMetadata,
  291. ) -> ProxyValue:
  292. return self._fx("call_function", op, args, kwargs, meta)
  293. def call_sym(
  294. self,
  295. target: Fn,
  296. args: Tuple[Argument, ...],
  297. meta: NodeMetadata,
  298. ) -> ProxyValue:
  299. return self._fx("call_function", target, args, {}, meta)
  300. def call_cond(
  301. self,
  302. pred: ProxyValue,
  303. true_fn: torch.fx.GraphModule,
  304. false_fn: torch.fx.GraphModule,
  305. inputs: List[Argument],
  306. meta: NodeMetadata,
  307. ) -> ProxyValue:
  308. true_branch = self.call_submodule(true_fn, tuple(inputs))
  309. false_branch = self.call_submodule(false_fn, tuple(inputs))
  310. assert true_branch is not None
  311. assert false_branch is not None
  312. return self._fx(
  313. "call_function",
  314. torch.ops.higher_order.cond,
  315. (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)),
  316. {},
  317. meta,
  318. )
  319. def call_map(
  320. self,
  321. f: torch.fx.GraphModule,
  322. mapped_args: List[ProxyValue],
  323. operands: List[ProxyValue],
  324. meta: NodeMetadata,
  325. ) -> ProxyValue:
  326. xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
  327. f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
  328. assert f_branch is not None
  329. return self._fx(
  330. "call_function",
  331. torch.ops.higher_order.map_impl,
  332. (f_branch.graph_module, mapped_args, operands),
  333. {},
  334. meta,
  335. )
  336. def call_getitem(
  337. self, value: ProxyValue, key: int, meta: NodeMetadata
  338. ) -> ProxyValue:
  339. return self._fx("call_function", operator.getitem, (value, key), {}, meta)
  340. def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
  341. return self._fx("output", "output", (results,), {}, meta)
  342. def call_submodule(
  343. self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
  344. ) -> PassResult:
  345. prev_tracer, self.tracer = self.tracer, self.ExportTracer(
  346. self, graph_module.graph._codegen
  347. )
  348. self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
  349. interpreter = self.ExportInterpreter(self, graph_module)
  350. prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
  351. torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
  352. )
  353. inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
  354. with fx_traceback.preserve_node_meta():
  355. interpreter.run(*inputs_data)
  356. new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
  357. self.tracer = prev_tracer
  358. self.interpreter = prev_interpreter
  359. return PassResult(
  360. new_graph_module,
  361. True,
  362. )
  363. def call(self, graph_module: fx.GraphModule) -> PassResult:
  364. if not getattr(self, "_initialized", False):
  365. raise ExportPassBaseError(
  366. "ExportPass is not initialized with __init__().",
  367. )
  368. inputs = self.inputs(graph_module)
  369. fake_tensor_mode = None
  370. for i in inputs:
  371. if isinstance(i, FakeTensor):
  372. assert (
  373. fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
  374. ), "Multiple fake tensor mode detected."
  375. fake_tensor_mode = i.fake_mode
  376. if fake_tensor_mode is None:
  377. self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
  378. fake_tensor_mode = nullcontext() # type: ignore[assignment]
  379. dispatcher_mode = nullcontext() # type: ignore[assignment]
  380. else:
  381. fake_tensor_mode.allow_non_fake_inputs = True
  382. self.tracer.fake_tensor_mode = fake_tensor_mode
  383. dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment]
  384. self.fake_tensor_mode = self.tracer.fake_tensor_mode
  385. with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr]
  386. result = self.call_submodule(graph_module, tuple(inputs))
  387. return result