Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion litellm/llms/bedrock/chat/converse_handler.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this test modify bedrock?

Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def make_sync_call(
model_response=model_response, json_mode=json_mode
)
else:
decoder = AWSEventStreamDecoder(model=model)
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=stream_chunk_size))

# LOGGING
Expand Down
119 changes: 85 additions & 34 deletions litellm/llms/bedrock/chat/converse_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import copy
import json
import time
import types
from typing import List, Literal, Optional, Tuple, Union, cast, overload
Expand Down Expand Up @@ -1552,6 +1553,82 @@ def _translate_message_content(

return content_str, tools, reasoningContentBlocks, citationsContentBlocks

@staticmethod
def _filter_json_mode_tools(
json_mode: Optional[bool],
tools: List[ChatCompletionToolCallChunk],
chat_completion_message: ChatCompletionResponseMessage,
) -> Optional[List[ChatCompletionToolCallChunk]]:
"""
When json_mode is True, Bedrock may return the internal `json_tool_call`
tool alongside real user-defined tools. This method handles 3 scenarios:

1. Only json_tool_call present -> convert to text content, return None
2. Mixed json_tool_call + real -> filter out json_tool_call, return real tools
3. No json_tool_call / no json_mode -> return tools as-is
"""
if not json_mode or not tools:
return tools if tools else None

json_tool_indices = [
i
for i, t in enumerate(tools)
if t["function"].get("name") == RESPONSE_FORMAT_TOOL_NAME
]

if not json_tool_indices:
# No json_tool_call found, return tools unchanged
return tools

if len(json_tool_indices) == len(tools):
# All tools are json_tool_call — convert first one to content
verbose_logger.debug(
"Processing JSON tool call response for response_format"
)
json_mode_content_str: Optional[str] = tools[0]["function"].get(
"arguments"
)
if json_mode_content_str is not None:
try:
response_data = json.loads(json_mode_content_str)
if (
isinstance(response_data, dict)
and "properties" in response_data
and len(response_data) == 1
):
response_data = response_data["properties"]
json_mode_content_str = json.dumps(response_data)
except json.JSONDecodeError:
pass
chat_completion_message["content"] = json_mode_content_str
return None

# Mixed: filter out json_tool_call, keep real tools.
# Preserve the json_tool_call content as message text so the structured
# output from response_format is not silently lost.
for idx in json_tool_indices:
json_mode_args = tools[idx]["function"].get("arguments")
if json_mode_args is not None:
try:
response_data = json.loads(json_mode_args)
if (
isinstance(response_data, dict)
and "properties" in response_data
and len(response_data) == 1
):
response_data = response_data["properties"]
json_mode_args = json.dumps(response_data)
except json.JSONDecodeError:
pass
existing = chat_completion_message.get("content") or ""
chat_completion_message["content"] = (
existing + json_mode_args if existing else json_mode_args
)
break # only use the first json_tool_call

real_tools = [t for i, t in enumerate(tools) if i not in json_tool_indices]
return real_tools if real_tools else None

def _transform_response( # noqa: PLR0915
self,
model: str,
Expand All @@ -1574,7 +1651,7 @@ def _transform_response( # noqa: PLR0915
additional_args={"complete_input_dict": data},
)

json_mode: Optional[bool] = optional_params.pop("json_mode", None)
json_mode: Optional[bool] = optional_params.get("json_mode", None)
## RESPONSE OBJECT
try:
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
Expand Down Expand Up @@ -1658,39 +1735,13 @@ def _transform_response( # noqa: PLR0915
"thinking_blocks"
] = self._transform_thinking_blocks(reasoningContentBlocks)
chat_completion_message["content"] = content_str
if (
json_mode is True
and tools is not None
and len(tools) == 1
and tools[0]["function"].get("name") == RESPONSE_FORMAT_TOOL_NAME
):
verbose_logger.debug(
"Processing JSON tool call response for response_format"
)
json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments")
if json_mode_content_str is not None:
import json

# Bedrock returns the response wrapped in a "properties" object
# We need to extract the actual content from this wrapper
try:
response_data = json.loads(json_mode_content_str)

# If Bedrock wrapped the response in "properties", extract the content
if (
isinstance(response_data, dict)
and "properties" in response_data
and len(response_data) == 1
):
response_data = response_data["properties"]
json_mode_content_str = json.dumps(response_data)
except json.JSONDecodeError:
# If parsing fails, use the original response
pass

chat_completion_message["content"] = json_mode_content_str
else:
chat_completion_message["tool_calls"] = tools
filtered_tools = self._filter_json_mode_tools(
json_mode=json_mode,
tools=tools,
chat_completion_message=chat_completion_message,
)
if filtered_tools:
chat_completion_message["tool_calls"] = filtered_tools

## CALCULATING USAGE - bedrock returns usage in the headers
usage = self._transform_usage(completion_response["usage"])
Expand Down
64 changes: 48 additions & 16 deletions litellm/llms/bedrock/chat/invoke_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from litellm import verbose_logger
from litellm._uuid import uuid
from litellm.caching.caching import InMemoryCache
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
Expand Down Expand Up @@ -252,7 +253,7 @@ async def make_call(
response.aiter_bytes(chunk_size=stream_chunk_size)
)
else:
decoder = AWSEventStreamDecoder(model=model)
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=stream_chunk_size)
)
Expand Down Expand Up @@ -346,7 +347,7 @@ def make_sync_call(
response.iter_bytes(chunk_size=stream_chunk_size)
)
else:
decoder = AWSEventStreamDecoder(model=model)
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
completion_stream = decoder.iter_bytes(
response.iter_bytes(chunk_size=stream_chunk_size)
)
Expand Down Expand Up @@ -1282,14 +1283,16 @@ def get_response_stream_shape():


