webdriver_template/telecli/lib/python3.11/site-packages/starlette/authentication.py
2024-08-10 17:48:21 +06:00

156 lines
4.9 KiB
Python

from __future__ import annotations
import functools
import inspect
import sys
import typing
from urllib.parse import urlencode
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
from starlette._utils import is_async_callable
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse
from starlette.websockets import WebSocket
_P = ParamSpec("_P")
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
for scope in scopes:
if scope not in conn.auth.scopes:
return False
return True
def requires(
scopes: str | typing.Sequence[str],
status_code: int = 403,
redirect: str | None = None,
) -> typing.Callable[
[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]
]:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(
func: typing.Callable[_P, typing.Any],
) -> typing.Callable[_P, typing.Any]:
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
type_ = parameter.name
break
else:
raise Exception(
f'No "request" or "websocket" argument on function "{func}"'
)
if type_ == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
websocket = kwargs.get(
"websocket", args[idx] if idx < len(args) else None
)
assert isinstance(websocket, WebSocket)
if not has_required_scope(websocket, scopes_list):
await websocket.close()
else:
await func(*args, **kwargs)
return websocket_wrapper
elif is_async_callable(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return await func(*args, **kwargs)
return async_wrapper
else:
# Handle sync request/response functions.
@functools.wraps(func)
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return func(*args, **kwargs)
return sync_wrapper
return decorator
class AuthenticationError(Exception):
pass
class AuthenticationBackend:
async def authenticate(
self, conn: HTTPConnection
) -> tuple[AuthCredentials, BaseUser] | None:
raise NotImplementedError() # pragma: no cover
class AuthCredentials:
def __init__(self, scopes: typing.Sequence[str] | None = None):
self.scopes = [] if scopes is None else list(scopes)
class BaseUser:
@property
def is_authenticated(self) -> bool:
raise NotImplementedError() # pragma: no cover
@property
def display_name(self) -> str:
raise NotImplementedError() # pragma: no cover
@property
def identity(self) -> str:
raise NotImplementedError() # pragma: no cover
class SimpleUser(BaseUser):
def __init__(self, username: str) -> None:
self.username = username
@property
def is_authenticated(self) -> bool:
return True
@property
def display_name(self) -> str:
return self.username
class UnauthenticatedUser(BaseUser):
@property
def is_authenticated(self) -> bool:
return False
@property
def display_name(self) -> str:
return ""