| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- from __future__ import annotations
- import ssl
- import typing
- import anyio
- from .._exceptions import (
- ConnectError,
- ConnectTimeout,
- ReadError,
- ReadTimeout,
- WriteError,
- WriteTimeout,
- map_exceptions,
- )
- from .._utils import is_socket_readable
- from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
- class AnyIOStream(AsyncNetworkStream):
- def __init__(self, stream: anyio.abc.ByteStream) -> None:
- self._stream = stream
- async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
- exc_map = {
- TimeoutError: ReadTimeout,
- anyio.BrokenResourceError: ReadError,
- anyio.ClosedResourceError: ReadError,
- anyio.EndOfStream: ReadError,
- }
- with map_exceptions(exc_map):
- with anyio.fail_after(timeout):
- try:
- return await self._stream.receive(max_bytes=max_bytes)
- except anyio.EndOfStream: # pragma: nocover
- return b""
- async def write(self, buffer: bytes, timeout: float | None = None) -> None:
- if not buffer:
- return
- exc_map = {
- TimeoutError: WriteTimeout,
- anyio.BrokenResourceError: WriteError,
- anyio.ClosedResourceError: WriteError,
- }
- with map_exceptions(exc_map):
- with anyio.fail_after(timeout):
- await self._stream.send(item=buffer)
- async def aclose(self) -> None:
- await self._stream.aclose()
- async def start_tls(
- self,
- ssl_context: ssl.SSLContext,
- server_hostname: str | None = None,
- timeout: float | None = None,
- ) -> AsyncNetworkStream:
- exc_map = {
- TimeoutError: ConnectTimeout,
- anyio.BrokenResourceError: ConnectError,
- anyio.EndOfStream: ConnectError,
- ssl.SSLError: ConnectError,
- }
- with map_exceptions(exc_map):
- try:
- with anyio.fail_after(timeout):
- ssl_stream = await anyio.streams.tls.TLSStream.wrap(
- self._stream,
- ssl_context=ssl_context,
- hostname=server_hostname,
- standard_compatible=False,
- server_side=False,
- )
- except Exception as exc: # pragma: nocover
- await self.aclose()
- raise exc
- return AnyIOStream(ssl_stream)
- def get_extra_info(self, info: str) -> typing.Any:
- if info == "ssl_object":
- return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None)
- if info == "client_addr":
- return self._stream.extra(anyio.abc.SocketAttribute.local_address, None)
- if info == "server_addr":
- return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None)
- if info == "socket":
- return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
- if info == "is_readable":
- sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
- return is_socket_readable(sock)
- return None
- class AnyIOBackend(AsyncNetworkBackend):
- async def connect_tcp(
- self,
- host: str,
- port: int,
- timeout: float | None = None,
- local_address: str | None = None,
- socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
- ) -> AsyncNetworkStream: # pragma: nocover
- if socket_options is None:
- socket_options = []
- exc_map = {
- TimeoutError: ConnectTimeout,
- OSError: ConnectError,
- anyio.BrokenResourceError: ConnectError,
- }
- with map_exceptions(exc_map):
- with anyio.fail_after(timeout):
- stream: anyio.abc.ByteStream = await anyio.connect_tcp(
- remote_host=host,
- remote_port=port,
- local_host=local_address,
- )
- # By default TCP sockets opened in `asyncio` include TCP_NODELAY.
- for option in socket_options:
- stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
- return AnyIOStream(stream)
- async def connect_unix_socket(
- self,
- path: str,
- timeout: float | None = None,
- socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
- ) -> AsyncNetworkStream: # pragma: nocover
- if socket_options is None:
- socket_options = []
- exc_map = {
- TimeoutError: ConnectTimeout,
- OSError: ConnectError,
- anyio.BrokenResourceError: ConnectError,
- }
- with map_exceptions(exc_map):
- with anyio.fail_after(timeout):
- stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
- for option in socket_options:
- stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
- return AnyIOStream(stream)
- async def sleep(self, seconds: float) -> None:
- await anyio.sleep(seconds) # pragma: nocover
|