Skip to content

Commit d691ea9

Browse files
zastrowmShang Liu
andcommitted
feat: Add support for toolChoice to providers
For structured output so that some providers can force tool calls Co-authored-by: Shang Liu <[email protected]>
1 parent d66fcdb commit d691ea9

File tree

14 files changed

+361
-30
lines changed

14 files changed

+361
-30
lines changed

src/strands/models/anthropic.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..types.content import ContentBlock, Messages
1919
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
2020
from ..types.streaming import StreamEvent
21-
from ..types.tools import ToolSpec
21+
from ..types.tools import ToolChoice, ToolSpec
2222
from ._config_validation import validate_config_keys
2323
from .model import Model
2424

@@ -195,14 +195,19 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
195195
return formatted_messages
196196

197197
def format_request(
198-
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
198+
self,
199+
messages: Messages,
200+
tool_specs: Optional[list[ToolSpec]] = None,
201+
system_prompt: Optional[str] = None,
202+
tool_choice: Optional[ToolChoice] = None,
199203
) -> dict[str, Any]:
200204
"""Format an Anthropic streaming request.
201205
202206
Args:
203207
messages: List of message objects to be processed by the model.
204208
tool_specs: List of tool specifications to make available to the model.
205209
system_prompt: System prompt to provide context to the model.
210+
tool_choice: Selection strategy for tool invocation.
206211
207212
Returns:
208213
An Anthropic streaming request.
@@ -223,6 +228,7 @@ def format_request(
223228
}
224229
for tool_spec in tool_specs or []
225230
],
231+
**({"tool_choice": tool_choice} if tool_choice else {}),
226232
**({"system": system_prompt} if system_prompt else {}),
227233
**(self.config.get("params") or {}),
228234
}
@@ -350,6 +356,7 @@ async def stream(
350356
messages: Messages,
351357
tool_specs: Optional[list[ToolSpec]] = None,
352358
system_prompt: Optional[str] = None,
359+
tool_choice: Optional[ToolChoice] = None,
353360
**kwargs: Any,
354361
) -> AsyncGenerator[StreamEvent, None]:
355362
"""Stream conversation with the Anthropic model.
@@ -358,6 +365,7 @@ async def stream(
358365
messages: List of message objects to be processed by the model.
359366
tool_specs: List of tool specifications to make available to the model.
360367
system_prompt: System prompt to provide context to the model.
368+
tool_choice: Selection strategy for tool invocation.
361369
**kwargs: Additional keyword arguments for future extensibility.
362370
363371
Yields:
@@ -368,7 +376,7 @@ async def stream(
368376
ModelThrottledException: If the request is throttled by Anthropic.
369377
"""
370378
logger.debug("formatting request")
371-
request = self.format_request(messages, tool_specs, system_prompt)
379+
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
372380
logger.debug("request=<%s>", request)
373381

374382
logger.debug("invoking model")
@@ -410,7 +418,13 @@ async def structured_output(
410418
"""
411419
tool_spec = convert_pydantic_to_tool_spec(output_model)
412420

413-
response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs)
421+
response = self.stream(
422+
messages=prompt,
423+
tool_specs=[tool_spec],
424+
system_prompt=system_prompt,
425+
tool_choice=cast(ToolChoice, {"any": {}}),
426+
**kwargs,
427+
)
414428
async for event in process_stream(response):
415429
yield event
416430

src/strands/models/bedrock.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
ModelThrottledException,
2424
)
2525
from ..types.streaming import CitationsDelta, StreamEvent
26-
from ..types.tools import ToolResult, ToolSpec
26+
from ..types.tools import ToolChoice, ToolResult, ToolSpec
2727
from ._config_validation import validate_config_keys
2828
from .model import Model
2929

