Source code for litestar.openapi.plugins

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

import msgspec

from litestar.constants import OPENAPI_JSON_HANDLER_NAME
from litestar.enums import MediaType, OpenAPIMediaType
from litestar.handlers import get
from litestar.serialization import encode_json, get_serializer

if TYPE_CHECKING:
    from collections.abc import Sequence

    from litestar.config.csrf import CSRFConfig
    from litestar.connection import Request
    from litestar.router import Router


__all__ = (
    "OpenAPIRenderPlugin",
    "RapidocRenderPlugin",
    "RedocRenderPlugin",
    "ScalarRenderPlugin",
    "StoplightRenderPlugin",
    "SwaggerRenderPlugin",
    "YamlRenderPlugin",
)

_favicon_url = "https://cdn.jsdelivr.net/gh/litestar-org/branding@main/assets/Branding%20-%20PNG%20-%20Transparent/Badge%20-%20Blue%20and%20Yellow.png"
_default_favicon = f"<link rel='icon' type='image/png' href='{_favicon_url}'>"
_default_style = "<style>body { margin: 0; padding: 0 }</style>"


def _get_cookie_value_or_undefined(cookie_name: str) -> str:
    """Javascript code as a string to get the value of a cookie by name or undefined."""
    return f"document.cookie.split('; ').find((row) => row.startsWith('{cookie_name}='))?.split('=')[1];"


