from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
overload,
)
from litestar.connection import WebSocket
from litestar.exceptions import ImproperlyConfiguredException, WebSocketDisconnect
from litestar.types import (
AnyCallable,
Dependencies,
Empty,
EmptyType,
ExceptionHandler,
Guard,
Middleware,
TypeEncodersMap,
)
from litestar.utils import ensure_async_callable
from litestar.utils.signature import ParsedSignature, get_fn_type_hints
from ._utils import (
ListenerHandler,
create_handle_receive,
create_handle_send,
create_handler_signature,
create_stub_dependency,
)
from .route_handler import WebsocketRouteHandler
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Coroutine, Mapping, Sequence
from litestar import Litestar, Router
from litestar.dto import AbstractDTO
from litestar.routes import BaseRoute
from litestar.types.asgi_types import WebSocketMode
from litestar.types.composite_types import ParametersMap, TypeDecodersSequence
__all__ = ("WebsocketListener", "WebsocketListenerRouteHandler", "websocket_listener")
[docs]
class WebsocketListenerRouteHandler(WebsocketRouteHandler):
"""A websocket listener that automatically accepts a connection, handles disconnects,
invokes a callback function every time new data is received and sends any data
returned
"""
__slots__ = { # noqa: RUF023
"connection_accept_handler": "Callback to accept a WebSocket connection. By default, calls WebSocket.accept",
"on_accept": "Callback invoked after a WebSocket connection has been accepted",
"on_disconnect": "Callback invoked after a WebSocket connection has been closed",
"_connection_lifespan": None,
"_receive_handler": None,
"_receive_mode": None,
"_send_handler": None,
"_send_mode": None,
}
@overload
def __init__(
self,
path: str | list[str] | None = None,
*,
fn: AnyCallable,
connection_lifespan: Callable[..., AbstractAsyncContextManager[Any]] | None = None,
dependencies: Dependencies | None = None,
dto: type[AbstractDTO] | None | EmptyType = Empty,
exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None,
guards: Sequence[Guard] | None = None,
middleware: Sequence[Middleware] | None = None,
receive_mode: WebSocketMode = "text",
send_mode: WebSocketMode = "text",
name: str | None = None,
opt: dict[str, Any] | None = None,
return_dto: type[AbstractDTO] | None | EmptyType = Empty,
signature_namespace: Mapping[str, Any] | None = None,
type_decoders: TypeDecodersSequence | None = None,
type_encoders: TypeEncodersMap | None = None,
websocket_class: type[WebSocket] | None = None,
parameters: ParametersMap | None = None,
**kwargs: Any,
) -> None: ...
@overload
def __init__(
self,
path: str | list[str] | None = None,
*,
fn: AnyCallable,
connection_accept_handler: Callable[[WebSocket], Coroutine[Any, Any, None]] = WebSocket.accept,
dependencies: Dependencies | None = None,
dto: type[AbstractDTO] | None | EmptyType = Empty,
exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None,
guards: Sequence[Guard] | None = None,
middleware: Sequence[Middleware] | None = None,
receive_mode: WebSocketMode = "text",
send_mode: WebSocketMode = "text",
name: str | None = None,
on_accept: AnyCallable | None = None,
on_disconnect: AnyCallable | None = None,
opt: dict[str, Any] | None = None,
return_dto: type[AbstractDTO] | None | EmptyType = Empty,
signature_namespace: Mapping[str, Any] | None = None,
type_decoders: TypeDecodersSequence | None = None,
type_encoders: TypeEncodersMap | None = None,
websocket_class: type[WebSocket] | None = None,
parameters: ParametersMap | None = None,
**kwargs: Any,
) -> None: ...
[docs]
def __init__(
self,
path: str | list[str] | None = None,
*,
fn: AnyCallable,
connection_accept_handler: Callable[[WebSocket], Coroutine[Any, Any, None]] = WebSocket.accept,
connection_lifespan: Callable[..., AbstractAsyncContextManager[Any]] | None = None,
dependencies: Dependencies | None = None,
dto: type[AbstractDTO] | None | EmptyType = Empty,
exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None,
guards: Sequence[Guard] | None = None,
middleware: Sequence[Middleware] | None = None,
receive_mode: WebSocketMode = "text",
send_mode: WebSocketMode = "text",
name: str | None = None,
on_accept: AnyCallable | None = None,
on_disconnect: AnyCallable | None = None,
opt: dict[str, Any] | None = None,
return_dto: type[AbstractDTO] | None | EmptyType = Empty,
signature_namespace: Mapping[str, Any] | None = None,
type_decoders: TypeDecodersSequence | None = None,
type_encoders: TypeEncodersMap | None = None,
websocket_class: type[WebSocket] | None = None,
parameters: ParametersMap | None = None,
**kwargs: Any,
) -> None:
"""Initialize ``WebsocketRouteHandler``
Args:
path: A path fragment for the route handler function or a sequence of path fragments. If not given defaults
to ``/``
fn: The handler function
connection_accept_handler: A callable that accepts a :class:`WebSocket <.connection.WebSocket>` instance
and returns a coroutine that when awaited, will accept the connection. Defaults to ``WebSocket.accept``.
connection_lifespan: An asynchronous context manager, handling the lifespan of the connection. By default,
it calls the ``connection_accept_handler``, ``on_connect`` and ``on_disconnect``. Can request any
dependencies, for example the :class:`WebSocket <.connection.WebSocket>` connection
dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances.
dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and
validation of request data.
exception_handlers: A mapping of status codes and/or exception types to handler functions.
guards: A sequence of :class:`Guard <.types.Guard>` callables.
middleware: A sequence of :class:`Middleware <.types.Middleware>`.
receive_mode: Websocket mode to receive data in, either `text` or `binary`.
send_mode: Websocket mode to receive data in, either `text` or `binary`.
name: A string identifying the route handler.
on_accept: Callback invoked after a connection has been accepted. Can request any dependencies, for example
the :class:`WebSocket <.connection.WebSocket>` connection
on_disconnect: Callback invoked after a connection has been closed. Can request any dependencies, for
example the :class:`WebSocket <.connection.WebSocket>` connection
opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or
wherever you have access to :class:`Request <.connection.Request>` or
:class:`ASGI Scope <.types.Scope>`.
return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing
outbound response data.
signature_namespace: A mapping of names to types for use in forward reference resolution during signature
modelling.
type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec
hook for deserialization.
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's
default websocket class.
parameters: A mapping of :func:`Parameter <.params.Parameter>` definitions
**kwargs: Any additional kwarg - will be set in the opt dictionary.
"""
if connection_lifespan and any([on_accept, on_disconnect, connection_accept_handler is not WebSocket.accept]):
raise ImproperlyConfiguredException(
"connection_lifespan can not be used with connection hooks "
"(on_accept, on_disconnect, connection_accept_handler)",
)
self._receive_mode: WebSocketMode = receive_mode
self._send_mode: WebSocketMode = send_mode
self._connection_lifespan = connection_lifespan
self._send_handler: Callable[[WebSocket, Any], Coroutine[None, None, None]] | EmptyType = Empty
self._receive_handler: Callable[[WebSocket], Any] | EmptyType = Empty
self.connection_accept_handler = connection_accept_handler
self.on_accept = ensure_async_callable(on_accept) if on_accept else None
self.on_disconnect = ensure_async_callable(on_disconnect) if on_disconnect else None
listener_dependencies = dict(dependencies or {})
listener_dependencies["connection_lifespan_dependencies"] = create_stub_dependency(
connection_lifespan or self.default_connection_lifespan
)
if self.on_accept:
listener_dependencies["on_accept_dependencies"] = create_stub_dependency(self.on_accept)
if self.on_disconnect:
listener_dependencies["on_disconnect_dependencies"] = create_stub_dependency(self.on_disconnect)
super().__init__(
fn=fn,
path=path,
dependencies=listener_dependencies,
exception_handlers=exception_handlers,
guards=guards,
middleware=middleware,
name=name,
opt=opt,
signature_namespace=signature_namespace,
dto=dto,
return_dto=return_dto,
type_decoders=type_decoders,
type_encoders=type_encoders,
websocket_class=websocket_class,
parameters=parameters,
**kwargs,
)
def _get_merge_opts(self, others: tuple[Router, ...]) -> dict[str, Any]:
merge_opts = super()._get_merge_opts(others)
merge_opts.update(
receive_mode=self._receive_mode,
send_mode=self._send_mode,
connection_lifespan=self._connection_lifespan,
connection_accept_handler=self.connection_accept_handler,
on_accept=self.on_accept,
on_disconnect=self.on_disconnect,
)
return merge_opts
[docs]
def on_registration(self, route: BaseRoute, app: Litestar) -> None:
self.fn = self._prepare_fn()
super().on_registration(route, app)
def _prepare_fn(self) -> ListenerHandler:
parsed_signature = ParsedSignature.from_fn(self.fn, self.signature_namespace)
if "data" not in parsed_signature.parameters:
raise ImproperlyConfiguredException("Websocket listeners must accept a 'data' parameter")
for param in ("request", "body"):
if param in parsed_signature.parameters:
raise ImproperlyConfiguredException(f"The {param} kwarg is not supported with websocket listeners")
# we are manipulating the signature of the decorated function below, so we must store the original values for
# use elsewhere.
self._parsed_return_field = parsed_signature.return_type
self._parsed_data_field = parsed_signature.parameters.get("data")
self._parsed_fn_signature = ParsedSignature.from_signature(
create_handler_signature(parsed_signature.original_signature),
fn_type_hints={
**get_fn_type_hints(self.fn, namespace=self.signature_namespace),
**get_fn_type_hints(ListenerHandler.__call__, namespace=self.signature_namespace),
},
)
return ListenerHandler(
listener=self, fn=self.fn, parsed_signature=parsed_signature, namespace=self.signature_namespace
)
def _validate_handler_function(self) -> None:
"""Validate the route handler function once it's set by inspecting its return annotations."""
# validation occurs in the call method
[docs]
@asynccontextmanager
async def default_connection_lifespan(
self,
socket: WebSocket,
on_accept_dependencies: Optional[dict[str, Any]] = None, # noqa: UP045
on_disconnect_dependencies: Optional[dict[str, Any]] = None, # noqa: UP045
) -> AsyncGenerator[None, None]:
"""Handle the connection lifespan of a :class:`WebSocket <.connection.WebSocket>`.
Args:
socket: The :class:`WebSocket <.connection.WebSocket>` connection
on_accept_dependencies: Dependencies requested by the :attr:`on_accept` hook
on_disconnect_dependencies: Dependencies requested by the :attr:`on_disconnect` hook
By, default this will
- Call :attr:`connection_accept_handler` to accept a connection
- Call :attr:`on_accept` if defined after a connection has been accepted
- Call :attr:`on_disconnect` upon leaving the context
"""
await self.connection_accept_handler(socket)
if self.on_accept:
await self.on_accept(**(on_accept_dependencies or {}))
try:
yield
except WebSocketDisconnect:
pass
finally:
if self.on_disconnect:
await self.on_disconnect(**(on_disconnect_dependencies or {}))
def resolve_receive_handler(self) -> Callable[[WebSocket], Any]:
if self._receive_handler is Empty:
self._receive_handler = create_handle_receive(self)
return self._receive_handler
def resolve_send_handler(self) -> Callable[[WebSocket, Any], Coroutine[None, None, None]]:
if self._send_handler is Empty:
self._send_handler = create_handle_send(self)
return self._send_handler
[docs]
class WebsocketListener(ABC):
path: str | list[str] | None = None
"""A path fragment for the route handler function or a sequence of path fragments. If not given defaults to ``/``"""
dependencies: Dependencies | None = None
"""A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances."""
dto: type[AbstractDTO] | None | EmptyType = Empty
""":class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and validation of request data"""
exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None
"""A mapping of status codes and/or exception types to handler functions."""
guards: list[Guard] | None = None
"""A sequence of :class:`Guard <.types.Guard>` callables."""
middleware: list[Middleware] | None = None
"""A sequence of :class:`Middleware <.types.Middleware>`."""
receive_mode: WebSocketMode = "text"
""":class:`WebSocket <.connection.WebSocket>` mode to receive data in, either ``text`` or ``binary``."""
send_mode: WebSocketMode = "text"
"""Websocket mode to send data in, either `text` or `binary`."""
name: str | None = None
"""A string identifying the route handler."""
opt: dict[str, Any] | None = None
"""
A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you
have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`.
"""
return_dto: type[AbstractDTO] | None | EmptyType = Empty
""":class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing outbound response data."""
signature_namespace: Mapping[str, Any] | None = None
"""
A mapping of names to types for use in forward reference resolution during signature modelling.
"""
type_decoders: TypeDecodersSequence | None = None
"""
type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec
hook for deserialization.
"""
type_encoders: TypeEncodersMap | None = None
"""
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
"""
websocket_class: type[WebSocket] | None = None
"""
websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's
default websocket class.
"""
def to_handler(self) -> WebsocketListenerRouteHandler:
on_accept = self.on_accept if self.on_accept != WebsocketListener.on_accept else None
on_disconnect = self.on_disconnect if self.on_disconnect != WebsocketListener.on_disconnect else None
return WebsocketListenerRouteHandler(
dependencies=self.dependencies,
dto=self.dto,
exception_handlers=self.exception_handlers,
guards=self.guards,
middleware=self.middleware,
send_mode=self.send_mode,
receive_mode=self.receive_mode,
name=self.name,
on_accept=on_accept,
on_disconnect=on_disconnect,
opt=self.opt,
path=self.path,
return_dto=self.return_dto,
signature_namespace=self.signature_namespace,
type_decoders=self.type_decoders,
type_encoders=self.type_encoders,
websocket_class=self.websocket_class,
fn=self.on_receive,
)
[docs]
def on_accept(self, *args: Any, **kwargs: Any) -> Any:
"""Called after a :class:`WebSocket <.connection.WebSocket>` connection
has been accepted. Can receive any dependencies
"""
[docs]
def on_disconnect(self, *args: Any, **kwargs: Any) -> Any:
"""Called after a :class:`WebSocket <.connection.WebSocket>` connection
has been disconnected. Can receive any dependencies
"""
[docs]
@abstractmethod
def on_receive(self, *args: Any, **kwargs: Any) -> Any:
"""Called after data has been received from the WebSocket.
This should take a ``data`` argument, receiving the processed WebSocket data,
and can additionally include handler dependencies such as ``state``, or other
regular dependencies.
Data returned from this function will be serialized and sent via the socket
according to handler configuration.
"""
raise NotImplementedError
@overload
def websocket_listener(
path: str | list[str] | None = None,
*,
connection_lifespan: Callable[..., AbstractAsyncContextManager[Any]] | None = None,
dependencies: Dependencies | None = None,
dto: type[AbstractDTO] | None | EmptyType = Empty,
exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None,
guards: list[Guard] | None = None,
middleware: list[Middleware] | None = None,
receive_mode: WebSocketMode = "text",
send_mode: WebSocketMode = "text",
name: str | None = None,
opt: dict[str, Any] | None = None,
return_dto: type[AbstractDTO] | None | EmptyType = Empty,
signature_namespace: Mapping[str, Any] | None = None,
type_decoders: TypeDecodersSequence | None = None,
type_encoders: TypeEncodersMap | None = None,
websocket_class: type[WebSocket] | None = None,
**kwargs: Any,
) -> Callable[[AnyCallable], WebsocketListenerRouteHandler]: ...
@overload
def websocket_listener(
path: str | list[str] | None = None,
*,
connection_accept_handler: Callable[[WebSocket], Coroutine[Any, Any, None]] = WebSocket.accept,
dependencies: Dependencies | None = None,
dto: type[AbstractDTO] | None | EmptyType = Empty,
exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None,
guards: list[Guard] | None = None,
middleware: list[Middleware] | None = None,
receive_mode: WebSocketMode = "text",
send_mode: WebSocketMode = "text",
name: str | None = None,
on_accept: AnyCallable | None = None,
on_disconnect: AnyCallable | None = None,
opt: dict[str, Any] | None = None,
return_dto: type[AbstractDTO] | None | EmptyType = Empty,
signature_namespace: Mapping[str, Any] | None = None,
type_decoders: TypeDecodersSequence | None = None,
type_encoders: TypeEncodersMap | None = None,
websocket_class: type[WebSocket] | None = None,
**kwargs: Any,
) -> Callable[[AnyCallable], WebsocketListenerRouteHandler]: ...
[docs]
def websocket_listener(
path: str | list[str] | None = None,
*,
connection_accept_handler: Callable[[WebSocket], Coroutine[Any, Any, None]] = WebSocket.accept,
connection_lifespan: Callable[..., AbstractAsyncContextManager[Any]] | None = None,
dependencies: Dependencies | None = None,
dto: type[AbstractDTO] | None | EmptyType = Empty,
exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None,
guards: list[Guard] | None = None,
middleware: list[Middleware] | None = None,
receive_mode: WebSocketMode = "text",
send_mode: WebSocketMode = "text",
name: str | None = None,
on_accept: AnyCallable | None = None,
on_disconnect: AnyCallable | None = None,
opt: dict[str, Any] | None = None,
return_dto: type[AbstractDTO] | None | EmptyType = Empty,
signature_namespace: Mapping[str, Any] | None = None,
type_decoders: TypeDecodersSequence | None = None,
type_encoders: TypeEncodersMap | None = None,
websocket_class: type[WebSocket] | None = None,
**kwargs: Any,
) -> Callable[[AnyCallable], WebsocketListenerRouteHandler]:
"""Create a :class:`WebsocketListenerRouteHandler`.
Args:
path: A path fragment for the route handler function or a sequence of path fragments. If not given defaults
to ``/``
connection_accept_handler: A callable that accepts a :class:`WebSocket <.connection.WebSocket>` instance
and returns a coroutine that when awaited, will accept the connection. Defaults to ``WebSocket.accept``.
connection_lifespan: An asynchronous context manager, handling the lifespan of the connection. By default,
it calls the ``connection_accept_handler``, ``on_connect`` and ``on_disconnect``. Can request any
dependencies, for example the :class:`WebSocket <.connection.WebSocket>` connection
dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances.
dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and
validation of request data.
exception_handlers: A mapping of status codes and/or exception types to handler functions.
guards: A sequence of :class:`Guard <.types.Guard>` callables.
middleware: A sequence of :class:`Middleware <.types.Middleware>`.
receive_mode: Websocket mode to receive data in, either `text` or `binary`.
send_mode: Websocket mode to receive data in, either `text` or `binary`.
name: A string identifying the route handler.
on_accept: Callback invoked after a connection has been accepted. Can request any dependencies, for example
the :class:`WebSocket <.connection.WebSocket>` connection
on_disconnect: Callback invoked after a connection has been closed. Can request any dependencies, for
example the :class:`WebSocket <.connection.WebSocket>` connection
opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or
wherever you have access to :class:`Request <.connection.Request>` or
:class:`ASGI Scope <.types.Scope>`.
return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing
outbound response data.
signature_namespace: A mapping of names to types for use in forward reference resolution during signature
modelling.
type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec
hook for deserialization.
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
**kwargs: Any additional kwarg - will be set in the opt dictionary.
websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's
default websocket class.
"""
def decorator(fn: AnyCallable) -> WebsocketListenerRouteHandler:
return WebsocketListenerRouteHandler(
fn=fn,
path=path,
connection_accept_handler=connection_accept_handler,
connection_lifespan=connection_lifespan,
dependencies=dependencies,
dto=dto,
exception_handlers=exception_handlers,
guard=guards,
middleware=middleware,
receive_mode=receive_mode,
send_mode=send_mode,
name=name,
on_accept=on_accept,
on_disconnect=on_disconnect,
opt=opt,
return_dto=return_dto,
signature_namespace=signature_namespace,
type_decoders=type_decoders,
type_encoders=type_encoders,
websocket_class=websocket_class,
**kwargs,
)
return decorator