Skip to content

Commit 7476215

Browse files
committed
Add support for JWT Token authentication with JWK
This commit introduces a new authentication module `jwk-token` that allows the application to authenticate users using JSON Web Tokens (JWT) signed with JSON Web Keys (JWK). The new module fetches JWKs from a specified URL and validates incoming JWTs against these keys. Example config: ```yaml authentication: module: jwk-token jwk_config: url: https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/certs jwt_configuration: user_id_claim: user_id username_claim: username ```
1 parent c1c7ba3 commit 7476215

File tree

9 files changed

+823
-3
lines changed

9 files changed

+823
-3
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ dependencies = [
1515
"cachetools>=6.1.0",
1616
"prometheus-client>=0.22.1",
1717
"starlette>=0.47.1",
18+
"aiohttp>=3.12.14",
19+
"authlib>=1.6.0",
1820
]
1921

2022
[tool.pyright]
@@ -53,6 +55,7 @@ dev = [
5355
"ruff>=0.11.13",
5456
"aiosqlite",
5557
"behave>=1.2.6",
58+
"types-cachetools>=6.1.0.20250717",
5659
]
5760
build = [
5861
"build>=1.2.2.post1",

src/auth/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from auth.interface import AuthInterface
6-
from auth import noop, noop_with_token, k8s
6+
from auth import noop, noop_with_token, k8s, jwk_token
77
from configuration import configuration
88
import constants
99

@@ -15,7 +15,7 @@ def get_auth_dependency(
1515
virtual_path: str = constants.DEFAULT_VIRTUAL_PATH,
1616
) -> AuthInterface:
1717
"""Select the configured authentication dependency interface."""
18-
module = configuration.authentication_configuration.module # pyright: ignore
18+
module = configuration.authentication_configuration.module
1919

2020
logger.debug(
2121
"Initializing authentication dependency: module='%s', virtual_path='%s'",
@@ -32,6 +32,11 @@ def get_auth_dependency(
3232
)
3333
case constants.AUTH_MOD_K8S:
3434
return k8s.K8SAuthDependency(virtual_path=virtual_path)
35+
case constants.AUTH_MOD_JWK_TOKEN:
36+
return jwk_token.JwkTokenAuthDependency(
37+
configuration.authentication_configuration.jwk_configuration,
38+
virtual_path=virtual_path,
39+
)
3540
case _:
3641
err_msg = f"Unsupported authentication module '{module}'"
3742
logger.error(err_msg)

src/auth/jwk_token.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
"""Manage authentication flow for FastAPI endpoints with JWK based JWT auth."""
2+
3+
import logging
4+
from asyncio import Lock
5+
from typing import Any
6+
7+
from fastapi import Request, HTTPException, status
8+
from authlib.jose import JsonWebKey, KeySet, jwt, Key
9+
from authlib.jose.errors import (
10+
BadSignatureError,
11+
DecodeError,
12+
ExpiredTokenError,
13+
JoseError,
14+
)
15+
from cachetools import TTLCache
16+
import aiohttp
17+
18+
from constants import (
19+
DEFAULT_VIRTUAL_PATH,
20+
)
21+
from auth.interface import AuthInterface
22+
from auth.utils import extract_user_token
23+
from models.config import JwkConfiguration
24+
25+
logger = logging.getLogger(__name__)
26+
27+
# Global JWK registry to avoid re-fetching JWKs for each request. Cached for 1
28+
# hour, keys are unlikely to change frequently.
29+
_jwk_cache: TTLCache[str, KeySet] = TTLCache(maxsize=3, ttl=3600)
30+
# Ideally this would be an RWLock, but it would require adding a dependency on
31+
# aiorwlock
32+
_jwk_cache_lock = Lock()
33+
34+
35+
async def get_jwk_set(url: str) -> KeySet:
36+
"""Fetch the JWK set from the cache, or fetch it from the URL if not cached."""
37+
async with _jwk_cache_lock:
38+
if url not in _jwk_cache:
39+
async with aiohttp.ClientSession() as session:
40+
# TODO(omertuc): handle connection errors, timeouts, etc.
41+
async with session.get(url) as resp:
42+
resp.raise_for_status()
43+
_jwk_cache[url] = JsonWebKey.import_key_set(await resp.json())
44+
return _jwk_cache[url]
45+
46+
47+
class KeyNotFoundError(Exception):
48+
"""Exception raised when a key is not found in the JWK set based on kid/alg."""
49+
50+
51+
def key_resolver_func(jwk_set):
52+
"""
53+
Return a function to find a key in the given jwk_set. The function matches the
54+
signature expected by the jwt.decode key kwarg.
55+
"""
56+
57+
def _internal(header: dict[str, Any], _payload: dict[str, Any]) -> Key:
58+
"""Match kid and alg from the JWT header to the JWK set.
59+
60+
Resolve the key from the JWK set based on the JWT header. Also
61+
match the algorithm to make sure the algorithm stated by the user
62+
is the same algorithm the key itself expects.
63+
64+
# We intentionally do not use find_by_kid because it's a bad function
65+
# that doesn't take the alg into account
66+
"""
67+
if "alg" not in header:
68+
raise KeyNotFoundError("Token header missing 'alg' field")
69+
70+
if "kid" in header:
71+
keys = [key for key in jwk_set.keys if key.kid == header.get("kid")]
72+
73+
if len(keys) == 0:
74+
raise KeyNotFoundError(
75+
"No key found matching kid and alg in the JWK set"
76+
)
77+
78+
if len(keys) > 1:
79+
# This should never happen! Bad JWK set!
80+
raise KeyNotFoundError(
81+
"Internal server error, multiple keys found matching this kid"
82+
)
83+
84+
key = keys[0]
85+
86+
if key["alg"] != header["alg"]:
87+
raise KeyNotFoundError(
88+
"Key found by kid does not match the algorithm in the token header"
89+
)
90+
91+
return key
92+
93+
# No kid in the token header, we will try to find a key by alg
94+
keys = [key for key in jwk_set.keys if key["alg"] == header["alg"]]
95+
96+
if len(keys) == 0:
97+
raise KeyNotFoundError("No key found matching alg in the JWK set")
98+
99+
# Token has no kid and even we have more than one key with this algorithm - we will
100+
# return the first key which matches the algorithm, hopefully it will
101+
# match the token, but if not, unlucky - we're not going to brute-force all
102+
# keys until we find the one that matches, that makes us more vulnerable to DoS
103+
return keys[0]
104+
105+
return _internal
106+
107+
108+
class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods
109+
"""JWK AuthDependency class for JWK-based JWT authentication."""
110+
111+
def __init__(
112+
self, config: JwkConfiguration, virtual_path: str = DEFAULT_VIRTUAL_PATH
113+
) -> None:
114+
"""Initialize the required allowed paths for authorization checks."""
115+
self.virtual_path: str = virtual_path
116+
self.config: JwkConfiguration = config
117+
118+
async def __call__(self, request: Request) -> tuple[str, str, str]:
119+
"""Authenticate the JWT in the headers against the keys from the JWK url."""
120+
user_token = extract_user_token(request.headers)
121+
122+
jwk_set = await get_jwk_set(str(self.config.url))
123+
124+
try:
125+
claims = jwt.decode(user_token, key=key_resolver_func(jwk_set))
126+
except KeyNotFoundError as exc:
127+
raise HTTPException(
128+
status_code=status.HTTP_401_UNAUTHORIZED,
129+
detail="Invalid token: signed by unknown key or algorithm mismatch",
130+
) from exc
131+
except BadSignatureError as exc:
132+
raise HTTPException(
133+
status_code=status.HTTP_401_UNAUTHORIZED,
134+
detail="Invalid token: bad signature",
135+
) from exc
136+
except DecodeError as exc:
137+
raise HTTPException(
138+
status_code=status.HTTP_400_BAD_REQUEST,
139+
detail="Invalid token: decode error",
140+
) from exc
141+
except JoseError as exc:
142+
raise HTTPException(
143+
status_code=status.HTTP_400_BAD_REQUEST,
144+
detail="Invalid token: unknown error",
145+
) from exc
146+
except Exception as exc:
147+
raise HTTPException(
148+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
149+
detail="Internal server error",
150+
) from exc
151+
152+
try:
153+
claims.validate()
154+
except ExpiredTokenError as exc:
155+
raise HTTPException(
156+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired"
157+
) from exc
158+
except JoseError as exc:
159+
raise HTTPException(
160+
status_code=status.HTTP_401_UNAUTHORIZED,
161+
detail="Error validating token",
162+
) from exc
163+
except Exception as exc:
164+
raise HTTPException(
165+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
166+
detail="Internal server error during token validation",
167+
) from exc
168+
169+
try:
170+
user_id: str = claims[self.config.jwt_configuration.user_id_claim]
171+
except KeyError as exc:
172+
raise HTTPException(
173+
status_code=status.HTTP_401_UNAUTHORIZED,
174+
detail=f"Token missing claim: {self.config.jwt_configuration.user_id_claim}",
175+
) from exc
176+
177+
try:
178+
username: str = claims[self.config.jwt_configuration.username_claim]
179+
except KeyError as exc:
180+
raise HTTPException(
181+
status_code=status.HTTP_401_UNAUTHORIZED,
182+
detail=f"Token missing claim: {self.config.jwt_configuration.username_claim}",
183+
) from exc
184+
185+
logger.info("Successfully authenticated user %s (ID: %s)", username, user_id)
186+
187+
return user_id, username, user_token

src/configuration.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,16 @@ def mcp_servers(self) -> list[ModelContextProtocolServer]:
8585
return self._configuration.mcp_servers
8686

8787
@property
88-
def authentication_configuration(self) -> Optional[AuthenticationConfiguration]:
88+
def authentication_configuration(self) -> AuthenticationConfiguration:
8989
"""Return authentication configuration."""
9090
assert (
9191
self._configuration is not None
9292
), "logic error: configuration is not loaded"
93+
94+
assert (
95+
self._configuration.authentication is not None
96+
), "logic error: authentication configuration is not loaded"
97+
9398
return self._configuration.authentication
9499

95100
@property

src/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,19 @@
3333
AUTH_MOD_K8S = "k8s"
3434
AUTH_MOD_NOOP = "noop"
3535
AUTH_MOD_NOOP_WITH_TOKEN = "noop-with-token"
36+
AUTH_MOD_JWK_TOKEN = "jwk-token"
3637
# Supported authentication modules
3738
SUPPORTED_AUTHENTICATION_MODULES = frozenset(
3839
{
3940
AUTH_MOD_K8S,
4041
AUTH_MOD_NOOP,
4142
AUTH_MOD_NOOP_WITH_TOKEN,
43+
AUTH_MOD_JWK_TOKEN,
4244
}
4345
)
4446
DEFAULT_AUTHENTICATION_MODULE = AUTH_MOD_NOOP
47+
DEFAULT_JWT_UID_CLAIM = "user_id"
48+
DEFAULT_JWT_USER_NAME_CLAIM = "username"
4549

4650
# Data collector constants
4751
DATA_COLLECTOR_COLLECTION_INTERVAL = 7200 # 2 hours in seconds

src/models/config.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,28 @@ def check_storage_location_is_set_when_needed(self) -> Self:
147147
return self
148148

149149

150+
class JwtConfiguration(BaseModel):
151+
"""JWT configuration."""
152+
153+
user_id_claim: str = constants.DEFAULT_JWT_UID_CLAIM
154+
username_claim: str = constants.DEFAULT_JWT_USER_NAME_CLAIM
155+
156+
157+
class JwkConfiguration(BaseModel):
158+
"""JWK configuration."""
159+
160+
url: AnyHttpUrl
161+
jwt_configuration: JwtConfiguration = JwtConfiguration()
162+
163+
150164
class AuthenticationConfiguration(BaseModel):
151165
"""Authentication configuration."""
152166

153167
module: str = constants.DEFAULT_AUTHENTICATION_MODULE
154168
skip_tls_verification: bool = False
155169
k8s_cluster_api: Optional[AnyHttpUrl] = None
156170
k8s_ca_cert_path: Optional[FilePath] = None
171+
jwk_config: Optional[JwkConfiguration] = None
157172

158173
@model_validator(mode="after")
159174
def check_authentication_model(self) -> Self:
@@ -164,8 +179,25 @@ def check_authentication_model(self) -> Self:
164179
f"Unsupported authentication module '{self.module}'. "
165180
f"Supported modules: {supported_modules}"
166181
)
182+
183+
if self.module == constants.AUTH_MOD_JWK_TOKEN:
184+
if self.jwk_config is None:
185+
raise ValueError(
186+
"JWK configuration must be specified when using JWK token authentication"
187+
)
188+
167189
return self
168190

191+
@property
192+
def jwk_configuration(self) -> JwkConfiguration:
193+
"""Return JWK configuration if the module is JWK token."""
194+
if self.module != constants.AUTH_MOD_JWK_TOKEN:
195+
raise ValueError(
196+
"JWK configuration is only available for JWK token authentication module"
197+
)
198+
assert self.jwk_config is not None, "JWK configuration should not be None"
199+
return self.jwk_config
200+
169201

170202
class Customization(BaseModel):
171203
"""Service customization."""

0 commit comments

Comments
 (0)