Source code for litestar.data_extractors

from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, cast

from litestar._parsers import parse_cookie_string
from litestar.connection.request import Request
from litestar.datastructures.upload_file import UploadFile
from litestar.enums import HttpMethod, RequestEncodingType

__all__ = (
    "ConnectionDataExtractor",
    "ExtractedRequestData",
    "ExtractedResponseData",
    "RequestExtractorField",
    "ResponseDataExtractor",
    "ResponseExtractorField",
)


if TYPE_CHECKING:
    from collections.abc import Coroutine, Iterable

    from litestar.connection import ASGIConnection
    from litestar.types import Method
    from litestar.types.asgi_types import HTTPResponseBodyEvent, HTTPResponseStartEvent


def _obfuscate(values: dict[str, Any], fields_to_obfuscate: set[str]) -> dict[str, Any]:
    """Obfuscate values in a dictionary, replacing values with `******`

    Args:
        values: A dictionary of strings
        fields_to_obfuscate: keys to obfuscate

    Returns:
        A dictionary with obfuscated strings
    """
    return {key: "*****" if key.lower() in fields_to_obfuscate else value for key, value in values.items()}


RequestExtractorField = Literal[
    "path", "method", "content_type", "headers", "cookies", "query", "path_params", "body", "scheme", "client"
]

ResponseExtractorField = Literal["status_code", "headers", "body", "cookies"]


