| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- from __future__ import annotations
- import functools
- import inspect
- import sys
- import typing
- from urllib.parse import urlencode
- if sys.version_info >= (3, 10): # pragma: no cover
- from typing import ParamSpec
- else: # pragma: no cover
- from typing_extensions import ParamSpec
- from starlette._utils import is_async_callable
- from starlette.exceptions import HTTPException
- from starlette.requests import HTTPConnection, Request
- from starlette.responses import RedirectResponse
- from starlette.websockets import WebSocket
- _P = ParamSpec("_P")
- def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
- for scope in scopes:
- if scope not in conn.auth.scopes:
- return False
- return True
- def requires(
- scopes: str | typing.Sequence[str],
- status_code: int = 403,
- redirect: str | None = None,
- ) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
- scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
- def decorator(
- func: typing.Callable[_P, typing.Any],
- ) -> typing.Callable[_P, typing.Any]:
- sig = inspect.signature(func)
- for idx, parameter in enumerate(sig.parameters.values()):
- if parameter.name == "request" or parameter.name == "websocket":
- type_ = parameter.name
- break
- else:
- raise Exception(f'No "request" or "websocket" argument on function "{func}"')
- if type_ == "websocket":
- # Handle websocket functions. (Always async)
- @functools.wraps(func)
- async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
- websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
- assert isinstance(websocket, WebSocket)
- if not has_required_scope(websocket, scopes_list):
- await websocket.close()
- else:
- await func(*args, **kwargs)
- return websocket_wrapper
- elif is_async_callable(func):
- # Handle async request/response functions.
- @functools.wraps(func)
- async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
- request = kwargs.get("request", args[idx] if idx < len(args) else None)
- assert isinstance(request, Request)
- if not has_required_scope(request, scopes_list):
- if redirect is not None:
- orig_request_qparam = urlencode({"next": str(request.url)})
- next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
- return RedirectResponse(url=next_url, status_code=303)
- raise HTTPException(status_code=status_code)
- return await func(*args, **kwargs)
- return async_wrapper
- else:
- # Handle sync request/response functions.
- @functools.wraps(func)
- def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
- request = kwargs.get("request", args[idx] if idx < len(args) else None)
- assert isinstance(request, Request)
- if not has_required_scope(request, scopes_list):
- if redirect is not None:
- orig_request_qparam = urlencode({"next": str(request.url)})
- next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
- return RedirectResponse(url=next_url, status_code=303)
- raise HTTPException(status_code=status_code)
- return func(*args, **kwargs)
- return sync_wrapper
- return decorator
- class AuthenticationError(Exception):
- pass
- class AuthenticationBackend:
- async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
- raise NotImplementedError() # pragma: no cover
- class AuthCredentials:
- def __init__(self, scopes: typing.Sequence[str] | None = None):
- self.scopes = [] if scopes is None else list(scopes)
- class BaseUser:
- @property
- def is_authenticated(self) -> bool:
- raise NotImplementedError() # pragma: no cover
- @property
- def display_name(self) -> str:
- raise NotImplementedError() # pragma: no cover
- @property
- def identity(self) -> str:
- raise NotImplementedError() # pragma: no cover
- class SimpleUser(BaseUser):
- def __init__(self, username: str) -> None:
- self.username = username
- @property
- def is_authenticated(self) -> bool:
- return True
- @property
- def display_name(self) -> str:
- return self.username
- class UnauthenticatedUser(BaseUser):
- @property
- def is_authenticated(self) -> bool:
- return False
- @property
- def display_name(self) -> str:
- return ""
|