_asyncio.py 87 KB


  1. from __future__ import annotations
  2. import array
  3. import asyncio
  4. import concurrent.futures
  5. import math
  6. import os
  7. import socket
  8. import sys
  9. import threading
  10. import weakref
  11. from asyncio import (
  12. AbstractEventLoop,
  13. CancelledError,
  14. all_tasks,
  15. create_task,
  16. current_task,
  17. get_running_loop,
  18. sleep,
  19. )
  20. from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
  21. from collections import OrderedDict, deque
  22. from collections.abc import AsyncIterator, Iterable
  23. from concurrent.futures import Future
  24. from contextlib import suppress
  25. from contextvars import Context, copy_context
  26. from dataclasses import dataclass
  27. from functools import partial, wraps
  28. from inspect import (
  29. CORO_RUNNING,
  30. CORO_SUSPENDED,
  31. getcoroutinestate,
  32. iscoroutine,
  33. )
  34. from io import IOBase
  35. from os import PathLike
  36. from queue import Queue
  37. from signal import Signals
  38. from socket import AddressFamily, SocketKind
  39. from threading import Thread
  40. from types import TracebackType
  41. from typing import (
  42. IO,
  43. Any,
  44. AsyncGenerator,
  45. Awaitable,
  46. Callable,
  47. Collection,
  48. ContextManager,
  49. Coroutine,
  50. Optional,
  51. Sequence,
  52. Tuple,
  53. TypeVar,
  54. cast,
  55. )
  56. from weakref import WeakKeyDictionary
  57. import sniffio
  58. from .. import (
  59. CapacityLimiterStatistics,
  60. EventStatistics,
  61. LockStatistics,
  62. TaskInfo,
  63. abc,
  64. )
  65. from .._core._eventloop import claim_worker_thread, threadlocals
  66. from .._core._exceptions import (
  67. BrokenResourceError,
  68. BusyResourceError,
  69. ClosedResourceError,
  70. EndOfStream,
  71. WouldBlock,
  72. iterate_exceptions,
  73. )
  74. from .._core._sockets import convert_ipv6_sockaddr
  75. from .._core._streams import create_memory_object_stream
  76. from .._core._synchronization import (
  77. CapacityLimiter as BaseCapacityLimiter,
  78. )
  79. from .._core._synchronization import Event as BaseEvent
  80. from .._core._synchronization import Lock as BaseLock
  81. from .._core._synchronization import (
  82. ResourceGuard,
  83. SemaphoreStatistics,
  84. )
  85. from .._core._synchronization import Semaphore as BaseSemaphore
  86. from .._core._tasks import CancelScope as BaseCancelScope
  87. from ..abc import (
  88. AsyncBackend,
  89. IPSockAddrType,
  90. SocketListener,
  91. UDPPacketType,
  92. UNIXDatagramPacketType,
  93. )
  94. from ..abc._eventloop import StrOrBytesPath
  95. from ..lowlevel import RunVar
  96. from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
  97. if sys.version_info >= (3, 10):
  98. from typing import ParamSpec
  99. else:
  100. from typing_extensions import ParamSpec
  101. if sys.version_info >= (3, 11):
  102. from asyncio import Runner
  103. from typing import TypeVarTuple, Unpack
  104. else:
  105. import contextvars
  106. import enum
  107. import signal
  108. from asyncio import coroutines, events, exceptions, tasks
  109. from exceptiongroup import BaseExceptionGroup
  110. from typing_extensions import TypeVarTuple, Unpack
  111. class _State(enum.Enum):
  112. CREATED = "created"
  113. INITIALIZED = "initialized"
  114. CLOSED = "closed"
  115. class Runner:
  116. # Copied from CPython 3.11
  117. def __init__(
  118. self,
  119. *,
  120. debug: bool | None = None,
  121. loop_factory: Callable[[], AbstractEventLoop] | None = None,
  122. ):
  123. self._state = _State.CREATED
  124. self._debug = debug
  125. self._loop_factory = loop_factory
  126. self._loop: AbstractEventLoop | None = None
  127. self._context = None
  128. self._interrupt_count = 0
  129. self._set_event_loop = False
  130. def __enter__(self) -> Runner:
  131. self._lazy_init()
  132. return self
  133. def __exit__(
  134. self,
  135. exc_type: type[BaseException],
  136. exc_val: BaseException,
  137. exc_tb: TracebackType,
  138. ) -> None:
  139. self.close()
  140. def close(self) -> None:
  141. """Shutdown and close event loop."""
  142. if self._state is not _State.INITIALIZED:
  143. return
  144. try:
  145. loop = self._loop
  146. _cancel_all_tasks(loop)
  147. loop.run_until_complete(loop.shutdown_asyncgens())
  148. if hasattr(loop, "shutdown_default_executor"):
  149. loop.run_until_complete(loop.shutdown_default_executor())
  150. else:
  151. loop.run_until_complete(_shutdown_default_executor(loop))
  152. finally:
  153. if self._set_event_loop:
  154. events.set_event_loop(None)
  155. loop.close()
  156. self._loop = None
  157. self._state = _State.CLOSED
  158. def get_loop(self) -> AbstractEventLoop:
  159. """Return embedded event loop."""
  160. self._lazy_init()
  161. return self._loop
  162. def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
  163. """Run a coroutine inside the embedded event loop."""
  164. if not coroutines.iscoroutine(coro):
  165. raise ValueError(f"a coroutine was expected, got {coro!r}")
  166. if events._get_running_loop() is not None:
  167. # fail fast with short traceback
  168. raise RuntimeError(
  169. "Runner.run() cannot be called from a running event loop"
  170. )
  171. self._lazy_init()
  172. if context is None:
  173. context = self._context
  174. task = context.run(self._loop.create_task, coro)
  175. if (
  176. threading.current_thread() is threading.main_thread()
  177. and signal.getsignal(signal.SIGINT) is signal.default_int_handler
  178. ):
  179. sigint_handler = partial(self._on_sigint, main_task=task)
  180. try:
  181. signal.signal(signal.SIGINT, sigint_handler)
  182. except ValueError:
  183. # `signal.signal` may throw if `threading.main_thread` does
  184. # not support signals (e.g. embedded interpreter with signals
  185. # not registered - see gh-91880)
  186. sigint_handler = None
  187. else:
  188. sigint_handler = None
  189. self._interrupt_count = 0
  190. try:
  191. return self._loop.run_until_complete(task)
  192. except exceptions.CancelledError:
  193. if self._interrupt_count > 0:
  194. uncancel = getattr(task, "uncancel", None)
  195. if uncancel is not None and uncancel() == 0:
  196. raise KeyboardInterrupt()
  197. raise # CancelledError
  198. finally:
  199. if (
  200. sigint_handler is not None
  201. and signal.getsignal(signal.SIGINT) is sigint_handler
  202. ):
  203. signal.signal(signal.SIGINT, signal.default_int_handler)
  204. def _lazy_init(self) -> None:
  205. if self._state is _State.CLOSED:
  206. raise RuntimeError("Runner is closed")
  207. if self._state is _State.INITIALIZED:
  208. return
  209. if self._loop_factory is None:
  210. self._loop = events.new_event_loop()
  211. if not self._set_event_loop:
  212. # Call set_event_loop only once to avoid calling
  213. # attach_loop multiple times on child watchers
  214. events.set_event_loop(self._loop)
  215. self._set_event_loop = True
  216. else:
  217. self._loop = self._loop_factory()
  218. if self._debug is not None:
  219. self._loop.set_debug(self._debug)
  220. self._context = contextvars.copy_context()
  221. self._state = _State.INITIALIZED
  222. def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
  223. self._interrupt_count += 1
  224. if self._interrupt_count == 1 and not main_task.done():
  225. main_task.cancel()
  226. # wakeup loop if it is blocked by select() with long timeout
  227. self._loop.call_soon_threadsafe(lambda: None)
  228. return
  229. raise KeyboardInterrupt()
  230. def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
  231. to_cancel = tasks.all_tasks(loop)
  232. if not to_cancel:
  233. return
  234. for task in to_cancel:
  235. task.cancel()
  236. loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
  237. for task in to_cancel:
  238. if task.cancelled():
  239. continue
  240. if task.exception() is not None:
  241. loop.call_exception_handler(
  242. {
  243. "message": "unhandled exception during asyncio.run() shutdown",
  244. "exception": task.exception(),
  245. "task": task,
  246. }
  247. )
  248. async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
  249. """Schedule the shutdown of the default executor."""
  250. def _do_shutdown(future: asyncio.futures.Future) -> None:
  251. try:
  252. loop._default_executor.shutdown(wait=True) # type: ignore[attr-defined]
  253. loop.call_soon_threadsafe(future.set_result, None)
  254. except Exception as ex:
  255. loop.call_soon_threadsafe(future.set_exception, ex)
  256. loop._executor_shutdown_called = True
  257. if loop._default_executor is None:
  258. return
  259. future = loop.create_future()
  260. thread = threading.Thread(target=_do_shutdown, args=(future,))
  261. thread.start()
  262. try:
  263. await future
  264. finally:
  265. thread.join()
  266. T_Retval = TypeVar("T_Retval")
  267. T_contra = TypeVar("T_contra", contravariant=True)
  268. PosArgsT = TypeVarTuple("PosArgsT")
  269. P = ParamSpec("P")
  270. _root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
  271. def find_root_task() -> asyncio.Task:
  272. root_task = _root_task.get(None)
  273. if root_task is not None and not root_task.done():
  274. return root_task
  275. # Look for a task that has been started via run_until_complete()
  276. for task in all_tasks():
  277. if task._callbacks and not task.done():
  278. callbacks = [cb for cb, context in task._callbacks]
  279. for cb in callbacks:
  280. if (
  281. cb is _run_until_complete_cb
  282. or getattr(cb, "__module__", None) == "uvloop.loop"
  283. ):
  284. _root_task.set(task)
  285. return task
  286. # Look up the topmost task in the AnyIO task tree, if possible
  287. task = cast(asyncio.Task, current_task())
  288. state = _task_states.get(task)
  289. if state:
  290. cancel_scope = state.cancel_scope
  291. while cancel_scope and cancel_scope._parent_scope is not None:
  292. cancel_scope = cancel_scope._parent_scope
  293. if cancel_scope is not None:
  294. return cast(asyncio.Task, cancel_scope._host_task)
  295. return task
  296. def get_callable_name(func: Callable) -> str:
  297. module = getattr(func, "__module__", None)
  298. qualname = getattr(func, "__qualname__", None)
  299. return ".".join([x for x in (module, qualname) if x])
  300. #
  301. # Event loop
  302. #
  303. _run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
  304. def _task_started(task: asyncio.Task) -> bool:
  305. """Return ``True`` if the task has been started and has not finished."""
  306. try:
  307. return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
  308. except AttributeError:
  309. # task coro is async_genenerator_asend https://bugs.python.org/issue37771
  310. raise Exception(f"Cannot determine if task {task} has started or not") from None
  311. #
  312. # Timeouts and cancellation
  313. #
  314. class CancelScope(BaseCancelScope):
  315. def __new__(
  316. cls, *, deadline: float = math.inf, shield: bool = False
  317. ) -> CancelScope:
  318. return object.__new__(cls)
  319. def __init__(self, deadline: float = math.inf, shield: bool = False):
  320. self._deadline = deadline
  321. self._shield = shield
  322. self._parent_scope: CancelScope | None = None
  323. self._child_scopes: set[CancelScope] = set()
  324. self._cancel_called = False
  325. self._cancelled_caught = False
  326. self._active = False
  327. self._timeout_handle: asyncio.TimerHandle | None = None
  328. self._cancel_handle: asyncio.Handle | None = None
  329. self._tasks: set[asyncio.Task] = set()
  330. self._host_task: asyncio.Task | None = None
  331. self._cancel_calls: int = 0
  332. self._cancelling: int | None = None
  333. def __enter__(self) -> CancelScope:
  334. if self._active:
  335. raise RuntimeError(
  336. "Each CancelScope may only be used for a single 'with' block"
  337. )
  338. self._host_task = host_task = cast(asyncio.Task, current_task())
  339. self._tasks.add(host_task)
  340. try:
  341. task_state = _task_states[host_task]
  342. except KeyError:
  343. task_state = TaskState(None, self)
  344. _task_states[host_task] = task_state
  345. else:
  346. self._parent_scope = task_state.cancel_scope
  347. task_state.cancel_scope = self
  348. if self._parent_scope is not None:
  349. self._parent_scope._child_scopes.add(self)
  350. self._parent_scope._tasks.remove(host_task)
  351. self._timeout()
  352. self._active = True
  353. if sys.version_info >= (3, 11):
  354. self._cancelling = self._host_task.cancelling()
  355. # Start cancelling the host task if the scope was cancelled before entering
  356. if self._cancel_called:
  357. self._deliver_cancellation(self)
  358. return self
  359. def __exit__(
  360. self,
  361. exc_type: type[BaseException] | None,
  362. exc_val: BaseException | None,
  363. exc_tb: TracebackType | None,
  364. ) -> bool | None:
  365. if not self._active:
  366. raise RuntimeError("This cancel scope is not active")
  367. if current_task() is not self._host_task:
  368. raise RuntimeError(
  369. "Attempted to exit cancel scope in a different task than it was "
  370. "entered in"
  371. )
  372. assert self._host_task is not None
  373. host_task_state = _task_states.get(self._host_task)
  374. if host_task_state is None or host_task_state.cancel_scope is not self:
  375. raise RuntimeError(
  376. "Attempted to exit a cancel scope that isn't the current tasks's "
  377. "current cancel scope"
  378. )
  379. self._active = False
  380. if self._timeout_handle:
  381. self._timeout_handle.cancel()
  382. self._timeout_handle = None
  383. self._tasks.remove(self._host_task)
  384. if self._parent_scope is not None:
  385. self._parent_scope._child_scopes.remove(self)
  386. self._parent_scope._tasks.add(self._host_task)
  387. host_task_state.cancel_scope = self._parent_scope
  388. # Restart the cancellation effort in the closest directly cancelled parent
  389. # scope if this one was shielded
  390. self._restart_cancellation_in_parent()
  391. if self._cancel_called and exc_val is not None:
  392. for exc in iterate_exceptions(exc_val):
  393. if isinstance(exc, CancelledError):
  394. self._cancelled_caught = self._uncancel(exc)
  395. if self._cancelled_caught:
  396. break
  397. return self._cancelled_caught
  398. return None
  399. def _uncancel(self, cancelled_exc: CancelledError) -> bool:
  400. if sys.version_info < (3, 9) or self._host_task is None:
  401. self._cancel_calls = 0
  402. return True
  403. # Undo all cancellations done by this scope
  404. if self._cancelling is not None:
  405. while self._cancel_calls:
  406. self._cancel_calls -= 1
  407. if self._host_task.uncancel() <= self._cancelling:
  408. return True
  409. self._cancel_calls = 0
  410. return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args
  411. def _timeout(self) -> None:
  412. if self._deadline != math.inf:
  413. loop = get_running_loop()
  414. if loop.time() >= self._deadline:
  415. self.cancel()
  416. else:
  417. self._timeout_handle = loop.call_at(self._deadline, self._timeout)
  418. def _deliver_cancellation(self, origin: CancelScope) -> bool:
  419. """
  420. Deliver cancellation to directly contained tasks and nested cancel scopes.
  421. Schedule another run at the end if we still have tasks eligible for
  422. cancellation.
  423. :param origin: the cancel scope that originated the cancellation
  424. :return: ``True`` if the delivery needs to be retried on the next cycle
  425. """
  426. should_retry = False
  427. current = current_task()
  428. for task in self._tasks:
  429. if task._must_cancel: # type: ignore[attr-defined]
  430. continue
  431. # The task is eligible for cancellation if it has started
  432. should_retry = True
  433. if task is not current and (task is self._host_task or _task_started(task)):
  434. waiter = task._fut_waiter # type: ignore[attr-defined]
  435. if not isinstance(waiter, asyncio.Future) or not waiter.done():
  436. origin._cancel_calls += 1
  437. if sys.version_info >= (3, 9):
  438. task.cancel(f"Cancelled by cancel scope {id(origin):x}")
  439. else:
  440. task.cancel()
  441. # Deliver cancellation to child scopes that aren't shielded or running their own
  442. # cancellation callbacks
  443. for scope in self._child_scopes:
  444. if not scope._shield and not scope.cancel_called:
  445. should_retry = scope._deliver_cancellation(origin) or should_retry
  446. # Schedule another callback if there are still tasks left
  447. if origin is self:
  448. if should_retry:
  449. self._cancel_handle = get_running_loop().call_soon(
  450. self._deliver_cancellation, origin
  451. )
  452. else:
  453. self._cancel_handle = None
  454. return should_retry
  455. def _restart_cancellation_in_parent(self) -> None:
  456. """
  457. Restart the cancellation effort in the closest directly cancelled parent scope.
  458. """
  459. scope = self._parent_scope
  460. while scope is not None:
  461. if scope._cancel_called:
  462. if scope._cancel_handle is None:
  463. scope._deliver_cancellation(scope)
  464. break
  465. # No point in looking beyond any shielded scope
  466. if scope._shield:
  467. break
  468. scope = scope._parent_scope
  469. def _parent_cancelled(self) -> bool:
  470. # Check whether any parent has been cancelled
  471. cancel_scope = self._parent_scope
  472. while cancel_scope is not None and not cancel_scope._shield:
  473. if cancel_scope._cancel_called:
  474. return True
  475. else:
  476. cancel_scope = cancel_scope._parent_scope
  477. return False
  478. def cancel(self) -> None:
  479. if not self._cancel_called:
  480. if self._timeout_handle:
  481. self._timeout_handle.cancel()
  482. self._timeout_handle = None
  483. self._cancel_called = True
  484. if self._host_task is not None:
  485. self._deliver_cancellation(self)
  486. @property
  487. def deadline(self) -> float:
  488. return self._deadline
  489. @deadline.setter
  490. def deadline(self, value: float) -> None:
  491. self._deadline = float(value)
  492. if self._timeout_handle is not None:
  493. self._timeout_handle.cancel()
  494. self._timeout_handle = None
  495. if self._active and not self._cancel_called:
  496. self._timeout()
  497. @property
  498. def cancel_called(self) -> bool:
  499. return self._cancel_called
  500. @property
  501. def cancelled_caught(self) -> bool:
  502. return self._cancelled_caught
  503. @property
  504. def shield(self) -> bool:
  505. return self._shield
  506. @shield.setter
  507. def shield(self, value: bool) -> None:
  508. if self._shield != value:
  509. self._shield = value
  510. if not value:
  511. self._restart_cancellation_in_parent()
  512. #
  513. # Task states
  514. #
  515. class TaskState:
  516. """
  517. Encapsulates auxiliary task information that cannot be added to the Task instance
  518. itself because there are no guarantees about its implementation.
  519. """
  520. __slots__ = "parent_id", "cancel_scope", "__weakref__"
  521. def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
  522. self.parent_id = parent_id
  523. self.cancel_scope = cancel_scope
  524. _task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
  525. #
  526. # Task groups
  527. #
  528. class _AsyncioTaskStatus(abc.TaskStatus):
  529. def __init__(self, future: asyncio.Future, parent_id: int):
  530. self._future = future
  531. self._parent_id = parent_id
  532. def started(self, value: T_contra | None = None) -> None:
  533. try:
  534. self._future.set_result(value)
  535. except asyncio.InvalidStateError:
  536. if not self._future.cancelled():
  537. raise RuntimeError(
  538. "called 'started' twice on the same task status"
  539. ) from None
  540. task = cast(asyncio.Task, current_task())
  541. _task_states[task].parent_id = self._parent_id
  542. class TaskGroup(abc.TaskGroup):
  543. def __init__(self) -> None:
  544. self.cancel_scope: CancelScope = CancelScope()
  545. self._active = False
  546. self._exceptions: list[BaseException] = []
  547. self._tasks: set[asyncio.Task] = set()
  548. async def __aenter__(self) -> TaskGroup:
  549. self.cancel_scope.__enter__()
  550. self._active = True
  551. return self
  552. async def __aexit__(
  553. self,
  554. exc_type: type[BaseException] | None,
  555. exc_val: BaseException | None,
  556. exc_tb: TracebackType | None,
  557. ) -> bool | None:
  558. ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
  559. if exc_val is not None:
  560. self.cancel_scope.cancel()
  561. if not isinstance(exc_val, CancelledError):
  562. self._exceptions.append(exc_val)
  563. cancelled_exc_while_waiting_tasks: CancelledError | None = None
  564. while self._tasks:
  565. try:
  566. await asyncio.wait(self._tasks)
  567. except CancelledError as exc:
  568. # This task was cancelled natively; reraise the CancelledError later
  569. # unless this task was already interrupted by another exception
  570. self.cancel_scope.cancel()
  571. if cancelled_exc_while_waiting_tasks is None:
  572. cancelled_exc_while_waiting_tasks = exc
  573. self._active = False
  574. if self._exceptions:
  575. raise BaseExceptionGroup(
  576. "unhandled errors in a TaskGroup", self._exceptions
  577. )
  578. # Raise the CancelledError received while waiting for child tasks to exit,
  579. # unless the context manager itself was previously exited with another
  580. # exception, or if any of the child tasks raised an exception other than
  581. # CancelledError
  582. if cancelled_exc_while_waiting_tasks:
  583. if exc_val is None or ignore_exception:
  584. raise cancelled_exc_while_waiting_tasks
  585. return ignore_exception
  586. def _spawn(
  587. self,
  588. func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
  589. args: tuple[Unpack[PosArgsT]],
  590. name: object,
  591. task_status_future: asyncio.Future | None = None,
  592. ) -> asyncio.Task:
  593. def task_done(_task: asyncio.Task) -> None:
  594. task_state = _task_states[_task]
  595. assert task_state.cancel_scope is not None
  596. assert _task in task_state.cancel_scope._tasks
  597. task_state.cancel_scope._tasks.remove(_task)
  598. self._tasks.remove(task)
  599. del _task_states[_task]
  600. try:
  601. exc = _task.exception()
  602. except CancelledError as e:
  603. while isinstance(e.__context__, CancelledError):
  604. e = e.__context__
  605. exc = e
  606. if exc is not None:
  607. # The future can only be in the cancelled state if the host task was
  608. # cancelled, so return immediately instead of adding one more
  609. # CancelledError to the exceptions list
  610. if task_status_future is not None and task_status_future.cancelled():
  611. return
  612. if task_status_future is None or task_status_future.done():
  613. if not isinstance(exc, CancelledError):
  614. self._exceptions.append(exc)
  615. if not self.cancel_scope._parent_cancelled():
  616. self.cancel_scope.cancel()
  617. else:
  618. task_status_future.set_exception(exc)
  619. elif task_status_future is not None and not task_status_future.done():
  620. task_status_future.set_exception(
  621. RuntimeError("Child exited without calling task_status.started()")
  622. )
  623. if not self._active:
  624. raise RuntimeError(
  625. "This task group is not active; no new tasks can be started."
  626. )
  627. kwargs = {}
  628. if task_status_future:
  629. parent_id = id(current_task())
  630. kwargs["task_status"] = _AsyncioTaskStatus(
  631. task_status_future, id(self.cancel_scope._host_task)
  632. )
  633. else:
  634. parent_id = id(self.cancel_scope._host_task)
  635. coro = func(*args, **kwargs)
  636. if not iscoroutine(coro):
  637. prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
  638. raise TypeError(
  639. f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
  640. f"the return value ({coro!r}) is not a coroutine object"
  641. )
  642. name = get_callable_name(func) if name is None else str(name)
  643. task = create_task(coro, name=name)
  644. task.add_done_callback(task_done)
  645. # Make the spawned task inherit the task group's cancel scope
  646. _task_states[task] = TaskState(
  647. parent_id=parent_id, cancel_scope=self.cancel_scope
  648. )
  649. self.cancel_scope._tasks.add(task)
  650. self._tasks.add(task)
  651. return task
  652. def start_soon(
  653. self,
  654. func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
  655. *args: Unpack[PosArgsT],
  656. name: object = None,
  657. ) -> None:
  658. self._spawn(func, args, name)
  659. async def start(
  660. self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
  661. ) -> Any:
  662. future: asyncio.Future = asyncio.Future()
  663. task = self._spawn(func, args, name, future)
  664. # If the task raises an exception after sending a start value without a switch
  665. # point between, the task group is cancelled and this method never proceeds to
  666. # process the completed future. That's why we have to have a shielded cancel
  667. # scope here.
  668. try:
  669. return await future
  670. except CancelledError:
  671. # Cancel the task and wait for it to exit before returning
  672. task.cancel()
  673. with CancelScope(shield=True), suppress(CancelledError):
  674. await task
  675. raise
  676. #
  677. # Threads
  678. #
  679. _Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
  680. class WorkerThread(Thread):
  681. MAX_IDLE_TIME = 10 # seconds
  682. def __init__(
  683. self,
  684. root_task: asyncio.Task,
  685. workers: set[WorkerThread],
  686. idle_workers: deque[WorkerThread],
  687. ):
  688. super().__init__(name="AnyIO worker thread")
  689. self.root_task = root_task
  690. self.workers = workers
  691. self.idle_workers = idle_workers
  692. self.loop = root_task._loop
  693. self.queue: Queue[
  694. tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
  695. ] = Queue(2)
  696. self.idle_since = AsyncIOBackend.current_time()
  697. self.stopping = False
  698. def _report_result(
  699. self, future: asyncio.Future, result: Any, exc: BaseException | None
  700. ) -> None:
  701. self.idle_since = AsyncIOBackend.current_time()
  702. if not self.stopping:
  703. self.idle_workers.append(self)
  704. if not future.cancelled():
  705. if exc is not None:
  706. if isinstance(exc, StopIteration):
  707. new_exc = RuntimeError("coroutine raised StopIteration")
  708. new_exc.__cause__ = exc
  709. exc = new_exc
  710. future.set_exception(exc)
  711. else:
  712. future.set_result(result)
  713. def run(self) -> None:
  714. with claim_worker_thread(AsyncIOBackend, self.loop):
  715. while True:
  716. item = self.queue.get()
  717. if item is None:
  718. # Shutdown command received
  719. return
  720. context, func, args, future, cancel_scope = item
  721. if not future.cancelled():
  722. result = None
  723. exception: BaseException | None = None
  724. threadlocals.current_cancel_scope = cancel_scope
  725. try:
  726. result = context.run(func, *args)
  727. except BaseException as exc:
  728. exception = exc
  729. finally:
  730. del threadlocals.current_cancel_scope
  731. if not self.loop.is_closed():
  732. self.loop.call_soon_threadsafe(
  733. self._report_result, future, result, exception
  734. )
  735. self.queue.task_done()
  736. def stop(self, f: asyncio.Task | None = None) -> None:
  737. self.stopping = True
  738. self.queue.put_nowait(None)
  739. self.workers.discard(self)
  740. try:
  741. self.idle_workers.remove(self)
  742. except ValueError:
  743. pass
  744. _threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
  745. "_threadpool_idle_workers"
  746. )
  747. _threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
  748. class BlockingPortal(abc.BlockingPortal):
  749. def __new__(cls) -> BlockingPortal:
  750. return object.__new__(cls)
  751. def __init__(self) -> None:
  752. super().__init__()
  753. self._loop = get_running_loop()
  754. def _spawn_task_from_thread(
  755. self,
  756. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  757. args: tuple[Unpack[PosArgsT]],
  758. kwargs: dict[str, Any],
  759. name: object,
  760. future: Future[T_Retval],
  761. ) -> None:
  762. AsyncIOBackend.run_sync_from_thread(
  763. partial(self._task_group.start_soon, name=name),
  764. (self._call_func, func, args, kwargs, future),
  765. self._loop,
  766. )
  767. #
  768. # Subprocesses
  769. #
  770. @dataclass(eq=False)
  771. class StreamReaderWrapper(abc.ByteReceiveStream):
  772. _stream: asyncio.StreamReader
  773. async def receive(self, max_bytes: int = 65536) -> bytes:
  774. data = await self._stream.read(max_bytes)
  775. if data:
  776. return data
  777. else:
  778. raise EndOfStream
  779. async def aclose(self) -> None:
  780. self._stream.set_exception(ClosedResourceError())
  781. await AsyncIOBackend.checkpoint()
  782. @dataclass(eq=False)
  783. class StreamWriterWrapper(abc.ByteSendStream):
  784. _stream: asyncio.StreamWriter
  785. async def send(self, item: bytes) -> None:
  786. self._stream.write(item)
  787. await self._stream.drain()
  788. async def aclose(self) -> None:
  789. self._stream.close()
  790. await AsyncIOBackend.checkpoint()
  791. @dataclass(eq=False)
  792. class Process(abc.Process):
  793. _process: asyncio.subprocess.Process
  794. _stdin: StreamWriterWrapper | None
  795. _stdout: StreamReaderWrapper | None
  796. _stderr: StreamReaderWrapper | None
  797. async def aclose(self) -> None:
  798. with CancelScope(shield=True):
  799. if self._stdin:
  800. await self._stdin.aclose()
  801. if self._stdout:
  802. await self._stdout.aclose()
  803. if self._stderr:
  804. await self._stderr.aclose()
  805. try:
  806. await self.wait()
  807. except BaseException:
  808. self.kill()
  809. with CancelScope(shield=True):
  810. await self.wait()
  811. raise
  812. async def wait(self) -> int:
  813. return await self._process.wait()
  814. def terminate(self) -> None:
  815. self._process.terminate()
  816. def kill(self) -> None:
  817. self._process.kill()
  818. def send_signal(self, signal: int) -> None:
  819. self._process.send_signal(signal)
  820. @property
  821. def pid(self) -> int:
  822. return self._process.pid
  823. @property
  824. def returncode(self) -> int | None:
  825. return self._process.returncode
  826. @property
  827. def stdin(self) -> abc.ByteSendStream | None:
  828. return self._stdin
  829. @property
  830. def stdout(self) -> abc.ByteReceiveStream | None:
  831. return self._stdout
  832. @property
  833. def stderr(self) -> abc.ByteReceiveStream | None:
  834. return self._stderr
  835. def _forcibly_shutdown_process_pool_on_exit(
  836. workers: set[Process], _task: object
  837. ) -> None:
  838. """
  839. Forcibly shuts down worker processes belonging to this event loop."""
  840. child_watcher: asyncio.AbstractChildWatcher | None = None
  841. if sys.version_info < (3, 12):
  842. try:
  843. child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
  844. except NotImplementedError:
  845. pass
  846. # Close as much as possible (w/o async/await) to avoid warnings
  847. for process in workers:
  848. if process.returncode is None:
  849. continue
  850. process._stdin._stream._transport.close() # type: ignore[union-attr]
  851. process._stdout._stream._transport.close() # type: ignore[union-attr]
  852. process._stderr._stream._transport.close() # type: ignore[union-attr]
  853. process.kill()
  854. if child_watcher:
  855. child_watcher.remove_child_handler(process.pid)
  856. async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
  857. """
  858. Shuts down worker processes belonging to this event loop.
  859. NOTE: this only works when the event loop was started using asyncio.run() or
  860. anyio.run().
  861. """
  862. process: abc.Process
  863. try:
  864. await sleep(math.inf)
  865. except asyncio.CancelledError:
  866. for process in workers:
  867. if process.returncode is None:
  868. process.kill()
  869. for process in workers:
  870. await process.aclose()
  871. #
  872. # Sockets and networking
  873. #
  874. class StreamProtocol(asyncio.Protocol):
  875. read_queue: deque[bytes]
  876. read_event: asyncio.Event
  877. write_event: asyncio.Event
  878. exception: Exception | None = None
  879. is_at_eof: bool = False
  880. def connection_made(self, transport: asyncio.BaseTransport) -> None:
  881. self.read_queue = deque()
  882. self.read_event = asyncio.Event()
  883. self.write_event = asyncio.Event()
  884. self.write_event.set()
  885. cast(asyncio.Transport, transport).set_write_buffer_limits(0)
  886. def connection_lost(self, exc: Exception | None) -> None:
  887. if exc:
  888. self.exception = BrokenResourceError()
  889. self.exception.__cause__ = exc
  890. self.read_event.set()
  891. self.write_event.set()
  892. def data_received(self, data: bytes) -> None:
  893. # ProactorEventloop sometimes sends bytearray instead of bytes
  894. self.read_queue.append(bytes(data))
  895. self.read_event.set()
  896. def eof_received(self) -> bool | None:
  897. self.is_at_eof = True
  898. self.read_event.set()
  899. return True
  900. def pause_writing(self) -> None:
  901. self.write_event = asyncio.Event()
  902. def resume_writing(self) -> None:
  903. self.write_event.set()
  904. class DatagramProtocol(asyncio.DatagramProtocol):
  905. read_queue: deque[tuple[bytes, IPSockAddrType]]
  906. read_event: asyncio.Event
  907. write_event: asyncio.Event
  908. exception: Exception | None = None
  909. def connection_made(self, transport: asyncio.BaseTransport) -> None:
  910. self.read_queue = deque(maxlen=100) # arbitrary value
  911. self.read_event = asyncio.Event()
  912. self.write_event = asyncio.Event()
  913. self.write_event.set()
  914. def connection_lost(self, exc: Exception | None) -> None:
  915. self.read_event.set()
  916. self.write_event.set()
  917. def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
  918. addr = convert_ipv6_sockaddr(addr)
  919. self.read_queue.append((data, addr))
  920. self.read_event.set()
  921. def error_received(self, exc: Exception) -> None:
  922. self.exception = exc
  923. def pause_writing(self) -> None:
  924. self.write_event.clear()
  925. def resume_writing(self) -> None:
  926. self.write_event.set()
  927. class SocketStream(abc.SocketStream):
  928. def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
  929. self._transport = transport
  930. self._protocol = protocol
  931. self._receive_guard = ResourceGuard("reading from")
  932. self._send_guard = ResourceGuard("writing to")
  933. self._closed = False
  934. @property
  935. def _raw_socket(self) -> socket.socket:
  936. return self._transport.get_extra_info("socket")
  937. async def receive(self, max_bytes: int = 65536) -> bytes:
  938. with self._receive_guard:
  939. if (
  940. not self._protocol.read_event.is_set()
  941. and not self._transport.is_closing()
  942. and not self._protocol.is_at_eof
  943. ):
  944. self._transport.resume_reading()
  945. await self._protocol.read_event.wait()
  946. self._transport.pause_reading()
  947. else:
  948. await AsyncIOBackend.checkpoint()
  949. try:
  950. chunk = self._protocol.read_queue.popleft()
  951. except IndexError:
  952. if self._closed:
  953. raise ClosedResourceError from None
  954. elif self._protocol.exception:
  955. raise self._protocol.exception from None
  956. else:
  957. raise EndOfStream from None
  958. if len(chunk) > max_bytes:
  959. # Split the oversized chunk
  960. chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
  961. self._protocol.read_queue.appendleft(leftover)
  962. # If the read queue is empty, clear the flag so that the next call will
  963. # block until data is available
  964. if not self._protocol.read_queue:
  965. self._protocol.read_event.clear()
  966. return chunk
  967. async def send(self, item: bytes) -> None:
  968. with self._send_guard:
  969. await AsyncIOBackend.checkpoint()
  970. if self._closed:
  971. raise ClosedResourceError
  972. elif self._protocol.exception is not None:
  973. raise self._protocol.exception
  974. try:
  975. self._transport.write(item)
  976. except RuntimeError as exc:
  977. if self._transport.is_closing():
  978. raise BrokenResourceError from exc
  979. else:
  980. raise
  981. await self._protocol.write_event.wait()
  982. async def send_eof(self) -> None:
  983. try:
  984. self._transport.write_eof()
  985. except OSError:
  986. pass
  987. async def aclose(self) -> None:
  988. if not self._transport.is_closing():
  989. self._closed = True
  990. try:
  991. self._transport.write_eof()
  992. except OSError:
  993. pass
  994. self._transport.close()
  995. await sleep(0)
  996. self._transport.abort()
  997. class _RawSocketMixin:
  998. _receive_future: asyncio.Future | None = None
  999. _send_future: asyncio.Future | None = None
  1000. _closing = False
  1001. def __init__(self, raw_socket: socket.socket):
  1002. self.__raw_socket = raw_socket
  1003. self._receive_guard = ResourceGuard("reading from")
  1004. self._send_guard = ResourceGuard("writing to")
  1005. @property
  1006. def _raw_socket(self) -> socket.socket:
  1007. return self.__raw_socket
  1008. def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
  1009. def callback(f: object) -> None:
  1010. del self._receive_future
  1011. loop.remove_reader(self.__raw_socket)
  1012. f = self._receive_future = asyncio.Future()
  1013. loop.add_reader(self.__raw_socket, f.set_result, None)
  1014. f.add_done_callback(callback)
  1015. return f
  1016. def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
  1017. def callback(f: object) -> None:
  1018. del self._send_future
  1019. loop.remove_writer(self.__raw_socket)
  1020. f = self._send_future = asyncio.Future()
  1021. loop.add_writer(self.__raw_socket, f.set_result, None)
  1022. f.add_done_callback(callback)
  1023. return f
  1024. async def aclose(self) -> None:
  1025. if not self._closing:
  1026. self._closing = True
  1027. if self.__raw_socket.fileno() != -1:
  1028. self.__raw_socket.close()
  1029. if self._receive_future:
  1030. self._receive_future.set_result(None)
  1031. if self._send_future:
  1032. self._send_future.set_result(None)
  1033. class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
  1034. async def send_eof(self) -> None:
  1035. with self._send_guard:
  1036. self._raw_socket.shutdown(socket.SHUT_WR)
  1037. async def receive(self, max_bytes: int = 65536) -> bytes:
  1038. loop = get_running_loop()
  1039. await AsyncIOBackend.checkpoint()
  1040. with self._receive_guard:
  1041. while True:
  1042. try:
  1043. data = self._raw_socket.recv(max_bytes)
  1044. except BlockingIOError:
  1045. await self._wait_until_readable(loop)
  1046. except OSError as exc:
  1047. if self._closing:
  1048. raise ClosedResourceError from None
  1049. else:
  1050. raise BrokenResourceError from exc
  1051. else:
  1052. if not data:
  1053. raise EndOfStream
  1054. return data
  1055. async def send(self, item: bytes) -> None:
  1056. loop = get_running_loop()
  1057. await AsyncIOBackend.checkpoint()
  1058. with self._send_guard:
  1059. view = memoryview(item)
  1060. while view:
  1061. try:
  1062. bytes_sent = self._raw_socket.send(view)
  1063. except BlockingIOError:
  1064. await self._wait_until_writable(loop)
  1065. except OSError as exc:
  1066. if self._closing:
  1067. raise ClosedResourceError from None
  1068. else:
  1069. raise BrokenResourceError from exc
  1070. else:
  1071. view = view[bytes_sent:]
  1072. async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
  1073. if not isinstance(msglen, int) or msglen < 0:
  1074. raise ValueError("msglen must be a non-negative integer")
  1075. if not isinstance(maxfds, int) or maxfds < 1:
  1076. raise ValueError("maxfds must be a positive integer")
  1077. loop = get_running_loop()
  1078. fds = array.array("i")
  1079. await AsyncIOBackend.checkpoint()
  1080. with self._receive_guard:
  1081. while True:
  1082. try:
  1083. message, ancdata, flags, addr = self._raw_socket.recvmsg(
  1084. msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
  1085. )
  1086. except BlockingIOError:
  1087. await self._wait_until_readable(loop)
  1088. except OSError as exc:
  1089. if self._closing:
  1090. raise ClosedResourceError from None
  1091. else:
  1092. raise BrokenResourceError from exc
  1093. else:
  1094. if not message and not ancdata:
  1095. raise EndOfStream
  1096. break
  1097. for cmsg_level, cmsg_type, cmsg_data in ancdata:
  1098. if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
  1099. raise RuntimeError(
  1100. f"Received unexpected ancillary data; message = {message!r}, "
  1101. f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
  1102. )
  1103. fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
  1104. return message, list(fds)
  1105. async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
  1106. if not message:
  1107. raise ValueError("message must not be empty")
  1108. if not fds:
  1109. raise ValueError("fds must not be empty")
  1110. loop = get_running_loop()
  1111. filenos: list[int] = []
  1112. for fd in fds:
  1113. if isinstance(fd, int):
  1114. filenos.append(fd)
  1115. elif isinstance(fd, IOBase):
  1116. filenos.append(fd.fileno())
  1117. fdarray = array.array("i", filenos)
  1118. await AsyncIOBackend.checkpoint()
  1119. with self._send_guard:
  1120. while True:
  1121. try:
  1122. # The ignore can be removed after mypy picks up
  1123. # https://github.com/python/typeshed/pull/5545
  1124. self._raw_socket.sendmsg(
  1125. [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
  1126. )
  1127. break
  1128. except BlockingIOError:
  1129. await self._wait_until_writable(loop)
  1130. except OSError as exc:
  1131. if self._closing:
  1132. raise ClosedResourceError from None
  1133. else:
  1134. raise BrokenResourceError from exc
  1135. class TCPSocketListener(abc.SocketListener):
  1136. _accept_scope: CancelScope | None = None
  1137. _closed = False
  1138. def __init__(self, raw_socket: socket.socket):
  1139. self.__raw_socket = raw_socket
  1140. self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
  1141. self._accept_guard = ResourceGuard("accepting connections from")
  1142. @property
  1143. def _raw_socket(self) -> socket.socket:
  1144. return self.__raw_socket
  1145. async def accept(self) -> abc.SocketStream:
  1146. if self._closed:
  1147. raise ClosedResourceError
  1148. with self._accept_guard:
  1149. await AsyncIOBackend.checkpoint()
  1150. with CancelScope() as self._accept_scope:
  1151. try:
  1152. client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
  1153. except asyncio.CancelledError:
  1154. # Workaround for https://bugs.python.org/issue41317
  1155. try:
  1156. self._loop.remove_reader(self._raw_socket)
  1157. except (ValueError, NotImplementedError):
  1158. pass
  1159. if self._closed:
  1160. raise ClosedResourceError from None
  1161. raise
  1162. finally:
  1163. self._accept_scope = None
  1164. client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  1165. transport, protocol = await self._loop.connect_accepted_socket(
  1166. StreamProtocol, client_sock
  1167. )
  1168. return SocketStream(transport, protocol)
  1169. async def aclose(self) -> None:
  1170. if self._closed:
  1171. return
  1172. self._closed = True
  1173. if self._accept_scope:
  1174. # Workaround for https://bugs.python.org/issue41317
  1175. try:
  1176. self._loop.remove_reader(self._raw_socket)
  1177. except (ValueError, NotImplementedError):
  1178. pass
  1179. self._accept_scope.cancel()
  1180. await sleep(0)
  1181. self._raw_socket.close()
  1182. class UNIXSocketListener(abc.SocketListener):
  1183. def __init__(self, raw_socket: socket.socket):
  1184. self.__raw_socket = raw_socket
  1185. self._loop = get_running_loop()
  1186. self._accept_guard = ResourceGuard("accepting connections from")
  1187. self._closed = False
  1188. async def accept(self) -> abc.SocketStream:
  1189. await AsyncIOBackend.checkpoint()
  1190. with self._accept_guard:
  1191. while True:
  1192. try:
  1193. client_sock, _ = self.__raw_socket.accept()
  1194. client_sock.setblocking(False)
  1195. return UNIXSocketStream(client_sock)
  1196. except BlockingIOError:
  1197. f: asyncio.Future = asyncio.Future()
  1198. self._loop.add_reader(self.__raw_socket, f.set_result, None)
  1199. f.add_done_callback(
  1200. lambda _: self._loop.remove_reader(self.__raw_socket)
  1201. )
  1202. await f
  1203. except OSError as exc:
  1204. if self._closed:
  1205. raise ClosedResourceError from None
  1206. else:
  1207. raise BrokenResourceError from exc
  1208. async def aclose(self) -> None:
  1209. self._closed = True
  1210. self.__raw_socket.close()
  1211. @property
  1212. def _raw_socket(self) -> socket.socket:
  1213. return self.__raw_socket
  1214. class UDPSocket(abc.UDPSocket):
  1215. def __init__(
  1216. self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
  1217. ):
  1218. self._transport = transport
  1219. self._protocol = protocol
  1220. self._receive_guard = ResourceGuard("reading from")
  1221. self._send_guard = ResourceGuard("writing to")
  1222. self._closed = False
  1223. @property
  1224. def _raw_socket(self) -> socket.socket:
  1225. return self._transport.get_extra_info("socket")
  1226. async def aclose(self) -> None:
  1227. if not self._transport.is_closing():
  1228. self._closed = True
  1229. self._transport.close()
  1230. async def receive(self) -> tuple[bytes, IPSockAddrType]:
  1231. with self._receive_guard:
  1232. await AsyncIOBackend.checkpoint()
  1233. # If the buffer is empty, ask for more data
  1234. if not self._protocol.read_queue and not self._transport.is_closing():
  1235. self._protocol.read_event.clear()
  1236. await self._protocol.read_event.wait()
  1237. try:
  1238. return self._protocol.read_queue.popleft()
  1239. except IndexError:
  1240. if self._closed:
  1241. raise ClosedResourceError from None
  1242. else:
  1243. raise BrokenResourceError from None
  1244. async def send(self, item: UDPPacketType) -> None:
  1245. with self._send_guard:
  1246. await AsyncIOBackend.checkpoint()
  1247. await self._protocol.write_event.wait()
  1248. if self._closed:
  1249. raise ClosedResourceError
  1250. elif self._transport.is_closing():
  1251. raise BrokenResourceError
  1252. else:
  1253. self._transport.sendto(*item)
  1254. class ConnectedUDPSocket(abc.ConnectedUDPSocket):
  1255. def __init__(
  1256. self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
  1257. ):
  1258. self._transport = transport
  1259. self._protocol = protocol
  1260. self._receive_guard = ResourceGuard("reading from")
  1261. self._send_guard = ResourceGuard("writing to")
  1262. self._closed = False
  1263. @property
  1264. def _raw_socket(self) -> socket.socket:
  1265. return self._transport.get_extra_info("socket")
  1266. async def aclose(self) -> None:
  1267. if not self._transport.is_closing():
  1268. self._closed = True
  1269. self._transport.close()
  1270. async def receive(self) -> bytes:
  1271. with self._receive_guard:
  1272. await AsyncIOBackend.checkpoint()
  1273. # If the buffer is empty, ask for more data
  1274. if not self._protocol.read_queue and not self._transport.is_closing():
  1275. self._protocol.read_event.clear()
  1276. await self._protocol.read_event.wait()
  1277. try:
  1278. packet = self._protocol.read_queue.popleft()
  1279. except IndexError:
  1280. if self._closed:
  1281. raise ClosedResourceError from None
  1282. else:
  1283. raise BrokenResourceError from None
  1284. return packet[0]
  1285. async def send(self, item: bytes) -> None:
  1286. with self._send_guard:
  1287. await AsyncIOBackend.checkpoint()
  1288. await self._protocol.write_event.wait()
  1289. if self._closed:
  1290. raise ClosedResourceError
  1291. elif self._transport.is_closing():
  1292. raise BrokenResourceError
  1293. else:
  1294. self._transport.sendto(item)
  1295. class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
  1296. async def receive(self) -> UNIXDatagramPacketType:
  1297. loop = get_running_loop()
  1298. await AsyncIOBackend.checkpoint()
  1299. with self._receive_guard:
  1300. while True:
  1301. try:
  1302. data = self._raw_socket.recvfrom(65536)
  1303. except BlockingIOError:
  1304. await self._wait_until_readable(loop)
  1305. except OSError as exc:
  1306. if self._closing:
  1307. raise ClosedResourceError from None
  1308. else:
  1309. raise BrokenResourceError from exc
  1310. else:
  1311. return data
  1312. async def send(self, item: UNIXDatagramPacketType) -> None:
  1313. loop = get_running_loop()
  1314. await AsyncIOBackend.checkpoint()
  1315. with self._send_guard:
  1316. while True:
  1317. try:
  1318. self._raw_socket.sendto(*item)
  1319. except BlockingIOError:
  1320. await self._wait_until_writable(loop)
  1321. except OSError as exc:
  1322. if self._closing:
  1323. raise ClosedResourceError from None
  1324. else:
  1325. raise BrokenResourceError from exc
  1326. else:
  1327. return
  1328. class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
  1329. async def receive(self) -> bytes:
  1330. loop = get_running_loop()
  1331. await AsyncIOBackend.checkpoint()
  1332. with self._receive_guard:
  1333. while True:
  1334. try:
  1335. data = self._raw_socket.recv(65536)
  1336. except BlockingIOError:
  1337. await self._wait_until_readable(loop)
  1338. except OSError as exc:
  1339. if self._closing:
  1340. raise ClosedResourceError from None
  1341. else:
  1342. raise BrokenResourceError from exc
  1343. else:
  1344. return data
  1345. async def send(self, item: bytes) -> None:
  1346. loop = get_running_loop()
  1347. await AsyncIOBackend.checkpoint()
  1348. with self._send_guard:
  1349. while True:
  1350. try:
  1351. self._raw_socket.send(item)
  1352. except BlockingIOError:
  1353. await self._wait_until_writable(loop)
  1354. except OSError as exc:
  1355. if self._closing:
  1356. raise ClosedResourceError from None
  1357. else:
  1358. raise BrokenResourceError from exc
  1359. else:
  1360. return
  1361. _read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
  1362. _write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
  1363. #
  1364. # Synchronization
  1365. #
  1366. class Event(BaseEvent):
  1367. def __new__(cls) -> Event:
  1368. return object.__new__(cls)
  1369. def __init__(self) -> None:
  1370. self._event = asyncio.Event()
  1371. def set(self) -> None:
  1372. self._event.set()
  1373. def is_set(self) -> bool:
  1374. return self._event.is_set()
  1375. async def wait(self) -> None:
  1376. if self.is_set():
  1377. await AsyncIOBackend.checkpoint()
  1378. else:
  1379. await self._event.wait()
  1380. def statistics(self) -> EventStatistics:
  1381. return EventStatistics(len(self._event._waiters))
  1382. class Lock(BaseLock):
  1383. def __new__(cls, *, fast_acquire: bool = False) -> Lock:
  1384. return object.__new__(cls)
  1385. def __init__(self, *, fast_acquire: bool = False) -> None:
  1386. self._fast_acquire = fast_acquire
  1387. self._owner_task: asyncio.Task | None = None
  1388. self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque()
  1389. async def acquire(self) -> None:
  1390. task = cast(asyncio.Task, current_task())
  1391. if self._owner_task is None and not self._waiters:
  1392. await AsyncIOBackend.checkpoint_if_cancelled()
  1393. self._owner_task = task
  1394. # Unless on the "fast path", yield control of the event loop so that other
  1395. # tasks can run too
  1396. if not self._fast_acquire:
  1397. try:
  1398. await AsyncIOBackend.cancel_shielded_checkpoint()
  1399. except CancelledError:
  1400. self.release()
  1401. raise
  1402. return
  1403. if self._owner_task == task:
  1404. raise RuntimeError("Attempted to acquire an already held Lock")
  1405. fut: asyncio.Future[None] = asyncio.Future()
  1406. item = task, fut
  1407. self._waiters.append(item)
  1408. try:
  1409. await fut
  1410. except CancelledError:
  1411. self._waiters.remove(item)
  1412. if self._owner_task is task:
  1413. self.release()
  1414. raise
  1415. self._waiters.remove(item)
  1416. def acquire_nowait(self) -> None:
  1417. task = cast(asyncio.Task, current_task())
  1418. if self._owner_task is None and not self._waiters:
  1419. self._owner_task = task
  1420. return
  1421. if self._owner_task is task:
  1422. raise RuntimeError("Attempted to acquire an already held Lock")
  1423. raise WouldBlock
  1424. def locked(self) -> bool:
  1425. return self._owner_task is not None
  1426. def release(self) -> None:
  1427. if self._owner_task != current_task():
  1428. raise RuntimeError("The current task is not holding this lock")
  1429. for task, fut in self._waiters:
  1430. if not fut.cancelled():
  1431. self._owner_task = task
  1432. fut.set_result(None)
  1433. return
  1434. self._owner_task = None
  1435. def statistics(self) -> LockStatistics:
  1436. task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None
  1437. return LockStatistics(self.locked(), task_info, len(self._waiters))
  1438. class Semaphore(BaseSemaphore):
  1439. def __new__(
  1440. cls,
  1441. initial_value: int,
  1442. *,
  1443. max_value: int | None = None,
  1444. fast_acquire: bool = False,
  1445. ) -> Semaphore:
  1446. return object.__new__(cls)
  1447. def __init__(
  1448. self,
  1449. initial_value: int,
  1450. *,
  1451. max_value: int | None = None,
  1452. fast_acquire: bool = False,
  1453. ):
  1454. super().__init__(initial_value, max_value=max_value)
  1455. self._value = initial_value
  1456. self._max_value = max_value
  1457. self._fast_acquire = fast_acquire
  1458. self._waiters: deque[asyncio.Future[None]] = deque()
  1459. async def acquire(self) -> None:
  1460. if self._value > 0 and not self._waiters:
  1461. await AsyncIOBackend.checkpoint_if_cancelled()
  1462. self._value -= 1
  1463. # Unless on the "fast path", yield control of the event loop so that other
  1464. # tasks can run too
  1465. if not self._fast_acquire:
  1466. try:
  1467. await AsyncIOBackend.cancel_shielded_checkpoint()
  1468. except CancelledError:
  1469. self.release()
  1470. raise
  1471. return
  1472. fut: asyncio.Future[None] = asyncio.Future()
  1473. self._waiters.append(fut)
  1474. try:
  1475. await fut
  1476. except CancelledError:
  1477. try:
  1478. self._waiters.remove(fut)
  1479. except ValueError:
  1480. self.release()
  1481. raise
  1482. def acquire_nowait(self) -> None:
  1483. if self._value == 0:
  1484. raise WouldBlock
  1485. self._value -= 1
  1486. def release(self) -> None:
  1487. if self._max_value is not None and self._value == self._max_value:
  1488. raise ValueError("semaphore released too many times")
  1489. for fut in self._waiters:
  1490. if not fut.cancelled():
  1491. fut.set_result(None)
  1492. self._waiters.remove(fut)
  1493. return
  1494. self._value += 1
  1495. @property
  1496. def value(self) -> int:
  1497. return self._value
  1498. @property
  1499. def max_value(self) -> int | None:
  1500. return self._max_value
  1501. def statistics(self) -> SemaphoreStatistics:
  1502. return SemaphoreStatistics(len(self._waiters))
  1503. class CapacityLimiter(BaseCapacityLimiter):
  1504. _total_tokens: float = 0
  1505. def __new__(cls, total_tokens: float) -> CapacityLimiter:
  1506. return object.__new__(cls)
  1507. def __init__(self, total_tokens: float):
  1508. self._borrowers: set[Any] = set()
  1509. self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
  1510. self.total_tokens = total_tokens
  1511. async def __aenter__(self) -> None:
  1512. await self.acquire()
  1513. async def __aexit__(
  1514. self,
  1515. exc_type: type[BaseException] | None,
  1516. exc_val: BaseException | None,
  1517. exc_tb: TracebackType | None,
  1518. ) -> None:
  1519. self.release()
  1520. @property
  1521. def total_tokens(self) -> float:
  1522. return self._total_tokens
  1523. @total_tokens.setter
  1524. def total_tokens(self, value: float) -> None:
  1525. if not isinstance(value, int) and not math.isinf(value):
  1526. raise TypeError("total_tokens must be an int or math.inf")
  1527. if value < 1:
  1528. raise ValueError("total_tokens must be >= 1")
  1529. waiters_to_notify = max(value - self._total_tokens, 0)
  1530. self._total_tokens = value
  1531. # Notify waiting tasks that they have acquired the limiter
  1532. while self._wait_queue and waiters_to_notify:
  1533. event = self._wait_queue.popitem(last=False)[1]
  1534. event.set()
  1535. waiters_to_notify -= 1
  1536. @property
  1537. def borrowed_tokens(self) -> int:
  1538. return len(self._borrowers)
  1539. @property
  1540. def available_tokens(self) -> float:
  1541. return self._total_tokens - len(self._borrowers)
  1542. def acquire_nowait(self) -> None:
  1543. self.acquire_on_behalf_of_nowait(current_task())
  1544. def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
  1545. if borrower in self._borrowers:
  1546. raise RuntimeError(
  1547. "this borrower is already holding one of this CapacityLimiter's "
  1548. "tokens"
  1549. )
  1550. if self._wait_queue or len(self._borrowers) >= self._total_tokens:
  1551. raise WouldBlock
  1552. self._borrowers.add(borrower)
  1553. async def acquire(self) -> None:
  1554. return await self.acquire_on_behalf_of(current_task())
  1555. async def acquire_on_behalf_of(self, borrower: object) -> None:
  1556. await AsyncIOBackend.checkpoint_if_cancelled()
  1557. try:
  1558. self.acquire_on_behalf_of_nowait(borrower)
  1559. except WouldBlock:
  1560. event = asyncio.Event()
  1561. self._wait_queue[borrower] = event
  1562. try:
  1563. await event.wait()
  1564. except BaseException:
  1565. self._wait_queue.pop(borrower, None)
  1566. raise
  1567. self._borrowers.add(borrower)
  1568. else:
  1569. try:
  1570. await AsyncIOBackend.cancel_shielded_checkpoint()
  1571. except BaseException:
  1572. self.release()
  1573. raise
  1574. def release(self) -> None:
  1575. self.release_on_behalf_of(current_task())
  1576. def release_on_behalf_of(self, borrower: object) -> None:
  1577. try:
  1578. self._borrowers.remove(borrower)
  1579. except KeyError:
  1580. raise RuntimeError(
  1581. "this borrower isn't holding any of this CapacityLimiter's tokens"
  1582. ) from None
  1583. # Notify the next task in line if this limiter has free capacity now
  1584. if self._wait_queue and len(self._borrowers) < self._total_tokens:
  1585. event = self._wait_queue.popitem(last=False)[1]
  1586. event.set()
  1587. def statistics(self) -> CapacityLimiterStatistics:
  1588. return CapacityLimiterStatistics(
  1589. self.borrowed_tokens,
  1590. self.total_tokens,
  1591. tuple(self._borrowers),
  1592. len(self._wait_queue),
  1593. )
  1594. _default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
  1595. #
  1596. # Operating system signals
  1597. #
  1598. class _SignalReceiver:
  1599. def __init__(self, signals: tuple[Signals, ...]):
  1600. self._signals = signals
  1601. self._loop = get_running_loop()
  1602. self._signal_queue: deque[Signals] = deque()
  1603. self._future: asyncio.Future = asyncio.Future()
  1604. self._handled_signals: set[Signals] = set()
  1605. def _deliver(self, signum: Signals) -> None:
  1606. self._signal_queue.append(signum)
  1607. if not self._future.done():
  1608. self._future.set_result(None)
  1609. def __enter__(self) -> _SignalReceiver:
  1610. for sig in set(self._signals):
  1611. self._loop.add_signal_handler(sig, self._deliver, sig)
  1612. self._handled_signals.add(sig)
  1613. return self
  1614. def __exit__(
  1615. self,
  1616. exc_type: type[BaseException] | None,
  1617. exc_val: BaseException | None,
  1618. exc_tb: TracebackType | None,
  1619. ) -> bool | None:
  1620. for sig in self._handled_signals:
  1621. self._loop.remove_signal_handler(sig)
  1622. return None
  1623. def __aiter__(self) -> _SignalReceiver:
  1624. return self
  1625. async def __anext__(self) -> Signals:
  1626. await AsyncIOBackend.checkpoint()
  1627. if not self._signal_queue:
  1628. self._future = asyncio.Future()
  1629. await self._future
  1630. return self._signal_queue.popleft()
  1631. #
  1632. # Testing and debugging
  1633. #
  1634. class AsyncIOTaskInfo(TaskInfo):
  1635. def __init__(self, task: asyncio.Task):
  1636. task_state = _task_states.get(task)
  1637. if task_state is None:
  1638. parent_id = None
  1639. else:
  1640. parent_id = task_state.parent_id
  1641. super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
  1642. self._task = weakref.ref(task)
  1643. def has_pending_cancellation(self) -> bool:
  1644. if not (task := self._task()):
  1645. # If the task isn't around anymore, it won't have a pending cancellation
  1646. return False
  1647. if sys.version_info >= (3, 11):
  1648. if task.cancelling():
  1649. return True
  1650. elif (
  1651. isinstance(task._fut_waiter, asyncio.Future)
  1652. and task._fut_waiter.cancelled()
  1653. ):
  1654. return True
  1655. if task_state := _task_states.get(task):
  1656. if cancel_scope := task_state.cancel_scope:
  1657. return cancel_scope.cancel_called or (
  1658. not cancel_scope.shield and cancel_scope._parent_cancelled()
  1659. )
  1660. return False
  1661. class TestRunner(abc.TestRunner):
  1662. _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
  1663. def __init__(
  1664. self,
  1665. *,
  1666. debug: bool | None = None,
  1667. use_uvloop: bool = False,
  1668. loop_factory: Callable[[], AbstractEventLoop] | None = None,
  1669. ) -> None:
  1670. if use_uvloop and loop_factory is None:
  1671. import uvloop
  1672. loop_factory = uvloop.new_event_loop
  1673. self._runner = Runner(debug=debug, loop_factory=loop_factory)
  1674. self._exceptions: list[BaseException] = []
  1675. self._runner_task: asyncio.Task | None = None
  1676. def __enter__(self) -> TestRunner:
  1677. self._runner.__enter__()
  1678. self.get_loop().set_exception_handler(self._exception_handler)
  1679. return self
  1680. def __exit__(
  1681. self,
  1682. exc_type: type[BaseException] | None,
  1683. exc_val: BaseException | None,
  1684. exc_tb: TracebackType | None,
  1685. ) -> None:
  1686. self._runner.__exit__(exc_type, exc_val, exc_tb)
  1687. def get_loop(self) -> AbstractEventLoop:
  1688. return self._runner.get_loop()
  1689. def _exception_handler(
  1690. self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
  1691. ) -> None:
  1692. if isinstance(context.get("exception"), Exception):
  1693. self._exceptions.append(context["exception"])
  1694. else:
  1695. loop.default_exception_handler(context)
  1696. def _raise_async_exceptions(self) -> None:
  1697. # Re-raise any exceptions raised in asynchronous callbacks
  1698. if self._exceptions:
  1699. exceptions, self._exceptions = self._exceptions, []
  1700. if len(exceptions) == 1:
  1701. raise exceptions[0]
  1702. elif exceptions:
  1703. raise BaseExceptionGroup(
  1704. "Multiple exceptions occurred in asynchronous callbacks", exceptions
  1705. )
  1706. async def _run_tests_and_fixtures(
  1707. self,
  1708. receive_stream: MemoryObjectReceiveStream[
  1709. tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
  1710. ],
  1711. ) -> None:
  1712. from _pytest.outcomes import OutcomeException
  1713. with receive_stream, self._send_stream:
  1714. async for coro, future in receive_stream:
  1715. try:
  1716. retval = await coro
  1717. except CancelledError as exc:
  1718. if not future.cancelled():
  1719. future.cancel(*exc.args)
  1720. raise
  1721. except BaseException as exc:
  1722. if not future.cancelled():
  1723. future.set_exception(exc)
  1724. if not isinstance(exc, (Exception, OutcomeException)):
  1725. raise
  1726. else:
  1727. if not future.cancelled():
  1728. future.set_result(retval)
  1729. async def _call_in_runner_task(
  1730. self,
  1731. func: Callable[P, Awaitable[T_Retval]],
  1732. *args: P.args,
  1733. **kwargs: P.kwargs,
  1734. ) -> T_Retval:
  1735. if not self._runner_task:
  1736. self._send_stream, receive_stream = create_memory_object_stream[
  1737. Tuple[Awaitable[Any], asyncio.Future]
  1738. ](1)
  1739. self._runner_task = self.get_loop().create_task(
  1740. self._run_tests_and_fixtures(receive_stream)
  1741. )
  1742. coro = func(*args, **kwargs)
  1743. future: asyncio.Future[T_Retval] = self.get_loop().create_future()
  1744. self._send_stream.send_nowait((coro, future))
  1745. return await future
  1746. def run_asyncgen_fixture(
  1747. self,
  1748. fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
  1749. kwargs: dict[str, Any],
  1750. ) -> Iterable[T_Retval]:
  1751. asyncgen = fixture_func(**kwargs)
  1752. fixturevalue: T_Retval = self.get_loop().run_until_complete(
  1753. self._call_in_runner_task(asyncgen.asend, None)
  1754. )
  1755. self._raise_async_exceptions()
  1756. yield fixturevalue
  1757. try:
  1758. self.get_loop().run_until_complete(
  1759. self._call_in_runner_task(asyncgen.asend, None)
  1760. )
  1761. except StopAsyncIteration:
  1762. self._raise_async_exceptions()
  1763. else:
  1764. self.get_loop().run_until_complete(asyncgen.aclose())
  1765. raise RuntimeError("Async generator fixture did not stop")
  1766. def run_fixture(
  1767. self,
  1768. fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
  1769. kwargs: dict[str, Any],
  1770. ) -> T_Retval:
  1771. retval = self.get_loop().run_until_complete(
  1772. self._call_in_runner_task(fixture_func, **kwargs)
  1773. )
  1774. self._raise_async_exceptions()
  1775. return retval
  1776. def run_test(
  1777. self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
  1778. ) -> None:
  1779. try:
  1780. self.get_loop().run_until_complete(
  1781. self._call_in_runner_task(test_func, **kwargs)
  1782. )
  1783. except Exception as exc:
  1784. self._exceptions.append(exc)
  1785. self._raise_async_exceptions()
  1786. class AsyncIOBackend(AsyncBackend):
  1787. @classmethod
  1788. def run(
  1789. cls,
  1790. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  1791. args: tuple[Unpack[PosArgsT]],
  1792. kwargs: dict[str, Any],
  1793. options: dict[str, Any],
  1794. ) -> T_Retval:
  1795. @wraps(func)
  1796. async def wrapper() -> T_Retval:
  1797. task = cast(asyncio.Task, current_task())
  1798. task.set_name(get_callable_name(func))
  1799. _task_states[task] = TaskState(None, None)
  1800. try:
  1801. return await func(*args)
  1802. finally:
  1803. del _task_states[task]
  1804. debug = options.get("debug", None)
  1805. loop_factory = options.get("loop_factory", None)
  1806. if loop_factory is None and options.get("use_uvloop", False):
  1807. import uvloop
  1808. loop_factory = uvloop.new_event_loop
  1809. with Runner(debug=debug, loop_factory=loop_factory) as runner:
  1810. return runner.run(wrapper())
  1811. @classmethod
  1812. def current_token(cls) -> object:
  1813. return get_running_loop()
  1814. @classmethod
  1815. def current_time(cls) -> float:
  1816. return get_running_loop().time()
  1817. @classmethod
  1818. def cancelled_exception_class(cls) -> type[BaseException]:
  1819. return CancelledError
  1820. @classmethod
  1821. async def checkpoint(cls) -> None:
  1822. await sleep(0)
  1823. @classmethod
  1824. async def checkpoint_if_cancelled(cls) -> None:
  1825. task = current_task()
  1826. if task is None:
  1827. return
  1828. try:
  1829. cancel_scope = _task_states[task].cancel_scope
  1830. except KeyError:
  1831. return
  1832. while cancel_scope:
  1833. if cancel_scope.cancel_called:
  1834. await sleep(0)
  1835. elif cancel_scope.shield:
  1836. break
  1837. else:
  1838. cancel_scope = cancel_scope._parent_scope
  1839. @classmethod
  1840. async def cancel_shielded_checkpoint(cls) -> None:
  1841. with CancelScope(shield=True):
  1842. await sleep(0)
  1843. @classmethod
  1844. async def sleep(cls, delay: float) -> None:
  1845. await sleep(delay)
  1846. @classmethod
  1847. def create_cancel_scope(
  1848. cls, *, deadline: float = math.inf, shield: bool = False
  1849. ) -> CancelScope:
  1850. return CancelScope(deadline=deadline, shield=shield)
  1851. @classmethod
  1852. def current_effective_deadline(cls) -> float:
  1853. try:
  1854. cancel_scope = _task_states[
  1855. current_task() # type: ignore[index]
  1856. ].cancel_scope
  1857. except KeyError:
  1858. return math.inf
  1859. deadline = math.inf
  1860. while cancel_scope:
  1861. deadline = min(deadline, cancel_scope.deadline)
  1862. if cancel_scope._cancel_called:
  1863. deadline = -math.inf
  1864. break
  1865. elif cancel_scope.shield:
  1866. break
  1867. else:
  1868. cancel_scope = cancel_scope._parent_scope
  1869. return deadline
  1870. @classmethod
  1871. def create_task_group(cls) -> abc.TaskGroup:
  1872. return TaskGroup()
  1873. @classmethod
  1874. def create_event(cls) -> abc.Event:
  1875. return Event()
  1876. @classmethod
  1877. def create_lock(cls, *, fast_acquire: bool) -> abc.Lock:
  1878. return Lock(fast_acquire=fast_acquire)
  1879. @classmethod
  1880. def create_semaphore(
  1881. cls,
  1882. initial_value: int,
  1883. *,
  1884. max_value: int | None = None,
  1885. fast_acquire: bool = False,
  1886. ) -> abc.Semaphore:
  1887. return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
  1888. @classmethod
  1889. def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
  1890. return CapacityLimiter(total_tokens)
  1891. @classmethod
  1892. async def run_sync_in_worker_thread(
  1893. cls,
  1894. func: Callable[[Unpack[PosArgsT]], T_Retval],
  1895. args: tuple[Unpack[PosArgsT]],
  1896. abandon_on_cancel: bool = False,
  1897. limiter: abc.CapacityLimiter | None = None,
  1898. ) -> T_Retval:
  1899. await cls.checkpoint()
  1900. # If this is the first run in this event loop thread, set up the necessary
  1901. # variables
  1902. try:
  1903. idle_workers = _threadpool_idle_workers.get()
  1904. workers = _threadpool_workers.get()
  1905. except LookupError:
  1906. idle_workers = deque()
  1907. workers = set()
  1908. _threadpool_idle_workers.set(idle_workers)
  1909. _threadpool_workers.set(workers)
  1910. async with limiter or cls.current_default_thread_limiter():
  1911. with CancelScope(shield=not abandon_on_cancel) as scope:
  1912. future: asyncio.Future = asyncio.Future()
  1913. root_task = find_root_task()
  1914. if not idle_workers:
  1915. worker = WorkerThread(root_task, workers, idle_workers)
  1916. worker.start()
  1917. workers.add(worker)
  1918. root_task.add_done_callback(worker.stop)
  1919. else:
  1920. worker = idle_workers.pop()
  1921. # Prune any other workers that have been idle for MAX_IDLE_TIME
  1922. # seconds or longer
  1923. now = cls.current_time()
  1924. while idle_workers:
  1925. if (
  1926. now - idle_workers[0].idle_since
  1927. < WorkerThread.MAX_IDLE_TIME
  1928. ):
  1929. break
  1930. expired_worker = idle_workers.popleft()
  1931. expired_worker.root_task.remove_done_callback(
  1932. expired_worker.stop
  1933. )
  1934. expired_worker.stop()
  1935. context = copy_context()
  1936. context.run(sniffio.current_async_library_cvar.set, None)
  1937. if abandon_on_cancel or scope._parent_scope is None:
  1938. worker_scope = scope
  1939. else:
  1940. worker_scope = scope._parent_scope
  1941. worker.queue.put_nowait((context, func, args, future, worker_scope))
  1942. return await future
  1943. @classmethod
  1944. def check_cancelled(cls) -> None:
  1945. scope: CancelScope | None = threadlocals.current_cancel_scope
  1946. while scope is not None:
  1947. if scope.cancel_called:
  1948. raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
  1949. if scope.shield:
  1950. return
  1951. scope = scope._parent_scope
  1952. @classmethod
  1953. def run_async_from_thread(
  1954. cls,
  1955. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  1956. args: tuple[Unpack[PosArgsT]],
  1957. token: object,
  1958. ) -> T_Retval:
  1959. async def task_wrapper(scope: CancelScope) -> T_Retval:
  1960. __tracebackhide__ = True
  1961. task = cast(asyncio.Task, current_task())
  1962. _task_states[task] = TaskState(None, scope)
  1963. scope._tasks.add(task)
  1964. try:
  1965. return await func(*args)
  1966. except CancelledError as exc:
  1967. raise concurrent.futures.CancelledError(str(exc)) from None
  1968. finally:
  1969. scope._tasks.discard(task)
  1970. loop = cast(AbstractEventLoop, token)
  1971. context = copy_context()
  1972. context.run(sniffio.current_async_library_cvar.set, "asyncio")
  1973. wrapper = task_wrapper(threadlocals.current_cancel_scope)
  1974. f: concurrent.futures.Future[T_Retval] = context.run(
  1975. asyncio.run_coroutine_threadsafe, wrapper, loop
  1976. )
  1977. return f.result()
  1978. @classmethod
  1979. def run_sync_from_thread(
  1980. cls,
  1981. func: Callable[[Unpack[PosArgsT]], T_Retval],
  1982. args: tuple[Unpack[PosArgsT]],
  1983. token: object,
  1984. ) -> T_Retval:
  1985. @wraps(func)
  1986. def wrapper() -> None:
  1987. try:
  1988. sniffio.current_async_library_cvar.set("asyncio")
  1989. f.set_result(func(*args))
  1990. except BaseException as exc:
  1991. f.set_exception(exc)
  1992. if not isinstance(exc, Exception):
  1993. raise
  1994. f: concurrent.futures.Future[T_Retval] = Future()
  1995. loop = cast(AbstractEventLoop, token)
  1996. loop.call_soon_threadsafe(wrapper)
  1997. return f.result()
  1998. @classmethod
  1999. def create_blocking_portal(cls) -> abc.BlockingPortal:
  2000. return BlockingPortal()
  2001. @classmethod
  2002. async def open_process(
  2003. cls,
  2004. command: StrOrBytesPath | Sequence[StrOrBytesPath],
  2005. *,
  2006. stdin: int | IO[Any] | None,
  2007. stdout: int | IO[Any] | None,
  2008. stderr: int | IO[Any] | None,
  2009. **kwargs: Any,
  2010. ) -> Process:
  2011. await cls.checkpoint()
  2012. if isinstance(command, PathLike):
  2013. command = os.fspath(command)
  2014. if isinstance(command, (str, bytes)):
  2015. process = await asyncio.create_subprocess_shell(
  2016. command,
  2017. stdin=stdin,
  2018. stdout=stdout,
  2019. stderr=stderr,
  2020. **kwargs,
  2021. )
  2022. else:
  2023. process = await asyncio.create_subprocess_exec(
  2024. *command,
  2025. stdin=stdin,
  2026. stdout=stdout,
  2027. stderr=stderr,
  2028. **kwargs,
  2029. )
  2030. stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
  2031. stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
  2032. stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
  2033. return Process(process, stdin_stream, stdout_stream, stderr_stream)
  2034. @classmethod
  2035. def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
  2036. create_task(
  2037. _shutdown_process_pool_on_exit(workers),
  2038. name="AnyIO process pool shutdown task",
  2039. )
  2040. find_root_task().add_done_callback(
  2041. partial(_forcibly_shutdown_process_pool_on_exit, workers) # type:ignore[arg-type]
  2042. )
  2043. @classmethod
  2044. async def connect_tcp(
  2045. cls, host: str, port: int, local_address: IPSockAddrType | None = None
  2046. ) -> abc.SocketStream:
  2047. transport, protocol = cast(
  2048. Tuple[asyncio.Transport, StreamProtocol],
  2049. await get_running_loop().create_connection(
  2050. StreamProtocol, host, port, local_addr=local_address
  2051. ),
  2052. )
  2053. transport.pause_reading()
  2054. return SocketStream(transport, protocol)
  2055. @classmethod
  2056. async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
  2057. await cls.checkpoint()
  2058. loop = get_running_loop()
  2059. raw_socket = socket.socket(socket.AF_UNIX)
  2060. raw_socket.setblocking(False)
  2061. while True:
  2062. try:
  2063. raw_socket.connect(path)
  2064. except BlockingIOError:
  2065. f: asyncio.Future = asyncio.Future()
  2066. loop.add_writer(raw_socket, f.set_result, None)
  2067. f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
  2068. await f
  2069. except BaseException:
  2070. raw_socket.close()
  2071. raise
  2072. else:
  2073. return UNIXSocketStream(raw_socket)
  2074. @classmethod
  2075. def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
  2076. return TCPSocketListener(sock)
  2077. @classmethod
  2078. def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
  2079. return UNIXSocketListener(sock)
  2080. @classmethod
  2081. async def create_udp_socket(
  2082. cls,
  2083. family: AddressFamily,
  2084. local_address: IPSockAddrType | None,
  2085. remote_address: IPSockAddrType | None,
  2086. reuse_port: bool,
  2087. ) -> UDPSocket | ConnectedUDPSocket:
  2088. transport, protocol = await get_running_loop().create_datagram_endpoint(
  2089. DatagramProtocol,
  2090. local_addr=local_address,
  2091. remote_addr=remote_address,
  2092. family=family,
  2093. reuse_port=reuse_port,
  2094. )
  2095. if protocol.exception:
  2096. transport.close()
  2097. raise protocol.exception
  2098. if not remote_address:
  2099. return UDPSocket(transport, protocol)
  2100. else:
  2101. return ConnectedUDPSocket(transport, protocol)
  2102. @classmethod
  2103. async def create_unix_datagram_socket( # type: ignore[override]
  2104. cls, raw_socket: socket.socket, remote_path: str | bytes | None
  2105. ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
  2106. await cls.checkpoint()
  2107. loop = get_running_loop()
  2108. if remote_path:
  2109. while True:
  2110. try:
  2111. raw_socket.connect(remote_path)
  2112. except BlockingIOError:
  2113. f: asyncio.Future = asyncio.Future()
  2114. loop.add_writer(raw_socket, f.set_result, None)
  2115. f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
  2116. await f
  2117. except BaseException:
  2118. raw_socket.close()
  2119. raise
  2120. else:
  2121. return ConnectedUNIXDatagramSocket(raw_socket)
  2122. else:
  2123. return UNIXDatagramSocket(raw_socket)
  2124. @classmethod
  2125. async def getaddrinfo(
  2126. cls,
  2127. host: bytes | str | None,
  2128. port: str | int | None,
  2129. *,
  2130. family: int | AddressFamily = 0,
  2131. type: int | SocketKind = 0,
  2132. proto: int = 0,
  2133. flags: int = 0,
  2134. ) -> list[
  2135. tuple[
  2136. AddressFamily,
  2137. SocketKind,
  2138. int,
  2139. str,
  2140. tuple[str, int] | tuple[str, int, int, int],
  2141. ]
  2142. ]:
  2143. return await get_running_loop().getaddrinfo(
  2144. host, port, family=family, type=type, proto=proto, flags=flags
  2145. )
  2146. @classmethod
  2147. async def getnameinfo(
  2148. cls, sockaddr: IPSockAddrType, flags: int = 0
  2149. ) -> tuple[str, str]:
  2150. return await get_running_loop().getnameinfo(sockaddr, flags)
  2151. @classmethod
  2152. async def wait_socket_readable(cls, sock: socket.socket) -> None:
  2153. await cls.checkpoint()
  2154. try:
  2155. read_events = _read_events.get()
  2156. except LookupError:
  2157. read_events = {}
  2158. _read_events.set(read_events)
  2159. if read_events.get(sock):
  2160. raise BusyResourceError("reading from") from None
  2161. loop = get_running_loop()
  2162. event = read_events[sock] = asyncio.Event()
  2163. loop.add_reader(sock, event.set)
  2164. try:
  2165. await event.wait()
  2166. finally:
  2167. if read_events.pop(sock, None) is not None:
  2168. loop.remove_reader(sock)
  2169. readable = True
  2170. else:
  2171. readable = False
  2172. if not readable:
  2173. raise ClosedResourceError
  2174. @classmethod
  2175. async def wait_socket_writable(cls, sock: socket.socket) -> None:
  2176. await cls.checkpoint()
  2177. try:
  2178. write_events = _write_events.get()
  2179. except LookupError:
  2180. write_events = {}
  2181. _write_events.set(write_events)
  2182. if write_events.get(sock):
  2183. raise BusyResourceError("writing to") from None
  2184. loop = get_running_loop()
  2185. event = write_events[sock] = asyncio.Event()
  2186. loop.add_writer(sock.fileno(), event.set)
  2187. try:
  2188. await event.wait()
  2189. finally:
  2190. if write_events.pop(sock, None) is not None:
  2191. loop.remove_writer(sock)
  2192. writable = True
  2193. else:
  2194. writable = False
  2195. if not writable:
  2196. raise ClosedResourceError
  2197. @classmethod
  2198. def current_default_thread_limiter(cls) -> CapacityLimiter:
  2199. try:
  2200. return _default_thread_limiter.get()
  2201. except LookupError:
  2202. limiter = CapacityLimiter(40)
  2203. _default_thread_limiter.set(limiter)
  2204. return limiter
  2205. @classmethod
  2206. def open_signal_receiver(
  2207. cls, *signals: Signals
  2208. ) -> ContextManager[AsyncIterator[Signals]]:
  2209. return _SignalReceiver(signals)
  2210. @classmethod
  2211. def get_current_task(cls) -> TaskInfo:
  2212. return AsyncIOTaskInfo(current_task()) # type: ignore[arg-type]
  2213. @classmethod
  2214. def get_running_tasks(cls) -> Sequence[TaskInfo]:
  2215. return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
  2216. @classmethod
  2217. async def wait_all_tasks_blocked(cls) -> None:
  2218. await cls.checkpoint()
  2219. this_task = current_task()
  2220. while True:
  2221. for task in all_tasks():
  2222. if task is this_task:
  2223. continue
  2224. waiter = task._fut_waiter # type: ignore[attr-defined]
  2225. if waiter is None or waiter.done():
  2226. await sleep(0.1)
  2227. break
  2228. else:
  2229. return
  2230. @classmethod
  2231. def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
  2232. return TestRunner(**options)
  2233. backend_class = AsyncIOBackend