from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
from litestar.connection.base import (
ASGIConnection,
AuthT,
StateT,
UserT,
empty_receive,
empty_send,
)
from litestar.datastructures.headers import Headers
from litestar.exceptions import WebSocketDisconnect
from litestar.serialization import decode_json, decode_msgpack, default_serializer, encode_json, encode_msgpack
from litestar.status_codes import WS_1000_NORMAL_CLOSURE
__all__ = ("WebSocket",)
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from litestar.handlers.websocket_handlers import WebsocketRouteHandler # noqa: F401
from litestar.types import Message, Serializer, WebSocketScope
from litestar.types.asgi_types import (
Receive,
ReceiveMessage,
Scope,
Send,
WebSocketAcceptEvent,
WebSocketCloseEvent,
WebSocketDisconnectEvent,
WebSocketMode,
WebSocketReceiveEvent,
WebSocketSendEvent,
)
DISCONNECT_MESSAGE = "connection is disconnected"
[docs]
class WebSocket(Generic[UserT, AuthT, StateT], ASGIConnection["WebsocketRouteHandler", UserT, AuthT, StateT]):
"""The Litestar WebSocket class."""
__slots__ = ("connection_state",)
scope: WebSocketScope # pyright: ignore
"""The ASGI scope attached to the connection."""
receive: Receive
"""The ASGI receive function."""
send: Send
"""The ASGI send function."""
[docs]
def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None:
"""Initialize ``WebSocket``.
Args:
scope: The ASGI connection scope.
receive: The ASGI receive function.
send: The ASGI send function.
"""
super().__init__(scope, self.receive_wrapper(receive), self.send_wrapper(send))
self.connection_state: Literal["init", "connect", "receive", "disconnect"] = "init"
[docs]
def receive_wrapper(self, receive: Receive) -> Receive:
"""Wrap ``receive`` to set 'self.connection_state' and validate events.
Args:
receive: The ASGI receive function.
Returns:
An ASGI receive function.
"""
async def wrapped_receive() -> ReceiveMessage:
if self.connection_state == "disconnect":
raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE)
message = await receive()
if message["type"] == "websocket.connect":
self.connection_state = "connect"
elif message["type"] == "websocket.receive":
self.connection_state = "receive"
else:
self.connection_state = "disconnect"
return message
return wrapped_receive
[docs]
def send_wrapper(self, send: Send) -> Send:
"""Wrap ``send`` to ensure that state is not disconnected.
Args:
send: The ASGI send function.
Returns:
An ASGI send function.
"""
async def wrapped_send(message: Message) -> None:
if self.connection_state == "disconnect":
raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover
await send(message)
return wrapped_send
[docs]
async def accept(
self,
subprotocols: str | None = None,
headers: Headers | dict[str, Any] | list[tuple[bytes, bytes]] | None = None,
) -> None:
"""Accept the incoming connection. This method should be called before receiving data.
Args:
subprotocols: Websocket sub-protocol to use.
headers: Headers to set on the data sent.
Returns:
None
"""
if self.connection_state == "init":
await self.receive()
_headers: list[tuple[bytes, bytes]] = headers if isinstance(headers, list) else []
if isinstance(headers, dict):
_headers = Headers(headers=headers).to_header_list()
if isinstance(headers, Headers):
_headers = headers.to_header_list()
event: WebSocketAcceptEvent = {
"type": "websocket.accept",
"subprotocol": subprotocols,
"headers": _headers,
}
await self.send(event)
[docs]
async def close(self, code: int = WS_1000_NORMAL_CLOSURE, reason: str | None = None) -> None:
"""Send an 'websocket.close' event.
Args:
code: Status code.
reason: Reason for closing the connection
Returns:
None
"""
event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason or ""}
await self.send(event)
@overload
async def receive_data(self, mode: Literal["text"]) -> str: ...
@overload
async def receive_data(self, mode: Literal["binary"]) -> bytes: ...
[docs]
async def receive_data(self, mode: WebSocketMode) -> str | bytes:
"""Receive an 'websocket.receive' event and returns the data stored on it.
Args:
mode: The respective event key to use.
Returns:
The event's data.
"""
if self.connection_state == "init":
await self.accept()
event = cast("WebSocketReceiveEvent | WebSocketDisconnectEvent", await self.receive())
if event["type"] == "websocket.disconnect":
raise WebSocketDisconnect(detail="disconnect event", code=event["code"])
return event.get("text") or "" if mode == "text" else event.get("bytes") or b""
@overload
def iter_data(self, mode: Literal["text"]) -> AsyncGenerator[str, None]: ...
@overload
def iter_data(self, mode: Literal["binary"]) -> AsyncGenerator[bytes, None]: ...
[docs]
async def iter_data(self, mode: WebSocketMode = "text") -> AsyncGenerator[str | bytes, None]:
"""Continuously receive data and yield it
Args:
mode: Socket mode to use. Either ``text`` or ``binary``
"""
try:
while True:
yield await self.receive_data(mode)
except WebSocketDisconnect:
pass
[docs]
async def receive_text(self) -> str:
"""Receive data as text.
Returns:
A string.
"""
return await self.receive_data(mode="text")
[docs]
async def receive_bytes(self) -> bytes:
"""Receive data as bytes.
Returns:
A byte-string.
"""
return await self.receive_data(mode="binary")
[docs]
async def receive_json(self, mode: WebSocketMode = "text") -> Any:
"""Receive data and decode it as JSON.
Args:
mode: Either ``text`` or ``binary``.
Returns:
An arbitrary value
"""
data = await self.receive_data(mode=mode)
return decode_json(value=data, type_decoders=self.route_handler.type_decoders)
[docs]
async def receive_msgpack(self) -> Any:
"""Receive data and decode it as MessagePack.
Note that since MessagePack is a binary format, this method will always receive
data in ``binary`` mode.
Returns:
An arbitrary value
"""
data = await self.receive_data(mode="binary")
return decode_msgpack(value=data, type_decoders=self.route_handler.type_decoders)
[docs]
async def iter_json(self, mode: WebSocketMode = "text") -> AsyncGenerator[Any, None]:
"""Continuously receive data and yield it, decoding it as JSON in the process.
Args:
mode: Socket mode to use. Either ``text`` or ``binary``
"""
async for data in self.iter_data(mode):
yield decode_json(value=data, type_decoders=self.route_handler.type_decoders)
[docs]
async def iter_msgpack(self) -> AsyncGenerator[Any, None]:
"""Continuously receive data and yield it, decoding it as MessagePack in the
process.
Note that since MessagePack is a binary format, this method will always receive
data in ``binary`` mode.
"""
async for data in self.iter_data(mode="binary"):
yield decode_msgpack(value=data, type_decoders=self.route_handler.type_decoders)
[docs]
async def send_data(self, data: str | bytes, mode: WebSocketMode = "text", encoding: str = "utf-8") -> None:
"""Send a 'websocket.send' event.
Args:
data: Data to send.
mode: The respective event key to use.
encoding: Encoding to use when converting bytes / str.
Returns:
None
"""
if self.connection_state == "init": # pragma: no cover
await self.accept()
event: WebSocketSendEvent = {"type": "websocket.send", "bytes": None, "text": None}
if mode == "binary":
event["bytes"] = data if isinstance(data, bytes) else data.encode(encoding)
else:
event["text"] = data if isinstance(data, str) else data.decode(encoding)
await self.send(event)
@overload
async def send_text(self, data: bytes, encoding: str = "utf-8") -> None: ...
@overload
async def send_text(self, data: str) -> None: ...
[docs]
async def send_text(self, data: str | bytes, encoding: str = "utf-8") -> None:
"""Send data using the ``text`` key of the send event.
Args:
data: Data to send
encoding: Encoding to use for binary data.
Returns:
None
"""
await self.send_data(data=data, encoding=encoding)
@overload
async def send_bytes(self, data: bytes) -> None: ...
@overload
async def send_bytes(self, data: str, encoding: str = "utf-8") -> None: ...
[docs]
async def send_bytes(self, data: str | bytes, encoding: str = "utf-8") -> None:
"""Send data using the ``bytes`` key of the send event.
Args:
data: Data to send
encoding: Encoding to use for binary data.
Returns:
None
"""
await self.send_data(data=data, mode="binary", encoding=encoding)
[docs]
async def send_json(
self,
data: Any,
mode: WebSocketMode = "text",
encoding: str = "utf-8",
serializer: Serializer = default_serializer,
) -> None:
"""Send data as JSON.
Args:
data: A value to serialize.
mode: Either ``text`` or ``binary``.
encoding: Encoding to use for binary data.
serializer: A serializer function.
Returns:
None
"""
await self.send_data(data=encode_json(data, serializer), mode=mode, encoding=encoding)
[docs]
async def send_msgpack(
self,
data: Any,
encoding: str = "utf-8",
serializer: Serializer = default_serializer,
) -> None:
"""Send data as MessagePack.
Note that since MessagePack is a binary format, this method will always send
data in ``binary`` mode.
Args:
data: A value to serialize.
encoding: Encoding to use for binary data.
serializer: A serializer function.
Returns:
None
"""
await self.send_data(data=encode_msgpack(data, serializer), mode="binary", encoding=encoding)