Skip to content

Commit 6086541

Browse files
committed
feat: Add warning for providers that don't support toolChoice
1 parent d691ea9 commit 6086541

20 files changed

+302
-53
lines changed

src/strands/models/_config_validation.py renamed to src/strands/models/_validation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from typing_extensions import get_type_hints
77

8+
from ..types.tools import ToolChoice
9+
810

911
def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None:
1012
"""Validate that config keys match the TypedDict fields.
@@ -25,3 +27,11 @@ def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) ->
2527
f"\nSee https://github.com/strands-agents/sdk-python/issues/815",
2628
stacklevel=4,
2729
)
30+
31+
32+
def warn_on_tool_choice_not_supported(tool_choice: ToolChoice | None) -> None:
33+
if tool_choice:
34+
warnings.warn(
35+
"A ToolChoice was provided to this provider but is not supported and will be ignored",
36+
stacklevel=4,
37+
)

src/strands/models/anthropic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
2020
from ..types.streaming import StreamEvent
2121
from ..types.tools import ToolChoice, ToolSpec
22-
from ._config_validation import validate_config_keys
22+
from ._validation import validate_config_keys
2323
from .model import Model
2424

2525
logger = logging.getLogger(__name__)
@@ -199,7 +199,7 @@ def format_request(
199199
messages: Messages,
200200
tool_specs: Optional[list[ToolSpec]] = None,
201201
system_prompt: Optional[str] = None,
202-
tool_choice: Optional[ToolChoice] = None,
202+
tool_choice: ToolChoice | None = None,
203203
) -> dict[str, Any]:
204204
"""Format an Anthropic streaming request.
205205
@@ -356,7 +356,7 @@ async def stream(
356356
messages: Messages,
357357
tool_specs: Optional[list[ToolSpec]] = None,
358358
system_prompt: Optional[str] = None,
359-
tool_choice: Optional[ToolChoice] = None,
359+
tool_choice: ToolChoice | None = None,
360360
**kwargs: Any,
361361
) -> AsyncGenerator[StreamEvent, None]:
362362
"""Stream conversation with the Anthropic model.

src/strands/models/bedrock.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from ..types.streaming import CitationsDelta, StreamEvent
2626
from ..types.tools import ToolChoice, ToolResult, ToolSpec
27-
from ._config_validation import validate_config_keys
27+
from ._validation import validate_config_keys
2828
from .model import Model
2929

3030
logger = logging.getLogger(__name__)
@@ -195,7 +195,7 @@ def format_request(
195195
messages: Messages,
196196
tool_specs: Optional[list[ToolSpec]] = None,
197197
system_prompt: Optional[str] = None,
198-
tool_choice: Optional[ToolChoice] = None,
198+
tool_choice: ToolChoice | None = None,
199199
) -> dict[str, Any]:
200200
"""Format a Bedrock converse stream request.
201201
@@ -226,7 +226,7 @@ def format_request(
226226
else []
227227
),
228228
],
229-
**({"toolChoice": tool_choice} if tool_choice else {}),
229+
**({"toolChoice": tool_choice if tool_choice else {"auto": {}}}),
230230
}
231231
}
232232
if tool_specs
@@ -418,7 +418,7 @@ async def stream(
418418
messages: Messages,
419419
tool_specs: Optional[list[ToolSpec]] = None,
420420
system_prompt: Optional[str] = None,
421-
tool_choice: Optional[ToolChoice] = None,
421+
tool_choice: ToolChoice | None = None,
422422
**kwargs: Any,
423423
) -> AsyncGenerator[StreamEvent, None]:
424424
"""Stream conversation with the Bedrock model.
@@ -467,7 +467,7 @@ def _stream(
467467
messages: Messages,
468468
tool_specs: Optional[list[ToolSpec]] = None,
469469
system_prompt: Optional[str] = None,
470-
tool_choice: Optional[ToolChoice] = None,
470+
tool_choice: ToolChoice | None = None,
471471
) -> None:
472472
"""Stream conversation with the Bedrock model.
473473

src/strands/models/litellm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..types.content import ContentBlock, Messages
1616
from ..types.streaming import StreamEvent
1717
from ..types.tools import ToolChoice, ToolSpec
18-
from ._config_validation import validate_config_keys
18+
from ._validation import validate_config_keys
1919
from .openai import OpenAIModel
2020

