Skip to content

Commit 08d8ac1

Browse files
committed
Add support for JWT Token authentication with JWK
This commit introduces a new authentication module `jwk-token` that allows the application to authenticate users using JSON Web Tokens (JWT) signed with JSON Web Keys (JWK). The new module fetches JWKs from a specified URL and validates incoming JWTs against these keys. Example config: ```yaml authentication: module: jwk-token jwk_config: url: https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/certs jwt_configuration: user_id_claim: user_id username_claim: username ```
1 parent 79b5427 commit 08d8ac1

File tree

9 files changed

+629
-3
lines changed

9 files changed

+629
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ dependencies = [
1313
"llama-stack>=0.2.13",
1414
"rich>=14.0.0",
1515
"cachetools>=6.1.0",
16+
"aiohttp>=3.12.14",
17+
"authlib>=1.6.0",
1618
]
1719

1820
[tool.pyright]

src/auth/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from auth.interface import AuthInterface
6-
from auth import noop, noop_with_token, k8s
6+
from auth import noop, noop_with_token, k8s, jwk_token
77
from configuration import configuration
88
import constants
99

@@ -15,7 +15,7 @@ def get_auth_dependency(
1515
virtual_path: str = constants.DEFAULT_VIRTUAL_PATH,
1616
) -> AuthInterface:
1717
"""Select the configured authentication dependency interface."""
18-
module = configuration.authentication_configuration.module # pyright: ignore
18+
module = configuration.authentication_configuration.module
1919

