Source code for litestar.testing.life_span_handler

from __future__ import annotations

import contextlib
from math import inf
from typing import TYPE_CHECKING, Optional, cast

import anyio
from anyio import create_memory_object_stream
from anyio.streams.stapled import StapledObjectStream

if TYPE_CHECKING:
    from types import TracebackType

    from litestar.types import (
        ASGIApp,
        LifeSpanReceiveMessage,  # noqa: F401
        LifeSpanSendMessage,
        LifeSpanShutdownEvent,
        LifeSpanStartupEvent,
    )


[docs] class LifeSpanHandler:
[docs] def __init__(self, app: ASGIApp) -> None: self.stream_send = StapledObjectStream[Optional["LifeSpanSendMessage"]](*create_memory_object_stream(inf)) # type: ignore[arg-type] self.stream_receive = StapledObjectStream["LifeSpanReceiveMessage"](*create_memory_object_stream(inf)) # type: ignore[arg-type] self.app = app self._exit_stack = contextlib.AsyncExitStack()
async def __aenter__(self) -> LifeSpanHandler: async with contextlib.AsyncExitStack() as exit_stack: await exit_stack.enter_async_context(self.stream_send) await exit_stack.enter_async_context(self.stream_receive) self._tg = await exit_stack.enter_async_context(anyio.create_task_group()) with anyio.CancelScope() as cs: self._tg.start_soon(self.lifespan, cs) await self.wait_startup() exit_stack.push_async_callback(self.wait_shutdown) self._exit_stack = exit_stack.pop_all() return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: await self._exit_stack.__aexit__(exc_type, exc_value, traceback) async def receive(self) -> LifeSpanSendMessage: message = await self.stream_send.receive() return cast("LifeSpanSendMessage", message) async def wait_startup(self) -> None: event: LifeSpanStartupEvent = {"type": "lifespan.startup"} await self.stream_receive.send(event) message = await self.receive() if message["type"] not in ( "lifespan.startup.complete", "lifespan.startup.failed", ): raise RuntimeError( "Received unexpected ASGI message type. Expected 'lifespan.startup.complete' or " f"'lifespan.startup.failed'. Got {message['type']!r}", ) if message["type"] == "lifespan.startup.failed": await self.receive() async def wait_shutdown(self) -> None: lifespan_shutdown_event: LifeSpanShutdownEvent = {"type": "lifespan.shutdown"} await self.stream_receive.send(lifespan_shutdown_event) message = await self.receive() if message["type"] not in ( "lifespan.shutdown.complete", "lifespan.shutdown.failed", ): raise RuntimeError( "Received unexpected ASGI message type. Expected 'lifespan.shutdown.complete' or " f"'lifespan.shutdown.failed'. Got {message['type']!r}", ) if message["type"] == "lifespan.shutdown.failed": await self.receive() async def lifespan(self, cs: anyio.CancelScope) -> None: scope = {"type": "lifespan"} try: await self.app(scope, self.stream_receive.receive, self.stream_send.send) # type: ignore[arg-type] except BaseException: cs.cancel() raise