[docs] class ExtractedRequestData(TypedDict, total=False): """Dictionary representing extracted request data.""" body: Coroutine[Any, Any, Any] client: tuple[str, int] content_type: tuple[str, dict[str, str]] cookies: dict[str, str] headers: dict[str, str] method: Method path: str path_params: dict[str, Any] query: bytes | dict[str, Any] scheme: str
[docs] class ConnectionDataExtractor: """Utility class to extract data from an :class:`ASGIConnection <litestar.connection.ASGIConnection>`, :class:`Request <litestar.connection.Request>` or :class:`WebSocket <litestar.connection.WebSocket>` instance. """ __slots__ = ( "connection_extractors", "obfuscate_cookies", "obfuscate_headers", "parse_body", "parse_query", "request_extractors", "skip_parse_malformed_body", )
[docs] def __init__( self, extract_body: bool = True, extract_client: bool = True, extract_content_type: bool = True, extract_cookies: bool = True, extract_headers: bool = True, extract_method: bool = True, extract_path: bool = True, extract_path_params: bool = True, extract_query: bool = True, extract_scheme: bool = True, obfuscate_cookies: set[str] | None = None, obfuscate_headers: set[str] | None = None, parse_body: bool = False, parse_query: bool = False, skip_parse_malformed_body: bool = False, ) -> None: """Initialize ``ConnectionDataExtractor`` Args: extract_body: Whether to extract body, (for requests only). extract_client: Whether to extract the client (host, port) mapping. extract_content_type: Whether to extract the content type and any options. extract_cookies: Whether to extract cookies. extract_headers: Whether to extract headers. extract_method: Whether to extract the HTTP method, (for requests only). extract_path: Whether to extract the path. extract_path_params: Whether to extract path parameters. extract_query: Whether to extract query parameters. extract_scheme: Whether to extract the http scheme. obfuscate_headers: headers keys to obfuscate. Obfuscated values are replaced with '*****'. obfuscate_cookies: cookie keys to obfuscate. Obfuscated values are replaced with '*****'. parse_body: Whether to parse the body value or return the raw byte string, (for requests only). parse_query: Whether to parse query parameters or return the raw byte string. skip_parse_malformed_body: Whether to skip parsing the body if it is malformed """ self.parse_body = parse_body self.parse_query = parse_query self.skip_parse_malformed_body = skip_parse_malformed_body self.obfuscate_headers = {h.lower() for h in (obfuscate_headers or set())} self.obfuscate_cookies = {c.lower() for c in (obfuscate_cookies or set())} self.connection_extractors: dict[str, Callable[[ASGIConnection[Any, Any, Any, Any]], Any]] = {} self.request_extractors: dict[RequestExtractorField, Callable[[Request[Any, Any, Any]], Any]] = {} if extract_scheme: self.connection_extractors["scheme"] = self.extract_scheme if extract_client: self.connection_extractors["client"] = self.extract_client if extract_path: self.connection_extractors["path"] = self.extract_path if extract_headers: self.connection_extractors["headers"] = self.extract_headers if extract_cookies: self.connection_extractors["cookies"] = self.extract_cookies if extract_query: self.connection_extractors["query"] = self.extract_query if extract_path_params: self.connection_extractors["path_params"] = self.extract_path_params if extract_method: self.request_extractors["method"] = self.extract_method if extract_content_type: self.request_extractors["content_type"] = self.extract_content_type if extract_body: self.request_extractors["body"] = self.extract_body
[docs] def __call__(self, connection: ASGIConnection[Any, Any, Any, Any]) -> ExtractedRequestData: """Extract data from the connection, returning a dictionary of values. Notes: - The value for ``body`` - if present - is an unresolved Coroutine and as such should be awaited by the receiver. Args: connection: An ASGI connection or its subclasses. Returns: A string keyed dictionary of extracted values. """ extractors = ( {**self.connection_extractors, **self.request_extractors} # type: ignore[misc] if isinstance(connection, Request) else self.connection_extractors ) return cast("ExtractedRequestData", {key: extractor(connection) for key, extractor in extractors.items()})
async def extract( self, connection: ASGIConnection[Any, Any, Any, Any], fields: Iterable[str] ) -> ExtractedRequestData: extractors = ( {**self.connection_extractors, **self.request_extractors} # type: ignore[misc] if isinstance(connection, Request) else self.connection_extractors ) data = {} for key, extractor in extractors.items(): if key not in fields: continue if inspect.iscoroutinefunction(extractor): value = await extractor(connection) else: value = extractor(connection) data[key] = value return cast("ExtractedRequestData", data)
[docs] @staticmethod def extract_scheme(connection: ASGIConnection[Any, Any, Any, Any]) -> str: """Extract the scheme from an ``ASGIConnection`` Args: connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. Returns: The connection's scope["scheme"] value """ return connection.scope["scheme"]
[docs] @staticmethod def extract_client(connection: ASGIConnection[Any, Any, Any, Any]) -> tuple[str, int]: """Extract the client from an ``ASGIConnection`` Args: connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. Returns: The connection's scope["client"] value or a default value. """ return connection.scope.get("client") or ("", 0)
[docs] @staticmethod def extract_path(connection: ASGIConnection[Any, Any, Any, Any]) -> str: """Extract the path from an ``ASGIConnection`` Args: connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. Returns: The connection's scope["path"] value """ return connection.scope["path"]
[docs] def extract_headers(self, connection: ASGIConnection[Any, Any, Any, Any]) -> dict[str, str]: """Extract headers from an ``ASGIConnection`` Args: connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. Returns: A dictionary with the connection's headers. """ headers = {k.decode("latin-1"): v.decode("latin-1") for k, v in connection.scope["headers"]} return _obfuscate(headers, self.obfuscate_headers) if self.obfuscate_headers else headers
[docs] def extract_cookies(self, connection: ASGIConnection[Any, Any, Any, Any]) -> dict[str, str]: """Extract cookies from an ``ASGIConnection`` Args: connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. Returns: A dictionary with the connection's cookies. """ return _obfuscate(connection.cookies, self.obfuscate_cookies) if self.obfuscate_cookies else connection.cookies
[docs] def extract_query(self, connection: ASGIConnection[Any, Any, Any, Any]) -> Any: """Extract query from an ``ASGIConnection`` Args: connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. Returns: Either a dictionary with the connection's parsed query string or the raw query byte-string. """ return connection.query_params.dict() if self.parse_query else connection.scope.get("query_string", b"")
[docs] @staticmethod def extract_path_params(connection: ASGIConnection[Any, Any, Any, Any]) -> dict[str, Any]: """Extract the path parameters from an ``ASGIConnection`` Args: connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance. Returns: A dictionary with the connection's path parameters. """ return connection.path_params
[docs] @staticmethod def extract_method(request: Request[Any, Any, Any]) -> Method: """Extract the method from an ``ASGIConnection`` Args: request: A :class:`Request <litestar.connection.Request>` instance. Returns: The request's scope["method"] value. """ return request.scope["method"]
[docs] @staticmethod def extract_content_type(request: Request[Any, Any, Any]) -> tuple[str, dict[str, str]]: """Extract the content-type from an ``ASGIConnection`` Args: request: A :class:`Request <litestar.connection.Request>` instance. Returns: A tuple containing the request's parsed 'Content-Type' header. """ return request.content_type
[docs] async def extract_body(self, request: Request[Any, Any, Any]) -> Any: """Extract the body from an ``ASGIConnection`` Args: request: A :class:`Request <litestar.connection.Request>` instance. Returns: Either the parsed request body or the raw byte-string. """ if request.method == HttpMethod.GET: return None if not self.parse_body: return await request.body() try: request_encoding_type = request.content_type[0] if request_encoding_type == RequestEncodingType.JSON: return await request.json() form_data = await request.form() if request_encoding_type == RequestEncodingType.URL_ENCODED: return dict(form_data) return { key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items() } except Exception as exc: if self.skip_parse_malformed_body: return await request.body() raise exc
[docs] class ExtractedResponseData(TypedDict, total=False): """Dictionary representing extracted response data.""" body: bytes status_code: int headers: dict[str, str] cookies: dict[str, str]
[docs] class ResponseDataExtractor: """Utility class to extract data from a ``Message``""" __slots__ = ("extractors", "obfuscate_cookies", "obfuscate_headers", "parse_headers")
[docs] def __init__( self, extract_body: bool = True, extract_cookies: bool = True, extract_headers: bool = True, extract_status_code: bool = True, obfuscate_cookies: set[str] | None = None, obfuscate_headers: set[str] | None = None, ) -> None: """Initialize ``ResponseDataExtractor`` with options. Args: extract_body: Whether to extract the body. extract_cookies: Whether to extract the cookies. extract_headers: Whether to extract the headers. extract_status_code: Whether to extract the status code. obfuscate_cookies: cookie keys to obfuscate. Obfuscated values are replaced with '*****'. obfuscate_headers: headers keys to obfuscate. Obfuscated values are replaced with '*****'. """ self.obfuscate_headers = {h.lower() for h in (obfuscate_headers or set())} self.obfuscate_cookies = {c.lower() for c in (obfuscate_cookies or set())} self.extractors: dict[ ResponseExtractorField, Callable[[tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]], Any] ] = {} if extract_body: self.extractors["body"] = self.extract_response_body if extract_status_code: self.extractors["status_code"] = self.extract_status_code if extract_headers: self.extractors["headers"] = self.extract_headers if extract_cookies: self.extractors["cookies"] = self.extract_cookies
[docs] def __call__(self, messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> ExtractedResponseData: """Extract data from the response, returning a dictionary of values. Args: messages: A tuple containing :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>` and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`. Returns: A string keyed dictionary of extracted values. """ return cast("ExtractedResponseData", {key: extractor(messages) for key, extractor in self.extractors.items()})
[docs] @staticmethod def extract_response_body(messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> bytes: """Extract the response body from a ``Message`` Args: messages: A tuple containing :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>` and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`. Returns: The Response's body as a byte-string. """ return messages[1]["body"]
[docs] @staticmethod def extract_status_code(messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> int: """Extract a status code from a ``Message`` Args: messages: A tuple containing :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>` and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`. Returns: The Response's status-code. """ return messages[0]["status"]
[docs] def extract_headers(self, messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> dict[str, str]: """Extract headers from a ``Message`` Args: messages: A tuple containing :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>` and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`. Returns: The Response's headers dict. """ headers = { key.decode("latin-1"): value.decode("latin-1") for key, value in filter(lambda x: x[0].lower() != b"set-cookie", messages[0]["headers"]) } return ( _obfuscate( headers, self.obfuscate_headers, ) if self.obfuscate_headers else headers )
[docs] def extract_cookies(self, messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> dict[str, str]: """Extract cookies from a ``Message`` Args: messages: A tuple containing :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>` and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`. Returns: The Response's cookies dict. """ if cookie_string := ";".join( [x[1].decode("latin-1") for x in filter(lambda x: x[0].lower() == b"set-cookie", messages[0]["headers"])] ): parsed_cookies = parse_cookie_string(cookie_string) return _obfuscate(parsed_cookies, self.obfuscate_cookies) if self.obfuscate_cookies else parsed_cookies return {}