class AWSEventStreamDecoder:
def __init__(self, model: str) -> None:
def __init__(self, model: str, json_mode: Optional[bool] = False) -> None:
from botocore.parsers import EventStreamJSONParser

self.model = model
self.parser = EventStreamJSONParser()
self.content_blocks: List[ContentBlockDeltaEvent] = []
self.tool_calls_index: Optional[int] = None
self.response_id: Optional[str] = None
self.json_mode = json_mode
self._current_tool_name: Optional[str] = None

def check_empty_tool_call_args(self) -> bool:
"""
Expand Down Expand Up @@ -1391,6 +1394,16 @@ def _handle_converse_start_event(
response_tool_name = get_bedrock_tool_name(
response_tool_name=_response_tool_name
)
self._current_tool_name = response_tool_name

# When json_mode is True, suppress the internal json_tool_call
# and convert its content to text in delta events instead
if (
self.json_mode is True
and response_tool_name == RESPONSE_FORMAT_TOOL_NAME
):
return tool_use, provider_specific_fields, thinking_blocks

self.tool_calls_index = (
0 if self.tool_calls_index is None else self.tool_calls_index + 1
)
Expand Down Expand Up @@ -1445,19 +1458,27 @@ def _handle_converse_delta_event(
if "text" in delta_obj:
text = delta_obj["text"]
elif "toolUse" in delta_obj:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": delta_obj["toolUse"]["input"],
},
"index": (
self.tool_calls_index
if self.tool_calls_index is not None
else index
),
}
# When json_mode is True and this is the internal json_tool_call,
# convert tool input to text content instead of tool call arguments
if (
self.json_mode is True
and self._current_tool_name == RESPONSE_FORMAT_TOOL_NAME
):
text = delta_obj["toolUse"]["input"]
else:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": delta_obj["toolUse"]["input"],
},
"index": (
self.tool_calls_index
if self.tool_calls_index is not None
else index
),
}
elif "reasoningContent" in delta_obj:
provider_specific_fields = {
"reasoningContent": delta_obj["reasoningContent"],
Expand Down Expand Up @@ -1494,6 +1515,17 @@ def _handle_converse_stop_event(
) -> Optional[ChatCompletionToolCallChunk]:
"""Handle stop/contentBlockIndex event in converse chunk parsing."""
tool_use: Optional[ChatCompletionToolCallChunk] = None

# If the ending block was the internal json_tool_call, skip emitting
# the empty-args tool chunk and reset tracking state
if (
self.json_mode is True
and self._current_tool_name == RESPONSE_FORMAT_TOOL_NAME
):
self._current_tool_name = None
return tool_use

self._current_tool_name = None
is_empty = self.check_empty_tool_call_args()
if is_empty:
tool_use = {
Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions tests/mcp_tests/test_mcp_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def test_acompletion_mcp_auto_exec(monkeypatch):
inputSchema={"type": "object", "properties": {}},
)

async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy):
async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy, **kwargs):
return [dummy_tool], {"local_search": "local"}

async def fake_execute(**kwargs):
Expand Down Expand Up @@ -95,7 +95,7 @@ async def test_acompletion_mcp_respects_manual_approval(monkeypatch):
inputSchema={"type": "object", "properties": {}},
)

async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy):
async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy, **kwargs):
return [dummy_tool], {"local_search": "local"}

async def fake_execute(**kwargs):
Expand Down Expand Up @@ -170,7 +170,7 @@ async def test_completion_mcp_with_streaming_no_timeout_error(monkeypatch):
inputSchema={"type": "object", "properties": {}},
)

async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy):
async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy, **kwargs):
return [dummy_tool], {"local_search": "local"}

async def fake_execute(**kwargs):
Expand Down Expand Up @@ -470,7 +470,7 @@ async def test_mcp_metadata_in_streaming_final_chunk(monkeypatch):
inputSchema={"type": "object", "properties": {}},
)

async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy):
async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy, **kwargs):
return [dummy_tool], {"local_search": "local"}

async def fake_execute(**kwargs):
Expand Down Expand Up @@ -793,7 +793,7 @@ async def test_mcp_streaming_metadata_ordering(monkeypatch):
inputSchema={"type": "object", "properties": {}},
)

async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy):
async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy, **kwargs):
return [dummy_tool], {"local_search": "local"}

async def fake_execute(**kwargs):
Expand Down
11 changes: 6 additions & 5 deletions tests/test_litellm/integrations/cloudzero/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def test_create_cbf_record(self):
transformer = CBFTransformer()
with patch.object(transformer.czrn_generator, 'create_from_litellm_data') as mock_czrn, \
patch.object(transformer.czrn_generator, 'extract_components') as mock_extract:

mock_czrn.return_value = 'test-czrn'
mock_extract.return_value = ('service', 'provider', 'region', 'account', 'resource', 'local_id')

row = {
'date': '2025-01-19',
'spend': 10.5,
Expand All @@ -107,14 +107,15 @@ def test_create_cbf_record(self):
'successful_requests': 5,
'failed_requests': 0
}

result = transformer._create_cbf_record(row)

assert isinstance(result, CBFRecord)
assert result['cost/cost'] == 10.5
assert result['usage/amount'] == 150 # 100 + 50
assert result['usage/units'] == 'tokens'
assert result['resource/id'] == 'test-czrn'
# resource/id is set to model name per implementation (line 144 in transform.py)
assert result['resource/id'] == 'gpt-4'

def test_create_cbf_record_adds_user_email_tag(self):
"""Test that user_email field is emitted as a resource tag when present."""
Expand Down
Loading
Loading