from __future__ import annotations import typing from starlette.authentication import ( AuthCredentials, AuthenticationBackend, AuthenticationError, UnauthenticatedUser, ) from starlette.requests import HTTPConnection from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Receive, Scope, Send class AuthenticationMiddleware: def __init__( self, app: ASGIApp, backend: AuthenticationBackend, on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None, ) -> None: self.app = app self.backend = backend self.on_error: typing.Callable[ [HTTPConnection, AuthenticationError], Response ] = on_error if on_error is not None else self.default_on_error async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ["http", "websocket"]: await self.app(scope, receive, send) return conn = HTTPConnection(scope) try: auth_result = await self.backend.authenticate(conn) except AuthenticationError as exc: response = self.on_error(conn, exc) if scope["type"] == "websocket": await send({"type": "websocket.close", "code": 1000}) else: await response(scope, receive, send) return if auth_result is None: auth_result = AuthCredentials(), UnauthenticatedUser() scope["auth"], scope["user"] = auth_result await self.app(scope, receive, send) @staticmethod def default_on_error(conn: HTTPConnection, exc: Exception) -> Response: return PlainTextResponse(str(exc), status_code=400)