Source code for litestar.typing

from __future__ import annotations

import dataclasses
import warnings
from collections import abc
from collections.abc import Collection, Mapping
from dataclasses import dataclass, is_dataclass, replace
from enum import Enum
from inspect import Parameter, Signature
from typing import Any, AnyStr, Callable, ForwardRef, Literal, TypeVar, cast

from litestar.types import Empty

try:
    import annotated_types
except ImportError:
    annotated_types = Empty  # type: ignore[assignment]

from typing import get_type_hints

from msgspec import UnsetType
from typing_extensions import (
    NewType,
    NotRequired,
    Required,
    get_args,
    get_origin,
    is_typeddict,
)
from typing_extensions import (
    TypeAliasType as TeTypeAliasType,
)

try:
    from typing import TypeAliasType  # type: ignore[attr-defined]

    TypeAliasTypes = (TypeAliasType, TeTypeAliasType)
except ImportError:
    TypeAliasTypes = (TeTypeAliasType,)  # type: ignore[assignment]

from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning
from litestar.params import BodyKwarg, DependencyKwarg, KwargDefinition, ParameterKwarg
from litestar.types.builtin_types import NoneType, UnionTypes
from litestar.utils.predicates import (
    is_any,
    is_class_and_subclass,
    is_generic,
    is_non_string_iterable,
    is_non_string_sequence,
    is_union,
)
from litestar.utils.typing import (
    get_instantiable_origin,
    get_safe_generic_origin,
    get_type_hints_with_generics_resolved,
    make_non_optional_union,
    unwrap_annotation,
)

__all__ = ("FieldDefinition",)

T = TypeVar("T", bound=KwargDefinition)


def _annotated_types_extractor(meta: Any, is_sequence_container: bool) -> dict[str, Any]:  # noqa: C901
    if annotated_types is Empty:  # type: ignore[comparison-overlap]  # pragma: no branch
        return {}  # type: ignore[unreachable]  # pragma: no cover

    kwargs = {}
    if isinstance(meta, annotated_types.GroupedMetadata):
        for sub_meta in meta:
            kwargs.update(_annotated_types_extractor(sub_meta, is_sequence_container=is_sequence_container))
        return kwargs
    if isinstance(meta, annotated_types.Gt):
        kwargs["gt"] = meta.gt
    elif isinstance(meta, annotated_types.Ge):
        kwargs["ge"] = meta.ge
    elif isinstance(meta, annotated_types.Lt):
        kwargs["lt"] = meta.lt
    elif isinstance(meta, annotated_types.Le):
        kwargs["le"] = meta.le
    elif isinstance(meta, annotated_types.MultipleOf):
        kwargs["multiple_of"] = meta.multiple_of
    elif isinstance(meta, annotated_types.MinLen):
        if is_sequence_container:
            kwargs["min_items"] = meta.min_length
        else:
            kwargs["min_length"] = meta.min_length
    elif isinstance(meta, annotated_types.MaxLen):
        if is_sequence_container:
            kwargs["max_items"] = meta.max_length
        else:
            kwargs["max_length"] = meta.max_length
    elif isinstance(meta, annotated_types.Predicate):
        if meta.func == str.islower:
            kwargs["lower_case"] = True
        elif meta.func == str.isupper:
            kwargs["upper_case"] = True
        elif meta.func == str.isascii:
            kwargs["pattern"] = "[[:ascii:]]"
        elif meta.func == str.isdigit:  # pragma: no cover  # coverage quirk: It expects a jump here for branch coverage
            kwargs["pattern"] = "[[:digit:]]"
    return kwargs


