| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679 |
- from __future__ import annotations
- import typing
- from shlex import shlex
- from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
- from starlette.concurrency import run_in_threadpool
- from starlette.types import Scope
- class Address(typing.NamedTuple):
- host: str
- port: int
- _KeyType = typing.TypeVar("_KeyType")
- # Mapping keys are invariant but their values are covariant since
- # you can only read them
- # that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
- _CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)
- class URL:
- def __init__(
- self,
- url: str = "",
- scope: Scope | None = None,
- **components: typing.Any,
- ) -> None:
- if scope is not None:
- assert not url, 'Cannot set both "url" and "scope".'
- assert not components, 'Cannot set both "scope" and "**components".'
- scheme = scope.get("scheme", "http")
- server = scope.get("server", None)
- path = scope["path"]
- query_string = scope.get("query_string", b"")
- host_header = None
- for key, value in scope["headers"]:
- if key == b"host":
- host_header = value.decode("latin-1")
- break
- if host_header is not None:
- url = f"{scheme}://{host_header}{path}"
- elif server is None:
- url = path
- else:
- host, port = server
- default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
- if port == default_port:
- url = f"{scheme}://{host}{path}"
- else:
- url = f"{scheme}://{host}:{port}{path}"
- if query_string:
- url += "?" + query_string.decode()
- elif components:
- assert not url, 'Cannot set both "url" and "**components".'
- url = URL("").replace(**components).components.geturl()
- self._url = url
- @property
- def components(self) -> SplitResult:
- if not hasattr(self, "_components"):
- self._components = urlsplit(self._url)
- return self._components
- @property
- def scheme(self) -> str:
- return self.components.scheme
- @property
- def netloc(self) -> str:
- return self.components.netloc
- @property
- def path(self) -> str:
- return self.components.path
- @property
- def query(self) -> str:
- return self.components.query
- @property
- def fragment(self) -> str:
- return self.components.fragment
- @property
- def username(self) -> None | str:
- return self.components.username
- @property
- def password(self) -> None | str:
- return self.components.password
- @property
- def hostname(self) -> None | str:
- return self.components.hostname
- @property
- def port(self) -> int | None:
- return self.components.port
- @property
- def is_secure(self) -> bool:
- return self.scheme in ("https", "wss")
- def replace(self, **kwargs: typing.Any) -> URL:
- if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
- hostname = kwargs.pop("hostname", None)
- port = kwargs.pop("port", self.port)
- username = kwargs.pop("username", self.username)
- password = kwargs.pop("password", self.password)
- if hostname is None:
- netloc = self.netloc
- _, _, hostname = netloc.rpartition("@")
- if hostname[-1] != "]":
- hostname = hostname.rsplit(":", 1)[0]
- netloc = hostname
- if port is not None:
- netloc += f":{port}"
- if username is not None:
- userpass = username
- if password is not None:
- userpass += f":{password}"
- netloc = f"{userpass}@{netloc}"
- kwargs["netloc"] = netloc
- components = self.components._replace(**kwargs)
- return self.__class__(components.geturl())
- def include_query_params(self, **kwargs: typing.Any) -> URL:
- params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
- params.update({str(key): str(value) for key, value in kwargs.items()})
- query = urlencode(params.multi_items())
- return self.replace(query=query)
- def replace_query_params(self, **kwargs: typing.Any) -> URL:
- query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
- return self.replace(query=query)
- def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL:
- if isinstance(keys, str):
- keys = [keys]
- params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
- for key in keys:
- params.pop(key, None)
- query = urlencode(params.multi_items())
- return self.replace(query=query)
- def __eq__(self, other: typing.Any) -> bool:
- return str(self) == str(other)
- def __str__(self) -> str:
- return self._url
- def __repr__(self) -> str:
- url = str(self)
- if self.password:
- url = str(self.replace(password="********"))
- return f"{self.__class__.__name__}({repr(url)})"
- class URLPath(str):
- """
- A URL path string that may also hold an associated protocol and/or host.
- Used by the routing to return `url_path_for` matches.
- """
- def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
- assert protocol in ("http", "websocket", "")
- return str.__new__(cls, path)
- def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
- self.protocol = protocol
- self.host = host
- def make_absolute_url(self, base_url: str | URL) -> URL:
- if isinstance(base_url, str):
- base_url = URL(base_url)
- if self.protocol:
- scheme = {
- "http": {True: "https", False: "http"},
- "websocket": {True: "wss", False: "ws"},
- }[self.protocol][base_url.is_secure]
- else:
- scheme = base_url.scheme
- netloc = self.host or base_url.netloc
- path = base_url.path.rstrip("/") + str(self)
- return URL(scheme=scheme, netloc=netloc, path=path)
- class Secret:
- """
- Holds a string value that should not be revealed in tracebacks etc.
- You should cast the value to `str` at the point it is required.
- """
- def __init__(self, value: str):
- self._value = value
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- return f"{class_name}('**********')"
- def __str__(self) -> str:
- return self._value
- def __bool__(self) -> bool:
- return bool(self._value)
- class CommaSeparatedStrings(typing.Sequence[str]):
- def __init__(self, value: str | typing.Sequence[str]):
- if isinstance(value, str):
- splitter = shlex(value, posix=True)
- splitter.whitespace = ","
- splitter.whitespace_split = True
- self._items = [item.strip() for item in splitter]
- else:
- self._items = list(value)
- def __len__(self) -> int:
- return len(self._items)
- def __getitem__(self, index: int | slice) -> typing.Any:
- return self._items[index]
- def __iter__(self) -> typing.Iterator[str]:
- return iter(self._items)
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- items = [item for item in self]
- return f"{class_name}({items!r})"
- def __str__(self) -> str:
- return ", ".join(repr(item) for item in self)
- class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
- _dict: dict[_KeyType, _CovariantValueType]
- def __init__(
- self,
- *args: ImmutableMultiDict[_KeyType, _CovariantValueType]
- | typing.Mapping[_KeyType, _CovariantValueType]
- | typing.Iterable[tuple[_KeyType, _CovariantValueType]],
- **kwargs: typing.Any,
- ) -> None:
- assert len(args) < 2, "Too many arguments."
- value: typing.Any = args[0] if args else []
- if kwargs:
- value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
- if not value:
- _items: list[tuple[typing.Any, typing.Any]] = []
- elif hasattr(value, "multi_items"):
- value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
- _items = list(value.multi_items())
- elif hasattr(value, "items"):
- value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
- _items = list(value.items())
- else:
- value = typing.cast("list[tuple[typing.Any, typing.Any]]", value)
- _items = list(value)
- self._dict = {k: v for k, v in _items}
- self._list = _items
- def getlist(self, key: typing.Any) -> list[_CovariantValueType]:
- return [item_value for item_key, item_value in self._list if item_key == key]
- def keys(self) -> typing.KeysView[_KeyType]:
- return self._dict.keys()
- def values(self) -> typing.ValuesView[_CovariantValueType]:
- return self._dict.values()
- def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
- return self._dict.items()
- def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
- return list(self._list)
- def __getitem__(self, key: _KeyType) -> _CovariantValueType:
- return self._dict[key]
- def __contains__(self, key: typing.Any) -> bool:
- return key in self._dict
- def __iter__(self) -> typing.Iterator[_KeyType]:
- return iter(self.keys())
- def __len__(self) -> int:
- return len(self._dict)
- def __eq__(self, other: typing.Any) -> bool:
- if not isinstance(other, self.__class__):
- return False
- return sorted(self._list) == sorted(other._list)
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- items = self.multi_items()
- return f"{class_name}({items!r})"
- class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
- def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
- self.setlist(key, [value])
- def __delitem__(self, key: typing.Any) -> None:
- self._list = [(k, v) for k, v in self._list if k != key]
- del self._dict[key]
- def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
- self._list = [(k, v) for k, v in self._list if k != key]
- return self._dict.pop(key, default)
- def popitem(self) -> tuple[typing.Any, typing.Any]:
- key, value = self._dict.popitem()
- self._list = [(k, v) for k, v in self._list if k != key]
- return key, value
- def poplist(self, key: typing.Any) -> list[typing.Any]:
- values = [v for k, v in self._list if k == key]
- self.pop(key)
- return values
- def clear(self) -> None:
- self._dict.clear()
- self._list.clear()
- def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
- if key not in self:
- self._dict[key] = default
- self._list.append((key, default))
- return self[key]
- def setlist(self, key: typing.Any, values: list[typing.Any]) -> None:
- if not values:
- self.pop(key, None)
- else:
- existing_items = [(k, v) for (k, v) in self._list if k != key]
- self._list = existing_items + [(key, value) for value in values]
- self._dict[key] = values[-1]
- def append(self, key: typing.Any, value: typing.Any) -> None:
- self._list.append((key, value))
- self._dict[key] = value
- def update(
- self,
- *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]],
- **kwargs: typing.Any,
- ) -> None:
- value = MultiDict(*args, **kwargs)
- existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
- self._list = existing_items + value.multi_items()
- self._dict.update(value)
- class QueryParams(ImmutableMultiDict[str, str]):
- """
- An immutable multidict.
- """
- def __init__(
- self,
- *args: ImmutableMultiDict[typing.Any, typing.Any]
- | typing.Mapping[typing.Any, typing.Any]
- | list[tuple[typing.Any, typing.Any]]
- | str
- | bytes,
- **kwargs: typing.Any,
- ) -> None:
- assert len(args) < 2, "Too many arguments."
- value = args[0] if args else []
- if isinstance(value, str):
- super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
- elif isinstance(value, bytes):
- super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
- else:
- super().__init__(*args, **kwargs) # type: ignore[arg-type]
- self._list = [(str(k), str(v)) for k, v in self._list]
- self._dict = {str(k): str(v) for k, v in self._dict.items()}
- def __str__(self) -> str:
- return urlencode(self._list)
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- query_string = str(self)
- return f"{class_name}({query_string!r})"
- class UploadFile:
- """
- An uploaded file included as part of the request data.
- """
- def __init__(
- self,
- file: typing.BinaryIO,
- *,
- size: int | None = None,
- filename: str | None = None,
- headers: Headers | None = None,
- ) -> None:
- self.filename = filename
- self.file = file
- self.size = size
- self.headers = headers or Headers()
- @property
- def content_type(self) -> str | None:
- return self.headers.get("content-type", None)
- @property
- def _in_memory(self) -> bool:
- # check for SpooledTemporaryFile._rolled
- rolled_to_disk = getattr(self.file, "_rolled", True)
- return not rolled_to_disk
- async def write(self, data: bytes) -> None:
- if self.size is not None:
- self.size += len(data)
- if self._in_memory:
- self.file.write(data)
- else:
- await run_in_threadpool(self.file.write, data)
- async def read(self, size: int = -1) -> bytes:
- if self._in_memory:
- return self.file.read(size)
- return await run_in_threadpool(self.file.read, size)
- async def seek(self, offset: int) -> None:
- if self._in_memory:
- self.file.seek(offset)
- else:
- await run_in_threadpool(self.file.seek, offset)
- async def close(self) -> None:
- if self._in_memory:
- self.file.close()
- else:
- await run_in_threadpool(self.file.close)
- def __repr__(self) -> str:
- return (
- f"{self.__class__.__name__}("
- f"filename={self.filename!r}, "
- f"size={self.size!r}, "
- f"headers={self.headers!r})"
- )
- class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
- """
- An immutable multidict, containing both file uploads and text input.
- """
- def __init__(
- self,
- *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
- **kwargs: str | UploadFile,
- ) -> None:
- super().__init__(*args, **kwargs)
- async def close(self) -> None:
- for key, value in self.multi_items():
- if isinstance(value, UploadFile):
- await value.close()
- class Headers(typing.Mapping[str, str]):
- """
- An immutable, case-insensitive multidict.
- """
- def __init__(
- self,
- headers: typing.Mapping[str, str] | None = None,
- raw: list[tuple[bytes, bytes]] | None = None,
- scope: typing.MutableMapping[str, typing.Any] | None = None,
- ) -> None:
- self._list: list[tuple[bytes, bytes]] = []
- if headers is not None:
- assert raw is None, 'Cannot set both "headers" and "raw".'
- assert scope is None, 'Cannot set both "headers" and "scope".'
- self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
- elif raw is not None:
- assert scope is None, 'Cannot set both "raw" and "scope".'
- self._list = raw
- elif scope is not None:
- # scope["headers"] isn't necessarily a list
- # it might be a tuple or other iterable
- self._list = scope["headers"] = list(scope["headers"])
- @property
- def raw(self) -> list[tuple[bytes, bytes]]:
- return list(self._list)
- def keys(self) -> list[str]: # type: ignore[override]
- return [key.decode("latin-1") for key, value in self._list]
- def values(self) -> list[str]: # type: ignore[override]
- return [value.decode("latin-1") for key, value in self._list]
- def items(self) -> list[tuple[str, str]]: # type: ignore[override]
- return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
- def getlist(self, key: str) -> list[str]:
- get_header_key = key.lower().encode("latin-1")
- return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
- def mutablecopy(self) -> MutableHeaders:
- return MutableHeaders(raw=self._list[:])
- def __getitem__(self, key: str) -> str:
- get_header_key = key.lower().encode("latin-1")
- for header_key, header_value in self._list:
- if header_key == get_header_key:
- return header_value.decode("latin-1")
- raise KeyError(key)
- def __contains__(self, key: typing.Any) -> bool:
- get_header_key = key.lower().encode("latin-1")
- for header_key, header_value in self._list:
- if header_key == get_header_key:
- return True
- return False
- def __iter__(self) -> typing.Iterator[typing.Any]:
- return iter(self.keys())
- def __len__(self) -> int:
- return len(self._list)
- def __eq__(self, other: typing.Any) -> bool:
- if not isinstance(other, Headers):
- return False
- return sorted(self._list) == sorted(other._list)
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- as_dict = dict(self.items())
- if len(as_dict) == len(self):
- return f"{class_name}({as_dict!r})"
- return f"{class_name}(raw={self.raw!r})"
- class MutableHeaders(Headers):
- def __setitem__(self, key: str, value: str) -> None:
- """
- Set the header `key` to `value`, removing any duplicate entries.
- Retains insertion order.
- """
- set_key = key.lower().encode("latin-1")
- set_value = value.encode("latin-1")
- found_indexes: list[int] = []
- for idx, (item_key, item_value) in enumerate(self._list):
- if item_key == set_key:
- found_indexes.append(idx)
- for idx in reversed(found_indexes[1:]):
- del self._list[idx]
- if found_indexes:
- idx = found_indexes[0]
- self._list[idx] = (set_key, set_value)
- else:
- self._list.append((set_key, set_value))
- def __delitem__(self, key: str) -> None:
- """
- Remove the header `key`.
- """
- del_key = key.lower().encode("latin-1")
- pop_indexes: list[int] = []
- for idx, (item_key, item_value) in enumerate(self._list):
- if item_key == del_key:
- pop_indexes.append(idx)
- for idx in reversed(pop_indexes):
- del self._list[idx]
- def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
- if not isinstance(other, typing.Mapping):
- raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
- self.update(other)
- return self
- def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
- if not isinstance(other, typing.Mapping):
- raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
- new = self.mutablecopy()
- new.update(other)
- return new
- @property
- def raw(self) -> list[tuple[bytes, bytes]]:
- return self._list
- def setdefault(self, key: str, value: str) -> str:
- """
- If the header `key` does not exist, then set it to `value`.
- Returns the header value.
- """
- set_key = key.lower().encode("latin-1")
- set_value = value.encode("latin-1")
- for idx, (item_key, item_value) in enumerate(self._list):
- if item_key == set_key:
- return item_value.decode("latin-1")
- self._list.append((set_key, set_value))
- return value
- def update(self, other: typing.Mapping[str, str]) -> None:
- for key, val in other.items():
- self[key] = val
- def append(self, key: str, value: str) -> None:
- """
- Append a header, preserving any duplicate entries.
- """
- append_key = key.lower().encode("latin-1")
- append_value = value.encode("latin-1")
- self._list.append((append_key, append_value))
- def add_vary_header(self, vary: str) -> None:
- existing = self.get("vary")
- if existing is not None:
- vary = ", ".join([existing, vary])
- self["vary"] = vary
- class State:
- """
- An object that can be used to store arbitrary state.
- Used for `request.state` and `app.state`.
- """
- _state: dict[str, typing.Any]
- def __init__(self, state: dict[str, typing.Any] | None = None):
- if state is None:
- state = {}
- super().__setattr__("_state", state)
- def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
- self._state[key] = value
- def __getattr__(self, key: typing.Any) -> typing.Any:
- try:
- return self._state[key]
- except KeyError:
- message = "'{}' object has no attribute '{}'"
- raise AttributeError(message.format(self.__class__.__name__, key))
- def __delattr__(self, key: typing.Any) -> None:
- del self._state[key]
|