Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,7 @@ impl Agent {
).await?;

let mut no_tools_called = true;
let mut messages_to_add = Conversation::default();
let mut tools_updated = false;
let mut did_recovery_compact_this_iteration = false;
let mut exit_chat = false;
Expand Down Expand Up @@ -1241,8 +1242,7 @@ impl Agent {
if !text.is_empty() {
last_assistant_text = text;
}
session_manager.add_message(&session_config.id, &response).await?;
conversation.push(response);
messages_to_add.push(response);
continue;
}

Expand Down Expand Up @@ -1438,8 +1438,7 @@ impl Agent {
response.created,
thinking_content,
).with_id(format!("msg_{}", Uuid::new_v4()));
session_manager.add_message(&session_config.id, &thinking_msg).await?;
conversation.push(thinking_msg);
messages_to_add.push(thinking_msg);
}

// Collect reasoning content to attach to tool request messages
Expand Down Expand Up @@ -1467,14 +1466,11 @@ impl Agent {
request.metadata.as_ref(),
request.tool_meta.clone(),
);
messages_to_add.push(request_msg);
let final_response = tool_response_messages[idx]
.lock().await.clone();
// Persist the tool request and response as a pair
session_manager.add_message(&session_config.id, &request_msg).await?;
session_manager.add_message(&session_config.id, &final_response).await?;
conversation.push(request_msg);
conversation.push(final_response.clone());
yield AgentEvent::Message(final_response);
yield AgentEvent::Message(final_response.clone());
messages_to_add.push(final_response);
} else {
error!(
"Tool call could not be parsed: {}",
Expand Down Expand Up @@ -1618,14 +1614,12 @@ impl Agent {
Some(None) => {
warn!("Final output tool has not been called yet. Continuing agent loop.");
let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
session_manager.add_message(&session_config.id, &message).await?;
conversation.push(message.clone());
messages_to_add.push(message.clone());
yield AgentEvent::Message(message);
}
Some(Some(output)) => {
let message = Message::assistant().with_text(output);
session_manager.add_message(&session_config.id, &message).await?;
conversation.push(message.clone());
messages_to_add.push(message.clone());
yield AgentEvent::Message(message);
exit_chat = true;
}
Expand All @@ -1637,6 +1631,7 @@ impl Agent {
Ok(should_retry) => {
if should_retry {
info!("Retry logic triggered, restarting agent loop");
messages_to_add = Conversation::default();
session_manager.replace_conversation(&session_config.id, &conversation).await?;
yield AgentEvent::HistoryReplaced(conversation.clone());
} else {
Expand Down Expand Up @@ -1680,14 +1675,17 @@ impl Agent {
}).await?;
}
conversation = Conversation::new_unvalidated(updated_messages);
session_manager.add_message(&session_config.id, &summary_msg).await?;
conversation.push(summary_msg);
messages_to_add.push(summary_msg);
} else {
warn!("Expected a tool request/reply pair, but found {} matching messages",
matching.len());
}
}

for msg in &messages_to_add {
session_manager.add_message(&session_config.id, msg).await?;
}
conversation.extend(messages_to_add);
Comment on lines +1685 to +1688
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Drop buffered messages when retrying from initial conversation

When handle_retry_logic returns should_retry = true, it resets conversation back to initial_messages, but this block still persists and re-appends messages_to_add from the failed attempt. In retry-enabled runs where the model produced non-tool assistant output before failing success checks, those failed outputs are written to session history anyway and included in the next turn, so retries no longer start from a clean state and can drift or bloat context across attempts.

Useful? React with 👍 / 👎.

if exit_chat {
break;
}
Expand Down
270 changes: 270 additions & 0 deletions crates/goose/tests/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,4 +591,274 @@ mod tests {
);
}
}

#[cfg(test)]
mod streaming_persistence_tests {
use super::*;
use async_trait::async_trait;
use goose::agents::{AgentConfig, SessionConfig};
use goose::config::permission::PermissionManager;
use goose::config::GooseMode;
use goose::conversation::message::Message;
use goose::model::ModelConfig;
use goose::providers::base::{
MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage,
};
use goose::providers::errors::ProviderError;
use goose::session::session_manager::SessionType;
use goose::session::SessionManager;
use rmcp::model::{CallToolRequestParams, Role, Tool};
use rmcp::object;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio_util::sync::CancellationToken;

struct MultiStepProvider {
call_count: AtomicUsize,
cancel_token: CancellationToken,
}

impl MultiStepProvider {
fn new(cancel_token: CancellationToken) -> Self {
Self {
call_count: AtomicUsize::new(0),
cancel_token,
}
}
}

impl ProviderDef for MultiStepProvider {
type Provider = Self;

fn metadata() -> ProviderMetadata {
ProviderMetadata {
name: "multi-step-mock".to_string(),
display_name: "Multi-Step Mock".to_string(),
description: "Mock provider for streaming persistence tests".to_string(),
default_model: "mock-model".to_string(),
known_models: vec![],
model_doc_link: "".to_string(),
config_keys: vec![],
setup_steps: vec![],
}
}

fn from_env(
_model: ModelConfig,
_extensions: Vec<goose::config::ExtensionConfig>,
) -> futures::future::BoxFuture<'static, anyhow::Result<Self>> {
unimplemented!()
}
}

