Skip to content

Commit 345a87a

Browse files
committed
Add support for JWT Token authentication with JWK
This commit introduces a new authentication module `jwk-token` that allows the application to authenticate users using JSON Web Tokens (JWT) signed with JSON Web Keys (JWK). The new module fetches JWKs from a specified URL and validates incoming JWTs against these keys. Example config: ```yaml authentication: module: jwk-token jwk_config: url: https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/certs jwt_configuration: user_id_claim: user_id username_claim: username ```
1 parent e49dcd5 commit 345a87a

File tree

9 files changed

+803
-3
lines changed

9 files changed

+803
-3
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ dependencies = [
1515
"cachetools>=6.1.0",
1616
"prometheus-client>=0.22.1",
1717
"starlette>=0.47.1",
18+
"aiohttp>=3.12.14",
19+
"authlib>=1.6.0",
1820
]
1921

2022
[tool.pyright]
@@ -53,6 +55,7 @@ dev = [
5355
"ruff>=0.11.13",
5456
"aiosqlite",
5557
"behave>=1.2.6",
58+
"types-cachetools>=6.1.0.20250717",
5659
]
5760
build = [
5861
"build>=1.2.2.post1",

src/auth/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from auth.interface import AuthInterface
6-
from auth import noop, noop_with_token, k8s
6+
from auth import noop, noop_with_token, k8s, jwk_token
77
from configuration import configuration
88
import constants
99

@@ -15,7 +15,7 @@ def get_auth_dependency(
1515
virtual_path: str = constants.DEFAULT_VIRTUAL_PATH,
1616
) -> AuthInterface:
1717
"""Select the configured authentication dependency interface."""
18-
module = configuration.authentication_configuration.module # pyright: ignore
18+
module = configuration.authentication_configuration.module
1919

2020
logger.debug(
2121
"Initializing authentication dependency: module='%s', virtual_path='%s'",
@@ -32,6 +32,11 @@ def get_auth_dependency(
3232
)
3333
case constants.AUTH_MOD_K8S:
3434
return k8s.K8SAuthDependency(virtual_path=virtual_path)
35+
case constants.AUTH_MOD_JWK_TOKEN:
36+
return jwk_token.JwkTokenAuthDependency(
37+
configuration.authentication_configuration.jwk_configuration,
38+
virtual_path=virtual_path,
39+
)
3540
case _:
3641
err_msg = f"Unsupported authentication module '{module}'"
3742
logger.error(err_msg)

src/auth/jwk_token.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Manage authentication flow for FastAPI endpoints with JWK based JWT auth."""
2+
3+
import logging
4+
from asyncio import Lock
5+
from typing import Any
6+
7+
from fastapi import Request, HTTPException, status
8+
from authlib.jose import JsonWebKey, KeySet, jwt, Key
9+
from authlib.jose.errors import (
10+
BadSignatureError,
11+
DecodeError,
12+
ExpiredTokenError,
13+
JoseError,
14+
)
15+
from cachetools import TTLCache
16+
import aiohttp
17+
18+
from constants import (
19+
DEFAULT_VIRTUAL_PATH,
20+
)
21+
from auth.interface import AuthInterface
22+
from auth.utils import extract_user_token
23+
from models.config import JwkConfiguration
24+
25+
logger = logging.getLogger(__name__)
26+
27+
# Global JWK registry to avoid re-fetching JWKs for each request. Cached for 1
28+
# hour, keys are unlikely to change frequently.
29+
_jwk_cache: TTLCache[str, KeySet] = TTLCache(maxsize=3, ttl=3600)
30+
# Ideally this would be an RWLock, but it would require adding a dependency on
31+
# aiorwlock
32+
_jwk_cache_lock = Lock()
33+
34+
35+
async def get_jwk_set(url: str) -> KeySet:
36+
"""Fetch the JWK set from the cache, or fetch it from the URL if not cached."""
37+
async with _jwk_cache_lock:
38+
if url not in _jwk_cache:
39+
async with aiohttp.ClientSession() as session:
40+
# TODO: handle connection errors, timeouts, etc.
41+
async with session.get(url) as resp:
42+
resp.raise_for_status()
43+
_jwk_cache[url] = JsonWebKey.import_key_set(await resp.json())
44+
return _jwk_cache[url]
45+
46+
47+
class KeyNotFoundError(Exception):
48+
"""Exception raised when a key is not found in the JWK set based on kid/alg."""
49+
50+
51+
class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods
52+
"""JWK AuthDependency class for JWK-based JWT authentication."""
53+
54+
def __init__(
55+
self, config: JwkConfiguration, virtual_path: str = DEFAULT_VIRTUAL_PATH
56+
) -> None:
57+
"""Initialize the required allowed paths for authorization checks."""
58+
self.virtual_path: str = virtual_path
59+
self.config: JwkConfiguration = config
60+
61+
async def __call__(self, request: Request) -> tuple[str, str, str]:
62+
"""Authenticate the JWT in the headers against the keys from the JWK url."""
63+
user_token = extract_user_token(request.headers)
64+
65+
jwk_set = await get_jwk_set(str(self.config.url))
66+
67+
def resolve_key(header: dict[str, Any], _payload: dict[str, Any]) -> Key:
68+
"""Match kid and alg from the JWT header to the JWK set.
69+
70+
Resolve the key from the JWK set based on the JWT header. Also
71+
match the algorithm to make sure the algorithm stated by the user
72+
is the same algorithm the key itself expects.
73+
"""
74+
key = None
75+
if "kid" not in header:
76+
# Token has no kid - we will return the first key (which should
77+
# work well for a single key set), otherwise hopefully it will match
78+
# the token, but if not, unlucky - we're not going to
79+
# brute-force all keys until we find the one that matches, that
80+
# makes us more vulnerable to DoS
81+
key = list(jwk_set.keys)[0]
82+
else:
83+
try:
84+
key = jwk_set.find_by_kid(header["kid"])
85+
except ValueError as exc:
86+
raise KeyNotFoundError from exc
87+
88+
if key["kid"] != header["kid"]:
89+
# find_by_kid sometimes returns a key that does not match the kid
90+
raise KeyNotFoundError
91+
92+
if header["alg"] != key["alg"]:
93+
raise KeyNotFoundError
94+
95+
return key
96+
97+
try:
98+
claims = jwt.decode(user_token, key=resolve_key)
99+
except KeyNotFoundError as exc:
100+
raise HTTPException(
101+
status_code=status.HTTP_401_UNAUTHORIZED,
102+
detail="Invalid token: signed by unknown key or algorithm mismatch",
103+
) from exc
104+
except BadSignatureError as exc:
105+
raise HTTPException(
106+
status_code=status.HTTP_401_UNAUTHORIZED,
107+
detail="Invalid token: bad signature",
108+
) from exc
109+
except DecodeError as exc:
110+
raise HTTPException(
111+
status_code=status.HTTP_400_BAD_REQUEST,
112+
detail="Invalid token: decode error",
113+
) from exc
114+
except JoseError as exc:
115+
raise HTTPException(
116+
status_code=status.HTTP_400_BAD_REQUEST,
117+
detail="Invalid token: unknown error",
118+
) from exc
119+
except Exception as exc:
120+
raise HTTPException(
121+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
122+
detail="Internal server error",
123+
) from exc
124+
125+
try:
126+
claims.validate()
127+
except ExpiredTokenError as exc:
128+
raise HTTPException(
129+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired"
130+
) from exc
131+
except JoseError as exc:
132+
raise HTTPException(
133+
status_code=status.HTTP_401_UNAUTHORIZED,
134+
detail="Error validating token",
135+
) from exc
136+
except Exception as exc:
137+
raise HTTPException(
138+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
139+
detail="Internal server error during token validation",
140+
) from exc
141+
142+
try:
143+
user_id: str = claims[self.config.jwt_configuration.user_id_claim]
144+
except KeyError as exc:
145+
raise HTTPException(
146+
status_code=status.HTTP_401_UNAUTHORIZED,
147+
detail=f"Token missing claim: {self.config.jwt_configuration.user_id_claim}",
148+
) from exc
149+
150+
try:
151+
username: str = claims[self.config.jwt_configuration.username_claim]
152+
except KeyError as exc:
153+
raise HTTPException(
154+
status_code=status.HTTP_401_UNAUTHORIZED,
155+
detail=f"Token missing claim: {self.config.jwt_configuration.username_claim}",
156+
) from exc
157+
158+
logger.info("Successfully authenticated user %s (ID: %s)", username, user_id)
159+
160+
return user_id, username, user_token

src/configuration.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,16 @@ def mcp_servers(self) -> list[ModelContextProtocolServer]:
8484
return self._configuration.mcp_servers
8585

8686
@property
87-
def authentication_configuration(self) -> Optional[AuthenticationConfiguration]:
87+
def authentication_configuration(self) -> AuthenticationConfiguration:
8888
"""Return authentication configuration."""
8989
assert (
9090
self._configuration is not None
9191
), "logic error: configuration is not loaded"
92+
93+
assert (
94+
self._configuration.authentication is not None
95+
), "logic error: authentication configuration is not loaded"
96+
9297
return self._configuration.authentication
9398

9499
@property

src/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,19 @@
3333
AUTH_MOD_K8S = "k8s"
3434
AUTH_MOD_NOOP = "noop"
3535
AUTH_MOD_NOOP_WITH_TOKEN = "noop-with-token"
36+
AUTH_MOD_JWK_TOKEN = "jwk-token"
3637
# Supported authentication modules
3738
SUPPORTED_AUTHENTICATION_MODULES = frozenset(
3839
{
3940
AUTH_MOD_K8S,
4041
AUTH_MOD_NOOP,
4142
AUTH_MOD_NOOP_WITH_TOKEN,
43+
AUTH_MOD_JWK_TOKEN,
4244
}
4345
)
4446
DEFAULT_AUTHENTICATION_MODULE = AUTH_MOD_NOOP
47+
DEFAULT_JWT_UID_CLAIM = "user_id"
48+
DEFAULT_JWT_USER_NAME_CLAIM = "username"
4549

4650
# Data collector constants
4751
DATA_COLLECTOR_COLLECTION_INTERVAL = 7200 # 2 hours in seconds

src/models/config.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,38 @@ def check_storage_location_is_set_when_needed(self) -> Self:
147147
return self
148148

149149

150+
class JwtConfiguration(BaseModel):
151+
"""JWT configuration."""
152+
153+
user_id_claim: str = constants.DEFAULT_JWT_UID_CLAIM
154+
username_claim: str = constants.DEFAULT_JWT_USER_NAME_CLAIM
155+
156+
@model_validator(mode="after")
157+
def check_jwt_configuration(self) -> Self:
158+
"""Validate JWT configuration."""
159+
return self
160+
161+
162+
class JwkConfiguration(BaseModel):
163+
"""JWK configuration."""
164+
165+
url: AnyHttpUrl
166+
jwt_configuration: JwtConfiguration = JwtConfiguration()
167+
168+
@model_validator(mode="after")
169+
def check_jwk_configuration(self) -> Self:
170+
"""Validate JWK configuration."""
171+
return self
172+
173+
150174
class AuthenticationConfiguration(BaseModel):
151175
"""Authentication configuration."""
152176

153177
module: str = constants.DEFAULT_AUTHENTICATION_MODULE
154178
skip_tls_verification: bool = False
155179
k8s_cluster_api: Optional[AnyHttpUrl] = None
156180
k8s_ca_cert_path: Optional[FilePath] = None
181+
jwk_config: Optional[JwkConfiguration] = None
157182

158183
@model_validator(mode="after")
159184
def check_authentication_model(self) -> Self:
@@ -164,8 +189,25 @@ def check_authentication_model(self) -> Self:
164189
f"Unsupported authentication module '{self.module}'. "
165190
f"Supported modules: {supported_modules}"
166191
)
192+
193+
if self.module == constants.AUTH_MOD_JWK_TOKEN:
194+
if self.jwk_config is None:
195+
raise ValueError(
196+
"JWK configuration must be specified when using JWK token authentication"
197+
)
198+
167199
return self
168200

201+
@property
202+
def jwk_configuration(self) -> JwkConfiguration:
203+
"""Return JWK configuration if the module is JWK token."""
204+
if self.module != constants.AUTH_MOD_JWK_TOKEN:
205+
raise ValueError(
206+
"JWK configuration is only available for JWK token authentication module"
207+
)
208+
assert self.jwk_config is not None, "JWK configuration should not be None"
209+
return self.jwk_config
210+
169211

170212
class Customization(BaseModel):
171213
"""Service customization."""

0 commit comments

Comments
 (0)