Skip to content

Commit 3b92258

Browse files
committed
Add support for JWT Token authentication with JWK
This commit introduces a new authentication module `jwk-token` that allows the application to authenticate users using JSON Web Tokens (JWT) signed with JSON Web Keys (JWK). The new module fetches JWKs from a specified URL and validates incoming JWTs against these keys. For now, the module expects the JWT to contain specific fields for user ID and username, which will cause an error if they are not present. These fields will be configurable in the future but for now, they match the fields used in Red Hat SSO JWTs.
1 parent c454da0 commit 3b92258

File tree

6 files changed

+148
-2
lines changed

6 files changed

+148
-2
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ dependencies = [
1313
"llama-stack>=0.2.13",
1414
"rich>=14.0.0",
1515
"cachetools>=6.1.0",
16+
"aiohttp>=3.12.14",
17+
"authlib>=1.6.0",
1618
]
1719

1820
[tool.pyright]

src/auth/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from auth.interface import AuthInterface
6-
from auth import noop, noop_with_token, k8s
6+
from auth import noop, noop_with_token, k8s, jwk_token
77
from configuration import configuration
88
import constants
99

@@ -15,7 +15,13 @@ def get_auth_dependency(
1515
virtual_path: str = constants.DEFAULT_VIRTUAL_PATH,
1616
) -> AuthInterface:
1717
"""Select the configured authentication dependency interface."""
18-
module = configuration.authentication_configuration.module # pyright: ignore
18+
19+
if configuration.authentication_configuration is None:
20+
err_msg = "Authentication configuration is not set"
21+
logger.error(err_msg)
22+
raise ValueError(err_msg)
23+
24+
module = configuration.authentication_configuration.module
1925

