from __future__ import annotations
import itertools
import re
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Iterator
from dataclasses import dataclass
from functools import partial
from io import StringIO
from typing import TYPE_CHECKING, Any
import anyio
from litestar.concurrency import sync_to_thread
from litestar.enums import MediaType
from litestar.exceptions import ImproperlyConfiguredException
from litestar.response.streaming import ASGIStreamingResponse, Stream
from litestar.utils import AsyncIteratorWrapper
from litestar.utils.helpers import get_enum_string_value
__all__ = ("ASGIStreamingSSEResponse", "ServerSentEvent", "ServerSentEventMessage")
if TYPE_CHECKING:
from litestar.background_tasks import BackgroundTask, BackgroundTasks
from litestar.connection import Request
from litestar.datastructures.cookie import Cookie
from litestar.response.base import ASGIResponse
from litestar.types import Receive, ResponseCookies, ResponseHeaders, Send, SSEData, StreamType, TypeEncodersMap
_LINE_BREAK_RE = re.compile(r"\r\n|\r|\n")
DEFAULT_SEPARATOR = "\r\n"
class _ServerSentEventIterator(AsyncIteratorWrapper[bytes]):
__slots__ = ("comment_message", "content_async_iterator", "event_id", "event_type", "retry_duration")
content_async_iterator: AsyncIterable[SSEData]
def __init__(
self,
content: str | bytes | StreamType[SSEData] | Callable[[], str | bytes | StreamType[SSEData]],
event_type: str | None = None,
event_id: int | str | None = None,
retry_duration: int | None = None,
comment_message: str | None = None,
) -> None:
self.comment_message = comment_message
self.event_id = event_id
self.event_type = event_type
self.retry_duration = retry_duration
chunks: list[bytes] = []
if comment_message is not None:
chunks.extend(f": {chunk}{DEFAULT_SEPARATOR}".encode() for chunk in _LINE_BREAK_RE.split(comment_message))
if event_id is not None:
chunks.append(f"id: {event_id}{DEFAULT_SEPARATOR}".encode())
if event_type is not None:
chunks.append(f"event: {event_type}{DEFAULT_SEPARATOR}".encode())
if retry_duration is not None:
chunks.append(f"retry: {retry_duration}{DEFAULT_SEPARATOR}".encode())
super().__init__(iterator=chunks)
if not isinstance(content, (Iterator, AsyncIterator, AsyncIteratorWrapper)) and callable(content):
content = content()
if isinstance(content, (str, bytes)):
self.content_async_iterator = AsyncIteratorWrapper([content])
elif isinstance(content, Iterable):
self.content_async_iterator = AsyncIteratorWrapper(content)
elif isinstance(content, (AsyncIterable, AsyncIteratorWrapper)):
self.content_async_iterator = content
else:
raise ImproperlyConfiguredException(f"Invalid type {type(content)} for ServerSentEvent")
def ensure_bytes(self, data: str | int | bytes | dict | ServerSentEventMessage, sep: str) -> bytes:
if isinstance(data, ServerSentEventMessage):
return data.encode()
if isinstance(data, dict):
data["sep"] = sep
return ServerSentEventMessage(**data).encode()
return ServerSentEventMessage(
data=data, id=self.event_id, event=self.event_type, retry=self.retry_duration, sep=sep
).encode()
def _call_next(self) -> bytes:
try:
return next(self.iterator)
except StopIteration as e:
raise ValueError from e
async def _async_generator(self) -> AsyncGenerator[bytes, None]:
while True:
try:
yield await sync_to_thread(self._call_next)
except ValueError:
async for value in self.content_async_iterator:
yield self.ensure_bytes(value, DEFAULT_SEPARATOR)
break
[docs]
@dataclass
class ServerSentEventMessage:
data: str | int | bytes | None = ""
event: str | None = None
id: int | str | None = None
retry: int | None = None
comment: str | None = None
sep: str = DEFAULT_SEPARATOR
def encode(self) -> bytes:
buffer = StringIO()
if self.comment is not None:
for chunk in _LINE_BREAK_RE.split(self.comment):
buffer.write(f": {chunk}")
buffer.write(self.sep)
if self.id is not None:
buffer.write(_LINE_BREAK_RE.sub("", f"id: {self.id}"))
buffer.write(self.sep)
if self.event is not None:
buffer.write(_LINE_BREAK_RE.sub("", f"event: {self.event}"))
buffer.write(self.sep)
if self.data is not None:
data = self.data
for chunk in _LINE_BREAK_RE.split(data.decode() if isinstance(data, bytes) else str(data)):
buffer.write(f"data: {chunk}")
buffer.write(self.sep)
if self.retry is not None:
buffer.write(f"retry: {self.retry}")
buffer.write(self.sep)
buffer.write(self.sep)
return buffer.getvalue().encode("utf-8")
[docs]
class ASGIStreamingSSEResponse(ASGIStreamingResponse):
"""ASGI streaming response with optional keepalive ping support for SSE."""
__slots__ = ("_ping_interval", "_send_lock")
[docs]
def __init__(self, *, ping_interval: float | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._ping_interval = ping_interval
self._send_lock = anyio.Lock() if ping_interval is not None else None
async def _send(self, send: Send, payload: bytes) -> None:
"""Send a body chunk with lock for concurrent ping/stream safety."""
if self._send_lock is None:
raise RuntimeError("_send called without a send lock; ping_interval must be set")
async with self._send_lock:
await send({"type": "http.response.body", "body": payload, "more_body": True})
async def _ping(self, send: Send, stop_event: anyio.Event) -> None:
"""Send SSE comment keepalive pings at the configured interval."""
if self._ping_interval is None:
raise RuntimeError("_ping called without a ping interval configured")
while not stop_event.is_set():
with anyio.move_on_after(self._ping_interval):
await stop_event.wait()
if not stop_event.is_set():
await self._send(send, b": ping\r\n\r\n")
[docs]
async def send_body(self, send: Send, receive: Receive) -> None:
"""Emit the response body, with optional keepalive pings."""
if self._ping_interval is None:
await super().send_body(send, receive)
return
stop_event = anyio.Event()
async with anyio.create_task_group() as tg:
tg.start_soon(partial(self._listen_for_disconnect, tg.cancel_scope, receive))
tg.start_soon(self._ping, send, stop_event)
async for chunk in self.iterator:
data = chunk if isinstance(chunk, bytes) else chunk.encode(self.encoding)
await self._send(send, data)
stop_event.set()
tg.cancel_scope.cancel()
await send({"type": "http.response.body", "body": b"", "more_body": False})
[docs]
class ServerSentEvent(Stream):
[docs]
def __init__(
self,
content: str | bytes | StreamType[SSEData],
*,
background: BackgroundTask | BackgroundTasks | None = None,
cookies: ResponseCookies | None = None,
encoding: str = "utf-8",
headers: ResponseHeaders | None = None,
event_type: str | None = None,
event_id: int | str | None = None,
retry_duration: int | None = None,
comment_message: str | None = None,
status_code: int | None = None,
ping_interval: float | None = None,
) -> None:
"""Initialize the response.
Args:
content: Bytes, string or a sync or async iterator or iterable.
background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or
:class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished.
Defaults to None.
cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response
``Set-Cookie`` header.
encoding: The encoding to be used for the response headers.
headers: A string keyed dictionary of response headers. Header keys are insensitive.
status_code: The response status code. Defaults to 200.
event_type: The type of the SSE event. If given, the browser will sent the event to any 'event-listener'
declared for it (e.g. via 'addEventListener' in JS).
event_id: The event ID. This sets the event source's 'last event id'.
retry_duration: Retry duration in milliseconds.
comment_message: A comment message. This value is ignored by clients and is used mostly for pinging.
ping_interval: Interval in seconds between keepalive pings. When set, an SSE comment
(``: ping``) is sent at the specified interval to prevent connection timeouts from
reverse proxies or clients. Defaults to ``None`` (no pings).
"""
if ping_interval is not None and ping_interval <= 0:
raise ImproperlyConfiguredException("ping_interval must be a positive number")
self.ping_interval = ping_interval
super().__init__(
content=_ServerSentEventIterator(
content=content,
event_type=event_type,
event_id=event_id,
retry_duration=retry_duration,
comment_message=comment_message,
),
media_type="text/event-stream",
background=background,
cookies=cookies,
encoding=encoding,
headers=headers,
status_code=status_code,
)
self.headers.setdefault("Cache-Control", "no-cache")
self.headers["Connection"] = "keep-alive"
self.headers["X-Accel-Buffering"] = "no"
[docs]
def to_asgi_response(
self,
request: Request,
*,
background: BackgroundTask | BackgroundTasks | None = None,
cookies: Iterable[Cookie] | None = None,
headers: dict[str, str] | None = None,
is_head_response: bool = False,
media_type: MediaType | str | None = None,
status_code: int | None = None,
type_encoders: TypeEncodersMap | None = None,
) -> ASGIResponse:
"""Create an ASGI streaming response, with optional keepalive ping support.
When ``ping_interval`` is set, returns an :class:`ASGIStreamingSSEResponse` that
sends periodic SSE comment pings. Otherwise delegates to the parent implementation.
Args:
request: The :class:`Request <.connection.Request>` instance.
background: Background task(s) to be executed after the response is sent.
cookies: A list of cookies to be set on the response.
headers: Additional headers to be merged with the response headers. Response headers take precedence.
is_head_response: Whether the response is a HEAD response.
media_type: Media type for the response. If ``media_type`` is already set on the response, this is ignored.
status_code: Status code for the response. If ``status_code`` is already set on the response, this is
ignored.
type_encoders: A dictionary of type encoders to use for encoding the response content.
Returns:
An ASGIStreamingResponse (or ASGIStreamingSSEResponse when ping_interval is set).
"""
if self.ping_interval is None:
return super().to_asgi_response(
request,
background=background,
cookies=cookies,
headers=headers,
is_head_response=is_head_response,
media_type=media_type,
status_code=status_code,
type_encoders=type_encoders,
)
headers = {**headers, **self.headers} if headers is not None else self.headers
cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies)
media_type = get_enum_string_value(media_type or self.media_type or MediaType.JSON)
return ASGIStreamingSSEResponse(
ping_interval=self.ping_interval,
background=self.background or background,
content_length=0,
cookies=cookies,
encoding=self.encoding,
headers=headers,
is_head_response=is_head_response,
iterator=self.iterator,
media_type=media_type,
status_code=self.status_code or status_code,
)