exceptions.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from __future__ import annotations
  2. import typing
  3. from starlette._exception_handler import (
  4. ExceptionHandlers,
  5. StatusHandlers,
  6. wrap_app_handling_exceptions,
  7. )
  8. from starlette.exceptions import HTTPException, WebSocketException
  9. from starlette.requests import Request
  10. from starlette.responses import PlainTextResponse, Response
  11. from starlette.types import ASGIApp, Receive, Scope, Send
  12. from starlette.websockets import WebSocket
  13. class ExceptionMiddleware:
  14. def __init__(
  15. self,
  16. app: ASGIApp,
  17. handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None,
  18. debug: bool = False,
  19. ) -> None:
  20. self.app = app
  21. self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
  22. self._status_handlers: StatusHandlers = {}
  23. self._exception_handlers: ExceptionHandlers = {
  24. HTTPException: self.http_exception,
  25. WebSocketException: self.websocket_exception,
  26. }
  27. if handlers is not None: # pragma: no branch
  28. for key, value in handlers.items():
  29. self.add_exception_handler(key, value)
  30. def add_exception_handler(
  31. self,
  32. exc_class_or_status_code: int | type[Exception],
  33. handler: typing.Callable[[Request, Exception], Response],
  34. ) -> None:
  35. if isinstance(exc_class_or_status_code, int):
  36. self._status_handlers[exc_class_or_status_code] = handler
  37. else:
  38. assert issubclass(exc_class_or_status_code, Exception)
  39. self._exception_handlers[exc_class_or_status_code] = handler
  40. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  41. if scope["type"] not in ("http", "websocket"):
  42. await self.app(scope, receive, send)
  43. return
  44. scope["starlette.exception_handlers"] = (
  45. self._exception_handlers,
  46. self._status_handlers,
  47. )
  48. conn: Request | WebSocket
  49. if scope["type"] == "http":
  50. conn = Request(scope, receive, send)
  51. else:
  52. conn = WebSocket(scope, receive, send)
  53. await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  54. def http_exception(self, request: Request, exc: Exception) -> Response:
  55. assert isinstance(exc, HTTPException)
  56. if exc.status_code in {204, 304}:
  57. return Response(status_code=exc.status_code, headers=exc.headers)
  58. return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
  59. async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
  60. assert isinstance(exc, WebSocketException)
  61. await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover