debug.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import contextlib
  4. import dataclasses
  5. import functools
  6. import itertools
  7. import logging
  8. import os
  9. import os.path
  10. import pickle
  11. import pstats
  12. import shutil
  13. import subprocess
  14. from typing import Any, Dict, List, Optional
  15. from unittest.mock import patch
  16. import torch
  17. from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
  18. from torch import fx as fx
  19. from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug
  20. from torch._dynamo.utils import get_debug_dir
  21. from torch.fx.graph_module import GraphModule
  22. from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
  23. from torch.fx.passes.tools_common import legalize_graph
  24. from torch.utils._pytree import tree_map
  25. from . import config, ir # noqa: F811, this is needed
  26. from .scheduler import (
  27. BaseSchedulerNode,
  28. FusedSchedulerNode,
  29. NopKernelSchedulerNode,
  30. OutputNode,
  31. SchedulerNode,
  32. )
  33. from .virtualized import V
  34. log = logging.getLogger(__name__)
  35. SchedulerNodeList = List[Any]
  36. BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
  37. GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
  38. @functools.lru_cache(None)
  39. def has_dot() -> bool:
  40. try:
  41. subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
  42. return True
  43. except subprocess.SubprocessError:
  44. return False
  45. def draw_buffers(nodes: List[BaseSchedulerNode], print_graph=False, fname=None):
  46. """
  47. Draw a graph in fname.svg.
  48. """
  49. if not has_dot():
  50. log.warning("draw_buffers() requires `graphviz` package")
  51. return
  52. if fname is None:
  53. fname = get_graph_being_compiled()
  54. graph = create_fx_from_snodes(nodes)
  55. for node in graph.nodes:
  56. if "fusion_meta" not in node.meta:
  57. continue
  58. group = node.meta["fusion_meta"].group
  59. if isinstance(group, tuple):
  60. if isinstance(group[1], int):
  61. group = (group[1],)
  62. else:
  63. group = group[1]
  64. # gather meta data
  65. dtype = None
  66. if isinstance(node, ir.ComputedBuffer):
  67. dtype = node.data.dtype
  68. metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type]
  69. node.meta["tensor_meta"] = metadata
  70. if print_graph:
  71. print(graph)
  72. gm = GraphModule({}, graph)
  73. legalize_graph(gm)
  74. gm.graph.lint()
  75. draw_graph(
  76. gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape
  77. )
  78. def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
  79. """
  80. Creates a FX Graph from a list of SchedulerNode objects.
  81. """
  82. def get_fake_func(name):
  83. def func1(*args):
  84. return 0
  85. func1.__name__ = name
  86. return func1
  87. FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
  88. buf_to_fx_node = {}
  89. graph = torch.fx.Graph()
  90. first_node = None
  91. outputs = []
  92. group: Any = None
  93. # create call_function node for each Buffer and Kernel
  94. for snode in snodes:
  95. if snode.is_extern():
  96. node_type = "extern"
  97. group = node_type
  98. elif snode.is_template():
  99. node_type = "template"
  100. group = node_type
  101. elif isinstance(snode, NopKernelSchedulerNode):
  102. node_type = "nop"
  103. group = node_type
  104. elif isinstance(snode, SchedulerNode):
  105. node_type = "compute"
  106. group = snode.group
  107. elif isinstance(snode, FusedSchedulerNode):
  108. node_type = "fused"
  109. group = snode.group
  110. else:
  111. raise RuntimeError("Unknown node type")
  112. fused_name = torch._inductor.utils.get_fused_kernel_name(
  113. snode.get_nodes(), "original_aten"
  114. )
  115. func_name = f"{node_type}: {fused_name}"
  116. node_func = get_fake_func(func_name)
  117. kwargs = {}
  118. if hasattr(snode, "get_device"):
  119. kwargs = {"device": snode.get_device()}
  120. fx_node = graph.call_function(node_func, args=(), kwargs=kwargs)
  121. def in_output(snode):
  122. if isinstance(snode, FusedSchedulerNode):
  123. return any(in_output(x) for x in snode.snodes)
  124. return any(isinstance(user.node, OutputNode) for user in snode.users)
  125. if in_output(snode):
  126. outputs.append(fx_node)
  127. name = snode.get_name()
  128. fx_node.name = name
  129. fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
  130. if isinstance(snode, FusedSchedulerNode):
  131. for x in snode.snodes:
  132. buf_to_fx_node[x.get_name()] = fx_node
  133. buf_to_fx_node[name] = fx_node
  134. if first_node is None:
  135. first_node = fx_node
  136. # create edges between nodes
  137. for snode in snodes:
  138. name = snode.get_name()
  139. deps = snode.read_writes.reads
  140. fx_node = buf_to_fx_node[name]
  141. new_args = []
  142. for dep in deps:
  143. if dep.name in buf_to_fx_node:
  144. dep_node = buf_to_fx_node[dep.name]
  145. else:
  146. with graph.inserting_before(first_node):
  147. dep_node = graph.placeholder(dep.name)
  148. buf_to_fx_node[dep.name] = dep_node
  149. new_args.append(dep_node)
  150. fx_node.args = tuple(new_args)
  151. graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
  152. return graph
  153. def update_orig_fx_node_name_to_buf_name(
  154. nodes: SchedulerNodeList,
  155. node_name_to_buf_name: Dict[str, str],
  156. parent_buf_name: Optional[str] = None,
  157. n_origins: int = 0,
  158. ):
  159. if nodes is None:
  160. return
  161. for node in nodes:
  162. # for FusedSchedulerNode, traverse recursively into get_nodes()
  163. buf_name = node.get_name()
  164. children_nodes = node.get_nodes()
  165. if children_nodes is not None and len(children_nodes) > 1:
  166. update_orig_fx_node_name_to_buf_name(
  167. children_nodes,
  168. node_name_to_buf_name,
  169. buf_name if parent_buf_name is None else parent_buf_name,
  170. )
  171. continue
  172. else:
  173. assert len(children_nodes) == 1 and children_nodes[0] == node
  174. ir_node = node.node
  175. if ir_node is None or ir_node.origins is None:
  176. continue
  177. for origin in ir_node.origins:
  178. node_name = origin.name
  179. # when buf1 and buf2 both have origin=node1
  180. # we draw node1 according to buf1
  181. if node_name not in node_name_to_buf_name:
  182. node_name_to_buf_name[node_name] = (
  183. buf_name if parent_buf_name is None else parent_buf_name
  184. )
  185. def get_node_name_to_buf_meta(node_name_to_buf_name: Dict[str, str]):
  186. buf_name_to_n_node = {}
  187. for node_name, buf_name in node_name_to_buf_name.items():
  188. if buf_name not in buf_name_to_n_node:
  189. buf_name_to_n_node[buf_name] = {node_name}
  190. else:
  191. buf_name_to_n_node[buf_name].add(node_name)
  192. node_name_to_buf_meta = {}
  193. for node_name, buf_name in node_name_to_buf_name.items():
  194. n_node = len(buf_name_to_n_node[buf_name])
  195. node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node)
  196. return node_name_to_buf_meta
  197. def annotate_orig_fx_with_snodes(
  198. gm: torch.fx.GraphModule, snodes: SchedulerNodeList
  199. ) -> None:
  200. """
  201. Creates a FX Graph from a list of SchedulerNode objects.
  202. """
  203. node_name_to_buf_name: Dict[str, str] = {}
  204. update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
  205. if node_name_to_buf_name is None:
  206. return
  207. node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name)
  208. for node in gm.graph.nodes:
  209. if node.name in node_name_to_buf_meta:
  210. node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name)
  211. @contextlib.contextmanager
  212. def enable_aot_logging():
  213. compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
  214. import torch._functorch.aot_autograd
  215. log = logging.getLogger(torch._functorch.aot_autograd.__name__)
  216. stack = contextlib.ExitStack()
  217. if not compile_debug:
  218. try:
  219. yield
  220. finally:
  221. stack.close()
  222. return
  223. # Enable all graphs to be logged to a file by setting the flags to True
  224. # and the log level of the file logger to DEBUG
  225. stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
  226. path = os.path.join(get_debug_dir(), "torchinductor")
  227. os.makedirs(path, exist_ok=True)
  228. fh = logging.FileHandler(
  229. os.path.join(
  230. path,
  231. f"aot_{get_aot_graph_name()}_debug.log",
  232. )
  233. )
  234. fh.setLevel(logging.DEBUG)
  235. fh.setFormatter(
  236. logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
  237. )
  238. log.addHandler(fh)
  239. try:
  240. yield
  241. finally:
  242. log.removeHandler(fh)
  243. stack.close()
  244. class DebugContext:
  245. _counter = itertools.count()
  246. @staticmethod
  247. def wrap(fn):
  248. @functools.wraps(fn)
  249. def inner(*args, **kwargs):
  250. with DebugContext():
  251. return fn(*args, **kwargs)
  252. return wrap_compiler_debug(inner, compiler_name="inductor")
  253. @staticmethod
  254. def create_debug_dir(folder_name: str) -> Optional[str]:
  255. debug_dir = config.trace.debug_dir or get_debug_dir()
  256. for n in DebugContext._counter:
  257. dirname = os.path.join(
  258. debug_dir,
  259. "torchinductor",
  260. f"{folder_name}.{n}",
  261. )
  262. if not os.path.exists(dirname):
  263. os.makedirs(dirname)
  264. return dirname
  265. return None
  266. def __init__(self):
  267. self._prof = None
  268. self._path = None
  269. self._stack = contextlib.ExitStack()
  270. def copy(self, new_path: str):
  271. if not self._path:
  272. return
  273. assert new_path.endswith(".debug"), new_path
  274. from filelock import FileLock
  275. try:
  276. with FileLock(f"{new_path}.lock"):
  277. if os.path.exists(new_path):
  278. shutil.rmtree(new_path)
  279. shutil.copytree(self._path, new_path)
  280. except OSError:
  281. log.warning(
  282. "Failed to copy debug files from %s to %s", self._path, new_path
  283. )
  284. def fopen(self, filename: str, write_mode: str = "w", *args, **kwargs):
  285. assert self._path
  286. return open(os.path.join(self._path, filename), write_mode, *args, **kwargs)
  287. @contextlib.contextmanager
  288. def fopen_context(self, filename: str, write_mode: str = "w", *args, **kwargs):
  289. assert self._path
  290. with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f:
  291. yield f
  292. def filename(self, suffix: str):
  293. assert self._path
  294. return os.path.join(self._path, suffix)
  295. def upload_tar(self):
  296. if config.trace.upload_tar is not None:
  297. import tarfile
  298. assert self._path
  299. tar_file = os.path.join(
  300. self._path, f"{os.path.basename(self._path)}.tar.gz"
  301. )
  302. with tarfile.open(tar_file, "w:gz") as tar:
  303. tar.add(self._path, arcname=os.path.basename(self._path))
  304. config.trace.upload_tar(tar_file)
  305. def __enter__(self):
  306. if config.debug:
  307. log = logging.getLogger("torch._dynamo")
  308. prev_level = log.level
  309. log.setLevel(logging.DEBUG)
  310. def reset_log_level(level):
  311. log.setLevel(level)
  312. self._stack.callback(reset_log_level, prev_level)
  313. self._stack.enter_context(V.set_debug_handler(self))
  314. if not config.trace.enabled:
  315. return
  316. self._path = self.create_debug_dir(get_aot_graph_name())
  317. if config.trace.debug_log:
  318. self._setup_log_capture("debug.log", logging.DEBUG)
  319. if config.trace.info_log:
  320. self._setup_log_capture("info.log", logging.INFO)
  321. def _setup_log_capture(self, filename: str, level: int):
  322. log = logging.getLogger("torch._inductor")
  323. fd = self._stack.enter_context(self.fopen(filename))
  324. ch = logging.StreamHandler(fd)
  325. ch.setLevel(level)
  326. ch.setFormatter(
  327. logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
  328. )
  329. log.addHandler(ch)
  330. log.setLevel(min(log.level, level))
  331. self._stack.callback(log.removeHandler, ch)
  332. def __exit__(self, exc_type, exc_val, exc_tb):
  333. if self._prof:
  334. self._prof.disable()
  335. self._save_profile_data()
  336. if self._path:
  337. self.upload_tar()
  338. log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
  339. self._stack.close()
  340. def _save_profile_data(self):
  341. assert self._prof
  342. self._prof.dump_stats(self.filename("compile.prof"))
  343. with self.fopen("compile.stats") as fd:
  344. stats = pstats.Stats(self._prof, stream=fd)
  345. stats.strip_dirs()
  346. stats.sort_stats("cumtime")
  347. stats.print_stats(100)
  348. stats.sort_stats("tottime")
  349. stats.print_stats(100)
  350. def __getattr__(self, name):
  351. if config.trace.enabled and getattr(config.trace, name):
  352. try:
  353. return getattr(DebugFormatter(self), name)
  354. except Exception:
  355. log.warning("Ignoring exception in debug code", exc_info=True)
  356. else:
  357. def ignored(*args, **kwargs):
  358. pass
  359. return ignored
  360. class DebugFormatter:
  361. def __init__(self, handler):
  362. self.fopen = handler.fopen
  363. self.fopen_context = handler.fopen_context
  364. self.filename = handler.filename
  365. self.handler = handler
  366. def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
  367. with self.fopen("fx_graph_runnable.py") as fd:
  368. save_graph_repro(fd, gm, inputs, "inductor")
  369. with self.fopen("fx_graph_readable.py") as fd:
  370. fd.write(gm.print_readable(print_output=False))
  371. def fx_graph_transformed(
  372. self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
  373. ):
  374. with self.fopen("fx_graph_transformed.py") as fd:
  375. fd.write(gm.print_readable(print_output=False))
  376. def ir_pre_fusion(self, nodes: SchedulerNodeList):
  377. self._write_ir("ir_pre_fusion.txt", nodes)
  378. def ir_post_fusion(self, nodes: SchedulerNodeList):
  379. self._write_ir("ir_post_fusion.txt", nodes)
  380. def _write_ir(self, filename: str, nodes: SchedulerNodeList):
  381. with self.fopen(filename) as fd:
  382. log.info("Writing debug ir to %s", fd.name)
  383. for node in nodes:
  384. fd.write(node.debug_str())
  385. fd.write("\n\n\n")
  386. def graph_diagram(self, nodes: SchedulerNodeList):
  387. draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
  388. def draw_orig_fx_graph(self, gm: torch.fx.GraphModule, nodes: SchedulerNodeList):
  389. annotate_orig_fx_with_snodes(gm, nodes)
  390. draw_graph(
  391. gm,
  392. fname=self.filename("orig_fx_graph_diagram.svg"),
  393. clear_meta=False,
  394. prog=GRAPHVIZ_COMMAND_SCALABLE,
  395. parse_stack_trace=True,
  396. dot_graph_shape=config.trace.dot_graph_shape,
  397. )
  398. def output_code(self, filename):
  399. shutil.copy(filename, self.filename("output_code.py"))
  400. def log_autotuning_results(
  401. self,
  402. name: str,
  403. input_nodes: List[ir.IRNode],
  404. timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
  405. elapse: float,
  406. precompile_elapse: float,
  407. ):
  408. import json
  409. from .ir import FixedLayout
  410. def build_node_info(node: ir.IRNode):
  411. if hasattr(node, "name"):
  412. node_name = node.name
  413. else:
  414. node_name = ""
  415. node_info = {
  416. "name": node_name,
  417. "type": type(node).__name__,
  418. }
  419. try:
  420. layout = node.get_layout()
  421. if isinstance(layout, FixedLayout):
  422. offset = 0
  423. try:
  424. offset = int(layout.offset)
  425. except Exception:
  426. try:
  427. offset = V.graph.sizevars.size_hint(
  428. layout.offset, fallback=0
  429. )
  430. except Exception:
  431. pass
  432. static_layout = FixedLayout(
  433. layout.device,
  434. dtype=layout.dtype,
  435. size=list(V.graph.sizevars.size_hints(layout.size)),
  436. stride=list(V.graph.sizevars.size_hints(layout.stride)),
  437. offset=offset,
  438. )
  439. node_info["layout"] = str(static_layout)
  440. else:
  441. node_info["layout"] = str(node.get_layout())
  442. except Exception as e:
  443. pass
  444. try:
  445. node_info["dtype"] = str(node.get_dtype())
  446. except Exception as e:
  447. pass
  448. try:
  449. node_info["device"] = str(node.get_device())
  450. except Exception as e:
  451. pass
  452. try:
  453. node_info["stride"] = str(
  454. V.graph.sizevars.size_hints(node.get_stride())
  455. )
  456. except Exception as e:
  457. pass
  458. try:
  459. node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size()))
  460. except Exception as e:
  461. pass
  462. try:
  463. node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel()))
  464. except Exception as e:
  465. pass
  466. if hasattr(node, "data") and isinstance(node.data, ir.IRNode):
  467. node_info["data"] = build_node_info(node.data)
  468. return node_info
  469. general_properties = {
  470. "op_name": name,
  471. "cuda_device_name": torch.cuda.get_device_name(),
  472. "cuda_device_count": torch.cuda.device_count(),
  473. "input_nodes": [build_node_info(node) for node in input_nodes],
  474. "autotuning_time": elapse,
  475. "precompile_time": precompile_elapse,
  476. }
  477. with self.fopen_context(
  478. "autotuning_result_json_list.txt", "at", encoding="utf-8"
  479. ) as fd:
  480. for caller, time in timings.items():
  481. info_dict = dict(caller.info_dict())
  482. info_dict.update(general_properties)
  483. info_dict["benchmark_result"] = time
  484. json.dump(info_dict, fd)
  485. fd.write("\n")
  486. @dataclasses.dataclass
  487. class TensorMetadataHolder:
  488. tensor_metadata: TensorMetadata
  489. device: torch.device
  490. save_args_cnt = itertools.count()
  491. def save_args_for_compile_fx_inner(*args, **kwargs):
  492. """
  493. This function is used to save arguments for a compile_fx_inner function call
  494. to the file system. Later on one can replay the compile_fx_inner call
  495. with the saved arguments using load_args_and_run_compile_fx_inner.
  496. """
  497. folder = "/tmp/inductor_saved_args"
  498. if not os.path.exists(folder):
  499. os.mkdir(folder)
  500. def handle_tensor(x):
  501. """
  502. Pickle FakeTensor will result in error:
  503. AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'
  504. Convert all Tensor to metadata. This may also makes pickle faster.
  505. """
  506. if isinstance(x, torch.Tensor):
  507. return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
  508. else:
  509. return x
  510. args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))
  511. fn_name = "compile_fx_inner"
  512. path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl"
  513. with open(path, "wb") as f:
  514. pickle.dump((args_to_save, kwargs_to_save), f)
  515. if log.isEnabledFor(logging.DEBUG):
  516. message = f"""
  517. Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
  518. run the following:
  519. from torch._inductor.debug import load_args_and_run_compile_fx_inner
  520. load_args_and_run_compile_fx_inner({path!r})
  521. """
  522. # call print rather than log.debug. log.debug will print message
  523. # prefix for each line which makes the code snippet harder to be
  524. # copied.
  525. # Not a big deal since the code is already been guarded by checking
  526. # the log level.
  527. print(message)
  528. def load_args_and_run_compile_fx_inner(path: str):
  529. from torch._inductor.compile_fx import compile_fx_inner
  530. with open(path, "rb") as f:
  531. args, kwargs = pickle.load(f)
  532. def handle_tensor(x):
  533. if isinstance(x, TensorMetadataHolder):
  534. return torch._dynamo.testing.rand_strided(
  535. x.tensor_metadata.shape,
  536. x.tensor_metadata.stride,
  537. x.tensor_metadata.dtype,
  538. x.device,
  539. )
  540. else:
  541. return x
  542. fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
  543. with fake_mode, config.patch("save_args", False):
  544. args, kwargs = tree_map(handle_tensor, (args, kwargs))
  545. return compile_fx_inner(*args, **kwargs)