_exception_handler.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from __future__ import annotations
  2. import typing
  3. from starlette._utils import is_async_callable
  4. from starlette.concurrency import run_in_threadpool
  5. from starlette.exceptions import HTTPException
  6. from starlette.requests import Request
  7. from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
  8. from starlette.websockets import WebSocket
  9. ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
  10. StatusHandlers = typing.Dict[int, ExceptionHandler]
  11. def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
  12. for cls in type(exc).__mro__:
  13. if cls in exc_handlers:
  14. return exc_handlers[cls]
  15. return None
  16. def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
  17. exception_handlers: ExceptionHandlers
  18. status_handlers: StatusHandlers
  19. try:
  20. exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
  21. except KeyError:
  22. exception_handlers, status_handlers = {}, {}
  23. async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
  24. response_started = False
  25. async def sender(message: Message) -> None:
  26. nonlocal response_started
  27. if message["type"] == "http.response.start":
  28. response_started = True
  29. await send(message)
  30. try:
  31. await app(scope, receive, sender)
  32. except Exception as exc:
  33. handler = None
  34. if isinstance(exc, HTTPException):
  35. handler = status_handlers.get(exc.status_code)
  36. if handler is None:
  37. handler = _lookup_exception_handler(exception_handlers, exc)
  38. if handler is None:
  39. raise exc
  40. if response_started:
  41. raise RuntimeError("Caught handled exception, but response already started.") from exc
  42. if is_async_callable(handler):
  43. response = await handler(conn, exc)
  44. else:
  45. response = await run_in_threadpool(handler, conn, exc) # type: ignore
  46. if response is not None:
  47. await response(scope, receive, sender)
  48. return wrapped_app