Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion litellm/llms/bedrock/chat/converse_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def make_sync_call(
model_response=model_response, json_mode=json_mode
)
else:
decoder = AWSEventStreamDecoder(model=model)
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=stream_chunk_size))

# LOGGING
Expand Down
118 changes: 84 additions & 34 deletions litellm/llms/bedrock/chat/converse_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import copy
import json
import time
import types
from typing import List, Literal, Optional, Tuple, Union, cast, overload
Expand Down Expand Up @@ -1552,6 +1553,81 @@ def _translate_message_content(

return content_str, tools, reasoningContentBlocks, citationsContentBlocks

@staticmethod
def _filter_json_mode_tools(
json_mode: Optional[bool],
tools: List[ChatCompletionToolCallChunk],
chat_completion_message: ChatCompletionResponseMessage,
) -> Optional[List[ChatCompletionToolCallChunk]]:
"""
When json_mode is True, Bedrock may return the internal `json_tool_call`
tool alongside real user-defined tools. This method handles 3 scenarios:

1. Only json_tool_call present -> convert to text content, return None
2. Mixed json_tool_call + real -> filter out json_tool_call, return real tools
3. No json_tool_call / no json_mode -> return tools as-is
"""
if not json_mode or not tools:
return tools if tools else None

json_tool_indices = [
i
for i, t in enumerate(tools)
if t["function"].get("name") == RESPONSE_FORMAT_TOOL_NAME
]

if not json_tool_indices:
# No json_tool_call found, return tools unchanged
return tools

if len(json_tool_indices) == len(tools):
# All tools are json_tool_call — convert first one to content
verbose_logger.debug(
"Processing JSON tool call response for response_format"
)
json_mode_content_str: Optional[str] = tools[0]["function"].get(
"arguments"
)
if json_mode_content_str is not None:
try:
response_data = json.loads(json_mode_content_str)
if (
isinstance(response_data, dict)
and "properties" in response_data
and len(response_data) == 1
):
response_data = response_data["properties"]
json_mode_content_str = json.dumps(response_data)
except json.JSONDecodeError:
pass
chat_completion_message["content"] = json_mode_content_str
return None

# Mixed: filter out json_tool_call, keep real tools.
# Preserve the json_tool_call content as message text so the structured
# output from response_format is not silently lost.
first_idx = json_tool_indices[0]
json_mode_args = tools[first_idx]["function"].get("arguments")
if json_mode_args is not None:
try:
response_data = json.loads(json_mode_args)
if (
isinstance(response_data, dict)
and "properties" in response_data
and len(response_data) == 1
):
response_data = response_data["properties"]
json_mode_args = json.dumps(response_data)
except json.JSONDecodeError:
pass
existing = chat_completion_message.get("content") or ""
chat_completion_message["content"] = (
existing + json_mode_args if existing else json_mode_args
)

real_tools = [t for i, t in enumerate(tools) if i not in json_tool_indices]
return real_tools if real_tools else None

def _transform_response( # noqa: PLR0915
self,
model: str,
Expand All @@ -1574,7 +1650,7 @@ def _transform_response( # noqa: PLR0915
additional_args={"complete_input_dict": data},
)

json_mode: Optional[bool] = optional_params.pop("json_mode", None)
json_mode: Optional[bool] = optional_params.get("json_mode", None)
## RESPONSE OBJECT
try:
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
Expand Down Expand Up @@ -1658,39 +1734,13 @@ def _transform_response( # noqa: PLR0915
"thinking_blocks"
] = self._transform_thinking_blocks(reasoningContentBlocks)
chat_completion_message["content"] = content_str
if (
json_mode is True
and tools is not None
and len(tools) == 1
and tools[0]["function"].get("name") == RESPONSE_FORMAT_TOOL_NAME
):
verbose_logger.debug(
"Processing JSON tool call response for response_format"
)
json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments")
if json_mode_content_str is not None:
import json

# Bedrock returns the response wrapped in a "properties" object
# We need to extract the actual content from this wrapper
try:
response_data = json.loads(json_mode_content_str)

# If Bedrock wrapped the response in "properties", extract the content
if (
isinstance(response_data, dict)
and "properties" in response_data
and len(response_data) == 1
):
response_data = response_data["properties"]
json_mode_content_str = json.dumps(response_data)
except json.JSONDecodeError:
# If parsing fails, use the original response
pass

chat_completion_message["content"] = json_mode_content_str
else:
chat_completion_message["tool_calls"] = tools
filtered_tools = self._filter_json_mode_tools(
json_mode=json_mode,
tools=tools,
chat_completion_message=chat_completion_message,
)
if filtered_tools:
chat_completion_message["tool_calls"] = filtered_tools

## CALCULATING USAGE - bedrock returns usage in the headers
usage = self._transform_usage(completion_response["usage"])
Expand Down
64 changes: 48 additions & 16 deletions litellm/llms/bedrock/chat/invoke_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from litellm import verbose_logger
from litellm._uuid import uuid
from litellm.caching.caching import InMemoryCache
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
Expand Down Expand Up @@ -252,7 +253,7 @@ async def make_call(
response.aiter_bytes(chunk_size=stream_chunk_size)
)
else:
decoder = AWSEventStreamDecoder(model=model)
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=stream_chunk_size)
)
Expand Down Expand Up @@ -346,7 +347,7 @@ def make_sync_call(
response.iter_bytes(chunk_size=stream_chunk_size)
)
else:
decoder = AWSEventStreamDecoder(model=model)
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
completion_stream = decoder.iter_bytes(
response.iter_bytes(chunk_size=stream_chunk_size)
)
Expand Down Expand Up @@ -1282,14 +1283,16 @@ def get_response_stream_shape():


