Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion litellm/llms/bedrock/chat/converse_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def make_sync_call(
model_response=model_response, json_mode=json_mode
)
else:
decoder = AWSEventStreamDecoder(model=model)
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=stream_chunk_size))

# LOGGING
Expand Down
126 changes: 94 additions & 32 deletions litellm/llms/bedrock/chat/converse_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,92 @@ def _translate_message_content(self, content_blocks: List[ContentBlock]) -> Tupl

return content_str, tools, reasoningContentBlocks, citationsContentBlocks

@staticmethod
def _unwrap_bedrock_properties(json_str: str) -> str:
"""
Unwrap Bedrock's response_format JSON structure.

If the JSON has a single "properties" key, extract its value.
Otherwise, return the original string.

Args:
json_str: JSON string to unwrap

Returns:
Unwrapped JSON string or original if unwrapping not needed
"""
try:
response_data = json.loads(json_str)
if (
isinstance(response_data, dict)
and "properties" in response_data
and len(response_data) == 1
):
response_data = response_data["properties"]
return json.dumps(response_data)
except json.JSONDecodeError:
pass
return json_str

@staticmethod
def _filter_json_mode_tools(
json_mode: Optional[bool],
tools: List[ChatCompletionToolCallChunk],
chat_completion_message: ChatCompletionResponseMessage,
) -> Optional[List[ChatCompletionToolCallChunk]]:
"""
When json_mode is True, Bedrock may return the internal `json_tool_call`
tool alongside real user-defined tools. This method handles 3 scenarios:

1. Only json_tool_call present -> convert to text content, return None
2. Mixed json_tool_call + real -> filter out json_tool_call, return real tools
3. No json_tool_call / no json_mode -> return tools as-is
"""
if not json_mode or not tools:
return tools if tools else None

json_tool_indices = [
i
for i, t in enumerate(tools)
if t["function"].get("name") == RESPONSE_FORMAT_TOOL_NAME
]

if not json_tool_indices:
# No json_tool_call found, return tools unchanged
return tools

if len(json_tool_indices) == len(tools):
# All tools are json_tool_call — convert first one to content
verbose_logger.debug(
"Processing JSON tool call response for response_format"
)
json_mode_content_str: Optional[str] = tools[0]["function"].get(
"arguments"
)
if json_mode_content_str is not None:
json_mode_content_str = AmazonConverseConfig._unwrap_bedrock_properties(
json_mode_content_str
)
chat_completion_message["content"] = json_mode_content_str
return None

# Mixed: filter out json_tool_call, keep real tools.
# Preserve the json_tool_call content as message text so the structured
# output from response_format is not silently lost.
first_idx = json_tool_indices[0]
json_mode_args = tools[first_idx]["function"].get("arguments")
if json_mode_args is not None:
json_mode_args = AmazonConverseConfig._unwrap_bedrock_properties(
json_mode_args
)
existing = chat_completion_message.get("content") or ""
chat_completion_message["content"] = (
existing + json_mode_args if existing else json_mode_args
)

real_tools = [t for i, t in enumerate(tools) if i not in json_tool_indices]
return real_tools if real_tools else None

