interpreter.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. # mypy: allow-untyped-defs
  2. from .graph_module import GraphModule
  3. from ._lazy_graph_module import _make_graph_module
  4. from .graph import Graph
  5. from .node import Argument, Node, Target, map_arg, map_aggregate
  6. from .proxy import Proxy
  7. from ._symbolic_trace import Tracer
  8. from ._compatibility import compatibility
  9. from . import config
  10. import torch.fx.traceback as fx_traceback
  11. import torch
  12. from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
  13. import inspect
  14. from contextlib import contextmanager
  15. from torch.hub import tqdm
  16. __all__ = ['Interpreter', 'Transformer']
  17. @compatibility(is_backward_compatible=True)
  18. class Interpreter:
  19. """
  20. An Interpreter executes an FX graph Node-by-Node. This pattern
  21. can be useful for many things, including writing code
  22. transformations as well as analysis passes.
  23. Methods in the Interpreter class can be overridden to customize
  24. the behavior of execution. The map of overrideable methods
  25. in terms of call hierarchy::
  26. run()
  27. +-- run_node
  28. +-- placeholder()
  29. +-- get_attr()
  30. +-- call_function()
  31. +-- call_method()
  32. +-- call_module()
  33. +-- output()
  34. Example:
  35. Suppose we want to swap all instances of ``torch.neg`` with
  36. ``torch.sigmoid`` and vice versa (including their ``Tensor``
  37. method equivalents). We could subclass Interpreter like so::
  38. class NegSigmSwapInterpreter(Interpreter):
  39. def call_function(self, target : Target,
  40. args : Tuple, kwargs : Dict) -> Any:
  41. if target == torch.sigmoid:
  42. return torch.neg(*args, **kwargs)
  43. return super().call_function(n)
  44. def call_method(self, target : Target,
  45. args : Tuple, kwargs : Dict) -> Any:
  46. if target == 'neg':
  47. call_self, *args_tail = args
  48. return call_self.sigmoid(*args_tail, **kwargs)
  49. return super().call_method(n)
  50. def fn(x):
  51. return torch.sigmoid(x).neg()
  52. gm = torch.fx.symbolic_trace(fn)
  53. input = torch.randn(3, 4)
  54. result = NegSigmSwapInterpreter(gm).run(input)
  55. torch.testing.assert_close(result, torch.neg(input).sigmoid())
  56. Args:
  57. module (torch.nn.Module): The module to be executed
  58. garbage_collect_values (bool): Whether to delete values after their last
  59. use within the Module's execution. This ensures optimal memory usage during
  60. execution. This can be disabled to, for example, examine all of the intermediate
  61. values in the execution by looking at the ``Interpreter.env`` attribute.
  62. graph (Optional[Graph]): If passed, the interpreter will execute this
  63. graph instead of `module.graph`, using the provided `module`
  64. argument to satisfy any requests for state.
  65. """
  66. @compatibility(is_backward_compatible=True)
  67. def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None):
  68. self.module = module
  69. self.submodules = dict(self.module.named_modules())
  70. if graph is not None:
  71. self.graph = graph
  72. else:
  73. self.graph = self.module.graph
  74. self.env : Dict[Node, Any] = {}
  75. self.name = "Interpreter"
  76. self.garbage_collect_values = garbage_collect_values
  77. self.extra_traceback = True
  78. if self.garbage_collect_values:
  79. # Run through reverse nodes and record the first instance of a use
  80. # of a given node. This represents the *last* use of the node in the
  81. # execution order of the program, which we will use to free unused
  82. # values
  83. node_to_last_use : Dict[Node, Node] = {}
  84. self.user_to_last_uses : Dict[Node, List[Node]] = {}
  85. def register_last_uses(n : Node, user : Node):
  86. if n not in node_to_last_use:
  87. node_to_last_use[n] = user
  88. self.user_to_last_uses.setdefault(user, []).append(n)
  89. for node in reversed(self.graph.nodes):
  90. map_arg(node.args, lambda n: register_last_uses(n, node))
  91. map_arg(node.kwargs, lambda n: register_last_uses(n, node))
  92. @compatibility(is_backward_compatible=True)
  93. def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any:
  94. """
  95. Run `module` via interpretation and return the result.
  96. Args:
  97. *args: The arguments to the Module to run, in positional order
  98. initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
  99. This is a dict mapping `Node` to any value. This can be used, for example, to
  100. pre-populate results for certain `Nodes` so as to do only partial evaluation within
  101. the interpreter.
  102. enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
  103. process_outputs function first before using them.
  104. Returns:
  105. Any: The value returned from executing the Module
  106. """
  107. self.env = initial_env if initial_env is not None else {}
  108. # Positional function args are consumed left-to-right by
  109. # `placeholder` nodes. Use an iterator to keep track of
  110. # position and extract those values.
  111. if enable_io_processing:
  112. args = self.graph.process_inputs(*args)
  113. self.args_iter : Iterator[Any] = iter(args)
  114. pbar = tqdm(total=len(self.graph.nodes),
  115. desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}",
  116. initial=0, position=0, leave=True, disable=config.disable_progress, delay=0)
  117. for node in self.graph.nodes:
  118. pbar.update(1)
  119. if node in self.env:
  120. # Short circuit if we have this value. This could
  121. # be used, for example, for partial evaluation
  122. # where the caller has pre-populated `env` with
  123. # values for a subset of the program.
  124. continue
  125. try:
  126. self.env[node] = self.run_node(node)
  127. except Exception as e:
  128. if self.extra_traceback:
  129. msg = f"While executing {node.format_node()}"
  130. msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg)
  131. msg += f"\nOriginal traceback:\n{node.stack_trace}"
  132. e.args = (msg,) + e.args[1:]
  133. if isinstance(e, KeyError):
  134. raise RuntimeError(*e.args) from e
  135. raise
  136. if self.garbage_collect_values:
  137. for to_delete in self.user_to_last_uses.get(node, []):
  138. del self.env[to_delete]
  139. if node.op == 'output':
  140. output_val = self.env[node]
  141. return self.graph.process_outputs(output_val) if enable_io_processing else output_val
  142. @compatibility(is_backward_compatible=True)
  143. def boxed_run(self, args_list):
  144. """
  145. Run `module` via interpretation and return the result. This uses the "boxed"
  146. calling convention, where you pass a list of arguments, which will be cleared
  147. by the interpreter. This ensures that input tensors are promptly deallocated.
  148. """
  149. args_iter = iter(args_list)
  150. env = {}
  151. for n in self.graph.nodes:
  152. if n.op == "placeholder":
  153. env[n] = next(args_iter)
  154. args_list.clear()
  155. return self.run(initial_env=env)
  156. @contextmanager
  157. def _set_current_node(self, node):
  158. with fx_traceback.set_current_meta(node):
  159. yield
  160. @compatibility(is_backward_compatible=True)
  161. def run_node(self, n : Node) -> Any:
  162. """
  163. Run a specific node ``n`` and return the result.
  164. Calls into placeholder, get_attr, call_function,
  165. call_method, call_module, or output depending
  166. on ``node.op``
  167. Args:
  168. n (Node): The Node to execute
  169. Returns:
  170. Any: The result of executing ``n``
  171. """
  172. with self._set_current_node(n):
  173. args, kwargs = self.fetch_args_kwargs_from_env(n)
  174. assert isinstance(args, tuple)
  175. assert isinstance(kwargs, dict)
  176. return getattr(self, n.op)(n.target, args, kwargs)
  177. # Main Node running APIs
  178. @compatibility(is_backward_compatible=True)
  179. def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  180. """
  181. Execute a ``placeholder`` node. Note that this is stateful:
  182. ``Interpreter`` maintains an internal iterator over
  183. arguments passed to ``run`` and this method returns
  184. next() on that iterator.
  185. Args:
  186. target (Target): The call target for this node. See
  187. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  188. details on semantics
  189. args (Tuple): Tuple of positional args for this invocation
  190. kwargs (Dict): Dict of keyword arguments for this invocation
  191. Returns:
  192. Any: The argument value that was retrieved.
  193. """
  194. assert isinstance(target, str)
  195. if target.startswith('*'):
  196. # For a starred parameter e.g. `*args`, retrieve all
  197. # remaining values from the args list.
  198. return list(self.args_iter)
  199. else:
  200. try:
  201. return next(self.args_iter)
  202. except StopIteration as si:
  203. if len(args) > 0:
  204. return args[0]
  205. else:
  206. raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si
  207. @compatibility(is_backward_compatible=True)
  208. def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  209. """
  210. Execute a ``get_attr`` node. Will retrieve an attribute
  211. value from the ``Module`` hierarchy of ``self.module``.
  212. Args:
  213. target (Target): The call target for this node. See
  214. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  215. details on semantics
  216. args (Tuple): Tuple of positional args for this invocation
  217. kwargs (Dict): Dict of keyword arguments for this invocation
  218. Return:
  219. Any: The value of the attribute that was retrieved
  220. """
  221. assert isinstance(target, str)
  222. return self.fetch_attr(target)
  223. @compatibility(is_backward_compatible=True)
  224. def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  225. """
  226. Execute a ``call_function`` node and return the result.
  227. Args:
  228. target (Target): The call target for this node. See
  229. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  230. details on semantics
  231. args (Tuple): Tuple of positional args for this invocation
  232. kwargs (Dict): Dict of keyword arguments for this invocation
  233. Return
  234. Any: The value returned by the function invocation
  235. """
  236. assert not isinstance(target, str)
  237. # Execute the function and return the result
  238. return target(*args, **kwargs)
  239. @compatibility(is_backward_compatible=True)
  240. def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  241. """
  242. Execute a ``call_method`` node and return the result.
  243. Args:
  244. target (Target): The call target for this node. See
  245. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  246. details on semantics
  247. args (Tuple): Tuple of positional args for this invocation
  248. kwargs (Dict): Dict of keyword arguments for this invocation
  249. Return
  250. Any: The value returned by the method invocation
  251. """
  252. # args[0] is the `self` object for this method call
  253. self_obj, *args_tail = args
  254. # Execute the method and return the result
  255. assert isinstance(target, str)
  256. return getattr(self_obj, target)(*args_tail, **kwargs)
  257. @compatibility(is_backward_compatible=True)
  258. def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  259. """
  260. Execute a ``call_module`` node and return the result.
  261. Args:
  262. target (Target): The call target for this node. See
  263. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  264. details on semantics
  265. args (Tuple): Tuple of positional args for this invocation
  266. kwargs (Dict): Dict of keyword arguments for this invocation
  267. Return
  268. Any: The value returned by the module invocation
  269. """
  270. # Retrieve executed args and kwargs values from the environment
  271. # Execute the method and return the result
  272. assert isinstance(target, str)
  273. submod = self.fetch_attr(target)
  274. return submod(*args, **kwargs)
  275. @compatibility(is_backward_compatible=True)
  276. def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  277. """
  278. Execute an ``output`` node. This really just retrieves
  279. the value referenced by the ``output`` node and returns it.
  280. Args:
  281. target (Target): The call target for this node. See
  282. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  283. details on semantics
  284. args (Tuple): Tuple of positional args for this invocation
  285. kwargs (Dict): Dict of keyword arguments for this invocation
  286. Return:
  287. Any: The return value referenced by the output node
  288. """
  289. return args[0]
  290. # Helper methods
  291. @compatibility(is_backward_compatible=True)
  292. def fetch_attr(self, target : str):
  293. """
  294. Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
  295. Args:
  296. target (str): The fully-qualified name of the attribute to fetch
  297. Return:
  298. Any: The value of the attribute.
  299. """
  300. target_atoms = target.split('.')
  301. attr_itr = self.module
  302. for i, atom in enumerate(target_atoms):
  303. if not hasattr(attr_itr, atom):
  304. raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
  305. attr_itr = getattr(attr_itr, atom)
  306. return attr_itr
  307. @compatibility(is_backward_compatible=True)
  308. def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
  309. """
  310. Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
  311. from the current execution environment.
  312. Args:
  313. n (Node): The node for which ``args`` and ``kwargs`` should be fetched.
  314. Return:
  315. Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``.
  316. """
  317. args = self.map_nodes_to_values(n.args, n)
  318. assert isinstance(args, tuple)
  319. kwargs = self.map_nodes_to_values(n.kwargs, n)
  320. assert isinstance(kwargs, dict)
  321. return args, kwargs
  322. @compatibility(is_backward_compatible=True)
  323. def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
  324. """
  325. Recursively descend through ``args`` and look up the concrete value
  326. for each ``Node`` in the current execution environment.
  327. Args:
  328. args (Argument): Data structure within which to look up concrete values
  329. n (Node): Node to which ``args`` belongs. This is only used for error reporting.
  330. """
  331. def load_arg(n_arg : Node) -> Any:
  332. if n_arg not in self.env:
  333. raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() '
  334. f'to diagnose such issues')
  335. return self.env[n_arg]
  336. return map_arg(args, load_arg)
  337. @compatibility(is_backward_compatible=True)
  338. class Transformer(Interpreter):
  339. """
  340. ``Transformer`` is a special type of interpreter that produces a
  341. new ``Module``. It exposes a ``transform()`` method that returns
  342. the transformed ``Module``. ``Transformer`` does not require
  343. arguments to run, as ``Interpreter`` does. ``Transformer`` works
  344. entirely symbolically.
  345. Example:
  346. Suppose we want to swap all instances of ``torch.neg`` with
  347. ``torch.sigmoid`` and vice versa (including their ``Tensor``
  348. method equivalents). We could subclass ``Transformer`` like so::
  349. class NegSigmSwapXformer(Transformer):
  350. def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  351. if target == torch.sigmoid:
  352. return torch.neg(*args, **kwargs)
  353. return super().call_function(n)
  354. def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  355. if target == 'neg':
  356. call_self, *args_tail = args
  357. return call_self.sigmoid(*args_tail, **kwargs)
  358. return super().call_method(n)
  359. def fn(x):
  360. return torch.sigmoid(x).neg()
  361. gm = torch.fx.symbolic_trace(fn)
  362. transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
  363. input = torch.randn(3, 4)
  364. torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
  365. Args:
  366. module (GraphModule): The ``Module`` to be transformed.
  367. """
  368. @compatibility(is_backward_compatible=True)
  369. def __init__(self, module):
  370. super().__init__(module)
  371. self.new_graph = Graph()
  372. self.new_graph.set_codegen(module.graph._codegen)
  373. class TransformerTracer(Tracer):
  374. def __init__(self, graph: Graph):
  375. super().__init__()
  376. self.graph = graph
  377. self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment]
  378. def is_leaf_module(self, _, __) -> bool:
  379. return True
  380. self.tracer = TransformerTracer(self.new_graph)
  381. self.tracer.root = module
  382. @compatibility(is_backward_compatible=True)
  383. def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
  384. """
  385. Execute a ``placeholder`` node. In ``Transformer``, this is
  386. overridden to insert a new ``placeholder`` into the output
  387. graph.
  388. Args:
  389. target (Target): The call target for this node. See
  390. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  391. details on semantics
  392. args (Tuple): Tuple of positional args for this invocation
  393. kwargs (Dict): Dict of keyword arguments for this invocation
  394. """
  395. assert isinstance(target, str)
  396. default_value = next(iter(args)) if args else inspect.Signature.empty
  397. return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer)
  398. @compatibility(is_backward_compatible=True)
  399. def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
  400. """
  401. Execute a ``get_attr`` node. In ``Transformer``, this is
  402. overridden to insert a new ``get_attr`` node into the output
  403. graph.
  404. Args:
  405. target (Target): The call target for this node. See
  406. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  407. details on semantics
  408. args (Tuple): Tuple of positional args for this invocation
  409. kwargs (Dict): Dict of keyword arguments for this invocation
  410. """
  411. assert isinstance(target, str)
  412. return self.tracer.create_proxy("get_attr", target, args, kwargs)
  413. @compatibility(is_backward_compatible=True)
  414. def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  415. # Override so that the leaf module policy from `self.tracer` is respected.
  416. assert isinstance(target, str)
  417. submod = self.fetch_attr(target)
  418. return self.tracer.call_module(submod, submod.forward, args, kwargs)
  419. @compatibility(is_backward_compatible=True)
  420. def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  421. # Override so that functions that were wrapped are still wrapped.
  422. return self.tracer.create_proxy('call_function', target, args, kwargs)
  423. @compatibility(is_backward_compatible=True)
  424. def transform(self) -> GraphModule:
  425. """
  426. Transform ``self.module`` and return the transformed
  427. ``GraphModule``.
  428. """
  429. with fx_traceback.preserve_node_meta():
  430. result = super().run(enable_io_processing=False)
  431. if result is not None:
  432. def strip_proxy(a : Union[Argument, Proxy]) -> Any:
  433. return a.node if isinstance(a, Proxy) else a
  434. self.new_graph.output(map_aggregate(result, strip_proxy))
  435. return _make_graph_module(self.module, self.new_graph)