225 lines
8.6 KiB
Python
225 lines
8.6 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import enum
|
||
|
import json
|
||
|
import typing
|
||
|
|
||
|
from starlette.requests import HTTPConnection
|
||
|
from starlette.responses import Response
|
||
|
from starlette.types import Message, Receive, Scope, Send
|
||
|
|
||
|
|
||
|
class WebSocketState(enum.Enum):
|
||
|
CONNECTING = 0
|
||
|
CONNECTED = 1
|
||
|
DISCONNECTED = 2
|
||
|
RESPONSE = 3
|
||
|
|
||
|
|
||
|
class WebSocketDisconnect(Exception):
|
||
|
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
|
||
|
self.code = code
|
||
|
self.reason = reason or ""
|
||
|
|
||
|
|
||
|
class WebSocket(HTTPConnection):
|
||
|
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||
|
super().__init__(scope)
|
||
|
assert scope["type"] == "websocket"
|
||
|
self._receive = receive
|
||
|
self._send = send
|
||
|
self.client_state = WebSocketState.CONNECTING
|
||
|
self.application_state = WebSocketState.CONNECTING
|
||
|
|
||
|
async def receive(self) -> Message:
|
||
|
"""
|
||
|
Receive ASGI websocket messages, ensuring valid state transitions.
|
||
|
"""
|
||
|
if self.client_state == WebSocketState.CONNECTING:
|
||
|
message = await self._receive()
|
||
|
message_type = message["type"]
|
||
|
if message_type != "websocket.connect":
|
||
|
raise RuntimeError(
|
||
|
'Expected ASGI message "websocket.connect", '
|
||
|
f"but got {message_type!r}"
|
||
|
)
|
||
|
self.client_state = WebSocketState.CONNECTED
|
||
|
return message
|
||
|
elif self.client_state == WebSocketState.CONNECTED:
|
||
|
message = await self._receive()
|
||
|
message_type = message["type"]
|
||
|
if message_type not in {"websocket.receive", "websocket.disconnect"}:
|
||
|
raise RuntimeError(
|
||
|
'Expected ASGI message "websocket.receive" or '
|
||
|
f'"websocket.disconnect", but got {message_type!r}'
|
||
|
)
|
||
|
if message_type == "websocket.disconnect":
|
||
|
self.client_state = WebSocketState.DISCONNECTED
|
||
|
return message
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
'Cannot call "receive" once a disconnect message has been received.'
|
||
|
)
|
||
|
|
||
|
async def send(self, message: Message) -> None:
|
||
|
"""
|
||
|
Send ASGI websocket messages, ensuring valid state transitions.
|
||
|
"""
|
||
|
if self.application_state == WebSocketState.CONNECTING:
|
||
|
message_type = message["type"]
|
||
|
if message_type not in {
|
||
|
"websocket.accept",
|
||
|
"websocket.close",
|
||
|
"websocket.http.response.start",
|
||
|
}:
|
||
|
raise RuntimeError(
|
||
|
'Expected ASGI message "websocket.accept",'
|
||
|
'"websocket.close" or "websocket.http.response.start",'
|
||
|
f"but got {message_type!r}"
|
||
|
)
|
||
|
if message_type == "websocket.close":
|
||
|
self.application_state = WebSocketState.DISCONNECTED
|
||
|
elif message_type == "websocket.http.response.start":
|
||
|
self.application_state = WebSocketState.RESPONSE
|
||
|
else:
|
||
|
self.application_state = WebSocketState.CONNECTED
|
||
|
await self._send(message)
|
||
|
elif self.application_state == WebSocketState.CONNECTED:
|
||
|
message_type = message["type"]
|
||
|
if message_type not in {"websocket.send", "websocket.close"}:
|
||
|
raise RuntimeError(
|
||
|
'Expected ASGI message "websocket.send" or "websocket.close", '
|
||
|
f"but got {message_type!r}"
|
||
|
)
|
||
|
if message_type == "websocket.close":
|
||
|
self.application_state = WebSocketState.DISCONNECTED
|
||
|
try:
|
||
|
await self._send(message)
|
||
|
except OSError:
|
||
|
self.application_state = WebSocketState.DISCONNECTED
|
||
|
raise WebSocketDisconnect(code=1006)
|
||
|
elif self.application_state == WebSocketState.RESPONSE:
|
||
|
message_type = message["type"]
|
||
|
if message_type != "websocket.http.response.body":
|
||
|
raise RuntimeError(
|
||
|
'Expected ASGI message "websocket.http.response.body", '
|
||
|
f"but got {message_type!r}"
|
||
|
)
|
||
|
if not message.get("more_body", False):
|
||
|
self.application_state = WebSocketState.DISCONNECTED
|
||
|
await self._send(message)
|
||
|
else:
|
||
|
raise RuntimeError('Cannot call "send" once a close message has been sent.')
|
||
|
|
||
|
async def accept(
|
||
|
self,
|
||
|
subprotocol: str | None = None,
|
||
|
headers: typing.Iterable[tuple[bytes, bytes]] | None = None,
|
||
|
) -> None:
|
||
|
headers = headers or []
|
||
|
|
||
|
if self.client_state == WebSocketState.CONNECTING:
|
||
|
# If we haven't yet seen the 'connect' message, then wait for it first.
|
||
|
await self.receive()
|
||
|
await self.send(
|
||
|
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
|
||
|
)
|
||
|
|
||
|
def _raise_on_disconnect(self, message: Message) -> None:
|
||
|
if message["type"] == "websocket.disconnect":
|
||
|
raise WebSocketDisconnect(message["code"], message.get("reason"))
|
||
|
|
||
|
async def receive_text(self) -> str:
|
||
|
if self.application_state != WebSocketState.CONNECTED:
|
||
|
raise RuntimeError(
|
||
|
'WebSocket is not connected. Need to call "accept" first.'
|
||
|
)
|
||
|
message = await self.receive()
|
||
|
self._raise_on_disconnect(message)
|
||
|
return typing.cast(str, message["text"])
|
||
|
|
||
|
async def receive_bytes(self) -> bytes:
|
||
|
if self.application_state != WebSocketState.CONNECTED:
|
||
|
raise RuntimeError(
|
||
|
'WebSocket is not connected. Need to call "accept" first.'
|
||
|
)
|
||
|
message = await self.receive()
|
||
|
self._raise_on_disconnect(message)
|
||
|
return typing.cast(bytes, message["bytes"])
|
||
|
|
||
|
async def receive_json(self, mode: str = "text") -> typing.Any:
|
||
|
if mode not in {"text", "binary"}:
|
||
|
raise RuntimeError('The "mode" argument should be "text" or "binary".')
|
||
|
if self.application_state != WebSocketState.CONNECTED:
|
||
|
raise RuntimeError(
|
||
|
'WebSocket is not connected. Need to call "accept" first.'
|
||
|
)
|
||
|
message = await self.receive()
|
||
|
self._raise_on_disconnect(message)
|
||
|
|
||
|
if mode == "text":
|
||
|
text = message["text"]
|
||
|
else:
|
||
|
text = message["bytes"].decode("utf-8")
|
||
|
return json.loads(text)
|
||
|
|
||
|
async def iter_text(self) -> typing.AsyncIterator[str]:
|
||
|
try:
|
||
|
while True:
|
||
|
yield await self.receive_text()
|
||
|
except WebSocketDisconnect:
|
||
|
pass
|
||
|
|
||
|
async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
|
||
|
try:
|
||
|
while True:
|
||
|
yield await self.receive_bytes()
|
||
|
except WebSocketDisconnect:
|
||
|
pass
|
||
|
|
||
|
async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
|
||
|
try:
|
||
|
while True:
|
||
|
yield await self.receive_json()
|
||
|
except WebSocketDisconnect:
|
||
|
pass
|
||
|
|
||
|
async def send_text(self, data: str) -> None:
|
||
|
await self.send({"type": "websocket.send", "text": data})
|
||
|
|
||
|
async def send_bytes(self, data: bytes) -> None:
|
||
|
await self.send({"type": "websocket.send", "bytes": data})
|
||
|
|
||
|
async def send_json(self, data: typing.Any, mode: str = "text") -> None:
|
||
|
if mode not in {"text", "binary"}:
|
||
|
raise RuntimeError('The "mode" argument should be "text" or "binary".')
|
||
|
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||
|
if mode == "text":
|
||
|
await self.send({"type": "websocket.send", "text": text})
|
||
|
else:
|
||
|
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
|
||
|
|
||
|
async def close(self, code: int = 1000, reason: str | None = None) -> None:
|
||
|
await self.send(
|
||
|
{"type": "websocket.close", "code": code, "reason": reason or ""}
|
||
|
)
|
||
|
|
||
|
async def send_denial_response(self, response: Response) -> None:
|
||
|
if "websocket.http.response" in self.scope.get("extensions", {}):
|
||
|
await response(self.scope, self.receive, self.send)
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
"The server doesn't support the Websocket Denial Response extension."
|
||
|
)
|
||
|
|
||
|
|
||
|
class WebSocketClose:
|
||
|
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
|
||
|
self.code = code
|
||
|
self.reason = reason or ""
|
||
|
|
||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||
|
await send(
|
||
|
{"type": "websocket.close", "code": self.code, "reason": self.reason}
|
||
|
)
|