Skip to content

Commit 654d151

Browse files
committed
Add support for JWT Token authentication with JWK
This commit introduces a new authentication module `jwk-token` that allows the application to authenticate users using JSON Web Tokens (JWT) signed with JSON Web Keys (JWK). The new module fetches JWKs from a specified URL and validates incoming JWTs against these keys. Example config: ```yaml authentication: module: jwk-token jwk_config: url: https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/certs jwt_configuration: user_id_claim: user_id username_claim: username ```
1 parent 79b5427 commit 654d151

File tree

8 files changed

+614
-3
lines changed

8 files changed

+614
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ dependencies = [
1313
"llama-stack>=0.2.13",
1414
"rich>=14.0.0",
1515
"cachetools>=6.1.0",
16+
"aiohttp>=3.12.14",
17+
"authlib>=1.6.0",
1618
]
1719

1820
[tool.pyright]

src/auth/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from auth.interface import AuthInterface
6-
from auth import noop, noop_with_token, k8s
6+
from auth import noop, noop_with_token, k8s, jwk_token
77
from configuration import configuration
88
import constants
99

@@ -15,7 +15,7 @@ def get_auth_dependency(
1515
virtual_path: str = constants.DEFAULT_VIRTUAL_PATH,
1616
) -> AuthInterface:
1717
"""Select the configured authentication dependency interface."""
18-
module = configuration.authentication_configuration.module # pyright: ignore
18+
module = configuration.authentication_configuration.module
1919

