debugging.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # mypy: ignore-errors
  2. import dataclasses
  3. import functools
  4. from importlib import import_module
  5. from typing import Any, List, Optional
  6. import torch
  7. from functorch.compile import min_cut_rematerialization_partition
  8. from torch import _guards
  9. from torch._functorch import config as functorch_config
  10. from torch._functorch.compilers import ts_compile
  11. from .common import aot_autograd
  12. from .registry import register_debug_backend as register_backend
  13. """
  14. This file contains TorchDynamo backends intended for debugging uses.
  15. """
  16. @register_backend
  17. def eager(gm, fake_tensor_inputs):
  18. return gm.forward
  19. @register_backend
  20. def eager_noexcept(gm, fake_tensor_inputs):
  21. # This backend is intended to check that dynamo-generated GraphModules
  22. # do not cause errors.
  23. def inner(*args):
  24. try:
  25. return gm(*args)
  26. except Exception as e:
  27. raise torch._dynamo.exc.TorchDynamoException(
  28. "Unexpected exception when running generated GraphModule"
  29. ) from e
  30. return inner
  31. @register_backend
  32. def pre_dispatch_eager(gm, fake_tensor_inputs):
  33. from torch.fx.experimental.proxy_tensor import make_fx
  34. def runnable_gm(*args):
  35. return torch.fx.Interpreter(gm).run(*args)
  36. pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
  37. pre_dispatch_gm.print_readable()
  38. return pre_dispatch_gm
  39. @register_backend
  40. def eager_debug(gm, fake_tensor_inputs):
  41. from torch._subclasses.schema_check_mode import SchemaCheckMode
  42. # We could add more debugging bits here.
  43. # Right now, this backend can be used to check for and error on
  44. # custom dispatcher ops that have incorrect schemas.
  45. def inner(*args):
  46. with SchemaCheckMode():
  47. return torch.fx.Interpreter(gm).run(*args)
  48. return inner
  49. @register_backend(name="ts")
  50. def torchscript(gm, fake_tensor_inputs):
  51. return torch.jit.script(gm)
  52. # used boxed call to discard inputs when they are no longer needed
  53. def boxed_nop(fx_g, example_inputs):
  54. def run(args):
  55. return torch.fx.Interpreter(fx_g).boxed_run(args)
  56. run._boxed_call = True
  57. return run
  58. # Useful for debugging purpose
  59. # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
  60. aot_eager = aot_autograd(
  61. fw_compiler=boxed_nop,
  62. partition_fn=min_cut_rematerialization_partition,
  63. keep_inference_input_mutations=True,
  64. )
  65. register_backend(name="aot_eager", compiler_fn=aot_eager)
  66. aot_eager_default_partitioner = aot_autograd(
  67. fw_compiler=boxed_nop, keep_inference_input_mutations=True
  68. )
  69. register_backend(
  70. name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner
  71. )
  72. # Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
  73. # inductor problems.
  74. # aot_eager_decomp_partition just replaces the inductor compiler with nop to help
  75. # isolate inductor vs aot_eager errors
  76. def aot_eager_decomp_partition(gm, fake_tensor_inputs):
  77. with functorch_config.patch(unlift_effect_tokens=True):
  78. return aot_autograd(
  79. # these are taken from memory_efficient_fusion()
  80. fw_compiler=boxed_nop,
  81. bw_compiler=boxed_nop,
  82. # NB: lambda here is to delay import of inductor
  83. decompositions=lambda: import_module(
  84. "torch._inductor.compile_fx"
  85. ).select_decomp_table(),
  86. partition_fn=functools.partial(
  87. min_cut_rematerialization_partition, compiler="inductor"
  88. ),
  89. )(gm, fake_tensor_inputs)
  90. register_backend(
  91. name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
  92. )
  93. # AOT Autograd with torchscript backend. Default partitioner.
  94. # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
  95. # by using the relevant fuser with torch.jit.fuser(...)
  96. aot_ts = aot_autograd(fw_compiler=ts_compile)
  97. register_backend(name="aot_ts", compiler_fn=aot_ts)
  98. # These buggy backends are used for inducing bugs so that we can test
  99. # our repro extraction / minifier scripts
  100. class ReluCompileError(Exception):
  101. pass
  102. class TestingOnlyCompileError(Exception):
  103. pass
  104. @register_backend
  105. def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
  106. for node in gm.graph.nodes:
  107. if node.target == torch.relu:
  108. raise ReluCompileError
  109. return gm
  110. @register_backend
  111. def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
  112. for node in gm.graph.nodes:
  113. if node.target == torch.relu:
  114. node.target = torch._assert
  115. node.args = (False, "ReluRuntimeError")
  116. gm.recompile()
  117. return gm
  118. @register_backend
  119. def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
  120. for node in gm.graph.nodes:
  121. if node.target == torch.relu:
  122. node.target = torch.add
  123. node.args = (node.args[0], 1)
  124. gm.recompile()
  125. return gm
  126. @register_backend
  127. def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
  128. # Require at least one non-trivial thing in the graph,
  129. # see https://github.com/pytorch/pytorch/issues/102898
  130. for node in gm.graph.nodes:
  131. if node.op == "call_function":
  132. break
  133. else:
  134. return gm
  135. for t in example_inputs:
  136. if not t.is_leaf:
  137. raise TestingOnlyCompileError
  138. return gm
  139. @dataclasses.dataclass
  140. class ExplainOutput:
  141. """
  142. This is the output of :func:`torch._dynamo.explain()`
  143. There is no reason to create this class directly.
  144. """
  145. graphs: List[torch.fx.GraphModule]
  146. graph_count: int
  147. graph_break_count: int
  148. break_reasons: List[
  149. Any
  150. ] # Type is GraphCompileReason but doesn't matter for this purpose
  151. op_count: int
  152. ops_per_graph: Optional[List[torch.fx.Node]] = None
  153. out_guards: Optional[List[_guards.Guard]] = None
  154. compile_times: Optional[str] = None
  155. def __str__(self):
  156. output = f"Graph Count: {self.graph_count}\n"
  157. output += f"Graph Break Count: {self.graph_break_count}\n"
  158. output += f"Op Count: {self.op_count}\n"
  159. output += "Break Reasons:\n"
  160. for idx, break_reason in enumerate(self.break_reasons):
  161. output += f" Break Reason {idx+1}:\n"
  162. output += f" Reason: {break_reason.reason}\n"
  163. output += " User Stack:\n"
  164. for frame_summary in break_reason.user_stack:
  165. output += f" {frame_summary}\n"
  166. if self.ops_per_graph is not None:
  167. output += "Ops per Graph:\n"
  168. for idx, ops in enumerate(self.ops_per_graph):
  169. output += f" Ops {idx+1}:\n"
  170. for op in ops:
  171. output += f" {op}\n"
  172. if self.out_guards is not None:
  173. output += "Out Guards:\n"
  174. for i, guard in enumerate(self.out_guards):
  175. output += f" Guard {i+1}:\n"
  176. output += f" {str(guard)}"
  177. if self.compile_times is not None:
  178. output += f"Compile Times: {self.compile_times}\n"
  179. return output
  180. def _explain_graph_detail(
  181. gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons
  182. ):
  183. """
  184. This function is a utility which processes a torch.fx.GraphModule and
  185. accumulates information about its ops, graph breaks, and other details. It
  186. is intended to be used by the ExplainWithBackend class and
  187. `torch._dynamo.explain()` to provide details from Dynamo's graph capture.
  188. Parameters:
  189. gm (torch.fx.GraphModule): The GraphModule to be processed.
  190. graphs (list): A list that accumulates all the GraphModules processed.
  191. op_count (int): The total count of operations in all GraphModules processed so far.
  192. ops_per_graph (list): A list that accumulates the operations of each GraphModule.
  193. break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule.
  194. Returns:
  195. tuple: A tuple containing the processed GraphModule, the updated lists of graphs,
  196. operations per graph, and break reasons, and the updated operation count.
  197. """
  198. graphs.append(gm)
  199. ops = [node.target for node in gm.graph.nodes if node.op == "call_function"]
  200. op_count += len(ops)
  201. ops_per_graph.append(ops)
  202. if gm.compile_subgraph_reason.graph_break:
  203. break_reasons.append(gm.compile_subgraph_reason)
  204. return gm, graphs, op_count, ops_per_graph, break_reasons
  205. class ExplainWithBackend:
  206. """
  207. This class is intended to be used as a backend for `torch.compile`. It is
  208. composable with other backends. When used in this way, it accumulates
  209. information about graph breaks, ops, and other info and provides a string
  210. representation summarizing this information.
  211. Attributes:
  212. backend (str): The name of the backend to use for optimization.
  213. graphs (list): A list of the graphs captured by TorchDynamo.
  214. op_count (int): The total number of operations in all optimized graphs.
  215. break_reasons (list): A list of graph break reasons with stack traces.
  216. Example Usage:
  217. def fn(x):
  218. x = torch.sigmoid(x)
  219. return x
  220. torch._dynamo.reset()
  221. eb = ExplainWithBackend("inductor")
  222. optimized_fn = torch.compile(fn, backend=eb)
  223. result = optimized_fn(torch.randn(5))
  224. print(eb.output())
  225. """
  226. def __init__(self, backend):
  227. from .registry import lookup_backend
  228. self.backend = lookup_backend(backend)
  229. self.graphs = []
  230. self.op_count = 0
  231. self.break_reasons = []
  232. def __call__(self, gm: torch.fx.GraphModule, example_inputs):
  233. gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail(
  234. gm, self.graphs, self.op_count, [], self.break_reasons
  235. )
  236. return self.backend(gm, example_inputs)
  237. def output(self) -> ExplainOutput:
  238. graph_count = len(self.graphs)
  239. output = ExplainOutput(
  240. self.graphs,
  241. graph_count,
  242. graph_count - 1,
  243. self.break_reasons,
  244. self.op_count,
  245. )
  246. return output