diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 306d63b77d0..a0f2f65fb7f 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -511,6 +511,7 @@ def get_supported_openai_params(self, model: str) -> List[str]: "response_format", "requestMetadata", "service_tier", + "parallel_tool_calls", ] if ( @@ -913,6 +914,13 @@ def map_openai_params( ) if _tool_choice_value is not None: optional_params["tool_choice"] = _tool_choice_value + if param == "parallel_tool_calls": + disable_parallel = not value + optional_params["_parallel_tool_use_config"] = { + "tool_choice": { + "disable_parallel_tool_use": disable_parallel + } + } if param == "thinking": optional_params["thinking"] = value elif param == "reasoning_effort" and isinstance(value, str): diff --git a/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py b/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py index 11d6bb028d8..56c1cfa8515 100644 --- a/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py +++ b/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py @@ -430,19 +430,36 @@ async def apply_guardrail(self, inputs, request_data, input_type, logging_obj=No streaming = RealTimeStreaming(client_ws, backend_ws, logging_obj) await streaming.backend_to_client_send_messages() - # ASSERT 1: no response.create was sent to backend (injection blocked). + # ASSERT 1: the guardrail blocked the normal auto-response and instead + # injected a conversation.item.create + response.create to voice the + # violation message. There should be exactly ONE response.create (the + # guardrail-triggered one), preceded by a response.cancel and a + # conversation.item.create carrying the violation text. sent_to_backend = [ json.loads(c.args[0]) for c in backend_ws.send.call_args_list if c.args ] - response_creates = [ + response_cancels = [ + e for e in sent_to_backend if e.get("type") == "response.cancel" + ] + assert len(response_cancels) == 1, ( + f"Guardrail should send response.cancel, got: {response_cancels}" + ) + guardrail_items = [ e for e in sent_to_backend - if e.get("type") == "response.create" + if e.get("type") == "conversation.item.create" ] - assert len(response_creates) == 0, ( - f"Guardrail should prevent response.create for injected content, " - f"but got: {response_creates}" + assert len(guardrail_items) == 1, ( + f"Guardrail should inject a conversation.item.create with violation message, " + f"got: {guardrail_items}" + ) + response_creates = [ + e for e in sent_to_backend if e.get("type") == "response.create" + ] + assert len(response_creates) == 1, ( + f"Guardrail should send exactly one response.create to voice the violation, " + f"got: {response_creates}" ) # ASSERT 2: error event was sent directly to the client WebSocket @@ -595,14 +612,26 @@ async def apply_guardrail(self, inputs, request_data, input_type, logging_obj=No assert len(error_events) == 1, f"Expected one error event, got: {sent_texts}" assert error_events[0]["error"]["type"] == "guardrail_violation" - # ASSERT: blocked item was NOT forwarded to the backend + # ASSERT: the original blocked item was NOT forwarded to the backend. + # The guardrail handler injects its own conversation.item.create with + # the violation message — only that one should be present, not the + # original user message. sent_to_backend = [c.args[0] for c in backend_ws.send.call_args_list if c.args] forwarded_items = [ json.loads(m) for m in sent_to_backend if isinstance(m, str) and json.loads(m).get("type") == "conversation.item.create" ] - assert len(forwarded_items) == 0, ( - f"Blocked item should not be forwarded to backend, got: {forwarded_items}" + # Filter out guardrail-injected items (contain "Say exactly the following message") + original_items = [ + item for item in forwarded_items + if not any( + "Say exactly the following message" in c.get("text", "") + for c in item.get("item", {}).get("content", []) + if isinstance(c, dict) + ) + ] + assert len(original_items) == 0, ( + f"Blocked item should not be forwarded to backend, got: {original_items}" ) litellm.callbacks = [] # cleanup