diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 2d248dc81f3..1dc239e50a4 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -14,7 +14,7 @@ import os import secrets from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import RedirectResponse @@ -82,7 +82,15 @@ get_server_root_path, ) from litellm.secret_managers.main import get_secret_bool, str_to_bool -from litellm.types.proxy.management_endpoints.ui_sso import * +from litellm.types.proxy.management_endpoints.ui_sso import ( + DefaultTeamSSOParams, + MicrosoftGraphAPIUserGroupDirectoryObject, + MicrosoftGraphAPIUserGroupResponse, + MicrosoftServicePrincipalTeam, + RoleMappings, + TeamMappings, +) +from litellm.types.proxy.management_endpoints.ui_sso import * # noqa: F403, F401 from litellm.types.proxy.ui_sso import ParsedOpenIDResult if TYPE_CHECKING: @@ -96,15 +104,15 @@ def normalize_email(email: Optional[str]) -> Optional[str]: """ Normalize email address to lowercase for consistent storage and comparison. - + Email addresses should be treated as case-insensitive for SSO purposes, even though RFC 5321 technically allows case-sensitive local parts. This prevents issues where SSO providers return emails with different casing than what's stored in the database. - + Args: email: Email address to normalize, can be None - + Returns: Lowercased email address, or None if input is None """ @@ -280,7 +288,7 @@ async def google_login( # check if user defined a custom auth sso sign in handler, if yes, use it if user_custom_ui_sso_sign_in_handler is not None: try: - from litellm_enterprise.proxy.auth.custom_sso_handler import ( + from litellm_enterprise.proxy.auth.custom_sso_handler import ( # type: ignore[import-untyped] EnterpriseCustomSSOHandler, ) @@ -428,7 +436,9 @@ def generic_response_convertor( display_name=get_nested_value( response, generic_user_display_name_attribute_name ), - email=normalize_email(get_nested_value(response, generic_user_email_attribute_name)), + email=normalize_email( + get_nested_value(response, generic_user_email_attribute_name) + ), first_name=get_nested_value(response, generic_user_first_name_attribute_name), last_name=get_nested_value(response, generic_user_last_name_attribute_name), provider=get_nested_value(response, generic_provider_attribute_name), @@ -517,6 +527,7 @@ async def _setup_team_mappings() -> Optional["TeamMappings"]: if team_mappings_data: from litellm.types.proxy.management_endpoints.ui_sso import TeamMappings + if isinstance(team_mappings_data, dict): team_mappings = TeamMappings(**team_mappings_data) elif isinstance(team_mappings_data, TeamMappings): @@ -554,6 +565,7 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: if role_mappings_data: from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings + if isinstance(role_mappings_data, dict): role_mappings = RoleMappings(**role_mappings_data) elif isinstance(role_mappings_data, RoleMappings): @@ -567,7 +579,7 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: verbose_proxy_logger.debug( f"Could not load role_mappings from database: {e}. Continuing with existing role logic." ) - + generic_role_mappings = os.getenv("GENERIC_ROLE_MAPPINGS_ROLES", None) generic_role_mappings_group_claim = os.getenv( "GENERIC_ROLE_MAPPINGS_GROUP_CLAIM", None @@ -577,8 +589,8 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: ) if generic_role_mappings is not None: verbose_proxy_logger.debug( - "Found role_mappings for generic provider in environment variables" - ) + "Found role_mappings for generic provider in environment variables" + ) import ast try: @@ -603,7 +615,9 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: ) return role_mappings except TypeError as e: - verbose_proxy_logger.warning(f"Error decoding role mappings from environment variables: {e}. Continuing with existing role logic.") + verbose_proxy_logger.warning( + f"Error decoding role mappings from environment variables: {e}. Continuing with existing role logic." + ) return role_mappings @@ -680,7 +694,7 @@ def response_convertor(response, client): try: result = await generic_sso.verify_and_process( request, - params=SSOAuthenticationHandler.prepare_token_exchange_parameters( + params=await SSOAuthenticationHandler.prepare_token_exchange_parameters( request=request, generic_include_client_id=generic_include_client_id, ), @@ -875,7 +889,7 @@ def _build_sso_user_update_data( Returns: dict: Update data containing user_email and optionally user_role if valid - """ + """ update_data: dict = {"user_email": normalize_email(user_email)} # Get SSO role from result and include if valid @@ -1673,7 +1687,7 @@ async def get_generic_sso_redirect_response( """ from urllib.parse import parse_qs, urlencode, urlparse, urlunparse - from litellm.proxy.proxy_server import user_api_key_cache + from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache with generic_sso: # TODO: state should be a random string and added to the user session with cookie @@ -1702,13 +1716,21 @@ async def get_generic_sso_redirect_response( # If PKCE is enabled, add PKCE parameters to the redirect URL if code_verifier and "state" in redirect_params: - # Store code_verifier in cache (10 min TTL) + # Store code_verifier in cache (10 min TTL). Use Redis when available + # so callbacks landing on another pod can retrieve it (multi-pod SSO). cache_key = f"pkce_verifier:{redirect_params['state']}" - user_api_key_cache.set_cache( - key=cache_key, - value=code_verifier, - ttl=600, - ) + if redis_usage_cache is not None: + await redis_usage_cache.async_set_cache( + key=cache_key, + value=code_verifier, + ttl=600, + ) + else: + await user_api_key_cache.async_set_cache( + key=cache_key, + value=code_verifier, + ttl=600, + ) # Add PKCE parameters to the authorization URL if pkce_params: @@ -2305,7 +2327,7 @@ async def get_redirect_response_from_openid( # noqa: PLR0915 return redirect_response @staticmethod - def prepare_token_exchange_parameters( + async def prepare_token_exchange_parameters( request: Request, generic_include_client_id: bool, ) -> dict: @@ -2319,27 +2341,38 @@ def prepare_token_exchange_parameters( Returns: dict: Token exchange parameters """ - # Prepare token exchange parameters - token_params = {"include_client_id": generic_include_client_id} + # Prepare token exchange parameters (may add code_verifier: str later) + token_params: Dict[str, Any] = {"include_client_id": generic_include_client_id} - # Retrieve PKCE code_verifier if PKCE was used in authorization + # Retrieve PKCE code_verifier if PKCE was used in authorization. + # Use same cache as store: Redis when available (multi-pod), else in-memory. query_params = dict(request.query_params) state = query_params.get("state") if state: - from litellm.proxy.proxy_server import user_api_key_cache + from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache cache_key = f"pkce_verifier:{state}" - code_verifier = user_api_key_cache.get_cache(key=cache_key) + if redis_usage_cache is not None: + code_verifier = await redis_usage_cache.async_get_cache(key=cache_key) + else: + code_verifier = await user_api_key_cache.async_get_cache(key=cache_key) if code_verifier: - # Add code_verifier to token exchange parameters - token_params["code_verifier"] = code_verifier + # Add code_verifier to token exchange parameters (Redis returns decoded string) + token_params["code_verifier"] = ( + code_verifier + if isinstance(code_verifier, str) + else str(code_verifier) + ) verbose_proxy_logger.debug( "PKCE code_verifier retrieved and will be included in token exchange" ) # Clean up the cache entry (single-use verifier) - user_api_key_cache.delete_cache(key=cache_key) + if redis_usage_cache is not None: + await redis_usage_cache.async_delete_cache(key=cache_key) + else: + await user_api_key_cache.async_delete_cache(key=cache_key) return token_params @staticmethod @@ -2482,7 +2515,9 @@ def openid_from_response( response = response or {} verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}") openid_response = CustomOpenID( - email=normalize_email(response.get(MICROSOFT_USER_EMAIL_ATTRIBUTE) or response.get("mail")), + email=normalize_email( + response.get(MICROSOFT_USER_EMAIL_ATTRIBUTE) or response.get("mail") + ), display_name=response.get(MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE), provider="microsoft", id=response.get(MICROSOFT_USER_ID_ATTRIBUTE), diff --git a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py index 41096503a2e..938d4e8c871 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py @@ -2,12 +2,10 @@ import json import os import sys -from typing import Optional, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import Request -from fastapi.testclient import TestClient from litellm._uuid import uuid @@ -16,7 +14,7 @@ ) # Adds the parent directory to the system path import litellm -from litellm.proxy._types import LiteLLM_UserTable, NewTeamRequest, NewUserResponse +from litellm.proxy._types import LiteLLM_UserTable, NewUserResponse from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.management_endpoints.sso import CustomMicrosoftSSO from litellm.proxy.management_endpoints.types import CustomOpenID @@ -134,16 +132,32 @@ def test_microsoft_sso_handler_openid_from_response_with_custom_attributes(): expected_team_ids = ["team1"] # Act - with patch("litellm.constants.MICROSOFT_USER_EMAIL_ATTRIBUTE", "custom_email_field"), \ - patch("litellm.constants.MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE", "custom_display_name"), \ - patch("litellm.constants.MICROSOFT_USER_ID_ATTRIBUTE", "custom_id_field"), \ - patch("litellm.constants.MICROSOFT_USER_FIRST_NAME_ATTRIBUTE", "custom_first_name"), \ - patch("litellm.constants.MICROSOFT_USER_LAST_NAME_ATTRIBUTE", "custom_last_name"), \ - patch("litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_EMAIL_ATTRIBUTE", "custom_email_field"), \ - patch("litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE", "custom_display_name"), \ - patch("litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_ID_ATTRIBUTE", "custom_id_field"), \ - patch("litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_FIRST_NAME_ATTRIBUTE", "custom_first_name"), \ - patch("litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_LAST_NAME_ATTRIBUTE", "custom_last_name"): + with patch( + "litellm.constants.MICROSOFT_USER_EMAIL_ATTRIBUTE", "custom_email_field" + ), patch( + "litellm.constants.MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE", "custom_display_name" + ), patch( + "litellm.constants.MICROSOFT_USER_ID_ATTRIBUTE", "custom_id_field" + ), patch( + "litellm.constants.MICROSOFT_USER_FIRST_NAME_ATTRIBUTE", "custom_first_name" + ), patch( + "litellm.constants.MICROSOFT_USER_LAST_NAME_ATTRIBUTE", "custom_last_name" + ), patch( + "litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_EMAIL_ATTRIBUTE", + "custom_email_field", + ), patch( + "litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE", + "custom_display_name", + ), patch( + "litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_ID_ATTRIBUTE", + "custom_id_field", + ), patch( + "litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_FIRST_NAME_ATTRIBUTE", + "custom_first_name", + ), patch( + "litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_LAST_NAME_ATTRIBUTE", + "custom_last_name", + ): result = MicrosoftSSOHandler.openid_from_response( response=mock_response, team_ids=expected_team_ids, user_role=None ) @@ -229,7 +243,6 @@ def test_get_microsoft_callback_response_raw_sso_response(): ) # Assert - print("result from verify_and_process", result) assert isinstance(result, dict) assert result["mail"] == "microsoft_user@example.com" assert result["displayName"] == "Microsoft User" @@ -453,10 +466,6 @@ def mock_jsonify_team_object(db_data): # Assert # Verify team was created with correct parameters mock_prisma.db.litellm_teamtable.create.assert_called_once() - print( - "mock_prisma.db.litellm_teamtable.create.call_args", - mock_prisma.db.litellm_teamtable.create.call_args, - ) create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[ "data" ] @@ -581,7 +590,7 @@ def test_apply_user_info_values_to_sso_user_defined_values_with_models(): def test_apply_user_info_values_sso_role_takes_precedence(): """ Test that SSO role takes precedence over DB role. - + When Microsoft SSO returns a user_role, it should be used instead of the role stored in the database. This ensures SSO is the authoritative source for user roles. """ @@ -676,16 +685,16 @@ def test_normalize_email(): """ # Test with lowercase email assert normalize_email("test@example.com") == "test@example.com" - + # Test with uppercase email assert normalize_email("TEST@EXAMPLE.COM") == "test@example.com" - + # Test with mixed case email assert normalize_email("Test.User@Example.COM") == "test.user@example.com" - + # Test with None assert normalize_email(None) is None - + # Test with empty string assert normalize_email("") == "" @@ -898,7 +907,7 @@ async def test_upsert_sso_user_no_role_in_sso_response(): def test_get_user_email_and_id_extracts_microsoft_role(): """ Test that _get_user_email_and_id_from_result extracts user_role from Microsoft SSO. - + This ensures Microsoft SSO roles (from app_roles in id_token) are properly extracted and converted from enum to string. """ @@ -964,7 +973,7 @@ async def test_get_user_info_from_db_user_exists(): with patch( "litellm.proxy.management_endpoints.ui_sso.get_user_object" ) as mock_get_user_object: - user_info = await get_user_info_from_db(**args) + await get_user_info_from_db(**args) mock_get_user_object.assert_called_once() assert mock_get_user_object.call_args.kwargs["user_id"] == "krrishd" @@ -1006,7 +1015,7 @@ async def test_get_user_info_from_db_user_exists_alternate_user_id(): with patch( "litellm.proxy.management_endpoints.ui_sso.get_user_object" ) as mock_get_user_object: - user_info = await get_user_info_from_db(**args) + await get_user_info_from_db(**args) mock_get_user_object.assert_called_once() assert mock_get_user_object.call_args.kwargs["user_id"] == "krrishd-email1234" @@ -1015,7 +1024,7 @@ async def test_get_user_info_from_db_user_exists_alternate_user_id(): async def test_get_user_info_from_db_user_not_exists_creates_user(): """ Test that get_user_info_from_db creates a new user when user doesn't exist in DB. - + When get_existing_user_info_from_db returns None, get_user_info_from_db should: 1. Call upsert_sso_user with user_info=None 2. upsert_sso_user should call insert_sso_user to create the user @@ -1103,7 +1112,7 @@ async def test_get_user_info_from_db_user_not_exists_creates_user(): async def test_get_user_info_from_db_user_exists_updates_user(): """ Test that get_user_info_from_db updates existing user when user exists in DB. - + When get_existing_user_info_from_db returns a user, get_user_info_from_db should: 1. Call upsert_sso_user with the existing user_info 2. upsert_sso_user should update the user in the database @@ -1195,6 +1204,7 @@ async def test_get_user_info_from_db_user_exists_updates_user(): # Should return the updated user assert user_info == updated_user + @pytest.mark.asyncio async def test_check_and_update_if_proxy_admin_id(): """ @@ -1302,10 +1312,10 @@ async def test_get_generic_sso_response_with_additional_headers(): mock_sso_class = MagicMock(return_value=mock_sso_instance) with patch.dict(os.environ, test_env_vars): - with patch("fastapi_sso.sso.base.DiscoveryDocument") as mock_discovery: + with patch("fastapi_sso.sso.base.DiscoveryDocument"): with patch( "fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class - ) as mock_create_provider: + ): # Act result, received_response = await get_generic_sso_response( request=mock_request, @@ -1363,10 +1373,10 @@ async def test_get_generic_sso_response_with_empty_headers(): mock_sso_class = MagicMock(return_value=mock_sso_instance) with patch.dict(os.environ, test_env_vars): - with patch("fastapi_sso.sso.base.DiscoveryDocument") as mock_discovery: + with patch("fastapi_sso.sso.base.DiscoveryDocument"): with patch( "fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class - ) as mock_create_provider: + ): # Act result, received_response = await get_generic_sso_response( request=mock_request, @@ -1751,8 +1761,6 @@ def test_enterprise_import_error_handling(self): """Test that proper error is raised when enterprise module is not available""" from unittest.mock import MagicMock, patch - from litellm.proxy.management_endpoints.ui_sso import google_login - # Mock request mock_request = MagicMock() mock_request.base_url = "https://test.example.com/" @@ -1774,7 +1782,7 @@ async def mock_google_login(): # This mimics the relevant part of google_login that would trigger the import error try: from enterprise.litellm_enterprise.proxy.auth.custom_sso_handler import ( - EnterpriseCustomSSOHandler, + EnterpriseCustomSSOHandler, # noqa: F401 ) return "success" @@ -1978,59 +1986,56 @@ async def test_cli_sso_callback_stores_session(self): # Test data session_key = "sk-session-456" - + # Mock user info mock_user_info = LiteLLM_UserTable( user_id="test-user-123", user_role="internal_user", teams=["team1", "team2"], - models=["gpt-4"] + models=["gpt-4"], ) # Mock SSO result - mock_sso_result = { - "user_email": "test@example.com", - "user_id": "test-user-123" - } + mock_sso_result = {"user_email": "test@example.com", "user_id": "test-user-123"} # Mock cache mock_cache = MagicMock() - + with patch( "litellm.proxy.management_endpoints.ui_sso.get_user_info_from_db", - return_value=mock_user_info - ), patch( - "litellm.proxy.proxy_server.prisma_client", MagicMock() - ), patch( + return_value=mock_user_info, + ), patch("litellm.proxy.proxy_server.prisma_client", MagicMock()), patch( "litellm.proxy.proxy_server.user_api_key_cache", mock_cache ), patch( "litellm.proxy.common_utils.html_forms.cli_sso_success.render_cli_sso_success_page", return_value="Success", ): - # Act result = await cli_sso_callback( - request=mock_request, key=session_key, existing_key=None, result=mock_sso_result + request=mock_request, + key=session_key, + existing_key=None, + result=mock_sso_result, ) # Assert - verify session was stored in cache mock_cache.set_cache.assert_called_once() call_args = mock_cache.set_cache.call_args - + # Verify cache key format assert "cli_sso_session:" in call_args.kwargs["key"] assert session_key in call_args.kwargs["key"] - + # Verify session data structure session_data = call_args.kwargs["value"] assert session_data["user_id"] == "test-user-123" assert session_data["user_role"] == "internal_user" assert session_data["teams"] == ["team1", "team2"] assert session_data["models"] == ["gpt-4"] - + # Verify TTL assert call_args.kwargs["ttl"] == 600 # 10 minutes - + assert result.status_code == 200 # Verify response contains success message (response is HTML) assert result.body is not None @@ -2046,17 +2051,14 @@ async def test_cli_poll_key_returns_teams_for_selection(self): "user_id": "test-user-456", "user_role": "internal_user", "teams": ["team-a", "team-b", "team-c"], - "models": ["gpt-4"] + "models": ["gpt-4"], } # Mock cache mock_cache = MagicMock() mock_cache.get_cache.return_value = session_data - - with patch( - "litellm.proxy.proxy_server.user_api_key_cache", mock_cache - ): + with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache): # Act - First poll without team_id result = await cli_poll_key(key_id=session_key, team_id=None) @@ -2066,7 +2068,7 @@ async def test_cli_poll_key_returns_teams_for_selection(self): assert result["user_id"] == "test-user-456" assert result["teams"] == ["team-a", "team-b", "team-c"] assert "key" not in result # JWT should not be generated yet - + # Verify session was NOT deleted mock_cache.delete_cache.assert_not_called() @@ -2170,34 +2172,33 @@ async def test_cli_poll_key_generates_jwt_with_team(self): "user_role": "internal_user", "teams": ["team-a", "team-b", "team-c"], "models": ["gpt-4"], - "user_email": "test@example.com" + "user_email": "test@example.com", } - + # Mock user info mock_user_info = LiteLLM_UserTable( user_id="test-user-789", user_role="internal_user", teams=["team-a", "team-b", "team-c"], - models=["gpt-4"] + models=["gpt-4"], ) # Mock cache mock_cache = MagicMock() mock_cache.get_cache.return_value = session_data - + mock_jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test.token" - - with patch( - "litellm.proxy.proxy_server.user_api_key_cache", mock_cache - ), patch( + + with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache), patch( "litellm.proxy.proxy_server.prisma_client" ) as mock_prisma, patch( "litellm.proxy.auth.auth_checks.ExperimentalUIJWTToken.get_cli_jwt_auth_token", - return_value=mock_jwt_token + return_value=mock_jwt_token, ) as mock_get_jwt: - # Mock the user lookup - mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=mock_user_info) + mock_prisma.db.litellm_usertable.find_unique = AsyncMock( + return_value=mock_user_info + ) # Act - Second poll with team_id result = await cli_poll_key(key_id=session_key, team_id=selected_team) @@ -2208,12 +2209,12 @@ async def test_cli_poll_key_generates_jwt_with_team(self): assert result["user_id"] == "test-user-789" assert result["team_id"] == selected_team assert result["teams"] == ["team-a", "team-b", "team-c"] - + # Verify JWT was generated with correct team mock_get_jwt.assert_called_once() jwt_call_args = mock_get_jwt.call_args assert jwt_call_args.kwargs["team_id"] == selected_team - + # Verify session was deleted after JWT generation mock_cache.delete_cache.assert_called_once() @@ -2223,7 +2224,6 @@ class TestGetAppRolesFromIdToken: def test_roles_picked_when_app_roles_not_exists(self): """Test that 'roles' is picked when 'app_roles' doesn't exist""" - import jwt # Create a token with only 'roles' claim token_payload = { @@ -2247,7 +2247,6 @@ def test_roles_picked_when_app_roles_not_exists(self): def test_app_roles_picked_when_both_exist(self): """Test that 'app_roles' takes precedence when both 'app_roles' and 'roles' exist""" - import jwt # Create a token with both 'app_roles' and 'roles' claims token_payload = { @@ -2268,7 +2267,6 @@ def test_app_roles_picked_when_both_exist(self): def test_roles_picked_when_app_roles_is_empty(self): """Test that 'roles' is picked when 'app_roles' exists but is empty""" - import jwt # Create a token with empty 'app_roles' and populated 'roles' token_payload = { @@ -2289,7 +2287,6 @@ def test_roles_picked_when_app_roles_is_empty(self): def test_empty_list_when_neither_exists(self): """Test that empty list is returned when neither 'app_roles' nor 'roles' exist""" - import jwt # Create a token without roles claims token_payload = {"sub": "user123", "email": "test@example.com"} @@ -2313,7 +2310,6 @@ def test_empty_list_when_no_token_provided(self): def test_empty_list_when_roles_not_a_list(self): """Test that empty list is returned when roles is not a list""" - import jwt # Create a token with non-list roles token_payload = { @@ -2333,7 +2329,6 @@ def test_empty_list_when_roles_not_a_list(self): def test_error_handling_on_jwt_decode_exception(self): """Test that exceptions during JWT decode are handled gracefully""" - import jwt mock_token = "invalid.jwt.token" @@ -2726,12 +2721,6 @@ def test_generic_response_convertor_with_nested_attributes(self): # to handle dotted paths like "attributes.userId" # Current behavior: returns None for nested paths - print(f"User ID result: {result.id}") - print(f"Email result: {result.email}") - print(f"First name result: {result.first_name}") - print(f"Last name result: {result.last_name}") - print(f"Display name result: {result.display_name}") - # Expected behavior with current implementation (no nested path support): assert result.id == "nested-user-456" assert ( @@ -2821,14 +2810,15 @@ def test_state_priority_cli_state_provided(self): # Arrange cli_state = "litellm-session-token:sk-test123" - + with patch.dict(os.environ, {"GENERIC_CLIENT_STATE": "env_state_value"}): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=cli_state, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=cli_state, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert @@ -2843,14 +2833,15 @@ def test_state_priority_env_variable_when_no_cli_state(self): # Arrange env_state = "custom_env_state_value" - + with patch.dict(os.environ, {"GENERIC_CLIENT_STATE": env_state}): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=None, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=None, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert @@ -2867,13 +2858,14 @@ def test_state_priority_generated_uuid_fallback(self): with patch.dict(os.environ, {}, clear=False): # Remove GENERIC_CLIENT_STATE if it exists os.environ.pop("GENERIC_CLIENT_STATE", None) - + # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=None, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=None, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert @@ -2893,26 +2885,27 @@ def test_state_with_pkce_enabled(self): # Arrange test_state = "test_state_123" - + with patch.dict(os.environ, {"GENERIC_CLIENT_USE_PKCE": "true"}): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=test_state, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=test_state, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert state assert redirect_params["state"] == test_state - + # Assert PKCE parameters assert code_verifier is not None assert len(code_verifier) == 43 # Standard PKCE verifier length assert "code_challenge" in redirect_params assert "code_challenge_method" in redirect_params assert redirect_params["code_challenge_method"] == "S256" - + # Verify code_challenge is correctly derived from code_verifier expected_challenge_bytes = hashlib.sha256( code_verifier.encode("utf-8") @@ -2932,14 +2925,15 @@ def test_state_with_pkce_disabled(self): # Arrange test_state = "test_state_456" - + with patch.dict(os.environ, {"GENERIC_CLIENT_USE_PKCE": "false"}): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=test_state, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=test_state, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert @@ -2957,7 +2951,7 @@ def test_state_priority_cli_state_overrides_env_with_pkce(self): # Arrange cli_state = "cli_state_priority" env_state = "env_state_should_not_be_used" - + with patch.dict( os.environ, { @@ -2966,17 +2960,18 @@ def test_state_priority_cli_state_overrides_env_with_pkce(self): }, ): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=cli_state, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=cli_state, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert assert redirect_params["state"] == cli_state # CLI state takes priority assert redirect_params["state"] != env_state - + # PKCE should still be generated assert code_verifier is not None assert "code_challenge" in redirect_params @@ -2990,14 +2985,15 @@ def test_empty_string_state_uses_env_variable(self): # Arrange env_state = "env_state_for_empty_cli" - + with patch.dict(os.environ, {"GENERIC_CLIENT_STATE": env_state}): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state="", # Empty string - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state="", # Empty string + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert - empty string is falsy, so env variable should be used @@ -3014,7 +3010,7 @@ def test_multiple_calls_generate_different_uuids(self): # Arrange - no state provided with patch.dict(os.environ, {}, clear=False): os.environ.pop("GENERIC_CLIENT_STATE", None) - + # Act params1, _ = SSOAuthenticationHandler._get_generic_sso_redirect_params( state=None, @@ -3077,15 +3073,18 @@ async def test_prepare_token_exchange_parameters_with_pkce(self): test_state = "test_oauth_state_123" mock_request.query_params = {"state": test_state} - # Mock cache + # Mock cache with async methods mock_cache = MagicMock() test_code_verifier = "test_code_verifier_abc123xyz" - mock_cache.get_cache.return_value = test_code_verifier + mock_cache.async_get_cache = AsyncMock(return_value=test_code_verifier) + mock_cache.async_delete_cache = AsyncMock() - with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache): + with patch("litellm.proxy.proxy_server.redis_usage_cache", None), patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache): # Act - token_params = SSOAuthenticationHandler.prepare_token_exchange_parameters( - request=mock_request, generic_include_client_id=False + token_params = ( + await SSOAuthenticationHandler.prepare_token_exchange_parameters( + request=mock_request, generic_include_client_id=False + ) ) # Assert @@ -3093,10 +3092,10 @@ async def test_prepare_token_exchange_parameters_with_pkce(self): assert token_params["code_verifier"] == test_code_verifier # Verify cache was accessed and deleted - mock_cache.get_cache.assert_called_once_with( + mock_cache.async_get_cache.assert_called_once_with( key=f"pkce_verifier:{test_state}" ) - mock_cache.delete_cache.assert_called_once_with( + mock_cache.async_delete_cache.assert_called_once_with( key=f"pkce_verifier:{test_state}" ) @@ -3121,6 +3120,8 @@ async def test_get_generic_sso_redirect_response_with_pkce(self): test_state = "test456" mock_cache = MagicMock() + mock_cache.async_set_cache = AsyncMock() + with patch.dict(os.environ, {"GENERIC_CLIENT_USE_PKCE": "true"}): with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache): # Act @@ -3131,9 +3132,9 @@ async def test_get_generic_sso_redirect_response_with_pkce(self): ) # Assert - # Verify cache was called to store code_verifier - mock_cache.set_cache.assert_called_once() - cache_call = mock_cache.set_cache.call_args + # Verify async cache was called to store code_verifier + mock_cache.async_set_cache.assert_called_once() + cache_call = mock_cache.async_set_cache.call_args assert cache_call.kwargs["key"] == f"pkce_verifier:{test_state}" assert cache_call.kwargs["ttl"] == 600 assert len(cache_call.kwargs["value"]) == 43 @@ -3145,6 +3146,178 @@ async def test_get_generic_sso_redirect_response_with_pkce(self): assert "code_challenge_method=S256" in updated_location assert f"state={test_state}" in updated_location + @pytest.mark.asyncio + async def test_pkce_redis_multi_pod_verifier_roundtrip(self): + """ + Mock Redis to verify PKCE code_verifier round-trip across "pods": + Pod A stores verifier in Redis; Pod B retrieves it (no real IdP). + """ + from litellm.proxy.management_endpoints.ui_sso import SSOAuthenticationHandler + + # In-memory mock of Redis (shared between "pods") + class MockRedisCache: + def __init__(self): + self._store = {} + + async def async_set_cache(self, key, value, **kwargs): + self._store[key] = json.dumps(value) + + async def async_get_cache(self, key, **kwargs): + val = self._store.get(key) + if val is None: + return None + # Simulate RedisCache._get_cache_logic: stored as JSON string, return decoded + if isinstance(val, str): + try: + return json.loads(val) + except (ValueError, TypeError): + return val + return val + + async def async_delete_cache(self, key): + self._store.pop(key, None) + + mock_redis = MockRedisCache() + mock_in_memory = MagicMock() + + mock_sso = MagicMock() + mock_redirect_response = MagicMock() + mock_redirect_response.headers = { + "location": "https://auth.example.com/authorize?state=multi_pod_state_xyz&client_id=abc" + } + mock_sso.get_login_redirect = AsyncMock(return_value=mock_redirect_response) + mock_sso.__enter__ = MagicMock(return_value=mock_sso) + mock_sso.__exit__ = MagicMock(return_value=False) + + with patch.dict(os.environ, {"GENERIC_CLIENT_USE_PKCE": "true"}): + with patch("litellm.proxy.proxy_server.redis_usage_cache", mock_redis): + with patch( + "litellm.proxy.proxy_server.user_api_key_cache", mock_in_memory + ): + # Pod A: start login, store code_verifier in "Redis" + await SSOAuthenticationHandler.get_generic_sso_redirect_response( + generic_sso=mock_sso, + state="multi_pod_state_xyz", + generic_authorization_endpoint="https://auth.example.com/authorize", + ) + mock_in_memory.async_set_cache.assert_not_called() + # MockRedisCache is a real class; assert on state, not .assert_called_* + stored_key = "pkce_verifier:multi_pod_state_xyz" + assert stored_key in mock_redis._store + stored_value = mock_redis._store[stored_key] + assert isinstance(stored_value, str) and len(json.loads(stored_value)) == 43 + + # Pod B: callback with same state, retrieve from "Redis" + mock_request = MagicMock(spec=Request) + mock_request.query_params = {"state": "multi_pod_state_xyz"} + token_params = await SSOAuthenticationHandler.prepare_token_exchange_parameters( + request=mock_request, generic_include_client_id=False + ) + assert "code_verifier" in token_params + assert token_params["code_verifier"] == json.loads(stored_value) + mock_in_memory.async_get_cache.assert_not_called() + # delete_cache called; key removed (asserted below) + + # Verifier consumed (single-use); key removed from "Redis" + assert "pkce_verifier:multi_pod_state_xyz" not in mock_redis._store + + @pytest.mark.asyncio + async def test_pkce_fallback_in_memory_roundtrip_when_redis_none(self): + """ + Regression: When redis_usage_cache is None (no Redis configured), + code_verifier is stored and retrieved via user_api_key_cache. + Roundtrip works when callback hits same pod (same in-memory cache). + Single-pod or no-Redis deployments must continue to work. + """ + from litellm.proxy.management_endpoints.ui_sso import SSOAuthenticationHandler + + # In-memory store (simulates user_api_key_cache on one pod) + in_memory_store = {} + + async def async_set_cache(key, value, **kwargs): + in_memory_store[key] = value + + async def async_get_cache(key, **kwargs): + return in_memory_store.get(key) + + async def async_delete_cache(key): + in_memory_store.pop(key, None) + + mock_in_memory = MagicMock() + mock_in_memory.async_set_cache = AsyncMock(side_effect=async_set_cache) + mock_in_memory.async_get_cache = AsyncMock(side_effect=async_get_cache) + mock_in_memory.async_delete_cache = AsyncMock(side_effect=async_delete_cache) + + mock_sso = MagicMock() + mock_redirect_response = MagicMock() + mock_redirect_response.headers = { + "location": "https://auth.example.com/authorize?state=fallback_state_xyz&client_id=abc" + } + mock_sso.get_login_redirect = AsyncMock(return_value=mock_redirect_response) + mock_sso.__enter__ = MagicMock(return_value=mock_sso) + mock_sso.__exit__ = MagicMock(return_value=False) + + with patch.dict(os.environ, {"GENERIC_CLIENT_USE_PKCE": "true"}): + with patch("litellm.proxy.proxy_server.redis_usage_cache", None): + with patch( + "litellm.proxy.proxy_server.user_api_key_cache", mock_in_memory + ): + # Pod A: start login, store code_verifier in in-memory cache + await SSOAuthenticationHandler.get_generic_sso_redirect_response( + generic_sso=mock_sso, + state="fallback_state_xyz", + generic_authorization_endpoint="https://auth.example.com/authorize", + ) + mock_in_memory.async_set_cache.assert_called_once() + stored_key = mock_in_memory.async_set_cache.call_args.kwargs["key"] + stored_value = mock_in_memory.async_set_cache.call_args.kwargs[ + "value" + ] + assert stored_key == "pkce_verifier:fallback_state_xyz" + assert isinstance(stored_value, str) and len(stored_value) == 43 + + # Same pod: callback retrieves from in-memory cache + mock_request = MagicMock(spec=Request) + mock_request.query_params = {"state": "fallback_state_xyz"} + token_params = await SSOAuthenticationHandler.prepare_token_exchange_parameters( + request=mock_request, generic_include_client_id=False + ) + assert "code_verifier" in token_params + assert token_params["code_verifier"] == stored_value + mock_in_memory.async_get_cache.assert_called_once_with( + key=stored_key + ) + mock_in_memory.async_delete_cache.assert_called_once_with( + key=stored_key + ) + + # Verifier consumed; key removed from in-memory + assert "pkce_verifier:fallback_state_xyz" not in in_memory_store + + @pytest.mark.asyncio + async def test_pkce_prepare_token_exchange_returns_nothing_when_no_state(self): + """ + Regression: prepare_token_exchange_parameters with no state in request + does not call cache and does not add code_verifier. + """ + from litellm.proxy.management_endpoints.ui_sso import SSOAuthenticationHandler + + mock_redis = MagicMock() + mock_in_memory = MagicMock() + + with patch("litellm.proxy.proxy_server.redis_usage_cache", mock_redis): + with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_in_memory): + mock_request = MagicMock(spec=Request) + mock_request.query_params = {} + token_params = ( + await SSOAuthenticationHandler.prepare_token_exchange_parameters( + request=mock_request, generic_include_client_id=False + ) + ) + assert "code_verifier" not in token_params + mock_redis.async_get_cache.assert_not_called() + mock_in_memory.async_get_cache.assert_not_called() + # Tests for SSO user team assignment bug (Issue: SSO Users Not Added to Entra-Synced Teams on First Login) class TestAddMissingTeamMember: @@ -3268,9 +3441,7 @@ async def test_sso_first_login_full_flow_adds_user_to_teams(self): team_member_calls = [] async def track_team_member_add(team_id, user_info): - team_member_calls.append( - {"team_id": team_id, "user_id": user_info.user_id} - ) + team_member_calls.append({"team_id": team_id, "user_id": user_info.user_id}) # New SSO user with Entra groups new_user = NewUserResponse( @@ -3331,7 +3502,6 @@ async def test_add_missing_team_member_handles_all_user_types( """ Parametrized test ensuring add_missing_team_member works for all user types. """ - from litellm.proxy._types import LiteLLM_UserTable from litellm.proxy.management_endpoints.ui_sso import add_missing_team_member user_info = user_info_factory("test-user-id") @@ -3421,7 +3591,7 @@ async def test_role_mappings_override_default_internal_user_params(): return_value=mock_new_user_response, ) as mock_new_user: # Act - result = await insert_sso_user( + _ = await insert_sso_user( result_openid=mock_result_openid, user_defined_values=user_defined_values, ) @@ -3443,7 +3613,7 @@ async def test_role_mappings_override_default_internal_user_params(): assert ( new_user_request.budget_duration == "30d" ), "budget_duration from default_internal_user_params should be applied" - + # Note: models are applied via _update_internal_new_user_params inside new_user, # not in insert_sso_user, so we verify user_defined_values was updated correctly # by checking that the function completed successfully and other defaults were applied @@ -3558,7 +3728,10 @@ async def test_sso_readiness_google_missing_secret(self): assert data["sso_configured"] is True assert data["provider"] == "google" assert "GOOGLE_CLIENT_SECRET" in data["missing_environment_variables"] - assert "Google SSO is configured but missing required environment variables" in data["message"] + assert ( + "Google SSO is configured but missing required environment variables" + in data["message"] + ) finally: app.dependency_overrides.clear() @@ -3607,7 +3780,7 @@ async def test_sso_readiness_microsoft_configurations( response = client.get("/sso/readiness") assert response.status_code == expected_status - + if expected_status == 200: data = response.json() assert data["sso_configured"] is True @@ -3677,7 +3850,7 @@ async def test_sso_readiness_generic_configurations( response = client.get("/sso/readiness") assert response.status_code == expected_status - + if expected_status == 200: data = response.json() assert data["sso_configured"] is True @@ -3722,8 +3895,14 @@ async def test_custom_microsoft_sso_uses_default_endpoints_when_no_env_vars(self discovery = await sso.get_discovery_document() - assert discovery["authorization_endpoint"] == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/authorize" - assert discovery["token_endpoint"] == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token" + assert ( + discovery["authorization_endpoint"] + == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/authorize" + ) + assert ( + discovery["token_endpoint"] + == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token" + ) assert discovery["userinfo_endpoint"] == "https://graph.microsoft.com/v1.0/me" @pytest.mark.asyncio @@ -3787,8 +3966,13 @@ async def test_custom_microsoft_sso_uses_partial_custom_endpoints(self): # Custom auth endpoint assert discovery["authorization_endpoint"] == custom_auth_endpoint # Default token and userinfo endpoints - assert discovery["token_endpoint"] == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token" - assert discovery["userinfo_endpoint"] == "https://graph.microsoft.com/v1.0/me" + assert ( + discovery["token_endpoint"] + == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token" + ) + assert ( + discovery["userinfo_endpoint"] == "https://graph.microsoft.com/v1.0/me" + ) def test_custom_microsoft_sso_uses_common_tenant_when_none(self): """ @@ -3825,11 +4009,7 @@ async def test_setup_team_mappings(): # Arrange mock_prisma = MagicMock() mock_sso_config = MagicMock() - mock_sso_config.sso_settings = { - "team_mappings": { - "team_ids_jwt_field": "groups" - } - } + mock_sso_config.sso_settings = {"team_mappings": {"team_ids_jwt_field": "groups"}} mock_prisma.db.litellm_ssoconfig.find_unique = AsyncMock( return_value=mock_sso_config )