2020
logger.debug(
2121
"Initializing authentication dependency: module='%s', virtual_path='%s'",
@@ -32,6 +32,11 @@ def get_auth_dependency(
3232
)
3333
case constants.AUTH_MOD_K8S:
3434
return k8s.K8SAuthDependency(virtual_path=virtual_path)
35+
case constants.AUTH_MOD_JWK_TOKEN:
36+
return jwk_token.JwkTokenAuthDependency(
37+
configuration.authentication_configuration.jwk_configuration,
38+
virtual_path=virtual_path,
39+
)
3540
case _:
3641
err_msg = f"Unsupported authentication module '{module}'"
3742
logger.error(err_msg)

src/auth/jwk_token.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""Manage authentication flow for FastAPI endpoints with no-op auth."""
2+
3+
import logging
4+
from asyncio import Lock
5+
6+
from fastapi import Request, HTTPException, status
7+
from authlib.jose import JsonWebKey, KeySet, jwt
8+
from authlib.jose.errors import (
9+
BadSignatureError,
10+
DecodeError,
11+
ExpiredTokenError,
12+
JoseError,
13+
)
14+
from cachetools import TTLCache
15+
import aiohttp
16+
17+
from constants import (
18+
DEFAULT_VIRTUAL_PATH,
19+
)
20+
from auth.interface import AuthInterface
21+
from auth.utils import extract_user_token
22+
from models.config import JwkConfiguration
23+
24+
logger = logging.getLogger(__name__)
25+
26+
# Global JWK registry to avoid re-fetching JWKs for each request. Cached for 1
27+
# hour, keys are unlikely to change frequently.
28+
_jwk_cache: TTLCache[str, KeySet] = TTLCache(maxsize=3, ttl=3600)
29+
# Ideally this would be an RWLock, but it would require adding a dependency on
30+
# aiorwlock
31+
_jwk_cache_lock = Lock()
32+
33+
34+
async def get_jwk_set(url: str):
35+
"""
36+
Fetch the JWK set from the cache, or fetch it from the URL if not cached.
37+
"""
38+
async with _jwk_cache_lock:
39+
if url not in _jwk_cache:
40+
async with aiohttp.ClientSession() as session:
41+
async with session.get(url) as resp:
42+
resp.raise_for_status()
43+
_jwk_cache[url] = JsonWebKey.import_key_set(await resp.json())
44+
return _jwk_cache[url]
45+
46+
47+
class KeyNotFoundError(Exception):
48+
"""
49+
Exception raised when a key is not found in the JWK set based on kid/alg.
50+
"""
51+
52+
53+
class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods
54+
"""JWK AuthDependency class that checks that the given bearer token is
55+
valid based on a given JWKS URL."""
56+
57+
def __init__(
58+
self, config: JwkConfiguration, virtual_path: str = DEFAULT_VIRTUAL_PATH
59+
) -> None:
60+
"""Initialize the required allowed paths for authorization checks."""
61+
self.virtual_path: str = virtual_path
62+
self.config: JwkConfiguration = config
63+
64+
async def __call__(self, request: Request) -> tuple[str, str, str]:
65+
user_token = extract_user_token(request.headers)
66+
67+
jwk_set = await get_jwk_set(str(self.config.url))
68+
69+
def resolve_key(header, _payload):
70+
"""Resolve the key from the JWK set based on the JWT header. Also
71+
match the algorithm to make sure the algorithm stated by the user
72+
is the same algorithm the key itself expects."""
73+
key = jwk_set.find_by_kid(header["kid"])
74+
75+
if header["alg"] != key["alg"]:
76+
raise KeyNotFoundError
77+
78+
return key
79+
80+
if not user_token:
81+
raise HTTPException(
82+
status_code=status.HTTP_401_UNAUTHORIZED,
83+
detail="Authorization header is missing or invalid",
84+
)
85+
86+
try:
87+
claims = jwt.decode(user_token, key=resolve_key)
88+
except KeyNotFoundError as exc:
89+
raise HTTPException(
90+
status_code=status.HTTP_401_UNAUTHORIZED,
91+
detail="Invalid token: signed by unknown key or algorithm mismatch",
92+
) from exc
93+
except BadSignatureError as exc:
94+
raise HTTPException(
95+
status_code=status.HTTP_401_UNAUTHORIZED,
96+
detail="Invalid token: bad signature",
97+
) from exc
98+
except DecodeError as exc:
99+
raise HTTPException(
100+
status_code=status.HTTP_400_BAD_REQUEST,
101+
detail="Invalid token: decode error",
102+
) from exc
103+
except JoseError as exc:
104+
raise HTTPException(
105+
status_code=status.HTTP_400_BAD_REQUEST,
106+
detail="Invalid token: unknown error",
107+
) from exc
108+
except Exception as exc:
109+
raise HTTPException(
110+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
111+
detail="Internal server error",
112+
) from exc
113+
114+
try:
115+
claims.validate()
116+
except ExpiredTokenError as exc:
117+
raise HTTPException(
118+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired"
119+
) from exc
120+
except JoseError as exc:
121+
raise HTTPException(
122+
status_code=status.HTTP_401_UNAUTHORIZED,
123+
detail="Error validating token",
124+
) from exc
125+
except Exception as exc:
126+
raise HTTPException(
127+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
128+
detail="Internal server error during token validation",
129+
) from exc
130+
131+
try:
132+
user_id: str = claims[self.config.jwt_configuration.user_id_claim]
133+
except KeyError as exc:
134+
raise HTTPException(
135+
status_code=status.HTTP_401_UNAUTHORIZED,
136+
detail=f"Token missing required claim: {self.config.jwt_configuration.user_id_claim}",
137+
) from exc
138+
139+
try:
140+
username: str = claims[self.config.jwt_configuration.username_claim]
141+
except KeyError as exc:
142+
raise HTTPException(
143+
status_code=status.HTTP_401_UNAUTHORIZED,
144+
detail=f"Token missing required claim: {self.config.jwt_configuration.username_claim}",
145+
) from exc
146+
147+
logger.info("Successfully authenticated user %s (ID: %s)", username, user_id)
148+
149+
return user_id, username, user_token

src/configuration.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,16 @@ def mcp_servers(self) -> list[ModelContextProtocolServer]:
8484
return self._configuration.mcp_servers
8585

8686
@property
87-
def authentication_configuration(self) -> Optional[AuthenticationConfiguration]:
87+
def authentication_configuration(self) -> AuthenticationConfiguration:
8888
"""Return authentication configuration."""
8989
assert (
9090
self._configuration is not None
9191
), "logic error: configuration is not loaded"
92+
93+
assert (
94+
self._configuration.authentication is not None
95+
), "logic error: authentication configuration is not loaded"
96+
9297
return self._configuration.authentication
9398

9499
@property

src/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,19 @@
3333
AUTH_MOD_K8S = "k8s"
3434
AUTH_MOD_NOOP = "noop"
3535
AUTH_MOD_NOOP_WITH_TOKEN = "noop-with-token"
36+
AUTH_MOD_JWK_TOKEN = "jwk-token"
3637
# Supported authentication modules
3738
SUPPORTED_AUTHENTICATION_MODULES = frozenset(
3839
{
3940
AUTH_MOD_K8S,
4041
AUTH_MOD_NOOP,
4142
AUTH_MOD_NOOP_WITH_TOKEN,
43+
AUTH_MOD_JWK_TOKEN,
4244
}
4345
)
4446
DEFAULT_AUTHENTICATION_MODULE = AUTH_MOD_NOOP
47+
DEFAULT_JWT_UID_CLAIM = "user_id"
48+
DEFAULT_JWT_USER_NAME_CLAIM = "username"
4549

4650
# Data collector constants
4751
DATA_COLLECTOR_COLLECTION_INTERVAL = 7200 # 2 hours in seconds

src/models/config.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,40 @@ def check_storage_location_is_set_when_needed(self) -> Self:
131131
return self
132132

133133

134+
class JwtConfiguration(BaseModel):
135+
"""JWT configuration."""
136+
137+
user_id_claim: str = constants.DEFAULT_JWT_UID_CLAIM
138+
username_claim: str = constants.DEFAULT_JWT_USER_NAME_CLAIM
139+
140+
@model_validator(mode="after")
141+
def check_jwt_configuration(self) -> Self:
142+
"""Validate JWT configuration."""
143+
return self
144+
145+
146+
class JwkConfiguration(BaseModel):
147+
"""JWK configuration."""
148+
149+
url: Optional[AnyHttpUrl] = None
150+
jwt_configuration: JwtConfiguration = JwtConfiguration()
151+
152+
@model_validator(mode="after")
153+
def check_jwk_configuration(self) -> Self:
154+
"""Validate JWK configuration."""
155+
if self.url is None:
156+
raise ValueError("JWK URL must be specified")
157+
return self
158+
159+
134160
class AuthenticationConfiguration(BaseModel):
135161
"""Authentication configuration."""
136162

137163
module: str = constants.DEFAULT_AUTHENTICATION_MODULE
138164
skip_tls_verification: bool = False
139165
k8s_cluster_api: Optional[AnyHttpUrl] = None
140166
k8s_ca_cert_path: Optional[FilePath] = None
167+
jwk_config: Optional[JwkConfiguration] = None
141168

142169
@model_validator(mode="after")
143170
def check_authentication_model(self) -> Self:
@@ -148,8 +175,26 @@ def check_authentication_model(self) -> Self:
148175
f"Unsupported authentication module '{self.module}'. "
149176
f"Supported modules: {supported_modules}"
150177
)
178+
179+
if self.module == constants.AUTH_MOD_JWK_TOKEN:
180+
if self.jwk_config is None:
181+
raise ValueError(
182+
"JWK configuration must be specified when using JWK token authentication"
183+
)
184+
151185
return self
152186

187+
@property
188+
def jwk_configuration(self) -> JwkConfiguration:
189+
"""Return JWK configuration if the module is JWK token."""
190+
191+
if self.module != constants.AUTH_MOD_JWK_TOKEN:
192+
raise ValueError(
193+
"JWK configuration is only available for JWK token authentication module"
194+
)
195+
assert self.jwk_config is not None, "JWK configuration should not be None"
196+
return self.jwk_config
197+
153198

154199
class Customization(BaseModel):
155200
"""Service customization."""

0 commit comments

Comments
 (0)