@@ -195,13 +195,15 @@ def format_request(
195195
messages: Messages,
196196
tool_specs: Optional[list[ToolSpec]] = None,
197197
system_prompt: Optional[str] = None,
198+
tool_choice: Optional[ToolChoice] = None,
198199
) -> dict[str, Any]:
199200
"""Format a Bedrock converse stream request.
200201
201202
Args:
202203
messages: List of message objects to be processed by the model.
203204
tool_specs: List of tool specifications to make available to the model.
204205
system_prompt: System prompt to provide context to the model.
206+
tool_choice: Selection strategy for tool invocation.
205207
206208
Returns:
207209
A Bedrock converse stream request.
@@ -224,7 +226,7 @@ def format_request(
224226
else []
225227
),
226228
],
227-
"toolChoice": {"auto": {}},
229+
**({"toolChoice": tool_choice} if tool_choice else {}),
228230
}
229231
}
230232
if tool_specs
@@ -416,6 +418,7 @@ async def stream(
416418
messages: Messages,
417419
tool_specs: Optional[list[ToolSpec]] = None,
418420
system_prompt: Optional[str] = None,
421+
tool_choice: Optional[ToolChoice] = None,
419422
**kwargs: Any,
420423
) -> AsyncGenerator[StreamEvent, None]:
421424
"""Stream conversation with the Bedrock model.
@@ -427,6 +430,7 @@ async def stream(
427430
messages: List of message objects to be processed by the model.
428431
tool_specs: List of tool specifications to make available to the model.
429432
system_prompt: System prompt to provide context to the model.
433+
tool_choice: Selection strategy for tool invocation.
430434
**kwargs: Additional keyword arguments for future extensibility.
431435
432436
Yields:
@@ -445,7 +449,7 @@ def callback(event: Optional[StreamEvent] = None) -> None:
445449
loop = asyncio.get_event_loop()
446450
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()
447451

448-
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
452+
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt, tool_choice)
449453
task = asyncio.create_task(thread)
450454

451455
while True:
@@ -463,6 +467,7 @@ def _stream(
463467
messages: Messages,
464468
tool_specs: Optional[list[ToolSpec]] = None,
465469
system_prompt: Optional[str] = None,
470+
tool_choice: Optional[ToolChoice] = None,
466471
) -> None:
467472
"""Stream conversation with the Bedrock model.
468473
@@ -474,14 +479,15 @@ def _stream(
474479
messages: List of message objects to be processed by the model.
475480
tool_specs: List of tool specifications to make available to the model.
476481
system_prompt: System prompt to provide context to the model.
482+
tool_choice: Selection strategy for tool invocation.
477483
478484
Raises:
479485
ContextWindowOverflowException: If the input exceeds the model's context window.
480486
ModelThrottledException: If the model service is throttling requests.
481487
"""
482488
try:
483489
logger.debug("formatting request")
484-
request = self.format_request(messages, tool_specs, system_prompt)
490+
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
485491
logger.debug("request=<%s>", request)
486492

487493
logger.debug("invoking model")
@@ -738,6 +744,7 @@ async def structured_output(
738744
messages=prompt,
739745
tool_specs=[tool_spec],
740746
system_prompt=system_prompt,
747+
tool_choice=cast(ToolChoice, {"any": {}}),
741748
**kwargs,
742749
)
743750
async for event in streaming.process_stream(response):

src/strands/models/litellm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ..types.content import ContentBlock, Messages
1616
from ..types.streaming import StreamEvent
17-
from ..types.tools import ToolSpec
17+
from ..types.tools import ToolChoice, ToolSpec
1818
from ._config_validation import validate_config_keys
1919
from .openai import OpenAIModel
2020

