Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/my-website/docs/proxy/config_settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,6 @@ router_settings:
| DEFAULT_MAX_TOKENS | Default maximum tokens for LLM calls. Default is 4096
| DEFAULT_MAX_TOKENS_FOR_TRITON | Default maximum tokens for Triton models. Default is 2000
| DEFAULT_MAX_REDIS_BATCH_CACHE_SIZE | Default maximum size for redis batch cache. Default is 1000
| DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL | Default embedding model for MCP semantic tool filtering. Default is "text-embedding-3-small"
| DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD | Default similarity threshold for MCP semantic tool filtering. Default is 0.3
| DEFAULT_MCP_SEMANTIC_FILTER_TOP_K | Default number of top results to return for MCP semantic tool filtering. Default is 10
| DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT | Default token count for mock response completions. Default is 20
| DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT | Default token count for mock response prompts. Default is 10
| DEFAULT_MODEL_CREATED_AT_TIME | Default creation timestamp for models. Default is 1677610602
Expand Down Expand Up @@ -805,7 +802,6 @@ router_settings:
| MAXIMUM_TRACEBACK_LINES_TO_LOG | Maximum number of lines to log in traceback in LiteLLM Logs UI. Default is 100
| MAX_RETRY_DELAY | Maximum delay in seconds for retrying requests. Default is 8.0
| MAX_LANGFUSE_INITIALIZED_CLIENTS | Maximum number of Langfuse clients to initialize on proxy. Default is 50. This is set since langfuse initializes 1 thread everytime a client is initialized. We've had an incident in the past where we reached 100% cpu utilization because Langfuse was initialized several times.
| MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH | Maximum header length for MCP semantic filter tools. Default is 150
| MIN_NON_ZERO_TEMPERATURE | Minimum non-zero temperature value. Default is 0.0001
| MINIMUM_PROMPT_CACHE_TOKEN_COUNT | Minimum token count for caching a prompt. Default is 1024
| MISTRAL_API_BASE | Base URL for Mistral API. Default is https://api.mistral.ai
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- CreateIndex
CREATE INDEX "LiteLLM_VerificationToken_user_id_team_id_idx" ON "LiteLLM_VerificationToken"("user_id", "team_id");

-- CreateIndex
CREATE INDEX "LiteLLM_VerificationToken_team_id_idx" ON "LiteLLM_VerificationToken"("team_id");

-- CreateIndex
CREATE INDEX "LiteLLM_VerificationToken_budget_reset_at_expires_idx" ON "LiteLLM_VerificationToken"("budget_reset_at", "expires");
10 changes: 10 additions & 0 deletions litellm-proxy-extras/litellm_proxy_extras/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,16 @@ model LiteLLM_VerificationToken {
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id])

// SELECT COUNT(*) FROM (SELECT "public"."LiteLLM_VerificationToken"."token" FROM "public"."LiteLLM_VerificationToken" WHERE ("public"."LiteLLM_VerificationToken"."user_id" = $1 AND ("public"."LiteLLM_VerificationToken"."team_id" IS NULL OR "public"."LiteLLM_VerificationToken"."team_id" <> $2)) OFFSET $3 ) AS "sub"
// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."user_id" = $1 OFFSET $2
@@index([user_id, team_id])

// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."team_id" = $1 OFFSET $2
@@index([team_id])

// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (("public"."LiteLLM_VerificationToken"."expires" IS NULL OR "public"."LiteLLM_VerificationToken"."expires" > $1) AND "public"."LiteLLM_VerificationToken"."budget_reset_at" < $2) OFFSET $3
@@index([budget_reset_at, expires])
}

