diff --git a/docs/my-website/docs/mcp_semantic_filter.md b/docs/my-website/docs/mcp_semantic_filter.md new file mode 100644 index 00000000000..c58be80a680 --- /dev/null +++ b/docs/my-website/docs/mcp_semantic_filter.md @@ -0,0 +1,158 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# MCP Semantic Tool Filter + +Automatically filter MCP tools by semantic relevance. When you have many MCP tools registered, LiteLLM semantically matches the user's query against tool descriptions and sends only the most relevant tools to the LLM. + +## How It Works + +Tool search shifts tool selection from a prompt-engineering problem to a retrieval problem. Instead of injecting a large static list of tools into every prompt, the semantic filter: + +1. Builds a semantic index of all available MCP tools on startup +2. On each request, semantically matches the user's query against tool descriptions +3. Returns only the top-K most relevant tools to the LLM + +This approach improves context efficiency, increases reliability by reducing tool confusion, and enables scalability to ecosystems with hundreds or thousands of MCP tools. + +```mermaid +sequenceDiagram + participant Client + participant LiteLLM as LiteLLM Proxy + participant SemanticFilter as Semantic Filter + participant MCP as MCP Registry + participant LLM as LLM Provider + + Note over LiteLLM,MCP: Startup: Build Semantic Index + LiteLLM->>MCP: Fetch all registered MCP tools + MCP->>LiteLLM: Return all tools (e.g., 50 tools) + LiteLLM->>SemanticFilter: Build semantic router with embeddings + SemanticFilter->>LLM: Generate embeddings for tool descriptions + LLM->>SemanticFilter: Return embeddings + Note over SemanticFilter: Index ready for fast lookup + + Note over Client,LLM: Request: Semantic Tool Filtering + Client->>LiteLLM: POST /v1/responses with MCP tools + LiteLLM->>SemanticFilter: Expand MCP references (50 tools available) + SemanticFilter->>SemanticFilter: Extract user query from request + SemanticFilter->>LLM: Generate query embedding + LLM->>SemanticFilter: Return query embedding + SemanticFilter->>SemanticFilter: Match query against tool embeddings + SemanticFilter->>LiteLLM: Return top-K tools (e.g., 3 most relevant) + LiteLLM->>LLM: Forward request with filtered tools (3 tools) + LLM->>LiteLLM: Return response + LiteLLM->>Client: Response with headers
x-litellm-semantic-filter: 50->3
x-litellm-semantic-filter-tools: tool1,tool2,tool3 +``` + +## Configuration + +Enable semantic filtering in your LiteLLM config: + +```yaml title="config.yaml" showLineNumbers +litellm_settings: + mcp_semantic_tool_filter: + enabled: true + embedding_model: "text-embedding-3-small" # Model for semantic matching + top_k: 5 # Max tools to return + similarity_threshold: 0.3 # Min similarity score +``` + +**Configuration Options:** +- `enabled` - Enable/disable semantic filtering (default: `false`) +- `embedding_model` - Model for generating embeddings (default: `"text-embedding-3-small"`) +- `top_k` - Maximum number of tools to return (default: `10`) +- `similarity_threshold` - Minimum similarity score for matches (default: `0.3`) + +## Usage + +Use MCP tools normally with the Responses API or Chat Completions. The semantic filter runs automatically: + + + + +```bash title="Responses API with Semantic Filtering" showLineNumbers +curl --location 'http://localhost:4000/v1/responses' \ +--header 'Content-Type: application/json' \ +--header "Authorization: Bearer sk-1234" \ +--data '{ + "model": "gpt-4o", + "input": [ + { + "role": "user", + "content": "give me TLDR of what BerriAI/litellm repo is about", + "type": "message" + } + ], + "tools": [ + { + "type": "mcp", + "server_url": "litellm_proxy", + "require_approval": "never" + } + ], + "tool_choice": "required" +}' +``` + + + + +```bash title="Chat Completions with Semantic Filtering" showLineNumbers +curl --location 'http://localhost:4000/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header "Authorization: Bearer sk-1234" \ +--data '{ + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "Search Wikipedia for LiteLLM"} + ], + "tools": [ + { + "type": "mcp", + "server_url": "litellm_proxy" + } + ] +}' +``` + + + + +## Response Headers + +The semantic filter adds diagnostic headers to every response: + +``` +x-litellm-semantic-filter: 10->3 +x-litellm-semantic-filter-tools: wikipedia-fetch,github-search,slack-post +``` + +- **`x-litellm-semantic-filter`** - Shows before→after tool count (e.g., `10->3` means 10 tools were filtered down to 3) +- **`x-litellm-semantic-filter-tools`** - CSV list of the filtered tool names (max 150 chars, clipped with `...` if longer) + +These headers help you understand which tools were selected for each request and verify the filter is working correctly. + +## Example + +If you have 50 MCP tools registered and make a request asking about Wikipedia, the semantic filter will: + +1. Semantically match your query `"Search Wikipedia for LiteLLM"` against all 50 tool descriptions +2. Select the top 5 most relevant tools (e.g., `wikipedia-fetch`, `wikipedia-search`, etc.) +3. Pass only those 5 tools to the LLM +4. Add headers showing `x-litellm-semantic-filter: 50->5` + +This dramatically reduces prompt size while ensuring the LLM has access to the right tools for the task. + +## Performance + +The semantic filter is optimized for production: +- Router builds once on startup (no per-request overhead) +- Semantic matching typically takes under 50ms +- Fails gracefully - returns all tools if filtering fails +- No impact on latency for requests without MCP tools + +## Related + +- [MCP Overview](./mcp.md) - Learn about MCP in LiteLLM +- [MCP Permission Management](./mcp_control.md) - Control tool access by key/team +- [Using MCP](./mcp_usage.md) - Complete MCP usage guide diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index a9248d83dd6..38af2b1f594 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -538,6 +538,7 @@ const sidebars = { items: [ "mcp", "mcp_usage", + "mcp_semantic_filter", "mcp_control", "mcp_cost", "mcp_guardrail", diff --git a/litellm/constants.py b/litellm/constants.py index 3c84547d7ce..6427c367924 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -67,6 +67,20 @@ os.getenv("DEFAULT_REASONING_EFFORT_DISABLE_THINKING_BUDGET", 0) ) +# MCP Semantic Tool Filter Defaults +DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL = str( + os.getenv("DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL", "text-embedding-3-small") +) +DEFAULT_MCP_SEMANTIC_FILTER_TOP_K = int( + os.getenv("DEFAULT_MCP_SEMANTIC_FILTER_TOP_K", 10) +) +DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD = float( + os.getenv("DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD", 0.3) +) +MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH = int( + os.getenv("MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH", 150) +) + # Gemini model-specific minimal thinking budget constants DEFAULT_REASONING_EFFORT_MINIMAL_THINKING_BUDGET_GEMINI_2_5_FLASH = int( os.getenv("DEFAULT_REASONING_EFFORT_MINIMAL_THINKING_BUDGET_GEMINI_2_5_FLASH", 1) diff --git a/litellm/proxy/_experimental/mcp_server/semantic_tool_filter.py b/litellm/proxy/_experimental/mcp_server/semantic_tool_filter.py new file mode 100644 index 00000000000..c83ef13a64a --- /dev/null +++ b/litellm/proxy/_experimental/mcp_server/semantic_tool_filter.py @@ -0,0 +1,248 @@ +""" +Semantic MCP Tool Filtering using semantic-router + +Filters MCP tools semantically for /chat/completions and /responses endpoints. +""" +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from litellm._logging import verbose_logger + +if TYPE_CHECKING: + from mcp.types import Tool as MCPTool + from semantic_router.routers import SemanticRouter + + from litellm.router import Router + + +class SemanticMCPToolFilter: + """Filters MCP tools using semantic similarity to reduce context window size.""" + + def __init__( + self, + embedding_model: str, + litellm_router_instance: "Router", + top_k: int = 10, + similarity_threshold: float = 0.3, + enabled: bool = True, + ): + """ + Initialize the semantic tool filter. + + Args: + embedding_model: Model to use for embeddings (e.g., "text-embedding-3-small") + litellm_router_instance: Router instance for embedding generation + top_k: Maximum number of tools to return + similarity_threshold: Minimum similarity score for filtering + enabled: Whether filtering is enabled + """ + self.enabled = enabled + self.top_k = top_k + self.similarity_threshold = similarity_threshold + self.embedding_model = embedding_model + self.router_instance = litellm_router_instance + self.tool_router: Optional["SemanticRouter"] = None + self._tool_map: Dict[str, Any] = {} # MCPTool objects or OpenAI function dicts + + async def build_router_from_mcp_registry(self) -> None: + """Build semantic router from all MCP tools in the registry (no auth checks).""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + + try: + # Get all servers from registry without auth checks + registry = global_mcp_server_manager.get_registry() + if not registry: + verbose_logger.warning("MCP registry is empty") + self.tool_router = None + return + + # Fetch tools from all servers in parallel + all_tools = [] + for server_id, server in registry.items(): + try: + tools = await global_mcp_server_manager.get_tools_for_server(server_id) + all_tools.extend(tools) + except Exception as e: + verbose_logger.warning(f"Failed to fetch tools from server {server_id}: {e}") + continue + + if not all_tools: + verbose_logger.warning("No MCP tools found in registry") + self.tool_router = None + return + + verbose_logger.info(f"Fetched {len(all_tools)} tools from {len(registry)} MCP servers") + self._build_router(all_tools) + + except Exception as e: + verbose_logger.error(f"Failed to build router from MCP registry: {e}") + self.tool_router = None + raise + + def _extract_tool_info(self, tool) -> tuple[str, str]: + """Extract name and description from MCP tool or OpenAI function dict.""" + if isinstance(tool, dict): + # OpenAI function format + name = tool.get("name", "") + description = tool.get("description", name) + else: + # MCPTool object + name = tool.name + description = tool.description or tool.name + + return name, description + + def _build_router(self, tools: List) -> None: + """Build semantic router with tools (MCPTool objects or OpenAI function dicts).""" + from semantic_router.routers import SemanticRouter + from semantic_router.routers.base import Route + + from litellm.router_strategy.auto_router.litellm_encoder import ( + LiteLLMRouterEncoder, + ) + + if not tools: + self.tool_router = None + return + + try: + # Convert tools to routes + routes = [] + self._tool_map = {} + + for tool in tools: + name, description = self._extract_tool_info(tool) + self._tool_map[name] = tool + + routes.append( + Route( + name=name, + description=description, + utterances=[description], + score_threshold=self.similarity_threshold, + ) + ) + + self.tool_router = SemanticRouter( + routes=routes, + encoder=LiteLLMRouterEncoder( + litellm_router_instance=self.router_instance, + model_name=self.embedding_model, + score_threshold=self.similarity_threshold, + ), + auto_sync="local", + ) + + verbose_logger.info( + f"Built semantic router with {len(routes)} tools" + ) + + except Exception as e: + verbose_logger.error(f"Failed to build semantic router: {e}") + self.tool_router = None + raise + + async def filter_tools( + self, + query: str, + available_tools: List[Any], + top_k: Optional[int] = None, + ) -> List[Any]: + """ + Filter tools semantically based on query. + + Args: + query: User query to match against tools + available_tools: Full list of available MCP tools + top_k: Override default top_k (optional) + + Returns: + Filtered and ordered list of tools (up to top_k) + """ + # Early returns for cases where we can't/shouldn't filter + if not self.enabled: + return available_tools + + if not available_tools: + return available_tools + + if not query or not query.strip(): + return available_tools + + # Router should be built on startup - if not, something went wrong + if self.tool_router is None: + verbose_logger.warning("Router not initialized - was build_router_from_mcp_registry() called on startup?") + return available_tools + + # Run semantic filtering + try: + limit = top_k or self.top_k + matches = self.tool_router(text=query, limit=limit) + matched_tool_names = self._extract_tool_names_from_matches(matches) + + if not matched_tool_names: + return available_tools + + return self._get_tools_by_names(matched_tool_names, available_tools) + + except Exception as e: + verbose_logger.error(f"Semantic tool filter failed: {e}", exc_info=True) + return available_tools + + def _extract_tool_names_from_matches(self, matches) -> List[str]: + """Extract tool names from semantic router match results.""" + if not matches: + return [] + + # Handle single match + if hasattr(matches, "name") and matches.name: + return [matches.name] + + # Handle list of matches + if isinstance(matches, list): + return [m.name for m in matches if hasattr(m, "name") and m.name] + + return [] + + def _get_tools_by_names( + self, tool_names: List[str], available_tools: List[Any] + ) -> List[Any]: + """Get tools from available_tools by their names, preserving order.""" + # Match tools from available_tools (preserves format - dict or MCPTool) + matched_tools = [] + for tool in available_tools: + tool_name, _ = self._extract_tool_info(tool) + if tool_name in tool_names: + matched_tools.append(tool) + + # Reorder to match semantic router's ordering + tool_map = {self._extract_tool_info(t)[0]: t for t in matched_tools} + return [tool_map[name] for name in tool_names if name in tool_map] + + def extract_user_query(self, messages: List[Dict[str, Any]]) -> str: + """ + Extract user query from messages for /chat/completions or /responses. + + Args: + messages: List of message dictionaries (from 'messages' or 'input' field) + + Returns: + Extracted query string + """ + for msg in reversed(messages): + if msg.get("role") == "user": + content = msg.get("content", "") + + if isinstance(content, str): + return content + + if isinstance(content, list): + texts = [ + block.get("text", "") if isinstance(block, dict) else str(block) + for block in content + if isinstance(block, (dict, str)) + ] + return " ".join(texts) + + return "" diff --git a/litellm/proxy/hooks/mcp_semantic_filter/ARCHITECTURE.md b/litellm/proxy/hooks/mcp_semantic_filter/ARCHITECTURE.md new file mode 100644 index 00000000000..f2f9a1d4856 --- /dev/null +++ b/litellm/proxy/hooks/mcp_semantic_filter/ARCHITECTURE.md @@ -0,0 +1,96 @@ +# MCP Semantic Tool Filter Architecture + +## Why Filter MCP Tools + +When multiple MCP servers are connected, the proxy may expose hundreds of tools. Sending all tools in every request wastes context window tokens and increases cost. The semantic filter keeps only the top-K most relevant tools based on embedding similarity. + +```mermaid +sequenceDiagram + participant Client + participant Hook as SemanticToolFilterHook + participant Filter as SemanticMCPToolFilter + participant Router as semantic-router + participant LLM + + Client->>Hook: POST /chat/completions + Note over Client,Hook: tools: [100+ MCP tools] + Note over Client,Hook: messages: [{"role": "user", "content": "Get my Jira issues"}] + + rect rgb(240, 240, 240) + Note over Hook: 1. Extract User Query + Hook->>Filter: filter_tools("Get my Jira issues", tools) + end + + rect rgb(240, 240, 240) + Note over Filter: 2. Convert Tools → Routes + Note over Filter: Tool name + description → Route + end + + rect rgb(240, 240, 240) + Note over Filter: 3. Semantic Matching + Filter->>Router: router(query) + Router->>Router: Embeddings + similarity + Router-->>Filter: [top 10 matches] + end + + rect rgb(240, 240, 240) + Note over Filter: 4. Return Filtered Tools + Filter-->>Hook: [10 relevant tools] + end + + Hook->>LLM: POST /chat/completions + Note over Hook,LLM: tools: [10 Jira-related tools] ← FILTERED + Note over Hook,LLM: messages: [...] ← UNCHANGED + + LLM-->>Client: Response (unchanged) +``` + +## Filter Operations + +The hook intercepts requests before they reach the LLM: + +| Operation | Description | +|-----------|-------------| +| **Extract query** | Get user message from `messages[-1]` | +| **Convert to Routes** | Transform MCP tools into semantic-router Routes | +| **Semantic match** | Use `semantic-router` to find top-K similar tools | +| **Filter tools** | Replace request `tools` with filtered subset | + +## Trigger Conditions + +The filter only runs when: +- Call type is `completion` or `acompletion` +- Request contains `tools` field +- Request contains `messages` field +- Filter is enabled in config + +## What Does NOT Change + +- Request messages +- Response body +- Non-tool parameters + +## Integration with semantic-router + +Reuses existing LiteLLM infrastructure: +- `semantic-router` - Already an optional dependency +- `LiteLLMRouterEncoder` - Wraps `Router.aembedding()` for embeddings +- `SemanticRouter` - Handles similarity calculation and top-K selection + +## Configuration + +```yaml +litellm_settings: + mcp_semantic_tool_filter: + enabled: true + embedding_model: "openai/text-embedding-3-small" + top_k: 10 + similarity_threshold: 0.3 +``` + +## Error Handling + +The filter fails gracefully: +- If filtering fails → Return all tools (no impact on functionality) +- If query extraction fails → Skip filtering +- If no matches found → Return all tools diff --git a/litellm/proxy/hooks/mcp_semantic_filter/__init__.py b/litellm/proxy/hooks/mcp_semantic_filter/__init__.py new file mode 100644 index 00000000000..36d357d560f --- /dev/null +++ b/litellm/proxy/hooks/mcp_semantic_filter/__init__.py @@ -0,0 +1,9 @@ +""" +MCP Semantic Tool Filter Hook + +Semantic filtering for MCP tools to reduce context window size +and improve tool selection accuracy. +""" +from litellm.proxy.hooks.mcp_semantic_filter.hook import SemanticToolFilterHook + +__all__ = ["SemanticToolFilterHook"] diff --git a/litellm/proxy/hooks/mcp_semantic_filter/hook.py b/litellm/proxy/hooks/mcp_semantic_filter/hook.py new file mode 100644 index 00000000000..fc9349c2a42 --- /dev/null +++ b/litellm/proxy/hooks/mcp_semantic_filter/hook.py @@ -0,0 +1,353 @@ +""" +Semantic Tool Filter Hook + +Pre-call hook that filters MCP tools semantically before LLM inference. +Reduces context window size and improves tool selection accuracy. +""" +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from litellm._logging import verbose_proxy_logger +from litellm.constants import ( + DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL, + DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD, + DEFAULT_MCP_SEMANTIC_FILTER_TOP_K, +) +from litellm.integrations.custom_logger import CustomLogger + +if TYPE_CHECKING: + from litellm.caching.caching import DualCache + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + from litellm.proxy._types import UserAPIKeyAuth + from litellm.router import Router + + +class SemanticToolFilterHook(CustomLogger): + """ + Pre-call hook that filters MCP tools semantically. + + This hook: + 1. Extracts the user query from messages + 2. Filters tools based on semantic similarity to the query + 3. Returns only the top-k most relevant tools to the LLM + """ + + def __init__(self, semantic_filter: "SemanticMCPToolFilter"): + """ + Initialize the hook. + + Args: + semantic_filter: SemanticMCPToolFilter instance + """ + super().__init__() + self.filter = semantic_filter + + verbose_proxy_logger.debug( + f"Initialized SemanticToolFilterHook with filter: " + f"enabled={semantic_filter.enabled}, top_k={semantic_filter.top_k}" + ) + + def _should_expand_mcp_tools(self, tools: List[Any]) -> bool: + """ + Check if tools contain MCP references with server_url="litellm_proxy". + + Only expands MCP tools pointing to litellm proxy, not external MCP servers. + """ + from litellm.responses.mcp.litellm_proxy_mcp_handler import ( + LiteLLM_Proxy_MCP_Handler, + ) + + return LiteLLM_Proxy_MCP_Handler._should_use_litellm_mcp_gateway(tools) + + async def _expand_mcp_tools( + self, + tools: List[Any], + user_api_key_dict: "UserAPIKeyAuth", + ) -> List[Dict[str, Any]]: + """ + Expand MCP references to actual tool definitions. + + Reuses LiteLLM_Proxy_MCP_Handler._process_mcp_tools_to_openai_format + which internally does: parse -> fetch -> filter -> deduplicate -> transform + """ + from litellm.responses.mcp.litellm_proxy_mcp_handler import ( + LiteLLM_Proxy_MCP_Handler, + ) + + # Parse to separate MCP tools from other tools + mcp_tools, _ = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools) + + if not mcp_tools: + return [] + + # Use single combined method instead of 3 separate calls + # This already handles: fetch -> filter by allowed_tools -> deduplicate -> transform + openai_tools, _ = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_to_openai_format( + user_api_key_auth=user_api_key_dict, + mcp_tools_with_litellm_proxy=mcp_tools + ) + + # Convert Pydantic models to dicts for compatibility + openai_tools_as_dicts = [] + for tool in openai_tools: + if hasattr(tool, "model_dump"): + tool_dict = tool.model_dump(exclude_none=True) + verbose_proxy_logger.debug(f"Converted Pydantic tool to dict: {type(tool).__name__} -> dict with keys: {list(tool_dict.keys())}") + openai_tools_as_dicts.append(tool_dict) + elif hasattr(tool, "dict"): + tool_dict = tool.dict(exclude_none=True) + verbose_proxy_logger.debug(f"Converted Pydantic tool (v1) to dict: {type(tool).__name__} -> dict") + openai_tools_as_dicts.append(tool_dict) + elif isinstance(tool, dict): + verbose_proxy_logger.debug(f"Tool is already a dict with keys: {list(tool.keys())}") + openai_tools_as_dicts.append(tool) + else: + verbose_proxy_logger.warning(f"Tool is unknown type: {type(tool)}, passing as-is") + openai_tools_as_dicts.append(tool) + + verbose_proxy_logger.debug( + f"Expanded {len(mcp_tools)} MCP reference(s) to {len(openai_tools_as_dicts)} tools (all as dicts)" + ) + + return openai_tools_as_dicts + + def _get_metadata_variable_name(self, data: dict) -> str: + if "litellm_metadata" in data: + return "litellm_metadata" + return "metadata" + + async def async_pre_call_hook( + self, + user_api_key_dict: "UserAPIKeyAuth", + cache: "DualCache", + data: dict, + call_type: str, + ) -> Optional[Union[Exception, str, dict]]: + """ + Filter tools before LLM call based on user query. + + This hook is called before the LLM request is made. It filters the + tools list to only include semantically relevant tools. + + Args: + user_api_key_dict: User authentication + cache: Cache instance + data: Request data containing messages and tools + call_type: Type of call (completion, acompletion, etc.) + + Returns: + Modified data dict with filtered tools, or None if no changes + """ + # Only filter endpoints that support tools + if call_type not in ("completion", "acompletion", "aresponses"): + verbose_proxy_logger.debug( + f"Skipping semantic filter for call_type={call_type}" + ) + return None + + # Check if tools are present + tools = data.get("tools") + if not tools: + verbose_proxy_logger.debug("No tools in request, skipping semantic filter") + return None + + original_tool_count = len(tools) + + # Check for MCP references (server_url="litellm_proxy") and expand them + if self._should_expand_mcp_tools(tools): + verbose_proxy_logger.debug( + "Detected litellm_proxy MCP references, expanding before semantic filtering" + ) + + try: + expanded_tools = await self._expand_mcp_tools( + tools, user_api_key_dict + ) + + if not expanded_tools: + verbose_proxy_logger.warning( + "No tools expanded from MCP references" + ) + return None + + verbose_proxy_logger.info( + f"Expanded {len(tools)} MCP reference(s) to {len(expanded_tools)} tools" + ) + + # Update tools for filtering + tools = expanded_tools + original_tool_count = len(tools) + + except Exception as e: + verbose_proxy_logger.error( + f"Failed to expand MCP references: {e}", exc_info=True + ) + return None + + # Check if messages are present (try both "messages" and "input" for responses API) + messages = data.get("messages", []) + if not messages: + messages = data.get("input", []) + if not messages: + verbose_proxy_logger.debug("No messages in request, skipping semantic filter") + return None + + # Check if filter is enabled + if not self.filter.enabled: + verbose_proxy_logger.debug("Semantic filter disabled, skipping") + return None + + try: + # Extract user query from messages + user_query = self.filter.extract_user_query(messages) + if not user_query: + verbose_proxy_logger.debug("No user query found, skipping semantic filter") + return None + + verbose_proxy_logger.debug( + f"Applying semantic filter to {len(tools)} tools " + f"with query: '{user_query[:50]}...'" + ) + + # Filter tools semantically + filtered_tools = await self.filter.filter_tools( + query=user_query, + available_tools=tools, # type: ignore + ) + + # Always update tools and emit header (even if count unchanged) + data["tools"] = filtered_tools + + # Store filter stats and tool names for response header + filter_stats = f"{original_tool_count}->{len(filtered_tools)}" + tool_names_csv = self._get_tool_names_csv(filtered_tools) + + _metadata_variable_name = self._get_metadata_variable_name(data) + data[_metadata_variable_name]["litellm_semantic_filter_stats"] = filter_stats + data[_metadata_variable_name]["litellm_semantic_filter_tools"] = tool_names_csv + + verbose_proxy_logger.info( + f"Semantic tool filter: {filter_stats} tools" + ) + + return data + + except Exception as e: + verbose_proxy_logger.warning( + f"Semantic tool filter hook failed: {e}. Proceeding with all tools." + ) + return None + + async def async_post_call_response_headers_hook( + self, + data: dict, + user_api_key_dict: "UserAPIKeyAuth", + response: Any, + request_headers: Optional[Dict[str, str]] = None, + ) -> Optional[Dict[str, str]]: + """Add semantic filter stats and tool names to response headers.""" + from litellm.constants import MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH + + _metadata_variable_name = self._get_metadata_variable_name(data) + metadata = data[_metadata_variable_name] + + filter_stats = metadata.get("litellm_semantic_filter_stats") + if not filter_stats: + return None + + headers = {"x-litellm-semantic-filter": filter_stats} + + # Add CSV of filtered tool names (nginx-safe length) + tool_names_csv = metadata.get("litellm_semantic_filter_tools", "") + if tool_names_csv: + if len(tool_names_csv) > MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH: + tool_names_csv = tool_names_csv[:MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH - 3] + "..." + + headers["x-litellm-semantic-filter-tools"] = tool_names_csv + + return headers + + def _get_tool_names_csv(self, tools: List[Any]) -> str: + """Extract tool names and return as CSV string.""" + if not tools: + return "" + + tool_names = [] + for tool in tools: + name = tool.get("name", "") if isinstance(tool, dict) else getattr(tool, "name", "") + if name: + tool_names.append(name) + + return ",".join(tool_names) + + @staticmethod + async def initialize_from_config( + config: Optional[Dict[str, Any]], + llm_router: Optional["Router"], + ) -> Optional["SemanticToolFilterHook"]: + """ + Initialize semantic tool filter from proxy config. + + Args: + config: Proxy configuration dict (litellm_settings.mcp_semantic_tool_filter) + llm_router: LiteLLM router instance for embeddings + + Returns: + SemanticToolFilterHook instance if enabled, None otherwise + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + if not config or not config.get("enabled", False): + verbose_proxy_logger.debug("Semantic tool filter not enabled in config") + return None + + if llm_router is None: + verbose_proxy_logger.warning( + "Cannot initialize semantic filter: llm_router is None" + ) + return None + + try: + + embedding_model = config.get( + "embedding_model", DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL + ) + top_k = config.get("top_k", DEFAULT_MCP_SEMANTIC_FILTER_TOP_K) + similarity_threshold = config.get( + "similarity_threshold", DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD + ) + + semantic_filter = SemanticMCPToolFilter( + embedding_model=embedding_model, + litellm_router_instance=llm_router, + top_k=top_k, + similarity_threshold=similarity_threshold, + enabled=True, + ) + + # Build router from MCP registry on startup + await semantic_filter.build_router_from_mcp_registry() + + hook = SemanticToolFilterHook(semantic_filter) + + verbose_proxy_logger.info( + f"✅ MCP Semantic Tool Filter enabled: " + f"embedding_model={embedding_model}, top_k={top_k}, " + f"similarity_threshold={similarity_threshold}" + ) + + return hook + + except ImportError as e: + verbose_proxy_logger.warning( + f"semantic-router not installed. Install with: " + f"pip install 'litellm[semantic-router]'. Error: {e}" + ) + return None + except Exception as e: + verbose_proxy_logger.exception( + f"Failed to initialize MCP semantic tool filter: {e}" + ) + return None diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index e12e75b54ff..d87ae8b14ca 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,4 +1,14 @@ model_list: + - model_name: gpt-4o + litellm_params: + model: openai/gpt-4o + api_key: os.environ/OPENAI_API_KEY + + - model_name: text-embedding-3-small + litellm_params: + model: openai/text-embedding-3-small + api_key: os.environ/OPENAI_API_KEY + - model_name: bedrock-claude-sonnet-3.5 litellm_params: model: "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0" @@ -22,4 +32,31 @@ model_list: - model_name: bedrock-nova-premier litellm_params: model: "bedrock/us.amazon.nova-premier-v1:0" - aws_region_name: "us-east-1" \ No newline at end of file + aws_region_name: "us-east-1" + +# MCP Server Configuration +mcp_servers: + # Wikipedia MCP - reliable and works without external deps + wikipedia: + transport: "stdio" + command: "uvx" + args: ["mcp-server-fetch"] + description: "Fetch web pages and Wikipedia content" + deepwiki: + transport: "http" + url: "https://mcp.deepwiki.com/mcp" + +# General Settings +general_settings: + master_key: sk-1234 + store_model_in_db: false + +# LiteLLM Settings +litellm_settings: + # Enable MCP Semantic Tool Filter + mcp_semantic_tool_filter: + enabled: true + embedding_model: "text-embedding-3-small" + top_k: 5 + similarity_threshold: 0.3 + diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 2343fe8c35d..07598eaff15 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -793,6 +793,21 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915 redis_usage_cache=redis_usage_cache, ) + ## SEMANTIC TOOL FILTER ## + # Read litellm_settings from config for semantic filter initialization + try: + verbose_proxy_logger.debug("About to initialize semantic tool filter") + _config = proxy_config.get_config_state() + _litellm_settings = _config.get("litellm_settings", {}) + verbose_proxy_logger.debug(f"litellm_settings keys = {list(_litellm_settings.keys())}") + await ProxyStartupEvent._initialize_semantic_tool_filter( + llm_router=llm_router, + litellm_settings=_litellm_settings, + ) + verbose_proxy_logger.debug("After semantic tool filter initialization") + except Exception as e: + verbose_proxy_logger.error(f"Semantic filter init failed: {e}", exc_info=True) + ## JWT AUTH ## ProxyStartupEvent._initialize_jwt_auth( general_settings=general_settings, @@ -4741,6 +4756,34 @@ def _initialize_startup_logging( llm_router=llm_router, redis_usage_cache=redis_usage_cache ) + @classmethod + async def _initialize_semantic_tool_filter( + cls, + llm_router: Optional[Router], + litellm_settings: Dict[str, Any], + ): + """Initialize MCP semantic tool filter if configured""" + from litellm.proxy.hooks.mcp_semantic_filter import SemanticToolFilterHook + + verbose_proxy_logger.info( + f"Initializing semantic tool filter: llm_router={llm_router is not None}, " + f"litellm_settings keys={list(litellm_settings.keys())}" + ) + + mcp_semantic_filter_config = litellm_settings.get("mcp_semantic_tool_filter", None) + verbose_proxy_logger.debug(f"Semantic filter config: {mcp_semantic_filter_config}") + + hook = await SemanticToolFilterHook.initialize_from_config( + config=mcp_semantic_filter_config, + llm_router=llm_router, + ) + + if hook: + verbose_proxy_logger.debug("✅ Semantic tool filter hook registered") + litellm.logging_callback_manager.add_litellm_callback(hook) + else: + verbose_proxy_logger.warning("❌ Semantic tool filter hook not initialized") + @classmethod def _initialize_jwt_auth( cls, diff --git a/tests/mcp_tests/test_semantic_tool_filter_e2e.py b/tests/mcp_tests/test_semantic_tool_filter_e2e.py new file mode 100644 index 00000000000..cf951c1884b --- /dev/null +++ b/tests/mcp_tests/test_semantic_tool_filter_e2e.py @@ -0,0 +1,74 @@ +""" +End-to-end test for MCP Semantic Tool Filtering +""" +import asyncio +import os +import sys +from unittest.mock import Mock + +import pytest + +sys.path.insert(0, os.path.abspath("../..")) + +from mcp.types import Tool as MCPTool + + +@pytest.mark.asyncio +async def test_e2e_semantic_filter(): + """E2E: Load router/filter and verify hook filters tools.""" + from litellm import Router + from litellm.proxy.hooks.mcp_semantic_filter import SemanticToolFilterHook + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + # Create router and filter + router = Router( + model_list=[{ + "model_name": "text-embedding-3-small", + "litellm_params": {"model": "openai/text-embedding-3-small"}, + }] + ) + + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=router, + top_k=3, + enabled=True, + ) + + hook = SemanticToolFilterHook(filter_instance) + + # Create 10 tools + tools = [ + MCPTool(name="gmail_send", description="Send an email via Gmail", inputSchema={"type": "object"}), + MCPTool(name="calendar_create", description="Create a calendar event", inputSchema={"type": "object"}), + MCPTool(name="file_upload", description="Upload a file", inputSchema={"type": "object"}), + MCPTool(name="web_search", description="Search the web", inputSchema={"type": "object"}), + MCPTool(name="slack_send", description="Send Slack message", inputSchema={"type": "object"}), + MCPTool(name="doc_read", description="Read document", inputSchema={"type": "object"}), + MCPTool(name="db_query", description="Query database", inputSchema={"type": "object"}), + MCPTool(name="api_call", description="Make API call", inputSchema={"type": "object"}), + MCPTool(name="task_create", description="Create task", inputSchema={"type": "object"}), + MCPTool(name="note_add", description="Add note", inputSchema={"type": "object"}), + ] + + data = { + "model": "gpt-4", + "messages": [{"role": "user", "content": "Send an email and create a calendar event"}], + "tools": tools, + } + + # Call hook + result = await hook.async_pre_call_hook( + user_api_key_dict=Mock(), + cache=Mock(), + data=data, + call_type="completion", + ) + + # Single assertion: hook filtered tools + assert result and len(result["tools"]) < len(tools), f"Expected filtered tools, got {len(result['tools'])} tools (original: {len(tools)})" + + print(f"✅ E2E test passed: Filtering reduced tools from {len(tools)} to {len(result['tools'])}") + print(f" Filtered tools: {[t.name for t in result['tools']]}") diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py new file mode 100644 index 00000000000..8d35f5bbdc9 --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py @@ -0,0 +1,384 @@ +""" +Unit tests for MCP Semantic Tool Filtering + +Tests the core filtering logic that takes a long list of tools and returns +an ordered set of top K tools based on semantic similarity. +""" +import asyncio +import os +import sys +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +sys.path.insert(0, os.path.abspath("../..")) + +from mcp.types import Tool as MCPTool + + +@pytest.mark.asyncio +async def test_semantic_filter_basic_filtering(): + """ + Test that the semantic filter correctly filters tools based on query. + + Given: 10 email/calendar tools + When: Query is "send an email" + Then: Email tools should rank higher than calendar tools + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + # Create mock tools - mix of email and calendar tools + tools = [ + MCPTool(name="gmail_send", description="Send an email via Gmail", inputSchema={"type": "object"}), + MCPTool(name="outlook_send", description="Send an email via Outlook", inputSchema={"type": "object"}), + MCPTool(name="calendar_create", description="Create a calendar event", inputSchema={"type": "object"}), + MCPTool(name="calendar_update", description="Update a calendar event", inputSchema={"type": "object"}), + MCPTool(name="email_read", description="Read emails from inbox", inputSchema={"type": "object"}), + MCPTool(name="email_delete", description="Delete an email", inputSchema={"type": "object"}), + MCPTool(name="calendar_delete", description="Delete a calendar event", inputSchema={"type": "object"}), + MCPTool(name="email_search", description="Search for emails", inputSchema={"type": "object"}), + MCPTool(name="calendar_list", description="List calendar events", inputSchema={"type": "object"}), + MCPTool(name="email_forward", description="Forward an email to someone", inputSchema={"type": "object"}), + ] + + # Mock router that returns mock embeddings + from litellm.types.utils import Embedding, EmbeddingResponse + + mock_router = Mock() + + def mock_embedding_sync(*args, **kwargs): + return EmbeddingResponse( + data=[Embedding(embedding=[0.1] * 1536, index=0, object="embedding")], + model="text-embedding-3-small", + object="list", + usage={"prompt_tokens": 10, "total_tokens": 10} + ) + + async def mock_embedding_async(*args, **kwargs): + return mock_embedding_sync() + + mock_router.embedding = mock_embedding_sync + mock_router.aembedding = mock_embedding_async + + # Create filter + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=True, + ) + + # Filter tools with email-related query + filtered = await filter_instance.filter_tools( + query="send an email to john@example.com", + available_tools=tools, + ) + + # Assertions - validate filtering mechanics work + assert len(filtered) <= 3, f"Should return at most 3 tools (top_k), got {len(filtered)}" + assert len(filtered) > 0, "Should return at least some tools" + assert len(filtered) < len(tools), f"Should filter down from {len(tools)} tools, got {len(filtered)}" + + # Validate tools are actual MCPTool objects + for tool in filtered: + assert hasattr(tool, 'name'), "Filtered result should be MCPTool with name" + assert hasattr(tool, 'description'), "Filtered result should be MCPTool with description" + + filtered_names = [t.name for t in filtered] + print(f"✅ Successfully filtered {len(tools)} tools down to top {len(filtered)}: {filtered_names}") + print(f" Filter respects top_k parameter correctly") + + +@pytest.mark.asyncio +async def test_semantic_filter_top_k_limiting(): + """ + Test that the filter respects top_k parameter. + + Given: 20 tools + When: top_k=5 + Then: Should return at most 5 tools + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + # Create 20 tools + tools = [ + MCPTool(name=f"tool_{i}", description=f"Tool number {i} for testing", inputSchema={"type": "object"}) + for i in range(20) + ] + + # Mock router + from litellm.types.utils import Embedding, EmbeddingResponse + + mock_router = Mock() + + def mock_embedding_sync(*args, **kwargs): + return EmbeddingResponse( + data=[Embedding(embedding=[0.1] * 1536, index=0, object="embedding")], + model="text-embedding-3-small", + object="list", + usage={"prompt_tokens": 10, "total_tokens": 10} + ) + + async def mock_embedding_async(*args, **kwargs): + return mock_embedding_sync() + + mock_router.embedding = mock_embedding_sync + mock_router.aembedding = mock_embedding_async + + # Create filter with top_k=5 + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=5, + similarity_threshold=0.3, + enabled=True, + ) + + # Filter tools + filtered = await filter_instance.filter_tools( + query="test query", + available_tools=tools, + ) + + # Should return at most 5 tools + assert len(filtered) <= 5, f"Expected at most 5 tools, got {len(filtered)}" + print(f"Returned {len(filtered)} tools out of {len(tools)} (top_k=5)") + + +@pytest.mark.asyncio +async def test_semantic_filter_disabled(): + """ + Test that when filter is disabled, all tools are returned. + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + tools = [ + MCPTool(name=f"tool_{i}", description=f"Tool {i}", inputSchema={"type": "object"}) + for i in range(10) + ] + + mock_router = Mock() + + # Create disabled filter + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=False, # Disabled + ) + + # Filter tools + filtered = await filter_instance.filter_tools( + query="test query", + available_tools=tools, + ) + + # Should return all tools when disabled + assert len(filtered) == len(tools), f"Expected all {len(tools)} tools, got {len(filtered)}" + + +@pytest.mark.asyncio +async def test_semantic_filter_empty_tools(): + """ + Test that filter handles empty tool list gracefully. + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + mock_router = Mock() + + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=True, + ) + + # Filter empty list + filtered = await filter_instance.filter_tools( + query="test query", + available_tools=[], + ) + + assert len(filtered) == 0, "Should return empty list for empty input" + + +@pytest.mark.asyncio +async def test_semantic_filter_extract_user_query(): + """ + Test that user query extraction works correctly from messages. + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + mock_router = Mock() + + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=True, + ) + + # Test string content + messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Send an email to john@example.com"}, + ] + + query = filter_instance.extract_user_query(messages) + assert query == "Send an email to john@example.com" + + # Test list content blocks + messages_with_blocks = [ + {"role": "user", "content": [ + {"type": "text", "text": "Hello, "}, + {"type": "text", "text": "send email please"}, + ]}, + ] + + query2 = filter_instance.extract_user_query(messages_with_blocks) + assert "Hello" in query2 and "send email" in query2 + + # Test no user messages + messages_no_user = [ + {"role": "system", "content": "System message only"}, + ] + + query3 = filter_instance.extract_user_query(messages_no_user) + assert query3 == "" + + +@pytest.mark.asyncio +async def test_semantic_filter_hook_triggers_on_completion(): + """ + Test that the hook triggers for completion requests with tools. + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + from litellm.proxy.hooks.mcp_semantic_filter import SemanticToolFilterHook + from litellm.types.utils import Embedding, EmbeddingResponse + + # Create mock filter + mock_router = Mock() + + def mock_embedding_sync(*args, **kwargs): + return EmbeddingResponse( + data=[Embedding(embedding=[0.1] * 1536, index=0, object="embedding")], + model="text-embedding-3-small", + object="list", + usage={"prompt_tokens": 10, "total_tokens": 10} + ) + + async def mock_embedding_async(*args, **kwargs): + return mock_embedding_sync() + + mock_router.embedding = mock_embedding_sync + mock_router.aembedding = mock_embedding_async + + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=True, + ) + + # Create hook + hook = SemanticToolFilterHook(filter_instance) + + # Prepare data - completion request with tools + tools = [ + MCPTool(name=f"tool_{i}", description=f"Tool {i}", inputSchema={"type": "object"}) + for i in range(10) + ] + + data = { + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "Send an email"} + ], + "tools": tools, + } + + # Mock user API key dict and cache + mock_user_api_key_dict = Mock() + mock_cache = Mock() + + # Call hook + result = await hook.async_pre_call_hook( + user_api_key_dict=mock_user_api_key_dict, + cache=mock_cache, + data=data, + call_type="completion", + ) + + # Assertions + assert result is not None, "Hook should return modified data" + assert "tools" in result, "Result should contain tools" + assert len(result["tools"]) < len(tools), f"Hook should filter tools, got {len(result['tools'])}/{len(tools)}" + + print(f"✅ Hook triggered correctly: {len(tools)} -> {len(result['tools'])} tools") + + + +@pytest.mark.asyncio +async def test_semantic_filter_hook_skips_no_tools(): + """ + Test that the hook does NOT trigger when there are no tools. + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + from litellm.proxy.hooks.mcp_semantic_filter import SemanticToolFilterHook + + # Create mock filter + mock_router = Mock() + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=True, + ) + + # Create hook + hook = SemanticToolFilterHook(filter_instance) + + # Prepare data - completion without tools + data = { + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "Hello"} + ], + } + + # Mock user API key dict and cache + mock_user_api_key_dict = Mock() + mock_cache = Mock() + + # Call hook + result = await hook.async_pre_call_hook( + user_api_key_dict=mock_user_api_key_dict, + cache=mock_cache, + data=data, + call_type="completion", + ) + + # Should return None (no modification) + assert result is None, "Hook should skip requests without tools" + print("✅ Hook correctly skips requests without tools") +