Source code for litestar.middleware.rate_limit

from __future__ import annotations

from dataclasses import dataclass, field
from time import time
from typing import TYPE_CHECKING, Any, Callable, Literal, cast

from litestar.datastructures import MutableScopeHeaders
from litestar.enums import ScopeType
from litestar.exceptions import TooManyRequestsException
from litestar.middleware.base import AbstractMiddleware, DefineMiddleware
from litestar.serialization import decode_json, encode_json
from litestar.utils import ensure_async_callable

__all__ = (
    "CacheObject",
    "RateLimitConfig",
    "RateLimitMiddleware",
    "get_remote_address",
)


if TYPE_CHECKING:
    from collections.abc import Awaitable

    from litestar import Litestar
    from litestar.connection import Request
    from litestar.stores.base import Store
    from litestar.types import ASGIApp, Message, Receive, Scope, Send, SyncOrAsyncUnion


DurationUnit = Literal["second", "minute", "hour", "day"]

DURATION_VALUES: dict[DurationUnit, int] = {"second": 1, "minute": 60, "hour": 3600, "day": 86400}


[docs] @dataclass class CacheObject: """Representation of a cached object's metadata.""" __slots__ = ("history", "reset") history: list[int] reset: int
[docs] def get_remote_address(request: Request[Any, Any, Any]) -> str: """Get a client's remote address from a ``Request`` Args: request: A :class:`Request <.connection.Request>` instance. Returns: An address, uniquely identifying this client """ return request.client.host if request.client else "127.0.0.1"
[docs] class RateLimitMiddleware(AbstractMiddleware): """Rate-limiting middleware."""
[docs] def __init__(self, app: ASGIApp, config: RateLimitConfig) -> None: """Initialize ``RateLimitMiddleware``. Args: app: The ``next`` ASGI app to call. config: An instance of RateLimitConfig. """ super().__init__( app=app, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key, scopes={ScopeType.HTTP} ) self.check_throttle_handler = cast("Callable[[Request], Awaitable[bool]] | None", config.check_throttle_handler) self.config = config self.max_requests: int = config.rate_limit[1] self.unit: DurationUnit = config.rate_limit[0] self.get_identifier_for_request = config.identifier_for_request
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ASGI callable. Args: scope: The ASGI connection scope. receive: The ASGI receive function. send: The ASGI send function. Returns: None """ app = scope["litestar_app"] request: Request[Any, Any, Any] = app.request_class(scope) store = self.config.get_store_from_app(app) if await self.should_check_request(request=request): identifier = self.get_identifier_for_request(request) key = f"{type(self).__name__}::{identifier}" route_handler = request.scope["route_handler"] if getattr(route_handler, "is_mount", False): key += "::mount" cache_object = await self.retrieve_cached_history(key, store) if len(cache_object.history) >= self.max_requests: raise TooManyRequestsException( headers=self.create_response_headers(cache_object=cache_object) if self.config.set_rate_limit_headers else None ) await self.set_cached_history(key=key, cache_object=cache_object, store=store) if self.config.set_rate_limit_headers: send = self.create_send_wrapper(send=send, cache_object=cache_object) await self.app(scope, receive, send) # pyright: ignore
[docs] def create_send_wrapper(self, send: Send, cache_object: CacheObject) -> Send: """Create a ``send`` function that wraps the original send to inject response headers. Args: send: The ASGI send function. cache_object: A StorageObject instance. Returns: Send wrapper callable. """ async def send_wrapper(message: Message) -> None: """Wrap the ASGI ``Send`` callable. Args: message: An ASGI ``Message`` Returns: None """ if message["type"] == "http.response.start": message.setdefault("headers", []) headers = MutableScopeHeaders(message) for key, value in self.create_response_headers(cache_object=cache_object).items(): headers[key] = value await send(message) return send_wrapper
[docs] async def retrieve_cached_history(self, key: str, store: Store) -> CacheObject: """Retrieve a list of time stamps for the given duration unit. Args: key: Cache key. store: A :class:`Store <.stores.base.Store>` Returns: An :class:`CacheObject`. """ duration = DURATION_VALUES[self.unit] now = int(time()) cached_string = await store.get(key) if cached_string: cache_object = CacheObject(**decode_json(value=cached_string)) if cache_object.reset <= now: return CacheObject(history=[], reset=now + duration) while cache_object.history and cache_object.history[-1] <= now - duration: cache_object.history.pop() return cache_object return CacheObject(history=[], reset=now + duration)
[docs] async def set_cached_history(self, key: str, cache_object: CacheObject, store: Store) -> None: """Store history extended with the current timestamp in cache. Args: key: Cache key. cache_object: A :class:`CacheObject`. store: A :class:`Store <.stores.base.Store>` Returns: None """ cache_object.history = [int(time()), *cache_object.history] await store.set(key, encode_json(cache_object), expires_in=DURATION_VALUES[self.unit])
[docs] async def should_check_request(self, request: Request[Any, Any, Any]) -> bool: """Return a boolean indicating if a request should be checked for rate limiting. Args: request: A :class:`Request <.connection.Request>` instance. Returns: Boolean dictating whether the request should be checked for rate-limiting. """ if self.check_throttle_handler: return await self.check_throttle_handler(request) return True
[docs] def create_response_headers(self, cache_object: CacheObject) -> dict[str, str]: """Create ratelimit response headers. Notes: * see the `IETF RateLimit draft <https://datatracker.ietf.org/doc/draft-ietf-httpapi-ratelimit-headers/>`_ Args: cache_object:A :class:`CacheObject`. Returns: A dict of http headers. """ remaining_requests = str( self.max_requests - len(cache_object.history) if len(cache_object.history) <= self.max_requests else 0 ) return { self.config.rate_limit_policy_header_key: f"{self.max_requests}; w={DURATION_VALUES[self.unit]}", self.config.rate_limit_limit_header_key: str(self.max_requests), self.config.rate_limit_remaining_header_key: remaining_requests, self.config.rate_limit_reset_header_key: str(cache_object.reset - int(time())), }
[docs] @dataclass class RateLimitConfig: """Configuration for ``RateLimitMiddleware``""" rate_limit: tuple[DurationUnit, int] """A tuple containing a time unit (second, minute, hour, day) and quantity, e.g. ("day", 1) or ("minute", 5).""" exclude: str | list[str] | None = field(default=None) """A pattern or list of patterns to skip in the rate limiting middleware.""" exclude_opt_key: str | None = field(default=None) """An identifier to use on routes to disable rate limiting for a particular route.""" identifier_for_request: Callable[[Request], str] = get_remote_address """ A callable that receives the request and returns an identifier for which the limit should be applied. Defaults to :func:`~litestar.middleware.rate_limit.get_remote_address`, which returns the client's address. Note that :func:`~litestar.middleware.rate_limit.get_remote_address` does *NOT* honour ``X-FORWARDED-FOR`` headers, as these cannot be trusted implicitly. If running behind a proxy, a secure way of updating the client's address should be implemented, such as uvicorn's `ProxyHeaderMiddleware <https://github.com/encode/uvicorn/blob/master/uvicorn/middleware/proxy_headers.py>`_ or hypercon's `ProxyFixMiddleware <https://hypercorn.readthedocs.io/en/latest/how_to_guides/proxy_fix.html>`_ . """ check_throttle_handler: Callable[[Request[Any, Any, Any]], SyncOrAsyncUnion[bool]] | None = field(default=None) """Handler callable that receives the request instance, returning a boolean dictating whether or not the request should be checked for rate limiting. """ middleware_class: type[RateLimitMiddleware] = field(default=RateLimitMiddleware) """The middleware class to use.""" set_rate_limit_headers: bool = field(default=True) """Boolean dictating whether to set the rate limit headers on the response.""" rate_limit_policy_header_key: str = field(default="RateLimit-Policy") """Key to use for the rate limit policy header.""" rate_limit_remaining_header_key: str = field(default="RateLimit-Remaining") """Key to use for the rate limit remaining header.""" rate_limit_reset_header_key: str = field(default="RateLimit-Reset") """Key to use for the rate limit reset header.""" rate_limit_limit_header_key: str = field(default="RateLimit-Limit") """Key to use for the rate limit limit header.""" store: str = "rate_limit" """Name of the :class:`Store <.stores.base.Store>` to use""" def __post_init__(self) -> None: if self.check_throttle_handler: self.check_throttle_handler = ensure_async_callable(self.check_throttle_handler) # type: ignore[arg-type] @property def middleware(self) -> DefineMiddleware: """Use this property to insert the config into a middleware list on one of the application layers. Examples: .. code-block:: python from litestar import Litestar, Request, get from litestar.middleware.rate_limit import RateLimitConfig # limit to 10 requests per minute, excluding the schema path throttle_config = RateLimitConfig(rate_limit=("minute", 10), exclude=["/schema"]) @get("/") def my_handler(request: Request) -> None: ... app = Litestar(route_handlers=[my_handler], middleware=[throttle_config.middleware]) Returns: An instance of :class:`DefineMiddleware <.middleware.base.DefineMiddleware>` including ``self`` as the config kwarg value. """ return DefineMiddleware(self.middleware_class, config=self)
[docs] def get_store_from_app(self, app: Litestar) -> Store: """Get the store defined in :attr:`store` from an :class:`Litestar <.app.Litestar>` instance.""" return app.stores.get(self.store)