Source code for litestar.response.sse

from __future__ import annotations

import io
import re
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from litestar.concurrency import sync_to_thread
from litestar.exceptions import ImproperlyConfiguredException
from litestar.response.streaming import Stream
from litestar.utils import AsyncIteratorWrapper

if TYPE_CHECKING:
    from litestar.background_tasks import BackgroundTask, BackgroundTasks
    from litestar.types import ResponseCookies, ResponseHeaders, SSEData, StreamType

_LINE_BREAK_RE = re.compile(r"\r\n|\r|\n")
DEFAULT_SEPARATOR = "\r\n"


class _ServerSentEventIterator(AsyncIteratorWrapper[bytes]):
    __slots__ = ("comment_message", "content_async_iterator", "event_id", "event_type", "retry_duration")

    content_async_iterator: AsyncIterable[SSEData]

    def __init__(
        self,
        content: str | bytes | StreamType[SSEData],
        event_type: str | None = None,
        event_id: int | str | None = None,
        retry_duration: int | None = None,
        comment_message: str | None = None,
    ) -> None:
        self.comment_message = comment_message
        self.event_id = event_id
        self.event_type = event_type
        self.retry_duration = retry_duration
        chunks: list[bytes] = []
        if comment_message is not None:
            chunks.extend([f": {chunk}\r\n".encode() for chunk in _LINE_BREAK_RE.split(comment_message)])

        if event_id is not None:
            chunks.append(f"id: {event_id}\r\n".encode())

        if event_type is not None:
            chunks.append(f"event: {event_type}\r\n".encode())

        if retry_duration is not None:
            chunks.append(f"retry: {retry_duration}\r\n".encode())

        super().__init__(iterator=chunks)

        if not isinstance(content, (Iterator, AsyncIterator, AsyncIteratorWrapper)) and callable(content):
            content = content()  # type: ignore[unreachable]

        if isinstance(content, (str, bytes)):
            self.content_async_iterator = AsyncIteratorWrapper([content])
        elif isinstance(content, (Iterable, Iterator)):
            self.content_async_iterator = AsyncIteratorWrapper(content)
        elif isinstance(content, (AsyncIterable, AsyncIterator, AsyncIteratorWrapper)):
            self.content_async_iterator = content
        else:
            raise ImproperlyConfiguredException(f"Invalid type {type(content)} for ServerSentEvent")

    def ensure_bytes(self, data: str | int | bytes | dict | ServerSentEventMessage | Any, sep: str) -> bytes:
        if isinstance(data, ServerSentEventMessage):
            return data.encode()
        if isinstance(data, dict):
            data["sep"] = sep
            return ServerSentEventMessage(**data).encode()

        return ServerSentEventMessage(
            data=data, id=self.event_id, event=self.event_type, retry=self.retry_duration, sep=sep
        ).encode()

    def _call_next(self) -> bytes:
        try:
            return next(self.iterator)
        except StopIteration as e:
            raise ValueError from e

    async def _async_generator(self) -> AsyncGenerator[bytes, None]:
        while True:
            try:
                yield await sync_to_thread(self._call_next)
            except ValueError:
                async for value in self.content_async_iterator:
                    yield self.ensure_bytes(value, DEFAULT_SEPARATOR)
                break


[docs] @dataclass class ServerSentEventMessage: data: str | int | bytes | None = "" event: str | None = None id: int | str | None = None retry: int | None = None comment: str | None = None sep: str = DEFAULT_SEPARATOR def encode(self) -> bytes: buffer = io.StringIO() if self.comment is not None: for chunk in _LINE_BREAK_RE.split(str(self.comment)): buffer.write(f": {chunk}") buffer.write(self.sep) if self.id is not None: buffer.write(_LINE_BREAK_RE.sub("", f"id: {self.id}")) buffer.write(self.sep) if self.event is not None: buffer.write(_LINE_BREAK_RE.sub("", f"event: {self.event}")) buffer.write(self.sep) if self.data is not None: data = self.data for chunk in _LINE_BREAK_RE.split(data.decode() if isinstance(data, bytes) else str(data)): buffer.write(f"data: {chunk}") buffer.write(self.sep) if self.retry is not None: buffer.write(f"retry: {self.retry}") buffer.write(self.sep) buffer.write(self.sep) return buffer.getvalue().encode("utf-8")
[docs] class ServerSentEvent(Stream):
[docs] def __init__( self, content: str | bytes | StreamType[SSEData], *, background: BackgroundTask | BackgroundTasks | None = None, cookies: ResponseCookies | None = None, encoding: str = "utf-8", headers: ResponseHeaders | None = None, event_type: str | None = None, event_id: int | str | None = None, retry_duration: int | None = None, comment_message: str | None = None, status_code: int | None = None, ) -> None: """Initialize the response. Args: content: Bytes, string or a sync or async iterator or iterable. background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. Defaults to None. cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response ``Set-Cookie`` header. encoding: The encoding to be used for the response headers. headers: A string keyed dictionary of response headers. Header keys are insensitive. status_code: The response status code. Defaults to 200. event_type: The type of the SSE event. If given, the browser will sent the event to any 'event-listener' declared for it (e.g. via 'addEventListener' in JS). event_id: The event ID. This sets the event source's 'last event id'. retry_duration: Retry duration in milliseconds. comment_message: A comment message. This value is ignored by clients and is used mostly for pinging. """ super().__init__( content=_ServerSentEventIterator( content=content, event_type=event_type, event_id=event_id, retry_duration=retry_duration, comment_message=comment_message, ), media_type="text/event-stream", background=background, cookies=cookies, encoding=encoding, headers=headers, status_code=status_code, ) self.headers.setdefault("Cache-Control", "no-cache") self.headers["Connection"] = "keep-alive" self.headers["X-Accel-Buffering"] = "no"