diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 4c4e6fa6342..317613420a5 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -33,10 +33,9 @@ if TYPE_CHECKING: from fastapi import HTTPException - - from litellm.caching.caching import DualCache from opentelemetry.trace import Span as _Span + from litellm.caching.caching import DualCache from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.proxy._types import UserAPIKeyAuth from litellm.types.mcp import ( @@ -484,6 +483,138 @@ async def async_post_mcp_tool_call_hook( """ return None + ######################################################### + # AGENTIC LOOP HOOKS (for litellm.messages + future completion support) + ######################################################### + + async def async_should_run_agentic_loop( + self, + response: Any, + model: str, + messages: List[Dict], + tools: Optional[List[Dict]], + stream: bool, + custom_llm_provider: str, + kwargs: Dict, + ) -> Tuple[bool, Dict]: + """ + Hook to determine if agentic loop should be executed. + + Called after receiving response from model, before returning to user. + + USE CASE: Enables transparent server-side tool execution for models that + don't natively support server-side tools. User makes ONE API call and gets + back the final answer - the agentic loop happens transparently on the server. + + Example use cases: + - WebSearch: Intercept WebSearch tool calls for Bedrock/Claude, execute + litellm.search(), return final answer with search results + - Code execution: Execute code in sandboxed environment, return results + - Database queries: Execute queries server-side, return data to model + - API calls: Make external API calls and inject responses back into context + + Flow: + 1. User calls litellm.messages.acreate(tools=[...]) + 2. Model responds with tool_use + 3. THIS HOOK checks if tool should run server-side + 4. If True, async_run_agentic_loop executes the tool + 5. User receives final answer (never sees intermediate tool_use) + + Args: + response: Response from model (AnthropicMessagesResponse or AsyncIterator) + model: Model name + messages: Original messages sent to model + tools: List of tool definitions from request + stream: Whether response is streaming + custom_llm_provider: Provider name (e.g., "bedrock", "anthropic") + kwargs: Additional request parameters + + Returns: + (should_run, tools): + should_run: True if agentic loop should execute + tools: Dict with tool_calls and metadata for execution + + Example: + # Detect WebSearch tool call + if has_websearch_tool_use(response): + return True, { + "tool_calls": extract_tool_calls(response), + "tool_type": "websearch" + } + return False, {} + """ + return False, {} + + async def async_run_agentic_loop( + self, + tools: Dict, + model: str, + messages: List[Dict], + response: Any, + anthropic_messages_provider_config: Any, + anthropic_messages_optional_request_params: Dict, + logging_obj: "LiteLLMLoggingObj", + stream: bool, + kwargs: Dict, + ) -> Any: + """ + Hook to execute agentic loop based on context from should_run hook. + + Called only if async_messages_should_run_agentic_loop returns True. + + USE CASE: Execute server-side tools and orchestrate the agentic loop to + return a complete answer to the user in a single API call. + + What to do here: + 1. Extract tool calls from tools dict + 2. Execute the tools (litellm.search, code execution, DB queries, etc.) + 3. Build assistant message with tool_use blocks + 4. Build user message with tool_result blocks containing results + 5. Make follow-up litellm.messages.acreate() call with results + 6. Return the final response + + Args: + tools: Dict from async_should_run_agentic_loop + Contains tool_calls and metadata + model: Model name + messages: Original messages sent to model + response: Original response from model (with tool_use) + anthropic_messages_provider_config: Provider config for making requests + anthropic_messages_optional_request_params: Request parameters (tools, etc.) + logging_obj: LiteLLM logging object + stream: Whether response is streaming + kwargs: Additional request parameters + + Returns: + Final response after executing agentic loop + (AnthropicMessagesResponse with final answer) + + Example: + # Extract tool calls + tool_calls = agentic_context["tool_calls"] + + # Execute searches in parallel + search_results = await asyncio.gather( + *[litellm.asearch(tc["input"]["query"]) for tc in tool_calls] + ) + + # Build messages with tool results + assistant_msg = {"role": "assistant", "content": [...tool_use blocks...]} + user_msg = {"role": "user", "content": [...tool_result blocks...]} + + # Make follow-up request + from litellm.anthropic_interface import messages + final_response = await messages.acreate( + model=model, + messages=messages + [assistant_msg, user_msg], + max_tokens=anthropic_messages_optional_request_params.get("max_tokens"), + **anthropic_messages_optional_request_params + ) + + return final_response + """ + pass + # Useful helpers for custom logger classes def truncate_standard_logging_payload_content( diff --git a/litellm/integrations/websearch_interception/ARCHITECTURE.md b/litellm/integrations/websearch_interception/ARCHITECTURE.md new file mode 100644 index 00000000000..345741c3c03 --- /dev/null +++ b/litellm/integrations/websearch_interception/ARCHITECTURE.md @@ -0,0 +1,182 @@ +# WebSearch Interception Architecture + +Server-side WebSearch tool execution for models that don't natively support it (e.g., Bedrock/Claude). + +## How It Works + +User makes **ONE** `litellm.messages.acreate()` call → Gets final answer with search results. +The agentic loop happens transparently on the server. + +--- + +## Request Flow + +### Without Interception (Client-Side) +User manually handles tool execution: +1. User calls `litellm.messages.acreate()` → Gets `tool_use` response +2. User executes `litellm.asearch()` +3. User calls `litellm.messages.acreate()` again with results +4. User gets final answer + +**Result**: 2 API calls, manual tool execution + +### With Interception (Server-Side) +Server handles tool execution automatically: + +```mermaid +sequenceDiagram + participant User + participant Messages as litellm.messages.acreate() + participant Handler as llm_http_handler.py + participant Logger as WebSearchInterceptionLogger + participant Router as proxy_server.llm_router + participant Search as litellm.asearch() + participant Provider as Bedrock API + + User->>Messages: acreate(tools=[WebSearch]) + Messages->>Handler: async_anthropic_messages_handler() + Handler->>Provider: Request + Provider-->>Handler: Response (tool_use) + Handler->>Logger: async_should_run_agentic_loop() + Logger->>Logger: Detect WebSearch tool_use + Logger-->>Handler: (True, tools) + Handler->>Logger: async_run_agentic_loop(tools) + Logger->>Router: Get search_provider from search_tools + Router-->>Logger: search_provider + Logger->>Search: asearch(query, provider) + Search-->>Logger: Search results + Logger->>Logger: Build tool_result message + Logger->>Messages: acreate() with results + Messages->>Provider: Request with search results + Provider-->>Messages: Final answer + Messages-->>Logger: Final response + Logger-->>Handler: Final response + Handler-->>User: Final answer (with search results) +``` + +**Result**: 1 API call from user, server handles agentic loop + +--- + +## Key Components + +| Component | File | Purpose | +|-----------|------|---------| +| **WebSearchInterceptionLogger** | `handler.py` | CustomLogger that implements agentic loop hooks | +| **Transformation Logic** | `transformation.py` | Detect tool_use, build tool_result messages, format search responses | +| **Agentic Loop Hooks** | `integrations/custom_logger.py` | Base hooks: `async_should_run_agentic_loop()`, `async_run_agentic_loop()` | +| **Hook Orchestration** | `llms/custom_httpx/llm_http_handler.py` | `_call_agentic_completion_hooks()` - calls hooks after response | +| **Router Search Tools** | `proxy/proxy_server.py` | `llm_router.search_tools` - configured search providers | +| **Search Endpoints** | `proxy/search_endpoints/endpoints.py` | Router logic for selecting search provider | + +--- + +## Configuration + +```python +from litellm.integrations.websearch_interception import WebSearchInterceptionLogger +from litellm.types.utils import LlmProviders + +# Enable for Bedrock with specific search tool +litellm.callbacks = [ + WebSearchInterceptionLogger( + enabled_providers=[LlmProviders.BEDROCK], + search_tool_name="my-perplexity-tool" # Optional: uses router's first tool if None + ) +] + +# Make request (streaming or non-streaming both work) +response = await litellm.messages.acreate( + model="bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0", + messages=[{"role": "user", "content": "What is LiteLLM?"}], + tools=[{"name": "WebSearch", ...}], + max_tokens=1024, + stream=True # Streaming is automatically converted to non-streaming for WebSearch +) +``` + +--- + +## Streaming Support + +WebSearch interception works transparently with both streaming and non-streaming requests. + +**How streaming is handled:** +1. User makes request with `stream=True` and WebSearch tool +2. Before API call, `anthropic_messages()` detects WebSearch + interception enabled +3. Converts `stream=True` → `stream=False` internally +4. Agentic loop executes with non-streaming responses +5. Final response returned to user (non-streaming) + +**Why this approach:** +- Server-side agentic loops require consuming full responses to detect tool_use +- User opts into this behavior by enabling WebSearch interception +- Provides seamless experience without client changes + +**Testing:** +- **Non-streaming**: `test_websearch_interception_e2e.py` +- **Streaming**: `test_websearch_interception_streaming_e2e.py` + +--- + +## Search Provider Selection + +1. If `search_tool_name` specified → Look up in `llm_router.search_tools` +2. If not found or None → Use first available search tool +3. If no router or no tools → Fallback to `perplexity` + +Example router config: +```yaml +search_tools: + - search_tool_name: "my-perplexity-tool" + litellm_params: + search_provider: "perplexity" + - search_tool_name: "my-tavily-tool" + litellm_params: + search_provider: "tavily" +``` + +--- + +## Message Flow + +### Initial Request +```python +messages = [{"role": "user", "content": "What is LiteLLM?"}] +tools = [{"name": "WebSearch", ...}] +``` + +### First API Call (Internal) +**Response**: `tool_use` with `name="WebSearch"`, `input={"query": "what is litellm"}` + +### Server Processing +1. Logger detects WebSearch tool_use +2. Looks up search provider from router +3. Executes `litellm.asearch(query="what is litellm", search_provider="perplexity")` +4. Gets results: `"Title: LiteLLM Docs\nURL: docs.litellm.ai\n..."` + +### Follow-Up Request (Internal) +```python +messages = [ + {"role": "user", "content": "What is LiteLLM?"}, + {"role": "assistant", "content": [{"type": "tool_use", ...}]}, + {"role": "user", "content": [{"type": "tool_result", "content": "search results..."}]} +] +``` + +### User Receives +```python +response.content[0].text +# "Based on the search results, LiteLLM is a unified interface..." +``` + +--- + +## Testing + +**E2E Tests**: +- `test_websearch_interception_e2e.py` - Non-streaming real API calls to Bedrock +- `test_websearch_interception_streaming_e2e.py` - Streaming real API calls to Bedrock + +**Unit Tests**: `test_websearch_interception.py` +Mocked tests for tool detection, provider filtering, edge cases. diff --git a/litellm/integrations/websearch_interception/__init__.py b/litellm/integrations/websearch_interception/__init__.py new file mode 100644 index 00000000000..c0feb5235e2 --- /dev/null +++ b/litellm/integrations/websearch_interception/__init__.py @@ -0,0 +1,12 @@ +""" +WebSearch Interception Module + +Provides server-side WebSearch tool execution for models that don't natively +support server-side tool calling (e.g., Bedrock/Claude). +""" + +from litellm.integrations.websearch_interception.handler import ( + WebSearchInterceptionLogger, +) + +__all__ = ["WebSearchInterceptionLogger"] diff --git a/litellm/integrations/websearch_interception/handler.py b/litellm/integrations/websearch_interception/handler.py new file mode 100644 index 00000000000..0b08bc2312a --- /dev/null +++ b/litellm/integrations/websearch_interception/handler.py @@ -0,0 +1,422 @@ +""" +WebSearch Interception Handler + +CustomLogger that intercepts WebSearch tool calls for models that don't +natively support web search (e.g., Bedrock/Claude) and executes them +server-side using litellm router's search tools. +""" + +import asyncio +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +import litellm +from litellm._logging import verbose_logger +from litellm.anthropic_interface import messages as anthropic_messages +from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.websearch_interception.transformation import ( + WebSearchTransformation, +) +from litellm.types.integrations.websearch_interception import ( + WebSearchInterceptionConfig, +) +from litellm.types.utils import LlmProviders + + +class WebSearchInterceptionLogger(CustomLogger): + """ + CustomLogger that intercepts WebSearch tool calls for models that don't + natively support web search. + + Implements agentic loop: + 1. Detects WebSearch tool_use in model response + 2. Executes litellm.asearch() for each query using router's search tools + 3. Makes follow-up request with search results + 4. Returns final response + """ + + def __init__( + self, + enabled_providers: Optional[List[Union[LlmProviders, str]]] = None, + search_tool_name: Optional[str] = None, + ): + """ + Args: + enabled_providers: List of LLM providers to enable interception for. + Use LlmProviders enum values (e.g., [LlmProviders.BEDROCK]) + Default: [LlmProviders.BEDROCK] + search_tool_name: Name of search tool configured in router's search_tools. + If None, will attempt to use first available search tool. + """ + super().__init__() + # Convert enum values to strings for comparison + if enabled_providers is None: + self.enabled_providers = [LlmProviders.BEDROCK.value] + else: + self.enabled_providers = [ + p.value if isinstance(p, LlmProviders) else p + for p in enabled_providers + ] + self.search_tool_name = search_tool_name + + @classmethod + def from_config_yaml( + cls, config: WebSearchInterceptionConfig + ) -> "WebSearchInterceptionLogger": + """ + Initialize WebSearchInterceptionLogger from proxy config.yaml parameters. + + Args: + config: Configuration dictionary from litellm_settings.websearch_interception_params + + Returns: + Configured WebSearchInterceptionLogger instance + + Example: + From proxy_config.yaml: + litellm_settings: + websearch_interception_params: + enabled_providers: ["bedrock"] + search_tool_name: "my-perplexity-search" + + Usage: + config = litellm_settings.get("websearch_interception_params", {}) + logger = WebSearchInterceptionLogger.from_config_yaml(config) + """ + # Extract parameters from config + enabled_providers_str = config.get("enabled_providers", None) + search_tool_name = config.get("search_tool_name", None) + + # Convert string provider names to LlmProviders enum values + enabled_providers: Optional[List[Union[LlmProviders, str]]] = None + if enabled_providers_str is not None: + enabled_providers = [] + for provider in enabled_providers_str: + try: + # Try to convert string to LlmProviders enum + provider_enum = LlmProviders(provider) + enabled_providers.append(provider_enum) + except ValueError: + # If conversion fails, keep as string + enabled_providers.append(provider) + + return cls( + enabled_providers=enabled_providers, + search_tool_name=search_tool_name, + ) + + async def async_should_run_agentic_loop( + self, + response: Any, + model: str, + messages: List[Dict], + tools: Optional[List[Dict]], + stream: bool, + custom_llm_provider: str, + kwargs: Dict, + ) -> Tuple[bool, Dict]: + """Check if WebSearch tool interception is needed""" + + verbose_logger.debug(f"WebSearchInterception: Hook called! provider={custom_llm_provider}, stream={stream}") + verbose_logger.debug(f"WebSearchInterception: Response type: {type(response)}") + + # Check if provider should be intercepted + # Note: custom_llm_provider is already normalized by get_llm_provider() + # (e.g., "bedrock/invoke/..." -> "bedrock") + if custom_llm_provider not in self.enabled_providers: + verbose_logger.debug( + f"WebSearchInterception: Skipping provider {custom_llm_provider} (not in enabled list: {self.enabled_providers})" + ) + return False, {} + + # Check if tools include WebSearch + has_websearch_tool = any(t.get("name") == "WebSearch" for t in (tools or [])) + if not has_websearch_tool: + verbose_logger.debug( + "WebSearchInterception: No WebSearch tool in request" + ) + return False, {} + + # Detect WebSearch tool_use in response + should_intercept, tool_calls = WebSearchTransformation.transform_request( + response=response, + stream=stream, + ) + + if not should_intercept: + verbose_logger.debug( + "WebSearchInterception: No WebSearch tool_use detected in response" + ) + return False, {} + + verbose_logger.debug( + f"WebSearchInterception: Detected {len(tool_calls)} WebSearch tool call(s), executing agentic loop" + ) + + # Return tools dict with tool calls + tools_dict = { + "tool_calls": tool_calls, + "tool_type": "websearch", + "provider": custom_llm_provider, + } + return True, tools_dict + + async def async_run_agentic_loop( + self, + tools: Dict, + model: str, + messages: List[Dict], + response: Any, + anthropic_messages_provider_config: Any, + anthropic_messages_optional_request_params: Dict, + logging_obj: Any, + stream: bool, + kwargs: Dict, + ) -> Any: + """Execute agentic loop with WebSearch execution""" + + tool_calls = tools["tool_calls"] + + verbose_logger.debug( + f"WebSearchInterception: Executing agentic loop for {len(tool_calls)} search(es)" + ) + + return await self._execute_agentic_loop( + model=model, + messages=messages, + tool_calls=tool_calls, + anthropic_messages_optional_request_params=anthropic_messages_optional_request_params, + logging_obj=logging_obj, + stream=stream, + kwargs=kwargs, + ) + + async def _execute_agentic_loop( + self, + model: str, + messages: List[Dict], + tool_calls: List[Dict], + anthropic_messages_optional_request_params: Dict, + logging_obj: Any, + stream: bool, + kwargs: Dict, + ) -> Any: + """Execute litellm.search() and make follow-up request""" + + # Extract search queries from tool_use blocks + search_tasks = [] + for tool_call in tool_calls: + query = tool_call["input"].get("query") + if query: + verbose_logger.debug( + f"WebSearchInterception: Queuing search for query='{query}'" + ) + search_tasks.append(self._execute_search(query)) + else: + verbose_logger.warning( + f"WebSearchInterception: Tool call {tool_call['id']} has no query" + ) + # Add empty result for tools without query + search_tasks.append(self._create_empty_search_result()) + + # Execute searches in parallel + verbose_logger.debug( + f"WebSearchInterception: Executing {len(search_tasks)} search(es) in parallel" + ) + search_results = await asyncio.gather(*search_tasks, return_exceptions=True) + + # Handle any exceptions in search results + final_search_results: List[str] = [] + for i, result in enumerate(search_results): + if isinstance(result, Exception): + verbose_logger.error( + f"WebSearchInterception: Search {i} failed with error: {str(result)}" + ) + final_search_results.append( + f"Search failed: {str(result)}" + ) + elif isinstance(result, str): + # Explicitly cast to str for type checker + final_search_results.append(cast(str, result)) + else: + # Should never happen, but handle for type safety + verbose_logger.warning( + f"WebSearchInterception: Unexpected result type {type(result)} at index {i}" + ) + final_search_results.append(str(result)) + + # Build assistant and user messages using transformation + assistant_message, user_message = WebSearchTransformation.transform_response( + tool_calls=tool_calls, + search_results=final_search_results, + ) + + # Make follow-up request with search results + follow_up_messages = messages + [assistant_message, user_message] + + verbose_logger.debug( + "WebSearchInterception: Making follow-up request with search results" + ) + verbose_logger.debug( + f"WebSearchInterception: Follow-up messages count: {len(follow_up_messages)}" + ) + verbose_logger.debug( + f"WebSearchInterception: Last message (tool_result): {user_message}" + ) + + # Use anthropic_messages.acreate for follow-up request + try: + # Extract max_tokens from optional params or kwargs + # max_tokens is a required parameter for anthropic_messages.acreate() + max_tokens = anthropic_messages_optional_request_params.get( + "max_tokens", + kwargs.get("max_tokens", 1024) # Default to 1024 if not found + ) + + verbose_logger.debug( + f"WebSearchInterception: Using max_tokens={max_tokens} for follow-up request" + ) + + # Create a copy of optional params without max_tokens (since we pass it explicitly) + optional_params_without_max_tokens = { + k: v for k, v in anthropic_messages_optional_request_params.items() + if k != 'max_tokens' + } + + # Get model from logging_obj.model_call_details["agentic_loop_params"] + # This preserves the full model name with provider prefix (e.g., "bedrock/invoke/...") + full_model_name = model + if logging_obj is not None: + agentic_params = logging_obj.model_call_details.get("agentic_loop_params", {}) + full_model_name = agentic_params.get("model", model) + verbose_logger.debug( + f"WebSearchInterception: Using model name: {full_model_name}" + ) + + final_response = await anthropic_messages.acreate( + max_tokens=max_tokens, + messages=follow_up_messages, + model=full_model_name, + **optional_params_without_max_tokens, + **kwargs, + ) + verbose_logger.debug( + f"WebSearchInterception: Follow-up request completed, response type: {type(final_response)}" + ) + verbose_logger.debug( + f"WebSearchInterception: Final response: {final_response}" + ) + return final_response + except Exception as e: + verbose_logger.exception( + f"WebSearchInterception: Follow-up request failed: {str(e)}" + ) + raise + + async def _execute_search(self, query: str) -> str: + """Execute a single web search using router's search tools""" + try: + # Import router from proxy_server + try: + from litellm.proxy.proxy_server import llm_router + except ImportError: + verbose_logger.warning( + "WebSearchInterception: Could not import llm_router from proxy_server, " + "falling back to direct litellm.asearch() with perplexity" + ) + llm_router = None + + # Determine search provider from router's search_tools + search_provider: Optional[str] = None + if llm_router is not None and hasattr(llm_router, "search_tools"): + if self.search_tool_name: + # Find specific search tool by name + matching_tools = [ + tool for tool in llm_router.search_tools + if tool.get("search_tool_name") == self.search_tool_name + ] + if matching_tools: + search_tool = matching_tools[0] + search_provider = search_tool.get("litellm_params", {}).get("search_provider") + verbose_logger.debug( + f"WebSearchInterception: Found search tool '{self.search_tool_name}' " + f"with provider '{search_provider}'" + ) + else: + verbose_logger.warning( + f"WebSearchInterception: Search tool '{self.search_tool_name}' not found in router, " + "falling back to first available or perplexity" + ) + + # If no specific tool or not found, use first available + if not search_provider and llm_router.search_tools: + first_tool = llm_router.search_tools[0] + search_provider = first_tool.get("litellm_params", {}).get("search_provider") + verbose_logger.debug( + f"WebSearchInterception: Using first available search tool with provider '{search_provider}'" + ) + + # Fallback to perplexity if no router or no search tools configured + if not search_provider: + search_provider = "perplexity" + verbose_logger.debug( + "WebSearchInterception: No search tools configured in router, " + f"using default provider '{search_provider}'" + ) + + verbose_logger.debug( + f"WebSearchInterception: Executing search for '{query}' using provider '{search_provider}'" + ) + result = await litellm.asearch( + query=query, search_provider=search_provider + ) + + # Format using transformation function + search_result_text = WebSearchTransformation.format_search_response(result) + + verbose_logger.debug( + f"WebSearchInterception: Search completed for '{query}', got {len(search_result_text)} chars" + ) + return search_result_text + except Exception as e: + verbose_logger.error( + f"WebSearchInterception: Search failed for '{query}': {str(e)}" + ) + raise + + async def _create_empty_search_result(self) -> str: + """Create an empty search result for tool calls without queries""" + return "No search query provided" + + @staticmethod + def initialize_from_proxy_config( + litellm_settings: Dict[str, Any], + callback_specific_params: Dict[str, Any], + ) -> "WebSearchInterceptionLogger": + """ + Static method to initialize WebSearchInterceptionLogger from proxy config. + + Used in callback_utils.py to simplify initialization logic. + + Args: + litellm_settings: Dictionary containing litellm_settings from proxy_config.yaml + callback_specific_params: Dictionary containing callback-specific parameters + + Returns: + Configured WebSearchInterceptionLogger instance + + Example: + From callback_utils.py: + websearch_obj = WebSearchInterceptionLogger.initialize_from_proxy_config( + litellm_settings=litellm_settings, + callback_specific_params=callback_specific_params + ) + """ + # Get websearch_interception_params from litellm_settings or callback_specific_params + websearch_params: WebSearchInterceptionConfig = {} + if "websearch_interception_params" in litellm_settings: + websearch_params = litellm_settings["websearch_interception_params"] + elif "websearch_interception" in callback_specific_params: + websearch_params = callback_specific_params["websearch_interception"] + + # Use classmethod to initialize from config + return WebSearchInterceptionLogger.from_config_yaml(websearch_params) diff --git a/litellm/integrations/websearch_interception/transformation.py b/litellm/integrations/websearch_interception/transformation.py new file mode 100644 index 00000000000..e8211311281 --- /dev/null +++ b/litellm/integrations/websearch_interception/transformation.py @@ -0,0 +1,184 @@ +""" +WebSearch Tool Transformation + +Transforms between Anthropic tool_use format and LiteLLM search format. +""" + +from typing import Any, Dict, List, Tuple + +from litellm._logging import verbose_logger +from litellm.llms.base_llm.search.transformation import SearchResponse + + +class WebSearchTransformation: + """ + Transformation class for WebSearch tool interception. + + Handles transformation between: + - Anthropic tool_use format → LiteLLM search requests + - LiteLLM SearchResponse → Anthropic tool_result format + """ + + @staticmethod + def transform_request( + response: Any, + stream: bool, + ) -> Tuple[bool, List[Dict]]: + """ + Transform Anthropic response to extract WebSearch tool calls. + + Detects if response contains WebSearch tool_use blocks and extracts + the search queries for execution. + + Args: + response: Model response (dict or AnthropicMessagesResponse) + stream: Whether response is streaming + + Returns: + (has_websearch, tool_calls): + has_websearch: True if WebSearch tool_use found + tool_calls: List of tool_use dicts with id, name, input + + Note: + Streaming requests are handled by converting stream=True to stream=False + in the WebSearchInterceptionLogger.async_log_pre_api_call hook before + the API request is made. This means by the time this method is called, + streaming requests have already been converted to non-streaming. + """ + if stream: + # This should not happen in practice since we convert streaming to non-streaming + # in async_log_pre_api_call, but keep this check for safety + verbose_logger.warning( + "WebSearchInterception: Unexpected streaming response, skipping interception" + ) + return False, [] + + # Parse non-streaming response + return WebSearchTransformation._detect_from_non_streaming_response(response) + + @staticmethod + def _detect_from_non_streaming_response( + response: Any, + ) -> Tuple[bool, List[Dict]]: + """Parse non-streaming response for WebSearch tool_use""" + + # Handle both dict and object responses + if isinstance(response, dict): + content = response.get("content", []) + else: + if not hasattr(response, "content"): + verbose_logger.debug( + "WebSearchInterception: Response has no content attribute" + ) + return False, [] + content = response.content or [] + + if not content: + verbose_logger.debug( + "WebSearchInterception: Response has empty content" + ) + return False, [] + + # Find all WebSearch tool_use blocks + tool_calls = [] + for block in content: + # Handle both dict and object blocks + if isinstance(block, dict): + block_type = block.get("type") + block_name = block.get("name") + block_id = block.get("id") + block_input = block.get("input", {}) + else: + block_type = getattr(block, "type", None) + block_name = getattr(block, "name", None) + block_id = getattr(block, "id", None) + block_input = getattr(block, "input", {}) + + if block_type == "tool_use" and block_name == "WebSearch": + # Convert to dict for easier handling + tool_call = { + "id": block_id, + "type": "tool_use", + "name": "WebSearch", + "input": block_input, + } + tool_calls.append(tool_call) + verbose_logger.debug( + f"WebSearchInterception: Found WebSearch tool_use with id={tool_call['id']}" + ) + + return len(tool_calls) > 0, tool_calls + + @staticmethod + def transform_response( + tool_calls: List[Dict], + search_results: List[str], + ) -> Tuple[Dict, Dict]: + """ + Transform LiteLLM search results to Anthropic tool_result format. + + Builds the assistant and user messages needed for the agentic loop + follow-up request. + + Args: + tool_calls: List of tool_use dicts from transform_request + search_results: List of search result strings (one per tool_call) + + Returns: + (assistant_message, user_message): + assistant_message: Message with tool_use blocks + user_message: Message with tool_result blocks + """ + # Build assistant message with tool_use blocks + assistant_message = { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": tc["id"], + "name": tc["name"], + "input": tc["input"], + } + for tc in tool_calls + ], + } + + # Build user message with tool_result blocks + user_message = { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_calls[i]["id"], + "content": search_results[i], + } + for i in range(len(tool_calls)) + ], + } + + return assistant_message, user_message + + @staticmethod + def format_search_response(result: SearchResponse) -> str: + """ + Format SearchResponse as text for tool_result content. + + Args: + result: SearchResponse from litellm.asearch() + + Returns: + Formatted text with Title, URL, Snippet for each result + """ + # Convert SearchResponse to string + if hasattr(result, "results") and result.results: + # Format results as text + search_result_text = "\n\n".join( + [ + f"Title: {r.title}\nURL: {r.url}\nSnippet: {r.snippet}" + for r in result.results + ] + ) + else: + search_result_text = str(result) + + return search_result_text diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/handler.py b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py index 908b46c11e2..11245b1bdba 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/handler.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py @@ -57,6 +57,38 @@ async def anthropic_messages( """ Async: Make llm api request in Anthropic /messages API spec """ + # WebSearch Interception: Convert stream=True to stream=False if WebSearch interception is enabled + # This allows transparent server-side agentic loop execution for streaming requests + if stream and tools and any(t.get("name") == "WebSearch" for t in tools): + # Extract provider using litellm's helper function + try: + _, provider, _, _ = litellm.get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + api_key=api_key, + ) + except Exception: + # Fallback to simple split if helper fails + provider = model.split("/")[0] if "/" in model else "" + + # Check if WebSearch interception is enabled in callbacks + from litellm._logging import verbose_logger + from litellm.integrations.websearch_interception import ( + WebSearchInterceptionLogger, + ) + if litellm.callbacks: + for callback in litellm.callbacks: + if isinstance(callback, WebSearchInterceptionLogger): + # Check if provider is enabled for interception + if provider in callback.enabled_providers: + verbose_logger.debug( + f"WebSearchInterception: Converting stream=True to stream=False for WebSearch interception " + f"(provider={provider})" + ) + stream = False + break + local_vars = locals() loop = asyncio.get_event_loop() kwargs["is_async"] = True @@ -145,6 +177,10 @@ def anthropic_messages_handler( # Use provided client or create a new one litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore + # Store original model name before get_llm_provider strips the provider prefix + # This is needed by agentic hooks (e.g., websearch_interception) to make follow-up requests + original_model = model + litellm_params = GenericLiteLLMParams( **kwargs, api_key=api_key, @@ -162,6 +198,14 @@ def anthropic_messages_handler( api_base=litellm_params.api_base, api_key=litellm_params.api_key, ) + + # Store agentic loop params in logging object for agentic hooks + # This provides original request context needed for follow-up calls + if litellm_logging_obj is not None: + litellm_logging_obj.model_call_details["agentic_loop_params"] = { + "model": original_model, + "custom_llm_provider": custom_llm_provider, + } if litellm_params.mock_response and isinstance(litellm_params.mock_response, str): diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 2f6d74eb7a7..490786155c6 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -1929,6 +1929,7 @@ async def async_anthropic_messages_handler( # used for logging + cost tracking logging_obj.model_call_details["httpx_response"] = response + initial_response: Union[AsyncIterator, AnthropicMessagesResponse] if stream: completion_stream = anthropic_messages_provider_config.get_async_streaming_response_iterator( model=model, @@ -1936,14 +1937,29 @@ async def async_anthropic_messages_handler( request_body=request_body, litellm_logging_obj=logging_obj, ) - return completion_stream + initial_response = completion_stream else: - return anthropic_messages_provider_config.transform_anthropic_messages_response( + initial_response = anthropic_messages_provider_config.transform_anthropic_messages_response( model=model, raw_response=response, logging_obj=logging_obj, ) + # Call agentic completion hooks + final_response = await self._call_agentic_completion_hooks( + response=initial_response, + model=model, + messages=messages, + anthropic_messages_provider_config=anthropic_messages_provider_config, + anthropic_messages_optional_request_params=anthropic_messages_optional_request_params, + logging_obj=logging_obj, + stream=stream or False, + custom_llm_provider=custom_llm_provider, + kwargs=kwargs, + ) + + return final_response if final_response is not None else initial_response + def anthropic_messages_handler( self, model: str, @@ -4334,6 +4350,76 @@ def _prepare_fake_stream_request( return stream, data return stream, data + async def _call_agentic_completion_hooks( + self, + response: Any, + model: str, + messages: List[Dict], + anthropic_messages_provider_config: "BaseAnthropicMessagesConfig", + anthropic_messages_optional_request_params: Dict, + logging_obj: "LiteLLMLoggingObj", + stream: bool, + custom_llm_provider: str, + kwargs: Dict, + ) -> Optional[Any]: + """ + Call agentic completion hooks for all custom loggers. + + 1. Call async_should_run_agentic_completion to check if agentic loop is needed + 2. If yes, call async_run_agentic_completion to execute the loop + + Returns the response from agentic loop, or None if no hook runs. + """ + from litellm._logging import verbose_logger + from litellm.integrations.custom_logger import CustomLogger + + callbacks = litellm.callbacks + ( + logging_obj.dynamic_success_callbacks or [] + ) + tools = anthropic_messages_optional_request_params.get("tools", []) + + for callback in callbacks: + try: + if isinstance(callback, CustomLogger): + # First: Check if agentic loop should run + should_run, tool_calls = ( + await callback.async_should_run_agentic_loop( + response=response, + model=model, + messages=messages, + tools=tools, + stream=stream, + custom_llm_provider=custom_llm_provider, + kwargs=kwargs, + ) + ) + + if should_run: + # Second: Execute agentic loop + # Add custom_llm_provider to kwargs so the agentic loop can reconstruct the full model name + kwargs_with_provider = kwargs.copy() if kwargs else {} + kwargs_with_provider["custom_llm_provider"] = custom_llm_provider + agentic_response = await callback.async_run_agentic_loop( + tools=tool_calls, + model=model, + messages=messages, + response=response, + anthropic_messages_provider_config=anthropic_messages_provider_config, + anthropic_messages_optional_request_params=anthropic_messages_optional_request_params, + logging_obj=logging_obj, + stream=stream, + kwargs=kwargs_with_provider, + ) + # First hook that runs agentic loop wins + return agentic_response + + except Exception as e: + verbose_logger.exception( + f"LiteLLM.AgenticHookError: Exception in agentic completion hooks: {str(e)}" + ) + + return None + def _handle_error( self, e: Exception, diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 52f7f227b52..0d3e61b75c7 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -749,19 +749,25 @@ async def base_process_llm_request( headers=custom_headers, ) elif route_type == "anthropic_messages": - selected_data_generator = ( - ProxyBaseLLMRequestProcessing.async_sse_data_generator( - response=response, - user_api_key_dict=user_api_key_dict, - request_data=self.data, - proxy_logging_obj=proxy_logging_obj, + # Check if response is actually a streaming response (async generator) + # Non-streaming responses (dict) should be returned directly + # This handles cases like websearch_interception agentic loop + # which returns a non-streaming dict even for streaming requests + if self._is_streaming_response(response): + selected_data_generator = ( + ProxyBaseLLMRequestProcessing.async_sse_data_generator( + response=response, + user_api_key_dict=user_api_key_dict, + request_data=self.data, + proxy_logging_obj=proxy_logging_obj, + ) ) - ) - return await create_response( - generator=selected_data_generator, - media_type="text/event-stream", - headers=custom_headers, - ) + return await create_response( + generator=selected_data_generator, + media_type="text/event-stream", + headers=custom_headers, + ) + # Non-streaming response - fall through to normal response handling elif select_data_generator: selected_data_generator = select_data_generator( response=response, diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 76c54332fa3..cb434da55b3 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -269,6 +269,16 @@ def initialize_callbacks_on_proxy( # noqa: PLR0915 **azure_content_safety_params, ) imported_list.append(azure_content_safety_obj) + elif isinstance(callback, str) and callback == "websearch_interception": + from litellm.integrations.websearch_interception.handler import ( + WebSearchInterceptionLogger, + ) + + websearch_interception_obj = WebSearchInterceptionLogger.initialize_from_proxy_config( + litellm_settings=litellm_settings, + callback_specific_params=callback_specific_params, + ) + imported_list.append(websearch_interception_obj) elif isinstance(callback, CustomLogger): imported_list.append(callback) else: diff --git a/litellm/proxy/example_config_yaml/websearch_interception_config.yaml b/litellm/proxy/example_config_yaml/websearch_interception_config.yaml new file mode 100644 index 00000000000..2c1cd623c30 --- /dev/null +++ b/litellm/proxy/example_config_yaml/websearch_interception_config.yaml @@ -0,0 +1,16 @@ +model_list: + - model_name: claude-3-5-sonnet + litellm_params: + model: bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0 + +# Search tools configuration +search_tools: + - search_tool_name: "my-perplexity-search" + litellm_params: + search_provider: "perplexity" + +litellm_settings: + success_callback: ["websearch_interception"] + websearch_interception_params: + enabled_providers: ["bedrock"] + search_tool_name: "my-perplexity-search" diff --git a/litellm/types/integrations/websearch_interception.py b/litellm/types/integrations/websearch_interception.py new file mode 100644 index 00000000000..d8a36169b88 --- /dev/null +++ b/litellm/types/integrations/websearch_interception.py @@ -0,0 +1,23 @@ +""" +Type definitions for WebSearch Interception integration. +""" + +from typing import List, Optional, TypedDict + + +class WebSearchInterceptionConfig(TypedDict, total=False): + """ + Configuration parameters for WebSearchInterceptionLogger. + + Used in proxy_config.yaml under litellm_settings: + litellm_settings: + websearch_interception_params: + enabled_providers: ["bedrock"] + search_tool_name: "my-perplexity-search" + """ + + enabled_providers: List[str] + """List of LLM provider names to enable interception for (e.g., ['bedrock', 'vertex_ai'])""" + + search_tool_name: Optional[str] + """Name of search tool configured in router's search_tools. If None, uses first available.""" diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 8301a6da2d9..8c30e4d7e80 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -109,6 +109,20 @@ class SearchContextCostPerQuery(TypedDict, total=False): search_context_size_high: float +class AgenticLoopParams(TypedDict, total=False): + """ + Parameters passed to agentic loop hooks (e.g., WebSearch interception). + + Stored in logging_obj.model_call_details["agentic_loop_params"] to provide + agentic hooks with the original request context needed for follow-up calls. + """ + model: str + """The model string with provider prefix (e.g., 'bedrock/invoke/...')""" + + custom_llm_provider: str + """The LLM provider name (e.g., 'bedrock', 'anthropic')""" + + class ModelInfoBase(ProviderSpecificModelInfo, total=False): key: Required[str] # the key in litellm.model_cost which is returned diff --git a/tests/pass_through_unit_tests/test_websearch_interception_e2e.py b/tests/pass_through_unit_tests/test_websearch_interception_e2e.py new file mode 100644 index 00000000000..2dec9da8b70 --- /dev/null +++ b/tests/pass_through_unit_tests/test_websearch_interception_e2e.py @@ -0,0 +1,325 @@ +""" +Real E2E Tests for WebSearch Interception + +Makes actual calls to test WebSearch interception with Perplexity. +Tests both streaming and non-streaming requests. +""" + +import os +import sys + +sys.path.insert(0, os.path.abspath("../..")) + +import litellm +from litellm.integrations.websearch_interception import ( + WebSearchInterceptionLogger, +) +from litellm.anthropic_interface import messages +from litellm.types.utils import LlmProviders + + +async def test_websearch_interception_non_streaming(): + """ + Test WebSearch interception with non-streaming request. + Validates that agentic loop executes transparently. + """ + litellm._turn_on_debug() + + print("\n" + "="*80) + print("E2E TEST 1: WebSearch Interception (Non-Streaming)") + print("="*80) + + # Initialize real router with search_tools configuration + import litellm.proxy.proxy_server as proxy_server + from litellm import Router + + # Create real router with search_tools + router = Router( + search_tools=[ + { + "search_tool_name": "my-perplexity-search", + "litellm_params": { + "search_provider": "perplexity" + } + } + ] + ) + proxy_server.llm_router = router + + print("\n✅ Initialized router with search_tools:") + print(f" - search_tool_name: my-perplexity-search") + print(f" - search_provider: perplexity") + + # Enable WebSearch interception for bedrock + websearch_logger = WebSearchInterceptionLogger( + enabled_providers=[LlmProviders.BEDROCK], + search_tool_name="my-perplexity-search", + ) + litellm.callbacks = [websearch_logger] + litellm.set_verbose = True + + print("\n✅ Configured WebSearch interception for Bedrock") + print("✅ Will use search tool from router") + + try: + # Make request with WebSearch tool (non-streaming) + print("\n📞 Making litellm.messages.acreate() call...") + print(f" Model: bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0") + print(f" Query: 'What is LiteLLM?'") + print(f" Tools: WebSearch") + print(f" Stream: False") + + response = await messages.acreate( + model="bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0", + messages=[{"role": "user", "content": "What is LiteLLM? Give me a brief overview."}], + tools=[ + { + "name": "WebSearch", + "description": "Search the web for information", + "input_schema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query", + } + }, + "required": ["query"], + }, + } + ], + max_tokens=1024, + stream=False, + ) + + print("\n✅ Received response!") + + # Handle both dict and object responses + if isinstance(response, dict): + response_id = response.get("id") + response_model = response.get("model") + response_stop_reason = response.get("stop_reason") + response_content = response.get("content", []) + else: + response_id = response.id + response_model = response.model + response_stop_reason = response.stop_reason + response_content = response.content + + print(f"\n📄 Response ID: {response_id}") + print(f"📄 Model: {response_model}") + print(f"📄 Stop Reason: {response_stop_reason}") + print(f"📄 Content blocks: {len(response_content)}") + + # Debug: Print all content block types + for i, block in enumerate(response_content): + block_type = block.get("type") if isinstance(block, dict) else block.type + print(f" Block {i}: type={block_type}") + if block_type == "tool_use": + block_name = block.get("name") if isinstance(block, dict) else block.name + print(f" name={block_name}") + + # Validate response + assert response is not None, "Response should not be None" + assert response_content is not None, "Response should have content" + assert len(response_content) > 0, "Response should have at least one content block" + + # Check if response contains tool_use (means interception didn't work) + has_tool_use = any( + (block.get("type") if isinstance(block, dict) else block.type) == "tool_use" + for block in response_content + ) + + # Check if we got a text response + has_text = any( + (block.get("type") if isinstance(block, dict) else block.type) == "text" + for block in response_content + ) + + if has_tool_use: + print("\n❌ TEST 1 FAILED: Interception did not work") + print(f"❌ Stop reason: {response_stop_reason}") + print("❌ Response contains tool_use blocks") + return False + + elif has_text and response_stop_reason != "tool_use": + text_block = next( + block for block in response_content + if (block.get("type") if isinstance(block, dict) else block.type) == "text" + ) + text_content = text_block.get("text") if isinstance(text_block, dict) else text_block.text + + print(f"\n📝 Response Text:") + print(f" {text_content[:200]}...") + + if "litellm" in text_content.lower(): + print("\n" + "="*80) + print("✅ TEST 1 PASSED!") + print("="*80) + print("✅ User made ONE litellm.messages.acreate() call") + print("✅ Got back final answer (not tool_use)") + print("✅ Agentic loop executed transparently") + print("✅ WebSearch interception working!") + print("="*80) + return True + else: + print("\n⚠️ Got text response but doesn't mention LiteLLM") + return False + else: + print("\n❌ Unexpected response format") + return False + + except Exception as e: + print(f"\n❌ Test 1 failed with error: {str(e)}") + import traceback + traceback.print_exc() + return False + + +async def test_websearch_interception_streaming(): + """ + Test WebSearch interception with streaming request. + Validates that stream=True is converted to stream=False transparently. + """ + print("\n" + "="*80) + print("E2E TEST 2: WebSearch Interception (Streaming)") + print("="*80) + + # Router already initialized from test 1 + print("\n✅ Using existing router configuration") + print("✅ WebSearch interception already enabled for Bedrock") + print("✅ Streaming will be converted to non-streaming for WebSearch interception") + + try: + # Make request with WebSearch tool AND stream=True + print("\n📞 Making litellm.messages.acreate() call with stream=True...") + print(f" Model: bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0") + print(f" Query: 'What is LiteLLM?'") + print(f" Tools: WebSearch") + print(f" Stream: True (will be converted to False)") + + response = await messages.acreate( + model="bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0", + messages=[{"role": "user", "content": "What is LiteLLM? Give me a brief overview."}], + tools=[ + { + "name": "WebSearch", + "description": "Search the web for information", + "input_schema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query", + } + }, + "required": ["query"], + }, + } + ], + max_tokens=1024, + stream=True, # REQUEST STREAMING + ) + + print("\n✅ Received response!") + + # Check if response is actually a stream (async generator) + import inspect + is_stream = inspect.isasyncgen(response) + + if is_stream: + print("\n⚠️ WARNING: Response is a stream (async_generator)") + print("⚠️ This means stream conversion didn't work!") + print("\n📦 Consuming stream chunks:") + + chunks = [] + chunk_count = 0 + async for chunk in response: + chunk_count += 1 + print(f"\n--- Chunk {chunk_count} ---") + print(chunk) + chunks.append(chunk) + + print(f"\n❌ TEST 2 FAILED: Got {len(chunks)} stream chunks instead of single response") + return False + + # If not a stream, validate as normal response + print("✅ Response is NOT a stream (conversion worked!)") + + # Handle both dict and object responses + if isinstance(response, dict): + response_id = response.get("id") + response_model = response.get("model") + response_stop_reason = response.get("stop_reason") + response_content = response.get("content", []) + else: + response_id = response.id + response_model = response.model + response_stop_reason = response.stop_reason + response_content = response.content + + print(f"\n📄 Response ID: {response_id}") + print(f"📄 Model: {response_model}") + print(f"📄 Stop Reason: {response_stop_reason}") + print(f"📄 Content blocks: {len(response_content)}") + + # Debug: Print all content block types + for i, block in enumerate(response_content): + block_type = block.get("type") if isinstance(block, dict) else block.type + print(f" Block {i}: type={block_type}") + + # Validate response + assert response is not None, "Response should not be None" + assert response_content is not None, "Response should have content" + assert len(response_content) > 0, "Response should have at least one content block" + + # Check if response contains tool_use (means interception didn't work) + has_tool_use = any( + (block.get("type") if isinstance(block, dict) else block.type) == "tool_use" + for block in response_content + ) + + # Check if we got a text response + has_text = any( + (block.get("type") if isinstance(block, dict) else block.type) == "text" + for block in response_content + ) + + if has_tool_use: + print("\n❌ TEST 2 FAILED: Interception did not work") + print("❌ Response contains tool_use blocks") + return False + + elif has_text and response_stop_reason != "tool_use": + text_block = next( + block for block in response_content + if (block.get("type") if isinstance(block, dict) else block.type) == "text" + ) + text_content = text_block.get("text") if isinstance(text_block, dict) else text_block.text + + print(f"\n📝 Response Text:") + print(f" {text_content[:200]}...") + + if "litellm" in text_content.lower(): + print("\n" + "="*80) + print("✅ TEST 2 PASSED!") + print("="*80) + print("✅ User made ONE litellm.messages.acreate() call with stream=True") + print("✅ Stream was transparently converted to non-streaming") + print("✅ Got back final answer (not tool_use)") + print("✅ Agentic loop executed transparently") + print("✅ WebSearch interception working with streaming!") + print("="*80) + return True + else: + print("\n⚠️ Got text response but doesn't mention LiteLLM") + return False + else: + print("\n❌ Unexpected response format") + return False + + except Exception as e: + print(f"\n❌ Test 2 failed with error: {str(e)}") + import traceback + traceback.print_exc() + return False diff --git a/tests/test_litellm/integrations/websearch_interception/test_handler.py b/tests/test_litellm/integrations/websearch_interception/test_handler.py new file mode 100644 index 00000000000..8ac53315aa0 --- /dev/null +++ b/tests/test_litellm/integrations/websearch_interception/test_handler.py @@ -0,0 +1,69 @@ +""" +Unit tests for WebSearch Interception Handler + +Tests the WebSearchInterceptionLogger class and helper functions. +""" + +from unittest.mock import Mock + +import pytest + +from litellm.integrations.websearch_interception.handler import ( + WebSearchInterceptionLogger, +) +from litellm.types.utils import LlmProviders + + +def test_initialize_from_proxy_config(): + """Test initialization from proxy config with litellm_settings""" + litellm_settings = { + "websearch_interception_params": { + "enabled_providers": ["bedrock", "vertex_ai"], + "search_tool_name": "my-search", + } + } + callback_specific_params = {} + + logger = WebSearchInterceptionLogger.initialize_from_proxy_config( + litellm_settings=litellm_settings, + callback_specific_params=callback_specific_params, + ) + + assert LlmProviders.BEDROCK.value in logger.enabled_providers + assert LlmProviders.VERTEX_AI.value in logger.enabled_providers + assert logger.search_tool_name == "my-search" + + +@pytest.mark.asyncio +async def test_async_should_run_agentic_loop(): + """Test that agentic loop is NOT triggered for wrong provider or missing WebSearch tool""" + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + # Test 1: Wrong provider (not in enabled_providers) + response = Mock() + should_run, tools_dict = await logger.async_should_run_agentic_loop( + response=response, + model="gpt-4", + messages=[], + tools=[{"name": "WebSearch"}], + stream=False, + custom_llm_provider="openai", # Not in enabled_providers + kwargs={}, + ) + + assert should_run is False + assert tools_dict == {} + + # Test 2: No WebSearch tool in request + should_run, tools_dict = await logger.async_should_run_agentic_loop( + response=response, + model="bedrock/claude", + messages=[], + tools=[{"name": "SomeOtherTool"}], # No WebSearch + stream=False, + custom_llm_provider="bedrock", + kwargs={}, + ) + + assert should_run is False + assert tools_dict == {}