testclient.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762
  1. from __future__ import annotations
  2. import contextlib
  3. import inspect
  4. import io
  5. import json
  6. import math
  7. import queue
  8. import sys
  9. import typing
  10. from concurrent.futures import Future
  11. from functools import cached_property
  12. from types import GeneratorType
  13. from urllib.parse import unquote, urljoin
  14. import anyio
  15. import anyio.abc
  16. import anyio.from_thread
  17. from anyio.abc import ObjectReceiveStream, ObjectSendStream
  18. from anyio.streams.stapled import StapledObjectStream
  19. from starlette._utils import is_async_callable
  20. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  21. from starlette.websockets import WebSocketDisconnect
  22. if sys.version_info >= (3, 10): # pragma: no cover
  23. from typing import TypeGuard
  24. else: # pragma: no cover
  25. from typing_extensions import TypeGuard
  26. try:
  27. import httpx
  28. except ModuleNotFoundError: # pragma: no cover
  29. raise RuntimeError(
  30. "The starlette.testclient module requires the httpx package to be installed.\n"
  31. "You can install this with:\n"
  32. " $ pip install httpx\n"
  33. )
  34. _PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
  35. ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
  36. ASGI2App = typing.Callable[[Scope], ASGIInstance]
  37. ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
  38. _RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]
  39. def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
  40. if inspect.isclass(app):
  41. return hasattr(app, "__await__")
  42. return is_async_callable(app)
  43. class _WrapASGI2:
  44. """
  45. Provide an ASGI3 interface onto an ASGI2 app.
  46. """
  47. def __init__(self, app: ASGI2App) -> None:
  48. self.app = app
  49. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  50. instance = self.app(scope)
  51. await instance(receive, send)
  52. class _AsyncBackend(typing.TypedDict):
  53. backend: str
  54. backend_options: dict[str, typing.Any]
  55. class _Upgrade(Exception):
  56. def __init__(self, session: WebSocketTestSession) -> None:
  57. self.session = session
  58. class WebSocketDenialResponse( # type: ignore[misc]
  59. httpx.Response,
  60. WebSocketDisconnect,
  61. ):
  62. """
  63. A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
  64. `WebSocket` is closed before being accepted with a `send_denial_response()`.
  65. """
  66. class WebSocketTestSession:
  67. def __init__(
  68. self,
  69. app: ASGI3App,
  70. scope: Scope,
  71. portal_factory: _PortalFactoryType,
  72. ) -> None:
  73. self.app = app
  74. self.scope = scope
  75. self.accepted_subprotocol = None
  76. self.portal_factory = portal_factory
  77. self._receive_queue: queue.Queue[Message] = queue.Queue()
  78. self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
  79. self.extra_headers = None
  80. def __enter__(self) -> WebSocketTestSession:
  81. self.exit_stack = contextlib.ExitStack()
  82. self.portal = self.exit_stack.enter_context(self.portal_factory())
  83. try:
  84. _: Future[None] = self.portal.start_task_soon(self._run)
  85. self.send({"type": "websocket.connect"})
  86. message = self.receive()
  87. self._raise_on_close(message)
  88. except Exception:
  89. self.exit_stack.close()
  90. raise
  91. self.accepted_subprotocol = message.get("subprotocol", None)
  92. self.extra_headers = message.get("headers", None)
  93. return self
  94. @cached_property
  95. def should_close(self) -> anyio.Event:
  96. return anyio.Event()
  97. async def _notify_close(self) -> None:
  98. self.should_close.set()
  99. def __exit__(self, *args: typing.Any) -> None:
  100. try:
  101. self.close(1000)
  102. finally:
  103. self.portal.start_task_soon(self._notify_close)
  104. self.exit_stack.close()
  105. while not self._send_queue.empty():
  106. message = self._send_queue.get()
  107. if isinstance(message, BaseException):
  108. raise message
  109. async def _run(self) -> None:
  110. """
  111. The sub-thread in which the websocket session runs.
  112. """
  113. async def run_app(tg: anyio.abc.TaskGroup) -> None:
  114. try:
  115. await self.app(self.scope, self._asgi_receive, self._asgi_send)
  116. except anyio.get_cancelled_exc_class():
  117. ...
  118. except BaseException as exc:
  119. self._send_queue.put(exc)
  120. raise
  121. finally:
  122. tg.cancel_scope.cancel()
  123. async with anyio.create_task_group() as tg:
  124. tg.start_soon(run_app, tg)
  125. await self.should_close.wait()
  126. tg.cancel_scope.cancel()
  127. async def _asgi_receive(self) -> Message:
  128. while self._receive_queue.empty():
  129. self._queue_event = anyio.Event()
  130. await self._queue_event.wait()
  131. return self._receive_queue.get()
  132. async def _asgi_send(self, message: Message) -> None:
  133. self._send_queue.put(message)
  134. def _raise_on_close(self, message: Message) -> None:
  135. if message["type"] == "websocket.close":
  136. raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
  137. elif message["type"] == "websocket.http.response.start":
  138. status_code: int = message["status"]
  139. headers: list[tuple[bytes, bytes]] = message["headers"]
  140. body: list[bytes] = []
  141. while True:
  142. message = self.receive()
  143. assert message["type"] == "websocket.http.response.body"
  144. body.append(message["body"])
  145. if not message.get("more_body", False):
  146. break
  147. raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
  148. def send(self, message: Message) -> None:
  149. self._receive_queue.put(message)
  150. if hasattr(self, "_queue_event"):
  151. self.portal.start_task_soon(self._queue_event.set)
  152. def send_text(self, data: str) -> None:
  153. self.send({"type": "websocket.receive", "text": data})
  154. def send_bytes(self, data: bytes) -> None:
  155. self.send({"type": "websocket.receive", "bytes": data})
  156. def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
  157. text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
  158. if mode == "text":
  159. self.send({"type": "websocket.receive", "text": text})
  160. else:
  161. self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
  162. def close(self, code: int = 1000, reason: str | None = None) -> None:
  163. self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
  164. def receive(self) -> Message:
  165. message = self._send_queue.get()
  166. if isinstance(message, BaseException):
  167. raise message
  168. return message
  169. def receive_text(self) -> str:
  170. message = self.receive()
  171. self._raise_on_close(message)
  172. return typing.cast(str, message["text"])
  173. def receive_bytes(self) -> bytes:
  174. message = self.receive()
  175. self._raise_on_close(message)
  176. return typing.cast(bytes, message["bytes"])
  177. def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
  178. message = self.receive()
  179. self._raise_on_close(message)
  180. if mode == "text":
  181. text = message["text"]
  182. else:
  183. text = message["bytes"].decode("utf-8")
  184. return json.loads(text)
  185. class _TestClientTransport(httpx.BaseTransport):
  186. def __init__(
  187. self,
  188. app: ASGI3App,
  189. portal_factory: _PortalFactoryType,
  190. raise_server_exceptions: bool = True,
  191. root_path: str = "",
  192. *,
  193. client: tuple[str, int],
  194. app_state: dict[str, typing.Any],
  195. ) -> None:
  196. self.app = app
  197. self.raise_server_exceptions = raise_server_exceptions
  198. self.root_path = root_path
  199. self.portal_factory = portal_factory
  200. self.app_state = app_state
  201. self.client = client
  202. def handle_request(self, request: httpx.Request) -> httpx.Response:
  203. scheme = request.url.scheme
  204. netloc = request.url.netloc.decode(encoding="ascii")
  205. path = request.url.path
  206. raw_path = request.url.raw_path
  207. query = request.url.query.decode(encoding="ascii")
  208. default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
  209. if ":" in netloc:
  210. host, port_string = netloc.split(":", 1)
  211. port = int(port_string)
  212. else:
  213. host = netloc
  214. port = default_port
  215. # Include the 'host' header.
  216. if "host" in request.headers:
  217. headers: list[tuple[bytes, bytes]] = []
  218. elif port == default_port: # pragma: no cover
  219. headers = [(b"host", host.encode())]
  220. else: # pragma: no cover
  221. headers = [(b"host", (f"{host}:{port}").encode())]
  222. # Include other request headers.
  223. headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
  224. scope: dict[str, typing.Any]
  225. if scheme in {"ws", "wss"}:
  226. subprotocol = request.headers.get("sec-websocket-protocol", None)
  227. if subprotocol is None:
  228. subprotocols: typing.Sequence[str] = []
  229. else:
  230. subprotocols = [value.strip() for value in subprotocol.split(",")]
  231. scope = {
  232. "type": "websocket",
  233. "path": unquote(path),
  234. "raw_path": raw_path.split(b"?", 1)[0],
  235. "root_path": self.root_path,
  236. "scheme": scheme,
  237. "query_string": query.encode(),
  238. "headers": headers,
  239. "client": self.client,
  240. "server": [host, port],
  241. "subprotocols": subprotocols,
  242. "state": self.app_state.copy(),
  243. "extensions": {"websocket.http.response": {}},
  244. }
  245. session = WebSocketTestSession(self.app, scope, self.portal_factory)
  246. raise _Upgrade(session)
  247. scope = {
  248. "type": "http",
  249. "http_version": "1.1",
  250. "method": request.method,
  251. "path": unquote(path),
  252. "raw_path": raw_path.split(b"?", 1)[0],
  253. "root_path": self.root_path,
  254. "scheme": scheme,
  255. "query_string": query.encode(),
  256. "headers": headers,
  257. "client": self.client,
  258. "server": [host, port],
  259. "extensions": {"http.response.debug": {}},
  260. "state": self.app_state.copy(),
  261. }
  262. request_complete = False
  263. response_started = False
  264. response_complete: anyio.Event
  265. raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
  266. template = None
  267. context = None
  268. async def receive() -> Message:
  269. nonlocal request_complete
  270. if request_complete:
  271. if not response_complete.is_set():
  272. await response_complete.wait()
  273. return {"type": "http.disconnect"}
  274. body = request.read()
  275. if isinstance(body, str):
  276. body_bytes: bytes = body.encode("utf-8") # pragma: no cover
  277. elif body is None:
  278. body_bytes = b"" # pragma: no cover
  279. elif isinstance(body, GeneratorType):
  280. try: # pragma: no cover
  281. chunk = body.send(None)
  282. if isinstance(chunk, str):
  283. chunk = chunk.encode("utf-8")
  284. return {"type": "http.request", "body": chunk, "more_body": True}
  285. except StopIteration: # pragma: no cover
  286. request_complete = True
  287. return {"type": "http.request", "body": b""}
  288. else:
  289. body_bytes = body
  290. request_complete = True
  291. return {"type": "http.request", "body": body_bytes}
  292. async def send(message: Message) -> None:
  293. nonlocal raw_kwargs, response_started, template, context
  294. if message["type"] == "http.response.start":
  295. assert not response_started, 'Received multiple "http.response.start" messages.'
  296. raw_kwargs["status_code"] = message["status"]
  297. raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
  298. response_started = True
  299. elif message["type"] == "http.response.body":
  300. assert response_started, 'Received "http.response.body" without "http.response.start".'
  301. assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
  302. body = message.get("body", b"")
  303. more_body = message.get("more_body", False)
  304. if request.method != "HEAD":
  305. raw_kwargs["stream"].write(body)
  306. if not more_body:
  307. raw_kwargs["stream"].seek(0)
  308. response_complete.set()
  309. elif message["type"] == "http.response.debug":
  310. template = message["info"]["template"]
  311. context = message["info"]["context"]
  312. try:
  313. with self.portal_factory() as portal:
  314. response_complete = portal.call(anyio.Event)
  315. portal.call(self.app, scope, receive, send)
  316. except BaseException as exc:
  317. if self.raise_server_exceptions:
  318. raise exc
  319. if self.raise_server_exceptions:
  320. assert response_started, "TestClient did not receive any response."
  321. elif not response_started:
  322. raw_kwargs = {
  323. "status_code": 500,
  324. "headers": [],
  325. "stream": io.BytesIO(),
  326. }
  327. raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
  328. response = httpx.Response(**raw_kwargs, request=request)
  329. if template is not None:
  330. response.template = template # type: ignore[attr-defined]
  331. response.context = context # type: ignore[attr-defined]
  332. return response
  333. class TestClient(httpx.Client):
  334. __test__ = False
  335. task: Future[None]
  336. portal: anyio.abc.BlockingPortal | None = None
  337. def __init__(
  338. self,
  339. app: ASGIApp,
  340. base_url: str = "http://testserver",
  341. raise_server_exceptions: bool = True,
  342. root_path: str = "",
  343. backend: typing.Literal["asyncio", "trio"] = "asyncio",
  344. backend_options: dict[str, typing.Any] | None = None,
  345. cookies: httpx._types.CookieTypes | None = None,
  346. headers: dict[str, str] | None = None,
  347. follow_redirects: bool = True,
  348. client: tuple[str, int] = ("testclient", 50000),
  349. ) -> None:
  350. self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
  351. if _is_asgi3(app):
  352. asgi_app = app
  353. else:
  354. app = typing.cast(ASGI2App, app) # type: ignore[assignment]
  355. asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
  356. self.app = asgi_app
  357. self.app_state: dict[str, typing.Any] = {}
  358. transport = _TestClientTransport(
  359. self.app,
  360. portal_factory=self._portal_factory,
  361. raise_server_exceptions=raise_server_exceptions,
  362. root_path=root_path,
  363. app_state=self.app_state,
  364. client=client,
  365. )
  366. if headers is None:
  367. headers = {}
  368. headers.setdefault("user-agent", "testclient")
  369. super().__init__(
  370. base_url=base_url,
  371. headers=headers,
  372. transport=transport,
  373. follow_redirects=follow_redirects,
  374. cookies=cookies,
  375. )
  376. @contextlib.contextmanager
  377. def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
  378. if self.portal is not None:
  379. yield self.portal
  380. else:
  381. with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
  382. yield portal
  383. def request( # type: ignore[override]
  384. self,
  385. method: str,
  386. url: httpx._types.URLTypes,
  387. *,
  388. content: httpx._types.RequestContent | None = None,
  389. data: _RequestData | None = None,
  390. files: httpx._types.RequestFiles | None = None,
  391. json: typing.Any = None,
  392. params: httpx._types.QueryParamTypes | None = None,
  393. headers: httpx._types.HeaderTypes | None = None,
  394. cookies: httpx._types.CookieTypes | None = None,
  395. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  396. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  397. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  398. extensions: dict[str, typing.Any] | None = None,
  399. ) -> httpx.Response:
  400. url = self._merge_url(url)
  401. return super().request(
  402. method,
  403. url,
  404. content=content,
  405. data=data,
  406. files=files,
  407. json=json,
  408. params=params,
  409. headers=headers,
  410. cookies=cookies,
  411. auth=auth,
  412. follow_redirects=follow_redirects,
  413. timeout=timeout,
  414. extensions=extensions,
  415. )
  416. def get( # type: ignore[override]
  417. self,
  418. url: httpx._types.URLTypes,
  419. *,
  420. params: httpx._types.QueryParamTypes | None = None,
  421. headers: httpx._types.HeaderTypes | None = None,
  422. cookies: httpx._types.CookieTypes | None = None,
  423. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  424. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  425. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  426. extensions: dict[str, typing.Any] | None = None,
  427. ) -> httpx.Response:
  428. return super().get(
  429. url,
  430. params=params,
  431. headers=headers,
  432. cookies=cookies,
  433. auth=auth,
  434. follow_redirects=follow_redirects,
  435. timeout=timeout,
  436. extensions=extensions,
  437. )
  438. def options( # type: ignore[override]
  439. self,
  440. url: httpx._types.URLTypes,
  441. *,
  442. params: httpx._types.QueryParamTypes | None = None,
  443. headers: httpx._types.HeaderTypes | None = None,
  444. cookies: httpx._types.CookieTypes | None = None,
  445. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  446. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  447. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  448. extensions: dict[str, typing.Any] | None = None,
  449. ) -> httpx.Response:
  450. return super().options(
  451. url,
  452. params=params,
  453. headers=headers,
  454. cookies=cookies,
  455. auth=auth,
  456. follow_redirects=follow_redirects,
  457. timeout=timeout,
  458. extensions=extensions,
  459. )
  460. def head( # type: ignore[override]
  461. self,
  462. url: httpx._types.URLTypes,
  463. *,
  464. params: httpx._types.QueryParamTypes | None = None,
  465. headers: httpx._types.HeaderTypes | None = None,
  466. cookies: httpx._types.CookieTypes | None = None,
  467. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  468. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  469. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  470. extensions: dict[str, typing.Any] | None = None,
  471. ) -> httpx.Response:
  472. return super().head(
  473. url,
  474. params=params,
  475. headers=headers,
  476. cookies=cookies,
  477. auth=auth,
  478. follow_redirects=follow_redirects,
  479. timeout=timeout,
  480. extensions=extensions,
  481. )
  482. def post( # type: ignore[override]
  483. self,
  484. url: httpx._types.URLTypes,
  485. *,
  486. content: httpx._types.RequestContent | None = None,
  487. data: _RequestData | None = None,
  488. files: httpx._types.RequestFiles | None = None,
  489. json: typing.Any = None,
  490. params: httpx._types.QueryParamTypes | None = None,
  491. headers: httpx._types.HeaderTypes | None = None,
  492. cookies: httpx._types.CookieTypes | None = None,
  493. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  494. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  495. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  496. extensions: dict[str, typing.Any] | None = None,
  497. ) -> httpx.Response:
  498. return super().post(
  499. url,
  500. content=content,
  501. data=data,
  502. files=files,
  503. json=json,
  504. params=params,
  505. headers=headers,
  506. cookies=cookies,
  507. auth=auth,
  508. follow_redirects=follow_redirects,
  509. timeout=timeout,
  510. extensions=extensions,
  511. )
  512. def put( # type: ignore[override]
  513. self,
  514. url: httpx._types.URLTypes,
  515. *,
  516. content: httpx._types.RequestContent | None = None,
  517. data: _RequestData | None = None,
  518. files: httpx._types.RequestFiles | None = None,
  519. json: typing.Any = None,
  520. params: httpx._types.QueryParamTypes | None = None,
  521. headers: httpx._types.HeaderTypes | None = None,
  522. cookies: httpx._types.CookieTypes | None = None,
  523. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  524. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  525. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  526. extensions: dict[str, typing.Any] | None = None,
  527. ) -> httpx.Response:
  528. return super().put(
  529. url,
  530. content=content,
  531. data=data,
  532. files=files,
  533. json=json,
  534. params=params,
  535. headers=headers,
  536. cookies=cookies,
  537. auth=auth,
  538. follow_redirects=follow_redirects,
  539. timeout=timeout,
  540. extensions=extensions,
  541. )
  542. def patch( # type: ignore[override]
  543. self,
  544. url: httpx._types.URLTypes,
  545. *,
  546. content: httpx._types.RequestContent | None = None,
  547. data: _RequestData | None = None,
  548. files: httpx._types.RequestFiles | None = None,
  549. json: typing.Any = None,
  550. params: httpx._types.QueryParamTypes | None = None,
  551. headers: httpx._types.HeaderTypes | None = None,
  552. cookies: httpx._types.CookieTypes | None = None,
  553. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  554. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  555. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  556. extensions: dict[str, typing.Any] | None = None,
  557. ) -> httpx.Response:
  558. return super().patch(
  559. url,
  560. content=content,
  561. data=data,
  562. files=files,
  563. json=json,
  564. params=params,
  565. headers=headers,
  566. cookies=cookies,
  567. auth=auth,
  568. follow_redirects=follow_redirects,
  569. timeout=timeout,
  570. extensions=extensions,
  571. )
  572. def delete( # type: ignore[override]
  573. self,
  574. url: httpx._types.URLTypes,
  575. *,
  576. params: httpx._types.QueryParamTypes | None = None,
  577. headers: httpx._types.HeaderTypes | None = None,
  578. cookies: httpx._types.CookieTypes | None = None,
  579. auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  580. follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  581. timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
  582. extensions: dict[str, typing.Any] | None = None,
  583. ) -> httpx.Response:
  584. return super().delete(
  585. url,
  586. params=params,
  587. headers=headers,
  588. cookies=cookies,
  589. auth=auth,
  590. follow_redirects=follow_redirects,
  591. timeout=timeout,
  592. extensions=extensions,
  593. )
  594. def websocket_connect(
  595. self,
  596. url: str,
  597. subprotocols: typing.Sequence[str] | None = None,
  598. **kwargs: typing.Any,
  599. ) -> WebSocketTestSession:
  600. url = urljoin("ws://testserver", url)
  601. headers = kwargs.get("headers", {})
  602. headers.setdefault("connection", "upgrade")
  603. headers.setdefault("sec-websocket-key", "testserver==")
  604. headers.setdefault("sec-websocket-version", "13")
  605. if subprotocols is not None:
  606. headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
  607. kwargs["headers"] = headers
  608. try:
  609. super().request("GET", url, **kwargs)
  610. except _Upgrade as exc:
  611. session = exc.session
  612. else:
  613. raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
  614. return session
  615. def __enter__(self) -> TestClient:
  616. with contextlib.ExitStack() as stack:
  617. self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
  618. @stack.callback
  619. def reset_portal() -> None:
  620. self.portal = None
  621. send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None]
  622. receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None]
  623. send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
  624. receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
  625. send1, receive1 = anyio.create_memory_object_stream(math.inf)
  626. send2, receive2 = anyio.create_memory_object_stream(math.inf)
  627. self.stream_send = StapledObjectStream(send1, receive1)
  628. self.stream_receive = StapledObjectStream(send2, receive2)
  629. self.task = portal.start_task_soon(self.lifespan)
  630. portal.call(self.wait_startup)
  631. @stack.callback
  632. def wait_shutdown() -> None:
  633. portal.call(self.wait_shutdown)
  634. self.exit_stack = stack.pop_all()
  635. return self
  636. def __exit__(self, *args: typing.Any) -> None:
  637. self.exit_stack.close()
  638. async def lifespan(self) -> None:
  639. scope = {"type": "lifespan", "state": self.app_state}
  640. try:
  641. await self.app(scope, self.stream_receive.receive, self.stream_send.send)
  642. finally:
  643. await self.stream_send.send(None)
  644. async def wait_startup(self) -> None:
  645. await self.stream_receive.send({"type": "lifespan.startup"})
  646. async def receive() -> typing.Any:
  647. message = await self.stream_send.receive()
  648. if message is None:
  649. self.task.result()
  650. return message
  651. message = await receive()
  652. assert message["type"] in (
  653. "lifespan.startup.complete",
  654. "lifespan.startup.failed",
  655. )
  656. if message["type"] == "lifespan.startup.failed":
  657. await receive()
  658. async def wait_shutdown(self) -> None:
  659. async def receive() -> typing.Any:
  660. message = await self.stream_send.receive()
  661. if message is None:
  662. self.task.result()
  663. return message
  664. async with self.stream_send, self.stream_receive:
  665. await self.stream_receive.send({"type": "lifespan.shutdown"})
  666. message = await receive()
  667. assert message["type"] in (
  668. "lifespan.shutdown.complete",
  669. "lifespan.shutdown.failed",
  670. )
  671. if message["type"] == "lifespan.shutdown.failed":
  672. await receive()