Source code for litestar.channels.backends.psycopg

from __future__ import annotations

from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any

from psycopg import AsyncConnection
from psycopg.sql import SQL, Identifier

from litestar.channels.backends.base import ChannelsBackend

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator, Iterable


[docs] class PsycoPgChannelsBackend(ChannelsBackend): _listener_conn: AsyncConnection[Any]
[docs] def __init__(self, pg_dsn: str) -> None: self._pg_dsn = pg_dsn self._subscribed_channels: set[str] = set() self._exit_stack = AsyncExitStack()
[docs] async def on_startup(self) -> None: self._listener_conn = await AsyncConnection[Any].connect(self._pg_dsn, autocommit=True) await self._exit_stack.enter_async_context(self._listener_conn)
[docs] async def on_shutdown(self) -> None: await self._exit_stack.aclose()
[docs] async def publish(self, data: bytes, channels: Iterable[str]) -> None: dec_data = data.decode("utf-8") async with await AsyncConnection[Any].connect(self._pg_dsn, autocommit=True) as conn: for channel in channels: await conn.execute(SQL("NOTIFY {channel}, {data}").format(channel=Identifier(channel), data=dec_data))
[docs] async def subscribe(self, channels: Iterable[str]) -> None: for channel in set(channels) - self._subscribed_channels: await self._listener_conn.execute(SQL("LISTEN {channel}").format(channel=Identifier(channel))) self._subscribed_channels.add(channel) await self._listener_conn.commit()
[docs] async def unsubscribe(self, channels: Iterable[str]) -> None: for channel in channels: await self._listener_conn.execute(SQL("UNLISTEN {channel}").format(channel=Identifier(channel))) await self._listener_conn.commit() self._subscribed_channels = self._subscribed_channels - set(channels)
[docs] async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: async for notify in self._listener_conn.notifies(): yield notify.channel, notify.payload.encode("utf-8")
[docs] async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: raise NotImplementedError()