From dc4da3edfa3413b2e70f8bd506681170e0f406ae Mon Sep 17 00:00:00 2001 From: michelligabriele Date: Fri, 6 Feb 2026 18:32:03 +0100 Subject: [PATCH] fix(sso): extract user roles from JWT access token for Keycloak compatibility Keycloak (and similar OIDC providers) include role claims in the JWT access token but not in the UserInfo endpoint response. Previously, roles were only extracted from UserInfo, causing all SSO users to default to internal_user_view_only regardless of their actual role. Changes: - Extract user roles from JWT access token in process_sso_jwt_access_token() when UserInfo doesn't provide them (tries role_mappings first, then GENERIC_USER_ROLE_ATTRIBUTE) - Handle list-type role values in get_litellm_user_role() since Keycloak returns roles as arrays (e.g. ["proxy_admin"] instead of "proxy_admin") - Add 9 new unit tests covering role extraction and list handling - Update 3 existing tests for new JWT decode behavior Closes #20407 --- litellm/proxy/management_endpoints/types.py | 13 +- litellm/proxy/management_endpoints/ui_sso.py | 84 +++++-- .../proxy/management_endpoints/test_ui_sso.py | 231 +++++++++++++++++- 3 files changed, 304 insertions(+), 24 deletions(-) diff --git a/litellm/proxy/management_endpoints/types.py b/litellm/proxy/management_endpoints/types.py index ad2ad0a5fe5..a35fc4a5f3f 100644 --- a/litellm/proxy/management_endpoints/types.py +++ b/litellm/proxy/management_endpoints/types.py @@ -28,17 +28,24 @@ def is_valid_litellm_user_role(role_str: str) -> bool: return False -def get_litellm_user_role(role_str: str) -> Optional[LitellmUserRoles]: +def get_litellm_user_role(role_str) -> Optional[LitellmUserRoles]: """ - Convert a string to a LitellmUserRoles enum if valid (case-insensitive). + Convert a string (or list of strings) to a LitellmUserRoles enum if valid (case-insensitive). + + Handles list inputs since some SSO providers (e.g., Keycloak) return roles + as arrays like ["proxy_admin"] instead of plain strings. Args: - role_str: String to convert (e.g., "proxy_admin", "PROXY_ADMIN", "internal_user") + role_str: String or list to convert (e.g., "proxy_admin", ["proxy_admin"]) Returns: LitellmUserRoles enum if valid, None otherwise """ try: + if isinstance(role_str, list): + if len(role_str) == 0: + return None + role_str = role_str[0] # Use _value2member_map_ for O(1) lookup, case-insensitive result = LitellmUserRoles._value2member_map_.get(role_str.lower()) return cast(Optional[LitellmUserRoles], result) diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 2d248dc81f3..278f3bdaafd 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -171,36 +171,86 @@ def process_sso_jwt_access_token( access_token_str: Optional[str], sso_jwt_handler: Optional[JWTHandler], result: Union[OpenID, dict, None], + role_mappings: Optional["RoleMappings"] = None, ) -> None: """ - Process SSO JWT access token and extract team IDs if available. + Process SSO JWT access token and extract team IDs and user role if available. - This function decodes the JWT access token and extracts team IDs using the - sso_jwt_handler, then sets the team_ids attribute on the result object. + This function decodes the JWT access token and extracts team IDs and user + role, then sets them on the result object. Role extraction from the access + token is needed because some SSO providers (e.g., Keycloak) do not include + role claims in the UserInfo endpoint response. Args: access_token_str: The JWT access token string sso_jwt_handler: SSO-specific JWT handler for team ID extraction - result: The SSO result object to update with team IDs + result: The SSO result object to update with team IDs and role + role_mappings: Optional role mappings configuration for group-based role determination """ - if access_token_str and sso_jwt_handler and result: + if access_token_str and result: import jwt access_token_payload = jwt.decode( access_token_str, options={"verify_signature": False} ) - # Handle both dict and object result types - if isinstance(result, dict): - result_team_ids: Optional[List[str]] = result.get("team_ids", []) - if not result_team_ids: - team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload) - result["team_ids"] = team_ids - else: - result_team_ids = getattr(result, "team_ids", []) if result else [] - if not result_team_ids: - team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload) - setattr(result, "team_ids", team_ids) + # Extract team IDs from access token if sso_jwt_handler is available + if sso_jwt_handler: + if isinstance(result, dict): + result_team_ids: Optional[List[str]] = result.get("team_ids", []) + if not result_team_ids: + team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload) + result["team_ids"] = team_ids + else: + result_team_ids = getattr(result, "team_ids", []) if result else [] + if not result_team_ids: + team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload) + setattr(result, "team_ids", team_ids) + + # Extract user role from access token if not already set from UserInfo + existing_role = result.get("user_role") if isinstance(result, dict) else getattr(result, "user_role", None) + if existing_role is None: + user_role: Optional[LitellmUserRoles] = None + + # Try role_mappings first (group-based role determination) + if role_mappings is not None and role_mappings.roles: + group_claim = role_mappings.group_claim + user_groups_raw: Any = get_nested_value(access_token_payload, group_claim) + + user_groups: List[str] = [] + if isinstance(user_groups_raw, list): + user_groups = [str(g) for g in user_groups_raw] + elif isinstance(user_groups_raw, str): + user_groups = [g.strip() for g in user_groups_raw.split(",") if g.strip()] + elif user_groups_raw is not None: + user_groups = [str(user_groups_raw)] + + if user_groups: + user_role = determine_role_from_groups(user_groups, role_mappings) + verbose_proxy_logger.debug( + f"Determined role '{user_role}' from access token groups '{user_groups}' using role_mappings" + ) + elif role_mappings.default_role: + user_role = role_mappings.default_role + + # Fallback: try GENERIC_USER_ROLE_ATTRIBUTE on the access token payload + if user_role is None: + generic_user_role_attribute_name = os.getenv("GENERIC_USER_ROLE_ATTRIBUTE", "role") + user_role_from_token = get_nested_value(access_token_payload, generic_user_role_attribute_name) + if user_role_from_token is not None: + user_role = get_litellm_user_role(user_role_from_token) + verbose_proxy_logger.debug( + f"Extracted role '{user_role}' from access token field '{generic_user_role_attribute_name}'" + ) + + if user_role is not None: + if isinstance(result, dict): + result["user_role"] = user_role + else: + setattr(result, "user_role", user_role) + verbose_proxy_logger.debug( + f"Set user_role='{user_role}' from JWT access token" + ) @router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False) @@ -688,7 +738,7 @@ def response_convertor(response, client): ) access_token_str: Optional[str] = generic_sso.access_token - process_sso_jwt_access_token(access_token_str, sso_jwt_handler, result) + process_sso_jwt_access_token(access_token_str, sso_jwt_handler, result, role_mappings=role_mappings) except Exception as e: verbose_proxy_logger.exception( diff --git a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py index 41096503a2e..16f80826798 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py @@ -25,6 +25,8 @@ MicrosoftSSOHandler, SSOAuthenticationHandler, normalize_email, + process_sso_jwt_access_token, + determine_role_from_groups, _setup_team_mappings, ) from litellm.types.proxy.management_endpoints.ui_sso import ( @@ -1298,6 +1300,7 @@ async def test_get_generic_sso_response_with_additional_headers(): # Mock the SSO provider and its methods mock_sso_instance = MagicMock() mock_sso_instance.verify_and_process = AsyncMock(return_value=mock_sso_response) + mock_sso_instance.access_token = None # Avoid triggering JWT decode in process_sso_jwt_access_token mock_sso_class = MagicMock(return_value=mock_sso_instance) @@ -1359,6 +1362,7 @@ async def test_get_generic_sso_response_with_empty_headers(): # Mock the SSO provider and its methods mock_sso_instance = MagicMock() mock_sso_instance.verify_and_process = AsyncMock(return_value=mock_sso_response) + mock_sso_instance.access_token = None # Avoid triggering JWT decode in process_sso_jwt_access_token mock_sso_class = MagicMock(return_value=mock_sso_instance) @@ -2546,22 +2550,25 @@ def test_process_sso_jwt_access_token_no_access_token(self, mock_jwt_handler): assert result.team_ids == [] def test_process_sso_jwt_access_token_no_sso_jwt_handler(self, sample_jwt_token): - """Test that nothing happens when sso_jwt_handler is None""" + """Test that JWT is decoded for role extraction even when sso_jwt_handler is None, + but team_ids are not extracted (team extraction requires sso_jwt_handler).""" from litellm.proxy.management_endpoints.ui_sso import ( process_sso_jwt_access_token, ) result = CustomOpenID(id="test_user", email="test@example.com", team_ids=[]) - with patch("jwt.decode") as mock_jwt_decode: + mock_payload = {"sub": "test_user", "email": "test@example.com"} + with patch("jwt.decode", return_value=mock_payload) as mock_jwt_decode: # Act process_sso_jwt_access_token( access_token_str=sample_jwt_token, sso_jwt_handler=None, result=result ) - # Assert nothing was processed - mock_jwt_decode.assert_not_called() + # JWT is decoded (for role extraction) but team_ids are not extracted + mock_jwt_decode.assert_called_once() assert result.team_ids == [] + assert result.user_role is None def test_process_sso_jwt_access_token_no_result( self, mock_jwt_handler, sample_jwt_token @@ -3848,3 +3855,219 @@ async def test_setup_team_mappings(): mock_prisma.db.litellm_ssoconfig.find_unique.assert_called_once_with( where={"id": "sso_config"} ) + + +# ============================================================================ +# Tests for get_litellm_user_role with list inputs (Keycloak returns lists) +# ============================================================================ + + +def test_get_litellm_user_role_with_string(): + """Test that get_litellm_user_role works with a plain string.""" + from litellm.proxy._types import LitellmUserRoles + from litellm.proxy.management_endpoints.types import get_litellm_user_role + + result = get_litellm_user_role("proxy_admin") + assert result == LitellmUserRoles.PROXY_ADMIN + + +def test_get_litellm_user_role_with_list(): + """ + Test that get_litellm_user_role handles list inputs. + Keycloak returns roles as arrays like ["proxy_admin"] instead of strings. + """ + from litellm.proxy._types import LitellmUserRoles + from litellm.proxy.management_endpoints.types import get_litellm_user_role + + result = get_litellm_user_role(["proxy_admin"]) + assert result == LitellmUserRoles.PROXY_ADMIN + + +def test_get_litellm_user_role_with_empty_list(): + """Test that get_litellm_user_role returns None for empty lists.""" + from litellm.proxy.management_endpoints.types import get_litellm_user_role + + result = get_litellm_user_role([]) + assert result is None + + +def test_get_litellm_user_role_with_invalid_role(): + """Test that get_litellm_user_role returns None for invalid roles.""" + from litellm.proxy.management_endpoints.types import get_litellm_user_role + + result = get_litellm_user_role("not_a_real_role") + assert result is None + + +def test_get_litellm_user_role_with_list_multiple_roles(): + """Test that get_litellm_user_role takes the first element from a multi-element list.""" + from litellm.proxy._types import LitellmUserRoles + from litellm.proxy.management_endpoints.types import get_litellm_user_role + + result = get_litellm_user_role(["proxy_admin", "internal_user"]) + assert result == LitellmUserRoles.PROXY_ADMIN + + +# ============================================================================ +# Tests for process_sso_jwt_access_token role extraction +# ============================================================================ + + +def test_process_sso_jwt_access_token_extracts_role_from_access_token(): + """ + Test that process_sso_jwt_access_token extracts user role from the JWT + access token when the UserInfo response did not include it. + + This is the core fix for the Keycloak SSO role mapping bug: Keycloak's + UserInfo endpoint does not return role claims, but the JWT access token + contains them. + """ + import jwt as pyjwt + + from litellm.proxy._types import LitellmUserRoles + + # Create a JWT access token with role claims (as Keycloak would) + access_token_payload = { + "sub": "user-123", + "email": "admin@test.com", + "litellm_role": ["proxy_admin"], + } + access_token_str = pyjwt.encode(access_token_payload, "secret", algorithm="HS256") + + # Result object with no role set (simulating UserInfo response without roles) + result = CustomOpenID( + id="user-123", + email="admin@test.com", + display_name="Admin User", + team_ids=[], + user_role=None, + ) + + # Call with GENERIC_USER_ROLE_ATTRIBUTE pointing to litellm_role + with patch.dict(os.environ, {"GENERIC_USER_ROLE_ATTRIBUTE": "litellm_role"}): + process_sso_jwt_access_token( + access_token_str=access_token_str, + sso_jwt_handler=None, + result=result, + role_mappings=None, + ) + + assert result.user_role == LitellmUserRoles.PROXY_ADMIN + + +def test_process_sso_jwt_access_token_does_not_override_existing_role(): + """ + Test that process_sso_jwt_access_token does NOT override a role that was + already extracted from the UserInfo response. + """ + import jwt as pyjwt + + from litellm.proxy._types import LitellmUserRoles + + access_token_payload = { + "sub": "user-123", + "litellm_role": ["internal_user"], + } + access_token_str = pyjwt.encode(access_token_payload, "secret", algorithm="HS256") + + # Result already has a role (e.g., set from UserInfo) + result = CustomOpenID( + id="user-123", + email="admin@test.com", + display_name="Admin User", + team_ids=[], + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + with patch.dict(os.environ, {"GENERIC_USER_ROLE_ATTRIBUTE": "litellm_role"}): + process_sso_jwt_access_token( + access_token_str=access_token_str, + sso_jwt_handler=None, + result=result, + role_mappings=None, + ) + + # Should keep the original role + assert result.user_role == LitellmUserRoles.PROXY_ADMIN + + +def test_process_sso_jwt_access_token_extracts_role_from_nested_field(): + """ + Test role extraction from a nested JWT field like resource_access.client.roles. + """ + import jwt as pyjwt + + from litellm.proxy._types import LitellmUserRoles + + access_token_payload = { + "sub": "user-123", + "resource_access": { + "my-client": { + "roles": ["proxy_admin"] + } + }, + } + access_token_str = pyjwt.encode(access_token_payload, "secret", algorithm="HS256") + + result = CustomOpenID( + id="user-123", + email="admin@test.com", + display_name="Admin User", + team_ids=[], + user_role=None, + ) + + with patch.dict(os.environ, {"GENERIC_USER_ROLE_ATTRIBUTE": "resource_access.my-client.roles"}): + process_sso_jwt_access_token( + access_token_str=access_token_str, + sso_jwt_handler=None, + result=result, + role_mappings=None, + ) + + assert result.user_role == LitellmUserRoles.PROXY_ADMIN + + +def test_process_sso_jwt_access_token_with_role_mappings(): + """ + Test role extraction using role_mappings (group-based role determination) + from the JWT access token. + """ + import jwt as pyjwt + + from litellm.proxy._types import LitellmUserRoles + from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings + + access_token_payload = { + "sub": "user-123", + "groups": ["keycloak-admins", "developers"], + } + access_token_str = pyjwt.encode(access_token_payload, "secret", algorithm="HS256") + + result = CustomOpenID( + id="user-123", + email="admin@test.com", + display_name="Admin User", + team_ids=[], + user_role=None, + ) + + role_mappings = RoleMappings( + provider="generic", + group_claim="groups", + default_role=LitellmUserRoles.INTERNAL_USER, + roles={ + LitellmUserRoles.PROXY_ADMIN: ["keycloak-admins"], + LitellmUserRoles.INTERNAL_USER: ["developers"], + }, + ) + + process_sso_jwt_access_token( + access_token_str=access_token_str, + sso_jwt_handler=None, + result=result, + role_mappings=role_mappings, + ) + + # Should get highest privilege role + assert result.user_role == LitellmUserRoles.PROXY_ADMIN