Source code for litestar.serialization.msgspec_hooks

from __future__ import annotations

from collections import deque
from datetime import date, datetime, time
from decimal import Decimal
from functools import partial
from ipaddress import (
    IPv4Address,
    IPv4Interface,
    IPv4Network,
    IPv6Address,
    IPv6Interface,
    IPv6Network,
)
from pathlib import Path, PurePath
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
from uuid import UUID

import msgspec

from litestar.datastructures.secret_values import SecretBytes, SecretString
from litestar.exceptions import SerializationException
from litestar.types import Empty, EmptyType, Serializer, TypeDecodersSequence
from litestar.utils.typing import get_origin_or_inner_type

if TYPE_CHECKING:
    from collections.abc import Mapping

    from litestar.types import TypeEncodersMap

__all__ = (
    "decode_json",
    "decode_msgpack",
    "default_deserializer",
    "default_serializer",
    "encode_json",
    "encode_msgpack",
    "get_serializer",
)

T = TypeVar("T")

DEFAULT_TYPE_ENCODERS: TypeEncodersMap = {
    Path: str,
    PurePath: str,
    IPv4Address: str,
    IPv4Interface: str,
    IPv4Network: str,
    IPv6Address: str,
    IPv6Interface: str,
    IPv6Network: str,
    datetime: lambda val: val.isoformat(),
    date: lambda val: val.isoformat(),
    time: lambda val: val.isoformat(),
    deque: list,
    Decimal: lambda val: int(val) if val.as_tuple().exponent >= 0 else float(val),
    Pattern: lambda val: val.pattern,
    SecretBytes: lambda val: val.get_obscured().decode("utf-8"),
    SecretString: lambda val: val.get_obscured(),
    # support subclasses of stdlib types, If no previous type matched, these will be
    # the last type in the mro, so we use this to (attempt to) convert a subclass into
    # its base class. # see https://github.com/jcrist/msgspec/issues/248
    # and https://github.com/litestar-org/litestar/issues/1003
    str: str,
    int: int,
    float: float,
    set: set,
    frozenset: frozenset,
    bytes: bytes,
}


