Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 90 additions & 42 deletions litellm/integrations/websearch_interception/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
import litellm
from litellm._logging import verbose_logger
from litellm.anthropic_interface import messages as anthropic_messages
from litellm.constants import LITELLM_WEB_SEARCH_TOOL_NAME
from litellm.constants import DEFAULT_MAX_TOKENS, LITELLM_WEB_SEARCH_TOOL_NAME
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import filter_internal_params
from litellm.integrations.websearch_interception.tools import (
get_litellm_web_search_tool,
is_web_search_tool,
is_web_search_tool_chat_completion,
)
from litellm.integrations.websearch_interception.transformation import (
ResponseFormat,
WebSearchTransformation,
)
from litellm.types.integrations.websearch_interception import (
Expand Down Expand Up @@ -76,8 +78,9 @@ async def async_pre_call_deployment_hook(
Instead, we convert it to a regular tool so the model returns tool_use blocks
that we can intercept and execute ourselves.
"""
# Check if this is for an enabled provider
# Get provider from litellm_params (set by router in _add_deployment)
custom_llm_provider = kwargs.get("litellm_params", {}).get("custom_llm_provider", "")

if custom_llm_provider not in self.enabled_providers:
return None

Expand Down Expand Up @@ -111,8 +114,9 @@ async def async_pre_call_deployment_hook(
# Keep other tools as-is
converted_tools.append(tool)

# Return modified kwargs with converted tools
return {"tools": converted_tools}
# Return full kwargs with modified tools - spread preserves all other
# parameters (model, messages, etc.) for the pre_api_call hook contract
return {**kwargs, "tools": converted_tools}

@classmethod
def from_config_yaml(
Expand Down Expand Up @@ -275,29 +279,32 @@ async def async_should_run_agentic_loop(
return False, {}

# Detect WebSearch tool_use in response (Anthropic format)
should_intercept, tool_calls = WebSearchTransformation.transform_request(
transformed = WebSearchTransformation.transform_request(
response=response,
stream=stream,
response_format="anthropic",
)

if not should_intercept:
if not transformed.has_websearch:
verbose_logger.debug(
"WebSearchInterception: No WebSearch tool_use detected in response"
)
return False, {}

verbose_logger.debug(
f"WebSearchInterception: Detected {len(tool_calls)} WebSearch tool call(s), executing agentic loop"
f"WebSearchInterception: Detected {len(transformed.tool_calls)} WebSearch tool call(s), "
f"{len(transformed.thinking_blocks)} thinking block(s), executing agentic loop"
)

# Return tools dict with tool calls
# Return tools dict with tool calls and thinking blocks (if any)
tools_dict = {
"tool_calls": tool_calls,
"tool_calls": transformed.tool_calls,
"tool_type": "websearch",
"provider": custom_llm_provider,
"response_format": "anthropic",
}
if transformed.thinking_blocks:
tools_dict["thinking_blocks"] = transformed.thinking_blocks
return True, tools_dict

async def async_should_run_chat_completion_agentic_loop(
Expand Down Expand Up @@ -335,29 +342,32 @@ async def async_should_run_chat_completion_agentic_loop(
return False, {}

# Detect WebSearch tool_calls in response (OpenAI format)
should_intercept, tool_calls = WebSearchTransformation.transform_request(
transformed = WebSearchTransformation.transform_request(
response=response,
stream=stream,
response_format="openai",
)

if not should_intercept:
if not transformed.has_websearch:
verbose_logger.debug(
"WebSearchInterception: No WebSearch tool_calls detected in response"
)
return False, {}

verbose_logger.debug(
f"WebSearchInterception: Detected {len(tool_calls)} WebSearch tool call(s), executing agentic loop"
f"WebSearchInterception: Detected {len(transformed.tool_calls)} WebSearch tool call(s), "
f"{len(transformed.thinking_blocks)} thinking block(s), executing agentic loop"
)

# Return tools dict with tool calls
# Return tools dict with tool calls and thinking blocks (if any)
tools_dict = {
"tool_calls": tool_calls,
"tool_calls": transformed.tool_calls,
"tool_type": "websearch",
"provider": custom_llm_provider,
"response_format": "openai",
}
if transformed.thinking_blocks:
tools_dict["thinking_blocks"] = transformed.thinking_blocks
return True, tools_dict

async def async_run_agentic_loop(
Expand All @@ -379,6 +389,7 @@ async def async_run_agentic_loop(
"""

tool_calls = tools["tool_calls"]
thinking_blocks = tools.get("thinking_blocks", [])

verbose_logger.debug(
f"WebSearchInterception: Executing agentic loop for {len(tool_calls)} search(es)"
Expand All @@ -388,6 +399,7 @@ async def async_run_agentic_loop(
model=model,
messages=messages,
tool_calls=tool_calls,
thinking_blocks=thinking_blocks,
anthropic_messages_optional_request_params=anthropic_messages_optional_request_params,
logging_obj=logging_obj,
stream=stream,
Expand Down Expand Up @@ -429,19 +441,8 @@ async def async_run_chat_completion_agentic_loop(
response_format=response_format,
)

async def _execute_agentic_loop(
self,
model: str,
messages: List[Dict],
tool_calls: List[Dict],
anthropic_messages_optional_request_params: Dict,
logging_obj: Any,
stream: bool,
kwargs: Dict,
) -> Any:
"""Execute litellm.search() and make follow-up request"""

# Extract search queries from tool_use blocks
async def _execute_searches(self, tool_calls: List[Dict]) -> List[str]:
"""Execute search queries from tool_use blocks in parallel and return results."""
search_tasks = []
for tool_call in tool_calls:
query = tool_call["input"].get("query")
Expand All @@ -454,39 +455,50 @@ async def _execute_agentic_loop(
verbose_logger.warning(
f"WebSearchInterception: Tool call {tool_call['id']} has no query"
)
# Add empty result for tools without query
search_tasks.append(self._create_empty_search_result())

# Execute searches in parallel
verbose_logger.debug(
f"WebSearchInterception: Executing {len(search_tasks)} search(es) in parallel"
)
search_results = await asyncio.gather(*search_tasks, return_exceptions=True)

# Handle any exceptions in search results
final_search_results: List[str] = []
for i, result in enumerate(search_results):
if isinstance(result, Exception):
verbose_logger.error(
f"WebSearchInterception: Search {i} failed with error: {str(result)}"
)
final_search_results.append(
f"Search failed: {str(result)}"
)
final_search_results.append(f"Search failed: {str(result)}")
elif isinstance(result, str):
# Explicitly cast to str for type checker
final_search_results.append(cast(str, result))
else:
# Should never happen, but handle for type safety
verbose_logger.warning(
f"WebSearchInterception: Unexpected result type {type(result)} at index {i}"
)
final_search_results.append(str(result))
return final_search_results

async def _execute_agentic_loop(
self,
model: str,
messages: List[Dict],
tool_calls: List[Dict],
thinking_blocks: List[Dict],
anthropic_messages_optional_request_params: Dict,
logging_obj: Any,
stream: bool,
kwargs: Dict,
) -> Any:
"""Execute litellm.search() and make follow-up request"""

final_search_results = await self._execute_searches(tool_calls)

# Build assistant and user messages using transformation
# Include thinking_blocks to satisfy Anthropic's thinking mode requirements
assistant_message, user_message = WebSearchTransformation.transform_response(
tool_calls=tool_calls,
search_results=final_search_results,
thinking_blocks=thinking_blocks,
)

# Make follow-up request with search results
Expand All @@ -512,6 +524,26 @@ async def _execute_agentic_loop(
kwargs.get("max_tokens", 1024) # Default to 1024 if not found
)

# Validate and adjust max_tokens if needed to meet Anthropic's requirement
# Anthropic requires: max_tokens > thinking.budget_tokens
if "thinking" in anthropic_messages_optional_request_params:
thinking_param = anthropic_messages_optional_request_params.get("thinking", {})
if isinstance(thinking_param, dict) and thinking_param.get("type") == "enabled":
budget_tokens = thinking_param.get("budget_tokens", 0)

# Check if adjustment is needed
if budget_tokens > 0 and max_tokens <= budget_tokens:
# Use a formula that ensures sufficient tokens for response
# Follow pattern from litellm/llms/base_llm/chat/transformation.py
original_max_tokens = max_tokens
max_tokens = budget_tokens + DEFAULT_MAX_TOKENS

verbose_logger.warning(
f"WebSearchInterception: max_tokens ({original_max_tokens}) <= budget_tokens ({budget_tokens}). "
f"Adjusting max_tokens to {max_tokens} (budget_tokens + DEFAULT_MAX_TOKENS={DEFAULT_MAX_TOKENS}) "
f"to meet Anthropic's requirement"
)

verbose_logger.debug(
f"WebSearchInterception: Using max_tokens={max_tokens} for follow-up request"
)
Expand All @@ -524,9 +556,14 @@ async def _execute_agentic_loop(

# Remove internal websearch interception flags from kwargs before follow-up request
# These flags are used internally and should not be passed to the LLM provider
kwargs_for_followup = filter_internal_params(kwargs)

# Remove keys already present in optional_params or passed explicitly to avoid
# "got multiple values for keyword argument" errors (e.g. context_management)
explicit_keys = {"max_tokens", "messages", "model"}
kwargs_for_followup = {
k: v for k, v in kwargs.items()
if not k.startswith('_websearch_interception')
k: v for k, v in kwargs_for_followup.items()
if k not in optional_params_without_max_tokens and k not in explicit_keys
}

# Get model from logging_obj.model_call_details["agentic_loop_params"]
Expand Down Expand Up @@ -572,8 +609,10 @@ async def _execute_search(self, query: str) -> str:
)
llm_router = None

# Determine search provider from router's search_tools
# Determine search provider and credentials from router's search_tools
search_provider: Optional[str] = None
api_key: Optional[str] = None
api_base: Optional[str] = None
if llm_router is not None and hasattr(llm_router, "search_tools"):
if self.search_tool_name:
# Find specific search tool by name
Expand All @@ -583,7 +622,10 @@ async def _execute_search(self, query: str) -> str:
]
if matching_tools:
search_tool = matching_tools[0]
search_provider = search_tool.get("litellm_params", {}).get("search_provider")
litellm_params = search_tool.get("litellm_params", {})
search_provider = litellm_params.get("search_provider")
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
verbose_logger.debug(
f"WebSearchInterception: Found search tool '{self.search_tool_name}' "
f"with provider '{search_provider}'"
Expand All @@ -597,7 +639,10 @@ async def _execute_search(self, query: str) -> str:
# If no specific tool or not found, use first available
if not search_provider and llm_router.search_tools:
first_tool = llm_router.search_tools[0]
search_provider = first_tool.get("litellm_params", {}).get("search_provider")
litellm_params = first_tool.get("litellm_params", {})
search_provider = litellm_params.get("search_provider")
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
verbose_logger.debug(
f"WebSearchInterception: Using first available search tool with provider '{search_provider}'"
)
Expand All @@ -614,7 +659,10 @@ async def _execute_search(self, query: str) -> str:
f"WebSearchInterception: Executing search for '{query}' using provider '{search_provider}'"
)
result = await litellm.asearch(
query=query, search_provider=search_provider
query=query,
search_provider=search_provider,
api_key=api_key,
api_base=api_base,
)

# Format using transformation function
Expand All @@ -639,7 +687,7 @@ async def _execute_chat_completion_agentic_loop( # noqa: PLR0915
logging_obj: Any,
stream: bool,
kwargs: Dict,
response_format: str = "openai",
response_format: ResponseFormat = "openai",
) -> Any:
"""Execute litellm.search() and make follow-up chat completion request"""

Expand Down
Loading