[docs] @dataclass(frozen=True) class FieldDefinition: """Represents a function parameter or type annotation.""" __slots__ = ( "annotation", "args", "default", "extra", "inner_types", "instantiable_origin", "kwarg_definition", "metadata", "name", "origin", "raw", "safe_generic_origin", "type_wrappers", ) raw: Any """The annotation exactly as received.""" annotation: Any """The annotation with any "wrapper" types removed, e.g. Annotated.""" type_wrappers: tuple[type, ...] """A set of all "wrapper" types, e.g. Annotated.""" origin: Any """The result of calling ``get_origin(annotation)`` after unwrapping Annotated, e.g. list.""" args: tuple[Any, ...] """The result of calling ``get_args(annotation)`` after unwrapping Annotated, e.g. (int,).""" metadata: tuple[Any, ...] """Any metadata associated with the annotation via ``Annotated``.""" instantiable_origin: Any """An equivalent type to ``origin`` that can be safely instantiated. E.g., ``Sequence`` -> ``list``.""" safe_generic_origin: Any """An equivalent type to ``origin`` that can be safely used as a generic type across all supported Python versions. This is to serve safely rebuilding a generic outer type with different args at runtime. """ inner_types: tuple[FieldDefinition, ...] """The type's generic args parsed as ``FieldDefinition``, if applicable.""" default: Any """Default value of the field.""" extra: dict[str, Any] """A mapping of extra values.""" kwarg_definition: KwargDefinition | DependencyKwarg | None """Kwarg Parameter.""" name: str """Field name.""" def __eq__(self, other: Any) -> bool: if not isinstance(other, FieldDefinition): return False if self.origin: return self.origin == other.origin and self.inner_types == other.inner_types return self.annotation == other.annotation # type: ignore[no-any-return] def __hash__(self) -> int: return hash((self.name, self.raw, self.annotation, self.origin, self.inner_types)) @property def has_default(self) -> bool: """Check if the field has a default value. Returns: True if the default is not Empty or Ellipsis otherwise False. """ return self.default is not Empty and self.default is not Ellipsis @property def is_non_string_iterable(self) -> bool: """Check if the field type is an Iterable. If ``self.annotation`` is an optional union, only the non-optional members of the union are evaluated. See: https://github.com/litestar-org/litestar/issues/1106 """ annotation = self.annotation if self.is_optional: annotation = make_non_optional_union(annotation) return is_non_string_iterable(annotation) @property def is_non_string_sequence(self) -> bool: """Check if the field type is a non-string Sequence. If ``self.annotation`` is an optional union, only the non-optional members of the union are evaluated. See: https://github.com/litestar-org/litestar/issues/1106 """ annotation = self.annotation if self.is_optional: annotation = make_non_optional_union(annotation) return is_non_string_sequence(annotation) @property def is_any(self) -> bool: """Check if the field type is Any.""" return is_any(self.annotation) @property def is_generic(self) -> bool: """Check if the field type is a custom class extending Generic.""" return is_generic(self.annotation) @property def is_simple_type(self) -> bool: """Check if the field type is a singleton value (e.g. int, str etc.).""" return not ( self.is_generic or self.is_optional or self.is_union or self.is_mapping or self.is_non_string_iterable or self.is_new_type ) @property def is_parameter_field(self) -> bool: """Check if the field type is a parameter kwarg value.""" return isinstance(self.kwarg_definition, ParameterKwarg) @property def is_const(self) -> bool: """Check if the field is defined as constant value.""" return bool(self.kwarg_definition and getattr(self.kwarg_definition, "const", False)) @property def is_required(self) -> bool: """Check if the field should be marked as a required parameter.""" if Required in self.type_wrappers: # type: ignore[comparison-overlap] return True if NotRequired in self.type_wrappers or UnsetType in self.args: # type: ignore[comparison-overlap] return False if isinstance(self.kwarg_definition, ParameterKwarg) and self.kwarg_definition.required is not None: return self.kwarg_definition.required return not self.is_optional and not self.is_any and (not self.has_default or self.default is None) @property def is_annotated(self) -> bool: """Check if the field type is Annotated.""" return bool(self.metadata) @property def is_literal(self) -> bool: """Check if the field type is Literal.""" return self.origin is Literal @property def is_forward_ref(self) -> bool: """Whether the annotation is a forward reference or not.""" return isinstance(self.annotation, (str, ForwardRef)) @property def is_mapping(self) -> bool: """Whether the annotation is a mapping or not.""" return self.is_subclass_of(Mapping) @property def is_tuple(self) -> bool: """Whether the annotation is a ``tuple`` or not.""" return self.is_subclass_of(tuple) @property def is_new_type(self) -> bool: return isinstance(self.annotation, NewType) @property def is_type_alias_type(self) -> bool: """Whether the annotation is a ``TypeAliasType``""" return isinstance(self.annotation, TypeAliasTypes) @property def is_type_var(self) -> bool: """Whether the annotation is a TypeVar or not.""" return isinstance(self.annotation, TypeVar) @property def is_union(self) -> bool: """Whether the annotation is a union type or not.""" return self.origin in UnionTypes @property def is_optional(self) -> bool: """Whether the annotation is Optional or not.""" return bool(self.is_union and NoneType in self.args) @property def is_none_type(self) -> bool: """Whether the annotation is NoneType or not.""" return self.annotation is NoneType @property def is_collection(self) -> bool: """Whether the annotation is a collection type or not.""" return self.is_subclass_of(Collection) @property def is_non_string_collection(self) -> bool: """Whether the annotation is a non-string collection type or not.""" return self.is_collection and not self.is_subclass_of((str, bytes)) @property def bound_types(self) -> tuple[FieldDefinition, ...] | None: """A tuple of bound types - if the annotation is a TypeVar with bound types, otherwise None.""" if self.is_type_var and (bound := getattr(self.annotation, "__bound__", None)): if is_union(bound): return tuple(FieldDefinition.from_annotation(t) for t in get_args(bound)) return (FieldDefinition.from_annotation(bound),) return None @property def generic_types(self) -> tuple[FieldDefinition, ...] | None: """A tuple of generic types passed into the annotation - if its generic.""" if not (bases := getattr(self.annotation, "__orig_bases__", None)): return None args: list[FieldDefinition] = [] for base_args in [getattr(base, "__args__", ()) for base in bases]: for arg in base_args: field_definition = FieldDefinition.from_annotation(arg) if field_definition.generic_types: args.extend(field_definition.generic_types) else: args.append(field_definition) return tuple(args) @property def is_dataclass_type(self) -> bool: """Whether the annotation is a dataclass type or not.""" return is_dataclass(cast("type", self.origin or self.annotation)) @property def is_typeddict_type(self) -> bool: """Whether the type is TypedDict or not.""" return is_typeddict(self.origin or self.annotation) @property def is_enum(self) -> bool: return self.is_subclass_of(Enum) @property def type_(self) -> Any: """The type of the annotation with all the wrappers removed, including the generic types.""" return self.origin or self.annotation
[docs] def is_subclass_of(self, cl: type[Any] | tuple[type[Any], ...]) -> bool: """Whether the annotation is a subclass of the given type. Where ``self.annotation`` is a union type, this method will return ``True`` when all members of the union are a subtype of ``cl``, otherwise, ``False``. Args: cl: The type to check, or tuple of types. Passed as 2nd argument to ``issubclass()``. Returns: Whether the annotation is a subtype of the given type(s). """ if self.origin: if self.origin in UnionTypes: return all(t.is_subclass_of(cl) for t in self.inner_types) return self.origin not in UnionTypes and is_class_and_subclass(self.origin, cl) if self.annotation is AnyStr: return is_class_and_subclass(str, cl) or is_class_and_subclass(bytes, cl) return self.annotation is not Any and not self.is_type_var and is_class_and_subclass(self.annotation, cl)
[docs] def has_inner_subclass_of(self, cl: type[Any] | tuple[type[Any], ...]) -> bool: """Whether any generic args are a subclass of the given type. Args: cl: The type to check, or tuple of types. Passed as 2nd argument to ``issubclass()``. Returns: Whether any of the type's generic args are a subclass of the given type. """ return any(t.is_subclass_of(cl) for t in self.inner_types)
[docs] def get_type_hints(self, *, include_extras: bool = False, resolve_generics: bool = False) -> dict[str, Any]: """Get the type hints for the annotation. Args: include_extras: Flag to indicate whether to include ``Annotated[T, ...]`` or not. resolve_generics: Flag to indicate whether to resolve the generic types in the type hints or not. Returns: The type hints. """ if self.origin is not None or self.is_generic: if resolve_generics: return get_type_hints_with_generics_resolved(self.annotation, include_extras=include_extras) return get_type_hints(self.origin or self.annotation, include_extras=include_extras) return get_type_hints(self.annotation, include_extras=include_extras)
[docs] @classmethod def from_annotation(cls, annotation: Any, **kwargs: Any) -> FieldDefinition: """Initialize FieldDefinition. Args: annotation: The type annotation. This should be extracted from the return of ``get_type_hints(..., include_extras=True)`` so that forward references are resolved and recursive ``Annotated`` types are flattened. **kwargs: Additional keyword arguments to pass to the ``FieldDefinition`` constructor. Returns: FieldDefinition """ unwrapped, metadata, wrappers = unwrap_annotation(annotation if annotation is not Empty else Any) origin = get_origin(unwrapped) annotation_args = () if origin is abc.Callable else get_args(unwrapped) if not kwargs.get("kwarg_definition"): if isinstance(kwargs.get("default"), (KwargDefinition, DependencyKwarg)): kwargs["kwarg_definition"] = kwargs.pop("default") elif kwarg_definition := next( (v for v in metadata if isinstance(v, (KwargDefinition, DependencyKwarg))), None ): kwargs["kwarg_definition"] = kwarg_definition if kwarg_definition.default is not Empty: warnings.warn( f"Deprecated default value specification for annotation '{annotation}'. Setting defaults " f"inside 'typing.Annotated' is discouraged and support for this will be removed in a future " f"version. Defaults should be set with regular parameter default values. Use " "'param: Annotated[<type>, Parameter(...)] = <default>' instead of " "'param: Annotated[<type>, Parameter(..., default=<default>)].", category=DeprecationWarning, stacklevel=2, ) if kwargs.get("default", Empty) is not Empty and kwarg_definition.default != kwargs["default"]: warnings.warn( f"Ambiguous default values for annotation '{annotation}'. The default value " f"'{kwarg_definition.default!r}' set inside the parameter annotation differs from the " f"parameter default value '{kwargs['default']!r}'", category=LitestarWarning, stacklevel=2, ) metadata = tuple(v for v in metadata if not isinstance(v, (KwargDefinition, DependencyKwarg))) elif (extra := kwargs.get("extra", {})) and "kwarg_definition" in extra: kwargs["kwarg_definition"] = extra.pop("kwarg_definition") # there might be additional metadata if metadata: kwarg_definition_merge_args = {} is_sequence_container = is_non_string_sequence(annotation) # extract metadata into KwargDefinition attributes for meta in metadata: kwarg_definition_merge_args.update( _annotated_types_extractor(meta, is_sequence_container=is_sequence_container) ) # if we already have a KwargDefinition, merge it with the additional metadata if existing_kwargs_definition := kwargs.get("kwarg_definition"): kwargs["kwarg_definition"] = dataclasses.replace( existing_kwargs_definition, **kwarg_definition_merge_args ) # if not, create a new KwargDefinition else: model = BodyKwarg if kwargs.get("name") == "data" else ParameterKwarg kwargs["kwarg_definition"] = model(**kwarg_definition_merge_args) kwargs.setdefault("annotation", unwrapped) kwargs.setdefault("args", annotation_args) kwargs.setdefault("default", Empty) kwargs.setdefault("extra", {}) kwargs.setdefault("inner_types", tuple(FieldDefinition.from_annotation(arg) for arg in annotation_args)) kwargs.setdefault("instantiable_origin", get_instantiable_origin(origin, unwrapped)) kwargs.setdefault("kwarg_definition", None) kwargs.setdefault("metadata", metadata) kwargs.setdefault("name", "") kwargs.setdefault("origin", origin) kwargs.setdefault("raw", annotation) kwargs.setdefault("safe_generic_origin", get_safe_generic_origin(origin, unwrapped)) kwargs.setdefault("type_wrappers", wrappers) instance = FieldDefinition(**kwargs) if not instance.has_default and instance.kwarg_definition: return replace(instance, default=instance.kwarg_definition.default) return instance
[docs] @classmethod def from_kwarg( cls, annotation: Any, name: str, default: Any = Empty, inner_types: tuple[FieldDefinition, ...] | None = None, kwarg_definition: KwargDefinition | DependencyKwarg | None = None, extra: dict[str, Any] | None = None, ) -> FieldDefinition: """Create a new FieldDefinition instance. Args: annotation: The type of the kwarg. name: Field name. default: A default value. inner_types: A tuple of FieldDefinition instances representing the inner types, if any. kwarg_definition: Kwarg Parameter. extra: A mapping of extra values. Returns: FieldDefinition instance. """ return cls.from_annotation( annotation, name=name, default=default, **{ k: v for k, v in { "inner_types": inner_types, "kwarg_definition": kwarg_definition, "extra": extra, }.items() if v is not None }, )
[docs] @classmethod def from_parameter(cls, parameter: Parameter, fn_type_hints: dict[str, Any]) -> FieldDefinition: """Initialize ParsedSignatureParameter. Args: parameter: inspect.Parameter fn_type_hints: mapping of names to types. Should be result of ``get_type_hints()``, preferably via the :attr:``get_fn_type_hints() <.utils.signature_parsing.get_fn_type_hints>`` helper. Returns: ParsedSignatureParameter. """ from litestar.datastructures import ImmutableState try: annotation = fn_type_hints[parameter.name] except KeyError as e: raise ImproperlyConfiguredException( f"'{parameter.name}' does not have a type annotation. If it should receive any value, use 'Any'." ) from e if parameter.name == "state" and not issubclass(annotation, ImmutableState): raise ImproperlyConfiguredException( f"The type annotation `{annotation}` is an invalid type for the 'state' reserved kwarg. " "It must be typed to a subclass of `litestar.datastructures.ImmutableState` or " "`litestar.datastructures.State`." ) return FieldDefinition.from_kwarg( annotation=annotation, name=parameter.name, default=Empty if parameter.default is Signature.empty else parameter.default, )
[docs] def match_predicate_recursively(self, predicate: Callable[[FieldDefinition], bool]) -> bool: """Recursively test the passed in predicate against the field and any of its inner fields. Args: predicate: A callable that receives a field definition instance as an arg and returns a boolean. Returns: A boolean. """ return predicate(self) or any(t.match_predicate_recursively(predicate) for t in self.inner_types)