Source code for litestar.testing.websocket_test_session

from __future__ import annotations

import contextlib
import math
from typing import TYPE_CHECKING, Any, Literal, cast

import anyio
import anyio.abc
from anyio.streams.stapled import StapledObjectStream

from litestar.exceptions import WebSocketDisconnect
from litestar.serialization import decode_json, decode_msgpack, encode_json, encode_msgpack
from litestar.status_codes import WS_1000_NORMAL_CLOSURE

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator
    from types import TracebackType

    from anyio.streams.memory import MemoryObjectReceiveStream

    from litestar.testing.client.sync_client import TestClient
    from litestar.types import (
        ASGIApp,
        WebSocketDisconnectEvent,
        WebSocketReceiveMessage,
        WebSocketScope,
        WebSocketSendMessage,
    )


__all__ = ("AsyncWebSocketTestSession", "WebSocketTestSession")


[docs] class WebSocketTestSession:
[docs] def __init__( self, client: TestClient[Any], scope: WebSocketScope, portal: anyio.abc.BlockingPortal, connect_timeout: float | None = None, ) -> None: self._exit_stack = contextlib.ExitStack() self._portal = portal self._client = client self._scope = scope self._connect_timeout = connect_timeout
@contextlib.asynccontextmanager async def _run_session(self) -> AsyncGenerator[AsyncWebSocketTestSession]: async with ( anyio.create_task_group() as tg, AsyncWebSocketTestSession( app=self._client.app, scope=self._scope, connect_timeout=self._connect_timeout, tg=tg, ) as session, ): yield session def __enter__(self) -> WebSocketTestSession: with contextlib.ExitStack() as exit_stack: self._async_session = exit_stack.enter_context( self._portal.wrap_async_context_manager(self._portal.call(self._run_session)) ) self._exit_stack = exit_stack.pop_all() return self def __exit__( self, exc_type: type[BaseException] | None = None, exc_value: BaseException | None = None, traceback: TracebackType | None = None, ) -> None: self._exit_stack.__exit__(exc_type, exc_value, traceback) @property def accepted_subprotocol(self) -> str | None: return self._async_session.accepted_subprotocol @property def extra_headers(self) -> list[tuple[bytes, bytes]]: return self._async_session.extra_headers @property def scope(self) -> WebSocketScope: return self._async_session.scope
[docs] def send(self, data: str | bytes, mode: Literal["text", "binary"] = "text", encoding: str = "utf-8") -> None: """Sends a "receive" event. This is the inverse of the ASGI send method. Args: data: Either a string or a byte string. mode: The key to use - ``text`` or ``bytes`` encoding: The encoding to use when encoding or decoding data. Returns: None. """ self._portal.call(self._async_session.send, data, mode, encoding)
[docs] def send_text(self, data: str, encoding: str = "utf-8") -> None: """Sends the data using the ``text`` key. Args: data: Data to send. encoding: Encoding to use. Returns: None """ self._portal.call(self._async_session.send_text, data, encoding)
[docs] def send_bytes(self, data: bytes, encoding: str = "utf-8") -> None: """Sends the data using the ``bytes`` key. Args: data: Data to send. encoding: Encoding to use. Returns: None """ self._portal.call(self._async_session.send_bytes, data, encoding)
[docs] def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None: """Sends the given data as JSON. Args: data: The data to send. mode: Either ``text`` or ``binary`` Returns: None """ self.send(encode_json(data), mode=mode)
[docs] def send_msgpack(self, data: Any) -> None: """Sends the given data as MessagePack. Args: data: The data to send. Returns: None """ self.send(encode_msgpack(data), mode="binary")
[docs] def close(self, code: int = WS_1000_NORMAL_CLOSURE, reason: str | None = None) -> None: """Sends an 'websocket.disconnect' event. Args: code: status code for closing the connection. reason: Reason for closure Returns: None. """ self._portal.call(self._async_session.close, code, reason)
[docs] def receive(self, block: bool = True, timeout: float | None = None) -> WebSocketSendMessage: """This is the base receive method. Args: block: Block until a message is received timeout: If ``block`` is ``True``, block at most ``timeout`` seconds Notes: - you can use one of the other receive methods to extract the data from the message. Returns: A websocket message. """ return self._portal.call(self._async_session.receive, block, timeout)
[docs] def receive_text(self, block: bool = True, timeout: float | None = None) -> str: """Receive data in ``text`` mode and return a string Args: block: Block until a message is received timeout: If ``block`` is ``True``, block at most ``timeout`` seconds Returns: A string value. """ return self._portal.call(self._async_session.receive_text, block, timeout)
[docs] def receive_bytes(self, block: bool = True, timeout: float | None = None) -> bytes: """Receive data in ``binary`` mode and return bytes Args: block: Block until a message is received timeout: If ``block`` is ``True``, block at most ``timeout`` seconds Returns: A string value. """ return self._portal.call(self._async_session.receive_bytes, block, timeout)
[docs] def receive_json( self, mode: Literal["text", "binary"] = "text", block: bool = True, timeout: float | None = None ) -> Any: """Receive data in either ``text`` or ``binary`` mode and decode it as JSON. Args: mode: Either ``text`` or ``binary`` block: Block until a message is received timeout: If ``block`` is ``True``, block at most ``timeout`` seconds Returns: An arbitrary value """ return self._portal.call(self._async_session.receive_json, mode, block, timeout)
[docs] def receive_msgpack(self, block: bool = True, timeout: float | None = None) -> Any: return self._portal.call(self._async_session.receive_msgpack, block, timeout)
[docs] class AsyncWebSocketTestSession:
[docs] def __init__( self, *, app: ASGIApp, scope: WebSocketScope, connect_timeout: float | None = None, tg: anyio.abc.TaskGroup, ) -> None: self.scope = scope self.accepted_subprotocol: str | None = None self.extra_headers: list[tuple[bytes, bytes]] = [] self.app = app self._tg = tg self._send_stream = StapledObjectStream(*anyio.create_memory_object_stream["WebSocketSendMessage"](math.inf)) self._receive_stream = StapledObjectStream( *anyio.create_memory_object_stream["WebSocketReceiveMessage"](math.inf) ) self._exit_stack = contextlib.AsyncExitStack() self._connect_timeout = connect_timeout
async def __aenter__(self) -> AsyncWebSocketTestSession: async with contextlib.AsyncExitStack() as exit_stack: cancel_scope = anyio.CancelScope() app_done = await self._tg.start(self._run, cancel_scope, self._receive_stream, self._send_stream) exit_stack.callback(cancel_scope.cancel) exit_stack.push_async_callback(app_done.wait) exit_stack.push_async_callback(self.close) await self._asgi_send({"type": "websocket.connect"}) message = await self.receive(timeout=self._connect_timeout) if message["type"] != "websocket.accept": raise RuntimeError( f"Unexpected ASGI message. Expected 'websocket.accept'. Received {message['type']!r}" ) self.accepted_subprotocol = message.get("subprotocol") self.extra_headers = list(message.get("headers", [])) self._exit_stack = exit_stack.pop_all() return self async def __aexit__( self, exc_type: type[BaseException] | None = None, exc_value: BaseException | None = None, traceback: TracebackType | None = None, ) -> None: await self._exit_stack.__aexit__(exc_type, exc_value, traceback) async def _run( self, cancel_scope: anyio.CancelScope, receive_stream: StapledObjectStream, send_stream: StapledObjectStream, *, task_status: anyio.abc.TaskStatus, ) -> None: app_done = anyio.Event() with cancel_scope: async with send_stream, receive_stream: task_status.started(app_done) await self.app(self.scope, receive_stream.receive, send_stream.send) app_done.set() await anyio.sleep_forever() async def _asgi_send(self, message: WebSocketReceiveMessage) -> None: await self._receive_stream.send(message)
[docs] async def close(self, code: int = WS_1000_NORMAL_CLOSURE, reason: str | None = None) -> None: """Sends an 'websocket.disconnect' event. Args: code: status code for closing the connection. reason: Reason for closure Returns: None. """ event: WebSocketDisconnectEvent = {"type": "websocket.disconnect", "code": code, "reason": reason} await self._asgi_send(event)
[docs] async def send(self, data: str | bytes, mode: Literal["text", "binary"] = "text", encoding: str = "utf-8") -> None: """Sends a "receive" event. This is the inverse of the ASGI send method. Args: data: Either a string or a byte string. mode: The key to use - ``text`` or ``bytes`` encoding: The encoding to use when encoding or decoding data. Returns: None. """ if mode == "text": data = data.decode(encoding) if isinstance(data, bytes) else data text_event: WebSocketReceiveMessage = {"type": "websocket.receive", "text": data} # type: ignore[assignment] await self._asgi_send(text_event) else: data = data if isinstance(data, bytes) else data.encode(encoding) binary_event: WebSocketReceiveMessage = {"type": "websocket.receive", "bytes": data} # type: ignore[assignment] await self._asgi_send(binary_event)
[docs] async def send_text(self, data: str, encoding: str = "utf-8") -> None: """Sends the data using the ``text`` key. Args: data: Data to send. encoding: Encoding to use. Returns: None """ await self.send(data=data, encoding=encoding)
[docs] async def send_bytes(self, data: bytes, encoding: str = "utf-8") -> None: """Sends the data using the ``bytes`` key. Args: data: Data to send. encoding: Encoding to use. Returns: None """ await self.send(data=data, mode="binary", encoding=encoding)
[docs] async def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None: """Sends the given data as JSON. Args: data: The data to send. mode: Either ``text`` or ``binary`` Returns: None """ await self.send(encode_json(data), mode=mode)
[docs] async def send_msgpack(self, data: Any) -> None: """Sends the given data as MessagePack. Args: data: The data to send. Returns: None """ await self.send(encode_msgpack(data), mode="binary")
[docs] async def receive(self, block: bool = True, timeout: float | None = None) -> WebSocketSendMessage: """This is the base receive method. Args: block: Block until a message is received timeout: If ``block`` is ``True``, block at most ``timeout`` seconds Notes: - you can use one of the other receive methods to extract the data from the message. Returns: A websocket message. """ message: WebSocketSendMessage if not block: message = cast("MemoryObjectReceiveStream", self._send_stream.receive_stream).receive_nowait() else: with anyio.fail_after(timeout): message = await self._send_stream.receive() if message["type"] == "websocket.close": raise WebSocketDisconnect( detail=cast("str", message.get("reason", "")), code=message.get("code", WS_1000_NORMAL_CLOSURE), ) return message
[docs] async def receive_text(self, block: bool = True, timeout: float | None = None) -> str: """Receive data in ``text`` mode and return a string Args: block: Block until a message is received timeout: If ``block`` is ``True``, block at most ``timeout`` seconds Returns: A string value. """ message = await self.receive(block=block, timeout=timeout) return cast("str", message.get("text", ""))
[docs] async def receive_bytes(self, block: bool = True, timeout: float | None = None) -> bytes: """Receive data in ``binary`` mode and return bytes Args: block: Block until a message is received timeout: If ``block`` is ``True``, block at most ``timeout`` seconds Returns: A string value. """ message = await self.receive(block=block, timeout=timeout) return cast("bytes", message.get("bytes", b""))
[docs] async def receive_json( self, mode: Literal["text", "binary"] = "text", block: bool = True, timeout: float | None = None ) -> Any: """Receive data in either ``text`` or ``binary`` mode and decode it as JSON. Args: mode: Either ``text`` or ``binary`` block: Block until a message is received timeout: If ``block`` is ``True``, block at most ``timeout`` seconds Returns: An arbitrary value """ message = await self.receive(block=block, timeout=timeout) if mode == "text": return decode_json(cast("str", message.get("text", ""))) return decode_json(cast("bytes", message.get("bytes", b"")))
[docs] async def receive_msgpack(self, block: bool = True, timeout: float | None = None) -> Any: message = await self.receive(block=block, timeout=timeout) return decode_msgpack(cast("bytes", message.get("bytes", b"")))