diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 4140273ea25..be1e4dbcdc9 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2104,6 +2104,7 @@ class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase): rotation_interval: Optional[str] = None # How often to rotate (e.g., "30d", "90d") last_rotation_at: Optional[datetime] = None # When this key was last rotated key_rotation_at: Optional[datetime] = None # When this key should next be rotated + router_settings: Optional[Dict] = None # Router settings for this key (Key > Team > Global precedence) model_config = ConfigDict(protected_namespaces=()) diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 537b48f06ed..e9ce10ccf31 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -440,6 +440,29 @@ async def common_processing_pre_call_logic( user_api_key_dict=user_api_key_dict, data=self.data, call_type=route_type # type: ignore ) + # Apply hierarchical router_settings (Key > Team > Global) + if llm_router is not None and proxy_config is not None: + from litellm.proxy.proxy_server import prisma_client + + router_settings = await proxy_config._get_hierarchical_router_settings( + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + ) + + # If router_settings found (from key, team, or global), apply them + # This ensures key/team settings override global settings + if router_settings is not None and router_settings: + # Get model_list from current router + model_list = llm_router.get_model_list() + if model_list is not None: + # Create user_config with model_list and router_settings + # This creates a per-request router with the hierarchical settings + user_config = { + "model_list": model_list, + **router_settings + } + self.data["user_config"] = user_config + if "messages" in self.data and self.data["messages"]: logging_obj.update_messages(self.data["messages"]) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d264b82b873..3cc16937dca 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3222,6 +3222,84 @@ def _decrypt_db_variables(self, variables_dict: dict) -> dict: decrypted_variables[k] = decrypted_value return decrypted_variables + async def _get_hierarchical_router_settings( + self, + user_api_key_dict: Optional["UserAPIKeyAuth"], + prisma_client: Optional[PrismaClient], + ) -> Optional[dict]: + """ + Get router_settings in priority order: Key > Team > Global + + Returns: + dict: Combined router_settings, or None if no settings found + """ + if prisma_client is None: + return None + + import json + import yaml + + # 1. Try key-level router_settings + if user_api_key_dict is not None: + # Check if router_settings is available on the key object + key_router_settings_value = getattr(user_api_key_dict, "router_settings", None) + if key_router_settings_value is not None: + key_router_settings = None + if isinstance(key_router_settings_value, str): + try: + key_router_settings = yaml.safe_load(key_router_settings_value) + except (yaml.YAMLError, json.JSONDecodeError): + try: + key_router_settings = json.loads(key_router_settings_value) + except json.JSONDecodeError: + pass + elif isinstance(key_router_settings_value, dict): + key_router_settings = key_router_settings_value + + # If key has router_settings (non-empty dict), use it + if key_router_settings is not None and isinstance(key_router_settings, dict) and key_router_settings: + return key_router_settings + + # 2. Try team-level router_settings + if user_api_key_dict is not None and user_api_key_dict.team_id is not None: + try: + team_obj = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": user_api_key_dict.team_id} + ) + if team_obj is not None: + team_router_settings_value = getattr(team_obj, "router_settings", None) + if team_router_settings_value is not None: + team_router_settings = None + if isinstance(team_router_settings_value, str): + try: + team_router_settings = yaml.safe_load(team_router_settings_value) + except (yaml.YAMLError, json.JSONDecodeError): + try: + team_router_settings = json.loads(team_router_settings_value) + except json.JSONDecodeError: + pass + elif isinstance(team_router_settings_value, dict): + team_router_settings = team_router_settings_value + + # If team has router_settings (non-empty dict), use it + if team_router_settings is not None and isinstance(team_router_settings, dict) and team_router_settings: + return team_router_settings + except Exception: + # If team lookup fails, continue to global settings + pass + + # 3. Try global router_settings + try: + db_router_settings = await prisma_client.db.litellm_config.find_first( + where={"param_name": "router_settings"} + ) + if db_router_settings is not None and isinstance(db_router_settings.param_value, dict) and db_router_settings.param_value: + return db_router_settings.param_value + except Exception: + pass + + return None + async def _add_router_settings_from_db_config( self, config_data: dict, diff --git a/tests/test_litellm/proxy/test_common_request_processing.py b/tests/test_litellm/proxy/test_common_request_processing.py index b5d44385698..bab95feec06 100644 --- a/tests/test_litellm/proxy/test_common_request_processing.py +++ b/tests/test_litellm/proxy/test_common_request_processing.py @@ -75,6 +75,84 @@ async def mock_common_processing_pre_call_logic( pytest.fail("litellm_call_id is not a valid UUID") assert data_passed["litellm_call_id"] == returned_data["litellm_call_id"] + @pytest.mark.asyncio + async def test_should_apply_hierarchical_router_settings_to_user_config( + self, monkeypatch + ): + processing_obj = ProxyBaseLLMRequestProcessing(data={}) + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + + async def mock_add_litellm_data_to_request(*args, **kwargs): + return {} + + async def mock_common_processing_pre_call_logic( + user_api_key_dict, data, call_type + ): + data_copy = copy.deepcopy(data) + return data_copy + + mock_proxy_logging_obj = MagicMock(spec=ProxyLogging) + mock_proxy_logging_obj.pre_call_hook = AsyncMock( + side_effect=mock_common_processing_pre_call_logic + ) + monkeypatch.setattr( + litellm.proxy.common_request_processing, + "add_litellm_data_to_request", + mock_add_litellm_data_to_request, + ) + + mock_general_settings = {} + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + mock_proxy_config = MagicMock(spec=ProxyConfig) + + mock_router_settings = { + "routing_strategy": "least-busy", + "timeout": 30.0, + "num_retries": 3, + } + mock_proxy_config._get_hierarchical_router_settings = AsyncMock( + return_value=mock_router_settings + ) + + mock_model_list = [ + {"model_name": "gpt-3.5-turbo", "litellm_params": {"model": "gpt-3.5-turbo"}}, + {"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}}, + ] + mock_llm_router = MagicMock() + mock_llm_router.get_model_list = MagicMock(return_value=mock_model_list) + + mock_prisma_client = MagicMock() + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", + mock_prisma_client, + ) + + route_type = "acompletion" + + returned_data, logging_obj = await processing_obj.common_processing_pre_call_logic( + request=mock_request, + general_settings=mock_general_settings, + user_api_key_dict=mock_user_api_key_dict, + proxy_logging_obj=mock_proxy_logging_obj, + proxy_config=mock_proxy_config, + route_type=route_type, + llm_router=mock_llm_router, + ) + + mock_proxy_config._get_hierarchical_router_settings.assert_called_once_with( + user_api_key_dict=mock_user_api_key_dict, + prisma_client=mock_prisma_client, + ) + mock_llm_router.get_model_list.assert_called_once() + + assert "user_config" in returned_data + user_config = returned_data["user_config"] + assert user_config["model_list"] == mock_model_list + assert user_config["routing_strategy"] == "least-busy" + assert user_config["timeout"] == 30.0 + assert user_config["num_retries"] == 3 + @pytest.mark.asyncio async def test_stream_timeout_header_processing(self): """ diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index 53d89df5026..751a9033871 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -3124,3 +3124,95 @@ def test_deep_merge_dicts_skips_none_and_empty_lists(monkeypatch): assert result["general_settings"]["nested"]["key1"] == "updated_value1" assert result["general_settings"]["nested"]["key2"] == "value2" assert result["general_settings"]["nested"]["key3"] == "value3" + + +@pytest.mark.asyncio +async def test_get_hierarchical_router_settings(): + """ + Test _get_hierarchical_router_settings method's priority order: Key > Team > Global + """ + from unittest.mock import AsyncMock, MagicMock + + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.proxy_server import ProxyConfig + + proxy_config = ProxyConfig() + + # Test Case 1: Returns None when prisma_client is None + result = await proxy_config._get_hierarchical_router_settings( + user_api_key_dict=None, + prisma_client=None, + ) + assert result is None + + # Test Case 2: Returns key-level router_settings when available (as dict) + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + mock_user_api_key_dict.router_settings = {"routing_strategy": "key-level", "timeout": 10} + mock_user_api_key_dict.team_id = None + + mock_prisma_client = MagicMock() + + result = await proxy_config._get_hierarchical_router_settings( + user_api_key_dict=mock_user_api_key_dict, + prisma_client=mock_prisma_client, + ) + assert result == {"routing_strategy": "key-level", "timeout": 10} + + # Test Case 3: Returns key-level router_settings when available (as YAML string) + mock_user_api_key_dict.router_settings = "routing_strategy: key-yaml\ntimeout: 20" + result = await proxy_config._get_hierarchical_router_settings( + user_api_key_dict=mock_user_api_key_dict, + prisma_client=mock_prisma_client, + ) + assert result == {"routing_strategy": "key-yaml", "timeout": 20} + + # Test Case 4: Falls back to team-level router_settings when key-level is not available + mock_user_api_key_dict.router_settings = None + mock_user_api_key_dict.team_id = "team-123" + + mock_team_obj = MagicMock() + mock_team_obj.router_settings = {"routing_strategy": "team-level", "timeout": 30} + + mock_prisma_client.db.litellm_teamtable.find_unique = AsyncMock( + return_value=mock_team_obj + ) + + result = await proxy_config._get_hierarchical_router_settings( + user_api_key_dict=mock_user_api_key_dict, + prisma_client=mock_prisma_client, + ) + assert result == {"routing_strategy": "team-level", "timeout": 30} + mock_prisma_client.db.litellm_teamtable.find_unique.assert_called_once_with( + where={"team_id": "team-123"} + ) + + # Test Case 5: Falls back to global router_settings when neither key nor team settings are available + mock_user_api_key_dict.router_settings = None + mock_prisma_client.db.litellm_teamtable.find_unique = AsyncMock(return_value=None) + + mock_db_config = MagicMock() + mock_db_config.param_value = {"routing_strategy": "global-level", "timeout": 40} + + mock_prisma_client.db.litellm_config.find_first = AsyncMock( + return_value=mock_db_config + ) + + result = await proxy_config._get_hierarchical_router_settings( + user_api_key_dict=mock_user_api_key_dict, + prisma_client=mock_prisma_client, + ) + assert result == {"routing_strategy": "global-level", "timeout": 40} + mock_prisma_client.db.litellm_config.find_first.assert_called_once_with( + where={"param_name": "router_settings"} + ) + + # Test Case 6: Returns None when no settings are found + mock_user_api_key_dict.router_settings = None + mock_prisma_client.db.litellm_teamtable.find_unique = AsyncMock(return_value=None) + mock_prisma_client.db.litellm_config.find_first = AsyncMock(return_value=None) + + result = await proxy_config._get_hierarchical_router_settings( + user_api_key_dict=mock_user_api_key_dict, + prisma_client=mock_prisma_client, + ) + assert result is None