background.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from __future__ import annotations
  2. import sys
  3. import typing
  4. if sys.version_info >= (3, 10): # pragma: no cover
  5. from typing import ParamSpec
  6. else: # pragma: no cover
  7. from typing_extensions import ParamSpec
  8. from starlette._utils import is_async_callable
  9. from starlette.concurrency import run_in_threadpool
  10. P = ParamSpec("P")
  11. class BackgroundTask:
  12. def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
  13. self.func = func
  14. self.args = args
  15. self.kwargs = kwargs
  16. self.is_async = is_async_callable(func)
  17. async def __call__(self) -> None:
  18. if self.is_async:
  19. await self.func(*self.args, **self.kwargs)
  20. else:
  21. await run_in_threadpool(self.func, *self.args, **self.kwargs)
  22. class BackgroundTasks(BackgroundTask):
  23. def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None):
  24. self.tasks = list(tasks) if tasks else []
  25. def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
  26. task = BackgroundTask(func, *args, **kwargs)
  27. self.tasks.append(task)
  28. async def __call__(self) -> None:
  29. for task in self.tasks:
  30. await task()