diff --git a/litellm/proxy/auth/model_checks.py b/litellm/proxy/auth/model_checks.py index 71ae1348f39..af2574d88ee 100644 --- a/litellm/proxy/auth/model_checks.py +++ b/litellm/proxy/auth/model_checks.py @@ -64,6 +64,27 @@ def _get_models_from_access_groups( return all_models +def get_access_groups_from_models( + model_access_groups: Dict[str, List[str]], + models: List[str], +) -> List[str]: + """ + Extract access group names from a models list. + + Given a models list like ["gpt-4", "beta-models", "claude-v1"] + and access groups like {"beta-models": ["gpt-5", "gpt-6"]}, + returns ["beta-models"]. + + This is used to pass allowed access groups to the router for filtering + deployments during load balancing (GitHub issue #18333). + """ + access_groups = [] + for model in models: + if model in model_access_groups: + access_groups.append(model) + return access_groups + + async def get_mcp_server_ids( user_api_key_dict: UserAPIKeyAuth, ) -> List[str]: @@ -80,7 +101,6 @@ async def get_mcp_server_ids( # Make a direct SQL query to get just the mcp_servers try: - result = await prisma_client.db.litellm_objectpermissiontable.find_unique( where={"object_permission_id": user_api_key_dict.object_permission_id}, ) @@ -176,6 +196,7 @@ def get_complete_model_list( """ unique_models = [] + def append_unique(models): for model in models: if model not in unique_models: @@ -188,7 +209,7 @@ def append_unique(models): else: append_unique(proxy_model_list) if include_model_access_groups: - append_unique(list(model_access_groups.keys())) # TODO: keys order + append_unique(list(model_access_groups.keys())) # TODO: keys order if user_model: append_unique([user_model]) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 3f844f21eb0..ea752091e53 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -173,12 +173,12 @@ def _get_dynamic_logging_metadata( user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig ) -> Optional[TeamCallbackMetadata]: callback_settings_obj: Optional[TeamCallbackMetadata] = None - key_dynamic_logging_settings: Optional[ - dict - ] = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict) - team_dynamic_logging_settings: Optional[ - dict - ] = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict) + key_dynamic_logging_settings: Optional[dict] = ( + KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict) + ) + team_dynamic_logging_settings: Optional[dict] = ( + KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict) + ) ######################################################################################### # Key-based callbacks ######################################################################################### @@ -661,11 +661,11 @@ def add_key_level_controls( ## KEY-LEVEL SPEND LOGS / TAGS if "tags" in key_metadata and key_metadata["tags"] is not None: - data[_metadata_variable_name][ - "tags" - ] = LiteLLMProxyRequestSetup._merge_tags( - request_tags=data[_metadata_variable_name].get("tags"), - tags_to_add=key_metadata["tags"], + data[_metadata_variable_name]["tags"] = ( + LiteLLMProxyRequestSetup._merge_tags( + request_tags=data[_metadata_variable_name].get("tags"), + tags_to_add=key_metadata["tags"], + ) ) if "disable_global_guardrails" in key_metadata and isinstance( key_metadata["disable_global_guardrails"], bool @@ -846,7 +846,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915 # Add headers to metadata for guardrails to access (fixes #17477) # Guardrails use metadata["headers"] to access request headers (e.g., User-Agent) - if _metadata_variable_name in data and isinstance(data[_metadata_variable_name], dict): + if _metadata_variable_name in data and isinstance( + data[_metadata_variable_name], dict + ): data[_metadata_variable_name]["headers"] = _headers # check for forwardable headers @@ -931,9 +933,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915 data[_metadata_variable_name]["litellm_api_version"] = version if general_settings is not None: - data[_metadata_variable_name][ - "global_max_parallel_requests" - ] = general_settings.get("global_max_parallel_requests", None) + data[_metadata_variable_name]["global_max_parallel_requests"] = ( + general_settings.get("global_max_parallel_requests", None) + ) ### KEY-LEVEL Controls key_metadata = user_api_key_dict.metadata @@ -1000,6 +1002,37 @@ async def add_litellm_data_to_request( # noqa: PLR0915 "user_api_key_model_max_budget" ] = user_api_key_dict.model_max_budget + # Extract allowed access groups for router filtering (GitHub issue #18333) + # This allows the router to filter deployments based on key's and team's access groups + # NOTE: We keep key and team access groups SEPARATE because a key doesn't always + # inherit all team access groups (per maintainer feedback). + if llm_router is not None: + from litellm.proxy.auth.model_checks import get_access_groups_from_models + + model_access_groups = llm_router.get_model_access_groups() + + # Key-level access groups (from user_api_key_dict.models) + key_models = list(user_api_key_dict.models) if user_api_key_dict.models else [] + key_allowed_access_groups = get_access_groups_from_models( + model_access_groups=model_access_groups, models=key_models + ) + if key_allowed_access_groups: + data[_metadata_variable_name][ + "user_api_key_allowed_access_groups" + ] = key_allowed_access_groups + + # Team-level access groups (from user_api_key_dict.team_models) + team_models = ( + list(user_api_key_dict.team_models) if user_api_key_dict.team_models else [] + ) + team_allowed_access_groups = get_access_groups_from_models( + model_access_groups=model_access_groups, models=team_models + ) + if team_allowed_access_groups: + data[_metadata_variable_name][ + "user_api_key_team_allowed_access_groups" + ] = team_allowed_access_groups + data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata _headers = dict(request.headers) _headers.pop( diff --git a/litellm/router.py b/litellm/router.py index 638df49ac05..91acb91428e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -86,6 +86,7 @@ is_clientside_credential, ) from litellm.router_utils.common_utils import ( + filter_deployments_by_access_groups, filter_team_based_models, filter_web_search_deployments, ) @@ -7781,10 +7782,17 @@ async def async_get_healthy_deployments( request_kwargs=request_kwargs, ) - verbose_router_logger.debug( - f"healthy_deployments after web search filter: {healthy_deployments}" + verbose_router_logger.debug(f"healthy_deployments after web search filter: {healthy_deployments}") + + # Filter by allowed access groups (GitHub issue #18333) + # This prevents cross-team load balancing when teams have models with same name in different access groups + healthy_deployments = filter_deployments_by_access_groups( + healthy_deployments=healthy_deployments, + request_kwargs=request_kwargs, ) + verbose_router_logger.debug(f"healthy_deployments after access group filter: {healthy_deployments}") + if isinstance(healthy_deployments, dict): return healthy_deployments diff --git a/litellm/router_utils/common_utils.py b/litellm/router_utils/common_utils.py index 10acc343abd..2c0ea5976d6 100644 --- a/litellm/router_utils/common_utils.py +++ b/litellm/router_utils/common_utils.py @@ -75,6 +75,7 @@ def filter_team_based_models( if deployment.get("model_info", {}).get("id") not in ids_to_remove ] + def _deployment_supports_web_search(deployment: Dict) -> bool: """ Check if a deployment supports web search. @@ -112,7 +113,7 @@ def filter_web_search_deployments( is_web_search_request = False tools = request_kwargs.get("tools") or [] for tool in tools: - # These are the two websearch tools for OpenAI / Azure. + # These are the two websearch tools for OpenAI / Azure. if tool.get("type") == "web_search" or tool.get("type") == "web_search_preview": is_web_search_request = True break @@ -121,8 +122,82 @@ def filter_web_search_deployments( return healthy_deployments # Filter out deployments that don't support web search - final_deployments = [d for d in healthy_deployments if _deployment_supports_web_search(d)] + final_deployments = [ + d for d in healthy_deployments if _deployment_supports_web_search(d) + ] if len(healthy_deployments) > 0 and len(final_deployments) == 0: verbose_logger.warning("No deployments support web search for request") return final_deployments + +def filter_deployments_by_access_groups( + healthy_deployments: Union[List[Dict], Dict], + request_kwargs: Optional[Dict] = None, +) -> Union[List[Dict], Dict]: + """ + Filter deployments to only include those matching the user's allowed access groups. + + Reads from TWO separate metadata fields (per maintainer feedback): + - `user_api_key_allowed_access_groups`: Access groups from the API Key's models. + - `user_api_key_team_allowed_access_groups`: Access groups from the Team's models. + + A deployment is included if its access_groups overlap with EITHER the key's + or the team's allowed access groups. Deployments with no access_groups are + always included (not restricted). + + This prevents cross-team load balancing when multiple teams have models with + the same name but in different access groups (GitHub issue #18333). + """ + if request_kwargs is None: + return healthy_deployments + + if isinstance(healthy_deployments, dict): + return healthy_deployments + + metadata = request_kwargs.get("metadata") or {} + litellm_metadata = request_kwargs.get("litellm_metadata") or {} + + # Gather key-level allowed access groups + key_allowed_access_groups = ( + metadata.get("user_api_key_allowed_access_groups") + or litellm_metadata.get("user_api_key_allowed_access_groups") + or [] + ) + + # Gather team-level allowed access groups + team_allowed_access_groups = ( + metadata.get("user_api_key_team_allowed_access_groups") + or litellm_metadata.get("user_api_key_team_allowed_access_groups") + or [] + ) + + # Combine both for the final allowed set + combined_allowed_access_groups = list(key_allowed_access_groups) + list( + team_allowed_access_groups + ) + + # If no access groups specified from either source, return all deployments (backwards compatible) + if not combined_allowed_access_groups: + return healthy_deployments + + allowed_set = set(combined_allowed_access_groups) + filtered = [] + for deployment in healthy_deployments: + model_info = deployment.get("model_info") or {} + deployment_access_groups = model_info.get("access_groups") or [] + + # If deployment has no access groups, include it (not restricted) + if not deployment_access_groups: + filtered.append(deployment) + continue + + # Include if any of deployment's groups overlap with allowed groups + if set(deployment_access_groups) & allowed_set: + filtered.append(deployment) + + if len(healthy_deployments) > 0 and len(filtered) == 0: + verbose_logger.warning( + f"No deployments match allowed access groups {combined_allowed_access_groups}" + ) + + return filtered diff --git a/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py b/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py new file mode 100644 index 00000000000..9ac5072c5d8 --- /dev/null +++ b/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py @@ -0,0 +1,227 @@ +""" +Unit tests for filter_deployments_by_access_groups function. + +Tests the fix for GitHub issue #18333: Models loadbalanced outside of Model Access Group. +""" + +import pytest + +from litellm.router_utils.common_utils import filter_deployments_by_access_groups + + +class TestFilterDeploymentsByAccessGroups: + """Tests for the filter_deployments_by_access_groups function.""" + + def test_no_filter_when_no_access_groups_in_metadata(self): + """When no allowed_access_groups in metadata, return all deployments.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + ] + request_kwargs = {"metadata": {"user_api_key_team_id": "team-1"}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 2 # All deployments returned + + def test_filter_to_single_access_group(self): + """Filter to only deployments matching allowed access group.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + ] + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "2" + + def test_filter_with_multiple_allowed_groups(self): + """Filter with multiple allowed access groups.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + {"model_info": {"id": "3", "access_groups": ["AG3"]}}, + ] + request_kwargs = { + "metadata": {"user_api_key_allowed_access_groups": ["AG1", "AG2"]} + } + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 2 + ids = [d["model_info"]["id"] for d in result] + assert "1" in ids + assert "2" in ids + assert "3" not in ids + + def test_deployment_with_multiple_access_groups(self): + """Deployment with multiple access groups should match if any overlap.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1", "AG2"]}}, + {"model_info": {"id": "2", "access_groups": ["AG3"]}}, + ] + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "1" + + def test_deployment_without_access_groups_included(self): + """Deployments without access groups should be included (not restricted).""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2"}}, # No access_groups + {"model_info": {"id": "3", "access_groups": []}}, # Empty access_groups + ] + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + # Should include deployments 2 and 3 (no restrictions) + assert len(result) == 2 + ids = [d["model_info"]["id"] for d in result] + assert "2" in ids + assert "3" in ids + + def test_dict_deployment_passes_through(self): + """When deployment is a dict (specific deployment), pass through.""" + deployment = {"model_info": {"id": "1", "access_groups": ["AG1"]}} + request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} + + result = filter_deployments_by_access_groups( + healthy_deployments=deployment, + request_kwargs=request_kwargs, + ) + + assert result == deployment # Unchanged + + def test_none_request_kwargs_passes_through(self): + """When request_kwargs is None, return deployments unchanged.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + ] + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=None, + ) + + assert result == deployments + + def test_litellm_metadata_fallback(self): + """Should also check litellm_metadata for allowed access groups.""" + deployments = [ + {"model_info": {"id": "1", "access_groups": ["AG1"]}}, + {"model_info": {"id": "2", "access_groups": ["AG2"]}}, + ] + request_kwargs = { + "litellm_metadata": {"user_api_key_allowed_access_groups": ["AG1"]} + } + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "1" + + +def test_filter_deployments_by_access_groups_issue_18333(): + """ + Regression test for GitHub issue #18333. + + Scenario: Two models named 'gpt-5' in different access groups (AG1, AG2). + Team2 has access to AG2 only. When Team2 requests 'gpt-5', only the AG2 + deployment should be available for load balancing. + """ + deployments = [ + { + "model_name": "gpt-5", + "litellm_params": {"model": "gpt-4.1", "api_key": "key-1"}, + "model_info": {"id": "ag1-deployment", "access_groups": ["AG1"]}, + }, + { + "model_name": "gpt-5", + "litellm_params": {"model": "gpt-4o", "api_key": "key-2"}, + "model_info": {"id": "ag2-deployment", "access_groups": ["AG2"]}, + }, + ] + + # Team2's request with allowed access groups + request_kwargs = { + "metadata": { + "user_api_key_team_id": "team-2", + "user_api_key_allowed_access_groups": ["AG2"], + } + } + + result = filter_deployments_by_access_groups( + healthy_deployments=deployments, + request_kwargs=request_kwargs, + ) + + # Only AG2 deployment should be returned + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "ag2-deployment" + assert result[0]["litellm_params"]["model"] == "gpt-4o" + + +def test_get_access_groups_from_models(): + """ + Test the helper function that extracts access group names from models list. + This is used by the proxy to populate user_api_key_allowed_access_groups. + """ + from litellm.proxy.auth.model_checks import get_access_groups_from_models + + # Setup: access groups definition + model_access_groups = { + "AG1": ["gpt-4", "gpt-5"], + "AG2": ["claude-v1", "claude-v2"], + "beta-models": ["gpt-5-turbo"], + } + + # Test 1: Extract access groups from models list + models = ["gpt-4", "AG1", "AG2", "some-other-model"] + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=models + ) + assert set(result) == {"AG1", "AG2"} + + # Test 2: No access groups in models list + models = ["gpt-4", "claude-v1", "some-model"] + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=models + ) + assert result == [] + + # Test 3: Empty models list + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=[] + ) + assert result == [] + + # Test 4: All access groups + models = ["AG1", "AG2", "beta-models"] + result = get_access_groups_from_models( + model_access_groups=model_access_groups, models=models + ) + assert set(result) == {"AG1", "AG2", "beta-models"}