Source code for litestar.response.sse

from __future__ import annotations

import re
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Iterator
from dataclasses import dataclass
from functools import partial
from io import StringIO
from typing import TYPE_CHECKING, Any

import anyio

from litestar.concurrency import sync_to_thread
from litestar.enums import MediaType
from litestar.exceptions import ImproperlyConfiguredException
from litestar.response.streaming import ASGIStreamingResponse, Stream
from litestar.utils import AsyncIteratorWrapper
from litestar.utils.helpers import get_enum_string_value

__all__ = ("ASGIStreamingSSEResponse", "ServerSentEvent", "ServerSentEventMessage")

if TYPE_CHECKING:
    from litestar.background_tasks import BackgroundTask, BackgroundTasks
    from litestar.connection import Request
    from litestar.datastructures.cookie import Cookie
    from litestar.response.base import ASGIResponse
    from litestar.types import Receive, ResponseCookies, ResponseHeaders, Send, SSEData, StreamType, TypeEncodersMap

_LINE_BREAK_RE = re.compile(r"\r\n|\r|\n")
DEFAULT_SEPARATOR = "\r\n"


class _ServerSentEventIterator(AsyncIteratorWrapper[bytes]):
    __slots__ = ("comment_message", "content_async_iterator", "event_id", "event_type", "retry_duration")

    content_async_iterator: AsyncIterable[SSEData]

    def __init__(
        self,
        content: str | bytes | StreamType[SSEData] | Callable[[], str | bytes | StreamType[SSEData]],
        event_type: str | None = None,
        event_id: int | str | None = None,
        retry_duration: int | None = None,
        comment_message: str | None = None,
    ) -> None:
        self.comment_message = comment_message
        self.event_id = event_id
        self.event_type = event_type
        self.retry_duration = retry_duration
        chunks: list[bytes] = []

        if comment_message is not None:
            chunks.extend(f": {chunk}{DEFAULT_SEPARATOR}".encode() for chunk in _LINE_BREAK_RE.split(comment_message))

        if event_id is not None:
            chunks.append(f"id: {event_id}{DEFAULT_SEPARATOR}".encode())

        if event_type is not None:
            chunks.append(f"event: {event_type}{DEFAULT_SEPARATOR}".encode())

        if retry_duration is not None:
            chunks.append(f"retry: {retry_duration}{DEFAULT_SEPARATOR}".encode())

        super().__init__(iterator=chunks)

        if not isinstance(content, (Iterator, AsyncIterator, AsyncIteratorWrapper)) and callable(content):
            content = content()

        if isinstance(content, (str, bytes)):
            self.content_async_iterator = AsyncIteratorWrapper([content])
        elif isinstance(content, Iterable):
            self.content_async_iterator = AsyncIteratorWrapper(content)
        elif isinstance(content, (AsyncIterable, AsyncIteratorWrapper)):
            self.content_async_iterator = content
        else:
            raise ImproperlyConfiguredException(f"Invalid type {type(content)} for ServerSentEvent")

    def ensure_bytes(self, data: str | int | bytes | dict | ServerSentEventMessage, sep: str) -> bytes:
        if isinstance(data, ServerSentEventMessage):
            return data.encode()
        if isinstance(data, dict):
            data["sep"] = sep
            return ServerSentEventMessage(**data).encode()

        return ServerSentEventMessage(
            data=data, id=self.event_id, event=self.event_type, retry=self.retry_duration, sep=sep
        ).encode()

    def _call_next(self) -> bytes:
        try:
            return next(self.iterator)
        except StopIteration as e:
            raise ValueError from e

    async def _async_generator(self) -> AsyncGenerator[bytes, None]:
        while True:
            try:
                yield await sync_to_thread(self._call_next)
            except ValueError:
                async for value in self.content_async_iterator:
                    yield self.ensure_bytes(value, DEFAULT_SEPARATOR)
                break


[docs] @dataclass class ServerSentEventMessage: data: str | int | bytes | None = "" event: str | None = None id: int | str | None = None retry: int | None = None comment: str | None = None sep: str = DEFAULT_SEPARATOR def encode(self) -> bytes: buffer = StringIO() if self.comment is not None: for chunk in _LINE_BREAK_RE.split(self.comment): buffer.write(f": {chunk}") buffer.write(self.sep) if self.id is not None: buffer.write(_LINE_BREAK_RE.sub("", f"id: {self.id}")) buffer.write(self.sep) if self.event is not None: buffer.write(_LINE_BREAK_RE.sub("", f"event: {self.event}")) buffer.write(self.sep) if self.data is not None: data = self.data for chunk in _LINE_BREAK_RE.split(data.decode() if isinstance(data, bytes) else str(data)): buffer.write(f"data: {chunk}") buffer.write(self.sep) if self.retry is not None: buffer.write(f"retry: {self.retry}") buffer.write(self.sep) buffer.write(self.sep) return buffer.getvalue().encode("utf-8")
[docs] class ASGIStreamingSSEResponse(ASGIStreamingResponse): """ASGI streaming response with optional keepalive ping support for SSE.""" __slots__ = ("_ping_interval", "_send_lock")
[docs] def __init__(self, *, ping_interval: float | None = None, **kwargs: Any) -> None: super().__init__(**kwargs) self._ping_interval = ping_interval self._send_lock = anyio.Lock() if ping_interval is not None else None
async def _send(self, send: Send, payload: bytes) -> None: """Send a body chunk with lock for concurrent ping/stream safety.""" if self._send_lock is None: raise RuntimeError("_send called without a send lock; ping_interval must be set") async with self._send_lock: await send({"type": "http.response.body", "body": payload, "more_body": True}) async def _ping(self, send: Send, stop_event: anyio.Event) -> None: """Send SSE comment keepalive pings at the configured interval.""" if self._ping_interval is None: raise RuntimeError("_ping called without a ping interval configured") while not stop_event.is_set(): with anyio.move_on_after(self._ping_interval): await stop_event.wait() if not stop_event.is_set(): await self._send(send, b": ping\r\n\r\n")
[docs] async def send_body(self, send: Send, receive: Receive) -> None: """Emit the response body, with optional keepalive pings.""" if self._ping_interval is None: await super().send_body(send, receive) return stop_event = anyio.Event() async with anyio.create_task_group() as tg: tg.start_soon(partial(self._listen_for_disconnect, tg.cancel_scope, receive)) tg.start_soon(self._ping, send, stop_event) async for chunk in self.iterator: data = chunk if isinstance(chunk, bytes) else chunk.encode(self.encoding) await self._send(send, data) stop_event.set() tg.cancel_scope.cancel() await send({"type": "http.response.body", "body": b"", "more_body": False})
[docs] class ServerSentEvent(Stream):
[docs] def __init__( self, content: str | bytes | StreamType[SSEData], *, background: BackgroundTask | BackgroundTasks | None = None, cookies: ResponseCookies | None = None, encoding: str = "utf-8", headers: ResponseHeaders | None = None, event_type: str | None = None, event_id: int | str | None = None, retry_duration: int | None = None, comment_message: str | None = None, status_code: int | None = None, ping_interval: float | None = None, ) -> None: """Initialize the response. Args: content: Bytes, string or a sync or async iterator or iterable. background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. Defaults to None. cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response ``Set-Cookie`` header. encoding: The encoding to be used for the response headers. headers: A string keyed dictionary of response headers. Header keys are insensitive. status_code: The response status code. Defaults to 200. event_type: The type of the SSE event. If given, the browser will sent the event to any 'event-listener' declared for it (e.g. via 'addEventListener' in JS). event_id: The event ID. This sets the event source's 'last event id'. retry_duration: Retry duration in milliseconds. comment_message: A comment message. This value is ignored by clients and is used mostly for pinging. ping_interval: Interval in seconds between keepalive pings. When set, an SSE comment (``: ping``) is sent at the specified interval to prevent connection timeouts from reverse proxies or clients. Defaults to ``None`` (no pings). """ if ping_interval is not None and ping_interval <= 0: raise ImproperlyConfiguredException("ping_interval must be a positive number") self.ping_interval = ping_interval super().__init__( content=_ServerSentEventIterator( content=content, event_type=event_type, event_id=event_id, retry_duration=retry_duration, comment_message=comment_message, ), media_type="text/event-stream", background=background, cookies=cookies, encoding=encoding, headers=headers, status_code=status_code, ) self.headers.setdefault("Cache-Control", "no-cache") self.headers["Connection"] = "keep-alive" self.headers["X-Accel-Buffering"] = "no"
[docs] def to_asgi_response( self, request: Request, *, background: BackgroundTask | BackgroundTasks | None = None, cookies: Iterable[Cookie] | None = None, headers: dict[str, str] | None = None, is_head_response: bool = False, media_type: MediaType | str | None = None, status_code: int | None = None, type_encoders: TypeEncodersMap | None = None, ) -> ASGIResponse: """Create an ASGI streaming response, with optional keepalive ping support. When ``ping_interval`` is set, returns an :class:`ASGIStreamingSSEResponse` that sends periodic SSE comment pings. Otherwise delegates to the parent implementation. Args: request: The :class:`Request <.connection.Request>` instance. background: Background task(s) to be executed after the response is sent. cookies: A list of cookies to be set on the response. headers: Additional headers to be merged with the response headers. Response headers take precedence. is_head_response: Whether the response is a HEAD response. media_type: Media type for the response. If ``media_type`` is already set on the response, this is ignored. status_code: Status code for the response. If ``status_code`` is already set on the response, this is ignored. type_encoders: A dictionary of type encoders to use for encoding the response content. Returns: An ASGIStreamingResponse (or ASGIStreamingSSEResponse when ping_interval is set). """ if self.ping_interval is None: return super().to_asgi_response( request, background=background, cookies=cookies, headers=headers, is_head_response=is_head_response, media_type=media_type, status_code=status_code, type_encoders=type_encoders, ) headers = {**headers, **self.headers} if headers is not None else self.headers media_type = get_enum_string_value(media_type or self.media_type or MediaType.JSON) return ASGIStreamingSSEResponse( ping_interval=self.ping_interval, background=self.background or background, content_length=0, cookies=self._merge_cookies(cookies), encoding=self.encoding, headers=headers, is_head_response=is_head_response, iterator=self.iterator, media_type=media_type, status_code=self.status_code or status_code, )