2121
logger = logging.getLogger(__name__)
@@ -114,7 +114,7 @@ async def stream(
114114
messages: Messages,
115115
tool_specs: Optional[list[ToolSpec]] = None,
116116
system_prompt: Optional[str] = None,
117-
tool_choice: Optional[ToolChoice] = None,
117+
tool_choice: ToolChoice | None = None,
118118
**kwargs: Any,
119119
) -> AsyncGenerator[StreamEvent, None]:
120120
"""Stream conversation with the LiteLLM model.
@@ -123,15 +123,14 @@ async def stream(
123123
messages: List of message objects to be processed by the model.
124124
tool_specs: List of tool specifications to make available to the model.
125125
system_prompt: System prompt to provide context to the model.
126-
tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for
127-
interface consistency but is currently ignored for this model provider.**
126+
tool_choice: Selection strategy for tool invocation.
128127
**kwargs: Additional keyword arguments for future extensibility.
129128
130129
Yields:
131130
Formatted message chunks from the model.
132131
"""
133132
logger.debug("formatting request")
134-
request = self.format_request(messages, tool_specs, system_prompt)
133+
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
135134
logger.debug("request=<%s>", request)
136135

137136
logger.debug("invoking model")

src/strands/models/llamaapi.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ..types.exceptions import ModelThrottledException
2020
from ..types.streaming import StreamEvent, Usage
2121
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
22-
from ._config_validation import validate_config_keys
22+
from ._validation import validate_config_keys, warn_on_tool_choice_not_supported
2323
from .model import Model
2424

2525
logger = logging.getLogger(__name__)
@@ -330,7 +330,7 @@ async def stream(
330330
messages: Messages,
331331
tool_specs: Optional[list[ToolSpec]] = None,
332332
system_prompt: Optional[str] = None,
333-
tool_choice: Optional[ToolChoice] = None,
333+
tool_choice: ToolChoice | None = None,
334334
**kwargs: Any,
335335
) -> AsyncGenerator[StreamEvent, None]:
336336
"""Stream conversation with the LlamaAPI model.
@@ -349,6 +349,8 @@ async def stream(
349349
Raises:
350350
ModelThrottledException: When the model service is throttling requests from the client.
351351
"""
352+
warn_on_tool_choice_not_supported(tool_choice)
353+
352354
logger.debug("formatting request")
353355
request = self.format_request(messages, tool_specs, system_prompt)
354356
logger.debug("request=<%s>", request)

src/strands/models/mistral.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..types.exceptions import ModelThrottledException
1717
from ..types.streaming import StopReason, StreamEvent
1818
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
19-
from ._config_validation import validate_config_keys
19+
from ._validation import validate_config_keys, warn_on_tool_choice_not_supported
2020
from .model import Model
2121

2222
logger = logging.getLogger(__name__)
@@ -397,7 +397,7 @@ async def stream(
397397
messages: Messages,
398398
tool_specs: Optional[list[ToolSpec]] = None,
399399
system_prompt: Optional[str] = None,
400-
tool_choice: Optional[ToolChoice] = None,
400+
tool_choice: ToolChoice | None = None,
401401
**kwargs: Any,
402402
) -> AsyncGenerator[StreamEvent, None]:
403403
"""Stream conversation with the Mistral model.
@@ -416,6 +416,8 @@ async def stream(
416416
Raises:
417417
ModelThrottledException: When the model service is throttling requests.
418418
"""
419+
warn_on_tool_choice_not_supported(tool_choice)
420+
419421
logger.debug("formatting request")
420422
request = self.format_request(messages, tool_specs, system_prompt)
421423
logger.debug("request=<%s>", request)

