_traceback.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # mypy: allow-untyped-defs
  2. from types import TracebackType
  3. from typing import List, Optional
  4. import tempfile
  5. import traceback
  6. import contextlib
  7. import inspect
  8. import os.path
  9. # This file contains utilities for ensuring dynamically compile()'d
  10. # code fragments display their line numbers in backtraces.
  11. #
  12. # The constraints:
  13. #
  14. # - We don't have control over the user exception printer (in particular,
  15. # we cannot assume the linecache trick will work, c.f.
  16. # https://stackoverflow.com/q/50515651/23845 )
  17. #
  18. # - We don't want to create temporary files every time we compile()
  19. # some code; file creation should happen lazily only at exception
  20. # time. Arguably, you *should* be willing to write out your
  21. # generated Python code to file system, but in some situations
  22. # (esp. library code) it would violate user expectation to write
  23. # to the file system, so we try to avoid it. In particular, we'd
  24. # like to keep the files around, so users can open up the files
  25. # mentioned in the trace; if the file is invisible, we want to
  26. # avoid clogging up the filesystem.
  27. #
  28. # If this is not a constraint for you, there is a substantially simpler
  29. # way to implement the functionality in this PR: instead of using
  30. # eval/exec directly, just always write a Python file to filesystem
  31. # and compile that.
  32. #
  33. # - You have control over a context where the compiled code will get
  34. # executed, so that we can interpose while the stack is unwinding
  35. # (otherwise, we have no way to interpose on the exception printing
  36. # process.)
  37. #
  38. # There are two things you have to do to make use of the utilities here:
  39. #
  40. # - When you compile your source code, you must save its string source
  41. # in its f_globals under the magic name "__compile_source__"
  42. #
  43. # - Before running the compiled code, enter the
  44. # report_compile_source_on_error() context manager.
  45. @contextlib.contextmanager
  46. def report_compile_source_on_error():
  47. try:
  48. yield
  49. except Exception as exc:
  50. tb = exc.__traceback__
  51. # Walk the traceback, looking for frames that have
  52. # source attached
  53. stack = []
  54. while tb is not None:
  55. filename = tb.tb_frame.f_code.co_filename
  56. source = tb.tb_frame.f_globals.get("__compile_source__")
  57. if filename == "<string>" and source is not None:
  58. # What black magic are we doing here? Intuitively, what
  59. # we would like to do is overwrite the co_filename on any
  60. # frames that were generated from exec/eval so that they
  61. # point to a temporary file that has the actual line
  62. # information, so Python's default error printer can print
  63. # useful line information on it.
  64. #
  65. # Writing out the temporary file is easy. But overwriting
  66. # co_filename is not! You can't modify the code object
  67. # associated with a frame. You can, however, reconstruct
  68. # a traceback with entirely new frames from scratch, so that's
  69. # what we do. But there's another problem, which is how to
  70. # make the frame?
  71. #
  72. # The black magic is we make a frankenstein frame and code
  73. # object which resembles the original frame/code enough so
  74. # that it will print properly under traceback and the default
  75. # error printer, but IT IS NOT THE ORIGINAL FRAME (you
  76. # couldn't, e.g., execute its code with different variables
  77. # and expect it to work.)
  78. # Don't delete the temporary file so the user can inspect it
  79. # TODO: This creates a temporary file for every frame, but we
  80. # technically only need one per distinct __compile_source__
  81. with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:
  82. f.write(source)
  83. # Create a frame. Python doesn't let you construct
  84. # FrameType directly, so just make one with compile
  85. frame = tb.tb_frame
  86. code = compile('__inspect_currentframe()', f.name, 'eval')
  87. code = code.replace(co_name=frame.f_code.co_name)
  88. # Python 3.11 only
  89. if hasattr(frame.f_code, 'co_linetable'):
  90. # We can't copy ALL of the metadata over, because you
  91. # can cause Python to segfault this way. What exactly
  92. # do we need? We need enough information for
  93. # traceback to be able to print the exception
  94. # correctly. Code reading Lib/traceback.py reveals
  95. # that traceback calls code.co_positions() in order to
  96. # get the augmented line/col numbers. Objects/codeobject.c,
  97. # specifically _PyCode_InitAddressRange, reveals that
  98. # this iterator is initialized from co_linetable and
  99. # co_firstfileno. So copy these we must!
  100. code = code.replace( # type: ignore[call-arg]
  101. co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined]
  102. co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined]
  103. )
  104. fake_frame = eval(
  105. code,
  106. frame.f_globals,
  107. {
  108. **frame.f_locals,
  109. '__inspect_currentframe': inspect.currentframe
  110. }
  111. )
  112. fake_tb = TracebackType(
  113. None, fake_frame, tb.tb_lasti, tb.tb_lineno
  114. )
  115. stack.append(fake_tb)
  116. else:
  117. stack.append(tb)
  118. tb = tb.tb_next
  119. # Reconstruct the linked list
  120. tb_next = None
  121. for tb in reversed(stack):
  122. tb.tb_next = tb_next
  123. tb_next = tb
  124. raise exc.with_traceback(tb_next) # noqa: B904
  125. def shorten_filename(fn, *, base=None):
  126. """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""
  127. if base is None:
  128. base = os.path.dirname(os.path.dirname(__file__))
  129. # Truncate torch/foo.py to foo.py
  130. try:
  131. prefix = os.path.commonpath([fn, base])
  132. except ValueError:
  133. return fn
  134. else:
  135. return fn[len(prefix) + 1:]
  136. def format_frame(frame, *, base=None, line=False):
  137. """
  138. Format a FrameSummary in a short way, without printing full absolute path or code.
  139. The idea is the result fits on a single line.
  140. """
  141. extra_line = ""
  142. if line:
  143. extra_line = f"{frame.line} # "
  144. return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"
  145. def format_traceback_short(tb):
  146. """Format a TracebackType in a short way, printing only the inner-most frame."""
  147. return format_frame(traceback.extract_tb(tb)[-1])
  148. class CapturedTraceback:
  149. __slots__ = ['tb', 'skip']
  150. def __init__(self, tb, skip=0):
  151. self.tb = tb
  152. self.skip = skip
  153. def cleanup(self):
  154. self.tb = None
  155. def summary(self):
  156. import torch._C._profiler
  157. if self.tb is None:
  158. # TODO: Maybe indicate that the traceback was elided?
  159. return traceback.StackSummary()
  160. return _extract_symbolized_tb(
  161. torch._C._profiler.symbolize_tracebacks([self.tb])[0],
  162. self.skip
  163. )
  164. def __getstate__(self):
  165. return (None, {
  166. 'tb': None, # TB is not pickleable
  167. 'skip': self.skip,
  168. })
  169. @staticmethod
  170. def extract(*, script=False, cpp=False, skip=0):
  171. """
  172. Like traceback.extract_stack(), but faster (approximately 20x faster); it
  173. is fast enough that you can unconditionally log stacks this way as part of
  174. normal execution. It returns a torch._C._profiler.CapturedTraceback
  175. object that must be formatted specially with format_captured_tb.
  176. By default, this only reports Python backtraces (like extract_stack). You
  177. can set the script/cpp kwargs to also turn on TorchScript/C++ trace
  178. reporting.
  179. """
  180. import torch._C._profiler
  181. if script or cpp:
  182. assert skip == 0, "skip with script/cpp NYI"
  183. return CapturedTraceback(
  184. torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
  185. # Elide extract() frame if we don't have script/cpp frames. If
  186. # we do have those frames, it doesn't work so force zero.
  187. 0 if script or cpp else skip + 1
  188. )
  189. def format(self):
  190. """
  191. Formats a single torch._C._profiler.CapturedTraceback into a list of
  192. strings equivalent to the output of traceback.format_list. Note that if
  193. pass it CapturedTraceback with C++ traces, it is better not to use this
  194. function and use the batch formatting API format_captured_tbs to amortize
  195. the cost of symbolization
  196. """
  197. return traceback.format_list(self.summary())
  198. @staticmethod
  199. def format_all(tbs):
  200. """
  201. Bulk version of CapturedTraceback.format. Returns a list of list of strings.
  202. """
  203. import torch._C._profiler
  204. # Directly populate tracebacks that already have cached summaries
  205. rs: List[Optional[List[str]]] = []
  206. delayed_idxs = []
  207. for i, tb in enumerate(tbs):
  208. if tb.tb is None:
  209. rs.append([])
  210. else:
  211. rs.append(None)
  212. delayed_idxs.append(i)
  213. stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs])
  214. for i, stb in zip(delayed_idxs, stbs):
  215. rs[i] = traceback.format_list(tbs[i].summary())
  216. return rs
  217. def _extract_symbolized_tb(tb, skip):
  218. """
  219. Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of
  220. pre-processed stack trace entries.
  221. """
  222. stack = traceback.StackSummary()
  223. for f in reversed(tb[skip:]):
  224. stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))
  225. return stack