diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 278f3bdaafd..7b3ad0423ac 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -190,9 +190,15 @@ def process_sso_jwt_access_token( if access_token_str and result: import jwt - access_token_payload = jwt.decode( - access_token_str, options={"verify_signature": False} - ) + try: + access_token_payload = jwt.decode( + access_token_str, options={"verify_signature": False} + ) + except jwt.exceptions.DecodeError: + verbose_proxy_logger.debug( + "Access token is not a valid JWT (possibly an opaque token), skipping JWT-based extraction" + ) + return # Extract team IDs from access token if sso_jwt_handler is available if sso_jwt_handler: diff --git a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py index 16f80826798..ef84699a8fa 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py @@ -2374,47 +2374,6 @@ def sample_jwt_payload(self): "groups": ["team1", "team2", "team3"], } - def test_process_sso_jwt_access_token_with_valid_token( - self, mock_jwt_handler, sample_jwt_token, sample_jwt_payload - ): - """Test processing a valid JWT access token with team extraction""" - from litellm.proxy.management_endpoints.ui_sso import ( - process_sso_jwt_access_token, - ) - - # Create a result object without team_ids - result = CustomOpenID( - id="test_user", - email="test@example.com", - first_name="Test", - last_name="User", - display_name="Test User", - provider="generic", - team_ids=[], - ) - - with patch("jwt.decode", return_value=sample_jwt_payload) as mock_jwt_decode: - # Act - process_sso_jwt_access_token( - access_token_str=sample_jwt_token, - sso_jwt_handler=mock_jwt_handler, - result=result, - ) - - # Assert - # Verify JWT was decoded correctly - mock_jwt_decode.assert_called_once_with( - sample_jwt_token, options={"verify_signature": False} - ) - - # Verify team IDs were extracted from JWT - mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with( - sample_jwt_payload - ) - - # Verify team IDs were set on the result object - assert result.team_ids == ["team1", "team2", "team3"] - def test_process_sso_jwt_access_token_with_existing_team_ids( self, mock_jwt_handler, sample_jwt_token ): @@ -2549,27 +2508,6 @@ def test_process_sso_jwt_access_token_no_access_token(self, mock_jwt_handler): mock_jwt_handler.get_team_ids_from_jwt.assert_not_called() assert result.team_ids == [] - def test_process_sso_jwt_access_token_no_sso_jwt_handler(self, sample_jwt_token): - """Test that JWT is decoded for role extraction even when sso_jwt_handler is None, - but team_ids are not extracted (team extraction requires sso_jwt_handler).""" - from litellm.proxy.management_endpoints.ui_sso import ( - process_sso_jwt_access_token, - ) - - result = CustomOpenID(id="test_user", email="test@example.com", team_ids=[]) - - mock_payload = {"sub": "test_user", "email": "test@example.com"} - with patch("jwt.decode", return_value=mock_payload) as mock_jwt_decode: - # Act - process_sso_jwt_access_token( - access_token_str=sample_jwt_token, sso_jwt_handler=None, result=result - ) - - # JWT is decoded (for role extraction) but team_ids are not extracted - mock_jwt_decode.assert_called_once() - assert result.team_ids == [] - assert result.user_role is None - def test_process_sso_jwt_access_token_no_result( self, mock_jwt_handler, sample_jwt_token ): @@ -2590,10 +2528,12 @@ def test_process_sso_jwt_access_token_no_result( mock_jwt_decode.assert_not_called() mock_jwt_handler.get_team_ids_from_jwt.assert_not_called() - def test_process_sso_jwt_access_token_jwt_decode_exception( + def test_process_sso_jwt_access_token_non_decode_exception_propagates( self, mock_jwt_handler, sample_jwt_token ): - """Test that JWT decode exceptions are not caught (should propagate up)""" + """Test that non-DecodeError JWT exceptions still propagate up.""" + import jwt as pyjwt + from litellm.proxy.management_endpoints.ui_sso import ( process_sso_jwt_access_token, ) @@ -2601,19 +2541,16 @@ def test_process_sso_jwt_access_token_jwt_decode_exception( result = CustomOpenID(id="test_user", email="test@example.com", team_ids=[]) with patch( - "jwt.decode", side_effect=Exception("JWT decode error") + "jwt.decode", side_effect=pyjwt.exceptions.InvalidKeyError("Invalid key") ) as mock_jwt_decode: - # Act & Assert - with pytest.raises(Exception, match="JWT decode error"): + with pytest.raises(pyjwt.exceptions.InvalidKeyError, match="Invalid key"): process_sso_jwt_access_token( access_token_str=sample_jwt_token, sso_jwt_handler=mock_jwt_handler, result=result, ) - # Verify JWT decode was attempted mock_jwt_decode.assert_called_once() - # But team extraction should not have been called mock_jwt_handler.get_team_ids_from_jwt.assert_not_called() def test_process_sso_jwt_access_token_empty_team_ids_from_jwt( @@ -2646,6 +2583,124 @@ def test_process_sso_jwt_access_token_empty_team_ids_from_jwt( # Even empty team IDs should be set assert result.team_ids == [] + def test_process_sso_jwt_access_token_with_opaque_token(self, mock_jwt_handler): + """Test that opaque (non-JWT) access tokens are handled gracefully without raising.""" + from litellm.proxy.management_endpoints.ui_sso import ( + process_sso_jwt_access_token, + ) + + result = CustomOpenID( + id="test_user", + email="test@example.com", + first_name="Test", + last_name="User", + display_name="Test User", + provider="generic", + team_ids=["existing_team"], + user_role=None, + ) + + # Opaque tokens like those from Logto are short random strings, not JWTs + opaque_token = "uTxyjXbS_random_opaque_token_string" + + # Should NOT raise - opaque tokens should be silently skipped + process_sso_jwt_access_token( + access_token_str=opaque_token, + sso_jwt_handler=mock_jwt_handler, + result=result, + ) + + # Result should be untouched + mock_jwt_handler.get_team_ids_from_jwt.assert_not_called() + assert result.team_ids == ["existing_team"] + assert result.user_role is None + + def test_process_sso_jwt_access_token_real_jwt_with_role_and_teams( + self, mock_jwt_handler + ): + """Test that a real JWT containing role and team fields is correctly processed.""" + import jwt as pyjwt + + from litellm.proxy.management_endpoints.ui_sso import ( + process_sso_jwt_access_token, + ) + + payload = { + "sub": "user123", + "email": "admin@example.com", + "role": "proxy_admin", + "groups": ["team_alpha", "team_beta"], + } + real_jwt_token = pyjwt.encode(payload, "test-secret", algorithm="HS256") + + mock_jwt_handler.get_team_ids_from_jwt.return_value = [ + "team_alpha", + "team_beta", + ] + + result = CustomOpenID( + id="user123", + email="admin@example.com", + first_name="Admin", + last_name="User", + display_name="Admin User", + provider="generic", + team_ids=[], + user_role=None, + ) + + process_sso_jwt_access_token( + access_token_str=real_jwt_token, + sso_jwt_handler=mock_jwt_handler, + result=result, + ) + + # Team IDs should be extracted via sso_jwt_handler + mock_jwt_handler.get_team_ids_from_jwt.assert_called_once_with(payload) + assert result.team_ids == ["team_alpha", "team_beta"] + + # Role should be extracted from the "role" field in the JWT + from litellm.proxy._types import LitellmUserRoles + + assert result.user_role == LitellmUserRoles.PROXY_ADMIN + + def test_process_sso_jwt_access_token_real_jwt_without_role_and_teams(self): + """Test that a real JWT without role/team fields leaves result unchanged.""" + import jwt as pyjwt + + from litellm.proxy.management_endpoints.ui_sso import ( + process_sso_jwt_access_token, + ) + + payload = { + "sub": "user456", + "email": "plain@example.com", + "iat": 1700000000, + } + real_jwt_token = pyjwt.encode(payload, "test-secret", algorithm="HS256") + + result = CustomOpenID( + id="user456", + email="plain@example.com", + first_name="Plain", + last_name="User", + display_name="Plain User", + provider="generic", + team_ids=[], + user_role=None, + ) + + # No sso_jwt_handler, no role/team fields in JWT + process_sso_jwt_access_token( + access_token_str=real_jwt_token, + sso_jwt_handler=None, + result=result, + ) + + # Nothing should be modified + assert result.team_ids == [] + assert result.user_role is None + @pytest.mark.asyncio async def test_get_ui_settings_includes_api_doc_base_url():