// Audit table for deleted keys - preserves spend and key information for historical tracking
Expand Down
25 changes: 23 additions & 2 deletions litellm/proxy/auth/model_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,27 @@ def _get_models_from_access_groups(
return all_models


def get_access_groups_from_models(
model_access_groups: Dict[str, List[str]],
models: List[str],
) -> List[str]:
"""
Extract access group names from a models list.

Given a models list like ["gpt-4", "beta-models", "claude-v1"]
and access groups like {"beta-models": ["gpt-5", "gpt-6"]},
returns ["beta-models"].

This is used to pass allowed access groups to the router for filtering
deployments during load balancing (GitHub issue #18333).
"""
access_groups = []
for model in models:
if model in model_access_groups:
access_groups.append(model)
return access_groups


async def get_mcp_server_ids(
user_api_key_dict: UserAPIKeyAuth,
) -> List[str]:
Expand All @@ -80,7 +101,6 @@ async def get_mcp_server_ids(

# Make a direct SQL query to get just the mcp_servers
try:

result = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": user_api_key_dict.object_permission_id},
)
Expand Down Expand Up @@ -176,6 +196,7 @@ def get_complete_model_list(
"""

unique_models = []

def append_unique(models):
for model in models:
if model not in unique_models:
Expand All @@ -188,7 +209,7 @@ def append_unique(models):
else:
append_unique(proxy_model_list)
if include_model_access_groups:
append_unique(list(model_access_groups.keys())) # TODO: keys order
append_unique(list(model_access_groups.keys())) # TODO: keys order

if user_model:
append_unique([user_model])
Expand Down
31 changes: 31 additions & 0 deletions litellm/proxy/litellm_pre_call_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,37 @@ async def add_litellm_data_to_request( # noqa: PLR0915
"user_api_key_user_max_budget"
] = user_api_key_dict.user_max_budget

# Extract allowed access groups for router filtering (GitHub issue #18333)
# This allows the router to filter deployments based on key's and team's access groups
# NOTE: We keep key and team access groups SEPARATE because a key doesn't always
# inherit all team access groups (per maintainer feedback).
if llm_router is not None:
from litellm.proxy.auth.model_checks import get_access_groups_from_models

model_access_groups = llm_router.get_model_access_groups()
Comment on lines +1028 to +1031
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling llm_router.get_model_access_groups() on every request may cause performance issues. This method calls get_model_list() and iterates over all models (line 7530-7542 in router.py), which could be expensive for routers with many models. Consider caching the result or moving this computation outside the critical request path.

Context Used: Rule from dashboard - What: Avoid creating new database requests or Router objects in the critical request path.

Why: Cre... (source)

Prompt To Fix With AI
This is a comment left during a code review.
Path: litellm/proxy/litellm_pre_call_utils.py
Line: 1028:1031

Comment:
Calling `llm_router.get_model_access_groups()` on every request may cause performance issues. This method calls `get_model_list()` and iterates over all models (line 7530-7542 in `router.py`), which could be expensive for routers with many models. Consider caching the result or moving this computation outside the critical request path.

**Context Used:** Rule from `dashboard` - What: Avoid creating new database requests or Router objects in the critical request path.

Why: Cre... ([source](https://app.greptile.com/review/custom-context?memory=0c2a17ad-5f29-423f-a48b-371852ac4169))

How can I resolve this? If you propose a fix, please make it concise.


# Key-level access groups (from user_api_key_dict.models)
key_models = list(user_api_key_dict.models) if user_api_key_dict.models else []
key_allowed_access_groups = get_access_groups_from_models(
model_access_groups=model_access_groups, models=key_models
)
if key_allowed_access_groups:
data[_metadata_variable_name][
"user_api_key_allowed_access_groups"
] = key_allowed_access_groups

# Team-level access groups (from user_api_key_dict.team_models)
team_models = (
list(user_api_key_dict.team_models) if user_api_key_dict.team_models else []
)
team_allowed_access_groups = get_access_groups_from_models(
model_access_groups=model_access_groups, models=team_models
)
if team_allowed_access_groups:
data[_metadata_variable_name][
"user_api_key_team_allowed_access_groups"
] = team_allowed_access_groups

data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
Expand Down
10 changes: 10 additions & 0 deletions litellm/proxy/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,16 @@ model LiteLLM_VerificationToken {
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id])

// SELECT COUNT(*) FROM (SELECT "public"."LiteLLM_VerificationToken"."token" FROM "public"."LiteLLM_VerificationToken" WHERE ("public"."LiteLLM_VerificationToken"."user_id" = $1 AND ("public"."LiteLLM_VerificationToken"."team_id" IS NULL OR "public"."LiteLLM_VerificationToken"."team_id" <> $2)) OFFSET $3 ) AS "sub"
// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."user_id" = $1 OFFSET $2
@@index([user_id, team_id])

// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."team_id" = $1 OFFSET $2
@@index([team_id])

// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (("public"."LiteLLM_VerificationToken"."expires" IS NULL OR "public"."LiteLLM_VerificationToken"."expires" > $1) AND "public"."LiteLLM_VerificationToken"."budget_reset_at" < $2) OFFSET $3
@@index([budget_reset_at, expires])
}

// Audit table for deleted keys - preserves spend and key information for historical tracking
Expand Down
12 changes: 10 additions & 2 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
is_clientside_credential,
)
from litellm.router_utils.common_utils import (
filter_deployments_by_access_groups,
filter_team_based_models,
filter_web_search_deployments,
)
Expand Down Expand Up @@ -8087,10 +8088,17 @@ async def async_get_healthy_deployments(
request_kwargs=request_kwargs,
)

verbose_router_logger.debug(
f"healthy_deployments after web search filter: {healthy_deployments}"
verbose_router_logger.debug(f"healthy_deployments after web search filter: {healthy_deployments}")

# Filter by allowed access groups (GitHub issue #18333)
# This prevents cross-team load balancing when teams have models with same name in different access groups
healthy_deployments = filter_deployments_by_access_groups(
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
)

verbose_router_logger.debug(f"healthy_deployments after access group filter: {healthy_deployments}")

if isinstance(healthy_deployments, dict):
return healthy_deployments

Expand Down
79 changes: 77 additions & 2 deletions litellm/router_utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def filter_team_based_models(
if deployment.get("model_info", {}).get("id") not in ids_to_remove
]


def _deployment_supports_web_search(deployment: Dict) -> bool:
"""
Check if a deployment supports web search.
Expand Down Expand Up @@ -112,7 +113,7 @@ def filter_web_search_deployments(
is_web_search_request = False
tools = request_kwargs.get("tools") or []
for tool in tools:
# These are the two websearch tools for OpenAI / Azure.
# These are the two websearch tools for OpenAI / Azure.
if tool.get("type") == "web_search" or tool.get("type") == "web_search_preview":
is_web_search_request = True
break
Expand All @@ -121,8 +122,82 @@ def filter_web_search_deployments(
return healthy_deployments

# Filter out deployments that don't support web search
final_deployments = [d for d in healthy_deployments if _deployment_supports_web_search(d)]
final_deployments = [
d for d in healthy_deployments if _deployment_supports_web_search(d)
]
if len(healthy_deployments) > 0 and len(final_deployments) == 0:
verbose_logger.warning("No deployments support web search for request")
return final_deployments


def filter_deployments_by_access_groups(
healthy_deployments: Union[List[Dict], Dict],
request_kwargs: Optional[Dict] = None,
) -> Union[List[Dict], Dict]:
"""
Filter deployments to only include those matching the user's allowed access groups.

Reads from TWO separate metadata fields (per maintainer feedback):
- `user_api_key_allowed_access_groups`: Access groups from the API Key's models.
- `user_api_key_team_allowed_access_groups`: Access groups from the Team's models.

A deployment is included if its access_groups overlap with EITHER the key's
or the team's allowed access groups. Deployments with no access_groups are
always included (not restricted).

This prevents cross-team load balancing when multiple teams have models with
the same name but in different access groups (GitHub issue #18333).
"""
if request_kwargs is None:
return healthy_deployments

if isinstance(healthy_deployments, dict):
return healthy_deployments

metadata = request_kwargs.get("metadata") or {}
litellm_metadata = request_kwargs.get("litellm_metadata") or {}

# Gather key-level allowed access groups
key_allowed_access_groups = (
metadata.get("user_api_key_allowed_access_groups")
or litellm_metadata.get("user_api_key_allowed_access_groups")
or []
)

# Gather team-level allowed access groups
team_allowed_access_groups = (
metadata.get("user_api_key_team_allowed_access_groups")
or litellm_metadata.get("user_api_key_team_allowed_access_groups")
or []
)

# Combine both for the final allowed set
combined_allowed_access_groups = list(key_allowed_access_groups) + list(
team_allowed_access_groups
)

# If no access groups specified from either source, return all deployments (backwards compatible)
if not combined_allowed_access_groups:
return healthy_deployments

allowed_set = set(combined_allowed_access_groups)
filtered = []
for deployment in healthy_deployments:
model_info = deployment.get("model_info") or {}
deployment_access_groups = model_info.get("access_groups") or []

# If deployment has no access groups, include it (not restricted)
if not deployment_access_groups:
filtered.append(deployment)
continue

# Include if any of deployment's groups overlap with allowed groups
if set(deployment_access_groups) & allowed_set:
filtered.append(deployment)

if len(healthy_deployments) > 0 and len(filtered) == 0:
verbose_logger.warning(
f"No deployments match allowed access groups {combined_allowed_access_groups}"
)

return filtered
10 changes: 10 additions & 0 deletions schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,16 @@ model LiteLLM_VerificationToken {
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id])

// SELECT COUNT(*) FROM (SELECT "public"."LiteLLM_VerificationToken"."token" FROM "public"."LiteLLM_VerificationToken" WHERE ("public"."LiteLLM_VerificationToken"."user_id" = $1 AND ("public"."LiteLLM_VerificationToken"."team_id" IS NULL OR "public"."LiteLLM_VerificationToken"."team_id" <> $2)) OFFSET $3 ) AS "sub"
// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."user_id" = $1 OFFSET $2
@@index([user_id, team_id])

// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."team_id" = $1 OFFSET $2
@@index([team_id])

// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (("public"."LiteLLM_VerificationToken"."expires" IS NULL OR "public"."LiteLLM_VerificationToken"."expires" > $1) AND "public"."LiteLLM_VerificationToken"."budget_reset_at" < $2) OFFSET $3
@@index([budget_reset_at, expires])
}

// Audit table for deleted keys - preserves spend and key information for historical tracking
Expand Down
19 changes: 2 additions & 17 deletions tests/mcp_tests/test_semantic_tool_filter_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,8 @@

from mcp.types import Tool as MCPTool

# Check if semantic-router is available
try:
import semantic_router
SEMANTIC_ROUTER_AVAILABLE = True
except ImportError:
SEMANTIC_ROUTER_AVAILABLE = False


@pytest.mark.asyncio
@pytest.mark.skipif(
not SEMANTIC_ROUTER_AVAILABLE,
reason="semantic-router not installed. Install with: pip install 'litellm[semantic-router]'"
)
async def test_e2e_semantic_filter():
"""E2E: Load router/filter and verify hook filters tools."""
from litellm import Router
Expand All @@ -48,6 +37,8 @@ async def test_e2e_semantic_filter():
enabled=True,
)

hook = SemanticToolFilterHook(filter_instance)

# Create 10 tools
tools = [
MCPTool(name="gmail_send", description="Send an email via Gmail", inputSchema={"type": "object"}),
Expand All @@ -62,16 +53,10 @@ async def test_e2e_semantic_filter():
MCPTool(name="note_add", description="Add note", inputSchema={"type": "object"}),
]

# Build router with test tools
filter_instance._build_router(tools)

hook = SemanticToolFilterHook(filter_instance)

data = {
"model": "gpt-4",
"messages": [{"role": "user", "content": "Send an email and create a calendar event"}],
"tools": tools,
"metadata": {}, # Initialize metadata dict for hook to store filter stats
}

# Call hook
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ async def mock_embedding_async(*args, **kwargs):
enabled=True,
)

# Build router with the tools before filtering
filter_instance._build_router(tools)

# Filter tools with email-related query
filtered = await filter_instance.filter_tools(
query="send an email to john@example.com",
Expand Down Expand Up @@ -142,9 +139,6 @@ async def mock_embedding_async(*args, **kwargs):
enabled=True,
)

# Build router with the tools before filtering
filter_instance._build_router(tools)

# Filter tools
filtered = await filter_instance.filter_tools(
query="test query",
Expand Down Expand Up @@ -303,25 +297,21 @@ async def mock_embedding_async(*args, **kwargs):
enabled=True,
)

# Create hook
hook = SemanticToolFilterHook(filter_instance)

# Prepare data - completion request with tools
tools = [
MCPTool(name=f"tool_{i}", description=f"Tool {i}", inputSchema={"type": "object"})
for i in range(10)
]

# Build router with the tools before filtering
filter_instance._build_router(tools)

# Create hook
hook = SemanticToolFilterHook(filter_instance)

data = {
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Send an email"}
],
"tools": tools,
"metadata": {}, # Hook needs metadata field to store filter stats
}

# Mock user API key dict and cache
Expand Down
Loading
Loading