| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- # mypy: allow-untyped-defs
- from types import TracebackType
- from typing import List, Optional
- import tempfile
- import traceback
- import contextlib
- import inspect
- import os.path
- # This file contains utilities for ensuring dynamically compile()'d
- # code fragments display their line numbers in backtraces.
- #
- # The constraints:
- #
- # - We don't have control over the user exception printer (in particular,
- # we cannot assume the linecache trick will work, c.f.
- # https://stackoverflow.com/q/50515651/23845 )
- #
- # - We don't want to create temporary files every time we compile()
- # some code; file creation should happen lazily only at exception
- # time. Arguably, you *should* be willing to write out your
- # generated Python code to file system, but in some situations
- # (esp. library code) it would violate user expectation to write
- # to the file system, so we try to avoid it. In particular, we'd
- # like to keep the files around, so users can open up the files
- # mentioned in the trace; if the file is invisible, we want to
- # avoid clogging up the filesystem.
- #
- # If this is not a constraint for you, there is a substantially simpler
- # way to implement the functionality in this PR: instead of using
- # eval/exec directly, just always write a Python file to filesystem
- # and compile that.
- #
- # - You have control over a context where the compiled code will get
- # executed, so that we can interpose while the stack is unwinding
- # (otherwise, we have no way to interpose on the exception printing
- # process.)
- #
- # There are two things you have to do to make use of the utilities here:
- #
- # - When you compile your source code, you must save its string source
- # in its f_globals under the magic name "__compile_source__"
- #
- # - Before running the compiled code, enter the
- # report_compile_source_on_error() context manager.
- @contextlib.contextmanager
- def report_compile_source_on_error():
- try:
- yield
- except Exception as exc:
- tb = exc.__traceback__
- # Walk the traceback, looking for frames that have
- # source attached
- stack = []
- while tb is not None:
- filename = tb.tb_frame.f_code.co_filename
- source = tb.tb_frame.f_globals.get("__compile_source__")
- if filename == "<string>" and source is not None:
- # What black magic are we doing here? Intuitively, what
- # we would like to do is overwrite the co_filename on any
- # frames that were generated from exec/eval so that they
- # point to a temporary file that has the actual line
- # information, so Python's default error printer can print
- # useful line information on it.
- #
- # Writing out the temporary file is easy. But overwriting
- # co_filename is not! You can't modify the code object
- # associated with a frame. You can, however, reconstruct
- # a traceback with entirely new frames from scratch, so that's
- # what we do. But there's another problem, which is how to
- # make the frame?
- #
- # The black magic is we make a frankenstein frame and code
- # object which resembles the original frame/code enough so
- # that it will print properly under traceback and the default
- # error printer, but IT IS NOT THE ORIGINAL FRAME (you
- # couldn't, e.g., execute its code with different variables
- # and expect it to work.)
- # Don't delete the temporary file so the user can inspect it
- # TODO: This creates a temporary file for every frame, but we
- # technically only need one per distinct __compile_source__
- with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:
- f.write(source)
- # Create a frame. Python doesn't let you construct
- # FrameType directly, so just make one with compile
- frame = tb.tb_frame
- code = compile('__inspect_currentframe()', f.name, 'eval')
- code = code.replace(co_name=frame.f_code.co_name)
- # Python 3.11 only
- if hasattr(frame.f_code, 'co_linetable'):
- # We can't copy ALL of the metadata over, because you
- # can cause Python to segfault this way. What exactly
- # do we need? We need enough information for
- # traceback to be able to print the exception
- # correctly. Code reading Lib/traceback.py reveals
- # that traceback calls code.co_positions() in order to
- # get the augmented line/col numbers. Objects/codeobject.c,
- # specifically _PyCode_InitAddressRange, reveals that
- # this iterator is initialized from co_linetable and
- # co_firstfileno. So copy these we must!
- code = code.replace( # type: ignore[call-arg]
- co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined]
- co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined]
- )
- fake_frame = eval(
- code,
- frame.f_globals,
- {
- **frame.f_locals,
- '__inspect_currentframe': inspect.currentframe
- }
- )
- fake_tb = TracebackType(
- None, fake_frame, tb.tb_lasti, tb.tb_lineno
- )
- stack.append(fake_tb)
- else:
- stack.append(tb)
- tb = tb.tb_next
- # Reconstruct the linked list
- tb_next = None
- for tb in reversed(stack):
- tb.tb_next = tb_next
- tb_next = tb
- raise exc.with_traceback(tb_next) # noqa: B904
- def shorten_filename(fn, *, base=None):
- """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""
- if base is None:
- base = os.path.dirname(os.path.dirname(__file__))
- # Truncate torch/foo.py to foo.py
- try:
- prefix = os.path.commonpath([fn, base])
- except ValueError:
- return fn
- else:
- return fn[len(prefix) + 1:]
- def format_frame(frame, *, base=None, line=False):
- """
- Format a FrameSummary in a short way, without printing full absolute path or code.
- The idea is the result fits on a single line.
- """
- extra_line = ""
- if line:
- extra_line = f"{frame.line} # "
- return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"
- def format_traceback_short(tb):
- """Format a TracebackType in a short way, printing only the inner-most frame."""
- return format_frame(traceback.extract_tb(tb)[-1])
- class CapturedTraceback:
- __slots__ = ['tb', 'skip']
- def __init__(self, tb, skip=0):
- self.tb = tb
- self.skip = skip
- def cleanup(self):
- self.tb = None
- def summary(self):
- import torch._C._profiler
- if self.tb is None:
- # TODO: Maybe indicate that the traceback was elided?
- return traceback.StackSummary()
- return _extract_symbolized_tb(
- torch._C._profiler.symbolize_tracebacks([self.tb])[0],
- self.skip
- )
- def __getstate__(self):
- return (None, {
- 'tb': None, # TB is not pickleable
- 'skip': self.skip,
- })
- @staticmethod
- def extract(*, script=False, cpp=False, skip=0):
- """
- Like traceback.extract_stack(), but faster (approximately 20x faster); it
- is fast enough that you can unconditionally log stacks this way as part of
- normal execution. It returns a torch._C._profiler.CapturedTraceback
- object that must be formatted specially with format_captured_tb.
- By default, this only reports Python backtraces (like extract_stack). You
- can set the script/cpp kwargs to also turn on TorchScript/C++ trace
- reporting.
- """
- import torch._C._profiler
- if script or cpp:
- assert skip == 0, "skip with script/cpp NYI"
- return CapturedTraceback(
- torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
- # Elide extract() frame if we don't have script/cpp frames. If
- # we do have those frames, it doesn't work so force zero.
- 0 if script or cpp else skip + 1
- )
- def format(self):
- """
- Formats a single torch._C._profiler.CapturedTraceback into a list of
- strings equivalent to the output of traceback.format_list. Note that if
- pass it CapturedTraceback with C++ traces, it is better not to use this
- function and use the batch formatting API format_captured_tbs to amortize
- the cost of symbolization
- """
- return traceback.format_list(self.summary())
- @staticmethod
- def format_all(tbs):
- """
- Bulk version of CapturedTraceback.format. Returns a list of list of strings.
- """
- import torch._C._profiler
- # Directly populate tracebacks that already have cached summaries
- rs: List[Optional[List[str]]] = []
- delayed_idxs = []
- for i, tb in enumerate(tbs):
- if tb.tb is None:
- rs.append([])
- else:
- rs.append(None)
- delayed_idxs.append(i)
- stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs])
- for i, stb in zip(delayed_idxs, stbs):
- rs[i] = traceback.format_list(tbs[i].summary())
- return rs
- def _extract_symbolized_tb(tb, skip):
- """
- Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of
- pre-processed stack trace entries.
- """
- stack = traceback.StackSummary()
- for f in reversed(tb[skip:]):
- stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))
- return stack
|