_utils.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from __future__ import annotations
  2. import asyncio
  3. import functools
  4. import sys
  5. import typing
  6. from contextlib import contextmanager
  7. from starlette.types import Scope
  8. if sys.version_info >= (3, 10): # pragma: no cover
  9. from typing import TypeGuard
  10. else: # pragma: no cover
  11. from typing_extensions import TypeGuard
  12. has_exceptiongroups = True
  13. if sys.version_info < (3, 11): # pragma: no cover
  14. try:
  15. from exceptiongroup import BaseExceptionGroup
  16. except ImportError:
  17. has_exceptiongroups = False
  18. T = typing.TypeVar("T")
  19. AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
  20. @typing.overload
  21. def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
  22. @typing.overload
  23. def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: ...
  24. def is_async_callable(obj: typing.Any) -> typing.Any:
  25. while isinstance(obj, functools.partial):
  26. obj = obj.func
  27. return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))
  28. T_co = typing.TypeVar("T_co", covariant=True)
  29. class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ...
  30. class SupportsAsyncClose(typing.Protocol):
  31. async def close(self) -> None: ... # pragma: no cover
  32. SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
  33. class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
  34. __slots__ = ("aw", "entered")
  35. def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None:
  36. self.aw = aw
  37. def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]:
  38. return self.aw.__await__()
  39. async def __aenter__(self) -> SupportsAsyncCloseType:
  40. self.entered = await self.aw
  41. return self.entered
  42. async def __aexit__(self, *args: typing.Any) -> None | bool:
  43. await self.entered.close()
  44. return None
  45. @contextmanager
  46. def collapse_excgroups() -> typing.Generator[None, None, None]:
  47. try:
  48. yield
  49. except BaseException as exc:
  50. if has_exceptiongroups: # pragma: no cover
  51. while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
  52. exc = exc.exceptions[0]
  53. raise exc
  54. def get_route_path(scope: Scope) -> str:
  55. path: str = scope["path"]
  56. root_path = scope.get("root_path", "")
  57. if not root_path:
  58. return path
  59. if not path.startswith(root_path):
  60. return path
  61. if path == root_path:
  62. return ""
  63. if path[len(root_path)] == "/":
  64. return path[len(root_path) :]
  65. return path