Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,7 @@ class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase):
rotation_interval: Optional[str] = None # How often to rotate (e.g., "30d", "90d")
last_rotation_at: Optional[datetime] = None # When this key was last rotated
key_rotation_at: Optional[datetime] = None # When this key should next be rotated
router_settings: Optional[dict] = None
model_config = ConfigDict(protected_namespaces=())


Expand Down
17 changes: 17 additions & 0 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,23 @@ async def common_processing_pre_call_logic(
user_api_key_dict=user_api_key_dict, data=self.data, call_type=route_type # type: ignore
)

# Apply hierarchical router_settings (Key > Team)
# Global router_settings are already on the Router object itself.
if llm_router is not None and proxy_config is not None:
from litellm.proxy.proxy_server import prisma_client

router_settings = await proxy_config._get_hierarchical_router_settings(
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
)

# If router_settings found (from key or team), apply them
# Pass settings as per-request overrides instead of creating a new Router
# This avoids expensive Router instantiation on each request
if router_settings is not None:
self.data["router_settings_override"] = router_settings

if "messages" in self.data and self.data["messages"]:
logging_obj.update_messages(self.data["messages"])

Expand Down
103 changes: 92 additions & 11 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3249,17 +3249,18 @@ async def _update_llm_router(
_model_list: list = self.decrypt_model_list_from_db(
new_models=models_list
)
if len(_model_list) > 0:
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
llm_router = litellm.Router(
model_list=_model_list,
router_general_settings=RouterGeneralSettings(
async_only_mode=True # only init async clients
),
search_tools=search_tools,
ignore_invalid_deployments=True,
)
verbose_proxy_logger.debug(f"updated llm_router: {llm_router}")
# Create router even with empty model list to support search_tools
# Router can function with model_list=[] and only search_tools
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
llm_router = litellm.Router(
model_list=_model_list,
router_general_settings=RouterGeneralSettings(
async_only_mode=True # only init async clients
),
search_tools=search_tools,
ignore_invalid_deployments=True,
)
verbose_proxy_logger.debug(f"updated llm_router: {llm_router}")
else:
verbose_proxy_logger.debug(f"len new_models: {len(models_list)}")
if search_tools is not None and llm_router is not None:
Expand Down Expand Up @@ -3402,6 +3403,86 @@ def _decrypt_db_variables(self, variables_dict: dict) -> dict:
decrypted_variables[k] = decrypted_value
return decrypted_variables

@staticmethod
def _parse_router_settings_value(value: Any) -> Optional[dict]:
"""
Parse a router_settings value that may be a dict or a JSON/YAML string.

Returns a non-empty dict if valid, otherwise None.
"""
if value is None:
return None

parsed: Optional[dict] = None
if isinstance(value, dict):
parsed = value
elif isinstance(value, str):
import json

import yaml

try:
parsed = yaml.safe_load(value)
except (yaml.YAMLError, json.JSONDecodeError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yaml.safe_load() doesn't raise JSONDecodeError, only YAMLError. The json.JSONDecodeError in the except clause won't catch yaml parsing failures.

Suggested change
except (yaml.YAMLError, json.JSONDecodeError):
except yaml.YAMLError:

try:
parsed = json.loads(value)
except json.JSONDecodeError:
pass

if isinstance(parsed, dict) and parsed:
return parsed
return None

async def _get_hierarchical_router_settings(
self,
user_api_key_dict: Optional["UserAPIKeyAuth"],
prisma_client: Optional[PrismaClient],
proxy_logging_obj: Optional["ProxyLogging"] = None,
) -> Optional[dict]:
"""
Get router_settings in priority order: Key > Team

Uses the already-cached key object and the cached team lookup
(get_team_object) to avoid direct DB queries on the hot path.

Global router_settings are NOT looked up here — they are already
applied to the Router object at config-load / DB-sync time.

Returns:
dict: router_settings, or None if no settings found
"""
# 1. Try key-level router_settings
# user_api_key_dict is already the cached/authenticated key object —
# no DB call needed.
if user_api_key_dict is not None:
key_settings = self._parse_router_settings_value(
getattr(user_api_key_dict, "router_settings", None)
)
if key_settings is not None:
return key_settings

# 2. Try team-level router_settings using cached team lookup
# get_team_object checks in-memory cache / Redis first, only falls
# back to DB on a cache miss.
if user_api_key_dict is not None and user_api_key_dict.team_id is not None:
try:
team_obj = await get_team_object(
team_id=user_api_key_dict.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
team_settings = self._parse_router_settings_value(
getattr(team_obj, "router_settings", None)
)
if team_settings is not None:
return team_settings
except Exception:
# If team lookup fails, no team-level settings available
pass

return None

async def _add_router_settings_from_db_config(
self,
config_data: dict,
Expand Down
35 changes: 25 additions & 10 deletions litellm/proxy/route_llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,33 @@ async def route_request(
models = [model.strip() for model in data.pop("model").split(",")]
return llm_router.abatch_completion(models=models, **data)

elif "user_config" in data:
router_config = data.pop("user_config")
elif "router_settings_override" in data:
# Apply per-request router settings overrides from key/team config
# Instead of creating a new Router (expensive), merge settings into kwargs
# The Router already supports per-request overrides for these settings
override_settings = data.pop("router_settings_override")

# Filter router_config to only include valid Router.__init__ arguments
# This prevents TypeError when invalid parameters are stored in the database
valid_args = litellm.Router.get_valid_args()
filtered_config = {k: v for k, v in router_config.items() if k in valid_args}
# Settings that the Router accepts as per-request kwargs
# These override the global router settings for this specific request
per_request_settings = [
"fallbacks",
"context_window_fallbacks",
"content_policy_fallbacks",
"num_retries",
"timeout",
"model_group_retry_policy",
]

user_router = litellm.Router(**filtered_config)
ret_val = getattr(user_router, f"{route_type}")(**data)
user_router.discard()
return ret_val
# Merge override settings into data (only if not already set in request)
for key in per_request_settings:
if key in override_settings and key not in data:
data[key] = override_settings[key]

# Use main router with overridden kwargs
if llm_router is not None:
return getattr(llm_router, f"{route_type}")(**data)
else:
return getattr(litellm, f"{route_type}")(**data)
elif llm_router is not None:
# Skip model-based routing for container operations
if route_type in [
Expand Down
11 changes: 9 additions & 2 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4893,6 +4893,10 @@ async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
# Support per-request model_group_retry_policy override (from key/team settings)
model_group_retry_policy = kwargs.pop(
"model_group_retry_policy", self.model_group_retry_policy
)
model_group: Optional[str] = kwargs.get("model")
num_retries = kwargs.pop("num_retries")

Expand Down Expand Up @@ -4941,17 +4945,20 @@ async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
_retry_policy_applies = False
if (
self.retry_policy is not None
or self.model_group_retry_policy is not None
or model_group_retry_policy is not None
):
# get num_retries from retry policy
# Use the model_group captured at the start of the function, or get it from metadata
# kwargs.get("model") at this point is the deployment model, not the model_group
_model_group_for_retry_policy = (
model_group or _metadata.get("model_group") or kwargs.get("model")
)
_retry_policy_retries = self.get_num_retries_from_retry_policy(
# Use per-request model_group_retry_policy if provided, otherwise use self
_retry_policy_retries = _get_num_retries_from_retry_policy(
exception=original_exception,
model_group=_model_group_for_retry_policy,
model_group_retry_policy=model_group_retry_policy,
retry_policy=self.retry_policy,
)
if _retry_policy_retries is not None:
num_retries = _retry_policy_retries
Expand Down
87 changes: 87 additions & 0 deletions tests/test_litellm/proxy/test_common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,93 @@ async def mock_common_processing_pre_call_logic(
pytest.fail("litellm_call_id is not a valid UUID")
assert data_passed["litellm_call_id"] == returned_data["litellm_call_id"]

@pytest.mark.asyncio
async def test_should_apply_hierarchical_router_settings_as_override(
self, monkeypatch
):
"""
Test that hierarchical router settings are stored as router_settings_override
instead of creating a full user_config with model_list.

This approach avoids expensive per-request Router instantiation by passing
settings as kwargs overrides to the main router.
"""
processing_obj = ProxyBaseLLMRequestProcessing(data={})
mock_request = MagicMock(spec=Request)
mock_request.headers = {}

async def mock_add_litellm_data_to_request(*args, **kwargs):
return {}

async def mock_common_processing_pre_call_logic(
user_api_key_dict, data, call_type
):
data_copy = copy.deepcopy(data)
return data_copy

mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
mock_proxy_logging_obj.pre_call_hook = AsyncMock(
side_effect=mock_common_processing_pre_call_logic
)
monkeypatch.setattr(
litellm.proxy.common_request_processing,
"add_litellm_data_to_request",
mock_add_litellm_data_to_request,
)

mock_general_settings = {}
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_proxy_config = MagicMock(spec=ProxyConfig)

mock_router_settings = {
"routing_strategy": "least-busy",
"timeout": 30.0,
"num_retries": 3,
}
mock_proxy_config._get_hierarchical_router_settings = AsyncMock(
return_value=mock_router_settings
)

mock_llm_router = MagicMock()

mock_prisma_client = MagicMock()
monkeypatch.setattr(
"litellm.proxy.proxy_server.prisma_client",
mock_prisma_client,
)

route_type = "acompletion"

returned_data, logging_obj = await processing_obj.common_processing_pre_call_logic(
request=mock_request,
general_settings=mock_general_settings,
user_api_key_dict=mock_user_api_key_dict,
proxy_logging_obj=mock_proxy_logging_obj,
proxy_config=mock_proxy_config,
route_type=route_type,
llm_router=mock_llm_router,
)

mock_proxy_config._get_hierarchical_router_settings.assert_called_once_with(
user_api_key_dict=mock_user_api_key_dict,
prisma_client=mock_prisma_client,
proxy_logging_obj=mock_proxy_logging_obj,
)
# get_model_list should NOT be called - we no longer copy model list for per-request routers
mock_llm_router.get_model_list.assert_not_called()

# Settings should be stored as router_settings_override (not user_config)
# This allows passing them as kwargs to the main router instead of creating a new one
assert "router_settings_override" in returned_data
assert "user_config" not in returned_data

router_settings_override = returned_data["router_settings_override"]
assert router_settings_override["routing_strategy"] == "least-busy"
assert router_settings_override["timeout"] == 30.0
assert router_settings_override["num_retries"] == 3
# model_list should NOT be in the override settings
assert "model_list" not in router_settings_override

@pytest.mark.asyncio
async def test_stream_timeout_header_processing(self):
"""
Expand Down
Loading
Loading