[docs] def default_serializer(value: Any, type_encoders: Mapping[Any, Callable[[Any], Any]] | None = None) -> Any: """Transform values non-natively supported by ``msgspec`` Args: value: A value to serialized type_encoders: Mapping of types to callables to transforming types Returns: A serialized value Raises: TypeError: if value is not supported """ type_encoders = {**DEFAULT_TYPE_ENCODERS, **(type_encoders or {})} for base in value.__class__.__mro__[:-1]: try: encoder = type_encoders[base] except KeyError: continue else: return encoder(value) raise TypeError(f"Unsupported type: {type(value)!r}")
[docs] def default_deserializer( target_type: Any, value: Any, type_decoders: TypeDecodersSequence | None = None ) -> Any: # pragma: no cover """Transform values non-natively supported by ``msgspec`` Args: target_type: Encountered type value: Value to coerce type_decoders: Optional sequence of type decoders Returns: A ``msgspec``-supported type """ from litestar.datastructures.state import ImmutableState try: if isinstance(value, target_type): return value except TypeError as exc: # we might get a TypeError here if target_type is a subscribed generic. For # performance reasons, we let this happen and only unwrap this when we're # certain this might be the case if (origin := get_origin_or_inner_type(target_type)) is not None: target_type = origin if isinstance(value, target_type): return value else: raise exc if type_decoders: for predicate, decoder in type_decoders: if predicate(target_type): return decoder(target_type, value) if issubclass(target_type, (PurePath, ImmutableState, UUID)): return target_type(value) if issubclass(target_type, SecretBytes) and isinstance(value, (bytes, str)): return SecretBytes(value.encode("utf-8") if isinstance(value, str) else value) if issubclass(target_type, SecretString) and isinstance(value, str): return SecretString(value) raise TypeError(f"Unsupported type: {type(value)!r}")
_msgspec_json_encoder = msgspec.json.Encoder(enc_hook=default_serializer) _msgspec_json_decoder = msgspec.json.Decoder(dec_hook=default_deserializer) _msgspec_msgpack_encoder = msgspec.msgpack.Encoder(enc_hook=default_serializer) _msgspec_msgpack_decoder = msgspec.msgpack.Decoder(dec_hook=default_deserializer)
[docs] def encode_json(value: Any, serializer: Callable[[Any], Any] | None = None) -> bytes: """Encode a value into JSON. Args: value: Value to encode serializer: Optional callable to support non-natively supported types. Returns: JSON as bytes Raises: SerializationException: If error encoding ``obj``. """ try: return msgspec.json.encode(value, enc_hook=serializer) if serializer else _msgspec_json_encoder.encode(value) except (TypeError, msgspec.EncodeError) as msgspec_error: raise SerializationException(str(msgspec_error)) from msgspec_error
@overload def decode_json(value: str | bytes, strict: bool = ...) -> Any: ... @overload def decode_json(value: str | bytes, type_decoders: TypeDecodersSequence | None, strict: bool = ...) -> Any: ... @overload def decode_json(value: str | bytes, target_type: type[T], strict: bool = ...) -> T: ... @overload def decode_json( value: str | bytes, target_type: type[T], type_decoders: TypeDecodersSequence | None, strict: bool = ... ) -> T: ...
[docs] def decode_json( # type: ignore[misc] value: str | bytes, target_type: type[T] | EmptyType = Empty, # pyright: ignore type_decoders: TypeDecodersSequence | None = None, strict: bool = True, ) -> Any: """Decode a JSON string/bytes into an object. Args: value: Value to decode target_type: An optional type to decode the data into type_decoders: Optional sequence of type decoders strict: Whether type coercion rules should be strict. Setting to False enables a wider set of coercion rules from string to non-string types for all values Returns: An object Raises: SerializationException: If error decoding ``value``. """ try: if target_type is Empty: return _msgspec_json_decoder.decode(value) return msgspec.json.decode( value, dec_hook=partial( default_deserializer, type_decoders=type_decoders, ), type=target_type, strict=strict, ) except msgspec.DecodeError as msgspec_error: raise SerializationException(str(msgspec_error)) from msgspec_error
[docs] def encode_msgpack(value: Any, serializer: Callable[[Any], Any] | None = default_serializer) -> bytes: """Encode a value into MessagePack. Args: value: Value to encode serializer: Optional callable to support non-natively supported types Returns: MessagePack as bytes Raises: SerializationException: If error encoding ``obj``. """ try: if serializer is None or serializer is default_serializer: return _msgspec_msgpack_encoder.encode(value) return msgspec.msgpack.encode(value, enc_hook=serializer) except (TypeError, msgspec.EncodeError) as msgspec_error: raise SerializationException(str(msgspec_error)) from msgspec_error
@overload def decode_msgpack(value: bytes, strict: bool = ...) -> Any: ... @overload def decode_msgpack(value: bytes, type_decoders: TypeDecodersSequence | None, strict: bool = ...) -> Any: ... @overload def decode_msgpack(value: bytes, target_type: type[T], strict: bool = ...) -> T: ... @overload def decode_msgpack( value: bytes, target_type: type[T], type_decoders: TypeDecodersSequence | None, strict: bool = ... ) -> T: ...
[docs] def decode_msgpack( # type: ignore[misc] value: bytes, target_type: type[T] | EmptyType = Empty, # pyright: ignore[reportInvalidTypeVarUse] type_decoders: TypeDecodersSequence | None = None, strict: bool = True, ) -> Any: """Decode a MessagePack string/bytes into an object. Args: value: Value to decode target_type: An optional type to decode the data into type_decoders: Optional sequence of type decoders strict: Whether type coercion rules should be strict. Setting to False enables a wider set of coercion rules from string to non-string types for all values Returns: An object Raises: SerializationException: If error decoding ``value``. """ try: if target_type is Empty: return _msgspec_msgpack_decoder.decode(value) return msgspec.msgpack.decode( value, dec_hook=partial(default_deserializer, type_decoders=type_decoders), type=target_type, strict=strict, ) except msgspec.DecodeError as msgspec_error: raise SerializationException(str(msgspec_error)) from msgspec_error
[docs] def get_serializer(type_encoders: TypeEncodersMap | None = None) -> Serializer: """Get the serializer for the given type encoders.""" if type_encoders: return partial(default_serializer, type_encoders=type_encoders) return default_serializer