[docs] class OpenAPIRenderPlugin(ABC): """Base class for OpenAPI UI render plugins.""" paths: list[str]
[docs] def __init__( self, *, path: str | Sequence[str], media_type: MediaType | OpenAPIMediaType = MediaType.HTML, favicon: str = _default_favicon, style: str = _default_style, ) -> None: """Initialize the OpenAPI UI render plugin. Args: path: Path to serve the OpenAPI UI at. media_type: Media type for the handler. favicon: Html <link> tag for the favicon. style: Base styling of the html body. """ self.paths = [path] if isinstance(path, str) else list(path) self.media_type = media_type self.favicon = favicon self.style = style
[docs] @staticmethod def render_json(request: Request, openapi_schema: dict[str, Any]) -> bytes: """Render the OpenAPI schema as JSON. Args: request: The request that triggered the render. openapi_schema: The OpenAPI schema as a dictionary. Returns: The rendered JSON. """ return encode_json(openapi_schema, serializer=get_serializer(request.route_handler.type_encoders))
[docs] @abstractmethod def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: """Render the OpenAPI UI. Args: request: The request that triggered the render. openapi_schema: The OpenAPI schema as a dictionary. Returns: The rendered HTML. """ raise NotImplementedError
[docs] @staticmethod def get_openapi_json_route(request: Request) -> str: """Get the route for the OpenAPI JSON schema. Returns: The route for the OpenAPI JSON schema. """ return request.app.route_reverse(OPENAPI_JSON_HANDLER_NAME)
[docs] def receive_router(self, router: Router) -> None: """Receive the router that serves the OpenAPI UI. Can be used by plugins to additionally configure the router, e.g. to add additional routes. Args: router: The router that serves the OpenAPI UI. """ return
[docs] def has_path(self, path: str) -> bool: """Check if the plugin has a path. Args: path: The path to check. Returns: True if the plugin has the path, False otherwise. """ return path in self.paths
class JsonRenderPlugin(OpenAPIRenderPlugin): """Render the OpenAPI schema as JSON.""" def __init__( self, *, path: str | Sequence[str] = "/openapi.json", media_type: MediaType | OpenAPIMediaType = OpenAPIMediaType.OPENAPI_JSON, **kwargs: Any, ) -> None: """Initialize the OpenAPI UI render plugin. Args: path: Path to serve the OpenAPI UI at. media_type: Media type for the handler. **kwargs: Additional arguments to pass to the base class. """ super().__init__(path=path, media_type=media_type, **kwargs) def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: """Render an OpenAPI schema as JSON. Args: request: The request. openapi_schema: The OpenAPI schema as a dictionary. Returns: The rendered OpenAPI schema as JSON. """ return self.render_json(request, openapi_schema)
[docs] class YamlRenderPlugin(OpenAPIRenderPlugin): """Render an OpenAPI schema as YAML."""
[docs] def __init__( self, *, path: str | Sequence[str] = ("/openapi.yaml", "/openapi.yml"), media_type: MediaType | OpenAPIMediaType = OpenAPIMediaType.OPENAPI_YAML, **kwargs: Any, ) -> None: """Initialize the OpenAPI UI render plugin. Args: path: Path to serve the OpenAPI UI at. media_type: Media type for the handler. **kwargs: Additional arguments to pass to the base class. """ super().__init__(path=path, media_type=media_type, **kwargs)
[docs] def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: """Render an OpenAPI schema as YAML. Args: request: The request. openapi_schema: The OpenAPI schema as a dictionary. Returns: The rendered OpenAPI schema as YAML. """ import yaml # using msgspec.to_builtins() ensures that any examples generated by polyfactory that have the # UNSET value (possible if the examples are being generated for a partial DTO model which makes # every type a union with UNSET) are stripped out. openapi_schema = msgspec.to_builtins( openapi_schema, enc_hook=get_serializer(request.route_handler.type_encoders) ) return yaml.dump(openapi_schema, default_flow_style=False).encode("utf-8")
[docs] class RapidocRenderPlugin(OpenAPIRenderPlugin): """Render an OpenAPI schema using Rapidoc."""
[docs] def __init__( self, *, version: str = "9.3.4", js_url: str | None = None, path: str | Sequence[str] = "/rapidoc", **kwargs: Any, ) -> None: """Initialize the OpenAPI UI render plugin. Args: version: Rapidoc version to download from the CDN. If js_url is provided, this is ignored. js_url: Download url for the RapiDoc JS bundle. If not provided, the version will be used to construct the url. path: Path to serve the OpenAPI UI at. **kwargs: Additional arguments to pass to the base class. """ self.js_url = js_url or f"https://unpkg.com/rapidoc@{version}/dist/rapidoc-min.js" super().__init__(path=path, **kwargs)
[docs] def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: """Render an HTML page for Rapidoc. .. note:: Override this method to customize the template. Args: request: The request. openapi_schema: The OpenAPI schema as a dictionary. Returns: A rendered html string. """ def create_request_interceptor(csrf_config: CSRFConfig) -> str: if csrf_config.cookie_httponly: return "" return f""" <script> window.addEventListener('DOMContentLoaded', (event) => {{ const rapidocEl = document.getElementsByTagName("rapi-doc")[0]; rapidocEl.addEventListener('before-try', (e) => {{ const csrf_token = {_get_cookie_value_or_undefined(csrf_config.cookie_name)}; if (csrf_token !== undefined) {{ e.detail.request.headers.append('{csrf_config.header_name}', csrf_token); }} }}); }}); </script>""" head = f""" <head> <title>{openapi_schema["info"]["title"]}</title> {self.favicon} <meta charset="utf-8"/> <meta name="viewport" content="width=device-width, initial-scale=1"> <script src="{self.js_url}" crossorigin></script> {self.style} </head> """ body = f""" <body> <rapi-doc spec-url="{self.get_openapi_json_route(request)}" /> {create_request_interceptor(request.app.csrf_config) if request.app.csrf_config else ""} </body> """ return f""" <!DOCTYPE html> <html> {head} {body} </html> """.encode()
[docs] class RedocRenderPlugin(OpenAPIRenderPlugin): """Render an OpenAPI schema using Redoc."""
[docs] def __init__( self, *, version: str = "next", js_url: str | None = None, google_fonts: bool = True, path: str | Sequence[str] = "/redoc", **kwargs: Any, ) -> None: """Initialize the OpenAPI UI render plugin. Args: version: Redoc version to download from the CDN. If js_url is provided, this is ignored. js_url: Download url for the Redoc JS bundle. If not provided, the version will be used to construct the url. google_fonts: Download google fonts via CDN. Should be set to False when not using a CDN. path: Path to serve the OpenAPI UI at. **kwargs: Additional arguments to pass to the base class. """ self.js_url = js_url or f"https://cdn.jsdelivr.net/npm/redoc@{version}/bundles/redoc.standalone.js" self.google_fonts = google_fonts super().__init__(path=path, **kwargs)
[docs] def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: """Render an HTML page for Redoc. .. note:: override this method to customize the template. Args: request: The request. openapi_schema: The OpenAPI schema as a dictionary. Returns: A rendered html string. """ head = f""" <head> <title>{openapi_schema["info"]["title"]}</title> {self.favicon} <meta charset="utf-8"/> <meta name="viewport" content="width=device-width, initial-scale=1"> """ if self.google_fonts: head += """ <link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet"> """ head += f""" <script src="{self.js_url}" crossorigin></script> {self.style} </head> """ body = b"".join( [ b"<body><div id='redoc-container'/><script type='text/javascript'>Redoc.init(", self.render_json(request, openapi_schema), b",undefined,document.getElementById('redoc-container'))</script></body>", ] ) return b"".join( [ b"<!DOCTYPE html><html>", head.encode(), body, b"</html>", ] )
[docs] class ScalarRenderPlugin(OpenAPIRenderPlugin): """Plugin to render an OpenAPI schema using Scalar. .. versionadded:: 2.8.0 """ _default_css_url = "https://cdn.jsdelivr.net/gh/litestar-org/branding@main/assets/openapi/scalar.css"
[docs] def __init__( self, *, version: str = "latest", js_url: str | None = None, css_url: str | None = None, path: str | Sequence[str] = "/scalar", options: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Initialize the Scalar OpenAPI UI render plugin. Args: version: Scalar version to download from the CDN. If js_url is provided, this is ignored. js_url: Download url for the Scalar JS bundle. If not provided, the version will be used to construct the url. css_url: Download url for the Scalar CSS bundle. If not provided, the Litestar-provided CSS will be used. path: Path to serve the OpenAPI UI at. options: Scalar configuration options. If not provided the default Scalar configuration will be used. **kwargs: Additional arguments to pass to the base class. """ self.js_url = js_url or f"https://cdn.jsdelivr.net/npm/@scalar/api-reference@{version}" self.css_url = css_url or self._default_css_url self.options = options super().__init__(path=path, **kwargs)
[docs] def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: """Render an HTMl page for Scalar. .. note:: Override this method to customize the template. Args: request: The request. openapi_schema: The OpenAPI schema as a dictionary. Returns: A rendered html string. """ head = f""" <head> <title>{openapi_schema["info"]["title"]}</title> {self.style} <meta charset="utf-8"/> <meta name="viewport" content="width=device-width, initial-scale=1"> {self.favicon} <link rel="stylesheet" type="text/css" href="{self.css_url}"> </head> """ body = f""" <noscript> Scalar requires Javascript to function. Please enable it to browse the documentation. </noscript> <script id="api-reference" data-url="{self.get_openapi_json_route(request)}"> </script> {self.render_options()} <script src="{self.js_url}" crossorigin></script> """ return f""" <!DOCTYPE html> <html> {head} {body} </html> """.encode()
[docs] def render_options(self) -> str: """Render options to Scalar configuration.""" if not self.options: return "" return f""" <script> document.getElementById('api-reference').dataset.configuration = '{msgspec.json.encode(self.options).decode()}' </script> """
[docs] class StoplightRenderPlugin(OpenAPIRenderPlugin): """Render an OpenAPI schema using StopLight Elements."""
[docs] def __init__( self, *, version: str = "7.7.18", js_url: str | None = None, css_url: str | None = None, path: str | Sequence[str] = "/elements", **kwargs: Any, ) -> None: """Initialize the OpenAPI UI render plugin. Args: version: StopLight Elements version to download from the CDN. If js_url is provided, this is ignored. js_url: Download url for the StopLight Elements JS bundle. If not provided, the version will be used to construct the url. css_url: Download url for the StopLight Elements CSS bundle. If not provided, the version will be used to construct the url. path: Path to serve the OpenAPI UI at. **kwargs: Additional arguments to pass to the base class. """ self.js_url = js_url or f"https://unpkg.com/@stoplight/elements@{version}/web-components.min.js" self.css_url = css_url or f"https://unpkg.com/@stoplight/elements@{version}/styles.min.css" super().__init__(path=path, **kwargs)
[docs] def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: """Render an HTML page for StopLight Elements. .. note:: Override this method to customize the template. Args: request: The request. openapi_schema: The OpenAPI schema as a dictionary. Returns: A rendered html string. """ head = f""" <head> <title>{openapi_schema["info"]["title"]}</title> {self.favicon} <meta charset="utf-8"/> <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"> <link rel="stylesheet" href="{self.css_url}"> <script src="{self.js_url}" crossorigin></script> {self.style} </head> """ body = f""" <body> <elements-api apiDescriptionUrl="{self.get_openapi_json_route(request)}" router="hash" layout="sidebar" /> </body> """ return f""" <!DOCTYPE html> <html> {head} {body} </html> """.encode()
[docs] class SwaggerRenderPlugin(OpenAPIRenderPlugin): """Render an OpenAPI schema using Swagger-UI."""
[docs] def __init__( self, version: str = "5.18.2", js_url: str | None = None, css_url: str | None = None, standalone_preset_js_url: str | None = None, init_oauth: dict[str, Any] | bytes | None = None, path: str | Sequence[str] = "/swagger", **kwargs: Any, ) -> None: """Initialize the OpenAPI UI render plugin. Args: version: SwaggerUI version to download from the CDN. If js_url is provided, this is ignored. js_url: Download url for the Swagger UI JS bundle. If not provided, the version will be used to construct the url. css_url: Download url for the Swagger UI CSS bundle. If not provided, the version will be used to construct the url. standalone_preset_js_url: Download url for the Swagger Standalone Preset JS bundle. If not provided, the version will be used to construct the url. init_oauth: JSON to initialize Swagger UI OAuth2 by calling the ``initOAuth`` method. Refer to the following URL for details: `Swagger-UI <https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/>`_. path: Path to serve the OpenAPI UI at. **kwargs: Additional arguments to pass to the base class. """ self.js_url = js_url or f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{version}/swagger-ui-bundle.js" self.css_url = css_url or f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{version}/swagger-ui.css" self.standalone_preset_js_url = ( standalone_preset_js_url or f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{version}/swagger-ui-standalone-preset.js" ) self.init_oauth = init_oauth or {} super().__init__(path=path, **kwargs)
[docs] def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: """Render an HTML page for Swagger-UI. Notes: - override this method to customize the template. Args: request: The request. openapi_schema: The OpenAPI schema as a dictionary. Returns: A rendered html string. """ def create_request_interceptor(csrf_config: CSRFConfig) -> bytes: if csrf_config.cookie_httponly: return b"" return f""" requestInterceptor: (request) => {{ const csrf_token = {_get_cookie_value_or_undefined(csrf_config.cookie_name)}; if (csrf_token !== undefined) {{ request.headers['{csrf_config.header_name}'] = csrf_token; }} return request; }},""".encode() head = f""" <head> <title>{openapi_schema["info"]["title"]}</title> {self.favicon} <meta charset="utf-8"/> <meta name="viewport" content="width=device-width, initial-scale=1"> <link href="{self.css_url}" rel="stylesheet"> <script src="{self.js_url}" crossorigin></script> <script src="{self.standalone_preset_js_url}" crossorigin></script> {self.style} </head> """ body = b"".join( [ b""" <body> <div id='swagger-container'/> <script type='text/javascript'> const ui = SwaggerUIBundle({ spec: """, self.render_json(request, openapi_schema), b""", dom_id: '#swagger-container', deepLinking: true, showExtensions: true, showCommonExtensions: true, presets: [ SwaggerUIBundle.presets.apis, SwaggerUIBundle.SwaggerUIStandalonePreset ],""", create_request_interceptor(request.app.csrf_config) if request.app.csrf_config else b"", b""" }) ui.initOAuth(""", encode_json(self.init_oauth), b""") </script> </body> """, ] ) return b"".join([b"<!DOCTYPE html><html>", head.encode(), body, b"</html>"])
[docs] def receive_router(self, router: Router) -> None: """Receive the router that serves the OpenAPI UI. Adds a route to serve the OAuth2 redirect page. Args: router: The router that serves the OpenAPI UI. """ router.register( get("/oauth2-redirect.html", media_type=MediaType.HTML, sync_to_thread=False)(self.render_oauth2_redirect), )
[docs] @staticmethod def render_oauth2_redirect() -> bytes: """Render an HTML oauth2-redirect.html page for Swagger-UI. .. note:: Override this method to customize the template. Returns: A rendered html string. """ return rb"""<!doctype html> <html lang="en-US"> <head> <title>Swagger UI: OAuth2 Redirect</title> </head> <body> <script> 'use strict'; function run () { var oauth2 = window.opener.swaggerUIRedirectOauth2; var sentState = oauth2.state; var redirectUrl = oauth2.redirectUrl; var isValid, qp, arr; if (/code|token|error/.test(window.location.hash)) { qp = window.location.hash.substring(1).replace('?', '&'); } else { qp = location.search.substring(1); } arr = qp.split("&"); arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';}); qp = qp ? JSON.parse('{' + arr.join() + '}', function (key, value) { return key === "" ? value : decodeURIComponent(value); } ) : {}; isValid = qp.state === sentState; if (( oauth2.auth.schema.get("flow") === "accessCode" || oauth2.auth.schema.get("flow") === "authorizationCode" || oauth2.auth.schema.get("flow") === "authorization_code" ) && !oauth2.auth.code) { if (!isValid) { oauth2.errCb({ authId: oauth2.auth.name, source: "auth", level: "warning", message: "Authorization may be unsafe, passed state was changed in server. The passed state wasn't returned from auth server." }); } if (qp.code) { delete oauth2.state; oauth2.auth.code = qp.code; oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl}); } else { let oauthErrorMsg; if (qp.error) { oauthErrorMsg = "["+qp.error+"]: " + (qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") + (qp.error_uri ? "More info: "+qp.error_uri : ""); } oauth2.errCb({ authId: oauth2.auth.name, source: "auth", level: "error", message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server." }); } } else { oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl}); } window.close(); } if (document.readyState !== 'loading') { run(); } else { document.addEventListener('DOMContentLoaded', function () { run(); }); } </script> </body> </html>"""