2026
logger.debug(
2127
"Initializing authentication dependency: module='%s', virtual_path='%s'",
@@ -32,6 +38,15 @@ def get_auth_dependency(
3238
)
3339
case constants.AUTH_MOD_K8S:
3440
return k8s.K8SAuthDependency(virtual_path=virtual_path)
41+
case constants.AUTH_MOD_JWK_TOKEN:
42+
if configuration.authentication_configuration.jwk_url is None:
43+
err_msg = "JWK URL must be provided for JWK Token authentication"
44+
logger.error(err_msg)
45+
raise ValueError(err_msg)
46+
return jwk_token.JwkTokenAuthDependency(
47+
configuration.authentication_configuration.jwk_url,
48+
virtual_path=virtual_path,
49+
)
3550
case _:
3651
err_msg = f"Unsupported authentication module '{module}'"
3752
logger.error(err_msg)

src/auth/jwk_token.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""Manage authentication flow for FastAPI endpoints with no-op auth."""
2+
3+
import logging
4+
from asyncio import Lock
5+
6+
from fastapi import Request
7+
from authlib.jose import JsonWebKey, KeySet, jwt, JoseError
8+
from cachetools import TTLCache
9+
import aiohttp
10+
from pydantic import AnyHttpUrl
11+
12+
from constants import (
13+
DEFAULT_USER_NAME,
14+
DEFAULT_USER_UID,
15+
DEFAULT_VIRTUAL_PATH,
16+
)
17+
from auth.interface import AuthInterface
18+
from auth.utils import extract_user_token
19+
20+
logger = logging.getLogger(__name__)
21+
22+
# Global JWK registry to avoid re-fetching JWKs for each request. Cached for 1
23+
# hour, keys are unlikely to change frequently.
24+
_jwk_cache: TTLCache[str, KeySet] = TTLCache(maxsize=3, ttl=3600)
25+
# Ideally this would be a RWLock, but it would require adding a dependency on
26+
# aiorwlock
27+
_jwk_cache_lock = Lock()
28+
29+
# TODO(omertuc): These default field names currently match the values in the Red Hat SSO JWTs
30+
# (e.g. `ocm token` command). Ideally they would be configurable so users of
31+
# lightspeed-core could set them to match their own JWT claims.
32+
DEFAULT_JWT_FIELD_USER_UID = "user_id"
33+
DEFAULT_JWT_FIELD_USER_NAME = "username"
34+
35+
36+
async def get_jwk_set(jwk_url):
37+
"""
38+
Fetch the JWK set from the cache, or fetch it from the URL if not cached.
39+
"""
40+
async with _jwk_cache_lock:
41+
if jwk_url not in _jwk_cache:
42+
async with aiohttp.ClientSession() as session:
43+
async with session.get(jwk_url) as resp:
44+
resp.raise_for_status()
45+
data = await resp.json()
46+
_jwk_cache[jwk_url] = JsonWebKey.import_key_set(data)
47+
return _jwk_cache[jwk_url]
48+
49+
50+
class JwkTokenAuthDependency(AuthInterface): # pylint: disable=too-few-public-methods
51+
"""JWK AuthDependency class that checks that the given bearer token is
52+
valid based on a given JWKS URL."""
53+
54+
def __init__(
55+
self, jwks_url: AnyHttpUrl, virtual_path: str = DEFAULT_VIRTUAL_PATH
56+
) -> None:
57+
"""Initialize the required allowed paths for authorization checks."""
58+
self.virtual_path = virtual_path
59+
self.jwks_url: AnyHttpUrl = jwks_url
60+
61+
async def __call__(self, request: Request) -> tuple[str, str, str]:
62+
user_token = extract_user_token(request.headers)
63+
64+
jwk_set = await get_jwk_set(str(self.jwks_url))
65+
66+
if not user_token:
67+
logger.error("No authorization token found in request")
68+
return DEFAULT_USER_UID, DEFAULT_USER_NAME, self.virtual_path
69+
70+
try:
71+
claims = jwt.decode(user_token, jwk_set)
72+
73+
try:
74+
user_id = claims[DEFAULT_JWT_FIELD_USER_UID]
75+
except KeyError as exc:
76+
raise JoseError(
77+
f"JWT token does not contain required field '{DEFAULT_JWT_FIELD_USER_UID}'"
78+
) from exc
79+
80+
try:
81+
username = claims[DEFAULT_JWT_FIELD_USER_NAME]
82+
except KeyError as exc:
83+
raise JoseError(
84+
f"JWT token does not contain required field '{DEFAULT_JWT_FIELD_USER_NAME}'"
85+
) from exc
86+
87+
logger.info(
88+
"Successfully authenticated user %s (ID: %s)", username, user_id
89+
)
90+
return user_id, username, user_token
91+
except JoseError:
92+
logger.error("JWT token validation failed")
93+
raise
94+
except Exception:
95+
logger.error("Unexpected error during token validation")
96+
raise

src/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,14 @@
3333
AUTH_MOD_K8S = "k8s"
3434
AUTH_MOD_NOOP = "noop"
3535
AUTH_MOD_NOOP_WITH_TOKEN = "noop-with-token"
36+
AUTH_MOD_JWK_TOKEN = "jwk-token"
3637
# Supported authentication modules
3738
SUPPORTED_AUTHENTICATION_MODULES = frozenset(
3839
{
3940
AUTH_MOD_K8S,
4041
AUTH_MOD_NOOP,
4142
AUTH_MOD_NOOP_WITH_TOKEN,
43+
AUTH_MOD_JWK_TOKEN,
4244
}
4345
)
4446
DEFAULT_AUTHENTICATION_MODULE = AUTH_MOD_NOOP

src/models/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class AuthenticationConfiguration(BaseModel):
112112
skip_tls_verification: bool = False
113113
k8s_cluster_api: Optional[AnyHttpUrl] = None
114114
k8s_ca_cert_path: Optional[FilePath] = None
115+
jwk_url: Optional[AnyHttpUrl] = None
115116

116117
@model_validator(mode="after")
117118
def check_authentication_model(self) -> Self:

0 commit comments

Comments
 (0)