diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 3075792f9f2..5adf54c1627 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -712,15 +712,13 @@ async def get_user_info_from_db( ) # Upsert SSO User to LiteLLM DB - - if user_info is None: - user_info = await SSOAuthenticationHandler.upsert_sso_user( - result=result, - user_info=user_info, - user_email=user_email, - user_defined_values=user_defined_values, - prisma_client=prisma_client, - ) + user_info = await SSOAuthenticationHandler.upsert_sso_user( + result=result, + user_info=user_info, + user_email=user_email, + user_defined_values=user_defined_values, + prisma_client=prisma_client, + ) await SSOAuthenticationHandler.add_user_to_teams_from_sso_response( result=result, diff --git a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py index f0cf5459ad2..f983af2d0b2 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py @@ -844,9 +844,9 @@ def test_get_user_email_and_id_extracts_microsoft_role(): @pytest.mark.asyncio -async def test_get_user_info_from_db(): +async def test_get_user_info_from_db_user_exists(): """ - received args in get_user_info_from_db: {'result': CustomOpenID(id='krrishd', email='krrishdholakia@gmail.com', first_name=None, last_name=None, display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce', picture=None, provider=None, team_ids=[]), 'prisma_client': , 'user_api_key_cache': , 'proxy_logging_obj': , 'user_email': 'krrishdholakia@gmail.com', 'user_defined_values': {'models': [], 'user_id': 'krrishd', 'user_email': 'krrishdholakia@gmail.com', 'max_budget': None, 'user_role': None, 'budget_duration': None}} + Test that get_user_info_from_db finds existing user and calls upsert_sso_user to update. """ from litellm.proxy.management_endpoints.ui_sso import get_user_info_from_db @@ -888,7 +888,7 @@ async def test_get_user_info_from_db(): @pytest.mark.asyncio -async def test_get_user_info_from_db_alternate_user_id(): +async def test_get_user_info_from_db_user_exists_alternate_user_id(): from litellm.proxy.management_endpoints.ui_sso import get_user_info_from_db prisma_client = MagicMock() @@ -929,6 +929,190 @@ async def test_get_user_info_from_db_alternate_user_id(): assert mock_get_user_object.call_args.kwargs["user_id"] == "krrishd-email1234" +@pytest.mark.asyncio +async def test_get_user_info_from_db_user_not_exists_creates_user(): + """ + Test that get_user_info_from_db creates a new user when user doesn't exist in DB. + + When get_existing_user_info_from_db returns None, get_user_info_from_db should: + 1. Call upsert_sso_user with user_info=None + 2. upsert_sso_user should call insert_sso_user to create the user + 3. Add user to teams from SSO response + """ + from litellm.proxy._types import NewUserResponse, SSOUserDefinedValues + from litellm.proxy.management_endpoints.ui_sso import get_user_info_from_db + + prisma_client = MagicMock() + user_api_key_cache = MagicMock() + proxy_logging_obj = MagicMock() + user_email = "newuser@example.com" + user_defined_values: SSOUserDefinedValues = { + "models": [], + "user_id": "new-user-123", + "user_email": "newuser@example.com", + "max_budget": None, + "user_role": None, + "budget_duration": None, + } + + sso_result = CustomOpenID( + id="new-user-123", + email="newuser@example.com", + first_name="New", + last_name="User", + display_name="New User", + picture=None, + provider="microsoft", + team_ids=["team-1", "team-2"], + ) + + args = { + "result": sso_result, + "prisma_client": prisma_client, + "user_api_key_cache": user_api_key_cache, + "proxy_logging_obj": proxy_logging_obj, + "user_email": user_email, + "user_defined_values": user_defined_values, + } + + # Mock new user response + mock_new_user = NewUserResponse( + user_id="new-user-123", + key="sk-xxxxx", + teams=None, + ) + + with patch( + "litellm.proxy.management_endpoints.ui_sso.get_existing_user_info_from_db", + return_value=None, # User doesn't exist + ) as mock_get_existing, patch( + "litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.upsert_sso_user", + return_value=mock_new_user, + ) as mock_upsert, patch( + "litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.add_user_to_teams_from_sso_response", + ) as mock_add_teams: + # Act + user_info = await get_user_info_from_db(**args) + + # Assert + # Should try to find user by id + mock_get_existing.assert_called_once() + assert mock_get_existing.call_args.kwargs["user_id"] == "new-user-123" + assert mock_get_existing.call_args.kwargs["user_email"] == "newuser@example.com" + + # Should call upsert_sso_user with None user_info + mock_upsert.assert_called_once() + upsert_call_args = mock_upsert.call_args + assert upsert_call_args.kwargs["user_info"] is None + assert upsert_call_args.kwargs["user_email"] == "newuser@example.com" + assert upsert_call_args.kwargs["user_defined_values"] == user_defined_values + + # Should add user to teams + mock_add_teams.assert_called_once() + add_teams_call_args = mock_add_teams.call_args + assert add_teams_call_args.kwargs["result"] == sso_result + assert add_teams_call_args.kwargs["user_info"] == mock_new_user + + # Should return the new user + assert user_info == mock_new_user + + +@pytest.mark.asyncio +async def test_get_user_info_from_db_user_exists_updates_user(): + """ + Test that get_user_info_from_db updates existing user when user exists in DB. + + When get_existing_user_info_from_db returns a user, get_user_info_from_db should: + 1. Call upsert_sso_user with the existing user_info + 2. upsert_sso_user should update the user in the database + 3. Add user to teams from SSO response + """ + from litellm.proxy._types import LiteLLM_UserTable, SSOUserDefinedValues + from litellm.proxy.management_endpoints.ui_sso import get_user_info_from_db + + prisma_client = MagicMock() + user_api_key_cache = MagicMock() + proxy_logging_obj = MagicMock() + user_email = "existing@example.com" + user_defined_values: SSOUserDefinedValues = { + "models": [], + "user_id": "existing-user-456", + "user_email": "existing@example.com", + "max_budget": None, + "user_role": None, + "budget_duration": None, + } + + sso_result = CustomOpenID( + id="existing-user-456", + email="existing@example.com", + first_name="Existing", + last_name="User", + display_name="Existing User", + picture=None, + provider="microsoft", + team_ids=["team-3"], + ) + + # Existing user in DB + existing_user = LiteLLM_UserTable( + user_id="existing-user-456", + user_email="old@example.com", + user_role="internal_user", + models=["gpt-4"], + teams=[], + ) + + # Updated user after upsert + updated_user = LiteLLM_UserTable( + user_id="existing-user-456", + user_email="existing@example.com", # Updated email + user_role="internal_user", + models=["gpt-4"], + teams=[], + ) + + args = { + "result": sso_result, + "prisma_client": prisma_client, + "user_api_key_cache": user_api_key_cache, + "proxy_logging_obj": proxy_logging_obj, + "user_email": user_email, + "user_defined_values": user_defined_values, + } + + with patch( + "litellm.proxy.management_endpoints.ui_sso.get_existing_user_info_from_db", + return_value=existing_user, # User exists + ) as mock_get_existing, patch( + "litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.upsert_sso_user", + return_value=updated_user, + ) as mock_upsert, patch( + "litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.add_user_to_teams_from_sso_response", + ) as mock_add_teams: + # Act + user_info = await get_user_info_from_db(**args) + + # Assert + # Should find existing user + mock_get_existing.assert_called_once() + assert mock_get_existing.call_args.kwargs["user_id"] == "existing-user-456" + + # Should call upsert_sso_user with existing user_info + mock_upsert.assert_called_once() + upsert_call_args = mock_upsert.call_args + assert upsert_call_args.kwargs["user_info"] == existing_user + assert upsert_call_args.kwargs["user_email"] == "existing@example.com" + + # Should add user to teams + mock_add_teams.assert_called_once() + add_teams_call_args = mock_add_teams.call_args + assert add_teams_call_args.kwargs["result"] == sso_result + assert add_teams_call_args.kwargs["user_info"] == updated_user + + # Should return the updated user + assert user_info == updated_user + @pytest.mark.asyncio async def test_check_and_update_if_proxy_admin_id(): """