Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 2 additions & 23 deletions litellm/proxy/auth/model_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,6 @@ def _get_models_from_access_groups(
return all_models


def get_access_groups_from_models(
model_access_groups: Dict[str, List[str]],
models: List[str],
) -> List[str]:
"""
Extract access group names from a models list.

Given a models list like ["gpt-4", "beta-models", "claude-v1"]
and access groups like {"beta-models": ["gpt-5", "gpt-6"]},
returns ["beta-models"].

This is used to pass allowed access groups to the router for filtering
deployments during load balancing (GitHub issue #18333).
"""
access_groups = []
for model in models:
if model in model_access_groups:
access_groups.append(model)
return access_groups


async def get_mcp_server_ids(
user_api_key_dict: UserAPIKeyAuth,
) -> List[str]:
Expand All @@ -101,6 +80,7 @@ async def get_mcp_server_ids(

# Make a direct SQL query to get just the mcp_servers
try:

result = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": user_api_key_dict.object_permission_id},
)
Expand Down Expand Up @@ -196,7 +176,6 @@ def get_complete_model_list(
"""

unique_models = []

def append_unique(models):
for model in models:
if model not in unique_models:
Expand All @@ -209,7 +188,7 @@ def append_unique(models):
else:
append_unique(proxy_model_list)
if include_model_access_groups:
append_unique(list(model_access_groups.keys())) # TODO: keys order
append_unique(list(model_access_groups.keys())) # TODO: keys order

if user_model:
append_unique([user_model])
Expand Down
59 changes: 14 additions & 45 deletions litellm/proxy/litellm_pre_call_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,12 @@ def _get_dynamic_logging_metadata(
user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig
) -> Optional[TeamCallbackMetadata]:
callback_settings_obj: Optional[TeamCallbackMetadata] = None
key_dynamic_logging_settings: Optional[dict] = (
KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict)
)
team_dynamic_logging_settings: Optional[dict] = (
KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict)
)
key_dynamic_logging_settings: Optional[
dict
] = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict)
team_dynamic_logging_settings: Optional[
dict
] = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict)
#########################################################################################
# Key-based callbacks
#########################################################################################
Expand Down Expand Up @@ -661,11 +661,11 @@ def add_key_level_controls(

## KEY-LEVEL SPEND LOGS / TAGS
if "tags" in key_metadata and key_metadata["tags"] is not None:
data[_metadata_variable_name]["tags"] = (
LiteLLMProxyRequestSetup._merge_tags(
request_tags=data[_metadata_variable_name].get("tags"),
tags_to_add=key_metadata["tags"],
)
data[_metadata_variable_name][
"tags"
] = LiteLLMProxyRequestSetup._merge_tags(
request_tags=data[_metadata_variable_name].get("tags"),
tags_to_add=key_metadata["tags"],
)
if "disable_global_guardrails" in key_metadata and isinstance(
key_metadata["disable_global_guardrails"], bool
Expand Down Expand Up @@ -933,9 +933,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
data[_metadata_variable_name]["litellm_api_version"] = version

if general_settings is not None:
data[_metadata_variable_name]["global_max_parallel_requests"] = (
general_settings.get("global_max_parallel_requests", None)
)
data[_metadata_variable_name][
"global_max_parallel_requests"
] = general_settings.get("global_max_parallel_requests", None)

### KEY-LEVEL Controls
key_metadata = user_api_key_dict.metadata
Expand Down Expand Up @@ -1002,37 +1002,6 @@ async def add_litellm_data_to_request( # noqa: PLR0915
"user_api_key_model_max_budget"
] = user_api_key_dict.model_max_budget

# Extract allowed access groups for router filtering (GitHub issue #18333)
# This allows the router to filter deployments based on key's and team's access groups
# NOTE: We keep key and team access groups SEPARATE because a key doesn't always
# inherit all team access groups (per maintainer feedback).
if llm_router is not None:
from litellm.proxy.auth.model_checks import get_access_groups_from_models

model_access_groups = llm_router.get_model_access_groups()

# Key-level access groups (from user_api_key_dict.models)
key_models = list(user_api_key_dict.models) if user_api_key_dict.models else []
key_allowed_access_groups = get_access_groups_from_models(
model_access_groups=model_access_groups, models=key_models
)
if key_allowed_access_groups:
data[_metadata_variable_name][
"user_api_key_allowed_access_groups"
] = key_allowed_access_groups

# Team-level access groups (from user_api_key_dict.team_models)
team_models = (
list(user_api_key_dict.team_models) if user_api_key_dict.team_models else []
)
team_allowed_access_groups = get_access_groups_from_models(
model_access_groups=model_access_groups, models=team_models
)
if team_allowed_access_groups:
data[_metadata_variable_name][
"user_api_key_team_allowed_access_groups"
] = team_allowed_access_groups

data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
Expand Down
12 changes: 2 additions & 10 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@
is_clientside_credential,
)
from litellm.router_utils.common_utils import (
filter_deployments_by_access_groups,
filter_team_based_models,
filter_web_search_deployments,
)
Expand Down Expand Up @@ -7847,17 +7846,10 @@ async def async_get_healthy_deployments(
request_kwargs=request_kwargs,
)

verbose_router_logger.debug(f"healthy_deployments after web search filter: {healthy_deployments}")

# Filter by allowed access groups (GitHub issue #18333)
# This prevents cross-team load balancing when teams have models with same name in different access groups
healthy_deployments = filter_deployments_by_access_groups(
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
verbose_router_logger.debug(
f"healthy_deployments after web search filter: {healthy_deployments}"
)

verbose_router_logger.debug(f"healthy_deployments after access group filter: {healthy_deployments}")

if isinstance(healthy_deployments, dict):
return healthy_deployments

Expand Down
79 changes: 2 additions & 77 deletions litellm/router_utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def filter_team_based_models(
if deployment.get("model_info", {}).get("id") not in ids_to_remove
]


def _deployment_supports_web_search(deployment: Dict) -> bool:
"""
Check if a deployment supports web search.
Expand Down Expand Up @@ -113,7 +112,7 @@ def filter_web_search_deployments(
is_web_search_request = False
tools = request_kwargs.get("tools") or []
for tool in tools:
# These are the two websearch tools for OpenAI / Azure.
# These are the two websearch tools for OpenAI / Azure.
if tool.get("type") == "web_search" or tool.get("type") == "web_search_preview":
is_web_search_request = True
break
Expand All @@ -122,82 +121,8 @@ def filter_web_search_deployments(
return healthy_deployments

# Filter out deployments that don't support web search
final_deployments = [
d for d in healthy_deployments if _deployment_supports_web_search(d)
]
final_deployments = [d for d in healthy_deployments if _deployment_supports_web_search(d)]
if len(healthy_deployments) > 0 and len(final_deployments) == 0:
verbose_logger.warning("No deployments support web search for request")
return final_deployments


def filter_deployments_by_access_groups(
healthy_deployments: Union[List[Dict], Dict],
request_kwargs: Optional[Dict] = None,
) -> Union[List[Dict], Dict]:
"""
Filter deployments to only include those matching the user's allowed access groups.

Reads from TWO separate metadata fields (per maintainer feedback):
- `user_api_key_allowed_access_groups`: Access groups from the API Key's models.
- `user_api_key_team_allowed_access_groups`: Access groups from the Team's models.

A deployment is included if its access_groups overlap with EITHER the key's
or the team's allowed access groups. Deployments with no access_groups are
always included (not restricted).

This prevents cross-team load balancing when multiple teams have models with
the same name but in different access groups (GitHub issue #18333).
"""
if request_kwargs is None:
return healthy_deployments

if isinstance(healthy_deployments, dict):
return healthy_deployments

metadata = request_kwargs.get("metadata") or {}
litellm_metadata = request_kwargs.get("litellm_metadata") or {}

# Gather key-level allowed access groups
key_allowed_access_groups = (
metadata.get("user_api_key_allowed_access_groups")
or litellm_metadata.get("user_api_key_allowed_access_groups")
or []
)

# Gather team-level allowed access groups
team_allowed_access_groups = (
metadata.get("user_api_key_team_allowed_access_groups")
or litellm_metadata.get("user_api_key_team_allowed_access_groups")
or []
)

# Combine both for the final allowed set
combined_allowed_access_groups = list(key_allowed_access_groups) + list(
team_allowed_access_groups
)

# If no access groups specified from either source, return all deployments (backwards compatible)
if not combined_allowed_access_groups:
return healthy_deployments

allowed_set = set(combined_allowed_access_groups)
filtered = []
for deployment in healthy_deployments:
model_info = deployment.get("model_info") or {}
deployment_access_groups = model_info.get("access_groups") or []

# If deployment has no access groups, include it (not restricted)
if not deployment_access_groups:
filtered.append(deployment)
continue

# Include if any of deployment's groups overlap with allowed groups
if set(deployment_access_groups) & allowed_set:
filtered.append(deployment)

if len(healthy_deployments) > 0 and len(filtered) == 0:
verbose_logger.warning(
f"No deployments match allowed access groups {combined_allowed_access_groups}"
)

return filtered
Loading
Loading