From 72ae50f2563841794990f7bacba43b336cdd1959 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Tue, 6 Jan 2026 16:31:48 +0900 Subject: [PATCH] fix: tool_choice parameter not being honored when passed to agent.run() --- .../packages/core/agent_framework/_clients.py | 16 ++-- .../agent_framework/openai/_chat_client.py | 4 +- .../openai/_responses_client.py | 4 +- .../packages/core/tests/core/test_agents.py | 90 +++++++++++++++++++ python/packages/core/tests/core/test_types.py | 48 ++++++++++ 5 files changed, 153 insertions(+), 9 deletions(-) diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index bfb2c3f7d4..6743902475 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -101,7 +101,7 @@ async def get_response( stop: str | Sequence[str] | None = None, store: bool | None = None, temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", + tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] @@ -160,7 +160,7 @@ def get_streaming_response( stop: str | Sequence[str] | None = None, store: bool | None = None, temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", + tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] @@ -501,7 +501,7 @@ async def get_response( stop: str | Sequence[str] | None = None, store: bool | None = None, temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", + tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] @@ -596,7 +596,7 @@ async def get_streaming_response( stop: str | Sequence[str] | None = None, store: bool | None = None, temperature: float | None = None, - tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", + tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] @@ -688,12 +688,18 @@ def _prepare_tool_choice(self, chat_options: ChatOptions) -> None: chat_options: The chat options to prepare. """ chat_tool_mode = chat_options.tool_choice - if chat_tool_mode is None or chat_tool_mode == ToolMode.NONE or chat_tool_mode == "none": + # Explicitly disabled: clear tools and set to NONE + if chat_tool_mode == ToolMode.NONE or chat_tool_mode == "none": chat_options.tools = None chat_options.tool_choice = ToolMode.NONE return + # No tools available: set to NONE regardless of requested mode if not chat_options.tools: chat_options.tool_choice = ToolMode.NONE + # Tools available but no explicit mode: default to AUTO + elif chat_tool_mode is None: + chat_options.tool_choice = ToolMode.AUTO + # Tools available with explicit mode: preserve the mode else: chat_options.tool_choice = chat_tool_mode diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index 305757356d..a2365b58f2 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -205,8 +205,8 @@ def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options: run_options.pop("tools", None) run_options.pop("parallel_tool_calls", None) run_options.pop("tool_choice", None) - # tool choice when `tool_choice` is a dict with single key `mode`, extract the mode value - if (tool_choice := run_options.get("tool_choice")) and len(tool_choice.keys()) == 1: + # tool_choice: ToolMode serializes to {"type": "tool_mode", "mode": "..."}, extract mode + if (tool_choice := run_options.get("tool_choice")) and isinstance(tool_choice, dict) and "mode" in tool_choice: run_options["tool_choice"] = tool_choice["mode"] # response format diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index 54a0f5544b..9e50677fae 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -435,8 +435,8 @@ async def _prepare_options( else: run_options.pop("parallel_tool_calls", None) run_options.pop("tool_choice", None) - # tool choice when `tool_choice` is a dict with single key `mode`, extract the mode value - if (tool_choice := run_options.get("tool_choice")) and len(tool_choice.keys()) == 1: + # tool_choice: ToolMode serializes to {"type": "tool_mode", "mode": "..."}, extract mode + if (tool_choice := run_options.get("tool_choice")) and isinstance(tool_choice, dict) and "mode" in tool_choice: run_options["tool_choice"] = tool_choice["mode"] # additional properties diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index a6df07cbbe..7611df0cb0 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -632,3 +632,93 @@ def echo_thread_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnk assert result.text == "done" assert captured.get("has_thread") is True assert captured.get("has_message_store") is True + + +async def test_chat_agent_tool_choice_run_level_overrides_agent_level( + chat_client_base: Any, ai_function_tool: Any +) -> None: + """Verify that tool_choice passed to run() overrides agent-level tool_choice.""" + from agent_framework import ChatOptions, ToolMode + + captured_options: list[ChatOptions] = [] + + # Store the original inner method + original_inner = chat_client_base._inner_get_response + + async def capturing_inner( + *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> ChatResponse: + captured_options.append(chat_options) + return await original_inner(messages=messages, chat_options=chat_options, **kwargs) + + chat_client_base._inner_get_response = capturing_inner + + # Create agent with agent-level tool_choice="auto" and a tool (tools required for tool_choice to be meaningful) + agent = ChatAgent(chat_client=chat_client_base, tool_choice="auto", tools=[ai_function_tool]) + + # Run with run-level tool_choice="required" + await agent.run("Hello", tool_choice="required") + + # Verify the client received tool_choice="required", not "auto" + assert len(captured_options) >= 1 + assert captured_options[0].tool_choice == "required" + assert captured_options[0].tool_choice == ToolMode.REQUIRED_ANY + + +async def test_chat_agent_tool_choice_agent_level_used_when_run_level_not_specified( + chat_client_base: Any, ai_function_tool: Any +) -> None: + """Verify that agent-level tool_choice is used when run() doesn't specify one.""" + from agent_framework import ChatOptions, ToolMode + + captured_options: list[ChatOptions] = [] + + original_inner = chat_client_base._inner_get_response + + async def capturing_inner( + *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> ChatResponse: + captured_options.append(chat_options) + return await original_inner(messages=messages, chat_options=chat_options, **kwargs) + + chat_client_base._inner_get_response = capturing_inner + + # Create agent with agent-level tool_choice="required" and a tool + agent = ChatAgent(chat_client=chat_client_base, tool_choice="required", tools=[ai_function_tool]) + + # Run without specifying tool_choice + await agent.run("Hello") + + # Verify the client received tool_choice="required" from agent-level + assert len(captured_options) >= 1 + assert captured_options[0].tool_choice == "required" + assert captured_options[0].tool_choice == ToolMode.REQUIRED_ANY + + +async def test_chat_agent_tool_choice_none_at_run_preserves_agent_level( + chat_client_base: Any, ai_function_tool: Any +) -> None: + """Verify that tool_choice=None at run() uses agent-level default.""" + from agent_framework import ChatOptions + + captured_options: list[ChatOptions] = [] + + original_inner = chat_client_base._inner_get_response + + async def capturing_inner( + *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> ChatResponse: + captured_options.append(chat_options) + return await original_inner(messages=messages, chat_options=chat_options, **kwargs) + + chat_client_base._inner_get_response = capturing_inner + + # Create agent with agent-level tool_choice="auto" and a tool + agent = ChatAgent(chat_client=chat_client_base, tool_choice="auto", tools=[ai_function_tool]) + + # Run with explicitly passing None (same as not specifying) + await agent.run("Hello", tool_choice=None) + + # Verify the client received tool_choice="auto" from agent-level + assert len(captured_options) >= 1 + assert captured_options[0].tool_choice == "auto" diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index 81242147d2..4d52d81d22 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -844,6 +844,54 @@ def test_chat_options_and(ai_function_tool, ai_tool) -> None: assert options3.additional_properties.get("p") == 1 +def test_chat_options_and_tool_choice_override() -> None: + """Test that tool_choice from other takes precedence in ChatOptions merge.""" + # Agent-level defaults to "auto" + agent_options = ChatOptions(model_id="gpt-4o", tool_choice="auto") + # Run-level specifies "required" + run_options = ChatOptions(tool_choice="required") + + merged = agent_options & run_options + + # Run-level should override agent-level + assert merged.tool_choice == "required" + assert merged.model_id == "gpt-4o" # Other fields preserved + + +def test_chat_options_and_tool_choice_none_in_other_uses_self() -> None: + """Test that when other.tool_choice is None, self.tool_choice is used.""" + agent_options = ChatOptions(tool_choice="auto") + run_options = ChatOptions(model_id="gpt-4.1") # tool_choice is None + + merged = agent_options & run_options + + # Should keep agent-level tool_choice since run-level is None + assert merged.tool_choice == "auto" + assert merged.model_id == "gpt-4.1" + + +def test_chat_options_and_tool_choice_with_tool_mode() -> None: + """Test ChatOptions merge with ToolMode objects.""" + agent_options = ChatOptions(tool_choice=ToolMode.AUTO) + run_options = ChatOptions(tool_choice=ToolMode.REQUIRED_ANY) + + merged = agent_options & run_options + + assert merged.tool_choice == ToolMode.REQUIRED_ANY + assert merged.tool_choice == "required" # ToolMode equality with string + + +def test_chat_options_and_tool_choice_required_specific_function() -> None: + """Test ChatOptions merge with required specific function.""" + agent_options = ChatOptions(tool_choice="auto") + run_options = ChatOptions(tool_choice=ToolMode.REQUIRED(function_name="get_weather")) + + merged = agent_options & run_options + + assert merged.tool_choice == "required" + assert merged.tool_choice.required_function_name == "get_weather" + + # region Agent Response Fixtures