63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
from __future__ import annotations
|
|
|
|
import typing
|
|
|
|
from starlette.datastructures import URL, Headers
|
|
from starlette.responses import PlainTextResponse, RedirectResponse, Response
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
|
|
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
|
|
|
|
|
|
class TrustedHostMiddleware:
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
allowed_hosts: typing.Sequence[str] | None = None,
|
|
www_redirect: bool = True,
|
|
) -> None:
|
|
if allowed_hosts is None:
|
|
allowed_hosts = ["*"]
|
|
|
|
for pattern in allowed_hosts:
|
|
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
|
|
if pattern.startswith("*") and pattern != "*":
|
|
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
|
|
self.app = app
|
|
self.allowed_hosts = list(allowed_hosts)
|
|
self.allow_any = "*" in allowed_hosts
|
|
self.www_redirect = www_redirect
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if self.allow_any or scope["type"] not in (
|
|
"http",
|
|
"websocket",
|
|
): # pragma: no cover
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
headers = Headers(scope=scope)
|
|
host = headers.get("host", "").split(":")[0]
|
|
is_valid_host = False
|
|
found_www_redirect = False
|
|
for pattern in self.allowed_hosts:
|
|
if host == pattern or (
|
|
pattern.startswith("*") and host.endswith(pattern[1:])
|
|
):
|
|
is_valid_host = True
|
|
break
|
|
elif "www." + host == pattern:
|
|
found_www_redirect = True
|
|
|
|
if is_valid_host:
|
|
await self.app(scope, receive, send)
|
|
else:
|
|
response: Response
|
|
if found_www_redirect and self.www_redirect:
|
|
url = URL(scope=scope)
|
|
redirect_url = url.replace(netloc="www." + url.netloc)
|
|
response = RedirectResponse(url=str(redirect_url))
|
|
else:
|
|
response = PlainTextResponse("Invalid host header", status_code=400)
|
|
await response(scope, receive, send)
|