diff --git a/litellm/integrations/websearch_interception/handler.py b/litellm/integrations/websearch_interception/handler.py index 1277cac51d7..729b9774493 100644 --- a/litellm/integrations/websearch_interception/handler.py +++ b/litellm/integrations/websearch_interception/handler.py @@ -12,14 +12,16 @@ import litellm from litellm._logging import verbose_logger from litellm.anthropic_interface import messages as anthropic_messages -from litellm.constants import LITELLM_WEB_SEARCH_TOOL_NAME +from litellm.constants import DEFAULT_MAX_TOKENS, LITELLM_WEB_SEARCH_TOOL_NAME from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import filter_internal_params from litellm.integrations.websearch_interception.tools import ( get_litellm_web_search_tool, is_web_search_tool, is_web_search_tool_chat_completion, ) from litellm.integrations.websearch_interception.transformation import ( + ResponseFormat, WebSearchTransformation, ) from litellm.types.integrations.websearch_interception import ( @@ -76,8 +78,9 @@ async def async_pre_call_deployment_hook( Instead, we convert it to a regular tool so the model returns tool_use blocks that we can intercept and execute ourselves. """ - # Check if this is for an enabled provider + # Get provider from litellm_params (set by router in _add_deployment) custom_llm_provider = kwargs.get("litellm_params", {}).get("custom_llm_provider", "") + if custom_llm_provider not in self.enabled_providers: return None @@ -111,8 +114,9 @@ async def async_pre_call_deployment_hook( # Keep other tools as-is converted_tools.append(tool) - # Return modified kwargs with converted tools - return {"tools": converted_tools} + # Return full kwargs with modified tools - spread preserves all other + # parameters (model, messages, etc.) for the pre_api_call hook contract + return {**kwargs, "tools": converted_tools} @classmethod def from_config_yaml( @@ -275,29 +279,32 @@ async def async_should_run_agentic_loop( return False, {} # Detect WebSearch tool_use in response (Anthropic format) - should_intercept, tool_calls = WebSearchTransformation.transform_request( + transformed = WebSearchTransformation.transform_request( response=response, stream=stream, response_format="anthropic", ) - if not should_intercept: + if not transformed.has_websearch: verbose_logger.debug( "WebSearchInterception: No WebSearch tool_use detected in response" ) return False, {} verbose_logger.debug( - f"WebSearchInterception: Detected {len(tool_calls)} WebSearch tool call(s), executing agentic loop" + f"WebSearchInterception: Detected {len(transformed.tool_calls)} WebSearch tool call(s), " + f"{len(transformed.thinking_blocks)} thinking block(s), executing agentic loop" ) - # Return tools dict with tool calls + # Return tools dict with tool calls and thinking blocks (if any) tools_dict = { - "tool_calls": tool_calls, + "tool_calls": transformed.tool_calls, "tool_type": "websearch", "provider": custom_llm_provider, "response_format": "anthropic", } + if transformed.thinking_blocks: + tools_dict["thinking_blocks"] = transformed.thinking_blocks return True, tools_dict async def async_should_run_chat_completion_agentic_loop( @@ -335,29 +342,32 @@ async def async_should_run_chat_completion_agentic_loop( return False, {} # Detect WebSearch tool_calls in response (OpenAI format) - should_intercept, tool_calls = WebSearchTransformation.transform_request( + transformed = WebSearchTransformation.transform_request( response=response, stream=stream, response_format="openai", ) - if not should_intercept: + if not transformed.has_websearch: verbose_logger.debug( "WebSearchInterception: No WebSearch tool_calls detected in response" ) return False, {} verbose_logger.debug( - f"WebSearchInterception: Detected {len(tool_calls)} WebSearch tool call(s), executing agentic loop" + f"WebSearchInterception: Detected {len(transformed.tool_calls)} WebSearch tool call(s), " + f"{len(transformed.thinking_blocks)} thinking block(s), executing agentic loop" ) - # Return tools dict with tool calls + # Return tools dict with tool calls and thinking blocks (if any) tools_dict = { - "tool_calls": tool_calls, + "tool_calls": transformed.tool_calls, "tool_type": "websearch", "provider": custom_llm_provider, "response_format": "openai", } + if transformed.thinking_blocks: + tools_dict["thinking_blocks"] = transformed.thinking_blocks return True, tools_dict async def async_run_agentic_loop( @@ -379,6 +389,7 @@ async def async_run_agentic_loop( """ tool_calls = tools["tool_calls"] + thinking_blocks = tools.get("thinking_blocks", []) verbose_logger.debug( f"WebSearchInterception: Executing agentic loop for {len(tool_calls)} search(es)" @@ -388,6 +399,7 @@ async def async_run_agentic_loop( model=model, messages=messages, tool_calls=tool_calls, + thinking_blocks=thinking_blocks, anthropic_messages_optional_request_params=anthropic_messages_optional_request_params, logging_obj=logging_obj, stream=stream, @@ -429,19 +441,8 @@ async def async_run_chat_completion_agentic_loop( response_format=response_format, ) - async def _execute_agentic_loop( - self, - model: str, - messages: List[Dict], - tool_calls: List[Dict], - anthropic_messages_optional_request_params: Dict, - logging_obj: Any, - stream: bool, - kwargs: Dict, - ) -> Any: - """Execute litellm.search() and make follow-up request""" - - # Extract search queries from tool_use blocks + async def _execute_searches(self, tool_calls: List[Dict]) -> List[str]: + """Execute search queries from tool_use blocks in parallel and return results.""" search_tasks = [] for tool_call in tool_calls: query = tool_call["input"].get("query") @@ -454,39 +455,50 @@ async def _execute_agentic_loop( verbose_logger.warning( f"WebSearchInterception: Tool call {tool_call['id']} has no query" ) - # Add empty result for tools without query search_tasks.append(self._create_empty_search_result()) - # Execute searches in parallel verbose_logger.debug( f"WebSearchInterception: Executing {len(search_tasks)} search(es) in parallel" ) search_results = await asyncio.gather(*search_tasks, return_exceptions=True) - # Handle any exceptions in search results final_search_results: List[str] = [] for i, result in enumerate(search_results): if isinstance(result, Exception): verbose_logger.error( f"WebSearchInterception: Search {i} failed with error: {str(result)}" ) - final_search_results.append( - f"Search failed: {str(result)}" - ) + final_search_results.append(f"Search failed: {str(result)}") elif isinstance(result, str): - # Explicitly cast to str for type checker final_search_results.append(cast(str, result)) else: - # Should never happen, but handle for type safety verbose_logger.warning( f"WebSearchInterception: Unexpected result type {type(result)} at index {i}" ) final_search_results.append(str(result)) + return final_search_results + + async def _execute_agentic_loop( + self, + model: str, + messages: List[Dict], + tool_calls: List[Dict], + thinking_blocks: List[Dict], + anthropic_messages_optional_request_params: Dict, + logging_obj: Any, + stream: bool, + kwargs: Dict, + ) -> Any: + """Execute litellm.search() and make follow-up request""" + + final_search_results = await self._execute_searches(tool_calls) # Build assistant and user messages using transformation + # Include thinking_blocks to satisfy Anthropic's thinking mode requirements assistant_message, user_message = WebSearchTransformation.transform_response( tool_calls=tool_calls, search_results=final_search_results, + thinking_blocks=thinking_blocks, ) # Make follow-up request with search results @@ -512,6 +524,26 @@ async def _execute_agentic_loop( kwargs.get("max_tokens", 1024) # Default to 1024 if not found ) + # Validate and adjust max_tokens if needed to meet Anthropic's requirement + # Anthropic requires: max_tokens > thinking.budget_tokens + if "thinking" in anthropic_messages_optional_request_params: + thinking_param = anthropic_messages_optional_request_params.get("thinking", {}) + if isinstance(thinking_param, dict) and thinking_param.get("type") == "enabled": + budget_tokens = thinking_param.get("budget_tokens", 0) + + # Check if adjustment is needed + if budget_tokens > 0 and max_tokens <= budget_tokens: + # Use a formula that ensures sufficient tokens for response + # Follow pattern from litellm/llms/base_llm/chat/transformation.py + original_max_tokens = max_tokens + max_tokens = budget_tokens + DEFAULT_MAX_TOKENS + + verbose_logger.warning( + f"WebSearchInterception: max_tokens ({original_max_tokens}) <= budget_tokens ({budget_tokens}). " + f"Adjusting max_tokens to {max_tokens} (budget_tokens + DEFAULT_MAX_TOKENS={DEFAULT_MAX_TOKENS}) " + f"to meet Anthropic's requirement" + ) + verbose_logger.debug( f"WebSearchInterception: Using max_tokens={max_tokens} for follow-up request" ) @@ -524,9 +556,14 @@ async def _execute_agentic_loop( # Remove internal websearch interception flags from kwargs before follow-up request # These flags are used internally and should not be passed to the LLM provider + kwargs_for_followup = filter_internal_params(kwargs) + + # Remove keys already present in optional_params or passed explicitly to avoid + # "got multiple values for keyword argument" errors (e.g. context_management) + explicit_keys = {"max_tokens", "messages", "model"} kwargs_for_followup = { - k: v for k, v in kwargs.items() - if not k.startswith('_websearch_interception') + k: v for k, v in kwargs_for_followup.items() + if k not in optional_params_without_max_tokens and k not in explicit_keys } # Get model from logging_obj.model_call_details["agentic_loop_params"] @@ -572,8 +609,10 @@ async def _execute_search(self, query: str) -> str: ) llm_router = None - # Determine search provider from router's search_tools + # Determine search provider and credentials from router's search_tools search_provider: Optional[str] = None + api_key: Optional[str] = None + api_base: Optional[str] = None if llm_router is not None and hasattr(llm_router, "search_tools"): if self.search_tool_name: # Find specific search tool by name @@ -583,7 +622,10 @@ async def _execute_search(self, query: str) -> str: ] if matching_tools: search_tool = matching_tools[0] - search_provider = search_tool.get("litellm_params", {}).get("search_provider") + litellm_params = search_tool.get("litellm_params", {}) + search_provider = litellm_params.get("search_provider") + api_key = litellm_params.get("api_key") + api_base = litellm_params.get("api_base") verbose_logger.debug( f"WebSearchInterception: Found search tool '{self.search_tool_name}' " f"with provider '{search_provider}'" @@ -597,7 +639,10 @@ async def _execute_search(self, query: str) -> str: # If no specific tool or not found, use first available if not search_provider and llm_router.search_tools: first_tool = llm_router.search_tools[0] - search_provider = first_tool.get("litellm_params", {}).get("search_provider") + litellm_params = first_tool.get("litellm_params", {}) + search_provider = litellm_params.get("search_provider") + api_key = litellm_params.get("api_key") + api_base = litellm_params.get("api_base") verbose_logger.debug( f"WebSearchInterception: Using first available search tool with provider '{search_provider}'" ) @@ -614,7 +659,10 @@ async def _execute_search(self, query: str) -> str: f"WebSearchInterception: Executing search for '{query}' using provider '{search_provider}'" ) result = await litellm.asearch( - query=query, search_provider=search_provider + query=query, + search_provider=search_provider, + api_key=api_key, + api_base=api_base, ) # Format using transformation function @@ -639,7 +687,7 @@ async def _execute_chat_completion_agentic_loop( # noqa: PLR0915 logging_obj: Any, stream: bool, kwargs: Dict, - response_format: str = "openai", + response_format: ResponseFormat = "openai", ) -> Any: """Execute litellm.search() and make follow-up chat completion request""" diff --git a/litellm/integrations/websearch_interception/transformation.py b/litellm/integrations/websearch_interception/transformation.py index e44ec35c3a2..5782c279e17 100644 --- a/litellm/integrations/websearch_interception/transformation.py +++ b/litellm/integrations/websearch_interception/transformation.py @@ -3,14 +3,40 @@ Transforms between Anthropic/OpenAI tool_use format and LiteLLM search format. """ + import json -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Literal, NamedTuple, Optional, Union + +ResponseFormat = Literal["openai", "anthropic"] from litellm._logging import verbose_logger from litellm.constants import LITELLM_WEB_SEARCH_TOOL_NAME from litellm.llms.base_llm.search.transformation import SearchResponse +class TransformedRequest(NamedTuple): + """Result of transform_request() for WebSearch tool detection.""" + + has_websearch: bool + """True if WebSearch tool_use was found in the response.""" + + tool_calls: List[Dict[str, Any]] + """List of tool_use dicts with id, name, input.""" + + thinking_blocks: List[Dict[str, Any]] + """List of thinking/redacted_thinking blocks to preserve.""" + + +class TransformedResponse(NamedTuple): + """Result of transform_response() for WebSearch tool results.""" + + assistant_message: Dict[str, Any] + """Assistant message with tool_use blocks (and thinking blocks for Anthropic).""" + + tool_result_messages: Union[Dict[str, Any], List[Dict[str, Any]]] + """User message with tool_result blocks (Anthropic) or list of tool messages (OpenAI).""" + + class WebSearchTransformation: """ Transformation class for WebSearch tool interception. @@ -25,13 +51,14 @@ class WebSearchTransformation: def transform_request( response: Any, stream: bool, - response_format: str = "anthropic", - ) -> Tuple[bool, List[Dict]]: + response_format: ResponseFormat = "anthropic", + ) -> TransformedRequest: """ Transform model response to extract WebSearch tool calls. Detects if response contains WebSearch tool_use/tool_calls blocks and extracts - the search queries for execution. + the search queries for execution. Also captures thinking blocks for + proper follow-up message construction. Args: response: Model response (dict, AnthropicMessagesResponse, or ModelResponse) @@ -39,9 +66,7 @@ def transform_request( response_format: Response format - "anthropic" or "openai" (default: "anthropic") Returns: - (has_websearch, tool_calls): - has_websearch: True if WebSearch tool_use found - tool_calls: List of tool_use/tool_calls dicts with id, name, input/function + TransformedRequest with has_websearch, tool_calls, and thinking_blocks Note: Streaming requests are handled by converting stream=True to stream=False @@ -52,10 +77,8 @@ def transform_request( if stream: # This should not happen in practice since we convert streaming to non-streaming # in async_log_pre_api_call, but keep this check for safety - verbose_logger.warning( - "WebSearchInterception: Unexpected streaming response, skipping interception" - ) - return False, [] + verbose_logger.warning("WebSearchInterception: Unexpected streaming response, skipping interception") + return TransformedRequest(False, [], []) # Parse non-streaming response based on format if response_format == "openai": @@ -66,28 +89,25 @@ def transform_request( @staticmethod def _detect_from_non_streaming_response( response: Any, - ) -> Tuple[bool, List[Dict]]: - """Parse non-streaming response for WebSearch tool_use""" + ) -> TransformedRequest: + """Parse non-streaming response for WebSearch tool_use and thinking blocks""" # Handle both dict and object responses if isinstance(response, dict): content = response.get("content", []) else: if not hasattr(response, "content"): - verbose_logger.debug( - "WebSearchInterception: Response has no content attribute" - ) - return False, [] + verbose_logger.debug("WebSearchInterception: Response has no content attribute") + return TransformedRequest(False, [], []) content = response.content or [] if not content: - verbose_logger.debug( - "WebSearchInterception: Response has empty content" - ) - return False, [] + verbose_logger.debug("WebSearchInterception: Response has empty content") + return TransformedRequest(False, [], []) - # Find all WebSearch tool_use blocks + # Find all WebSearch tool_use blocks and thinking blocks tool_calls = [] + thinking_blocks = [] for block in content: # Handle both dict and object blocks if isinstance(block, dict): @@ -101,10 +121,26 @@ def _detect_from_non_streaming_response( block_id = getattr(block, "id", None) block_input = getattr(block, "input", {}) + # Capture thinking and redacted_thinking blocks for follow-up messages + # Normalize to dict to ensure JSON serialization works + if block_type in ("thinking", "redacted_thinking"): + if isinstance(block, dict): + thinking_blocks.append(block) + else: + # Normalize SDK objects to dicts for safe serialization in follow-up requests + normalized = {"type": block_type} + for attr in ("thinking", "data", "signature"): + if hasattr(block, attr): + normalized[attr] = getattr(block, attr) + thinking_blocks.append(normalized) + verbose_logger.debug(f"WebSearchInterception: Captured {block_type} block for follow-up") + # Check for LiteLLM standard or legacy web search tools # Handles: litellm_web_search, WebSearch, web_search if block_type == "tool_use" and block_name in ( - LITELLM_WEB_SEARCH_TOOL_NAME, "WebSearch", "web_search" + LITELLM_WEB_SEARCH_TOOL_NAME, + "WebSearch", + "web_search", ): # Convert to dict for easier handling tool_call = { @@ -114,34 +150,28 @@ def _detect_from_non_streaming_response( "input": block_input, } tool_calls.append(tool_call) - verbose_logger.debug( - f"WebSearchInterception: Found {block_name} tool_use with id={tool_call['id']}" - ) + verbose_logger.debug(f"WebSearchInterception: Found {block_name} tool_use with id={block_id}") - return len(tool_calls) > 0, tool_calls + return TransformedRequest(len(tool_calls) > 0, tool_calls, thinking_blocks) @staticmethod def _detect_from_openai_response( response: Any, - ) -> Tuple[bool, List[Dict]]: + ) -> TransformedRequest: """Parse OpenAI-style response for WebSearch tool_calls""" - + # Handle both dict and ModelResponse objects if isinstance(response, dict): choices = response.get("choices", []) else: if not hasattr(response, "choices"): - verbose_logger.debug( - "WebSearchInterception: Response has no choices attribute" - ) - return False, [] + verbose_logger.debug("WebSearchInterception: Response has no choices attribute") + return TransformedRequest(False, [], []) choices = response.choices or [] if not choices: - verbose_logger.debug( - "WebSearchInterception: Response has empty choices" - ) - return False, [] + verbose_logger.debug("WebSearchInterception: Response has empty choices") + return TransformedRequest(False, [], []) # Get first choice's message first_choice = choices[0] @@ -149,12 +179,10 @@ def _detect_from_openai_response( message = first_choice.get("message", {}) else: message = getattr(first_choice, "message", None) - + if not message: - verbose_logger.debug( - "WebSearchInterception: First choice has no message" - ) - return False, [] + verbose_logger.debug("WebSearchInterception: First choice has no message") + return TransformedRequest(False, [], []) # Get tool_calls from message if isinstance(message, dict): @@ -163,10 +191,8 @@ def _detect_from_openai_response( openai_tool_calls = getattr(message, "tool_calls", None) or [] if not openai_tool_calls: - verbose_logger.debug( - "WebSearchInterception: Message has no tool_calls" - ) - return False, [] + verbose_logger.debug("WebSearchInterception: Message has no tool_calls") + return TransformedRequest(False, [], []) # Find all WebSearch tool calls tool_calls = [] @@ -177,7 +203,9 @@ def _detect_from_openai_response( tool_type = tool_call.get("type") function = tool_call.get("function", {}) function_name = function.get("name") if isinstance(function, dict) else getattr(function, "name", None) - function_arguments = function.get("arguments") if isinstance(function, dict) else getattr(function, "arguments", None) + function_arguments = ( + function.get("arguments") if isinstance(function, dict) else getattr(function, "arguments", None) + ) else: tool_id = getattr(tool_call, "id", None) tool_type = getattr(tool_call, "type", None) @@ -186,9 +214,7 @@ def _detect_from_openai_response( function_arguments = getattr(function, "arguments", None) if function else None # Check for LiteLLM standard or legacy web search tools - if tool_type == "function" and function_name in ( - LITELLM_WEB_SEARCH_TOOL_NAME, "WebSearch", "web_search" - ): + if tool_type == "function" and function_name in (LITELLM_WEB_SEARCH_TOOL_NAME, "WebSearch", "web_search"): # Parse arguments (might be JSON string) if isinstance(function_arguments, str): try: @@ -213,18 +239,17 @@ def _detect_from_openai_response( "input": arguments, # For compatibility with Anthropic format } tool_calls.append(tool_call_dict) - verbose_logger.debug( - f"WebSearchInterception: Found {function_name} tool_call with id={tool_id}" - ) + verbose_logger.debug(f"WebSearchInterception: Found {function_name} tool_call with id={tool_id}") - return len(tool_calls) > 0, tool_calls + return TransformedRequest(len(tool_calls) > 0, tool_calls, []) @staticmethod def transform_response( - tool_calls: List[Dict], + tool_calls: List[Dict[str, Any]], search_results: List[str], - response_format: str = "anthropic", - ) -> Tuple[Dict, Union[Dict, List[Dict]]]: + response_format: ResponseFormat = "anthropic", + thinking_blocks: Optional[List[Dict[str, Any]]] = None, + ) -> TransformedResponse: """ Transform LiteLLM search results to Anthropic/OpenAI tool_result format. @@ -235,31 +260,38 @@ def transform_response( tool_calls: List of tool_use/tool_calls dicts from transform_request search_results: List of search result strings (one per tool_call) response_format: Response format - "anthropic" or "openai" (default: "anthropic") + thinking_blocks: List of thinking/redacted_thinking blocks to include at the start of + assistant message (Anthropic format only) Returns: (assistant_message, user_or_tool_messages): - For Anthropic: assistant_message with tool_use blocks, user_message with tool_result blocks + For Anthropic: assistant_message with thinking + tool_use blocks, user_message with tool_result blocks For OpenAI: assistant_message with tool_calls, tool_messages list with tool results """ if response_format == "openai": - return WebSearchTransformation._transform_response_openai( - tool_calls, search_results - ) + return WebSearchTransformation._transform_response_openai(tool_calls, search_results) else: return WebSearchTransformation._transform_response_anthropic( - tool_calls, search_results + tool_calls, search_results, thinking_blocks or [] ) @staticmethod def _transform_response_anthropic( - tool_calls: List[Dict], + tool_calls: List[Dict[str, Any]], search_results: List[str], - ) -> Tuple[Dict, Dict]: - """Transform to Anthropic format (single user message with tool_result blocks)""" - # Build assistant message with tool_use blocks - assistant_message = { - "role": "assistant", - "content": [ + thinking_blocks: List[Dict[str, Any]], + ) -> TransformedResponse: + """Transform to Anthropic format with optional thinking blocks""" + # Build assistant message content - thinking blocks first, then tool_use + assistant_content: List[Dict[str, Any]] = [] + + # Add thinking blocks at the start (required when thinking is enabled) + if thinking_blocks: + assistant_content.extend(thinking_blocks) + + # Add tool_use blocks + assistant_content.extend( + [ { "type": "tool_use", "id": tc["id"], @@ -267,7 +299,12 @@ def _transform_response_anthropic( "input": tc["input"], } for tc in tool_calls - ], + ] + ) + + assistant_message = { + "role": "assistant", + "content": assistant_content, } # Build user message with tool_result blocks @@ -283,13 +320,13 @@ def _transform_response_anthropic( ], } - return assistant_message, user_message + return TransformedResponse(assistant_message, user_message) @staticmethod def _transform_response_openai( - tool_calls: List[Dict], + tool_calls: List[Dict[str, Any]], search_results: List[str], - ) -> Tuple[Dict, List[Dict]]: + ) -> TransformedResponse: """Transform to OpenAI format (assistant with tool_calls, separate tool messages)""" # Build assistant message with tool_calls assistant_message = { @@ -317,7 +354,7 @@ def _transform_response_openai( for i in range(len(tool_calls)) ] - return assistant_message, tool_messages + return TransformedResponse(assistant_message, tool_messages) @staticmethod def format_search_response(result: SearchResponse) -> str: @@ -334,10 +371,7 @@ def format_search_response(result: SearchResponse) -> str: if hasattr(result, "results") and result.results: # Format results as text search_result_text = "\n\n".join( - [ - f"Title: {r.title}\nURL: {r.url}\nSnippet: {r.snippet}" - for r in result.results - ] + [f"Title: {r.title}\nURL: {r.url}\nSnippet: {r.snippet}" for r in result.results] ) else: search_result_text = str(result) diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 7c8e2ebeaff..985e3c9ad95 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -1,6 +1,6 @@ # What is this? ## Helper utilities -from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Union, Final import httpx @@ -16,10 +16,20 @@ else: Span = Any +# Known internal parameters that should never be sent to provider APIs +INTERNAL_PARAMS: Final = { + "skip_mcp_handler", + "mcp_handler_context", + "_skip_mcp_handler", +} -def safe_divide_seconds( - seconds: float, denominator: float, default: Optional[float] = None -) -> Optional[float]: +# Known internal parameters prefixes that should never be sent to provider APIs +INTERNAL_PARAMS_PREFIXES: Final = { + "_websearch_interception", +} + + +def safe_divide_seconds(seconds: float, denominator: float, default: Optional[float] = None) -> Optional[float]: """ Safely divide seconds by denominator, handling zero division. @@ -71,9 +81,7 @@ def map_finish_reason( return "length" elif finish_reason == "ERROR_TOXIC": return "content_filter" - elif ( - finish_reason == "ERROR" - ): # openai currently doesn't support an 'error' finish reason + elif finish_reason == "ERROR": # openai currently doesn't support an 'error' finish reason return "stop" # huggingface mapping https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate_stream elif finish_reason == "eos_token" or finish_reason == "stop_sequence": @@ -107,9 +115,7 @@ def remove_index_from_tool_calls( _tool_calls = message.get("tool_calls") if _tool_calls is not None and isinstance(_tool_calls, list): for tool_call in _tool_calls: - if ( - isinstance(tool_call, dict) and "index" in tool_call - ): # Type guard to ensure it's a dict + if isinstance(tool_call, dict) and "index" in tool_call: # Type guard to ensure it's a dict tool_call.pop("index", None) return @@ -124,9 +130,7 @@ def remove_items_at_indices(items: Optional[List[Any]], indices: Iterable[int]) items.pop(index) -def add_missing_spend_metadata_to_litellm_metadata( - litellm_metadata: dict, metadata: dict -) -> dict: +def add_missing_spend_metadata_to_litellm_metadata(litellm_metadata: dict, metadata: dict) -> dict: """ Helper to get litellm metadata for spend tracking @@ -168,9 +172,7 @@ def get_litellm_metadata_from_kwargs(kwargs: dict): metadata = litellm_params.get("metadata", {}) litellm_metadata = litellm_params.get("litellm_metadata", {}) if litellm_metadata and metadata: - litellm_metadata = add_missing_spend_metadata_to_litellm_metadata( - litellm_metadata, metadata - ) + litellm_metadata = add_missing_spend_metadata_to_litellm_metadata(litellm_metadata, metadata) if litellm_metadata: return litellm_metadata elif metadata: @@ -219,9 +221,7 @@ def _get_parent_otel_span_from_kwargs( return kwargs["litellm_parent_otel_span"] return None except Exception as e: - verbose_logger.exception( - "Error in _get_parent_otel_span_from_kwargs: " + str(e) - ) + verbose_logger.exception("Error in _get_parent_otel_span_from_kwargs: " + str(e)) return None @@ -235,9 +235,7 @@ def process_response_headers(response_headers: Union[httpx.Headers, dict]) -> di for k, v in response_headers.items(): if k in OPENAI_RESPONSE_HEADERS: # return openai-compatible headers openai_headers[k] = v - if k.startswith( - "llm_provider-" - ): # return raw provider headers (incl. openai-compatible ones) + if k.startswith("llm_provider-"): # return raw provider headers (incl. openai-compatible ones) processed_headers[k] = v else: additional_headers["{}-{}".format("llm_provider", k)] = v @@ -288,13 +286,8 @@ def safe_deep_copy(data): if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span") data["metadata"]["litellm_parent_otel_span"] = "placeholder" - if ( - "litellm_metadata" in data - and "litellm_parent_otel_span" in data["litellm_metadata"] - ): - litellm_parent_otel_span = data["litellm_metadata"].pop( - "litellm_parent_otel_span" - ) + if "litellm_metadata" in data and "litellm_parent_otel_span" in data["litellm_metadata"]: + litellm_parent_otel_span = data["litellm_metadata"].pop("litellm_parent_otel_span") data["litellm_metadata"]["litellm_parent_otel_span"] = "placeholder" # Step 2: Per-key deepcopy with fallback @@ -315,13 +308,8 @@ def safe_deep_copy(data): if isinstance(data, dict) and litellm_parent_otel_span is not None: if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span - if ( - "litellm_metadata" in data - and "litellm_parent_otel_span" in data["litellm_metadata"] - ): - data["litellm_metadata"][ - "litellm_parent_otel_span" - ] = litellm_parent_otel_span + if "litellm_metadata" in data and "litellm_parent_otel_span" in data["litellm_metadata"]: + data["litellm_metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span return new_data @@ -374,9 +362,7 @@ def filter_exceptions_from_params(data: Any, max_depth: int = 20) -> Any: result_list: list[Any] = [] for item in data: # Skip exception and callable items - if isinstance(item, Exception) or ( - callable(item) and not isinstance(item, type) - ): + if isinstance(item, Exception) or (callable(item) and not isinstance(item, type)): continue try: filtered = filter_exceptions_from_params(item, max_depth - 1) @@ -390,9 +376,28 @@ def filter_exceptions_from_params(data: Any, max_depth: int = 20) -> Any: return data -def filter_internal_params( - data: dict, additional_internal_params: Optional[set] = None -) -> dict: +def _is_param_internal(param: str, additional_internal_params: Optional[set]) -> bool: + """ + Check if a parameter is internal and should not be sent to provider APIs. + + Args: + param: Parameter name to check + additional_internal_params: Optional set of extra internal param names + + Returns: + True if param matches INTERNAL_PARAMS, additional_internal_params, + or starts with any INTERNAL_PARAMS_PREFIXES + """ + if param in INTERNAL_PARAMS: + return True + if additional_internal_params and param in additional_internal_params: + return True + if any(param.startswith(prefix) for prefix in INTERNAL_PARAMS_PREFIXES): + return True + return False + + +def filter_internal_params(data: dict, additional_internal_params: Optional[set] = None) -> dict: """ Filter out LiteLLM internal parameters that shouldn't be sent to provider APIs. @@ -409,16 +414,4 @@ def filter_internal_params( if not isinstance(data, dict): return data - # Known internal parameters that should never be sent to provider APIs - internal_params = { - "skip_mcp_handler", - "mcp_handler_context", - "_skip_mcp_handler", - } - - # Add any additional internal params if provided - if additional_internal_params: - internal_params.update(additional_internal_params) - - # Filter out internal parameters - return {k: v for k, v in data.items() if k not in internal_params} + return {k: v for k, v in data.items() if not _is_param_internal(k, additional_internal_params)} diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index a5f8fe22a2c..8d63e9fd343 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -65,6 +65,7 @@ any_assistant_message_has_thinking_blocks, get_max_tokens, has_tool_call_blocks, + last_assistant_message_has_no_thinking_blocks, last_assistant_with_tool_calls_has_no_thinking_blocks, supports_reasoning, token_counter, @@ -1200,7 +1201,9 @@ def transform_request( ) # Drop thinking param if thinking is enabled but thinking_blocks are missing - # This prevents the error: "Expected thinking or redacted_thinking, but found tool_use" + # This prevents Anthropic errors: + # - "Expected thinking or redacted_thinking, but found tool_use" (assistant with tool_calls) + # - "Expected thinking or redacted_thinking, but found text" (assistant with text content) # # IMPORTANT: Only drop thinking if NO assistant messages have thinking_blocks. # If any message has thinking_blocks, we must keep thinking enabled, otherwise @@ -1209,7 +1212,10 @@ def transform_request( if ( optional_params.get("thinking") is not None and messages is not None - and last_assistant_with_tool_calls_has_no_thinking_blocks(messages) + and ( + last_assistant_with_tool_calls_has_no_thinking_blocks(messages) + or last_assistant_message_has_no_thinking_blocks(messages) + ) and not any_assistant_message_has_thinking_blocks(messages) ): if litellm.modify_params: diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index ac209904e6e..6b8b406d114 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -110,8 +110,9 @@ def get_json_schema_from_pydantic_object( return type_to_response_format_param(response_format=response_format) def is_thinking_enabled(self, non_default_params: dict) -> bool: + thinking_type = non_default_params.get("thinking", {}).get("type") return ( - non_default_params.get("thinking", {}).get("type") == "enabled" + thinking_type in ("enabled", "adaptive") or non_default_params.get("reasoning_effort") is not None ) diff --git a/litellm/llms/bedrock/README.md b/litellm/llms/bedrock/README.md new file mode 100644 index 00000000000..2963eaa54d1 --- /dev/null +++ b/litellm/llms/bedrock/README.md @@ -0,0 +1,67 @@ +# AWS Bedrock Provider + +This directory contains the AWS Bedrock provider implementation for LiteLLM. + +## Beta Headers Management + +### Overview + +Bedrock anthropic-beta header handling uses a centralized whitelist-based filter (`beta_headers_config.py`) across all three Bedrock APIs to ensure: +- Only supported headers reach AWS (prevents API errors) +- Consistent behavior across Invoke Chat, Invoke Messages, and Converse APIs +- Zero maintenance when new Claude models are released + +### Key Features + +1. **Version-Based Filtering**: Headers specify minimum version (e.g., "requires Claude 4.5+") instead of hardcoded model lists +2. **Family Restrictions**: Can limit headers to specific families (opus/sonnet/haiku) +3. **Automatic Translation**: `advanced-tool-use` → `tool-search-tool` + `tool-examples` for backward compatibility + +### Adding New Beta Headers + +When AWS Bedrock adds support for a new Anthropic beta header, update `beta_headers_config.py`: + +```python +# 1. Add to whitelist +BEDROCK_CORE_SUPPORTED_BETAS.add("new-feature-2027-01-15") + +# 2. (Optional) Add version requirement +BETA_HEADER_MINIMUM_VERSION["new-feature-2027-01-15"] = 5.0 + +# 3. (Optional) Add family restriction +BETA_HEADER_FAMILY_RESTRICTIONS["new-feature-2027-01-15"] = ["opus"] +``` + +Then add tests in `tests/test_litellm/llms/bedrock/test_beta_headers_config.py`. + +### Adding New Claude Models + +When Anthropic releases new models (e.g., Claude Opus 5): +- **Required code changes**: ZERO ✅ +- The version-based filter automatically handles new models +- No hardcoded lists to update + +### Testing + +```bash +# Test beta headers filtering +poetry run pytest tests/test_litellm/llms/bedrock/test_beta_headers_config.py -v + +# Test API integrations +poetry run pytest tests/test_litellm/llms/bedrock/test_anthropic_beta_support.py -v + +# Test everything +poetry run pytest tests/test_litellm/llms/bedrock/ -v +``` + +### Debug Logging + +Enable debug logging to see filtering decisions: +```bash +LITELLM_LOG=DEBUG +``` + +### References + +- [AWS Bedrock Documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html) +- [Anthropic Beta Headers](https://docs.anthropic.com/claude/reference/versioning) diff --git a/litellm/llms/bedrock/beta_headers_config.py b/litellm/llms/bedrock/beta_headers_config.py new file mode 100644 index 00000000000..98c9ed30fb6 --- /dev/null +++ b/litellm/llms/bedrock/beta_headers_config.py @@ -0,0 +1,392 @@ +""" +Shared configuration for Bedrock anthropic-beta header handling. + +This module provides centralized whitelist-based filtering for anthropic-beta +headers across all Bedrock APIs (Invoke Chat, Invoke Messages, Converse). + +## Architecture + +All three Bedrock APIs use BedrockBetaHeaderFilter to ensure consistent filtering: +- Invoke Chat API: BedrockAPI.INVOKE_CHAT +- Invoke Messages API: BedrockAPI.INVOKE_MESSAGES (with advanced-tool-use translation) +- Converse API: BedrockAPI.CONVERSE + +## Future-Proof Design + +The filter uses version-based model support instead of hardcoded model lists: +- New Claude models (e.g., Opus 5, Sonnet 5) require ZERO code changes +- Beta headers specify minimum version (e.g., "requires 4.5+") +- Family restrictions (opus/sonnet/haiku) when needed + +## Adding New Beta Headers + +When AWS Bedrock adds support for a new Anthropic beta header: + +**Scenario 1: Works on all models** +```python +BEDROCK_CORE_SUPPORTED_BETAS.add("new-feature-2027-01-15") +# Done! Works on all models automatically. +``` + +**Scenario 2: Requires specific version** +```python +BEDROCK_CORE_SUPPORTED_BETAS.add("advanced-reasoning-2027-06-15") +BETA_HEADER_MINIMUM_VERSION["advanced-reasoning-2027-06-15"] = 5.0 +# Done! Works on all Claude 5.0+ models (Opus, Sonnet, Haiku). +``` + +**Scenario 3: Version + family restriction** +```python +BEDROCK_CORE_SUPPORTED_BETAS.add("ultra-context-2027-12-15") +BETA_HEADER_MINIMUM_VERSION["ultra-context-2027-12-15"] = 5.5 +BETA_HEADER_FAMILY_RESTRICTIONS["ultra-context-2027-12-15"] = ["opus"] +# Done! Works on Opus 5.5+ only. +``` + +**Always add tests** in `tests/test_litellm/llms/bedrock/test_beta_headers_config.py` + +## Testing + +Run the test suite to verify changes: +```bash +poetry run pytest tests/test_litellm/llms/bedrock/test_beta_headers_config.py -v +poetry run pytest tests/test_litellm/llms/bedrock/test_anthropic_beta_support.py -v +``` + +## Debug Logging + +Enable debug logging to see filtering decisions: +```bash +LITELLM_LOG=DEBUG +``` + +Reference: +- AWS Bedrock Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html +""" + +import re +from enum import Enum +from typing import Dict, List, Optional, Set + +from litellm._logging import verbose_logger + + +class BedrockAPI(Enum): + """Enum for different Bedrock API types.""" + + INVOKE_CHAT = "invoke_chat" + INVOKE_MESSAGES = "invoke_messages" + CONVERSE = "converse" + + +# Core whitelist of beta headers supported by ALL Bedrock APIs +BEDROCK_CORE_SUPPORTED_BETAS: Set[str] = { + "computer-use-2024-10-22", # Legacy computer use + "computer-use-2025-01-24", # Current computer use (Claude 3.7 Sonnet) + "computer-use-2025-11-24", # Latest computer use (Claude Opus 4.5+) + "token-efficient-tools-2025-02-19", # Tool use (Claude 3.7+ and Claude 4+) + "interleaved-thinking-2025-05-14", # Interleaved thinking (Claude 4+) + "output-128k-2025-02-19", # 128K output tokens (Claude 3.7 Sonnet) + "dev-full-thinking-2025-05-14", # Developer mode for raw thinking (Claude 4+) + "context-1m-2025-08-07", # 1 million tokens (Claude Sonnet 4) + "context-management-2025-06-27", # Context management (Claude Sonnet/Haiku 4.5) + "effort-2025-11-24", # Effort parameter (Claude Opus 4.5) + "tool-search-tool-2025-10-19", # Tool search (Claude Opus 4.5) + "tool-examples-2025-10-29", # Tool use examples (Claude Opus 4.5) +} + +# API-specific exclusions (headers NOT supported by specific APIs) +BEDROCK_API_EXCLUSIONS: Dict[BedrockAPI, Set[str]] = { + BedrockAPI.CONVERSE: set(), # No additional exclusions + BedrockAPI.INVOKE_CHAT: set(), # No additional exclusions + BedrockAPI.INVOKE_MESSAGES: set(), # No additional exclusions +} + +# Model version extraction regex pattern +# Matches Bedrock model IDs in both formats: +# New: claude-{family}-{major}-{minor}-{date} (e.g., claude-opus-4-5-20250514-v1:0) +# Legacy: claude-{major}-{minor}-{family}-{date} (e.g., claude-3-5-sonnet-20240620-v1:0) +# Minor version is a single digit followed by a hyphen (to avoid capturing the date). +MODEL_VERSION_PATTERN = r"claude-(?:(?:opus|sonnet|haiku)-)?(\d+)(?:-(\d)-)?" + +# Minimum model version required for each beta header (major.minor format) +# Default behavior: If a beta header is NOT in this dict, it's supported by ALL Anthropic models +# This approach is future-proof - new models automatically support all headers unless excluded +BETA_HEADER_MINIMUM_VERSION: Dict[str, float] = { + # Extended thinking features require Claude 4.0+ + "interleaved-thinking-2025-05-14": 4.0, + "dev-full-thinking-2025-05-14": 4.0, + # 1M context requires Claude 4.0+ + "context-1m-2025-08-07": 4.0, + # Context management requires Claude 4.5+ + "context-management-2025-06-27": 4.5, + # Effort parameter requires Claude 4.5+ (but only Opus 4.5, see family restrictions) + "effort-2025-11-24": 4.5, + # Tool search requires Claude 4.5+ + "tool-search-tool-2025-10-19": 4.5, + "tool-examples-2025-10-29": 4.5, +} + +# Model family restrictions for specific beta headers +# Only enforced if the version requirement is met +# Example: "effort-2025-11-24" requires Claude 4.5+ AND Opus family +BETA_HEADER_FAMILY_RESTRICTIONS: Dict[str, List[str]] = { + "effort-2025-11-24": ["opus"], # Only Opus 4.5+ supports effort + # Tool search works on Opus 4.5+ and Sonnet 4.5+, but not Haiku + "tool-search-tool-2025-10-19": ["opus", "sonnet"], + "tool-examples-2025-10-29": ["opus", "sonnet"], +} + +# Beta headers that should be translated for backward compatibility +# Maps input header pattern to output headers +# Uses version-based approach for future-proofing +BETA_HEADER_TRANSLATIONS: Dict[str, Dict] = { + "advanced-tool-use": { + "target_headers": ["tool-search-tool-2025-10-19", "tool-examples-2025-10-29"], + "minimum_version": 4.5, # Requires Claude 4.5+ + "allowed_families": ["opus", "sonnet"], # Not available on Haiku + }, +} + + +class BedrockBetaHeaderFilter: + """ + Centralized filter for anthropic-beta headers across all Bedrock APIs. + + Uses a whitelist-based approach to ensure only supported headers are sent to AWS. + """ + + def __init__(self, api_type: BedrockAPI): + """ + Initialize the filter for a specific Bedrock API. + + Args: + api_type: The Bedrock API type (Invoke Chat, Invoke Messages, or Converse) + """ + self.api_type = api_type + self.supported_betas = self._get_supported_betas() + + def _get_supported_betas(self) -> Set[str]: + """Get the set of supported beta headers for this API type.""" + # Start with core supported headers + supported = BEDROCK_CORE_SUPPORTED_BETAS.copy() + + # Remove API-specific exclusions + exclusions = BEDROCK_API_EXCLUSIONS.get(self.api_type, set()) + supported -= exclusions + + return supported + + def _extract_model_version(self, model: str) -> Optional[float]: + """ + Extract Claude model version from Bedrock model ID. + + Args: + model: Bedrock model ID (e.g., "anthropic.claude-opus-4-5-20250514-v1:0") + + Returns: + Version as float (e.g., 4.5), or None if unable to parse + + Examples: + "anthropic.claude-opus-4-5-20250514-v1:0" -> 4.5 + "anthropic.claude-sonnet-4-20250514-v1:0" -> 4.0 + "anthropic.claude-3-5-sonnet-20240620-v1:0" -> 3.5 + "anthropic.claude-3-sonnet-20240229-v1:0" -> 3.0 + """ + match = re.search(MODEL_VERSION_PATTERN, model) + if not match: + return None + + major = int(match.group(1)) + minor = int(match.group(2)) if match.group(2) else 0 + + return float(f"{major}.{minor}") + + def _extract_model_family(self, model: str) -> Optional[str]: + """ + Extract Claude model family (opus, sonnet, haiku) from Bedrock model ID. + + Args: + model: Bedrock model ID + + Returns: + Family name (opus/sonnet/haiku) or None if unable to parse + + Examples: + "anthropic.claude-opus-4-5-20250514-v1:0" -> "opus" + "anthropic.claude-3-5-sonnet-20240620-v1:0" -> "sonnet" + """ + model_lower = model.lower() + if "opus" in model_lower: + return "opus" + elif "sonnet" in model_lower: + return "sonnet" + elif "haiku" in model_lower: + return "haiku" + return None + + def _model_supports_beta(self, model: str, beta: str) -> bool: + """ + Check if a model supports a specific beta header. + + Uses a future-proof approach: + 1. If beta has no version requirement -> ALLOW (supports all models) + 2. If beta has version requirement -> Extract model version and compare + 3. If beta has family restriction -> Check model family + + This means NEW models automatically support all beta headers unless explicitly + restricted by version/family requirements. + + Args: + model: The Bedrock model ID (e.g., "anthropic.claude-sonnet-4-20250514-v1:0") + beta: The beta header to check + + Returns: + True if the model supports the beta header, False otherwise + """ + # Default: If no version requirement specified, ALL Anthropic models support it + # This makes the system future-proof for new models + if beta not in BETA_HEADER_MINIMUM_VERSION: + return True + + # Extract model version + model_version = self._extract_model_version(model) + if model_version is None: + # If we can't parse version, be conservative and reject + # (This should rarely happen with well-formed Bedrock model IDs) + return False + + # Check minimum version requirement + required_version = BETA_HEADER_MINIMUM_VERSION[beta] + if model_version < required_version: + return False # Model version too old + + # Check family restrictions (if any) + if beta in BETA_HEADER_FAMILY_RESTRICTIONS: + model_family = self._extract_model_family(model) + if model_family is None: + # Can't determine family, be conservative + return False + + allowed_families = BETA_HEADER_FAMILY_RESTRICTIONS[beta] + if model_family not in allowed_families: + return False # Wrong family + + # All checks passed + return True + + def _translate_beta_headers(self, beta_headers: Set[str], model: str) -> Set[str]: + """ + Translate beta headers for backward compatibility. + + Uses version-based checks to determine if model supports translation. + Future-proof: new models at the required version automatically support translations. + + Args: + beta_headers: Set of beta headers to translate + model: The Bedrock model ID + + Returns: + Set of translated beta headers + """ + translated = beta_headers.copy() + + for input_pattern, translation_info in BETA_HEADER_TRANSLATIONS.items(): + # Check if any beta header matches the input pattern + matching_headers = [h for h in beta_headers if input_pattern in h.lower()] + + if matching_headers: + # Check if model supports the translation using version-based logic + model_version = self._extract_model_version(model) + if model_version is None: + continue # Can't determine version, skip translation + + # Check minimum version + required_version = translation_info.get("minimum_version") + if required_version and model_version < required_version: + continue # Model too old for this translation + + # Check family restrictions (if any) + allowed_families = translation_info.get("allowed_families") + if allowed_families: + model_family = self._extract_model_family(model) + if model_family not in allowed_families: + continue # Wrong family + + # Model supports translation - apply it + for header in matching_headers: + translated.discard(header) + verbose_logger.debug( + f"Bedrock {self.api_type.value}: Translating beta header '{header}' for model {model}" + ) + + for target_header in translation_info["target_headers"]: + translated.add(target_header) + verbose_logger.debug( + f"Bedrock {self.api_type.value}: Added translated header '{target_header}'" + ) + + return translated + + def filter_beta_headers( + self, beta_headers: List[str], model: str, translate: bool = True + ) -> List[str]: + """ + Filter and translate beta headers for Bedrock. + + This is the main entry point for filtering beta headers. + + Args: + beta_headers: List of beta headers from user request + model: The Bedrock model ID + translate: Whether to apply header translations (default: True) + + Returns: + Filtered and translated list of beta headers + """ + if not beta_headers: + return [] + + # Convert to set for efficient operations + beta_set = set(beta_headers) + + # Apply translations if enabled + if translate: + beta_set = self._translate_beta_headers(beta_set, model) + + # Filter: Keep only whitelisted headers + filtered = set() + for beta in beta_set: + # Check if header is in whitelist + if beta not in self.supported_betas: + verbose_logger.debug( + f"Bedrock {self.api_type.value}: Filtered out unsupported beta header: {beta}" + ) + continue + + # Check if model supports this header + if not self._model_supports_beta(model, beta): + verbose_logger.debug( + f"Bedrock {self.api_type.value}: Filtered out beta header '{beta}' (not supported on model {model})" + ) + continue + + filtered.add(beta) + + verbose_logger.debug( + f"Bedrock {self.api_type.value}: Final beta headers for {model}: {sorted(filtered)}" + ) + return sorted(list(filtered)) # Sort for deterministic output + + +def get_bedrock_beta_filter(api_type: BedrockAPI) -> BedrockBetaHeaderFilter: + """ + Factory function to get a beta header filter for a specific API. + + Args: + api_type: The Bedrock API type + + Returns: + BedrockBetaHeaderFilter instance + """ + return BedrockBetaHeaderFilter(api_type) diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index daac3e6a008..1389d997b7c 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -33,6 +33,10 @@ ) from litellm.llms.anthropic.chat.transformation import AnthropicConfig from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.llms.bedrock.beta_headers_config import ( + BedrockAPI, + get_bedrock_beta_filter, +) from litellm.types.llms.bedrock import * from litellm.types.llms.openai import ( AllMessageValues, @@ -61,6 +65,7 @@ add_dummy_tool, any_assistant_message_has_thinking_blocks, has_tool_call_blocks, + last_assistant_message_has_no_thinking_blocks, last_assistant_with_tool_calls_has_no_thinking_blocks, supports_reasoning, ) @@ -81,13 +86,7 @@ "text_editor_", ] -# Beta header patterns that are not supported by Bedrock Converse API -# These will be filtered out to prevent errors -UNSUPPORTED_BEDROCK_CONVERSE_BETA_PATTERNS = [ - "advanced-tool-use", # Bedrock Converse doesn't support advanced-tool-use beta headers - "prompt-caching", # Prompt caching not supported in Converse API - "compact-2026-01-12", # The compact beta feature is not currently supported on the Converse and ConverseStream APIs -] +# Beta header filtering is now handled by centralized beta_headers_config module # Models that support Bedrock's native structured outputs API (outputConfig.textFormat) # Uses substring matching against the Bedrock model ID @@ -1211,6 +1210,10 @@ def _prepare_request_params( # These are LiteLLM internal parameters, not API parameters additional_request_params = filter_internal_params(additional_request_params) + # Remove Anthropic-specific body params that Bedrock doesn't support + # (these features are enabled via anthropic-beta headers instead) + additional_request_params.pop("context_management", None) + # Filter out non-serializable objects (exceptions, callables, logging objects, etc.) # from additional_request_params to prevent JSON serialization errors # This filters: Exception objects, callable objects (functions), Logging objects, etc. @@ -1328,11 +1331,16 @@ def _process_tools_and_beta( # Append pre-formatted tools (systemTool etc.) after transformation bedrock_tools.extend(pre_formatted_tools) - # Set anthropic_beta in additional_request_params if we have any beta features - # ONLY apply to Anthropic/Claude models - other models (e.g., Qwen, Llama) don't support this field + # Filter beta headers using centralized whitelist with model-specific support + # This handles version/family restrictions and unsupported beta patterns base_model = BedrockModelInfo.get_base_model(model) if anthropic_beta_list and base_model.startswith("anthropic"): - additional_request_params["anthropic_beta"] = anthropic_beta_list + beta_filter = get_bedrock_beta_filter(BedrockAPI.CONVERSE) + filtered_betas = beta_filter.filter_beta_headers( + anthropic_beta_list, model, translate=True + ) + if filtered_betas: + additional_request_params["anthropic_beta"] = filtered_betas return bedrock_tools, anthropic_beta_list @@ -1365,7 +1373,9 @@ def _transform_request_helper( ) # Drop thinking param if thinking is enabled but thinking_blocks are missing - # This prevents the error: "Expected thinking or redacted_thinking, but found tool_use" + # This prevents Anthropic errors: + # - "Expected thinking or redacted_thinking, but found tool_use" (assistant with tool_calls) + # - "Expected thinking or redacted_thinking, but found text" (assistant with text content) # # IMPORTANT: Only drop thinking if NO assistant messages have thinking_blocks. # If any message has thinking_blocks, we must keep thinking enabled, otherwise @@ -1373,7 +1383,10 @@ def _transform_request_helper( if ( optional_params.get("thinking") is not None and messages is not None - and last_assistant_with_tool_calls_has_no_thinking_blocks(messages) + and ( + last_assistant_with_tool_calls_has_no_thinking_blocks(messages) + or last_assistant_message_has_no_thinking_blocks(messages) + ) and not any_assistant_message_has_thinking_blocks(messages) ): if litellm.modify_params: diff --git a/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py index dfab81123fd..a6a87b2e348 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py @@ -7,6 +7,10 @@ AmazonInvokeConfig, ) from litellm.llms.bedrock.common_utils import get_anthropic_beta_from_headers +from litellm.llms.bedrock.beta_headers_config import ( + BedrockAPI, + get_bedrock_beta_filter, +) from litellm.types.llms.anthropic import ANTHROPIC_TOOL_SEARCH_BETA_HEADER from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ModelResponse @@ -105,6 +109,9 @@ def transform_request( _anthropic_request.pop("stream", None) # Bedrock Invoke doesn't support output_format parameter _anthropic_request.pop("output_format", None) + # Bedrock doesn't support context_management as a body param; + # the feature is enabled via the anthropic-beta header instead + _anthropic_request.pop("context_management", None) if "anthropic_version" not in _anthropic_request: _anthropic_request["anthropic_version"] = self.anthropic_version @@ -132,10 +139,17 @@ def transform_request( if "opus-4" in model.lower() or "opus_4" in model.lower(): beta_set.add("tool-search-tool-2025-10-19") - # Filter out beta headers that Bedrock Invoke doesn't support - # Uses centralized configuration from anthropic_beta_headers_config.json - beta_list = list(beta_set) - _anthropic_request["anthropic_beta"] = beta_list + # Filter beta headers using centralized whitelist with model-specific support + # AWS Bedrock only supports a specific whitelist of beta flags + # Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html + beta_filter = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + beta_list = beta_filter.filter_beta_headers( + list(beta_set), model, translate=False + ) + beta_set = set(beta_list) + + if beta_set: + _anthropic_request["anthropic_beta"] = list(beta_set) return _anthropic_request diff --git a/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py b/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py index 03885ff2080..e4acd01c2c7 100644 --- a/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py +++ b/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py @@ -23,6 +23,10 @@ from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import ( AmazonInvokeConfig, ) +from litellm.llms.bedrock.beta_headers_config import ( + BedrockAPI, + get_bedrock_beta_filter, +) from litellm.llms.bedrock.common_utils import ( get_anthropic_beta_from_headers, is_claude_4_5_on_bedrock, @@ -53,8 +57,7 @@ class AmazonAnthropicClaudeMessagesConfig( DEFAULT_BEDROCK_ANTHROPIC_API_VERSION = "bedrock-2023-05-31" - # Beta header patterns that are not supported by Bedrock Invoke API - # These will be filtered out to prevent 400 "invalid beta flag" errors + # Beta header filtering is now handled by centralized beta_headers_config module def __init__(self, **kwargs): BaseAnthropicMessagesConfig.__init__(self, **kwargs) @@ -402,6 +405,10 @@ def transform_anthropic_messages_request( anthropic_messages_request=anthropic_messages_request, ) + # 5b. Remove `context_management` from request body (Bedrock doesn't support it as a body param; + # the feature is enabled via the anthropic-beta header instead) + anthropic_messages_request.pop("context_management", None) + # 6. AUTO-INJECT beta headers based on features used anthropic_model_info = AnthropicModelInfo() tools = anthropic_messages_optional_request_params.get("tools") @@ -433,11 +440,15 @@ def transform_anthropic_messages_request( beta_set=beta_set, ) - if "tool-search-tool-2025-10-19" in beta_set: - beta_set.add("tool-examples-2025-10-29") - - if beta_set: - anthropic_messages_request["anthropic_beta"] = list(beta_set) + # Filter beta headers using centralized whitelist with model-specific support and translation + # This handles advanced-tool-use translation and version/family restrictions + beta_filter = get_bedrock_beta_filter(BedrockAPI.INVOKE_MESSAGES) + filtered_betas = beta_filter.filter_beta_headers( + list(beta_set), model, translate=True + ) + + if filtered_betas: + anthropic_messages_request["anthropic_beta"] = filtered_betas return anthropic_messages_request diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index 6800dff55ac..f523d4f8685 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -90,13 +90,13 @@ def _redact_pii_matches(response_json: dict) -> dict: # Redact PII entities in sensitive information policy sensitive_info_policy = assessment.get("sensitiveInformationPolicy") if sensitive_info_policy: - pii_entities = sensitive_info_policy.get("piiEntities", []) + pii_entities = sensitive_info_policy.get("piiEntities", []) or [] for pii_entity in pii_entities: if "match" in pii_entity: pii_entity["match"] = "[REDACTED]" # Redact regex matches - regexes = sensitive_info_policy.get("regexes", []) + regexes = sensitive_info_policy.get("regexes", []) or [] for regex_match in regexes: if "match" in regex_match: regex_match["match"] = "[REDACTED]" @@ -104,12 +104,12 @@ def _redact_pii_matches(response_json: dict) -> dict: # Redact custom word matches in word policy word_policy = assessment.get("wordPolicy") if word_policy: - custom_words = word_policy.get("customWords", []) + custom_words = word_policy.get("customWords", []) or [] for custom_word in custom_words: if "match" in custom_word: custom_word["match"] = "[REDACTED]" - managed_words = word_policy.get("managedWordLists", []) + managed_words = word_policy.get("managedWordLists", []) or [] for managed_word in managed_words: if "match" in managed_word: managed_word["match"] = "[REDACTED]" @@ -689,7 +689,7 @@ def _should_raise_guardrail_blocked_exception( # Check topic policy topic_policy = assessment.get("topicPolicy") if topic_policy: - topics = topic_policy.get("topics", []) + topics = topic_policy.get("topics", []) or [] for topic in topics: if topic.get("action") == "BLOCKED": return True @@ -697,7 +697,7 @@ def _should_raise_guardrail_blocked_exception( # Check content policy content_policy = assessment.get("contentPolicy") if content_policy: - filters = content_policy.get("filters", []) + filters = content_policy.get("filters", []) or [] for filter_item in filters: if filter_item.get("action") == "BLOCKED": return True @@ -705,11 +705,11 @@ def _should_raise_guardrail_blocked_exception( # Check word policy word_policy = assessment.get("wordPolicy") if word_policy: - custom_words = word_policy.get("customWords", []) + custom_words = word_policy.get("customWords", []) or [] for custom_word in custom_words: if custom_word.get("action") == "BLOCKED": return True - managed_words = word_policy.get("managedWordLists", []) + managed_words = word_policy.get("managedWordLists", []) or [] for managed_word in managed_words: if managed_word.get("action") == "BLOCKED": return True @@ -717,21 +717,19 @@ def _should_raise_guardrail_blocked_exception( # Check sensitive information policy sensitive_info_policy = assessment.get("sensitiveInformationPolicy") if sensitive_info_policy: - pii_entities = sensitive_info_policy.get("piiEntities", []) - if pii_entities: - for pii_entity in pii_entities: - if pii_entity.get("action") == "BLOCKED": - return True - regexes = sensitive_info_policy.get("regexes", []) - if regexes: - for regex in regexes: - if regex.get("action") == "BLOCKED": - return True + pii_entities = sensitive_info_policy.get("piiEntities", []) or [] + for pii_entity in pii_entities: + if pii_entity.get("action") == "BLOCKED": + return True + regexes = sensitive_info_policy.get("regexes", []) or [] + for regex in regexes: + if regex.get("action") == "BLOCKED": + return True # Check contextual grounding policy contextual_grounding_policy = assessment.get("contextualGroundingPolicy") if contextual_grounding_policy: - grounding_filters = contextual_grounding_policy.get("filters", []) + grounding_filters = contextual_grounding_policy.get("filters", []) or [] for grounding_filter in grounding_filters: if grounding_filter.get("action") == "BLOCKED": return True diff --git a/litellm/router.py b/litellm/router.py index 9c821b11fbd..f5f7414ce4f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -6297,6 +6297,11 @@ def _add_deployment(self, deployment: Deployment) -> Deployment: ): raise Exception(f"Unsupported provider - {custom_llm_provider}") + # Store custom_llm_provider in litellm_params so it's available to callbacks + # after alias resolution (e.g., websearch_interception pre-call hooks) + if custom_llm_provider: + deployment.litellm_params.custom_llm_provider = custom_llm_provider + #### DEPLOYMENT NAMES INIT ######## self.deployment_names.append(deployment.litellm_params.model) ############ Users can either pass tpm/rpm as a litellm_param or a router param ########### diff --git a/litellm/utils.py b/litellm/utils.py index 6a18fcc9e35..85d5e2d1fb0 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7512,6 +7512,33 @@ def has_tool_call_blocks(messages: List[AllMessageValues]) -> bool: return False +def _message_has_thinking_blocks(message: AllMessageValues) -> bool: + """ + Check if a single assistant message has thinking blocks. + + Checks both the 'thinking_blocks' field (LiteLLM/OpenAI format) and + the 'content' array for thinking/redacted_thinking blocks (Anthropic format). + """ + # Check thinking_blocks field (LiteLLM/OpenAI format) + thinking_blocks = message.get("thinking_blocks") + if thinking_blocks is not None and ( + not hasattr(thinking_blocks, "__len__") or len(thinking_blocks) > 0 + ): + return True + + # Check content array for thinking blocks (Anthropic format) + content = message.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") in ( + "thinking", + "redacted_thinking", + ): + return True + + return False + + def any_assistant_message_has_thinking_blocks( messages: List[AllMessageValues], ) -> bool: @@ -7527,10 +7554,7 @@ def any_assistant_message_has_thinking_blocks( """ for message in messages: if message.get("role") == "assistant": - thinking_blocks = message.get("thinking_blocks") - if thinking_blocks is not None and ( - not hasattr(thinking_blocks, "__len__") or len(thinking_blocks) > 0 - ): + if _message_has_thinking_blocks(message): return True return False @@ -7562,11 +7586,40 @@ def last_assistant_with_tool_calls_has_no_thinking_blocks( if last_assistant_with_tools is None: return False - # Check if it has thinking_blocks - thinking_blocks = last_assistant_with_tools.get("thinking_blocks") - return thinking_blocks is None or ( - hasattr(thinking_blocks, "__len__") and len(thinking_blocks) == 0 - ) + return not _message_has_thinking_blocks(last_assistant_with_tools) + + +def last_assistant_message_has_no_thinking_blocks( + messages: List[AllMessageValues], +) -> bool: + """ + Returns true if the last assistant message has content but no thinking_blocks. + + This is used to detect when thinking param should be dropped to avoid + Anthropic error: "Expected thinking or redacted_thinking, but found text" + + When thinking is enabled, ALL assistant messages must start with thinking_blocks. + If the client didn't preserve thinking_blocks, we need to drop the thinking param. + + IMPORTANT: This should only be used in conjunction with + any_assistant_message_has_thinking_blocks() to ensure we don't drop thinking + when other messages in the conversation contain thinking blocks. + """ + # Find the last assistant message + last_assistant = None + for message in messages: + if message.get("role") == "assistant": + last_assistant = message + + if last_assistant is None: + return False + + # Only flag if message has content (empty messages aren't an issue) + content = last_assistant.get("content") + if not content: + return False + + return not _message_has_thinking_blocks(last_assistant) def add_dummy_tool(custom_llm_provider: str) -> List[ChatCompletionToolParam]: diff --git a/tests/test_litellm/integrations/websearch_interception/test_websearch_interception_handler.py b/tests/test_litellm/integrations/websearch_interception/test_websearch_interception_handler.py index 5abecb46c99..69428a510e0 100644 --- a/tests/test_litellm/integrations/websearch_interception/test_websearch_interception_handler.py +++ b/tests/test_litellm/integrations/websearch_interception/test_websearch_interception_handler.py @@ -4,16 +4,110 @@ Tests the WebSearchInterceptionLogger class and helper functions. """ +import asyncio from unittest.mock import Mock -import pytest +from litellm.constants import LITELLM_WEB_SEARCH_TOOL_NAME from litellm.integrations.websearch_interception.handler import ( WebSearchInterceptionLogger, ) +from litellm.integrations.websearch_interception.tools import ( + get_litellm_web_search_tool, + is_web_search_tool, +) +from litellm.litellm_core_utils.core_helpers import filter_internal_params from litellm.types.utils import LlmProviders +class TestIsWebSearchTool: + """Tests for is_web_search_tool() helper function""" + + def test_litellm_standard_tool(self): + """Should detect LiteLLM standard web search tool""" + tool = {"name": LITELLM_WEB_SEARCH_TOOL_NAME} + assert is_web_search_tool(tool) is True + + def test_anthropic_native_web_search(self): + """Should detect Anthropic native web_search_* type""" + tool = {"type": "web_search_20250305", "name": "web_search"} + assert is_web_search_tool(tool) is True + + def test_anthropic_native_future_version(self): + """Should detect future versions of Anthropic web_search type""" + tool = {"type": "web_search_20260101", "name": "web_search"} + assert is_web_search_tool(tool) is True + + def test_claude_code_web_search(self): + """Should detect Claude Code's web_search with type field""" + tool = {"name": "web_search", "type": "web_search_20250305"} + assert is_web_search_tool(tool) is True + + def test_legacy_websearch_format(self): + """Should detect legacy WebSearch format""" + tool = {"name": "WebSearch"} + assert is_web_search_tool(tool) is True + + def test_non_websearch_tool(self): + """Should not detect non-web-search tools""" + assert is_web_search_tool({"name": "calculator"}) is False + assert is_web_search_tool({"name": "read_file"}) is False + assert is_web_search_tool({"type": "function", "name": "search"}) is False + + def test_web_search_name_without_type(self): + """Should NOT detect 'web_search' name without type field (could be custom tool)""" + tool = {"name": "web_search"} # No type field + assert is_web_search_tool(tool) is False + + +class TestGetLitellmWebSearchTool: + """Tests for get_litellm_web_search_tool() helper function""" + + def test_returns_valid_tool_definition(self): + """Should return a valid tool definition""" + tool = get_litellm_web_search_tool() + + assert tool["name"] == LITELLM_WEB_SEARCH_TOOL_NAME + assert "description" in tool + assert "input_schema" in tool + assert tool["input_schema"]["type"] == "object" + assert "query" in tool["input_schema"]["properties"] + + +class TestWebSearchInterceptionLoggerInit: + """Tests for WebSearchInterceptionLogger initialization""" + + def test_default_initialization(self): + """Test default initialization with no parameters""" + logger = WebSearchInterceptionLogger() + + # Default should have bedrock enabled + assert "bedrock" in logger.enabled_providers + assert logger.search_tool_name is None + + def test_custom_providers(self): + """Test initialization with custom providers""" + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock", "vertex_ai", "openai"]) + + assert "bedrock" in logger.enabled_providers + assert "vertex_ai" in logger.enabled_providers + assert "openai" in logger.enabled_providers + + def test_custom_search_tool_name(self): + """Test initialization with custom search tool name""" + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"], search_tool_name="custom-search-tool") + + assert logger.search_tool_name == "custom-search-tool" + + def test_llm_providers_enum_conversion(self): + """Test that LlmProviders enum values are converted to strings""" + logger = WebSearchInterceptionLogger(enabled_providers=[LlmProviders.BEDROCK, LlmProviders.VERTEX_AI]) + + # Should be stored as string values + assert "bedrock" in logger.enabled_providers + assert "vertex_ai" in logger.enabled_providers + + def test_initialize_from_proxy_config(): """Test initialization from proxy config with litellm_settings""" litellm_settings = { @@ -34,50 +128,221 @@ def test_initialize_from_proxy_config(): assert logger.search_tool_name == "my-search" -@pytest.mark.asyncio -async def test_async_should_run_agentic_loop(): - """Test that agentic loop is NOT triggered for wrong provider or missing WebSearch tool""" - logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) - - # Test 1: Wrong provider (not in enabled_providers) - response = Mock() - should_run, tools_dict = await logger.async_should_run_agentic_loop( - response=response, - model="gpt-4", - messages=[], - tools=[{"name": "WebSearch"}], - stream=False, - custom_llm_provider="openai", # Not in enabled_providers - kwargs={}, - ) +def test_initialize_from_proxy_config_defaults(): + """Test initialization from proxy config with defaults when params missing""" + litellm_settings = {} + callback_specific_params = {} - assert should_run is False - assert tools_dict == {} - - # Test 2: No WebSearch tool in request - should_run, tools_dict = await logger.async_should_run_agentic_loop( - response=response, - model="bedrock/claude", - messages=[], - tools=[{"name": "SomeOtherTool"}], # No WebSearch - stream=False, - custom_llm_provider="bedrock", - kwargs={}, + logger = WebSearchInterceptionLogger.initialize_from_proxy_config( + litellm_settings=litellm_settings, + callback_specific_params=callback_specific_params, ) - assert should_run is False - assert tools_dict == {} + # Should use default bedrock provider + assert "bedrock" in logger.enabled_providers + + +def test_async_should_run_agentic_loop_wrong_provider(): + """Test that agentic loop is NOT triggered for wrong provider""" + + async def _test(): + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + response = Mock() + should_run, tools_dict = await logger.async_should_run_agentic_loop( + response=response, + model="gpt-4", + messages=[], + tools=[{"type": "web_search_20250305", "name": "web_search"}], + stream=False, + custom_llm_provider="openai", # Not in enabled_providers + kwargs={}, + ) + + assert should_run is False + assert tools_dict == {} + + asyncio.run(_test()) + + +def test_async_should_run_agentic_loop_no_websearch_tool(): + """Test that agentic loop is NOT triggered when no WebSearch tool in request""" + + async def _test(): + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + response = Mock() + should_run, tools_dict = await logger.async_should_run_agentic_loop( + response=response, + model="bedrock/claude", + messages=[], + tools=[{"name": "calculator"}], # No WebSearch tool + stream=False, + custom_llm_provider="bedrock", + kwargs={}, + ) + + assert should_run is False + assert tools_dict == {} + + asyncio.run(_test()) + + +def test_async_should_run_agentic_loop_no_websearch_in_response(): + """Test that agentic loop is NOT triggered when response has no WebSearch tool_use""" + + async def _test(): + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + # Response with text only, no tool_use + response = {"content": [{"type": "text", "text": "I don't need to search for this."}]} + + should_run, tools_dict = await logger.async_should_run_agentic_loop( + response=response, + model="bedrock/claude", + messages=[], + tools=[{"type": "web_search_20250305", "name": "web_search"}], + stream=False, + custom_llm_provider="bedrock", + kwargs={}, + ) + + assert should_run is False + assert tools_dict == {} + + asyncio.run(_test()) + + +def test_async_should_run_agentic_loop_positive_case(): + """Test that agentic loop IS triggered when WebSearch tool_use in response""" + + async def _test(): + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + # Response with WebSearch tool_use + response = { + "content": [ + { + "type": "tool_use", + "id": "tool_123", + "name": "WebSearch", + "input": {"query": "weather in SF"}, + } + ] + } + + should_run, tools_dict = await logger.async_should_run_agentic_loop( + response=response, + model="bedrock/us.anthropic.claude-opus-4-5-20251101-v1:0", + messages=[], + tools=[{"type": "web_search_20250305", "name": "web_search"}], + stream=False, + custom_llm_provider="bedrock", + kwargs={}, + ) + + assert should_run is True + assert "tool_calls" in tools_dict + assert len(tools_dict["tool_calls"]) == 1 + assert tools_dict["tool_calls"][0]["id"] == "tool_123" + assert tools_dict["tool_type"] == "websearch" + + asyncio.run(_test()) + + +def test_async_should_run_agentic_loop_includes_thinking_blocks(): + """Test that thinking blocks are captured in tools_dict""" + + async def _test(): + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + # Response with thinking block and WebSearch tool_use + response = { + "content": [ + { + "type": "thinking", + "thinking": "Let me search for the weather...", + }, + { + "type": "tool_use", + "id": "tool_456", + "name": "WebSearch", + "input": {"query": "current weather SF"}, + }, + ] + } + + should_run, tools_dict = await logger.async_should_run_agentic_loop( + response=response, + model="bedrock/claude", + messages=[], + tools=[{"type": "web_search_20250305", "name": "web_search"}], + stream=False, + custom_llm_provider="bedrock", + kwargs={}, + ) + + assert should_run is True + assert "thinking_blocks" in tools_dict + assert len(tools_dict["thinking_blocks"]) == 1 + assert tools_dict["thinking_blocks"][0]["type"] == "thinking" + + asyncio.run(_test()) + + +def test_async_should_run_agentic_loop_empty_tools_list(): + """Test with empty tools list""" + + async def _test(): + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + response = Mock() + should_run, tools_dict = await logger.async_should_run_agentic_loop( + response=response, + model="bedrock/claude", + messages=[], + tools=[], # Empty tools list + stream=False, + custom_llm_provider="bedrock", + kwargs={}, + ) + + assert should_run is False + assert tools_dict == {} + + asyncio.run(_test()) + +def test_async_should_run_agentic_loop_none_tools(): + """Test with None tools""" -@pytest.mark.asyncio -async def test_internal_flags_filtered_from_followup_kwargs(): + async def _test(): + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + response = Mock() + should_run, tools_dict = await logger.async_should_run_agentic_loop( + response=response, + model="bedrock/claude", + messages=[], + tools=None, # None tools + stream=False, + custom_llm_provider="bedrock", + kwargs={}, + ) + + assert should_run is False + assert tools_dict == {} + + asyncio.run(_test()) + + +def test_internal_flags_filtered_from_followup_kwargs(): """Test that internal _websearch_interception flags are filtered from follow-up request kwargs. Regression test for bug where _websearch_interception_converted_stream was passed to the follow-up LLM request, causing "Extra inputs are not permitted" errors from providers like Bedrock that use strict parameter validation. """ - logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) # Simulate kwargs that would be passed during agentic loop execution kwargs_with_internal_flags = { @@ -87,11 +352,8 @@ async def test_internal_flags_filtered_from_followup_kwargs(): "max_tokens": 1024, } - # Apply the same filtering logic used in _execute_agentic_loop - kwargs_for_followup = { - k: v for k, v in kwargs_with_internal_flags.items() - if not k.startswith('_websearch_interception') - } + # Use the actual filter function from the codebase + kwargs_for_followup = filter_internal_params(kwargs_with_internal_flags) # Verify internal flags are filtered out assert "_websearch_interception_converted_stream" not in kwargs_for_followup @@ -100,3 +362,276 @@ async def test_internal_flags_filtered_from_followup_kwargs(): # Verify regular kwargs are preserved assert kwargs_for_followup["temperature"] == 0.7 assert kwargs_for_followup["max_tokens"] == 1024 + + +def test_duplicate_kwargs_filtered_from_followup(): + """Test that kwargs already in optional_params are deduplicated before follow-up request. + + Regression test for bug where context_management appeared in both + optional_params and kwargs, causing: "got multiple values for keyword argument 'context_management'" + """ + + optional_params_without_max_tokens = { + "thinking": {"type": "enabled", "budget_tokens": 5000}, + "context_management": {"type": "automatic", "max_context_tokens": 50000}, + "temperature": 0.7, + } + + kwargs_for_followup = { + "context_management": {"type": "automatic", "max_context_tokens": 50000}, + "some_other_kwarg": "value", + "max_tokens": 1024, + "model": "claude-opus-4-6", + "messages": [{"role": "user", "content": "test"}], + } + + # Apply the same dedup logic used in _execute_agentic_loop + explicit_keys = {"max_tokens", "messages", "model"} + kwargs_for_followup = { + k: v for k, v in kwargs_for_followup.items() + if k not in optional_params_without_max_tokens and k not in explicit_keys + } + + # context_management should be removed (already in optional_params) + assert "context_management" not in kwargs_for_followup + # Explicit keys should be removed + assert "max_tokens" not in kwargs_for_followup + assert "model" not in kwargs_for_followup + assert "messages" not in kwargs_for_followup + # Non-duplicate kwargs should be preserved + assert kwargs_for_followup["some_other_kwarg"] == "value" + + +class TestExecuteSearchApiKeyExtraction: + """Tests for API key extraction from router's search_tools configuration. + + Verifies that _execute_search() correctly loads search_provider, api_key, + and api_base from the router's search_tools config. + """ + + def test_extracts_credentials_from_named_search_tool(self): + """Should extract api_key, api_base, search_provider from a named search tool.""" + import asyncio + from unittest.mock import AsyncMock, MagicMock, patch + + logger = WebSearchInterceptionLogger( + enabled_providers=["bedrock"], + search_tool_name="my-perplexity", + ) + + mock_router = MagicMock() + mock_router.search_tools = [ + { + "search_tool_name": "my-perplexity", + "litellm_params": { + "search_provider": "perplexity", + "api_key": "pplx-secret-key", + "api_base": "https://custom.perplexity.ai", + }, + } + ] + + async def _test(): + with patch( + "litellm.integrations.websearch_interception.handler.litellm.asearch", + new_callable=AsyncMock, + ) as mock_asearch: + mock_asearch.return_value = MagicMock(results=[]) + + mock_proxy = MagicMock() + mock_proxy.llm_router = mock_router + with patch.dict("sys.modules", {"litellm.proxy.proxy_server": mock_proxy}): + await logger._execute_search("test query") + + mock_asearch.assert_called_once_with( + query="test query", + search_provider="perplexity", + api_key="pplx-secret-key", + api_base="https://custom.perplexity.ai", + ) + + asyncio.run(_test()) + + def test_falls_back_to_first_search_tool(self): + """Should fall back to first available search tool when named tool not found.""" + import asyncio + from unittest.mock import AsyncMock, MagicMock, patch + + logger = WebSearchInterceptionLogger( + enabled_providers=["bedrock"], + search_tool_name="nonexistent-tool", + ) + + mock_router = MagicMock() + mock_router.search_tools = [ + { + "search_tool_name": "default-search", + "litellm_params": { + "search_provider": "tavily", + "api_key": "tvly-fallback-key", + }, + } + ] + + async def _test(): + with patch( + "litellm.integrations.websearch_interception.handler.litellm.asearch", + new_callable=AsyncMock, + ) as mock_asearch: + mock_asearch.return_value = MagicMock(results=[]) + + mock_proxy = MagicMock() + mock_proxy.llm_router = mock_router + with patch.dict("sys.modules", {"litellm.proxy.proxy_server": mock_proxy}): + await logger._execute_search("test query") + + mock_asearch.assert_called_once_with( + query="test query", + search_provider="tavily", + api_key="tvly-fallback-key", + api_base=None, + ) + + asyncio.run(_test()) + + def test_falls_back_to_perplexity_when_no_router(self): + """Should fall back to perplexity when router has no search_tools.""" + import asyncio + from unittest.mock import AsyncMock, MagicMock, patch + + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + mock_router = MagicMock() + mock_router.search_tools = [] + + async def _test(): + with patch( + "litellm.integrations.websearch_interception.handler.litellm.asearch", + new_callable=AsyncMock, + ) as mock_asearch: + mock_asearch.return_value = MagicMock(results=[]) + + mock_proxy = MagicMock() + mock_proxy.llm_router = mock_router + with patch.dict("sys.modules", {"litellm.proxy.proxy_server": mock_proxy}): + await logger._execute_search("test query") + + mock_asearch.assert_called_once_with( + query="test query", + search_provider="perplexity", + api_key=None, + api_base=None, + ) + + asyncio.run(_test()) + + +class TestMaxTokensThinkingBudgetAdjustment: + """Tests for max_tokens adjustment when thinking budget is configured. + + Verifies that _execute_agentic_loop() adjusts max_tokens to satisfy + Anthropic's requirement: max_tokens > thinking.budget_tokens. + """ + + def test_adjusts_max_tokens_when_less_than_budget(self): + """Should adjust max_tokens when max_tokens <= budget_tokens.""" + import asyncio + from unittest.mock import AsyncMock, MagicMock, patch + + from litellm.constants import DEFAULT_MAX_TOKENS + + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + async def _test(): + with patch( + "litellm.integrations.websearch_interception.handler.anthropic_messages.acreate", + new_callable=AsyncMock, + ) as mock_acreate: + mock_acreate.return_value = MagicMock() + + await logger._execute_agentic_loop( + model="claude-opus-4-6", + messages=[{"role": "user", "content": "search for X"}], + tool_calls=[{"id": "t1", "name": "WebSearch", "input": {"query": "X"}}], + thinking_blocks=[], + anthropic_messages_optional_request_params={ + "max_tokens": 1024, + "thinking": {"type": "enabled", "budget_tokens": 5000}, + }, + logging_obj=None, + stream=False, + kwargs={}, + ) + + # max_tokens should be adjusted to budget_tokens + DEFAULT_MAX_TOKENS + call_kwargs = mock_acreate.call_args + assert call_kwargs.kwargs["max_tokens"] == 5000 + DEFAULT_MAX_TOKENS + + asyncio.run(_test()) + + def test_preserves_max_tokens_when_greater_than_budget(self): + """Should NOT adjust max_tokens when max_tokens > budget_tokens.""" + import asyncio + from unittest.mock import AsyncMock, MagicMock, patch + + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + async def _test(): + with patch( + "litellm.integrations.websearch_interception.handler.anthropic_messages.acreate", + new_callable=AsyncMock, + ) as mock_acreate: + mock_acreate.return_value = MagicMock() + + await logger._execute_agentic_loop( + model="claude-opus-4-6", + messages=[{"role": "user", "content": "search for X"}], + tool_calls=[{"id": "t1", "name": "WebSearch", "input": {"query": "X"}}], + thinking_blocks=[], + anthropic_messages_optional_request_params={ + "max_tokens": 16000, + "thinking": {"type": "enabled", "budget_tokens": 5000}, + }, + logging_obj=None, + stream=False, + kwargs={}, + ) + + # max_tokens should remain unchanged + call_kwargs = mock_acreate.call_args + assert call_kwargs.kwargs["max_tokens"] == 16000 + + asyncio.run(_test()) + + def test_no_adjustment_without_thinking(self): + """Should NOT adjust max_tokens when thinking is not enabled.""" + import asyncio + from unittest.mock import AsyncMock, MagicMock, patch + + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + async def _test(): + with patch( + "litellm.integrations.websearch_interception.handler.anthropic_messages.acreate", + new_callable=AsyncMock, + ) as mock_acreate: + mock_acreate.return_value = MagicMock() + + await logger._execute_agentic_loop( + model="claude-opus-4-6", + messages=[{"role": "user", "content": "search for X"}], + tool_calls=[{"id": "t1", "name": "WebSearch", "input": {"query": "X"}}], + thinking_blocks=[], + anthropic_messages_optional_request_params={ + "max_tokens": 1024, + }, + logging_obj=None, + stream=False, + kwargs={}, + ) + + # max_tokens should remain at original value + call_kwargs = mock_acreate.call_args + assert call_kwargs.kwargs["max_tokens"] == 1024 + + asyncio.run(_test()) diff --git a/tests/test_litellm/integrations/websearch_interception/test_websearch_interception_transformation.py b/tests/test_litellm/integrations/websearch_interception/test_websearch_interception_transformation.py new file mode 100644 index 00000000000..8c88cffb32e --- /dev/null +++ b/tests/test_litellm/integrations/websearch_interception/test_websearch_interception_transformation.py @@ -0,0 +1,427 @@ +""" +Unit tests for WebSearch Interception Transformation + +Tests the WebSearchTransformation class methods: +- transform_request: Extract WebSearch tool calls and thinking blocks from responses +- transform_response: Build follow-up messages with tool_use and tool_result blocks +- format_search_response: Format search results for tool_result content +""" + +from unittest.mock import Mock + + +from litellm.constants import LITELLM_WEB_SEARCH_TOOL_NAME +from litellm.integrations.websearch_interception.transformation import ( + WebSearchTransformation, +) + + +class TestTransformRequest: + """Tests for WebSearchTransformation.transform_request()""" + + def test_streaming_response_returns_empty(self): + """Streaming responses should return empty result (we handle by converting to non-streaming)""" + response = {"content": [{"type": "tool_use", "name": "WebSearch"}]} + + result = WebSearchTransformation.transform_request( + response=response, + stream=True, + ) + + assert result.has_websearch is False + assert result.tool_calls == [] + assert result.thinking_blocks == [] + + def test_dict_response_with_websearch_tool(self): + """Dict response with WebSearch tool_use should be detected""" + response = { + "content": [ + { + "type": "tool_use", + "id": "tool_123", + "name": "WebSearch", + "input": {"query": "weather in SF"}, + } + ] + } + + result = WebSearchTransformation.transform_request( + response=response, + stream=False, + ) + + assert result.has_websearch is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["id"] == "tool_123" + assert result.tool_calls[0]["name"] == "WebSearch" + assert result.tool_calls[0]["input"]["query"] == "weather in SF" + + def test_object_response_with_websearch_tool(self): + """Object response (with attributes) with WebSearch tool_use should be detected""" + tool_block = Mock() + tool_block.type = "tool_use" + tool_block.id = "tool_456" + tool_block.name = "WebSearch" + tool_block.input = {"query": "latest news"} + + response = Mock() + response.content = [tool_block] + + result = WebSearchTransformation.transform_request( + response=response, + stream=False, + ) + + assert result.has_websearch is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["id"] == "tool_456" + assert result.tool_calls[0]["input"]["query"] == "latest news" + + def test_detects_litellm_web_search_tool_name(self): + """Should detect the LiteLLM standard web search tool name""" + response = { + "content": [ + { + "type": "tool_use", + "id": "tool_789", + "name": LITELLM_WEB_SEARCH_TOOL_NAME, + "input": {"query": "test query"}, + } + ] + } + + result = WebSearchTransformation.transform_request( + response=response, + stream=False, + ) + + assert result.has_websearch is True + assert result.tool_calls[0]["name"] == LITELLM_WEB_SEARCH_TOOL_NAME + + def test_detects_web_search_lowercase(self): + """Should detect 'web_search' tool name (lowercase variant)""" + response = { + "content": [ + { + "type": "tool_use", + "id": "tool_abc", + "name": "web_search", + "input": {"query": "another query"}, + } + ] + } + + result = WebSearchTransformation.transform_request( + response=response, + stream=False, + ) + + assert result.has_websearch is True + assert result.tool_calls[0]["name"] == "web_search" + + def test_captures_thinking_blocks(self): + """Should capture thinking blocks from response""" + response = { + "content": [ + { + "type": "thinking", + "thinking": "Let me search for this information...", + }, + { + "type": "tool_use", + "id": "tool_def", + "name": "WebSearch", + "input": {"query": "AI news"}, + }, + ] + } + + result = WebSearchTransformation.transform_request( + response=response, + stream=False, + ) + + assert result.has_websearch is True + assert len(result.thinking_blocks) == 1 + assert result.thinking_blocks[0]["type"] == "thinking" + + def test_captures_redacted_thinking_blocks(self): + """Should capture redacted_thinking blocks from response""" + response = { + "content": [ + { + "type": "redacted_thinking", + "data": "base64redacteddata", + }, + { + "type": "tool_use", + "id": "tool_ghi", + "name": "WebSearch", + "input": {"query": "sensitive query"}, + }, + ] + } + + result = WebSearchTransformation.transform_request( + response=response, + stream=False, + ) + + assert result.has_websearch is True + assert len(result.thinking_blocks) == 1 + assert result.thinking_blocks[0]["type"] == "redacted_thinking" + + def test_thinking_blocks_normalized_to_dict_from_sdk_objects(self): + """SDK object thinking blocks should be normalized to dicts for JSON serialization""" + # Create mock SDK objects (not dicts) + thinking_block = Mock() + thinking_block.type = "thinking" + thinking_block.thinking = "Let me search for this..." + thinking_block.signature = "sig123" + + redacted_block = Mock() + redacted_block.type = "redacted_thinking" + redacted_block.data = "base64redacteddata" + + tool_block = Mock() + tool_block.type = "tool_use" + tool_block.id = "tool_xyz" + tool_block.name = "WebSearch" + tool_block.input = {"query": "test"} + + response = Mock() + response.content = [thinking_block, redacted_block, tool_block] + + result = WebSearchTransformation.transform_request( + response=response, + stream=False, + ) + + assert result.has_websearch is True + assert len(result.thinking_blocks) == 2 + + # Verify blocks were normalized to dicts (not Mock objects) + assert isinstance(result.thinking_blocks[0], dict) + assert isinstance(result.thinking_blocks[1], dict) + + # Verify thinking block content including signature + assert result.thinking_blocks[0]["type"] == "thinking" + assert result.thinking_blocks[0]["thinking"] == "Let me search for this..." + assert result.thinking_blocks[0]["signature"] == "sig123" + + # Verify redacted_thinking block content + assert result.thinking_blocks[1]["type"] == "redacted_thinking" + assert result.thinking_blocks[1]["data"] == "base64redacteddata" + + def test_multiple_tool_calls(self): + """Should handle multiple WebSearch tool_use blocks""" + response = { + "content": [ + { + "type": "tool_use", + "id": "tool_1", + "name": "WebSearch", + "input": {"query": "query 1"}, + }, + { + "type": "tool_use", + "id": "tool_2", + "name": "WebSearch", + "input": {"query": "query 2"}, + }, + ] + } + + result = WebSearchTransformation.transform_request( + response=response, + stream=False, + ) + + assert result.has_websearch is True + assert len(result.tool_calls) == 2 + assert result.tool_calls[0]["input"]["query"] == "query 1" + assert result.tool_calls[1]["input"]["query"] == "query 2" + + def test_no_websearch_in_response(self): + """Response without WebSearch tool should return has_websearch=False""" + response = { + "content": [ + {"type": "text", "text": "Here is a response"}, + { + "type": "tool_use", + "id": "tool_other", + "name": "calculator", + "input": {"expression": "2+2"}, + }, + ] + } + + result = WebSearchTransformation.transform_request( + response=response, + stream=False, + ) + + assert result.has_websearch is False + assert result.tool_calls == [] + + def test_empty_content_returns_empty_result(self): + """Empty content should return empty result""" + response = {"content": []} + + result = WebSearchTransformation.transform_request( + response=response, + stream=False, + ) + + assert result.has_websearch is False + assert result.tool_calls == [] + assert result.thinking_blocks == [] + + def test_response_without_content_attribute(self): + """Response object without content attribute should return empty result""" + response = Mock(spec=[]) # Mock with no attributes + + result = WebSearchTransformation.transform_request( + response=response, + stream=False, + ) + + assert result.has_websearch is False + + +class TestTransformResponse: + """Tests for WebSearchTransformation.transform_response()""" + + def test_builds_messages_without_thinking_blocks(self): + """Should build correct messages without thinking blocks""" + tool_calls = [ + { + "id": "tool_1", + "name": "WebSearch", + "input": {"query": "test query"}, + } + ] + search_results = ["Title: Result\nURL: http://example.com\nSnippet: Test snippet"] + thinking_blocks = [] + + assistant_msg, user_msg = WebSearchTransformation.transform_response( + tool_calls=tool_calls, + search_results=search_results, + thinking_blocks=thinking_blocks, + ) + + # Check assistant message + assert assistant_msg["role"] == "assistant" + assert len(assistant_msg["content"]) == 1 + assert assistant_msg["content"][0]["type"] == "tool_use" + assert assistant_msg["content"][0]["id"] == "tool_1" + assert assistant_msg["content"][0]["name"] == "WebSearch" + + # Check user message + assert user_msg["role"] == "user" + assert len(user_msg["content"]) == 1 + assert user_msg["content"][0]["type"] == "tool_result" + assert user_msg["content"][0]["tool_use_id"] == "tool_1" + assert "Test snippet" in user_msg["content"][0]["content"] + + def test_builds_messages_with_thinking_blocks(self): + """Should include thinking blocks at start of assistant message""" + tool_calls = [ + { + "id": "tool_2", + "name": "WebSearch", + "input": {"query": "another query"}, + } + ] + search_results = ["Search result text"] + thinking_blocks = [{"type": "thinking", "thinking": "I need to search for this..."}] + + assistant_msg, user_msg = WebSearchTransformation.transform_response( + tool_calls=tool_calls, + search_results=search_results, + thinking_blocks=thinking_blocks, + ) + + # Check assistant message has thinking block first, then tool_use + assert len(assistant_msg["content"]) == 2 + assert assistant_msg["content"][0]["type"] == "thinking" + assert assistant_msg["content"][1]["type"] == "tool_use" + + def test_multiple_tool_calls_and_results(self): + """Should handle multiple tool calls and their results""" + tool_calls = [ + {"id": "tool_a", "name": "WebSearch", "input": {"query": "q1"}}, + {"id": "tool_b", "name": "WebSearch", "input": {"query": "q2"}}, + ] + search_results = ["Result A", "Result B"] + thinking_blocks = [] + + assistant_msg, user_msg = WebSearchTransformation.transform_response( + tool_calls=tool_calls, + search_results=search_results, + thinking_blocks=thinking_blocks, + ) + + # Check tool_use blocks in assistant message + assert len(assistant_msg["content"]) == 2 + assert assistant_msg["content"][0]["id"] == "tool_a" + assert assistant_msg["content"][1]["id"] == "tool_b" + + # Check tool_result blocks in user message + assert len(user_msg["content"]) == 2 + assert user_msg["content"][0]["tool_use_id"] == "tool_a" + assert user_msg["content"][0]["content"] == "Result A" + assert user_msg["content"][1]["tool_use_id"] == "tool_b" + assert user_msg["content"][1]["content"] == "Result B" + + +class TestFormatSearchResponse: + """Tests for WebSearchTransformation.format_search_response()""" + + def test_formats_search_response_with_results(self): + """Should format SearchResponse with results into readable text""" + # Create mock SearchResponse + result1 = Mock() + result1.title = "First Result" + result1.url = "https://example.com/1" + result1.snippet = "This is the first snippet." + + result2 = Mock() + result2.title = "Second Result" + result2.url = "https://example.com/2" + result2.snippet = "This is the second snippet." + + search_response = Mock() + search_response.results = [result1, result2] + + formatted = WebSearchTransformation.format_search_response(search_response) + + assert "Title: First Result" in formatted + assert "URL: https://example.com/1" in formatted + assert "Snippet: This is the first snippet." in formatted + assert "Title: Second Result" in formatted + + def test_formats_empty_results(self): + """Should handle SearchResponse with no results""" + search_response = Mock() + search_response.results = [] + + formatted = WebSearchTransformation.format_search_response(search_response) + + # Should fallback to str(result) + assert formatted # Not empty + + def test_formats_response_without_results_attribute(self): + """Should fallback to str() for responses without results attribute""" + + # Create a simple class without 'results' attribute that converts to string + class SimpleResponse: + def __str__(self): + return "Fallback string representation" + + search_response = SimpleResponse() + + formatted = WebSearchTransformation.format_search_response(search_response) + + # Should use str() fallback since no results attribute + assert formatted == "Fallback string representation" diff --git a/tests/test_litellm/llms/bedrock/test_anthropic_beta_support.py b/tests/test_litellm/llms/bedrock/test_anthropic_beta_support.py index 074a319a603..897a3d6c056 100644 --- a/tests/test_litellm/llms/bedrock/test_anthropic_beta_support.py +++ b/tests/test_litellm/llms/bedrock/test_anthropic_beta_support.py @@ -51,15 +51,15 @@ def test_invoke_transformation_anthropic_beta(self): """Test that Invoke API transformation includes anthropic_beta in request.""" config = AmazonAnthropicClaudeConfig() headers = {"anthropic-beta": "context-1m-2025-08-07,computer-use-2024-10-22"} - + result = config.transform_request( - model="anthropic.claude-3-5-sonnet-20241022-v2:0", + model="anthropic.claude-opus-4-5-20250514-v1:0", messages=[{"role": "user", "content": "Test"}], optional_params={}, litellm_params={}, headers=headers ) - + assert "anthropic_beta" in result # Beta flags are stored as sets, so order may vary assert set(result["anthropic_beta"]) == {"context-1m-2025-08-07", "computer-use-2024-10-22"} @@ -68,15 +68,15 @@ def test_converse_transformation_anthropic_beta(self): """Test that Converse API transformation includes anthropic_beta in additionalModelRequestFields.""" config = AmazonConverseConfig() headers = {"anthropic-beta": "context-1m-2025-08-07,interleaved-thinking-2025-05-14"} - + result = config._transform_request_helper( - model="anthropic.claude-3-5-sonnet-20241022-v2:0", + model="anthropic.claude-opus-4-5-20250514-v1:0", system_content_blocks=[], optional_params={}, messages=[{"role": "user", "content": "Test"}], headers=headers ) - + assert "additionalModelRequestFields" in result additional_fields = result["additionalModelRequestFields"] assert "anthropic_beta" in additional_fields @@ -104,7 +104,7 @@ def test_converse_computer_use_compatibility(self): """Test that user anthropic_beta headers work with computer use tools.""" config = AmazonConverseConfig() headers = {"anthropic-beta": "context-1m-2025-08-07"} - + # Computer use tools should automatically add computer-use-2024-10-22 tools = [ { @@ -114,28 +114,29 @@ def test_converse_computer_use_compatibility(self): "display_height_px": 768 } ] - + result = config._transform_request_helper( - model="anthropic.claude-3-5-sonnet-20241022-v2:0", + model="anthropic.claude-opus-4-5-20250514-v1:0", system_content_blocks=[], optional_params={"tools": tools}, messages=[{"role": "user", "content": "Test"}], headers=headers ) - + additional_fields = result["additionalModelRequestFields"] betas = additional_fields["anthropic_beta"] - + # Should contain both user-provided and auto-added beta headers assert "context-1m-2025-08-07" in betas - assert "computer-use-2024-10-22" in betas + # Opus 4.5 gets computer-use-2025-11-24 (not the older 2024-10-22) + assert "computer-use-2025-11-24" in betas assert len(betas) == 2 # No duplicates def test_no_anthropic_beta_headers(self): """Test that transformations work correctly when no anthropic_beta headers are provided.""" config = AmazonConverseConfig() headers = {} - + result = config._transform_request_helper( model="anthropic.claude-3-5-sonnet-20241022-v2:0", system_content_blocks=[], @@ -143,10 +144,323 @@ def test_no_anthropic_beta_headers(self): messages=[{"role": "user", "content": "Test"}], headers=headers ) - + additional_fields = result.get("additionalModelRequestFields", {}) assert "anthropic_beta" not in additional_fields + +class TestBedrockBetaHeaderFiltering: + """Test centralized beta header filtering across all Bedrock APIs.""" + + def test_invoke_chat_filters_unsupported_headers(self): + """Test that Invoke Chat API filters out unsupported beta headers.""" + config = AmazonAnthropicClaudeConfig() + headers = { + "anthropic-beta": "computer-use-2025-01-24,unknown-beta-2099-01-01,context-1m-2025-08-07" + } + + result = config.transform_request( + model="anthropic.claude-opus-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + optional_params={}, + litellm_params={}, + headers=headers, + ) + + assert "anthropic_beta" in result + beta_set = set(result["anthropic_beta"]) + + # Should keep supported headers + assert "computer-use-2025-01-24" in beta_set + assert "context-1m-2025-08-07" in beta_set + + # Should filter out unsupported header + assert "unknown-beta-2099-01-01" not in beta_set + + def test_converse_filters_unsupported_headers(self): + """Test that Converse API filters out unsupported beta headers.""" + config = AmazonConverseConfig() + headers = { + "anthropic-beta": "interleaved-thinking-2025-05-14,unknown-beta-2099-01-01" + } + + result = config._transform_request_helper( + model="anthropic.claude-opus-4-5-20250514-v1:0", + system_content_blocks=[], + optional_params={}, + messages=[{"role": "user", "content": "Test"}], + headers=headers, + ) + + additional_fields = result["additionalModelRequestFields"] + beta_list = additional_fields["anthropic_beta"] + + # Should keep supported header + assert "interleaved-thinking-2025-05-14" in beta_list + + # Should filter out unsupported header + assert "unknown-beta-2099-01-01" not in beta_list + + def test_messages_filters_unsupported_headers(self): + """Test that Messages API filters out unsupported beta headers.""" + config = AmazonAnthropicClaudeMessagesConfig() + headers = { + "anthropic-beta": "output-128k-2025-02-19,unknown-beta-2099-01-01" + } + + result = config.transform_anthropic_messages_request( + model="anthropic.claude-opus-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + anthropic_messages_optional_request_params={"max_tokens": 100}, + litellm_params={}, + headers=headers, + ) + + assert "anthropic_beta" in result + beta_list = result["anthropic_beta"] + + # Should keep supported header + assert "output-128k-2025-02-19" in beta_list + + # Should filter out unsupported header + assert "unknown-beta-2099-01-01" not in beta_list + + def test_version_based_filtering_thinking_headers(self): + """Test that thinking headers are filtered based on model version.""" + config = AmazonAnthropicClaudeConfig() + headers = {"anthropic-beta": "interleaved-thinking-2025-05-14"} + + # Claude 4.5 should support thinking + result_45 = config.transform_request( + model="anthropic.claude-opus-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + optional_params={}, + litellm_params={}, + headers=headers, + ) + assert "anthropic_beta" in result_45 + assert "interleaved-thinking-2025-05-14" in result_45["anthropic_beta"] + + # Claude 3.5 should NOT support thinking + result_35 = config.transform_request( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", + messages=[{"role": "user", "content": "Test"}], + optional_params={}, + litellm_params={}, + headers=headers, + ) + # Should either not have anthropic_beta or not contain thinking header + if "anthropic_beta" in result_35: + assert "interleaved-thinking-2025-05-14" not in result_35["anthropic_beta"] + + def test_family_restriction_effort_opus_only(self): + """Test that effort parameter only works on Opus 4.5+.""" + config = AmazonAnthropicClaudeConfig() + headers = {"anthropic-beta": "effort-2025-11-24"} + + # Opus 4.5 should support effort + result_opus = config.transform_request( + model="anthropic.claude-opus-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + optional_params={}, + litellm_params={}, + headers=headers, + ) + assert "anthropic_beta" in result_opus + assert "effort-2025-11-24" in result_opus["anthropic_beta"] + + # Sonnet 4.5 should NOT support effort (wrong family) + result_sonnet = config.transform_request( + model="anthropic.claude-sonnet-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + optional_params={}, + litellm_params={}, + headers=headers, + ) + # Should either not have anthropic_beta or not contain effort + if "anthropic_beta" in result_sonnet: + assert "effort-2025-11-24" not in result_sonnet["anthropic_beta"] + + def test_tool_search_family_restriction(self): + """Test that tool search works on Opus and Sonnet 4.5+, but not Haiku.""" + config = AmazonAnthropicClaudeConfig() + headers = {"anthropic-beta": "tool-search-tool-2025-10-19"} + + # Opus 4.5 should support tool search + result_opus = config.transform_request( + model="anthropic.claude-opus-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + optional_params={}, + litellm_params={}, + headers=headers, + ) + assert "tool-search-tool-2025-10-19" in result_opus["anthropic_beta"] + + # Sonnet 4.5 should support tool search + result_sonnet = config.transform_request( + model="anthropic.claude-sonnet-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + optional_params={}, + litellm_params={}, + headers=headers, + ) + assert "tool-search-tool-2025-10-19" in result_sonnet["anthropic_beta"] + + # Haiku 4.5 should NOT support tool search (wrong family) + result_haiku = config.transform_request( + model="anthropic.claude-haiku-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + optional_params={}, + litellm_params={}, + headers=headers, + ) + # Should either not have anthropic_beta or not contain tool search + if "anthropic_beta" in result_haiku: + assert "tool-search-tool-2025-10-19" not in result_haiku["anthropic_beta"] + + def test_messages_advanced_tool_use_translation(self): + """Test that Messages API translates advanced-tool-use to tool search headers.""" + config = AmazonAnthropicClaudeMessagesConfig() + headers = {"anthropic-beta": "advanced-tool-use-2025-11-20"} + + # Opus 4.5 should translate advanced-tool-use + result = config.transform_anthropic_messages_request( + model="anthropic.claude-opus-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + anthropic_messages_optional_request_params={"max_tokens": 100}, + litellm_params={}, + headers=headers, + ) + + assert "anthropic_beta" in result + beta_list = result["anthropic_beta"] + + # Should translate to tool search headers + assert "tool-search-tool-2025-10-19" in beta_list + assert "tool-examples-2025-10-29" in beta_list + + # Should NOT contain original advanced-tool-use header + assert "advanced-tool-use-2025-11-20" not in beta_list + + def test_messages_advanced_tool_use_no_translation_old_model(self): + """Test that advanced-tool-use is NOT translated on older models.""" + config = AmazonAnthropicClaudeMessagesConfig() + headers = {"anthropic-beta": "advanced-tool-use-2025-11-20"} + + # Claude 4.0 should NOT translate (too old) + result = config.transform_anthropic_messages_request( + model="anthropic.claude-opus-4-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + anthropic_messages_optional_request_params={"max_tokens": 100}, + litellm_params={}, + headers=headers, + ) + + # Should not have anthropic_beta or should be empty + # (advanced-tool-use is not in whitelist and shouldn't translate) + if "anthropic_beta" in result: + assert len(result["anthropic_beta"]) == 0 + + def test_messages_advanced_tool_use_no_translation_haiku(self): + """Test that advanced-tool-use is NOT translated on Haiku (wrong family).""" + config = AmazonAnthropicClaudeMessagesConfig() + headers = {"anthropic-beta": "advanced-tool-use-2025-11-20"} + + # Haiku 4.5 should NOT translate (wrong family) + result = config.transform_anthropic_messages_request( + model="anthropic.claude-haiku-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + anthropic_messages_optional_request_params={"max_tokens": 100}, + litellm_params={}, + headers=headers, + ) + + # Should not have anthropic_beta or should be empty + if "anthropic_beta" in result: + assert len(result["anthropic_beta"]) == 0 + + def test_cross_api_consistency(self): + """Test that same headers work consistently across all three APIs.""" + headers = {"anthropic-beta": "computer-use-2025-01-24,context-1m-2025-08-07"} + model = "anthropic.claude-opus-4-5-20250514-v1:0" + + # Invoke Chat + config_invoke = AmazonAnthropicClaudeConfig() + result_invoke = config_invoke.transform_request( + model=model, + messages=[{"role": "user", "content": "Test"}], + optional_params={}, + litellm_params={}, + headers=headers, + ) + + # Converse + config_converse = AmazonConverseConfig() + result_converse = config_converse._transform_request_helper( + model=model, + system_content_blocks=[], + optional_params={}, + messages=[{"role": "user", "content": "Test"}], + headers=headers, + ) + + # Messages + config_messages = AmazonAnthropicClaudeMessagesConfig() + result_messages = config_messages.transform_anthropic_messages_request( + model=model, + messages=[{"role": "user", "content": "Test"}], + anthropic_messages_optional_request_params={"max_tokens": 100}, + litellm_params={}, + headers=headers, + ) + + # All should have the same beta headers + invoke_betas = set(result_invoke["anthropic_beta"]) + converse_betas = set( + result_converse["additionalModelRequestFields"]["anthropic_beta"] + ) + messages_betas = set(result_messages["anthropic_beta"]) + + assert invoke_betas == converse_betas == messages_betas + assert "computer-use-2025-01-24" in invoke_betas + assert "context-1m-2025-08-07" in invoke_betas + + def test_backward_compatibility_existing_headers(self): + """Test that all previously supported headers still work after migration.""" + config = AmazonAnthropicClaudeConfig() + + # Test all 11 core supported beta headers + all_headers = [ + "computer-use-2024-10-22", + "computer-use-2025-01-24", + "token-efficient-tools-2025-02-19", + "interleaved-thinking-2025-05-14", + "output-128k-2025-02-19", + "dev-full-thinking-2025-05-14", + "context-1m-2025-08-07", + "context-management-2025-06-27", + "effort-2025-11-24", + "tool-search-tool-2025-10-19", + "tool-examples-2025-10-29", + ] + + headers = {"anthropic-beta": ",".join(all_headers)} + + result = config.transform_request( + model="anthropic.claude-opus-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + optional_params={}, + litellm_params={}, + headers=headers, + ) + + assert "anthropic_beta" in result + result_betas = set(result["anthropic_beta"]) + + # All headers should be present (Opus 4.5 supports all of them) + for header in all_headers: + assert header in result_betas, f"Header {header} was filtered out unexpectedly" + def test_anthropic_beta_all_supported_features(self): """Test that all documented beta features are properly handled.""" supported_features = [ @@ -158,18 +472,19 @@ def test_anthropic_beta_all_supported_features(self): "output-128k-2025-02-19", "dev-full-thinking-2025-05-14" ] - + config = AmazonAnthropicClaudeConfig() headers = {"anthropic-beta": ",".join(supported_features)} - + + # Use Claude 4.5+ model since several features require 4.0+ result = config.transform_request( - model="anthropic.claude-3-5-sonnet-20241022-v2:0", + model="anthropic.claude-opus-4-5-20250514-v1:0", messages=[{"role": "user", "content": "Test"}], optional_params={}, litellm_params={}, headers=headers ) - + assert "anthropic_beta" in result # Beta flags are stored as sets, so order may vary assert set(result["anthropic_beta"]) == set(supported_features) @@ -355,8 +670,8 @@ def test_converse_nova_model_no_anthropic_beta(self): def test_converse_anthropic_model_gets_anthropic_beta(self): """Test that Anthropic models DO get anthropic_beta in additionalModelRequestFields.""" config = AmazonConverseConfig() - headers = {"anthropic-beta": "context-1m-2025-08-07"} - + headers = {"anthropic-beta": "computer-use-2025-01-24"} + result = config._transform_request_helper( model="anthropic.claude-3-5-sonnet-20241022-v2:0", system_content_blocks=[], @@ -364,18 +679,18 @@ def test_converse_anthropic_model_gets_anthropic_beta(self): messages=[{"role": "user", "content": "Test"}], headers=headers ) - + additional_fields = result.get("additionalModelRequestFields", {}) assert "anthropic_beta" in additional_fields, ( "anthropic_beta SHOULD be added for Anthropic models." ) - assert "context-1m-2025-08-07" in additional_fields["anthropic_beta"] + assert "computer-use-2025-01-24" in additional_fields["anthropic_beta"] def test_converse_anthropic_model_with_cross_region_prefix(self): """Test that Anthropic models with cross-region prefix still get anthropic_beta.""" config = AmazonConverseConfig() - headers = {"anthropic-beta": "context-1m-2025-08-07"} - + headers = {"anthropic-beta": "computer-use-2025-01-24"} + # Model with 'us.' cross-region prefix result = config._transform_request_helper( model="us.anthropic.claude-3-5-sonnet-20241022-v2:0", @@ -384,9 +699,174 @@ def test_converse_anthropic_model_with_cross_region_prefix(self): messages=[{"role": "user", "content": "Test"}], headers=headers ) - + additional_fields = result.get("additionalModelRequestFields", {}) assert "anthropic_beta" in additional_fields, ( "anthropic_beta SHOULD be added for Anthropic models with cross-region prefix." ) - assert "context-1m-2025-08-07" in additional_fields["anthropic_beta"] \ No newline at end of file + assert "computer-use-2025-01-24" in additional_fields["anthropic_beta"] + + def test_messages_advanced_tool_use_translation_opus_4_5(self): + """Test that advanced-tool-use header is translated to Bedrock-specific headers for Opus 4.5. + + Regression test for: Claude Code sends advanced-tool-use-2025-11-20 header which needs + to be translated to tool-search-tool-2025-10-19 and tool-examples-2025-10-29 for + Bedrock Invoke API on Claude Opus 4.5. + + Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html + """ + config = AmazonAnthropicClaudeMessagesConfig() + headers = {"anthropic-beta": "advanced-tool-use-2025-11-20"} + + result = config.transform_anthropic_messages_request( + model="us.anthropic.claude-opus-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + anthropic_messages_optional_request_params={"max_tokens": 100}, + litellm_params={}, + headers=headers + ) + + assert "anthropic_beta" in result + beta_headers = result["anthropic_beta"] + + # advanced-tool-use should be removed + assert "advanced-tool-use-2025-11-20" not in beta_headers, ( + "advanced-tool-use-2025-11-20 should be removed for Bedrock Invoke API" + ) + + # Bedrock-specific headers should be added for Opus 4.5 + assert "tool-search-tool-2025-10-19" in beta_headers, ( + "tool-search-tool-2025-10-19 should be added for Opus 4.5" + ) + assert "tool-examples-2025-10-29" in beta_headers, ( + "tool-examples-2025-10-29 should be added for Opus 4.5" + ) + + def test_messages_advanced_tool_use_translation_sonnet_4_5(self): + """Test that advanced-tool-use header is translated to Bedrock-specific headers for Sonnet 4.5. + + Regression test for: Claude Code sends advanced-tool-use-2025-11-20 header which needs + to be translated to tool-search-tool-2025-10-19 and tool-examples-2025-10-29 for + Bedrock Invoke API on Claude Sonnet 4.5. + + Ref: https://platform.claude.com/docs/en/agents-and-tools/tool-use/tool-search-tool + """ + config = AmazonAnthropicClaudeMessagesConfig() + headers = {"anthropic-beta": "advanced-tool-use-2025-11-20"} + + result = config.transform_anthropic_messages_request( + model="us.anthropic.claude-sonnet-4-5-20250929-v1:0", + messages=[{"role": "user", "content": "Test"}], + anthropic_messages_optional_request_params={"max_tokens": 100}, + litellm_params={}, + headers=headers + ) + + assert "anthropic_beta" in result + beta_headers = result["anthropic_beta"] + + # advanced-tool-use should be removed + assert "advanced-tool-use-2025-11-20" not in beta_headers, ( + "advanced-tool-use-2025-11-20 should be removed for Bedrock Invoke API" + ) + + # Bedrock-specific headers should be added for Sonnet 4.5 + assert "tool-search-tool-2025-10-19" in beta_headers, ( + "tool-search-tool-2025-10-19 should be added for Sonnet 4.5" + ) + assert "tool-examples-2025-10-29" in beta_headers, ( + "tool-examples-2025-10-29 should be added for Sonnet 4.5" + ) + + def test_messages_advanced_tool_use_filtered_unsupported_model(self): + """Test that advanced-tool-use header is filtered out for models that don't support tool search. + + The translation to Bedrock-specific headers should only happen for models that + support tool search on Bedrock (Opus 4.5, Sonnet 4.5). + For other models, the advanced-tool-use header should just be removed. + """ + config = AmazonAnthropicClaudeMessagesConfig() + headers = {"anthropic-beta": "advanced-tool-use-2025-11-20"} + + # Test with Claude 3.5 Sonnet (does NOT support tool search on Bedrock) + result = config.transform_anthropic_messages_request( + model="us.anthropic.claude-3-5-sonnet-20241022-v2:0", + messages=[{"role": "user", "content": "Test"}], + anthropic_messages_optional_request_params={"max_tokens": 100}, + litellm_params={}, + headers=headers + ) + + beta_headers = result.get("anthropic_beta", []) + + # advanced-tool-use should be removed + assert "advanced-tool-use-2025-11-20" not in beta_headers + + # Bedrock-specific headers should NOT be added for unsupported models + assert "tool-search-tool-2025-10-19" not in beta_headers + assert "tool-examples-2025-10-29" not in beta_headers + + +class TestContextManagementBodyParamStripping: + """Test that context_management is stripped from request body for Bedrock APIs. + + Bedrock doesn't support context_management as a request body parameter. + The feature is enabled via the anthropic-beta header instead. If left in the body, + Bedrock returns: 'context_management: Extra inputs are not permitted'. + """ + + def test_messages_api_strips_context_management(self): + """Test that Messages API removes context_management from request body.""" + config = AmazonAnthropicClaudeMessagesConfig() + headers = {} + + result = config.transform_anthropic_messages_request( + model="anthropic.claude-sonnet-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + anthropic_messages_optional_request_params={ + "max_tokens": 100, + "context_management": {"type": "automatic", "max_context_tokens": 50000}, + }, + litellm_params={}, + headers=headers, + ) + + # context_management must NOT be in the request body + assert "context_management" not in result + + def test_invoke_chat_api_strips_context_management(self): + """Test that Invoke Chat API removes context_management from request body.""" + config = AmazonAnthropicClaudeConfig() + headers = {} + + result = config.transform_request( + model="anthropic.claude-sonnet-4-5-20250514-v1:0", + messages=[{"role": "user", "content": "Test"}], + optional_params={ + "context_management": {"type": "automatic", "max_context_tokens": 50000}, + }, + litellm_params={}, + headers=headers, + ) + + # context_management must NOT be in the request body + assert "context_management" not in result + + def test_converse_api_strips_context_management(self): + """Test that Converse API doesn't pass context_management in additionalModelRequestFields.""" + config = AmazonConverseConfig() + headers = {} + + result = config._transform_request_helper( + model="anthropic.claude-sonnet-4-5-20250514-v1:0", + system_content_blocks=[], + optional_params={ + "context_management": {"type": "automatic", "max_context_tokens": 50000}, + }, + messages=[{"role": "user", "content": "Test"}], + headers=headers, + ) + + additional_fields = result.get("additionalModelRequestFields", {}) + # context_management must NOT leak into additionalModelRequestFields + assert "context_management" not in additional_fields diff --git a/tests/test_litellm/llms/bedrock/test_beta_headers_config.py b/tests/test_litellm/llms/bedrock/test_beta_headers_config.py new file mode 100644 index 00000000000..59fa77fbb04 --- /dev/null +++ b/tests/test_litellm/llms/bedrock/test_beta_headers_config.py @@ -0,0 +1,482 @@ +""" +Comprehensive tests for Bedrock beta headers configuration. + +Tests the centralized whitelist-based filtering with version-based model support. +""" + +import pytest + +from litellm.llms.bedrock.beta_headers_config import ( + BEDROCK_CORE_SUPPORTED_BETAS, + BedrockAPI, + BedrockBetaHeaderFilter, + get_bedrock_beta_filter, +) + + +class TestBedrockBetaHeaderFilter: + """Test the BedrockBetaHeaderFilter class.""" + + def test_factory_function(self): + """Test factory function returns correct filter instance.""" + filter_chat = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + assert isinstance(filter_chat, BedrockBetaHeaderFilter) + assert filter_chat.api_type == BedrockAPI.INVOKE_CHAT + + filter_converse = get_bedrock_beta_filter(BedrockAPI.CONVERSE) + assert isinstance(filter_converse, BedrockBetaHeaderFilter) + assert filter_converse.api_type == BedrockAPI.CONVERSE + + def test_whitelist_filtering_basic(self): + """Test basic whitelist filtering keeps supported headers.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + model = "anthropic.claude-opus-4-5-20250514-v1:0" + + # Supported header should pass through + result = filter_obj.filter_beta_headers( + ["computer-use-2025-01-24"], model, translate=False + ) + assert result == ["computer-use-2025-01-24"] + + def test_whitelist_filtering_blocks_unsupported(self): + """Test whitelist filtering blocks unsupported headers.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + model = "anthropic.claude-opus-4-5-20250514-v1:0" + + # Unsupported header should be filtered out + result = filter_obj.filter_beta_headers( + ["unknown-beta-2099-01-01"], model, translate=False + ) + assert result == [] + + def test_whitelist_filtering_mixed_headers(self): + """Test filtering with mix of supported and unsupported headers.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + model = "anthropic.claude-opus-4-5-20250514-v1:0" + + result = filter_obj.filter_beta_headers( + [ + "computer-use-2025-01-24", # Supported + "unknown-beta-2099-01-01", # Unsupported + "effort-2025-11-24", # Supported + ], + model, + translate=False, + ) + # Should only keep supported headers + assert set(result) == {"computer-use-2025-01-24", "effort-2025-11-24"} + + def test_empty_headers_list(self): + """Test filtering with empty headers list.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + model = "anthropic.claude-opus-4-5-20250514-v1:0" + + result = filter_obj.filter_beta_headers([], model) + assert result == [] + + def test_all_supported_betas_in_whitelist(self): + """Test that all core supported betas are in whitelist.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + assert len(filter_obj.supported_betas) == len(BEDROCK_CORE_SUPPORTED_BETAS) + + +class TestModelVersionExtraction: + """Test model version extraction logic.""" + + def test_extract_version_opus_4_5(self): + """Test version extraction for Claude Opus 4.5.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + version = filter_obj._extract_model_version( + "anthropic.claude-opus-4-5-20250514-v1:0" + ) + assert version == 4.5 + + def test_extract_version_sonnet_4(self): + """Test version extraction for Claude Sonnet 4.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + version = filter_obj._extract_model_version( + "anthropic.claude-sonnet-4-20250514-v1:0" + ) + assert version == 4.0 + + def test_extract_version_legacy_3_5_sonnet(self): + """Test version extraction for legacy Claude 3.5 Sonnet.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + version = filter_obj._extract_model_version( + "anthropic.claude-3-5-sonnet-20240620-v1:0" + ) + assert version == 3.5 + + def test_extract_version_legacy_3_sonnet(self): + """Test version extraction for legacy Claude 3 Sonnet.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + version = filter_obj._extract_model_version( + "anthropic.claude-3-sonnet-20240229-v1:0" + ) + assert version == 3.0 + + def test_extract_version_haiku_4_5(self): + """Test version extraction for Claude Haiku 4.5.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + version = filter_obj._extract_model_version( + "anthropic.claude-haiku-4-5-20250514-v1:0" + ) + assert version == 4.5 + + def test_extract_version_invalid_format(self): + """Test version extraction with invalid model format.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + version = filter_obj._extract_model_version("invalid-model-format") + assert version is None + + def test_extract_version_future_opus_5(self): + """Test version extraction for future Claude Opus 5.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + version = filter_obj._extract_model_version( + "anthropic.claude-opus-5-20270101-v1:0" + ) + assert version == 5.0 + + +class TestModelFamilyExtraction: + """Test model family extraction logic.""" + + def test_extract_family_opus(self): + """Test family extraction for Opus models.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + family = filter_obj._extract_model_family( + "anthropic.claude-opus-4-5-20250514-v1:0" + ) + assert family == "opus" + + def test_extract_family_sonnet(self): + """Test family extraction for Sonnet models.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + family = filter_obj._extract_model_family( + "anthropic.claude-sonnet-4-20250514-v1:0" + ) + assert family == "sonnet" + + def test_extract_family_haiku(self): + """Test family extraction for Haiku models.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + family = filter_obj._extract_model_family( + "anthropic.claude-haiku-4-5-20250514-v1:0" + ) + assert family == "haiku" + + def test_extract_family_legacy_sonnet(self): + """Test family extraction for legacy Sonnet naming.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + family = filter_obj._extract_model_family( + "anthropic.claude-3-5-sonnet-20240620-v1:0" + ) + assert family == "sonnet" + + def test_extract_family_invalid(self): + """Test family extraction with invalid model format.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + family = filter_obj._extract_model_family("invalid-model-format") + assert family is None + + +class TestVersionBasedFiltering: + """Test version-based filtering for beta headers.""" + + def test_thinking_headers_require_claude_4(self): + """Test that thinking headers require Claude 4.0+.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + + # Claude 4.5 should support thinking + result = filter_obj.filter_beta_headers( + ["interleaved-thinking-2025-05-14"], + "anthropic.claude-opus-4-5-20250514-v1:0", + translate=False, + ) + assert "interleaved-thinking-2025-05-14" in result + + # Claude 3.5 should NOT support thinking + result = filter_obj.filter_beta_headers( + ["interleaved-thinking-2025-05-14"], + "anthropic.claude-3-5-sonnet-20240620-v1:0", + translate=False, + ) + assert "interleaved-thinking-2025-05-14" not in result + + def test_context_management_requires_claude_4_5(self): + """Test that context management requires Claude 4.5+.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + + # Claude 4.5 should support context management + result = filter_obj.filter_beta_headers( + ["context-management-2025-06-27"], + "anthropic.claude-sonnet-4-5-20250514-v1:0", + translate=False, + ) + assert "context-management-2025-06-27" in result + + # Claude 4.0 should NOT support context management + result = filter_obj.filter_beta_headers( + ["context-management-2025-06-27"], + "anthropic.claude-sonnet-4-20250514-v1:0", + translate=False, + ) + assert "context-management-2025-06-27" not in result + + def test_computer_use_works_on_all_versions(self): + """Test that computer-use works on all Claude versions (no version requirement).""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + + # Claude 3.5 should support computer use + result = filter_obj.filter_beta_headers( + ["computer-use-2025-01-24"], + "anthropic.claude-3-5-sonnet-20240620-v1:0", + translate=False, + ) + assert "computer-use-2025-01-24" in result + + # Claude 4.5 should also support computer use + result = filter_obj.filter_beta_headers( + ["computer-use-2025-01-24"], + "anthropic.claude-opus-4-5-20250514-v1:0", + translate=False, + ) + assert "computer-use-2025-01-24" in result + + def test_future_model_supports_all_headers(self): + """Test that future Claude 5.0 automatically supports all 4.0+ headers.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + model = "anthropic.claude-opus-5-20270101-v1:0" + + # Claude 5 should support all headers requiring 4.0+ + result = filter_obj.filter_beta_headers( + [ + "interleaved-thinking-2025-05-14", # Requires 4.0+ + "context-management-2025-06-27", # Requires 4.5+ + "context-1m-2025-08-07", # Requires 4.0+ + ], + model, + translate=False, + ) + assert len(result) == 3 + assert "interleaved-thinking-2025-05-14" in result + assert "context-management-2025-06-27" in result + assert "context-1m-2025-08-07" in result + + +class TestFamilyRestrictions: + """Test model family restrictions for beta headers.""" + + def test_effort_only_on_opus(self): + """Test that effort parameter only works on Opus 4.5+.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + + # Opus 4.5 should support effort + result = filter_obj.filter_beta_headers( + ["effort-2025-11-24"], + "anthropic.claude-opus-4-5-20250514-v1:0", + translate=False, + ) + assert "effort-2025-11-24" in result + + # Sonnet 4.5 should NOT support effort (wrong family) + result = filter_obj.filter_beta_headers( + ["effort-2025-11-24"], + "anthropic.claude-sonnet-4-5-20250514-v1:0", + translate=False, + ) + assert "effort-2025-11-24" not in result + + # Haiku 4.5 should NOT support effort (wrong family) + result = filter_obj.filter_beta_headers( + ["effort-2025-11-24"], + "anthropic.claude-haiku-4-5-20250514-v1:0", + translate=False, + ) + assert "effort-2025-11-24" not in result + + def test_tool_search_on_opus_and_sonnet(self): + """Test that tool search works on Opus and Sonnet 4.5+, but not Haiku.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + + # Opus 4.5 should support tool search + result = filter_obj.filter_beta_headers( + ["tool-search-tool-2025-10-19"], + "anthropic.claude-opus-4-5-20250514-v1:0", + translate=False, + ) + assert "tool-search-tool-2025-10-19" in result + + # Sonnet 4.5 should support tool search + result = filter_obj.filter_beta_headers( + ["tool-search-tool-2025-10-19"], + "anthropic.claude-sonnet-4-5-20250514-v1:0", + translate=False, + ) + assert "tool-search-tool-2025-10-19" in result + + # Haiku 4.5 should NOT support tool search (wrong family) + result = filter_obj.filter_beta_headers( + ["tool-search-tool-2025-10-19"], + "anthropic.claude-haiku-4-5-20250514-v1:0", + translate=False, + ) + assert "tool-search-tool-2025-10-19" not in result + + +class TestBetaHeaderTranslation: + """Test beta header translation for backward compatibility.""" + + def test_advanced_tool_use_translation_opus_4_5(self): + """Test advanced-tool-use translates to tool search headers on Opus 4.5.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_MESSAGES) + model = "anthropic.claude-opus-4-5-20250514-v1:0" + + result = filter_obj.filter_beta_headers( + ["advanced-tool-use-2025-11-20"], model, translate=True + ) + + # Should translate to tool search headers + assert "tool-search-tool-2025-10-19" in result + assert "tool-examples-2025-10-29" in result + assert "advanced-tool-use-2025-11-20" not in result + + def test_advanced_tool_use_translation_sonnet_4_5(self): + """Test advanced-tool-use translates on Sonnet 4.5.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_MESSAGES) + model = "anthropic.claude-sonnet-4-5-20250514-v1:0" + + result = filter_obj.filter_beta_headers( + ["advanced-tool-use-2025-11-20"], model, translate=True + ) + + # Should translate to tool search headers + assert "tool-search-tool-2025-10-19" in result + assert "tool-examples-2025-10-29" in result + + def test_advanced_tool_use_no_translation_claude_4(self): + """Test advanced-tool-use does NOT translate on Claude 4.0 (too old).""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_MESSAGES) + model = "anthropic.claude-opus-4-20250514-v1:0" + + result = filter_obj.filter_beta_headers( + ["advanced-tool-use-2025-11-20"], model, translate=True + ) + + # Should not translate (version too old) + assert result == [] + + def test_advanced_tool_use_no_translation_haiku(self): + """Test advanced-tool-use does NOT translate on Haiku (wrong family).""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_MESSAGES) + model = "anthropic.claude-haiku-4-5-20250514-v1:0" + + result = filter_obj.filter_beta_headers( + ["advanced-tool-use-2025-11-20"], model, translate=True + ) + + # Should not translate (wrong family) + assert result == [] + + def test_translation_disabled(self): + """Test that translation can be disabled.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_MESSAGES) + model = "anthropic.claude-opus-4-5-20250514-v1:0" + + result = filter_obj.filter_beta_headers( + ["advanced-tool-use-2025-11-20"], model, translate=False + ) + + # Should not translate when disabled + # advanced-tool-use is not in whitelist, so should be filtered out + assert result == [] + + +class TestCrossAPIConsistency: + """Test that filtering is consistent across all three APIs.""" + + def test_same_headers_work_on_all_apis(self): + """Test that supported headers work consistently across all APIs.""" + model = "anthropic.claude-opus-4-5-20250514-v1:0" + headers = ["computer-use-2025-01-24", "effort-2025-11-24"] + + filter_chat = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + filter_messages = get_bedrock_beta_filter(BedrockAPI.INVOKE_MESSAGES) + filter_converse = get_bedrock_beta_filter(BedrockAPI.CONVERSE) + + result_chat = set(filter_chat.filter_beta_headers(headers, model, translate=False)) + result_messages = set( + filter_messages.filter_beta_headers(headers, model, translate=False) + ) + result_converse = set( + filter_converse.filter_beta_headers(headers, model, translate=False) + ) + + # All APIs should return the same results + assert result_chat == result_messages == result_converse + + def test_unsupported_headers_filtered_on_all_apis(self): + """Test that unsupported headers are filtered consistently.""" + model = "anthropic.claude-opus-4-5-20250514-v1:0" + headers = ["unknown-beta-2099-01-01"] + + filter_chat = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + filter_messages = get_bedrock_beta_filter(BedrockAPI.INVOKE_MESSAGES) + filter_converse = get_bedrock_beta_filter(BedrockAPI.CONVERSE) + + result_chat = filter_chat.filter_beta_headers(headers, model) + result_messages = filter_messages.filter_beta_headers(headers, model) + result_converse = filter_converse.filter_beta_headers(headers, model) + + # All should filter out unsupported headers + assert result_chat == result_messages == result_converse == [] + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_none_model_version_blocks_versioned_headers(self): + """Test that unparseable model version blocks headers with version requirements.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + model = "invalid-model-format" + + # Headers with version requirements should be blocked + result = filter_obj.filter_beta_headers( + ["interleaved-thinking-2025-05-14"], model, translate=False + ) + assert result == [] + + # Headers without version requirements should still work + result = filter_obj.filter_beta_headers( + ["computer-use-2025-01-24"], model, translate=False + ) + assert "computer-use-2025-01-24" in result + + def test_duplicate_headers_deduplicated(self): + """Test that duplicate headers are deduplicated.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + model = "anthropic.claude-opus-4-5-20250514-v1:0" + + result = filter_obj.filter_beta_headers( + [ + "computer-use-2025-01-24", + "computer-use-2025-01-24", + "computer-use-2025-01-24", + ], + model, + translate=False, + ) + assert result == ["computer-use-2025-01-24"] + + def test_output_is_sorted(self): + """Test that output is sorted for deterministic results.""" + filter_obj = get_bedrock_beta_filter(BedrockAPI.INVOKE_CHAT) + model = "anthropic.claude-opus-4-5-20250514-v1:0" + + result = filter_obj.filter_beta_headers( + ["effort-2025-11-24", "computer-use-2025-01-24", "context-1m-2025-08-07"], + model, + translate=False, + ) + # Should be alphabetically sorted + assert result == sorted(result) diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py index 84d320a0a27..9a04a6ac22a 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py @@ -1189,3 +1189,185 @@ async def test_bedrock_guardrail_blocked_content_with_masking_enabled(): print("✅ BLOCKED content with masking enabled raises exception correctly") + +@pytest.mark.asyncio +async def test__redact_pii_matches_with_null_list_fields(): + """Test that _redact_pii_matches handles None/null list fields without crashing. + + Bedrock API can return null for fields like regexes, customWords, managedWordLists, + and piiEntities. The .get("key", []) pattern returns None (not []) when the key + exists with a null value, which previously caused 'NoneType' object is not iterable. + """ + + # Real-world response from Bedrock where regexes is null + response_with_null_regexes = { + "action": "NONE", + "actionReason": "No action.", + "assessments": [ + { + "sensitiveInformationPolicy": { + "piiEntities": [ + { + "action": "NONE", + "detected": True, + "match": "joebloggs@gmail.com", + "type": "EMAIL", + } + ], + "regexes": None, # null from Bedrock API + }, + "wordPolicy": None, # entire policy is null + "topicPolicy": None, + "contentPolicy": None, + "contextualGroundingPolicy": None, + } + ], + } + + # Should not raise any exception + redacted = _redact_pii_matches(response_with_null_regexes) + + # PII entity match should be redacted + pii_entities = redacted["assessments"][0]["sensitiveInformationPolicy"][ + "piiEntities" + ] + assert pii_entities[0]["match"] == "[REDACTED]" + assert pii_entities[0]["type"] == "EMAIL" + + # Test with null piiEntities and non-null regexes + response_with_null_pii = { + "action": "NONE", + "assessments": [ + { + "sensitiveInformationPolicy": { + "piiEntities": None, # null + "regexes": [ + { + "name": "CUSTOM", + "match": "secret-pattern", + "action": "BLOCKED", + } + ], + }, + } + ], + } + + redacted = _redact_pii_matches(response_with_null_pii) + regexes = redacted["assessments"][0]["sensitiveInformationPolicy"]["regexes"] + assert regexes[0]["match"] == "[REDACTED]" + + # Test with null customWords and managedWordLists in wordPolicy + response_with_null_word_lists = { + "action": "NONE", + "assessments": [ + { + "wordPolicy": { + "customWords": None, # null + "managedWordLists": None, # null + }, + } + ], + } + + # Should not raise any exception + redacted = _redact_pii_matches(response_with_null_word_lists) + assert redacted["assessments"][0]["wordPolicy"]["customWords"] is None + assert redacted["assessments"][0]["wordPolicy"]["managedWordLists"] is None + + +@pytest.mark.asyncio +async def test_should_raise_guardrail_blocked_exception_with_null_list_fields(): + """Test that _should_raise_guardrail_blocked_exception handles None/null list fields. + + Same issue as _redact_pii_matches: Bedrock API returns null for list fields + like topics, filters, customWords, etc. which causes iteration over None. + """ + + guardrail = BedrockGuardrail( + guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT" + ) + + # Response where all policy sub-lists are null + response_all_null_lists = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [ + { + "topicPolicy": { + "topics": None, # null + }, + "contentPolicy": { + "filters": None, # null + }, + "wordPolicy": { + "customWords": None, # null + "managedWordLists": None, # null + }, + "sensitiveInformationPolicy": { + "piiEntities": None, # null + "regexes": None, # null + }, + "contextualGroundingPolicy": { + "filters": None, # null + }, + } + ], + } + + # Should not raise any exception and should return False + # (no BLOCKED actions found since all lists are null) + result = guardrail._should_raise_guardrail_blocked_exception(response_all_null_lists) + assert result is False + + # Response with a mix of null lists and a BLOCKED action + response_mixed_null_with_blocked = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [ + { + "topicPolicy": { + "topics": None, # null - should not crash + }, + "contentPolicy": { + "filters": [ + { + "type": "HATE", + "confidence": "HIGH", + "action": "BLOCKED", + } + ], + }, + "wordPolicy": { + "customWords": None, # null + "managedWordLists": None, # null + }, + "sensitiveInformationPolicy": { + "piiEntities": None, # null + "regexes": None, # null + }, + "contextualGroundingPolicy": None, # entire policy is null + } + ], + } + + # Should return True because there's a BLOCKED content filter + result = guardrail._should_raise_guardrail_blocked_exception( + response_mixed_null_with_blocked + ) + assert result is True + + # Response with null lists but action is not GUARDRAIL_INTERVENED + response_no_intervention = { + "action": "NONE", + "assessments": [ + { + "sensitiveInformationPolicy": { + "piiEntities": None, + "regexes": None, + }, + } + ], + } + + result = guardrail._should_raise_guardrail_blocked_exception(response_no_intervention) + assert result is False + diff --git a/tests/test_litellm/test_utils.py b/tests/test_litellm/test_utils.py index 7b29a4d90aa..b1938f0a3fb 100644 --- a/tests/test_litellm/test_utils.py +++ b/tests/test_litellm/test_utils.py @@ -3117,6 +3117,136 @@ def test_last_assistant_with_tool_calls_has_no_thinking_blocks_issue_18926(): assert should_drop_thinking is False +def test_last_assistant_message_has_no_thinking_blocks_text_only(): + """ + Test the scenario where the last assistant message has text content but no + thinking_blocks. This triggers the error: + "Expected thinking or redacted_thinking, but found text" + """ + from litellm.utils import ( + any_assistant_message_has_thinking_blocks, + last_assistant_message_has_no_thinking_blocks, + last_assistant_with_tool_calls_has_no_thinking_blocks, + ) + + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "What's 2+2?"}, + {"role": "assistant", "content": "4"}, + {"role": "user", "content": "Thanks"}, + ] + + # No assistant has tool_calls, so the old check returns False (doesn't detect issue) + assert last_assistant_with_tool_calls_has_no_thinking_blocks(messages) is False + + # New check detects the missing thinking blocks + assert last_assistant_message_has_no_thinking_blocks(messages) is True + + # No assistant has thinking_blocks + assert any_assistant_message_has_thinking_blocks(messages) is False + + # With the new check, we correctly detect thinking should be dropped + should_drop_thinking = ( + last_assistant_with_tool_calls_has_no_thinking_blocks(messages) + or last_assistant_message_has_no_thinking_blocks(messages) + ) and not any_assistant_message_has_thinking_blocks(messages) + assert should_drop_thinking is True + + +def test_last_assistant_message_has_no_thinking_blocks_with_content_list(): + """ + Test detection when assistant message has content as a list of blocks (Anthropic format). + """ + from litellm.utils import last_assistant_message_has_no_thinking_blocks + + messages = [ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "content": [{"type": "text", "text": "Hi there!"}], + }, + ] + + assert last_assistant_message_has_no_thinking_blocks(messages) is True + + +def test_last_assistant_message_has_thinking_in_content(): + """ + Test that function returns False when thinking blocks are in content array + (Anthropic format) rather than in the thinking_blocks field. + """ + from litellm.utils import ( + any_assistant_message_has_thinking_blocks, + last_assistant_message_has_no_thinking_blocks, + ) + + messages = [ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "Let me think..."}, + {"type": "text", "text": "The answer is 42."}, + ], + }, + ] + + # Content has thinking blocks, so should return False + assert last_assistant_message_has_no_thinking_blocks(messages) is False + + # any_assistant check should also detect thinking blocks in content + assert any_assistant_message_has_thinking_blocks(messages) is True + + +def test_last_assistant_message_no_content(): + """ + Test that function returns False when last assistant has no content. + """ + from litellm.utils import last_assistant_message_has_no_thinking_blocks + + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": None}, + ] + + assert last_assistant_message_has_no_thinking_blocks(messages) is False + + +def test_no_assistant_messages(): + """ + Test that function returns False when there are no assistant messages. + """ + from litellm.utils import last_assistant_message_has_no_thinking_blocks + + messages = [ + {"role": "user", "content": "Hello"}, + ] + + assert last_assistant_message_has_no_thinking_blocks(messages) is False + + +def test_thinking_blocks_field_detected_by_any_check(): + """ + Test that any_assistant_message_has_thinking_blocks detects thinking blocks + in both the thinking_blocks field and in the content array. + """ + from litellm.utils import any_assistant_message_has_thinking_blocks + + # Thinking in content array (Anthropic format) + messages_content = [ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "content": [ + {"type": "redacted_thinking", "data": "xxx"}, + {"type": "text", "text": "answer"}, + ], + }, + ] + assert any_assistant_message_has_thinking_blocks(messages_content) is True + + class TestAdditionalDropParamsForNonOpenAIProviders: """ Test additional_drop_params functionality for non-OpenAI providers.