class AWSEventStreamDecoder:
def __init__(self, model: str) -> None:
def __init__(self, model: str, json_mode: Optional[bool] = False) -> None:
from botocore.parsers import EventStreamJSONParser

self.model = model
self.parser = EventStreamJSONParser()
self.content_blocks: List[ContentBlockDeltaEvent] = []
self.tool_calls_index: Optional[int] = None
self.response_id: Optional[str] = None
self.json_mode = json_mode
self._current_tool_name: Optional[str] = None

def check_empty_tool_call_args(self) -> bool:
"""
Expand Down Expand Up @@ -1391,6 +1394,16 @@ def _handle_converse_start_event(
response_tool_name = get_bedrock_tool_name(
response_tool_name=_response_tool_name
)
self._current_tool_name = response_tool_name

# When json_mode is True, suppress the internal json_tool_call
# and convert its content to text in delta events instead
if (
self.json_mode is True
and response_tool_name == RESPONSE_FORMAT_TOOL_NAME
):
return tool_use, provider_specific_fields, thinking_blocks

self.tool_calls_index = (
0 if self.tool_calls_index is None else self.tool_calls_index + 1
)
Expand Down Expand Up @@ -1445,19 +1458,27 @@ def _handle_converse_delta_event(
if "text" in delta_obj:
text = delta_obj["text"]
elif "toolUse" in delta_obj:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": delta_obj["toolUse"]["input"],
},
"index": (
self.tool_calls_index
if self.tool_calls_index is not None
else index
),
}
# When json_mode is True and this is the internal json_tool_call,
# convert tool input to text content instead of tool call arguments
if (
self.json_mode is True
and self._current_tool_name == RESPONSE_FORMAT_TOOL_NAME
):
text = delta_obj["toolUse"]["input"]
else:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": delta_obj["toolUse"]["input"],
},
"index": (
self.tool_calls_index
if self.tool_calls_index is not None
else index
),
}
elif "reasoningContent" in delta_obj:
provider_specific_fields = {
"reasoningContent": delta_obj["reasoningContent"],
Expand Down Expand Up @@ -1494,6 +1515,17 @@ def _handle_converse_stop_event(
) -> Optional[ChatCompletionToolCallChunk]:
"""Handle stop/contentBlockIndex event in converse chunk parsing."""
tool_use: Optional[ChatCompletionToolCallChunk] = None

# If the ending block was the internal json_tool_call, skip emitting
# the empty-args tool chunk and reset tracking state
if (
self.json_mode is True
and self._current_tool_name == RESPONSE_FORMAT_TOOL_NAME
):
self._current_tool_name = None
return tool_use

self._current_tool_name = None
is_empty = self.check_empty_tool_call_args()
if is_empty:
tool_use = {
Expand Down
22 changes: 20 additions & 2 deletions litellm/proxy/policy_engine/policy_resolve_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,26 @@
"""

import json

from fastapi import APIRouter, Depends, HTTPException, Query
from typing import TYPE_CHECKING

# FastAPI imports - may not be available in all environments (e.g. during testing)
try:
from fastapi import APIRouter, Depends, HTTPException, Query
except ImportError:
# Provide stubs for type checking only
if TYPE_CHECKING:
from fastapi import APIRouter, Depends, HTTPException, Query
else:
# Create mock classes that won't be used
class APIRouter: # type: ignore[no-redef]
def post(self, *args, **kwargs):
def decorator(func):
return func
return decorator

def Depends(func): return func # type: ignore[misc]
HTTPException = Exception # type: ignore[misc,assignment]
def Query(*args, **kwargs): return None # type: ignore[misc]

from litellm._logging import verbose_proxy_logger
from litellm.constants import MAX_POLICY_ESTIMATE_IMPACT_ROWS
Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions tests/mcp_tests/test_mcp_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def test_acompletion_mcp_auto_exec(monkeypatch):
inputSchema={"type": "object", "properties": {}},
)

async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy):
async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy, **kwargs):
return [dummy_tool], {"local_search": "local"}

async def fake_execute(**kwargs):
Expand Down Expand Up @@ -95,7 +95,7 @@ async def test_acompletion_mcp_respects_manual_approval(monkeypatch):
inputSchema={"type": "object", "properties": {}},
)

async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy):
async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy, **kwargs):
return [dummy_tool], {"local_search": "local"}

async def fake_execute(**kwargs):
Expand Down Expand Up @@ -170,7 +170,7 @@ async def test_completion_mcp_with_streaming_no_timeout_error(monkeypatch):
inputSchema={"type": "object", "properties": {}},
)

async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy):
async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy, **kwargs):
return [dummy_tool], {"local_search": "local"}

async def fake_execute(**kwargs):
Expand Down Expand Up @@ -470,7 +470,7 @@ async def test_mcp_metadata_in_streaming_final_chunk(monkeypatch):
inputSchema={"type": "object", "properties": {}},
)

async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy):
async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy, **kwargs):
return [dummy_tool], {"local_search": "local"}

async def fake_execute(**kwargs):
Expand Down Expand Up @@ -793,7 +793,7 @@ async def test_mcp_streaming_metadata_ordering(monkeypatch):
inputSchema={"type": "object", "properties": {}},
)

async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy):
async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy, **kwargs):
return [dummy_tool], {"local_search": "local"}

async def fake_execute(**kwargs):
Expand Down
Loading
Loading