diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index f38f94f4c98..501cfacaf91 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2162,6 +2162,7 @@ class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase): rotation_interval: Optional[str] = None # How often to rotate (e.g., "30d", "90d") last_rotation_at: Optional[datetime] = None # When this key was last rotated key_rotation_at: Optional[datetime] = None # When this key should next be rotated + router_settings: Optional[dict] = None model_config = ConfigDict(protected_namespaces=()) diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 769c250d9fc..3e079045774 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -616,6 +616,23 @@ async def common_processing_pre_call_logic( user_api_key_dict=user_api_key_dict, data=self.data, call_type=route_type # type: ignore ) + # Apply hierarchical router_settings (Key > Team) + # Global router_settings are already on the Router object itself. + if llm_router is not None and proxy_config is not None: + from litellm.proxy.proxy_server import prisma_client + + router_settings = await proxy_config._get_hierarchical_router_settings( + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + proxy_logging_obj=proxy_logging_obj, + ) + + # If router_settings found (from key or team), apply them + # Pass settings as per-request overrides instead of creating a new Router + # This avoids expensive Router instantiation on each request + if router_settings is not None: + self.data["router_settings_override"] = router_settings + if "messages" in self.data and self.data["messages"]: logging_obj.update_messages(self.data["messages"]) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 637893872d9..60327d8da17 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3249,17 +3249,18 @@ async def _update_llm_router( _model_list: list = self.decrypt_model_list_from_db( new_models=models_list ) - if len(_model_list) > 0: - verbose_proxy_logger.debug(f"_model_list: {_model_list}") - llm_router = litellm.Router( - model_list=_model_list, - router_general_settings=RouterGeneralSettings( - async_only_mode=True # only init async clients - ), - search_tools=search_tools, - ignore_invalid_deployments=True, - ) - verbose_proxy_logger.debug(f"updated llm_router: {llm_router}") + # Create router even with empty model list to support search_tools + # Router can function with model_list=[] and only search_tools + verbose_proxy_logger.debug(f"_model_list: {_model_list}") + llm_router = litellm.Router( + model_list=_model_list, + router_general_settings=RouterGeneralSettings( + async_only_mode=True # only init async clients + ), + search_tools=search_tools, + ignore_invalid_deployments=True, + ) + verbose_proxy_logger.debug(f"updated llm_router: {llm_router}") else: verbose_proxy_logger.debug(f"len new_models: {len(models_list)}") if search_tools is not None and llm_router is not None: @@ -3402,6 +3403,86 @@ def _decrypt_db_variables(self, variables_dict: dict) -> dict: decrypted_variables[k] = decrypted_value return decrypted_variables + @staticmethod + def _parse_router_settings_value(value: Any) -> Optional[dict]: + """ + Parse a router_settings value that may be a dict or a JSON/YAML string. + + Returns a non-empty dict if valid, otherwise None. + """ + if value is None: + return None + + parsed: Optional[dict] = None + if isinstance(value, dict): + parsed = value + elif isinstance(value, str): + import json + + import yaml + + try: + parsed = yaml.safe_load(value) + except (yaml.YAMLError, json.JSONDecodeError): + try: + parsed = json.loads(value) + except json.JSONDecodeError: + pass + + if isinstance(parsed, dict) and parsed: + return parsed + return None + + async def _get_hierarchical_router_settings( + self, + user_api_key_dict: Optional["UserAPIKeyAuth"], + prisma_client: Optional[PrismaClient], + proxy_logging_obj: Optional["ProxyLogging"] = None, + ) -> Optional[dict]: + """ + Get router_settings in priority order: Key > Team + + Uses the already-cached key object and the cached team lookup + (get_team_object) to avoid direct DB queries on the hot path. + + Global router_settings are NOT looked up here — they are already + applied to the Router object at config-load / DB-sync time. + + Returns: + dict: router_settings, or None if no settings found + """ + # 1. Try key-level router_settings + # user_api_key_dict is already the cached/authenticated key object — + # no DB call needed. + if user_api_key_dict is not None: + key_settings = self._parse_router_settings_value( + getattr(user_api_key_dict, "router_settings", None) + ) + if key_settings is not None: + return key_settings + + # 2. Try team-level router_settings using cached team lookup + # get_team_object checks in-memory cache / Redis first, only falls + # back to DB on a cache miss. + if user_api_key_dict is not None and user_api_key_dict.team_id is not None: + try: + team_obj = await get_team_object( + team_id=user_api_key_dict.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + team_settings = self._parse_router_settings_value( + getattr(team_obj, "router_settings", None) + ) + if team_settings is not None: + return team_settings + except Exception: + # If team lookup fails, no team-level settings available + pass + + return None + async def _add_router_settings_from_db_config( self, config_data: dict, diff --git a/litellm/proxy/route_llm_request.py b/litellm/proxy/route_llm_request.py index 6baa2047d73..e941964644e 100644 --- a/litellm/proxy/route_llm_request.py +++ b/litellm/proxy/route_llm_request.py @@ -210,18 +210,33 @@ async def route_request( models = [model.strip() for model in data.pop("model").split(",")] return llm_router.abatch_completion(models=models, **data) - elif "user_config" in data: - router_config = data.pop("user_config") + elif "router_settings_override" in data: + # Apply per-request router settings overrides from key/team config + # Instead of creating a new Router (expensive), merge settings into kwargs + # The Router already supports per-request overrides for these settings + override_settings = data.pop("router_settings_override") - # Filter router_config to only include valid Router.__init__ arguments - # This prevents TypeError when invalid parameters are stored in the database - valid_args = litellm.Router.get_valid_args() - filtered_config = {k: v for k, v in router_config.items() if k in valid_args} + # Settings that the Router accepts as per-request kwargs + # These override the global router settings for this specific request + per_request_settings = [ + "fallbacks", + "context_window_fallbacks", + "content_policy_fallbacks", + "num_retries", + "timeout", + "model_group_retry_policy", + ] - user_router = litellm.Router(**filtered_config) - ret_val = getattr(user_router, f"{route_type}")(**data) - user_router.discard() - return ret_val + # Merge override settings into data (only if not already set in request) + for key in per_request_settings: + if key in override_settings and key not in data: + data[key] = override_settings[key] + + # Use main router with overridden kwargs + if llm_router is not None: + return getattr(llm_router, f"{route_type}")(**data) + else: + return getattr(litellm, f"{route_type}")(**data) elif llm_router is not None: # Skip model-based routing for container operations if route_type in [ diff --git a/litellm/router.py b/litellm/router.py index 6dc3278225e..374f80db361 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4893,6 +4893,10 @@ async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915 content_policy_fallbacks = kwargs.pop( "content_policy_fallbacks", self.content_policy_fallbacks ) + # Support per-request model_group_retry_policy override (from key/team settings) + model_group_retry_policy = kwargs.pop( + "model_group_retry_policy", self.model_group_retry_policy + ) model_group: Optional[str] = kwargs.get("model") num_retries = kwargs.pop("num_retries") @@ -4941,7 +4945,7 @@ async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915 _retry_policy_applies = False if ( self.retry_policy is not None - or self.model_group_retry_policy is not None + or model_group_retry_policy is not None ): # get num_retries from retry policy # Use the model_group captured at the start of the function, or get it from metadata @@ -4949,9 +4953,12 @@ async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915 _model_group_for_retry_policy = ( model_group or _metadata.get("model_group") or kwargs.get("model") ) - _retry_policy_retries = self.get_num_retries_from_retry_policy( + # Use per-request model_group_retry_policy if provided, otherwise use self + _retry_policy_retries = _get_num_retries_from_retry_policy( exception=original_exception, model_group=_model_group_for_retry_policy, + model_group_retry_policy=model_group_retry_policy, + retry_policy=self.retry_policy, ) if _retry_policy_retries is not None: num_retries = _retry_policy_retries diff --git a/tests/test_litellm/proxy/test_common_request_processing.py b/tests/test_litellm/proxy/test_common_request_processing.py index 69cf8240c63..7bebe00d61e 100644 --- a/tests/test_litellm/proxy/test_common_request_processing.py +++ b/tests/test_litellm/proxy/test_common_request_processing.py @@ -77,6 +77,93 @@ async def mock_common_processing_pre_call_logic( pytest.fail("litellm_call_id is not a valid UUID") assert data_passed["litellm_call_id"] == returned_data["litellm_call_id"] + @pytest.mark.asyncio + async def test_should_apply_hierarchical_router_settings_as_override( + self, monkeypatch + ): + """ + Test that hierarchical router settings are stored as router_settings_override + instead of creating a full user_config with model_list. + + This approach avoids expensive per-request Router instantiation by passing + settings as kwargs overrides to the main router. + """ + processing_obj = ProxyBaseLLMRequestProcessing(data={}) + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + + async def mock_add_litellm_data_to_request(*args, **kwargs): + return {} + + async def mock_common_processing_pre_call_logic( + user_api_key_dict, data, call_type + ): + data_copy = copy.deepcopy(data) + return data_copy + + mock_proxy_logging_obj = MagicMock(spec=ProxyLogging) + mock_proxy_logging_obj.pre_call_hook = AsyncMock( + side_effect=mock_common_processing_pre_call_logic + ) + monkeypatch.setattr( + litellm.proxy.common_request_processing, + "add_litellm_data_to_request", + mock_add_litellm_data_to_request, + ) + + mock_general_settings = {} + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + mock_proxy_config = MagicMock(spec=ProxyConfig) + + mock_router_settings = { + "routing_strategy": "least-busy", + "timeout": 30.0, + "num_retries": 3, + } + mock_proxy_config._get_hierarchical_router_settings = AsyncMock( + return_value=mock_router_settings + ) + + mock_llm_router = MagicMock() + + mock_prisma_client = MagicMock() + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", + mock_prisma_client, + ) + + route_type = "acompletion" + + returned_data, logging_obj = await processing_obj.common_processing_pre_call_logic( + request=mock_request, + general_settings=mock_general_settings, + user_api_key_dict=mock_user_api_key_dict, + proxy_logging_obj=mock_proxy_logging_obj, + proxy_config=mock_proxy_config, + route_type=route_type, + llm_router=mock_llm_router, + ) + + mock_proxy_config._get_hierarchical_router_settings.assert_called_once_with( + user_api_key_dict=mock_user_api_key_dict, + prisma_client=mock_prisma_client, + proxy_logging_obj=mock_proxy_logging_obj, + ) + # get_model_list should NOT be called - we no longer copy model list for per-request routers + mock_llm_router.get_model_list.assert_not_called() + + # Settings should be stored as router_settings_override (not user_config) + # This allows passing them as kwargs to the main router instead of creating a new one + assert "router_settings_override" in returned_data + assert "user_config" not in returned_data + + router_settings_override = returned_data["router_settings_override"] + assert router_settings_override["routing_strategy"] == "least-busy" + assert router_settings_override["timeout"] == 30.0 + assert router_settings_override["num_retries"] == 3 + # model_list should NOT be in the override settings + assert "model_list" not in router_settings_override + @pytest.mark.asyncio async def test_stream_timeout_header_processing(self): """ diff --git a/tests/test_litellm/proxy/test_route_llm_request.py b/tests/test_litellm/proxy/test_route_llm_request.py index 90eace63714..1283d2ccbe7 100644 --- a/tests/test_litellm/proxy/test_route_llm_request.py +++ b/tests/test_litellm/proxy/test_route_llm_request.py @@ -137,62 +137,103 @@ async def test_route_request_no_model_required_with_router_settings_and_no_route @pytest.mark.asyncio -async def test_route_request_with_invalid_router_params(): +async def test_route_request_with_router_settings_override(): """ - Test that route_request filters out invalid Router init params from 'user_config'. - This covers the fix for https://github.com/BerriAI/litellm/issues/19693 + Test that route_request handles router_settings_override by merging settings into kwargs + instead of creating a new Router (which is expensive and was the old behavior). + """ + # Mock data with router_settings_override containing per-request settings + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + "router_settings_override": { + "fallbacks": [{"gpt-3.5-turbo": ["gpt-4"]}], + "num_retries": 5, + "timeout": 30, + "model_group_retry_policy": {"gpt-3.5-turbo": {"RateLimitErrorRetries": 3}}, + # These settings should be ignored (not in per_request_settings list) + "routing_strategy": "least-busy", + "model_group_alias": {"alias": "real_model"}, + }, + } + + llm_router = MagicMock() + llm_router.acompletion.return_value = "success" + + response = await route_request(data, llm_router, None, "acompletion") + + assert response == "success" + # Verify the router method was called with merged settings + call_kwargs = llm_router.acompletion.call_args[1] + assert call_kwargs["fallbacks"] == [{"gpt-3.5-turbo": ["gpt-4"]}] + assert call_kwargs["num_retries"] == 5 + assert call_kwargs["timeout"] == 30 + assert call_kwargs["model_group_retry_policy"] == {"gpt-3.5-turbo": {"RateLimitErrorRetries": 3}} + # Verify unsupported settings were NOT merged + assert "routing_strategy" not in call_kwargs + assert "model_group_alias" not in call_kwargs + # Verify router_settings_override was removed from data + assert "router_settings_override" not in call_kwargs + + +@pytest.mark.asyncio +async def test_route_request_with_router_settings_override_no_router(): + """ + Test that router_settings_override works when no router is provided, + falling back to litellm module directly. """ import litellm - from litellm.router import Router - from unittest.mock import AsyncMock - # Mock data with user_config containing invalid keys (simulating DB entry) data = { "model": "gpt-3.5-turbo", - "user_config": { - "model_list": [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": {"model": "gpt-3.5-turbo", "api_key": "test"}, - } - ], - "model_alias_map": {"alias": "real_model"}, # INVALID PARAM - "invalid_garbage_key": "crash_me", # INVALID PARAM + "messages": [{"role": "user", "content": "Hello"}], + "router_settings_override": { + "fallbacks": [{"gpt-3.5-turbo": ["gpt-4"]}], + "num_retries": 3, }, } - # We expect Router(**config) to succeed because of the filtering. - # If filtering fails, this will raise TypeError and fail the test. + # Use MagicMock explicitly to avoid auto-AsyncMock behavior in Python 3.12+ + mock_completion = MagicMock(return_value="success") + original_acompletion = litellm.acompletion + litellm.acompletion = mock_completion + try: - # route_request calls getattr(user_router, route_type)(**data) - # We'll mock the internal call to avoid making real network requests - with pytest.MonkeyPatch.context() as m: - # Mock the method that gets called on the router instance - # We don't easily have access to the instance created INSIDE existing route_request - # So we will wrap litellm.Router to spy on it or verify it doesn't crash - - original_router_init = litellm.Router.__init__ - - def safe_router_init(self, **kwargs): - # Verify that invalid keys are NOT present in kwargs - assert "model_alias_map" not in kwargs - assert "invalid_garbage_key" not in kwargs - # Call original init (which would raise TypeError if invalid keys were present) - original_router_init(self, **kwargs) - - m.setattr(litellm.Router, "__init__", safe_router_init) - - # Use 'acompletion' as the route_type - # We also need to mock the completion method to avoid real calls - m.setattr(Router, "acompletion", AsyncMock(return_value="success")) - - response = await route_request(data, None, None, "acompletion") - assert response == "success" - - except TypeError as e: - pytest.fail( - f"route_request raised TypeError, implying invalid params were passed to Router: {e}" - ) - except Exception: - # Other exceptions might happen (e.g. valid config issues) but we care about TypeError here - pass + response = await route_request(data, None, None, "acompletion") + + assert response == "success" + # Verify litellm.acompletion was called with merged settings + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["fallbacks"] == [{"gpt-3.5-turbo": ["gpt-4"]}] + assert call_kwargs["num_retries"] == 3 + finally: + litellm.acompletion = original_acompletion + + +@pytest.mark.asyncio +async def test_route_request_with_router_settings_override_preserves_existing(): + """ + Test that router_settings_override does not override settings already in the request. + Request-level settings take precedence over key/team settings. + """ + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + "num_retries": 10, # Request-level setting + "router_settings_override": { + "num_retries": 3, # Key/team setting - should NOT override + "timeout": 30, # Key/team setting - should be applied + }, + } + + llm_router = MagicMock() + llm_router.acompletion.return_value = "success" + + response = await route_request(data, llm_router, None, "acompletion") + + assert response == "success" + call_kwargs = llm_router.acompletion.call_args[1] + # Request-level num_retries should take precedence + assert call_kwargs["num_retries"] == 10 + # Key/team timeout should be applied since not in request + assert call_kwargs["timeout"] == 30