diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 6e3aca2283..346b16116b 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -156,6 +156,7 @@ def __init__( exit_conditions: list[str] | None = None, state_schema: dict[str, Any] | None = None, max_agent_steps: int = 100, + final_answer_on_max_steps: bool = True, streaming_callback: StreamingCallbackT | None = None, raise_on_tool_invocation_failure: bool = False, tool_invoker_kwargs: dict[str, Any] | None = None, @@ -172,6 +173,10 @@ def __init__( :param state_schema: The schema for the runtime state used by the tools. :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. If the agent exceeds this number of steps, it will stop and return the current state. + :param final_answer_on_max_steps: If True, generates a final text response when max_agent_steps + is reached and the last message is a tool result. This ensures the agent always returns a + natural language response instead of raw tool output. Adds one additional LLM call that doesn't + count toward max_agent_steps. Defaults to True. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? @@ -214,6 +219,7 @@ def __init__( self.system_prompt = system_prompt self.exit_conditions = exit_conditions self.max_agent_steps = max_agent_steps + self.final_answer_on_max_steps = final_answer_on_max_steps self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure self.streaming_callback = streaming_callback @@ -269,6 +275,7 @@ def to_dict(self) -> dict[str, Any]: # We serialize the original state schema, not the resolved one to reflect the original user input state_schema=_schema_to_dict(self._state_schema), max_agent_steps=self.max_agent_steps, + final_answer_on_max_steps=self.final_answer_on_max_steps, streaming_callback=serialize_callable(self.streaming_callback) if self.streaming_callback else None, raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure, tool_invoker_kwargs=self.tool_invoker_kwargs, @@ -474,6 +481,90 @@ def _runtime_checks(self, break_point: AgentBreakpoint | None) -> None: if break_point and isinstance(break_point.break_point, ToolBreakpoint): _validate_tool_breakpoint_is_valid(agent_breakpoint=break_point, tools=self.tools) + def _generate_final_answer(self, exe_context: _ExecutionContext, span) -> None: + """Generate a final text response when max steps is reached with a tool result as last message.""" + if not self.final_answer_on_max_steps or not exe_context.state.data.get("messages"): + return + + last_msg = exe_context.state.data["messages"][-1] + if not last_msg.tool_call_result: + return + + try: + logger.info("Generating final text response after max steps reached.") + + # Add system message for context + final_prompt = ChatMessage.from_system( + "You have reached the maximum number of reasoning steps. " + "Based on the information gathered so far, provide a final answer " + "to the user's question. Tools are no longer available." + ) + + # Make final call with tools disabled + final_inputs = {k: v for k, v in exe_context.chat_generator_inputs.items() if k != "tools"} + final_result = self.chat_generator.run( + messages=exe_context.state.data["messages"] + [final_prompt], tools=[], **final_inputs + ) + + # Append final response + if final_result and "replies" in final_result: + for msg in final_result["replies"]: + exe_context.state.data["messages"].append(msg) + + span.set_tag("haystack.agent.final_answer_generated", True) + + except Exception as e: + logger.warning( + "Failed to generate final answer: {error}. Returning with tool result as last message.", error=str(e) + ) + span.set_tag("haystack.agent.final_answer_failed", True) + + async def _generate_final_answer_async(self, exe_context: _ExecutionContext, span) -> None: + """ + Async version: Generate a final text response when max steps is reached with tool result as last message. + """ + if not self.final_answer_on_max_steps or not exe_context.state.data.get("messages"): + return + + last_msg = exe_context.state.data["messages"][-1] + if not last_msg.tool_call_result: + return + + try: + logger.info("Generating final text response after max steps reached.") + + # Add system message for context + final_prompt = ChatMessage.from_system( + "You have reached the maximum number of reasoning steps. " + "Based on the information gathered so far, provide a final answer " + "to the user's question. Tools are no longer available." + ) + + # Make final call with tools disabled using AsyncPipeline + final_inputs = {k: v for k, v in exe_context.chat_generator_inputs.items() if k != "tools"} + final_inputs["tools"] = [] + + final_result = await AsyncPipeline._run_component_async( + component_name="chat_generator", + component={"instance": self.chat_generator}, + component_inputs={"messages": exe_context.state.data["messages"] + [final_prompt], **final_inputs}, + component_visits=exe_context.component_visits, + parent_span=span, + ) + + # Append final response + if final_result and "replies" in final_result: + for msg in final_result["replies"]: + exe_context.state.data["messages"].append(msg) + + span.set_tag("haystack.agent.final_answer_generated", True) + + except Exception as e: + logger.warning( + "Failed to generate final answer: {error}. Returning with tool result as last message.", error=str(e) + ) + span.set_tag("haystack.agent.final_answer_failed", True) + def run( # noqa: PLR0915 self, messages: list[ChatMessage], @@ -517,8 +608,12 @@ def run( # noqa: PLR0915 agent_inputs = { "messages": messages, "streaming_callback": streaming_callback, + "generation_kwargs": generation_kwargs, "break_point": break_point, "snapshot": snapshot, + "system_prompt": system_prompt, + "tools": tools, + "snapshot_callback": snapshot_callback, **kwargs, } self._runtime_checks(break_point=break_point) @@ -660,6 +755,8 @@ def run( # noqa: PLR0915 "Agent reached maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps, ) + self._generate_final_answer(exe_context, span) + span.set_content_tag("haystack.agent.output", exe_context.state.data) span.set_tag("haystack.agent.steps_taken", exe_context.counter) @@ -714,8 +811,12 @@ async def run_async( agent_inputs = { "messages": messages, "streaming_callback": streaming_callback, + "generation_kwargs": generation_kwargs, "break_point": break_point, "snapshot": snapshot, + "system_prompt": system_prompt, + "tools": tools, + "snapshot_callback": snapshot_callback, **kwargs, } self._runtime_checks(break_point=break_point) @@ -845,6 +946,8 @@ async def run_async( "Agent reached maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps, ) + await self._generate_final_answer_async(exe_context, span) + span.set_content_tag("haystack.agent.output", exe_context.state.data) span.set_tag("haystack.agent.steps_taken", exe_context.counter) diff --git a/releasenotes/notes/agent-final-answer-on-max-steps-a1b2c3d4e5f6g7h8.yaml b/releasenotes/notes/agent-final-answer-on-max-steps-a1b2c3d4e5f6g7h8.yaml new file mode 100644 index 0000000000..fa3afa5edf --- /dev/null +++ b/releasenotes/notes/agent-final-answer-on-max-steps-a1b2c3d4e5f6g7h8.yaml @@ -0,0 +1,8 @@ +--- +enhancements: + - | + Add `final_answer_on_max_steps` parameter to Agent component. When enabled (default: True), + the agent will generate a final natural language response if it reaches max_agent_steps with + a tool result as the last message. This ensures the agent always returns a user-friendly text + response instead of raw tool output, improving user experience when step limits are reached. + The feature adds one additional LLM call that doesn't count toward max_agent_steps. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index aa89c946e9..08cde6a1e7 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -751,6 +751,72 @@ def test_exceed_max_steps(self, monkeypatch, weather_tool, caplog): agent.run([ChatMessage.from_user("Hello")]) assert "Agent reached maximum agent steps" in caplog.text + def test_final_answer_on_max_steps_enabled(self, monkeypatch, weather_tool): + """Test that final answer is generated when max steps is reached with tool result as last message.""" + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + generator = OpenAIChatGenerator() + + # Mock responses: first returns tool call, then after tools run, we hit max steps + agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=1, final_answer_on_max_steps=True) + agent.warm_up() + + call_count = 0 + + def mock_run(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call: LLM wants to call tool + return { + "replies": [ + ChatMessage.from_assistant( + tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] + ) + ] + } + else: + # Final answer call (no tools available) + return {"replies": [ChatMessage.from_assistant("Based on the weather data, it's 20C in Berlin.")]} + + agent.chat_generator.run = mock_run + + result = agent.run([ChatMessage.from_user("What's the weather in Berlin?")]) + + # Last message should be text response, not tool result + assert result["last_message"].text + assert "Berlin" in result["last_message"].text + + def test_final_answer_on_max_steps_disabled(self, monkeypatch, weather_tool): + """Test that no final answer is generated when final_answer_on_max_steps=False.""" + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + generator = OpenAIChatGenerator() + + agent = Agent( + chat_generator=generator, tools=[weather_tool], max_agent_steps=1, final_answer_on_max_steps=False + ) + agent.warm_up() + + call_count = 0 + + def mock_run(*args, **kwargs): + nonlocal call_count + call_count += 1 + # Always return tool call to ensure we'd end with tool result + return { + "replies": [ + ChatMessage.from_assistant( + tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] + ) + ] + } + + agent.chat_generator.run = mock_run + + agent.run([ChatMessage.from_user("What's the weather?")]) + + # Should have ended without final answer call (only 1 LLM call, not 2) + assert call_count == 1 + def test_exit_conditions_checked_across_all_llm_messages(self, monkeypatch, weather_tool): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") generator = OpenAIChatGenerator()