authentication.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from __future__ import annotations
  2. import functools
  3. import inspect
  4. import sys
  5. import typing
  6. from urllib.parse import urlencode
  7. if sys.version_info >= (3, 10): # pragma: no cover
  8. from typing import ParamSpec
  9. else: # pragma: no cover
  10. from typing_extensions import ParamSpec
  11. from starlette._utils import is_async_callable
  12. from starlette.exceptions import HTTPException
  13. from starlette.requests import HTTPConnection, Request
  14. from starlette.responses import RedirectResponse
  15. from starlette.websockets import WebSocket
  16. _P = ParamSpec("_P")
  17. def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
  18. for scope in scopes:
  19. if scope not in conn.auth.scopes:
  20. return False
  21. return True
  22. def requires(
  23. scopes: str | typing.Sequence[str],
  24. status_code: int = 403,
  25. redirect: str | None = None,
  26. ) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
  27. scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
  28. def decorator(
  29. func: typing.Callable[_P, typing.Any],
  30. ) -> typing.Callable[_P, typing.Any]:
  31. sig = inspect.signature(func)
  32. for idx, parameter in enumerate(sig.parameters.values()):
  33. if parameter.name == "request" or parameter.name == "websocket":
  34. type_ = parameter.name
  35. break
  36. else:
  37. raise Exception(f'No "request" or "websocket" argument on function "{func}"')
  38. if type_ == "websocket":
  39. # Handle websocket functions. (Always async)
  40. @functools.wraps(func)
  41. async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
  42. websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
  43. assert isinstance(websocket, WebSocket)
  44. if not has_required_scope(websocket, scopes_list):
  45. await websocket.close()
  46. else:
  47. await func(*args, **kwargs)
  48. return websocket_wrapper
  49. elif is_async_callable(func):
  50. # Handle async request/response functions.
  51. @functools.wraps(func)
  52. async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
  53. request = kwargs.get("request", args[idx] if idx < len(args) else None)
  54. assert isinstance(request, Request)
  55. if not has_required_scope(request, scopes_list):
  56. if redirect is not None:
  57. orig_request_qparam = urlencode({"next": str(request.url)})
  58. next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
  59. return RedirectResponse(url=next_url, status_code=303)
  60. raise HTTPException(status_code=status_code)
  61. return await func(*args, **kwargs)
  62. return async_wrapper
  63. else:
  64. # Handle sync request/response functions.
  65. @functools.wraps(func)
  66. def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
  67. request = kwargs.get("request", args[idx] if idx < len(args) else None)
  68. assert isinstance(request, Request)
  69. if not has_required_scope(request, scopes_list):
  70. if redirect is not None:
  71. orig_request_qparam = urlencode({"next": str(request.url)})
  72. next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
  73. return RedirectResponse(url=next_url, status_code=303)
  74. raise HTTPException(status_code=status_code)
  75. return func(*args, **kwargs)
  76. return sync_wrapper
  77. return decorator
  78. class AuthenticationError(Exception):
  79. pass
  80. class AuthenticationBackend:
  81. async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
  82. raise NotImplementedError() # pragma: no cover
  83. class AuthCredentials:
  84. def __init__(self, scopes: typing.Sequence[str] | None = None):
  85. self.scopes = [] if scopes is None else list(scopes)
  86. class BaseUser:
  87. @property
  88. def is_authenticated(self) -> bool:
  89. raise NotImplementedError() # pragma: no cover
  90. @property
  91. def display_name(self) -> str:
  92. raise NotImplementedError() # pragma: no cover
  93. @property
  94. def identity(self) -> str:
  95. raise NotImplementedError() # pragma: no cover
  96. class SimpleUser(BaseUser):
  97. def __init__(self, username: str) -> None:
  98. self.username = username
  99. @property
  100. def is_authenticated(self) -> bool:
  101. return True
  102. @property
  103. def display_name(self) -> str:
  104. return self.username
  105. class UnauthenticatedUser(BaseUser):
  106. @property
  107. def is_authenticated(self) -> bool:
  108. return False
  109. @property
  110. def display_name(self) -> str:
  111. return ""