from __future__ import annotations
import dataclasses
from collections.abc import Sequence # noqa: TC003
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, TypedDict
import jwt
import msgspec
from litestar.exceptions import ImproperlyConfiguredException, NotAuthorizedException
if TYPE_CHECKING:
from typing_extensions import Self
__all__ = (
"JWTDecodeOptions",
"Token",
)
def _normalize_datetime(value: datetime) -> datetime:
"""Convert the given value into UTC and strip microseconds.
Args:
value: A datetime instance
Returns:
A datetime instance
"""
if value.tzinfo is not None:
value.astimezone(timezone.utc)
return value.replace(microsecond=0)
class JWTDecodeOptions(TypedDict, total=False):
"""``options`` for PyJWTs :func:`jwt.decode`"""
verify_aud: bool
verify_iss: bool
verify_exp: bool
verify_nbf: bool
strict_aud: bool
require: list[str]
[docs]
@dataclass
class Token:
"""JWT Token DTO."""
exp: datetime
"""Expiration - datetime for token expiration."""
sub: str
"""Subject - usually a unique identifier of the user or equivalent entity."""
iat: datetime = field(default_factory=lambda: _normalize_datetime(datetime.now(timezone.utc)))
"""Issued at - should always be current now."""
iss: str | None = field(default=None)
"""Issuer - optional unique identifier for the issuer."""
aud: str | Sequence[str] | None = field(default=None)
"""Audience - intended audience(s)."""
jti: str | None = field(default=None)
"""JWT ID - a unique identifier of the JWT between different issuers."""
extras: dict[str, Any] = field(default_factory=dict)
"""Extra fields that were found on the JWT token."""
def __post_init__(self) -> None:
if len(self.sub) < 1:
raise ImproperlyConfiguredException("sub must be a string with a length greater than 0")
if isinstance(self.exp, datetime) and (
(exp := _normalize_datetime(self.exp)).timestamp()
>= _normalize_datetime(datetime.now(timezone.utc)).timestamp()
):
self.exp = exp
else:
raise ImproperlyConfiguredException("exp value must be a datetime in the future")
if isinstance(self.iat, datetime) and (
(iat := _normalize_datetime(self.iat)).timestamp()
<= _normalize_datetime(datetime.now(timezone.utc)).timestamp()
):
self.iat = iat
else:
raise ImproperlyConfiguredException("iat must be a current or past time")
[docs]
@classmethod
def decode_payload(
cls,
encoded_token: str,
secret: str | bytes,
algorithms: list[str],
issuer: list[str] | None = None,
audience: str | Sequence[str] | None = None,
options: JWTDecodeOptions | None = None,
) -> Any:
"""Decode and verify the JWT and return its payload"""
return jwt.decode(
jwt=encoded_token,
key=secret,
algorithms=algorithms,
issuer=issuer,
audience=audience,
options=options, # type: ignore[arg-type]
)
[docs]
@classmethod
def decode(
cls,
encoded_token: str,
secret: str | bytes,
algorithm: str,
audience: str | Sequence[str] | None = None,
issuer: str | Sequence[str] | None = None,
require_claims: Sequence[str] | None = None,
verify_exp: bool = True,
verify_nbf: bool = True,
strict_audience: bool = False,
) -> Self:
"""Decode a passed in token string and return a Token instance.
Args:
encoded_token: A base64 string containing an encoded JWT.
secret: The secret with which the JWT is encoded.
algorithm: The algorithm used to encode the JWT.
audience: Verify the audience when decoding the token. If the audience in
the token does not match any audience given, raise a
:exc:`NotAuthorizedException`
issuer: Verify the issuer when decoding the token. If the issuer in the
token does not match any issuer given, raise a
:exc:`NotAuthorizedException`
require_claims: Verify that the given claims are present in the token
verify_exp: Verify that the value of the ``exp`` (*expiration*) claim is in
the future
verify_nbf: Verify that the value of the ``nbf`` (*not before*) claim is in
the past
strict_audience: Verify that the value of the ``aud`` (*audience*) claim is
a single value, and not a list of values, and matches ``audience``
exactly. Requires the value passed to the ``audience`` to be a sequence
of length 1
Returns:
A decoded Token instance.
Raises:
NotAuthorizedException: If the token is invalid.
"""
options: JWTDecodeOptions = {
"verify_aud": bool(audience),
"verify_iss": bool(issuer),
}
if require_claims:
options["require"] = list(require_claims)
if verify_exp is False:
options["verify_exp"] = False
if verify_nbf is False:
options["verify_nbf"] = False
if strict_audience:
if audience is None or (not isinstance(audience, str) and len(audience) != 1):
raise ValueError("When using 'strict_audience=True', 'audience' must be a sequence of length 1")
options["strict_aud"] = True
# although not documented, pyjwt requires audience to be a string if
# using the strict_aud option
if not isinstance(audience, str):
audience = audience[0]
try:
payload = cls.decode_payload(
encoded_token=encoded_token,
secret=secret,
algorithms=[algorithm],
audience=audience,
issuer=list(issuer) if issuer else None,
options=options,
)
# msgspec can do these conversions as well, but to keep backwards
# compatibility, we do it ourselves, since the datetime parsing works a
# little bit different there
payload["exp"] = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
payload["iat"] = datetime.fromtimestamp(payload["iat"], tz=timezone.utc)
extra_fields = payload.keys() - {f.name for f in dataclasses.fields(cls)}
extras = payload.setdefault("extras", {})
for key in extra_fields:
extras[key] = payload.pop(key)
return msgspec.convert(payload, cls, strict=False)
except (
KeyError,
jwt.exceptions.InvalidTokenError,
ImproperlyConfiguredException,
msgspec.ValidationError,
) as e:
raise NotAuthorizedException("Invalid token") from e
[docs]
def encode(
self,
secret: str | bytes,
algorithm: str,
headers: dict[str, Any] | None = None,
) -> str:
"""Encode the token instance into a string.
Args:
secret: The secret with which the JWT is encoded.
algorithm: The algorithm used to encode the JWT.
headers: Optional headers to include in the JWT (e.g., {"kid": "..."}).
Returns:
An encoded token string.
Raises:
ImproperlyConfiguredException: If encoding fails.
"""
try:
return jwt.encode(
payload={k: v for k, v in asdict(self).items() if v is not None},
key=secret,
algorithm=algorithm,
headers=headers,
)
except (jwt.DecodeError, NotImplementedError) as e:
raise ImproperlyConfiguredException("Failed to encode token") from e