from_thread.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. from __future__ import annotations
  2. import sys
  3. from collections.abc import Awaitable, Callable, Generator
  4. from concurrent.futures import Future
  5. from contextlib import AbstractContextManager, contextmanager
  6. from dataclasses import dataclass, field
  7. from inspect import isawaitable
  8. from threading import Lock, Thread, get_ident
  9. from types import TracebackType
  10. from typing import (
  11. Any,
  12. AsyncContextManager,
  13. ContextManager,
  14. Generic,
  15. TypeVar,
  16. cast,
  17. overload,
  18. )
  19. from ._core import _eventloop
  20. from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals
  21. from ._core._synchronization import Event
  22. from ._core._tasks import CancelScope, create_task_group
  23. from .abc import AsyncBackend
  24. from .abc._tasks import TaskStatus
  25. if sys.version_info >= (3, 11):
  26. from typing import TypeVarTuple, Unpack
  27. else:
  28. from typing_extensions import TypeVarTuple, Unpack
  29. T_Retval = TypeVar("T_Retval")
  30. T_co = TypeVar("T_co", covariant=True)
  31. PosArgsT = TypeVarTuple("PosArgsT")
  32. def run(
  33. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT]
  34. ) -> T_Retval:
  35. """
  36. Call a coroutine function from a worker thread.
  37. :param func: a coroutine function
  38. :param args: positional arguments for the callable
  39. :return: the return value of the coroutine function
  40. """
  41. try:
  42. async_backend = threadlocals.current_async_backend
  43. token = threadlocals.current_token
  44. except AttributeError:
  45. raise RuntimeError(
  46. "This function can only be run from an AnyIO worker thread"
  47. ) from None
  48. return async_backend.run_async_from_thread(func, args, token=token)
  49. def run_sync(
  50. func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
  51. ) -> T_Retval:
  52. """
  53. Call a function in the event loop thread from a worker thread.
  54. :param func: a callable
  55. :param args: positional arguments for the callable
  56. :return: the return value of the callable
  57. """
  58. try:
  59. async_backend = threadlocals.current_async_backend
  60. token = threadlocals.current_token
  61. except AttributeError:
  62. raise RuntimeError(
  63. "This function can only be run from an AnyIO worker thread"
  64. ) from None
  65. return async_backend.run_sync_from_thread(func, args, token=token)
  66. class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
  67. _enter_future: Future[T_co]
  68. _exit_future: Future[bool | None]
  69. _exit_event: Event
  70. _exit_exc_info: tuple[
  71. type[BaseException] | None, BaseException | None, TracebackType | None
  72. ] = (None, None, None)
  73. def __init__(self, async_cm: AsyncContextManager[T_co], portal: BlockingPortal):
  74. self._async_cm = async_cm
  75. self._portal = portal
  76. async def run_async_cm(self) -> bool | None:
  77. try:
  78. self._exit_event = Event()
  79. value = await self._async_cm.__aenter__()
  80. except BaseException as exc:
  81. self._enter_future.set_exception(exc)
  82. raise
  83. else:
  84. self._enter_future.set_result(value)
  85. try:
  86. # Wait for the sync context manager to exit.
  87. # This next statement can raise `get_cancelled_exc_class()` if
  88. # something went wrong in a task group in this async context
  89. # manager.
  90. await self._exit_event.wait()
  91. finally:
  92. # In case of cancellation, it could be that we end up here before
  93. # `_BlockingAsyncContextManager.__exit__` is called, and an
  94. # `_exit_exc_info` has been set.
  95. result = await self._async_cm.__aexit__(*self._exit_exc_info)
  96. return result
  97. def __enter__(self) -> T_co:
  98. self._enter_future = Future()
  99. self._exit_future = self._portal.start_task_soon(self.run_async_cm)
  100. return self._enter_future.result()
  101. def __exit__(
  102. self,
  103. __exc_type: type[BaseException] | None,
  104. __exc_value: BaseException | None,
  105. __traceback: TracebackType | None,
  106. ) -> bool | None:
  107. self._exit_exc_info = __exc_type, __exc_value, __traceback
  108. self._portal.call(self._exit_event.set)
  109. return self._exit_future.result()
  110. class _BlockingPortalTaskStatus(TaskStatus):
  111. def __init__(self, future: Future):
  112. self._future = future
  113. def started(self, value: object = None) -> None:
  114. self._future.set_result(value)
  115. class BlockingPortal:
  116. """An object that lets external threads run code in an asynchronous event loop."""
  117. def __new__(cls) -> BlockingPortal:
  118. return get_async_backend().create_blocking_portal()
  119. def __init__(self) -> None:
  120. self._event_loop_thread_id: int | None = get_ident()
  121. self._stop_event = Event()
  122. self._task_group = create_task_group()
  123. self._cancelled_exc_class = get_cancelled_exc_class()
  124. async def __aenter__(self) -> BlockingPortal:
  125. await self._task_group.__aenter__()
  126. return self
  127. async def __aexit__(
  128. self,
  129. exc_type: type[BaseException] | None,
  130. exc_val: BaseException | None,
  131. exc_tb: TracebackType | None,
  132. ) -> bool | None:
  133. await self.stop()
  134. return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
  135. def _check_running(self) -> None:
  136. if self._event_loop_thread_id is None:
  137. raise RuntimeError("This portal is not running")
  138. if self._event_loop_thread_id == get_ident():
  139. raise RuntimeError(
  140. "This method cannot be called from the event loop thread"
  141. )
  142. async def sleep_until_stopped(self) -> None:
  143. """Sleep until :meth:`stop` is called."""
  144. await self._stop_event.wait()
  145. async def stop(self, cancel_remaining: bool = False) -> None:
  146. """
  147. Signal the portal to shut down.
  148. This marks the portal as no longer accepting new calls and exits from
  149. :meth:`sleep_until_stopped`.
  150. :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False``
  151. to let them finish before returning
  152. """
  153. self._event_loop_thread_id = None
  154. self._stop_event.set()
  155. if cancel_remaining:
  156. self._task_group.cancel_scope.cancel()
  157. async def _call_func(
  158. self,
  159. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  160. args: tuple[Unpack[PosArgsT]],
  161. kwargs: dict[str, Any],
  162. future: Future[T_Retval],
  163. ) -> None:
  164. def callback(f: Future[T_Retval]) -> None:
  165. if f.cancelled() and self._event_loop_thread_id not in (
  166. None,
  167. get_ident(),
  168. ):
  169. self.call(scope.cancel)
  170. try:
  171. retval_or_awaitable = func(*args, **kwargs)
  172. if isawaitable(retval_or_awaitable):
  173. with CancelScope() as scope:
  174. if future.cancelled():
  175. scope.cancel()
  176. else:
  177. future.add_done_callback(callback)
  178. retval = await retval_or_awaitable
  179. else:
  180. retval = retval_or_awaitable
  181. except self._cancelled_exc_class:
  182. future.cancel()
  183. future.set_running_or_notify_cancel()
  184. except BaseException as exc:
  185. if not future.cancelled():
  186. future.set_exception(exc)
  187. # Let base exceptions fall through
  188. if not isinstance(exc, Exception):
  189. raise
  190. else:
  191. if not future.cancelled():
  192. future.set_result(retval)
  193. finally:
  194. scope = None # type: ignore[assignment]
  195. def _spawn_task_from_thread(
  196. self,
  197. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  198. args: tuple[Unpack[PosArgsT]],
  199. kwargs: dict[str, Any],
  200. name: object,
  201. future: Future[T_Retval],
  202. ) -> None:
  203. """
  204. Spawn a new task using the given callable.
  205. Implementors must ensure that the future is resolved when the task finishes.
  206. :param func: a callable
  207. :param args: positional arguments to be passed to the callable
  208. :param kwargs: keyword arguments to be passed to the callable
  209. :param name: name of the task (will be coerced to a string if not ``None``)
  210. :param future: a future that will resolve to the return value of the callable,
  211. or the exception raised during its execution
  212. """
  213. raise NotImplementedError
  214. @overload
  215. def call(
  216. self,
  217. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  218. *args: Unpack[PosArgsT],
  219. ) -> T_Retval: ...
  220. @overload
  221. def call(
  222. self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
  223. ) -> T_Retval: ...
  224. def call(
  225. self,
  226. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  227. *args: Unpack[PosArgsT],
  228. ) -> T_Retval:
  229. """
  230. Call the given function in the event loop thread.
  231. If the callable returns a coroutine object, it is awaited on.
  232. :param func: any callable
  233. :raises RuntimeError: if the portal is not running or if this method is called
  234. from within the event loop thread
  235. """
  236. return cast(T_Retval, self.start_task_soon(func, *args).result())
  237. @overload
  238. def start_task_soon(
  239. self,
  240. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  241. *args: Unpack[PosArgsT],
  242. name: object = None,
  243. ) -> Future[T_Retval]: ...
  244. @overload
  245. def start_task_soon(
  246. self,
  247. func: Callable[[Unpack[PosArgsT]], T_Retval],
  248. *args: Unpack[PosArgsT],
  249. name: object = None,
  250. ) -> Future[T_Retval]: ...
  251. def start_task_soon(
  252. self,
  253. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  254. *args: Unpack[PosArgsT],
  255. name: object = None,
  256. ) -> Future[T_Retval]:
  257. """
  258. Start a task in the portal's task group.
  259. The task will be run inside a cancel scope which can be cancelled by cancelling
  260. the returned future.
  261. :param func: the target function
  262. :param args: positional arguments passed to ``func``
  263. :param name: name of the task (will be coerced to a string if not ``None``)
  264. :return: a future that resolves with the return value of the callable if the
  265. task completes successfully, or with the exception raised in the task
  266. :raises RuntimeError: if the portal is not running or if this method is called
  267. from within the event loop thread
  268. :rtype: concurrent.futures.Future[T_Retval]
  269. .. versionadded:: 3.0
  270. """
  271. self._check_running()
  272. f: Future[T_Retval] = Future()
  273. self._spawn_task_from_thread(func, args, {}, name, f)
  274. return f
  275. def start_task(
  276. self,
  277. func: Callable[..., Awaitable[T_Retval]],
  278. *args: object,
  279. name: object = None,
  280. ) -> tuple[Future[T_Retval], Any]:
  281. """
  282. Start a task in the portal's task group and wait until it signals for readiness.
  283. This method works the same way as :meth:`.abc.TaskGroup.start`.
  284. :param func: the target function
  285. :param args: positional arguments passed to ``func``
  286. :param name: name of the task (will be coerced to a string if not ``None``)
  287. :return: a tuple of (future, task_status_value) where the ``task_status_value``
  288. is the value passed to ``task_status.started()`` from within the target
  289. function
  290. :rtype: tuple[concurrent.futures.Future[T_Retval], Any]
  291. .. versionadded:: 3.0
  292. """
  293. def task_done(future: Future[T_Retval]) -> None:
  294. if not task_status_future.done():
  295. if future.cancelled():
  296. task_status_future.cancel()
  297. elif future.exception():
  298. task_status_future.set_exception(future.exception())
  299. else:
  300. exc = RuntimeError(
  301. "Task exited without calling task_status.started()"
  302. )
  303. task_status_future.set_exception(exc)
  304. self._check_running()
  305. task_status_future: Future = Future()
  306. task_status = _BlockingPortalTaskStatus(task_status_future)
  307. f: Future = Future()
  308. f.add_done_callback(task_done)
  309. self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
  310. return f, task_status_future.result()
  311. def wrap_async_context_manager(
  312. self, cm: AsyncContextManager[T_co]
  313. ) -> ContextManager[T_co]:
  314. """
  315. Wrap an async context manager as a synchronous context manager via this portal.
  316. Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping
  317. in the middle until the synchronous context manager exits.
  318. :param cm: an asynchronous context manager
  319. :return: a synchronous context manager
  320. .. versionadded:: 2.1
  321. """
  322. return _BlockingAsyncContextManager(cm, self)
  323. @dataclass
  324. class BlockingPortalProvider:
  325. """
  326. A manager for a blocking portal. Used as a context manager. The first thread to
  327. enter this context manager causes a blocking portal to be started with the specific
  328. parameters, and the last thread to exit causes the portal to be shut down. Thus,
  329. there will be exactly one blocking portal running in this context as long as at
  330. least one thread has entered this context manager.
  331. The parameters are the same as for :func:`~anyio.run`.
  332. :param backend: name of the backend
  333. :param backend_options: backend options
  334. .. versionadded:: 4.4
  335. """
  336. backend: str = "asyncio"
  337. backend_options: dict[str, Any] | None = None
  338. _lock: Lock = field(init=False, default_factory=Lock)
  339. _leases: int = field(init=False, default=0)
  340. _portal: BlockingPortal = field(init=False)
  341. _portal_cm: AbstractContextManager[BlockingPortal] | None = field(
  342. init=False, default=None
  343. )
  344. def __enter__(self) -> BlockingPortal:
  345. with self._lock:
  346. if self._portal_cm is None:
  347. self._portal_cm = start_blocking_portal(
  348. self.backend, self.backend_options
  349. )
  350. self._portal = self._portal_cm.__enter__()
  351. self._leases += 1
  352. return self._portal
  353. def __exit__(
  354. self,
  355. exc_type: type[BaseException] | None,
  356. exc_val: BaseException | None,
  357. exc_tb: TracebackType | None,
  358. ) -> None:
  359. portal_cm: AbstractContextManager[BlockingPortal] | None = None
  360. with self._lock:
  361. assert self._portal_cm
  362. assert self._leases > 0
  363. self._leases -= 1
  364. if not self._leases:
  365. portal_cm = self._portal_cm
  366. self._portal_cm = None
  367. del self._portal
  368. if portal_cm:
  369. portal_cm.__exit__(None, None, None)
  370. @contextmanager
  371. def start_blocking_portal(
  372. backend: str = "asyncio", backend_options: dict[str, Any] | None = None
  373. ) -> Generator[BlockingPortal, Any, None]:
  374. """
  375. Start a new event loop in a new thread and run a blocking portal in its main task.
  376. The parameters are the same as for :func:`~anyio.run`.
  377. :param backend: name of the backend
  378. :param backend_options: backend options
  379. :return: a context manager that yields a blocking portal
  380. .. versionchanged:: 3.0
  381. Usage as a context manager is now required.
  382. """
  383. async def run_portal() -> None:
  384. async with BlockingPortal() as portal_:
  385. future.set_result(portal_)
  386. await portal_.sleep_until_stopped()
  387. def run_blocking_portal() -> None:
  388. if future.set_running_or_notify_cancel():
  389. try:
  390. _eventloop.run(
  391. run_portal, backend=backend, backend_options=backend_options
  392. )
  393. except BaseException as exc:
  394. if not future.done():
  395. future.set_exception(exc)
  396. future: Future[BlockingPortal] = Future()
  397. thread = Thread(target=run_blocking_portal, daemon=True)
  398. thread.start()
  399. try:
  400. cancel_remaining_tasks = False
  401. portal = future.result()
  402. try:
  403. yield portal
  404. except BaseException:
  405. cancel_remaining_tasks = True
  406. raise
  407. finally:
  408. try:
  409. portal.call(portal.stop, cancel_remaining_tasks)
  410. except RuntimeError:
  411. pass
  412. finally:
  413. thread.join()
  414. def check_cancelled() -> None:
  415. """
  416. Check if the cancel scope of the host task's running the current worker thread has
  417. been cancelled.
  418. If the host task's current cancel scope has indeed been cancelled, the
  419. backend-specific cancellation exception will be raised.
  420. :raises RuntimeError: if the current thread was not spawned by
  421. :func:`.to_thread.run_sync`
  422. """
  423. try:
  424. async_backend: AsyncBackend = threadlocals.current_async_backend
  425. except AttributeError:
  426. raise RuntimeError(
  427. "This function can only be run from an AnyIO worker thread"
  428. ) from None
  429. async_backend.check_cancelled()