src/strands/models/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def stream(
7070
messages: Messages,
7171
tool_specs: Optional[list[ToolSpec]] = None,
7272
system_prompt: Optional[str] = None,
73-
tool_choice: Optional[ToolChoice] = None,
73+
tool_choice: ToolChoice | None = None,
7474
**kwargs: Any,
7575
) -> AsyncIterable[StreamEvent]:
7676
"""Stream conversation with the model.

src/strands/models/ollama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ..types.content import ContentBlock, Messages
1515
from ..types.streaming import StopReason, StreamEvent
1616
from ..types.tools import ToolChoice, ToolSpec
17-
from ._config_validation import validate_config_keys
17+
from ._validation import validate_config_keys, warn_on_tool_choice_not_supported
1818
from .model import Model
1919

2020
logger = logging.getLogger(__name__)
@@ -287,7 +287,7 @@ async def stream(
287287
messages: Messages,
288288
tool_specs: Optional[list[ToolSpec]] = None,
289289
system_prompt: Optional[str] = None,
290-
tool_choice: Optional[ToolChoice] = None,
290+
tool_choice: ToolChoice | None = None,
291291
**kwargs: Any,
292292
) -> AsyncGenerator[StreamEvent, None]:
293293
"""Stream conversation with the Ollama model.
@@ -303,6 +303,8 @@ async def stream(
303303
Yields:
304304
Formatted message chunks from the model.
305305
"""
306+
warn_on_tool_choice_not_supported(tool_choice)
307+
306308
logger.debug("formatting request")
307309
request = self.format_request(messages, tool_specs, system_prompt)
308310
logger.debug("request=<%s>", request)

src/strands/models/openai.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ..types.content import ContentBlock, Messages
1818
from ..types.streaming import StreamEvent
1919
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
20-
from ._config_validation import validate_config_keys
20+
from ._validation import validate_config_keys
2121
from .model import Model
2222

2323
logger = logging.getLogger(__name__)
@@ -175,7 +175,7 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]:
175175
}
176176

177177
@classmethod
178-
def format_request_tool_choice(cls, tool_choice: ToolChoice) -> Union[str, dict[str, Any]]:
178+
def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str, Any]:
179179
"""Format a tool choice for OpenAI compatibility.
180180
181181
Args:
@@ -184,16 +184,19 @@ def format_request_tool_choice(cls, tool_choice: ToolChoice) -> Union[str, dict[
184184
Returns:
185185
OpenAI compatible tool choice format.
186186
"""
187+
if not tool_choice:
188+
return {}
189+
187190
match tool_choice:
188191
case {"auto": _}:
189-
return "auto" # OpenAI SDK doesn't define constants for these values
192+
return {"tool_choice": "auto"} # OpenAI SDK doesn't define constants for these values
190193
case {"any": _}:
191-
return "required"
194+
return {"tool_choice": "required"}
192195
case {"tool": {"name": tool_name}}:
193-
return {"type": "function", "function": {"name": tool_name}}
196+
return {"tool_choice": {"type": "function", "function": {"name": tool_name}}}
194197
case _:
195198
# This should not happen with proper typing, but handle gracefully
196-
return "auto"
199+
return {"tool_choice": "auto"}
197200

198201
@classmethod
199202
def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
@@ -241,7 +244,7 @@ def format_request(
241244
messages: Messages,
242245
tool_specs: Optional[list[ToolSpec]] = None,
243246
system_prompt: Optional[str] = None,
244-
tool_choice: Optional[ToolChoice] = None,
247+
tool_choice: ToolChoice | None = None,
245248
) -> dict[str, Any]:
246249
"""Format an OpenAI compatible chat streaming request.
247250
@@ -274,7 +277,7 @@ def format_request(
274277
}
275278
for tool_spec in tool_specs or []
276279
],
277-
**({"tool_choice": self.format_request_tool_choice(tool_choice)} if tool_choice else {}),
280+
**(self._format_request_tool_choice(tool_choice)),
278281
**cast(dict[str, Any], self.config.get("params", {})),
279282
}
280283

@@ -356,7 +359,7 @@ async def stream(
356359
messages: Messages,
357360
tool_specs: Optional[list[ToolSpec]] = None,
358361
system_prompt: Optional[str] = None,
359-
tool_choice: Optional[ToolChoice] = None,
362+
tool_choice: ToolChoice | None = None,
360363
**kwargs: Any,
361364
) -> AsyncGenerator[StreamEvent, None]:
362365
"""Stream conversation with the OpenAI model.

src/strands/models/sagemaker.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..types.content import ContentBlock, Messages
1616
from ..types.streaming import StreamEvent
1717
from ..types.tools import ToolChoice, ToolResult, ToolSpec
18-
from ._config_validation import validate_config_keys
18+
from ._validation import validate_config_keys, warn_on_tool_choice_not_supported
1919
from .openai import OpenAIModel
2020

2121
T = TypeVar("T", bound=BaseModel)
@@ -201,7 +201,7 @@ def format_request(
201201
messages: Messages,
202202
tool_specs: Optional[list[ToolSpec]] = None,
203203
system_prompt: Optional[str] = None,
204-
tool_choice: Optional[ToolChoice] = None,
204+
tool_choice: ToolChoice | None = None,
205205
) -> dict[str, Any]:
206206
"""Format an Amazon SageMaker chat streaming request.
207207
@@ -292,7 +292,7 @@ async def stream(
292292
messages: Messages,
293293
tool_specs: Optional[list[ToolSpec]] = None,
294294
system_prompt: Optional[str] = None,
295-
tool_choice: Optional[ToolChoice] = None,
295+
tool_choice: ToolChoice | None = None,
296296
**kwargs: Any,
297297
) -> AsyncGenerator[StreamEvent, None]:
298298
"""Stream conversation with the SageMaker model.
@@ -308,11 +308,14 @@ async def stream(
308308
Yields:
309309
Formatted message chunks from the model.
310310
"""
311+
warn_on_tool_choice_not_supported(tool_choice)
312+
311313
logger.debug("formatting request")
312314
request = self.format_request(messages, tool_specs, system_prompt)
313315
logger.debug("formatted request=<%s>", request)
314316

315317
logger.debug("invoking model")
318+
316319
try:
317320
if self.payload_config.get("stream", True):
318321
response = self.client.invoke_endpoint_with_response_stream(**request)

0 commit comments

Comments
 (0)