Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions litellm/proxy/management_endpoints/ui_sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,15 +712,13 @@ async def get_user_info_from_db(
)

# Upsert SSO User to LiteLLM DB

if user_info is None:
user_info = await SSOAuthenticationHandler.upsert_sso_user(
result=result,
user_info=user_info,
user_email=user_email,
user_defined_values=user_defined_values,
prisma_client=prisma_client,
)
user_info = await SSOAuthenticationHandler.upsert_sso_user(
result=result,
user_info=user_info,
user_email=user_email,
user_defined_values=user_defined_values,
prisma_client=prisma_client,
)

await SSOAuthenticationHandler.add_user_to_teams_from_sso_response(
result=result,
Expand Down
190 changes: 187 additions & 3 deletions tests/test_litellm/proxy/management_endpoints/test_ui_sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,9 +844,9 @@ def test_get_user_email_and_id_extracts_microsoft_role():


@pytest.mark.asyncio
async def test_get_user_info_from_db():
async def test_get_user_info_from_db_user_exists():
"""
received args in get_user_info_from_db: {'result': CustomOpenID(id='krrishd', email='krrishdholakia@gmail.com', first_name=None, last_name=None, display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce', picture=None, provider=None, team_ids=[]), 'prisma_client': <litellm.proxy.utils.PrismaClient object at 0x14a74e3c0>, 'user_api_key_cache': <litellm.caching.dual_cache.DualCache object at 0x148d37110>, 'proxy_logging_obj': <litellm.proxy.utils.ProxyLogging object at 0x148dd9090>, 'user_email': 'krrishdholakia@gmail.com', 'user_defined_values': {'models': [], 'user_id': 'krrishd', 'user_email': 'krrishdholakia@gmail.com', 'max_budget': None, 'user_role': None, 'budget_duration': None}}
Test that get_user_info_from_db finds existing user and calls upsert_sso_user to update.
"""
from litellm.proxy.management_endpoints.ui_sso import get_user_info_from_db

Expand Down Expand Up @@ -888,7 +888,7 @@ async def test_get_user_info_from_db():


@pytest.mark.asyncio
async def test_get_user_info_from_db_alternate_user_id():
async def test_get_user_info_from_db_user_exists_alternate_user_id():
from litellm.proxy.management_endpoints.ui_sso import get_user_info_from_db

prisma_client = MagicMock()
Expand Down Expand Up @@ -929,6 +929,190 @@ async def test_get_user_info_from_db_alternate_user_id():
assert mock_get_user_object.call_args.kwargs["user_id"] == "krrishd-email1234"


@pytest.mark.asyncio
async def test_get_user_info_from_db_user_not_exists_creates_user():
"""
Test that get_user_info_from_db creates a new user when user doesn't exist in DB.

When get_existing_user_info_from_db returns None, get_user_info_from_db should:
1. Call upsert_sso_user with user_info=None
2. upsert_sso_user should call insert_sso_user to create the user
3. Add user to teams from SSO response
"""
from litellm.proxy._types import NewUserResponse, SSOUserDefinedValues
from litellm.proxy.management_endpoints.ui_sso import get_user_info_from_db

prisma_client = MagicMock()
user_api_key_cache = MagicMock()
proxy_logging_obj = MagicMock()
user_email = "newuser@example.com"
user_defined_values: SSOUserDefinedValues = {
"models": [],
"user_id": "new-user-123",
"user_email": "newuser@example.com",
"max_budget": None,
"user_role": None,
"budget_duration": None,
}

sso_result = CustomOpenID(
id="new-user-123",
email="newuser@example.com",
first_name="New",
last_name="User",
display_name="New User",
picture=None,
provider="microsoft",
team_ids=["team-1", "team-2"],
)

args = {
"result": sso_result,
"prisma_client": prisma_client,
"user_api_key_cache": user_api_key_cache,
"proxy_logging_obj": proxy_logging_obj,
"user_email": user_email,
"user_defined_values": user_defined_values,
}

# Mock new user response
mock_new_user = NewUserResponse(
user_id="new-user-123",
key="sk-xxxxx",
teams=None,
)

with patch(
"litellm.proxy.management_endpoints.ui_sso.get_existing_user_info_from_db",
return_value=None, # User doesn't exist
) as mock_get_existing, patch(
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.upsert_sso_user",
return_value=mock_new_user,
) as mock_upsert, patch(
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.add_user_to_teams_from_sso_response",
) as mock_add_teams:
# Act
user_info = await get_user_info_from_db(**args)

# Assert
# Should try to find user by id
mock_get_existing.assert_called_once()
assert mock_get_existing.call_args.kwargs["user_id"] == "new-user-123"
assert mock_get_existing.call_args.kwargs["user_email"] == "newuser@example.com"

# Should call upsert_sso_user with None user_info
mock_upsert.assert_called_once()
upsert_call_args = mock_upsert.call_args
assert upsert_call_args.kwargs["user_info"] is None
assert upsert_call_args.kwargs["user_email"] == "newuser@example.com"
assert upsert_call_args.kwargs["user_defined_values"] == user_defined_values

# Should add user to teams
mock_add_teams.assert_called_once()
add_teams_call_args = mock_add_teams.call_args
assert add_teams_call_args.kwargs["result"] == sso_result
assert add_teams_call_args.kwargs["user_info"] == mock_new_user

# Should return the new user
assert user_info == mock_new_user


@pytest.mark.asyncio
async def test_get_user_info_from_db_user_exists_updates_user():
"""
Test that get_user_info_from_db updates existing user when user exists in DB.

When get_existing_user_info_from_db returns a user, get_user_info_from_db should:
1. Call upsert_sso_user with the existing user_info
2. upsert_sso_user should update the user in the database
3. Add user to teams from SSO response
"""
from litellm.proxy._types import LiteLLM_UserTable, SSOUserDefinedValues
from litellm.proxy.management_endpoints.ui_sso import get_user_info_from_db

prisma_client = MagicMock()
user_api_key_cache = MagicMock()
proxy_logging_obj = MagicMock()
user_email = "existing@example.com"
user_defined_values: SSOUserDefinedValues = {
"models": [],
"user_id": "existing-user-456",
"user_email": "existing@example.com",
"max_budget": None,
"user_role": None,
"budget_duration": None,
}

sso_result = CustomOpenID(
id="existing-user-456",
email="existing@example.com",
first_name="Existing",
last_name="User",
display_name="Existing User",
picture=None,
provider="microsoft",
team_ids=["team-3"],
)

# Existing user in DB
existing_user = LiteLLM_UserTable(
user_id="existing-user-456",
user_email="old@example.com",
user_role="internal_user",
models=["gpt-4"],
teams=[],
)

# Updated user after upsert
updated_user = LiteLLM_UserTable(
user_id="existing-user-456",
user_email="existing@example.com", # Updated email
user_role="internal_user",
models=["gpt-4"],
teams=[],
)

args = {
"result": sso_result,
"prisma_client": prisma_client,
"user_api_key_cache": user_api_key_cache,
"proxy_logging_obj": proxy_logging_obj,
"user_email": user_email,
"user_defined_values": user_defined_values,
}

with patch(
"litellm.proxy.management_endpoints.ui_sso.get_existing_user_info_from_db",
return_value=existing_user, # User exists
) as mock_get_existing, patch(
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.upsert_sso_user",
return_value=updated_user,
) as mock_upsert, patch(
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.add_user_to_teams_from_sso_response",
) as mock_add_teams:
# Act
user_info = await get_user_info_from_db(**args)

# Assert
# Should find existing user
mock_get_existing.assert_called_once()
assert mock_get_existing.call_args.kwargs["user_id"] == "existing-user-456"

# Should call upsert_sso_user with existing user_info
mock_upsert.assert_called_once()
upsert_call_args = mock_upsert.call_args
assert upsert_call_args.kwargs["user_info"] == existing_user
assert upsert_call_args.kwargs["user_email"] == "existing@example.com"

# Should add user to teams
mock_add_teams.assert_called_once()
add_teams_call_args = mock_add_teams.call_args
assert add_teams_call_args.kwargs["result"] == sso_result
assert add_teams_call_args.kwargs["user_info"] == updated_user

# Should return the updated user
assert user_info == updated_user

@pytest.mark.asyncio
async def test_check_and_update_if_proxy_admin_id():
"""
Expand Down
Loading