diff --git a/Cargo.lock b/Cargo.lock index 9541e0b7dbf..f53e9203127 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3486,6 +3486,7 @@ dependencies = [ "tokio", "tokio-cron-scheduler", "tokio-stream", + "tokio-util", "tracing", "tracing-subscriber", "url", @@ -8604,9 +8605,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.13" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" dependencies = [ "bytes", "futures-core", diff --git a/Cargo.toml b/Cargo.toml index ac8f47a6cdc..da1c0cd2b29 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,4 +15,4 @@ uninlined_format_args = "allow" # Patch for Windows cross-compilation issue with crunchy [patch.crates-io] -crunchy = { git = "https://github.com/nmathewson/crunchy", branch = "cross-compilation-fix" } \ No newline at end of file +crunchy = { git = "https://github.com/nmathewson/crunchy", branch = "cross-compilation-fix" } diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index bec2a2db938..8338cdb85f7 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -10,6 +10,7 @@ pub use self::export::message_to_markdown; pub use builder::{build_session, SessionBuilderConfig, SessionSettings}; use console::Color; use goose::agents::AgentEvent; +use goose::message::push_message; use goose::permission::permission_confirmation::PrincipalType; use goose::permission::Permission; use goose::permission::PermissionConfirmation; @@ -317,7 +318,7 @@ impl Session { /// Process a single message and get the response async fn process_message(&mut self, message: String) -> Result<()> { - self.messages.push(Message::user().with_text(&message)); + self.push_message(Message::user().with_text(&message)); // Get the provider from the agent for description generation let provider = self.agent.provider().await?; @@ -417,7 +418,7 @@ impl Session { RunMode::Normal => { save_history(&mut editor); - self.messages.push(Message::user().with_text(&content)); + self.push_message(Message::user().with_text(&content)); // Track the current directory and last instruction in projects.json let session_id = self @@ -740,7 +741,7 @@ impl Session { self.messages.clear(); // add the plan response as a user message let plan_message = Message::user().with_text(plan_response.as_concat_text()); - self.messages.push(plan_message); + self.push_message(plan_message); // act on the plan output::show_thinking(); self.process_agent_response(true).await?; @@ -755,13 +756,13 @@ impl Session { } else { // add the plan response (assistant message) & carry the conversation forward // in the next round, the user might wanna slightly modify the plan - self.messages.push(plan_response); + self.push_message(plan_response); } } PlannerResponseType::ClarifyingQuestions => { // add the plan response (assistant message) & carry the conversation forward // in the next round, the user will answer the clarifying questions - self.messages.push(plan_response); + self.push_message(plan_response); } } @@ -833,7 +834,7 @@ impl Session { confirmation.id.clone(), Err(ToolError::ExecutionError("Tool call cancelled by user".to_string())) )); - self.messages.push(response_message); + push_message(&mut self.messages, response_message); if let Some(session_file) = &self.session_file { session::persist_messages_with_schedule_id( session_file, @@ -930,7 +931,7 @@ impl Session { } // otherwise we have a model/tool to render else { - self.messages.push(message.clone()); + push_message(&mut self.messages, message.clone()); // No need to update description on assistant messages if let Some(session_file) = &self.session_file { @@ -946,7 +947,6 @@ impl Session { if interactive {output::hide_thinking()}; let _ = progress_bars.hide(); output::render_message(&message, self.debug); - if interactive {output::show_thinking()}; } } Some(Ok(AgentEvent::McpNotification((_id, message)))) => { @@ -1094,6 +1094,7 @@ impl Session { } } } + println!(); Ok(()) } @@ -1137,7 +1138,7 @@ impl Session { Err(ToolError::ExecutionError(notification.clone())), )); } - self.messages.push(response_message); + self.push_message(response_message); // No need for description update here if let Some(session_file) = &self.session_file { @@ -1154,7 +1155,7 @@ impl Session { "The existing call to {} was interrupted. How would you like to proceed?", last_tool_name ); - self.messages.push(Message::assistant().with_text(&prompt)); + self.push_message(Message::assistant().with_text(&prompt)); // No need for description update here if let Some(session_file) = &self.session_file { @@ -1176,7 +1177,7 @@ impl Session { Some(MessageContent::ToolResponse(_)) => { // Interruption occurred after a tool had completed but not assistant reply let prompt = "The tool calling loop was interrupted. How would you like to proceed?"; - self.messages.push(Message::assistant().with_text(prompt)); + self.push_message(Message::assistant().with_text(prompt)); // No need for description update here if let Some(session_file) = &self.session_file { @@ -1364,7 +1365,7 @@ impl Session { if msg.role == mcp_core::Role::User { output::render_message(&msg, self.debug); } - self.messages.push(msg); + self.push_message(msg); } if valid { @@ -1422,6 +1423,10 @@ impl Session { Ok(path) } + + fn push_message(&mut self, message: Message) { + push_message(&mut self.messages, message); + } } fn get_reasoner() -> Result, anyhow::Error> { diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index deeea706aed..4d4c1fefc07 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -8,7 +8,7 @@ use mcp_core::tool::ToolCall; use serde_json::Value; use std::cell::RefCell; use std::collections::HashMap; -use std::io::Error; +use std::io::{Error, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::Duration; @@ -164,7 +164,8 @@ pub fn render_message(message: &Message, debug: bool) { } } } - println!(); + + let _ = std::io::stdout().flush(); } pub fn render_text(text: &str, color: Option, dim: bool) { diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 5aae8e1f3bc..dccbcf2457b 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -225,6 +225,7 @@ async fn handler( return; } }; + let saved_message_count = all_messages.len(); loop { tokio::select! { @@ -242,16 +243,6 @@ async fn handler( ).await; break; } - - - let session_path = session_path.clone(); - let messages = all_messages.clone(); - let provider = Arc::clone(provider.as_ref().unwrap()); - tokio::spawn(async move { - if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await { - tracing::error!("Failed to store session history: {:?}", e); - } - }); } Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => { if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await { @@ -303,6 +294,17 @@ async fn handler( } } + if all_messages.len() > saved_message_count { + let provider = Arc::clone(provider.as_ref().unwrap()); + tokio::spawn(async move { + if let Err(e) = + session::persist_messages(&session_path, &all_messages, Some(provider)).await + { + tracing::error!("Failed to store session history: {:?}", e); + } + }); + } + let _ = stream_event( MessageEvent::Finish { reason: "stop".to_string(), diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index c8a9929895a..8761857096a 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -81,6 +81,7 @@ fs2 = "0.4.3" tokio-stream = "0.1.17" dashmap = "6.1" ahash = "0.8" +tokio-util = "0.7.15" # Vector database for tool selection lancedb = "0.13" diff --git a/crates/goose/examples/databricks_oauth.rs b/crates/goose/examples/databricks_oauth.rs index 2b49cc03a86..bf36d0a6a8e 100644 --- a/crates/goose/examples/databricks_oauth.rs +++ b/crates/goose/examples/databricks_oauth.rs @@ -2,8 +2,12 @@ use anyhow::Result; use dotenv::dotenv; use goose::{ message::Message, - providers::{base::Provider, databricks::DatabricksProvider}, + providers::{ + base::{Provider, Usage}, + databricks::DatabricksProvider, + }, }; +use tokio_stream::StreamExt; #[tokio::main] async fn main() -> Result<()> { @@ -20,21 +24,24 @@ async fn main() -> Result<()> { let message = Message::user().with_text("Tell me a short joke about programming."); // Get a response - let (response, usage) = provider - .complete("You are a helpful assistant.", &[message], &[]) + let mut stream = provider + .stream("You are a helpful assistant.", &[message], &[]) .await?; - // Print the response and usage statistics println!("\nResponse from AI:"); println!("---------------"); - for content in response.content { - dbg!(content); + let mut usage = Usage::default(); + while let Some(Ok((msg, usage_part))) = stream.next().await { + dbg!(msg); + usage_part.map(|u| { + usage += u.usage; + }); } println!("\nToken Usage:"); println!("------------"); - println!("Input tokens: {:?}", usage.usage.input_tokens); - println!("Output tokens: {:?}", usage.usage.output_tokens); - println!("Total tokens: {:?}", usage.usage.total_tokens); + println!("Input tokens: {:?}", usage.input_tokens); + println!("Output tokens: {:?}", usage.output_tokens); + println!("Total tokens: {:?}", usage.total_tokens); Ok(()) } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 38a488a9781..3adee02309a 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -14,7 +14,7 @@ use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ }; use crate::agents::sub_recipe_manager::SubRecipeManager; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; -use crate::message::Message; +use crate::message::{push_message, Message}; use crate::permission::permission_judge::check_tool_permissions; use crate::permission::PermissionConfirmation; use crate::providers::base::Provider; @@ -722,6 +722,16 @@ impl Agent { }); loop { + // Check for final output before incrementing turns or checking max_turns + // This ensures that if we have a final output ready, we return it immediately + // without being blocked by the max_turns limit - this is needed for streaming cases + if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { + if final_output_tool.final_output.is_some() { + yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap())); + break; + } + } + turns_taken += 1; if turns_taken > max_turns { yield AgentEvent::Message(Message::assistant().with_text( @@ -752,262 +762,291 @@ impl Agent { } } - match Self::generate_response_from_provider( + let mut stream = Self::stream_response_from_provider( self.provider().await?, &system_prompt, &messages, &tools, &toolshim_tools, - ).await { - Ok((response, usage)) => { - // Emit model change event if provider is lead-worker - let provider = self.provider().await?; - if let Some(lead_worker) = provider.as_lead_worker() { - // The actual model used is in the usage - let active_model = usage.model.clone(); - let (lead_model, worker_model) = lead_worker.get_model_info(); - let mode = if active_model == lead_model { - "lead" - } else if active_model == worker_model { - "worker" - } else { - "unknown" - }; - - yield AgentEvent::ModelChange { - model: active_model, - mode: mode.to_string(), - }; - } - - // record usage for the session in the session file - if let Some(session_config) = session.clone() { - Self::update_session_metrics(session_config, &usage, messages.len()).await?; - } + ).await?; + + let mut added_message = false; + while let Some(next) = stream.next().await { + match next { + Ok((response, usage)) => { + // Emit model change event if provider is lead-worker + let provider = self.provider().await?; + if let Some(lead_worker) = provider.as_lead_worker() { + if let Some(ref usage) = usage { + // The actual model used is in the usage + let active_model = usage.model.clone(); + let (lead_model, worker_model) = lead_worker.get_model_info(); + let mode = if active_model == lead_model { + "lead" + } else if active_model == worker_model { + "worker" + } else { + "unknown" + }; + + yield AgentEvent::ModelChange { + model: active_model, + mode: mode.to_string(), + }; + } + } - // categorize the type of requests we need to handle - let (frontend_requests, - remaining_requests, - filtered_response) = - self.categorize_tool_requests(&response).await; - - // Record tool calls in the router selector - let selector = self.router_tool_selector.lock().await.clone(); - if let Some(selector) = selector { - // Record frontend tool calls - for request in &frontend_requests { - if let Ok(tool_call) = &request.tool_call { - if let Err(e) = selector.record_tool_call(&tool_call.name).await { - tracing::error!("Failed to record frontend tool call: {}", e); - } + // record usage for the session in the session file + if let Some(session_config) = session.clone() { + if let Some(ref usage) = usage { + Self::update_session_metrics(session_config, usage, messages.len()).await?; } } - // Record remaining tool calls - for request in &remaining_requests { - if let Ok(tool_call) = &request.tool_call { - if let Err(e) = selector.record_tool_call(&tool_call.name).await { - tracing::error!("Failed to record tool call: {}", e); + + if let Some(response) = response { + // categorize the type of requests we need to handle + let (frontend_requests, + remaining_requests, + filtered_response) = + self.categorize_tool_requests(&response).await; + + // Record tool calls in the router selector + let selector = self.router_tool_selector.lock().await.clone(); + if let Some(selector) = selector { + // Record frontend tool calls + for request in &frontend_requests { + if let Ok(tool_call) = &request.tool_call { + if let Err(e) = selector.record_tool_call(&tool_call.name).await { + tracing::error!("Failed to record frontend tool call: {}", e); + } + } + } + // Record remaining tool calls + for request in &remaining_requests { + if let Ok(tool_call) = &request.tool_call { + if let Err(e) = selector.record_tool_call(&tool_call.name).await { + tracing::error!("Failed to record tool call: {}", e); + } + } } } - } - } - // Yield the assistant's response with frontend tool requests filtered out - yield AgentEvent::Message(filtered_response.clone()); - - tokio::task::yield_now().await; - - let num_tool_requests = frontend_requests.len() + remaining_requests.len(); - if num_tool_requests == 0 { - if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { - if final_output_tool.final_output.is_none() { - tracing::warn!("Final output tool has not been called yet. Continuing agent loop."); - let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE); - messages.push(message.clone()); - yield AgentEvent::Message(message); + // Yield the assistant's response with frontend tool requests filtered out + yield AgentEvent::Message(filtered_response.clone()); + + tokio::task::yield_now().await; + + let num_tool_requests = frontend_requests.len() + remaining_requests.len(); + if num_tool_requests == 0 { + if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { + if final_output_tool.final_output.is_none() { + tracing::warn!("Final output tool has not been called yet. Continuing agent loop."); + let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE); + messages.push(message.clone()); + yield AgentEvent::Message(message); + continue; + } else { + let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()); + messages.push(message.clone()); + yield AgentEvent::Message(message); + // Set added_message to true and continue to end the current iteration + added_message = true; + push_message(&mut messages, response); + continue; + } + } + // If there's no final output tool and no tool requests, continue the loop continue; - } else { - let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()); - messages.push(message.clone()); - yield AgentEvent::Message(message); } - } - break; - } - // Process tool requests depending on frontend tools and then goose_mode - let message_tool_response = Arc::new(Mutex::new(Message::user())); + // Process tool requests depending on frontend tools and then goose_mode + let message_tool_response = Arc::new(Mutex::new(Message::user())); - // First handle any frontend tool requests - let mut frontend_tool_stream = self.handle_frontend_tool_requests( - &frontend_requests, - message_tool_response.clone() - ); - - // we have a stream of frontend tools to handle, inside the stream - // execution is yeield back to this reply loop, and is of the same Message - // type, so we can yield that back up to be handled - while let Some(msg) = frontend_tool_stream.try_next().await? { - yield AgentEvent::Message(msg); - } - - // Clone goose_mode once before the match to avoid move issues - let mode = goose_mode.clone(); - if mode.as_str() == "chat" { - // Skip all tool calls in chat mode - for request in remaining_requests { - let mut response = message_tool_response.lock().await; - *response = response.clone().with_tool_response( - request.id.clone(), - Ok(vec![Content::text(CHAT_MODE_TOOL_SKIPPED_RESPONSE)]), + // First handle any frontend tool requests + let mut frontend_tool_stream = self.handle_frontend_tool_requests( + &frontend_requests, + message_tool_response.clone() ); - } - } else { - // At this point, we have handled the frontend tool requests and know goose_mode != "chat" - // What remains is handling the remaining tool requests (enable extension, - // regular tool calls) in goose_mode == ["auto", "approve" or "smart_approve"] - let mut permission_manager = PermissionManager::default(); - let (permission_check_result, enable_extension_request_ids) = check_tool_permissions( - &remaining_requests, - &mode, - tools_with_readonly_annotation.clone(), - tools_without_annotation.clone(), - &mut permission_manager, - self.provider().await?).await; - - // Handle pre-approved and read-only tools in parallel - let mut tool_futures: Vec<(String, ToolStream)> = Vec::new(); - - // Skip the confirmation for approved tools - for request in &permission_check_result.approved { - if let Ok(tool_call) = request.tool_call.clone() { - let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await; - - tool_futures.push((req_id, match tool_result { - Ok(result) => tool_stream( - result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())), - result.result, - ), - Err(e) => tool_stream( - Box::new(stream::empty()), - futures::future::ready(Err(e)), - ), - })); + + // we have a stream of frontend tools to handle, inside the stream + // execution is yeield back to this reply loop, and is of the same Message + // type, so we can yield that back up to be handled + while let Some(msg) = frontend_tool_stream.try_next().await? { + yield AgentEvent::Message(msg); } - } - for request in &permission_check_result.denied { - let mut response = message_tool_response.lock().await; - *response = response.clone().with_tool_response( - request.id.clone(), - Ok(vec![Content::text(DECLINED_RESPONSE)]), - ); - } + // Clone goose_mode once before the match to avoid move issues + let mode = goose_mode.clone(); + if mode.as_str() == "chat" { + // Skip all tool calls in chat mode + for request in remaining_requests { + let mut response = message_tool_response.lock().await; + *response = response.clone().with_tool_response( + request.id.clone(), + Ok(vec![Content::text(CHAT_MODE_TOOL_SKIPPED_RESPONSE)]), + ); + } + } else { + // At this point, we have handled the frontend tool requests and know goose_mode != "chat" + // What remains is handling the remaining tool requests (enable extension, + // regular tool calls) in goose_mode == ["auto", "approve" or "smart_approve"] + let mut permission_manager = PermissionManager::default(); + let (permission_check_result, enable_extension_request_ids) = check_tool_permissions( + &remaining_requests, + &mode, + tools_with_readonly_annotation.clone(), + tools_without_annotation.clone(), + &mut permission_manager, + self.provider().await?).await; + + // Handle pre-approved and read-only tools in parallel + let mut tool_futures: Vec<(String, ToolStream)> = Vec::new(); + + // Skip the confirmation for approved tools + for request in &permission_check_result.approved { + if let Ok(tool_call) = request.tool_call.clone() { + let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await; + + tool_futures.push((req_id, match tool_result { + Ok(result) => tool_stream( + result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())), + result.result, + ), + Err(e) => tool_stream( + Box::new(stream::empty()), + futures::future::ready(Err(e)), + ), + })); + } + } - // We need interior mutability in handle_approval_tool_requests - let tool_futures_arc = Arc::new(Mutex::new(tool_futures)); - - // Process tools requiring approval (enable extension, regular tool calls) - let mut tool_approval_stream = self.handle_approval_tool_requests( - &permission_check_result.needs_approval, - tool_futures_arc.clone(), - &mut permission_manager, - message_tool_response.clone() - ); - - // We have a stream of tool_approval_requests to handle - // Execution is yielded back to this reply loop, and is of the same Message - // type, so we can yield the Message back up to be handled and grab any - // confirmations or denials - while let Some(msg) = tool_approval_stream.try_next().await? { - yield AgentEvent::Message(msg); - } + for request in &permission_check_result.denied { + let mut response = message_tool_response.lock().await; + *response = response.clone().with_tool_response( + request.id.clone(), + Ok(vec![Content::text(DECLINED_RESPONSE)]), + ); + } - tool_futures = { - // Lock the mutex asynchronously - let mut futures_lock = tool_futures_arc.lock().await; - // Drain the vector and collect into a new Vec - futures_lock.drain(..).collect::>() - }; - - let with_id = tool_futures - .into_iter() - .map(|(request_id, stream)| { - stream.map(move |item| (request_id.clone(), item)) - }) - .collect::>(); - - let mut combined = stream::select_all(with_id); - - let mut all_install_successful = true; - - while let Some((request_id, item)) = combined.next().await { - match item { - ToolStreamItem::Result(output) => { - if enable_extension_request_ids.contains(&request_id) && output.is_err(){ - all_install_successful = false; + // We need interior mutability in handle_approval_tool_requests + let tool_futures_arc = Arc::new(Mutex::new(tool_futures)); + + // Process tools requiring approval (enable extension, regular tool calls) + let mut tool_approval_stream = self.handle_approval_tool_requests( + &permission_check_result.needs_approval, + tool_futures_arc.clone(), + &mut permission_manager, + message_tool_response.clone() + ); + + // We have a stream of tool_approval_requests to handle + // Execution is yielded back to this reply loop, and is of the same Message + // type, so we can yield the Message back up to be handled and grab any + // confirmations or denials + while let Some(msg) = tool_approval_stream.try_next().await? { + yield AgentEvent::Message(msg); + } + + tool_futures = { + // Lock the mutex asynchronously + let mut futures_lock = tool_futures_arc.lock().await; + // Drain the vector and collect into a new Vec + futures_lock.drain(..).collect::>() + }; + + let with_id = tool_futures + .into_iter() + .map(|(request_id, stream)| { + stream.map(move |item| (request_id.clone(), item)) + }) + .collect::>(); + + let mut combined = stream::select_all(with_id); + + let mut all_install_successful = true; + + while let Some((request_id, item)) = combined.next().await { + match item { + ToolStreamItem::Result(output) => { + if enable_extension_request_ids.contains(&request_id) && output.is_err(){ + all_install_successful = false; + } + let mut response = message_tool_response.lock().await; + *response = response.clone().with_tool_response(request_id, output); + }, + ToolStreamItem::Message(msg) => { + yield AgentEvent::McpNotification((request_id, msg)) + } } - let mut response = message_tool_response.lock().await; - *response = response.clone().with_tool_response(request_id, output); - }, - ToolStreamItem::Message(msg) => { - yield AgentEvent::McpNotification((request_id, msg)) + } + + // Update system prompt and tools if installations were successful + if all_install_successful { + (tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?; } } - } - // Update system prompt and tools if installations were successful - if all_install_successful { - (tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?; + let final_message_tool_resp = message_tool_response.lock().await.clone(); + yield AgentEvent::Message(final_message_tool_resp.clone()); + + added_message = true; + push_message(&mut messages, response); + push_message(&mut messages, final_message_tool_resp); + + // Check for MCP notifications from subagents again before next iteration + // Note: These are already handled as McpNotification events above, + // so we don't need to convert them to assistant messages here. + // This was causing duplicate plain-text notifications. + // let mcp_notifications = self.get_mcp_notifications().await; + // for notification in mcp_notifications { + // // Extract subagent info from the notification data for assistant messages + // if let JsonRpcMessage::Notification(ref notif) = notification { + // if let Some(params) = ¬if.params { + // if let Some(data) = params.get("data") { + // if let (Some(subagent_id), Some(message)) = ( + // data.get("subagent_id").and_then(|v| v.as_str()), + // data.get("message").and_then(|v| v.as_str()) + // ) { + // yield AgentEvent::Message( + // Message::assistant().with_text( + // format!("Subagent {}: {}", subagent_id, message) + // ) + // ); + // } + // } + // } + // } + // } } + }, + Err(ProviderError::ContextLengthExceeded(_)) => { + // At this point, the last message should be a user message + // because call to provider led to context length exceeded error + // Immediately yield a special message and break + yield AgentEvent::Message(Message::assistant().with_context_length_exceeded( + "The context length of the model has been exceeded. Please start a new session and try again.", + )); + break; + }, + Err(e) => { + // Create an error message & terminate the stream + error!("Error: {}", e); + yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."))); + break; } - - let final_message_tool_resp = message_tool_response.lock().await.clone(); - yield AgentEvent::Message(final_message_tool_resp.clone()); - - messages.push(response); - messages.push(final_message_tool_resp); - - // Check for MCP notifications from subagents again before next iteration - // Note: These are already handled as McpNotification events above, - // so we don't need to convert them to assistant messages here. - // This was causing duplicate plain-text notifications. - // let mcp_notifications = self.get_mcp_notifications().await; - // for notification in mcp_notifications { - // // Extract subagent info from the notification data for assistant messages - // if let JsonRpcMessage::Notification(ref notif) = notification { - // if let Some(params) = ¬if.params { - // if let Some(data) = params.get("data") { - // if let (Some(subagent_id), Some(message)) = ( - // data.get("subagent_id").and_then(|v| v.as_str()), - // data.get("message").and_then(|v| v.as_str()) - // ) { - // yield AgentEvent::Message( - // Message::assistant().with_text( - // format!("Subagent {}: {}", subagent_id, message) - // ) - // ); - // } - // } - // } - // } - // } - }, - Err(ProviderError::ContextLengthExceeded(_)) => { - // At this point, the last message should be a user message - // because call to provider led to context length exceeded error - // Immediately yield a special message and break - yield AgentEvent::Message(Message::assistant().with_context_length_exceeded( - "The context length of the model has been exceeded. Please start a new session and try again.", - )); - break; - }, - Err(e) => { - // Create an error message & terminate the stream - error!("Error: {}", e); - yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."))); - break; } } + if !added_message { + if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { + if final_output_tool.final_output.is_none() { + tracing::warn!("Final output tool has not been called yet. Continuing agent loop."); + yield AgentEvent::Message(Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE)); + continue; + } else { + yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap())); + } + } + break; + } // Yield control back to the scheduler to prevent blocking tokio::task::yield_now().await; diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 573e9a93b11..486b09fbbff 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -2,10 +2,13 @@ use anyhow::Result; use std::collections::HashSet; use std::sync::Arc; +use async_stream::try_stream; +use futures::stream::StreamExt; + use crate::agents::router_tool_selector::RouterToolSelectionStrategy; use crate::config::Config; use crate::message::{Message, MessageContent, ToolRequest}; -use crate::providers::base::{Provider, ProviderUsage}; +use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage}; use crate::providers::errors::ProviderError; use crate::providers::toolshim::{ augment_message_with_tool_calls, convert_tool_messages_to_text, @@ -16,6 +19,19 @@ use mcp_core::tool::Tool; use super::super::agents::Agent; +async fn toolshim_postprocess( + response: Message, + toolshim_tools: &[Tool], +) -> Result { + let interpreter = OllamaInterpreter::new().map_err(|e| { + ProviderError::ExecutionError(format!("Failed to create OllamaInterpreter: {}", e)) + })?; + + augment_message_with_tool_calls(&interpreter, response, toolshim_tools) + .await + .map_err(|e| ProviderError::ExecutionError(format!("Failed to augment message: {}", e))) +} + impl Agent { /// Prepares tools and system prompt for a provider request pub(crate) async fn prepare_tools_and_prompt( @@ -128,25 +144,67 @@ impl Agent { .complete(system_prompt, &messages_for_provider, tools) .await?; - // Store the model information in the global store crate::providers::base::set_current_model(&usage.model); - // Post-process / structure the response only if tool interpretation is enabled if config.toolshim { - let interpreter = OllamaInterpreter::new().map_err(|e| { - ProviderError::ExecutionError(format!("Failed to create OllamaInterpreter: {}", e)) - })?; - - response = augment_message_with_tool_calls(&interpreter, response, toolshim_tools) - .await - .map_err(|e| { - ProviderError::ExecutionError(format!("Failed to augment message: {}", e)) - })?; + response = toolshim_postprocess(response, toolshim_tools).await?; } Ok((response, usage)) } + /// Stream a response from the LLM provider. + /// Handles toolshim transformations if needed + pub(crate) async fn stream_response_from_provider( + provider: Arc, + system_prompt: &str, + messages: &[Message], + tools: &[Tool], + toolshim_tools: &[Tool], + ) -> Result { + let config = provider.get_model_config(); + + // Convert tool messages to text if toolshim is enabled + let messages_for_provider = if config.toolshim { + convert_tool_messages_to_text(messages) + } else { + messages.to_vec() + }; + + // Clone owned data to move into the async stream + let system_prompt = system_prompt.to_owned(); + let tools = tools.to_owned(); + let toolshim_tools = toolshim_tools.to_owned(); + let provider = provider.clone(); + + let mut stream = if provider.supports_streaming() { + provider + .stream(system_prompt.as_str(), &messages_for_provider, &tools) + .await? + } else { + let (message, usage) = provider + .complete(system_prompt.as_str(), &messages_for_provider, &tools) + .await?; + stream_from_single_message(message, usage) + }; + + Ok(Box::pin(try_stream! { + while let Some(Ok((mut message, usage))) = stream.next().await { + // Store the model information in the global store + if let Some(usage) = usage.as_ref() { + crate::providers::base::set_current_model(&usage.model); + } + + // Post-process / structure the response only if tool interpretation is enabled + if message.is_some() && config.toolshim { + message = Some(toolshim_postprocess(message.unwrap(), &toolshim_tools).await?); + } + + yield (message, usage); + } + })) + } + /// Categorize tool requests from the response into different types /// Returns: /// - frontend_requests: Tool requests that should be handled by the frontend @@ -191,6 +249,7 @@ impl Agent { } let filtered_message = Message { + id: response.id.clone(), role: response.role.clone(), created: response.created, content: filtered_content, diff --git a/crates/goose/src/context_mgmt/summarize.rs b/crates/goose/src/context_mgmt/summarize.rs index 75fe05b534c..84ea104bb9b 100644 --- a/crates/goose/src/context_mgmt/summarize.rs +++ b/crates/goose/src/context_mgmt/summarize.rs @@ -247,14 +247,14 @@ mod tests { _tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { Ok(( - Message { - role: Role::Assistant, - created: Utc::now().timestamp(), - content: vec![MessageContent::Text(TextContent { + Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text(TextContent { text: "Summarized content".to_string(), annotations: None, })], - }, + ), ProviderUsage::new("mock".to_string(), Usage::default()), )) } @@ -277,30 +277,26 @@ mod tests { } fn set_up_text_message(text: &str, role: Role) -> Message { - Message { - role, - created: 0, - content: vec![MessageContent::text(text.to_string())], - } + Message::new(role, 0, vec![MessageContent::text(text.to_string())]) } fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message { - Message { - role: Role::Assistant, - created: 0, - content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))], - } + Message::new( + Role::Assistant, + 0, + vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))], + ) } fn set_up_tool_response_message(id: &str, tool_response: Vec) -> Message { - Message { - role: Role::User, - created: 0, - content: vec![MessageContent::tool_response( + Message::new( + Role::User, + 0, + vec![MessageContent::tool_response( id.to_string(), Ok(tool_response), )], - } + ) } #[tokio::test] @@ -448,14 +444,14 @@ mod tests { #[tokio::test] async fn test_reintegrate_removed_messages() { - let summarized_messages = vec![Message { - role: Role::Assistant, - created: Utc::now().timestamp(), - content: vec![MessageContent::Text(TextContent { + let summarized_messages = vec![Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text(TextContent { text: "Summary".to_string(), annotations: None, })], - }]; + )]; let arguments = json!({ "param1": "value1" }); diff --git a/crates/goose/src/message.rs b/crates/goose/src/message.rs index 87c4ae9a247..ef207eb1d01 100644 --- a/crates/goose/src/message.rs +++ b/crates/goose/src/message.rs @@ -303,15 +303,46 @@ impl From for Message { /// A message to or from an LLM #[serde(rename_all = "camelCase")] pub struct Message { + pub id: Option, pub role: Role, pub created: i64, pub content: Vec, } +pub fn push_message(messages: &mut Vec, message: Message) { + if let Some(last) = messages + .last_mut() + .filter(|m| m.id.is_some() && m.id == message.id) + { + match (last.content.last_mut(), message.content.last()) { + (Some(MessageContent::Text(ref mut last)), Some(MessageContent::Text(new))) + if message.content.len() == 1 => + { + last.text.push_str(&new.text); + } + (_, _) => { + last.content.extend(message.content); + } + } + } else { + messages.push(message); + } +} + impl Message { + pub fn new(role: Role, created: i64, content: Vec) -> Self { + Message { + id: None, + role, + created, + content, + } + } + /// Create a new user message with the current timestamp pub fn user() -> Self { Message { + id: None, role: Role::User, created: Utc::now().timestamp(), content: Vec::new(), @@ -321,6 +352,7 @@ impl Message { /// Create a new assistant message with the current timestamp pub fn assistant() -> Self { Message { + id: None, role: Role::Assistant, created: Utc::now().timestamp(), content: Vec::new(), diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index 368ce82fb43..b8b5110e182 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -81,10 +81,10 @@ fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec { }) .collect(); let mut check_messages = vec![]; - check_messages.push(Message { - role: mcp_core::Role::User, - created: Utc::now().timestamp(), - content: vec![MessageContent::Text(TextContent { + check_messages.push(Message::new( + mcp_core::Role::User, + Utc::now().timestamp(), + vec![MessageContent::Text(TextContent { text: format!( "Here are the tool requests: {:?}\n\nAnalyze the tool requests and list the tools that perform read-only operations. \ \n\nGuidelines for Read-Only Operations: \ @@ -96,7 +96,7 @@ fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec { ), annotations: None, })], - }); + )); check_messages } @@ -296,10 +296,10 @@ mod tests { _tools: &[Tool], ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { Ok(( - Message { - role: Role::Assistant, - created: Utc::now().timestamp(), - content: vec![MessageContent::ToolRequest(ToolRequest { + Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::ToolRequest(ToolRequest { id: "mock_tool_request".to_string(), tool_call: ToolResult::Ok(ToolCall { name: "platform__tool_by_tool_permission".to_string(), @@ -308,7 +308,7 @@ mod tests { }), }), })], - }, + ), ProviderUsage::new("mock".to_string(), Usage::default()), )) } @@ -354,10 +354,10 @@ mod tests { #[test] fn test_extract_read_only_tools() { - let message = Message { - role: Role::Assistant, - created: Utc::now().timestamp(), - content: vec![MessageContent::ToolRequest(ToolRequest { + let message = Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::ToolRequest(ToolRequest { id: "tool_2".to_string(), tool_call: ToolResult::Ok(ToolCall { name: "platform__tool_by_tool_permission".to_string(), @@ -366,7 +366,7 @@ mod tests { }), }), })], - }; + ); let result = extract_read_only_tools(&message); assert!(result.is_some()); diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 09a5ef08e40..c3510fe8800 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use futures::Stream; use serde::{Deserialize, Serialize}; use super::errors::ProviderError; @@ -8,6 +9,8 @@ use mcp_core::tool::Tool; use utoipa::ToSchema; use once_cell::sync::Lazy; +use std::ops::{Add, AddAssign}; +use std::pin::Pin; use std::sync::Mutex; /// A global store for the current model being used, we use this as when a provider returns, it tells us the real model, not an alias @@ -184,13 +187,43 @@ impl ProviderUsage { } } -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default, Copy)] pub struct Usage { pub input_tokens: Option, pub output_tokens: Option, pub total_tokens: Option, } +fn sum_optionals(a: Option, b: Option) -> Option +where + T: Add + Default, +{ + match (a, b) { + (Some(x), Some(y)) => Some(x + y), + (Some(x), None) => Some(x + T::default()), + (None, Some(y)) => Some(T::default() + y), + (None, None) => None, + } +} + +impl Add for Usage { + type Output = Self; + + fn add(self, other: Self) -> Self { + Self { + input_tokens: sum_optionals(self.input_tokens, other.input_tokens), + output_tokens: sum_optionals(self.output_tokens, other.output_tokens), + total_tokens: sum_optionals(self.total_tokens, other.total_tokens), + } + } +} + +impl AddAssign for Usage { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + impl Usage { pub fn new( input_tokens: Option, @@ -270,6 +303,21 @@ pub trait Provider: Send + Sync { None } + async fn stream( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result { + Err(ProviderError::NotImplemented( + "streaming not implemented".to_string(), + )) + } + + fn supports_streaming(&self) -> bool { + false + } + /// Get the currently active model name /// For regular providers, this returns the configured model /// For LeadWorkerProvider, this returns the currently active model (lead or worker) @@ -282,6 +330,18 @@ pub trait Provider: Send + Sync { } } +/// A message stream yields partial text content but complete tool calls, all within the Message object +/// So a message with text will contain potentially just a word of a longer response, but tool calls +/// messages will only be yielded once concatenated. +pub type MessageStream = Pin< + Box, Option), ProviderError>> + Send>, +>; + +pub fn stream_from_single_message(message: Message, usage: ProviderUsage) -> MessageStream { + let stream = futures::stream::once(async move { Ok((Some(message), Some(usage))) }); + Box::pin(stream) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index 823afb6a71a..0ee20336017 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -219,11 +219,11 @@ impl ClaudeCodeProvider { annotations: None, })]; - let response_message = Message { - role: Role::Assistant, - created: chrono::Utc::now().timestamp(), - content: message_content, - }; + let response_message = Message::new( + Role::Assistant, + chrono::Utc::now().timestamp(), + message_content, + ); Ok((response_message, usage)) } @@ -353,14 +353,14 @@ impl ClaudeCodeProvider { println!("================================"); } - let message = Message { - role: mcp_core::Role::Assistant, - created: chrono::Utc::now().timestamp(), - content: vec![MessageContent::Text(mcp_core::content::TextContent { + let message = Message::new( + mcp_core::Role::Assistant, + chrono::Utc::now().timestamp(), + vec![MessageContent::Text(mcp_core::content::TextContent { text: description.clone(), annotations: None, })], - }; + ); let usage = Usage::default(); diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index fbfd22a0a63..e074d9393af 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -1,4 +1,16 @@ -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use anyhow::Result; +use async_stream::try_stream; +use async_trait::async_trait; +use futures::TryStreamExt; +use reqwest::{Client, StatusCode}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::io; +use std::time::Duration; +use tokio::pin; +use tokio_util::io::StreamReader; + +use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::embedding::EmbeddingCapable; use super::errors::ProviderError; use super::formats::databricks::{create_request, get_usage, response_to_message}; @@ -7,17 +19,13 @@ use super::utils::{get_model, ImageFormat}; use crate::config::ConfigError; use crate::message::Message; use crate::model::ModelConfig; +use crate::providers::formats::databricks::response_to_streaming_message; use mcp_core::tool::Tool; use serde_json::json; -use url::Url; - -use anyhow::Result; -use async_trait::async_trait; -use reqwest::{Client, StatusCode}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::time::Duration; use tokio::time::sleep; +use tokio_stream::StreamExt; +use tokio_util::codec::{FramedRead, LinesCodec}; +use url::Url; const DEFAULT_CLIENT_ID: &str = "databricks-cli"; const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; @@ -266,9 +274,6 @@ impl DatabricksProvider { } async fn post(&self, payload: Value) -> Result { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - // Check if this is an embedding request by looking at the payload structure let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none(); let path = if is_embedding { @@ -279,56 +284,71 @@ impl DatabricksProvider { format!("serving-endpoints/{}/invocations", self.model.model_name) }; - let url = base_url.join(&path).map_err(|e| { + match self.post_with_retry(path.as_str(), &payload).await { + Ok(res) => res.json().await.map_err(|_| { + ProviderError::RequestFailed("Response body is not valid JSON".to_string()) + }), + Err(e) => Err(e), + } + } + + async fn post_with_retry( + &self, + path: &str, + payload: &Value, + ) -> Result { + let base_url = Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let url = base_url.join(path).map_err(|e| { ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) })?; - // Initialize retry counter let mut attempts = 0; - let mut last_error = None; - loop { - // Check if we've exceeded max retries - if attempts > 0 && attempts > self.retry_config.max_retries { - let error_msg = format!( - "Exceeded maximum retry attempts ({}) for rate limiting (429)", - self.retry_config.max_retries - ); - tracing::error!("{}", error_msg); - return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg))); - } - let auth_header = self.ensure_auth_header().await?; let response = self .client .post(url.clone()) .header("Authorization", auth_header) - .json(&payload) + .json(payload) .send() .await?; let status = response.status(); - let payload: Option = response.json().await.ok(); - match status { - StatusCode::OK => { - return payload.ok_or_else(|| { - ProviderError::RequestFailed("Response body is not valid JSON".to_string()) - }); - } - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - return Err(ProviderError::Authentication(format!( - "Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", - status, payload - ))); + break match status { + StatusCode::OK => Ok(response), + StatusCode::TOO_MANY_REQUESTS + | StatusCode::INTERNAL_SERVER_ERROR + | StatusCode::SERVICE_UNAVAILABLE => { + if attempts < self.retry_config.max_retries { + attempts += 1; + tracing::warn!( + "{}: retrying ({}/{})", + status, + attempts, + self.retry_config.max_retries + ); + + let delay = self.retry_config.delay_for_attempt(attempts); + tracing::info!("Backing off for {:?} before retry", delay); + sleep(delay).await; + + continue; + } + + Err(match status { + StatusCode::TOO_MANY_REQUESTS => { + ProviderError::RateLimitExceeded("Rate limit exceeded".to_string()) + } + _ => ProviderError::ServerError("Server error".to_string()), + }) } StatusCode::BAD_REQUEST => { // Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific // We try to extract the error message from the payload and check for phrases that indicate context length exceeded - let payload_str = serde_json::to_string(&payload) - .unwrap_or_default() - .to_lowercase(); + let bytes = response.bytes().await?; + let payload_str = String::from_utf8_lossy(&bytes).to_lowercase(); let check_phrases = [ "too long", "context length", @@ -347,13 +367,13 @@ impl DatabricksProvider { } let mut error_msg = "Unknown error".to_string(); - if let Some(payload) = &payload { + if let Ok(response_json) = serde_json::from_slice::(&bytes) { // try to convert message to string, if that fails use external_model_message - error_msg = payload + error_msg = response_json .get("message") .and_then(|m| m.as_str()) .or_else(|| { - payload + response_json .get("external_model_message") .and_then(|ext| ext.get("message")) .and_then(|m| m.as_str()) @@ -366,7 +386,7 @@ impl DatabricksProvider { "{}", format!( "Provider request failed with status: {}. Payload: {:?}", - status, payload + status, payload_str ) ); return Err(ProviderError::RequestFailed(format!( @@ -374,50 +394,13 @@ impl DatabricksProvider { status, error_msg ))); } - StatusCode::TOO_MANY_REQUESTS => { - attempts += 1; - let error_msg = format!( - "Rate limit exceeded (attempt {}/{}): {:?}", - attempts, self.retry_config.max_retries, payload - ); - tracing::warn!("{}. Retrying after backoff...", error_msg); - - // Store the error in case we need to return it after max retries - last_error = Some(ProviderError::RateLimitExceeded(error_msg)); - - // Calculate and apply the backoff delay - let delay = self.retry_config.delay_for_attempt(attempts); - tracing::info!("Backing off for {:?} before retry", delay); - sleep(delay).await; - - // Continue to the next retry attempt - continue; - } - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - attempts += 1; - let error_msg = format!( - "Server error (attempt {}/{}): {:?}", - attempts, self.retry_config.max_retries, payload - ); - tracing::warn!("{}. Retrying after backoff...", error_msg); - - // Store the error in case we need to return it after max retries - last_error = Some(ProviderError::ServerError(error_msg)); - - // Calculate and apply the backoff delay - let delay = self.retry_config.delay_for_attempt(attempts); - tracing::info!("Backing off for {:?} before retry", delay); - sleep(delay).await; - - // Continue to the next retry attempt - continue; - } _ => { tracing::debug!( "{}", format!( "Provider request failed with status: {}. Payload: {:?}", - status, payload + status, + response.text().await.ok().unwrap_or_default() ) ); return Err(ProviderError::RequestFailed(format!( @@ -425,7 +408,7 @@ impl DatabricksProvider { status ))); } - } + }; } } } @@ -472,13 +455,12 @@ impl Provider for DatabricksProvider { // Parse response let message = response_to_message(response.clone())?; - let usage = match get_usage(&response) { - Ok(usage) => usage, - Err(ProviderError::UsageError(e)) => { - tracing::debug!("Failed to get usage data: {}", e); + let usage = match response.get("usage").map(get_usage) { + Some(usage) => usage, + None => { + tracing::debug!("Failed to get usage data"); Usage::default() } - Err(e) => return Err(e), }; let model = get_model(&response); super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); @@ -486,6 +468,54 @@ impl Provider for DatabricksProvider { Ok((message, ProviderUsage::new(model, usage))) } + async fn stream( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + let mut payload = create_request(&self.model, system, messages, tools, &self.image_format)?; + // Remove the model key which is part of the url with databricks + payload + .as_object_mut() + .expect("payload should have model key") + .remove("model"); + + payload + .as_object_mut() + .unwrap() + .insert("stream".to_string(), Value::Bool(true)); + + let response = self + .post_with_retry( + format!("serving-endpoints/{}/invocations", self.model.model_name).as_str(), + &payload, + ) + .await?; + + // Map reqwest error to io::Error + let stream = response.bytes_stream().map_err(io::Error::other); + + let model_config = self.model.clone(); + // Wrap in a line decoder and yield lines inside the stream + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); + + let message_stream = response_to_streaming_message(framed); + pin!(message_stream); + while let Some(message) = message_stream.next().await { + let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; + super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + yield (message, usage); + } + })) + } + + fn supports_streaming(&self) -> bool { + true + } + fn supports_embeddings(&self) -> bool { true } diff --git a/crates/goose/src/providers/errors.rs b/crates/goose/src/providers/errors.rs index c9f867c41f0..3ff7d1880ca 100644 --- a/crates/goose/src/providers/errors.rs +++ b/crates/goose/src/providers/errors.rs @@ -23,6 +23,9 @@ pub enum ProviderError { #[error("Usage data error: {0}")] UsageError(String), + + #[error("Unsupported operation: {0}")] + NotImplemented(String), } impl From for ProviderError { diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index 6c6f0f9b605..e627505b657 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -212,17 +212,17 @@ mod tests { _tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { Ok(( - Message { - role: Role::Assistant, - created: Utc::now().timestamp(), - content: vec![MessageContent::Text(TextContent { + Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text(TextContent { text: format!( "Response from {} with model {}", self.name, self.model_config.model_name ), annotations: None, })], - }, + ), ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()), )) } diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs index 29b3491585d..9a1651b57fc 100644 --- a/crates/goose/src/providers/formats/bedrock.rs +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -260,11 +260,7 @@ pub fn from_bedrock_message(message: &bedrock::Message) -> Result { .collect::>>()?; let created = Utc::now().timestamp(); - Ok(Message { - role, - content, - created, - }) + Ok(Message::new(role, created, content)) } pub fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result { diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index c74c4cbbe03..8c462b624e9 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -1,14 +1,16 @@ use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; -use crate::providers::base::Usage; -use crate::providers::errors::ProviderError; +use crate::providers::base::{ProviderUsage, Usage}; use crate::providers::utils::{ convert_image, detect_image_path, is_valid_function_name, load_image_file, sanitize_function_name, ImageFormat, }; use anyhow::{anyhow, Error}; +use async_stream::try_stream; +use futures::Stream; use mcp_core::ToolError; use mcp_core::{Content, Role, Tool, ToolCall}; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; /// Convert internal Message format to Databricks' API message specification @@ -358,18 +360,162 @@ pub fn response_to_message(response: Value) -> anyhow::Result { } } - Ok(Message { - role: Role::Assistant, - created: chrono::Utc::now().timestamp(), + Ok(Message::new( + Role::Assistant, + chrono::Utc::now().timestamp(), content, - }) + )) } -pub fn get_usage(data: &Value) -> Result { - let usage = data - .get("usage") - .ok_or_else(|| ProviderError::UsageError("No usage data in response".to_string()))?; +#[derive(Serialize, Deserialize, Debug)] +struct DeltaToolCallFunction { + name: Option, + arguments: String, // chunk of encoded JSON, +} + +#[derive(Serialize, Deserialize, Debug)] +struct DeltaToolCall { + id: Option, + function: DeltaToolCallFunction, + index: Option, + r#type: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +struct Delta { + content: Option, + role: Option, + tool_calls: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +struct StreamingChoice { + delta: Delta, + index: Option, + finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +struct StreamingChunk { + choices: Vec, + created: Option, + id: Option, + usage: Option, + model: String, +} + +fn strip_data_prefix(line: &str) -> Option<&str> { + line.strip_prefix("data: ").map(|s| s.trim()) +} + +pub fn response_to_streaming_message( + mut stream: S, +) -> impl Stream, Option)>> + 'static +where + S: Stream> + Unpin + Send + 'static, +{ + try_stream! { + use futures::StreamExt; + + 'outer: while let Some(response) = stream.next().await { + if response.as_ref().is_ok_and(|s| s == "data: [DONE]") { + break 'outer; + } + let response_str = response?; + let line = strip_data_prefix(&response_str); + + if line.is_none() || line.is_some_and(|l| l.is_empty()) { + continue + } + + let chunk: StreamingChunk = serde_json::from_str(line + .ok_or_else(|| anyhow!("unexpected stream format"))?) + .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; + let model = chunk.model.clone(); + + let usage = chunk.usage.as_ref().map(|u| { + ProviderUsage { + usage: get_usage(u), + model, + } + }); + + if chunk.choices.is_empty() { + yield (None, usage) + } else if let Some(tool_calls) = &chunk.choices[0].delta.tool_calls { + let tool_call = &tool_calls[0]; + let id = tool_call.id.clone().ok_or(anyhow!("No tool call ID"))?; + let function_name = tool_call.function.name.clone().ok_or(anyhow!("No function name"))?; + let mut arguments = tool_call.function.arguments.clone(); + + while let Some(response_chunk) = stream.next().await { + if response_chunk.as_ref().is_ok_and(|s| s == "data: [DONE]") { + break 'outer; + } + let response_str = response_chunk?; + if let Some(line) = strip_data_prefix(&response_str) { + let tool_chunk: StreamingChunk = serde_json::from_str(line) + .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; + let more_args = tool_chunk.choices[0].delta.tool_calls.as_ref() + .and_then(|calls| calls.first()) + .map(|call| call.function.arguments.as_str()); + if let Some(more_args) = more_args { + arguments.push_str(more_args); + } else { + break; + } + } + } + + let parsed = if arguments.is_empty() { + Ok(json!({})) + } else { + serde_json::from_str::(&arguments) + }; + + let content = match parsed { + Ok(params) => MessageContent::tool_request( + id, + Ok(ToolCall::new(function_name, params)), + ), + Err(e) => { + let error = ToolError::InvalidParameters(format!( + "Could not interpret tool use parameters for id {}: {}", + id, e + )); + MessageContent::tool_request(id, Err(error)) + } + }; + + yield ( + Some(Message { + id: chunk.id, + role: Role::Assistant, + created: chrono::Utc::now().timestamp(), + content: vec![content], + }), + usage, + ) + } else if let Some(text) = &chunk.choices[0].delta.content { + yield ( + Some(Message { + id: chunk.id, + role: Role::Assistant, + created: chrono::Utc::now().timestamp(), + content: vec![MessageContent::text(text)], + }), + if chunk.choices[0].finish_reason.is_some() { + usage + } else { + None + }, + ) + } + } + } +} +pub fn get_usage(usage: &Value) -> Usage { let input_tokens = usage .get("prompt_tokens") .and_then(|v| v.as_i64()) @@ -389,7 +535,7 @@ pub fn get_usage(data: &Value) -> Result { _ => None, }); - Ok(Usage::new(input_tokens, output_tokens, total_tokens)) + Usage::new(input_tokens, output_tokens, total_tokens) } /// Validates and fixes tool schemas to ensure they have proper parameter structure. diff --git a/crates/goose/src/providers/formats/google.rs b/crates/goose/src/providers/formats/google.rs index 47b774df5a6..6c801d50453 100644 --- a/crates/goose/src/providers/formats/google.rs +++ b/crates/goose/src/providers/formats/google.rs @@ -209,11 +209,7 @@ pub fn response_to_message(response: Value) -> Result { let role = Role::Assistant; let created = chrono::Utc::now().timestamp(); if candidate.is_none() { - return Ok(Message { - role, - created, - content, - }); + return Ok(Message::new(role, created, content)); } let candidate = candidate.unwrap(); let parts = candidate @@ -252,11 +248,7 @@ pub fn response_to_message(response: Value) -> Result { } } } - Ok(Message { - role, - created, - content, - }) + Ok(Message::new(role, created, content)) } /// Extract usage information from Google's API response @@ -324,43 +316,39 @@ mod tests { use serde_json::json; fn set_up_text_message(text: &str, role: Role) -> Message { - Message { - role, - created: 0, - content: vec![MessageContent::text(text.to_string())], - } + Message::new(role, 0, vec![MessageContent::text(text.to_string())]) } fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message { - Message { - role: Role::User, - created: 0, - content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))], - } + Message::new( + Role::User, + 0, + vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))], + ) } fn set_up_tool_confirmation_message(id: &str, tool_call: ToolCall) -> Message { - Message { - role: Role::User, - created: 0, - content: vec![MessageContent::tool_confirmation_request( + Message::new( + Role::User, + 0, + vec![MessageContent::tool_confirmation_request( id.to_string(), tool_call.name.clone(), tool_call.arguments.clone(), Some("Goose would like to call the above tool. Allow? (y/n):".to_string()), )], - } + ) } fn set_up_tool_response_message(id: &str, tool_response: Vec) -> Message { - Message { - role: Role::Assistant, - created: 0, - content: vec![MessageContent::tool_response( + Message::new( + Role::Assistant, + 0, + vec![MessageContent::tool_response( id.to_string(), Ok(tool_response), )], - } + ) } fn set_up_tool(name: &str, description: &str, params: Value) -> Tool { diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 402e09cfee9..ce929253405 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -274,11 +274,11 @@ pub fn response_to_message(response: Value) -> anyhow::Result { } } - Ok(Message { - role: Role::Assistant, - created: chrono::Utc::now().timestamp(), + Ok(Message::new( + Role::Assistant, + chrono::Utc::now().timestamp(), content, - }) + )) } pub fn get_usage(data: &Value) -> Result { diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index 6308f902405..fc696bab21a 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -169,14 +169,14 @@ impl GeminiCliProvider { )); } - let message = Message { - role: Role::Assistant, - created: chrono::Utc::now().timestamp(), - content: vec![MessageContent::Text(TextContent { + let message = Message::new( + Role::Assistant, + chrono::Utc::now().timestamp(), + vec![MessageContent::Text(TextContent { text: response_text, annotations: None, })], - }; + ); let usage = Usage::default(); // No usage info available for gemini CLI @@ -214,14 +214,14 @@ impl GeminiCliProvider { println!("================================"); } - let message = Message { - role: Role::Assistant, - created: chrono::Utc::now().timestamp(), - content: vec![MessageContent::Text(TextContent { + let message = Message::new( + Role::Assistant, + chrono::Utc::now().timestamp(), + vec![MessageContent::Text(TextContent { text: description.clone(), annotations: None, })], - }; + ); let usage = Usage::default(); diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index 8e4552d5e18..ea892342c8d 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -480,14 +480,14 @@ mod tests { _tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { Ok(( - Message { - role: Role::Assistant, - created: Utc::now().timestamp(), - content: vec![MessageContent::Text(TextContent { + Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text(TextContent { text: format!("Response from {}", self.name), annotations: None, })], - }, + ), ProviderUsage::new(self.name.clone(), Usage::default()), )) } @@ -643,14 +643,14 @@ mod tests { )) } else { Ok(( - Message { - role: Role::Assistant, - created: Utc::now().timestamp(), - content: vec![MessageContent::Text(TextContent { + Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text(TextContent { text: format!("Response from {}", self.name), annotations: None, })], - }, + ), ProviderUsage::new(self.name.clone(), Usage::default()), )) } diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 19b95deb83f..c2ced2a57fd 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -203,14 +203,14 @@ impl SageMakerTgiProvider { // Strip any HTML tags that might have been generated let clean_text = self.strip_html_tags(generated_text); - Ok(Message { - role: Role::Assistant, - created: Utc::now().timestamp(), - content: vec![MessageContent::Text(TextContent { + Ok(Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text(TextContent { text: clean_text, annotations: None, })], - }) + )) } /// Strip HTML tags from text to ensure clean output diff --git a/crates/goose/src/providers/toolshim.rs b/crates/goose/src/providers/toolshim.rs index 2827caa71d9..0647d0a06e7 100644 --- a/crates/goose/src/providers/toolshim.rs +++ b/crates/goose/src/providers/toolshim.rs @@ -359,11 +359,7 @@ pub fn convert_tool_messages_to_text(messages: &[Message]) -> Vec { } if has_tool_content { - Message { - role: message.role.clone(), - content: new_content, - created: message.created, - } + Message::new(message.role.clone(), message.created, new_content) } else { message.clone() } diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index c73023b42e7..7bcc172b645 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -319,12 +319,15 @@ pub fn unescape_json_values(value: &Value) -> Value { } } -pub fn emit_debug_trace( +pub fn emit_debug_trace( model_config: &ModelConfig, - payload: &Value, - response: &Value, + payload: &T1, + response: &T2, usage: &Usage, -) { +) where + T1: ?Sized + Serialize, + T2: ?Sized + Serialize, +{ tracing::debug!( model_config = %serde_json::to_string_pretty(model_config).unwrap_or_default(), input = %serde_json::to_string_pretty(payload).unwrap_or_default(), diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 81ee9c0b85a..9046cdd58c9 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -557,11 +557,7 @@ impl Provider for VeniceProvider { }; Ok(( - Message { - role: Role::Assistant, - created: Utc::now().timestamp(), - content, - }, + Message::new(Role::Assistant, Utc::now().timestamp(), content), ProviderUsage::new(strip_flags(&self.model.model_name).to_string(), usage), )) } diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 64648722545..20c455deb9b 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1370,14 +1370,14 @@ mod tests { _tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { Ok(( - Message { - role: Role::Assistant, - created: Utc::now().timestamp(), - content: vec![MessageContent::Text(TextContent { + Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text(TextContent { text: "Mocked scheduled response".to_string(), annotations: None, })], - }, + ), ProviderUsage::new("mock-scheduler-test".to_string(), Usage::default()), )) } diff --git a/crates/goose/src/session/storage.rs b/crates/goose/src/session/storage.rs index b176511f300..689006fa990 100644 --- a/crates/goose/src/session/storage.rs +++ b/crates/goose/src/session/storage.rs @@ -1526,7 +1526,7 @@ mod tests { "]}}\"\\n\\\"{[", "Edge case: } ] some text", "{\"foo\": \"} ]\"}", - "}]", + "}]", ]; let mut messages = Vec::new(); diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index c83ab813243..6db01dd3ced 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -407,25 +407,25 @@ mod tests { "You are a helpful assistant that can answer questions about the weather."; let messages = vec![ - Message { - role: Role::User, - created: 0, - content: vec![MessageContent::text( + Message::new( + Role::User, + 0, + vec![MessageContent::text( "What's the weather like in San Francisco?", )], - }, - Message { - role: Role::Assistant, - created: 1, - content: vec![MessageContent::text( + ), + Message::new( + Role::Assistant, + 1, + vec![MessageContent::text( "Looks like it's 60 degrees Fahrenheit in San Francisco.", )], - }, - Message { - role: Role::User, - created: 2, - content: vec![MessageContent::text("How about New York?")], - }, + ), + Message::new( + Role::User, + 2, + vec![MessageContent::text("How about New York?")], + ), ]; let tools = vec![Tool { @@ -505,25 +505,25 @@ mod tests { "You are a helpful assistant that can answer questions about the weather."; let messages = vec![ - Message { - role: Role::User, - created: 0, - content: vec![MessageContent::text( + Message::new( + Role::User, + 0, + vec![MessageContent::text( "What's the weather like in San Francisco?", )], - }, - Message { - role: Role::Assistant, - created: 1, - content: vec![MessageContent::text( + ), + Message::new( + Role::Assistant, + 1, + vec![MessageContent::text( "Looks like it's 60 degrees Fahrenheit in San Francisco.", )], - }, - Message { - role: Role::User, - created: 2, - content: vec![MessageContent::text("How about New York?")], - }, + ), + Message::new( + Role::User, + 2, + vec![MessageContent::text("How about New York?")], + ), ]; let tools = vec![Tool { diff --git a/ui/desktop/src/components/ChatView.tsx b/ui/desktop/src/components/ChatView.tsx index 6c34d069c53..4bb80d1879c 100644 --- a/ui/desktop/src/components/ChatView.tsx +++ b/ui/desktop/src/components/ChatView.tsx @@ -786,7 +786,7 @@ function ChatContent({ {filteredMessages.map((message, index) => (
diff --git a/ui/desktop/src/components/GooseMessage.tsx b/ui/desktop/src/components/GooseMessage.tsx index 12568d28ea8..cb09e3da127 100644 --- a/ui/desktop/src/components/GooseMessage.tsx +++ b/ui/desktop/src/components/GooseMessage.tsx @@ -130,7 +130,7 @@ export default function GooseMessage({ ]); return ( -
+
{/* Chain-of-Thought (hidden by default) */} {cotText && ( diff --git a/ui/desktop/src/hooks/useMessageStream.ts b/ui/desktop/src/hooks/useMessageStream.ts index 3e38238631a..325c9630b81 100644 --- a/ui/desktop/src/hooks/useMessageStream.ts +++ b/ui/desktop/src/hooks/useMessageStream.ts @@ -1,4 +1,4 @@ -import { useState, useCallback, useEffect, useRef, useId } from 'react'; +import { useState, useCallback, useEffect, useRef, useId, useReducer } from 'react'; import useSWR from 'swr'; import { getSecretKey } from '../config'; import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message'; @@ -235,6 +235,9 @@ export function useMessageStream({ }; }, [headers, body]); + // TODO: not this? + const [, forceUpdate] = useReducer((x) => x + 1, 0); + // Process the SSE stream from the server const processMessageStream = useCallback( async (response: Response, currentMessages: Message[]) => { @@ -284,8 +287,23 @@ export function useMessageStream({ : parsedEvent.message.sendToLLM, }; + console.log('New message:', JSON.stringify(newMessage, null, 2)); + // Update messages with the new message - currentMessages = [...currentMessages, newMessage]; + + if ( + newMessage.id && + currentMessages.length > 0 && + currentMessages[currentMessages.length - 1].id === newMessage.id + ) { + // If the last message has the same ID, update it instead of adding a new one + const lastMessage = currentMessages[currentMessages.length - 1]; + lastMessage.content = [...lastMessage.content, ...newMessage.content]; + forceUpdate(); + } else { + currentMessages = [...currentMessages, newMessage]; + } + mutate(currentMessages, false); break; } @@ -373,7 +391,7 @@ export function useMessageStream({ return currentMessages; }, - [mutate, onFinish, onError] + [mutate, onFinish, onError, forceUpdate] ); // Send a request to the server diff --git a/ui/desktop/src/types/message.ts b/ui/desktop/src/types/message.ts index 4e52b6cf785..1a83938c897 100644 --- a/ui/desktop/src/types/message.ts +++ b/ui/desktop/src/types/message.ts @@ -201,7 +201,7 @@ export function getTextContent(message: Message): string { } return ''; }) - .join('\n'); + .join(''); } export function getToolRequests(message: Message): ToolRequestMessageContent[] {