from __future__ import annotations
import asyncio
from asyncio import CancelledError, Queue, QueueFull
from collections import deque
from collections.abc import AsyncGenerator, Awaitable
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar
if TYPE_CHECKING:
from litestar.channels import ChannelsPlugin
T = TypeVar("T")
BacklogStrategy = Literal["backoff", "dropleft"]
EventCallback = Callable[[bytes], Awaitable[Any]]
class AsyncDeque(Queue, Generic[T]):
def __init__(self, maxsize: int | None) -> None:
self._deque_maxlen = maxsize
super().__init__()
def _init(self, maxsize: int) -> None:
self._queue: deque[T] = deque(maxlen=self._deque_maxlen)
[docs]
class Subscriber:
"""A wrapper around a stream of events published to subscribed channels"""
[docs]
def __init__(
self,
plugin: ChannelsPlugin,
max_backlog: int | None = None,
backlog_strategy: BacklogStrategy = "backoff",
) -> None:
self._task: asyncio.Task | None = None
self._plugin = plugin
self._backend = plugin._backend
self._queue: Queue[bytes | None] | AsyncDeque[bytes | None]
if max_backlog and backlog_strategy == "dropleft":
self._queue = AsyncDeque(maxsize=max_backlog or 0)
else:
self._queue = Queue(maxsize=max_backlog or 0)
async def put(self, item: bytes | None) -> None:
await self._queue.put(item)
[docs]
def put_nowait(self, item: bytes | None) -> bool:
"""Put an item in the subscriber's stream without waiting"""
try:
self._queue.put_nowait(item)
return True
except QueueFull:
return False
@property
def qsize(self) -> int:
return self._queue.qsize()
[docs]
async def iter_events(self) -> AsyncGenerator[bytes, None]:
"""Iterate over the stream of events. If no items are available, block until
one becomes available
"""
while True:
item = await self._queue.get()
if item is None:
self._queue.task_done()
break
yield item
self._queue.task_done()
[docs]
@asynccontextmanager
async def run_in_background(self, on_event: EventCallback, join: bool = True) -> AsyncGenerator[None, None]:
"""Start a task in the background that sends events from the subscriber's stream
to ``socket`` as they become available. On exit, it will prevent the stream from
accepting new events and wait until the currently enqueued ones are processed.
Should the context be left with an exception, the task will be cancelled
immediately.
Args:
on_event: Callback to invoke with the event data for every event
join: If ``True``, wait for all items in the stream to be processed before
stopping the worker. Note that an error occurring within the context
will always lead to the immediate cancellation of the worker
"""
self._start_in_background(on_event=on_event)
async with AsyncExitStack() as exit_stack:
exit_stack.push_async_callback(self.stop, join=False)
yield
exit_stack.pop_all()
await self.stop(join=join)
async def _worker(self, on_event: EventCallback) -> None:
async for event in self.iter_events():
await on_event(event)
def _start_in_background(self, on_event: EventCallback) -> None:
"""Start a task in the background that sends events from the subscriber's stream
to ``socket`` as they become available.
Args:
on_event: Callback to invoke with the event data for every event
"""
if self._task is not None:
raise RuntimeError("Subscriber is already running")
self._task = asyncio.create_task(self._worker(on_event))
@property
def is_running(self) -> bool:
"""Return whether a sending task is currently running"""
return self._task is not None
[docs]
async def stop(self, join: bool = False) -> None:
"""Stop a task was previously started with :meth:`run_in_background`. If the
task is not yet done it will be cancelled and awaited
Args:
join: If ``True`` wait for all items to be processed before stopping the task
"""
if not self._task:
return
if join:
await self._queue.join()
if not self._task.done():
self._task.cancel()
with suppress(CancelledError):
await self._task
self._task = None