from __future__ import annotations
import math
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import AbstractAsyncContextManager as AsyncContextManager
from contextlib import AsyncExitStack
from functools import partial
from typing import TYPE_CHECKING, Any
import anyio
from litestar.exceptions import ImproperlyConfiguredException
if TYPE_CHECKING:
from collections.abc import Sequence
from types import TracebackType
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from litestar.events.listener import EventListener
__all__ = ("BaseEventEmitterBackend", "SimpleEventEmitter")
[docs]
class BaseEventEmitterBackend(AsyncContextManager["BaseEventEmitterBackend"], ABC):
"""Abstract class used to define event emitter backends."""
__slots__ = ("listeners",)
listeners: defaultdict[str, set[EventListener]]
[docs]
def __init__(self, listeners: Sequence[EventListener]) -> None:
"""Create an event emitter instance.
Args:
listeners: A list of listeners.
"""
self.listeners = defaultdict(set)
for listener in listeners:
for event_id in listener.event_ids:
self.listeners[event_id].add(listener)
[docs]
@abstractmethod
def emit(self, event_id: str, *args: Any, **kwargs: Any) -> None:
"""Emit an event to all attached listeners.
Args:
event_id: The ID of the event to emit, e.g 'my_event'.
*args: args to pass to the listener(s).
**kwargs: kwargs to pass to the listener(s)
Returns:
None
"""
raise NotImplementedError("not implemented")
[docs]
class SimpleEventEmitter(BaseEventEmitterBackend):
"""Event emitter the works only in the current process"""
__slots__ = ("_exit_stack", "_queue", "_receive_stream", "_send_stream")
[docs]
def __init__(self, listeners: Sequence[EventListener]) -> None:
"""Create an event emitter instance.
Args:
listeners: A list of listeners.
"""
super().__init__(listeners=listeners)
self._receive_stream: MemoryObjectReceiveStream | None = None
self._send_stream: MemoryObjectSendStream | None = None
self._exit_stack: AsyncExitStack | None = None
async def _worker(self, receive_stream: MemoryObjectReceiveStream) -> None:
"""Run items from ``receive_stream`` in a task group.
Returns:
None
"""
async with receive_stream, anyio.create_task_group() as task_group:
async for item in receive_stream:
fn, args, kwargs = item
if kwargs:
fn = partial(fn, **kwargs)
task_group.start_soon(fn, *args) # pyright: ignore[reportGeneralTypeIssues]
async def __aenter__(self) -> SimpleEventEmitter:
self._exit_stack = AsyncExitStack()
send_stream, receive_stream = anyio.create_memory_object_stream(math.inf) # type: ignore[var-annotated]
self._send_stream = send_stream
task_group = anyio.create_task_group()
await self._exit_stack.enter_async_context(task_group)
await self._exit_stack.enter_async_context(send_stream)
task_group.start_soon(self._worker, receive_stream)
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self._exit_stack:
await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
self._exit_stack = None
self._send_stream = None
[docs]
def emit(self, event_id: str, *args: Any, **kwargs: Any) -> None:
"""Emit an event to all attached listeners.
Args:
event_id: The ID of the event to emit, e.g 'my_event'.
*args: args to pass to the listener(s).
**kwargs: kwargs to pass to the listener(s)
Returns:
None
"""
if not (self._send_stream and self._exit_stack):
raise RuntimeError("Emitter not initialized")
if listeners := self.listeners.get(event_id):
for listener in listeners:
self._send_stream.send_nowait((listener.fn, args, kwargs))
return
raise ImproperlyConfiguredException(f"no event listeners are registered for event ID: {event_id}")