133 lines
5.1 KiB
Python
133 lines
5.1 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import json
|
||
|
import typing
|
||
|
|
||
|
from starlette import status
|
||
|
from starlette._utils import is_async_callable
|
||
|
from starlette.concurrency import run_in_threadpool
|
||
|
from starlette.exceptions import HTTPException
|
||
|
from starlette.requests import Request
|
||
|
from starlette.responses import PlainTextResponse, Response
|
||
|
from starlette.types import Message, Receive, Scope, Send
|
||
|
from starlette.websockets import WebSocket
|
||
|
|
||
|
|
||
|
class HTTPEndpoint:
|
||
|
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||
|
assert scope["type"] == "http"
|
||
|
self.scope = scope
|
||
|
self.receive = receive
|
||
|
self.send = send
|
||
|
self._allowed_methods = [
|
||
|
method
|
||
|
for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
|
||
|
if getattr(self, method.lower(), None) is not None
|
||
|
]
|
||
|
|
||
|
def __await__(self) -> typing.Generator[typing.Any, None, None]:
|
||
|
return self.dispatch().__await__()
|
||
|
|
||
|
async def dispatch(self) -> None:
|
||
|
request = Request(self.scope, receive=self.receive)
|
||
|
handler_name = (
|
||
|
"get"
|
||
|
if request.method == "HEAD" and not hasattr(self, "head")
|
||
|
else request.method.lower()
|
||
|
)
|
||
|
|
||
|
handler: typing.Callable[[Request], typing.Any] = getattr(
|
||
|
self, handler_name, self.method_not_allowed
|
||
|
)
|
||
|
is_async = is_async_callable(handler)
|
||
|
if is_async:
|
||
|
response = await handler(request)
|
||
|
else:
|
||
|
response = await run_in_threadpool(handler, request)
|
||
|
await response(self.scope, self.receive, self.send)
|
||
|
|
||
|
async def method_not_allowed(self, request: Request) -> Response:
|
||
|
# If we're running inside a starlette application then raise an
|
||
|
# exception, so that the configurable exception handler can deal with
|
||
|
# returning the response. For plain ASGI apps, just return the response.
|
||
|
headers = {"Allow": ", ".join(self._allowed_methods)}
|
||
|
if "app" in self.scope:
|
||
|
raise HTTPException(status_code=405, headers=headers)
|
||
|
return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
|
||
|
|
||
|
|
||
|
class WebSocketEndpoint:
|
||
|
encoding: str | None = None # May be "text", "bytes", or "json".
|
||
|
|
||
|
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||
|
assert scope["type"] == "websocket"
|
||
|
self.scope = scope
|
||
|
self.receive = receive
|
||
|
self.send = send
|
||
|
|
||
|
def __await__(self) -> typing.Generator[typing.Any, None, None]:
|
||
|
return self.dispatch().__await__()
|
||
|
|
||
|
async def dispatch(self) -> None:
|
||
|
websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
|
||
|
await self.on_connect(websocket)
|
||
|
|
||
|
close_code = status.WS_1000_NORMAL_CLOSURE
|
||
|
|
||
|
try:
|
||
|
while True:
|
||
|
message = await websocket.receive()
|
||
|
if message["type"] == "websocket.receive":
|
||
|
data = await self.decode(websocket, message)
|
||
|
await self.on_receive(websocket, data)
|
||
|
elif message["type"] == "websocket.disconnect":
|
||
|
close_code = int(
|
||
|
message.get("code") or status.WS_1000_NORMAL_CLOSURE
|
||
|
)
|
||
|
break
|
||
|
except Exception as exc:
|
||
|
close_code = status.WS_1011_INTERNAL_ERROR
|
||
|
raise exc
|
||
|
finally:
|
||
|
await self.on_disconnect(websocket, close_code)
|
||
|
|
||
|
async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
|
||
|
if self.encoding == "text":
|
||
|
if "text" not in message:
|
||
|
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||
|
raise RuntimeError("Expected text websocket messages, but got bytes")
|
||
|
return message["text"]
|
||
|
|
||
|
elif self.encoding == "bytes":
|
||
|
if "bytes" not in message:
|
||
|
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||
|
raise RuntimeError("Expected bytes websocket messages, but got text")
|
||
|
return message["bytes"]
|
||
|
|
||
|
elif self.encoding == "json":
|
||
|
if message.get("text") is not None:
|
||
|
text = message["text"]
|
||
|
else:
|
||
|
text = message["bytes"].decode("utf-8")
|
||
|
|
||
|
try:
|
||
|
return json.loads(text)
|
||
|
except json.decoder.JSONDecodeError:
|
||
|
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||
|
raise RuntimeError("Malformed JSON data received.")
|
||
|
|
||
|
assert (
|
||
|
self.encoding is None
|
||
|
), f"Unsupported 'encoding' attribute {self.encoding}"
|
||
|
return message["text"] if message.get("text") else message["bytes"]
|
||
|
|
||
|
async def on_connect(self, websocket: WebSocket) -> None:
|
||
|
"""Override to handle an incoming websocket connection"""
|
||
|
await websocket.accept()
|
||
|
|
||
|
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
|
||
|
"""Override to handle an incoming websocket message"""
|
||
|
|
||
|
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
|
||
|
"""Override to handle a disconnecting websocket"""
|