2323 ModelThrottledException ,
2424)
2525from ..types .streaming import CitationsDelta , StreamEvent
26- from ..types .tools import ToolResult , ToolSpec
26+ from ..types .tools import ToolChoice , ToolResult , ToolSpec
2727from ._config_validation import validate_config_keys
2828from .model import Model
2929
@@ -195,13 +195,15 @@ def format_request(
195195 messages : Messages ,
196196 tool_specs : Optional [list [ToolSpec ]] = None ,
197197 system_prompt : Optional [str ] = None ,
198+ tool_choice : Optional [ToolChoice ] = None ,
198199 ) -> dict [str , Any ]:
199200 """Format a Bedrock converse stream request.
200201
201202 Args:
202203 messages: List of message objects to be processed by the model.
203204 tool_specs: List of tool specifications to make available to the model.
204205 system_prompt: System prompt to provide context to the model.
206+ tool_choice: Selection strategy for tool invocation.
205207
206208 Returns:
207209 A Bedrock converse stream request.
@@ -224,7 +226,7 @@ def format_request(
224226 else []
225227 ),
226228 ],
227- "toolChoice" : { "auto" : {}} ,
229+ ** ({ "toolChoice" : tool_choice } if tool_choice else {}) ,
228230 }
229231 }
230232 if tool_specs
@@ -416,6 +418,7 @@ async def stream(
416418 messages : Messages ,
417419 tool_specs : Optional [list [ToolSpec ]] = None ,
418420 system_prompt : Optional [str ] = None ,
421+ tool_choice : Optional [ToolChoice ] = None ,
419422 ** kwargs : Any ,
420423 ) -> AsyncGenerator [StreamEvent , None ]:
421424 """Stream conversation with the Bedrock model.
@@ -427,6 +430,7 @@ async def stream(
427430 messages: List of message objects to be processed by the model.
428431 tool_specs: List of tool specifications to make available to the model.
429432 system_prompt: System prompt to provide context to the model.
433+ tool_choice: Selection strategy for tool invocation.
430434 **kwargs: Additional keyword arguments for future extensibility.
431435
432436 Yields:
@@ -445,7 +449,7 @@ def callback(event: Optional[StreamEvent] = None) -> None:
445449 loop = asyncio .get_event_loop ()
446450 queue : asyncio .Queue [Optional [StreamEvent ]] = asyncio .Queue ()
447451
448- thread = asyncio .to_thread (self ._stream , callback , messages , tool_specs , system_prompt )
452+ thread = asyncio .to_thread (self ._stream , callback , messages , tool_specs , system_prompt , tool_choice )
449453 task = asyncio .create_task (thread )
450454
451455 while True :
@@ -463,6 +467,7 @@ def _stream(
463467 messages : Messages ,
464468 tool_specs : Optional [list [ToolSpec ]] = None ,
465469 system_prompt : Optional [str ] = None ,
470+ tool_choice : Optional [ToolChoice ] = None ,
466471 ) -> None :
467472 """Stream conversation with the Bedrock model.
468473
@@ -474,14 +479,15 @@ def _stream(
474479 messages: List of message objects to be processed by the model.
475480 tool_specs: List of tool specifications to make available to the model.
476481 system_prompt: System prompt to provide context to the model.
482+ tool_choice: Selection strategy for tool invocation.
477483
478484 Raises:
479485 ContextWindowOverflowException: If the input exceeds the model's context window.
480486 ModelThrottledException: If the model service is throttling requests.
481487 """
482488 try :
483489 logger .debug ("formatting request" )
484- request = self .format_request (messages , tool_specs , system_prompt )
490+ request = self .format_request (messages , tool_specs , system_prompt , tool_choice )
485491 logger .debug ("request=<%s>" , request )
486492
487493 logger .debug ("invoking model" )
@@ -738,6 +744,7 @@ async def structured_output(
738744 messages = prompt ,
739745 tool_specs = [tool_spec ],
740746 system_prompt = system_prompt ,
747+ tool_choice = cast (ToolChoice , {"any" : {}}),
741748 ** kwargs ,
742749 )
743750 async for event in streaming .process_stream (response ):
0 commit comments