from __future__ import annotations
import hashlib
import hmac
import secrets
from secrets import compare_digest
from typing import TYPE_CHECKING, Any
from litestar.datastructures import MutableScopeHeaders
from litestar.datastructures.cookie import Cookie
from litestar.enums import RequestEncodingType, ScopeType
from litestar.exceptions import PermissionDeniedException
from litestar.middleware._utils import (
build_exclude_path_pattern,
should_bypass_middleware,
)
from litestar.middleware.base import MiddlewareProtocol
from litestar.utils.scope.state import ScopeState
if TYPE_CHECKING:
from litestar.config.csrf import CSRFConfig
from litestar.connection import Request
from litestar.types import (
ASGIApp,
HTTPSendMessage,
Message,
Receive,
Scope,
Scopes,
Send,
)
__all__ = ("CSRFMiddleware",)
CSRF_SECRET_BYTES = 32
CSRF_SECRET_LENGTH = CSRF_SECRET_BYTES * 2
def generate_csrf_hash(token: str, secret: str) -> str:
"""Generate an HMAC that signs the CSRF token.
Args:
token: A hashed token.
secret: A secret value.
Returns:
A CSRF hash.
"""
return hmac.new(secret.encode(), token.encode(), hashlib.sha256).hexdigest()
def generate_csrf_token(secret: str) -> str:
"""Generate a CSRF token that includes a randomly generated string signed by an HMAC.
Args:
secret: A secret string.
Returns:
A unique CSRF token.
"""
token = secrets.token_hex(CSRF_SECRET_BYTES)
token_hash = generate_csrf_hash(token=token, secret=secret)
return token + token_hash
[docs]
class CSRFMiddleware(MiddlewareProtocol):
"""CSRF Middleware class.
This Middleware protects against attacks by setting a CSRF cookie with a token and verifying it in request headers.
"""
scopes: Scopes = {ScopeType.HTTP}
[docs]
def __init__(self, app: ASGIApp, config: CSRFConfig) -> None:
"""Initialize ``CSRFMiddleware``.
Args:
app: The ``next`` ASGI app to call.
config: The CSRFConfig instance.
"""
self.app = app
self.config = config
self.exclude = build_exclude_path_pattern(exclude=config.exclude, middleware_cls=type(self))
[docs]
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI callable.
Args:
scope: The ASGI connection scope.
receive: The ASGI receive function.
send: The ASGI send function.
Returns:
None
"""
if scope["type"] != ScopeType.HTTP:
await self.app(scope, receive, send)
return
if should_bypass_middleware(
scope=scope,
scopes=self.scopes,
exclude_opt_key=self.config.exclude_from_csrf_key,
exclude_path_pattern=self.exclude,
):
await self.app(scope, receive, send)
return
request: Request[Any, Any, Any] = scope["litestar_app"].request_class(scope=scope, receive=receive)
content_type, _ = request.content_type
csrf_cookie = request.cookies.get(self.config.cookie_name)
existing_csrf_token = request.headers.get(self.config.header_name)
if not existing_csrf_token and content_type in {
RequestEncodingType.URL_ENCODED,
RequestEncodingType.MULTI_PART,
}:
form = await request.form()
existing_csrf_token = form.get("_csrf_token", None)
connection_state = ScopeState.from_scope(scope)
if request.method in self.config.safe_methods:
token = connection_state.csrf_token = csrf_cookie or generate_csrf_token(secret=self.config.secret)
await self.app(scope, receive, self.create_send_wrapper(send=send, csrf_cookie=csrf_cookie, token=token))
elif (
existing_csrf_token is not None
and csrf_cookie is not None
and self._csrf_tokens_match(existing_csrf_token, csrf_cookie)
):
connection_state.csrf_token = existing_csrf_token
await self.app(scope, receive, send)
else:
raise PermissionDeniedException("CSRF token verification failed")
[docs]
def create_send_wrapper(self, send: Send, token: str, csrf_cookie: str | None) -> Send:
"""Wrap ``send`` to handle CSRF validation.
Args:
token: The CSRF token.
send: The ASGI send function.
csrf_cookie: CSRF cookie.
Returns:
An ASGI send function.
"""
async def send_wrapper(message: Message) -> None:
"""Send function that wraps the original send to inject a cookie.
Args:
message: An ASGI ``Message``
Returns:
None
"""
if csrf_cookie is None and message["type"] == "http.response.start":
message.setdefault("headers", [])
self._set_cookie_if_needed(message=message, token=token)
await send(message)
return send_wrapper
def _set_cookie_if_needed(self, message: HTTPSendMessage, token: str) -> None:
headers = MutableScopeHeaders.from_message(message)
cookie = Cookie(
key=self.config.cookie_name,
value=token,
path=self.config.cookie_path,
secure=self.config.cookie_secure,
httponly=self.config.cookie_httponly,
samesite=self.config.cookie_samesite,
domain=self.config.cookie_domain,
)
headers.add("set-cookie", cookie.to_header(header=""))
def _decode_csrf_token(self, token: str) -> str | None:
"""Decode a CSRF token and validate its HMAC."""
if len(token) < CSRF_SECRET_LENGTH + 1:
return None
token_secret = token[:CSRF_SECRET_LENGTH]
existing_hash = token[CSRF_SECRET_LENGTH:]
expected_hash = generate_csrf_hash(token=token_secret, secret=self.config.secret)
return token_secret if compare_digest(existing_hash, expected_hash) else None
def _csrf_tokens_match(self, request_csrf_token: str, cookie_csrf_token: str) -> bool:
"""Take the CSRF tokens from the request and the cookie and verify both are valid and identical."""
decoded_request_token = self._decode_csrf_token(request_csrf_token)
decoded_cookie_token = self._decode_csrf_token(cookie_csrf_token)
if decoded_request_token is None or decoded_cookie_token is None:
return False
return compare_digest(decoded_request_token, decoded_cookie_token)