from __future__ import annotations
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any
from psycopg import AsyncConnection
from psycopg.sql import SQL, Identifier
from litestar.channels.backends.base import ChannelsBackend
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterable
[docs]
class PsycoPgChannelsBackend(ChannelsBackend):
_listener_conn: AsyncConnection[Any]
[docs]
def __init__(self, pg_dsn: str) -> None:
self._pg_dsn = pg_dsn
self._subscribed_channels: set[str] = set()
self._exit_stack = AsyncExitStack()
[docs]
async def on_startup(self) -> None:
self._listener_conn = await AsyncConnection[Any].connect(self._pg_dsn, autocommit=True)
await self._exit_stack.enter_async_context(self._listener_conn)
[docs]
async def on_shutdown(self) -> None:
await self._exit_stack.aclose()
[docs]
async def publish(self, data: bytes, channels: Iterable[str]) -> None:
dec_data = data.decode("utf-8")
async with await AsyncConnection[Any].connect(self._pg_dsn, autocommit=True) as conn:
for channel in channels:
await conn.execute(SQL("NOTIFY {channel}, {data}").format(channel=Identifier(channel), data=dec_data))
[docs]
async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
await self._listener_conn.execute(SQL("LISTEN {channel}").format(channel=Identifier(channel)))
self._subscribed_channels.add(channel)
await self._listener_conn.commit()
[docs]
async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
await self._listener_conn.execute(SQL("UNLISTEN {channel}").format(channel=Identifier(channel)))
await self._listener_conn.commit()
self._subscribed_channels = self._subscribed_channels - set(channels)
[docs]
async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
async for notify in self._listener_conn.notifies():
yield notify.channel, notify.payload.encode("utf-8")
[docs]
async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
raise NotImplementedError()