From f78fc4e0feb3195d2c6817aa944cd148ba47272a Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Thu, 22 Jan 2026 15:32:10 -0800 Subject: [PATCH] Fix org all proxy model case --- .../management_endpoints/team_endpoints.py | 20 ++-- .../test_team_endpoints.py | 94 +++++++++++++++++++ 2 files changed, 106 insertions(+), 8 deletions(-) diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 4d313fb1235..98b1e9cb3e6 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -496,14 +496,18 @@ async def _check_org_team_limits( # Validate team models against organization's allowed models if data.models is not None and len(org_table.models) > 0: - for m in data.models: - if m not in org_table.models: - raise HTTPException( - status_code=400, - detail={ - "error": f"Model '{m}' not in organization's allowed models. Organization allowed models={org_table.models}. Organization: {org_table.organization_id}" - }, - ) + # If organization has 'all-proxy-models', skip validation as it allows all models + if SpecialModelNames.all_proxy_models.value in org_table.models: + pass + else: + for m in data.models: + if m not in org_table.models: + raise HTTPException( + status_code=400, + detail={ + "error": f"Model '{m}' not in organization's allowed models. Organization allowed models={org_table.models}. Organization: {org_table.organization_id}" + }, + ) # Validate team TPM/RPM against organization's TPM/RPM limits (direct comparison) if ( diff --git a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py index 0d78545823f..a97ed93a85c 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py @@ -3654,6 +3654,100 @@ async def test_update_team_org_scoped_models_not_in_org_models(): assert "claude-3-opus" in str(exc_info.value.message) or "organization" in str(exc_info.value.message).lower() +@pytest.mark.asyncio +async def test_update_team_org_scoped_models_with_all_proxy_models(): + """ + Test that /team/update for an org-scoped team succeeds when organization has 'all-proxy-models'. + + Scenario: + - Organization has models=['all-proxy-models'] (catch-all for all models) + - Org-scoped team exists + - User tries to update team models to ['rerank-english-v3.0', 'text-embedding-3-small', 'gpt-4o-mini-test'] + - Expected: Should succeed because 'all-proxy-models' allows all models + """ + from fastapi import Request + + from litellm.proxy._types import ( + LiteLLM_OrganizationTable, + SpecialModelNames, + UpdateTeamRequest, + UserAPIKeyAuth, + ) + from litellm.proxy.management_endpoints.team_endpoints import update_team + + # Create user (org admin) + org_admin_user = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="org-admin-all-proxy-models-test", + models=[], + ) + + # Create update request with models that aren't explicitly in org's models list + # but should be allowed because org has 'all-proxy-models' + update_request = UpdateTeamRequest( + team_id="org-team-all-proxy-models-123", + models=["rerank-english-v3.0", "text-embedding-3-small", "gpt-4o-mini-test"], + ) + + dummy_request = MagicMock(spec=Request) + + # Mock organization with 'all-proxy-models' (catch-all) + mock_org = MagicMock(spec=LiteLLM_OrganizationTable) + mock_org.organization_id = "test-org-all-proxy-models" + mock_org.models = [SpecialModelNames.all_proxy_models.value] # Allows all models + mock_org.litellm_budget_table = None + + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch( + "litellm.proxy.proxy_server.user_api_key_cache" + ) as mock_cache, patch( + "litellm.proxy.proxy_server.litellm_proxy_admin_name", "admin" + ), patch( + "litellm.proxy.proxy_server.create_audit_log_for_update", new=AsyncMock() + ) as mock_audit, patch( + "litellm.proxy.management_endpoints.team_endpoints.get_org_object", + new=AsyncMock(return_value=mock_org) + ) as mock_get_org: + + # Mock existing org-scoped team + mock_existing_team = MagicMock() + mock_existing_team.team_id = "org-team-all-proxy-models-123" + mock_existing_team.organization_id = "test-org-all-proxy-models" + mock_existing_team.models = ["gpt-4"] + mock_existing_team.model_id = None + mock_existing_team.model_dump.return_value = { + "team_id": "org-team-all-proxy-models-123", + "organization_id": "test-org-all-proxy-models", + "models": ["gpt-4"], + } + mock_prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=mock_existing_team) + mock_prisma.jsonify_team_object = lambda db_data: db_data + mock_cache.async_set_cache = AsyncMock() # Mock cache set for _cache_team_object + + # Mock team update + mock_updated_team = MagicMock() + mock_updated_team.team_id = "org-team-all-proxy-models-123" + mock_updated_team.organization_id = "test-org-all-proxy-models" + mock_updated_team.models = ["rerank-english-v3.0", "text-embedding-3-small", "gpt-4o-mini-test"] + mock_updated_team.litellm_model_table = None + mock_updated_team.model_dump.return_value = { + "team_id": "org-team-all-proxy-models-123", + "organization_id": "test-org-all-proxy-models", + "models": ["rerank-english-v3.0", "text-embedding-3-small", "gpt-4o-mini-test"], + } + mock_prisma.db.litellm_teamtable.update = AsyncMock(return_value=mock_updated_team) + + # Should NOT raise an exception - 'all-proxy-models' allows all models + result = await update_team( + data=update_request, + http_request=dummy_request, + user_api_key_dict=org_admin_user, + ) + + # Verify the team was updated successfully with the new models + assert result is not None + assert result["data"].models == ["rerank-english-v3.0", "text-embedding-3-small", "gpt-4o-mini-test"] + + @pytest.mark.asyncio async def test_update_team_tpm_limit_exceeds_user_limit(): """