requests.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. from __future__ import annotations
  2. import json
  3. import typing
  4. from http import cookies as http_cookies
  5. import anyio
  6. from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
  7. from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
  8. from starlette.exceptions import HTTPException
  9. from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
  10. from starlette.types import Message, Receive, Scope, Send
  11. if typing.TYPE_CHECKING:
  12. from multipart.multipart import parse_options_header
  13. from starlette.applications import Starlette
  14. from starlette.routing import Router
  15. else:
  16. try:
  17. try:
  18. from python_multipart.multipart import parse_options_header
  19. except ModuleNotFoundError: # pragma: no cover
  20. from multipart.multipart import parse_options_header
  21. except ModuleNotFoundError: # pragma: no cover
  22. parse_options_header = None
  23. SERVER_PUSH_HEADERS_TO_COPY = {
  24. "accept",
  25. "accept-encoding",
  26. "accept-language",
  27. "cache-control",
  28. "user-agent",
  29. }
  30. def cookie_parser(cookie_string: str) -> dict[str, str]:
  31. """
  32. This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
  33. It attempts to mimic browser cookie parsing behavior: browsers and web servers
  34. frequently disregard the spec (RFC 6265) when setting and reading cookies,
  35. so we attempt to suit the common scenarios here.
  36. This function has been adapted from Django 3.1.0.
  37. Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
  38. on an outdated spec and will fail on lots of input we want to support
  39. """
  40. cookie_dict: dict[str, str] = {}
  41. for chunk in cookie_string.split(";"):
  42. if "=" in chunk:
  43. key, val = chunk.split("=", 1)
  44. else:
  45. # Assume an empty name per
  46. # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
  47. key, val = "", chunk
  48. key, val = key.strip(), val.strip()
  49. if key or val:
  50. # unquote using Python's algorithm.
  51. cookie_dict[key] = http_cookies._unquote(val)
  52. return cookie_dict
  53. class ClientDisconnect(Exception):
  54. pass
  55. class HTTPConnection(typing.Mapping[str, typing.Any]):
  56. """
  57. A base class for incoming HTTP connections, that is used to provide
  58. any functionality that is common to both `Request` and `WebSocket`.
  59. """
  60. def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
  61. assert scope["type"] in ("http", "websocket")
  62. self.scope = scope
  63. def __getitem__(self, key: str) -> typing.Any:
  64. return self.scope[key]
  65. def __iter__(self) -> typing.Iterator[str]:
  66. return iter(self.scope)
  67. def __len__(self) -> int:
  68. return len(self.scope)
  69. # Don't use the `abc.Mapping.__eq__` implementation.
  70. # Connection instances should never be considered equal
  71. # unless `self is other`.
  72. __eq__ = object.__eq__
  73. __hash__ = object.__hash__
  74. @property
  75. def app(self) -> typing.Any:
  76. return self.scope["app"]
  77. @property
  78. def url(self) -> URL:
  79. if not hasattr(self, "_url"): # pragma: no branch
  80. self._url = URL(scope=self.scope)
  81. return self._url
  82. @property
  83. def base_url(self) -> URL:
  84. if not hasattr(self, "_base_url"):
  85. base_url_scope = dict(self.scope)
  86. # This is used by request.url_for, it might be used inside a Mount which
  87. # would have its own child scope with its own root_path, but the base URL
  88. # for url_for should still be the top level app root path.
  89. app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
  90. path = app_root_path
  91. if not path.endswith("/"):
  92. path += "/"
  93. base_url_scope["path"] = path
  94. base_url_scope["query_string"] = b""
  95. base_url_scope["root_path"] = app_root_path
  96. self._base_url = URL(scope=base_url_scope)
  97. return self._base_url
  98. @property
  99. def headers(self) -> Headers:
  100. if not hasattr(self, "_headers"):
  101. self._headers = Headers(scope=self.scope)
  102. return self._headers
  103. @property
  104. def query_params(self) -> QueryParams:
  105. if not hasattr(self, "_query_params"): # pragma: no branch
  106. self._query_params = QueryParams(self.scope["query_string"])
  107. return self._query_params
  108. @property
  109. def path_params(self) -> dict[str, typing.Any]:
  110. return self.scope.get("path_params", {})
  111. @property
  112. def cookies(self) -> dict[str, str]:
  113. if not hasattr(self, "_cookies"):
  114. cookies: dict[str, str] = {}
  115. cookie_header = self.headers.get("cookie")
  116. if cookie_header:
  117. cookies = cookie_parser(cookie_header)
  118. self._cookies = cookies
  119. return self._cookies
  120. @property
  121. def client(self) -> Address | None:
  122. # client is a 2 item tuple of (host, port), None if missing
  123. host_port = self.scope.get("client")
  124. if host_port is not None:
  125. return Address(*host_port)
  126. return None
  127. @property
  128. def session(self) -> dict[str, typing.Any]:
  129. assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
  130. return self.scope["session"] # type: ignore[no-any-return]
  131. @property
  132. def auth(self) -> typing.Any:
  133. assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
  134. return self.scope["auth"]
  135. @property
  136. def user(self) -> typing.Any:
  137. assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
  138. return self.scope["user"]
  139. @property
  140. def state(self) -> State:
  141. if not hasattr(self, "_state"):
  142. # Ensure 'state' has an empty dict if it's not already populated.
  143. self.scope.setdefault("state", {})
  144. # Create a state instance with a reference to the dict in which it should
  145. # store info
  146. self._state = State(self.scope["state"])
  147. return self._state
  148. def url_for(self, name: str, /, **path_params: typing.Any) -> URL:
  149. url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
  150. if url_path_provider is None:
  151. raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
  152. url_path = url_path_provider.url_path_for(name, **path_params)
  153. return url_path.make_absolute_url(base_url=self.base_url)
  154. async def empty_receive() -> typing.NoReturn:
  155. raise RuntimeError("Receive channel has not been made available")
  156. async def empty_send(message: Message) -> typing.NoReturn:
  157. raise RuntimeError("Send channel has not been made available")
  158. class Request(HTTPConnection):
  159. _form: FormData | None
  160. def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
  161. super().__init__(scope)
  162. assert scope["type"] == "http"
  163. self._receive = receive
  164. self._send = send
  165. self._stream_consumed = False
  166. self._is_disconnected = False
  167. self._form = None
  168. @property
  169. def method(self) -> str:
  170. return typing.cast(str, self.scope["method"])
  171. @property
  172. def receive(self) -> Receive:
  173. return self._receive
  174. async def stream(self) -> typing.AsyncGenerator[bytes, None]:
  175. if hasattr(self, "_body"):
  176. yield self._body
  177. yield b""
  178. return
  179. if self._stream_consumed:
  180. raise RuntimeError("Stream consumed")
  181. while not self._stream_consumed:
  182. message = await self._receive()
  183. if message["type"] == "http.request":
  184. body = message.get("body", b"")
  185. if not message.get("more_body", False):
  186. self._stream_consumed = True
  187. if body:
  188. yield body
  189. elif message["type"] == "http.disconnect": # pragma: no branch
  190. self._is_disconnected = True
  191. raise ClientDisconnect()
  192. yield b""
  193. async def body(self) -> bytes:
  194. if not hasattr(self, "_body"):
  195. chunks: list[bytes] = []
  196. async for chunk in self.stream():
  197. chunks.append(chunk)
  198. self._body = b"".join(chunks)
  199. return self._body
  200. async def json(self) -> typing.Any:
  201. if not hasattr(self, "_json"): # pragma: no branch
  202. body = await self.body()
  203. self._json = json.loads(body)
  204. return self._json
  205. async def _get_form(
  206. self,
  207. *,
  208. max_files: int | float = 1000,
  209. max_fields: int | float = 1000,
  210. max_part_size: int = 1024 * 1024,
  211. ) -> FormData:
  212. if self._form is None: # pragma: no branch
  213. assert (
  214. parse_options_header is not None
  215. ), "The `python-multipart` library must be installed to use form parsing."
  216. content_type_header = self.headers.get("Content-Type")
  217. content_type: bytes
  218. content_type, _ = parse_options_header(content_type_header)
  219. if content_type == b"multipart/form-data":
  220. try:
  221. multipart_parser = MultiPartParser(
  222. self.headers,
  223. self.stream(),
  224. max_files=max_files,
  225. max_fields=max_fields,
  226. max_part_size=max_part_size,
  227. )
  228. self._form = await multipart_parser.parse()
  229. except MultiPartException as exc:
  230. if "app" in self.scope:
  231. raise HTTPException(status_code=400, detail=exc.message)
  232. raise exc
  233. elif content_type == b"application/x-www-form-urlencoded":
  234. form_parser = FormParser(self.headers, self.stream())
  235. self._form = await form_parser.parse()
  236. else:
  237. self._form = FormData()
  238. return self._form
  239. def form(
  240. self,
  241. *,
  242. max_files: int | float = 1000,
  243. max_fields: int | float = 1000,
  244. max_part_size: int = 1024 * 1024,
  245. ) -> AwaitableOrContextManager[FormData]:
  246. return AwaitableOrContextManagerWrapper(
  247. self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
  248. )
  249. async def close(self) -> None:
  250. if self._form is not None: # pragma: no branch
  251. await self._form.close()
  252. async def is_disconnected(self) -> bool:
  253. if not self._is_disconnected:
  254. message: Message = {}
  255. # If message isn't immediately available, move on
  256. with anyio.CancelScope() as cs:
  257. cs.cancel()
  258. message = await self._receive()
  259. if message.get("type") == "http.disconnect":
  260. self._is_disconnected = True
  261. return self._is_disconnected
  262. async def send_push_promise(self, path: str) -> None:
  263. if "http.response.push" in self.scope.get("extensions", {}):
  264. raw_headers: list[tuple[bytes, bytes]] = []
  265. for name in SERVER_PUSH_HEADERS_TO_COPY:
  266. for value in self.headers.getlist(name):
  267. raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
  268. await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})