Source code for litestar.middleware.csrf

from __future__ import annotations

import hashlib
import hmac
import secrets
from secrets import compare_digest
from typing import TYPE_CHECKING, Any

from litestar.datastructures import MutableScopeHeaders
from litestar.datastructures.cookie import Cookie
from litestar.enums import RequestEncodingType, ScopeType
from litestar.exceptions import PermissionDeniedException
from litestar.middleware._utils import (
    build_exclude_path_pattern,
    should_bypass_middleware,
)
from litestar.middleware.base import MiddlewareProtocol
from litestar.utils.scope.state import ScopeState

if TYPE_CHECKING:
    from litestar.config.csrf import CSRFConfig
    from litestar.connection import Request
    from litestar.types import (
        ASGIApp,
        HTTPSendMessage,
        Message,
        Receive,
        Scope,
        Scopes,
        Send,
    )

__all__ = ("CSRFMiddleware",)

CSRF_SECRET_BYTES = 32
CSRF_SECRET_LENGTH = CSRF_SECRET_BYTES * 2


def generate_csrf_hash(token: str, secret: str) -> str:
    """Generate an HMAC that signs the CSRF token.

    Args:
        token: A hashed token.
        secret: A secret value.

    Returns:
        A CSRF hash.
    """
    return hmac.new(secret.encode(), token.encode(), hashlib.sha256).hexdigest()


def generate_csrf_token(secret: str) -> str:
    """Generate a CSRF token that includes a randomly generated string signed by an HMAC.

    Args:
        secret: A secret string.

    Returns:
        A unique CSRF token.
    """
    token = secrets.token_hex(CSRF_SECRET_BYTES)
    token_hash = generate_csrf_hash(token=token, secret=secret)
    return token + token_hash


[docs] class CSRFMiddleware(MiddlewareProtocol): """CSRF Middleware class. This Middleware protects against attacks by setting a CSRF cookie with a token and verifying it in request headers. """ scopes: Scopes = {ScopeType.HTTP}
[docs] def __init__(self, app: ASGIApp, config: CSRFConfig) -> None: """Initialize ``CSRFMiddleware``. Args: app: The ``next`` ASGI app to call. config: The CSRFConfig instance. """ self.app = app self.config = config self.exclude = build_exclude_path_pattern(exclude=config.exclude, middleware_cls=type(self))
[docs] async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ASGI callable. Args: scope: The ASGI connection scope. receive: The ASGI receive function. send: The ASGI send function. Returns: None """ if scope["type"] != ScopeType.HTTP: await self.app(scope, receive, send) return if should_bypass_middleware( scope=scope, scopes=self.scopes, exclude_opt_key=self.config.exclude_from_csrf_key, exclude_path_pattern=self.exclude, ): await self.app(scope, receive, send) return request: Request[Any, Any, Any] = scope["litestar_app"].request_class(scope=scope, receive=receive) content_type, _ = request.content_type csrf_cookie = request.cookies.get(self.config.cookie_name) existing_csrf_token = request.headers.get(self.config.header_name) if not existing_csrf_token and content_type in { RequestEncodingType.URL_ENCODED, RequestEncodingType.MULTI_PART, }: form = await request.form() existing_csrf_token = form.get("_csrf_token", None) connection_state = ScopeState.from_scope(scope) if request.method in self.config.safe_methods: token = connection_state.csrf_token = csrf_cookie or generate_csrf_token(secret=self.config.secret) await self.app(scope, receive, self.create_send_wrapper(send=send, csrf_cookie=csrf_cookie, token=token)) elif ( existing_csrf_token is not None and csrf_cookie is not None and self._csrf_tokens_match(existing_csrf_token, csrf_cookie) ): connection_state.csrf_token = existing_csrf_token await self.app(scope, receive, send) else: raise PermissionDeniedException("CSRF token verification failed")
[docs] def create_send_wrapper(self, send: Send, token: str, csrf_cookie: str | None) -> Send: """Wrap ``send`` to handle CSRF validation. Args: token: The CSRF token. send: The ASGI send function. csrf_cookie: CSRF cookie. Returns: An ASGI send function. """ async def send_wrapper(message: Message) -> None: """Send function that wraps the original send to inject a cookie. Args: message: An ASGI ``Message`` Returns: None """ if csrf_cookie is None and message["type"] == "http.response.start": message.setdefault("headers", []) self._set_cookie_if_needed(message=message, token=token) await send(message) return send_wrapper
def _set_cookie_if_needed(self, message: HTTPSendMessage, token: str) -> None: headers = MutableScopeHeaders.from_message(message) cookie = Cookie( key=self.config.cookie_name, value=token, path=self.config.cookie_path, secure=self.config.cookie_secure, httponly=self.config.cookie_httponly, samesite=self.config.cookie_samesite, domain=self.config.cookie_domain, ) headers.add("set-cookie", cookie.to_header(header="")) def _decode_csrf_token(self, token: str) -> str | None: """Decode a CSRF token and validate its HMAC.""" if len(token) < CSRF_SECRET_LENGTH + 1: return None token_secret = token[:CSRF_SECRET_LENGTH] existing_hash = token[CSRF_SECRET_LENGTH:] expected_hash = generate_csrf_hash(token=token_secret, secret=self.config.secret) return token_secret if compare_digest(existing_hash, expected_hash) else None def _csrf_tokens_match(self, request_csrf_token: str, cookie_csrf_token: str) -> bool: """Take the CSRF tokens from the request and the cookie and verify both are valid and identical.""" decoded_request_token = self._decode_csrf_token(request_csrf_token) decoded_cookie_token = self._decode_csrf_token(cookie_csrf_token) if decoded_request_token is None or decoded_cookie_token is None: return False return compare_digest(decoded_request_token, decoded_cookie_token)