Skip to content

Commit 31cca10

Browse files
committed
Add support for JWT Token authentication with JWK
This commit introduces a new authentication module `jwk-token` that allows the application to authenticate users using JSON Web Tokens (JWT) signed with JSON Web Keys (JWK). The new module fetches JWKs from a specified URL and validates incoming JWTs against these keys. Example config: ```yaml authentication: module: jwk-token jwk_config: url: https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/certs jwt_configuration: user_id_claim: user_id username_claim: username ```
1 parent 0fafc5c commit 31cca10

File tree

9 files changed

+825
-3
lines changed

9 files changed

+825
-3
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ dependencies = [
1515
"cachetools>=6.1.0",
1616
"prometheus-client>=0.22.1",
1717
"starlette>=0.47.1",
18+
"aiohttp>=3.12.14",
19+
"authlib>=1.6.0",
1820
]
1921

2022
[tool.pyright]
@@ -53,6 +55,7 @@ dev = [
5355
"ruff>=0.11.13",
5456
"aiosqlite",
5557
"behave>=1.2.6",
58+
"types-cachetools>=6.1.0.20250717",
5659
]
5760
build = [
5861
"build>=1.2.2.post1",

src/auth/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from auth.interface import AuthInterface
6-
from auth import noop, noop_with_token, k8s
6+
from auth import noop, noop_with_token, k8s, jwk_token
77
from configuration import configuration
88
import constants
99

@@ -15,7 +15,7 @@ def get_auth_dependency(
1515
virtual_path: str = constants.DEFAULT_VIRTUAL_PATH,
1616
) -> AuthInterface:
1717
"""Select the configured authentication dependency interface."""
18-
module = configuration.authentication_configuration.module # pyright: ignore
18+
module = configuration.authentication_configuration.module
1919

2020
logger.debug(
2121
"Initializing authentication dependency: module='%s', virtual_path='%s'",
@@ -32,6 +32,11 @@ def get_auth_dependency(
3232
)
3333
case constants.AUTH_MOD_K8S:
3434
return k8s.K8SAuthDependency(virtual_path=virtual_path)
35+
case constants.AUTH_MOD_JWK_TOKEN:
36+
return jwk_token.JwkTokenAuthDependency(
37+
configuration.authentication_configuration.jwk_configuration,
38+
virtual_path=virtual_path,
39+
)
3540
case _:
3641
err_msg = f"Unsupported authentication module '{module}'"
3742
logger.error(err_msg)

src/auth/jwk_token.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""Manage authentication flow for FastAPI endpoints with JWK based JWT auth."""
2+
3+
import logging
4+
from asyncio import Lock
5+
from typing import Any, Callable
6+
7+
from fastapi import Request, HTTPException, status
8+
from authlib.jose import JsonWebKey, KeySet, jwt, Key
9+
from authlib.jose.errors import (
10+
BadSignatureError,
11+
DecodeError,
12+
ExpiredTokenError,
13+
JoseError,
14+
)
15+
from cachetools import TTLCache
16+
import aiohttp
17+
18+
from constants import (
19+
DEFAULT_VIRTUAL_PATH,
20+
)
21+
from auth.interface import AuthInterface
22+
from auth.utils import extract_user_token
23+
from models.config import JwkConfiguration
24+
25+
logger = logging.getLogger(__name__)
26+
27+
# Global JWK registry to avoid re-fetching JWKs for each request. Cached for 1
28+
# hour, keys are unlikely to change frequently.
29+
_jwk_cache: TTLCache[str, KeySet] = TTLCache(maxsize=3, ttl=3600)
30+
# Ideally this would be an RWLock, but it would require adding a dependency on
31+
# aiorwlock
32+
_jwk_cache_lock = Lock()
33+
34+
35+
async def get_jwk_set(url: str) -> KeySet:
36+
"""Fetch the JWK set from the cache, or fetch it from the URL if not cached."""
37+
async with _jwk_cache_lock:
38+
if url not in _jwk_cache:
39+
async with aiohttp.ClientSession() as session:
40+
# TODO(omertuc): handle connection errors, timeouts, etc.
41+
async with session.get(url) as resp:
42+
resp.raise_for_status()
43+
_jwk_cache[url] = JsonWebKey.import_key_set(await resp.json())
44+
return _jwk_cache[url]
45+
46+
47+
class KeyNotFoundError(Exception):
48+
"""Exception raised when a key is not found in the JWK set based on kid/alg."""
49+
50+
51+
def key_resolver_func(
52+
jwk_set: KeySet,
53+
) -> Callable[[dict[str, Any], dict[str, Any]], Key]:
54+
"""
55+
Create a key resolver function.
56+
57+
Return a function to find a key in the given jwk_set. The function matches the
58+
signature expected by the jwt.decode key kwarg.
59+
"""
60+
61+
def _internal(header: dict[str, Any], _payload: dict[str, Any]) -> Key:
62+
"""Match kid and alg from the JWT header to the JWK set.
63+
64+
Resolve the key from the JWK set based on the JWT header. Also
65+
match the algorithm to make sure the algorithm stated by the user
66+
is the same algorithm the key itself expects.
67+
68+
# We intentionally do not use find_by_kid because it's a bad function
69+
# that doesn't take the alg into account
70+
"""
71+
if "alg" not in header:
72+
raise KeyNotFoundError("Token header missing 'alg' field")
73+
74+
if "kid" in header:
75+
keys = [key for key in jwk_set.keys if key.kid == header.get("kid")]
76+
77+
if len(keys) == 0:
78+
raise KeyNotFoundError(
79+
"No key found matching kid and alg in the JWK set"
80+
)
81+
82+
if len(keys) > 1:
83+
# This should never happen! Bad JWK set!
84+
raise KeyNotFoundError(
85+
"Internal server error, multiple keys found matching this kid"
86+
)
87+
88+
key = keys[0]
89+
90+
if key["alg"] != header["alg"]:
91+
raise KeyNotFoundError(
92+
"Key found by kid does not match the algorithm in the token header"
93+
)
94+
95+
return key
96+
97+
# No kid in the token header, we will try to find a key by alg
98+
keys = [key for key in jwk_set.keys if key["alg"] == header["alg"]]
99+
100+
if len(keys) == 0:
101+
raise KeyNotFoundError("No key found matching alg in the JWK set")
102+
103+
# Token has no kid and even we have more than one key with this algorithm - we will
104+
# return the first key which matches the algorithm, hopefully it will
105+
# match the token, but if not, unlucky - we're not going to brute-force all
106+
# keys until we find the one that matches, that makes us more vulnerable to DoS
107+
return keys[0]
108+
109+
return _internal
110+
111+
112+
class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods
113+
"""JWK AuthDependency class for JWK-based JWT authentication."""
114+
115+
def __init__(
116+
self, config: JwkConfiguration, virtual_path: str = DEFAULT_VIRTUAL_PATH
117+
) -> None:
118+
"""Initialize the required allowed paths for authorization checks."""
119+
self.virtual_path: str = virtual_path
120+
self.config: JwkConfiguration = config
121+
122+
async def __call__(self, request: Request) -> tuple[str, str, str]:
123+
"""Authenticate the JWT in the headers against the keys from the JWK url."""
124+
user_token = extract_user_token(request.headers)
125+
126+
jwk_set = await get_jwk_set(str(self.config.url))
127+
128+
try:
129+
claims = jwt.decode(user_token, key=key_resolver_func(jwk_set))
130+
except KeyNotFoundError as exc:
131+
raise HTTPException(
132+
status_code=status.HTTP_401_UNAUTHORIZED,
133+
detail="Invalid token: signed by unknown key or algorithm mismatch",
134+
) from exc
135+
except BadSignatureError as exc:
136+
raise HTTPException(
137+
status_code=status.HTTP_401_UNAUTHORIZED,
138+
detail="Invalid token: bad signature",
139+
) from exc
140+
except DecodeError as exc:
141+
raise HTTPException(
142+
status_code=status.HTTP_400_BAD_REQUEST,
143+
detail="Invalid token: decode error",
144+
) from exc
145+
except JoseError as exc:
146+
raise HTTPException(
147+
status_code=status.HTTP_400_BAD_REQUEST,
148+
detail="Invalid token: unknown error",
149+
) from exc
150+
except Exception as exc:
151+
raise HTTPException(
152+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
153+
detail="Internal server error",
154+
) from exc
155+
156+
try:
157+
claims.validate()
158+
except ExpiredTokenError as exc:
159+
raise HTTPException(
160+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired"
161+
) from exc
162+
except JoseError as exc:
163+
raise HTTPException(
164+
status_code=status.HTTP_401_UNAUTHORIZED,
165+
detail="Error validating token",
166+
) from exc
167+
except Exception as exc:
168+
raise HTTPException(
169+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
170+
detail="Internal server error during token validation",
171+
) from exc
172+
173+
try:
174+
user_id: str = claims[self.config.jwt_configuration.user_id_claim]
175+
except KeyError as exc:
176+
raise HTTPException(
177+
status_code=status.HTTP_401_UNAUTHORIZED,
178+
detail=f"Token missing claim: {self.config.jwt_configuration.user_id_claim}",
179+
) from exc
180+
181+
try:
182+
username: str = claims[self.config.jwt_configuration.username_claim]
183+
except KeyError as exc:
184+
raise HTTPException(
185+
status_code=status.HTTP_401_UNAUTHORIZED,
186+
detail=f"Token missing claim: {self.config.jwt_configuration.username_claim}",
187+
) from exc
188+
189+
logger.info("Successfully authenticated user %s (ID: %s)", username, user_id)
190+
191+
return user_id, username, user_token

src/configuration.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,16 @@ def mcp_servers(self) -> list[ModelContextProtocolServer]:
8585
return self._configuration.mcp_servers
8686

8787
@property
88-
def authentication_configuration(self) -> Optional[AuthenticationConfiguration]:
88+
def authentication_configuration(self) -> AuthenticationConfiguration:
8989
"""Return authentication configuration."""
9090
assert (
9191
self._configuration is not None
9292
), "logic error: configuration is not loaded"
93+
94+
assert (
95+
self._configuration.authentication is not None
96+
), "logic error: authentication configuration is not loaded"
97+
9398
return self._configuration.authentication
9499

95100
@property

src/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,19 @@
3333
AUTH_MOD_K8S = "k8s"
3434
AUTH_MOD_NOOP = "noop"
3535
AUTH_MOD_NOOP_WITH_TOKEN = "noop-with-token"
36+
AUTH_MOD_JWK_TOKEN = "jwk-token"
3637
# Supported authentication modules
3738
SUPPORTED_AUTHENTICATION_MODULES = frozenset(
3839
{
3940
AUTH_MOD_K8S,
4041
AUTH_MOD_NOOP,
4142
AUTH_MOD_NOOP_WITH_TOKEN,
43+
AUTH_MOD_JWK_TOKEN,
4244
}
4345
)
4446
DEFAULT_AUTHENTICATION_MODULE = AUTH_MOD_NOOP
47+
DEFAULT_JWT_UID_CLAIM = "user_id"
48+
DEFAULT_JWT_USER_NAME_CLAIM = "username"
4549

4650
# Data collector constants
4751
DATA_COLLECTOR_COLLECTION_INTERVAL = 7200 # 2 hours in seconds

src/models/config.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,28 @@ def check_storage_location_is_set_when_needed(self) -> Self:
147147
return self
148148

149149

150+
class JwtConfiguration(BaseModel):
151+
"""JWT configuration."""
152+
153+
user_id_claim: str = constants.DEFAULT_JWT_UID_CLAIM
154+
username_claim: str = constants.DEFAULT_JWT_USER_NAME_CLAIM
155+
156+
157+
class JwkConfiguration(BaseModel):
158+
"""JWK configuration."""
159+
160+
url: AnyHttpUrl
161+
jwt_configuration: JwtConfiguration = JwtConfiguration()
162+
163+
150164
class AuthenticationConfiguration(BaseModel):
151165
"""Authentication configuration."""
152166

153167
module: str = constants.DEFAULT_AUTHENTICATION_MODULE
154168
skip_tls_verification: bool = False
155169
k8s_cluster_api: Optional[AnyHttpUrl] = None
156170
k8s_ca_cert_path: Optional[FilePath] = None
171+
jwk_config: Optional[JwkConfiguration] = None
157172

158173
@model_validator(mode="after")
159174
def check_authentication_model(self) -> Self:
@@ -164,8 +179,25 @@ def check_authentication_model(self) -> Self:
164179
f"Unsupported authentication module '{self.module}'. "
165180
f"Supported modules: {supported_modules}"
166181
)
182+
183+
if self.module == constants.AUTH_MOD_JWK_TOKEN:
184+
if self.jwk_config is None:
185+
raise ValueError(
186+
"JWK configuration must be specified when using JWK token authentication"
187+
)
188+
167189
return self
168190

191+
@property
192+
def jwk_configuration(self) -> JwkConfiguration:
193+
"""Return JWK configuration if the module is JWK token."""
194+
if self.module != constants.AUTH_MOD_JWK_TOKEN:
195+
raise ValueError(
196+
"JWK configuration is only available for JWK token authentication module"
197+
)
198+
assert self.jwk_config is not None, "JWK configuration should not be None"
199+
return self.jwk_config
200+
169201

170202
class Customization(BaseModel):
171203
"""Service customization."""

0 commit comments

Comments
 (0)