endpoints.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from __future__ import annotations
  2. import json
  3. import typing
  4. from starlette import status
  5. from starlette._utils import is_async_callable
  6. from starlette.concurrency import run_in_threadpool
  7. from starlette.exceptions import HTTPException
  8. from starlette.requests import Request
  9. from starlette.responses import PlainTextResponse, Response
  10. from starlette.types import Message, Receive, Scope, Send
  11. from starlette.websockets import WebSocket
  12. class HTTPEndpoint:
  13. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  14. assert scope["type"] == "http"
  15. self.scope = scope
  16. self.receive = receive
  17. self.send = send
  18. self._allowed_methods = [
  19. method
  20. for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
  21. if getattr(self, method.lower(), None) is not None
  22. ]
  23. def __await__(self) -> typing.Generator[typing.Any, None, None]:
  24. return self.dispatch().__await__()
  25. async def dispatch(self) -> None:
  26. request = Request(self.scope, receive=self.receive)
  27. handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
  28. handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed)
  29. is_async = is_async_callable(handler)
  30. if is_async:
  31. response = await handler(request)
  32. else:
  33. response = await run_in_threadpool(handler, request)
  34. await response(self.scope, self.receive, self.send)
  35. async def method_not_allowed(self, request: Request) -> Response:
  36. # If we're running inside a starlette application then raise an
  37. # exception, so that the configurable exception handler can deal with
  38. # returning the response. For plain ASGI apps, just return the response.
  39. headers = {"Allow": ", ".join(self._allowed_methods)}
  40. if "app" in self.scope:
  41. raise HTTPException(status_code=405, headers=headers)
  42. return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
  43. class WebSocketEndpoint:
  44. encoding: str | None = None # May be "text", "bytes", or "json".
  45. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  46. assert scope["type"] == "websocket"
  47. self.scope = scope
  48. self.receive = receive
  49. self.send = send
  50. def __await__(self) -> typing.Generator[typing.Any, None, None]:
  51. return self.dispatch().__await__()
  52. async def dispatch(self) -> None:
  53. websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
  54. await self.on_connect(websocket)
  55. close_code = status.WS_1000_NORMAL_CLOSURE
  56. try:
  57. while True:
  58. message = await websocket.receive()
  59. if message["type"] == "websocket.receive":
  60. data = await self.decode(websocket, message)
  61. await self.on_receive(websocket, data)
  62. elif message["type"] == "websocket.disconnect": # pragma: no branch
  63. close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
  64. break
  65. except Exception as exc:
  66. close_code = status.WS_1011_INTERNAL_ERROR
  67. raise exc
  68. finally:
  69. await self.on_disconnect(websocket, close_code)
  70. async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
  71. if self.encoding == "text":
  72. if "text" not in message:
  73. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  74. raise RuntimeError("Expected text websocket messages, but got bytes")
  75. return message["text"]
  76. elif self.encoding == "bytes":
  77. if "bytes" not in message:
  78. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  79. raise RuntimeError("Expected bytes websocket messages, but got text")
  80. return message["bytes"]
  81. elif self.encoding == "json":
  82. if message.get("text") is not None:
  83. text = message["text"]
  84. else:
  85. text = message["bytes"].decode("utf-8")
  86. try:
  87. return json.loads(text)
  88. except json.decoder.JSONDecodeError:
  89. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  90. raise RuntimeError("Malformed JSON data received.")
  91. assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
  92. return message["text"] if message.get("text") else message["bytes"]
  93. async def on_connect(self, websocket: WebSocket) -> None:
  94. """Override to handle an incoming websocket connection"""
  95. await websocket.accept()
  96. async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
  97. """Override to handle an incoming websocket message"""
  98. async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
  99. """Override to handle a disconnecting websocket"""