base.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from __future__ import annotations
  2. import typing
  3. import anyio
  4. from anyio.abc import ObjectReceiveStream, ObjectSendStream
  5. from starlette._utils import collapse_excgroups
  6. from starlette.requests import ClientDisconnect, Request
  7. from starlette.responses import AsyncContentStream, Response
  8. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  9. RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
  10. DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
  11. T = typing.TypeVar("T")
  12. class _CachedRequest(Request):
  13. """
  14. If the user calls Request.body() from their dispatch function
  15. we cache the entire request body in memory and pass that to downstream middlewares,
  16. but if they call Request.stream() then all we do is send an
  17. empty body so that downstream things don't hang forever.
  18. """
  19. def __init__(self, scope: Scope, receive: Receive):
  20. super().__init__(scope, receive)
  21. self._wrapped_rcv_disconnected = False
  22. self._wrapped_rcv_consumed = False
  23. self._wrapped_rc_stream = self.stream()
  24. async def wrapped_receive(self) -> Message:
  25. # wrapped_rcv state 1: disconnected
  26. if self._wrapped_rcv_disconnected:
  27. # we've already sent a disconnect to the downstream app
  28. # we don't need to wait to get another one
  29. # (although most ASGI servers will just keep sending it)
  30. return {"type": "http.disconnect"}
  31. # wrapped_rcv state 1: consumed but not yet disconnected
  32. if self._wrapped_rcv_consumed:
  33. # since the downstream app has consumed us all that is left
  34. # is to send it a disconnect
  35. if self._is_disconnected:
  36. # the middleware has already seen the disconnect
  37. # since we know the client is disconnected no need to wait
  38. # for the message
  39. self._wrapped_rcv_disconnected = True
  40. return {"type": "http.disconnect"}
  41. # we don't know yet if the client is disconnected or not
  42. # so we'll wait until we get that message
  43. msg = await self.receive()
  44. if msg["type"] != "http.disconnect": # pragma: no cover
  45. # at this point a disconnect is all that we should be receiving
  46. # if we get something else, things went wrong somewhere
  47. raise RuntimeError(f"Unexpected message received: {msg['type']}")
  48. self._wrapped_rcv_disconnected = True
  49. return msg
  50. # wrapped_rcv state 3: not yet consumed
  51. if getattr(self, "_body", None) is not None:
  52. # body() was called, we return it even if the client disconnected
  53. self._wrapped_rcv_consumed = True
  54. return {
  55. "type": "http.request",
  56. "body": self._body,
  57. "more_body": False,
  58. }
  59. elif self._stream_consumed:
  60. # stream() was called to completion
  61. # return an empty body so that downstream apps don't hang
  62. # waiting for a disconnect
  63. self._wrapped_rcv_consumed = True
  64. return {
  65. "type": "http.request",
  66. "body": b"",
  67. "more_body": False,
  68. }
  69. else:
  70. # body() was never called and stream() wasn't consumed
  71. try:
  72. stream = self.stream()
  73. chunk = await stream.__anext__()
  74. self._wrapped_rcv_consumed = self._stream_consumed
  75. return {
  76. "type": "http.request",
  77. "body": chunk,
  78. "more_body": not self._stream_consumed,
  79. }
  80. except ClientDisconnect:
  81. self._wrapped_rcv_disconnected = True
  82. return {"type": "http.disconnect"}
  83. class BaseHTTPMiddleware:
  84. def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
  85. self.app = app
  86. self.dispatch_func = self.dispatch if dispatch is None else dispatch
  87. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  88. if scope["type"] != "http":
  89. await self.app(scope, receive, send)
  90. return
  91. request = _CachedRequest(scope, receive)
  92. wrapped_receive = request.wrapped_receive
  93. response_sent = anyio.Event()
  94. async def call_next(request: Request) -> Response:
  95. app_exc: Exception | None = None
  96. send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
  97. recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
  98. send_stream, recv_stream = anyio.create_memory_object_stream()
  99. async def receive_or_disconnect() -> Message:
  100. if response_sent.is_set():
  101. return {"type": "http.disconnect"}
  102. async with anyio.create_task_group() as task_group:
  103. async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
  104. result = await func()
  105. task_group.cancel_scope.cancel()
  106. return result
  107. task_group.start_soon(wrap, response_sent.wait)
  108. message = await wrap(wrapped_receive)
  109. if response_sent.is_set():
  110. return {"type": "http.disconnect"}
  111. return message
  112. async def close_recv_stream_on_response_sent() -> None:
  113. await response_sent.wait()
  114. recv_stream.close()
  115. async def send_no_error(message: Message) -> None:
  116. try:
  117. await send_stream.send(message)
  118. except anyio.BrokenResourceError:
  119. # recv_stream has been closed, i.e. response_sent has been set.
  120. return
  121. async def coro() -> None:
  122. nonlocal app_exc
  123. async with send_stream:
  124. try:
  125. await self.app(scope, receive_or_disconnect, send_no_error)
  126. except Exception as exc:
  127. app_exc = exc
  128. task_group.start_soon(close_recv_stream_on_response_sent)
  129. task_group.start_soon(coro)
  130. try:
  131. message = await recv_stream.receive()
  132. info = message.get("info", None)
  133. if message["type"] == "http.response.debug" and info is not None:
  134. message = await recv_stream.receive()
  135. except anyio.EndOfStream:
  136. if app_exc is not None:
  137. raise app_exc
  138. raise RuntimeError("No response returned.")
  139. assert message["type"] == "http.response.start"
  140. async def body_stream() -> typing.AsyncGenerator[bytes, None]:
  141. async with recv_stream:
  142. async for message in recv_stream:
  143. assert message["type"] == "http.response.body"
  144. body = message.get("body", b"")
  145. if body:
  146. yield body
  147. if not message.get("more_body", False):
  148. break
  149. if app_exc is not None:
  150. raise app_exc
  151. response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
  152. response.raw_headers = message["headers"]
  153. return response
  154. with collapse_excgroups():
  155. async with anyio.create_task_group() as task_group:
  156. response = await self.dispatch_func(request, call_next)
  157. await response(scope, wrapped_receive, send)
  158. response_sent.set()
  159. async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
  160. raise NotImplementedError() # pragma: no cover
  161. class _StreamingResponse(Response):
  162. def __init__(
  163. self,
  164. content: AsyncContentStream,
  165. status_code: int = 200,
  166. headers: typing.Mapping[str, str] | None = None,
  167. media_type: str | None = None,
  168. info: typing.Mapping[str, typing.Any] | None = None,
  169. ) -> None:
  170. self.info = info
  171. self.body_iterator = content
  172. self.status_code = status_code
  173. self.media_type = media_type
  174. self.init_headers(headers)
  175. self.background = None
  176. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  177. if self.info is not None:
  178. await send({"type": "http.response.debug", "info": self.info})
  179. await send(
  180. {
  181. "type": "http.response.start",
  182. "status": self.status_code,
  183. "headers": self.raw_headers,
  184. }
  185. )
  186. async for chunk in self.body_iterator:
  187. await send({"type": "http.response.body", "body": chunk, "more_body": True})
  188. await send({"type": "http.response.body", "body": b"", "more_body": False})
  189. if self.background:
  190. await self.background()