218 lines
8.6 KiB
Python
218 lines
8.6 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import typing
|
||
|
|
||
|
import anyio
|
||
|
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
||
|
|
||
|
from starlette._utils import collapse_excgroups
|
||
|
from starlette.background import BackgroundTask
|
||
|
from starlette.requests import ClientDisconnect, Request
|
||
|
from starlette.responses import ContentStream, Response, StreamingResponse
|
||
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||
|
|
||
|
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
|
||
|
DispatchFunction = typing.Callable[
|
||
|
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
|
||
|
]
|
||
|
T = typing.TypeVar("T")
|
||
|
|
||
|
|
||
|
class _CachedRequest(Request):
|
||
|
"""
|
||
|
If the user calls Request.body() from their dispatch function
|
||
|
we cache the entire request body in memory and pass that to downstream middlewares,
|
||
|
but if they call Request.stream() then all we do is send an
|
||
|
empty body so that downstream things don't hang forever.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, scope: Scope, receive: Receive):
|
||
|
super().__init__(scope, receive)
|
||
|
self._wrapped_rcv_disconnected = False
|
||
|
self._wrapped_rcv_consumed = False
|
||
|
self._wrapped_rc_stream = self.stream()
|
||
|
|
||
|
async def wrapped_receive(self) -> Message:
|
||
|
# wrapped_rcv state 1: disconnected
|
||
|
if self._wrapped_rcv_disconnected:
|
||
|
# we've already sent a disconnect to the downstream app
|
||
|
# we don't need to wait to get another one
|
||
|
# (although most ASGI servers will just keep sending it)
|
||
|
return {"type": "http.disconnect"}
|
||
|
# wrapped_rcv state 1: consumed but not yet disconnected
|
||
|
if self._wrapped_rcv_consumed:
|
||
|
# since the downstream app has consumed us all that is left
|
||
|
# is to send it a disconnect
|
||
|
if self._is_disconnected:
|
||
|
# the middleware has already seen the disconnect
|
||
|
# since we know the client is disconnected no need to wait
|
||
|
# for the message
|
||
|
self._wrapped_rcv_disconnected = True
|
||
|
return {"type": "http.disconnect"}
|
||
|
# we don't know yet if the client is disconnected or not
|
||
|
# so we'll wait until we get that message
|
||
|
msg = await self.receive()
|
||
|
if msg["type"] != "http.disconnect": # pragma: no cover
|
||
|
# at this point a disconnect is all that we should be receiving
|
||
|
# if we get something else, things went wrong somewhere
|
||
|
raise RuntimeError(f"Unexpected message received: {msg['type']}")
|
||
|
return msg
|
||
|
|
||
|
# wrapped_rcv state 3: not yet consumed
|
||
|
if getattr(self, "_body", None) is not None:
|
||
|
# body() was called, we return it even if the client disconnected
|
||
|
self._wrapped_rcv_consumed = True
|
||
|
return {
|
||
|
"type": "http.request",
|
||
|
"body": self._body,
|
||
|
"more_body": False,
|
||
|
}
|
||
|
elif self._stream_consumed:
|
||
|
# stream() was called to completion
|
||
|
# return an empty body so that downstream apps don't hang
|
||
|
# waiting for a disconnect
|
||
|
self._wrapped_rcv_consumed = True
|
||
|
return {
|
||
|
"type": "http.request",
|
||
|
"body": b"",
|
||
|
"more_body": False,
|
||
|
}
|
||
|
else:
|
||
|
# body() was never called and stream() wasn't consumed
|
||
|
try:
|
||
|
stream = self.stream()
|
||
|
chunk = await stream.__anext__()
|
||
|
self._wrapped_rcv_consumed = self._stream_consumed
|
||
|
return {
|
||
|
"type": "http.request",
|
||
|
"body": chunk,
|
||
|
"more_body": not self._stream_consumed,
|
||
|
}
|
||
|
except ClientDisconnect:
|
||
|
self._wrapped_rcv_disconnected = True
|
||
|
return {"type": "http.disconnect"}
|
||
|
|
||
|
|
||
|
class BaseHTTPMiddleware:
|
||
|
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
|
||
|
self.app = app
|
||
|
self.dispatch_func = self.dispatch if dispatch is None else dispatch
|
||
|
|
||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||
|
if scope["type"] != "http":
|
||
|
await self.app(scope, receive, send)
|
||
|
return
|
||
|
|
||
|
request = _CachedRequest(scope, receive)
|
||
|
wrapped_receive = request.wrapped_receive
|
||
|
response_sent = anyio.Event()
|
||
|
|
||
|
async def call_next(request: Request) -> Response:
|
||
|
app_exc: Exception | None = None
|
||
|
send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
|
||
|
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
|
||
|
send_stream, recv_stream = anyio.create_memory_object_stream()
|
||
|
|
||
|
async def receive_or_disconnect() -> Message:
|
||
|
if response_sent.is_set():
|
||
|
return {"type": "http.disconnect"}
|
||
|
|
||
|
async with anyio.create_task_group() as task_group:
|
||
|
|
||
|
async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
|
||
|
result = await func()
|
||
|
task_group.cancel_scope.cancel()
|
||
|
return result
|
||
|
|
||
|
task_group.start_soon(wrap, response_sent.wait)
|
||
|
message = await wrap(wrapped_receive)
|
||
|
|
||
|
if response_sent.is_set():
|
||
|
return {"type": "http.disconnect"}
|
||
|
|
||
|
return message
|
||
|
|
||
|
async def close_recv_stream_on_response_sent() -> None:
|
||
|
await response_sent.wait()
|
||
|
recv_stream.close()
|
||
|
|
||
|
async def send_no_error(message: Message) -> None:
|
||
|
try:
|
||
|
await send_stream.send(message)
|
||
|
except anyio.BrokenResourceError:
|
||
|
# recv_stream has been closed, i.e. response_sent has been set.
|
||
|
return
|
||
|
|
||
|
async def coro() -> None:
|
||
|
nonlocal app_exc
|
||
|
|
||
|
async with send_stream:
|
||
|
try:
|
||
|
await self.app(scope, receive_or_disconnect, send_no_error)
|
||
|
except Exception as exc:
|
||
|
app_exc = exc
|
||
|
|
||
|
task_group.start_soon(close_recv_stream_on_response_sent)
|
||
|
task_group.start_soon(coro)
|
||
|
|
||
|
try:
|
||
|
message = await recv_stream.receive()
|
||
|
info = message.get("info", None)
|
||
|
if message["type"] == "http.response.debug" and info is not None:
|
||
|
message = await recv_stream.receive()
|
||
|
except anyio.EndOfStream:
|
||
|
if app_exc is not None:
|
||
|
raise app_exc
|
||
|
raise RuntimeError("No response returned.")
|
||
|
|
||
|
assert message["type"] == "http.response.start"
|
||
|
|
||
|
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
|
||
|
async with recv_stream:
|
||
|
async for message in recv_stream:
|
||
|
assert message["type"] == "http.response.body"
|
||
|
body = message.get("body", b"")
|
||
|
if body:
|
||
|
yield body
|
||
|
if not message.get("more_body", False):
|
||
|
break
|
||
|
|
||
|
if app_exc is not None:
|
||
|
raise app_exc
|
||
|
|
||
|
response = _StreamingResponse(
|
||
|
status_code=message["status"], content=body_stream(), info=info
|
||
|
)
|
||
|
response.raw_headers = message["headers"]
|
||
|
return response
|
||
|
|
||
|
with collapse_excgroups():
|
||
|
async with anyio.create_task_group() as task_group:
|
||
|
response = await self.dispatch_func(request, call_next)
|
||
|
await response(scope, wrapped_receive, send)
|
||
|
response_sent.set()
|
||
|
|
||
|
async def dispatch(
|
||
|
self, request: Request, call_next: RequestResponseEndpoint
|
||
|
) -> Response:
|
||
|
raise NotImplementedError() # pragma: no cover
|
||
|
|
||
|
|
||
|
class _StreamingResponse(StreamingResponse):
|
||
|
def __init__(
|
||
|
self,
|
||
|
content: ContentStream,
|
||
|
status_code: int = 200,
|
||
|
headers: typing.Mapping[str, str] | None = None,
|
||
|
media_type: str | None = None,
|
||
|
background: BackgroundTask | None = None,
|
||
|
info: typing.Mapping[str, typing.Any] | None = None,
|
||
|
) -> None:
|
||
|
self._info = info
|
||
|
super().__init__(content, status_code, headers, media_type, background)
|
||
|
|
||
|
async def stream_response(self, send: Send) -> None:
|
||
|
if self._info:
|
||
|
await send({"type": "http.response.debug", "info": self._info})
|
||
|
return await super().stream_response(send)
|