_trace_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # mypy: allow-untyped-defs
  2. import functools
  3. from contextlib import contextmanager
  4. from dataclasses import dataclass, field
  5. from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
  6. import torch
  7. import torch.nn as nn
  8. @dataclass
  9. class TracingConfig:
  10. """
  11. This represents a symbolic tracing configuration.
  12. Args:
  13. tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to
  14. use for symbolic tracing. The default value is the native
  15. :class:`torch.fx.Tracer` constructed with default arguments.
  16. However, the user may want to pass a different value such as the
  17. ``HFTracer`` for models in the HuggingFace Transformers_ library.
  18. .. _Transformers: https://huggingface.co/docs/transformers/index
  19. concrete_args (Optional[Dict[str, Any]]): Concrete arguments that
  20. should not be treated as ``torch.fx.Proxy`` when tracing the
  21. module ``forward()``. Passing ``concrete_args`` allows partially
  22. specializing the forward, e.g. to remove control flow or data
  23. structures. This ``concrete_args`` here is the same argument used
  24. in :meth:`~torch.fx.Tracer.trace`.
  25. """
  26. tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer)
  27. concrete_args: Optional[Dict[str, Any]] = None
  28. class _ParamUsageInfo(NamedTuple):
  29. """
  30. This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record
  31. execution information. The ``dict`` maps modules to a list of these
  32. ``_ParamUsageInfo`` instances, where each instance represents a group of
  33. parameters used together.
  34. Specifically, for each module key in the ``dict``, each instance of this
  35. class represents either:
  36. (1) the module and some sublist of its ``named_parameters()`` used
  37. together in execution (see ``_patched_create_proxy()``), or
  38. (2) a submodule and all of ``submodule.named_parameters()`` (see
  39. ``_patched_call_module()``).
  40. Type (1) corresponds to directly using parameters in ops without calling
  41. ``forward()``, and type (2) corresponds to calling ``forward()``. The
  42. mapped-to lists in the ``dict`` follow the execution order.
  43. """
  44. module: nn.Module
  45. named_params: List[Tuple[str, nn.Parameter]]
  46. class _ExecutionInfo:
  47. """
  48. This represents the execution order information from the forward pass.
  49. Attributes:
  50. curr_module (nn.Module): Current module being traced.
  51. module_forward_order (List[nn.Module]): The modules in (pre-)forward
  52. order, i.e. the order in which their ``forward()`` methods are
  53. called. Each call to a module's ``forward()`` corresponds to one
  54. element in the list.
  55. module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]):
  56. Maps a module to a list of module execution infos. See
  57. :class:`_ParamUsageInfo` for details.
  58. param_forward_order (List[nn.Parameter]): The parameters in forward
  59. execution order, where only a parameter's first participation is
  60. included.
  61. visited_params (Set[nn.Parameter]): The parameters visited so far
  62. during the trace. This is only used during tracing for fast
  63. membership check. Invariant: The parameters in
  64. ``param_forward_order`` are exactly those in ``visited_params``.
  65. """
  66. def __init__(self, root_module: nn.Module) -> None:
  67. self.curr_module: nn.Module = root_module
  68. self.module_forward_order: List[nn.Module] = [root_module]
  69. self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = {
  70. root_module: []
  71. }
  72. self.param_forward_order: List[nn.Parameter] = []
  73. self.visited_params: Set[nn.Parameter] = set()
  74. class _ExecOrderTracer:
  75. def __init__(self) -> None:
  76. self.exec_info: Optional[_ExecutionInfo] = None
  77. @contextmanager
  78. def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module):
  79. self.exec_info = _ExecutionInfo(root_module)
  80. orig_call_module = tracer.call_module
  81. orig_create_proxy = tracer.create_proxy
  82. tracer.call_module = functools.partial(
  83. self._patched_call_module, orig_call_module, self.exec_info
  84. )
  85. fqn_to_param = dict(root_module.named_parameters())
  86. tracer.create_proxy = functools.partial(
  87. self._patched_create_proxy,
  88. orig_create_proxy,
  89. self.exec_info,
  90. fqn_to_param,
  91. )
  92. try:
  93. yield
  94. finally:
  95. tracer.call_module = orig_call_module
  96. tracer.create_proxy = orig_create_proxy
  97. def _patched_call_module(
  98. self,
  99. call_module: Callable,
  100. exec_info: _ExecutionInfo,
  101. # Below are the expected arguments to `call_module()`
  102. module: nn.Module,
  103. forward: Callable,
  104. args: Tuple[Any, ...],
  105. kwargs: Dict[str, Any],
  106. ) -> Any:
  107. """
  108. Overrides ``call_module`` to save execution information to
  109. ``exec_info``. Note that ``call_module`` is called during symbolic
  110. tracing for each non-root module.
  111. Args:
  112. call_module (Callable): Original ``call_module`` to override.
  113. exec_info (_ExecutionInfo): Used to record execution information.
  114. module (nn.Module): Module corresponding to this ``call_module``.
  115. forward (Callable): ``forward()`` method of ``module`` to be called
  116. for this ``call_module``.
  117. args (Tuple[Any, ...]): Positional arguments for ``forward``.
  118. kwargs (Dict[str, Any]): Keyword arguments for ``forward``.
  119. Returns:
  120. Same return value as ``call_module``.
  121. """
  122. exec_info.module_forward_order.append(module)
  123. named_params = list(module.named_parameters())
  124. curr_module = exec_info.curr_module
  125. if named_params:
  126. assert (
  127. curr_module in exec_info.module_to_param_usage_infos
  128. ), "The current module should have already been processed by a patched `call_module`"
  129. exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
  130. _ParamUsageInfo(module, named_params)
  131. )
  132. prev_curr_module = curr_module
  133. exec_info.curr_module = module
  134. exec_info.module_to_param_usage_infos[module] = []
  135. output = call_module(module, forward, args, kwargs)
  136. exec_info.curr_module = prev_curr_module
  137. return output
  138. def _patched_create_proxy(
  139. self,
  140. create_proxy: Callable,
  141. exec_info: _ExecutionInfo,
  142. fqn_to_param: Dict[str, nn.Parameter],
  143. # Below are the expected arguments to `create_proxy()`
  144. kind: str,
  145. target: torch.fx.node.Target,
  146. args: Tuple[Any, ...],
  147. kwargs: Dict[str, Any],
  148. name: Optional[str] = None,
  149. type_expr: Optional[Any] = None,
  150. proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
  151. ) -> torch.fx.Proxy:
  152. """
  153. Overrides ``create_proxy`` to save execution information to
  154. ``exec_info``. Note that ``create_proxy`` is called during symbolic
  155. tracing for each leaf function/method/module.
  156. Args:
  157. create_proxy (Callable): Original ``create_proxy`` to override.
  158. exec_info (_ExecutionInfo): Used to record execution information.
  159. fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the
  160. root module's ``named_parameters()`` with FQN as key and
  161. parameter as value.
  162. kind (str): Kind of the target method ('call_function',
  163. 'call_method', 'get_attr', 'call_module', 'placeholder', or
  164. 'output'). See :class:`torch.fx.Graph` for details. This is
  165. passed to ``create_proxy``.
  166. target (torch.fx.node.Target): Contains the string name of the
  167. function/method/module. This is passed to ``create_proxy``.
  168. args (Tuple[Any, ...]): Positional arguments for the function/
  169. method/module. This is passed to ``create_proxy``.
  170. kwargs (Dict[str, Any]): Keyword arguments for the function/method/
  171. module. This is passed to ``create_proxy``
  172. name (Optional[str]): An optional string name for the ``Node``
  173. created in ``create_proxy``. This is passed to
  174. ``create_proxy``.
  175. type_expr (Optional[Any]): An optional type annotation representing
  176. the Python type that the output of the node has. This is passed
  177. to ``create_proxy``.
  178. proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]):
  179. An alternative proxy constructor used in ``create_proxy``. This
  180. is passed to ``create_proxy``.
  181. Returns:
  182. torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object.
  183. """
  184. proxy = create_proxy(
  185. kind, target, args, kwargs, name, type_expr, proxy_factory_fn
  186. )
  187. curr_module = exec_info.curr_module
  188. if kind in ("call_function", "call_method"):
  189. if args is not None:
  190. named_params: List[Tuple[str, nn.Parameter]] = []
  191. for arg in args:
  192. if (
  193. isinstance(arg, torch.fx.Proxy)
  194. and arg.node.target in fqn_to_param
  195. ):
  196. param = fqn_to_param[arg.node.target]
  197. named_params.append((arg.node.target, param))
  198. if param not in exec_info.visited_params:
  199. exec_info.visited_params.add(param)
  200. exec_info.param_forward_order.append(param)
  201. if named_params:
  202. exec_info.module_to_param_usage_infos[curr_module].append(
  203. _ParamUsageInfo(curr_module, named_params)
  204. )
  205. elif kind == "call_module":
  206. named_params = list(curr_module.named_parameters())
  207. if named_params:
  208. exec_info.module_to_param_usage_infos[curr_module].append(
  209. _ParamUsageInfo(curr_module, named_params)
  210. )
  211. for _, param in named_params:
  212. if param not in exec_info.visited_params:
  213. exec_info.visited_params.add(param)
  214. exec_info.param_forward_order.append(param)
  215. return proxy