from __future__ import annotations import typing from starlette.datastructures import URL, Headers from starlette.responses import PlainTextResponse, RedirectResponse, Response from starlette.types import ASGIApp, Receive, Scope, Send ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'." class TrustedHostMiddleware: def __init__( self, app: ASGIApp, allowed_hosts: typing.Sequence[str] | None = None, www_redirect: bool = True, ) -> None: if allowed_hosts is None: allowed_hosts = ["*"] for pattern in allowed_hosts: assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD if pattern.startswith("*") and pattern != "*": assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD self.app = app self.allowed_hosts = list(allowed_hosts) self.allow_any = "*" in allowed_hosts self.www_redirect = www_redirect async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.allow_any or scope["type"] not in ( "http", "websocket", ): # pragma: no cover await self.app(scope, receive, send) return headers = Headers(scope=scope) host = headers.get("host", "").split(":")[0] is_valid_host = False found_www_redirect = False for pattern in self.allowed_hosts: if host == pattern or ( pattern.startswith("*") and host.endswith(pattern[1:]) ): is_valid_host = True break elif "www." + host == pattern: found_www_redirect = True if is_valid_host: await self.app(scope, receive, send) else: response: Response if found_www_redirect and self.www_redirect: url = URL(scope=scope) redirect_url = url.replace(netloc="www." + url.netloc) response = RedirectResponse(url=str(redirect_url)) else: response = PlainTextResponse("Invalid host header", status_code=400) await response(scope, receive, send)