From bec6b6a2a98c3d0996a2b2430e3dfd0da6f51e2b Mon Sep 17 00:00:00 2001 From: Harshit Jain Date: Sat, 10 Jan 2026 21:24:28 +0530 Subject: [PATCH 1/2] fixes: ci pipeline --- litellm/proxy/auth/model_checks.py | 25 +- litellm/proxy/litellm_pre_call_utils.py | 39 ++- litellm/router.py | 50 +++- litellm/router_utils/common_utils.py | 65 ++++- ...est_filter_deployments_by_access_groups.py | 225 ++++++++++++++++++ 5 files changed, 373 insertions(+), 31 deletions(-) create mode 100644 tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py diff --git a/litellm/proxy/auth/model_checks.py b/litellm/proxy/auth/model_checks.py index 71ae1348f39..af2574d88ee 100644 --- a/litellm/proxy/auth/model_checks.py +++ b/litellm/proxy/auth/model_checks.py @@ -64,6 +64,27 @@ def _get_models_from_access_groups( return all_models +def get_access_groups_from_models( + model_access_groups: Dict[str, List[str]], + models: List[str], +) -> List[str]: + """ + Extract access group names from a models list. + + Given a models list like ["gpt-4", "beta-models", "claude-v1"] + and access groups like {"beta-models": ["gpt-5", "gpt-6"]}, + returns ["beta-models"]. + + This is used to pass allowed access groups to the router for filtering + deployments during load balancing (GitHub issue #18333). + """ + access_groups = [] + for model in models: + if model in model_access_groups: + access_groups.append(model) + return access_groups + + async def get_mcp_server_ids( user_api_key_dict: UserAPIKeyAuth, ) -> List[str]: @@ -80,7 +101,6 @@ async def get_mcp_server_ids( # Make a direct SQL query to get just the mcp_servers try: - result = await prisma_client.db.litellm_objectpermissiontable.find_unique( where={"object_permission_id": user_api_key_dict.object_permission_id}, ) @@ -176,6 +196,7 @@ def get_complete_model_list( """ unique_models = [] + def append_unique(models): for model in models: if model not in unique_models: @@ -188,7 +209,7 @@ def append_unique(models): else: append_unique(proxy_model_list) if include_model_access_groups: - append_unique(list(model_access_groups.keys())) # TODO: keys order + append_unique(list(model_access_groups.keys())) # TODO: keys order if user_model: append_unique([user_model]) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 9be78264e85..641ac3d481f 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1,7 +1,7 @@ import asyncio import copy import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeAlias, Union from fastapi import Request from starlette.datastructures import Headers @@ -41,7 +41,7 @@ if TYPE_CHECKING: from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig - ProxyConfig = _ProxyConfig + ProxyConfig: TypeAlias = _ProxyConfig else: ProxyConfig = Any @@ -562,11 +562,15 @@ def add_litellm_metadata_from_request_headers( trace_id_from_header = headers.get("x-litellm-trace-id") if agent_id_from_header: metadata_from_headers["agent_id"] = agent_id_from_header - verbose_proxy_logger.debug(f"Extracted agent_id from header: {agent_id_from_header}") - + verbose_proxy_logger.debug( + f"Extracted agent_id from header: {agent_id_from_header}" + ) + if trace_id_from_header: metadata_from_headers["trace_id"] = trace_id_from_header - verbose_proxy_logger.debug(f"Extracted trace_id from header: {trace_id_from_header}") + verbose_proxy_logger.debug( + f"Extracted trace_id from header: {trace_id_from_header}" + ) if isinstance(data[_metadata_variable_name], dict): data[_metadata_variable_name].update(metadata_from_headers) @@ -1012,14 +1016,23 @@ async def add_litellm_data_to_request( # noqa: PLR0915 "user_api_key_model_max_budget" ] = user_api_key_dict.model_max_budget - # User spend, budget - used by prometheus.py - # Follow same pattern as team and API key budgets - data[_metadata_variable_name][ - "user_api_key_user_spend" - ] = user_api_key_dict.user_spend - data[_metadata_variable_name][ - "user_api_key_user_max_budget" - ] = user_api_key_dict.user_max_budget + # Extract allowed access groups for router filtering (GitHub issue #18333) + # This allows the router to filter deployments based on team's access groups + if llm_router is not None: + from litellm.proxy.auth.model_checks import get_access_groups_from_models + + model_access_groups = llm_router.get_model_access_groups() + # Combine key models and team models to get all allowed access groups + all_models = list(user_api_key_dict.models) + list( + user_api_key_dict.team_models or [] + ) + allowed_access_groups = get_access_groups_from_models( + model_access_groups=model_access_groups, models=all_models + ) + if allowed_access_groups: + data[_metadata_variable_name][ + "user_api_key_allowed_access_groups" + ] = allowed_access_groups data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata _headers = dict(request.headers) diff --git a/litellm/router.py b/litellm/router.py index d01c8443dab..57d7ddfdfef 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -88,6 +88,7 @@ is_clientside_credential, ) from litellm.router_utils.common_utils import ( + filter_deployments_by_access_groups, filter_team_based_models, filter_web_search_deployments, ) @@ -619,11 +620,12 @@ def __init__( # noqa: PLR0915 self.retry_policy = RetryPolicy(**retry_policy) elif isinstance(retry_policy, RetryPolicy): self.retry_policy = retry_policy - verbose_router_logger.info( - "\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format( - self.retry_policy.model_dump(exclude_none=True) + if self.retry_policy is not None: + verbose_router_logger.info( + "\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format( + self.retry_policy.model_dump(exclude_none=True) + ) ) - ) self.model_group_retry_policy: Optional[ Dict[str, RetryPolicy] @@ -636,11 +638,12 @@ def __init__( # noqa: PLR0915 elif isinstance(allowed_fails_policy, AllowedFailsPolicy): self.allowed_fails_policy = allowed_fails_policy - verbose_router_logger.info( - "\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format( - self.allowed_fails_policy.model_dump(exclude_none=True) + if self.allowed_fails_policy is not None: + verbose_router_logger.info( + "\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format( + self.allowed_fails_policy.model_dump(exclude_none=True) + ) ) - ) self.alerting_config: Optional[AlertingConfig] = alerting_config @@ -4769,9 +4772,23 @@ async def async_function_with_fallbacks_common_utils( # noqa: PLR0915 if hasattr(original_exception, "message"): # add the available fallbacks to the exception - original_exception.message += ". Received Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore - model_group, - fallback_model_group, + deployment_info = "" + if kwargs is not None: + metadata = kwargs.get("metadata", {}) + if metadata and "deployment" in metadata: + deployment_info = f"\nUsed Deployment: {metadata['deployment']}" + if "model_info" in metadata: + model_info = metadata["model_info"] + if isinstance(model_info, dict): + deployment_info += ( + f"\nDeployment ID: {model_info.get('id', 'unknown')}" + ) + + original_exception.message += ( # type: ignore + f". Received Model Group={model_group}" + f"\nAvailable Model Group Fallbacks={fallback_model_group}" + f"{deployment_info}" + f"\n\n💡 Tip: If using wildcard patterns (e.g., 'openai/*'), ensure all matching deployments have credentials with access to this model." ) if len(fallback_failure_exception_str) > 0: original_exception.message += ( # type: ignore @@ -8091,6 +8108,17 @@ async def async_get_healthy_deployments( f"healthy_deployments after web search filter: {healthy_deployments}" ) + # Filter by allowed access groups (GitHub issue #18333) + # This prevents cross-team load balancing when teams have models with same name in different access groups + healthy_deployments = filter_deployments_by_access_groups( + healthy_deployments=healthy_deployments, + request_kwargs=request_kwargs, + ) + + verbose_router_logger.debug( + f"healthy_deployments after access group filter: {healthy_deployments}" + ) + if isinstance(healthy_deployments, dict): return healthy_deployments diff --git a/litellm/router_utils/common_utils.py b/litellm/router_utils/common_utils.py index 10acc343abd..0cd572acf52 100644 --- a/litellm/router_utils/common_utils.py +++ b/litellm/router_utils/common_utils.py @@ -32,9 +32,9 @@ def add_model_file_id_mappings( model_file_id_mapping = {} if isinstance(healthy_deployments, list): for deployment, response in zip(healthy_deployments, responses): - model_file_id_mapping[deployment.get("model_info", {}).get("id")] = ( - response.id - ) + model_file_id_mapping[ + deployment.get("model_info", {}).get("id") + ] = response.id elif isinstance(healthy_deployments, dict): for model_id, file_id in healthy_deployments.items(): model_file_id_mapping[model_id] = file_id @@ -75,6 +75,7 @@ def filter_team_based_models( if deployment.get("model_info", {}).get("id") not in ids_to_remove ] + def _deployment_supports_web_search(deployment: Dict) -> bool: """ Check if a deployment supports web search. @@ -112,7 +113,7 @@ def filter_web_search_deployments( is_web_search_request = False tools = request_kwargs.get("tools") or [] for tool in tools: - # These are the two websearch tools for OpenAI / Azure. + # These are the two websearch tools for OpenAI / Azure. if tool.get("type") == "web_search" or tool.get("type") == "web_search_preview": is_web_search_request = True break @@ -121,8 +122,62 @@ def filter_web_search_deployments( return healthy_deployments # Filter out deployments that don't support web search - final_deployments = [d for d in healthy_deployments if _deployment_supports_web_search(d)] + final_deployments = [ + d for d in healthy_deployments if _deployment_supports_web_search(d) + ] if len(healthy_deployments) > 0 and len(final_deployments) == 0: verbose_logger.warning("No deployments support web search for request") return final_deployments + +def filter_deployments_by_access_groups( + healthy_deployments: Union[List[Dict], Dict], + request_kwargs: Optional[Dict] = None, +) -> Union[List[Dict], Dict]: + """ + Filter deployments to only include those matching the team's allowed access groups. + + If the request includes `user_api_key_allowed_access_groups` in metadata, + only return deployments where at least one of the deployment's access_groups + matches the allowed list. + + This prevents cross-team load balancing when multiple teams have models with + the same name but in different access groups (GitHub issue #18333). + """ + if request_kwargs is None: + return healthy_deployments + + if isinstance(healthy_deployments, dict): + return healthy_deployments + + metadata = request_kwargs.get("metadata") or {} + litellm_metadata = request_kwargs.get("litellm_metadata") or {} + allowed_access_groups = metadata.get( + "user_api_key_allowed_access_groups" + ) or litellm_metadata.get("user_api_key_allowed_access_groups") + + # If no access groups specified, return all deployments (backwards compatible) + if not allowed_access_groups: + return healthy_deployments + + allowed_set = set(allowed_access_groups) + filtered = [] + for deployment in healthy_deployments: + model_info = deployment.get("model_info") or {} + deployment_access_groups = model_info.get("access_groups") or [] + + # If deployment has no access groups, include it (not restricted) + if not deployment_access_groups: + filtered.append(deployment) + continue + + # Include if any of deployment's groups overlap with allowed groups + if set(deployment_access_groups) & allowed_set: + filtered.append(deployment) + + if len(healthy_deployments) > 0 and len(filtered) == 0: + verbose_logger.warning( + f"No deployments match allowed access groups {allowed_access_groups}" + ) + + return filtered diff --git a/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py b/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py new file mode 100644 index 00000000000..c46e4bba140 --- /dev/null +++ b/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py @@ -0,0 +1,225 @@ +""" +Unit tests for filter_deployments_by_access_groups function. + +Tests the fix for GitHub issue #18333: Models loadbalanced outside of Model Access Group. +""" + +from litellm.router_utils.common_utils import filter_deployments_by_access_groups + + +class TestFilterDeploymentsByAccessGroups: + """Tests for the filter_deployments_by_access_groups function.""" + + def test_no_filter_when_no_access_groups_in_metadata(self): + """When no allowed_access_groups in metadata, return all deployments.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + ] + request_kwargs = {"metadata": {"user_api_key_team_id": "team-1"}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 2 # All deployments returned + + def test_filter_to_single_access_group(self): + """Filter to only deployments matching allowed access group.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + ] + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "2" + + def test_filter_with_multiple_allowed_groups(self): + """Filter with multiple allowed access groups.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + {"model_info": {"id": "3", "access_groups": ["AG3"]}}, + ] + request_kwargs = { + "metadata": {"user_api_key_allowed_access_groups": ["AG1", "AG2"]} + } + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 2 + ids = [d["model_info"]["id"] for d in result] + assert "1" in ids + assert "2" in ids + assert "3" not in ids + + def test_deployment_with_multiple_access_groups(self): + """Deployment with multiple access groups should match if any overlap.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1", "AG2"]}}, + {"model_info": {"id": "2", "access_groups": ["AG3"]}}, + ] + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "1" + + def test_deployment_without_access_groups_included(self): + """Deployments without access groups should be included (not restricted).""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2"}}, # No access_groups + {"model_info": {"id": "3", "access_groups": []}}, # Empty access_groups + ] + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + # Should include deployments 2 and 3 (no restrictions) + assert len(result) == 2 + ids = [d["model_info"]["id"] for d in result] + assert "2" in ids + assert "3" in ids + + def test_dict_deployment_passes_through(self): + """When deployment is a dict (specific deployment), pass through.""" + deployment = {"model_info": {"id": "1", "access_groups": ["AG1"]}} + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployment, + request_kwargs=request_kwargs, + ) + + assert result == deployment # Unchanged + + def test_none_request_kwargs_passes_through(self): + """When request_kwargs is None, return deployments unchanged.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + ] + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=None, + ) + + assert result == deployments + + def test_litellm_metadata_fallback(self): + """Should also check litellm_metadata for allowed access groups.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + ] + request_kwargs = { + "litellm_metadata": {"user_api_key_allowed_access_groups": ["AG1"]} + } + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "1" + + +def test_filter_deployments_by_access_groups_issue_18333(): + """ + Regression test for GitHub issue #18333. + + Scenario: Two models named 'gpt-5' in different access groups (AG1, AG2). + Team2 has access to AG2 only. When Team2 requests 'gpt-5', only the AG2 + deployment should be available for load balancing. + """ + deployments = [ + { + "model_name": "gpt-5", + "litellm_params": {"model": "gpt-4.1", "api_key": "key-1"}, + "model_info": {"id": "ag1-deployment", "access_groups": ["AG1"]}, + }, + { + "model_name": "gpt-5", + "litellm_params": {"model": "gpt-4o", "api_key": "key-2"}, + "model_info": {"id": "ag2-deployment", "access_groups": ["AG2"]}, + }, + ] + + # Team2's request with allowed access groups + request_kwargs = { + "metadata": { + "user_api_key_team_id": "team-2", + "user_api_key_allowed_access_groups": ["AG2"], + } + } + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + # Only AG2 deployment should be returned + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "ag2-deployment" + assert result[0]["litellm_params"]["model"] == "gpt-4o" + + +def test_get_access_groups_from_models(): + """ + Test the helper function that extracts access group names from models list. + This is used by the proxy to populate user_api_key_allowed_access_groups. + """ + from litellm.proxy.auth.model_checks import get_access_groups_from_models + + # Setup: access groups definition + model_access_groups = { + "AG1": ["gpt-4", "gpt-5"], + "AG2": ["claude-v1", "claude-v2"], + "beta-models": ["gpt-5-turbo"], + } + + # Test 1: Extract access groups from models list + models = ["gpt-4", "AG1", "AG2", "some-other-model"] + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=models + ) + assert set(result) == {"AG1", "AG2"} + + # Test 2: No access groups in models list + models = ["gpt-4", "claude-v1", "some-model"] + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=models + ) + assert result == [] + + # Test 3: Empty models list + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=[] + ) + assert result == [] + + # Test 4: All access groups + models = ["AG1", "AG2", "beta-models"] + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=models + ) + assert set(result) == {"AG1", "AG2", "beta-models"} From 025d96eb846c6ba5c097dd5cc600609a7af339e7 Mon Sep 17 00:00:00 2001 From: Harshit Jain <48647625+Harshit28j@users.noreply.github.com> Date: Tue, 3 Feb 2026 20:20:21 +0530 Subject: [PATCH 2/2] Update litellm/proxy/litellm_pre_call_utils.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- litellm/proxy/litellm_pre_call_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 641ac3d481f..361e50aaf21 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1016,6 +1016,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915 "user_api_key_model_max_budget" ] = user_api_key_dict.model_max_budget + # User spend, budget - used by prometheus.py + # Follow same pattern as team and API key budgets + data[_metadata_variable_name][ + "user_api_key_user_spend" + ] = user_api_key_dict.user_spend + data[_metadata_variable_name][ + "user_api_key_user_max_budget" + ] = user_api_key_dict.user_max_budget + # Extract allowed access groups for router filtering (GitHub issue #18333) # This allows the router to filter deployments based on team's access groups if llm_router is not None: