diff --git a/litellm/proxy/management_endpoints/common_utils.py b/litellm/proxy/management_endpoints/common_utils.py index c16b4c4b93d..8f7dd4f8dfa 100644 --- a/litellm/proxy/management_endpoints/common_utils.py +++ b/litellm/proxy/management_endpoints/common_utils.py @@ -1,5 +1,7 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, TYPE_CHECKING +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache from litellm.proxy._types import ( KeyRequestBase, LiteLLM_ManagementEndpoint_MetadataFields, @@ -11,6 +13,9 @@ ) from litellm.proxy.utils import _premium_user_check +if TYPE_CHECKING: + from litellm.proxy.utils import PrismaClient, ProxyLogging + def _user_has_admin_view(user_api_key_dict: UserAPIKeyAuth) -> bool: return ( @@ -31,6 +36,78 @@ def _is_user_team_admin( return False +async def _user_has_admin_privileges( + user_api_key_dict: UserAPIKeyAuth, + prisma_client: Optional["PrismaClient"] = None, + user_api_key_cache: Optional["DualCache"] = None, + proxy_logging_obj: Optional["ProxyLogging"] = None, +) -> bool: + """ + Check if user has admin privileges (proxy admin, team admin, or org admin). + + Args: + user_api_key_dict: User API key authentication object + prisma_client: Prisma client for database operations + user_api_key_cache: Cache for user API keys + proxy_logging_obj: Proxy logging object + + Returns: + True if user is proxy admin, team admin for any team, or org admin for any organization + """ + # Check if user is proxy admin + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN: + return True + + # If no database connection, can't check team/org admin status + if prisma_client is None or user_api_key_dict.user_id is None: + return False + + # Get user object to check team and org admin status + from litellm.caching import DualCache as DualCacheImport + from litellm.proxy.auth.auth_checks import get_user_object + + try: + user_obj = await get_user_object( + user_id=user_api_key_dict.user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache or DualCacheImport(), + user_id_upsert=False, + proxy_logging_obj=proxy_logging_obj, + ) + + if user_obj is None: + return False + + # Check if user is org admin for any organization + if user_obj.organization_memberships is not None: + for membership in user_obj.organization_memberships: + if membership.user_role == LitellmUserRoles.ORG_ADMIN.value: + return True + + # Check if user is team admin for any team + if user_obj.teams is not None and len(user_obj.teams) > 0: + # Get all teams user is in + teams = await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": user_obj.teams}} + ) + + for team in teams: + team_obj = LiteLLM_TeamTable(**team.model_dump()) + if _is_user_team_admin( + user_api_key_dict=user_api_key_dict, team_obj=team_obj + ): + return True + + except Exception as e: + # If there's an error checking, default to False for security + verbose_proxy_logger.debug( + f"Error checking admin privileges for user {user_api_key_dict.user_id}: {e}" + ) + return False + + return False + + def _set_object_metadata_field( object_data: Union[ LiteLLM_TeamTable, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8ae07121177..f49493286cb 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -5070,6 +5070,7 @@ async def model_list( only_model_access_groups: Optional[bool] = False, include_metadata: Optional[bool] = False, fallback_type: Optional[str] = None, + scope: Optional[str] = None, ): """ Use `/model/info` - to get detailed model information, example - pricing, mode, etc. @@ -5080,14 +5081,85 @@ async def model_list( - include_metadata: Include additional metadata in the response with fallback information - fallback_type: Type of fallbacks to include ("general", "context_window", "content_policy") Defaults to "general" when include_metadata=true + - scope: Optional scope parameter. Currently only accepts "expand". + When scope=expand is passed, proxy admins, team admins, and org admins + will receive all proxy models as if they are a proxy admin. """ global llm_model_list, general_settings, llm_router, prisma_client, user_api_key_cache, proxy_logging_obj + from litellm.proxy.management_endpoints.common_utils import ( + _user_has_admin_privileges, + ) from litellm.proxy.utils import ( create_model_info_response, get_available_models_for_user, ) + # Validate scope parameter if provided + if scope is not None and scope != "expand": + raise HTTPException( + status_code=400, + detail=f"Invalid scope parameter. Only 'expand' is currently supported. Received: {scope}", + ) + + # Check if scope=expand is requested and user has admin privileges + should_expand_scope = False + if scope == "expand": + should_expand_scope = await _user_has_admin_privileges( + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + # If scope=expand and user has admin privileges, return all proxy models + if should_expand_scope: + # Get all proxy models as if user is a proxy admin + if llm_router is None: + proxy_model_list = [] + model_access_groups = {} + else: + proxy_model_list = llm_router.get_model_names() + model_access_groups = llm_router.get_model_access_groups() + + # Include model access groups if requested + if include_model_access_groups: + proxy_model_list = list(set(proxy_model_list + list(model_access_groups.keys()))) + + # Get complete model list including wildcard routes if requested + from litellm.proxy.auth.model_checks import get_complete_model_list + + all_models = get_complete_model_list( + key_models=[], + team_models=[], + proxy_model_list=proxy_model_list, + user_model=None, + infer_model_from_keys=False, + return_wildcard_routes=return_wildcard_routes or False, + llm_router=llm_router, + model_access_groups=model_access_groups, + include_model_access_groups=include_model_access_groups or False, + only_model_access_groups=only_model_access_groups or False, + ) + + # Build response data with all proxy models + model_data = [] + for model in all_models: + model_info = create_model_info_response( + model_id=model, + provider="openai", + include_metadata=include_metadata or False, + fallback_type=fallback_type, + llm_router=llm_router, + ) + model_data.append(model_info) + + return dict( + data=model_data, + object="list", + ) + + # Otherwise, use the normal behavior (current implementation) # Get available models for the user all_models = await get_available_models_for_user( user_api_key_dict=user_api_key_dict, diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index f1854380efe..cb519e9f509 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -3625,3 +3625,442 @@ def test_enrich_model_info_with_litellm_data(): assert call_args["model_info"]["id"] == "existing-id" assert call_args["model_info"]["custom_key"] == "custom_value" assert call_args["model_info"]["input_cost_per_token"] == 0.001 + + +@pytest.mark.asyncio +async def test_model_list_scope_parameter_validation(monkeypatch): + """Test that invalid scope parameter raises HTTPException""" + from fastapi import HTTPException + from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles + from litellm.proxy.proxy_server import model_list + + mock_user_api_key_dict = UserAPIKeyAuth( + user_id="test-user", + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="test-key", + ) + + # Test invalid scope parameter + with pytest.raises(HTTPException) as exc_info: + await model_list( + user_api_key_dict=mock_user_api_key_dict, + scope="invalid_scope", + ) + + assert exc_info.value.status_code == 400 + assert "Invalid scope parameter" in exc_info.value.detail + assert "Only 'expand' is currently supported" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_model_list_scope_expand_proxy_admin(monkeypatch): + """Test that proxy admin with scope=expand returns all proxy models""" + from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles, LiteLLM_UserTable + from litellm.proxy.proxy_server import model_list + + # Mock user API key dict for proxy admin + mock_user_api_key_dict = UserAPIKeyAuth( + user_id="proxy-admin-user", + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="test-key", + ) + + # Mock llm_router with proxy models + mock_router = MagicMock() + mock_router.get_model_names.return_value = ["gpt-4", "gpt-3.5-turbo", "claude-3-opus"] + mock_router.get_model_access_groups.return_value = {} + + # Mock prisma_client + mock_prisma_client = MagicMock() + + # Mock user_api_key_cache + mock_user_api_key_cache = MagicMock() + + # Mock proxy_logging_obj + mock_proxy_logging_obj = MagicMock() + + # Mock get_complete_model_list + mock_all_models = ["gpt-4", "gpt-3.5-turbo", "claude-3-opus"] + + # Mock create_model_info_response + def mock_create_model_info_response(model_id, provider, include_metadata=False, fallback_type=None, llm_router=None): + return {"id": model_id, "object": "model"} + + # Apply monkeypatches + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr("litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache) + monkeypatch.setattr("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj) + monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {}) + monkeypatch.setattr("litellm.proxy.proxy_server.user_model", None) + monkeypatch.setattr( + "litellm.proxy.auth.model_checks.get_complete_model_list", + lambda **kwargs: mock_all_models, + ) + monkeypatch.setattr( + "litellm.proxy.utils.create_model_info_response", + mock_create_model_info_response, + ) + + # Call model_list with scope=expand + result = await model_list( + user_api_key_dict=mock_user_api_key_dict, + scope="expand", + ) + + # Verify result contains all proxy models + assert result["object"] == "list" + assert len(result["data"]) == 3 + assert all(model["id"] in mock_all_models for model in result["data"]) + + # Verify router methods were called + mock_router.get_model_names.assert_called_once() + mock_router.get_model_access_groups.assert_called_once() + + +@pytest.mark.asyncio +async def test_model_list_scope_expand_org_admin(monkeypatch): + """Test that org admin with scope=expand returns all proxy models""" + from litellm.proxy._types import ( + UserAPIKeyAuth, + LitellmUserRoles, + LiteLLM_UserTable, + ) + from litellm.proxy.proxy_server import model_list + + # Mock user API key dict for org admin + mock_user_api_key_dict = UserAPIKeyAuth( + user_id="org-admin-user", + user_role=LitellmUserRoles.INTERNAL_USER, # Not proxy admin, but org admin + api_key="test-key", + ) + + # Mock user object with org admin membership + from litellm.proxy._types import LiteLLM_OrganizationMembershipTable + from datetime import datetime + + mock_user_obj = LiteLLM_UserTable( + user_id="org-admin-user", + user_email="org-admin@example.com", + organization_memberships=[ + LiteLLM_OrganizationMembershipTable( + user_id="org-admin-user", + organization_id="org-123", + user_role=LitellmUserRoles.ORG_ADMIN.value, + spend=0.0, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + ], + teams=[], + ) + + # Mock llm_router with proxy models + mock_router = MagicMock() + mock_router.get_model_names.return_value = ["gpt-4", "gpt-3.5-turbo", "claude-3-opus"] + mock_router.get_model_access_groups.return_value = {} + + # Mock prisma_client + mock_prisma_client = MagicMock() + + # Mock user_api_key_cache + mock_user_api_key_cache = MagicMock() + + # Mock proxy_logging_obj + mock_proxy_logging_obj = MagicMock() + + # Mock get_user_object to return user with org admin role + async def mock_get_user_object(*args, **kwargs): + return mock_user_obj + + # Mock get_complete_model_list + mock_all_models = ["gpt-4", "gpt-3.5-turbo", "claude-3-opus"] + + # Mock create_model_info_response + def mock_create_model_info_response(model_id, provider, include_metadata=False, fallback_type=None, llm_router=None): + return {"id": model_id, "object": "model"} + + # Apply monkeypatches + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr("litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache) + monkeypatch.setattr("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj) + monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {}) + monkeypatch.setattr("litellm.proxy.proxy_server.user_model", None) + monkeypatch.setattr( + "litellm.proxy.auth.auth_checks.get_user_object", + mock_get_user_object, + ) + monkeypatch.setattr( + "litellm.proxy.auth.model_checks.get_complete_model_list", + lambda **kwargs: mock_all_models, + ) + monkeypatch.setattr( + "litellm.proxy.utils.create_model_info_response", + mock_create_model_info_response, + ) + + # Call model_list with scope=expand + result = await model_list( + user_api_key_dict=mock_user_api_key_dict, + scope="expand", + ) + + # Verify result contains all proxy models + assert result["object"] == "list" + assert len(result["data"]) == 3 + assert all(model["id"] in mock_all_models for model in result["data"]) + + # Verify router methods were called + mock_router.get_model_names.assert_called_once() + mock_router.get_model_access_groups.assert_called_once() + + +@pytest.mark.asyncio +async def test_model_list_scope_expand_team_admin(monkeypatch): + """Test that team admin with scope=expand returns all proxy models""" + from litellm.proxy._types import ( + UserAPIKeyAuth, + LitellmUserRoles, + LiteLLM_UserTable, + LiteLLM_TeamTable, + ) + from litellm.proxy.proxy_server import model_list + + # Mock user API key dict for team admin + mock_user_api_key_dict = UserAPIKeyAuth( + user_id="team-admin-user", + user_role=LitellmUserRoles.INTERNAL_USER, # Not proxy admin, but team admin + api_key="test-key", + ) + + # Mock team with user as admin - use dict structure that matches Prisma return + mock_team = MagicMock() + mock_team.model_dump.return_value = { + "team_id": "team-123", + "members_with_roles": [ + {"user_id": "team-admin-user", "role": "admin"} + ], + } + # Create team object from the dict (validator will convert members_with_roles to Member objects) + mock_team_obj = LiteLLM_TeamTable(**mock_team.model_dump()) + + # Mock user object with team membership + mock_user_obj = LiteLLM_UserTable( + user_id="team-admin-user", + user_email="team-admin@example.com", + organization_memberships=[], + teams=["team-123"], + ) + + # Mock llm_router with proxy models + mock_router = MagicMock() + mock_router.get_model_names.return_value = ["gpt-4", "gpt-3.5-turbo", "claude-3-opus"] + mock_router.get_model_access_groups.return_value = {} + + # Mock prisma_client + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_teamtable.find_many = AsyncMock( + return_value=[mock_team] + ) + + # Mock user_api_key_cache + mock_user_api_key_cache = MagicMock() + + # Mock proxy_logging_obj + mock_proxy_logging_obj = MagicMock() + + # Mock get_user_object to return user with team membership + async def mock_get_user_object(*args, **kwargs): + return mock_user_obj + + # Mock get_complete_model_list + mock_all_models = ["gpt-4", "gpt-3.5-turbo", "claude-3-opus"] + + # Mock create_model_info_response + def mock_create_model_info_response(model_id, provider, include_metadata=False, fallback_type=None, llm_router=None): + return {"id": model_id, "object": "model"} + + # Apply monkeypatches + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr("litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache) + monkeypatch.setattr("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj) + monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {}) + monkeypatch.setattr("litellm.proxy.proxy_server.user_model", None) + monkeypatch.setattr( + "litellm.proxy.auth.auth_checks.get_user_object", + mock_get_user_object, + ) + monkeypatch.setattr( + "litellm.proxy.auth.model_checks.get_complete_model_list", + lambda **kwargs: mock_all_models, + ) + monkeypatch.setattr( + "litellm.proxy.utils.create_model_info_response", + mock_create_model_info_response, + ) + + # Call model_list with scope=expand + result = await model_list( + user_api_key_dict=mock_user_api_key_dict, + scope="expand", + ) + + # Verify result contains all proxy models + assert result["object"] == "list" + assert len(result["data"]) == 3 + assert all(model["id"] in mock_all_models for model in result["data"]) + + # Verify router methods were called + mock_router.get_model_names.assert_called_once() + mock_router.get_model_access_groups.assert_called_once() + + +@pytest.mark.asyncio +async def test_model_list_scope_expand_normal_user(monkeypatch): + """Test that normal internal user with scope=expand returns only their models (not expanded)""" + from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles, LiteLLM_UserTable + from litellm.proxy.proxy_server import model_list + + # Mock user API key dict for normal internal user + mock_user_api_key_dict = UserAPIKeyAuth( + user_id="normal-user", + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="test-key", + models=["gpt-3.5-turbo"], # User only has access to this model + ) + + # Mock user object without admin privileges + mock_user_obj = LiteLLM_UserTable( + user_id="normal-user", + user_email="normal@example.com", + organization_memberships=[], # No org admin + teams=[], # No teams + ) + + # Mock llm_router + mock_router = MagicMock() + mock_router.get_model_names.return_value = ["gpt-4", "gpt-3.5-turbo", "claude-3-opus"] + + # Mock prisma_client + mock_prisma_client = MagicMock() + + # Mock user_api_key_cache + mock_user_api_key_cache = MagicMock() + + # Mock proxy_logging_obj + mock_proxy_logging_obj = MagicMock() + + # Mock get_user_object to return user without admin privileges + async def mock_get_user_object(*args, **kwargs): + return mock_user_obj + + # Mock get_available_models_for_user to return only user's models + async def mock_get_available_models_for_user(*args, **kwargs): + return ["gpt-3.5-turbo"] # Only user's accessible models + + # Mock create_model_info_response + def mock_create_model_info_response(model_id, provider, include_metadata=False, fallback_type=None, llm_router=None): + return {"id": model_id, "object": "model"} + + # Apply monkeypatches + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr("litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache) + monkeypatch.setattr("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj) + monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {}) + monkeypatch.setattr("litellm.proxy.proxy_server.user_model", None) + monkeypatch.setattr( + "litellm.proxy.auth.auth_checks.get_user_object", + mock_get_user_object, + ) + monkeypatch.setattr( + "litellm.proxy.utils.get_available_models_for_user", + mock_get_available_models_for_user, + ) + monkeypatch.setattr( + "litellm.proxy.utils.create_model_info_response", + mock_create_model_info_response, + ) + + # Call model_list with scope=expand + result = await model_list( + user_api_key_dict=mock_user_api_key_dict, + scope="expand", + ) + + # Verify result contains only user's models (not all proxy models) + assert result["object"] == "list" + assert len(result["data"]) == 1 + assert result["data"][0]["id"] == "gpt-3.5-turbo" + + # Verify router methods were NOT called (normal path, not expanded) + mock_router.get_model_names.assert_not_called() + mock_router.get_model_access_groups.assert_not_called() + + +@pytest.mark.asyncio +async def test_model_list_no_scope_parameter(monkeypatch): + """Test that model_list without scope parameter uses normal behavior""" + from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles + from litellm.proxy.proxy_server import model_list + + # Mock user API key dict + mock_user_api_key_dict = UserAPIKeyAuth( + user_id="test-user", + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="test-key", + models=["gpt-3.5-turbo"], + ) + + # Mock llm_router + mock_router = MagicMock() + + # Mock prisma_client + mock_prisma_client = MagicMock() + + # Mock user_api_key_cache + mock_user_api_key_cache = MagicMock() + + # Mock proxy_logging_obj + mock_proxy_logging_obj = MagicMock() + + # Mock get_available_models_for_user + async def mock_get_available_models_for_user(*args, **kwargs): + return ["gpt-3.5-turbo"] + + # Mock create_model_info_response + def mock_create_model_info_response(model_id, provider, include_metadata=False, fallback_type=None, llm_router=None): + return {"id": model_id, "object": "model"} + + # Apply monkeypatches + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr("litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache) + monkeypatch.setattr("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj) + monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {}) + monkeypatch.setattr("litellm.proxy.proxy_server.user_model", None) + monkeypatch.setattr( + "litellm.proxy.utils.get_available_models_for_user", + mock_get_available_models_for_user, + ) + monkeypatch.setattr( + "litellm.proxy.utils.create_model_info_response", + mock_create_model_info_response, + ) + + # Call model_list without scope parameter + result = await model_list( + user_api_key_dict=mock_user_api_key_dict, + scope=None, + ) + + # Verify result uses normal behavior + assert result["object"] == "list" + assert len(result["data"]) == 1 + assert result["data"][0]["id"] == "gpt-3.5-turbo" + + # Verify router methods were NOT called (normal path) + mock_router.get_model_names.assert_not_called() + mock_router.get_model_access_groups.assert_not_called()