78 lines
2.8 KiB
Python
78 lines
2.8 KiB
Python
from __future__ import annotations
|
|
|
|
import typing
|
|
|
|
from starlette._exception_handler import (
|
|
ExceptionHandlers,
|
|
StatusHandlers,
|
|
wrap_app_handling_exceptions,
|
|
)
|
|
from starlette.exceptions import HTTPException, WebSocketException
|
|
from starlette.requests import Request
|
|
from starlette.responses import PlainTextResponse, Response
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
from starlette.websockets import WebSocket
|
|
|
|
|
|
class ExceptionMiddleware:
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
handlers: typing.Mapping[
|
|
typing.Any, typing.Callable[[Request, Exception], Response]
|
|
]
|
|
| None = None,
|
|
debug: bool = False,
|
|
) -> None:
|
|
self.app = app
|
|
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
|
|
self._status_handlers: StatusHandlers = {}
|
|
self._exception_handlers: ExceptionHandlers = {
|
|
HTTPException: self.http_exception,
|
|
WebSocketException: self.websocket_exception,
|
|
}
|
|
if handlers is not None:
|
|
for key, value in handlers.items():
|
|
self.add_exception_handler(key, value)
|
|
|
|
def add_exception_handler(
|
|
self,
|
|
exc_class_or_status_code: int | type[Exception],
|
|
handler: typing.Callable[[Request, Exception], Response],
|
|
) -> None:
|
|
if isinstance(exc_class_or_status_code, int):
|
|
self._status_handlers[exc_class_or_status_code] = handler
|
|
else:
|
|
assert issubclass(exc_class_or_status_code, Exception)
|
|
self._exception_handlers[exc_class_or_status_code] = handler
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope["type"] not in ("http", "websocket"):
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
scope["starlette.exception_handlers"] = (
|
|
self._exception_handlers,
|
|
self._status_handlers,
|
|
)
|
|
|
|
conn: Request | WebSocket
|
|
if scope["type"] == "http":
|
|
conn = Request(scope, receive, send)
|
|
else:
|
|
conn = WebSocket(scope, receive, send)
|
|
|
|
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
|
|
|
|
def http_exception(self, request: Request, exc: Exception) -> Response:
|
|
assert isinstance(exc, HTTPException)
|
|
if exc.status_code in {204, 304}:
|
|
return Response(status_code=exc.status_code, headers=exc.headers)
|
|
return PlainTextResponse(
|
|
exc.detail, status_code=exc.status_code, headers=exc.headers
|
|
)
|
|
|
|
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
|
|
assert isinstance(exc, WebSocketException)
|
|
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
|