from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
from litestar.constants import (
HTTP_RESPONSE_BODY,
HTTP_RESPONSE_START,
)
from litestar.data_extractors import (
ConnectionDataExtractor,
RequestExtractorField,
ResponseDataExtractor,
ResponseExtractorField,
)
from litestar.enums import ScopeType
from litestar.middleware.base import ASGIMiddleware
from litestar.serialization import encode_json
from litestar.utils.empty import value_or_default
from litestar.utils.scope import get_serializer_from_scope
from litestar.utils.scope.state import ScopeState
__all__ = ("LoggingMiddleware",)
if TYPE_CHECKING:
import logging
from collections.abc import Iterable, Sequence
from litestar.connection import Request
from litestar.types import (
ASGIApp,
Logger,
Message,
Receive,
Scope,
Send,
Serializer,
)
[docs]
class LoggingMiddleware(ASGIMiddleware):
"""Logging middleware."""
scopes = (ScopeType.HTTP,)
[docs]
def __init__(
self,
logger: logging.Logger | Logger | str | Callable[[], Logger],
*,
exclude: str | list[str] | None = None,
exclude_opt_key: str | None = None,
include_compressed_body: bool = False,
request_cookies_to_obfuscate: Iterable[str] = ("session",),
request_headers_to_obfuscate: Iterable[str] = ("Authorization", "X-API-KEY"),
response_cookies_to_obfuscate: Iterable[str] = ("session",),
response_headers_to_obfuscate: Iterable[str] = ("Authorization", "X-API-KEY"),
request_log_message: str = "HTTP Request",
response_log_message: str = "HTTP Response",
request_log_fields: Sequence[RequestExtractorField] = (
"path",
"method",
"content_type",
"query",
"path_params",
),
response_log_fields: Sequence[ResponseExtractorField] = ("status_code",),
parse_body: bool = False,
parse_query: bool = True,
log_structured: bool = False,
) -> None:
self.exclude_opt_key = exclude_opt_key
self.exclude_path_pattern = tuple(exclude) if isinstance(exclude, list) else exclude
self.include_compressed_body = include_compressed_body
self.request_cookies_to_obfuscate = frozenset(request_cookies_to_obfuscate)
self.request_headers_to_obfuscate = frozenset(request_headers_to_obfuscate)
self.response_cookies_to_obfuscate = frozenset(response_cookies_to_obfuscate)
self.response_headers_to_obfuscate = frozenset(response_headers_to_obfuscate)
self.request_log_message = request_log_message
self.response_log_message = response_log_message
self.request_log_fields = request_log_fields
self.response_log_fields = response_log_fields
self.log_structured = log_structured
if isinstance(logger, str):
import logging
self.logger: Logger | logging.Logger = logging.getLogger(logger)
elif callable(logger):
self.logger = logger()
else:
self.logger = logger
self.request_extractor = ConnectionDataExtractor(
extract_body="body" in self.request_log_fields,
extract_client="client" in self.request_log_fields,
extract_content_type="content_type" in self.request_log_fields,
extract_cookies="cookies" in self.request_log_fields,
extract_headers="headers" in self.request_log_fields,
extract_method="method" in self.request_log_fields,
extract_path="path" in self.request_log_fields,
extract_path_params="path_params" in self.request_log_fields,
extract_query="query" in self.request_log_fields,
extract_scheme="scheme" in self.request_log_fields,
obfuscate_cookies=self.request_cookies_to_obfuscate,
obfuscate_headers=self.request_headers_to_obfuscate,
parse_body=parse_body,
parse_query=parse_query,
skip_parse_malformed_body=True,
)
self.response_extractor = ResponseDataExtractor(
extract_body="body" in self.response_log_fields,
extract_headers="headers" in self.response_log_fields,
extract_status_code="status_code" in self.response_log_fields,
obfuscate_cookies=self.response_cookies_to_obfuscate,
obfuscate_headers=self.response_headers_to_obfuscate,
)
[docs]
async def handle(self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp) -> None:
if self.response_log_fields:
send = self.create_send_wrapper(scope=scope, send=send)
if self.request_log_fields:
await self.log_request(scope=scope, receive=receive)
await next_app(scope, receive, send)
[docs]
async def log_request(self, scope: Scope, receive: Receive) -> None:
"""Extract request data and log the message.
Args:
scope: The ASGI connection scope.
receive: ASGI receive callable
Returns:
None
"""
extracted_data = await self.extract_request_data(request=scope["litestar_app"].request_class(scope, receive))
self.log_message(values=extracted_data)
[docs]
def log_response(self, scope: Scope) -> None:
"""Extract the response data and log the message.
Args:
scope: The ASGI connection scope.
Returns:
None
"""
extracted_data = self.extract_response_data(scope=scope)
self.log_message(values=extracted_data)
[docs]
def log_message(self, values: dict[str, Any]) -> None:
"""Log a message.
Args:
values: Extract values to log.
Returns:
None
"""
message = values.pop("message")
if self.log_structured:
self.logger.info(message, **values)
else:
extra_str = ", ".join(f"{k}={v}" for k, v in values.items())
self.logger.info(f"{message}: {extra_str}") # noqa: G004
def _serialize_value(self, serializer: Serializer | None, value: Any) -> Any:
if not self.log_structured and isinstance(value, (dict, list, tuple, set)):
value = encode_json(value, serializer)
return value.decode("utf-8", errors="backslashreplace") if isinstance(value, bytes) else value
[docs]
def create_send_wrapper(self, scope: Scope, send: Send) -> Send:
"""Create a ``send`` wrapper, which handles logging response data.
Args:
scope: The ASGI connection scope.
send: The ASGI send function.
Returns:
An ASGI send function.
"""
connection_state = ScopeState.from_scope(scope)
async def send_wrapper(message: Message) -> None:
if message["type"] == HTTP_RESPONSE_START:
connection_state.log_context[HTTP_RESPONSE_START] = message
elif message["type"] == HTTP_RESPONSE_BODY:
connection_state.log_context[HTTP_RESPONSE_BODY] = message
self.log_response(scope=scope)
if not message.get("more_body"):
connection_state.log_context.clear()
await send(message)
return send_wrapper