2020
logger.debug(
2121
"Initializing authentication dependency: module='%s', virtual_path='%s'",
@@ -32,6 +32,11 @@ def get_auth_dependency(
3232
)
3333
case constants.AUTH_MOD_K8S:
3434
return k8s.K8SAuthDependency(virtual_path=virtual_path)
35+
case constants.AUTH_MOD_JWK_TOKEN:
36+
return jwk_token.JwkTokenAuthDependency(
37+
configuration.authentication_configuration.jwk_configuration,
38+
virtual_path=virtual_path,
39+
)
3540
case _:
3641
err_msg = f"Unsupported authentication module '{module}'"
3742
logger.error(err_msg)

src/auth/jwk_token.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""Manage authentication flow for FastAPI endpoints with no-op auth."""
2+
3+
import logging
4+
from asyncio import Lock
5+
6+
from fastapi import Request, HTTPException, status
7+
from authlib.jose import JsonWebKey, KeySet, jwt
8+
from authlib.jose.errors import (
9+
BadSignatureError,
10+
DecodeError,
11+
ExpiredTokenError,
12+
JoseError,
13+
)
14+
from cachetools import TTLCache
15+
import aiohttp
16+
17+
from constants import (
18+
DEFAULT_VIRTUAL_PATH,
19+
)
20+
from auth.interface import AuthInterface
21+
from auth.utils import extract_user_token
22+
from models.config import JwkConfiguration
23+
24+
logger = logging.getLogger(__name__)
25+
26+
# Global JWK registry to avoid re-fetching JWKs for each request. Cached for 1
27+
# hour, keys are unlikely to change frequently.
28+
_jwk_cache: TTLCache[str, KeySet] = TTLCache(maxsize=3, ttl=3600)
29+
# Ideally this would be an RWLock, but it would require adding a dependency on
30+
# aiorwlock
31+
_jwk_cache_lock = Lock()
32+
33+
34+
async def get_jwk_set(url: str):
35+
"""Fetch the JWK set from the cache, or fetch it from the URL if not cached."""
36+
async with _jwk_cache_lock:
37+
if url not in _jwk_cache:
38+
async with aiohttp.ClientSession() as session:
39+
async with session.get(url) as resp:
40+
resp.raise_for_status()
41+
_jwk_cache[url] = JsonWebKey.import_key_set(await resp.json())
42+
return _jwk_cache[url]
43+
44+
45+
class KeyNotFoundError(Exception):
46+
"""Exception raised when a key is not found in the JWK set based on kid/alg."""
47+
48+
49+
class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods
50+
"""JWK AuthDependency class for JWK-based JWT authentication."""
51+
52+
def __init__(
53+
self, config: JwkConfiguration, virtual_path: str = DEFAULT_VIRTUAL_PATH
54+
) -> None:
55+
"""Initialize the required allowed paths for authorization checks."""
56+
self.virtual_path: str = virtual_path
57+
self.config: JwkConfiguration = config
58+
59+
async def __call__(self, request: Request) -> tuple[str, str, str]:
60+
"""Authenticate the JWT in the headers against the keys from the JWK url."""
61+
user_token = extract_user_token(request.headers)
62+
63+
jwk_set = await get_jwk_set(str(self.config.url))
64+
65+
def resolve_key(header, _payload):
66+
"""Match kid and alg from the JWT header to the JWK set.
67+
68+
Resolve the key from the JWK set based on the JWT header. Also
69+
match the algorithm to make sure the algorithm stated by the user
70+
is the same algorithm the key itself expects.
71+
"""
72+
key = jwk_set.find_by_kid(header["kid"])
73+
74+
if header["alg"] != key["alg"]:
75+
raise KeyNotFoundError
76+
77+
return key
78+
79+
if not user_token:
80+
raise HTTPException(
81+
status_code=status.HTTP_401_UNAUTHORIZED,
82+
detail="Authorization header is missing or invalid",
83+
)
84+
85+
try:
86+
claims = jwt.decode(user_token, key=resolve_key)
87+
except KeyNotFoundError as exc:
88+
raise HTTPException(
89+
status_code=status.HTTP_401_UNAUTHORIZED,
90+
detail="Invalid token: signed by unknown key or algorithm mismatch",
91+
) from exc
92+
except BadSignatureError as exc:
93+
raise HTTPException(
94+
status_code=status.HTTP_401_UNAUTHORIZED,
95+
detail="Invalid token: bad signature",
96+
) from exc
97+
except DecodeError as exc:
98+
raise HTTPException(
99+
status_code=status.HTTP_400_BAD_REQUEST,
100+
detail="Invalid token: decode error",
101+
) from exc
102+
except JoseError as exc:
103+
raise HTTPException(
104+
status_code=status.HTTP_400_BAD_REQUEST,
105+
detail="Invalid token: unknown error",
106+
) from exc
107+
except Exception as exc:
108+
raise HTTPException(
109+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
110+
detail="Internal server error",
111+
) from exc
112+
113+
try:
114+
claims.validate()
115+
except ExpiredTokenError as exc:
116+
raise HTTPException(
117+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired"
118+
) from exc
119+
except JoseError as exc:
120+
raise HTTPException(
121+
status_code=status.HTTP_401_UNAUTHORIZED,
122+
detail="Error validating token",
123+
) from exc
124+
except Exception as exc:
125+
raise HTTPException(
126+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
127+
detail="Internal server error during token validation",
128+
) from exc
129+
130+
try:
131+
user_id: str = claims[self.config.jwt_configuration.user_id_claim]
132+
except KeyError as exc:
133+
raise HTTPException(
134+
status_code=status.HTTP_401_UNAUTHORIZED,
135+
detail=f"Token missing claim: {self.config.jwt_configuration.user_id_claim}",
136+
) from exc
137+
138+
try:
139+
username: str = claims[self.config.jwt_configuration.username_claim]
140+
except KeyError as exc:
141+
raise HTTPException(
142+
status_code=status.HTTP_401_UNAUTHORIZED,
143+
detail=f"Token missing claim: {self.config.jwt_configuration.username_claim}",
144+
) from exc
145+
146+
logger.info("Successfully authenticated user %s (ID: %s)", username, user_id)
147+
148+
return user_id, username, user_token

src/configuration.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,16 @@ def mcp_servers(self) -> list[ModelContextProtocolServer]:
8484
return self._configuration.mcp_servers
8585

8686
@property
87-
def authentication_configuration(self) -> Optional[AuthenticationConfiguration]:
87+
def authentication_configuration(self) -> AuthenticationConfiguration:
8888
"""Return authentication configuration."""
8989
assert (
9090
self._configuration is not None
9191
), "logic error: configuration is not loaded"
92+
93+
assert (
94+
self._configuration.authentication is not None
95+
), "logic error: authentication configuration is not loaded"
96+
9297
return self._configuration.authentication
9398

9499
@property

src/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,19 @@
3333
AUTH_MOD_K8S = "k8s"
3434
AUTH_MOD_NOOP = "noop"
3535
AUTH_MOD_NOOP_WITH_TOKEN = "noop-with-token"
36+
AUTH_MOD_JWK_TOKEN = "jwk-token"
3637
# Supported authentication modules
3738
SUPPORTED_AUTHENTICATION_MODULES = frozenset(
3839
{
3940
AUTH_MOD_K8S,
4041
AUTH_MOD_NOOP,
4142
AUTH_MOD_NOOP_WITH_TOKEN,
43+
AUTH_MOD_JWK_TOKEN,
4244
}
4345
)
4446
DEFAULT_AUTHENTICATION_MODULE = AUTH_MOD_NOOP
47+
DEFAULT_JWT_UID_CLAIM = "user_id"
48+
DEFAULT_JWT_USER_NAME_CLAIM = "username"
4549

4650
# Data collector constants
4751
DATA_COLLECTOR_COLLECTION_INTERVAL = 7200 # 2 hours in seconds

src/models/config.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,40 @@ def check_storage_location_is_set_when_needed(self) -> Self:
131131
return self
132132

133133

134+
class JwtConfiguration(BaseModel):
135+
"""JWT configuration."""
136+
137+
user_id_claim: str = constants.DEFAULT_JWT_UID_CLAIM
138+
username_claim: str = constants.DEFAULT_JWT_USER_NAME_CLAIM
139+
140+
@model_validator(mode="after")
141+
def check_jwt_configuration(self) -> Self:
142+
"""Validate JWT configuration."""
143+
return self
144+
145+
146+
class JwkConfiguration(BaseModel):
147+
"""JWK configuration."""
148+
149+
url: Optional[AnyHttpUrl] = None
150+
jwt_configuration: JwtConfiguration = JwtConfiguration()
151+
152+
@model_validator(mode="after")
153+
def check_jwk_configuration(self) -> Self:
154+
"""Validate JWK configuration."""
155+
if self.url is None:
156+
raise ValueError("JWK URL must be specified")
157+
return self
158+
159+
134160
class AuthenticationConfiguration(BaseModel):
135161
"""Authentication configuration."""
136162

137163
module: str = constants.DEFAULT_AUTHENTICATION_MODULE
138164
skip_tls_verification: bool = False
139165
k8s_cluster_api: Optional[AnyHttpUrl] = None
140166
k8s_ca_cert_path: Optional[FilePath] = None
167+
jwk_config: Optional[JwkConfiguration] = None
141168

142169
@model_validator(mode="after")
143170
def check_authentication_model(self) -> Self:
@@ -148,8 +175,25 @@ def check_authentication_model(self) -> Self:
148175
f"Unsupported authentication module '{self.module}'. "
149176
f"Supported modules: {supported_modules}"
150177
)
178+
179+
if self.module == constants.AUTH_MOD_JWK_TOKEN:
180+
if self.jwk_config is None:
181+
raise ValueError(
182+
"JWK configuration must be specified when using JWK token authentication"
183+
)
184+
151185
return self
152186

187+
@property
188+
def jwk_configuration(self) -> JwkConfiguration:
189+
"""Return JWK configuration if the module is JWK token."""
190+
if self.module != constants.AUTH_MOD_JWK_TOKEN:
191+
raise ValueError(
192+
"JWK configuration is only available for JWK token authentication module"
193+
)
194+
assert self.jwk_config is not None, "JWK configuration should not be None"
195+
return self.jwk_config
196+
153197

154198
class Customization(BaseModel):
155199
"""Service customization."""

0 commit comments

Comments
 (0)