anyio.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from __future__ import annotations
  2. import ssl
  3. import typing
  4. import anyio
  5. from .._exceptions import (
  6. ConnectError,
  7. ConnectTimeout,
  8. ReadError,
  9. ReadTimeout,
  10. WriteError,
  11. WriteTimeout,
  12. map_exceptions,
  13. )
  14. from .._utils import is_socket_readable
  15. from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
  16. class AnyIOStream(AsyncNetworkStream):
  17. def __init__(self, stream: anyio.abc.ByteStream) -> None:
  18. self._stream = stream
  19. async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
  20. exc_map = {
  21. TimeoutError: ReadTimeout,
  22. anyio.BrokenResourceError: ReadError,
  23. anyio.ClosedResourceError: ReadError,
  24. anyio.EndOfStream: ReadError,
  25. }
  26. with map_exceptions(exc_map):
  27. with anyio.fail_after(timeout):
  28. try:
  29. return await self._stream.receive(max_bytes=max_bytes)
  30. except anyio.EndOfStream: # pragma: nocover
  31. return b""
  32. async def write(self, buffer: bytes, timeout: float | None = None) -> None:
  33. if not buffer:
  34. return
  35. exc_map = {
  36. TimeoutError: WriteTimeout,
  37. anyio.BrokenResourceError: WriteError,
  38. anyio.ClosedResourceError: WriteError,
  39. }
  40. with map_exceptions(exc_map):
  41. with anyio.fail_after(timeout):
  42. await self._stream.send(item=buffer)
  43. async def aclose(self) -> None:
  44. await self._stream.aclose()
  45. async def start_tls(
  46. self,
  47. ssl_context: ssl.SSLContext,
  48. server_hostname: str | None = None,
  49. timeout: float | None = None,
  50. ) -> AsyncNetworkStream:
  51. exc_map = {
  52. TimeoutError: ConnectTimeout,
  53. anyio.BrokenResourceError: ConnectError,
  54. anyio.EndOfStream: ConnectError,
  55. ssl.SSLError: ConnectError,
  56. }
  57. with map_exceptions(exc_map):
  58. try:
  59. with anyio.fail_after(timeout):
  60. ssl_stream = await anyio.streams.tls.TLSStream.wrap(
  61. self._stream,
  62. ssl_context=ssl_context,
  63. hostname=server_hostname,
  64. standard_compatible=False,
  65. server_side=False,
  66. )
  67. except Exception as exc: # pragma: nocover
  68. await self.aclose()
  69. raise exc
  70. return AnyIOStream(ssl_stream)
  71. def get_extra_info(self, info: str) -> typing.Any:
  72. if info == "ssl_object":
  73. return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None)
  74. if info == "client_addr":
  75. return self._stream.extra(anyio.abc.SocketAttribute.local_address, None)
  76. if info == "server_addr":
  77. return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None)
  78. if info == "socket":
  79. return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
  80. if info == "is_readable":
  81. sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
  82. return is_socket_readable(sock)
  83. return None
  84. class AnyIOBackend(AsyncNetworkBackend):
  85. async def connect_tcp(
  86. self,
  87. host: str,
  88. port: int,
  89. timeout: float | None = None,
  90. local_address: str | None = None,
  91. socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
  92. ) -> AsyncNetworkStream: # pragma: nocover
  93. if socket_options is None:
  94. socket_options = []
  95. exc_map = {
  96. TimeoutError: ConnectTimeout,
  97. OSError: ConnectError,
  98. anyio.BrokenResourceError: ConnectError,
  99. }
  100. with map_exceptions(exc_map):
  101. with anyio.fail_after(timeout):
  102. stream: anyio.abc.ByteStream = await anyio.connect_tcp(
  103. remote_host=host,
  104. remote_port=port,
  105. local_host=local_address,
  106. )
  107. # By default TCP sockets opened in `asyncio` include TCP_NODELAY.
  108. for option in socket_options:
  109. stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
  110. return AnyIOStream(stream)
  111. async def connect_unix_socket(
  112. self,
  113. path: str,
  114. timeout: float | None = None,
  115. socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
  116. ) -> AsyncNetworkStream: # pragma: nocover
  117. if socket_options is None:
  118. socket_options = []
  119. exc_map = {
  120. TimeoutError: ConnectTimeout,
  121. OSError: ConnectError,
  122. anyio.BrokenResourceError: ConnectError,
  123. }
  124. with map_exceptions(exc_map):
  125. with anyio.fail_after(timeout):
  126. stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
  127. for option in socket_options:
  128. stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
  129. return AnyIOStream(stream)
  130. async def sleep(self, seconds: float) -> None:
  131. await anyio.sleep(seconds) # pragma: nocover