from __future__ import annotations
import abc
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, Union, cast, runtime_checkable
if TYPE_CHECKING:
from collections.abc import Iterator
from inspect import Signature
from click import Group
from litestar._openapi.schema_generation import SchemaCreator
from litestar.app import Litestar
from litestar.config.app import AppConfig
from litestar.dto import AbstractDTO
from litestar.openapi.spec import Reference, Schema
from litestar.routes import BaseRoute
from litestar.typing import FieldDefinition
__all__ = (
"CLIPlugin",
"DIPlugin",
"InitPlugin",
"InitPluginProtocol",
"OpenAPISchemaPlugin",
"PluginProtocol",
"PluginRegistry",
"ReceiveRoutePlugin",
"SerializationPlugin",
)
[docs]
@runtime_checkable
class InitPluginProtocol(Protocol):
"""Protocol used to define plugins that affect the application's init process.
.. deprecated:: 2.15
Use 'InitPlugin' instead
"""
__slots__ = ()
[docs]
def on_app_init(self, app_config: AppConfig) -> AppConfig:
"""Receive the :class:`AppConfig<.config.app.AppConfig>` instance after `on_app_init` hooks have been called.
Examples:
.. code-block:: python
from litestar import Litestar, get
from litestar.di import Provide
from litestar.plugins import InitPluginProtocol
def get_name() -> str:
return "world"
@get("/my-path")
def my_route_handler(name: str) -> dict[str, str]:
return {"hello": name}
class MyPlugin(InitPluginProtocol):
def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config.dependencies["name"] = Provide(get_name)
app_config.route_handlers.append(my_route_handler)
return app_config
app = Litestar(plugins=[MyPlugin()])
Args:
app_config: The :class:`AppConfig <litestar.config.app.AppConfig>` instance.
Returns:
The app config object.
"""
return app_config # pragma: no cover
[docs]
class InitPlugin(InitPluginProtocol):
"""Protocol used to define plugins that affect the application's init process."""
__slots__ = ()
[docs]
def on_app_init(self, app_config: AppConfig) -> AppConfig:
"""Receive the :class:`AppConfig<.config.app.AppConfig>` instance after `on_app_init` hooks have been called.
Examples:
.. code-block:: python
from litestar import Litestar, get
from litestar.di import Provide
from litestar.plugins import InitPluginProtocol
def get_name() -> str:
return "world"
@get("/my-path")
def my_route_handler(name: str) -> dict[str, str]:
return {"hello": name}
class MyPlugin(InitPluginProtocol):
def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config.dependencies["name"] = Provide(get_name)
app_config.route_handlers.append(my_route_handler)
return app_config
app = Litestar(plugins=[MyPlugin()])
Args:
app_config: The :class:`AppConfig <litestar.config.app.AppConfig>` instance.
Returns:
The app config object.
"""
return app_config # pragma: no cover
[docs]
class ReceiveRoutePlugin:
"""Receive routes as they are added to the application."""
__slots__ = ()
[docs]
def receive_route(self, route: BaseRoute) -> None:
"""Receive routes as they are registered on an application."""
[docs]
class CLIPlugin:
"""Plugin protocol to extend the CLI Server Lifespan."""
[docs]
def on_cli_init(self, cli: Group) -> None:
"""Called when the CLI is initialized.
This can be used to extend or override existing commands.
Args:
cli: The root :class:`click.Group` of the Litestar CLI
Examples:
.. code-block:: python
from litestar import Litestar
from litestar.plugins import CLIPlugin
from click import Group
class CLIPlugin(CLIPlugin):
def on_cli_init(self, cli: Group) -> None:
@cli.command()
def is_debug_mode(app: Litestar):
print(app.debug)
app = Litestar(plugins=[CLIPlugin()])
"""
@contextmanager
def server_lifespan(self, app: Litestar) -> Iterator[None]:
yield
[docs]
class SerializationPlugin(abc.ABC):
"""Abstract base class for plugins that extend DTO functionality"""
[docs]
@abc.abstractmethod
def supports_type(self, field_definition: FieldDefinition) -> bool:
"""Given a value of indeterminate type, determine if this value is supported by the plugin.
Args:
field_definition: A parsed type.
Returns:
Whether the type is supported by the plugin.
"""
raise NotImplementedError()
[docs]
@abc.abstractmethod
def create_dto_for_type(self, field_definition: FieldDefinition) -> type[AbstractDTO]:
"""Given a parsed type, create a DTO class.
Args:
field_definition: A parsed type.
Returns:
A DTO class.
"""
raise NotImplementedError()
[docs]
class DIPlugin(abc.ABC):
"""Extend dependency injection"""
[docs]
@abc.abstractmethod
def has_typed_init(self, type_: Any) -> bool:
"""Return ``True`` if ``type_`` has type information available for its
:func:`__init__` method that cannot be extracted from this method's type
annotations (e.g. a Pydantic BaseModel subclass), and
:meth:`DIPlugin.get_typed_init` supports extraction of these annotations.
"""
...
[docs]
@abc.abstractmethod
def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]:
r"""Return signature and type information about the ``type_``\ s :func:`__init__`
method.
"""
...
[docs]
class OpenAPISchemaPlugin(abc.ABC):
"""Plugin to extend the support of OpenAPI schema generation for non-library types."""
[docs]
@staticmethod
def is_plugin_supported_type(value: Any) -> bool:
"""Given a value of indeterminate type, determine if this value is supported by the plugin.
This is called by the default implementation of :meth:`is_plugin_supported_field` for
backwards compatibility. User's should prefer to override that method instead.
Args:
value: An arbitrary value.
Returns:
A bool indicating whether the value is supported by the plugin.
"""
raise NotImplementedError(
"One of either is_plugin_supported_type or is_plugin_supported_field should be defined. "
"The default implementation of is_plugin_supported_field calls is_plugin_supported_type "
"for backwards compatibility. Users should prefer to override is_plugin_supported_field "
"as it receives a 'FieldDefinition' instance which is more useful than a raw type."
)
[docs]
@abc.abstractmethod
def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema | Reference:
"""Given a type annotation, transform it into an OpenAPI schema class.
Args:
field_definition: An :class:`OpenAPI <litestar.openapi.spec.schema.Schema>` instance.
schema_creator: An instance of the openapi SchemaCreator.
Returns:
An :class:`OpenAPI <litestar.openapi.spec.schema.Schema>` instance.
"""
raise NotImplementedError()
[docs]
def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool:
"""Given a :class:`FieldDefinition <litestar.typing.FieldDefinition>` that represents an indeterminate type,
determine if this value is supported by the plugin
Args:
field_definition: A parsed type.
Returns:
Whether the type is supported by the plugin.
"""
return self.is_plugin_supported_type(field_definition.annotation)
[docs]
@staticmethod
def is_undefined_sentinel(value: Any) -> bool:
"""Return ``True`` if ``value`` should be treated as an undefined field"""
return False
[docs]
@staticmethod
def is_constrained_field(field_definition: FieldDefinition) -> bool:
"""Return ``True`` if the field should be treated as constrained. If returning
``True``, constraints should be defined in the field's extras
"""
return False
PluginProtocol = Union[
CLIPlugin,
InitPluginProtocol,
OpenAPISchemaPlugin,
ReceiveRoutePlugin,
SerializationPlugin,
DIPlugin,
]
PluginT = TypeVar("PluginT", bound=PluginProtocol)
class PluginRegistry:
__slots__ = { # noqa: RUF023
"init": "Plugins that implement InitPlugin",
"openapi": "Plugins that implement OpenAPISchemaPlugin",
"receive_route": "ReceiveRoutePlugin instances",
"serialization": "Plugins that implement SerializationPlugin",
"cli": "Plugins that implement CLIPlugin",
"di": "DIPlugin instances",
"_plugins_by_type": None,
"_plugins": None,
"_get_plugins_of_type": None,
}
def __init__(self, plugins: list[PluginProtocol]) -> None:
self._plugins_by_type = {type(p): p for p in plugins}
self._plugins = frozenset(plugins)
self.init = tuple(p for p in plugins if isinstance(p, InitPluginProtocol))
self.openapi = tuple(p for p in plugins if isinstance(p, OpenAPISchemaPlugin))
self.receive_route = tuple(p for p in plugins if isinstance(p, ReceiveRoutePlugin))
self.serialization = tuple(p for p in plugins if isinstance(p, SerializationPlugin))
self.cli = tuple(p for p in plugins if isinstance(p, CLIPlugin))
self.di = tuple(p for p in plugins if isinstance(p, DIPlugin))
def get(self, type_: type[PluginT] | str) -> PluginT:
"""Return the registered plugin of ``type_``.
This should be used with subclasses of the plugin protocols.
"""
if isinstance(type_, str):
for plugin in self._plugins:
_name = plugin.__class__.__name__
_module = plugin.__class__.__module__
_qualname = (
f"{_module}.{plugin.__class__.__qualname__}"
if _module is not None and _module != "__builtin__"
else plugin.__class__.__qualname__
)
if type_ in {_name, _qualname}:
return cast("PluginT", plugin)
raise KeyError(f"No plugin of type {type_!r} registered")
try:
return cast("PluginT", self._plugins_by_type[type_]) # type: ignore[index]
except KeyError as e:
raise KeyError(f"No plugin of type {type_.__name__!r} registered") from e
def __iter__(self) -> Iterator[PluginProtocol]:
return iter(self._plugins)
def __contains__(self, item: PluginProtocol) -> bool:
return item in self._plugins