Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,13 @@ def build(
tool_output_type = output_type

if tool_name is None:
tool_name = default_tool_name
if multiple:
tool_name += f'_{tool_output_type.__name__}'
tool_name = (
tool_output_type.__name__
if inspect.isfunction(tool_output_type)
else f'{default_tool_name}_{tool_output_type.__name__}'
if multiple
else default_tool_name
)

i = 1
original_tool_name = tool_name
Expand Down
24 changes: 12 additions & 12 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert output_tools == snapshot(
[
ToolDefinition(
name='final_result',
name='get_weather',
description='The final response which ends this conversation',
parameters_json_schema={
'additionalProperties': False,
Expand Down Expand Up @@ -671,7 +671,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert output_tools == snapshot(
[
ToolDefinition(
name='final_result',
name='get_weather',
description='The final response which ends this conversation',
parameters_json_schema={
'additionalProperties': False,
Expand Down Expand Up @@ -801,7 +801,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
ModelResponse(
parts=[
ToolCallPart(
tool_name='final_result',
tool_name='get_weather',
args='{"city": "New York City"}',
tool_call_id=IsStr(),
)
Expand All @@ -814,7 +814,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
parts=[
RetryPromptPart(
content='City not found, I only know Mexico City',
tool_name='final_result',
tool_name='get_weather',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
)
Expand All @@ -823,7 +823,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
ModelResponse(
parts=[
ToolCallPart(
tool_name='final_result',
tool_name='get_weather',
args='{"city": "Mexico City"}',
tool_call_id=IsStr(),
)
Expand All @@ -835,7 +835,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
tool_name='get_weather',
content='Final result processed.',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
Expand Down Expand Up @@ -871,7 +871,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert output_tools == snapshot(
[
ToolDefinition(
name='final_result',
name='get_weather',
description='The final response which ends this conversation',
parameters_json_schema={
'additionalProperties': False,
Expand Down Expand Up @@ -947,7 +947,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert output_tools == snapshot(
[
ToolDefinition(
name='final_result_get_weather',
name='get_weather',
description='get_weather: The final response which ends this conversation',
parameters_json_schema={
'additionalProperties': False,
Expand Down Expand Up @@ -1017,7 +1017,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes
ModelResponse(
parts=[
ToolCallPart(
tool_name='final_result',
tool_name='handoff',
args='{"city": "Mexico City"}',
tool_call_id=IsStr(),
)
Expand All @@ -1029,7 +1029,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
tool_name='handoff',
content='Final result processed.',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
Expand All @@ -1052,7 +1052,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes
ModelResponse(
parts=[
ToolCallPart(
tool_name='final_result',
tool_name='get_weather',
args='{"city": "Mexico City"}',
tool_call_id=IsStr(),
)
Expand All @@ -1064,7 +1064,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
tool_name='get_weather',
content='Final result processed.',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
Expand Down
22 changes: 11 additions & 11 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,23 +406,23 @@ async def list_tools() -> list[None]:
tool_call_id='pyd_ai_2e0e396768a14fe482df90a29a78dc7b',
),
'Select the names and countries of all capitals': ToolCallPart(
tool_name='final_result_hand_off_to_sql_agent',
tool_name='hand_off_to_sql_agent',
args={'query': 'SELECT name, country FROM capitals;'},
),
'SELECT name, country FROM capitals;': ToolCallPart(
tool_name='final_result_run_sql_query',
tool_name='run_sql_query',
args={'query': 'SELECT name, country FROM capitals;'},
),
'SELECT * FROM capital_cities;': ToolCallPart(
tool_name='final_result_run_sql_query',
tool_name='run_sql_query',
args={'query': 'SELECT * FROM capital_cities;'},
),
'Select all pets': ToolCallPart(
tool_name='final_result_hand_off_to_sql_agent',
tool_name='hand_off_to_sql_agent',
args={'query': 'SELECT * FROM pets;'},
),
'SELECT * FROM pets;': ToolCallPart(
tool_name='final_result_run_sql_query',
tool_name='run_sql_query',
args={'query': 'SELECT * FROM pets;'},
),
'How do I fly from Amsterdam to Mexico City?': ToolCallPart(
Expand Down Expand Up @@ -614,36 +614,36 @@ async def model_logic( # noqa: C901
)
elif (
isinstance(m, RetryPromptPart)
and m.tool_name == 'final_result_run_sql_query'
and m.tool_name == 'run_sql_query'
and m.content == "Only 'SELECT *' is supported, you'll have to do column filtering manually."
):
return ModelResponse(
parts=[
ToolCallPart(
tool_name='final_result_run_sql_query',
tool_name='run_sql_query',
args={'query': 'SELECT * FROM capitals;'},
tool_call_id='pyd_ai_tool_call_id',
)
]
)
elif (
isinstance(m, RetryPromptPart)
and m.tool_name == 'final_result_hand_off_to_sql_agent'
and m.tool_name == 'hand_off_to_sql_agent'
and m.content
== "SQL agent failed: Unknown table 'capitals' in query 'SELECT * FROM capitals;'. Available tables: capital_cities."
):
return ModelResponse(
parts=[
ToolCallPart(
tool_name='final_result_hand_off_to_sql_agent',
tool_name='hand_off_to_sql_agent',
args={'query': 'SELECT * FROM capital_cities;'},
tool_call_id='pyd_ai_tool_call_id',
)
]
)
elif (
isinstance(m, RetryPromptPart)
and m.tool_name == 'final_result_run_sql_query'
and m.tool_name == 'run_sql_query'
and m.content == "Unknown table 'pets' in query 'SELECT * FROM pets;'. Available tables: capital_cities."
):
return ModelResponse(
Expand All @@ -660,7 +660,7 @@ async def model_logic( # noqa: C901
# SQL agent failed: The table 'pets' does not exist in the database. Only the table 'capital_cities' is available.
elif (
isinstance(m, RetryPromptPart)
and m.tool_name == 'final_result_hand_off_to_sql_agent'
and m.tool_name == 'hand_off_to_sql_agent'
and m.content
== "SQL agent failed: The table 'pets' does not exist in the database. Only the table 'capital_cities' is available."
):
Expand Down