| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- from __future__ import annotations
- import ssl
- import typing
- import trio
- from .._exceptions import (
- ConnectError,
- ConnectTimeout,
- ExceptionMapping,
- ReadError,
- ReadTimeout,
- WriteError,
- WriteTimeout,
- map_exceptions,
- )
- from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
- class TrioStream(AsyncNetworkStream):
- def __init__(self, stream: trio.abc.Stream) -> None:
- self._stream = stream
- async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
- timeout_or_inf = float("inf") if timeout is None else timeout
- exc_map: ExceptionMapping = {
- trio.TooSlowError: ReadTimeout,
- trio.BrokenResourceError: ReadError,
- trio.ClosedResourceError: ReadError,
- }
- with map_exceptions(exc_map):
- with trio.fail_after(timeout_or_inf):
- data: bytes = await self._stream.receive_some(max_bytes=max_bytes)
- return data
- async def write(self, buffer: bytes, timeout: float | None = None) -> None:
- if not buffer:
- return
- timeout_or_inf = float("inf") if timeout is None else timeout
- exc_map: ExceptionMapping = {
- trio.TooSlowError: WriteTimeout,
- trio.BrokenResourceError: WriteError,
- trio.ClosedResourceError: WriteError,
- }
- with map_exceptions(exc_map):
- with trio.fail_after(timeout_or_inf):
- await self._stream.send_all(data=buffer)
- async def aclose(self) -> None:
- await self._stream.aclose()
- async def start_tls(
- self,
- ssl_context: ssl.SSLContext,
- server_hostname: str | None = None,
- timeout: float | None = None,
- ) -> AsyncNetworkStream:
- timeout_or_inf = float("inf") if timeout is None else timeout
- exc_map: ExceptionMapping = {
- trio.TooSlowError: ConnectTimeout,
- trio.BrokenResourceError: ConnectError,
- }
- ssl_stream = trio.SSLStream(
- self._stream,
- ssl_context=ssl_context,
- server_hostname=server_hostname,
- https_compatible=True,
- server_side=False,
- )
- with map_exceptions(exc_map):
- try:
- with trio.fail_after(timeout_or_inf):
- await ssl_stream.do_handshake()
- except Exception as exc: # pragma: nocover
- await self.aclose()
- raise exc
- return TrioStream(ssl_stream)
- def get_extra_info(self, info: str) -> typing.Any:
- if info == "ssl_object" and isinstance(self._stream, trio.SSLStream):
- # Type checkers cannot see `_ssl_object` attribute because trio._ssl.SSLStream uses __getattr__/__setattr__.
- # Tracked at https://github.com/python-trio/trio/issues/542
- return self._stream._ssl_object # type: ignore[attr-defined]
- if info == "client_addr":
- return self._get_socket_stream().socket.getsockname()
- if info == "server_addr":
- return self._get_socket_stream().socket.getpeername()
- if info == "socket":
- stream = self._stream
- while isinstance(stream, trio.SSLStream):
- stream = stream.transport_stream
- assert isinstance(stream, trio.SocketStream)
- return stream.socket
- if info == "is_readable":
- socket = self.get_extra_info("socket")
- return socket.is_readable()
- return None
- def _get_socket_stream(self) -> trio.SocketStream:
- stream = self._stream
- while isinstance(stream, trio.SSLStream):
- stream = stream.transport_stream
- assert isinstance(stream, trio.SocketStream)
- return stream
- class TrioBackend(AsyncNetworkBackend):
- async def connect_tcp(
- self,
- host: str,
- port: int,
- timeout: float | None = None,
- local_address: str | None = None,
- socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
- ) -> AsyncNetworkStream:
- # By default for TCP sockets, trio enables TCP_NODELAY.
- # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
- if socket_options is None:
- socket_options = [] # pragma: no cover
- timeout_or_inf = float("inf") if timeout is None else timeout
- exc_map: ExceptionMapping = {
- trio.TooSlowError: ConnectTimeout,
- trio.BrokenResourceError: ConnectError,
- OSError: ConnectError,
- }
- with map_exceptions(exc_map):
- with trio.fail_after(timeout_or_inf):
- stream: trio.abc.Stream = await trio.open_tcp_stream(
- host=host, port=port, local_address=local_address
- )
- for option in socket_options:
- stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
- return TrioStream(stream)
- async def connect_unix_socket(
- self,
- path: str,
- timeout: float | None = None,
- socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
- ) -> AsyncNetworkStream: # pragma: nocover
- if socket_options is None:
- socket_options = []
- timeout_or_inf = float("inf") if timeout is None else timeout
- exc_map: ExceptionMapping = {
- trio.TooSlowError: ConnectTimeout,
- trio.BrokenResourceError: ConnectError,
- OSError: ConnectError,
- }
- with map_exceptions(exc_map):
- with trio.fail_after(timeout_or_inf):
- stream: trio.abc.Stream = await trio.open_unix_socket(path)
- for option in socket_options:
- stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
- return TrioStream(stream)
- async def sleep(self, seconds: float) -> None:
- await trio.sleep(seconds) # pragma: nocover
|