diff --git a/crates/goose-bench/src/bench_session.rs b/crates/goose-bench/src/bench_session.rs index 6c8a4f7cedf6..30dc1dd7cea9 100644 --- a/crates/goose-bench/src/bench_session.rs +++ b/crates/goose-bench/src/bench_session.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use goose::message::Message; +use goose::conversation::Conversation; use serde::{Deserialize, Serialize}; use std::path::PathBuf; @@ -19,7 +19,7 @@ pub struct BenchAgentError { pub trait BenchBaseSession: Send + Sync { async fn headless(&mut self, message: String) -> anyhow::Result<()>; fn session_file(&self) -> Option; - fn message_history(&self) -> Vec; + fn message_history(&self) -> Conversation; fn get_total_token_usage(&self) -> anyhow::Result>; } // struct for managing agent-session-access. to be passed to evals for benchmarking @@ -34,7 +34,7 @@ impl BenchAgent { Self { session, errors } } - pub(crate) async fn prompt(&mut self, p: String) -> anyhow::Result> { + pub(crate) async fn prompt(&mut self, p: String) -> anyhow::Result { // Clear previous errors { let mut errors = self.errors.lock().await; diff --git a/crates/goose-bench/src/eval_suites/core/computercontroller/script.rs b/crates/goose-bench/src/eval_suites/core/computercontroller/script.rs index b42243dfd2d7..cb50c3d6055a 100644 --- a/crates/goose-bench/src/eval_suites/core/computercontroller/script.rs +++ b/crates/goose-bench/src/eval_suites/core/computercontroller/script.rs @@ -8,7 +8,7 @@ use crate::eval_suites::{ }; use crate::register_evaluation; use async_trait::async_trait; -use goose::message::MessageContent; +use goose::conversation::message::MessageContent; use rmcp::model::Role; use serde_json::{self, Value}; diff --git a/crates/goose-bench/src/eval_suites/core/computercontroller/web_scrape.rs b/crates/goose-bench/src/eval_suites/core/computercontroller/web_scrape.rs index 04fba60f7586..77ae2105fa26 100644 --- a/crates/goose-bench/src/eval_suites/core/computercontroller/web_scrape.rs +++ b/crates/goose-bench/src/eval_suites/core/computercontroller/web_scrape.rs @@ -8,7 +8,7 @@ use crate::eval_suites::{ }; use crate::register_evaluation; use async_trait::async_trait; -use goose::message::MessageContent; +use goose::conversation::message::MessageContent; use rmcp::model::Role; use serde_json::{self, Value}; diff --git a/crates/goose-bench/src/eval_suites/core/developer/create_file.rs b/crates/goose-bench/src/eval_suites/core/developer/create_file.rs index 154319c38316..8a7a2587125e 100644 --- a/crates/goose-bench/src/eval_suites/core/developer/create_file.rs +++ b/crates/goose-bench/src/eval_suites/core/developer/create_file.rs @@ -8,7 +8,7 @@ use crate::eval_suites::{ }; use crate::register_evaluation; use async_trait::async_trait; -use goose::message::MessageContent; +use goose::conversation::message::MessageContent; use rmcp::model::Role; use serde_json::{self, Value}; diff --git a/crates/goose-bench/src/eval_suites/core/developer/list_files.rs b/crates/goose-bench/src/eval_suites/core/developer/list_files.rs index 8aea32cc5b14..5eca6589dfca 100644 --- a/crates/goose-bench/src/eval_suites/core/developer/list_files.rs +++ b/crates/goose-bench/src/eval_suites/core/developer/list_files.rs @@ -6,7 +6,7 @@ use crate::eval_suites::{ }; use crate::register_evaluation; use async_trait::async_trait; -use goose::message::MessageContent; +use goose::conversation::message::MessageContent; use rmcp::model::Role; use serde_json::{self, Value}; diff --git a/crates/goose-bench/src/eval_suites/core/developer/simple_repo_clone_test.rs b/crates/goose-bench/src/eval_suites/core/developer/simple_repo_clone_test.rs index ffa8541e6191..bbc6afc699a5 100644 --- a/crates/goose-bench/src/eval_suites/core/developer/simple_repo_clone_test.rs +++ b/crates/goose-bench/src/eval_suites/core/developer/simple_repo_clone_test.rs @@ -6,7 +6,7 @@ use crate::eval_suites::{ }; use crate::register_evaluation; use async_trait::async_trait; -use goose::message::MessageContent; +use goose::conversation::message::MessageContent; use rmcp::model::Role; use serde_json::{self, Value}; diff --git a/crates/goose-bench/src/eval_suites/core/developer_image/image.rs b/crates/goose-bench/src/eval_suites/core/developer_image/image.rs index 2ac8a8ce88f8..771b550d052b 100644 --- a/crates/goose-bench/src/eval_suites/core/developer_image/image.rs +++ b/crates/goose-bench/src/eval_suites/core/developer_image/image.rs @@ -6,7 +6,7 @@ use crate::eval_suites::{ }; use crate::register_evaluation; use async_trait::async_trait; -use goose::message::MessageContent; +use goose::conversation::message::MessageContent; use rmcp::model::Role; use serde_json::{self, Value}; diff --git a/crates/goose-bench/src/eval_suites/core/memory/save_fact.rs b/crates/goose-bench/src/eval_suites/core/memory/save_fact.rs index 4e3184e42af6..f5d01d9d8154 100644 --- a/crates/goose-bench/src/eval_suites/core/memory/save_fact.rs +++ b/crates/goose-bench/src/eval_suites/core/memory/save_fact.rs @@ -8,7 +8,7 @@ use crate::eval_suites::{ }; use crate::register_evaluation; use async_trait::async_trait; -use goose::message::MessageContent; +use goose::conversation::message::MessageContent; use rmcp::model::Role; use serde_json::{self, Value}; diff --git a/crates/goose-bench/src/eval_suites/metrics.rs b/crates/goose-bench/src/eval_suites/metrics.rs index d21d557d5275..1a424a21a34c 100644 --- a/crates/goose-bench/src/eval_suites/metrics.rs +++ b/crates/goose-bench/src/eval_suites/metrics.rs @@ -1,6 +1,7 @@ use crate::bench_session::BenchAgent; use crate::eval_suites::EvalMetricValue; -use goose::message::{Message, MessageContent}; +use goose::conversation::message::{Message, MessageContent}; +use goose::conversation::Conversation; use std::collections::HashMap; use std::time::Instant; @@ -8,7 +9,7 @@ use std::time::Instant; pub async fn collect_baseline_metrics( agent: &mut BenchAgent, prompt: String, -) -> (Vec, HashMap) { +) -> (Conversation, HashMap) { // Initialize metrics map let mut metrics = HashMap::new(); @@ -23,7 +24,7 @@ pub async fn collect_baseline_metrics( "prompt_error".to_string(), EvalMetricValue::String(format!("Error: {}", e)), ); - Vec::new() + Conversation::new_unvalidated(Vec::new()) } }; @@ -35,7 +36,7 @@ pub async fn collect_baseline_metrics( ); // Count tool calls - let (total_tool_calls, tool_calls_by_name) = count_tool_calls(&messages); + let (total_tool_calls, tool_calls_by_name) = count_tool_calls(messages.messages()); metrics.insert( "total_tool_calls".to_string(), EvalMetricValue::Integer(total_tool_calls), diff --git a/crates/goose-bench/src/eval_suites/utils.rs b/crates/goose-bench/src/eval_suites/utils.rs index 880e457b00cd..e5bbd15c9533 100644 --- a/crates/goose-bench/src/eval_suites/utils.rs +++ b/crates/goose-bench/src/eval_suites/utils.rs @@ -1,6 +1,6 @@ use crate::bench_work_dir::BenchmarkWorkDir; use anyhow::{Context, Result}; -use goose::message::Message; +use goose::conversation::message::Message; use std::fs::File; use std::io::Write; use std::path::PathBuf; diff --git a/crates/goose-bench/src/eval_suites/vibes/blog_summary.rs b/crates/goose-bench/src/eval_suites/vibes/blog_summary.rs index de2f22ef0b34..f3e0f22d55b1 100644 --- a/crates/goose-bench/src/eval_suites/vibes/blog_summary.rs +++ b/crates/goose-bench/src/eval_suites/vibes/blog_summary.rs @@ -37,7 +37,7 @@ impl Evaluation for BlogSummary { // Write response to file and get the text content let response_text = - match write_response_to_file(&response, run_loc, "blog_summary_output.txt") { + match write_response_to_file(response.messages(), run_loc, "blog_summary_output.txt") { Ok(text) => text, Err(e) => { println!("Warning: Failed to write blog summary output: {}", e); @@ -59,7 +59,7 @@ impl Evaluation for BlogSummary { )); // Check if the fetch tool was used - let used_fetch_tool = crate::eval_suites::used_tool(&response, "fetch"); + let used_fetch_tool = crate::eval_suites::used_tool(response.messages(), "fetch"); metrics.push(( "used_fetch_tool".to_string(), EvalMetricValue::Boolean(used_fetch_tool), diff --git a/crates/goose-bench/src/eval_suites/vibes/flappy_bird.rs b/crates/goose-bench/src/eval_suites/vibes/flappy_bird.rs index edd2f4a52424..ddd076417c49 100644 --- a/crates/goose-bench/src/eval_suites/vibes/flappy_bird.rs +++ b/crates/goose-bench/src/eval_suites/vibes/flappy_bird.rs @@ -6,7 +6,7 @@ use crate::eval_suites::{ }; use crate::register_evaluation; use async_trait::async_trait; -use goose::message::MessageContent; +use goose::conversation::message::MessageContent; use rmcp::model::Role; use serde_json::{self, Value}; use std::fs; diff --git a/crates/goose-bench/src/eval_suites/vibes/goose_wiki.rs b/crates/goose-bench/src/eval_suites/vibes/goose_wiki.rs index 2609584cf890..62a4efc542fc 100644 --- a/crates/goose-bench/src/eval_suites/vibes/goose_wiki.rs +++ b/crates/goose-bench/src/eval_suites/vibes/goose_wiki.rs @@ -6,7 +6,7 @@ use crate::eval_suites::{ }; use crate::register_evaluation; use async_trait::async_trait; -use goose::message::MessageContent; +use goose::conversation::message::MessageContent; use rmcp::model::Role; use serde_json::{self, Value}; use std::fs; diff --git a/crates/goose-bench/src/eval_suites/vibes/restaurant_research.rs b/crates/goose-bench/src/eval_suites/vibes/restaurant_research.rs index 5d6a55736b19..92f3333c954d 100644 --- a/crates/goose-bench/src/eval_suites/vibes/restaurant_research.rs +++ b/crates/goose-bench/src/eval_suites/vibes/restaurant_research.rs @@ -50,17 +50,20 @@ Present the information in order of significance or quality. Focus specifically ).await; // Write response to file and get the text content - let response_text = - match write_response_to_file(&response, run_loc, "restaurant_research_output.txt") { - Ok(text) => text, - Err(e) => { - println!("Warning: Failed to write restaurant research output: {}", e); - // If file write fails, still continue with the evaluation - response - .last() - .map_or_else(String::new, |msg| msg.as_concat_text()) - } - }; + let response_text = match write_response_to_file( + response.messages(), + run_loc, + "restaurant_research_output.txt", + ) { + Ok(text) => text, + Err(e) => { + println!("Warning: Failed to write restaurant research output: {}", e); + // If file write fails, still continue with the evaluation + response + .last() + .map_or_else(String::new, |msg| msg.as_concat_text()) + } + }; // Convert HashMap to Vec for our metrics let mut metrics = metrics_hashmap_to_vec(perf_metrics); @@ -79,7 +82,7 @@ Present the information in order of significance or quality. Focus specifically )); // Check if the fetch tool was used - let used_fetch_tool = crate::eval_suites::used_tool(&response, "fetch"); + let used_fetch_tool = crate::eval_suites::used_tool(response.messages(), "fetch"); metrics.push(( "used_fetch_tool".to_string(), EvalMetricValue::Boolean(used_fetch_tool), diff --git a/crates/goose-bench/src/eval_suites/vibes/squirrel_census.rs b/crates/goose-bench/src/eval_suites/vibes/squirrel_census.rs index fd628a6544bb..ad66471d4c93 100644 --- a/crates/goose-bench/src/eval_suites/vibes/squirrel_census.rs +++ b/crates/goose-bench/src/eval_suites/vibes/squirrel_census.rs @@ -6,7 +6,7 @@ use crate::eval_suites::{ }; use crate::register_evaluation; use async_trait::async_trait; -use goose::message::MessageContent; +use goose::conversation::message::MessageContent; use rmcp::model::Role; use serde_json::{self, Value}; diff --git a/crates/goose-cli/src/commands/bench.rs b/crates/goose-cli/src/commands/bench.rs index 02d3c026ecbf..982c74aafa93 100644 --- a/crates/goose-cli/src/commands/bench.rs +++ b/crates/goose-cli/src/commands/bench.rs @@ -2,7 +2,7 @@ use crate::session::build_session; use crate::session::SessionBuilderConfig; use crate::{logging, session, Session}; use async_trait::async_trait; -use goose::message::Message; +use goose::conversation::Conversation; use goose_bench::bench_session::{BenchAgent, BenchBaseSession}; use goose_bench::eval_suites::ExtensionRequirements; use std::path::PathBuf; @@ -18,7 +18,7 @@ impl BenchBaseSession for Session { fn session_file(&self) -> Option { self.session_file() } - fn message_history(&self) -> Vec { + fn message_history(&self) -> Conversation { self.message_history() } fn get_total_token_usage(&self) -> anyhow::Result> { diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index fcfbb8ef5483..ff218f408e02 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -13,7 +13,7 @@ use goose::config::{ Config, ConfigError, ExperimentManager, ExtensionConfigManager, ExtensionEntry, PermissionManager, }; -use goose::message::Message; +use goose::conversation::message::Message; use goose::providers::{create, providers}; use rmcp::model::{Tool, ToolAnnotations}; use rmcp::object; @@ -1551,7 +1551,7 @@ pub fn configure_max_turns_dialog() -> Result<(), Box> { /// Handle OpenRouter authentication pub async fn handle_openrouter_auth() -> Result<(), Box> { use goose::config::{configure_openrouter, signup_openrouter::OpenRouterAuth}; - use goose::message::Message; + use goose::conversation::message::Message; use goose::providers::create; // Use the OpenRouter authentication flow diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index 20a0f3d2f378..c37e57994b51 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -195,7 +195,8 @@ pub fn handle_session_export(identifier: Identifier, output_path: Option, + messages: Vec, session_file: &Path, session_name_override: Option<&str>, ) -> String { @@ -242,10 +243,12 @@ fn export_session_to_markdown( for message in &messages { // Check if this is a User message containing only ToolResponses let is_only_tool_response = message.role == rmcp::model::Role::User - && message - .content - .iter() - .all(|content| matches!(content, goose::message::MessageContent::ToolResponse(_))); + && message.content.iter().all(|content| { + matches!( + content, + goose::conversation::message::MessageContent::ToolResponse(_) + ) + }); // If the previous message had tool requests and this one is just tool responses, // don't create a new User section - we'll attach the responses to the tool calls @@ -274,11 +277,12 @@ fn export_session_to_markdown( markdown_output.push_str("\n\n---\n\n"); // Check if this message has any tool requests, to handle the next message differently - if message - .content - .iter() - .any(|content| matches!(content, goose::message::MessageContent::ToolRequest(_))) - { + if message.content.iter().any(|content| { + matches!( + content, + goose::conversation::message::MessageContent::ToolRequest(_) + ) + }) { skip_next_if_tool_response = true; } } diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index 82834607adee..91bfcb514e51 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -10,7 +10,8 @@ use axum::{ }; use futures::{sink::SinkExt, stream::StreamExt}; use goose::agents::{Agent, AgentEvent}; -use goose::message::Message as GooseMessage; +use goose::conversation::message::Message as GooseMessage; +use goose::conversation::Conversation; use goose::session; use serde::{Deserialize, Serialize}; use std::{net::SocketAddr, sync::Arc}; @@ -18,7 +19,7 @@ use tokio::sync::{Mutex, RwLock}; use tower_http::cors::{Any, CorsLayer}; use tracing::error; -type SessionStore = Arc>>>>>; +type SessionStore = Arc>>>>; type CancellationStore = Arc>>; #[derive(Clone)] @@ -319,8 +320,8 @@ async fn handle_socket(socket: WebSocket, state: AppState) { let mut sessions = state.sessions.write().await; // Load existing messages from JSONL file if it exists - let existing_messages = session::read_messages(&session_file) - .unwrap_or_else(|_| Vec::new()); + let existing_messages = + session::read_messages(&session_file).unwrap_or_default(); let new_session = Arc::new(Mutex::new(existing_messages)); sessions.insert(session_id.clone(), new_session.clone()); @@ -435,21 +436,21 @@ async fn handle_socket(socket: WebSocket, state: AppState) { async fn process_message_streaming( agent: &Agent, - session_messages: Arc>>, + session_messages: Arc>, session_file: std::path::PathBuf, content: String, sender: Arc>>, ) -> Result<()> { use futures::StreamExt; use goose::agents::SessionConfig; - use goose::message::MessageContent; + use goose::conversation::message::MessageContent; use goose::session; // Create a user message let user_message = GooseMessage::user().with_text(content.clone()); // Messages will be auto-compacted in agent.reply() if needed - let messages = { + let messages: Conversation = { let mut session_msgs = session_messages.lock().await; session_msgs.push(user_message.clone()); session_msgs.clone() @@ -493,7 +494,10 @@ async fn process_message_streaming( retry_config: None, }; - match agent.reply(&messages, Some(session_config), None).await { + match agent + .reply(messages.clone(), Some(session_config), None) + .await + { Ok(mut stream) => { while let Some(result) = stream.next().await { match result { @@ -617,7 +621,7 @@ async fn process_message_streaming( // For now, auto-summarize in web mode // TODO: Implement proper UI for context handling let (summarized_messages, _) = - agent.summarize_context(&messages).await?; + agent.summarize_context(messages.messages()).await?; { let mut session_msgs = session_messages.lock().await; *session_msgs = summarized_messages; @@ -633,7 +637,7 @@ async fn process_message_streaming( // Replace the session's message history with the compacted messages { let mut session_msgs = session_messages.lock().await; - *session_msgs = new_messages; + *session_msgs = Conversation::new_unvalidated(new_messages); } // Persist the updated messages to the JSONL file diff --git a/crates/goose-cli/src/scenario_tests/message_generator.rs b/crates/goose-cli/src/scenario_tests/message_generator.rs index d8b463007915..7384e75927f6 100644 --- a/crates/goose-cli/src/scenario_tests/message_generator.rs +++ b/crates/goose-cli/src/scenario_tests/message_generator.rs @@ -4,7 +4,7 @@ use crate::scenario_tests::scenario_runner::SCENARIO_TESTS_DIR; use base64::engine::general_purpose; use base64::Engine; -use goose::message::Message; +use goose::conversation::message::Message; use goose::providers::base::Provider; pub type MessageGenerator<'a> = Box Message + 'a>; diff --git a/crates/goose-cli/src/scenario_tests/mock_client.rs b/crates/goose-cli/src/scenario_tests/mock_client.rs index 352ad13ed30d..4c5d9c84bb85 100644 --- a/crates/goose-cli/src/scenario_tests/mock_client.rs +++ b/crates/goose-cli/src/scenario_tests/mock_client.rs @@ -100,7 +100,7 @@ impl McpClientTrait for MockClient { is_error: None, structured_content: None, }), - Err(e) => Err(Error::UnexpectedResponse), + Err(_e) => Err(Error::UnexpectedResponse), } } else { Err(Error::UnexpectedResponse) diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index 5ab1cd97a4df..b3aa3b67f9b9 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -1,4 +1,5 @@ use dotenvy::dotenv; +use goose::conversation::Conversation; use crate::scenario_tests::message_generator::MessageGenerator; use crate::scenario_tests::mock_client::weather_client; @@ -6,7 +7,6 @@ use crate::scenario_tests::provider_configs::{get_provider_configs, ProviderConf use crate::session::Session; use anyhow::Result; use goose::agents::Agent; -use goose::message::Message; use goose::model::ModelConfig; use goose::providers::{create, testprovider::TestProvider}; use std::collections::{HashMap, HashSet}; @@ -18,7 +18,7 @@ pub const SCENARIO_TESTS_DIR: &str = "src/scenario_tests"; #[derive(Debug, Clone)] pub struct ScenarioResult { - pub messages: Vec, + pub messages: Conversation, pub error: Option, } @@ -214,7 +214,7 @@ where break; } } - let updated_messages = session.message_history().to_vec(); + let updated_messages = session.message_history(); if let Some(ref err_msg) = error { if err_msg.contains("No recorded response found") { diff --git a/crates/goose-cli/src/scenario_tests/scenarios.rs b/crates/goose-cli/src/scenario_tests/scenarios.rs index 3be5668313c9..6b6c0ef99922 100644 --- a/crates/goose-cli/src/scenario_tests/scenarios.rs +++ b/crates/goose-cli/src/scenario_tests/scenarios.rs @@ -7,7 +7,7 @@ mod tests { use crate::scenario_tests::mock_client::WEATHER_TYPE; use crate::scenario_tests::scenario_runner::run_scenario; use anyhow::Result; - use goose::message::Message; + use goose::conversation::message::Message; #[tokio::test] async fn test_what_is_your_name() -> Result<()> { diff --git a/crates/goose-cli/src/session/export.rs b/crates/goose-cli/src/session/export.rs index 5c971402512f..1bbd444d9598 100644 --- a/crates/goose-cli/src/session/export.rs +++ b/crates/goose-cli/src/session/export.rs @@ -1,4 +1,4 @@ -use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; +use goose::conversation::message::{Message, MessageContent, ToolRequest, ToolResponse}; use goose::utils::safe_truncate; use rmcp::model::{RawContent, ResourceContents, Role}; use serde_json::Value; @@ -360,7 +360,7 @@ pub fn message_to_markdown(message: &Message, export_all_content: bool) -> Strin #[cfg(test)] mod tests { use super::*; - use goose::message::{Message, ToolRequest, ToolResponse}; + use goose::conversation::message::{Message, ToolRequest, ToolResponse}; use mcp_core::tool::ToolCall; use rmcp::model::{Content, RawTextContent, TextContent}; use serde_json::json; diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 60c412222f1f..82f5328bf4f2 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -10,13 +10,13 @@ mod thinking; use crate::session::task_execution_display::{ format_task_execution_notification, TASK_EXECUTION_NOTIFICATION_TYPE, }; +use goose::conversation::Conversation; use std::io::Write; pub use self::export::message_to_markdown; pub use builder::{build_session, SessionBuilderConfig, SessionSettings}; use console::Color; use goose::agents::AgentEvent; -use goose::message::push_message; use goose::permission::permission_confirmation::PrincipalType; use goose::permission::Permission; use goose::permission::PermissionConfirmation; @@ -31,7 +31,6 @@ use goose::agents::extension::{Envs, ExtensionConfig}; use goose::agents::types::RetryConfig; use goose::agents::{Agent, SessionConfig}; use goose::config::Config; -use goose::message::{Message, MessageContent}; use goose::providers::pricing::initialize_pricing_cache; use goose::session; use input::InputResult; @@ -39,6 +38,7 @@ use mcp_core::handler::ToolError; use rmcp::model::PromptMessage; use rmcp::model::ServerNotification; +use goose::conversation::message::{Message, MessageContent}; use rand::{distributions::Alphanumeric, Rng}; use rustyline::EditMode; use serde_json::Value; @@ -56,7 +56,7 @@ pub enum RunMode { pub struct Session { agent: Agent, - messages: Vec, + messages: Conversation, session_file: Option, // Cache for completion data - using std::sync for thread safety without async completion_cache: Arc>, @@ -134,11 +134,11 @@ impl Session { let messages = if let Some(session_file) = &session_file { session::read_messages(session_file).unwrap_or_else(|e| { eprintln!("Warning: Failed to load message history: {}", e); - Vec::new() + Conversation::new_unvalidated(Vec::new()) }) } else { // Don't try to read messages if we're not saving sessions - Vec::new() + Conversation::new_unvalidated(Vec::new()) }; Session { @@ -157,12 +157,12 @@ impl Session { /// Helper function to summarize context messages async fn summarize_context_messages( - messages: &mut Vec, + messages: &mut Conversation, agent: &Agent, message_suffix: &str, ) -> Result<()> { // Summarize messages to fit within context length - let (summarized_messages, _) = agent.summarize_context(messages).await?; + let (summarized_messages, _) = agent.summarize_context(messages.messages()).await?; let msg = format!("Context maxed out\n{}\n{}", "-".repeat(50), message_suffix); output::render_text(&msg, Some(Color::Yellow), true); *messages = summarized_messages; @@ -719,8 +719,10 @@ impl Session { let provider = self.agent.provider().await?; // Call the summarize_context method which uses the summarize_messages function - let (summarized_messages, _) = - self.agent.summarize_context(&self.messages).await?; + let (summarized_messages, _) = self + .agent + .summarize_context(self.messages.messages()) + .await?; // Update the session messages with the summarized ones self.messages = summarized_messages; @@ -771,12 +773,14 @@ impl Session { async fn plan_with_reasoner_model( &mut self, - plan_messages: Vec, + plan_messages: Conversation, reasoner: Arc, ) -> Result<(), anyhow::Error> { let plan_prompt = self.agent.get_plan_prompt().await?; output::show_thinking(); - let (plan_response, _usage) = reasoner.complete(&plan_prompt, &plan_messages, &[]).await?; + let (plan_response, _usage) = reasoner + .complete(&plan_prompt, plan_messages.messages(), &[]) + .await?; output::render_message(&plan_response, self.debug); output::hide_thinking(); let planner_response_type = @@ -875,7 +879,11 @@ impl Session { }); let mut stream = self .agent - .reply(&self.messages, session_config.clone(), Some(cancel_token)) + .reply( + self.messages.clone(), + session_config.clone(), + Some(cancel_token), + ) .await?; let mut progress_bars = output::McpSpinners::new(); @@ -921,7 +929,7 @@ impl Session { confirmation.id.clone(), Err(ToolError::ExecutionError("Tool call cancelled by user".to_string())) )); - push_message(&mut self.messages, response_message); + self.messages.push(response_message); if let Some(session_file) = &self.session_file { let working_dir = std::env::current_dir().ok(); session::persist_messages_with_schedule_id( @@ -983,7 +991,7 @@ impl Session { } "truncate" => { // Truncate messages to fit within context length - let (truncated_messages, _) = self.agent.truncate_context(&self.messages).await?; + let (truncated_messages, _) = self.agent.truncate_context(self.messages.messages()).await?; let msg = if context_strategy == "truncate" { format!("Context maxed out - automatically truncated messages.\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50)) } else { @@ -1013,7 +1021,7 @@ impl Session { stream = self .agent .reply( - &self.messages, + self.messages.clone(), session_config.clone(), None ) @@ -1065,7 +1073,7 @@ impl Session { } } - push_message(&mut self.messages, message.clone()); + self.messages.push(message.clone()); // No need to update description on assistant messages if let Some(session_file) = &self.session_file { @@ -1193,7 +1201,7 @@ impl Session { } Some(Ok(AgentEvent::HistoryReplaced(new_messages))) => { // Replace the session's message history with the compacted messages - self.messages = new_messages; + self.messages = Conversation::new_unvalidated(new_messages); // Persist the updated messages to the session file if let Some(session_file) = &self.session_file { @@ -1414,7 +1422,7 @@ impl Session { cache.last_updated = Instant::now(); } - pub fn message_history(&self) -> Vec { + pub fn message_history(&self) -> Conversation { self.messages.clone() } @@ -1432,7 +1440,7 @@ impl Session { ); // Render each message - for message in &self.messages { + for message in self.messages.iter() { output::render_message(message, self.debug); } @@ -1612,7 +1620,7 @@ impl Session { } fn push_message(&mut self, message: Message) { - push_message(&mut self.messages, message); + self.messages.push(message); } } diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index c264505e0d68..dc8d70124e8f 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -2,7 +2,7 @@ use anstream::println; use bat::WrappingMode; use console::{style, Color}; use goose::config::Config; -use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; +use goose::conversation::message::{Message, MessageContent, ToolRequest, ToolResponse}; use goose::providers::pricing::get_model_pricing; use goose::providers::pricing::parse_model_id; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; @@ -16,6 +16,7 @@ use std::io::{Error, IsTerminal, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::Duration; + // Re-export theme for use in main #[derive(Clone, Copy)] pub enum Theme { diff --git a/crates/goose-server/src/openapi.rs b/crates/goose-server/src/openapi.rs index 38ea0e343e5c..d1c0305239c0 100644 --- a/crates/goose-server/src/openapi.rs +++ b/crates/goose-server/src/openapi.rs @@ -3,10 +3,6 @@ use goose::agents::extension::ToolInfo; use goose::agents::ExtensionConfig; use goose::config::permission::PermissionLevel; use goose::config::ExtensionEntry; -use goose::message::{ - ContextLengthExceeded, FrontendToolRequest, Message, MessageContent, RedactedThinkingContent, - SummarizationRequested, ThinkingContent, ToolConfirmationRequest, ToolRequest, ToolResponse, -}; use goose::permission::permission_confirmation::PrincipalType; use goose::providers::base::{ConfigKey, ModelInfo, ProviderMetadata}; use goose::session::info::SessionInfo; @@ -17,6 +13,10 @@ use rmcp::model::{ }; use utoipa::{OpenApi, ToSchema}; +use goose::conversation::message::{ + ContextLengthExceeded, FrontendToolRequest, Message, MessageContent, RedactedThinkingContent, + SummarizationRequested, ThinkingContent, ToolConfirmationRequest, ToolRequest, ToolResponse, +}; use utoipa::openapi::schema::{ AdditionalProperties, AnyOfBuilder, ArrayBuilder, ObjectBuilder, OneOfBuilder, Schema, SchemaFormat, SchemaType, diff --git a/crates/goose-server/src/routes/context.rs b/crates/goose-server/src/routes/context.rs index 0630b607219f..eeff7fa452db 100644 --- a/crates/goose-server/src/routes/context.rs +++ b/crates/goose-server/src/routes/context.rs @@ -6,7 +6,7 @@ use axum::{ routing::post, Json, Router, }; -use goose::message::Message; +use goose::conversation::{message::Message, Conversation}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use utoipa::ToSchema; @@ -58,7 +58,7 @@ async fn manage_context( .await .map_err(|_| StatusCode::PRECONDITION_FAILED)?; - let mut processed_messages: Vec = vec![]; + let mut processed_messages = Conversation::new_unvalidated(vec![]); let mut token_counts: Vec = vec![]; if request.manage_action == "truncation" { @@ -74,7 +74,7 @@ async fn manage_context( } Ok(Json(ContextManageResponse { - messages: processed_messages, + messages: processed_messages.messages().clone(), token_counts, })) } diff --git a/crates/goose-server/src/routes/recipe.rs b/crates/goose-server/src/routes/recipe.rs index c2165fcfb9bf..532245159fe8 100644 --- a/crates/goose-server/src/routes/recipe.rs +++ b/crates/goose-server/src/routes/recipe.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::{extract::State, http::StatusCode, routing::post, Json, Router}; -use goose::message::Message; +use goose::conversation::{message::Message, Conversation}; use goose::recipe::Recipe; use goose::recipe_deeplink; use serde::{Deserialize, Serialize}; @@ -83,7 +83,9 @@ async fn create_recipe( .map_err(|_| (StatusCode::PRECONDITION_FAILED, Json(error_response)))?; // Create base recipe from agent state and messages - let recipe_result = agent.create_recipe(request.messages).await; + let recipe_result = agent + .create_recipe(Conversation::new_unvalidated(request.messages)) + .await; match recipe_result { Ok(mut recipe) => { diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index fb66e358f223..66480c6fb070 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -9,9 +9,10 @@ use axum::{ }; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; +use goose::conversation::message::{Message, MessageContent}; +use goose::conversation::Conversation; use goose::{ agents::{AgentEvent, SessionConfig}, - message::{push_message, Message, MessageContent}, permission::permission_confirmation::PrincipalType, }; use goose::{ @@ -186,7 +187,7 @@ async fn reply_handler( let stream = ReceiverStream::new(rx); let cancel_token = CancellationToken::new(); - let messages = request.messages; + let messages = Conversation::new_unvalidated(request.messages); let session_working_dir = request.session_working_dir.clone(); let session_id = request @@ -221,12 +222,9 @@ async fn reply_handler( retry_config: None, }; - // Messages will be auto-compacted in agent.reply() if needed - let messages_to_process = messages.clone(); - let mut stream = match agent .reply( - &messages_to_process, + messages.clone(), Some(session_config), Some(task_cancel.clone()), ) @@ -279,15 +277,15 @@ async fn reply_handler( match response { Ok(Some(Ok(AgentEvent::Message(message)))) => { for content in &message.content { - track_tool_telemetry(content, &all_messages); - } + track_tool_telemetry(content, all_messages.messages()); + } - push_message(&mut all_messages, message.clone()); - stream_event(MessageEvent::Message { message }, &tx, &cancel_token).await; + all_messages.push(message.clone()); + stream_event(MessageEvent::Message { message }, &tx, &cancel_token).await; } Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => { // Replace the message history with the compacted messages - all_messages = new_messages; + all_messages = Conversation::new_unvalidated(new_messages); // Note: We don't send this as a stream event since it's an internal operation // The client will see the compaction notification message that was sent before this event } @@ -518,6 +516,7 @@ pub fn routes(state: Arc) -> Router { #[cfg(test)] mod tests { use super::*; + use goose::conversation::message::Message; use goose::{ agents::Agent, model::ModelConfig, @@ -558,6 +557,7 @@ mod tests { mod integration_tests { use super::*; use axum::{body::Body, http::Request}; + use goose::conversation::message::Message; use std::sync::Arc; use tower::ServiceExt; diff --git a/crates/goose-server/src/routes/session.rs b/crates/goose-server/src/routes/session.rs index 5095cb38b309..c2c34e00574b 100644 --- a/crates/goose-server/src/routes/session.rs +++ b/crates/goose-server/src/routes/session.rs @@ -10,7 +10,7 @@ use axum::{ routing::{get, put}, Json, Router, }; -use goose::message::Message; +use goose::conversation::message::Message; use goose::session; use goose::session::info::{get_valid_sorted_sessions, SessionInfo, SortOrder}; use goose::session::SessionMetadata; @@ -137,7 +137,7 @@ async fn get_session_history( Ok(Json(SessionHistoryResponse { session_id, metadata, - messages, + messages: messages.messages().clone(), })) } diff --git a/crates/goose/examples/agent.rs b/crates/goose/examples/agent.rs index fcd90b9b3f24..74c804c26f04 100644 --- a/crates/goose/examples/agent.rs +++ b/crates/goose/examples/agent.rs @@ -4,7 +4,8 @@ use dotenvy::dotenv; use futures::StreamExt; use goose::agents::{Agent, AgentEvent, ExtensionConfig}; use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT}; -use goose::message::Message; +use goose::conversation::message::Message; +use goose::conversation::Conversation; use goose::providers::databricks::DatabricksProvider; #[tokio::main] @@ -32,10 +33,11 @@ async fn main() { println!(" {}", extension); } - let messages = vec![Message::user() - .with_text("can you summarize the readme.md in this dir using just a haiku?")]; + let conversation = Conversation::new(vec![Message::user() + .with_text("can you summarize the readme.md in this dir using just a haiku?")]) + .unwrap(); - let mut stream = agent.reply(&messages, None, None).await.unwrap(); + let mut stream = agent.reply(conversation, None, None).await.unwrap(); while let Some(Ok(AgentEvent::Message(message))) = stream.next().await { println!("{}", serde_json::to_string_pretty(&message).unwrap()); println!("\n"); diff --git a/crates/goose/examples/databricks_oauth.rs b/crates/goose/examples/databricks_oauth.rs index e764007b0405..fc31e3c7efee 100644 --- a/crates/goose/examples/databricks_oauth.rs +++ b/crates/goose/examples/databricks_oauth.rs @@ -1,11 +1,9 @@ use anyhow::Result; use dotenvy::dotenv; -use goose::{ - message::Message, - providers::{ - base::{Provider, Usage}, - databricks::DatabricksProvider, - }, +use goose::conversation::message::Message; +use goose::providers::{ + base::{Provider, Usage}, + databricks::DatabricksProvider, }; use tokio_stream::StreamExt; diff --git a/crates/goose/examples/image_tool.rs b/crates/goose/examples/image_tool.rs index a021c7349f71..c12c1f820024 100644 --- a/crates/goose/examples/image_tool.rs +++ b/crates/goose/examples/image_tool.rs @@ -1,9 +1,9 @@ use anyhow::Result; use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use dotenvy::dotenv; -use goose::{ - message::Message, - providers::{bedrock::BedrockProvider, databricks::DatabricksProvider, openai::OpenAiProvider}, +use goose::conversation::message::Message; +use goose::providers::{ + bedrock::BedrockProvider, databricks::DatabricksProvider, openai::OpenAiProvider, }; use mcp_core::tool::ToolCall; use rmcp::model::{Content, Tool}; diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index a92e823f0f35..469eef7a8c26 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -34,7 +34,7 @@ use crate::agents::types::SessionConfig; use crate::agents::types::{FrontendTool, ToolResultReceiver}; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::context_mgmt::auto_compact; -use crate::message::{push_message, Message, ToolRequest}; +use crate::conversation::{debug_conversation_fix, fix_conversation, Conversation}; use crate::permission::permission_judge::{check_tool_permissions, PermissionCheckResult}; use crate::permission::PermissionConfirmation; use crate::providers::base::Provider; @@ -56,13 +56,13 @@ use super::final_output_tool::FinalOutputTool; use super::platform_tools; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; use crate::agents::subagent_task_config::TaskConfig; -use crate::conversation_fixer::{debug_conversation_fix, ConversationFixer}; +use crate::conversation::message::{Message, ToolRequest}; const DEFAULT_MAX_TURNS: u32 = 1000; /// Context needed for the reply function pub struct ReplyContext { - pub messages: Vec, + pub messages: Conversation, pub tools: Vec, pub toolshim_tools: Vec, pub system_prompt: String, @@ -199,7 +199,7 @@ impl Agent { /// Handle retry logic for the agent reply loop async fn handle_retry_logic( &self, - messages: &mut Vec, + messages: &mut Conversation, session: &Option, initial_messages: &[Message], ) -> Result { @@ -218,24 +218,29 @@ impl Agent { async fn prepare_reply_context( &self, - unfixed_messages: &[Message], + unfixed_conversation: Conversation, session: &Option, ) -> Result { - let (messages, issues) = ConversationFixer::fix_conversation(Vec::from(unfixed_messages)); + let unfixed_messages = unfixed_conversation.messages().clone(); + let (conversation, issues) = fix_conversation(unfixed_conversation.clone()); if !issues.is_empty() { tracing::warn!( "Conversation issue fixed: {}", - debug_conversation_fix(unfixed_messages, &messages, &issues) + debug_conversation_fix( + unfixed_messages.as_slice(), + conversation.messages(), + &issues + ) ); } - let initial_messages = messages.clone(); + let initial_messages = conversation.messages().clone(); let config = Config::global(); let (tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?; let goose_mode = Self::determine_goose_mode(session.as_ref(), config); Ok(ReplyContext { - messages, + messages: conversation, tools, toolshim_tools, system_prompt, @@ -760,7 +765,7 @@ impl Agent { &self, messages: &[Message], session: &Option, - ) -> Result, String)>> { + ) -> Result> { // Try to get session metadata for more accurate token counts let session_metadata = if let Some(session_config) = session { match session::storage::get_path(session_config.id.clone()) { @@ -802,22 +807,22 @@ impl Agent { Ok(None) } - #[instrument(skip(self, unfixed_messages, session), fields(user_message))] + #[instrument(skip(self, unfixed_conversation, session), fields(user_message))] pub async fn reply( &self, - unfixed_messages: &[Message], + unfixed_conversation: Conversation, session: Option, cancel_token: Option, ) -> Result>> { // Handle auto-compaction before processing let (messages, compaction_msg) = match self - .handle_auto_compaction(unfixed_messages, &session) + .handle_auto_compaction(unfixed_conversation.messages(), &session) .await? { Some((compacted_messages, msg)) => (compacted_messages, Some(msg)), None => { let context = self - .prepare_reply_context(unfixed_messages, &session) + .prepare_reply_context(unfixed_conversation, &session) .await?; (context.messages, None) } @@ -827,10 +832,10 @@ impl Agent { if let Some(compaction_msg) = compaction_msg { return Ok(Box::pin(async_stream::try_stream! { yield AgentEvent::Message(Message::assistant().with_text(compaction_msg)); - yield AgentEvent::HistoryReplaced(messages.clone()); + yield AgentEvent::HistoryReplaced(messages.messages().clone()); // Continue with normal reply processing using compacted messages - let mut reply_stream = self.reply_internal(&messages, session, cancel_token).await?; + let mut reply_stream = self.reply_internal(messages, session, cancel_token).await?; while let Some(event) = reply_stream.next().await { yield event?; } @@ -838,13 +843,13 @@ impl Agent { } // No compaction needed, proceed with normal processing - self.reply_internal(&messages, session, cancel_token).await + self.reply_internal(messages, session, cancel_token).await } /// Main reply method that handles the actual agent processing async fn reply_internal( &self, - messages: &[Message], + messages: Conversation, session: Option, cancel_token: Option, ) -> Result>> { @@ -905,7 +910,7 @@ impl Agent { let mut stream = Self::stream_response_from_provider( self.provider().await?, &system_prompt, - &messages, + messages.messages(), &tools, &toolshim_tools, ).await?; @@ -1074,8 +1079,8 @@ impl Agent { yield AgentEvent::Message(final_message_tool_resp.clone()); added_message = true; - push_message(&mut messages_to_add, response); - push_message(&mut messages_to_add, final_message_tool_resp); + messages_to_add.push(response); + messages_to_add.push(final_message_tool_resp); } } Err(ProviderError::ContextLengthExceeded(_)) => { @@ -1244,7 +1249,7 @@ impl Agent { } } - pub async fn create_recipe(&self, mut messages: Vec) -> Result { + pub async fn create_recipe(&self, mut messages: Conversation) -> Result { let extension_manager = self.extension_manager.read().await; let extensions_info = extension_manager.get_extensions_info().await; @@ -1273,7 +1278,7 @@ impl Agent { .await .as_ref() .unwrap() - .complete(&system_prompt, &messages, &tools) + .complete(&system_prompt, messages.messages(), &tools) .await?; let content = result.as_concat_text(); diff --git a/crates/goose/src/agents/context.rs b/crates/goose/src/agents/context.rs index 7ef1c267e3b7..9bce12353028 100644 --- a/crates/goose/src/agents/context.rs +++ b/crates/goose/src/agents/context.rs @@ -1,6 +1,7 @@ use anyhow::Ok; -use crate::message::Message; +use crate::conversation::message::Message; +use crate::conversation::Conversation; use crate::token_counter::create_async_token_counter; use crate::context_mgmt::summarize::summarize_messages_async; @@ -14,7 +15,7 @@ impl Agent { pub async fn truncate_context( &self, messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded - ) -> Result<(Vec, Vec), anyhow::Error> { + ) -> Result<(Conversation, Vec), anyhow::Error> { let provider = self.provider().await?; let token_counter = create_async_token_counter() .await @@ -51,7 +52,7 @@ impl Agent { pub async fn summarize_context( &self, messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded - ) -> Result<(Vec, Vec), anyhow::Error> { + ) -> Result<(Conversation, Vec), anyhow::Error> { let provider = self.provider().await?; let token_counter = create_async_token_counter() .await diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 46c134e0306e..085e00b1430e 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -5,8 +5,10 @@ use std::sync::Arc; use async_stream::try_stream; use futures::stream::StreamExt; +use super::super::agents::Agent; use crate::agents::router_tool_selector::RouterToolSelectionStrategy; -use crate::message::{Message, MessageContent, ToolRequest}; +use crate::conversation::message::{Message, MessageContent, ToolRequest}; +use crate::conversation::Conversation; use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage}; use crate::providers::errors::ProviderError; use crate::providers::toolshim::{ @@ -16,8 +18,6 @@ use crate::providers::toolshim::{ use crate::session; use rmcp::model::Tool; -use super::super::agents::Agent; - async fn toolshim_postprocess( response: Message, toolshim_tools: &[Tool], @@ -127,12 +127,12 @@ impl Agent { let messages_for_provider = if config.toolshim { convert_tool_messages_to_text(messages) } else { - messages.to_vec() + Conversation::new_unvalidated(messages.to_vec()) }; // Call the provider to get a response let (mut response, usage) = provider - .complete(system_prompt, &messages_for_provider, tools) + .complete(system_prompt, messages_for_provider.messages(), tools) .await?; crate::providers::base::set_current_model(&usage.model); @@ -159,7 +159,7 @@ impl Agent { let messages_for_provider = if config.toolshim { convert_tool_messages_to_text(messages) } else { - messages.to_vec() + Conversation::new_unvalidated(messages.to_vec()) }; // Clone owned data to move into the async stream @@ -170,11 +170,19 @@ impl Agent { let mut stream = if provider.supports_streaming() { provider - .stream(system_prompt.as_str(), &messages_for_provider, &tools) + .stream( + system_prompt.as_str(), + messages_for_provider.messages(), + &tools, + ) .await? } else { let (message, usage) = provider - .complete(system_prompt.as_str(), &messages_for_provider, &tools) + .complete( + system_prompt.as_str(), + messages_for_provider.messages(), + &tools, + ) .await?; stream_from_single_message(message, usage) }; diff --git a/crates/goose/src/agents/retry.rs b/crates/goose/src/agents/retry.rs index 20c52127ba0b..ac0e374da1fe 100644 --- a/crates/goose/src/agents/retry.rs +++ b/crates/goose/src/agents/retry.rs @@ -11,7 +11,8 @@ use crate::agents::types::{ RetryConfig, SuccessCheck, DEFAULT_ON_FAILURE_TIMEOUT_SECONDS, DEFAULT_RETRY_TIMEOUT_SECONDS, }; use crate::config::Config; -use crate::message::Message; +use crate::conversation::message::Message; +use crate::conversation::Conversation; use crate::tool_monitor::ToolMonitor; /// Result of a retry logic evaluation @@ -92,12 +93,11 @@ impl RetryManager { /// Reset status for retry: clear message history and final output tool state async fn reset_status_for_retry( - messages: &mut Vec, + messages: &mut Conversation, initial_messages: &[Message], final_output_tool: &Arc>>, ) { - messages.clear(); - messages.extend_from_slice(initial_messages); + *messages = Conversation::new_unvalidated(initial_messages.to_vec()); info!("Reset message history to initial state for retry"); if let Some(final_output_tool) = final_output_tool.lock().await.as_mut() { @@ -109,7 +109,7 @@ impl RetryManager { /// Handle retry logic for the agent reply loop pub async fn handle_retry_logic( &self, - messages: &mut Vec, + messages: &mut Conversation, session: &Option, initial_messages: &[Message], final_output_tool: &Arc>>, diff --git a/crates/goose/src/agents/router_tool_selector.rs b/crates/goose/src/agents/router_tool_selector.rs index 706144af245e..8dd4ac0aa89e 100644 --- a/crates/goose/src/agents/router_tool_selector.rs +++ b/crates/goose/src/agents/router_tool_selector.rs @@ -12,7 +12,7 @@ use std::sync::Arc; use tokio::sync::RwLock; use crate::agents::tool_vectordb::ToolVectorDB; -use crate::message::Message; +use crate::conversation::message::Message; use crate::model::ModelConfig; use crate::providers::{self, base::Provider}; diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 701d4aac60c0..dc8c43eb158a 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -3,7 +3,6 @@ use crate::{ agents::extension::ExtensionConfig, agents::{extension_manager::ExtensionManager, Agent, TaskConfig}, config::ExtensionConfigManager, - message::{Message, MessageContent, ToolRequest}, prompt_template::render_global_file, providers::errors::ProviderError, }; @@ -13,6 +12,8 @@ use mcp_core::handler::ToolError; use rmcp::model::Tool; use serde::{Deserialize, Serialize}; // use serde_json::{self}; +use crate::conversation::message::{Message, MessageContent, ToolRequest}; +use crate::conversation::Conversation; use std::{collections::HashMap, sync::Arc}; use tokio::sync::{Mutex, RwLock}; use tokio_util::sync::CancellationToken; @@ -41,7 +42,7 @@ pub struct SubAgentProgress { /// A specialized agent that can handle specific tasks independently pub struct SubAgent { pub id: String, - pub conversation: Arc>>, + pub conversation: Arc>, pub status: Arc>, pub config: TaskConfig, pub turn_count: Arc>, @@ -80,7 +81,7 @@ impl SubAgent { let subagent = Arc::new(SubAgent { id: task_config.id.clone(), - conversation: Arc::new(Mutex::new(Vec::new())), + conversation: Arc::new(Mutex::new(Conversation::new_unvalidated(Vec::new()))), status: Arc::new(RwLock::new(SubAgentStatus::Ready)), config: task_config, turn_count: Arc::new(Mutex::new(0)), @@ -107,7 +108,7 @@ impl SubAgent { &self, message: String, task_config: TaskConfig, - ) -> Result, anyhow::Error> { + ) -> Result { debug!("Processing message for subagent {}", self.id); // Get provider from task config @@ -128,7 +129,10 @@ impl SubAgent { } // Get the current conversation for context - let mut messages = self.get_conversation().await; + let mut messages = { + let conversation = self.conversation.lock().await; + conversation.clone() + }; // Get tools from the subagent's own extension manager let tools: Vec = self @@ -156,7 +160,7 @@ impl SubAgent { match Agent::generate_response_from_provider( Arc::clone(provider), &system_prompt, - &messages, + messages.messages(), &tools, &toolshim_tools, ) @@ -264,11 +268,6 @@ impl SubAgent { conversation.push(message); } - /// Get the full conversation history - async fn get_conversation(&self) -> Vec { - self.conversation.lock().await.clone() - } - /// Build the system prompt for the subagent using the template async fn build_system_prompt(&self, available_tools: &[Tool]) -> Result { let mut context = HashMap::new(); diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index bf3b66cc4098..525000b6ea83 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -24,10 +24,10 @@ pub async fn run_complete_subagent_task( .flat_map(|message| { message.content.iter().filter_map(|content| { match content { - crate::message::MessageContent::Text(text_content) => { + crate::conversation::message::MessageContent::Text(text_content) => { Some(text_content.text.clone()) } - crate::message::MessageContent::ToolResponse(tool_response) => { + crate::conversation::message::MessageContent::ToolResponse(tool_response) => { // Extract text from tool response if let Ok(contents) = &tool_response.tool_result { let texts: Vec = contents diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index f80a03ca355c..045cdd0229dd 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -9,7 +9,6 @@ use tokio_util::sync::CancellationToken; use crate::config::permission::PermissionLevel; use crate::config::PermissionManager; -use crate::message::{Message, ToolRequest}; use crate::permission::Permission; use mcp_core::ToolResult; use rmcp::model::{Content, ServerNotification}; @@ -32,6 +31,7 @@ impl From>> for ToolCallResult { use super::agent::{tool_stream, ToolStream}; use crate::agents::Agent; +use crate::conversation::message::{Message, ToolRequest}; pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \ DO NOT attempt to call this tool again. \ diff --git a/crates/goose/src/agents/tool_route_manager.rs b/crates/goose/src/agents/tool_route_manager.rs index 08a77157aa13..5829d582574c 100644 --- a/crates/goose/src/agents/tool_route_manager.rs +++ b/crates/goose/src/agents/tool_route_manager.rs @@ -7,7 +7,7 @@ use crate::agents::tool_execution::ToolCallResult; use crate::agents::tool_router_index_manager::ToolRouterIndexManager; use crate::agents::tool_vectordb::generate_table_id; use crate::config::Config; -use crate::message::ToolRequest; +use crate::conversation::message::ToolRequest; use crate::providers::base::Provider; use anyhow::{anyhow, Result}; use mcp_core::ToolError; diff --git a/crates/goose/src/context_mgmt/auto_compact.rs b/crates/goose/src/context_mgmt/auto_compact.rs index 514089ed9f23..ffc970f9212a 100644 --- a/crates/goose/src/context_mgmt/auto_compact.rs +++ b/crates/goose/src/context_mgmt/auto_compact.rs @@ -1,3 +1,5 @@ +use crate::conversation::message::Message; +use crate::conversation::Conversation; use crate::{ agents::Agent, config::Config, @@ -5,7 +7,6 @@ use crate::{ common::{SYSTEM_PROMPT_TOKEN_OVERHEAD, TOOLS_TOKEN_OVERHEAD}, get_messages_token_counts_async, }, - message::Message, token_counter::create_async_token_counter, }; use anyhow::Result; @@ -17,7 +18,7 @@ pub struct AutoCompactResult { /// Whether compaction was performed pub compacted: bool, /// The messages after potential compaction - pub messages: Vec, + pub messages: Conversation, /// Token count before compaction (if compaction occurred) pub tokens_before: Option, /// Token count after compaction (if compaction occurred) @@ -140,7 +141,7 @@ pub async fn check_compaction_needed( pub async fn perform_compaction( agent: &Agent, messages: &[Message], -) -> Result<(Vec, usize, usize)> { +) -> Result<(Conversation, usize, usize)> { // Get token counter to measure before/after let token_counter = create_async_token_counter() .await @@ -199,7 +200,7 @@ pub async fn check_and_compact_messages( ); return Ok(AutoCompactResult { compacted: false, - messages: messages.to_vec(), + messages: Conversation::new_unvalidated(messages.to_vec()), tokens_before: None, tokens_after: None, }); @@ -243,9 +244,9 @@ pub async fn check_and_compact_messages( #[cfg(test)] mod tests { use super::*; + use crate::conversation::message::{Message, MessageContent}; use crate::{ agents::Agent, - message::{Message, MessageContent}, model::ModelConfig, providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}, providers::errors::ProviderError, diff --git a/crates/goose/src/context_mgmt/common.rs b/crates/goose/src/context_mgmt/common.rs index 49c4467ba3de..8bb5ff005279 100644 --- a/crates/goose/src/context_mgmt/common.rs +++ b/crates/goose/src/context_mgmt/common.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use rmcp::model::Tool; +use crate::conversation::message::Message; use crate::{ - message::Message, providers::base::Provider, token_counter::{AsyncTokenCounter, TokenCounter}, }; diff --git a/crates/goose/src/context_mgmt/summarize.rs b/crates/goose/src/context_mgmt/summarize.rs index c7a92fa2954a..0462d55fa4ff 100644 --- a/crates/goose/src/context_mgmt/summarize.rs +++ b/crates/goose/src/context_mgmt/summarize.rs @@ -1,5 +1,7 @@ -use super::common::{get_messages_token_counts, get_messages_token_counts_async}; -use crate::message::Message; +use super::common::get_messages_token_counts_async; +use crate::context_mgmt::get_messages_token_counts; +use crate::conversation::message::Message; +use crate::conversation::Conversation; use crate::prompt_template::render_global_file; use crate::providers::base::Provider; use crate::token_counter::{AsyncTokenCounter, TokenCounter}; @@ -23,13 +25,15 @@ async fn summarize_combined_messages( provider: &Arc, accumulated_summary: &[Message], current_chunk: &[Message], -) -> Result, anyhow::Error> { +) -> Result { // Combine the accumulated summary and current chunk into a single batch. - let combined_messages: Vec = accumulated_summary - .iter() - .cloned() - .chain(current_chunk.iter().cloned()) - .collect(); + let combined_messages = Conversation::new_unvalidated( + accumulated_summary + .iter() + .cloned() + .chain(current_chunk.iter().cloned()) + .collect::>(), + ); // Format the batch as a summarization request. let request_text = format!( @@ -47,7 +51,7 @@ async fn summarize_combined_messages( response.role = Role::User; // Return the summary as the new accumulated summary. - Ok(vec![response]) + Ok(Conversation::new_unvalidated(vec![response])) } // Summarization steps: @@ -57,10 +61,10 @@ pub async fn summarize_messages_oneshot( messages: &[Message], token_counter: &TokenCounter, _context_limit: usize, -) -> Result<(Vec, Vec), anyhow::Error> { +) -> Result<(Conversation, Vec), anyhow::Error> { if messages.is_empty() { // If no messages to summarize, return empty - return Ok((vec![], vec![])); + return Ok((Conversation::empty(), vec![])); } // Format all messages as a single string for the summarization prompt @@ -92,12 +96,10 @@ pub async fn summarize_messages_oneshot( response.role = Role::User; // Return just the summary without any tool response preservation - let final_summary = vec![response]; + let final_summary = Conversation::new_unvalidated([response].into_iter()); + let counts = get_messages_token_counts(token_counter, final_summary.messages()); - Ok(( - final_summary.clone(), - get_messages_token_counts(token_counter, &final_summary), - )) + Ok((final_summary, counts)) } // Summarization steps: @@ -111,10 +113,10 @@ pub async fn summarize_messages_chunked( messages: &[Message], token_counter: &TokenCounter, context_limit: usize, -) -> Result<(Vec, Vec), anyhow::Error> { +) -> Result<(Conversation, Vec), anyhow::Error> { let chunk_size = context_limit / 3; // 33% of the context window. let summary_prompt_tokens = token_counter.count_tokens(SUMMARY_PROMPT); - let mut accumulated_summary = Vec::new(); + let mut accumulated_summary = Conversation::empty(); // Get token counts for each message. let token_counts = get_messages_token_counts(token_counter, messages); @@ -126,9 +128,12 @@ pub async fn summarize_messages_chunked( for (message, message_tokens) in messages.iter().zip(token_counts.iter()) { if current_chunk_tokens + message_tokens > chunk_size - summary_prompt_tokens { // Summarize the current chunk with the accumulated summary. - accumulated_summary = - summarize_combined_messages(&provider, &accumulated_summary, ¤t_chunk) - .await?; + accumulated_summary = summarize_combined_messages( + &provider, + accumulated_summary.messages(), + ¤t_chunk, + ) + .await?; // Reset for the next chunk. current_chunk.clear(); @@ -143,13 +148,14 @@ pub async fn summarize_messages_chunked( // Summarize the final chunk if it exists. if !current_chunk.is_empty() { accumulated_summary = - summarize_combined_messages(&provider, &accumulated_summary, ¤t_chunk).await?; + summarize_combined_messages(&provider, accumulated_summary.messages(), ¤t_chunk) + .await?; } // Return just the summary without any tool response preservation Ok(( accumulated_summary.clone(), - get_messages_token_counts(token_counter, &accumulated_summary), + get_messages_token_counts(token_counter, accumulated_summary.messages()), )) } @@ -164,7 +170,7 @@ pub async fn summarize_messages( messages: &[Message], token_counter: &TokenCounter, context_limit: usize, -) -> Result<(Vec, Vec), anyhow::Error> { +) -> Result<(Conversation, Vec), anyhow::Error> { // Calculate total tokens in messages let total_tokens: usize = get_messages_token_counts(token_counter, messages) .iter() @@ -207,24 +213,27 @@ pub async fn summarize_messages_async( messages: &[Message], token_counter: &AsyncTokenCounter, context_limit: usize, -) -> Result<(Vec, Vec), anyhow::Error> { +) -> Result<(Conversation, Vec), anyhow::Error> { let chunk_size = context_limit / 3; // 33% of the context window. let summary_prompt_tokens = token_counter.count_tokens(SUMMARY_PROMPT); - let mut accumulated_summary = Vec::new(); + let mut accumulated_summary = Conversation::empty(); // Get token counts for each message. let token_counts = get_messages_token_counts_async(token_counter, messages); // Tokenize and break messages into chunks. - let mut current_chunk: Vec = Vec::new(); + let mut current_chunk = Vec::new(); let mut current_chunk_tokens = 0; for (message, message_tokens) in messages.iter().zip(token_counts.iter()) { if current_chunk_tokens + message_tokens > chunk_size - summary_prompt_tokens { // Summarize the current chunk with the accumulated summary. - accumulated_summary = - summarize_combined_messages(&provider, &accumulated_summary, ¤t_chunk) - .await?; + accumulated_summary = summarize_combined_messages( + &provider, + accumulated_summary.messages(), + ¤t_chunk, + ) + .await?; // Reset for the next chunk. current_chunk.clear(); @@ -239,22 +248,22 @@ pub async fn summarize_messages_async( // Summarize the final chunk if it exists. if !current_chunk.is_empty() { accumulated_summary = - summarize_combined_messages(&provider, &accumulated_summary, ¤t_chunk).await?; + summarize_combined_messages(&provider, accumulated_summary.messages(), ¤t_chunk) + .await?; } + let count = get_messages_token_counts_async(token_counter, accumulated_summary.messages()); + // Return just the summary without any tool response preservation - Ok(( - accumulated_summary.clone(), - get_messages_token_counts_async(token_counter, &accumulated_summary), - )) + Ok((accumulated_summary.clone(), count)) } #[cfg(test)] mod tests { use super::*; - use crate::message::{Message, MessageContent}; + use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; - use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; + use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage}; use crate::providers::errors::ProviderError; use chrono::Utc; use rmcp::model::Role; @@ -343,7 +352,7 @@ mod tests { "The summary should contain one message." ); assert_eq!( - summarized_messages[0].role, + summarized_messages.first().unwrap().role, Role::User, "The summarized message should be from the user." ); @@ -379,7 +388,7 @@ mod tests { "There should be one final summarized message." ); assert_eq!( - summarized_messages[0].role, + summarized_messages.first().unwrap().role, Role::User, "The summarized message should be from the user." ); @@ -551,7 +560,8 @@ mod tests { ); // Verify the content comes from the chunked approach - if let MessageContent::Text(text_content) = &summarized_messages[0].content[0] { + if let MessageContent::Text(text_content) = &summarized_messages.first().unwrap().content[0] + { assert_eq!(text_content.text, "Chunked summary"); } else { panic!("Expected text content"); @@ -592,7 +602,7 @@ mod tests { "One-shot should return a single summary message." ); assert_eq!( - summarized_messages[0].role, + summarized_messages.first().unwrap().role, Role::User, "Summary should be from user role for context." ); @@ -630,7 +640,7 @@ mod tests { "Chunked should return a single final summary." ); assert_eq!( - summarized_messages[0].role, + summarized_messages.first().unwrap().role, Role::User, "Summary should be from user role for context." ); diff --git a/crates/goose/src/context_mgmt/truncate.rs b/crates/goose/src/context_mgmt/truncate.rs index cb20500b53ea..c57e6d42f0ac 100644 --- a/crates/goose/src/context_mgmt/truncate.rs +++ b/crates/goose/src/context_mgmt/truncate.rs @@ -1,4 +1,5 @@ -use crate::message::{Message, MessageContent}; +use crate::conversation::message::{Message, MessageContent}; +use crate::conversation::Conversation; use crate::utils::safe_truncate; use anyhow::{anyhow, Result}; use rmcp::model::{RawContent, ResourceContents, Role}; @@ -16,7 +17,7 @@ fn handle_oversized_messages( token_counts: &[usize], context_limit: usize, strategy: &dyn TruncationStrategy, -) -> Result<(Vec, Vec), anyhow::Error> { +) -> Result<(Conversation, Vec), anyhow::Error> { let mut truncated_messages = Vec::new(); let mut truncated_token_counts = Vec::new(); let mut any_truncated = false; @@ -67,7 +68,10 @@ fn handle_oversized_messages( ); } - Ok((truncated_messages, truncated_token_counts)) + Ok(( + Conversation::new_unvalidated(truncated_messages), + truncated_token_counts, + )) } /// Truncates the content within a message while preserving its structure @@ -180,7 +184,7 @@ pub fn truncate_messages( token_counts: &[usize], context_limit: usize, strategy: &dyn TruncationStrategy, -) -> Result<(Vec, Vec), anyhow::Error> { +) -> Result<(Conversation, Vec), anyhow::Error> { let mut messages = messages.to_owned(); let mut token_counts = token_counts.to_owned(); @@ -221,7 +225,10 @@ pub fn truncate_messages( } if total_tokens <= context_limit { - return Ok((messages, token_counts)); // No truncation needed + return Ok(( + Conversation::new_unvalidated(messages.to_vec()), + token_counts.to_vec(), + )); // No truncation needed } // Step 2: Determine indices to remove based on strategy @@ -303,7 +310,10 @@ pub fn truncate_messages( } debug!("Truncation complete. Total tokens: {}", total_tokens); - Ok((messages, token_counts)) + Ok(( + Conversation::new_unvalidated(messages.to_vec()), + token_counts.to_vec(), + )) } /// Trait representing a truncation strategy @@ -378,7 +388,7 @@ impl TruncationStrategy for OldestFirstTruncation { #[cfg(test)] mod tests { use super::*; - use crate::message::Message; + use crate::conversation::message::Message; use anyhow::Result; use mcp_core::tool::ToolCall; use rmcp::model::Content; @@ -423,15 +433,13 @@ mod tests { num_pairs: usize, tokens: usize, remove_last: bool, - ) -> (Vec, Vec) { - let mut messages: Vec = (0..num_pairs) - .flat_map(|i| { - vec![ - user_text(i * 2, tokens).0, - assistant_text((i * 2) + 1, tokens).0, - ] - }) - .collect(); + ) -> (Conversation, Vec) { + let mut messages = Conversation::new_unvalidated((0..num_pairs).flat_map(|i| { + vec![ + user_text(i * 2, tokens).0, + assistant_text((i * 2) + 1, tokens).0, + ] + })); if remove_last { messages.pop(); @@ -498,13 +506,13 @@ mod tests { let context_limit = 25; let result = truncate_messages( - &messages, + &messages.messages(), &token_counts, context_limit, &OldestFirstTruncation, )?; - assert_eq!(result.0, messages); + assert_eq!(result.0.messages(), messages.messages()); assert_eq!(result.1, token_counts); Ok(()) } @@ -582,7 +590,7 @@ mod tests { let context_limit = 100; // Exactly matches total tokens let result = truncate_messages( - &messages, + &messages.messages(), &token_counts, context_limit, &OldestFirstTruncation, @@ -597,7 +605,7 @@ mod tests { token_counts.push(1); let result = truncate_messages( - &messages, + &messages.messages(), &token_counts, context_limit, &OldestFirstTruncation, @@ -702,7 +710,7 @@ mod tests { // Test impossibly small context window let (messages, token_counts) = create_messages_with_counts(1, 10, false); let result = truncate_messages( - &messages, + &messages.messages(), &token_counts, 5, // Impossibly small context &OldestFirstTruncation, diff --git a/crates/goose/src/message.rs b/crates/goose/src/conversation/message.rs similarity index 95% rename from crates/goose/src/message.rs rename to crates/goose/src/conversation/message.rs index a6194035d4d3..2206a75f60b5 100644 --- a/crates/goose/src/message.rs +++ b/crates/goose/src/conversation/message.rs @@ -1,18 +1,8 @@ -/// Messages which represent the content sent back and forth to LLM provider -/// -/// We use these messages in the agent code, and interfaces which interact with -/// the agent. That let's us reuse message histories across different interfaces. -/// -/// The content of the messages uses MCP types to avoid additional conversions -/// when interacting with MCP servers. use chrono::Utc; -use mcp_core::handler::ToolResult; -use mcp_core::tool::ToolCall; -use rmcp::model::ResourceContents; -use rmcp::model::Role; +use mcp_core::{ToolCall, ToolResult}; use rmcp::model::{ AnnotateAble, Content, ImageContent, PromptMessage, PromptMessageContent, PromptMessageRole, - RawContent, RawImageContent, RawTextContent, TextContent, + RawContent, RawImageContent, RawTextContent, ResourceContents, Role, TextContent, }; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -20,7 +10,7 @@ use std::collections::HashSet; use std::fmt; use utoipa::ToSchema; -mod tool_result_serde; +use crate::conversation::tool_result_serde; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -371,26 +361,6 @@ impl fmt::Debug for Message { } } -pub fn push_message(messages: &mut Vec, message: Message) { - if let Some(last) = messages - .last_mut() - .filter(|m| m.id.is_some() && m.id == message.id) - { - match (last.content.last_mut(), message.content.last()) { - (Some(MessageContent::Text(ref mut last)), Some(MessageContent::Text(new))) - if message.content.len() == 1 => - { - last.text.push_str(&new.text); - } - (_, _) => { - last.content.extend(message.content); - } - } - } else { - messages.push(message); - } -} - fn default_created() -> i64 { 0 // old messages do not have timestamps. } @@ -585,9 +555,14 @@ impl Message { #[cfg(test)] mod tests { - use super::*; + use crate::conversation::message::{Message, MessageContent}; + use crate::conversation::*; use mcp_core::handler::ToolError; - use rmcp::model::{PromptMessage, PromptMessageContent, RawEmbeddedResource, ResourceContents}; + use mcp_core::ToolCall; + use rmcp::model::{ + AnnotateAble, PromptMessage, PromptMessageContent, PromptMessageRole, RawEmbeddedResource, + RawImageContent, ResourceContents, + }; use serde_json::{json, Value}; #[test] diff --git a/crates/goose/src/conversation/mod.rs b/crates/goose/src/conversation/mod.rs new file mode 100644 index 000000000000..f1dbfa1e75b8 --- /dev/null +++ b/crates/goose/src/conversation/mod.rs @@ -0,0 +1,534 @@ +use crate::conversation::message::{Message, MessageContent}; +use rmcp::model::Role; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; +use thiserror::Error; + +pub mod message; +mod tool_result_serde; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Conversation(Vec); + +#[derive(Error, Debug)] +#[error("invalid conversation: {reason}")] +pub struct InvalidConversation { + reason: String, + conversation: Conversation, +} + +impl Conversation { + pub fn new(messages: I) -> Result + where + I: IntoIterator, + { + Self::new_unvalidated(messages).validate() + } + + pub fn new_unvalidated(messages: I) -> Self + where + I: IntoIterator, + { + Self(messages.into_iter().collect()) + } + + pub fn empty() -> Self { + Self::new_unvalidated([]) + } + + pub fn messages(&self) -> &Vec { + &self.0 + } + + pub fn push(&mut self, message: Message) { + if let Some(last) = self + .0 + .last_mut() + .filter(|m| m.id.is_some() && m.id == message.id) + { + match (last.content.last_mut(), message.content.last()) { + (Some(MessageContent::Text(ref mut last)), Some(MessageContent::Text(new))) + if message.content.len() == 1 => + { + last.text.push_str(&new.text); + } + (_, _) => { + last.content.extend(message.content); + } + } + } else { + self.0.push(message); + } + } + + pub fn last(&self) -> Option<&Message> { + self.0.last() + } + + pub fn first(&self) -> Option<&Message> { + self.0.first() + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn extend(&mut self, iter: I) + where + I: IntoIterator, + { + for message in iter { + self.push(message); + } + } + + pub fn iter(&self) -> std::slice::Iter { + self.0.iter() + } + + pub fn pop(&mut self) -> Option { + self.0.pop() + } + + pub fn truncate(&mut self, len: usize) { + self.0.truncate(len); + } + + pub fn clear(&mut self) { + self.0.clear(); + } + + fn validate(self) -> Result { + let (_messages, issues) = fix_messages(self.0.clone()); + if !issues.is_empty() { + let reason = issues.join("\n"); + Err(InvalidConversation { + reason, + conversation: self, + }) + } else { + Ok(self) + } + } +} + +impl Default for Conversation { + fn default() -> Self { + Self::empty() + } +} + +/// Fix a conversation that we're about to send to an LLM. So the last and first +/// messages should always be from the user. +pub fn fix_conversation(conversation: Conversation) -> (Conversation, Vec) { + let messages = conversation.messages().clone(); + let (messages, issues) = fix_messages(messages); + (Conversation::new_unvalidated(messages), issues) +} + +fn fix_messages(messages: Vec) -> (Vec, Vec) { + let (messages_1, empty_removed) = remove_empty_messages(messages); + let (messages_2, tool_calling_fixed) = fix_tool_calling(messages_1); + let (messages_3, messages_merged) = merge_consecutive_messages(messages_2); + let (messages_4, lead_trail_fixed) = fix_lead_trail(messages_3); + let (messages_5, populated_if_empty) = populate_if_empty(messages_4); + + let mut issues = Vec::new(); + issues.extend(empty_removed); + issues.extend(tool_calling_fixed); + issues.extend(messages_merged); + issues.extend(lead_trail_fixed); + issues.extend(populated_if_empty); + + (messages_5, issues) +} + +fn remove_empty_messages(messages: Vec) -> (Vec, Vec) { + let mut issues = Vec::new(); + let filtered_messages = messages + .into_iter() + .filter(|msg| { + if msg.content.is_empty() { + issues.push("Removed empty message".to_string()); + false + } else { + true + } + }) + .collect(); + (filtered_messages, issues) +} + +fn fix_tool_calling(mut messages: Vec) -> (Vec, Vec) { + let mut issues = Vec::new(); + let mut pending_tool_requests: HashSet = HashSet::new(); + + for message in &mut messages { + let mut content_to_remove = Vec::new(); + + match message.role { + Role::User => { + for (idx, content) in message.content.iter().enumerate() { + match content { + MessageContent::ToolRequest(req) => { + content_to_remove.push(idx); + issues.push(format!( + "Removed tool request '{}' from user message", + req.id + )); + } + MessageContent::ToolConfirmationRequest(req) => { + content_to_remove.push(idx); + issues.push(format!( + "Removed tool confirmation request '{}' from user message", + req.id + )); + } + MessageContent::Thinking(_) | MessageContent::RedactedThinking(_) => { + content_to_remove.push(idx); + issues.push("Removed thinking content from user message".to_string()); + } + MessageContent::ToolResponse(resp) => { + if pending_tool_requests.contains(&resp.id) { + pending_tool_requests.remove(&resp.id); + } else { + content_to_remove.push(idx); + issues + .push(format!("Removed orphaned tool response '{}'", resp.id)); + } + } + _ => {} + } + } + } + Role::Assistant => { + for (idx, content) in message.content.iter().enumerate() { + match content { + MessageContent::ToolResponse(resp) => { + content_to_remove.push(idx); + issues.push(format!( + "Removed tool response '{}' from assistant message", + resp.id + )); + } + MessageContent::FrontendToolRequest(req) => { + content_to_remove.push(idx); + issues.push(format!( + "Removed frontend tool request '{}' from assistant message", + req.id + )); + } + MessageContent::ToolRequest(req) => { + pending_tool_requests.insert(req.id.clone()); + } + _ => {} + } + } + } + } + + for &idx in content_to_remove.iter().rev() { + message.content.remove(idx); + } + } + + for message in &mut messages { + if message.role == Role::Assistant { + let mut content_to_remove = Vec::new(); + for (idx, content) in message.content.iter().enumerate() { + if let MessageContent::ToolRequest(req) = content { + if pending_tool_requests.contains(&req.id) { + content_to_remove.push(idx); + issues.push(format!("Removed orphaned tool request '{}'", req.id)); + } + } + } + for &idx in content_to_remove.iter().rev() { + message.content.remove(idx); + } + } + } + let (messages, empty_removed) = remove_empty_messages(messages); + issues.extend(empty_removed); + (messages, issues) +} + +fn merge_consecutive_messages(messages: Vec) -> (Vec, Vec) { + let mut issues = Vec::new(); + let mut merged_messages: Vec = Vec::new(); + + for message in messages { + if let Some(last) = merged_messages.last_mut() { + let effective = effective_role(&message); + if effective_role(last) == effective { + last.content.extend(message.content); + issues.push(format!("Merged consecutive {} messages", effective)); + continue; + } + } + merged_messages.push(message); + } + + (merged_messages, issues) +} + +fn has_tool_response(message: &Message) -> bool { + message + .content + .iter() + .any(|content| matches!(content, MessageContent::ToolResponse(_))) +} + +fn effective_role(message: &Message) -> String { + if message.role == Role::User && has_tool_response(message) { + "tool".to_string() + } else { + match message.role { + Role::User => "user".to_string(), + Role::Assistant => "assistant".to_string(), + } + } +} + +fn fix_lead_trail(mut messages: Vec) -> (Vec, Vec) { + let mut issues = Vec::new(); + + if let Some(first) = messages.first() { + if first.role == Role::Assistant { + messages.remove(0); + issues.push("Removed leading assistant message".to_string()); + } + } + + if let Some(last) = messages.last() { + if last.role == Role::Assistant { + messages.pop(); + issues.push("Removed trailing assistant message".to_string()); + } + } + + (messages, issues) +} + +const PLACEHOLDER_USER_MESSAGE: &str = "Hello"; + +fn populate_if_empty(mut messages: Vec) -> (Vec, Vec) { + let mut issues = Vec::new(); + + if messages.is_empty() { + issues.push("Added placeholder user message to empty conversation".to_string()); + messages.push(Message::user().with_text(PLACEHOLDER_USER_MESSAGE)); + } + (messages, issues) +} + +pub fn debug_conversation_fix( + messages: &[Message], + fixed: &[Message], + issues: &[String], +) -> String { + let mut output = String::new(); + + output.push_str("=== CONVERSATION FIX DEBUG ===\n\n"); + + output.push_str("BEFORE:\n"); + for (i, msg) in messages.iter().enumerate() { + output.push_str(&format!(" [{}] {}\n", i, msg.debug())); + } + + output.push_str("\nISSUES FOUND:\n"); + if issues.is_empty() { + output.push_str(" (none)\n"); + } else { + for issue in issues { + output.push_str(&format!(" - {}\n", issue)); + } + } + + output.push_str("\nAFTER:\n"); + for (i, msg) in fixed.iter().enumerate() { + output.push_str(&format!(" [{}] {}\n", i, msg.debug())); + } + + output.push_str("\n==============================\n"); + output +} + +#[cfg(test)] +mod tests { + use crate::conversation::message::Message; + use crate::conversation::{debug_conversation_fix, fix_conversation, Conversation}; + use mcp_core::tool::ToolCall; + use rmcp::model::Role; + use serde_json::json; + + fn run_verify(messages: Vec) -> (Vec, Vec) { + let (fixed, issues) = fix_conversation(Conversation::new_unvalidated(messages.clone())); + + // Uncomment the following line to print the debug report + // let report = debug_conversation_fix(&messages, &fixed, &issues); + // print!("\n{}", report); + + let (_fixed, issues_with_fixed) = fix_conversation(fixed.clone()); + assert_eq!( + issues_with_fixed.len(), + 0, + "Fixed conversation should have no issues, but found: {:?}\n\n{}", + issues_with_fixed, + debug_conversation_fix(&messages, &fixed.messages(), &issues) + ); + (fixed.messages().clone(), issues) + } + + #[test] + fn test_valid_conversation() { + let all_messages = vec![ + Message::user().with_text("Can you help me search for something?"), + Message::assistant() + .with_text("I'll help you search.") + .with_tool_request( + "search_1", + Ok(ToolCall::new( + "web_search", + json!({"query": "rust programming"}), + )), + ), + Message::user().with_tool_response("search_1", Ok(vec![])), + Message::assistant().with_text("Based on the search results, here's what I found..."), + ]; + + for i in 1..=all_messages.len() { + let messages = Conversation::new_unvalidated(all_messages[..i].to_vec()); + if messages.last().unwrap().role == Role::User { + let (fixed, issues) = fix_conversation(messages.clone()); + assert_eq!( + fixed.len(), + messages.len(), + "Step {}: Length should match", + i + ); + assert!( + issues.is_empty(), + "Step {}: Should have no issues, but found: {:?}", + i, + issues + ); + assert_eq!( + fixed.messages(), + messages.messages(), + "Step {}: Messages should be unchanged", + i + ); + } + } + } + + #[test] + fn test_role_alternation_and_content_placement_issues() { + let messages = vec![ + Message::user().with_text("Hello"), + Message::user().with_text("Another user message"), + Message::assistant() + .with_text("Response") + .with_tool_response("orphan_1", Ok(vec![])), // Wrong role + Message::assistant().with_thinking("Let me think", "sig"), + Message::user() + .with_tool_request("bad_req", Ok(ToolCall::new("search", json!({})))) + .with_text("User with bad tool request"), + ]; + + let (fixed, issues) = run_verify(messages); + + assert_eq!(fixed.len(), 3); + assert_eq!(issues.len(), 4); + + assert!(issues + .iter() + .any(|i| i.contains("Merged consecutive user messages"))); + assert!(issues + .iter() + .any(|i| i.contains("Removed tool response 'orphan_1' from assistant message"))); + assert!(issues + .iter() + .any(|i| i.contains("Removed tool request 'bad_req' from user message"))); + + assert_eq!(fixed[0].role, Role::User); + assert_eq!(fixed[1].role, Role::Assistant); + assert_eq!(fixed[2].role, Role::User); + + assert_eq!(fixed[0].content.len(), 2); + } + + #[test] + fn test_orphaned_tools_and_empty_messages() { + // This conversation completely collapses. the first user message is invalid + // then we remove the empty user message and the wrong tool response + // then we collapse the assistant messages + // which we then remove because you can't end a conversation with an assistant message + let messages = vec![ + Message::assistant() + .with_text("I'll search for you") + .with_tool_request("search_1", Ok(ToolCall::new("search", json!({})))), + Message::user(), + Message::user().with_tool_response("wrong_id", Ok(vec![])), + Message::assistant() + .with_tool_request("search_2", Ok(ToolCall::new("search", json!({})))), + ]; + + let (fixed, issues) = run_verify(messages); + + assert_eq!(fixed.len(), 1); + + assert!(issues.iter().any(|i| i.contains("Removed empty message"))); + assert!(issues + .iter() + .any(|i| i.contains("Removed orphaned tool response 'wrong_id'"))); + + assert_eq!(fixed[0].role, Role::User); + assert_eq!(fixed[0].as_concat_text(), "Hello"); + } + + #[test] + fn test_real_world_consecutive_assistant_messages() { + let conversation = Conversation::new_unvalidated(vec![ + Message::user().with_text("run ls in the current directory and then run a word count on the smallest file"), + Message::assistant() + .with_text("I'll help you run `ls` in the current directory and then perform a word count on the smallest file. Let me start by listing the directory contents.") + .with_tool_request("toolu_bdrk_018adWbP4X26CfoJU5hkhu3i", Ok(ToolCall::new("developer__shell", json!({"command": "ls -la"})))), + Message::assistant() + .with_text("Now I'll identify the smallest file by size. Looking at the output, I can see that both `slack.yaml` and `subrecipes.yaml` have a size of 0 bytes, making them the smallest files. I'll run a word count on one of them:") + .with_tool_request("toolu_bdrk_01KgDYHs4fAodi22NqxRzmwx", Ok(ToolCall::new("developer__shell", json!({"command": "wc slack.yaml"})))), + Message::user() + .with_tool_response("toolu_bdrk_01KgDYHs4fAodi22NqxRzmwx", Ok(vec![])), + Message::assistant() + .with_text("I ran `ls -la` in the current directory and found several files. Looking at the file sizes, I can see that both `slack.yaml` and `subrecipes.yaml` are 0 bytes (the smallest files). I ran a word count on `slack.yaml` which shows: **0 lines**, **0 words**, **0 characters**"), + Message::user().with_text("thanks!"), + ]); + + let (fixed, issues) = fix_conversation(conversation); + + assert_eq!(fixed.len(), 5); + assert_eq!(issues.len(), 2); + assert!(issues[0].contains("Removed orphaned tool request")); + assert!(issues[1].contains("Merged consecutive assistant messages")); + } + + #[test] + fn test_tool_response_effective_role() { + let messages = vec![ + Message::user().with_text("Search for something"), + Message::assistant() + .with_text("I'll search for you") + .with_tool_request("search_1", Ok(ToolCall::new("search", json!({})))), + Message::user().with_tool_response("search_1", Ok(vec![])), + Message::user().with_text("Thanks!"), + ]; + + let (_fixed, issues) = run_verify(messages); + assert_eq!(issues.len(), 0); + } +} diff --git a/crates/goose/src/message/tool_result_serde.rs b/crates/goose/src/conversation/tool_result_serde.rs similarity index 100% rename from crates/goose/src/message/tool_result_serde.rs rename to crates/goose/src/conversation/tool_result_serde.rs diff --git a/crates/goose/src/conversation_fixer.rs b/crates/goose/src/conversation_fixer.rs deleted file mode 100644 index 1a585ae699d1..000000000000 --- a/crates/goose/src/conversation_fixer.rs +++ /dev/null @@ -1,408 +0,0 @@ -use crate::message::{Message, MessageContent}; -use rmcp::model::Role; -use std::collections::HashSet; - -pub struct ConversationFixer; - -const PLACEHOLDER_USER_MESSAGE: &str = "Hello"; - -impl ConversationFixer { - /// Fix a conversation that we're about to send to an LLM. So the last and first - /// messages should always be from the user. - pub fn fix_conversation(messages: Vec) -> (Vec, Vec) { - let (messages_1, empty_removed) = Self::remove_empty_messages(messages); - let (messages_2, tool_calling_fixed) = Self::fix_tool_calling(messages_1); - let (messages_3, messages_merged) = Self::merge_consecutive_messages(messages_2); - let (messages_4, lead_trail_fixed) = Self::fix_lead_trail(messages_3); - let (messages_5, populated_if_empty) = Self::populate_if_empty(messages_4); - - let mut issues = Vec::new(); - issues.extend(empty_removed); - issues.extend(tool_calling_fixed); - issues.extend(messages_merged); - issues.extend(lead_trail_fixed); - issues.extend(populated_if_empty); - - (messages_5, issues) - } - - fn remove_empty_messages(messages: Vec) -> (Vec, Vec) { - let mut issues = Vec::new(); - let filtered_messages = messages - .into_iter() - .filter(|msg| { - if msg.content.is_empty() { - issues.push("Removed empty message".to_string()); - false - } else { - true - } - }) - .collect(); - (filtered_messages, issues) - } - - fn fix_tool_calling(mut messages: Vec) -> (Vec, Vec) { - let mut issues = Vec::new(); - let mut pending_tool_requests: HashSet = HashSet::new(); - - for message in &mut messages { - let mut content_to_remove = Vec::new(); - - match message.role { - Role::User => { - for (idx, content) in message.content.iter().enumerate() { - match content { - MessageContent::ToolRequest(req) => { - content_to_remove.push(idx); - issues.push(format!( - "Removed tool request '{}' from user message", - req.id - )); - } - MessageContent::ToolConfirmationRequest(req) => { - content_to_remove.push(idx); - issues.push(format!( - "Removed tool confirmation request '{}' from user message", - req.id - )); - } - MessageContent::Thinking(_) | MessageContent::RedactedThinking(_) => { - content_to_remove.push(idx); - issues - .push("Removed thinking content from user message".to_string()); - } - MessageContent::ToolResponse(resp) => { - if pending_tool_requests.contains(&resp.id) { - pending_tool_requests.remove(&resp.id); - } else { - content_to_remove.push(idx); - issues.push(format!( - "Removed orphaned tool response '{}'", - resp.id - )); - } - } - _ => {} - } - } - } - Role::Assistant => { - for (idx, content) in message.content.iter().enumerate() { - match content { - MessageContent::ToolResponse(resp) => { - content_to_remove.push(idx); - issues.push(format!( - "Removed tool response '{}' from assistant message", - resp.id - )); - } - MessageContent::FrontendToolRequest(req) => { - content_to_remove.push(idx); - issues.push(format!( - "Removed frontend tool request '{}' from assistant message", - req.id - )); - } - MessageContent::ToolRequest(req) => { - pending_tool_requests.insert(req.id.clone()); - } - _ => {} - } - } - } - } - - for &idx in content_to_remove.iter().rev() { - message.content.remove(idx); - } - } - - for message in &mut messages { - if message.role == Role::Assistant { - let mut content_to_remove = Vec::new(); - for (idx, content) in message.content.iter().enumerate() { - if let MessageContent::ToolRequest(req) = content { - if pending_tool_requests.contains(&req.id) { - content_to_remove.push(idx); - issues.push(format!("Removed orphaned tool request '{}'", req.id)); - } - } - } - for &idx in content_to_remove.iter().rev() { - message.content.remove(idx); - } - } - } - let (messages, empty_removed) = Self::remove_empty_messages(messages); - issues.extend(empty_removed); - (messages, issues) - } - - fn merge_consecutive_messages(messages: Vec) -> (Vec, Vec) { - let mut issues = Vec::new(); - let mut merged_messages: Vec = Vec::new(); - - for message in messages { - if let Some(last) = merged_messages.last_mut() { - let effective = Self::effective_role(&message); - if Self::effective_role(last) == effective { - last.content.extend(message.content); - issues.push(format!("Merged consecutive {} messages", effective)); - continue; - } - } - merged_messages.push(message); - } - - (merged_messages, issues) - } - - fn has_tool_response(message: &Message) -> bool { - message - .content - .iter() - .any(|content| matches!(content, MessageContent::ToolResponse(_))) - } - - fn effective_role(message: &Message) -> String { - if message.role == Role::User && Self::has_tool_response(message) { - "tool".to_string() - } else { - match message.role { - Role::User => "user".to_string(), - Role::Assistant => "assistant".to_string(), - } - } - } - - fn fix_lead_trail(mut messages: Vec) -> (Vec, Vec) { - let mut issues = Vec::new(); - - if let Some(first) = messages.first() { - if first.role == Role::Assistant { - messages.remove(0); - issues.push("Removed leading assistant message".to_string()); - } - } - - if let Some(last) = messages.last() { - if last.role == Role::Assistant { - messages.pop(); - issues.push("Removed trailing assistant message".to_string()); - } - } - - (messages, issues) - } - - fn populate_if_empty(mut messages: Vec) -> (Vec, Vec) { - let mut issues = Vec::new(); - - if messages.is_empty() { - issues.push("Added placeholder user message to empty conversation".to_string()); - messages.push(Message::user().with_text(PLACEHOLDER_USER_MESSAGE)); - } - (messages, issues) - } -} - -pub fn debug_conversation_fix( - messages: &[Message], - fixed: &[Message], - issues: &[String], -) -> String { - let mut output = String::new(); - - output.push_str("=== CONVERSATION FIX DEBUG ===\n\n"); - - output.push_str("BEFORE:\n"); - for (i, msg) in messages.iter().enumerate() { - output.push_str(&format!(" [{}] {}\n", i, msg.debug())); - } - - output.push_str("\nISSUES FOUND:\n"); - if issues.is_empty() { - output.push_str(" (none)\n"); - } else { - for issue in issues { - output.push_str(&format!(" - {}\n", issue)); - } - } - - output.push_str("\nAFTER:\n"); - for (i, msg) in fixed.iter().enumerate() { - output.push_str(&format!(" [{}] {}\n", i, msg.debug())); - } - - output.push_str("\n==============================\n"); - output -} - -#[cfg(test)] -mod tests { - use super::*; - use mcp_core::tool::ToolCall; - use serde_json::json; - - fn run_verify(messages: Vec) -> (Vec, Vec) { - let (fixed, issues) = ConversationFixer::fix_conversation(messages.clone()); - - // Uncomment the following line to print the debug report - // let report = debug_conversation_fix(&messages, &fixed, &issues); - // print!("\n{}", report); - - let (_fixed, issues_with_fixed) = ConversationFixer::fix_conversation(fixed.clone()); - assert_eq!( - issues_with_fixed.len(), - 0, - "Fixed conversation should have no issues, but found: {:?}\n\n{}", - issues_with_fixed, - debug_conversation_fix(&messages, &fixed, &issues) - ); - (fixed, issues) - } - - #[test] - fn test_valid_conversation() { - let all_messages = vec![ - Message::user().with_text("Can you help me search for something?"), - Message::assistant() - .with_text("I'll help you search.") - .with_tool_request( - "search_1", - Ok(ToolCall::new( - "web_search", - json!({"query": "rust programming"}), - )), - ), - Message::user().with_tool_response("search_1", Ok(vec![])), - Message::assistant().with_text("Based on the search results, here's what I found..."), - ]; - - for i in 1..=all_messages.len() { - let messages = all_messages[..i].to_vec(); - if messages.last().unwrap().role == Role::User { - let (fixed, issues) = ConversationFixer::fix_conversation(messages.clone()); - assert_eq!( - fixed.len(), - messages.len(), - "Step {}: Length should match", - i - ); - assert!( - issues.is_empty(), - "Step {}: Should have no issues, but found: {:?}", - i, - issues - ); - assert_eq!(fixed, messages, "Step {}: Messages should be unchanged", i); - } - } - } - - #[test] - fn test_role_alternation_and_content_placement_issues() { - let messages = vec![ - Message::user().with_text("Hello"), - Message::user().with_text("Another user message"), - Message::assistant() - .with_text("Response") - .with_tool_response("orphan_1", Ok(vec![])), // Wrong role - Message::assistant().with_thinking("Let me think", "sig"), - Message::user() - .with_tool_request("bad_req", Ok(ToolCall::new("search", json!({})))) - .with_text("User with bad tool request"), - ]; - - let (fixed, issues) = run_verify(messages); - - assert_eq!(fixed.len(), 3); - assert_eq!(issues.len(), 4); - - assert!(issues - .iter() - .any(|i| i.contains("Merged consecutive user messages"))); - assert!(issues - .iter() - .any(|i| i.contains("Removed tool response 'orphan_1' from assistant message"))); - assert!(issues - .iter() - .any(|i| i.contains("Removed tool request 'bad_req' from user message"))); - - assert_eq!(fixed[0].role, Role::User); - assert_eq!(fixed[1].role, Role::Assistant); - assert_eq!(fixed[2].role, Role::User); - - assert_eq!(fixed[0].content.len(), 2); - } - - #[test] - fn test_orphaned_tools_and_empty_messages() { - // This conversation completely collapses. the first user message is invalid - // then we remove the empty user message and the wrong tool response - // then we collapse the assistant messages - // which we then remove because you can't end a conversation with an assistant message - let messages = vec![ - Message::assistant() - .with_text("I'll search for you") - .with_tool_request("search_1", Ok(ToolCall::new("search", json!({})))), - Message::user(), - Message::user().with_tool_response("wrong_id", Ok(vec![])), - Message::assistant() - .with_tool_request("search_2", Ok(ToolCall::new("search", json!({})))), - ]; - - let (fixed, issues) = run_verify(messages); - - assert_eq!(fixed.len(), 1); - - assert!(issues.iter().any(|i| i.contains("Removed empty message"))); - assert!(issues - .iter() - .any(|i| i.contains("Removed orphaned tool response 'wrong_id'"))); - - assert_eq!(fixed[0].role, Role::User); - assert_eq!(fixed[0].as_concat_text(), "Hello"); - } - - #[test] - fn test_real_world_consecutive_assistant_messages() { - let messages = vec![ - Message::user().with_text("run ls in the current directory and then run a word count on the smallest file"), - Message::assistant() - .with_text("I'll help you run `ls` in the current directory and then perform a word count on the smallest file. Let me start by listing the directory contents.") - .with_tool_request("toolu_bdrk_018adWbP4X26CfoJU5hkhu3i", Ok(ToolCall::new("developer__shell", json!({"command": "ls -la"})))), - Message::assistant() - .with_text("Now I'll identify the smallest file by size. Looking at the output, I can see that both `slack.yaml` and `subrecipes.yaml` have a size of 0 bytes, making them the smallest files. I'll run a word count on one of them:") - .with_tool_request("toolu_bdrk_01KgDYHs4fAodi22NqxRzmwx", Ok(ToolCall::new("developer__shell", json!({"command": "wc slack.yaml"})))), - Message::user() - .with_tool_response("toolu_bdrk_01KgDYHs4fAodi22NqxRzmwx", Ok(vec![])), - Message::assistant() - .with_text("I ran `ls -la` in the current directory and found several files. Looking at the file sizes, I can see that both `slack.yaml` and `subrecipes.yaml` are 0 bytes (the smallest files). I ran a word count on `slack.yaml` which shows: **0 lines**, **0 words**, **0 characters**"), - Message::user().with_text("thanks!"), - ]; - - let (fixed, issues) = ConversationFixer::fix_conversation(messages); - - assert_eq!(fixed.len(), 5); - assert_eq!(issues.len(), 2); - assert!(issues[0].contains("Removed orphaned tool request")); - assert!(issues[1].contains("Merged consecutive assistant messages")); - } - - #[test] - fn test_tool_response_effective_role() { - let messages = vec![ - Message::user().with_text("Search for something"), - Message::assistant() - .with_text("I'll search for you") - .with_tool_request("search_1", Ok(ToolCall::new("search", json!({})))), - Message::user().with_tool_response("search_1", Ok(vec![])), - Message::user().with_text("Thanks!"), - ]; - - let (fixed, issues) = run_verify(messages); - assert_eq!(issues.len(), 0); - } -} diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index 97defbf8131c..7d774dddeddc 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -1,8 +1,7 @@ pub mod agents; pub mod config; pub mod context_mgmt; -mod conversation_fixer; -pub mod message; +pub mod conversation; pub mod model; pub mod oauth; pub mod permission; diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index 4b870e30a262..f37f1589d75a 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -1,7 +1,8 @@ use crate::agents::platform_tools::PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME; use crate::config::permission::PermissionLevel; use crate::config::PermissionManager; -use crate::message::{Message, MessageContent, ToolRequest}; +use crate::conversation::message::{Message, MessageContent, ToolRequest}; +use crate::conversation::Conversation; use crate::providers::base::Provider; use chrono::Utc; use indoc::indoc; @@ -68,7 +69,7 @@ fn create_read_only_tool() -> Tool { } /// Builds the message to be sent to the LLM for detecting read-only operations. -fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec { +fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Conversation { let tool_names: Vec = tool_requests .iter() .filter_map(|req| { @@ -93,7 +94,7 @@ fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec { tool_names.join(", "), ))], )); - check_messages + Conversation::new_unvalidated(check_messages) } /// Processes the response to extract the list of tools with read-only operations. @@ -135,7 +136,7 @@ pub async fn detect_read_only_tools( let res = provider .complete( "You are a good analyst and can detect operations whether they have read-only operations.", - &check_messages, + check_messages.messages(), &[tool.clone()], ) .await; @@ -260,7 +261,7 @@ pub async fn check_tool_permissions( #[cfg(test)] mod tests { use super::*; - use crate::message::{Message, MessageContent, ToolRequest}; + use crate::conversation::message::{Message, MessageContent, ToolRequest}; use crate::model::ModelConfig; use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::errors::ProviderError; @@ -340,7 +341,7 @@ mod tests { let messages = create_check_messages(vec![&tool_request]); assert_eq!(messages.len(), 1); - let content = &messages[0].content[0]; + let content = &messages.first().unwrap().content[0]; if let MessageContent::Text(text_content) = content { assert!(text_content .text diff --git a/crates/goose/src/permission/permission_store.rs b/crates/goose/src/permission/permission_store.rs index c4eebcf2ee9f..80a8a8128895 100644 --- a/crates/goose/src/permission/permission_store.rs +++ b/crates/goose/src/permission/permission_store.rs @@ -1,4 +1,4 @@ -use crate::message::ToolRequest; +use crate::conversation::message::ToolRequest; use anyhow::Result; use blake3::Hasher; use chrono::Utc; diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 2a5ad01925cb..aca6a4ef3896 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -15,8 +15,8 @@ use super::formats::anthropic::{ create_request, get_usage, response_to_message, response_to_streaming_message, }; use super::utils::{emit_debug_trace, get_model, map_http_error_to_provider_error}; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use crate::providers::retry::ProviderRetry; use rmcp::model::Tool; diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index 0a0a2236e9d3..f40993d67657 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -10,8 +10,8 @@ use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; use super::retry::ProviderRetry; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use rmcp::model::Tool; diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 3cd6e4e1c6bb..02fe7801e061 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -4,7 +4,8 @@ use serde::{Deserialize, Serialize}; use super::errors::ProviderError; use super::retry::RetryConfig; -use crate::message::Message; +use crate::conversation::message::Message; +use crate::conversation::Conversation; use crate::model::ModelConfig; use crate::utils::safe_truncate; use rmcp::model::Tool; @@ -370,7 +371,7 @@ pub trait Provider: Send + Sync { } /// Returns the first 3 user messages as strings for session naming - fn get_initial_user_messages(&self, messages: &[Message]) -> Vec { + fn get_initial_user_messages(&self, messages: &Conversation) -> Vec { messages .iter() .filter(|m| m.role == rmcp::model::Role::User) @@ -381,7 +382,10 @@ pub trait Provider: Send + Sync { /// Generate a session name/description based on the conversation history /// Creates a prompt asking for a concise description in 4 words or less. - async fn generate_session_name(&self, messages: &[Message]) -> Result { + async fn generate_session_name( + &self, + messages: &Conversation, + ) -> Result { let context = self.get_initial_user_messages(messages); let prompt = self.create_session_name_prompt(&context); let message = Message::user().with_text(&prompt); diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 4fc23b60e4c0..7579a7de7141 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -3,8 +3,8 @@ use std::collections::HashMap; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use super::retry::ProviderRetry; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use crate::providers::utils::emit_debug_trace; use anyhow::Result; diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 833fd4547aa4..3185a7961bc2 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -11,8 +11,8 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::emit_debug_trace; use crate::config::Config; +use crate::conversation::message::{Message, MessageContent}; use crate::impl_provider_default; -use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use rmcp::model::Tool; diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 40261695bd79..c635fe589470 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -18,8 +18,8 @@ use super::oauth; use super::retry::ProviderRetry; use super::utils::{get_model, handle_response_openai_compat, ImageFormat}; use crate::config::ConfigError; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use crate::providers::formats::openai::{get_usage, response_to_streaming_message}; use crate::providers::retry::{ diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index b0cb696b71bf..1ebf97344938 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -174,7 +174,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result> #[cfg(test)] mod tests { use super::*; - use crate::message::{Message, MessageContent}; + use crate::conversation::message::{Message, MessageContent}; use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage}; use chrono::Utc; use rmcp::model::{AnnotateAble, RawTextContent, Role}; diff --git a/crates/goose/src/providers/formats/anthropic.rs b/crates/goose/src/providers/formats/anthropic.rs index 0c451dea924d..206e75c5e576 100644 --- a/crates/goose/src/providers/formats/anthropic.rs +++ b/crates/goose/src/providers/formats/anthropic.rs @@ -1,4 +1,4 @@ -use crate::message::{Message, MessageContent}; +use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::base::Usage; use crate::providers::errors::ProviderError; @@ -676,6 +676,7 @@ where #[cfg(test)] mod tests { use super::*; + use crate::conversation::message::Message; use rmcp::object; use serde_json::json; @@ -983,6 +984,7 @@ mod tests { #[test] fn test_tool_error_handling_maintains_pairing() { + use crate::conversation::message::Message; use mcp_core::handler::ToolError; let messages = vec![ diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs index e947f347bae9..a3f108e0f966 100644 --- a/crates/goose/src/providers/formats/bedrock.rs +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -11,7 +11,7 @@ use rmcp::model::{Content, RawContent, ResourceContents, Role, Tool}; use serde_json::Value; use super::super::base::Usage; -use crate::message::{Message, MessageContent}; +use crate::conversation::message::{Message, MessageContent}; pub fn to_bedrock_message(message: &Message) -> Result { bedrock::Message::builder() diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index 5d95ed449867..acd2ebaac3bc 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -1,4 +1,4 @@ -use crate::message::{Message, MessageContent}; +use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::utils::{ convert_image, detect_image_path, is_valid_function_name, load_image_file, safely_parse_json, @@ -9,6 +9,7 @@ use mcp_core::{ToolCall, ToolError}; use rmcp::model::{AnnotateAble, Content, RawContent, ResourceContents, Role, Tool}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; + #[derive(Serialize)] struct DatabricksMessage { content: Value, @@ -616,6 +617,7 @@ pub fn create_request( #[cfg(test)] mod tests { use super::*; + use crate::conversation::message::Message; use rmcp::object; use serde_json::json; diff --git a/crates/goose/src/providers/formats/gcpvertexai.rs b/crates/goose/src/providers/formats/gcpvertexai.rs index e96a693b67c7..bdd8a3b6a95b 100644 --- a/crates/goose/src/providers/formats/gcpvertexai.rs +++ b/crates/goose/src/providers/formats/gcpvertexai.rs @@ -1,5 +1,5 @@ use super::{anthropic, google}; -use crate::message::Message; +use crate::conversation::message::Message; use crate::model::ModelConfig; use crate::providers::base::Usage; use anyhow::{Context, Result}; diff --git a/crates/goose/src/providers/formats/google.rs b/crates/goose/src/providers/formats/google.rs index a487084f06b5..1a2b40b446c8 100644 --- a/crates/goose/src/providers/formats/google.rs +++ b/crates/goose/src/providers/formats/google.rs @@ -1,4 +1,3 @@ -use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::base::Usage; use crate::providers::errors::ProviderError; @@ -8,6 +7,7 @@ use mcp_core::tool::ToolCall; use rand::{distributions::Alphanumeric, Rng}; use rmcp::model::{AnnotateAble, RawContent, Role, Tool}; +use crate::conversation::message::{Message, MessageContent}; use serde_json::{json, Map, Value}; use std::ops::Deref; @@ -335,6 +335,7 @@ pub fn create_request( #[cfg(test)] mod tests { use super::*; + use crate::conversation::message::Message; use rmcp::{model::Content, object}; use serde_json::json; diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 8606067f7ad1..d74ef135f224 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -1,4 +1,4 @@ -use crate::message::{Message, MessageContent}; +use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::base::{ProviderUsage, Usage}; use crate::providers::utils::{ @@ -641,6 +641,7 @@ pub fn create_request( #[cfg(test)] mod tests { use super::*; + use crate::conversation::message::Message; use rmcp::object; use serde_json::json; use tokio::pin; diff --git a/crates/goose/src/providers/formats/snowflake.rs b/crates/goose/src/providers/formats/snowflake.rs index 50669fe3c08b..ff101d9098ff 100644 --- a/crates/goose/src/providers/formats/snowflake.rs +++ b/crates/goose/src/providers/formats/snowflake.rs @@ -1,4 +1,4 @@ -use crate::message::{Message, MessageContent}; +use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::base::Usage; use crate::providers::errors::ProviderError; @@ -359,6 +359,7 @@ pub fn create_request( #[cfg(test)] mod tests { use super::*; + use crate::conversation::message::Message; use rmcp::object; use serde_json::json; @@ -546,6 +547,7 @@ data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet"," #[test] fn test_create_request_format() -> Result<()> { + use crate::conversation::message::Message; use crate::model::ModelConfig; let model_config = ModelConfig::new_or_fail("claude-3-5-sonnet"); @@ -654,6 +656,7 @@ data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet"," #[test] fn test_create_request_excludes_tools_for_description() -> Result<()> { + use crate::conversation::message::Message; use crate::model::ModelConfig; let model_config = ModelConfig::new_or_fail("claude-3-5-sonnet"); @@ -675,6 +678,7 @@ data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet"," #[test] fn test_message_formatting_skips_tool_requests() { + use crate::conversation::message::Message; use mcp_core::tool::ToolCall; // Create a conversation with text, tool requests, and tool responses diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index d038127358eb..969d7146d7e2 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -8,7 +8,7 @@ use serde_json::Value; use tokio::time::sleep; use url::Url; -use crate::message::Message; +use crate::conversation::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index afcdbdae3c5e..fcfc0f75c369 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -9,8 +9,8 @@ use tokio::process::Command; use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::emit_debug_trace; +use crate::conversation::message::{Message, MessageContent}; use crate::impl_provider_default; -use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use rmcp::model::Role; use rmcp::model::Tool; diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 9fef3549423e..fc1e7fb640dd 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -18,8 +18,8 @@ use super::retry::ProviderRetry; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::config::{Config, ConfigError}; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::ConfigKey; use rmcp::model::Tool; diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index fa401880bb44..fa262f403c3a 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -2,8 +2,8 @@ use super::api_client::{ApiClient, AuthMethod}; use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{emit_debug_trace, handle_response_google_compat, unescape_json_values}; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use crate::providers::formats::google::{create_request, get_usage, response_to_message}; diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index d840830dc699..acb51d1fc75d 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -2,8 +2,8 @@ use super::api_client::{ApiClient, AuthMethod}; use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{get_model, handle_response_openai_compat}; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index 6f1a5cd0e6c9..18564a2d1261 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -6,7 +6,7 @@ use tokio::sync::Mutex; use super::base::{LeadWorkerProviderTrait, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; -use crate::message::{Message, MessageContent}; +use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use rmcp::model::Tool; use rmcp::model::{Content, RawContent}; @@ -454,7 +454,7 @@ impl Provider for LeadWorkerProvider { #[cfg(test)] mod tests { use super::*; - use crate::message::MessageContent; + use crate::conversation::message::{Message, MessageContent}; use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage}; use chrono::Utc; use rmcp::model::{AnnotateAble, RawTextContent, Role}; diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index 54a84ced522d..8911341b4d42 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -9,8 +9,8 @@ use super::embedding::EmbeddingCapable; use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use rmcp::model::Tool; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 0b16e21dc889..7f17420fcf5d 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -3,8 +3,9 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{get_model, handle_response_openai_compat}; +use crate::conversation::message::Message; +use crate::conversation::Conversation; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; use crate::utils::safe_truncate; @@ -159,7 +160,10 @@ impl Provider for OllamaProvider { /// Generate a session name based on the conversation history /// This override filters out reasoning tokens that some Ollama models produce - async fn generate_session_name(&self, messages: &[Message]) -> Result { + async fn generate_session_name( + &self, + messages: &Conversation, + ) -> Result { let context = self.get_initial_user_messages(messages); let message = Message::user().with_text(self.create_session_name_prompt(&context)); let result = self diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index a874f421c4aa..fa4211f128bc 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -20,8 +20,8 @@ use super::utils::{ emit_debug_trace, get_model, handle_response_openai_compat, handle_status_openai_compat, ImageFormat, }; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::MessageStream; use crate::providers::formats::openai::response_to_streaming_message; diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index cc23fbfa786d..e98c49fc1690 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -10,8 +10,8 @@ use super::utils::{ emit_debug_trace, get_model, handle_response_google_compat, handle_response_openai_compat, is_google_model, }; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; use rmcp::model::Tool; diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 9b68ab656f67..90a2498d0977 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -13,8 +13,8 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::emit_debug_trace; +use crate::conversation::message::{Message, MessageContent}; use crate::impl_provider_default; -use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use chrono::Utc; use rmcp::model::Role; diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index f0b643b7748c..8e8ea663b5ae 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -10,8 +10,8 @@ use super::formats::snowflake::{create_request, get_usage, response_to_message}; use super::retry::ProviderRetry; use super::utils::{get_model, map_http_error_to_provider_error, ImageFormat}; use crate::config::ConfigError; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use rmcp::model::Tool; diff --git a/crates/goose/src/providers/testprovider.rs b/crates/goose/src/providers/testprovider.rs index c25ad0022105..eca2c87627b1 100644 --- a/crates/goose/src/providers/testprovider.rs +++ b/crates/goose/src/providers/testprovider.rs @@ -9,7 +9,7 @@ use std::sync::{Arc, Mutex}; use super::base::{Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; -use crate::message::Message; +use crate::conversation::message::Message; use crate::model::ModelConfig; use rmcp::model::Tool; @@ -162,7 +162,7 @@ impl Provider for TestProvider { #[cfg(test)] mod tests { use super::*; - use crate::message::{Message, MessageContent}; + use crate::conversation::message::{Message, MessageContent}; use crate::providers::base::{ProviderUsage, Usage}; use chrono::Utc; use rmcp::model::{RawTextContent, Role, TextContent}; diff --git a/crates/goose/src/providers/toolshim.rs b/crates/goose/src/providers/toolshim.rs index cae32e51baeb..c22255e9eca2 100644 --- a/crates/goose/src/providers/toolshim.rs +++ b/crates/goose/src/providers/toolshim.rs @@ -33,7 +33,8 @@ use super::errors::ProviderError; use super::ollama::OLLAMA_DEFAULT_PORT; use super::ollama::OLLAMA_HOST; -use crate::message::{Message, MessageContent}; +use crate::conversation::message::{Message, MessageContent}; +use crate::conversation::Conversation; use crate::model::ModelConfig; use crate::providers::formats::openai::create_request; use anyhow::Result; @@ -310,8 +311,8 @@ pub fn format_tool_info(tools: &[Tool]) -> String { /// Convert messages containing ToolRequest/ToolResponse to text messages for toolshim mode /// This is necessary because some providers (like Bedrock) validate that tool_use/tool_result /// blocks can only exist when tools are defined, but in toolshim mode we pass empty tools -pub fn convert_tool_messages_to_text(messages: &[Message]) -> Vec { - messages +pub fn convert_tool_messages_to_text(messages: &[Message]) -> Conversation { + let converted_messages: Vec = messages .iter() .map(|message| { let mut new_content = Vec::new(); @@ -366,7 +367,9 @@ pub fn convert_tool_messages_to_text(messages: &[Message]) -> Vec { message.clone() } }) - .collect() + .collect(); + + Conversation::new_unvalidated(converted_messages) } /// Modifies the system prompt to include tool usage instructions when tool interpretation is enabled diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 18b08aa4f26c..185587c6df6c 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -9,8 +9,8 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::map_http_error_to_provider_error; +use crate::conversation::message::{Message, MessageContent}; use crate::impl_provider_default; -use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use mcp_core::{ToolCall, ToolResult}; use rmcp::model::{Role, Tool}; diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index e1462f71664a..7b2aed5f15c8 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -2,8 +2,8 @@ use super::api_client::{ApiClient, AuthMethod}; use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{get_model, handle_response_openai_compat}; +use crate::conversation::message::Message; use crate::impl_provider_default; -use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 5541b2a34f4c..f65b6a6c8fec 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -15,7 +15,8 @@ use tokio_cron_scheduler::{job::JobId, Job, JobScheduler as TokioJobScheduler}; use crate::agents::AgentEvent; use crate::agents::{Agent, SessionConfig}; use crate::config::{self, Config}; -use crate::message::Message; +use crate::conversation::message::Message; +use crate::conversation::Conversation; use crate::providers::base::Provider as GooseProvider; // Alias to avoid conflict in test section use crate::providers::create; use crate::recipe::Recipe; @@ -1190,8 +1191,8 @@ async fn run_scheduled_job_internal( }; if let Some(prompt_text) = recipe.prompt { - let mut all_session_messages: Vec = - vec![Message::user().with_text(prompt_text.clone())]; + let mut all_session_messages = + Conversation::new_unvalidated(vec![Message::user().with_text(prompt_text.clone())]); let current_dir = match std::env::current_dir() { Ok(cd) => cd, @@ -1213,7 +1214,11 @@ async fn run_scheduled_job_internal( }; match agent - .reply(&all_session_messages, Some(session_config.clone()), None) + .reply( + all_session_messages.clone(), + Some(session_config.clone()), + None, + ) .await { Ok(mut stream) => { @@ -1314,9 +1319,11 @@ async fn run_scheduled_job_internal( message_count: 0, ..Default::default() }; - if let Err(e) = - crate::session::storage::save_messages_with_metadata(&session_file_path, &metadata, &[]) - { + if let Err(e) = crate::session::storage::save_messages_with_metadata( + &session_file_path, + &metadata, + &Conversation::new_unvalidated(vec![]), + ) { tracing::error!( "[Job {}] Failed to persist metadata for empty job: {}", job.id, @@ -1334,7 +1341,6 @@ mod tests { use super::*; use crate::recipe::Recipe; use crate::{ - message::MessageContent, model::ModelConfig, // Use the actual ModelConfig for the mock's field providers::base::{ProviderMetadata, ProviderUsage, Usage}, providers::errors::ProviderError, @@ -1345,6 +1351,7 @@ mod tests { // `read_metadata` is still used by the test itself, so keep it or its module. use crate::session::storage::read_metadata; + use crate::conversation::message::{Message, MessageContent}; use std::env; use std::fs::{self, File}; use std::io::Write; diff --git a/crates/goose/src/session/storage.rs b/crates/goose/src/session/storage.rs index 5d50d1d83437..5cc3e3a457da 100644 --- a/crates/goose/src/session/storage.rs +++ b/crates/goose/src/session/storage.rs @@ -5,7 +5,8 @@ // - Backup creation // Additional debug logging can be added if needed for troubleshooting. -use crate::message::Message; +use crate::conversation::message::Message; +use crate::conversation::Conversation; use crate::providers::base::Provider; use crate::utils::safe_truncate; use anyhow::Result; @@ -399,7 +400,7 @@ pub fn generate_session_id() -> String { /// Security features: /// - Validates file paths to prevent directory traversal /// - Includes all security limits from read_messages_with_truncation -pub fn read_messages(session_file: &Path) -> Result> { +pub fn read_messages(session_file: &Path) -> Result { // Validate the path for security let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; @@ -428,7 +429,7 @@ pub fn read_messages(session_file: &Path) -> Result> { pub fn read_messages_with_truncation( session_file: &Path, max_content_size: Option, -) -> Result> { +) -> Result { // Security check: file size limit if session_file.exists() { let metadata = fs::metadata(session_file)?; @@ -626,7 +627,7 @@ pub fn read_messages_with_truncation( } } - Ok(messages) + Ok(Conversation::new_unvalidated(messages)) } /// Parse a message from JSON string with optional content truncation @@ -685,7 +686,7 @@ fn parse_message_with_truncation( /// Truncate content within a message in place fn truncate_message_content_in_place(message: &mut Message, max_content_size: usize) { - use crate::message::MessageContent; + use crate::conversation::message::MessageContent; use rmcp::model::{RawContent, ResourceContents}; for content in &mut message.content { @@ -1051,7 +1052,7 @@ pub fn read_metadata(session_file: &Path) -> Result { /// - Validates file paths to prevent directory traversal pub async fn persist_messages( session_file: &Path, - messages: &[Message], + messages: &Conversation, provider: Option>, working_dir: Option, ) -> Result<()> { @@ -1069,7 +1070,7 @@ pub async fn persist_messages( /// - Uses atomic file operations via save_messages_with_metadata pub async fn persist_messages_with_schedule_id( session_file: &Path, - messages: &[Message], + messages: &Conversation, provider: Option>, schedule_id: Option, working_dir: Option, @@ -1144,7 +1145,7 @@ pub async fn persist_messages_with_schedule_id( pub fn save_messages_with_metadata( session_file: &Path, metadata: &SessionMetadata, - messages: &[Message], + messages: &Conversation, ) -> Result<()> { use fs2::FileExt; @@ -1257,7 +1258,7 @@ pub fn save_messages_with_metadata( /// of the session based on the conversation history. pub async fn generate_description( session_file: &Path, - messages: &[Message], + messages: &Conversation, provider: Arc, working_dir: Option, ) -> Result<()> { @@ -1275,7 +1276,7 @@ pub async fn generate_description( /// - Uses secure file operations for saving pub async fn generate_description_with_schedule_id( session_file: &Path, - messages: &[Message], + messages: &Conversation, provider: Arc, schedule_id: Option, working_dir: Option, @@ -1346,7 +1347,7 @@ pub async fn update_metadata(session_file: &Path, metadata: &SessionMetadata) -> #[cfg(test)] mod tests { use super::*; - use crate::message::MessageContent; + use crate::conversation::message::{Message, MessageContent}; use tempfile::tempdir; #[test] @@ -1428,10 +1429,10 @@ mod tests { let file_path = dir.path().join("test.jsonl"); // Create some test messages - let messages = vec![ + let messages = Conversation::new_unvalidated(vec![ Message::user().with_text("Hello"), Message::assistant().with_text("Hi there"), - ]; + ]); // Write messages persist_messages(&file_path, &messages, None, None).await?; @@ -1535,7 +1536,7 @@ mod tests { "}]", ]; - let mut messages = Vec::new(); + let mut messages = Conversation::empty(); for text in special_chars { messages.push(Message::user().with_text(text)); messages.push(Message::assistant().with_text(text)); @@ -1601,10 +1602,10 @@ mod tests { // Create a message with content larger than the 50KB truncation limit let very_large_text = "A".repeat(100_000); // 100KB of text - let messages = vec![ + let messages = Conversation::new_unvalidated(vec![ Message::user().with_text(&very_large_text), Message::assistant().with_text("Small response"), - ]; + ]); // Write messages persist_messages(&file_path, &messages, None, None).await?; @@ -1615,7 +1616,9 @@ mod tests { assert_eq!(messages.len(), read_messages.len()); // First message should be truncated - if let Some(MessageContent::Text(read_text)) = read_messages[0].content.first() { + if let Some(MessageContent::Text(read_text)) = + read_messages.first().unwrap().content.first() + { assert!( read_text.text.len() < very_large_text.len(), "Content should be truncated" @@ -1635,7 +1638,7 @@ mod tests { } // Second message should be unchanged - if let Some(MessageContent::Text(read_text)) = read_messages[1].content.first() { + if let Some(MessageContent::Text(read_text)) = read_messages.messages()[1].content.first() { assert_eq!(read_text.text, "Small response"); } else { panic!("Expected text content in second message"); @@ -1652,7 +1655,7 @@ mod tests { let mut metadata = SessionMetadata::default(); metadata.description = "Description with\nnewline and \"quotes\" and 🦆".to_string(); - let messages = vec![Message::user().with_text("test")]; + let messages = Conversation::new_unvalidated(vec![Message::user().with_text("test")]); // Write with special metadata save_messages_with_metadata(&file_path, &metadata, &messages)?; @@ -1679,7 +1682,7 @@ mod tests { assert_eq!(metadata.working_dir, get_home_dir()); // Test deserialization of invalid directory - let messages = vec![Message::user().with_text("test")]; + let messages = Conversation::new_unvalidated(vec![Message::user().with_text("test")]); save_messages_with_metadata(&file_path, &metadata, &messages)?; // Modify the file to include invalid directory @@ -1709,7 +1712,8 @@ mod tests { let working_dir_path = working_dir.path().to_path_buf(); // Create messages - let messages = vec![Message::user().with_text("test message")]; + let messages = + Conversation::new_unvalidated(vec![Message::user().with_text("test message")]); // Use persist_messages_with_schedule_id to set working dir persist_messages_with_schedule_id( @@ -1728,7 +1732,10 @@ mod tests { // Verify the messages are also preserved let read_messages = read_messages(&file_path)?; assert_eq!(read_messages.len(), 1); - assert_eq!(read_messages[0].role, messages[0].role); + assert_eq!( + read_messages.first().unwrap().role, + messages.messages()[0].role + ); Ok(()) } @@ -1744,7 +1751,8 @@ mod tests { let working_dir_path = working_dir.path().to_path_buf(); // Create messages - let messages = vec![Message::user().with_text("test message")]; + let messages = + Conversation::new_unvalidated(vec![Message::user().with_text("test message")]); // Get the home directory for comparison let home_dir = get_home_dir(); @@ -1911,10 +1919,10 @@ mod tests { let dir = tempdir()?; let file_path = dir.path().join("test_save_session.jsonl"); - let messages = vec![ + let messages = Conversation::new_unvalidated(vec![ Message::user().with_text("Hello"), Message::assistant().with_text("Hi there"), - ]; + ]); let metadata = SessionMetadata::default(); @@ -1937,10 +1945,10 @@ mod tests { let dir = tempdir()?; let file_path = dir.path().join("test_persist_no_save.jsonl"); - let messages = vec![ + let messages = Conversation::new_unvalidated(vec![ Message::user().with_text("Test message"), Message::assistant().with_text("Test response"), - ]; + ]); // Test persist_messages_with_schedule_id with working_dir parameter persist_messages_with_schedule_id( diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index daa6dca5fd53..c95493df519d 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use tiktoken_rs::CoreBPE; use tokio::sync::OnceCell; -use crate::message::Message; +use crate::conversation::message::Message; // Global tokenizer instance to avoid repeated initialization static TOKENIZER: OnceCell> = OnceCell::const_new(); @@ -380,7 +380,7 @@ pub async fn create_async_token_counter() -> Result { #[cfg(test)] mod tests { use super::*; - use crate::message::{Message, MessageContent}; + use crate::conversation::message::{Message, MessageContent}; use rmcp::model::{Role, Tool}; use rmcp::object; diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index cc77eb5f15b7..33f19ed53d6f 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -5,7 +5,8 @@ use std::sync::Arc; use anyhow::Result; use futures::StreamExt; use goose::agents::{Agent, AgentEvent}; -use goose::message::Message; +use goose::conversation::message::Message; +use goose::conversation::Conversation; use goose::model::ModelConfig; use goose::providers::base::Provider; use goose::providers::{ @@ -118,7 +119,7 @@ async fn run_truncate_test( agent.update_provider(provider).await?; let repeat_count = context_window + 10_000; let large_message_content = "hello ".repeat(repeat_count); - let messages = vec![ + let messages = Conversation::new(vec![ Message::user().with_text("hi there. what is 2 + 2?"), Message::assistant().with_text("hey! I think it's 4."), Message::user().with_text(&large_message_content), @@ -128,9 +129,10 @@ async fn run_truncate_test( Message::user().with_text( "did I ask you what's 2+2 in this message history? just respond with 'yes' or 'no'", ), - ]; + ]) + .unwrap(); - let reply_stream = agent.reply(&messages, None, None).await?; + let reply_stream = agent.reply(messages, None, None).await?; tokio::pin!(reply_stream); let mut responses = Vec::new(); @@ -166,11 +168,11 @@ async fn run_truncate_test( assert_eq!(responses[0].content.len(), 1); match responses[0].content[0] { - goose::message::MessageContent::Text(ref text_content) => { + goose::conversation::message::MessageContent::Text(ref text_content) => { assert!(text_content.text.to_lowercase().contains("no")); assert!(!text_content.text.to_lowercase().contains("yes")); } - goose::message::MessageContent::ContextLengthExceeded(_) => { + goose::conversation::message::MessageContent::ContextLengthExceeded(_) => { // This is an acceptable outcome for providers that don't truncate themselves // and correctly report that the context length was exceeded. println!( @@ -546,12 +548,14 @@ mod final_output_tool_tests { use goose::agents::final_output_tool::{ FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME, }; + use goose::conversation::Conversation; use goose::providers::base::MessageStream; use goose::recipe::Response; #[tokio::test] async fn test_final_output_assistant_message_in_reply() -> Result<()> { use async_trait::async_trait; + use goose::conversation::message::Message; use goose::model::ModelConfig; use goose::providers::base::{Provider, ProviderUsage, Usage}; use goose::providers::errors::ProviderError; @@ -626,7 +630,7 @@ mod final_output_tool_tests { ); // Simulate the reply stream continuing after the final output tool call. - let reply_stream = agent.reply(&vec![], None, None).await?; + let reply_stream = agent.reply(Conversation::empty(), None, None).await?; tokio::pin!(reply_stream); let mut responses = Vec::new(); @@ -652,6 +656,7 @@ mod final_output_tool_tests { #[tokio::test] async fn test_when_final_output_not_called_in_reply() -> Result<()> { use async_trait::async_trait; + use goose::conversation::message::Message; use goose::model::ModelConfig; use goose::providers::base::{Provider, ProviderUsage}; use goose::providers::errors::ProviderError; @@ -723,7 +728,7 @@ mod final_output_tool_tests { agent.add_final_output_tool(response).await; // Simulate the reply stream being called. - let reply_stream = agent.reply(&vec![], None, None).await?; + let reply_stream = agent.reply(Conversation::empty(), None, None).await?; tokio::pin!(reply_stream); let mut responses = Vec::new(); @@ -773,6 +778,8 @@ mod retry_tests { use super::*; use async_trait::async_trait; use goose::agents::types::{RetryConfig, SessionConfig, SuccessCheck}; + use goose::conversation::message::Message; + use goose::conversation::Conversation; use goose::model::ModelConfig; use goose::providers::base::{Provider, ProviderUsage, Usage}; use goose::providers::errors::ProviderError; @@ -855,10 +862,11 @@ mod retry_tests { retry_config: Some(retry_config), }; - let initial_messages = vec![Message::user().with_text("Complete this task")]; + let conversation = + Conversation::new(vec![Message::user().with_text("Complete this task")]).unwrap(); let reply_stream = agent - .reply(&initial_messages, Some(session_config), None) + .reply(conversation, Some(session_config), None) .await?; tokio::pin!(reply_stream); @@ -952,7 +960,8 @@ mod retry_tests { mod max_turns_tests { use super::*; use async_trait::async_trait; - use goose::message::MessageContent; + use goose::conversation::message::{Message, MessageContent}; + use goose::conversation::Conversation; use goose::model::ModelConfig; use goose::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use goose::providers::errors::ProviderError; @@ -1021,9 +1030,11 @@ mod max_turns_tests { max_turns: Some(1), retry_config: None, }; - let messages = vec![Message::user().with_text("Hello")]; + let conversation = Conversation::new(vec![Message::user().with_text("Hello")]).unwrap(); - let reply_stream = agent.reply(&messages, Some(session_config), None).await?; + let reply_stream = agent + .reply(conversation, Some(session_config), None) + .await?; tokio::pin!(reply_stream); let mut responses = Vec::new(); diff --git a/crates/goose/tests/private_tests.rs b/crates/goose/tests/private_tests.rs index e23d0c09e319..6242730af2cf 100644 --- a/crates/goose/tests/private_tests.rs +++ b/crates/goose/tests/private_tests.rs @@ -782,10 +782,10 @@ async fn test_schedule_tool_session_content_action_with_real_session() { // Create test metadata and messages let metadata = create_test_session_metadata(2, "/tmp"); - let messages = vec![ - goose::message::Message::user().with_text("Hello"), - goose::message::Message::assistant().with_text("Hi there!"), - ]; + let messages = goose::conversation::Conversation::new_unvalidated(vec![ + goose::conversation::message::Message::user().with_text("Hello"), + goose::conversation::message::Message::assistant().with_text("Hi there!"), + ]); // Save the session file goose::session::storage::save_messages_with_metadata(&session_path, &metadata, &messages) diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index f0ac979edcc2..a636d4f55db0 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -1,6 +1,6 @@ use anyhow::Result; use dotenvy::dotenv; -use goose::message::{Message, MessageContent}; +use goose::conversation::message::{Message, MessageContent}; use goose::providers::base::Provider; use goose::providers::errors::ProviderError; use goose::providers::{ @@ -257,6 +257,7 @@ impl ProviderTester { async fn test_image_content_support(&self) -> Result<()> { use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; + use goose::conversation::message::Message; use std::fs; // Try to read the test image