Source code for litestar.openapi.spec.base

from __future__ import annotations

from dataclasses import asdict, dataclass, fields, is_dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from collections.abc import Iterator
    from dataclasses import Field

__all__ = ("BaseSchemaObject",)


def _normalize_key(key: str) -> str:
    if key.endswith("_in"):
        return "in"
    if key.startswith("schema_"):
        return key.split("_")[1]
    if "_" in key:
        components = key.split("_")
        return components[0] + "".join(component.title() for component in components[1:])
    return "$ref" if key == "ref" else key


def _normalize_value(value: Any) -> Any:
    if isinstance(value, BaseSchemaObject):
        return value.to_schema()
    if is_dataclass(value):
        return {
            _normalize_value(k): _normalize_value(v)
            for k, v in asdict(value).items()  # type: ignore[call-overload]
            if v is not None
        }
    if isinstance(value, dict):
        return {_normalize_value(k): _normalize_value(v) for k, v in value.items() if v is not None}
    if isinstance(value, list):
        return [_normalize_value(v) for v in value]
    return value.value if isinstance(value, Enum) else value


[docs] @dataclass class BaseSchemaObject: """Base class for schema spec objects""" @property def _exclude_fields(self) -> set[str]: return set() def _iter_fields(self) -> Iterator[Field[Any]]: yield from fields(self)
[docs] def to_schema(self) -> dict[str, Any]: """Transform the spec dataclass object into a string keyed dictionary. This method traverses all nested values recursively. """ result: dict[str, Any] = {} exclude = self._exclude_fields for field in self._iter_fields(): if field.name in exclude: continue value = _normalize_value(getattr(self, field.name, None)) if value is not None: if "alias" in field.metadata: if not isinstance(field.metadata["alias"], str): raise TypeError('metadata["alias"] must be a str') key = field.metadata["alias"] else: key = _normalize_key(field.name) result[key] = value return result