| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- from __future__ import annotations
- import itertools
- import logging
- import ssl
- import types
- import typing
- from .._backends.sync import SyncBackend
- from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream
- from .._exceptions import ConnectError, ConnectTimeout
- from .._models import Origin, Request, Response
- from .._ssl import default_ssl_context
- from .._synchronization import Lock
- from .._trace import Trace
- from .http11 import HTTP11Connection
- from .interfaces import ConnectionInterface
- RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
- logger = logging.getLogger("httpcore.connection")
- def exponential_backoff(factor: float) -> typing.Iterator[float]:
- """
- Generate a geometric sequence that has a ratio of 2 and starts with 0.
- For example:
- - `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...`
- - `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...`
- """
- yield 0
- for n in itertools.count():
- yield factor * 2**n
- class HTTPConnection(ConnectionInterface):
- def __init__(
- self,
- origin: Origin,
- ssl_context: ssl.SSLContext | None = None,
- keepalive_expiry: float | None = None,
- http1: bool = True,
- http2: bool = False,
- retries: int = 0,
- local_address: str | None = None,
- uds: str | None = None,
- network_backend: NetworkBackend | None = None,
- socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
- ) -> None:
- self._origin = origin
- self._ssl_context = ssl_context
- self._keepalive_expiry = keepalive_expiry
- self._http1 = http1
- self._http2 = http2
- self._retries = retries
- self._local_address = local_address
- self._uds = uds
- self._network_backend: NetworkBackend = (
- SyncBackend() if network_backend is None else network_backend
- )
- self._connection: ConnectionInterface | None = None
- self._connect_failed: bool = False
- self._request_lock = Lock()
- self._socket_options = socket_options
- def handle_request(self, request: Request) -> Response:
- if not self.can_handle_request(request.url.origin):
- raise RuntimeError(
- f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
- )
- try:
- with self._request_lock:
- if self._connection is None:
- stream = self._connect(request)
- ssl_object = stream.get_extra_info("ssl_object")
- http2_negotiated = (
- ssl_object is not None
- and ssl_object.selected_alpn_protocol() == "h2"
- )
- if http2_negotiated or (self._http2 and not self._http1):
- from .http2 import HTTP2Connection
- self._connection = HTTP2Connection(
- origin=self._origin,
- stream=stream,
- keepalive_expiry=self._keepalive_expiry,
- )
- else:
- self._connection = HTTP11Connection(
- origin=self._origin,
- stream=stream,
- keepalive_expiry=self._keepalive_expiry,
- )
- except BaseException as exc:
- self._connect_failed = True
- raise exc
- return self._connection.handle_request(request)
- def _connect(self, request: Request) -> NetworkStream:
- timeouts = request.extensions.get("timeout", {})
- sni_hostname = request.extensions.get("sni_hostname", None)
- timeout = timeouts.get("connect", None)
- retries_left = self._retries
- delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
- while True:
- try:
- if self._uds is None:
- kwargs = {
- "host": self._origin.host.decode("ascii"),
- "port": self._origin.port,
- "local_address": self._local_address,
- "timeout": timeout,
- "socket_options": self._socket_options,
- }
- with Trace("connect_tcp", logger, request, kwargs) as trace:
- stream = self._network_backend.connect_tcp(**kwargs)
- trace.return_value = stream
- else:
- kwargs = {
- "path": self._uds,
- "timeout": timeout,
- "socket_options": self._socket_options,
- }
- with Trace(
- "connect_unix_socket", logger, request, kwargs
- ) as trace:
- stream = self._network_backend.connect_unix_socket(
- **kwargs
- )
- trace.return_value = stream
- if self._origin.scheme in (b"https", b"wss"):
- ssl_context = (
- default_ssl_context()
- if self._ssl_context is None
- else self._ssl_context
- )
- alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
- ssl_context.set_alpn_protocols(alpn_protocols)
- kwargs = {
- "ssl_context": ssl_context,
- "server_hostname": sni_hostname
- or self._origin.host.decode("ascii"),
- "timeout": timeout,
- }
- with Trace("start_tls", logger, request, kwargs) as trace:
- stream = stream.start_tls(**kwargs)
- trace.return_value = stream
- return stream
- except (ConnectError, ConnectTimeout):
- if retries_left <= 0:
- raise
- retries_left -= 1
- delay = next(delays)
- with Trace("retry", logger, request, kwargs) as trace:
- self._network_backend.sleep(delay)
- def can_handle_request(self, origin: Origin) -> bool:
- return origin == self._origin
- def close(self) -> None:
- if self._connection is not None:
- with Trace("close", logger, None, {}):
- self._connection.close()
- def is_available(self) -> bool:
- if self._connection is None:
- # If HTTP/2 support is enabled, and the resulting connection could
- # end up as HTTP/2 then we should indicate the connection as being
- # available to service multiple requests.
- return (
- self._http2
- and (self._origin.scheme == b"https" or not self._http1)
- and not self._connect_failed
- )
- return self._connection.is_available()
- def has_expired(self) -> bool:
- if self._connection is None:
- return self._connect_failed
- return self._connection.has_expired()
- def is_idle(self) -> bool:
- if self._connection is None:
- return self._connect_failed
- return self._connection.is_idle()
- def is_closed(self) -> bool:
- if self._connection is None:
- return self._connect_failed
- return self._connection.is_closed()
- def info(self) -> str:
- if self._connection is None:
- return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
- return self._connection.info()
- def __repr__(self) -> str:
- return f"<{self.__class__.__name__} [{self.info()}]>"
- # These context managers are not used in the standard flow, but are
- # useful for testing or working with connection instances directly.
- def __enter__(self) -> HTTPConnection:
- return self
- def __exit__(
- self,
- exc_type: type[BaseException] | None = None,
- exc_value: BaseException | None = None,
- traceback: types.TracebackType | None = None,
- ) -> None:
- self.close()
|