datastructures.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. from __future__ import annotations
  2. import typing
  3. from shlex import shlex
  4. from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
  5. from starlette.concurrency import run_in_threadpool
  6. from starlette.types import Scope
  7. class Address(typing.NamedTuple):
  8. host: str
  9. port: int
  10. _KeyType = typing.TypeVar("_KeyType")
  11. # Mapping keys are invariant but their values are covariant since
  12. # you can only read them
  13. # that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
  14. _CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)
  15. class URL:
  16. def __init__(
  17. self,
  18. url: str = "",
  19. scope: Scope | None = None,
  20. **components: typing.Any,
  21. ) -> None:
  22. if scope is not None:
  23. assert not url, 'Cannot set both "url" and "scope".'
  24. assert not components, 'Cannot set both "scope" and "**components".'
  25. scheme = scope.get("scheme", "http")
  26. server = scope.get("server", None)
  27. path = scope["path"]
  28. query_string = scope.get("query_string", b"")
  29. host_header = None
  30. for key, value in scope["headers"]:
  31. if key == b"host":
  32. host_header = value.decode("latin-1")
  33. break
  34. if host_header is not None:
  35. url = f"{scheme}://{host_header}{path}"
  36. elif server is None:
  37. url = path
  38. else:
  39. host, port = server
  40. default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
  41. if port == default_port:
  42. url = f"{scheme}://{host}{path}"
  43. else:
  44. url = f"{scheme}://{host}:{port}{path}"
  45. if query_string:
  46. url += "?" + query_string.decode()
  47. elif components:
  48. assert not url, 'Cannot set both "url" and "**components".'
  49. url = URL("").replace(**components).components.geturl()
  50. self._url = url
  51. @property
  52. def components(self) -> SplitResult:
  53. if not hasattr(self, "_components"):
  54. self._components = urlsplit(self._url)
  55. return self._components
  56. @property
  57. def scheme(self) -> str:
  58. return self.components.scheme
  59. @property
  60. def netloc(self) -> str:
  61. return self.components.netloc
  62. @property
  63. def path(self) -> str:
  64. return self.components.path
  65. @property
  66. def query(self) -> str:
  67. return self.components.query
  68. @property
  69. def fragment(self) -> str:
  70. return self.components.fragment
  71. @property
  72. def username(self) -> None | str:
  73. return self.components.username
  74. @property
  75. def password(self) -> None | str:
  76. return self.components.password
  77. @property
  78. def hostname(self) -> None | str:
  79. return self.components.hostname
  80. @property
  81. def port(self) -> int | None:
  82. return self.components.port
  83. @property
  84. def is_secure(self) -> bool:
  85. return self.scheme in ("https", "wss")
  86. def replace(self, **kwargs: typing.Any) -> URL:
  87. if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
  88. hostname = kwargs.pop("hostname", None)
  89. port = kwargs.pop("port", self.port)
  90. username = kwargs.pop("username", self.username)
  91. password = kwargs.pop("password", self.password)
  92. if hostname is None:
  93. netloc = self.netloc
  94. _, _, hostname = netloc.rpartition("@")
  95. if hostname[-1] != "]":
  96. hostname = hostname.rsplit(":", 1)[0]
  97. netloc = hostname
  98. if port is not None:
  99. netloc += f":{port}"
  100. if username is not None:
  101. userpass = username
  102. if password is not None:
  103. userpass += f":{password}"
  104. netloc = f"{userpass}@{netloc}"
  105. kwargs["netloc"] = netloc
  106. components = self.components._replace(**kwargs)
  107. return self.__class__(components.geturl())
  108. def include_query_params(self, **kwargs: typing.Any) -> URL:
  109. params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
  110. params.update({str(key): str(value) for key, value in kwargs.items()})
  111. query = urlencode(params.multi_items())
  112. return self.replace(query=query)
  113. def replace_query_params(self, **kwargs: typing.Any) -> URL:
  114. query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
  115. return self.replace(query=query)
  116. def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL:
  117. if isinstance(keys, str):
  118. keys = [keys]
  119. params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
  120. for key in keys:
  121. params.pop(key, None)
  122. query = urlencode(params.multi_items())
  123. return self.replace(query=query)
  124. def __eq__(self, other: typing.Any) -> bool:
  125. return str(self) == str(other)
  126. def __str__(self) -> str:
  127. return self._url
  128. def __repr__(self) -> str:
  129. url = str(self)
  130. if self.password:
  131. url = str(self.replace(password="********"))
  132. return f"{self.__class__.__name__}({repr(url)})"
  133. class URLPath(str):
  134. """
  135. A URL path string that may also hold an associated protocol and/or host.
  136. Used by the routing to return `url_path_for` matches.
  137. """
  138. def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
  139. assert protocol in ("http", "websocket", "")
  140. return str.__new__(cls, path)
  141. def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
  142. self.protocol = protocol
  143. self.host = host
  144. def make_absolute_url(self, base_url: str | URL) -> URL:
  145. if isinstance(base_url, str):
  146. base_url = URL(base_url)
  147. if self.protocol:
  148. scheme = {
  149. "http": {True: "https", False: "http"},
  150. "websocket": {True: "wss", False: "ws"},
  151. }[self.protocol][base_url.is_secure]
  152. else:
  153. scheme = base_url.scheme
  154. netloc = self.host or base_url.netloc
  155. path = base_url.path.rstrip("/") + str(self)
  156. return URL(scheme=scheme, netloc=netloc, path=path)
  157. class Secret:
  158. """
  159. Holds a string value that should not be revealed in tracebacks etc.
  160. You should cast the value to `str` at the point it is required.
  161. """
  162. def __init__(self, value: str):
  163. self._value = value
  164. def __repr__(self) -> str:
  165. class_name = self.__class__.__name__
  166. return f"{class_name}('**********')"
  167. def __str__(self) -> str:
  168. return self._value
  169. def __bool__(self) -> bool:
  170. return bool(self._value)
  171. class CommaSeparatedStrings(typing.Sequence[str]):
  172. def __init__(self, value: str | typing.Sequence[str]):
  173. if isinstance(value, str):
  174. splitter = shlex(value, posix=True)
  175. splitter.whitespace = ","
  176. splitter.whitespace_split = True
  177. self._items = [item.strip() for item in splitter]
  178. else:
  179. self._items = list(value)
  180. def __len__(self) -> int:
  181. return len(self._items)
  182. def __getitem__(self, index: int | slice) -> typing.Any:
  183. return self._items[index]
  184. def __iter__(self) -> typing.Iterator[str]:
  185. return iter(self._items)
  186. def __repr__(self) -> str:
  187. class_name = self.__class__.__name__
  188. items = [item for item in self]
  189. return f"{class_name}({items!r})"
  190. def __str__(self) -> str:
  191. return ", ".join(repr(item) for item in self)
  192. class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
  193. _dict: dict[_KeyType, _CovariantValueType]
  194. def __init__(
  195. self,
  196. *args: ImmutableMultiDict[_KeyType, _CovariantValueType]
  197. | typing.Mapping[_KeyType, _CovariantValueType]
  198. | typing.Iterable[tuple[_KeyType, _CovariantValueType]],
  199. **kwargs: typing.Any,
  200. ) -> None:
  201. assert len(args) < 2, "Too many arguments."
  202. value: typing.Any = args[0] if args else []
  203. if kwargs:
  204. value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
  205. if not value:
  206. _items: list[tuple[typing.Any, typing.Any]] = []
  207. elif hasattr(value, "multi_items"):
  208. value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
  209. _items = list(value.multi_items())
  210. elif hasattr(value, "items"):
  211. value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
  212. _items = list(value.items())
  213. else:
  214. value = typing.cast("list[tuple[typing.Any, typing.Any]]", value)
  215. _items = list(value)
  216. self._dict = {k: v for k, v in _items}
  217. self._list = _items
  218. def getlist(self, key: typing.Any) -> list[_CovariantValueType]:
  219. return [item_value for item_key, item_value in self._list if item_key == key]
  220. def keys(self) -> typing.KeysView[_KeyType]:
  221. return self._dict.keys()
  222. def values(self) -> typing.ValuesView[_CovariantValueType]:
  223. return self._dict.values()
  224. def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
  225. return self._dict.items()
  226. def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
  227. return list(self._list)
  228. def __getitem__(self, key: _KeyType) -> _CovariantValueType:
  229. return self._dict[key]
  230. def __contains__(self, key: typing.Any) -> bool:
  231. return key in self._dict
  232. def __iter__(self) -> typing.Iterator[_KeyType]:
  233. return iter(self.keys())
  234. def __len__(self) -> int:
  235. return len(self._dict)
  236. def __eq__(self, other: typing.Any) -> bool:
  237. if not isinstance(other, self.__class__):
  238. return False
  239. return sorted(self._list) == sorted(other._list)
  240. def __repr__(self) -> str:
  241. class_name = self.__class__.__name__
  242. items = self.multi_items()
  243. return f"{class_name}({items!r})"
  244. class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
  245. def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
  246. self.setlist(key, [value])
  247. def __delitem__(self, key: typing.Any) -> None:
  248. self._list = [(k, v) for k, v in self._list if k != key]
  249. del self._dict[key]
  250. def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  251. self._list = [(k, v) for k, v in self._list if k != key]
  252. return self._dict.pop(key, default)
  253. def popitem(self) -> tuple[typing.Any, typing.Any]:
  254. key, value = self._dict.popitem()
  255. self._list = [(k, v) for k, v in self._list if k != key]
  256. return key, value
  257. def poplist(self, key: typing.Any) -> list[typing.Any]:
  258. values = [v for k, v in self._list if k == key]
  259. self.pop(key)
  260. return values
  261. def clear(self) -> None:
  262. self._dict.clear()
  263. self._list.clear()
  264. def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  265. if key not in self:
  266. self._dict[key] = default
  267. self._list.append((key, default))
  268. return self[key]
  269. def setlist(self, key: typing.Any, values: list[typing.Any]) -> None:
  270. if not values:
  271. self.pop(key, None)
  272. else:
  273. existing_items = [(k, v) for (k, v) in self._list if k != key]
  274. self._list = existing_items + [(key, value) for value in values]
  275. self._dict[key] = values[-1]
  276. def append(self, key: typing.Any, value: typing.Any) -> None:
  277. self._list.append((key, value))
  278. self._dict[key] = value
  279. def update(
  280. self,
  281. *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]],
  282. **kwargs: typing.Any,
  283. ) -> None:
  284. value = MultiDict(*args, **kwargs)
  285. existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
  286. self._list = existing_items + value.multi_items()
  287. self._dict.update(value)
  288. class QueryParams(ImmutableMultiDict[str, str]):
  289. """
  290. An immutable multidict.
  291. """
  292. def __init__(
  293. self,
  294. *args: ImmutableMultiDict[typing.Any, typing.Any]
  295. | typing.Mapping[typing.Any, typing.Any]
  296. | list[tuple[typing.Any, typing.Any]]
  297. | str
  298. | bytes,
  299. **kwargs: typing.Any,
  300. ) -> None:
  301. assert len(args) < 2, "Too many arguments."
  302. value = args[0] if args else []
  303. if isinstance(value, str):
  304. super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
  305. elif isinstance(value, bytes):
  306. super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
  307. else:
  308. super().__init__(*args, **kwargs) # type: ignore[arg-type]
  309. self._list = [(str(k), str(v)) for k, v in self._list]
  310. self._dict = {str(k): str(v) for k, v in self._dict.items()}
  311. def __str__(self) -> str:
  312. return urlencode(self._list)
  313. def __repr__(self) -> str:
  314. class_name = self.__class__.__name__
  315. query_string = str(self)
  316. return f"{class_name}({query_string!r})"
  317. class UploadFile:
  318. """
  319. An uploaded file included as part of the request data.
  320. """
  321. def __init__(
  322. self,
  323. file: typing.BinaryIO,
  324. *,
  325. size: int | None = None,
  326. filename: str | None = None,
  327. headers: Headers | None = None,
  328. ) -> None:
  329. self.filename = filename
  330. self.file = file
  331. self.size = size
  332. self.headers = headers or Headers()
  333. @property
  334. def content_type(self) -> str | None:
  335. return self.headers.get("content-type", None)
  336. @property
  337. def _in_memory(self) -> bool:
  338. # check for SpooledTemporaryFile._rolled
  339. rolled_to_disk = getattr(self.file, "_rolled", True)
  340. return not rolled_to_disk
  341. async def write(self, data: bytes) -> None:
  342. if self.size is not None:
  343. self.size += len(data)
  344. if self._in_memory:
  345. self.file.write(data)
  346. else:
  347. await run_in_threadpool(self.file.write, data)
  348. async def read(self, size: int = -1) -> bytes:
  349. if self._in_memory:
  350. return self.file.read(size)
  351. return await run_in_threadpool(self.file.read, size)
  352. async def seek(self, offset: int) -> None:
  353. if self._in_memory:
  354. self.file.seek(offset)
  355. else:
  356. await run_in_threadpool(self.file.seek, offset)
  357. async def close(self) -> None:
  358. if self._in_memory:
  359. self.file.close()
  360. else:
  361. await run_in_threadpool(self.file.close)
  362. def __repr__(self) -> str:
  363. return (
  364. f"{self.__class__.__name__}("
  365. f"filename={self.filename!r}, "
  366. f"size={self.size!r}, "
  367. f"headers={self.headers!r})"
  368. )
  369. class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
  370. """
  371. An immutable multidict, containing both file uploads and text input.
  372. """
  373. def __init__(
  374. self,
  375. *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
  376. **kwargs: str | UploadFile,
  377. ) -> None:
  378. super().__init__(*args, **kwargs)
  379. async def close(self) -> None:
  380. for key, value in self.multi_items():
  381. if isinstance(value, UploadFile):
  382. await value.close()
  383. class Headers(typing.Mapping[str, str]):
  384. """
  385. An immutable, case-insensitive multidict.
  386. """
  387. def __init__(
  388. self,
  389. headers: typing.Mapping[str, str] | None = None,
  390. raw: list[tuple[bytes, bytes]] | None = None,
  391. scope: typing.MutableMapping[str, typing.Any] | None = None,
  392. ) -> None:
  393. self._list: list[tuple[bytes, bytes]] = []
  394. if headers is not None:
  395. assert raw is None, 'Cannot set both "headers" and "raw".'
  396. assert scope is None, 'Cannot set both "headers" and "scope".'
  397. self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
  398. elif raw is not None:
  399. assert scope is None, 'Cannot set both "raw" and "scope".'
  400. self._list = raw
  401. elif scope is not None:
  402. # scope["headers"] isn't necessarily a list
  403. # it might be a tuple or other iterable
  404. self._list = scope["headers"] = list(scope["headers"])
  405. @property
  406. def raw(self) -> list[tuple[bytes, bytes]]:
  407. return list(self._list)
  408. def keys(self) -> list[str]: # type: ignore[override]
  409. return [key.decode("latin-1") for key, value in self._list]
  410. def values(self) -> list[str]: # type: ignore[override]
  411. return [value.decode("latin-1") for key, value in self._list]
  412. def items(self) -> list[tuple[str, str]]: # type: ignore[override]
  413. return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
  414. def getlist(self, key: str) -> list[str]:
  415. get_header_key = key.lower().encode("latin-1")
  416. return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
  417. def mutablecopy(self) -> MutableHeaders:
  418. return MutableHeaders(raw=self._list[:])
  419. def __getitem__(self, key: str) -> str:
  420. get_header_key = key.lower().encode("latin-1")
  421. for header_key, header_value in self._list:
  422. if header_key == get_header_key:
  423. return header_value.decode("latin-1")
  424. raise KeyError(key)
  425. def __contains__(self, key: typing.Any) -> bool:
  426. get_header_key = key.lower().encode("latin-1")
  427. for header_key, header_value in self._list:
  428. if header_key == get_header_key:
  429. return True
  430. return False
  431. def __iter__(self) -> typing.Iterator[typing.Any]:
  432. return iter(self.keys())
  433. def __len__(self) -> int:
  434. return len(self._list)
  435. def __eq__(self, other: typing.Any) -> bool:
  436. if not isinstance(other, Headers):
  437. return False
  438. return sorted(self._list) == sorted(other._list)
  439. def __repr__(self) -> str:
  440. class_name = self.__class__.__name__
  441. as_dict = dict(self.items())
  442. if len(as_dict) == len(self):
  443. return f"{class_name}({as_dict!r})"
  444. return f"{class_name}(raw={self.raw!r})"
  445. class MutableHeaders(Headers):
  446. def __setitem__(self, key: str, value: str) -> None:
  447. """
  448. Set the header `key` to `value`, removing any duplicate entries.
  449. Retains insertion order.
  450. """
  451. set_key = key.lower().encode("latin-1")
  452. set_value = value.encode("latin-1")
  453. found_indexes: list[int] = []
  454. for idx, (item_key, item_value) in enumerate(self._list):
  455. if item_key == set_key:
  456. found_indexes.append(idx)
  457. for idx in reversed(found_indexes[1:]):
  458. del self._list[idx]
  459. if found_indexes:
  460. idx = found_indexes[0]
  461. self._list[idx] = (set_key, set_value)
  462. else:
  463. self._list.append((set_key, set_value))
  464. def __delitem__(self, key: str) -> None:
  465. """
  466. Remove the header `key`.
  467. """
  468. del_key = key.lower().encode("latin-1")
  469. pop_indexes: list[int] = []
  470. for idx, (item_key, item_value) in enumerate(self._list):
  471. if item_key == del_key:
  472. pop_indexes.append(idx)
  473. for idx in reversed(pop_indexes):
  474. del self._list[idx]
  475. def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
  476. if not isinstance(other, typing.Mapping):
  477. raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
  478. self.update(other)
  479. return self
  480. def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
  481. if not isinstance(other, typing.Mapping):
  482. raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
  483. new = self.mutablecopy()
  484. new.update(other)
  485. return new
  486. @property
  487. def raw(self) -> list[tuple[bytes, bytes]]:
  488. return self._list
  489. def setdefault(self, key: str, value: str) -> str:
  490. """
  491. If the header `key` does not exist, then set it to `value`.
  492. Returns the header value.
  493. """
  494. set_key = key.lower().encode("latin-1")
  495. set_value = value.encode("latin-1")
  496. for idx, (item_key, item_value) in enumerate(self._list):
  497. if item_key == set_key:
  498. return item_value.decode("latin-1")
  499. self._list.append((set_key, set_value))
  500. return value
  501. def update(self, other: typing.Mapping[str, str]) -> None:
  502. for key, val in other.items():
  503. self[key] = val
  504. def append(self, key: str, value: str) -> None:
  505. """
  506. Append a header, preserving any duplicate entries.
  507. """
  508. append_key = key.lower().encode("latin-1")
  509. append_value = value.encode("latin-1")
  510. self._list.append((append_key, append_value))
  511. def add_vary_header(self, vary: str) -> None:
  512. existing = self.get("vary")
  513. if existing is not None:
  514. vary = ", ".join([existing, vary])
  515. self["vary"] = vary
  516. class State:
  517. """
  518. An object that can be used to store arbitrary state.
  519. Used for `request.state` and `app.state`.
  520. """
  521. _state: dict[str, typing.Any]
  522. def __init__(self, state: dict[str, typing.Any] | None = None):
  523. if state is None:
  524. state = {}
  525. super().__setattr__("_state", state)
  526. def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
  527. self._state[key] = value
  528. def __getattr__(self, key: typing.Any) -> typing.Any:
  529. try:
  530. return self._state[key]
  531. except KeyError:
  532. message = "'{}' object has no attribute '{}'"
  533. raise AttributeError(message.format(self.__class__.__name__, key))
  534. def __delattr__(self, key: typing.Any) -> None:
  535. del self._state[key]