Source code for litestar.channels.backends.memory

from __future__ import annotations

from asyncio import Queue
from collections import defaultdict, deque
from typing import TYPE_CHECKING, Any

from litestar.channels.backends.base import ChannelsBackend

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator, Iterable


[docs] class MemoryChannelsBackend(ChannelsBackend): """An in-memory channels backend"""
[docs] def __init__(self, history: int = 0) -> None: self._max_history_length = history self._channels: set[str] = set() self._queue: Queue[tuple[str, bytes]] | None = None self._history: defaultdict[str, deque[bytes]] = defaultdict(lambda: deque(maxlen=self._max_history_length))
[docs] async def on_startup(self) -> None: self._queue = Queue()
[docs] async def on_shutdown(self) -> None: self._queue = None
[docs] async def publish(self, data: bytes, channels: Iterable[str]) -> None: """Publish ``data`` to ``channels``. If a channel has not yet been subscribed to, this will be a no-op. Args: data: Data to publish channels: Channels to publish to Returns: None Raises: RuntimeError: If ``on_startup`` has not been called yet """ if not self._queue: raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") for channel in channels: if channel not in self._channels: continue self._queue.put_nowait((channel, data)) if self._max_history_length: for channel in channels: self._history[channel].append(data)
[docs] async def subscribe(self, channels: Iterable[str]) -> None: """Subscribe to ``channels``, and enable publishing to them""" self._channels.update(channels)
[docs] async def unsubscribe(self, channels: Iterable[str]) -> None: """Unsubscribe from ``channels``""" self._channels -= set(channels) for channel in channels: self._history.pop(channel, None)
[docs] async def stream_events(self) -> AsyncGenerator[tuple[str, Any], None]: """Return a generator, iterating over events of subscribed channels as they become available""" if self._queue is None: raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") while True: channel, message = await self._queue.get() self._queue.task_done() # if a message is published to a channel and the channel is then # unsubscribed before retrieving that message from the stream, it can still # end up here, so we double-check if we still are interested in this message if channel in self._channels: yield channel, message
[docs] async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: """Return the event history of ``channel``, at most ``limit`` entries""" history = list(self._history[channel]) if limit: history = history[-limit:] return history