sync.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. from __future__ import annotations
  2. import functools
  3. import socket
  4. import ssl
  5. import sys
  6. import typing
  7. from .._exceptions import (
  8. ConnectError,
  9. ConnectTimeout,
  10. ExceptionMapping,
  11. ReadError,
  12. ReadTimeout,
  13. WriteError,
  14. WriteTimeout,
  15. map_exceptions,
  16. )
  17. from .._utils import is_socket_readable
  18. from .base import SOCKET_OPTION, NetworkBackend, NetworkStream
  19. class TLSinTLSStream(NetworkStream): # pragma: no cover
  20. """
  21. Because the standard `SSLContext.wrap_socket` method does
  22. not work for `SSLSocket` objects, we need this class
  23. to implement TLS stream using an underlying `SSLObject`
  24. instance in order to support TLS on top of TLS.
  25. """
  26. # Defined in RFC 8449
  27. TLS_RECORD_SIZE = 16384
  28. def __init__(
  29. self,
  30. sock: socket.socket,
  31. ssl_context: ssl.SSLContext,
  32. server_hostname: str | None = None,
  33. timeout: float | None = None,
  34. ):
  35. self._sock = sock
  36. self._incoming = ssl.MemoryBIO()
  37. self._outgoing = ssl.MemoryBIO()
  38. self.ssl_obj = ssl_context.wrap_bio(
  39. incoming=self._incoming,
  40. outgoing=self._outgoing,
  41. server_hostname=server_hostname,
  42. )
  43. self._sock.settimeout(timeout)
  44. self._perform_io(self.ssl_obj.do_handshake)
  45. def _perform_io(
  46. self,
  47. func: typing.Callable[..., typing.Any],
  48. ) -> typing.Any:
  49. ret = None
  50. while True:
  51. errno = None
  52. try:
  53. ret = func()
  54. except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
  55. errno = e.errno
  56. self._sock.sendall(self._outgoing.read())
  57. if errno == ssl.SSL_ERROR_WANT_READ:
  58. buf = self._sock.recv(self.TLS_RECORD_SIZE)
  59. if buf:
  60. self._incoming.write(buf)
  61. else:
  62. self._incoming.write_eof()
  63. if errno is None:
  64. return ret
  65. def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
  66. exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
  67. with map_exceptions(exc_map):
  68. self._sock.settimeout(timeout)
  69. return typing.cast(
  70. bytes, self._perform_io(functools.partial(self.ssl_obj.read, max_bytes))
  71. )
  72. def write(self, buffer: bytes, timeout: float | None = None) -> None:
  73. exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
  74. with map_exceptions(exc_map):
  75. self._sock.settimeout(timeout)
  76. while buffer:
  77. nsent = self._perform_io(functools.partial(self.ssl_obj.write, buffer))
  78. buffer = buffer[nsent:]
  79. def close(self) -> None:
  80. self._sock.close()
  81. def start_tls(
  82. self,
  83. ssl_context: ssl.SSLContext,
  84. server_hostname: str | None = None,
  85. timeout: float | None = None,
  86. ) -> NetworkStream:
  87. raise NotImplementedError()
  88. def get_extra_info(self, info: str) -> typing.Any:
  89. if info == "ssl_object":
  90. return self.ssl_obj
  91. if info == "client_addr":
  92. return self._sock.getsockname()
  93. if info == "server_addr":
  94. return self._sock.getpeername()
  95. if info == "socket":
  96. return self._sock
  97. if info == "is_readable":
  98. return is_socket_readable(self._sock)
  99. return None
  100. class SyncStream(NetworkStream):
  101. def __init__(self, sock: socket.socket) -> None:
  102. self._sock = sock
  103. def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
  104. exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
  105. with map_exceptions(exc_map):
  106. self._sock.settimeout(timeout)
  107. return self._sock.recv(max_bytes)
  108. def write(self, buffer: bytes, timeout: float | None = None) -> None:
  109. if not buffer:
  110. return
  111. exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
  112. with map_exceptions(exc_map):
  113. while buffer:
  114. self._sock.settimeout(timeout)
  115. n = self._sock.send(buffer)
  116. buffer = buffer[n:]
  117. def close(self) -> None:
  118. self._sock.close()
  119. def start_tls(
  120. self,
  121. ssl_context: ssl.SSLContext,
  122. server_hostname: str | None = None,
  123. timeout: float | None = None,
  124. ) -> NetworkStream:
  125. exc_map: ExceptionMapping = {
  126. socket.timeout: ConnectTimeout,
  127. OSError: ConnectError,
  128. }
  129. with map_exceptions(exc_map):
  130. try:
  131. if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover
  132. # If the underlying socket has already been upgraded
  133. # to the TLS layer (i.e. is an instance of SSLSocket),
  134. # we need some additional smarts to support TLS-in-TLS.
  135. return TLSinTLSStream(
  136. self._sock, ssl_context, server_hostname, timeout
  137. )
  138. else:
  139. self._sock.settimeout(timeout)
  140. sock = ssl_context.wrap_socket(
  141. self._sock, server_hostname=server_hostname
  142. )
  143. except Exception as exc: # pragma: nocover
  144. self.close()
  145. raise exc
  146. return SyncStream(sock)
  147. def get_extra_info(self, info: str) -> typing.Any:
  148. if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket):
  149. return self._sock._sslobj # type: ignore
  150. if info == "client_addr":
  151. return self._sock.getsockname()
  152. if info == "server_addr":
  153. return self._sock.getpeername()
  154. if info == "socket":
  155. return self._sock
  156. if info == "is_readable":
  157. return is_socket_readable(self._sock)
  158. return None
  159. class SyncBackend(NetworkBackend):
  160. def connect_tcp(
  161. self,
  162. host: str,
  163. port: int,
  164. timeout: float | None = None,
  165. local_address: str | None = None,
  166. socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
  167. ) -> NetworkStream:
  168. # Note that we automatically include `TCP_NODELAY`
  169. # in addition to any other custom socket options.
  170. if socket_options is None:
  171. socket_options = [] # pragma: no cover
  172. address = (host, port)
  173. source_address = None if local_address is None else (local_address, 0)
  174. exc_map: ExceptionMapping = {
  175. socket.timeout: ConnectTimeout,
  176. OSError: ConnectError,
  177. }
  178. with map_exceptions(exc_map):
  179. sock = socket.create_connection(
  180. address,
  181. timeout,
  182. source_address=source_address,
  183. )
  184. for option in socket_options:
  185. sock.setsockopt(*option) # pragma: no cover
  186. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  187. return SyncStream(sock)
  188. def connect_unix_socket(
  189. self,
  190. path: str,
  191. timeout: float | None = None,
  192. socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
  193. ) -> NetworkStream: # pragma: nocover
  194. if sys.platform == "win32":
  195. raise RuntimeError(
  196. "Attempted to connect to a UNIX socket on a Windows system."
  197. )
  198. if socket_options is None:
  199. socket_options = []
  200. exc_map: ExceptionMapping = {
  201. socket.timeout: ConnectTimeout,
  202. OSError: ConnectError,
  203. }
  204. with map_exceptions(exc_map):
  205. sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  206. for option in socket_options:
  207. sock.setsockopt(*option)
  208. sock.settimeout(timeout)
  209. sock.connect(path)
  210. return SyncStream(sock)