diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 916ddb7e83..dc41a4a50d 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -211,9 +211,13 @@ def build( tool_output_type = output_type if tool_name is None: - tool_name = default_tool_name - if multiple: - tool_name += f'_{tool_output_type.__name__}' + tool_name = ( + tool_output_type.__name__ + if inspect.isfunction(tool_output_type) + else f'{default_tool_name}_{tool_output_type.__name__}' + if multiple + else default_tool_name + ) i = 1 original_tool_name = tool_name diff --git a/tests/test_agent.py b/tests/test_agent.py index d598481551..5cb5b725a9 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -632,7 +632,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert output_tools == snapshot( [ ToolDefinition( - name='final_result', + name='get_weather', description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, @@ -671,7 +671,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert output_tools == snapshot( [ ToolDefinition( - name='final_result', + name='get_weather', description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, @@ -801,7 +801,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ModelResponse( parts=[ ToolCallPart( - tool_name='final_result', + tool_name='get_weather', args='{"city": "New York City"}', tool_call_id=IsStr(), ) @@ -814,7 +814,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: parts=[ RetryPromptPart( content='City not found, I only know Mexico City', - tool_name='final_result', + tool_name='get_weather', tool_call_id=IsStr(), timestamp=IsDatetime(), ) @@ -823,7 +823,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ModelResponse( parts=[ ToolCallPart( - tool_name='final_result', + tool_name='get_weather', args='{"city": "Mexico City"}', tool_call_id=IsStr(), ) @@ -835,7 +835,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ModelRequest( parts=[ ToolReturnPart( - tool_name='final_result', + tool_name='get_weather', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), @@ -871,7 +871,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert output_tools == snapshot( [ ToolDefinition( - name='final_result', + name='get_weather', description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, @@ -947,7 +947,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert output_tools == snapshot( [ ToolDefinition( - name='final_result_get_weather', + name='get_weather', description='get_weather: The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, @@ -1017,7 +1017,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes ModelResponse( parts=[ ToolCallPart( - tool_name='final_result', + tool_name='handoff', args='{"city": "Mexico City"}', tool_call_id=IsStr(), ) @@ -1029,7 +1029,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes ModelRequest( parts=[ ToolReturnPart( - tool_name='final_result', + tool_name='handoff', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), @@ -1052,7 +1052,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes ModelResponse( parts=[ ToolCallPart( - tool_name='final_result', + tool_name='get_weather', args='{"city": "Mexico City"}', tool_call_id=IsStr(), ) @@ -1064,7 +1064,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes ModelRequest( parts=[ ToolReturnPart( - tool_name='final_result', + tool_name='get_weather', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), diff --git a/tests/test_examples.py b/tests/test_examples.py index ad377bedbf..5bef09b30d 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -406,23 +406,23 @@ async def list_tools() -> list[None]: tool_call_id='pyd_ai_2e0e396768a14fe482df90a29a78dc7b', ), 'Select the names and countries of all capitals': ToolCallPart( - tool_name='final_result_hand_off_to_sql_agent', + tool_name='hand_off_to_sql_agent', args={'query': 'SELECT name, country FROM capitals;'}, ), 'SELECT name, country FROM capitals;': ToolCallPart( - tool_name='final_result_run_sql_query', + tool_name='run_sql_query', args={'query': 'SELECT name, country FROM capitals;'}, ), 'SELECT * FROM capital_cities;': ToolCallPart( - tool_name='final_result_run_sql_query', + tool_name='run_sql_query', args={'query': 'SELECT * FROM capital_cities;'}, ), 'Select all pets': ToolCallPart( - tool_name='final_result_hand_off_to_sql_agent', + tool_name='hand_off_to_sql_agent', args={'query': 'SELECT * FROM pets;'}, ), 'SELECT * FROM pets;': ToolCallPart( - tool_name='final_result_run_sql_query', + tool_name='run_sql_query', args={'query': 'SELECT * FROM pets;'}, ), 'How do I fly from Amsterdam to Mexico City?': ToolCallPart( @@ -614,13 +614,13 @@ async def model_logic( # noqa: C901 ) elif ( isinstance(m, RetryPromptPart) - and m.tool_name == 'final_result_run_sql_query' + and m.tool_name == 'run_sql_query' and m.content == "Only 'SELECT *' is supported, you'll have to do column filtering manually." ): return ModelResponse( parts=[ ToolCallPart( - tool_name='final_result_run_sql_query', + tool_name='run_sql_query', args={'query': 'SELECT * FROM capitals;'}, tool_call_id='pyd_ai_tool_call_id', ) @@ -628,14 +628,14 @@ async def model_logic( # noqa: C901 ) elif ( isinstance(m, RetryPromptPart) - and m.tool_name == 'final_result_hand_off_to_sql_agent' + and m.tool_name == 'hand_off_to_sql_agent' and m.content == "SQL agent failed: Unknown table 'capitals' in query 'SELECT * FROM capitals;'. Available tables: capital_cities." ): return ModelResponse( parts=[ ToolCallPart( - tool_name='final_result_hand_off_to_sql_agent', + tool_name='hand_off_to_sql_agent', args={'query': 'SELECT * FROM capital_cities;'}, tool_call_id='pyd_ai_tool_call_id', ) @@ -643,7 +643,7 @@ async def model_logic( # noqa: C901 ) elif ( isinstance(m, RetryPromptPart) - and m.tool_name == 'final_result_run_sql_query' + and m.tool_name == 'run_sql_query' and m.content == "Unknown table 'pets' in query 'SELECT * FROM pets;'. Available tables: capital_cities." ): return ModelResponse( @@ -660,7 +660,7 @@ async def model_logic( # noqa: C901 # SQL agent failed: The table 'pets' does not exist in the database. Only the table 'capital_cities' is available. elif ( isinstance(m, RetryPromptPart) - and m.tool_name == 'final_result_hand_off_to_sql_agent' + and m.tool_name == 'hand_off_to_sql_agent' and m.content == "SQL agent failed: The table 'pets' does not exist in the database. Only the table 'capital_cities' is available." ):