diff --git a/src/strands/models/_config_validation.py b/src/strands/models/_validation.py similarity index 66% rename from src/strands/models/_config_validation.py rename to src/strands/models/_validation.py index 085449bb8..9eabe28a1 100644 --- a/src/strands/models/_config_validation.py +++ b/src/strands/models/_validation.py @@ -5,6 +5,8 @@ from typing_extensions import get_type_hints +from ..types.tools import ToolChoice + def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None: """Validate that config keys match the TypedDict fields. @@ -25,3 +27,16 @@ def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> f"\nSee https://github.com/strands-agents/sdk-python/issues/815", stacklevel=4, ) + + +def warn_on_tool_choice_not_supported(tool_choice: ToolChoice | None) -> None: + """Emits a warning if a tool choice is provided but not supported by the provider. + + Args: + tool_choice: the tool_choice provided to the provider + """ + if tool_choice: + warnings.warn( + "A ToolChoice was provided to this provider but is not supported and will be ignored", + stacklevel=4, + ) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 06dc816f2..4afc8e3dc 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -18,8 +18,8 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec +from ._validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -195,7 +195,11 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: return formatted_messages def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format an Anthropic streaming request. @@ -203,6 +207,7 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Returns: An Anthropic streaming request. @@ -223,10 +228,25 @@ def format_request( } for tool_spec in tool_specs or [] ], + **(self._format_tool_choice(tool_choice)), **({"system": system_prompt} if system_prompt else {}), **(self.config.get("params") or {}), } + @staticmethod + def _format_tool_choice(tool_choice: ToolChoice | None) -> dict: + if tool_choice is None: + return {} + + if "any" in tool_choice: + return {"tool_choice": {"type": "any"}} + elif "auto" in tool_choice: + return {"tool_choice": {"type": "auto"}} + elif "tool" in tool_choice: + return {"tool_choice": {"type": "tool", "name": cast(ToolChoiceToolDict, tool_choice)["tool"]["name"]}} + else: + return {} + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: """Format the Anthropic response events into standardized message chunks. @@ -350,6 +370,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Anthropic model. @@ -358,6 +379,7 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -368,7 +390,7 @@ async def stream( ModelThrottledException: If the request is throttled by Anthropic. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("request=<%s>", request) logger.debug("invoking model") @@ -410,7 +432,13 @@ async def structured_output( """ tool_spec = convert_pydantic_to_tool_spec(output_model) - response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs) + response = self.stream( + messages=prompt, + tool_specs=[tool_spec], + system_prompt=system_prompt, + tool_choice=cast(ToolChoice, {"any": {}}), + **kwargs, + ) async for event in process_stream(response): yield event diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index f18422191..d75817ec4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -23,8 +23,8 @@ ModelThrottledException, ) from ..types.streaming import CitationsDelta, StreamEvent -from ..types.tools import ToolResult, ToolSpec -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec +from ._validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -195,6 +195,7 @@ def format_request( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format a Bedrock converse stream request. @@ -202,6 +203,7 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Returns: A Bedrock converse stream request. @@ -224,7 +226,7 @@ def format_request( else [] ), ], - "toolChoice": {"auto": {}}, + **({"toolChoice": tool_choice if tool_choice else {"auto": {}}}), } } if tool_specs @@ -416,6 +418,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Bedrock model. @@ -427,6 +430,7 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -445,7 +449,7 @@ def callback(event: Optional[StreamEvent] = None) -> None: loop = asyncio.get_event_loop() queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() - thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt) + thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt, tool_choice) task = asyncio.create_task(thread) while True: @@ -463,6 +467,7 @@ def _stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, ) -> None: """Stream conversation with the Bedrock model. @@ -474,6 +479,7 @@ def _stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Raises: ContextWindowOverflowException: If the input exceeds the model's context window. @@ -481,7 +487,7 @@ def _stream( """ try: logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("request=<%s>", request) logger.debug("invoking model") @@ -738,6 +744,7 @@ async def structured_output( messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, + tool_choice=cast(ToolChoice, {"any": {}}), **kwargs, ) async for event in streaming.process_stream(response): diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 36b385281..6bcc1359e 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -14,8 +14,8 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolSpec +from ._validation import validate_config_keys from .openai import OpenAIModel logger = logging.getLogger(__name__) @@ -114,6 +114,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -122,13 +123,14 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("request=<%s>", request) logger.debug("invoking model") diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 57ff85c66..4e801026c 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -18,8 +18,8 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent, Usage -from ..types.tools import ToolResult, ToolSpec, ToolUse -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -330,6 +330,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LlamaAPI model. @@ -338,6 +339,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -346,6 +349,8 @@ async def stream( Raises: ModelThrottledException: When the model service is throttling requests from the client. """ + warn_on_tool_choice_not_supported(tool_choice) + logger.debug("formatting request") request = self.format_request(messages, tool_specs, system_prompt) logger.debug("request=<%s>", request) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 401dde98e..90cd1b5d8 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -15,8 +15,8 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StopReason, StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -397,6 +397,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Mistral model. @@ -405,6 +406,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -413,6 +416,8 @@ async def stream( Raises: ModelThrottledException: When the model service is throttling requests. """ + warn_on_tool_choice_not_supported(tool_choice) + logger.debug("formatting request") request = self.format_request(messages, tool_specs, system_prompt) logger.debug("request=<%s>", request) diff --git a/src/strands/models/model.py b/src/strands/models/model.py index cb24b704d..7a8b4d4cc 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -8,7 +8,7 @@ from ..types.content import Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec +from ..types.tools import ToolChoice, ToolSpec logger = logging.getLogger(__name__) @@ -70,6 +70,7 @@ def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncIterable[StreamEvent]: """Stream conversation with the model. @@ -84,6 +85,7 @@ def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 4025dc062..c29772215 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -13,8 +13,8 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StopReason, StreamEvent -from ..types.tools import ToolSpec -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolSpec +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -287,6 +287,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Ollama model. @@ -295,11 +296,15 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ + warn_on_tool_choice_not_supported(tool_choice) + logger.debug("formatting request") request = self.format_request(messages, tool_specs, system_prompt) logger.debug("request=<%s>", request) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 16eb4defe..fd75ea175 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -16,8 +16,8 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -174,6 +174,30 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: "content": [cls.format_request_message_content(content) for content in contents], } + @classmethod + def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str, Any]: + """Format a tool choice for OpenAI compatibility. + + Args: + tool_choice: Tool choice configuration in Bedrock format. + + Returns: + OpenAI compatible tool choice format. + """ + if not tool_choice: + return {} + + match tool_choice: + case {"auto": _}: + return {"tool_choice": "auto"} # OpenAI SDK doesn't define constants for these values + case {"any": _}: + return {"tool_choice": "required"} + case {"tool": {"name": tool_name}}: + return {"tool_choice": {"type": "function", "function": {"name": tool_name}}} + case _: + # This should not happen with proper typing, but handle gracefully + return {"tool_choice": "auto"} + @classmethod def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: """Format an OpenAI compatible messages array. @@ -216,7 +240,11 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format an OpenAI compatible chat streaming request. @@ -224,6 +252,7 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Returns: An OpenAI compatible chat streaming request. @@ -248,6 +277,7 @@ def format_request( } for tool_spec in tool_specs or [] ], + **(self._format_request_tool_choice(tool_choice)), **cast(dict[str, Any], self.config.get("params", {})), } @@ -329,6 +359,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the OpenAI model. @@ -337,13 +368,14 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("formatted request=<%s>", request) logger.debug("invoking model") diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 74069b895..f635acce2 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -14,8 +14,8 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported from .openai import OpenAIModel T = TypeVar("T", bound=BaseModel) @@ -197,7 +197,11 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i @override def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format an Amazon SageMaker chat streaming request. @@ -205,6 +209,8 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** Returns: An Amazon SageMaker chat streaming request. @@ -286,6 +292,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the SageMaker model. @@ -294,16 +301,21 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ + warn_on_tool_choice_not_supported(tool_choice) + logger.debug("formatting request") request = self.format_request(messages, tool_specs, system_prompt) logger.debug("formatted request=<%s>", request) logger.debug("invoking model") + try: if self.payload_config.get("stream", True): response = self.client.invoke_endpoint_with_response_stream(**request) diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index 9bcdaad42..07119a21a 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -16,8 +16,8 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -355,6 +355,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Writer model. @@ -363,6 +364,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -371,6 +374,8 @@ async def stream( Raises: ModelThrottledException: When the model service is throttling requests from the client. """ + warn_on_tool_choice_not_supported(tool_choice) + logger.debug("formatting request") request = self.format_request(messages, tool_specs, system_prompt) logger.debug("request=<%s>", request) diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 1e0f4b841..e8d5531b2 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -145,10 +145,15 @@ class ToolContext: invocation_state: dict[str, Any] +# Individual ToolChoice type aliases +ToolChoiceAutoDict = dict[Literal["auto"], ToolChoiceAuto] +ToolChoiceAnyDict = dict[Literal["any"], ToolChoiceAny] +ToolChoiceToolDict = dict[Literal["tool"], ToolChoiceTool] + ToolChoice = Union[ - dict[Literal["auto"], ToolChoiceAuto], - dict[Literal["any"], ToolChoiceAny], - dict[Literal["tool"], ToolChoiceTool], + ToolChoiceAutoDict, + ToolChoiceAnyDict, + ToolChoiceToolDict, ] """ Configuration for how the model should choose tools. diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 9a7a4be11..74bbb8d45 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -417,6 +417,72 @@ def test_format_request_with_empty_content(model, model_id, max_tokens): assert tru_request == exp_request +def test_format_request_tool_choice_auto(model, messages, model_id, max_tokens): + tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] + tool_choice = {"auto": {}} + + tru_request = model.format_request(messages, tool_specs, tool_choice=tool_choice) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "tools": [ + { + "name": "test_tool", + "description": "test tool", + "input_schema": {"key": "value"}, + } + ], + "tool_choice": {"type": "auto"}, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_any(model, messages, model_id, max_tokens): + tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] + tool_choice = {"any": {}} + + tru_request = model.format_request(messages, tool_specs, tool_choice=tool_choice) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "tools": [ + { + "name": "test_tool", + "description": "test tool", + "input_schema": {"key": "value"}, + } + ], + "tool_choice": {"type": "any"}, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_tool(model, messages, model_id, max_tokens): + tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] + tool_choice = {"tool": {"name": "test_tool"}} + + tru_request = model.format_request(messages, tool_specs, tool_choice=tool_choice) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "tools": [ + { + "name": "test_tool", + "description": "test tool", + "input_schema": {"key": "value"}, + } + ], + "tool_choice": {"name": "test_tool", "type": "tool"}, + } + + assert tru_request == exp_request + + def test_format_chunk_message_start(model): event = {"type": "message_start"} @@ -785,3 +851,18 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_tool_choice_supported_no_warning(model, messages, captured_warnings): + """Test that toolChoice doesn't emit warning for supported providers.""" + tool_choice = {"auto": {}} + model.format_request(messages, tool_choice=tool_choice) + + assert len(captured_warnings) == 0 + + +def test_tool_choice_none_no_warning(model, messages, captured_warnings): + """Test that None toolChoice doesn't emit warning.""" + model.format_request(messages, tool_choice=None) + + assert len(captured_warnings) == 0 diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 624eec6e9..9bdbb2b0f 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -400,6 +400,57 @@ def test_format_request_tool_specs(model, messages, model_id, tool_spec): assert tru_request == exp_request +def test_format_request_tool_choice_auto(model, messages, model_id, tool_spec): + tool_choice = {"auto": {}} + tru_request = model.format_request(messages, [tool_spec], tool_choice=tool_choice) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": tool_choice, + }, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_any(model, messages, model_id, tool_spec): + tool_choice = {"any": {}} + tru_request = model.format_request(messages, [tool_spec], tool_choice=tool_choice) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": tool_choice, + }, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_tool(model, messages, model_id, tool_spec): + tool_choice = {"tool": {"name": "test_tool"}} + tru_request = model.format_request(messages, [tool_spec], tool_choice=tool_choice) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": tool_choice, + }, + } + + assert tru_request == exp_request + + def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): model.update_config(cache_prompt=cache_type, cache_tools=cache_type) tru_request = model.format_request(messages, [tool_spec]) @@ -1463,3 +1514,18 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_tool_choice_supported_no_warning(model, messages, tool_spec, captured_warnings): + """Test that toolChoice doesn't emit warning for supported providers.""" + tool_choice = {"auto": {}} + model.format_request(messages, [tool_spec], tool_choice=tool_choice) + + assert len(captured_warnings) == 0 + + +def test_tool_choice_none_no_warning(model, messages, captured_warnings): + """Test that None toolChoice doesn't emit warning.""" + model.format_request(messages, tool_choice=None) + + assert len(captured_warnings) == 0 diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 4f9f48b92..f345ba003 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -1,4 +1,5 @@ import unittest.mock +from unittest.mock import call import pydantic import pytest @@ -219,15 +220,16 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, assert tru_events == exp_events - expected_request = { - "api_key": api_key, - "model": model_id, - "messages": [{"role": "user", "content": [{"text": "calculate 2+2", "type": "text"}]}], - "stream": True, - "stream_options": {"include_usage": True}, - "tools": [], - } - litellm_acompletion.assert_called_once_with(**expected_request) + assert litellm_acompletion.call_args_list == [ + call( + api_key=api_key, + messages=[{"role": "user", "content": [{"text": "calculate 2+2", "type": "text"}]}], + model=model_id, + stream=True, + stream_options={"include_usage": True}, + tools=[], + ) + ] @pytest.mark.asyncio @@ -303,3 +305,18 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_tool_choice_supported_no_warning(model, messages, captured_warnings): + """Test that toolChoice doesn't emit warning for supported providers.""" + tool_choice = {"auto": {}} + model.format_request(messages, tool_choice=tool_choice) + + assert len(captured_warnings) == 0 + + +def test_tool_choice_none_no_warning(model, messages, captured_warnings): + """Test that None toolChoice doesn't emit warning.""" + model.format_request(messages, tool_choice=None) + + assert len(captured_warnings) == 0 diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 712ef8b7a..a6bbf5673 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -379,3 +379,38 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +@pytest.mark.asyncio +async def test_tool_choice_not_supported_warns(model, messages, captured_warnings, alist): + """Test that non-None toolChoice emits warning for unsupported providers.""" + tool_choice = {"auto": {}} + + with unittest.mock.patch.object(model.client.chat.completions, "create") as mock_create: + mock_chunk = unittest.mock.Mock() + mock_chunk.event.event_type = "start" + mock_chunk.event.stop_reason = "stop" + + mock_create.return_value = [mock_chunk] + + response = model.stream(messages, tool_choice=tool_choice) + await alist(response) + + assert len(captured_warnings) == 1 + assert "ToolChoice was provided to this provider but is not supported" in str(captured_warnings[0].message) + + +@pytest.mark.asyncio +async def test_tool_choice_none_no_warning(model, messages, captured_warnings, alist): + """Test that None toolChoice doesn't emit warning.""" + with unittest.mock.patch.object(model.client.chat.completions, "create") as mock_create: + mock_chunk = unittest.mock.Mock() + mock_chunk.event.event_type = "start" + mock_chunk.event.stop_reason = "stop" + + mock_create.return_value = [mock_chunk] + + response = model.stream(messages, tool_choice=None) + await alist(response) + + assert len(captured_warnings) == 0 diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 9b3f62a31..7808336f2 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -437,7 +437,7 @@ def test_format_chunk_unknown(model): @pytest.mark.asyncio -async def test_stream(mistral_client, model, agenerator, alist): +async def test_stream(mistral_client, model, agenerator, alist, captured_warnings): mock_usage = unittest.mock.Mock() mock_usage.prompt_tokens = 100 mock_usage.completion_tokens = 50 @@ -472,6 +472,41 @@ async def test_stream(mistral_client, model, agenerator, alist): mistral_client.chat.stream_async.assert_called_once_with(**expected_request) + assert len(captured_warnings) == 0 + + +@pytest.mark.asyncio +async def test_tool_choice_not_supported_warns(mistral_client, model, agenerator, alist, captured_warnings): + tool_choice = {"auto": {}} + + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + mock_event = unittest.mock.Mock( + data=unittest.mock.Mock( + choices=[ + unittest.mock.Mock( + delta=unittest.mock.Mock(content="test stream", tool_calls=None), + finish_reason="end_turn", + ) + ] + ), + usage=mock_usage, + ) + + mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages, None, None, tool_choice=tool_choice) + + # Consume the response + await alist(response) + + assert len(captured_warnings) == 1 + assert "ToolChoice was provided to this provider but is not supported" in str(captured_warnings[0].message) + @pytest.mark.asyncio async def test_stream_rate_limit_error(mistral_client, model, alist): diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index 9a63a3214..14db63a24 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -414,7 +414,7 @@ def test_format_chunk_other(model): @pytest.mark.asyncio -async def test_stream(ollama_client, model, agenerator, alist): +async def test_stream(ollama_client, model, agenerator, alist, captured_warnings): mock_event = unittest.mock.Mock() mock_event.message.tool_calls = None mock_event.message.content = "Hello" @@ -453,6 +453,31 @@ async def test_stream(ollama_client, model, agenerator, alist): } ollama_client.chat.assert_called_once_with(**expected_request) + # Ensure no warnings emitted + assert len(captured_warnings) == 0 + + +@pytest.mark.asyncio +async def test_tool_choice_not_supported_warns(ollama_client, model, agenerator, alist, captured_warnings): + """Test that non-None toolChoice emits warning for unsupported providers.""" + tool_choice = {"auto": {}} + + mock_event = unittest.mock.Mock() + mock_event.message.tool_calls = None + mock_event.message.content = "Hello" + mock_event.done_reason = "stop" + mock_event.eval_count = 10 + mock_event.prompt_eval_count = 5 + mock_event.total_duration = 1000000 # 1ms in nanoseconds + + ollama_client.chat = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + await alist(model.stream(messages, tool_choice=tool_choice)) + + assert len(captured_warnings) == 1 + assert "ToolChoice was provided to this provider but is not supported" in str(captured_warnings[0].message) + @pytest.mark.asyncio async def test_stream_with_tool_calls(ollama_client, model, agenerator, alist): diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 00cae7447..64da3cac2 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -179,6 +179,30 @@ def test_format_request_tool_message(): assert tru_result == exp_result +def test_format_request_tool_choice_auto(): + tool_choice = {"auto": {}} + + tru_result = OpenAIModel._format_request_tool_choice(tool_choice) + exp_result = {"tool_choice": "auto"} + assert tru_result == exp_result + + +def test_format_request_tool_choice_any(): + tool_choice = {"any": {}} + + tru_result = OpenAIModel._format_request_tool_choice(tool_choice) + exp_result = {"tool_choice": "required"} + assert tru_result == exp_result + + +def test_format_request_tool_choice_tool(): + tool_choice = {"tool": {"name": "test_tool"}} + + tru_result = OpenAIModel._format_request_tool_choice(tool_choice) + exp_result = {"tool_choice": {"type": "function", "function": {"name": "test_tool"}}} + assert tru_result == exp_result + + def test_format_request_messages(system_prompt): messages = [ { @@ -278,6 +302,123 @@ def test_format_request(model, messages, tool_specs, system_prompt): assert tru_request == exp_request +def test_format_request_with_tool_choice_auto(model, messages, tool_specs, system_prompt): + tool_choice = {"auto": {}} + tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "test", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "tool_choice": "auto", + "max_tokens": 1, + } + assert tru_request == exp_request + + +def test_format_request_with_tool_choice_any(model, messages, tool_specs, system_prompt): + tool_choice = {"any": {}} + tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "test", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "tool_choice": "required", + "max_tokens": 1, + } + assert tru_request == exp_request + + +def test_format_request_with_tool_choice_tool(model, messages, tool_specs, system_prompt): + tool_choice = {"tool": {"name": "test_tool"}} + tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "test", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "tool_choice": {"type": "function", "function": {"name": "test_tool"}}, + "max_tokens": 1, + } + assert tru_request == exp_request + + @pytest.mark.parametrize( ("event", "exp_chunk"), [ @@ -601,3 +742,18 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_tool_choice_supported_no_warning(model, messages, captured_warnings): + """Test that toolChoice doesn't emit warning for supported providers.""" + tool_choice = {"auto": {}} + model.format_request(messages, tool_choice=tool_choice) + + assert len(captured_warnings) == 0 + + +def test_tool_choice_none_no_warning(model, messages, captured_warnings): + """Test that None toolChoice doesn't emit warning.""" + model.format_request(messages, tool_choice=None) + + assert len(captured_warnings) == 0 diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index a9071c7e2..a5662ecdc 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -372,7 +372,7 @@ async def test_stream_with_tool_calls(self, sagemaker_client, model, messages): assert tool_use_data["name"] == "get_weather" @pytest.mark.asyncio - async def test_stream_with_partial_json(self, sagemaker_client, model, messages): + async def test_stream_with_partial_json(self, sagemaker_client, model, messages, captured_warnings): """Test streaming response with partial JSON chunks.""" # Mock the response from SageMaker with split JSON mock_response = { @@ -404,6 +404,30 @@ async def test_stream_with_partial_json(self, sagemaker_client, model, messages) text_delta = content_delta["contentBlockDelta"]["delta"]["text"] assert text_delta == "Paris is the capital of France." + # Ensure no warnings emitted + assert len(captured_warnings) == 0 + + @pytest.mark.asyncio + async def test_tool_choice_not_supported_warns(self, sagemaker_client, model, messages, captured_warnings, alist): + """Test that non-None toolChoice emits warning for unsupported providers.""" + tool_choice = {"auto": {}} + + """Test streaming response with partial JSON chunks.""" + # Mock the response from SageMaker with split JSON + mock_response = { + "Body": [ + {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + ] + } + sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response + + await alist(model.stream(messages, tool_choice=tool_choice)) + + # Ensure toolChoice parameter warning + assert len(captured_warnings) == 1 + assert "ToolChoice was provided to this provider but is not supported" in str(captured_warnings[0].message) + @pytest.mark.asyncio async def test_stream_non_streaming(self, sagemaker_client, model, messages): """Test non-streaming response.""" diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py index 75896ca68..8cf64a39a 100644 --- a/tests/strands/models/test_writer.py +++ b/tests/strands/models/test_writer.py @@ -353,7 +353,7 @@ async def test_stream_empty(writer_client, model, model_id): @pytest.mark.asyncio -async def test_stream_with_empty_choices(writer_client, model, model_id): +async def test_stream_with_empty_choices(writer_client, model, model_id, captured_warnings): mock_delta = unittest.mock.Mock(content="content", tool_calls=None) mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) @@ -381,6 +381,43 @@ async def test_stream_with_empty_choices(writer_client, model, model_id): } writer_client.chat.chat.assert_called_once_with(**expected_request) + # Ensure no warnings emitted + assert len(captured_warnings) == 0 + + +@pytest.mark.asyncio +async def test_tool_choice_not_supported_warns(writer_client, model, model_id, captured_warnings, alist): + mock_delta = unittest.mock.Mock(content="content", tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + mock_event_1 = unittest.mock.Mock(spec=[]) + mock_event_2 = unittest.mock.Mock(choices=[]) + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_5 = unittest.mock.Mock(usage=mock_usage) + + writer_client.chat.chat.return_value = mock_streaming_response( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] + ) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages, None, None, tool_choice={"auto": {}}) + + # Consume the response + await alist(response) + + expected_request = { + "model": model_id, + "messages": [{"role": "user", "content": [{"text": "test", "type": "text"}]}], + "stream": True, + "stream_options": {"include_usage": True}, + } + writer_client.chat.chat.assert_called_once_with(**expected_request) + + # Ensure expected warning is invoked + assert len(captured_warnings) == 1 + assert "ToolChoice was provided to this provider but is not supported" in str(captured_warnings[0].message) + def test_config_validation_warns_on_unknown_keys(writer_client, captured_warnings): """Test that unknown config keys emit a warning.""" diff --git a/tests_integ/models/test_conformance.py b/tests_integ/models/test_conformance.py index d9875bc07..eaef1eb88 100644 --- a/tests_integ/models/test_conformance.py +++ b/tests_integ/models/test_conformance.py @@ -1,7 +1,11 @@ +from unittest import SkipTest + import pytest +from pydantic import BaseModel +from strands import Agent from strands.models import Model -from tests_integ.models.providers import ProviderInfo, all_providers +from tests_integ.models.providers import ProviderInfo, all_providers, cohere, llama, mistral def get_models(): @@ -20,11 +24,39 @@ def provider_info(request) -> ProviderInfo: return request.param +@pytest.fixture() +def skip_for(provider_info: list[ProviderInfo]): + """A fixture which provides a function to skip the test if the provider is one of the providers specified.""" + + def skip_for_any_provider_in_list(providers: list[ProviderInfo], description: str): + """Skips the current test is the provider is one of those provided.""" + if provider_info in providers: + raise SkipTest(f"Skipping test for {provider_info.id}: {description}") + + return skip_for_any_provider_in_list + + @pytest.fixture() def model(provider_info): return provider_info.create_model() -def test_model_can_be_constructed(model: Model): +def test_model_can_be_constructed(model: Model, skip_for): assert model is not None pass + + +def test_structured_output_is_forced(skip_for, model): + """Tests that structured_output is always forced to return a value even if model doesn't have any information.""" + skip_for([mistral, cohere, llama], "structured_output is not forced for provider ") + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model) + + result = agent.structured_output(Weather, "How are you?") + + assert len(result.time) > 0 + assert len(result.weather) > 0