Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions litellm/proxy/management_endpoints/internal_user_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,38 +721,31 @@ async def _get_user_info_for_proxy_admin(user_api_key_dict: UserAPIKeyAuth):

- get all teams in LiteLLM_TeamTable
- get all keys in LiteLLM_VerificationToken table

Why separate helper for proxy admin ?
- To get Faster UI load times, get all teams and virtual keys in 1 query
"""

from litellm.constants import UI_SESSION_TOKEN_TEAM_ID
from litellm.proxy.proxy_server import prisma_client

sql_query = """
SELECT
(SELECT json_agg(t.*) FROM "LiteLLM_TeamTable" t) as teams,
(SELECT json_agg(k.*) FROM "LiteLLM_VerificationToken" k WHERE k.team_id != 'litellm-dashboard' OR k.team_id IS NULL) as keys
"""
if prisma_client is None:
raise Exception(
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)

results = await prisma_client.db.query_raw(sql_query)

verbose_proxy_logger.debug("results_keys: %s", results)
# Fetch teams and keys concurrently using Prisma instead of raw SQL.
_teams_in_db, keys_in_db = await asyncio.gather(
prisma_client.db.litellm_teamtable.find_many(),
prisma_client.db.litellm_verificationtoken.find_many(
where={
"OR": [
{"team_id": {"not": UI_SESSION_TOKEN_TEAM_ID}},
{"team_id": None},
]
}
),
)

_keys_in_db: List = results[0]["keys"] or []
# cast all keys to LiteLLM_VerificationToken
keys_in_db = []
for key in _keys_in_db:
if key.get("models") is None:
key["models"] = []
keys_in_db.append(LiteLLM_VerificationToken(**key))
verbose_proxy_logger.debug("results_keys: %s", keys_in_db)

# cast all teams to LiteLLM_TeamTable
_teams_in_db: List = results[0]["teams"] or []
_teams_in_db = [LiteLLM_TeamTable(**team) for team in _teams_in_db]
_teams_in_db.sort(key=lambda x: (getattr(x, "team_alias", "") or ""))
returned_keys = _process_keys_for_user_info(keys=keys_in_db, all_teams=_teams_in_db)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1728,4 +1728,85 @@ async def mock_find_unique(*args, **kwargs):
# Verify each condition uses {"in": ["admin-creator"]}
for condition in or_conditions:
field = list(condition.keys())[0]
assert condition[field] == {"in": ["admin-creator"]}
assert condition[field] == {"in": ["admin-creator"]}


@pytest.mark.asyncio
async def test_get_user_info_for_proxy_admin_uses_prisma(mocker):
"""
Test that _get_user_info_for_proxy_admin uses Prisma find_many calls
instead of raw SQL, and runs them concurrently.
"""
from unittest.mock import MagicMock

from litellm.constants import UI_SESSION_TOKEN_TEAM_ID
from litellm.proxy._types import (
LiteLLM_TeamTable,
LiteLLM_VerificationToken,
LitellmUserRoles,
UserAPIKeyAuth,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import (
_get_user_info_for_proxy_admin,
)

mock_team = LiteLLM_TeamTable(team_id="team-1", team_alias="Team One")
mock_key = MagicMock(spec=LiteLLM_VerificationToken)
mock_key.token = "hashed_token_123"
mock_key.team_id = "team-1"
mock_key.model_dump = MagicMock(return_value={
"token": "hashed_token_123",
"team_id": "team-1",
"models": [],
})

mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_teamtable.find_many = mocker.AsyncMock(
return_value=[mock_team]
)
mock_prisma_client.db.litellm_verificationtoken.find_many = mocker.AsyncMock(
return_value=[mock_key]
)
mock_prisma_client.get_data = mocker.AsyncMock(return_value=None)

mocker.patch(
"litellm.proxy.proxy_server.prisma_client",
mock_prisma_client,
)
mocker.patch(
"litellm.proxy.proxy_server.general_settings",
{},
)
mocker.patch(
"litellm.proxy.proxy_server.litellm_master_key_hash",
"master_key_hash",
)

user_api_key_dict = UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
user_id="admin_user",
)

result = await _get_user_info_for_proxy_admin(user_api_key_dict)

# Verify Prisma calls (not raw SQL)
mock_prisma_client.db.litellm_teamtable.find_many.assert_called_once_with()
mock_prisma_client.db.litellm_verificationtoken.find_many.assert_called_once_with(
where={
"OR": [
{"team_id": {"not": UI_SESSION_TOKEN_TEAM_ID}},
{"team_id": None},
]
}
)

# Verify no raw SQL was used
mock_prisma_client.db.query_raw.assert_not_called()

# Verify response shape
assert result.user_id == "admin_user"
assert len(result.teams) == 1
assert result.teams[0].team_id == "team-1"
assert len(result.keys) == 1
assert result.keys[0]["team_id"] == "team-1"
assert result.keys[0]["team_alias"] == "Team One"
Loading