Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions python/packages/core/agent_framework/_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def get_response(
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
Expand Down Expand Up @@ -160,7 +160,7 @@ def get_streaming_response(
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
Expand Down Expand Up @@ -501,7 +501,7 @@ async def get_response(
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
Expand Down Expand Up @@ -596,7 +596,7 @@ async def get_streaming_response(
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
Expand Down Expand Up @@ -688,12 +688,18 @@ def _prepare_tool_choice(self, chat_options: ChatOptions) -> None:
chat_options: The chat options to prepare.
"""
chat_tool_mode = chat_options.tool_choice
if chat_tool_mode is None or chat_tool_mode == ToolMode.NONE or chat_tool_mode == "none":
# Explicitly disabled: clear tools and set to NONE
if chat_tool_mode == ToolMode.NONE or chat_tool_mode == "none":
chat_options.tools = None
chat_options.tool_choice = ToolMode.NONE
return
# No tools available: set to NONE regardless of requested mode
if not chat_options.tools:
chat_options.tool_choice = ToolMode.NONE
# Tools available but no explicit mode: default to AUTO
elif chat_tool_mode is None:
chat_options.tool_choice = ToolMode.AUTO
# Tools available with explicit mode: preserve the mode
else:
chat_options.tool_choice = chat_tool_mode

Expand Down
4 changes: 2 additions & 2 deletions python/packages/core/agent_framework/openai/_chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options:
run_options.pop("tools", None)
run_options.pop("parallel_tool_calls", None)
run_options.pop("tool_choice", None)
# tool choice when `tool_choice` is a dict with single key `mode`, extract the mode value
if (tool_choice := run_options.get("tool_choice")) and len(tool_choice.keys()) == 1:
# tool_choice: ToolMode serializes to {"type": "tool_mode", "mode": "..."}, extract mode
if (tool_choice := run_options.get("tool_choice")) and isinstance(tool_choice, dict) and "mode" in tool_choice:
run_options["tool_choice"] = tool_choice["mode"]

# response format
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ async def _prepare_options(
else:
run_options.pop("parallel_tool_calls", None)
run_options.pop("tool_choice", None)
# tool choice when `tool_choice` is a dict with single key `mode`, extract the mode value
if (tool_choice := run_options.get("tool_choice")) and len(tool_choice.keys()) == 1:
# tool_choice: ToolMode serializes to {"type": "tool_mode", "mode": "..."}, extract mode
if (tool_choice := run_options.get("tool_choice")) and isinstance(tool_choice, dict) and "mode" in tool_choice:
run_options["tool_choice"] = tool_choice["mode"]

# additional properties
Expand Down
90 changes: 90 additions & 0 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,3 +632,93 @@ def echo_thread_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnk
assert result.text == "done"
assert captured.get("has_thread") is True
assert captured.get("has_message_store") is True


async def test_chat_agent_tool_choice_run_level_overrides_agent_level(
chat_client_base: Any, ai_function_tool: Any
) -> None:
"""Verify that tool_choice passed to run() overrides agent-level tool_choice."""
from agent_framework import ChatOptions, ToolMode

captured_options: list[ChatOptions] = []

# Store the original inner method
original_inner = chat_client_base._inner_get_response

async def capturing_inner(
*, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> ChatResponse:
captured_options.append(chat_options)
return await original_inner(messages=messages, chat_options=chat_options, **kwargs)

chat_client_base._inner_get_response = capturing_inner

# Create agent with agent-level tool_choice="auto" and a tool (tools required for tool_choice to be meaningful)
agent = ChatAgent(chat_client=chat_client_base, tool_choice="auto", tools=[ai_function_tool])

# Run with run-level tool_choice="required"
await agent.run("Hello", tool_choice="required")

# Verify the client received tool_choice="required", not "auto"
assert len(captured_options) >= 1
assert captured_options[0].tool_choice == "required"
assert captured_options[0].tool_choice == ToolMode.REQUIRED_ANY


async def test_chat_agent_tool_choice_agent_level_used_when_run_level_not_specified(
chat_client_base: Any, ai_function_tool: Any
) -> None:
"""Verify that agent-level tool_choice is used when run() doesn't specify one."""
from agent_framework import ChatOptions, ToolMode

captured_options: list[ChatOptions] = []

original_inner = chat_client_base._inner_get_response

async def capturing_inner(
*, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> ChatResponse:
captured_options.append(chat_options)
return await original_inner(messages=messages, chat_options=chat_options, **kwargs)

chat_client_base._inner_get_response = capturing_inner

# Create agent with agent-level tool_choice="required" and a tool
agent = ChatAgent(chat_client=chat_client_base, tool_choice="required", tools=[ai_function_tool])

# Run without specifying tool_choice
await agent.run("Hello")

# Verify the client received tool_choice="required" from agent-level
assert len(captured_options) >= 1
assert captured_options[0].tool_choice == "required"
assert captured_options[0].tool_choice == ToolMode.REQUIRED_ANY


async def test_chat_agent_tool_choice_none_at_run_preserves_agent_level(
chat_client_base: Any, ai_function_tool: Any
) -> None:
"""Verify that tool_choice=None at run() uses agent-level default."""
from agent_framework import ChatOptions

captured_options: list[ChatOptions] = []

original_inner = chat_client_base._inner_get_response

async def capturing_inner(
*, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> ChatResponse:
captured_options.append(chat_options)
return await original_inner(messages=messages, chat_options=chat_options, **kwargs)

chat_client_base._inner_get_response = capturing_inner

# Create agent with agent-level tool_choice="auto" and a tool
agent = ChatAgent(chat_client=chat_client_base, tool_choice="auto", tools=[ai_function_tool])

# Run with explicitly passing None (same as not specifying)
await agent.run("Hello", tool_choice=None)

# Verify the client received tool_choice="auto" from agent-level
assert len(captured_options) >= 1
assert captured_options[0].tool_choice == "auto"
48 changes: 48 additions & 0 deletions python/packages/core/tests/core/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,54 @@ def test_chat_options_and(ai_function_tool, ai_tool) -> None:
assert options3.additional_properties.get("p") == 1


def test_chat_options_and_tool_choice_override() -> None:
"""Test that tool_choice from other takes precedence in ChatOptions merge."""
# Agent-level defaults to "auto"
agent_options = ChatOptions(model_id="gpt-4o", tool_choice="auto")
# Run-level specifies "required"
run_options = ChatOptions(tool_choice="required")

merged = agent_options & run_options

# Run-level should override agent-level
assert merged.tool_choice == "required"
assert merged.model_id == "gpt-4o" # Other fields preserved


def test_chat_options_and_tool_choice_none_in_other_uses_self() -> None:
"""Test that when other.tool_choice is None, self.tool_choice is used."""
agent_options = ChatOptions(tool_choice="auto")
run_options = ChatOptions(model_id="gpt-4.1") # tool_choice is None

merged = agent_options & run_options

# Should keep agent-level tool_choice since run-level is None
assert merged.tool_choice == "auto"
assert merged.model_id == "gpt-4.1"


def test_chat_options_and_tool_choice_with_tool_mode() -> None:
"""Test ChatOptions merge with ToolMode objects."""
agent_options = ChatOptions(tool_choice=ToolMode.AUTO)
run_options = ChatOptions(tool_choice=ToolMode.REQUIRED_ANY)

merged = agent_options & run_options

assert merged.tool_choice == ToolMode.REQUIRED_ANY
assert merged.tool_choice == "required" # ToolMode equality with string


def test_chat_options_and_tool_choice_required_specific_function() -> None:
"""Test ChatOptions merge with required specific function."""
agent_options = ChatOptions(tool_choice="auto")
run_options = ChatOptions(tool_choice=ToolMode.REQUIRED(function_name="get_weather"))

merged = agent_options & run_options

assert merged.tool_choice == "required"
assert merged.tool_choice.required_function_name == "get_weather"


# region Agent Response Fixtures


Expand Down
Loading