exc.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # mypy: allow-untyped-defs
  2. import os
  3. import textwrap
  4. from enum import auto, Enum
  5. from traceback import extract_stack, format_exc, format_list, StackSummary
  6. from typing import Any, cast, NoReturn, Optional
  7. import torch._guards
  8. from . import config
  9. from .utils import counters
  10. def exportdb_error_message(case_name):
  11. return (
  12. "For more information about this error, see: "
  13. + "https://pytorch.org/docs/main/generated/exportdb/index.html#"
  14. + case_name.replace("_", "-")
  15. )
  16. import logging
  17. log = logging.getLogger(__name__)
  18. graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
  19. class TorchDynamoException(RuntimeError):
  20. pass
  21. class InternalTorchDynamoError(TorchDynamoException):
  22. pass
  23. class RestartAnalysis(TorchDynamoException):
  24. restart_reason: str
  25. def __init__(self, *args, restart_reason=None):
  26. self.restart_reason = restart_reason
  27. super().__init__(*args)
  28. class SpeculationRestartAnalysis(RestartAnalysis):
  29. pass
  30. class UnspecializeRestartAnalysis(RestartAnalysis):
  31. pass
  32. class SkipFrame(TorchDynamoException):
  33. pass
  34. class TorchRuntimeError(TorchDynamoException):
  35. pass
  36. class InvalidBackend(TorchDynamoException):
  37. def __init__(self, name):
  38. super().__init__(
  39. f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends."
  40. )
  41. class ResetRequired(TorchDynamoException):
  42. def __init__(self):
  43. super().__init__(
  44. textwrap.dedent(
  45. """
  46. Must call `torch._dynamo.reset()` before changing backends. Detected two calls to
  47. `torch.compile()` with a different backend compiler arguments.
  48. """
  49. )
  50. )
  51. class BackendCompilerFailed(TorchDynamoException):
  52. def __init__(self, backend_fn, inner_exception):
  53. self.backend_name = getattr(backend_fn, "__name__", "?")
  54. self.inner_exception = inner_exception
  55. msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}"
  56. super().__init__(msg)
  57. class Unsupported(TorchDynamoException):
  58. def __init__(self, msg):
  59. super().__init__(msg)
  60. self.real_stack = torch._guards.TracingContext.extract_stack()
  61. self.msg = msg
  62. self.category: Optional[str] = None
  63. self.add_to_stats()
  64. def remove_from_stats(self):
  65. assert self.category is not None
  66. counters[self.category][self.msg] -= 1
  67. if counters[self.category][self.msg] <= 0:
  68. del counters[self.category][self.msg]
  69. def add_to_stats(self, category="unimplemented"):
  70. self.category = category
  71. counters[category][self.msg] += 1
  72. class RecompileError(TorchDynamoException):
  73. pass
  74. class ArgsMismatchError(Unsupported):
  75. def __init__(self, msg):
  76. super().__init__(msg)
  77. class AttributeMutationError(Unsupported):
  78. def __init__(self, msg):
  79. super().__init__(msg)
  80. class CondOpArgsMismatchError(ArgsMismatchError):
  81. """
  82. Internal error from cond() due to arguments mismatch.
  83. """
  84. def __init__(self, msg):
  85. super().__init__(msg)
  86. class UserErrorType(Enum):
  87. DYNAMIC_CONTROL_FLOW = auto()
  88. ANTI_PATTERN = auto()
  89. STANDARD_LIBRARY = auto()
  90. CONSTRAINT_VIOLATION = auto()
  91. DYNAMIC_DIM = auto()
  92. INVALID_INPUT = auto()
  93. INVALID_OUTPUT = auto()
  94. class UserError(Unsupported):
  95. def __init__(self, error_type: UserErrorType, msg, case_name=None):
  96. """
  97. Type of errors that would be valid in Eager, but not supported in TorchDynamo.
  98. The error message should tell user about next actions.
  99. error_type: Type of user error
  100. msg: Actionable error message
  101. case_name: (Optional) Unique name (snake case) for the usage example in exportdb.
  102. """
  103. if case_name is not None:
  104. assert isinstance(case_name, str)
  105. if msg.endswith("."):
  106. msg += " "
  107. else:
  108. msg += "\n"
  109. msg += exportdb_error_message(case_name)
  110. super().__init__(msg)
  111. self.error_type = error_type
  112. self.message = msg
  113. class UserStopIteration(TorchDynamoException):
  114. value: Optional[Any]
  115. # Reference `StopIteration_init` in CPython
  116. # https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584
  117. def __init__(self, *args, **kwargs):
  118. super().__init__("unhandled `raise StopIteration`")
  119. if len(args) > 0:
  120. self.value = args[0]
  121. else:
  122. self.value = None
  123. class UnsafeScriptObjectError(TorchDynamoException):
  124. pass
  125. class UncapturedHigherOrderOpError(TorchDynamoException):
  126. pass
  127. class IncorrectUsage(Exception):
  128. pass
  129. class ObservedException(TorchDynamoException):
  130. pass
  131. # These exceptions are ok to fallback to eager/graph_break.
  132. exceptions_allowed_to_be_fallback = (
  133. torch._subclasses.fake_tensor.DataDependentOutputException,
  134. torch._subclasses.fake_tensor.DynamicOutputShapeException,
  135. torch._subclasses.fake_tensor.UnsupportedOperatorException,
  136. torch._subclasses.fake_tensor.UnsupportedFakeTensorException,
  137. )
  138. def unimplemented_with_warning(e: Exception, code, msg: str) -> NoReturn:
  139. # This function calls unimplemented internally and eventually graph breaks
  140. # or falls to eager. unimplemented itself does not print any user warnings,
  141. # i.e., its very silent. This helper function is intended when an error is
  142. # encountered in the torch.compile stack which is worth showing as warning
  143. # to the user. For example, if AOT Autograd backend fails with a fake tensor
  144. # exception, its ok to fallback to eager but not silently. Here, we can use
  145. # this function to log the message and the stack trace.
  146. graph_break_msg = format_error_msg_verbose(e, code)
  147. graph_breaks_log.debug("%s", graph_break_msg)
  148. log.warning(msg)
  149. unimplemented(msg, from_exc=e)
  150. _NOTHING = object()
  151. def unimplemented(msg: str, *, from_exc: Any = _NOTHING) -> NoReturn:
  152. assert msg != os.environ.get("BREAK", False)
  153. if from_exc is not _NOTHING:
  154. raise Unsupported(msg) from from_exc
  155. raise Unsupported(msg)
  156. def warning(msg: str) -> None:
  157. counters["warnings"][msg] += 1
  158. assert msg != os.environ.get("BREAK", False)
  159. # KeyError has special handling for its args
  160. # see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details
  161. class KeyErrorMsg:
  162. def __init__(self, value):
  163. self.value = value
  164. def __str__(self):
  165. return str(self.value)
  166. def __repr__(self) -> str:
  167. return self.__str__()
  168. def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None:
  169. import traceback
  170. exc.innermost_user_frame_summary = None # type: ignore[attr-defined]
  171. real_stack = get_real_stack(exc)
  172. if real_stack is not None and len(real_stack) > 0:
  173. exc.innermost_user_frame_summary = real_stack[-1] # type: ignore[attr-defined]
  174. msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}"
  175. if config.replay_record_enabled and hasattr(exc, "record_filename"):
  176. msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
  177. torch._dynamo.replay('{exc.record_filename}').\n"
  178. if not config.verbose and hasattr(exc, "real_stack"):
  179. msg += '\nSet TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information\n'
  180. if hasattr(exc, "inner_exception") and hasattr(
  181. exc.inner_exception, "minifier_path"
  182. ):
  183. if hasattr(exc.inner_exception, "buck_command"):
  184. msg += (
  185. f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
  186. f"this buck command to find the smallest traced graph "
  187. f"which reproduces this error: {exc.inner_exception.buck_command}\n"
  188. )
  189. else:
  190. msg += (
  191. f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
  192. "this script to find the smallest traced graph which reproduces this error.\n"
  193. )
  194. if not config.suppress_errors and not export:
  195. msg += (
  196. "\n\n"
  197. "You can suppress this exception and fall back to eager by setting:\n"
  198. " import torch._dynamo\n"
  199. " torch._dynamo.config.suppress_errors = True\n"
  200. )
  201. old_msg = "" if len(exc.args) == 0 else str(exc.args[0])
  202. if isinstance(exc, KeyError):
  203. exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:]
  204. else:
  205. new_msg = old_msg + msg
  206. exc.args = (new_msg,) + exc.args[1:]
  207. def get_real_stack(exc: Exception, frame=None) -> Optional[StackSummary]:
  208. real_stack = getattr(exc, "real_stack", None)
  209. if real_stack is None:
  210. return None
  211. # NB: it's possible for real_stack to be []; we still attempt to
  212. # report a stack anyway because the stack_above_dynamo may still
  213. # be useful for debugging
  214. stack_above_dynamo = []
  215. if frame is not None:
  216. # NB: frame is PyInterpreterFrame on Python 3.11 and later,
  217. # not a TRUE frame object. You can't actually feed it
  218. # to traceback because it doesn't have enough information.
  219. # To solve this problem, we technically should just materialize
  220. # the frame, the same way _PyFrame_GetFrameObject would do
  221. # (but we cannot actually do this, because this populates
  222. # frame_obj field, which default eval frame doesn't like).
  223. #
  224. # Fortunately, in this case, we can hack it: there's no need
  225. # to actually use the truly top frame, we can just extract
  226. # from where we are right now and rely on filter_stack to
  227. # get rid of all the dynamo frames. For ease of testing
  228. # we apply this behavior to ALL Python versions
  229. stack_above_dynamo = filter_stack(extract_stack())
  230. return cast(StackSummary, stack_above_dynamo + real_stack)
  231. # filter out all frames after entering dynamo
  232. def filter_stack(stack):
  233. user_stack = []
  234. for frame in stack:
  235. if "convert_frame" in frame.filename:
  236. break
  237. if "eval_frame" in frame.filename or "torch._dynamo.optimize(" in frame.line:
  238. continue
  239. user_stack.append(frame)
  240. return user_stack
  241. def format_error_msg_verbose(
  242. exc: Exception, code, record_filename=None, frame=None
  243. ) -> str:
  244. msg = (
  245. f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n"
  246. )
  247. msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
  248. msg += format_exc()
  249. real_stack = get_real_stack(exc, frame)
  250. if real_stack is not None:
  251. msg += (
  252. "\n"
  253. + "=" * 10
  254. + " The above exception occurred while processing the following code "
  255. + "=" * 10
  256. + "\n\n"
  257. )
  258. msg += "".join(format_list(real_stack))
  259. msg += "\n"
  260. msg += "=" * 10
  261. return msg
  262. def format_error_msg(exc: Exception, code, record_filename=None, frame=None) -> str:
  263. msg = os.linesep * 2
  264. if config.verbose:
  265. msg = format_error_msg_verbose(exc, code, record_filename, frame)
  266. else:
  267. msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\
  268. line {code.co_firstlineno} \ndue to: \n{format_exc()}"
  269. return msg