Source code for litestar.handlers.websocket_handlers.stream

from __future__ import annotations

import dataclasses
import functools
import warnings
from collections.abc import AsyncGenerator, Awaitable, Mapping
from typing import TYPE_CHECKING, Any, Callable, cast

import anyio
from msgspec.json import Encoder as JsonEncoder

from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning, WebSocketDisconnect
from litestar.handlers.websocket_handlers.route_handler import WebsocketRouteHandler
from litestar.types import Empty
from litestar.types.builtin_types import NoneType
from litestar.typing import FieldDefinition
from litestar.utils.signature import ParsedSignature

if TYPE_CHECKING:
    from litestar import Litestar, WebSocket
    from litestar.dto import AbstractDTO
    from litestar.routes import BaseRoute
    from litestar.types import Dependencies, EmptyType, ExceptionHandler, Guard, Middleware, TypeEncodersMap
    from litestar.types.asgi_types import WebSocketMode


[docs] async def send_websocket_stream( socket: WebSocket, stream: AsyncGenerator[Any, Any], *, close: bool = True, mode: WebSocketMode = "text", send_handler: Callable[[WebSocket, Any], Awaitable[Any]] | None = None, listen_for_disconnect: bool = False, warn_on_data_discard: bool = True, ) -> None: """Stream data to the ``socket`` from an asynchronous generator. Example: Sending the current time to the connected client every 0.5 seconds: .. code-block:: python async def stream_current_time() -> AsyncGenerator[str, None]: while True: yield str(time.time()) await asyncio.sleep(0.5) @websocket("/time") async def time_handler(socket: WebSocket) -> None: await socket.accept() await send_websocket_stream( socket, stream_current_time(), listen_for_disconnect=True, ) Args: socket: The :class:`~litestar.connection.WebSocket` to send to stream: An asynchronous generator yielding data to send close: If ``True``, close the socket after the generator is exhausted mode: WebSocket mode to use for sending when no ``send_handler`` is specified send_handler: Callable to handle the send process. If ``None``, defaults to ``type(socket).send_data`` listen_for_disconnect: If ``True``, listen for client disconnects in the background. If a client disconnects, stop the generator and cancel sending data. Should always be ``True`` unless disconnects are handled elsewhere, for example by reading data from the socket concurrently. Should never be set to ``True`` when reading data from socket concurrently, as it can lead to data loss warn_on_data_discard: If ``True`` and ``listen_for_disconnect=True``, warn if during listening for client disconnects, data is received from the socket """ if send_handler is None: send_handler = functools.partial(type(socket).send_data, mode=mode) async def send_stream() -> None: try: # client might have disconnected elsewhere, so we stop sending while socket.connection_state != "disconnect": await send_handler(socket, await stream.__anext__()) except StopAsyncIteration: pass if listen_for_disconnect: # wrap 'send_stream' and disconnect listener, so they'll cancel the other once # one of the finishes async def wrapped_stream() -> None: await send_stream() # stream exhausted, we can stop listening for a disconnect tg.cancel_scope.cancel() async def disconnect_listener() -> None: try: # run this in a loop - we might receive other data than disconnects. # listen_for_disconnect is explicitly not safe when consuming WS data # in other places, so discarding that data here is fine while True: await socket.receive_data("text") if warn_on_data_discard: warnings.warn( "received data from websocket while listening for client " "disconnect in a websocket_stream. listen_for_disconnect " "is not safe to use when attempting to receive data from " "the same socket concurrently with a websocket_stream. set " "listen_for_disconnect=False if you're attempting to " "receive data from this socket or set " "warn_on_data_discard=False to disable this warning", stacklevel=2, category=LitestarWarning, ) except WebSocketDisconnect: # client disconnected, we can stop streaming tg.cancel_scope.cancel() async with anyio.create_task_group() as tg: tg.start_soon(wrapped_stream) tg.start_soon(disconnect_listener) else: await send_stream() if close and socket.connection_state != "disconnect": await socket.close()
[docs] def websocket_stream( path: str | list[str] | None = None, *, dependencies: Dependencies | None = None, exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, guards: list[Guard] | None = None, middleware: list[Middleware] | None = None, name: str | None = None, opt: dict[str, Any] | None = None, signature_namespace: Mapping[str, Any] | None = None, websocket_class: type[WebSocket] | None = None, mode: WebSocketMode = "text", return_dto: type[AbstractDTO] | None | EmptyType = Empty, type_encoders: TypeEncodersMap | None = None, listen_for_disconnect: bool = True, warn_on_data_discard: bool = True, **kwargs: Any, ) -> Callable[[Callable[..., AsyncGenerator[Any, Any]]], WebsocketRouteHandler]: """Create a WebSocket handler that accepts a connection and sends data to it from an async generator. Example: Sending the current time to the connected client every 0.5 seconds: .. code-block:: python @websocket_stream("/time") async def send_time() -> AsyncGenerator[str, None]: while True: yield str(time.time()) await asyncio.sleep(0.5) Args: path: A path fragment for the route handler function or a sequence of path fragments. If not given defaults to ``/`` dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. exception_handlers: A mapping of status codes and/or exception types to handler functions. guards: A sequence of :class:`Guard <.types.Guard>` callables. middleware: A sequence of :class:`Middleware <.types.Middleware>`. name: A string identifying the route handler. opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's default websocket class. mode: WebSocket mode used for sending return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing outbound response data. type_encoders: A mapping of types to callables that transform them into types supported for serialization. listen_for_disconnect: If ``True``, listen for client disconnects in the background. If a client disconnects, stop the generator and cancel sending data. Should always be ``True`` unless disconnects are handled elsewhere, for example by reading data from the socket concurrently. Should never be set to ``True`` when reading data from socket concurrently, as it can lead to data loss warn_on_data_discard: If ``True`` and ``listen_for_disconnect=True``, warn if during listening for client disconnects, data is received from the socket **kwargs: Any additional kwarg - will be set in the opt dictionary. """ def decorator(fn: Callable[..., AsyncGenerator[Any, Any]]) -> WebsocketRouteHandler: return WebSocketStreamHandler( fn=fn, # type: ignore[arg-type] path=path, dependencies=dependencies, exception_handlers=exception_handlers, guard=guards, middleware=middleware, name=name, opt=opt, signature_namespace=signature_namespace, websocket_class=websocket_class, return_dto=return_dto, type_encoders=type_encoders, stream_options=_WebSocketStreamOptions( generator_fn=fn, send_mode=mode, listen_for_disconnect=listen_for_disconnect, warn_on_data_discard=warn_on_data_discard, ), **kwargs, ) return decorator
class WebSocketStreamHandler(WebsocketRouteHandler): __slots__ = ("_ws_stream_options",) _ws_stream_options: _WebSocketStreamOptions def on_registration(self, route: BaseRoute, app: Litestar) -> None: self._ws_stream_options = self.opt["stream_options"] parsed_handler_signature = parsed_stream_fn_signature = ParsedSignature.from_fn( self.fn, self.signature_namespace ) if not parsed_stream_fn_signature.return_type.is_subclass_of(AsyncGenerator): raise ImproperlyConfiguredException( f"Route handler {self}: 'websocket_stream' handlers must return an " f"'AsyncGenerator', not {type(parsed_stream_fn_signature.return_type.raw)!r}" ) # important not to use 'self._ws_stream_options.generator_fn' here; This would # break in cases the decorator has been used inside a controller, as it would # be a reference to the unbound method. The bound method is patched in later # after the controller has been initialized. This is a workaround that should # go away with v3.0's static handlers stream_fn = cast("Callable[..., AsyncGenerator[Any, Any]]", self.fn) # construct a fake signature for the kwargs modelling, using the generator # function passed to the handler as a base, to include all the dependencies, # params, injection kwargs, etc. + 'socket', so DI works properly, but the # signature looks to kwargs/signature modelling like a plain '@websocket' # handler that returns 'None' parsed_handler_signature = dataclasses.replace( parsed_handler_signature, return_type=FieldDefinition.from_annotation(NoneType) ) receives_socket_parameter = "socket" in parsed_stream_fn_signature.parameters if not receives_socket_parameter: parsed_handler_signature = dataclasses.replace( parsed_handler_signature, parameters={ **parsed_handler_signature.parameters, "socket": FieldDefinition.from_annotation("WebSocket", name="socket"), }, ) self._parsed_fn_signature = parsed_handler_signature self._parsed_return_field = parsed_stream_fn_signature.return_type.inner_types[0] json_encoder = JsonEncoder(enc_hook=self.default_serializer) self._dto = self._resolve_data_dto(app=app) self._return_dto = return_dto = self._resolve_return_dto(app=app, data_dto=self._dto) # make sure the closure doesn't capture self._ws_stream / self send_mode: WebSocketMode = self._ws_stream_options.send_mode # pyright: ignore listen_for_disconnect = self._ws_stream_options.listen_for_disconnect warn_on_data_discard = self._ws_stream_options.warn_on_data_discard async def send_handler(socket: WebSocket, data: Any) -> None: if isinstance(data, (str, bytes)): await socket.send_data(data=data, mode=send_mode) return if return_dto: encoded_data = return_dto(socket).data_to_encodable_type(data) data = json_encoder.encode(encoded_data) await socket.send_data(data=data, mode=send_mode) return data = json_encoder.encode(data) await socket.send_data(data=data, mode=send_mode) @functools.wraps(stream_fn) async def handler_fn(*args: Any, socket: WebSocket, **kw: Any) -> None: if receives_socket_parameter: kw["socket"] = socket await send_websocket_stream( socket=socket, stream=stream_fn(*args, **kw), mode=send_mode, close=True, listen_for_disconnect=listen_for_disconnect, warn_on_data_discard=warn_on_data_discard, send_handler=send_handler, ) self.fn = handler_fn # pyright: ignore super().on_registration(route, app) class _WebSocketStreamOptions: def __init__( self, generator_fn: Callable[..., AsyncGenerator[Any, Any]], listen_for_disconnect: bool, warn_on_data_discard: bool, send_mode: WebSocketMode, ) -> None: self.generator_fn = generator_fn self.listen_for_disconnect = listen_for_disconnect self.warn_on_data_discard = warn_on_data_discard self.send_mode = send_mode