| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366 |
- # mypy: allow-untyped-defs
- import os
- import textwrap
- from enum import auto, Enum
- from traceback import extract_stack, format_exc, format_list, StackSummary
- from typing import Any, cast, NoReturn, Optional
- import torch._guards
- from . import config
- from .utils import counters
- def exportdb_error_message(case_name):
- return (
- "For more information about this error, see: "
- + "https://pytorch.org/docs/main/generated/exportdb/index.html#"
- + case_name.replace("_", "-")
- )
- import logging
- log = logging.getLogger(__name__)
- graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
- class TorchDynamoException(RuntimeError):
- pass
- class InternalTorchDynamoError(TorchDynamoException):
- pass
- class RestartAnalysis(TorchDynamoException):
- restart_reason: str
- def __init__(self, *args, restart_reason=None):
- self.restart_reason = restart_reason
- super().__init__(*args)
- class SpeculationRestartAnalysis(RestartAnalysis):
- pass
- class UnspecializeRestartAnalysis(RestartAnalysis):
- pass
- class SkipFrame(TorchDynamoException):
- pass
- class TorchRuntimeError(TorchDynamoException):
- pass
- class InvalidBackend(TorchDynamoException):
- def __init__(self, name):
- super().__init__(
- f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends."
- )
- class ResetRequired(TorchDynamoException):
- def __init__(self):
- super().__init__(
- textwrap.dedent(
- """
- Must call `torch._dynamo.reset()` before changing backends. Detected two calls to
- `torch.compile()` with a different backend compiler arguments.
- """
- )
- )
- class BackendCompilerFailed(TorchDynamoException):
- def __init__(self, backend_fn, inner_exception):
- self.backend_name = getattr(backend_fn, "__name__", "?")
- self.inner_exception = inner_exception
- msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}"
- super().__init__(msg)
- class Unsupported(TorchDynamoException):
- def __init__(self, msg):
- super().__init__(msg)
- self.real_stack = torch._guards.TracingContext.extract_stack()
- self.msg = msg
- self.category: Optional[str] = None
- self.add_to_stats()
- def remove_from_stats(self):
- assert self.category is not None
- counters[self.category][self.msg] -= 1
- if counters[self.category][self.msg] <= 0:
- del counters[self.category][self.msg]
- def add_to_stats(self, category="unimplemented"):
- self.category = category
- counters[category][self.msg] += 1
- class RecompileError(TorchDynamoException):
- pass
- class ArgsMismatchError(Unsupported):
- def __init__(self, msg):
- super().__init__(msg)
- class AttributeMutationError(Unsupported):
- def __init__(self, msg):
- super().__init__(msg)
- class CondOpArgsMismatchError(ArgsMismatchError):
- """
- Internal error from cond() due to arguments mismatch.
- """
- def __init__(self, msg):
- super().__init__(msg)
- class UserErrorType(Enum):
- DYNAMIC_CONTROL_FLOW = auto()
- ANTI_PATTERN = auto()
- STANDARD_LIBRARY = auto()
- CONSTRAINT_VIOLATION = auto()
- DYNAMIC_DIM = auto()
- INVALID_INPUT = auto()
- INVALID_OUTPUT = auto()
- class UserError(Unsupported):
- def __init__(self, error_type: UserErrorType, msg, case_name=None):
- """
- Type of errors that would be valid in Eager, but not supported in TorchDynamo.
- The error message should tell user about next actions.
- error_type: Type of user error
- msg: Actionable error message
- case_name: (Optional) Unique name (snake case) for the usage example in exportdb.
- """
- if case_name is not None:
- assert isinstance(case_name, str)
- if msg.endswith("."):
- msg += " "
- else:
- msg += "\n"
- msg += exportdb_error_message(case_name)
- super().__init__(msg)
- self.error_type = error_type
- self.message = msg
- class UserStopIteration(TorchDynamoException):
- value: Optional[Any]
- # Reference `StopIteration_init` in CPython
- # https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584
- def __init__(self, *args, **kwargs):
- super().__init__("unhandled `raise StopIteration`")
- if len(args) > 0:
- self.value = args[0]
- else:
- self.value = None
- class UnsafeScriptObjectError(TorchDynamoException):
- pass
- class UncapturedHigherOrderOpError(TorchDynamoException):
- pass
- class IncorrectUsage(Exception):
- pass
- class ObservedException(TorchDynamoException):
- pass
- # These exceptions are ok to fallback to eager/graph_break.
- exceptions_allowed_to_be_fallback = (
- torch._subclasses.fake_tensor.DataDependentOutputException,
- torch._subclasses.fake_tensor.DynamicOutputShapeException,
- torch._subclasses.fake_tensor.UnsupportedOperatorException,
- torch._subclasses.fake_tensor.UnsupportedFakeTensorException,
- )
- def unimplemented_with_warning(e: Exception, code, msg: str) -> NoReturn:
- # This function calls unimplemented internally and eventually graph breaks
- # or falls to eager. unimplemented itself does not print any user warnings,
- # i.e., its very silent. This helper function is intended when an error is
- # encountered in the torch.compile stack which is worth showing as warning
- # to the user. For example, if AOT Autograd backend fails with a fake tensor
- # exception, its ok to fallback to eager but not silently. Here, we can use
- # this function to log the message and the stack trace.
- graph_break_msg = format_error_msg_verbose(e, code)
- graph_breaks_log.debug("%s", graph_break_msg)
- log.warning(msg)
- unimplemented(msg, from_exc=e)
- _NOTHING = object()
- def unimplemented(msg: str, *, from_exc: Any = _NOTHING) -> NoReturn:
- assert msg != os.environ.get("BREAK", False)
- if from_exc is not _NOTHING:
- raise Unsupported(msg) from from_exc
- raise Unsupported(msg)
- def warning(msg: str) -> None:
- counters["warnings"][msg] += 1
- assert msg != os.environ.get("BREAK", False)
- # KeyError has special handling for its args
- # see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details
- class KeyErrorMsg:
- def __init__(self, value):
- self.value = value
- def __str__(self):
- return str(self.value)
- def __repr__(self) -> str:
- return self.__str__()
- def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None:
- import traceback
- exc.innermost_user_frame_summary = None # type: ignore[attr-defined]
- real_stack = get_real_stack(exc)
- if real_stack is not None and len(real_stack) > 0:
- exc.innermost_user_frame_summary = real_stack[-1] # type: ignore[attr-defined]
- msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}"
- if config.replay_record_enabled and hasattr(exc, "record_filename"):
- msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
- torch._dynamo.replay('{exc.record_filename}').\n"
- if not config.verbose and hasattr(exc, "real_stack"):
- msg += '\nSet TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information\n'
- if hasattr(exc, "inner_exception") and hasattr(
- exc.inner_exception, "minifier_path"
- ):
- if hasattr(exc.inner_exception, "buck_command"):
- msg += (
- f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
- f"this buck command to find the smallest traced graph "
- f"which reproduces this error: {exc.inner_exception.buck_command}\n"
- )
- else:
- msg += (
- f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
- "this script to find the smallest traced graph which reproduces this error.\n"
- )
- if not config.suppress_errors and not export:
- msg += (
- "\n\n"
- "You can suppress this exception and fall back to eager by setting:\n"
- " import torch._dynamo\n"
- " torch._dynamo.config.suppress_errors = True\n"
- )
- old_msg = "" if len(exc.args) == 0 else str(exc.args[0])
- if isinstance(exc, KeyError):
- exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:]
- else:
- new_msg = old_msg + msg
- exc.args = (new_msg,) + exc.args[1:]
- def get_real_stack(exc: Exception, frame=None) -> Optional[StackSummary]:
- real_stack = getattr(exc, "real_stack", None)
- if real_stack is None:
- return None
- # NB: it's possible for real_stack to be []; we still attempt to
- # report a stack anyway because the stack_above_dynamo may still
- # be useful for debugging
- stack_above_dynamo = []
- if frame is not None:
- # NB: frame is PyInterpreterFrame on Python 3.11 and later,
- # not a TRUE frame object. You can't actually feed it
- # to traceback because it doesn't have enough information.
- # To solve this problem, we technically should just materialize
- # the frame, the same way _PyFrame_GetFrameObject would do
- # (but we cannot actually do this, because this populates
- # frame_obj field, which default eval frame doesn't like).
- #
- # Fortunately, in this case, we can hack it: there's no need
- # to actually use the truly top frame, we can just extract
- # from where we are right now and rely on filter_stack to
- # get rid of all the dynamo frames. For ease of testing
- # we apply this behavior to ALL Python versions
- stack_above_dynamo = filter_stack(extract_stack())
- return cast(StackSummary, stack_above_dynamo + real_stack)
- # filter out all frames after entering dynamo
- def filter_stack(stack):
- user_stack = []
- for frame in stack:
- if "convert_frame" in frame.filename:
- break
- if "eval_frame" in frame.filename or "torch._dynamo.optimize(" in frame.line:
- continue
- user_stack.append(frame)
- return user_stack
- def format_error_msg_verbose(
- exc: Exception, code, record_filename=None, frame=None
- ) -> str:
- msg = (
- f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n"
- )
- msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
- msg += format_exc()
- real_stack = get_real_stack(exc, frame)
- if real_stack is not None:
- msg += (
- "\n"
- + "=" * 10
- + " The above exception occurred while processing the following code "
- + "=" * 10
- + "\n\n"
- )
- msg += "".join(format_list(real_stack))
- msg += "\n"
- msg += "=" * 10
- return msg
- def format_error_msg(exc: Exception, code, record_filename=None, frame=None) -> str:
- msg = os.linesep * 2
- if config.verbose:
- msg = format_error_msg_verbose(exc, code, record_filename, frame)
- else:
- msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\
- line {code.co_firstlineno} \ndue to: \n{format_exc()}"
- return msg
|