connection.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. from __future__ import annotations
  2. import itertools
  3. import logging
  4. import ssl
  5. import types
  6. import typing
  7. from .._backends.auto import AutoBackend
  8. from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
  9. from .._exceptions import ConnectError, ConnectTimeout
  10. from .._models import Origin, Request, Response
  11. from .._ssl import default_ssl_context
  12. from .._synchronization import AsyncLock
  13. from .._trace import Trace
  14. from .http11 import AsyncHTTP11Connection
  15. from .interfaces import AsyncConnectionInterface
  16. RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
  17. logger = logging.getLogger("httpcore.connection")
  18. def exponential_backoff(factor: float) -> typing.Iterator[float]:
  19. """
  20. Generate a geometric sequence that has a ratio of 2 and starts with 0.
  21. For example:
  22. - `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...`
  23. - `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...`
  24. """
  25. yield 0
  26. for n in itertools.count():
  27. yield factor * 2**n
  28. class AsyncHTTPConnection(AsyncConnectionInterface):
  29. def __init__(
  30. self,
  31. origin: Origin,
  32. ssl_context: ssl.SSLContext | None = None,
  33. keepalive_expiry: float | None = None,
  34. http1: bool = True,
  35. http2: bool = False,
  36. retries: int = 0,
  37. local_address: str | None = None,
  38. uds: str | None = None,
  39. network_backend: AsyncNetworkBackend | None = None,
  40. socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
  41. ) -> None:
  42. self._origin = origin
  43. self._ssl_context = ssl_context
  44. self._keepalive_expiry = keepalive_expiry
  45. self._http1 = http1
  46. self._http2 = http2
  47. self._retries = retries
  48. self._local_address = local_address
  49. self._uds = uds
  50. self._network_backend: AsyncNetworkBackend = (
  51. AutoBackend() if network_backend is None else network_backend
  52. )
  53. self._connection: AsyncConnectionInterface | None = None
  54. self._connect_failed: bool = False
  55. self._request_lock = AsyncLock()
  56. self._socket_options = socket_options
  57. async def handle_async_request(self, request: Request) -> Response:
  58. if not self.can_handle_request(request.url.origin):
  59. raise RuntimeError(
  60. f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
  61. )
  62. try:
  63. async with self._request_lock:
  64. if self._connection is None:
  65. stream = await self._connect(request)
  66. ssl_object = stream.get_extra_info("ssl_object")
  67. http2_negotiated = (
  68. ssl_object is not None
  69. and ssl_object.selected_alpn_protocol() == "h2"
  70. )
  71. if http2_negotiated or (self._http2 and not self._http1):
  72. from .http2 import AsyncHTTP2Connection
  73. self._connection = AsyncHTTP2Connection(
  74. origin=self._origin,
  75. stream=stream,
  76. keepalive_expiry=self._keepalive_expiry,
  77. )
  78. else:
  79. self._connection = AsyncHTTP11Connection(
  80. origin=self._origin,
  81. stream=stream,
  82. keepalive_expiry=self._keepalive_expiry,
  83. )
  84. except BaseException as exc:
  85. self._connect_failed = True
  86. raise exc
  87. return await self._connection.handle_async_request(request)
  88. async def _connect(self, request: Request) -> AsyncNetworkStream:
  89. timeouts = request.extensions.get("timeout", {})
  90. sni_hostname = request.extensions.get("sni_hostname", None)
  91. timeout = timeouts.get("connect", None)
  92. retries_left = self._retries
  93. delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
  94. while True:
  95. try:
  96. if self._uds is None:
  97. kwargs = {
  98. "host": self._origin.host.decode("ascii"),
  99. "port": self._origin.port,
  100. "local_address": self._local_address,
  101. "timeout": timeout,
  102. "socket_options": self._socket_options,
  103. }
  104. async with Trace("connect_tcp", logger, request, kwargs) as trace:
  105. stream = await self._network_backend.connect_tcp(**kwargs)
  106. trace.return_value = stream
  107. else:
  108. kwargs = {
  109. "path": self._uds,
  110. "timeout": timeout,
  111. "socket_options": self._socket_options,
  112. }
  113. async with Trace(
  114. "connect_unix_socket", logger, request, kwargs
  115. ) as trace:
  116. stream = await self._network_backend.connect_unix_socket(
  117. **kwargs
  118. )
  119. trace.return_value = stream
  120. if self._origin.scheme in (b"https", b"wss"):
  121. ssl_context = (
  122. default_ssl_context()
  123. if self._ssl_context is None
  124. else self._ssl_context
  125. )
  126. alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
  127. ssl_context.set_alpn_protocols(alpn_protocols)
  128. kwargs = {
  129. "ssl_context": ssl_context,
  130. "server_hostname": sni_hostname
  131. or self._origin.host.decode("ascii"),
  132. "timeout": timeout,
  133. }
  134. async with Trace("start_tls", logger, request, kwargs) as trace:
  135. stream = await stream.start_tls(**kwargs)
  136. trace.return_value = stream
  137. return stream
  138. except (ConnectError, ConnectTimeout):
  139. if retries_left <= 0:
  140. raise
  141. retries_left -= 1
  142. delay = next(delays)
  143. async with Trace("retry", logger, request, kwargs) as trace:
  144. await self._network_backend.sleep(delay)
  145. def can_handle_request(self, origin: Origin) -> bool:
  146. return origin == self._origin
  147. async def aclose(self) -> None:
  148. if self._connection is not None:
  149. async with Trace("close", logger, None, {}):
  150. await self._connection.aclose()
  151. def is_available(self) -> bool:
  152. if self._connection is None:
  153. # If HTTP/2 support is enabled, and the resulting connection could
  154. # end up as HTTP/2 then we should indicate the connection as being
  155. # available to service multiple requests.
  156. return (
  157. self._http2
  158. and (self._origin.scheme == b"https" or not self._http1)
  159. and not self._connect_failed
  160. )
  161. return self._connection.is_available()
  162. def has_expired(self) -> bool:
  163. if self._connection is None:
  164. return self._connect_failed
  165. return self._connection.has_expired()
  166. def is_idle(self) -> bool:
  167. if self._connection is None:
  168. return self._connect_failed
  169. return self._connection.is_idle()
  170. def is_closed(self) -> bool:
  171. if self._connection is None:
  172. return self._connect_failed
  173. return self._connection.is_closed()
  174. def info(self) -> str:
  175. if self._connection is None:
  176. return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
  177. return self._connection.info()
  178. def __repr__(self) -> str:
  179. return f"<{self.__class__.__name__} [{self.info()}]>"
  180. # These context managers are not used in the standard flow, but are
  181. # useful for testing or working with connection instances directly.
  182. async def __aenter__(self) -> AsyncHTTPConnection:
  183. return self
  184. async def __aexit__(
  185. self,
  186. exc_type: type[BaseException] | None = None,
  187. exc_value: BaseException | None = None,
  188. traceback: types.TracebackType | None = None,
  189. ) -> None:
  190. await self.aclose()