from __future__ import annotations
import io
import re
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from litestar.concurrency import sync_to_thread
from litestar.exceptions import ImproperlyConfiguredException
from litestar.response.streaming import Stream
from litestar.utils import AsyncIteratorWrapper
if TYPE_CHECKING:
from litestar.background_tasks import BackgroundTask, BackgroundTasks
from litestar.types import ResponseCookies, ResponseHeaders, SSEData, StreamType
_LINE_BREAK_RE = re.compile(r"\r\n|\r|\n")
DEFAULT_SEPARATOR = "\r\n"
class _ServerSentEventIterator(AsyncIteratorWrapper[bytes]):
__slots__ = ("comment_message", "content_async_iterator", "event_id", "event_type", "retry_duration")
content_async_iterator: AsyncIterable[SSEData]
def __init__(
self,
content: str | bytes | StreamType[SSEData],
event_type: str | None = None,
event_id: int | str | None = None,
retry_duration: int | None = None,
comment_message: str | None = None,
) -> None:
self.comment_message = comment_message
self.event_id = event_id
self.event_type = event_type
self.retry_duration = retry_duration
chunks: list[bytes] = []
if comment_message is not None:
chunks.extend([f": {chunk}\r\n".encode() for chunk in _LINE_BREAK_RE.split(comment_message)])
if event_id is not None:
chunks.append(f"id: {event_id}\r\n".encode())
if event_type is not None:
chunks.append(f"event: {event_type}\r\n".encode())
if retry_duration is not None:
chunks.append(f"retry: {retry_duration}\r\n".encode())
super().__init__(iterator=chunks)
if not isinstance(content, (Iterator, AsyncIterator, AsyncIteratorWrapper)) and callable(content):
content = content() # type: ignore[unreachable]
if isinstance(content, (str, bytes)):
self.content_async_iterator = AsyncIteratorWrapper([content])
elif isinstance(content, (Iterable, Iterator)):
self.content_async_iterator = AsyncIteratorWrapper(content)
elif isinstance(content, (AsyncIterable, AsyncIterator, AsyncIteratorWrapper)):
self.content_async_iterator = content
else:
raise ImproperlyConfiguredException(f"Invalid type {type(content)} for ServerSentEvent")
def ensure_bytes(self, data: str | int | bytes | dict | ServerSentEventMessage | Any, sep: str) -> bytes:
if isinstance(data, ServerSentEventMessage):
return data.encode()
if isinstance(data, dict):
data["sep"] = sep
return ServerSentEventMessage(**data).encode()
return ServerSentEventMessage(
data=data, id=self.event_id, event=self.event_type, retry=self.retry_duration, sep=sep
).encode()
def _call_next(self) -> bytes:
try:
return next(self.iterator)
except StopIteration as e:
raise ValueError from e
async def _async_generator(self) -> AsyncGenerator[bytes, None]:
while True:
try:
yield await sync_to_thread(self._call_next)
except ValueError:
async for value in self.content_async_iterator:
yield self.ensure_bytes(value, DEFAULT_SEPARATOR)
break
[docs]
@dataclass
class ServerSentEventMessage:
data: str | int | bytes | None = ""
event: str | None = None
id: int | str | None = None
retry: int | None = None
comment: str | None = None
sep: str = DEFAULT_SEPARATOR
def encode(self) -> bytes:
buffer = io.StringIO()
if self.comment is not None:
for chunk in _LINE_BREAK_RE.split(str(self.comment)):
buffer.write(f": {chunk}")
buffer.write(self.sep)
if self.id is not None:
buffer.write(_LINE_BREAK_RE.sub("", f"id: {self.id}"))
buffer.write(self.sep)
if self.event is not None:
buffer.write(_LINE_BREAK_RE.sub("", f"event: {self.event}"))
buffer.write(self.sep)
if self.data is not None:
data = self.data
for chunk in _LINE_BREAK_RE.split(data.decode() if isinstance(data, bytes) else str(data)):
buffer.write(f"data: {chunk}")
buffer.write(self.sep)
if self.retry is not None:
buffer.write(f"retry: {self.retry}")
buffer.write(self.sep)
buffer.write(self.sep)
return buffer.getvalue().encode("utf-8")
[docs]
class ServerSentEvent(Stream):
[docs]
def __init__(
self,
content: str | bytes | StreamType[SSEData],
*,
background: BackgroundTask | BackgroundTasks | None = None,
cookies: ResponseCookies | None = None,
encoding: str = "utf-8",
headers: ResponseHeaders | None = None,
event_type: str | None = None,
event_id: int | str | None = None,
retry_duration: int | None = None,
comment_message: str | None = None,
status_code: int | None = None,
) -> None:
"""Initialize the response.
Args:
content: Bytes, string or a sync or async iterator or iterable.
background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or
:class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished.
Defaults to None.
cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response
``Set-Cookie`` header.
encoding: The encoding to be used for the response headers.
headers: A string keyed dictionary of response headers. Header keys are insensitive.
status_code: The response status code. Defaults to 200.
event_type: The type of the SSE event. If given, the browser will sent the event to any 'event-listener'
declared for it (e.g. via 'addEventListener' in JS).
event_id: The event ID. This sets the event source's 'last event id'.
retry_duration: Retry duration in milliseconds.
comment_message: A comment message. This value is ignored by clients and is used mostly for pinging.
"""
super().__init__(
content=_ServerSentEventIterator(
content=content,
event_type=event_type,
event_id=event_id,
retry_duration=retry_duration,
comment_message=comment_message,
),
media_type="text/event-stream",
background=background,
cookies=cookies,
encoding=encoding,
headers=headers,
status_code=status_code,
)
self.headers.setdefault("Cache-Control", "no-cache")
self.headers["Connection"] = "keep-alive"
self.headers["X-Accel-Buffering"] = "no"