diff --git a/litellm/proxy/auth/model_checks.py b/litellm/proxy/auth/model_checks.py index af2574d88e..71ae1348f3 100644 --- a/litellm/proxy/auth/model_checks.py +++ b/litellm/proxy/auth/model_checks.py @@ -64,27 +64,6 @@ def _get_models_from_access_groups( return all_models -def get_access_groups_from_models( - model_access_groups: Dict[str, List[str]], - models: List[str], -) -> List[str]: - """ - Extract access group names from a models list. - - Given a models list like ["gpt-4", "beta-models", "claude-v1"] - and access groups like {"beta-models": ["gpt-5", "gpt-6"]}, - returns ["beta-models"]. - - This is used to pass allowed access groups to the router for filtering - deployments during load balancing (GitHub issue #18333). - """ - access_groups = [] - for model in models: - if model in model_access_groups: - access_groups.append(model) - return access_groups - - async def get_mcp_server_ids( user_api_key_dict: UserAPIKeyAuth, ) -> List[str]: @@ -101,6 +80,7 @@ async def get_mcp_server_ids( # Make a direct SQL query to get just the mcp_servers try: + result = await prisma_client.db.litellm_objectpermissiontable.find_unique( where={"object_permission_id": user_api_key_dict.object_permission_id}, ) @@ -196,7 +176,6 @@ def get_complete_model_list( """ unique_models = [] - def append_unique(models): for model in models: if model not in unique_models: @@ -209,7 +188,7 @@ def append_unique(models): else: append_unique(proxy_model_list) if include_model_access_groups: - append_unique(list(model_access_groups.keys())) # TODO: keys order + append_unique(list(model_access_groups.keys())) # TODO: keys order if user_model: append_unique([user_model]) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 7a49c1f652..ad0ab6b7a3 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -173,12 +173,12 @@ def _get_dynamic_logging_metadata( user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig ) -> Optional[TeamCallbackMetadata]: callback_settings_obj: Optional[TeamCallbackMetadata] = None - key_dynamic_logging_settings: Optional[dict] = ( - KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict) - ) - team_dynamic_logging_settings: Optional[dict] = ( - KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict) - ) + key_dynamic_logging_settings: Optional[ + dict + ] = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict) + team_dynamic_logging_settings: Optional[ + dict + ] = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict) ######################################################################################### # Key-based callbacks ######################################################################################### @@ -661,11 +661,11 @@ def add_key_level_controls( ## KEY-LEVEL SPEND LOGS / TAGS if "tags" in key_metadata and key_metadata["tags"] is not None: - data[_metadata_variable_name]["tags"] = ( - LiteLLMProxyRequestSetup._merge_tags( - request_tags=data[_metadata_variable_name].get("tags"), - tags_to_add=key_metadata["tags"], - ) + data[_metadata_variable_name][ + "tags" + ] = LiteLLMProxyRequestSetup._merge_tags( + request_tags=data[_metadata_variable_name].get("tags"), + tags_to_add=key_metadata["tags"], ) if "disable_global_guardrails" in key_metadata and isinstance( key_metadata["disable_global_guardrails"], bool @@ -933,9 +933,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915 data[_metadata_variable_name]["litellm_api_version"] = version if general_settings is not None: - data[_metadata_variable_name]["global_max_parallel_requests"] = ( - general_settings.get("global_max_parallel_requests", None) - ) + data[_metadata_variable_name][ + "global_max_parallel_requests" + ] = general_settings.get("global_max_parallel_requests", None) ### KEY-LEVEL Controls key_metadata = user_api_key_dict.metadata @@ -1002,37 +1002,6 @@ async def add_litellm_data_to_request( # noqa: PLR0915 "user_api_key_model_max_budget" ] = user_api_key_dict.model_max_budget - # Extract allowed access groups for router filtering (GitHub issue #18333) - # This allows the router to filter deployments based on key's and team's access groups - # NOTE: We keep key and team access groups SEPARATE because a key doesn't always - # inherit all team access groups (per maintainer feedback). - if llm_router is not None: - from litellm.proxy.auth.model_checks import get_access_groups_from_models - - model_access_groups = llm_router.get_model_access_groups() - - # Key-level access groups (from user_api_key_dict.models) - key_models = list(user_api_key_dict.models) if user_api_key_dict.models else [] - key_allowed_access_groups = get_access_groups_from_models( - model_access_groups=model_access_groups, models=key_models - ) - if key_allowed_access_groups: - data[_metadata_variable_name][ - "user_api_key_allowed_access_groups" - ] = key_allowed_access_groups - - # Team-level access groups (from user_api_key_dict.team_models) - team_models = ( - list(user_api_key_dict.team_models) if user_api_key_dict.team_models else [] - ) - team_allowed_access_groups = get_access_groups_from_models( - model_access_groups=model_access_groups, models=team_models - ) - if team_allowed_access_groups: - data[_metadata_variable_name][ - "user_api_key_team_allowed_access_groups" - ] = team_allowed_access_groups - data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata _headers = dict(request.headers) _headers.pop( diff --git a/litellm/router.py b/litellm/router.py index f73d907c8c..dc07280ea1 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -86,7 +86,6 @@ is_clientside_credential, ) from litellm.router_utils.common_utils import ( - filter_deployments_by_access_groups, filter_team_based_models, filter_web_search_deployments, ) @@ -7847,17 +7846,10 @@ async def async_get_healthy_deployments( request_kwargs=request_kwargs, ) - verbose_router_logger.debug(f"healthy_deployments after web search filter: {healthy_deployments}") - - # Filter by allowed access groups (GitHub issue #18333) - # This prevents cross-team load balancing when teams have models with same name in different access groups - healthy_deployments = filter_deployments_by_access_groups( - healthy_deployments=healthy_deployments, - request_kwargs=request_kwargs, + verbose_router_logger.debug( + f"healthy_deployments after web search filter: {healthy_deployments}" ) - verbose_router_logger.debug(f"healthy_deployments after access group filter: {healthy_deployments}") - if isinstance(healthy_deployments, dict): return healthy_deployments diff --git a/litellm/router_utils/common_utils.py b/litellm/router_utils/common_utils.py index 2c0ea5976d..10acc343ab 100644 --- a/litellm/router_utils/common_utils.py +++ b/litellm/router_utils/common_utils.py @@ -75,7 +75,6 @@ def filter_team_based_models( if deployment.get("model_info", {}).get("id") not in ids_to_remove ] - def _deployment_supports_web_search(deployment: Dict) -> bool: """ Check if a deployment supports web search. @@ -113,7 +112,7 @@ def filter_web_search_deployments( is_web_search_request = False tools = request_kwargs.get("tools") or [] for tool in tools: - # These are the two websearch tools for OpenAI / Azure. + # These are the two websearch tools for OpenAI / Azure. if tool.get("type") == "web_search" or tool.get("type") == "web_search_preview": is_web_search_request = True break @@ -122,82 +121,8 @@ def filter_web_search_deployments( return healthy_deployments # Filter out deployments that don't support web search - final_deployments = [ - d for d in healthy_deployments if _deployment_supports_web_search(d) - ] + final_deployments = [d for d in healthy_deployments if _deployment_supports_web_search(d)] if len(healthy_deployments) > 0 and len(final_deployments) == 0: verbose_logger.warning("No deployments support web search for request") return final_deployments - -def filter_deployments_by_access_groups( - healthy_deployments: Union[List[Dict], Dict], - request_kwargs: Optional[Dict] = None, -) -> Union[List[Dict], Dict]: - """ - Filter deployments to only include those matching the user's allowed access groups. - - Reads from TWO separate metadata fields (per maintainer feedback): - - `user_api_key_allowed_access_groups`: Access groups from the API Key's models. - - `user_api_key_team_allowed_access_groups`: Access groups from the Team's models. - - A deployment is included if its access_groups overlap with EITHER the key's - or the team's allowed access groups. Deployments with no access_groups are - always included (not restricted). - - This prevents cross-team load balancing when multiple teams have models with - the same name but in different access groups (GitHub issue #18333). - """ - if request_kwargs is None: - return healthy_deployments - - if isinstance(healthy_deployments, dict): - return healthy_deployments - - metadata = request_kwargs.get("metadata") or {} - litellm_metadata = request_kwargs.get("litellm_metadata") or {} - - # Gather key-level allowed access groups - key_allowed_access_groups = ( - metadata.get("user_api_key_allowed_access_groups") - or litellm_metadata.get("user_api_key_allowed_access_groups") - or [] - ) - - # Gather team-level allowed access groups - team_allowed_access_groups = ( - metadata.get("user_api_key_team_allowed_access_groups") - or litellm_metadata.get("user_api_key_team_allowed_access_groups") - or [] - ) - - # Combine both for the final allowed set - combined_allowed_access_groups = list(key_allowed_access_groups) + list( - team_allowed_access_groups - ) - - # If no access groups specified from either source, return all deployments (backwards compatible) - if not combined_allowed_access_groups: - return healthy_deployments - - allowed_set = set(combined_allowed_access_groups) - filtered = [] - for deployment in healthy_deployments: - model_info = deployment.get("model_info") or {} - deployment_access_groups = model_info.get("access_groups") or [] - - # If deployment has no access groups, include it (not restricted) - if not deployment_access_groups: - filtered.append(deployment) - continue - - # Include if any of deployment's groups overlap with allowed groups - if set(deployment_access_groups) & allowed_set: - filtered.append(deployment) - - if len(healthy_deployments) > 0 and len(filtered) == 0: - verbose_logger.warning( - f"No deployments match allowed access groups {combined_allowed_access_groups}" - ) - - return filtered diff --git a/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py b/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py deleted file mode 100644 index 9ac5072c5d..0000000000 --- a/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py +++ /dev/null @@ -1,227 +0,0 @@ -""" -Unit tests for filter_deployments_by_access_groups function. - -Tests the fix for GitHub issue #18333: Models loadbalanced outside of Model Access Group. -""" - -import pytest - -from litellm.router_utils.common_utils import filter_deployments_by_access_groups - - -class TestFilterDeploymentsByAccessGroups: - """Tests for the filter_deployments_by_access_groups function.""" - - def test_no_filter_when_no_access_groups_in_metadata(self): - """When no allowed_access_groups in metadata, return all deployments.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - {"model_info": {"id": "2", "access_groups": ["AG2"]}}, - ] - request_kwargs = {"metadata": {"user_api_key_team_id": "team-1"}} - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - assert len(result) == 2 # All deployments returned - - def test_filter_to_single_access_group(self): - """Filter to only deployments matching allowed access group.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - {"model_info": {"id": "2", "access_groups": ["AG2"]}}, - ] - request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - assert len(result) == 1 - assert result[0]["model_info"]["id"] == "2" - - def test_filter_with_multiple_allowed_groups(self): - """Filter with multiple allowed access groups.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - {"model_info": {"id": "2", "access_groups": ["AG2"]}}, - {"model_info": {"id": "3", "access_groups": ["AG3"]}}, - ] - request_kwargs = { - "metadata": {"user_api_key_allowed_access_groups": ["AG1", "AG2"]} - } - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - assert len(result) == 2 - ids = [d["model_info"]["id"] for d in result] - assert "1" in ids - assert "2" in ids - assert "3" not in ids - - def test_deployment_with_multiple_access_groups(self): - """Deployment with multiple access groups should match if any overlap.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1", "AG2"]}}, - {"model_info": {"id": "2", "access_groups": ["AG3"]}}, - ] - request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - assert len(result) == 1 - assert result[0]["model_info"]["id"] == "1" - - def test_deployment_without_access_groups_included(self): - """Deployments without access groups should be included (not restricted).""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - {"model_info": {"id": "2"}}, # No access_groups - {"model_info": {"id": "3", "access_groups": []}}, # Empty access_groups - ] - request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - # Should include deployments 2 and 3 (no restrictions) - assert len(result) == 2 - ids = [d["model_info"]["id"] for d in result] - assert "2" in ids - assert "3" in ids - - def test_dict_deployment_passes_through(self): - """When deployment is a dict (specific deployment), pass through.""" - deployment = {"model_info": {"id": "1", "access_groups": ["AG1"]}} - request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} - - result = filter_deployments_by_access_groups( - healthy_deployments=deployment, - request_kwargs=request_kwargs, - ) - - assert result == deployment # Unchanged - - def test_none_request_kwargs_passes_through(self): - """When request_kwargs is None, return deployments unchanged.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - ] - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=None, - ) - - assert result == deployments - - def test_litellm_metadata_fallback(self): - """Should also check litellm_metadata for allowed access groups.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - {"model_info": {"id": "2", "access_groups": ["AG2"]}}, - ] - request_kwargs = { - "litellm_metadata": {"user_api_key_allowed_access_groups": ["AG1"]} - } - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - assert len(result) == 1 - assert result[0]["model_info"]["id"] == "1" - - -def test_filter_deployments_by_access_groups_issue_18333(): - """ - Regression test for GitHub issue #18333. - - Scenario: Two models named 'gpt-5' in different access groups (AG1, AG2). - Team2 has access to AG2 only. When Team2 requests 'gpt-5', only the AG2 - deployment should be available for load balancing. - """ - deployments = [ - { - "model_name": "gpt-5", - "litellm_params": {"model": "gpt-4.1", "api_key": "key-1"}, - "model_info": {"id": "ag1-deployment", "access_groups": ["AG1"]}, - }, - { - "model_name": "gpt-5", - "litellm_params": {"model": "gpt-4o", "api_key": "key-2"}, - "model_info": {"id": "ag2-deployment", "access_groups": ["AG2"]}, - }, - ] - - # Team2's request with allowed access groups - request_kwargs = { - "metadata": { - "user_api_key_team_id": "team-2", - "user_api_key_allowed_access_groups": ["AG2"], - } - } - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - # Only AG2 deployment should be returned - assert len(result) == 1 - assert result[0]["model_info"]["id"] == "ag2-deployment" - assert result[0]["litellm_params"]["model"] == "gpt-4o" - - -def test_get_access_groups_from_models(): - """ - Test the helper function that extracts access group names from models list. - This is used by the proxy to populate user_api_key_allowed_access_groups. - """ - from litellm.proxy.auth.model_checks import get_access_groups_from_models - - # Setup: access groups definition - model_access_groups = { - "AG1": ["gpt-4", "gpt-5"], - "AG2": ["claude-v1", "claude-v2"], - "beta-models": ["gpt-5-turbo"], - } - - # Test 1: Extract access groups from models list - models = ["gpt-4", "AG1", "AG2", "some-other-model"] - result = get_access_groups_from_models( - model_access_groups=model_access_groups, models=models - ) - assert set(result) == {"AG1", "AG2"} - - # Test 2: No access groups in models list - models = ["gpt-4", "claude-v1", "some-model"] - result = get_access_groups_from_models( - model_access_groups=model_access_groups, models=models - ) - assert result == [] - - # Test 3: Empty models list - result = get_access_groups_from_models( - model_access_groups=model_access_groups, models=[] - ) - assert result == [] - - # Test 4: All access groups - models = ["AG1", "AG2", "beta-models"] - result = get_access_groups_from_models( - model_access_groups=model_access_groups, models=models - ) - assert set(result) == {"AG1", "AG2", "beta-models"}