Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions litellm/proxy/auth/model_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,27 @@ def _get_models_from_access_groups(
return all_models


def get_access_groups_from_models(
model_access_groups: Dict[str, List[str]],
models: List[str],
) -> List[str]:
"""
Extract access group names from a models list.

Given a models list like ["gpt-4", "beta-models", "claude-v1"]
and access groups like {"beta-models": ["gpt-5", "gpt-6"]},
returns ["beta-models"].

This is used to pass allowed access groups to the router for filtering
deployments during load balancing (GitHub issue #18333).
"""
access_groups = []
for model in models:
if model in model_access_groups:
access_groups.append(model)
return access_groups


async def get_mcp_server_ids(
user_api_key_dict: UserAPIKeyAuth,
) -> List[str]:
Expand All @@ -80,7 +101,6 @@ async def get_mcp_server_ids(

# Make a direct SQL query to get just the mcp_servers
try:

result = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": user_api_key_dict.object_permission_id},
)
Expand Down Expand Up @@ -176,6 +196,7 @@ def get_complete_model_list(
"""

unique_models = []

def append_unique(models):
for model in models:
if model not in unique_models:
Expand All @@ -188,7 +209,7 @@ def append_unique(models):
else:
append_unique(proxy_model_list)
if include_model_access_groups:
append_unique(list(model_access_groups.keys())) # TODO: keys order
append_unique(list(model_access_groups.keys())) # TODO: keys order

if user_model:
append_unique([user_model])
Expand Down
32 changes: 27 additions & 5 deletions litellm/proxy/litellm_pre_call_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import copy
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeAlias, Union

from fastapi import Request
from starlette.datastructures import Headers
Expand Down Expand Up @@ -41,7 +41,7 @@
if TYPE_CHECKING:
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig

ProxyConfig = _ProxyConfig
ProxyConfig: TypeAlias = _ProxyConfig
else:
ProxyConfig = Any

Expand Down Expand Up @@ -562,11 +562,15 @@ def add_litellm_metadata_from_request_headers(
trace_id_from_header = headers.get("x-litellm-trace-id")
if agent_id_from_header:
metadata_from_headers["agent_id"] = agent_id_from_header
verbose_proxy_logger.debug(f"Extracted agent_id from header: {agent_id_from_header}")

verbose_proxy_logger.debug(
f"Extracted agent_id from header: {agent_id_from_header}"
)

if trace_id_from_header:
metadata_from_headers["trace_id"] = trace_id_from_header
verbose_proxy_logger.debug(f"Extracted trace_id from header: {trace_id_from_header}")
verbose_proxy_logger.debug(
f"Extracted trace_id from header: {trace_id_from_header}"
)

if isinstance(data[_metadata_variable_name], dict):
data[_metadata_variable_name].update(metadata_from_headers)
Expand Down Expand Up @@ -1021,6 +1025,24 @@ async def add_litellm_data_to_request( # noqa: PLR0915
"user_api_key_user_max_budget"
] = user_api_key_dict.user_max_budget

# Extract allowed access groups for router filtering (GitHub issue #18333)
# This allows the router to filter deployments based on team's access groups
if llm_router is not None:
from litellm.proxy.auth.model_checks import get_access_groups_from_models

model_access_groups = llm_router.get_model_access_groups()
# Combine key models and team models to get all allowed access groups
all_models = list(user_api_key_dict.models) + list(
user_api_key_dict.team_models or []
)
allowed_access_groups = get_access_groups_from_models(
model_access_groups=model_access_groups, models=all_models
)
if allowed_access_groups:
data[_metadata_variable_name][
"user_api_key_allowed_access_groups"
] = allowed_access_groups

data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
Expand Down
50 changes: 39 additions & 11 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
is_clientside_credential,
)
from litellm.router_utils.common_utils import (
filter_deployments_by_access_groups,
filter_team_based_models,
filter_web_search_deployments,
)
Expand Down Expand Up @@ -619,11 +620,12 @@ def __init__( # noqa: PLR0915
self.retry_policy = RetryPolicy(**retry_policy)
elif isinstance(retry_policy, RetryPolicy):
self.retry_policy = retry_policy
verbose_router_logger.info(
"\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format(
self.retry_policy.model_dump(exclude_none=True)
if self.retry_policy is not None:
verbose_router_logger.info(
"\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format(
self.retry_policy.model_dump(exclude_none=True)
)
)
)

self.model_group_retry_policy: Optional[
Dict[str, RetryPolicy]
Expand All @@ -636,11 +638,12 @@ def __init__( # noqa: PLR0915
elif isinstance(allowed_fails_policy, AllowedFailsPolicy):
self.allowed_fails_policy = allowed_fails_policy

verbose_router_logger.info(
"\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format(
self.allowed_fails_policy.model_dump(exclude_none=True)
if self.allowed_fails_policy is not None:
verbose_router_logger.info(
"\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format(
self.allowed_fails_policy.model_dump(exclude_none=True)
)
)
)

self.alerting_config: Optional[AlertingConfig] = alerting_config

Expand Down Expand Up @@ -4769,9 +4772,23 @@ async def async_function_with_fallbacks_common_utils( # noqa: PLR0915

if hasattr(original_exception, "message"):
# add the available fallbacks to the exception
original_exception.message += ". Received Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
model_group,
fallback_model_group,
deployment_info = ""
if kwargs is not None:
metadata = kwargs.get("metadata", {})
if metadata and "deployment" in metadata:
deployment_info = f"\nUsed Deployment: {metadata['deployment']}"
if "model_info" in metadata:
model_info = metadata["model_info"]
if isinstance(model_info, dict):
deployment_info += (
f"\nDeployment ID: {model_info.get('id', 'unknown')}"
)

original_exception.message += ( # type: ignore
f". Received Model Group={model_group}"
f"\nAvailable Model Group Fallbacks={fallback_model_group}"
f"{deployment_info}"
f"\n\n💡 Tip: If using wildcard patterns (e.g., 'openai/*'), ensure all matching deployments have credentials with access to this model."
)
if len(fallback_failure_exception_str) > 0:
original_exception.message += ( # type: ignore
Expand Down Expand Up @@ -8091,6 +8108,17 @@ async def async_get_healthy_deployments(
f"healthy_deployments after web search filter: {healthy_deployments}"
)

# Filter by allowed access groups (GitHub issue #18333)
# This prevents cross-team load balancing when teams have models with same name in different access groups
healthy_deployments = filter_deployments_by_access_groups(
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
)

verbose_router_logger.debug(
f"healthy_deployments after access group filter: {healthy_deployments}"
)

if isinstance(healthy_deployments, dict):
return healthy_deployments

Expand Down
65 changes: 60 additions & 5 deletions litellm/router_utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def add_model_file_id_mappings(
model_file_id_mapping = {}
if isinstance(healthy_deployments, list):
for deployment, response in zip(healthy_deployments, responses):
model_file_id_mapping[deployment.get("model_info", {}).get("id")] = (
response.id
)
model_file_id_mapping[
deployment.get("model_info", {}).get("id")
] = response.id
elif isinstance(healthy_deployments, dict):
for model_id, file_id in healthy_deployments.items():
model_file_id_mapping[model_id] = file_id
Expand Down Expand Up @@ -75,6 +75,7 @@ def filter_team_based_models(
if deployment.get("model_info", {}).get("id") not in ids_to_remove
]


def _deployment_supports_web_search(deployment: Dict) -> bool:
"""
Check if a deployment supports web search.
Expand Down Expand Up @@ -112,7 +113,7 @@ def filter_web_search_deployments(
is_web_search_request = False
tools = request_kwargs.get("tools") or []
for tool in tools:
# These are the two websearch tools for OpenAI / Azure.
# These are the two websearch tools for OpenAI / Azure.
if tool.get("type") == "web_search" or tool.get("type") == "web_search_preview":
is_web_search_request = True
break
Expand All @@ -121,8 +122,62 @@ def filter_web_search_deployments(
return healthy_deployments

# Filter out deployments that don't support web search
final_deployments = [d for d in healthy_deployments if _deployment_supports_web_search(d)]
final_deployments = [
d for d in healthy_deployments if _deployment_supports_web_search(d)
]
if len(healthy_deployments) > 0 and len(final_deployments) == 0:
verbose_logger.warning("No deployments support web search for request")
return final_deployments


def filter_deployments_by_access_groups(
healthy_deployments: Union[List[Dict], Dict],
request_kwargs: Optional[Dict] = None,
) -> Union[List[Dict], Dict]:
"""
Filter deployments to only include those matching the team's allowed access groups.

If the request includes `user_api_key_allowed_access_groups` in metadata,
only return deployments where at least one of the deployment's access_groups
matches the allowed list.

This prevents cross-team load balancing when multiple teams have models with
the same name but in different access groups (GitHub issue #18333).
"""
if request_kwargs is None:
return healthy_deployments

if isinstance(healthy_deployments, dict):
return healthy_deployments

metadata = request_kwargs.get("metadata") or {}
litellm_metadata = request_kwargs.get("litellm_metadata") or {}
allowed_access_groups = metadata.get(
"user_api_key_allowed_access_groups"
) or litellm_metadata.get("user_api_key_allowed_access_groups")

# If no access groups specified, return all deployments (backwards compatible)
if not allowed_access_groups:
return healthy_deployments

allowed_set = set(allowed_access_groups)
filtered = []
for deployment in healthy_deployments:
model_info = deployment.get("model_info") or {}
deployment_access_groups = model_info.get("access_groups") or []

# If deployment has no access groups, include it (not restricted)
if not deployment_access_groups:
filtered.append(deployment)
continue

# Include if any of deployment's groups overlap with allowed groups
if set(deployment_access_groups) & allowed_set:
filtered.append(deployment)

if len(healthy_deployments) > 0 and len(filtered) == 0:
verbose_logger.warning(
f"No deployments match allowed access groups {allowed_access_groups}"
)

return filtered
Loading
Loading