Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions litestar/channels/backends/psycopg.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,45 @@
from __future__ import annotations

from contextlib import AsyncExitStack
from typing import AsyncGenerator, Iterable
from typing import Any, AsyncGenerator, Iterable

import psycopg
from psycopg import AsyncConnection
from psycopg.sql import SQL, Identifier

from .base import ChannelsBackend


def _safe_quote(ident: str) -> str:
return '"{}"'.format(ident.replace('"', '""')) # sourcery skip
from litestar.channels.backends.base import ChannelsBackend


class PsycoPgChannelsBackend(ChannelsBackend):
_listener_conn: psycopg.AsyncConnection
_listener_conn: AsyncConnection[Any]

def __init__(self, pg_dsn: str) -> None:
self._pg_dsn = pg_dsn
self._subscribed_channels: set[str] = set()
self._exit_stack = AsyncExitStack()

async def on_startup(self) -> None:
self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True)
self._listener_conn = await AsyncConnection[Any].connect(self._pg_dsn, autocommit=True)
await self._exit_stack.enter_async_context(self._listener_conn)

async def on_shutdown(self) -> None:
await self._exit_stack.aclose()

async def publish(self, data: bytes, channels: Iterable[str]) -> None:
dec_data = data.decode("utf-8")
async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn:
async with await AsyncConnection[Any].connect(self._pg_dsn, autocommit=True) as conn:
for channel in channels:
await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data))
await conn.execute(SQL("NOTIFY {channel}, {data}").format(channel=Identifier(channel), data=dec_data))

async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
# can't use placeholders in LISTEN
await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore

await self._listener_conn.execute(SQL("LISTEN {channel}").format(channel=Identifier(channel)))
self._subscribed_channels.add(channel)
await self._listener_conn.commit()

async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
# can't use placeholders in UNLISTEN
await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore
await self._listener_conn.execute(SQL("UNLISTEN {channel}").format(channel=Identifier(channel)))
await self._listener_conn.commit()
self._subscribed_channels = self._subscribed_channels - set(channels)

async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_channels/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from litestar.exceptions import ImproperlyConfiguredException, LitestarException
from litestar.testing import TestClient, create_test_client
from litestar.types.asgi_types import WebSocketMode

from .util import get_from_stream
from tests.unit.test_channels.util import get_from_stream


@pytest.fixture(
Expand Down
Loading