Source code for litestar.middleware.constraints

import collections
import dataclasses
import functools
import inspect
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast

from typing_extensions import Self

from litestar.exceptions import LitestarException
from litestar.middleware.base import ASGIMiddleware
from litestar.types import Middleware
from litestar.utils.module_loader import import_string

if TYPE_CHECKING:
    from litestar.types.composite_types import MiddlewareFactory

__all__ = (
    "ConstraintViolationError",
    "CycleError",
    "MiddlewareConstraintError",
    "MiddlewareConstraints",
    "MiddlewareForwardRef",
    "check_middleware_constraints",
)


[docs] class MiddlewareConstraintError(LitestarException): pass
[docs] class ConstraintViolationError(MiddlewareConstraintError): pass
[docs] class CycleError(MiddlewareConstraintError): pass
[docs] @dataclasses.dataclass(frozen=True) class MiddlewareForwardRef: """Forward reference to a middleware""" target: str """Absolute path to an importable name of the middleware""" ignore_import_error: bool r""" If 'True', ignore :exc:`ImportError`\ s will be ignored when resolving the middleware """ @staticmethod @functools.cache def _resolve(target: str, ignore_not_found: bool) -> "Middleware | None": try: return cast("Middleware", import_string(target)) except ImportError: if ignore_not_found: return None raise
[docs] def resolve(self) -> "Middleware | None": """Resolve the reference to a concrete value by importing the target path. If ``ignore_import_error=True`` and an :exc:`ImportError` is raised, ignore the error and return ``None`` """ return self._resolve(self.target, self.ignore_import_error)
@dataclasses.dataclass class _ResolvedMiddlewareConstraints: before: tuple["Middleware | MiddlewareFactory", ...] after: tuple["Middleware | MiddlewareFactory", ...] first: bool last: bool unique: bool @property def is_empty(self) -> bool: return not (self.before or self.after or self.first or self.last or self.unique)
[docs] @dataclasses.dataclass(frozen=True) class MiddlewareConstraints: """Constraints for a middleware.""" before: tuple["MiddlewareForwardRef | Middleware | MiddlewareFactory", ...] = () """ Tuple of middlewares that, if present, need to appear *before* the middleware this constraint is applied to (i.e. closer to the application) """ after: tuple["MiddlewareForwardRef | Middleware | MiddlewareFactory", ...] = () """ Tuple of middlewares that, if present, need to appear *after* the middleware this constraint is applied to (i.e. closer to the handler) """ first: bool = False """ If ``True``, require the middleware to be the first (i.e. the first middleware on the application). Mutually exclusive with ``last=True``. Implicitly sets ``unique=True`` """ last: bool = False """ If ``True``, require the middleware to be the last (i.e. the last middleware on the handler). Mutually exclusive with ``first=True``. Implicitly sets ``unique=True`` """ unique: Optional[bool] = None """ If ``True``, require the middleware to be the only one of its type """ def __post_init__(self) -> None: if self.first: if self.last: raise MiddlewareConstraintError("Cannot set 'first=True' if 'last=True'") if self.unique is False: raise MiddlewareConstraintError("Cannot set 'first=True' if 'unique=False'") if self.after: raise MiddlewareConstraintError("Cannot set 'first=True' if if 'after' is not empty") if self.last: if self.unique is False: raise MiddlewareConstraintError("Cannot set 'last=True' if 'unique=False'") if self.before: raise MiddlewareConstraintError("Cannot set 'last=True' if 'before' is not empty")
[docs] def require_unique(self, unique: bool) -> Self: """Return a new constraint with a ``unique`` value set""" return dataclasses.replace(self, unique=unique)
[docs] def apply_first(self) -> Self: """Return a new constraint with ``first=True``. Overrides ``last=True``""" return dataclasses.replace(self, first=True, last=False, unique=True)
[docs] def apply_last(self) -> Self: """Return a new constraint with ``first=True``. Overrides ``first=True``""" return dataclasses.replace(self, first=False, last=True, unique=True)
[docs] def apply_before( self, other: "str | Middleware | MiddlewareFactory | MiddlewareForwardRef", ignore_import_error: bool = False, ) -> Self: """Return new :class:`~litestar.middleware.constraints.MiddlewareConstraints` with ``other`` added to existing ``before`` constraint. :param other: Middleware this middleware needs to be applied before. If passed a string, create a :class:`~litestar.middleware.constraints.MiddlewareForwardRef` that resolves to the actual middleware at runtime :param ignore_import_error: If ``True`` and ``other`` is a string, ignore the constraint if an :exc:`ImportError` occurs when trying to import it """ if isinstance(other, str): other = MiddlewareForwardRef(target=other, ignore_import_error=ignore_import_error) return dataclasses.replace(self, before=(*self.before, other))
[docs] def apply_after( self, other: "str | Middleware | MiddlewareFactory | MiddlewareForwardRef", ignore_import_error: bool = False, ) -> Self: """Return new :class:`~litestar.middleware.constraints.MiddlewareConstraints` with ``other`` added to existing ``after`` constraint. :param other: Middleware this middleware needs to be applied before. If passed a string, create a :class:`~litestar.middleware.constraints.MiddlewareForwardRef` that resolves to the actual middleware at runtime :param ignore_import_error: If ``True`` and ``other`` is a string, ignore the constraint if an :exc:`ImportError` occurs when trying to import it """ if isinstance(other, str): other = MiddlewareForwardRef(target=other, ignore_import_error=ignore_import_error) return dataclasses.replace(self, after=(*self.after, other))
@staticmethod def _resolve_middleware( middlewares: tuple["Middleware | MiddlewareFactory | MiddlewareForwardRef", ...], ) -> tuple["Middleware | MiddlewareFactory", ...]: resolved = [] for middleware in middlewares: if isinstance(middleware, MiddlewareForwardRef): if (resolved_middleware := middleware.resolve()) is None: continue middleware = resolved_middleware resolved.append(middleware) return tuple(resolved) def _resolve(self) -> _ResolvedMiddlewareConstraints: return _ResolvedMiddlewareConstraints( before=self._resolve_middleware(self.before), after=self._resolve_middleware(self.after), first=self.first, last=self.last, unique=False if self.unique is None else self.unique, )
def _fully_qualified_name(obj: Any) -> str: return f"{obj.__module__}.{obj.__qualname__}" def _dfs(node: object, graph: dict[object, list[object]], visiting: set[object], visited: set[object]) -> bool: if node in visiting: return True if node in visited: return False visiting.add(node) if node in graph: for neighbor in graph[node]: if _dfs(neighbor, graph=graph, visiting=visiting, visited=visited): return True visiting.remove(node) visited.add(node) return False def _detect_constraints_cycle(graph: dict[object, list[object]]) -> None: visited: set[object] = set() visiting: set[object] = set() for node in graph: if _dfs(node, graph=graph, visiting=visiting, visited=visited): raise CycleError() def _check_positional_constraints( graph: dict[object, list[object]], positions: dict[object, list[int]], directional_constraints: dict[tuple[object, object], Literal["before", "after"]], ) -> None: _detect_constraints_cycle(graph) for node_u, predecessors in graph.items(): u_positions = positions.get(node_u) if not u_positions: continue max_u_pos = max(u_positions) for node_v in predecessors: if not (v_positions := positions.get(node_v)): continue min_v_pos = min(v_positions) if max_u_pos >= min_v_pos: constraint = directional_constraints[(node_u, node_v)] first = node_u second = node_v first_idx = max_u_pos second_idx = min_v_pos # since we've converted all constraints to a 'before' check while # building the graph, retrieve the original constraint type # ('before', 'after') for this violation, and flip the nodes of 'after' # constraints again, so we can construct an error message if constraint == "after": first, second = second, first first_idx, second_idx = second_idx, first_idx first_name = _fully_qualified_name(first) second_name = _fully_qualified_name(second) msg = ( f"All instances of {first_name!r} must come {constraint} any " f"instance of {second_name!r}. Found instance of {first_name!r} " f"at index {first_idx}, instance of {second_name!r} at index " f"{second_idx}." ) raise ConstraintViolationError(msg) def _check_first_last_constraints( want_first: list[object], want_last: list[object], positions: dict[object, list[int]], total_count: int, ) -> None: if len(want_first) > 1: msg = f"Multiple middlewares define 'first=True': {', '.join(map(_fully_qualified_name, want_first))}" raise MiddlewareConstraintError(msg) if len(want_last) > 1: msg = f"Multiple middlewares define 'last=True': {', '.join(map(_fully_qualified_name, want_last))}" raise MiddlewareConstraintError(msg) if want_first: first = want_first[0] first_positions = positions[first] max_pos_first = max(first_positions) if max_pos_first > 0: msg = ( f"Middleware {_fully_qualified_name(first)!r} is required to be in the " f"first position, but was found at index {', '.join(map(str, first_positions))}. " "(Violates constraint 'first=True')" ) raise ConstraintViolationError(msg) if want_last: last = want_last[0] last_positions = positions[last] max_pos_first = min(last_positions) expected_index = total_count - 1 if max_pos_first != expected_index: msg = ( f"Middleware {_fully_qualified_name(last)!r} is required to be in the " f"last position (index {expected_index} of {expected_index}), but was " f"found at index {', '.join(map(str, last_positions))}. " "(Violates constraint 'last=True')" ) raise ConstraintViolationError(msg) def _check_unique_constraints(unique: list[object], positions: dict[object, list[int]]) -> None: for middleware in unique: found_positions = positions[middleware] if len(found_positions) > 1: msg = ( f"Middleware {_fully_qualified_name(middleware)!r} must be unique. " f"Found {len(found_positions)} instances (indices " f"{', '.join(map(str, found_positions))}). " "(Violates constraints 'unique=True')" ) raise ConstraintViolationError(msg) def check_middleware_constraints(middlewares: tuple[Middleware, ...]) -> None: want_first: list[object] = [] want_last: list[object] = [] unique: list[object] = [] # simple "graph" that tracks a middleware and its neighbors, according to the spec # we're given. to keep things simple, we're converting all requirements into a form # of 'node -> list[predecessor]', and remember the original constraint graph: dict[object, list[object]] = collections.defaultdict(list) directional_constraints: dict[tuple[object, object], Literal["before", "after"]] = {} # keep track of the positions of *all* instances of a middleware; there may be more # than one instance of a specific type positions: collections.defaultdict[object, list[int]] = collections.defaultdict(list) for i, middleware in enumerate(middlewares): middleware_type: Union[object, type] if inspect.isfunction(middleware): positions[middleware].append(i) middleware_type = middleware else: # 'middleware' might be a class or an instance of a class middleware_type = type(middleware) if not inspect.isclass(middleware) else middleware for base in middleware_type.mro()[:-1]: # pyright: ignore positions[base].append(i) if not (isinstance(middleware, ASGIMiddleware) and middleware.constraints): continue constraints = middleware.constraints._resolve() if constraints.is_empty: continue if constraints.first: want_first.append(middleware_type) if constraints.last: want_last.append(middleware_type) if constraints.unique: unique.append(middleware_type) for before in constraints.before: directional_constraints[(middleware_type, before)] = "before" graph[middleware_type].append(before) for after in constraints.after: directional_constraints[(after, middleware_type)] = "after" graph[after].append(middleware_type) _check_unique_constraints( unique=unique, positions=positions, ) _check_first_last_constraints( want_first=want_first, want_last=want_last, positions=positions, total_count=len(middlewares), ) _check_positional_constraints( graph=dict(graph), # convert defaultdict to a regular dict to avoid accidental key creation positions=positions, directional_constraints=directional_constraints, )