diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index c9f84e9571ba..59c55422ad61 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -30,7 +30,9 @@ use crate::agents::tool_router_index_manager::ToolRouterIndexManager; use crate::agents::types::SessionConfig; use crate::agents::types::{FrontendTool, SharedProvider, ToolResultReceiver}; use crate::config::{get_enabled_extensions, Config, GooseMode}; -use crate::context_mgmt::DEFAULT_COMPACTION_THRESHOLD; +use crate::context_mgmt::{ + check_if_compaction_needed, compact_messages, DEFAULT_COMPACTION_THRESHOLD, +}; use crate::conversation::{debug_conversation_fix, fix_conversation, Conversation}; use crate::mcp_utils::ToolResult; use crate::permission::permission_inspector::PermissionInspector; @@ -794,8 +796,13 @@ impl Agent { .ok_or_else(|| anyhow::anyhow!("Session {} has no conversation", session_config.id))?; let needs_auto_compact = !is_manual_compact - && crate::context_mgmt::check_if_compaction_needed(self, &conversation, None, &session) - .await?; + && check_if_compaction_needed( + self.provider().await?.as_ref(), + &conversation, + None, + &session, + ) + .await?; let conversation_to_compact = conversation.clone(); @@ -830,7 +837,7 @@ impl Agent { ) ); - match crate::context_mgmt::compact_messages(self, &conversation_to_compact, false).await { + match compact_messages(self.provider().await?.as_ref(), &conversation_to_compact, is_manual_compact).await { Ok((compacted_conversation, summarization_usage)) => { SessionManager::replace_conversation(&session_config.id, &compacted_conversation).await?; Self::update_session_metrics(&session_config, &summarization_usage, true).await?; @@ -1151,7 +1158,7 @@ impl Agent { ) ); - match crate::context_mgmt::compact_messages(self, &conversation, true).await { + match compact_messages(self.provider().await?.as_ref(), &conversation, false).await { Ok((compacted_conversation, usage)) => { SessionManager::replace_conversation(&session_config.id, &compacted_conversation).await?; Self::update_session_metrics(&session_config, &usage, true).await?; diff --git a/crates/goose/src/context_mgmt/mod.rs b/crates/goose/src/context_mgmt/mod.rs index e0ac4368aa81..1b5deb3ed369 100644 --- a/crates/goose/src/context_mgmt/mod.rs +++ b/crates/goose/src/context_mgmt/mod.rs @@ -1,17 +1,32 @@ use crate::conversation::message::MessageMetadata; use crate::conversation::message::{Message, MessageContent}; -use crate::conversation::Conversation; +use crate::conversation::{merge_consecutive_messages, Conversation}; use crate::prompt_template::render_global_file; use crate::providers::base::{Provider, ProviderUsage}; -use crate::{agents::Agent, config::Config, token_counter::create_token_counter}; +use crate::providers::errors::ProviderError; +use crate::{config::Config, token_counter::create_token_counter}; use anyhow::Result; use rmcp::model::Role; use serde::Serialize; -use std::sync::Arc; use tracing::{debug, info}; pub const DEFAULT_COMPACTION_THRESHOLD: f64 = 0.8; +const CONVERSATION_CONTINUATION_TEXT: &str = + "The previous message contains a summary that was prepared because a context limit was reached. +Do not mention that you read a summary or that conversation summarization occurred. +Just continue the conversation naturally based on the summarized context"; + +const TOOL_LOOP_CONTINUATION_TEXT: &str = + "The previous message contains a summary that was prepared because a context limit was reached. +Do not mention that you read a summary or that conversation summarization occurred. +Continue calling tools as necessary to complete the task."; + +const MANUAL_COMPACT_CONTINUATION_TEXT: &str = + "The previous message contains a summary that was prepared at the user's request. +Do not mention that you read a summary or that conversation summarization occurred. +Just continue the conversation naturally based on the summarized context"; + #[derive(Serialize)] struct SummarizeContext { messages: String, @@ -24,18 +39,18 @@ struct SummarizeContext { /// first to determine if compaction is necessary. /// /// # Arguments -/// * `agent` - The agent to use for context management +/// * `provider` - The provider to use for summarization /// * `conversation` - The current conversation history -/// * `preserve_last_user_message` - If true and last message is not a user message, copy the most recent user message to the end +/// * `manual_compact` - If true, this is a manual compaction (don't preserve user message) /// /// # Returns /// * A tuple containing: /// - `Conversation`: The compacted messages /// - `ProviderUsage`: Provider usage from summarization pub async fn compact_messages( - agent: &Agent, + provider: &dyn Provider, conversation: &Conversation, - preserve_last_user_message: bool, + manual_compact: bool, ) -> Result<(Conversation, ProviderUsage)> { info!("Performing message compaction"); @@ -55,7 +70,6 @@ pub async fn compact_messages( has_text && !has_tool_content }; - // Helper function to extract text content from a message let extract_text = |msg: &Message| -> Option { let text_parts: Vec = msg .content @@ -76,62 +90,72 @@ pub async fn compact_messages( } }; - // Check if the most recent message is a user message with text content only - let (messages_to_compact, preserved_user_text) = if let Some(last_message) = messages.last() { - if matches!(last_message.role, rmcp::model::Role::User) && has_text_only(last_message) { - // Remove the last user message before compaction and preserve its text - (&messages[..messages.len() - 1], extract_text(last_message)) - } else if preserve_last_user_message { - // Last message is not a user message with text only, but we want to preserve the most recent user message with text only - // Find the most recent user message with text content only and extract its text - let preserved_text = messages - .iter() - .rev() - .find(|msg| matches!(msg.role, rmcp::model::Role::User) && has_text_only(msg)) - .and_then(extract_text); - (messages.as_slice(), preserved_text) + // Find and preserve the most recent user message for non-manual compacts + let (preserved_user_message, is_most_recent) = if !manual_compact { + let found_msg = messages.iter().enumerate().rev().find(|(_, msg)| { + msg.is_agent_visible() + && matches!(msg.role, rmcp::model::Role::User) + && has_text_only(msg) + }); + + if let Some((idx, msg)) = found_msg { + let is_last = idx == messages.len() - 1; + (Some(msg.clone()), is_last) } else { - (messages.as_slice(), None) + (None, false) } } else { - (messages.as_slice(), None) + (None, false) }; - let provider = agent.provider().await?; - let (summary_message, summarization_usage) = - do_compact(provider.clone(), messages_to_compact).await?; + let messages_to_compact = messages.as_slice(); + + let (summary_message, summarization_usage) = do_compact(provider, messages_to_compact).await?; // Create the final message list with updated visibility metadata: // 1. Original messages become user_visible but not agent_visible // 2. Summary message becomes agent_visible but not user_visible - // 3. Assistant messages to continue the conversation remain both user_visible and agent_visible - + // 3. Assistant messages to continue the conversation are also agent_visible but not user_visible let mut final_messages = Vec::new(); - // Add all original messages with updated visibility (preserve user_visible, set agent_visible=false) - for msg in messages_to_compact.iter().cloned() { - let updated_metadata = msg.metadata.with_agent_invisible(); - let updated_msg = msg.with_metadata(updated_metadata); + for (idx, msg) in messages_to_compact.iter().enumerate() { + let updated_metadata = if is_most_recent + && idx == messages_to_compact.len() - 1 + && preserved_user_message.is_some() + { + // This is the most recent message and we're preserving it by adding a fresh copy + MessageMetadata::invisible() + } else { + msg.metadata.with_agent_invisible() + }; + let updated_msg = msg.clone().with_metadata(updated_metadata); final_messages.push(updated_msg); } - // Add the summary message (agent_visible=true, user_visible=false) let summary_msg = summary_message.with_metadata(MessageMetadata::agent_only()); - final_messages.push(summary_msg); - - // Add an assistant message to continue the conversation (agent_visible=true, user_visible=false) - let assistant_message = Message::assistant() - .with_text( - "The previous message contains a summary that was prepared because a context limit was reached. -Do not mention that you read a summary or that conversation summarization occurred -Just continue the conversation naturally based on the summarized context" - ) + + let mut continuation_messages = vec![summary_msg]; + + let continuation_text = if manual_compact { + MANUAL_COMPACT_CONTINUATION_TEXT + } else if is_most_recent { + CONVERSATION_CONTINUATION_TEXT + } else { + TOOL_LOOP_CONTINUATION_TEXT + }; + + let continuation_msg = Message::assistant() + .with_text(continuation_text) .with_metadata(MessageMetadata::agent_only()); - final_messages.push(assistant_message); + continuation_messages.push(continuation_msg); + + let (merged_continuation, _issues) = merge_consecutive_messages(continuation_messages); + final_messages.extend(merged_continuation); - // Add back the preserved user message if it exists - if let Some(user_text) = preserved_user_text { - final_messages.push(Message::user().with_text(&user_text)); + if let Some(user_msg) = preserved_user_message { + if let Some(text) = extract_text(&user_msg) { + final_messages.push(Message::user().with_text(&text)); + } } Ok(( @@ -142,7 +166,7 @@ Just continue the conversation naturally based on the summarized context" /// Check if messages exceed the auto-compaction threshold pub async fn check_if_compaction_needed( - agent: &Agent, + provider: &dyn Provider, conversation: &Conversation, threshold_override: Option, session: &crate::session::Session, @@ -155,7 +179,6 @@ pub async fn check_if_compaction_needed( .unwrap_or(DEFAULT_COMPACTION_THRESHOLD) }); - let provider = agent.provider().await?; let context_limit = provider.get_model_config().context_limit(); let (current_tokens, token_source) = match session.total_tokens { @@ -196,8 +219,58 @@ pub async fn check_if_compaction_needed( Ok(needs_compaction) } +fn filter_tool_responses<'a>(messages: &[&'a Message], remove_percent: u32) -> Vec<&'a Message> { + fn has_tool_response(msg: &Message) -> bool { + msg.content + .iter() + .any(|c| matches!(c, MessageContent::ToolResponse(_))) + } + + if remove_percent == 0 { + return messages.to_vec(); + } + + let tool_indices: Vec = messages + .iter() + .enumerate() + .filter(|(_, msg)| has_tool_response(msg)) + .map(|(i, _)| i) + .collect(); + + if tool_indices.is_empty() { + return messages.to_vec(); + } + + let num_to_remove = ((tool_indices.len() * remove_percent as usize) / 100).max(1); + + let middle = tool_indices.len() / 2; + let mut indices_to_remove = Vec::new(); + + // Middle out + for i in 0..num_to_remove { + if i % 2 == 0 { + let offset = i / 2; + if middle > offset { + indices_to_remove.push(tool_indices[middle - offset - 1]); + } + } else { + let offset = i / 2; + if middle + offset < tool_indices.len() { + indices_to_remove.push(tool_indices[middle + offset]); + } + } + } + + messages + .iter() + .enumerate() + .filter(|(i, _)| !indices_to_remove.contains(i)) + .map(|(_, msg)| *msg) + .collect() +} + async fn do_compact( - provider: Arc, + provider: &dyn Provider, messages: &[Message], ) -> Result<(Message, ProviderUsage), anyhow::Error> { let agent_visible_messages: Vec<&Message> = messages @@ -205,34 +278,61 @@ async fn do_compact( .filter(|msg| msg.is_agent_visible()) .collect(); - let messages_text = agent_visible_messages - .iter() - .map(|&msg| format_message_for_compacting(msg)) - .collect::>() - .join("\n"); + // Try progressively removing more tool response messages from the middle to reduce context length + let removal_percentages = [0, 10, 20, 50, 100]; - let context = SummarizeContext { - messages: messages_text, - }; + for (attempt, &remove_percent) in removal_percentages.iter().enumerate() { + let filtered_messages = filter_tool_responses(&agent_visible_messages, remove_percent); + + let messages_text = filtered_messages + .iter() + .map(|&msg| format_message_for_compacting(msg)) + .collect::>() + .join("\n"); + + let context = SummarizeContext { + messages: messages_text, + }; - let system_prompt = render_global_file("summarize_oneshot.md", &context)?; + let system_prompt = render_global_file("summarize_oneshot.md", &context)?; - let user_message = Message::user() - .with_text("Please summarize the conversation history provided in the system prompt."); - let summarization_request = vec![user_message]; + let user_message = Message::user() + .with_text("Please summarize the conversation history provided in the system prompt."); + let summarization_request = vec![user_message]; - let (mut response, mut provider_usage) = provider - .complete_fast(&system_prompt, &summarization_request, &[]) - .await?; + match provider + .complete_fast(&system_prompt, &summarization_request, &[]) + .await + { + Ok((mut response, mut provider_usage)) => { + response.role = Role::User; - response.role = Role::User; + provider_usage + .ensure_tokens(&system_prompt, &summarization_request, &response, &[]) + .await + .map_err(|e| anyhow::anyhow!("Failed to ensure usage tokens: {}", e))?; - provider_usage - .ensure_tokens(&system_prompt, &summarization_request, &response, &[]) - .await - .map_err(|e| anyhow::anyhow!("Failed to ensure usage tokens: {}", e))?; + return Ok((response, provider_usage)); + } + Err(e) => { + if matches!(e, ProviderError::ContextLengthExceeded(_)) { + if attempt < removal_percentages.len() - 1 { + continue; + } else { + return Err(anyhow::anyhow!( + "Failed to compact messages: context length still exceeded after {} attempts with maximum removal", + removal_percentages.len() + )); + } + } + return Err(e.into()); + } + } + } - Ok((response, provider_usage)) + Err(anyhow::anyhow!( + "Unexpected: exhausted all attempts without returning" + )) } fn format_message_for_compacting(msg: &Message) -> String { @@ -301,3 +401,158 @@ fn format_message_for_compacting(msg: &Message) -> String { format!("[{}]: {}", role_str, content_parts.join("\n")) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + model::ModelConfig, + providers::{ + base::{ProviderMetadata, Usage}, + errors::ProviderError, + }, + }; + use async_trait::async_trait; + use rmcp::model::{AnnotateAble, CallToolRequestParam, RawContent, Tool}; + + struct MockProvider { + message: Message, + config: ModelConfig, + max_tool_responses: Option, + } + + impl MockProvider { + fn new(message: Message, context_limit: usize) -> Self { + Self { + message, + config: ModelConfig { + model_name: "test".to_string(), + context_limit: Some(context_limit), + temperature: None, + max_tokens: None, + toolshim: false, + toolshim_model: None, + fast_model: None, + }, + max_tool_responses: None, + } + } + + fn with_max_tool_responses(mut self, max: usize) -> Self { + self.max_tool_responses = Some(max); + self + } + } + + #[async_trait] + impl Provider for MockProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::new("mock", "", "", "", vec![""], "", vec![]) + } + + fn get_name(&self) -> &str { + "mock" + } + + async fn complete_with_model( + &self, + _model_config: &ModelConfig, + _system: &str, + messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + // If max_tool_responses is set, fail if we have too many + if let Some(max) = self.max_tool_responses { + let tool_response_count = messages + .iter() + .filter(|m| { + m.content + .iter() + .any(|c| matches!(c, MessageContent::ToolResponse(_))) + }) + .count(); + + if tool_response_count > max { + return Err(ProviderError::ContextLengthExceeded(format!( + "Too many tool responses: {} > {}", + tool_response_count, max + ))); + } + } + + Ok(( + self.message.clone(), + ProviderUsage::new("mock-model".to_string(), Usage::default()), + )) + } + + fn get_model_config(&self) -> ModelConfig { + self.config.clone() + } + } + + #[tokio::test] + async fn test_keeps_tool_request() { + let response_message = Message::assistant().with_text(""); + let provider = MockProvider::new(response_message, 1); + let basic_conversation = vec![ + Message::user().with_text("read hello.txt"), + Message::assistant().with_tool_request( + "tool_0", + Ok(CallToolRequestParam { + name: "read_file".into(), + arguments: None, + }), + ), + Message::user().with_tool_response( + "tool_0", + Ok(vec![RawContent::text("hello, world").no_annotation()]), + ), + ]; + + let conversation = Conversation::new_unvalidated(basic_conversation); + let (compacted_conversation, _usage) = compact_messages(&provider, &conversation, false) + .await + .unwrap(); + + let agent_conversation = compacted_conversation.agent_visible_messages(); + + let _ = Conversation::new(agent_conversation) + .expect("compaction should produce a valid conversation"); + } + + #[tokio::test] + async fn test_progressive_removal_on_context_exceeded() { + let response_message = Message::assistant().with_text(""); + // Set max to 2 tool responses - will trigger progressive removal + let provider = MockProvider::new(response_message, 1000).with_max_tool_responses(2); + + // Create a conversation with many tool responses + let mut messages = vec![Message::user().with_text("start")]; + for i in 0..10 { + messages.push(Message::assistant().with_tool_request( + format!("tool_{}", i), + Ok(CallToolRequestParam { + name: "read_file".into(), + arguments: None, + }), + )); + messages.push(Message::user().with_tool_response( + format!("tool_{}", i), + Ok(vec![ + RawContent::text(format!("response{}", i)).no_annotation(), + ]), + )); + } + + let conversation = Conversation::new_unvalidated(messages); + let result = compact_messages(&provider, &conversation, false).await; + + // Should succeed after progressive removal + assert!( + result.is_ok(), + "Should succeed with progressive removal: {:?}", + result.err() + ); + } +} diff --git a/crates/goose/src/conversation/mod.rs b/crates/goose/src/conversation/mod.rs index b0ae237faa9c..a083dbe86722 100644 --- a/crates/goose/src/conversation/mod.rs +++ b/crates/goose/src/conversation/mod.rs @@ -371,7 +371,7 @@ fn fix_tool_calling(mut messages: Vec) -> (Vec, Vec) { (messages, issues) } -fn merge_consecutive_messages(messages: Vec) -> (Vec, Vec) { +pub fn merge_consecutive_messages(messages: Vec) -> (Vec, Vec) { let mut issues = Vec::new(); let mut merged_messages: Vec = Vec::new();