diff --git a/Cargo.lock b/Cargo.lock index b4f3c4aae229..a1fb6dc850ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3282,7 +3282,7 @@ dependencies = [ [[package]] name = "goose" -version = "1.3.0" +version = "1.4.1" dependencies = [ "ahash", "anyhow", @@ -3356,7 +3356,7 @@ dependencies = [ [[package]] name = "goose-bench" -version = "1.3.0" +version = "1.4.1" dependencies = [ "anyhow", "async-trait", @@ -3379,7 +3379,7 @@ dependencies = [ [[package]] name = "goose-cli" -version = "1.3.0" +version = "1.4.1" dependencies = [ "anstream", "anyhow", @@ -3430,7 +3430,7 @@ dependencies = [ [[package]] name = "goose-mcp" -version = "1.3.0" +version = "1.4.1" dependencies = [ "anyhow", "async-trait", @@ -3482,7 +3482,7 @@ dependencies = [ [[package]] name = "goose-server" -version = "1.3.0" +version = "1.4.1" dependencies = [ "anyhow", "async-trait", @@ -3519,7 +3519,7 @@ dependencies = [ [[package]] name = "goose-test" -version = "1.3.0" +version = "1.4.1" dependencies = [ "clap", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 35e756ec7f00..2689ba7cff2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ resolver = "2" [workspace.package] edition = "2021" -version = "1.3.0" +version = "1.4.1" authors = ["Block "] license = "Apache-2.0" repository = "https://github.com/block/goose" diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index 91bfcb514e51..6f175d57e3ae 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -620,7 +620,7 @@ async fn process_message_streaming( // For now, auto-summarize in web mode // TODO: Implement proper UI for context handling - let (summarized_messages, _) = + let (summarized_messages, _, _) = agent.summarize_context(messages.messages()).await?; { let mut session_msgs = session_messages.lock().await; diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 5b9b75f7f922..9ab1fb971381 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -162,7 +162,7 @@ impl Session { message_suffix: &str, ) -> Result<()> { // Summarize messages to fit within context length - let (summarized_messages, _) = agent.summarize_context(messages.messages()).await?; + let (summarized_messages, _, _) = agent.summarize_context(messages.messages()).await?; let msg = format!("Context maxed out\n{}\n{}", "-".repeat(50), message_suffix); output::render_text(&msg, Some(Color::Yellow), true); *messages = summarized_messages; @@ -719,7 +719,7 @@ impl Session { let provider = self.agent.provider().await?; // Call the summarize_context method which uses the summarize_messages function - let (summarized_messages, _) = self + let (summarized_messages, _token_counts, summarization_usage) = self .agent .summarize_context(self.messages.messages()) .await?; @@ -727,7 +727,7 @@ impl Session { // Update the session messages with the summarized ones self.messages = summarized_messages; - // Persist the summarized messages + // Persist the summarized messages and update session metadata with new token counts if let Some(session_file) = &self.session_file { let working_dir = std::env::current_dir().ok(); session::persist_messages_with_schedule_id( @@ -738,6 +738,46 @@ impl Session { working_dir, ) .await?; + + // Update session metadata with the new token counts from summarization + if let Some(usage) = summarization_usage { + let session_file_path = session::storage::get_path( + session::storage::Identifier::Path(session_file.to_path_buf()), + )?; + let mut metadata = + session::storage::read_metadata(&session_file_path)?; + + // Update token counts with the summarization usage + // Use output tokens as total since that's what's actually in the context going forward + let summary_tokens = usage.usage.output_tokens.unwrap_or(0); + metadata.total_tokens = Some(summary_tokens); + metadata.input_tokens = None; // Clear input tokens since we now have a summary + metadata.output_tokens = Some(summary_tokens); + metadata.message_count = self.messages.len(); + + // Update accumulated tokens (add the summarization cost) + let accumulate = |a: Option, b: Option| -> Option { + match (a, b) { + (Some(x), Some(y)) => Some(x + y), + _ => a.or(b), + } + }; + metadata.accumulated_total_tokens = accumulate( + metadata.accumulated_total_tokens, + usage.usage.total_tokens, + ); + metadata.accumulated_input_tokens = accumulate( + metadata.accumulated_input_tokens, + usage.usage.input_tokens, + ); + metadata.accumulated_output_tokens = accumulate( + metadata.accumulated_output_tokens, + usage.usage.output_tokens, + ); + + session::storage::update_metadata(&session_file_path, &metadata) + .await?; + } } output::hide_thinking(); diff --git a/crates/goose-server/src/routes/context.rs b/crates/goose-server/src/routes/context.rs index eeff7fa452db..e64f3fa6b635 100644 --- a/crates/goose-server/src/routes/context.rs +++ b/crates/goose-server/src/routes/context.rs @@ -67,7 +67,7 @@ async fn manage_context( .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; } else if request.manage_action == "summarize" { - (processed_messages, token_counts) = agent + (processed_messages, token_counts, _) = agent .summarize_context(&request.messages) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index dbc74d009c32..b553cbcc2836 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -122,11 +122,3 @@ path = "examples/agent.rs" [[example]] name = "databricks_oauth" path = "examples/databricks_oauth.rs" - -[[example]] -name = "async_token_counter_demo" -path = "examples/async_token_counter_demo.rs" - -[[bench]] -name = "tokenization_benchmark" -harness = false diff --git a/crates/goose/benches/tokenization_benchmark.rs b/crates/goose/benches/tokenization_benchmark.rs deleted file mode 100644 index 85be9a0ea32f..000000000000 --- a/crates/goose/benches/tokenization_benchmark.rs +++ /dev/null @@ -1,70 +0,0 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use goose::token_counter::TokenCounter; - -fn benchmark_tokenization(c: &mut Criterion) { - let lengths = [1_000, 5_000, 10_000, 50_000, 100_000, 124_000, 200_000]; - - // Create a single token counter using the fixed o200k_base encoding - let counter = TokenCounter::new(); // Uses fixed o200k_base encoding - - for &length in &lengths { - let text = "hello ".repeat(length); - c.bench_function(&format!("o200k_base_{}_tokens", length), |b| { - b.iter(|| counter.count_tokens(black_box(&text))) - }); - } -} - -fn benchmark_async_tokenization(c: &mut Criterion) { - let rt = tokio::runtime::Runtime::new().unwrap(); - let lengths = [1_000, 5_000, 10_000, 50_000, 100_000, 124_000, 200_000]; - - // Create an async token counter - let counter = rt.block_on(async { - goose::token_counter::create_async_token_counter() - .await - .unwrap() - }); - - for &length in &lengths { - let text = "hello ".repeat(length); - c.bench_function(&format!("async_o200k_base_{}_tokens", length), |b| { - b.iter(|| counter.count_tokens(black_box(&text))) - }); - } -} - -fn benchmark_cache_performance(c: &mut Criterion) { - let rt = tokio::runtime::Runtime::new().unwrap(); - - // Create an async token counter for cache testing - let counter = rt.block_on(async { - goose::token_counter::create_async_token_counter() - .await - .unwrap() - }); - - let test_texts = vec![ - "This is a test sentence for cache performance.", - "Another different sentence to test caching.", - "A third unique sentence for the benchmark.", - "This is a test sentence for cache performance.", // Repeat first one - "Another different sentence to test caching.", // Repeat second one - ]; - - c.bench_function("cache_hit_miss_pattern", |b| { - b.iter(|| { - for text in &test_texts { - counter.count_tokens(black_box(text)); - } - }) - }); -} - -criterion_group!( - benches, - benchmark_tokenization, - benchmark_async_tokenization, - benchmark_cache_performance -); -criterion_main!(benches); diff --git a/crates/goose/examples/async_token_counter_demo.rs b/crates/goose/examples/async_token_counter_demo.rs deleted file mode 100644 index 45aee116a505..000000000000 --- a/crates/goose/examples/async_token_counter_demo.rs +++ /dev/null @@ -1,98 +0,0 @@ -/// Demo showing the async token counter improvement -/// -/// This example demonstrates the key improvement: no blocking runtime creation -/// -/// BEFORE (blocking): -/// ```rust -/// let content = tokio::runtime::Runtime::new()?.block_on(async { -/// let response = reqwest::get(&file_url).await?; -/// // ... download logic -/// })?; -/// ``` -/// -/// AFTER (async): -/// ```rust -/// let client = reqwest::Client::new(); -/// let response = client.get(&file_url).send().await?; -/// let bytes = response.bytes().await?; -/// tokio::fs::write(&file_path, bytes).await?; -/// ``` -use goose::token_counter::{create_async_token_counter, TokenCounter}; -use std::time::Instant; - -#[tokio::main] -async fn main() -> Result<(), Box> { - println!("πŸš€ Async Token Counter Demo"); - println!("==========================="); - - // Test text samples - let samples = vec![ - "Hello, world!", - "This is a longer text sample for tokenization testing.", - "The quick brown fox jumps over the lazy dog.", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - "async/await patterns eliminate blocking operations", - ]; - - println!("\nπŸ“Š Performance Comparison"); - println!("-------------------------"); - - // Test original TokenCounter - let start = Instant::now(); - let sync_counter = TokenCounter::new(); - let sync_init_time = start.elapsed(); - - let start = Instant::now(); - let mut sync_total = 0; - for sample in &samples { - sync_total += sync_counter.count_tokens(sample); - } - let sync_count_time = start.elapsed(); - - println!("πŸ”΄ Synchronous TokenCounter:"); - println!(" Init time: {:?}", sync_init_time); - println!(" Count time: {:?}", sync_count_time); - println!(" Total tokens: {}", sync_total); - - // Test AsyncTokenCounter - let start = Instant::now(); - let async_counter = create_async_token_counter().await?; - let async_init_time = start.elapsed(); - - let start = Instant::now(); - let mut async_total = 0; - for sample in &samples { - async_total += async_counter.count_tokens(sample); - } - let async_count_time = start.elapsed(); - - println!("\n🟒 Async TokenCounter:"); - println!(" Init time: {:?}", async_init_time); - println!(" Count time: {:?}", async_count_time); - println!(" Total tokens: {}", async_total); - println!(" Cache size: {}", async_counter.cache_size()); - - // Test caching benefit - let start = Instant::now(); - let mut cached_total = 0; - for sample in &samples { - cached_total += async_counter.count_tokens(sample); // Should hit cache - } - let cached_time = start.elapsed(); - - println!("\n⚑ Cached TokenCounter (2nd run):"); - println!(" Count time: {:?}", cached_time); - println!(" Total tokens: {}", cached_total); - println!(" Cache size: {}", async_counter.cache_size()); - - // Verify same results - assert_eq!(sync_total, async_total); - assert_eq!(async_total, cached_total); - - println!( - " Token result caching: {}x faster on cached text", - async_count_time.as_nanos() / cached_time.as_nanos().max(1) - ); - - Ok(()) -} diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 927e011e3f01..a093cf314eb6 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -843,7 +843,13 @@ impl Agent { &self, messages: &[Message], session: &Option, - ) -> Result> { + ) -> Result< + Option<( + Conversation, + String, + Option, + )>, + > { // Try to get session metadata for more accurate token counts let session_metadata = if let Some(session_config) = session { match session::storage::get_path(session_config.id.clone()) { @@ -865,21 +871,23 @@ impl Agent { if compact_result.compacted { let compacted_messages = compact_result.messages; - // Create compaction notification message - let compaction_msg = if let (Some(before), Some(after)) = - (compact_result.tokens_before, compact_result.tokens_after) - { - format!( - "Auto-compacted context: {} β†’ {} tokens ({:.0}% reduction)\n\n", - before, - after, - (1.0 - (after as f64 / before as f64)) * 100.0 - ) - } else { - "Auto-compacted context to reduce token usage\n\n".to_string() - }; + // Get threshold from config to include in message + let config = crate::config::Config::global(); + let threshold = config + .get_param::("GOOSE_AUTO_COMPACT_THRESHOLD") + .unwrap_or(0.8); // Default to 80% + let threshold_percentage = (threshold * 100.0) as u32; + + let compaction_msg = format!( + "Exceeded auto-compact threshold of {}%. Context has been summarized and reduced.\n\n", + threshold_percentage + ); - return Ok(Some((compacted_messages, compaction_msg))); + return Ok(Some(( + compacted_messages, + compaction_msg, + compact_result.summarization_usage, + ))); } Ok(None) @@ -893,16 +901,16 @@ impl Agent { cancel_token: Option, ) -> Result>> { // Handle auto-compaction before processing - let (messages, compaction_msg) = match self + let (messages, compaction_msg, _summarization_usage) = match self .handle_auto_compaction(unfixed_conversation.messages(), &session) .await? { - Some((compacted_messages, msg)) => (compacted_messages, Some(msg)), + Some((compacted_messages, msg, usage)) => (compacted_messages, Some(msg), usage), None => { let context = self .prepare_reply_context(unfixed_conversation, &session) .await?; - (context.messages, None) + (context.messages, None, None) } }; diff --git a/crates/goose/src/agents/context.rs b/crates/goose/src/agents/context.rs index 9bce12353028..c9aa367c9c50 100644 --- a/crates/goose/src/agents/context.rs +++ b/crates/goose/src/agents/context.rs @@ -4,7 +4,7 @@ use crate::conversation::message::Message; use crate::conversation::Conversation; use crate::token_counter::create_async_token_counter; -use crate::context_mgmt::summarize::summarize_messages_async; +use crate::context_mgmt::summarize::summarize_messages; use crate::context_mgmt::truncate::{truncate_messages, OldestFirstTruncation}; use crate::context_mgmt::{estimate_target_context_limit, get_messages_token_counts_async}; @@ -49,40 +49,53 @@ impl Agent { } /// Public API to summarize the conversation so that its token count is within the allowed context limit. + /// Returns the summarized messages, token counts, and the ProviderUsage from summarization pub async fn summarize_context( &self, messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded - ) -> Result<(Conversation, Vec), anyhow::Error> { + ) -> Result< + ( + Conversation, + Vec, + Option, + ), + anyhow::Error, + > { let provider = self.provider().await?; - let token_counter = create_async_token_counter() - .await - .map_err(|e| anyhow::anyhow!("Failed to create token counter: {}", e))?; - let target_context_limit = estimate_target_context_limit(provider.clone()); + let summary_result = summarize_messages(provider.clone(), messages).await?; - let (mut new_messages, mut new_token_counts) = - summarize_messages_async(provider, messages, &token_counter, target_context_limit) - .await?; + let (mut new_messages, mut new_token_counts, summarization_usage) = match summary_result { + Some((summary_message, provider_usage)) => { + // For token counting purposes, we use the output tokens (the actual summary content) + // since that's what will be in the context going forward + let total_tokens = provider_usage.usage.output_tokens.unwrap_or(0) as usize; + ( + vec![summary_message], + vec![total_tokens], + Some(provider_usage), + ) + } + None => { + // No summary was generated (empty input) + tracing::warn!("Summarization failed. Returning empty messages."); + return Ok((Conversation::empty(), vec![], None)); + } + }; - // If the summarized messages only contains one message, it means no tool request and response message in the summarized messages, // Add an assistant message to the summarized messages to ensure the assistant's response is included in the context. if new_messages.len() == 1 { let assistant_message = Message::assistant().with_text( - "I had run into a context length exceeded error so I summarized our conversation.", + "I ran into a context length exceeded error so I summarized our conversation.", ); - let assistant_tokens = - token_counter.count_chat_tokens("", &[assistant_message.clone()], &[]); - - let current_total: usize = new_token_counts.iter().sum(); - if current_total + assistant_tokens <= target_context_limit { - new_messages.push(assistant_message); - new_token_counts.push(assistant_tokens); - } else { - // If we can't fit the assistant message, at least log what happened - tracing::warn!("Cannot add summarization notice message due to context limits. Current: {}, Assistant: {}, Limit: {}", - current_total, assistant_tokens, target_context_limit); - } + let assistant_message_tokens: usize = 14; + new_messages.push(assistant_message); + new_token_counts.push(assistant_message_tokens); } - Ok((new_messages, new_token_counts)) + Ok(( + Conversation::new_unvalidated(new_messages), + new_token_counts, + summarization_usage, + )) } } diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index e27c09dedb95..369727c469e7 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -109,6 +109,7 @@ async fn child_process_client( mut command: Command, timeout: &Option, ) -> ExtensionResult { + #[cfg(unix)] command.process_group(0); let (transport, mut stderr) = TokioChildProcess::builder(command) .stderr(Stdio::piped()) diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 085e00b1430e..db302add6ec8 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -15,6 +15,7 @@ use crate::providers::toolshim::{ augment_message_with_tool_calls, convert_tool_messages_to_text, modify_system_prompt_for_tool_json, OllamaInterpreter, }; + use crate::session; use rmcp::model::Tool; @@ -131,10 +132,20 @@ impl Agent { }; // Call the provider to get a response - let (mut response, usage) = provider + let (mut response, mut usage) = provider .complete(system_prompt, messages_for_provider.messages(), tools) .await?; + // Ensure we have token counts, estimating if necessary + usage + .ensure_tokens( + system_prompt, + messages_for_provider.messages(), + &response, + tools, + ) + .await?; + crate::providers::base::set_current_model(&usage.model); if config.toolshim { @@ -177,13 +188,24 @@ impl Agent { ) .await? } else { - let (message, usage) = provider + let (message, mut usage) = provider .complete( system_prompt.as_str(), messages_for_provider.messages(), &tools, ) .await?; + + // Ensure we have token counts for non-streaming case + usage + .ensure_tokens( + system_prompt.as_str(), + messages_for_provider.messages(), + &message, + &tools, + ) + .await?; + stream_from_single_message(message, usage) }; diff --git a/crates/goose/src/context_mgmt/auto_compact.rs b/crates/goose/src/context_mgmt/auto_compact.rs index 6e31aecf9ac8..25e427da4598 100644 --- a/crates/goose/src/context_mgmt/auto_compact.rs +++ b/crates/goose/src/context_mgmt/auto_compact.rs @@ -1,12 +1,7 @@ use crate::conversation::message::Message; use crate::conversation::Conversation; use crate::{ - agents::Agent, - config::Config, - context_mgmt::{ - common::{SYSTEM_PROMPT_TOKEN_OVERHEAD, TOOLS_TOKEN_OVERHEAD}, - get_messages_token_counts_async, - }, + agents::Agent, config::Config, context_mgmt::get_messages_token_counts_async, token_counter::create_async_token_counter, }; use anyhow::Result; @@ -19,10 +14,9 @@ pub struct AutoCompactResult { pub compacted: bool, /// The messages after potential compaction pub messages: Conversation, - /// Token count before compaction (if compaction occurred) - pub tokens_before: Option, - /// Token count after compaction (if compaction occurred) - pub tokens_after: Option, + /// Provider usage from summarization (if compaction occurred) + /// This contains the actual token counts after compaction + pub summarization_usage: Option, } /// Result of checking if compaction is needed @@ -126,47 +120,6 @@ pub async fn check_compaction_needed( }) } -/// Perform compaction on messages -/// -/// This function performs the actual compaction using the agent's summarization -/// capabilities. It assumes compaction is needed and should be called after -/// `check_compaction_needed` confirms it's necessary. -/// -/// # Arguments -/// * `agent` - The agent to use for context management -/// * `messages` - The current message history to compact -/// -/// # Returns -/// * Tuple of (compacted_messages, tokens_before, tokens_after) -pub async fn perform_compaction( - agent: &Agent, - messages: &[Message], -) -> Result<(Conversation, usize, usize)> { - // Get token counter to measure before/after - let token_counter = create_async_token_counter() - .await - .map_err(|e| anyhow::anyhow!("Failed to create token counter: {}", e))?; - - // Calculate tokens before compaction - let token_counts_before = get_messages_token_counts_async(&token_counter, messages); - let tokens_before: usize = token_counts_before.iter().sum(); - - info!("Performing compaction on {} tokens", tokens_before); - - // Perform compaction - let (compacted_messages, compacted_token_counts) = agent.summarize_context(messages).await?; - let tokens_after: usize = compacted_token_counts.iter().sum(); - - info!( - "Compaction complete: {} tokens -> {} tokens ({:.1}% reduction)", - tokens_before, - tokens_after, - (1.0 - (tokens_after as f64 / tokens_before as f64)) * 100.0 - ); - - Ok((compacted_messages, tokens_before, tokens_after)) -} - /// Check if messages need compaction and compact them if necessary /// /// This is a convenience wrapper function that combines checking and compaction. @@ -201,8 +154,7 @@ pub async fn check_and_compact_messages( return Ok(AutoCompactResult { compacted: false, messages: Conversation::new_unvalidated(messages.to_vec()), - tokens_before: None, - tokens_after: None, + summarization_usage: None, }); } @@ -225,8 +177,8 @@ pub async fn check_and_compact_messages( }; // Perform the compaction on messages excluding the preserved user message - let (mut compacted_messages, tokens_before, tokens_after) = - perform_compaction(agent, messages_to_compact).await?; + let (mut compacted_messages, _, summarization_usage) = + agent.summarize_context(messages_to_compact).await?; // Add back the preserved user message if it exists if let Some(user_message) = preserved_user_message { @@ -236,8 +188,7 @@ pub async fn check_and_compact_messages( Ok(AutoCompactResult { compacted: true, messages: compacted_messages, - tokens_before: Some(tokens_before + SYSTEM_PROMPT_TOKEN_OVERHEAD + TOOLS_TOKEN_OVERHEAD), - tokens_after: Some(tokens_after + SYSTEM_PROMPT_TOKEN_OVERHEAD + TOOLS_TOKEN_OVERHEAD), + summarization_usage, }) } @@ -326,7 +277,7 @@ mod tests { let mock_provider = Arc::new(MockProvider { model_config: ModelConfig::new("test-model") .unwrap() - .with_context_limit(100_000.into()), + .with_context_limit(Some(100_000)), }); let agent = Agent::new(); @@ -352,7 +303,7 @@ mod tests { let mock_provider = Arc::new(MockProvider { model_config: ModelConfig::new("test-model") .unwrap() - .with_context_limit(100_000.into()), + .with_context_limit(Some(100_000)), }); let agent = Agent::new(); @@ -375,40 +326,12 @@ mod tests { assert!(!result.needs_compaction); } - #[tokio::test] - async fn test_perform_compaction() { - let mock_provider = Arc::new(MockProvider { - model_config: ModelConfig::new("test-model") - .unwrap() - .with_context_limit(50_000.into()), - }); - - let agent = Agent::new(); - let _ = agent.update_provider(mock_provider).await; - - // Create some messages to compact - let messages = vec![ - create_test_message("First message"), - create_test_message("Second message"), - create_test_message("Third message"), - ]; - - let (compacted_messages, tokens_before, tokens_after) = - perform_compaction(&agent, &messages).await.unwrap(); - - assert!(tokens_before > 0); - assert!(tokens_after > 0); - // Note: The mock provider returns a fixed summary, which might not always be smaller - // In real usage, compaction should reduce tokens, but for testing we just verify it works - assert!(!compacted_messages.is_empty()); - } - #[tokio::test] async fn test_auto_compact_disabled() { let mock_provider = Arc::new(MockProvider { model_config: ModelConfig::new("test-model") .unwrap() - .with_context_limit(10_000.into()), + .with_context_limit(Some(10_000)), }); let agent = Agent::new(); @@ -423,8 +346,7 @@ mod tests { assert!(!result.compacted); assert_eq!(result.messages.len(), messages.len()); - assert!(result.tokens_before.is_none()); - assert!(result.tokens_after.is_none()); + assert!(result.summarization_usage.is_none()); // Test with threshold 1.0 (disabled) let result = check_and_compact_messages(&agent, &messages, Some(1.0), None) @@ -439,7 +361,7 @@ mod tests { let mock_provider = Arc::new(MockProvider { model_config: ModelConfig::new("test-model") .unwrap() - .with_context_limit(100_000.into()), // Increased to ensure overhead doesn't dominate + .with_context_limit(Some(100_000)), // Increased to ensure overhead doesn't dominate }); let agent = Agent::new(); @@ -499,14 +421,15 @@ mod tests { } assert!(result.compacted); - assert!(result.tokens_before.is_some()); - assert!(result.tokens_after.is_some()); + assert!(result.summarization_usage.is_some()); - // Should have fewer tokens after compaction - if let (Some(before), Some(after)) = (result.tokens_before, result.tokens_after) { + // Verify that summarization usage contains token counts + if let Some(usage) = &result.summarization_usage { + assert!(usage.usage.total_tokens.is_some()); + let after = usage.usage.total_tokens.unwrap_or(0) as usize; assert!( - after < before, - "Token count should decrease after compaction" + after > 0, + "Token count after compaction should be greater than 0" ); } @@ -519,7 +442,7 @@ mod tests { let mock_provider = Arc::new(MockProvider { model_config: ModelConfig::new("test-model") .unwrap() - .with_context_limit(30_000.into()), // Smaller context limit to make threshold easier to hit + .with_context_limit(Some(30_000)), // Smaller context limit to make threshold easier to hit }); let agent = Agent::new(); @@ -553,19 +476,7 @@ mod tests { // Debug info if not compacted if !result.compacted { - let provider = agent.provider().await.unwrap(); - let token_counter = create_async_token_counter().await.unwrap(); - let token_counts = get_messages_token_counts_async(&token_counter, &messages); - let total_tokens: usize = token_counts.iter().sum(); - let context_limit = provider.get_model_config().context_limit(); - let usage_ratio = total_tokens as f64 / context_limit as f64; - - eprintln!( - "Config test not compacted - tokens: {} / {} ({:.1}%)", - total_tokens, - context_limit, - usage_ratio * 100.0 - ); + eprintln!("Test failed - compaction not triggered"); } // With such a low threshold (10%), it should compact @@ -701,8 +612,7 @@ mod tests { // Should have triggered compaction assert!(result.compacted); - assert!(result.tokens_before.is_some()); - assert!(result.tokens_after.is_some()); + assert!(result.summarization_usage.is_some()); // Verify the compacted messages are returned assert!(!result.messages.is_empty()); diff --git a/crates/goose/src/context_mgmt/summarize.rs b/crates/goose/src/context_mgmt/summarize.rs index 0462d55fa4ff..68947cfa98f7 100644 --- a/crates/goose/src/context_mgmt/summarize.rs +++ b/crates/goose/src/context_mgmt/summarize.rs @@ -1,70 +1,26 @@ -use super::common::get_messages_token_counts_async; -use crate::context_mgmt::get_messages_token_counts; use crate::conversation::message::Message; -use crate::conversation::Conversation; use crate::prompt_template::render_global_file; use crate::providers::base::Provider; -use crate::token_counter::{AsyncTokenCounter, TokenCounter}; + use anyhow::Result; use rmcp::model::Role; use serde::Serialize; use std::sync::Arc; -// Constants for the summarization prompt and a follow-up user message. -const SUMMARY_PROMPT: &str = "You are good at summarizing conversations"; - #[derive(Serialize)] struct SummarizeContext { messages: String, } -/// Summarize the combined messages from the accumulated summary and the current chunk. -/// -/// This method builds the summarization request, sends it to the provider, and returns the summarized response. -async fn summarize_combined_messages( - provider: &Arc, - accumulated_summary: &[Message], - current_chunk: &[Message], -) -> Result { - // Combine the accumulated summary and current chunk into a single batch. - let combined_messages = Conversation::new_unvalidated( - accumulated_summary - .iter() - .cloned() - .chain(current_chunk.iter().cloned()) - .collect::>(), - ); - - // Format the batch as a summarization request. - let request_text = format!( - "Please summarize the following conversation history, preserving the key points. This summarization will be used for the later conversations.\n\n```\n{:?}\n```", - combined_messages - ); - let summarization_request = vec![Message::user().with_text(&request_text)]; - - // Send the request to the provider and fetch the response. - let mut response = provider - .complete(SUMMARY_PROMPT, &summarization_request, &[]) - .await? - .0; - // Set role to user as it will be used in following conversation as user content. - response.role = Role::User; - - // Return the summary as the new accumulated summary. - Ok(Conversation::new_unvalidated(vec![response])) -} +use crate::providers::base::ProviderUsage; -// Summarization steps: -// Using a single tailored prompt, summarize the entire conversation history. -pub async fn summarize_messages_oneshot( +/// Summarization function that uses the detailed prompt from the markdown template +pub async fn summarize_messages( provider: Arc, messages: &[Message], - token_counter: &TokenCounter, - _context_limit: usize, -) -> Result<(Conversation, Vec), anyhow::Error> { +) -> Result, anyhow::Error> { if messages.is_empty() { - // If no messages to summarize, return empty - return Ok((Conversation::empty(), vec![])); + return Ok(None); } // Format all messages as a single string for the summarization prompt @@ -86,176 +42,21 @@ pub async fn summarize_messages_oneshot( .with_text("Please summarize the conversation history provided in the system prompt."); let summarization_request = vec![user_message]; - // Send the request to the provider and fetch the response. - let mut response = provider + // Send the request to the provider and fetch the response + let (mut response, mut provider_usage) = provider .complete(&system_prompt, &summarization_request, &[]) - .await? - .0; + .await?; - // Set role to user as it will be used in following conversation as user content. + // Set role to user as it will be used in following conversation as user content response.role = Role::User; - // Return just the summary without any tool response preservation - let final_summary = Conversation::new_unvalidated([response].into_iter()); - let counts = get_messages_token_counts(token_counter, final_summary.messages()); - - Ok((final_summary, counts)) -} - -// Summarization steps: -// 1. Break down large text into smaller chunks (roughly 30% of the model’s context window). -// 2. For each chunk: -// a. Combine it with the previous summary (or leave blank for the first iteration). -// b. Summarize the combined text, focusing on extracting only the information we need. -// 3. Generate a final summary using a tailored prompt. -pub async fn summarize_messages_chunked( - provider: Arc, - messages: &[Message], - token_counter: &TokenCounter, - context_limit: usize, -) -> Result<(Conversation, Vec), anyhow::Error> { - let chunk_size = context_limit / 3; // 33% of the context window. - let summary_prompt_tokens = token_counter.count_tokens(SUMMARY_PROMPT); - let mut accumulated_summary = Conversation::empty(); - - // Get token counts for each message. - let token_counts = get_messages_token_counts(token_counter, messages); - - // Tokenize and break messages into chunks. - let mut current_chunk: Vec = Vec::new(); - let mut current_chunk_tokens = 0; - - for (message, message_tokens) in messages.iter().zip(token_counts.iter()) { - if current_chunk_tokens + message_tokens > chunk_size - summary_prompt_tokens { - // Summarize the current chunk with the accumulated summary. - accumulated_summary = summarize_combined_messages( - &provider, - accumulated_summary.messages(), - ¤t_chunk, - ) - .await?; - - // Reset for the next chunk. - current_chunk.clear(); - current_chunk_tokens = 0; - } - - // Add message to the current chunk. - current_chunk.push(message.clone()); - current_chunk_tokens += message_tokens; - } - - // Summarize the final chunk if it exists. - if !current_chunk.is_empty() { - accumulated_summary = - summarize_combined_messages(&provider, accumulated_summary.messages(), ¤t_chunk) - .await?; - } - - // Return just the summary without any tool response preservation - Ok(( - accumulated_summary.clone(), - get_messages_token_counts(token_counter, accumulated_summary.messages()), - )) -} - -/// Main summarization function that chooses the best algorithm based on context size. -/// -/// This function will: -/// 1. First try the one-shot summarization if there's enough context window available -/// 2. Fall back to the chunked approach if the one-shot fails or if context is too limited -/// 3. Choose the algorithm based on absolute token requirements rather than percentages -pub async fn summarize_messages( - provider: Arc, - messages: &[Message], - token_counter: &TokenCounter, - context_limit: usize, -) -> Result<(Conversation, Vec), anyhow::Error> { - // Calculate total tokens in messages - let total_tokens: usize = get_messages_token_counts(token_counter, messages) - .iter() - .sum(); - - // Calculate absolute token requirements (future-proof for large context models) - let system_prompt_overhead = 1000; // Conservative estimate for the summarization prompt - let response_overhead = 4000; // Generous buffer for response generation - let safety_buffer = 1000; // Small safety margin for tokenization variations - let total_required = total_tokens + system_prompt_overhead + response_overhead + safety_buffer; - - // Use one-shot if we have enough absolute space (no percentage-based limits) - if total_required <= context_limit { - match summarize_messages_oneshot( - Arc::clone(&provider), - messages, - token_counter, - context_limit, - ) + // Ensure we have token counts, estimating if necessary + provider_usage + .ensure_tokens(&system_prompt, &summarization_request, &response, &[]) .await - { - Ok(result) => return Ok(result), - Err(e) => { - // Log the error but continue to fallback - tracing::warn!( - "One-shot summarization failed, falling back to chunked approach: {}", - e - ); - } - } - } - - // Fall back to the chunked approach - summarize_messages_chunked(provider, messages, token_counter, context_limit).await -} - -/// Async version using AsyncTokenCounter for better performance -pub async fn summarize_messages_async( - provider: Arc, - messages: &[Message], - token_counter: &AsyncTokenCounter, - context_limit: usize, -) -> Result<(Conversation, Vec), anyhow::Error> { - let chunk_size = context_limit / 3; // 33% of the context window. - let summary_prompt_tokens = token_counter.count_tokens(SUMMARY_PROMPT); - let mut accumulated_summary = Conversation::empty(); - - // Get token counts for each message. - let token_counts = get_messages_token_counts_async(token_counter, messages); - - // Tokenize and break messages into chunks. - let mut current_chunk = Vec::new(); - let mut current_chunk_tokens = 0; - - for (message, message_tokens) in messages.iter().zip(token_counts.iter()) { - if current_chunk_tokens + message_tokens > chunk_size - summary_prompt_tokens { - // Summarize the current chunk with the accumulated summary. - accumulated_summary = summarize_combined_messages( - &provider, - accumulated_summary.messages(), - ¤t_chunk, - ) - .await?; - - // Reset for the next chunk. - current_chunk.clear(); - current_chunk_tokens = 0; - } + .map_err(|e| anyhow::anyhow!("Failed to ensure usage tokens: {}", e))?; - // Add message to the current chunk. - current_chunk.push(message.clone()); - current_chunk_tokens += message_tokens; - } - - // Summarize the final chunk if it exists. - if !current_chunk.is_empty() { - accumulated_summary = - summarize_combined_messages(&provider, accumulated_summary.messages(), ¤t_chunk) - .await?; - } - - let count = get_messages_token_counts_async(token_counter, accumulated_summary.messages()); - - // Return just the summary without any tool response preservation - Ok((accumulated_summary.clone(), count)) + Ok(Some((response, provider_usage))) } #[cfg(test)] @@ -303,13 +104,20 @@ mod tests { .no_annotation(), )], ), - ProviderUsage::new("mock".to_string(), Usage::default()), + ProviderUsage::new( + "mock".to_string(), + Usage { + input_tokens: Some(100), + output_tokens: Some(50), + total_tokens: Some(150), + }, + ), )) } } fn create_mock_provider() -> Result> { - let mock_model_config = ModelConfig::new("test-model")?.with_context_limit(200_000.into()); + let mock_model_config = ModelConfig::new("test-model")?.with_context_limit(Some(200_000)); Ok(Arc::new(MockProvider { model_config: mock_model_config, @@ -329,366 +137,49 @@ mod tests { } #[tokio::test] - async fn test_summarize_messages_single_chunk() { + async fn test_summarize_messages_basic() { let provider = create_mock_provider().expect("failed to create mock provider"); - let token_counter = TokenCounter::new(); - let context_limit = 10_000; // Higher limit to avoid underflow let messages = create_test_messages(); - let result = summarize_messages( - Arc::clone(&provider), - &messages, - &token_counter, - context_limit, - ) - .await; + let result = summarize_messages(Arc::clone(&provider), &messages).await; assert!(result.is_ok(), "The function should return Ok."); - let (summarized_messages, token_counts) = result.unwrap(); + let summary_result = result.unwrap(); - assert_eq!( - summarized_messages.len(), - 1, - "The summary should contain one message." - ); - assert_eq!( - summarized_messages.first().unwrap().role, - Role::User, - "The summarized message should be from the user." - ); - - assert_eq!( - token_counts.len(), - 1, - "Token counts should match the number of summarized messages." + assert!( + summary_result.is_some(), + "The summary should contain a result." ); - } - - #[tokio::test] - async fn test_summarize_messages_multiple_chunks() { - let provider = create_mock_provider().expect("failed to create mock provider"); - let token_counter = TokenCounter::new(); - let context_limit = 10_000; // Higher limit to avoid underflow - let messages = create_test_messages(); - - let result = summarize_messages( - Arc::clone(&provider), - &messages, - &token_counter, - context_limit, - ) - .await; + let (summarized_message, provider_usage) = summary_result.unwrap(); - assert!(result.is_ok(), "The function should return Ok."); - let (summarized_messages, token_counts) = result.unwrap(); - - assert_eq!( - summarized_messages.len(), - 1, - "There should be one final summarized message." - ); assert_eq!( - summarized_messages.first().unwrap().role, + summarized_message.role, Role::User, "The summarized message should be from the user." ); - - assert_eq!( - token_counts.len(), - 1, - "Token counts should match the number of summarized messages." - ); - } - - #[tokio::test] - async fn test_summarize_messages_empty_input() { - let provider = create_mock_provider().expect("failed to create mock provider"); - let token_counter = TokenCounter::new(); - let context_limit = 10_000; // Higher limit to avoid underflow - let messages: Vec = Vec::new(); - - let result = summarize_messages( - Arc::clone(&provider), - &messages, - &token_counter, - context_limit, - ) - .await; - - assert!(result.is_ok(), "The function should return Ok."); - let (summarized_messages, token_counts) = result.unwrap(); - - assert_eq!( - summarized_messages.len(), - 0, - "The summary should be empty for an empty input." - ); assert!( - token_counts.is_empty(), - "Token counts should be empty for an empty input." - ); - } - - #[tokio::test] - async fn test_summarize_messages_uses_oneshot_for_small_context() { - let provider = create_mock_provider().expect("failed to create mock provider"); - let token_counter = TokenCounter::new(); - let context_limit = 100_000; // Large context limit - let messages = create_test_messages(); // Small message set - - let result = summarize_messages( - Arc::clone(&provider), - &messages, - &token_counter, - context_limit, - ) - .await; - - assert!(result.is_ok(), "The function should return Ok."); - let (summarized_messages, _) = result.unwrap(); - - // Should use one-shot and return a single summarized message - assert_eq!( - summarized_messages.len(), - 1, - "Should use one-shot summarization for small context." + provider_usage.usage.input_tokens.unwrap_or(0) > 0, + "Should have input token count" ); - } - - #[tokio::test] - async fn test_summarize_messages_uses_chunked_for_large_context() { - let provider = create_mock_provider().expect("failed to create mock provider"); - let token_counter = TokenCounter::new(); - let context_limit = 10_000; // Higher limit to avoid underflow - let messages = create_test_messages(); - - let result = summarize_messages( - Arc::clone(&provider), - &messages, - &token_counter, - context_limit, - ) - .await; - - assert!(result.is_ok(), "The function should return Ok."); - let (summarized_messages, _) = result.unwrap(); - - // Should fall back to chunked approach - assert_eq!( - summarized_messages.len(), - 1, - "Should use chunked summarization for large context." - ); - } - - // Mock provider that fails on one-shot but succeeds on chunked - #[derive(Clone)] - struct FailingOneshotProvider { - model_config: ModelConfig, - call_count: Arc>, - } - - #[async_trait::async_trait] - impl Provider for FailingOneshotProvider { - fn metadata() -> ProviderMetadata { - ProviderMetadata::empty() - } - - fn get_model_config(&self) -> ModelConfig { - self.model_config.clone() - } - - async fn complete( - &self, - system: &str, - _messages: &[Message], - _tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let mut count = self.call_count.lock().unwrap(); - *count += 1; - - // Fail if this looks like a one-shot request - if system.contains("reasoning in `` tags") { - return Err(ProviderError::RateLimitExceeded( - "Simulated one-shot failure".to_string(), - )); - } - - // Succeed for chunked requests (uses the old SUMMARY_PROMPT) - Ok(( - Message::new( - Role::Assistant, - Utc::now().timestamp(), - vec![MessageContent::Text( - RawTextContent { - text: "Chunked summary".to_string(), - } - .no_annotation(), - )], - ), - ProviderUsage::new("mock".to_string(), Usage::default()), - )) - } - } - - #[tokio::test] - async fn test_summarize_messages_fallback_on_oneshot_failure() { - let call_count = Arc::new(std::sync::Mutex::new(0)); - let provider = Arc::new(FailingOneshotProvider { - model_config: ModelConfig::new("test-model") - .unwrap() - .with_context_limit(200_000.into()), - call_count: Arc::clone(&call_count), - }); - let token_counter = TokenCounter::new(); - let context_limit = 100_000; // Large enough to try one-shot first - let messages = create_test_messages(); - - let result = summarize_messages(provider, &messages, &token_counter, context_limit).await; - - assert!( - result.is_ok(), - "The function should return Ok after fallback." - ); - let (summarized_messages, _) = result.unwrap(); - - // Should have fallen back to chunked approach - assert_eq!( - summarized_messages.len(), - 1, - "Should successfully fall back to chunked approach." - ); - - // Verify the content comes from the chunked approach - if let MessageContent::Text(text_content) = &summarized_messages.first().unwrap().content[0] - { - assert_eq!(text_content.text, "Chunked summary"); - } else { - panic!("Expected text content"); - } - - // Should have made multiple calls (one-shot attempt + chunked calls) - let final_count = *call_count.lock().unwrap(); - assert!( - final_count > 1, - "Should have made multiple provider calls during fallback" - ); - } - - #[tokio::test] - async fn test_summarize_messages_oneshot_direct_call() { - let provider = create_mock_provider().expect("failed to create mock provider"); - let token_counter = TokenCounter::new(); - let context_limit = 100_000; - let messages = create_test_messages(); - - let result = summarize_messages_oneshot( - Arc::clone(&provider), - &messages, - &token_counter, - context_limit, - ) - .await; - - assert!( - result.is_ok(), - "One-shot summarization should work directly." - ); - let (summarized_messages, token_counts) = result.unwrap(); - - assert_eq!( - summarized_messages.len(), - 1, - "One-shot should return a single summary message." - ); - assert_eq!( - summarized_messages.first().unwrap().role, - Role::User, - "Summary should be from user role for context." - ); - assert_eq!( - token_counts.len(), - 1, - "Should have token count for the summary." - ); - } - - #[tokio::test] - async fn test_summarize_messages_chunked_direct_call() { - let provider = create_mock_provider().expect("failed to create mock provider"); - let token_counter = TokenCounter::new(); - let context_limit = 10_000; // Higher limit to avoid underflow - let messages = create_test_messages(); - - let result = summarize_messages_chunked( - Arc::clone(&provider), - &messages, - &token_counter, - context_limit, - ) - .await; - assert!( - result.is_ok(), - "Chunked summarization should work directly." - ); - let (summarized_messages, token_counts) = result.unwrap(); - - assert_eq!( - summarized_messages.len(), - 1, - "Chunked should return a single final summary." - ); - assert_eq!( - summarized_messages.first().unwrap().role, - Role::User, - "Summary should be from user role for context." - ); - assert_eq!( - token_counts.len(), - 1, - "Should have token count for the summary." + provider_usage.usage.output_tokens.unwrap_or(0) > 0, + "Should have output token count" ); } #[tokio::test] - async fn test_absolute_token_threshold_calculation() { + async fn test_summarize_messages_empty_input() { let provider = create_mock_provider().expect("failed to create mock provider"); - let token_counter = TokenCounter::new(); - - // Test with a context limit where absolute token calculation matters - let context_limit = 10_000; - let system_prompt_overhead = 1000; - let response_overhead = 4000; - let safety_buffer = 1000; - let max_message_tokens = - context_limit - system_prompt_overhead - response_overhead - safety_buffer; // 4000 tokens - - // Create messages that are just under the absolute threshold - let mut large_messages = Vec::new(); - let base_message = set_up_text_message("x".repeat(50).as_str(), Role::User); - - // Add enough messages to approach but not exceed the absolute threshold - let message_tokens = token_counter.count_tokens(&format!("{:?}", base_message)); - let num_messages = (max_message_tokens / message_tokens).saturating_sub(1); + let messages: Vec = Vec::new(); - for i in 0..num_messages { - large_messages.push(set_up_text_message(&format!("Message {}", i), Role::User)); - } + let result = summarize_messages(Arc::clone(&provider), &messages).await; - let result = summarize_messages( - Arc::clone(&provider), - &large_messages, - &token_counter, - context_limit, - ) - .await; + assert!(result.is_ok(), "The function should return Ok."); + let summary_result = result.unwrap(); assert!( - result.is_ok(), - "Should handle absolute threshold calculation correctly." + summary_result.is_none(), + "The summary should be None for empty input." ); - let (summarized_messages, _) = result.unwrap(); - assert_eq!(summarized_messages.len(), 1, "Should produce a summary."); } } diff --git a/crates/goose/src/prompts/summarize_oneshot.md b/crates/goose/src/prompts/summarize_oneshot.md index 8e621f2058aa..b0170517bebb 100644 --- a/crates/goose/src/prompts/summarize_oneshot.md +++ b/crates/goose/src/prompts/summarize_oneshot.md @@ -1,9 +1,15 @@ -## Summary Task -Generate detailed summary of conversation to date. -Include user requests, your responses, and all technical content. +## Task Context +- An llm context limit was reached when a user was in a working session with an agent (you) +- Generate a version of the below messages with only the most verbose parts removed +- Include user requests, your responses, all technical content, and as much of the original context as possible +- This will be used to let the user continue the working session +- Use framing and tone knowing the content will be read an agent (you) on a next exchange to allow for continuation of the session + +**Conversation History:** +{{ messages }} Wrap reasoning in `` tags: -- Review conversation chronologically +- Review conversation chronologically - For each part, log: - User goals and requests - Your method and solution @@ -11,14 +17,16 @@ Wrap reasoning in `` tags: - File names, code, signatures, errors, fixes - Highlight user feedback and revisions - Confirm completeness and accuracy +- This summary will only be read by you so it is ok to make it much longer than a normal summary you would show to a human +- Do not exclude any information that might be important to continuing a session working with you -### Summary Must Include the Following Sections: +### Include the Following Sections: 1. **User Intent** – All goals and requests 2. **Technical Concepts** – All discussed tools, methods 3. **Files + Code** – Viewed/edited files, full code, change justifications 4. **Errors + Fixes** – Bugs, resolutions, user-driven changes 5. **Problem Solving** – Issues solved or in progress -6. **User Messages** – All user messages, exclude tool output +6. **User Messages** – All user messages including tool calls, but truncate long tool call arguments or results 7. **Pending Tasks** – All unresolved user requests 8. **Current Work** – Active work at summary request time: filenames, code, alignment to latest instruction 9. **Next Step** – *Include only if* directly continues user instruction diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 02fe7801e061..60623abb3a3e 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -217,6 +217,34 @@ impl ProviderUsage { pub fn new(model: String, usage: Usage) -> Self { Self { model, usage } } + + /// Ensures this ProviderUsage has token counts, estimating them if necessary + pub async fn ensure_tokens( + &mut self, + system_prompt: &str, + request_messages: &[Message], + response: &Message, + tools: &[Tool], + ) -> Result<(), ProviderError> { + crate::providers::usage_estimator::ensure_usage_tokens( + self, + system_prompt, + request_messages, + response, + tools, + ) + .await + .map_err(|e| ProviderError::ExecutionError(format!("Failed to ensure usage tokens: {}", e))) + } + + /// Combine this ProviderUsage with another, adding their token counts + /// Uses the model from this ProviderUsage + pub fn combine_with(&self, other: &ProviderUsage) -> ProviderUsage { + ProviderUsage { + model: self.model.clone(), + usage: self.usage + other.usage, + } + } } #[derive(Debug, Clone, Serialize, Deserialize, Default, Copy)] diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 3e04fba896ee..60386d2171ca 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -28,6 +28,7 @@ pub mod sagemaker_tgi; pub mod snowflake; pub mod testprovider; pub mod toolshim; +pub mod usage_estimator; pub mod utils; pub mod utils_universal_openai_stream; pub mod venice; diff --git a/crates/goose/src/providers/usage_estimator.rs b/crates/goose/src/providers/usage_estimator.rs new file mode 100644 index 000000000000..e819b6dc8349 --- /dev/null +++ b/crates/goose/src/providers/usage_estimator.rs @@ -0,0 +1,128 @@ +use crate::conversation::message::Message; +use crate::providers::base::ProviderUsage; +use crate::token_counter::create_async_token_counter; +use anyhow::Result; +use rmcp::model::Tool; + +/// Ensures that ProviderUsage has token counts, estimating them if necessary. +/// This provides a single place to handle the fallback logic for providers that don't return usage data. +pub async fn ensure_usage_tokens( + provider_usage: &mut ProviderUsage, + system_prompt: &str, + request_messages: &[Message], + response: &Message, + tools: &[Tool], +) -> Result<()> { + if provider_usage.usage.input_tokens.is_some() && provider_usage.usage.output_tokens.is_some() { + return Ok(()); + } + + let token_counter = create_async_token_counter() + .await + .map_err(|e| anyhow::anyhow!("Failed to create token counter: {}", e))?; + + if provider_usage.usage.input_tokens.is_none() { + let input_count = token_counter.count_chat_tokens(system_prompt, request_messages, tools); + provider_usage.usage.input_tokens = Some(input_count as i32); + } + + if provider_usage.usage.output_tokens.is_none() { + let response_text = response + .content + .iter() + .map(|c| format!("{}", c)) + .collect::>() + .join(" "); + let output_count = token_counter.count_tokens(&response_text); + provider_usage.usage.output_tokens = Some(output_count as i32); + } + + if let (Some(input), Some(output)) = ( + provider_usage.usage.input_tokens, + provider_usage.usage.output_tokens, + ) { + provider_usage.usage.total_tokens = Some(input + output); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::conversation::message::Message; + use crate::providers::base::Usage; + + #[tokio::test] + async fn test_ensure_usage_tokens_already_complete() { + let mut usage = ProviderUsage::new( + "test-model".to_string(), + Usage::new(Some(100), Some(50), Some(150)), + ); + + let response = Message::assistant().with_text("Test response"); + + ensure_usage_tokens(&mut usage, "system", &[], &response, &[]) + .await + .unwrap(); + + // Should remain unchanged + assert_eq!(usage.usage.input_tokens, Some(100)); + assert_eq!(usage.usage.output_tokens, Some(50)); + assert_eq!(usage.usage.total_tokens, Some(150)); + } + + #[tokio::test] + async fn test_ensure_usage_tokens_missing_all() { + let mut usage = ProviderUsage::new("test-model".to_string(), Usage::default()); + + let response = Message::assistant().with_text("Test response"); + let messages = vec![Message::user().with_text("Hello")]; + + ensure_usage_tokens( + &mut usage, + "You are a helpful assistant", + &messages, + &response, + &[], + ) + .await + .unwrap(); + + // Should have estimated values + assert!(usage.usage.input_tokens.is_some()); + assert!(usage.usage.output_tokens.is_some()); + assert!(usage.usage.total_tokens.is_some()); + + // Basic sanity checks + assert!(usage.usage.input_tokens.unwrap() > 0); + assert!(usage.usage.output_tokens.unwrap() > 0); + assert_eq!( + usage.usage.total_tokens.unwrap(), + usage.usage.input_tokens.unwrap() + usage.usage.output_tokens.unwrap() + ); + } + + #[tokio::test] + async fn test_ensure_usage_tokens_partial() { + let mut usage = + ProviderUsage::new("test-model".to_string(), Usage::new(Some(100), None, None)); + + let response = Message::assistant().with_text("Test response"); + + ensure_usage_tokens(&mut usage, "system", &[], &response, &[]) + .await + .unwrap(); + + // Input should remain unchanged + assert_eq!(usage.usage.input_tokens, Some(100)); + // Output should be estimated + assert!(usage.usage.output_tokens.is_some()); + assert!(usage.usage.output_tokens.unwrap() > 0); + // Total should be calculated + assert_eq!( + usage.usage.total_tokens.unwrap(), + usage.usage.input_tokens.unwrap() + usage.usage.output_tokens.unwrap() + ); + } +} diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 589b3987dbb2..01aad7cd7271 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -10,7 +10,7 @@ "license": { "name": "Apache-2.0" }, - "version": "1.3.0" + "version": "1.4.0" }, "paths": { "/agent/add_sub_recipes": { diff --git a/ui/desktop/package-lock.json b/ui/desktop/package-lock.json index ef74dac64631..16bcead982a7 100644 --- a/ui/desktop/package-lock.json +++ b/ui/desktop/package-lock.json @@ -1,12 +1,12 @@ { "name": "goose-app", - "version": "1.3.0", + "version": "1.4.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "goose-app", - "version": "1.3.0", + "version": "1.4.1", "license": "Apache-2.0", "dependencies": { "@ai-sdk/openai": "^0.0.72", diff --git a/ui/desktop/package.json b/ui/desktop/package.json index 1456c0017d97..6f612f5bb02e 100644 --- a/ui/desktop/package.json +++ b/ui/desktop/package.json @@ -1,7 +1,7 @@ { "name": "goose-app", "productName": "Goose", - "version": "1.3.0", + "version": "1.4.1", "description": "Goose App", "engines": { "node": "^22.17.1" diff --git a/ui/desktop/src/components/MarkdownContent.tsx b/ui/desktop/src/components/MarkdownContent.tsx index 3566c6bcccee..15c9b293a983 100644 --- a/ui/desktop/src/components/MarkdownContent.tsx +++ b/ui/desktop/src/components/MarkdownContent.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect } from 'react'; +import React, { useState, useEffect, useRef } from 'react'; import ReactMarkdown from 'react-markdown'; import remarkGfm from 'remark-gfm'; import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; @@ -17,16 +17,31 @@ interface MarkdownContentProps { const CodeBlock = ({ language, children }: { language: string; children: string }) => { const [copied, setCopied] = useState(false); + const timeoutRef = useRef(null); + const handleCopy = async () => { try { await navigator.clipboard.writeText(children); setCopied(true); - setTimeout(() => setCopied(false), 2000); // Reset after 2 seconds + + if (timeoutRef.current) { + window.clearTimeout(timeoutRef.current); + } + + timeoutRef.current = window.setTimeout(() => setCopied(false), 2000); } catch (err) { console.error('Failed to copy text: ', err); } }; + useEffect(() => { + return () => { + if (timeoutRef.current) { + window.clearTimeout(timeoutRef.current); + } + }; + }, []); + return (