diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 385b4b0de32..264c7d765b3 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -545,9 +545,6 @@ router_settings: | DEFAULT_MAX_TOKENS | Default maximum tokens for LLM calls. Default is 4096 | DEFAULT_MAX_TOKENS_FOR_TRITON | Default maximum tokens for Triton models. Default is 2000 | DEFAULT_MAX_REDIS_BATCH_CACHE_SIZE | Default maximum size for redis batch cache. Default is 1000 -| DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL | Default embedding model for MCP semantic tool filtering. Default is "text-embedding-3-small" -| DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD | Default similarity threshold for MCP semantic tool filtering. Default is 0.3 -| DEFAULT_MCP_SEMANTIC_FILTER_TOP_K | Default number of top results to return for MCP semantic tool filtering. Default is 10 | DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT | Default token count for mock response completions. Default is 20 | DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT | Default token count for mock response prompts. Default is 10 | DEFAULT_MODEL_CREATED_AT_TIME | Default creation timestamp for models. Default is 1677610602 @@ -805,7 +802,6 @@ router_settings: | MAXIMUM_TRACEBACK_LINES_TO_LOG | Maximum number of lines to log in traceback in LiteLLM Logs UI. Default is 100 | MAX_RETRY_DELAY | Maximum delay in seconds for retrying requests. Default is 8.0 | MAX_LANGFUSE_INITIALIZED_CLIENTS | Maximum number of Langfuse clients to initialize on proxy. Default is 50. This is set since langfuse initializes 1 thread everytime a client is initialized. We've had an incident in the past where we reached 100% cpu utilization because Langfuse was initialized several times. -| MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH | Maximum header length for MCP semantic filter tools. Default is 150 | MIN_NON_ZERO_TEMPERATURE | Minimum non-zero temperature value. Default is 0.0001 | MINIMUM_PROMPT_CACHE_TOKEN_COUNT | Minimum token count for caching a prompt. Default is 1024 | MISTRAL_API_BASE | Base URL for Mistral API. Default is https://api.mistral.ai diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260129103648_add_verificationtoken_indexes/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260129103648_add_verificationtoken_indexes/migration.sql new file mode 100644 index 00000000000..572eea9b529 --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260129103648_add_verificationtoken_indexes/migration.sql @@ -0,0 +1,8 @@ +-- CreateIndex +CREATE INDEX "LiteLLM_VerificationToken_user_id_team_id_idx" ON "LiteLLM_VerificationToken"("user_id", "team_id"); + +-- CreateIndex +CREATE INDEX "LiteLLM_VerificationToken_team_id_idx" ON "LiteLLM_VerificationToken"("team_id"); + +-- CreateIndex +CREATE INDEX "LiteLLM_VerificationToken_budget_reset_at_expires_idx" ON "LiteLLM_VerificationToken"("budget_reset_at", "expires"); diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index b118400b620..3b81da10923 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -305,6 +305,16 @@ model LiteLLM_VerificationToken { litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id]) + + // SELECT COUNT(*) FROM (SELECT "public"."LiteLLM_VerificationToken"."token" FROM "public"."LiteLLM_VerificationToken" WHERE ("public"."LiteLLM_VerificationToken"."user_id" = $1 AND ("public"."LiteLLM_VerificationToken"."team_id" IS NULL OR "public"."LiteLLM_VerificationToken"."team_id" <> $2)) OFFSET $3 ) AS "sub" + // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."user_id" = $1 OFFSET $2 + @@index([user_id, team_id]) + + // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."team_id" = $1 OFFSET $2 + @@index([team_id]) + + // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (("public"."LiteLLM_VerificationToken"."expires" IS NULL OR "public"."LiteLLM_VerificationToken"."expires" > $1) AND "public"."LiteLLM_VerificationToken"."budget_reset_at" < $2) OFFSET $3 + @@index([budget_reset_at, expires]) } // Audit table for deleted keys - preserves spend and key information for historical tracking diff --git a/litellm/proxy/auth/model_checks.py b/litellm/proxy/auth/model_checks.py index 71ae1348f39..af2574d88ee 100644 --- a/litellm/proxy/auth/model_checks.py +++ b/litellm/proxy/auth/model_checks.py @@ -64,6 +64,27 @@ def _get_models_from_access_groups( return all_models +def get_access_groups_from_models( + model_access_groups: Dict[str, List[str]], + models: List[str], +) -> List[str]: + """ + Extract access group names from a models list. + + Given a models list like ["gpt-4", "beta-models", "claude-v1"] + and access groups like {"beta-models": ["gpt-5", "gpt-6"]}, + returns ["beta-models"]. + + This is used to pass allowed access groups to the router for filtering + deployments during load balancing (GitHub issue #18333). + """ + access_groups = [] + for model in models: + if model in model_access_groups: + access_groups.append(model) + return access_groups + + async def get_mcp_server_ids( user_api_key_dict: UserAPIKeyAuth, ) -> List[str]: @@ -80,7 +101,6 @@ async def get_mcp_server_ids( # Make a direct SQL query to get just the mcp_servers try: - result = await prisma_client.db.litellm_objectpermissiontable.find_unique( where={"object_permission_id": user_api_key_dict.object_permission_id}, ) @@ -176,6 +196,7 @@ def get_complete_model_list( """ unique_models = [] + def append_unique(models): for model in models: if model not in unique_models: @@ -188,7 +209,7 @@ def append_unique(models): else: append_unique(proxy_model_list) if include_model_access_groups: - append_unique(list(model_access_groups.keys())) # TODO: keys order + append_unique(list(model_access_groups.keys())) # TODO: keys order if user_model: append_unique([user_model]) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 9be78264e85..72f23e609ab 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1021,6 +1021,37 @@ async def add_litellm_data_to_request( # noqa: PLR0915 "user_api_key_user_max_budget" ] = user_api_key_dict.user_max_budget + # Extract allowed access groups for router filtering (GitHub issue #18333) + # This allows the router to filter deployments based on key's and team's access groups + # NOTE: We keep key and team access groups SEPARATE because a key doesn't always + # inherit all team access groups (per maintainer feedback). + if llm_router is not None: + from litellm.proxy.auth.model_checks import get_access_groups_from_models + + model_access_groups = llm_router.get_model_access_groups() + + # Key-level access groups (from user_api_key_dict.models) + key_models = list(user_api_key_dict.models) if user_api_key_dict.models else [] + key_allowed_access_groups = get_access_groups_from_models( + model_access_groups=model_access_groups, models=key_models + ) + if key_allowed_access_groups: + data[_metadata_variable_name][ + "user_api_key_allowed_access_groups" + ] = key_allowed_access_groups + + # Team-level access groups (from user_api_key_dict.team_models) + team_models = ( + list(user_api_key_dict.team_models) if user_api_key_dict.team_models else [] + ) + team_allowed_access_groups = get_access_groups_from_models( + model_access_groups=model_access_groups, models=team_models + ) + if team_allowed_access_groups: + data[_metadata_variable_name][ + "user_api_key_team_allowed_access_groups" + ] = team_allowed_access_groups + data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata _headers = dict(request.headers) _headers.pop( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index b118400b620..3b81da10923 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -305,6 +305,16 @@ model LiteLLM_VerificationToken { litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id]) + + // SELECT COUNT(*) FROM (SELECT "public"."LiteLLM_VerificationToken"."token" FROM "public"."LiteLLM_VerificationToken" WHERE ("public"."LiteLLM_VerificationToken"."user_id" = $1 AND ("public"."LiteLLM_VerificationToken"."team_id" IS NULL OR "public"."LiteLLM_VerificationToken"."team_id" <> $2)) OFFSET $3 ) AS "sub" + // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."user_id" = $1 OFFSET $2 + @@index([user_id, team_id]) + + // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."team_id" = $1 OFFSET $2 + @@index([team_id]) + + // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (("public"."LiteLLM_VerificationToken"."expires" IS NULL OR "public"."LiteLLM_VerificationToken"."expires" > $1) AND "public"."LiteLLM_VerificationToken"."budget_reset_at" < $2) OFFSET $3 + @@index([budget_reset_at, expires]) } // Audit table for deleted keys - preserves spend and key information for historical tracking diff --git a/litellm/router.py b/litellm/router.py index d01c8443dab..65445e29c41 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -88,6 +88,7 @@ is_clientside_credential, ) from litellm.router_utils.common_utils import ( + filter_deployments_by_access_groups, filter_team_based_models, filter_web_search_deployments, ) @@ -8087,10 +8088,17 @@ async def async_get_healthy_deployments( request_kwargs=request_kwargs, ) - verbose_router_logger.debug( - f"healthy_deployments after web search filter: {healthy_deployments}" + verbose_router_logger.debug(f"healthy_deployments after web search filter: {healthy_deployments}") + + # Filter by allowed access groups (GitHub issue #18333) + # This prevents cross-team load balancing when teams have models with same name in different access groups + healthy_deployments = filter_deployments_by_access_groups( + healthy_deployments=healthy_deployments, + request_kwargs=request_kwargs, ) + verbose_router_logger.debug(f"healthy_deployments after access group filter: {healthy_deployments}") + if isinstance(healthy_deployments, dict): return healthy_deployments diff --git a/litellm/router_utils/common_utils.py b/litellm/router_utils/common_utils.py index 10acc343abd..2c0ea5976d6 100644 --- a/litellm/router_utils/common_utils.py +++ b/litellm/router_utils/common_utils.py @@ -75,6 +75,7 @@ def filter_team_based_models( if deployment.get("model_info", {}).get("id") not in ids_to_remove ] + def _deployment_supports_web_search(deployment: Dict) -> bool: """ Check if a deployment supports web search. @@ -112,7 +113,7 @@ def filter_web_search_deployments( is_web_search_request = False tools = request_kwargs.get("tools") or [] for tool in tools: - # These are the two websearch tools for OpenAI / Azure. + # These are the two websearch tools for OpenAI / Azure. if tool.get("type") == "web_search" or tool.get("type") == "web_search_preview": is_web_search_request = True break @@ -121,8 +122,82 @@ def filter_web_search_deployments( return healthy_deployments # Filter out deployments that don't support web search - final_deployments = [d for d in healthy_deployments if _deployment_supports_web_search(d)] + final_deployments = [ + d for d in healthy_deployments if _deployment_supports_web_search(d) + ] if len(healthy_deployments) > 0 and len(final_deployments) == 0: verbose_logger.warning("No deployments support web search for request") return final_deployments + +def filter_deployments_by_access_groups( + healthy_deployments: Union[List[Dict], Dict], + request_kwargs: Optional[Dict] = None, +) -> Union[List[Dict], Dict]: + """ + Filter deployments to only include those matching the user's allowed access groups. + + Reads from TWO separate metadata fields (per maintainer feedback): + - `user_api_key_allowed_access_groups`: Access groups from the API Key's models. + - `user_api_key_team_allowed_access_groups`: Access groups from the Team's models. + + A deployment is included if its access_groups overlap with EITHER the key's + or the team's allowed access groups. Deployments with no access_groups are + always included (not restricted). + + This prevents cross-team load balancing when multiple teams have models with + the same name but in different access groups (GitHub issue #18333). + """ + if request_kwargs is None: + return healthy_deployments + + if isinstance(healthy_deployments, dict): + return healthy_deployments + + metadata = request_kwargs.get("metadata") or {} + litellm_metadata = request_kwargs.get("litellm_metadata") or {} + + # Gather key-level allowed access groups + key_allowed_access_groups = ( + metadata.get("user_api_key_allowed_access_groups") + or litellm_metadata.get("user_api_key_allowed_access_groups") + or [] + ) + + # Gather team-level allowed access groups + team_allowed_access_groups = ( + metadata.get("user_api_key_team_allowed_access_groups") + or litellm_metadata.get("user_api_key_team_allowed_access_groups") + or [] + ) + + # Combine both for the final allowed set + combined_allowed_access_groups = list(key_allowed_access_groups) + list( + team_allowed_access_groups + ) + + # If no access groups specified from either source, return all deployments (backwards compatible) + if not combined_allowed_access_groups: + return healthy_deployments + + allowed_set = set(combined_allowed_access_groups) + filtered = [] + for deployment in healthy_deployments: + model_info = deployment.get("model_info") or {} + deployment_access_groups = model_info.get("access_groups") or [] + + # If deployment has no access groups, include it (not restricted) + if not deployment_access_groups: + filtered.append(deployment) + continue + + # Include if any of deployment's groups overlap with allowed groups + if set(deployment_access_groups) & allowed_set: + filtered.append(deployment) + + if len(healthy_deployments) > 0 and len(filtered) == 0: + verbose_logger.warning( + f"No deployments match allowed access groups {combined_allowed_access_groups}" + ) + + return filtered diff --git a/schema.prisma b/schema.prisma index b118400b620..3b81da10923 100644 --- a/schema.prisma +++ b/schema.prisma @@ -305,6 +305,16 @@ model LiteLLM_VerificationToken { litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id]) + + // SELECT COUNT(*) FROM (SELECT "public"."LiteLLM_VerificationToken"."token" FROM "public"."LiteLLM_VerificationToken" WHERE ("public"."LiteLLM_VerificationToken"."user_id" = $1 AND ("public"."LiteLLM_VerificationToken"."team_id" IS NULL OR "public"."LiteLLM_VerificationToken"."team_id" <> $2)) OFFSET $3 ) AS "sub" + // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."user_id" = $1 OFFSET $2 + @@index([user_id, team_id]) + + // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."team_id" = $1 OFFSET $2 + @@index([team_id]) + + // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (("public"."LiteLLM_VerificationToken"."expires" IS NULL OR "public"."LiteLLM_VerificationToken"."expires" > $1) AND "public"."LiteLLM_VerificationToken"."budget_reset_at" < $2) OFFSET $3 + @@index([budget_reset_at, expires]) } // Audit table for deleted keys - preserves spend and key information for historical tracking diff --git a/tests/mcp_tests/test_semantic_tool_filter_e2e.py b/tests/mcp_tests/test_semantic_tool_filter_e2e.py index 91c072ae8a3..cf951c1884b 100644 --- a/tests/mcp_tests/test_semantic_tool_filter_e2e.py +++ b/tests/mcp_tests/test_semantic_tool_filter_e2e.py @@ -12,19 +12,8 @@ from mcp.types import Tool as MCPTool -# Check if semantic-router is available -try: - import semantic_router - SEMANTIC_ROUTER_AVAILABLE = True -except ImportError: - SEMANTIC_ROUTER_AVAILABLE = False - @pytest.mark.asyncio -@pytest.mark.skipif( - not SEMANTIC_ROUTER_AVAILABLE, - reason="semantic-router not installed. Install with: pip install 'litellm[semantic-router]'" -) async def test_e2e_semantic_filter(): """E2E: Load router/filter and verify hook filters tools.""" from litellm import Router @@ -48,6 +37,8 @@ async def test_e2e_semantic_filter(): enabled=True, ) + hook = SemanticToolFilterHook(filter_instance) + # Create 10 tools tools = [ MCPTool(name="gmail_send", description="Send an email via Gmail", inputSchema={"type": "object"}), @@ -62,16 +53,10 @@ async def test_e2e_semantic_filter(): MCPTool(name="note_add", description="Add note", inputSchema={"type": "object"}), ] - # Build router with test tools - filter_instance._build_router(tools) - - hook = SemanticToolFilterHook(filter_instance) - data = { "model": "gpt-4", "messages": [{"role": "user", "content": "Send an email and create a calendar event"}], "tools": tools, - "metadata": {}, # Initialize metadata dict for hook to store filter stats } # Call hook diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py index 87c597c659b..8d35f5bbdc9 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py @@ -71,9 +71,6 @@ async def mock_embedding_async(*args, **kwargs): enabled=True, ) - # Build router with the tools before filtering - filter_instance._build_router(tools) - # Filter tools with email-related query filtered = await filter_instance.filter_tools( query="send an email to john@example.com", @@ -142,9 +139,6 @@ async def mock_embedding_async(*args, **kwargs): enabled=True, ) - # Build router with the tools before filtering - filter_instance._build_router(tools) - # Filter tools filtered = await filter_instance.filter_tools( query="test query", @@ -303,25 +297,21 @@ async def mock_embedding_async(*args, **kwargs): enabled=True, ) + # Create hook + hook = SemanticToolFilterHook(filter_instance) + # Prepare data - completion request with tools tools = [ MCPTool(name=f"tool_{i}", description=f"Tool {i}", inputSchema={"type": "object"}) for i in range(10) ] - # Build router with the tools before filtering - filter_instance._build_router(tools) - - # Create hook - hook = SemanticToolFilterHook(filter_instance) - data = { "model": "gpt-4", "messages": [ {"role": "user", "content": "Send an email"} ], "tools": tools, - "metadata": {}, # Hook needs metadata field to store filter stats } # Mock user API key dict and cache diff --git a/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py b/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py new file mode 100644 index 00000000000..9ac5072c5d8 --- /dev/null +++ b/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py @@ -0,0 +1,227 @@ +""" +Unit tests for filter_deployments_by_access_groups function. + +Tests the fix for GitHub issue #18333: Models loadbalanced outside of Model Access Group. +""" + +import pytest + +from litellm.router_utils.common_utils import filter_deployments_by_access_groups + + +class TestFilterDeploymentsByAccessGroups: + """Tests for the filter_deployments_by_access_groups function.""" + + def test_no_filter_when_no_access_groups_in_metadata(self): + """When no allowed_access_groups in metadata, return all deployments.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + ] + request_kwargs = {"metadata": {"user_api_key_team_id": "team-1"}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 2 # All deployments returned + + def test_filter_to_single_access_group(self): + """Filter to only deployments matching allowed access group.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + ] + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "2" + + def test_filter_with_multiple_allowed_groups(self): + """Filter with multiple allowed access groups.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + {"model_info": {"id": "3", "access_groups": ["AG3"]}}, + ] + request_kwargs = { + "metadata": {"user_api_key_allowed_access_groups": ["AG1", "AG2"]} + } + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 2 + ids = [d["model_info"]["id"] for d in result] + assert "1" in ids + assert "2" in ids + assert "3" not in ids + + def test_deployment_with_multiple_access_groups(self): + """Deployment with multiple access groups should match if any overlap.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1", "AG2"]}}, + {"model_info": {"id": "2", "access_groups": ["AG3"]}}, + ] + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "1" + + def test_deployment_without_access_groups_included(self): + """Deployments without access groups should be included (not restricted).""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2"}}, # No access_groups + {"model_info": {"id": "3", "access_groups": []}}, # Empty access_groups + ] + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + # Should include deployments 2 and 3 (no restrictions) + assert len(result) == 2 + ids = [d["model_info"]["id"] for d in result] + assert "2" in ids + assert "3" in ids + + def test_dict_deployment_passes_through(self): + """When deployment is a dict (specific deployment), pass through.""" + deployment = {"model_info": {"id": "1", "access_groups": ["AG1"]}} + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployment, + request_kwargs=request_kwargs, + ) + + assert result == deployment # Unchanged + + def test_none_request_kwargs_passes_through(self): + """When request_kwargs is None, return deployments unchanged.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + ] + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=None, + ) + + assert result == deployments + + def test_litellm_metadata_fallback(self): + """Should also check litellm_metadata for allowed access groups.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + ] + request_kwargs = { + "litellm_metadata": {"user_api_key_allowed_access_groups": ["AG1"]} + } + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "1" + + +def test_filter_deployments_by_access_groups_issue_18333(): + """ + Regression test for GitHub issue #18333. + + Scenario: Two models named 'gpt-5' in different access groups (AG1, AG2). + Team2 has access to AG2 only. When Team2 requests 'gpt-5', only the AG2 + deployment should be available for load balancing. + """ + deployments = [ + { + "model_name": "gpt-5", + "litellm_params": {"model": "gpt-4.1", "api_key": "key-1"}, + "model_info": {"id": "ag1-deployment", "access_groups": ["AG1"]}, + }, + { + "model_name": "gpt-5", + "litellm_params": {"model": "gpt-4o", "api_key": "key-2"}, + "model_info": {"id": "ag2-deployment", "access_groups": ["AG2"]}, + }, + ] + + # Team2's request with allowed access groups + request_kwargs = { + "metadata": { + "user_api_key_team_id": "team-2", + "user_api_key_allowed_access_groups": ["AG2"], + } + } + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + # Only AG2 deployment should be returned + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "ag2-deployment" + assert result[0]["litellm_params"]["model"] == "gpt-4o" + + +def test_get_access_groups_from_models(): + """ + Test the helper function that extracts access group names from models list. + This is used by the proxy to populate user_api_key_allowed_access_groups. + """ + from litellm.proxy.auth.model_checks import get_access_groups_from_models + + # Setup: access groups definition + model_access_groups = { + "AG1": ["gpt-4", "gpt-5"], + "AG2": ["claude-v1", "claude-v2"], + "beta-models": ["gpt-5-turbo"], + } + + # Test 1: Extract access groups from models list + models = ["gpt-4", "AG1", "AG2", "some-other-model"] + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=models + ) + assert set(result) == {"AG1", "AG2"} + + # Test 2: No access groups in models list + models = ["gpt-4", "claude-v1", "some-model"] + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=models + ) + assert result == [] + + # Test 3: Empty models list + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=[] + ) + assert result == [] + + # Test 4: All access groups + models = ["AG1", "AG2", "beta-models"] + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=models + ) + assert set(result) == {"AG1", "AG2", "beta-models"}