diff --git a/litellm/proxy/management_endpoints/types.py b/litellm/proxy/management_endpoints/types.py index ad2ad0a5fe5..a35fc4a5f3f 100644 --- a/litellm/proxy/management_endpoints/types.py +++ b/litellm/proxy/management_endpoints/types.py @@ -28,17 +28,24 @@ def is_valid_litellm_user_role(role_str: str) -> bool: return False -def get_litellm_user_role(role_str: str) -> Optional[LitellmUserRoles]: +def get_litellm_user_role(role_str) -> Optional[LitellmUserRoles]: """ - Convert a string to a LitellmUserRoles enum if valid (case-insensitive). + Convert a string (or list of strings) to a LitellmUserRoles enum if valid (case-insensitive). + + Handles list inputs since some SSO providers (e.g., Keycloak) return roles + as arrays like ["proxy_admin"] instead of plain strings. Args: - role_str: String to convert (e.g., "proxy_admin", "PROXY_ADMIN", "internal_user") + role_str: String or list to convert (e.g., "proxy_admin", ["proxy_admin"]) Returns: LitellmUserRoles enum if valid, None otherwise """ try: + if isinstance(role_str, list): + if len(role_str) == 0: + return None + role_str = role_str[0] # Use _value2member_map_ for O(1) lookup, case-insensitive result = LitellmUserRoles._value2member_map_.get(role_str.lower()) return cast(Optional[LitellmUserRoles], result) diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 2d248dc81f3..278f3bdaafd 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -171,36 +171,86 @@ def process_sso_jwt_access_token( access_token_str: Optional[str], sso_jwt_handler: Optional[JWTHandler], result: Union[OpenID, dict, None], + role_mappings: Optional["RoleMappings"] = None, ) -> None: """ - Process SSO JWT access token and extract team IDs if available. + Process SSO JWT access token and extract team IDs and user role if available. - This function decodes the JWT access token and extracts team IDs using the - sso_jwt_handler, then sets the team_ids attribute on the result object. + This function decodes the JWT access token and extracts team IDs and user + role, then sets them on the result object. Role extraction from the access + token is needed because some SSO providers (e.g., Keycloak) do not include + role claims in the UserInfo endpoint response. Args: access_token_str: The JWT access token string sso_jwt_handler: SSO-specific JWT handler for team ID extraction - result: The SSO result object to update with team IDs + result: The SSO result object to update with team IDs and role + role_mappings: Optional role mappings configuration for group-based role determination """ - if access_token_str and sso_jwt_handler and result: + if access_token_str and result: import jwt access_token_payload = jwt.decode( access_token_str, options={"verify_signature": False} ) - # Handle both dict and object result types - if isinstance(result, dict): - result_team_ids: Optional[List[str]] = result.get("team_ids", []) - if not result_team_ids: - team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload) - result["team_ids"] = team_ids - else: - result_team_ids = getattr(result, "team_ids", []) if result else [] - if not result_team_ids: - team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload) - setattr(result, "team_ids", team_ids) + # Extract team IDs from access token if sso_jwt_handler is available + if sso_jwt_handler: + if isinstance(result, dict): + result_team_ids: Optional[List[str]] = result.get("team_ids", []) + if not result_team_ids: + team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload) + result["team_ids"] = team_ids + else: + result_team_ids = getattr(result, "team_ids", []) if result else [] + if not result_team_ids: + team_ids = sso_jwt_handler.get_team_ids_from_jwt(access_token_payload) + setattr(result, "team_ids", team_ids) + + # Extract user role from access token if not already set from UserInfo + existing_role = result.get("user_role") if isinstance(result, dict) else getattr(result, "user_role", None) + if existing_role is None: + user_role: Optional[LitellmUserRoles] = None + + # Try role_mappings first (group-based role determination) + if role_mappings is not None and role_mappings.roles: + group_claim = role_mappings.group_claim + user_groups_raw: Any = get_nested_value(access_token_payload, group_claim) + + user_groups: List[str] = [] + if isinstance(user_groups_raw, list): + user_groups = [str(g) for g in user_groups_raw] + elif isinstance(user_groups_raw, str): + user_groups = [g.strip() for g in user_groups_raw.split(",") if g.strip()] + elif user_groups_raw is not None: + user_groups = [str(user_groups_raw)] + + if user_groups: + user_role = determine_role_from_groups(user_groups, role_mappings) + verbose_proxy_logger.debug( + f"Determined role '{user_role}' from access token groups '{user_groups}' using role_mappings" + ) + elif role_mappings.default_role: + user_role = role_mappings.default_role + + # Fallback: try GENERIC_USER_ROLE_ATTRIBUTE on the access token payload + if user_role is None: + generic_user_role_attribute_name = os.getenv("GENERIC_USER_ROLE_ATTRIBUTE", "role") + user_role_from_token = get_nested_value(access_token_payload, generic_user_role_attribute_name) + if user_role_from_token is not None: + user_role = get_litellm_user_role(user_role_from_token) + verbose_proxy_logger.debug( + f"Extracted role '{user_role}' from access token field '{generic_user_role_attribute_name}'" + ) + + if user_role is not None: + if isinstance(result, dict): + result["user_role"] = user_role + else: + setattr(result, "user_role", user_role) + verbose_proxy_logger.debug( + f"Set user_role='{user_role}' from JWT access token" + ) @router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False) @@ -688,7 +738,7 @@ def response_convertor(response, client): ) access_token_str: Optional[str] = generic_sso.access_token - process_sso_jwt_access_token(access_token_str, sso_jwt_handler, result) + process_sso_jwt_access_token(access_token_str, sso_jwt_handler, result, role_mappings=role_mappings) except Exception as e: verbose_proxy_logger.exception( diff --git a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py index 41096503a2e..16f80826798 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py @@ -25,6 +25,8 @@ MicrosoftSSOHandler, SSOAuthenticationHandler, normalize_email, + process_sso_jwt_access_token, + determine_role_from_groups, _setup_team_mappings, ) from litellm.types.proxy.management_endpoints.ui_sso import ( @@ -1298,6 +1300,7 @@ async def test_get_generic_sso_response_with_additional_headers(): # Mock the SSO provider and its methods mock_sso_instance = MagicMock() mock_sso_instance.verify_and_process = AsyncMock(return_value=mock_sso_response) + mock_sso_instance.access_token = None # Avoid triggering JWT decode in process_sso_jwt_access_token mock_sso_class = MagicMock(return_value=mock_sso_instance) @@ -1359,6 +1362,7 @@ async def test_get_generic_sso_response_with_empty_headers(): # Mock the SSO provider and its methods mock_sso_instance = MagicMock() mock_sso_instance.verify_and_process = AsyncMock(return_value=mock_sso_response) + mock_sso_instance.access_token = None # Avoid triggering JWT decode in process_sso_jwt_access_token mock_sso_class = MagicMock(return_value=mock_sso_instance) @@ -2546,22 +2550,25 @@ def test_process_sso_jwt_access_token_no_access_token(self, mock_jwt_handler): assert result.team_ids == [] def test_process_sso_jwt_access_token_no_sso_jwt_handler(self, sample_jwt_token): - """Test that nothing happens when sso_jwt_handler is None""" + """Test that JWT is decoded for role extraction even when sso_jwt_handler is None, + but team_ids are not extracted (team extraction requires sso_jwt_handler).""" from litellm.proxy.management_endpoints.ui_sso import ( process_sso_jwt_access_token, ) result = CustomOpenID(id="test_user", email="test@example.com", team_ids=[]) - with patch("jwt.decode") as mock_jwt_decode: + mock_payload = {"sub": "test_user", "email": "test@example.com"} + with patch("jwt.decode", return_value=mock_payload) as mock_jwt_decode: # Act process_sso_jwt_access_token( access_token_str=sample_jwt_token, sso_jwt_handler=None, result=result ) - # Assert nothing was processed - mock_jwt_decode.assert_not_called() + # JWT is decoded (for role extraction) but team_ids are not extracted + mock_jwt_decode.assert_called_once() assert result.team_ids == [] + assert result.user_role is None def test_process_sso_jwt_access_token_no_result( self, mock_jwt_handler, sample_jwt_token @@ -3848,3 +3855,219 @@ async def test_setup_team_mappings(): mock_prisma.db.litellm_ssoconfig.find_unique.assert_called_once_with( where={"id": "sso_config"} ) + + +# ============================================================================ +# Tests for get_litellm_user_role with list inputs (Keycloak returns lists) +# ============================================================================ + + +def test_get_litellm_user_role_with_string(): + """Test that get_litellm_user_role works with a plain string.""" + from litellm.proxy._types import LitellmUserRoles + from litellm.proxy.management_endpoints.types import get_litellm_user_role + + result = get_litellm_user_role("proxy_admin") + assert result == LitellmUserRoles.PROXY_ADMIN + + +def test_get_litellm_user_role_with_list(): + """ + Test that get_litellm_user_role handles list inputs. + Keycloak returns roles as arrays like ["proxy_admin"] instead of strings. + """ + from litellm.proxy._types import LitellmUserRoles + from litellm.proxy.management_endpoints.types import get_litellm_user_role + + result = get_litellm_user_role(["proxy_admin"]) + assert result == LitellmUserRoles.PROXY_ADMIN + + +def test_get_litellm_user_role_with_empty_list(): + """Test that get_litellm_user_role returns None for empty lists.""" + from litellm.proxy.management_endpoints.types import get_litellm_user_role + + result = get_litellm_user_role([]) + assert result is None + + +def test_get_litellm_user_role_with_invalid_role(): + """Test that get_litellm_user_role returns None for invalid roles.""" + from litellm.proxy.management_endpoints.types import get_litellm_user_role + + result = get_litellm_user_role("not_a_real_role") + assert result is None + + +def test_get_litellm_user_role_with_list_multiple_roles(): + """Test that get_litellm_user_role takes the first element from a multi-element list.""" + from litellm.proxy._types import LitellmUserRoles + from litellm.proxy.management_endpoints.types import get_litellm_user_role + + result = get_litellm_user_role(["proxy_admin", "internal_user"]) + assert result == LitellmUserRoles.PROXY_ADMIN + + +# ============================================================================ +# Tests for process_sso_jwt_access_token role extraction +# ============================================================================ + + +def test_process_sso_jwt_access_token_extracts_role_from_access_token(): + """ + Test that process_sso_jwt_access_token extracts user role from the JWT + access token when the UserInfo response did not include it. + + This is the core fix for the Keycloak SSO role mapping bug: Keycloak's + UserInfo endpoint does not return role claims, but the JWT access token + contains them. + """ + import jwt as pyjwt + + from litellm.proxy._types import LitellmUserRoles + + # Create a JWT access token with role claims (as Keycloak would) + access_token_payload = { + "sub": "user-123", + "email": "admin@test.com", + "litellm_role": ["proxy_admin"], + } + access_token_str = pyjwt.encode(access_token_payload, "secret", algorithm="HS256") + + # Result object with no role set (simulating UserInfo response without roles) + result = CustomOpenID( + id="user-123", + email="admin@test.com", + display_name="Admin User", + team_ids=[], + user_role=None, + ) + + # Call with GENERIC_USER_ROLE_ATTRIBUTE pointing to litellm_role + with patch.dict(os.environ, {"GENERIC_USER_ROLE_ATTRIBUTE": "litellm_role"}): + process_sso_jwt_access_token( + access_token_str=access_token_str, + sso_jwt_handler=None, + result=result, + role_mappings=None, + ) + + assert result.user_role == LitellmUserRoles.PROXY_ADMIN + + +def test_process_sso_jwt_access_token_does_not_override_existing_role(): + """ + Test that process_sso_jwt_access_token does NOT override a role that was + already extracted from the UserInfo response. + """ + import jwt as pyjwt + + from litellm.proxy._types import LitellmUserRoles + + access_token_payload = { + "sub": "user-123", + "litellm_role": ["internal_user"], + } + access_token_str = pyjwt.encode(access_token_payload, "secret", algorithm="HS256") + + # Result already has a role (e.g., set from UserInfo) + result = CustomOpenID( + id="user-123", + email="admin@test.com", + display_name="Admin User", + team_ids=[], + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + with patch.dict(os.environ, {"GENERIC_USER_ROLE_ATTRIBUTE": "litellm_role"}): + process_sso_jwt_access_token( + access_token_str=access_token_str, + sso_jwt_handler=None, + result=result, + role_mappings=None, + ) + + # Should keep the original role + assert result.user_role == LitellmUserRoles.PROXY_ADMIN + + +def test_process_sso_jwt_access_token_extracts_role_from_nested_field(): + """ + Test role extraction from a nested JWT field like resource_access.client.roles. + """ + import jwt as pyjwt + + from litellm.proxy._types import LitellmUserRoles + + access_token_payload = { + "sub": "user-123", + "resource_access": { + "my-client": { + "roles": ["proxy_admin"] + } + }, + } + access_token_str = pyjwt.encode(access_token_payload, "secret", algorithm="HS256") + + result = CustomOpenID( + id="user-123", + email="admin@test.com", + display_name="Admin User", + team_ids=[], + user_role=None, + ) + + with patch.dict(os.environ, {"GENERIC_USER_ROLE_ATTRIBUTE": "resource_access.my-client.roles"}): + process_sso_jwt_access_token( + access_token_str=access_token_str, + sso_jwt_handler=None, + result=result, + role_mappings=None, + ) + + assert result.user_role == LitellmUserRoles.PROXY_ADMIN + + +def test_process_sso_jwt_access_token_with_role_mappings(): + """ + Test role extraction using role_mappings (group-based role determination) + from the JWT access token. + """ + import jwt as pyjwt + + from litellm.proxy._types import LitellmUserRoles + from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings + + access_token_payload = { + "sub": "user-123", + "groups": ["keycloak-admins", "developers"], + } + access_token_str = pyjwt.encode(access_token_payload, "secret", algorithm="HS256") + + result = CustomOpenID( + id="user-123", + email="admin@test.com", + display_name="Admin User", + team_ids=[], + user_role=None, + ) + + role_mappings = RoleMappings( + provider="generic", + group_claim="groups", + default_role=LitellmUserRoles.INTERNAL_USER, + roles={ + LitellmUserRoles.PROXY_ADMIN: ["keycloak-admins"], + LitellmUserRoles.INTERNAL_USER: ["developers"], + }, + ) + + process_sso_jwt_access_token( + access_token_str=access_token_str, + sso_jwt_handler=None, + result=result, + role_mappings=role_mappings, + ) + + # Should get highest privilege role + assert result.user_role == LitellmUserRoles.PROXY_ADMIN