| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762 |
- from __future__ import annotations
- import contextlib
- import inspect
- import io
- import json
- import math
- import queue
- import sys
- import typing
- from concurrent.futures import Future
- from functools import cached_property
- from types import GeneratorType
- from urllib.parse import unquote, urljoin
- import anyio
- import anyio.abc
- import anyio.from_thread
- from anyio.abc import ObjectReceiveStream, ObjectSendStream
- from anyio.streams.stapled import StapledObjectStream
- from starlette._utils import is_async_callable
- from starlette.types import ASGIApp, Message, Receive, Scope, Send
- from starlette.websockets import WebSocketDisconnect
- if sys.version_info >= (3, 10): # pragma: no cover
- from typing import TypeGuard
- else: # pragma: no cover
- from typing_extensions import TypeGuard
- try:
- import httpx
- except ModuleNotFoundError: # pragma: no cover
- raise RuntimeError(
- "The starlette.testclient module requires the httpx package to be installed.\n"
- "You can install this with:\n"
- " $ pip install httpx\n"
- )
- _PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
- ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
- ASGI2App = typing.Callable[[Scope], ASGIInstance]
- ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
- _RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]
- def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
- if inspect.isclass(app):
- return hasattr(app, "__await__")
- return is_async_callable(app)
- class _WrapASGI2:
- """
- Provide an ASGI3 interface onto an ASGI2 app.
- """
- def __init__(self, app: ASGI2App) -> None:
- self.app = app
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- instance = self.app(scope)
- await instance(receive, send)
- class _AsyncBackend(typing.TypedDict):
- backend: str
- backend_options: dict[str, typing.Any]
- class _Upgrade(Exception):
- def __init__(self, session: WebSocketTestSession) -> None:
- self.session = session
- class WebSocketDenialResponse( # type: ignore[misc]
- httpx.Response,
- WebSocketDisconnect,
- ):
- """
- A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
- `WebSocket` is closed before being accepted with a `send_denial_response()`.
- """
- class WebSocketTestSession:
- def __init__(
- self,
- app: ASGI3App,
- scope: Scope,
- portal_factory: _PortalFactoryType,
- ) -> None:
- self.app = app
- self.scope = scope
- self.accepted_subprotocol = None
- self.portal_factory = portal_factory
- self._receive_queue: queue.Queue[Message] = queue.Queue()
- self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
- self.extra_headers = None
- def __enter__(self) -> WebSocketTestSession:
- self.exit_stack = contextlib.ExitStack()
- self.portal = self.exit_stack.enter_context(self.portal_factory())
- try:
- _: Future[None] = self.portal.start_task_soon(self._run)
- self.send({"type": "websocket.connect"})
- message = self.receive()
- self._raise_on_close(message)
- except Exception:
- self.exit_stack.close()
- raise
- self.accepted_subprotocol = message.get("subprotocol", None)
- self.extra_headers = message.get("headers", None)
- return self
- @cached_property
- def should_close(self) -> anyio.Event:
- return anyio.Event()
- async def _notify_close(self) -> None:
- self.should_close.set()
- def __exit__(self, *args: typing.Any) -> None:
- try:
- self.close(1000)
- finally:
- self.portal.start_task_soon(self._notify_close)
- self.exit_stack.close()
- while not self._send_queue.empty():
- message = self._send_queue.get()
- if isinstance(message, BaseException):
- raise message
- async def _run(self) -> None:
- """
- The sub-thread in which the websocket session runs.
- """
- async def run_app(tg: anyio.abc.TaskGroup) -> None:
- try:
- await self.app(self.scope, self._asgi_receive, self._asgi_send)
- except anyio.get_cancelled_exc_class():
- ...
- except BaseException as exc:
- self._send_queue.put(exc)
- raise
- finally:
- tg.cancel_scope.cancel()
- async with anyio.create_task_group() as tg:
- tg.start_soon(run_app, tg)
- await self.should_close.wait()
- tg.cancel_scope.cancel()
- async def _asgi_receive(self) -> Message:
- while self._receive_queue.empty():
- self._queue_event = anyio.Event()
- await self._queue_event.wait()
- return self._receive_queue.get()
- async def _asgi_send(self, message: Message) -> None:
- self._send_queue.put(message)
- def _raise_on_close(self, message: Message) -> None:
- if message["type"] == "websocket.close":
- raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
- elif message["type"] == "websocket.http.response.start":
- status_code: int = message["status"]
- headers: list[tuple[bytes, bytes]] = message["headers"]
- body: list[bytes] = []
- while True:
- message = self.receive()
- assert message["type"] == "websocket.http.response.body"
- body.append(message["body"])
- if not message.get("more_body", False):
- break
- raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
- def send(self, message: Message) -> None:
- self._receive_queue.put(message)
- if hasattr(self, "_queue_event"):
- self.portal.start_task_soon(self._queue_event.set)
- def send_text(self, data: str) -> None:
- self.send({"type": "websocket.receive", "text": data})
- def send_bytes(self, data: bytes) -> None:
- self.send({"type": "websocket.receive", "bytes": data})
- def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
- text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
- if mode == "text":
- self.send({"type": "websocket.receive", "text": text})
- else:
- self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
- def close(self, code: int = 1000, reason: str | None = None) -> None:
- self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
- def receive(self) -> Message:
- message = self._send_queue.get()
- if isinstance(message, BaseException):
- raise message
- return message
- def receive_text(self) -> str:
- message = self.receive()
- self._raise_on_close(message)
- return typing.cast(str, message["text"])
- def receive_bytes(self) -> bytes:
- message = self.receive()
- self._raise_on_close(message)
- return typing.cast(bytes, message["bytes"])
- def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
- message = self.receive()
- self._raise_on_close(message)
- if mode == "text":
- text = message["text"]
- else:
- text = message["bytes"].decode("utf-8")
- return json.loads(text)
- class _TestClientTransport(httpx.BaseTransport):
- def __init__(
- self,
- app: ASGI3App,
- portal_factory: _PortalFactoryType,
- raise_server_exceptions: bool = True,
- root_path: str = "",
- *,
- client: tuple[str, int],
- app_state: dict[str, typing.Any],
- ) -> None:
- self.app = app
- self.raise_server_exceptions = raise_server_exceptions
- self.root_path = root_path
- self.portal_factory = portal_factory
- self.app_state = app_state
- self.client = client
- def handle_request(self, request: httpx.Request) -> httpx.Response:
- scheme = request.url.scheme
- netloc = request.url.netloc.decode(encoding="ascii")
- path = request.url.path
- raw_path = request.url.raw_path
- query = request.url.query.decode(encoding="ascii")
- default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
- if ":" in netloc:
- host, port_string = netloc.split(":", 1)
- port = int(port_string)
- else:
- host = netloc
- port = default_port
- # Include the 'host' header.
- if "host" in request.headers:
- headers: list[tuple[bytes, bytes]] = []
- elif port == default_port: # pragma: no cover
- headers = [(b"host", host.encode())]
- else: # pragma: no cover
- headers = [(b"host", (f"{host}:{port}").encode())]
- # Include other request headers.
- headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
- scope: dict[str, typing.Any]
- if scheme in {"ws", "wss"}:
- subprotocol = request.headers.get("sec-websocket-protocol", None)
- if subprotocol is None:
- subprotocols: typing.Sequence[str] = []
- else:
- subprotocols = [value.strip() for value in subprotocol.split(",")]
- scope = {
- "type": "websocket",
- "path": unquote(path),
- "raw_path": raw_path.split(b"?", 1)[0],
- "root_path": self.root_path,
- "scheme": scheme,
- "query_string": query.encode(),
- "headers": headers,
- "client": self.client,
- "server": [host, port],
- "subprotocols": subprotocols,
- "state": self.app_state.copy(),
- "extensions": {"websocket.http.response": {}},
- }
- session = WebSocketTestSession(self.app, scope, self.portal_factory)
- raise _Upgrade(session)
- scope = {
- "type": "http",
- "http_version": "1.1",
- "method": request.method,
- "path": unquote(path),
- "raw_path": raw_path.split(b"?", 1)[0],
- "root_path": self.root_path,
- "scheme": scheme,
- "query_string": query.encode(),
- "headers": headers,
- "client": self.client,
- "server": [host, port],
- "extensions": {"http.response.debug": {}},
- "state": self.app_state.copy(),
- }
- request_complete = False
- response_started = False
- response_complete: anyio.Event
- raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
- template = None
- context = None
- async def receive() -> Message:
- nonlocal request_complete
- if request_complete:
- if not response_complete.is_set():
- await response_complete.wait()
- return {"type": "http.disconnect"}
- body = request.read()
- if isinstance(body, str):
- body_bytes: bytes = body.encode("utf-8") # pragma: no cover
- elif body is None:
- body_bytes = b"" # pragma: no cover
- elif isinstance(body, GeneratorType):
- try: # pragma: no cover
- chunk = body.send(None)
- if isinstance(chunk, str):
- chunk = chunk.encode("utf-8")
- return {"type": "http.request", "body": chunk, "more_body": True}
- except StopIteration: # pragma: no cover
- request_complete = True
- return {"type": "http.request", "body": b""}
- else:
- body_bytes = body
- request_complete = True
- return {"type": "http.request", "body": body_bytes}
- async def send(message: Message) -> None:
- nonlocal raw_kwargs, response_started, template, context
- if message["type"] == "http.response.start":
- assert not response_started, 'Received multiple "http.response.start" messages.'
- raw_kwargs["status_code"] = message["status"]
- raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
- response_started = True
- elif message["type"] == "http.response.body":
- assert response_started, 'Received "http.response.body" without "http.response.start".'
- assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
- body = message.get("body", b"")
- more_body = message.get("more_body", False)
- if request.method != "HEAD":
- raw_kwargs["stream"].write(body)
- if not more_body:
- raw_kwargs["stream"].seek(0)
- response_complete.set()
- elif message["type"] == "http.response.debug":
- template = message["info"]["template"]
- context = message["info"]["context"]
- try:
- with self.portal_factory() as portal:
- response_complete = portal.call(anyio.Event)
- portal.call(self.app, scope, receive, send)
- except BaseException as exc:
- if self.raise_server_exceptions:
- raise exc
- if self.raise_server_exceptions:
- assert response_started, "TestClient did not receive any response."
- elif not response_started:
- raw_kwargs = {
- "status_code": 500,
- "headers": [],
- "stream": io.BytesIO(),
- }
- raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
- response = httpx.Response(**raw_kwargs, request=request)
- if template is not None:
- response.template = template # type: ignore[attr-defined]
- response.context = context # type: ignore[attr-defined]
- return response
- class TestClient(httpx.Client):
- __test__ = False
- task: Future[None]
- portal: anyio.abc.BlockingPortal | None = None
- def __init__(
- self,
- app: ASGIApp,
- base_url: str = "http://testserver",
- raise_server_exceptions: bool = True,
- root_path: str = "",
- backend: typing.Literal["asyncio", "trio"] = "asyncio",
- backend_options: dict[str, typing.Any] | None = None,
- cookies: httpx._types.CookieTypes | None = None,
- headers: dict[str, str] | None = None,
- follow_redirects: bool = True,
- client: tuple[str, int] = ("testclient", 50000),
- ) -> None:
- self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
- if _is_asgi3(app):
- asgi_app = app
- else:
- app = typing.cast(ASGI2App, app) # type: ignore[assignment]
- asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
- self.app = asgi_app
- self.app_state: dict[str, typing.Any] = {}
- transport = _TestClientTransport(
- self.app,
- portal_factory=self._portal_factory,
- raise_server_exceptions=raise_server_exceptions,
- root_path=root_path,
- app_state=self.app_state,
- client=client,
- )
- if headers is None:
- headers = {}
- headers.setdefault("user-agent", "testclient")
- super().__init__(
- base_url=base_url,
- headers=headers,
- transport=transport,
- follow_redirects=follow_redirects,
- cookies=cookies,
- )
- @contextlib.contextmanager
- def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
- if self.portal is not None:
- yield self.portal
- else:
- with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
- yield portal
- def request( # type: ignore[override]
- self,
- method: str,
- url: httpx._types.URLTypes,
- *,
- content: httpx._types.RequestContent | None = None,
- data: _RequestData | None = None,
- files: httpx._types.RequestFiles | None = None,
- json: typing.Any = None,
- params: httpx._types.QueryParamTypes | None = None,
- headers: httpx._types.HeaderTypes | None = None,
- cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
- ) -> httpx.Response:
- url = self._merge_url(url)
- return super().request(
- method,
- url,
- content=content,
- data=data,
- files=files,
- json=json,
- params=params,
- headers=headers,
- cookies=cookies,
- auth=auth,
- follow_redirects=follow_redirects,
- timeout=timeout,
- extensions=extensions,
- )
- def get( # type: ignore[override]
- self,
- url: httpx._types.URLTypes,
- *,
- params: httpx._types.QueryParamTypes | None = None,
- headers: httpx._types.HeaderTypes | None = None,
- cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
- ) -> httpx.Response:
- return super().get(
- url,
- params=params,
- headers=headers,
- cookies=cookies,
- auth=auth,
- follow_redirects=follow_redirects,
- timeout=timeout,
- extensions=extensions,
- )
- def options( # type: ignore[override]
- self,
- url: httpx._types.URLTypes,
- *,
- params: httpx._types.QueryParamTypes | None = None,
- headers: httpx._types.HeaderTypes | None = None,
- cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
- ) -> httpx.Response:
- return super().options(
- url,
- params=params,
- headers=headers,
- cookies=cookies,
- auth=auth,
- follow_redirects=follow_redirects,
- timeout=timeout,
- extensions=extensions,
- )
- def head( # type: ignore[override]
- self,
- url: httpx._types.URLTypes,
- *,
- params: httpx._types.QueryParamTypes | None = None,
- headers: httpx._types.HeaderTypes | None = None,
- cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
- ) -> httpx.Response:
- return super().head(
- url,
- params=params,
- headers=headers,
- cookies=cookies,
- auth=auth,
- follow_redirects=follow_redirects,
- timeout=timeout,
- extensions=extensions,
- )
- def post( # type: ignore[override]
- self,
- url: httpx._types.URLTypes,
- *,
- content: httpx._types.RequestContent | None = None,
- data: _RequestData | None = None,
- files: httpx._types.RequestFiles | None = None,
- json: typing.Any = None,
- params: httpx._types.QueryParamTypes | None = None,
- headers: httpx._types.HeaderTypes | None = None,
- cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
- ) -> httpx.Response:
- return super().post(
- url,
- content=content,
- data=data,
- files=files,
- json=json,
- params=params,
- headers=headers,
- cookies=cookies,
- auth=auth,
- follow_redirects=follow_redirects,
- timeout=timeout,
- extensions=extensions,
- )
- def put( # type: ignore[override]
- self,
- url: httpx._types.URLTypes,
- *,
- content: httpx._types.RequestContent | None = None,
- data: _RequestData | None = None,
- files: httpx._types.RequestFiles | None = None,
- json: typing.Any = None,
- params: httpx._types.QueryParamTypes | None = None,
- headers: httpx._types.HeaderTypes | None = None,
- cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
- ) -> httpx.Response:
- return super().put(
- url,
- content=content,
- data=data,
- files=files,
- json=json,
- params=params,
- headers=headers,
- cookies=cookies,
- auth=auth,
- follow_redirects=follow_redirects,
- timeout=timeout,
- extensions=extensions,
- )
- def patch( # type: ignore[override]
- self,
- url: httpx._types.URLTypes,
- *,
- content: httpx._types.RequestContent | None = None,
- data: _RequestData | None = None,
- files: httpx._types.RequestFiles | None = None,
- json: typing.Any = None,
- params: httpx._types.QueryParamTypes | None = None,
- headers: httpx._types.HeaderTypes | None = None,
- cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
- ) -> httpx.Response:
- return super().patch(
- url,
- content=content,
- data=data,
- files=files,
- json=json,
- params=params,
- headers=headers,
- cookies=cookies,
- auth=auth,
- follow_redirects=follow_redirects,
- timeout=timeout,
- extensions=extensions,
- )
- def delete( # type: ignore[override]
- self,
- url: httpx._types.URLTypes,
- *,
- params: httpx._types.QueryParamTypes | None = None,
- headers: httpx._types.HeaderTypes | None = None,
- cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
- ) -> httpx.Response:
- return super().delete(
- url,
- params=params,
- headers=headers,
- cookies=cookies,
- auth=auth,
- follow_redirects=follow_redirects,
- timeout=timeout,
- extensions=extensions,
- )
- def websocket_connect(
- self,
- url: str,
- subprotocols: typing.Sequence[str] | None = None,
- **kwargs: typing.Any,
- ) -> WebSocketTestSession:
- url = urljoin("ws://testserver", url)
- headers = kwargs.get("headers", {})
- headers.setdefault("connection", "upgrade")
- headers.setdefault("sec-websocket-key", "testserver==")
- headers.setdefault("sec-websocket-version", "13")
- if subprotocols is not None:
- headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
- kwargs["headers"] = headers
- try:
- super().request("GET", url, **kwargs)
- except _Upgrade as exc:
- session = exc.session
- else:
- raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
- return session
- def __enter__(self) -> TestClient:
- with contextlib.ExitStack() as stack:
- self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
- @stack.callback
- def reset_portal() -> None:
- self.portal = None
- send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None]
- receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None]
- send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
- receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
- send1, receive1 = anyio.create_memory_object_stream(math.inf)
- send2, receive2 = anyio.create_memory_object_stream(math.inf)
- self.stream_send = StapledObjectStream(send1, receive1)
- self.stream_receive = StapledObjectStream(send2, receive2)
- self.task = portal.start_task_soon(self.lifespan)
- portal.call(self.wait_startup)
- @stack.callback
- def wait_shutdown() -> None:
- portal.call(self.wait_shutdown)
- self.exit_stack = stack.pop_all()
- return self
- def __exit__(self, *args: typing.Any) -> None:
- self.exit_stack.close()
- async def lifespan(self) -> None:
- scope = {"type": "lifespan", "state": self.app_state}
- try:
- await self.app(scope, self.stream_receive.receive, self.stream_send.send)
- finally:
- await self.stream_send.send(None)
- async def wait_startup(self) -> None:
- await self.stream_receive.send({"type": "lifespan.startup"})
- async def receive() -> typing.Any:
- message = await self.stream_send.receive()
- if message is None:
- self.task.result()
- return message
- message = await receive()
- assert message["type"] in (
- "lifespan.startup.complete",
- "lifespan.startup.failed",
- )
- if message["type"] == "lifespan.startup.failed":
- await receive()
- async def wait_shutdown(self) -> None:
- async def receive() -> typing.Any:
- message = await self.stream_send.receive()
- if message is None:
- self.task.result()
- return message
- async with self.stream_send, self.stream_receive:
- await self.stream_receive.send({"type": "lifespan.shutdown"})
- message = await receive()
- assert message["type"] in (
- "lifespan.shutdown.complete",
- "lifespan.shutdown.failed",
- )
- if message["type"] == "lifespan.shutdown.failed":
- await receive()
|