http2.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. from __future__ import annotations
  2. import enum
  3. import logging
  4. import time
  5. import types
  6. import typing
  7. import h2.config
  8. import h2.connection
  9. import h2.events
  10. import h2.exceptions
  11. import h2.settings
  12. from .._backends.base import NetworkStream
  13. from .._exceptions import (
  14. ConnectionNotAvailable,
  15. LocalProtocolError,
  16. RemoteProtocolError,
  17. )
  18. from .._models import Origin, Request, Response
  19. from .._synchronization import Lock, Semaphore, ShieldCancellation
  20. from .._trace import Trace
  21. from .interfaces import ConnectionInterface
  22. logger = logging.getLogger("httpcore.http2")
  23. def has_body_headers(request: Request) -> bool:
  24. return any(
  25. k.lower() == b"content-length" or k.lower() == b"transfer-encoding"
  26. for k, v in request.headers
  27. )
  28. class HTTPConnectionState(enum.IntEnum):
  29. ACTIVE = 1
  30. IDLE = 2
  31. CLOSED = 3
  32. class HTTP2Connection(ConnectionInterface):
  33. READ_NUM_BYTES = 64 * 1024
  34. CONFIG = h2.config.H2Configuration(validate_inbound_headers=False)
  35. def __init__(
  36. self,
  37. origin: Origin,
  38. stream: NetworkStream,
  39. keepalive_expiry: float | None = None,
  40. ):
  41. self._origin = origin
  42. self._network_stream = stream
  43. self._keepalive_expiry: float | None = keepalive_expiry
  44. self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
  45. self._state = HTTPConnectionState.IDLE
  46. self._expire_at: float | None = None
  47. self._request_count = 0
  48. self._init_lock = Lock()
  49. self._state_lock = Lock()
  50. self._read_lock = Lock()
  51. self._write_lock = Lock()
  52. self._sent_connection_init = False
  53. self._used_all_stream_ids = False
  54. self._connection_error = False
  55. # Mapping from stream ID to response stream events.
  56. self._events: dict[
  57. int,
  58. list[
  59. h2.events.ResponseReceived
  60. | h2.events.DataReceived
  61. | h2.events.StreamEnded
  62. | h2.events.StreamReset,
  63. ],
  64. ] = {}
  65. # Connection terminated events are stored as state since
  66. # we need to handle them for all streams.
  67. self._connection_terminated: h2.events.ConnectionTerminated | None = None
  68. self._read_exception: Exception | None = None
  69. self._write_exception: Exception | None = None
  70. def handle_request(self, request: Request) -> Response:
  71. if not self.can_handle_request(request.url.origin):
  72. # This cannot occur in normal operation, since the connection pool
  73. # will only send requests on connections that handle them.
  74. # It's in place simply for resilience as a guard against incorrect
  75. # usage, for anyone working directly with httpcore connections.
  76. raise RuntimeError(
  77. f"Attempted to send request to {request.url.origin} on connection "
  78. f"to {self._origin}"
  79. )
  80. with self._state_lock:
  81. if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE):
  82. self._request_count += 1
  83. self._expire_at = None
  84. self._state = HTTPConnectionState.ACTIVE
  85. else:
  86. raise ConnectionNotAvailable()
  87. with self._init_lock:
  88. if not self._sent_connection_init:
  89. try:
  90. sci_kwargs = {"request": request}
  91. with Trace(
  92. "send_connection_init", logger, request, sci_kwargs
  93. ):
  94. self._send_connection_init(**sci_kwargs)
  95. except BaseException as exc:
  96. with ShieldCancellation():
  97. self.close()
  98. raise exc
  99. self._sent_connection_init = True
  100. # Initially start with just 1 until the remote server provides
  101. # its max_concurrent_streams value
  102. self._max_streams = 1
  103. local_settings_max_streams = (
  104. self._h2_state.local_settings.max_concurrent_streams
  105. )
  106. self._max_streams_semaphore = Semaphore(local_settings_max_streams)
  107. for _ in range(local_settings_max_streams - self._max_streams):
  108. self._max_streams_semaphore.acquire()
  109. self._max_streams_semaphore.acquire()
  110. try:
  111. stream_id = self._h2_state.get_next_available_stream_id()
  112. self._events[stream_id] = []
  113. except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
  114. self._used_all_stream_ids = True
  115. self._request_count -= 1
  116. raise ConnectionNotAvailable()
  117. try:
  118. kwargs = {"request": request, "stream_id": stream_id}
  119. with Trace("send_request_headers", logger, request, kwargs):
  120. self._send_request_headers(request=request, stream_id=stream_id)
  121. with Trace("send_request_body", logger, request, kwargs):
  122. self._send_request_body(request=request, stream_id=stream_id)
  123. with Trace(
  124. "receive_response_headers", logger, request, kwargs
  125. ) as trace:
  126. status, headers = self._receive_response(
  127. request=request, stream_id=stream_id
  128. )
  129. trace.return_value = (status, headers)
  130. return Response(
  131. status=status,
  132. headers=headers,
  133. content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
  134. extensions={
  135. "http_version": b"HTTP/2",
  136. "network_stream": self._network_stream,
  137. "stream_id": stream_id,
  138. },
  139. )
  140. except BaseException as exc: # noqa: PIE786
  141. with ShieldCancellation():
  142. kwargs = {"stream_id": stream_id}
  143. with Trace("response_closed", logger, request, kwargs):
  144. self._response_closed(stream_id=stream_id)
  145. if isinstance(exc, h2.exceptions.ProtocolError):
  146. # One case where h2 can raise a protocol error is when a
  147. # closed frame has been seen by the state machine.
  148. #
  149. # This happens when one stream is reading, and encounters
  150. # a GOAWAY event. Other flows of control may then raise
  151. # a protocol error at any point they interact with the 'h2_state'.
  152. #
  153. # In this case we'll have stored the event, and should raise
  154. # it as a RemoteProtocolError.
  155. if self._connection_terminated: # pragma: nocover
  156. raise RemoteProtocolError(self._connection_terminated)
  157. # If h2 raises a protocol error in some other state then we
  158. # must somehow have made a protocol violation.
  159. raise LocalProtocolError(exc) # pragma: nocover
  160. raise exc
  161. def _send_connection_init(self, request: Request) -> None:
  162. """
  163. The HTTP/2 connection requires some initial setup before we can start
  164. using individual request/response streams on it.
  165. """
  166. # Need to set these manually here instead of manipulating via
  167. # __setitem__() otherwise the H2Connection will emit SettingsUpdate
  168. # frames in addition to sending the undesired defaults.
  169. self._h2_state.local_settings = h2.settings.Settings(
  170. client=True,
  171. initial_values={
  172. # Disable PUSH_PROMISE frames from the server since we don't do anything
  173. # with them for now. Maybe when we support caching?
  174. h2.settings.SettingCodes.ENABLE_PUSH: 0,
  175. # These two are taken from h2 for safe defaults
  176. h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100,
  177. h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
  178. },
  179. )
  180. # Some websites (*cough* Yahoo *cough*) balk at this setting being
  181. # present in the initial handshake since it's not defined in the original
  182. # RFC despite the RFC mandating ignoring settings you don't know about.
  183. del self._h2_state.local_settings[
  184. h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
  185. ]
  186. self._h2_state.initiate_connection()
  187. self._h2_state.increment_flow_control_window(2**24)
  188. self._write_outgoing_data(request)
  189. # Sending the request...
  190. def _send_request_headers(self, request: Request, stream_id: int) -> None:
  191. """
  192. Send the request headers to a given stream ID.
  193. """
  194. end_stream = not has_body_headers(request)
  195. # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'.
  196. # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require
  197. # HTTP/1.1 style headers, and map them appropriately if we end up on
  198. # an HTTP/2 connection.
  199. authority = [v for k, v in request.headers if k.lower() == b"host"][0]
  200. headers = [
  201. (b":method", request.method),
  202. (b":authority", authority),
  203. (b":scheme", request.url.scheme),
  204. (b":path", request.url.target),
  205. ] + [
  206. (k.lower(), v)
  207. for k, v in request.headers
  208. if k.lower()
  209. not in (
  210. b"host",
  211. b"transfer-encoding",
  212. )
  213. ]
  214. self._h2_state.send_headers(stream_id, headers, end_stream=end_stream)
  215. self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id)
  216. self._write_outgoing_data(request)
  217. def _send_request_body(self, request: Request, stream_id: int) -> None:
  218. """
  219. Iterate over the request body sending it to a given stream ID.
  220. """
  221. if not has_body_headers(request):
  222. return
  223. assert isinstance(request.stream, typing.Iterable)
  224. for data in request.stream:
  225. self._send_stream_data(request, stream_id, data)
  226. self._send_end_stream(request, stream_id)
  227. def _send_stream_data(
  228. self, request: Request, stream_id: int, data: bytes
  229. ) -> None:
  230. """
  231. Send a single chunk of data in one or more data frames.
  232. """
  233. while data:
  234. max_flow = self._wait_for_outgoing_flow(request, stream_id)
  235. chunk_size = min(len(data), max_flow)
  236. chunk, data = data[:chunk_size], data[chunk_size:]
  237. self._h2_state.send_data(stream_id, chunk)
  238. self._write_outgoing_data(request)
  239. def _send_end_stream(self, request: Request, stream_id: int) -> None:
  240. """
  241. Send an empty data frame on on a given stream ID with the END_STREAM flag set.
  242. """
  243. self._h2_state.end_stream(stream_id)
  244. self._write_outgoing_data(request)
  245. # Receiving the response...
  246. def _receive_response(
  247. self, request: Request, stream_id: int
  248. ) -> tuple[int, list[tuple[bytes, bytes]]]:
  249. """
  250. Return the response status code and headers for a given stream ID.
  251. """
  252. while True:
  253. event = self._receive_stream_event(request, stream_id)
  254. if isinstance(event, h2.events.ResponseReceived):
  255. break
  256. status_code = 200
  257. headers = []
  258. assert event.headers is not None
  259. for k, v in event.headers:
  260. if k == b":status":
  261. status_code = int(v.decode("ascii", errors="ignore"))
  262. elif not k.startswith(b":"):
  263. headers.append((k, v))
  264. return (status_code, headers)
  265. def _receive_response_body(
  266. self, request: Request, stream_id: int
  267. ) -> typing.Iterator[bytes]:
  268. """
  269. Iterator that returns the bytes of the response body for a given stream ID.
  270. """
  271. while True:
  272. event = self._receive_stream_event(request, stream_id)
  273. if isinstance(event, h2.events.DataReceived):
  274. assert event.flow_controlled_length is not None
  275. assert event.data is not None
  276. amount = event.flow_controlled_length
  277. self._h2_state.acknowledge_received_data(amount, stream_id)
  278. self._write_outgoing_data(request)
  279. yield event.data
  280. elif isinstance(event, h2.events.StreamEnded):
  281. break
  282. def _receive_stream_event(
  283. self, request: Request, stream_id: int
  284. ) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
  285. """
  286. Return the next available event for a given stream ID.
  287. Will read more data from the network if required.
  288. """
  289. while not self._events.get(stream_id):
  290. self._receive_events(request, stream_id)
  291. event = self._events[stream_id].pop(0)
  292. if isinstance(event, h2.events.StreamReset):
  293. raise RemoteProtocolError(event)
  294. return event
  295. def _receive_events(
  296. self, request: Request, stream_id: int | None = None
  297. ) -> None:
  298. """
  299. Read some data from the network until we see one or more events
  300. for a given stream ID.
  301. """
  302. with self._read_lock:
  303. if self._connection_terminated is not None:
  304. last_stream_id = self._connection_terminated.last_stream_id
  305. if stream_id and last_stream_id and stream_id > last_stream_id:
  306. self._request_count -= 1
  307. raise ConnectionNotAvailable()
  308. raise RemoteProtocolError(self._connection_terminated)
  309. # This conditional is a bit icky. We don't want to block reading if we've
  310. # actually got an event to return for a given stream. We need to do that
  311. # check *within* the atomic read lock. Though it also need to be optional,
  312. # because when we call it from `_wait_for_outgoing_flow` we *do* want to
  313. # block until we've available flow control, event when we have events
  314. # pending for the stream ID we're attempting to send on.
  315. if stream_id is None or not self._events.get(stream_id):
  316. events = self._read_incoming_data(request)
  317. for event in events:
  318. if isinstance(event, h2.events.RemoteSettingsChanged):
  319. with Trace(
  320. "receive_remote_settings", logger, request
  321. ) as trace:
  322. self._receive_remote_settings_change(event)
  323. trace.return_value = event
  324. elif isinstance(
  325. event,
  326. (
  327. h2.events.ResponseReceived,
  328. h2.events.DataReceived,
  329. h2.events.StreamEnded,
  330. h2.events.StreamReset,
  331. ),
  332. ):
  333. if event.stream_id in self._events:
  334. self._events[event.stream_id].append(event)
  335. elif isinstance(event, h2.events.ConnectionTerminated):
  336. self._connection_terminated = event
  337. self._write_outgoing_data(request)
  338. def _receive_remote_settings_change(
  339. self, event: h2.events.RemoteSettingsChanged
  340. ) -> None:
  341. max_concurrent_streams = event.changed_settings.get(
  342. h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
  343. )
  344. if max_concurrent_streams:
  345. new_max_streams = min(
  346. max_concurrent_streams.new_value,
  347. self._h2_state.local_settings.max_concurrent_streams,
  348. )
  349. if new_max_streams and new_max_streams != self._max_streams:
  350. while new_max_streams > self._max_streams:
  351. self._max_streams_semaphore.release()
  352. self._max_streams += 1
  353. while new_max_streams < self._max_streams:
  354. self._max_streams_semaphore.acquire()
  355. self._max_streams -= 1
  356. def _response_closed(self, stream_id: int) -> None:
  357. self._max_streams_semaphore.release()
  358. del self._events[stream_id]
  359. with self._state_lock:
  360. if self._connection_terminated and not self._events:
  361. self.close()
  362. elif self._state == HTTPConnectionState.ACTIVE and not self._events:
  363. self._state = HTTPConnectionState.IDLE
  364. if self._keepalive_expiry is not None:
  365. now = time.monotonic()
  366. self._expire_at = now + self._keepalive_expiry
  367. if self._used_all_stream_ids: # pragma: nocover
  368. self.close()
  369. def close(self) -> None:
  370. # Note that this method unilaterally closes the connection, and does
  371. # not have any kind of locking in place around it.
  372. self._h2_state.close_connection()
  373. self._state = HTTPConnectionState.CLOSED
  374. self._network_stream.close()
  375. # Wrappers around network read/write operations...
  376. def _read_incoming_data(self, request: Request) -> list[h2.events.Event]:
  377. timeouts = request.extensions.get("timeout", {})
  378. timeout = timeouts.get("read", None)
  379. if self._read_exception is not None:
  380. raise self._read_exception # pragma: nocover
  381. try:
  382. data = self._network_stream.read(self.READ_NUM_BYTES, timeout)
  383. if data == b"":
  384. raise RemoteProtocolError("Server disconnected")
  385. except Exception as exc:
  386. # If we get a network error we should:
  387. #
  388. # 1. Save the exception and just raise it immediately on any future reads.
  389. # (For example, this means that a single read timeout or disconnect will
  390. # immediately close all pending streams. Without requiring multiple
  391. # sequential timeouts.)
  392. # 2. Mark the connection as errored, so that we don't accept any other
  393. # incoming requests.
  394. self._read_exception = exc
  395. self._connection_error = True
  396. raise exc
  397. events: list[h2.events.Event] = self._h2_state.receive_data(data)
  398. return events
  399. def _write_outgoing_data(self, request: Request) -> None:
  400. timeouts = request.extensions.get("timeout", {})
  401. timeout = timeouts.get("write", None)
  402. with self._write_lock:
  403. data_to_send = self._h2_state.data_to_send()
  404. if self._write_exception is not None:
  405. raise self._write_exception # pragma: nocover
  406. try:
  407. self._network_stream.write(data_to_send, timeout)
  408. except Exception as exc: # pragma: nocover
  409. # If we get a network error we should:
  410. #
  411. # 1. Save the exception and just raise it immediately on any future write.
  412. # (For example, this means that a single write timeout or disconnect will
  413. # immediately close all pending streams. Without requiring multiple
  414. # sequential timeouts.)
  415. # 2. Mark the connection as errored, so that we don't accept any other
  416. # incoming requests.
  417. self._write_exception = exc
  418. self._connection_error = True
  419. raise exc
  420. # Flow control...
  421. def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int:
  422. """
  423. Returns the maximum allowable outgoing flow for a given stream.
  424. If the allowable flow is zero, then waits on the network until
  425. WindowUpdated frames have increased the flow rate.
  426. https://tools.ietf.org/html/rfc7540#section-6.9
  427. """
  428. local_flow: int = self._h2_state.local_flow_control_window(stream_id)
  429. max_frame_size: int = self._h2_state.max_outbound_frame_size
  430. flow = min(local_flow, max_frame_size)
  431. while flow == 0:
  432. self._receive_events(request)
  433. local_flow = self._h2_state.local_flow_control_window(stream_id)
  434. max_frame_size = self._h2_state.max_outbound_frame_size
  435. flow = min(local_flow, max_frame_size)
  436. return flow
  437. # Interface for connection pooling...
  438. def can_handle_request(self, origin: Origin) -> bool:
  439. return origin == self._origin
  440. def is_available(self) -> bool:
  441. return (
  442. self._state != HTTPConnectionState.CLOSED
  443. and not self._connection_error
  444. and not self._used_all_stream_ids
  445. and not (
  446. self._h2_state.state_machine.state
  447. == h2.connection.ConnectionState.CLOSED
  448. )
  449. )
  450. def has_expired(self) -> bool:
  451. now = time.monotonic()
  452. return self._expire_at is not None and now > self._expire_at
  453. def is_idle(self) -> bool:
  454. return self._state == HTTPConnectionState.IDLE
  455. def is_closed(self) -> bool:
  456. return self._state == HTTPConnectionState.CLOSED
  457. def info(self) -> str:
  458. origin = str(self._origin)
  459. return (
  460. f"{origin!r}, HTTP/2, {self._state.name}, "
  461. f"Request Count: {self._request_count}"
  462. )
  463. def __repr__(self) -> str:
  464. class_name = self.__class__.__name__
  465. origin = str(self._origin)
  466. return (
  467. f"<{class_name} [{origin!r}, {self._state.name}, "
  468. f"Request Count: {self._request_count}]>"
  469. )
  470. # These context managers are not used in the standard flow, but are
  471. # useful for testing or working with connection instances directly.
  472. def __enter__(self) -> HTTP2Connection:
  473. return self
  474. def __exit__(
  475. self,
  476. exc_type: type[BaseException] | None = None,
  477. exc_value: BaseException | None = None,
  478. traceback: types.TracebackType | None = None,
  479. ) -> None:
  480. self.close()
  481. class HTTP2ConnectionByteStream:
  482. def __init__(
  483. self, connection: HTTP2Connection, request: Request, stream_id: int
  484. ) -> None:
  485. self._connection = connection
  486. self._request = request
  487. self._stream_id = stream_id
  488. self._closed = False
  489. def __iter__(self) -> typing.Iterator[bytes]:
  490. kwargs = {"request": self._request, "stream_id": self._stream_id}
  491. try:
  492. with Trace("receive_response_body", logger, self._request, kwargs):
  493. for chunk in self._connection._receive_response_body(
  494. request=self._request, stream_id=self._stream_id
  495. ):
  496. yield chunk
  497. except BaseException as exc:
  498. # If we get an exception while streaming the response,
  499. # we want to close the response (and possibly the connection)
  500. # before raising that exception.
  501. with ShieldCancellation():
  502. self.close()
  503. raise exc
  504. def close(self) -> None:
  505. if not self._closed:
  506. self._closed = True
  507. kwargs = {"stream_id": self._stream_id}
  508. with Trace("response_closed", logger, self._request, kwargs):
  509. self._connection._response_closed(stream_id=self._stream_id)