def _transform_response( # noqa: PLR0915
self,
model: str,
Expand All @@ -1801,7 +1887,7 @@ def _transform_response( # noqa: PLR0915
additional_args={"complete_input_dict": data},
)

json_mode: Optional[bool] = optional_params.pop("json_mode", None)
json_mode: Optional[bool] = optional_params.get("json_mode", None)
## RESPONSE OBJECT
try:
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
Expand Down Expand Up @@ -1885,37 +1971,13 @@ def _transform_response( # noqa: PLR0915
self._transform_thinking_blocks(reasoningContentBlocks)
)
chat_completion_message["content"] = content_str
if (
json_mode is True
and tools is not None
and len(tools) == 1
and tools[0]["function"].get("name") == RESPONSE_FORMAT_TOOL_NAME
):
verbose_logger.debug(
"Processing JSON tool call response for response_format"
)
json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments")
if json_mode_content_str is not None:
# Bedrock returns the response wrapped in a "properties" object
# We need to extract the actual content from this wrapper
try:
response_data = json.loads(json_mode_content_str)

# If Bedrock wrapped the response in "properties", extract the content
if (
isinstance(response_data, dict)
and "properties" in response_data
and len(response_data) == 1
):
response_data = response_data["properties"]
json_mode_content_str = json.dumps(response_data)
except json.JSONDecodeError:
# If parsing fails, use the original response
pass

chat_completion_message["content"] = json_mode_content_str
elif tools:
chat_completion_message["tool_calls"] = tools
filtered_tools = self._filter_json_mode_tools(
json_mode=json_mode,
tools=tools,
chat_completion_message=chat_completion_message,
)
if filtered_tools:
chat_completion_message["tool_calls"] = filtered_tools

## CALCULATING USAGE - bedrock returns usage in the headers
usage = self._transform_usage(completion_response["usage"])
Expand Down
64 changes: 48 additions & 16 deletions litellm/llms/bedrock/chat/invoke_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from litellm import verbose_logger
from litellm._uuid import uuid
from litellm.caching.caching import InMemoryCache
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
Expand Down Expand Up @@ -252,7 +253,7 @@ async def make_call(
response.aiter_bytes(chunk_size=stream_chunk_size)
)
else:
decoder = AWSEventStreamDecoder(model=model)
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=stream_chunk_size)
)
Expand Down Expand Up @@ -346,7 +347,7 @@ def make_sync_call(
response.iter_bytes(chunk_size=stream_chunk_size)
)
else:
decoder = AWSEventStreamDecoder(model=model)
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
completion_stream = decoder.iter_bytes(
response.iter_bytes(chunk_size=stream_chunk_size)
)
Expand Down Expand Up @@ -1282,14 +1283,16 @@ def get_response_stream_shape():


class AWSEventStreamDecoder:
def __init__(self, model: str) -> None:
def __init__(self, model: str, json_mode: Optional[bool] = False) -> None:
from botocore.parsers import EventStreamJSONParser

self.model = model
self.parser = EventStreamJSONParser()
self.content_blocks: List[ContentBlockDeltaEvent] = []
self.tool_calls_index: Optional[int] = None
self.response_id: Optional[str] = None
self.json_mode = json_mode
self._current_tool_name: Optional[str] = None

def check_empty_tool_call_args(self) -> bool:
"""
Expand Down Expand Up @@ -1391,6 +1394,16 @@ def _handle_converse_start_event(
response_tool_name = get_bedrock_tool_name(
response_tool_name=_response_tool_name
)
self._current_tool_name = response_tool_name

# When json_mode is True, suppress the internal json_tool_call
# and convert its content to text in delta events instead
if (
self.json_mode is True
and response_tool_name == RESPONSE_FORMAT_TOOL_NAME
):
return tool_use, provider_specific_fields, thinking_blocks

self.tool_calls_index = (
0 if self.tool_calls_index is None else self.tool_calls_index + 1
)
Expand Down Expand Up @@ -1445,19 +1458,27 @@ def _handle_converse_delta_event(
if "text" in delta_obj:
text = delta_obj["text"]
elif "toolUse" in delta_obj:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": delta_obj["toolUse"]["input"],
},
"index": (
self.tool_calls_index
if self.tool_calls_index is not None
else index
),
}
# When json_mode is True and this is the internal json_tool_call,
# convert tool input to text content instead of tool call arguments
if (
self.json_mode is True
and self._current_tool_name == RESPONSE_FORMAT_TOOL_NAME
):
text = delta_obj["toolUse"]["input"]
else:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": delta_obj["toolUse"]["input"],
},
"index": (
self.tool_calls_index
if self.tool_calls_index is not None
else index
),
}
elif "reasoningContent" in delta_obj:
provider_specific_fields = {
"reasoningContent": delta_obj["reasoningContent"],
Expand Down Expand Up @@ -1494,6 +1515,17 @@ def _handle_converse_stop_event(
) -> Optional[ChatCompletionToolCallChunk]:
"""Handle stop/contentBlockIndex event in converse chunk parsing."""
tool_use: Optional[ChatCompletionToolCallChunk] = None

# If the ending block was the internal json_tool_call, skip emitting
# the empty-args tool chunk and reset tracking state
if (
self.json_mode is True
and self._current_tool_name == RESPONSE_FORMAT_TOOL_NAME
):
self._current_tool_name = None
return tool_use

self._current_tool_name = None
is_empty = self.check_empty_tool_call_args()
if is_empty:
tool_use = {
Expand Down
Loading
Loading