from __future__ import annotations
import asyncio
from contextlib import AsyncExitStack
from functools import partial
from typing import TYPE_CHECKING, Callable, overload
import asyncpg
from litestar.channels import ChannelsBackend
from litestar.exceptions import ImproperlyConfiguredException
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Awaitable, Iterable
[docs]
class AsyncPgChannelsBackend(ChannelsBackend):
_listener_conn: asyncpg.Connection
@overload
def __init__(self, dsn: str) -> None: ...
@overload
def __init__(
self,
*,
make_connection: Callable[[], Awaitable[asyncpg.Connection]],
) -> None: ...
[docs]
def __init__(
self,
dsn: str | None = None,
*,
make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None,
) -> None:
if not (dsn or make_connection):
raise ImproperlyConfiguredException("Need to specify dsn or make_connection")
self._subscribed_channels: set[str] = set()
self._exit_stack = AsyncExitStack()
self._connect = make_connection or partial(asyncpg.connect, dsn=dsn)
self._queue: asyncio.Queue[tuple[str, bytes]] | None = None
[docs]
async def on_startup(self) -> None:
self._queue = asyncio.Queue()
self._listener_conn = await self._connect()
[docs]
async def on_shutdown(self) -> None:
await self._listener_conn.close()
self._queue = None
[docs]
async def publish(self, data: bytes, channels: Iterable[str]) -> None:
if self._queue is None:
raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?")
dec_data = data.decode("utf-8")
conn = await self._connect()
try:
for channel in channels:
await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data)
finally:
await conn.close()
[docs]
async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
await self._listener_conn.add_listener(channel, self._listener) # type: ignore[arg-type]
self._subscribed_channels.add(channel)
[docs]
async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
await self._listener_conn.remove_listener(channel, self._listener) # type: ignore[arg-type]
self._subscribed_channels = self._subscribed_channels - set(channels)
[docs]
async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
if self._queue is None:
raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?")
while True:
channel, message = await self._queue.get()
self._queue.task_done()
# an UNLISTEN may be in transit while we're getting here, so we double-check
# that we are actually supposed to deliver this message
if channel in self._subscribed_channels:
yield channel, message
[docs]
async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
raise NotImplementedError()
def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None:
if not isinstance(payload, str):
raise RuntimeError("Invalid data received")
self._queue.put_nowait((channel, payload.encode("utf-8"))) # type: ignore[union-attr]