Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 6 additions & 20 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,24 +837,6 @@ impl Agent {

let num_tool_requests = frontend_requests.len() + remaining_requests.len();
if num_tool_requests == 0 {
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
if final_output_tool.final_output.is_none() {
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
messages.push(message.clone());
yield AgentEvent::Message(message);
continue;
} else {
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
messages.push(message.clone());
yield AgentEvent::Message(message);
// Set added_message to true and continue to end the current iteration
added_message = true;
push_message(&mut messages, response);
continue;
}
}
// If there's no final output tool and no tool requests, continue the loop
continue;
}

Expand Down Expand Up @@ -1039,10 +1021,14 @@ impl Agent {
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
if final_output_tool.final_output.is_none() {
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
yield AgentEvent::Message(Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE));
let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
messages.push(message.clone());
yield AgentEvent::Message(message);
continue;
} else {
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()));
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
messages.push(message.clone());
yield AgentEvent::Message(message);
}
}
break;
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/agents/final_output_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use serde_json::Value;

pub const FINAL_OUTPUT_TOOL_NAME: &str = "recipe__final_output";
pub const FINAL_OUTPUT_CONTINUATION_MESSAGE: &str =
"I see I MUST call the `final_output` tool NOW with the final output for the user.";
"You MUST call the `final_output` tool NOW with the final output for the user.";

pub struct FinalOutputTool {
pub response: Response,
Expand Down
124 changes: 123 additions & 1 deletion crates/goose/tests/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,11 @@ mod schedule_tool_tests {
#[cfg(test)]
mod final_output_tool_tests {
use super::*;
use goose::agents::final_output_tool::FINAL_OUTPUT_TOOL_NAME;
use futures::stream;
use goose::agents::final_output_tool::{
FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME,
};
use goose::providers::base::MessageStream;
use goose::recipe::Response;

#[tokio::test]
Expand Down Expand Up @@ -637,6 +641,124 @@ mod final_output_tool_tests {

Ok(())
}

#[tokio::test]
async fn test_when_final_output_not_called_in_reply() -> Result<()> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

use async_trait::async_trait;
use goose::model::ModelConfig;
use goose::providers::base::{Provider, ProviderUsage};
use goose::providers::errors::ProviderError;
use mcp_core::tool::Tool;

#[derive(Clone)]
struct MockProvider {
model_config: ModelConfig,
}

#[async_trait]
impl Provider for MockProvider {
fn metadata() -> goose::providers::base::ProviderMetadata {
goose::providers::base::ProviderMetadata::empty()
}

fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}

fn supports_streaming(&self) -> bool {
true
}

async fn stream(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let deltas = vec![
Ok((Some(Message::assistant().with_text("Hello")), None)),
Ok((Some(Message::assistant().with_text("Hi!")), None)),
Ok((
Some(Message::assistant().with_text("What is the final output?")),
None,
)),
];

let stream = stream::iter(deltas.into_iter());
Ok(Box::pin(stream))
}

async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Err(ProviderError::NotImplemented("Not implemented".to_string()))
}
}

let agent = Agent::new();

let model_config = ModelConfig::new("test-model".to_string());
let mock_provider = Arc::new(MockProvider { model_config });
agent.update_provider(mock_provider).await?;

let response = Response {
json_schema: Some(serde_json::json!({
"type": "object",
"properties": {
"result": {"type": "string"}
},
"required": ["result"]
})),
};
agent.add_final_output_tool(response).await;

// Simulate the reply stream being called.
let reply_stream = agent.reply(&vec![], None).await?;
tokio::pin!(reply_stream);

let mut responses = Vec::new();
let mut count = 0;
while let Some(response_result) = reply_stream.next().await {
match response_result {
Ok(AgentEvent::Message(response)) => {
responses.push(response);
count += 1;
if count >= 4 {
// Limit to 4 messages to avoid infinite loop due to mock provider
break;
}
}
Ok(_) => {}
Err(e) => return Err(e),
}
}

assert!(!responses.is_empty(), "Should have received responses");
println!("Responses: {:?}", responses);
let last_message = responses.last().unwrap();

// Check that the first 3 messages do not have FINAL_OUTPUT_CONTINUATION_MESSAGE
for (i, response) in responses.iter().take(3).enumerate() {
let message_text = response.as_concat_text();
assert_ne!(
message_text,
FINAL_OUTPUT_CONTINUATION_MESSAGE,
"Message {} should not be the continuation message, got: '{}'",
i + 1,
message_text
);
}

// Check that the last message after the llm stream is the message directing the agent to continue
assert_eq!(last_message.role, mcp_core::role::Role::User);
let message_text = last_message.as_concat_text();
assert_eq!(message_text, FINAL_OUTPUT_CONTINUATION_MESSAGE);

Ok(())
}
}

#[cfg(test)]
Expand Down
Loading