@@ -114,6 +114,7 @@ async def stream(
114114
messages: Messages,
115115
tool_specs: Optional[list[ToolSpec]] = None,
116116
system_prompt: Optional[str] = None,
117+
tool_choice: Optional[ToolChoice] = None,
117118
**kwargs: Any,
118119
) -> AsyncGenerator[StreamEvent, None]:
119120
"""Stream conversation with the LiteLLM model.
@@ -122,6 +123,8 @@ async def stream(
122123
messages: List of message objects to be processed by the model.
123124
tool_specs: List of tool specifications to make available to the model.
124125
system_prompt: System prompt to provide context to the model.
126+
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
127+
interface consistency but is currently ignored for this model provider.**
125128
**kwargs: Additional keyword arguments for future extensibility.
126129
127130
Yields:

src/strands/models/llamaapi.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..types.content import ContentBlock, Messages
1919
from ..types.exceptions import ModelThrottledException
2020
from ..types.streaming import StreamEvent, Usage
21-
from ..types.tools import ToolResult, ToolSpec, ToolUse
21+
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
2222
from ._config_validation import validate_config_keys
2323
from .model import Model
2424

@@ -330,6 +330,7 @@ async def stream(
330330
messages: Messages,
331331
tool_specs: Optional[list[ToolSpec]] = None,
332332
system_prompt: Optional[str] = None,
333+
tool_choice: Optional[ToolChoice] = None,
333334
**kwargs: Any,
334335
) -> AsyncGenerator[StreamEvent, None]:
335336
"""Stream conversation with the LlamaAPI model.
@@ -338,6 +339,8 @@ async def stream(
338339
messages: List of message objects to be processed by the model.
339340
tool_specs: List of tool specifications to make available to the model.
340341
system_prompt: System prompt to provide context to the model.
342+
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
343+
interface consistency but is currently ignored for this model provider.**
341344
**kwargs: Additional keyword arguments for future extensibility.
342345
343346
Yields:

src/strands/models/mistral.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..types.content import ContentBlock, Messages
1616
from ..types.exceptions import ModelThrottledException
1717
from ..types.streaming import StopReason, StreamEvent
18-
from ..types.tools import ToolResult, ToolSpec, ToolUse
18+
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
1919
from ._config_validation import validate_config_keys
2020
from .model import Model
2121

@@ -397,6 +397,7 @@ async def stream(
397397
messages: Messages,
398398
tool_specs: Optional[list[ToolSpec]] = None,
399399
system_prompt: Optional[str] = None,
400+
tool_choice: Optional[ToolChoice] = None,
400401
**kwargs: Any,
401402
) -> AsyncGenerator[StreamEvent, None]:
402403
"""Stream conversation with the Mistral model.
@@ -405,6 +406,8 @@ async def stream(
405406
messages: List of message objects to be processed by the model.
406407
tool_specs: List of tool specifications to make available to the model.
407408
system_prompt: System prompt to provide context to the model.
409+
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
410+
interface consistency but is currently ignored for this model provider.**
408411
**kwargs: Additional keyword arguments for future extensibility.
409412
410413
Yields:

src/strands/models/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ..types.content import Messages
1010
from ..types.streaming import StreamEvent
11-
from ..types.tools import ToolSpec
11+
from ..types.tools import ToolChoice, ToolSpec
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -70,6 +70,7 @@ def stream(
7070
messages: Messages,
7171
tool_specs: Optional[list[ToolSpec]] = None,
7272
system_prompt: Optional[str] = None,
73+
tool_choice: Optional[ToolChoice] = None,
7374
**kwargs: Any,
7475
) -> AsyncIterable[StreamEvent]:
7576
"""Stream conversation with the model.
@@ -84,6 +85,7 @@ def stream(
8485
messages: List of message objects to be processed by the model.
8586
tool_specs: List of tool specifications to make available to the model.
8687
system_prompt: System prompt to provide context to the model.
88+
tool_choice: Selection strategy for tool invocation.
8789
**kwargs: Additional keyword arguments for future extensibility.
8890
8991
Yields:

src/strands/models/ollama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from ..types.content import ContentBlock, Messages
1515
from ..types.streaming import StopReason, StreamEvent
16-
from ..types.tools import ToolSpec
16+
from ..types.tools import ToolChoice, ToolSpec
1717
from ._config_validation import validate_config_keys
1818
from .model import Model
1919

@@ -287,6 +287,7 @@ async def stream(
287287
messages: Messages,
288288
tool_specs: Optional[list[ToolSpec]] = None,
289289
system_prompt: Optional[str] = None,
290+
tool_choice: Optional[ToolChoice] = None,
290291
**kwargs: Any,
291292
) -> AsyncGenerator[StreamEvent, None]:
292293
"""Stream conversation with the Ollama model.
@@ -295,6 +296,8 @@ async def stream(
295296
messages: List of message objects to be processed by the model.
296297
tool_specs: List of tool specifications to make available to the model.
297298
system_prompt: System prompt to provide context to the model.
299+
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
300+
interface consistency but is currently ignored for this model provider.**
298301
**kwargs: Additional keyword arguments for future extensibility.
299302
300303
Yields:

src/strands/models/openai.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from ..types.content import ContentBlock, Messages
1818
from ..types.streaming import StreamEvent
19-
from ..types.tools import ToolResult, ToolSpec, ToolUse
19+
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
2020
from ._config_validation import validate_config_keys
2121
from .model import Model
2222

@@ -174,6 +174,27 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]:
174174
"content": [cls.format_request_message_content(content) for content in contents],
175175
}
176176

177+
@classmethod
178+
def format_request_tool_choice(cls, tool_choice: ToolChoice) -> Union[str, dict[str, Any]]:
179+
"""Format a tool choice for OpenAI compatibility.
180+
181+
Args:
182+
tool_choice: Tool choice configuration in Bedrock format.
183+
184+
Returns:
185+
OpenAI compatible tool choice format.
186+
"""
187+
match tool_choice:
188+
case {"auto": _}:
189+
return "auto" # OpenAI SDK doesn't define constants for these values
190+
case {"any": _}:
191+
return "required"
192+
case {"tool": {"name": tool_name}}:
193+
return {"type": "function", "function": {"name": tool_name}}
194+
case _:
195+
# This should not happen with proper typing, but handle gracefully
196+
return "auto"
197+
177198
@classmethod
178199
def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
179200
"""Format an OpenAI compatible messages array.
@@ -216,14 +237,19 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str
216237
return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
217238

218239
def format_request(
219-
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
240+
self,
241+
messages: Messages,
242+
tool_specs: Optional[list[ToolSpec]] = None,
243+
system_prompt: Optional[str] = None,
244+
tool_choice: Optional[ToolChoice] = None,
220245
) -> dict[str, Any]:
221246
"""Format an OpenAI compatible chat streaming request.
222247
223248
Args:
224249
messages: List of message objects to be processed by the model.
225250
tool_specs: List of tool specifications to make available to the model.
226251
system_prompt: System prompt to provide context to the model.
252+
tool_choice: Selection strategy for tool invocation.
227253
228254
Returns:
229255
An OpenAI compatible chat streaming request.
@@ -248,6 +274,7 @@ def format_request(
248274
}
249275
for tool_spec in tool_specs or []
250276
],
277+
**({"tool_choice": self.format_request_tool_choice(tool_choice)} if tool_choice else {}),
251278
**cast(dict[str, Any], self.config.get("params", {})),
252279
}
253280

@@ -329,6 +356,7 @@ async def stream(
329356
messages: Messages,
330357
tool_specs: Optional[list[ToolSpec]] = None,
331358
system_prompt: Optional[str] = None,
359+
tool_choice: Optional[ToolChoice] = None,
332360
**kwargs: Any,
333361
) -> AsyncGenerator[StreamEvent, None]:
334362
"""Stream conversation with the OpenAI model.
@@ -337,13 +365,14 @@ async def stream(
337365
messages: List of message objects to be processed by the model.
338366
tool_specs: List of tool specifications to make available to the model.
339367
system_prompt: System prompt to provide context to the model.
368+
tool_choice: Selection strategy for tool invocation.
340369
**kwargs: Additional keyword arguments for future extensibility.
341370
342371
Yields:
343372
Formatted message chunks from the model.
344373
"""
345374
logger.debug("formatting request")
346-
request = self.format_request(messages, tool_specs, system_prompt)
375+
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
347376
logger.debug("formatted request=<%s>", request)
348377

349378
logger.debug("invoking model")

0 commit comments

Comments
 (0)