#[async_trait]
impl Provider for MultiStepProvider {
async fn stream(
&self,
_model_config: &ModelConfig,
_session_id: &str,
_system_prompt: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let call = self.call_count.fetch_add(1, Ordering::SeqCst);
let usage = ProviderUsage::new(
"mock-model".to_string(),
Usage::new(Some(10), Some(5), Some(15)),
);

match call {
0 => {
let tool_call = CallToolRequestParams::new("test_tool")
.with_arguments(object!({"param": "value"}));
let message =
Message::assistant().with_tool_request("call_1", Ok(tool_call));
let stream =
futures::stream::once(async move { Ok((Some(message), Some(usage))) });
Ok(Box::pin(stream))
}
1 => {
let msg_id = format!("msg_{}", uuid::Uuid::new_v4());
let tokens = vec!["Hello", " world", ", how", " are", " you?"];
let stream = futures::stream::iter(tokens.into_iter().enumerate().map(
move |(i, token)| {
let msg = Message::assistant()
.with_text(token)
.with_id(msg_id.clone());
let u = if i == 4 { Some(usage.clone()) } else { None };
Ok((Some(msg), u))
},
));
Ok(Box::pin(stream))
}
_ => {
let cancel = self.cancel_token.clone();
let msg_id = format!("msg_{}", uuid::Uuid::new_v4());
let tokens = vec!["This ", "should ", "be ", "cancelled ", "soon."];
let stream = futures::stream::iter(tokens.into_iter().enumerate().map(
move |(i, token)| {
if i == 1 {
cancel.cancel();
}
let msg = Message::assistant()
.with_text(token)
.with_id(msg_id.clone());
let u = if i == 4 { Some(usage.clone()) } else { None };
Ok((Some(msg), u))
},
));
Ok(Box::pin(stream))
}
}
}

fn get_model_config(&self) -> ModelConfig {
ModelConfig::new("mock-model").unwrap()
}

fn get_name(&self) -> &str {
"multi-step-mock"
}
}

#[tokio::test]
async fn test_streaming_text_not_persisted_per_token() -> Result<()> {
let cancel_token = CancellationToken::new();
let temp_dir = tempfile::tempdir()?;
let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf()));
let config = AgentConfig::new(
session_manager.clone(),
PermissionManager::instance(),
None,
GooseMode::Auto,
true, // disable session naming so it doesn't consume a provider call
GoosePlatform::GooseCli,
);
let agent = Agent::with_config(config);
let provider = Arc::new(MultiStepProvider::new(cancel_token.clone()));

let session = session_manager
.create_session(
PathBuf::default(),
"streaming-test".to_string(),
SessionType::Hidden,
GooseMode::default(),
)
.await?;

let session_id = session.id.clone();
agent.update_provider(provider, &session_id).await?;

// ── Single reply: tool call (call 0) → text stream (call 1) → cancelled text (call 2)
// max_turns=3 allows all three provider calls within one reply().
// call 0: tool call → agent executes tool, loops
// call 1: 5 text deltas → no tools called, agent exits loop
// call 2: 5 text deltas, cancel token fired after 1st → agent interrupted
//
// Because call 1 ends the agent loop (no_tools_called=true → exit),
// call 2 is NOT reached in the same reply. We issue a second reply()
// with the cancel token so the provider triggers cancellation.
let session_config = SessionConfig {
id: session_id.clone(),
schedule_id: None,
max_turns: Some(2),
retry_config: None,
};

let reply_stream = agent
.reply(
Message::user().with_text("Do something then say hello"),
session_config,
None,
)
.await?;
tokio::pin!(reply_stream);

while let Some(event) = reply_stream.next().await {
match event {
Ok(AgentEvent::Message(_)) => {}
Ok(_) => {}
Err(e) => return Err(e),
}
}

// ── Check persisted state after reply 1 ─────────────────
let reloaded = session_manager.get_session(&session_id, true).await?;
let messages = reloaded
.conversation
.expect("should have conversation")
.messages()
.to_vec();

let user_count = messages.iter().filter(|m| m.role == Role::User).count();
let asst_count = messages
.iter()
.filter(|m| m.role == Role::Assistant)
.count();

// Expected: user(prompt) + assistant(tool-req) + user(tool-resp) + assistant(text)
assert_eq!(
user_count, 2,
"Expected 2 user messages (prompt + tool response), got {user_count}",
);
assert_eq!(
asst_count, 2,
"Expected 2 assistant messages (tool request + text reply), got {asst_count} \
— streaming text deltas are being persisted as separate messages",
);

// ── Reply 2: text stream with provider-triggered cancellation (call 2)
let session_config2 = SessionConfig {
id: session_id.clone(),
schedule_id: None,
max_turns: Some(2),
retry_config: None,
};

let reply_stream2 = agent
.reply(
Message::user().with_text("Tell me more"),
session_config2,
Some(cancel_token),
)
.await?;
tokio::pin!(reply_stream2);

while let Some(event) = reply_stream2.next().await {
match event {
Ok(_) => {}
Err(e) => return Err(e),
}
}

// ── Check persisted state after cancellation ────────────
let reloaded2 = session_manager.get_session(&session_id, true).await?;
let messages2 = reloaded2
.conversation
.expect("should have conversation")
.messages()
.to_vec();

let user_count2 = messages2.iter().filter(|m| m.role == Role::User).count();
let asst_count2 = messages2
.iter()
.filter(|m| m.role == Role::Assistant)
.count();

// Reply 2 added 1 user message. The cancelled stream should
// have persisted at most 1 (partial) assistant message.
assert_eq!(
user_count2, 3,
"Expected 3 user messages (2 from reply 1 + follow-up), got {user_count2}",
);
assert!(
asst_count2 <= 3,
"Expected at most 3 assistant messages (2 from reply 1 + at most 1 partial \
from cancelled reply 2), got {asst_count2} \
— streaming deltas are leaking into persistence",
);

Ok(())
}
}
}
Loading