trio.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from __future__ import annotations
  2. import ssl
  3. import typing
  4. import trio
  5. from .._exceptions import (
  6. ConnectError,
  7. ConnectTimeout,
  8. ExceptionMapping,
  9. ReadError,
  10. ReadTimeout,
  11. WriteError,
  12. WriteTimeout,
  13. map_exceptions,
  14. )
  15. from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
  16. class TrioStream(AsyncNetworkStream):
  17. def __init__(self, stream: trio.abc.Stream) -> None:
  18. self._stream = stream
  19. async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
  20. timeout_or_inf = float("inf") if timeout is None else timeout
  21. exc_map: ExceptionMapping = {
  22. trio.TooSlowError: ReadTimeout,
  23. trio.BrokenResourceError: ReadError,
  24. trio.ClosedResourceError: ReadError,
  25. }
  26. with map_exceptions(exc_map):
  27. with trio.fail_after(timeout_or_inf):
  28. data: bytes = await self._stream.receive_some(max_bytes=max_bytes)
  29. return data
  30. async def write(self, buffer: bytes, timeout: float | None = None) -> None:
  31. if not buffer:
  32. return
  33. timeout_or_inf = float("inf") if timeout is None else timeout
  34. exc_map: ExceptionMapping = {
  35. trio.TooSlowError: WriteTimeout,
  36. trio.BrokenResourceError: WriteError,
  37. trio.ClosedResourceError: WriteError,
  38. }
  39. with map_exceptions(exc_map):
  40. with trio.fail_after(timeout_or_inf):
  41. await self._stream.send_all(data=buffer)
  42. async def aclose(self) -> None:
  43. await self._stream.aclose()
  44. async def start_tls(
  45. self,
  46. ssl_context: ssl.SSLContext,
  47. server_hostname: str | None = None,
  48. timeout: float | None = None,
  49. ) -> AsyncNetworkStream:
  50. timeout_or_inf = float("inf") if timeout is None else timeout
  51. exc_map: ExceptionMapping = {
  52. trio.TooSlowError: ConnectTimeout,
  53. trio.BrokenResourceError: ConnectError,
  54. }
  55. ssl_stream = trio.SSLStream(
  56. self._stream,
  57. ssl_context=ssl_context,
  58. server_hostname=server_hostname,
  59. https_compatible=True,
  60. server_side=False,
  61. )
  62. with map_exceptions(exc_map):
  63. try:
  64. with trio.fail_after(timeout_or_inf):
  65. await ssl_stream.do_handshake()
  66. except Exception as exc: # pragma: nocover
  67. await self.aclose()
  68. raise exc
  69. return TrioStream(ssl_stream)
  70. def get_extra_info(self, info: str) -> typing.Any:
  71. if info == "ssl_object" and isinstance(self._stream, trio.SSLStream):
  72. # Type checkers cannot see `_ssl_object` attribute because trio._ssl.SSLStream uses __getattr__/__setattr__.
  73. # Tracked at https://github.com/python-trio/trio/issues/542
  74. return self._stream._ssl_object # type: ignore[attr-defined]
  75. if info == "client_addr":
  76. return self._get_socket_stream().socket.getsockname()
  77. if info == "server_addr":
  78. return self._get_socket_stream().socket.getpeername()
  79. if info == "socket":
  80. stream = self._stream
  81. while isinstance(stream, trio.SSLStream):
  82. stream = stream.transport_stream
  83. assert isinstance(stream, trio.SocketStream)
  84. return stream.socket
  85. if info == "is_readable":
  86. socket = self.get_extra_info("socket")
  87. return socket.is_readable()
  88. return None
  89. def _get_socket_stream(self) -> trio.SocketStream:
  90. stream = self._stream
  91. while isinstance(stream, trio.SSLStream):
  92. stream = stream.transport_stream
  93. assert isinstance(stream, trio.SocketStream)
  94. return stream
  95. class TrioBackend(AsyncNetworkBackend):
  96. async def connect_tcp(
  97. self,
  98. host: str,
  99. port: int,
  100. timeout: float | None = None,
  101. local_address: str | None = None,
  102. socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
  103. ) -> AsyncNetworkStream:
  104. # By default for TCP sockets, trio enables TCP_NODELAY.
  105. # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
  106. if socket_options is None:
  107. socket_options = [] # pragma: no cover
  108. timeout_or_inf = float("inf") if timeout is None else timeout
  109. exc_map: ExceptionMapping = {
  110. trio.TooSlowError: ConnectTimeout,
  111. trio.BrokenResourceError: ConnectError,
  112. OSError: ConnectError,
  113. }
  114. with map_exceptions(exc_map):
  115. with trio.fail_after(timeout_or_inf):
  116. stream: trio.abc.Stream = await trio.open_tcp_stream(
  117. host=host, port=port, local_address=local_address
  118. )
  119. for option in socket_options:
  120. stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
  121. return TrioStream(stream)
  122. async def connect_unix_socket(
  123. self,
  124. path: str,
  125. timeout: float | None = None,
  126. socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
  127. ) -> AsyncNetworkStream: # pragma: nocover
  128. if socket_options is None:
  129. socket_options = []
  130. timeout_or_inf = float("inf") if timeout is None else timeout
  131. exc_map: ExceptionMapping = {
  132. trio.TooSlowError: ConnectTimeout,
  133. trio.BrokenResourceError: ConnectError,
  134. OSError: ConnectError,
  135. }
  136. with map_exceptions(exc_map):
  137. with trio.fail_after(timeout_or_inf):
  138. stream: trio.abc.Stream = await trio.open_unix_socket(path)
  139. for option in socket_options:
  140. stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
  141. return TrioStream(stream)
  142. async def sleep(self, seconds: float) -> None:
  143. await trio.sleep(seconds) # pragma: nocover