diff --git a/.envrc b/.envrc deleted file mode 100644 index af0cc9383e34..000000000000 --- a/.envrc +++ /dev/null @@ -1,2 +0,0 @@ -#!/usr/bin/env bash -use flake diff --git a/crates/goose-acp/src/server.rs b/crates/goose-acp/src/server.rs index 78b2c83b0d9b..407031560cfd 100644 --- a/crates/goose-acp/src/server.rs +++ b/crates/goose-acp/src/server.rs @@ -1289,20 +1289,14 @@ print(\"hello, world\") "mock" } - async fn complete_with_model( + async fn stream( &self, - _session_id: Option<&str>, _model_config: &goose::model::ModelConfig, + _session_id: &str, _system: &str, _messages: &[goose::conversation::message::Message], _tools: &[rmcp::model::Tool], - ) -> Result< - ( - goose::conversation::message::Message, - goose::providers::base::ProviderUsage, - ), - ProviderError, - > { + ) -> Result { unimplemented!() } diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index ef7bcb01e33f..f6056fcbcb9b 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -1673,11 +1673,11 @@ pub async fn handle_openrouter_auth() -> anyhow::Result<()> { match create("openrouter", model_config, Vec::new()).await { Ok(provider) => { - let model_config = provider.get_model_config(); + let provider_model_config = provider.get_model_config(); let test_result = provider - .complete_with_model( - None, - &model_config, + .complete( + &provider_model_config, + "", "You are goose, an AI assistant.", &[Message::user().with_text("Say 'Configuration test successful!'")], &[], diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 853e63986594..bfa5c8700cb0 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -211,8 +211,10 @@ pub async fn classify_planner_response( let prompt = format!("The text below is the output from an AI model which can either provide a plan or list of clarifying questions. Based on the text below, decide if the output is a \"plan\" or \"clarifying questions\".\n---\n{message_text}"); let message = Message::user().with_text(&prompt); + let model_config = provider.get_model_config(); let (result, _usage) = provider .complete( + &model_config, session_id, "Reply only with the classification label: \"plan\" or \"clarifying questions\"", &[message], @@ -840,8 +842,10 @@ impl CliSession { ) -> Result<(), anyhow::Error> { let plan_prompt = self.agent.get_plan_prompt(&self.session_id).await?; output::show_thinking(); + let model_config = reasoner.get_model_config(); let (plan_response, _usage) = reasoner .complete( + &model_config, &self.session_id, &plan_prompt, plan_messages.messages(), diff --git a/crates/goose/examples/databricks_oauth.rs b/crates/goose/examples/databricks_oauth.rs index a263fa8d1d69..ce7b53842c1a 100644 --- a/crates/goose/examples/databricks_oauth.rs +++ b/crates/goose/examples/databricks_oauth.rs @@ -8,21 +8,18 @@ use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; async fn main() -> Result<()> { dotenv().ok(); - // Clear any token to force OAuth std::env::remove_var("DATABRICKS_TOKEN"); - // Create the provider let provider = create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL, Vec::new()).await?; - // Create a simple message let message = Message::user().with_text("Tell me a short joke about programming."); - // Get a response + let model_config = provider.get_model_config(); let (response, usage) = provider - .complete_with_model( - None, - &provider.get_model_config(), + .complete( + &model_config, + "", "You are a helpful assistant.", &[message], &[], diff --git a/crates/goose/examples/image_tool.rs b/crates/goose/examples/image_tool.rs index 1fe66f92e576..df7c2e6cdfba 100644 --- a/crates/goose/examples/image_tool.rs +++ b/crates/goose/examples/image_tool.rs @@ -62,10 +62,11 @@ async fn main() -> Result<()> { }, } }); + let model_config = provider.get_model_config(); let (response, usage) = provider - .complete_with_model( - None, - &provider.get_model_config(), + .complete( + &model_config, + "", "You are a helpful assistant. Please describe any text you see in the image.", &messages, &[Tool::new("view_image", "View an image", input_schema)], diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 9ed9f8838731..18ec20187c1c 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1777,6 +1777,15 @@ impl Agent { ); tracing::info!("Calling provider to generate recipe content"); + let model_config = { + let provider_guard = self.provider.lock().await; + let provider = provider_guard.as_ref().ok_or_else(|| { + let error = anyhow!("Provider not available during recipe creation"); + tracing::error!("{}", error); + error + })?; + provider.get_model_config() + }; let (result, _usage) = self .provider .lock() @@ -1787,7 +1796,13 @@ impl Agent { tracing::error!("{}", error); error })? - .complete(session_id, &system_prompt, messages.messages(), &tools) + .complete( + &model_config, + session_id, + &system_prompt, + messages.messages(), + &tools, + ) .await .map_err(|e| { tracing::error!("Provider completion failed during recipe creation: {}", e); diff --git a/crates/goose/src/agents/mcp_client.rs b/crates/goose/src/agents/mcp_client.rs index fb585c2b8bb4..5359eca52d79 100644 --- a/crates/goose/src/agents/mcp_client.rs +++ b/crates/goose/src/agents/mcp_client.rs @@ -242,10 +242,11 @@ impl ClientHandler for GooseClient { .as_deref() .unwrap_or("You are a general-purpose AI agent called goose"); + let model_config = provider.get_model_config(); let (response, usage) = provider - .complete_with_model( - session_id.as_deref(), - &provider.get_model_config(), + .complete( + &model_config, + session_id.as_deref().unwrap_or(""), system_prompt, &provider_ready_messages, &[], diff --git a/crates/goose/src/agents/platform_extensions/apps.rs b/crates/goose/src/agents/platform_extensions/apps.rs index 46ea1cb94226..1bba33cf10c5 100644 --- a/crates/goose/src/agents/platform_extensions/apps.rs +++ b/crates/goose/src/agents/platform_extensions/apps.rs @@ -294,13 +294,7 @@ impl AppsManagerClient { model_config.max_tokens = Some(16384); let (response, _usage) = provider - .complete_with_model( - Some(session_id), - &model_config, - &system_prompt, - &messages, - &tools, - ) + .complete(&model_config, session_id, &system_prompt, &messages, &tools) .await .map_err(|e| format!("LLM call failed: {}", e))?; @@ -334,13 +328,7 @@ impl AppsManagerClient { model_config.max_tokens = Some(16384); let (response, _usage) = provider - .complete_with_model( - Some(session_id), - &model_config, - &system_prompt, - &messages, - &tools, - ) + .complete(&model_config, session_id, &system_prompt, &messages, &tools) .await .map_err(|e| format!("LLM call failed: {}", e))?; diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 54e6d81ae8ed..4dae82a889fb 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -11,7 +11,9 @@ use super::super::agents::Agent; use crate::agents::platform_extensions::code_execution; use crate::conversation::message::{Message, MessageContent, ToolRequest}; use crate::conversation::Conversation; -use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage}; +#[cfg(test)] +use crate::providers::base::stream_from_single_message; +use crate::providers::base::{MessageStream, Provider, ProviderUsage}; use crate::providers::errors::ProviderError; use crate::providers::toolshim::{ augment_message_with_tool_calls, convert_tool_messages_to_text, @@ -229,35 +231,18 @@ impl Agent { // Capture errors during stream creation and return them as part of the stream // so they can be handled by the existing error handling logic in the agent - let stream_result = if provider.supports_streaming() { - debug!("WAITING_LLM_STREAM_START"); - let result = provider - .stream( - session_id, - system_prompt.as_str(), - messages_for_provider.messages(), - &tools, - ) - .await; - debug!("WAITING_LLM_STREAM_END"); - result - } else { - debug!("WAITING_LLM_START"); - let complete_result = provider - .complete( - session_id, - system_prompt.as_str(), - messages_for_provider.messages(), - &tools, - ) - .await; - debug!("WAITING_LLM_END"); - - match complete_result { - Ok((message, usage)) => Ok(stream_from_single_message(message, usage)), - Err(e) => Err(e), - } - }; + let model_config = provider.get_model_config(); + debug!("WAITING_LLM_STREAM_START"); + let stream_result = provider + .stream( + &model_config, + session_id, + system_prompt.as_str(), + messages_for_provider.messages(), + &tools, + ) + .await; + debug!("WAITING_LLM_STREAM_END"); // If there was an error creating the stream, return a stream that yields that error let mut stream = match stream_result { @@ -462,18 +447,17 @@ mod tests { self.model_config.clone() } - async fn complete_with_model( + async fn stream( &self, - _session_id: Option<&str>, _model_config: &ModelConfig, + _session_id: &str, _system: &str, _messages: &[Message], _tools: &[Tool], - ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { - Ok(( - Message::assistant().with_text("ok"), - ProviderUsage::new("mock".to_string(), Usage::default()), - )) + ) -> Result { + let message = Message::assistant().with_text("ok"); + let usage = ProviderUsage::new("mock".to_string(), Usage::default()); + Ok(stream_from_single_message(message, usage)) } } diff --git a/crates/goose/src/context_mgmt/mod.rs b/crates/goose/src/context_mgmt/mod.rs index 2041f032e5ec..d8bd486ee8e2 100644 --- a/crates/goose/src/context_mgmt/mod.rs +++ b/crates/goose/src/context_mgmt/mod.rs @@ -2,6 +2,8 @@ use crate::conversation::message::{ActionRequiredData, MessageMetadata}; use crate::conversation::message::{Message, MessageContent}; use crate::conversation::{merge_consecutive_messages, Conversation}; use crate::prompt_template::render_template; +#[cfg(test)] +use crate::providers::base::{stream_from_single_message, MessageStream}; use crate::providers::base::{Provider, ProviderUsage}; use crate::providers::errors::ProviderError; use crate::{config::Config, token_counter::create_token_counter}; @@ -568,14 +570,14 @@ mod tests { "mock" } - async fn complete_with_model( + async fn stream( &self, - _session_id: Option<&str>, _model_config: &ModelConfig, + _session_id: &str, _system: &str, messages: &[Message], _tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { // If max_tool_responses is set, fail if we have too many if let Some(max) = self.max_tool_responses { let tool_response_count = messages @@ -595,10 +597,9 @@ mod tests { } } - Ok(( - self.message.clone(), - ProviderUsage::new("mock-model".to_string(), Usage::default()), - )) + let message = self.message.clone(); + let usage = ProviderUsage::new("mock-model".to_string(), Usage::default()); + Ok(stream_from_single_message(message, usage)) } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index d6f848bdf725..fa2d4797a076 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -144,8 +144,10 @@ pub async fn detect_read_only_tools( let system_prompt = render_template("permission_judge.md", &context) .unwrap_or_else(|_| "You are a good analyst and can detect operations whether they have read-only operations.".to_string()); + let model_config = provider.get_model_config(); let res = provider .complete( + &model_config, session_id, &system_prompt, check_messages.messages(), diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 2bec2118edea..fd2cd814b856 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -8,21 +8,15 @@ use std::io; use tokio::pin; use tokio_util::io::StreamReader; -use super::api_client::{ApiClient, ApiResponse, AuthMethod}; -use super::base::{ - ConfigKey, MessageStream, ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderUsage, -}; +use super::api_client::{ApiClient, AuthMethod}; +use super::base::{ConfigKey, MessageStream, ModelInfo, Provider, ProviderDef, ProviderMetadata}; use super::errors::ProviderError; -use super::formats::anthropic::{ - create_request, get_usage, response_to_message, response_to_streaming_message, -}; +use super::formats::anthropic::{create_request, response_to_streaming_message}; use super::openai_compatible::handle_status_openai_compat; use super::openai_compatible::map_http_error_to_provider_error; -use super::utils::get_model; use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::retry::ProviderRetry; use crate::providers::utils::RequestLog; use futures::future::BoxFuture; use rmcp::model::Tool; @@ -110,10 +104,19 @@ impl AnthropicProvider { api_client = api_client.with_headers(header_map)?; } + let supports_streaming = config.supports_streaming.unwrap_or(true); + + if !supports_streaming { + return Err(anyhow::anyhow!( + "Anthropic provider does not support non-streaming mode. All Claude models support streaming. \ + Please remove 'supports_streaming: false' from your provider configuration." + )); + } + Ok(Self { api_client, model, - supports_streaming: config.supports_streaming.unwrap_or(true), + supports_streaming, name: config.name.clone(), }) } @@ -131,50 +134,6 @@ impl AnthropicProvider { headers } - - async fn post( - &self, - session_id: Option<&str>, - payload: &Value, - ) -> Result { - let mut request = self.api_client.request(session_id, "v1/messages"); - - for (key, value) in self.get_conditional_headers() { - request = request.header(key, value)?; - } - - Ok(request.api_post(payload).await?) - } - - fn anthropic_api_call_result(response: ApiResponse) -> Result { - match response.status { - StatusCode::OK => response.payload.ok_or_else(|| { - ProviderError::RequestFailed("Response body is not valid JSON".to_string()) - }), - _ => { - if response.status == StatusCode::BAD_REQUEST { - if let Some(error_msg) = response - .payload - .as_ref() - .and_then(|p| p.get("error")) - .and_then(|e| e.get("message")) - .and_then(|m| m.as_str()) - { - let msg = error_msg.to_string(); - if msg.to_lowercase().contains("too long") - || msg.to_lowercase().contains("too many") - { - return Err(ProviderError::ContextLengthExceeded(msg)); - } - } - } - Err(map_http_error_to_provider_error( - response.status, - response.payload, - )) - } - } - } } impl ProviderDef for AnthropicProvider { @@ -223,42 +182,6 @@ impl Provider for AnthropicProvider { self.model.clone() } - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request(model_config, system, messages, tools)?; - - let response = self - .with_retry(|| async { self.post(session_id, &payload).await }) - .await?; - - let json_response = Self::anthropic_api_call_result(response)?; - - let message = response_to_message(&json_response)?; - let usage = get_usage(&json_response)?; - tracing::debug!("๐Ÿ” Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", - usage.input_tokens, usage.output_tokens, usage.total_tokens); - - let response_model = get_model(&json_response); - let mut log = RequestLog::start(&self.model, &payload)?; - log.write(&json_response, Some(&usage))?; - let provider_usage = ProviderUsage::new(response_model, usage); - tracing::debug!( - "๐Ÿ” Anthropic non-streaming returning ProviderUsage: {:?}", - provider_usage - ); - Ok((message, provider_usage)) - } - async fn fetch_supported_models(&self) -> Result, ProviderError> { let response = self.api_client.request(None, "v1/models").api_get().await?; @@ -286,19 +209,20 @@ impl Provider for AnthropicProvider { async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result { - let mut payload = create_request(&self.model, system, messages, tools)?; + let mut payload = create_request(model_config, system, messages, tools)?; payload .as_object_mut() .unwrap() .insert("stream".to_string(), Value::Bool(true)); let mut request = self.api_client.request(Some(session_id), "v1/messages"); - let mut log = RequestLog::start(&self.model, &payload)?; + let mut log = RequestLog::start(model_config, &payload)?; for (key, value) in self.get_conditional_headers() { request = request.header(key, value)?; @@ -326,8 +250,4 @@ impl Provider for AnthropicProvider { } })) } - - fn supports_streaming(&self) -> bool { - self.supports_streaming - } } diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 173dddf66002..bc3033373b1c 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -9,7 +9,7 @@ use super::errors::ProviderError; use super::retry::RetryConfig; use crate::config::base::ConfigValue; use crate::config::ExtensionConfig; -use crate::conversation::message::Message; +use crate::conversation::message::{Message, MessageContent}; use crate::conversation::Conversation; use crate::model::ModelConfig; use crate::utils::safe_truncate; @@ -379,34 +379,32 @@ pub trait Provider: Send + Sync { /// Get the name of this provider instance fn get_name(&self) -> &str; - // Internal implementation of complete, used by complete_fast and complete - // Providers should override this to implement their actual completion logic - // - /// # Parameters - /// - `session_id`: Use `None` only for configuration or pre-session tasks. - async fn complete_with_model( + /// Primary streaming method that all providers must implement. + async fn stream( &self, - session_id: Option<&str>, model_config: &ModelConfig, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError>; + ) -> Result; - // Default implementation: use the provider's configured model + /// Complete with a specific model config. async fn complete( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let model_config = self.get_model_config(); - self.complete_with_model(Some(session_id), &model_config, system, messages, tools) - .await + let stream = self + .stream(model_config, session_id, system, messages, tools) + .await?; + collect_stream(stream).await } - // Check if a fast model is configured, otherwise fall back to regular model + /// Try fast model first, fall back to regular model on failure. async fn complete_fast( &self, session_id: &str, @@ -417,11 +415,12 @@ pub trait Provider: Send + Sync { let model_config = self.get_model_config(); let fast_config = model_config.use_fast_model(); - match self - .complete_with_model(Some(session_id), &fast_config, system, messages, tools) - .await - { - Ok(result) => Ok(result), + let result = self + .complete(&fast_config, session_id, system, messages, tools) + .await; + + match result { + Ok(response) => Ok(response), Err(e) => { if fast_config.model_name != model_config.model_name { tracing::warn!( @@ -430,14 +429,8 @@ pub trait Provider: Send + Sync { e, model_config.model_name ); - self.complete_with_model( - Some(session_id), - &model_config, - system, - messages, - tools, - ) - .await + self.complete(&model_config, session_id, system, messages, tools) + .await } else { Err(e) } @@ -553,22 +546,6 @@ pub trait Provider: Send + Sync { None } - async fn stream( - &self, - _session_id: &str, - _system: &str, - _messages: &[Message], - _tools: &[Tool], - ) -> Result { - Err(ProviderError::NotImplemented( - "streaming not implemented".to_string(), - )) - } - - fn supports_streaming(&self) -> bool { - false - } - /// Get the currently active model name /// For regular providers, this returns the configured model /// For LeadWorkerProvider, this returns the currently active model (lead or worker) @@ -664,18 +641,149 @@ pub fn stream_from_single_message(message: Message, usage: ProviderUsage) -> Mes Box::pin(stream) } +/// Collect all chunks from a MessageStream into a single Message and ProviderUsage +pub async fn collect_stream( + mut stream: MessageStream, +) -> Result<(Message, ProviderUsage), ProviderError> { + use futures::StreamExt; + + let mut final_message: Option = None; + let mut final_usage: Option = None; + + while let Some(result) = stream.next().await { + let (msg_opt, usage_opt) = result?; + + if let Some(msg) = msg_opt { + final_message = Some(match final_message { + Some(mut prev) => { + for new_content in msg.content { + match (&mut prev.content.last_mut(), &new_content) { + // Coalesce consecutive text blocks + ( + Some(MessageContent::Text(last_text)), + MessageContent::Text(new_text), + ) => { + last_text.text.push_str(&new_text.text); + } + _ => { + prev.content.push(new_content); + } + } + } + prev + } + None => msg, + }); + } + + if let Some(usage) = usage_opt { + final_usage = Some(usage); + } + } + + match final_message { + Some(msg) => { + let usage = final_usage + .unwrap_or_else(|| ProviderUsage::new("unknown".to_string(), Usage::default())); + Ok((msg, usage)) + } + None => Err(ProviderError::ExecutionError( + "Stream yielded no message".to_string(), + )), + } +} + #[cfg(test)] mod tests { use super::*; use std::collections::HashMap; + use test_case::test_case; use serde_json::json; - #[test] - fn test_usage_creation() { - let usage = Usage::new(Some(10), Some(20), Some(30)); - assert_eq!(usage.input_tokens, Some(10)); - assert_eq!(usage.output_tokens, Some(20)); - assert_eq!(usage.total_tokens, Some(30)); + fn content_from_str(s: String) -> MessageContent { + if let Some(img_data) = s.strip_prefix("*img:") { + MessageContent::image(format!("http://example.com/{}", img_data), "image/png") + } else if let Some(tool_name) = s.strip_prefix("*tool:") { + let tool_call = Ok(rmcp::model::CallToolRequestParams { + meta: None, + task: None, + name: tool_name.to_string().into(), + arguments: Some(serde_json::Map::new()), + }); + MessageContent::tool_request(format!("tool_{}", tool_name), tool_call) + } else { + MessageContent::text(s) + } + } + + fn create_test_stream( + items: Vec, + ) -> impl Stream, Option), ProviderError>> { + use futures::stream; + stream::iter(items.into_iter().map(|item| { + let content = content_from_str(item); + let message = Message::new( + rmcp::model::Role::Assistant, + chrono::Utc::now().timestamp(), + vec![content], + ); + Ok((Some(message), None)) + })) + } + + fn content_to_strings(msg: &Message) -> Vec { + msg.content + .iter() + .map(|c| match c { + MessageContent::Text(t) => t.text.clone(), + MessageContent::Image(_) => "*img".to_string(), + MessageContent::ToolRequest(tr) => { + if let Ok(call) = &tr.tool_call { + format!("*tool:{}", call.name) + } else { + "*tool:error".to_string() + } + } + _ => "*other".to_string(), + }) + .collect() + } + + #[test_case( + vec!["Hello", " ", "world"], + vec!["Hello world"] + ; "consecutive text coalesces" + )] + #[test_case( + vec!["Hello", "*img:pic1", "world"], + vec!["Hello", "*img", "world"] + ; "non-text breaks coalescing" + )] + #[test_case( + vec!["A", "B", "*img:pic1", "C", "D", "*tool:read", "E", "F"], + vec!["AB", "*img", "CD", "*tool:read", "EF"] + ; "multiple text groups" + )] + #[test_case( + vec!["Text1", "*img:pic", "Text2"], + vec!["Text1", "*img", "Text2"] + ; "mixed content in chunk" + )] + #[tokio::test] + async fn test_collect_stream_coalescing(input_items: Vec<&str>, expected: Vec<&str>) { + let items: Vec = input_items.into_iter().map(|s| s.to_string()).collect(); + let stream = create_test_stream(items); + let (msg, _) = collect_stream(Box::pin(stream)).await.unwrap(); + assert_eq!(content_to_strings(&msg), expected); + } + + #[tokio::test] + async fn test_collect_stream_defaults_usage() { + // Should not error when usage is missing + let stream = create_test_stream(vec!["Hello".to_string()]); + let (msg, usage) = collect_stream(Box::pin(stream)).await.unwrap(); + assert_eq!(content_to_strings(&msg), vec!["Hello"]); + assert_eq!(usage.model, "unknown"); // Default usage } #[test] diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index e8889fbefc24..07f390ee760e 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; -use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage}; +use super::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, +}; use super::errors::ProviderError; use super::retry::{ProviderRetry, RetryConfig}; use crate::conversation::message::Message; @@ -312,14 +314,19 @@ impl Provider for BedrockProvider { skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete_with_model( + async fn stream( &self, - session_id: Option<&str>, model_config: &ModelConfig, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { + let session_id = if session_id.is_empty() { + None + } else { + Some(session_id) + }; let model_name = model_config.model_name.clone(); let (bedrock_message, bedrock_usage) = self @@ -346,7 +353,10 @@ impl Provider for BedrockProvider { )?; let provider_usage = ProviderUsage::new(model_name.to_string(), usage); - Ok((message, provider_usage)) + Ok(super::base::stream_from_single_message( + message, + provider_usage, + )) } } diff --git a/crates/goose/src/providers/chatgpt_codex.rs b/crates/goose/src/providers/chatgpt_codex.rs index fcd63a59a67f..b636c35e66c9 100644 --- a/crates/goose/src/providers/chatgpt_codex.rs +++ b/crates/goose/src/providers/chatgpt_codex.rs @@ -2,9 +2,7 @@ use crate::config::paths::Paths; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::api_client::AuthProvider; -use crate::providers::base::{ - ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, -}; +use crate::providers::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use crate::providers::errors::ProviderError; use crate::providers::formats::openai_responses::responses_api_to_streaming_message; use crate::providers::openai_compatible::handle_status_openai_compat; @@ -886,76 +884,15 @@ impl Provider for ChatGptCodexProvider { self.model.clone() } - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - // ChatGPT Codex API requires streaming - collect the stream into a single response - let mut payload = create_codex_request(model_config, system, messages, tools) - .map_err(|e| ProviderError::ExecutionError(e.to_string()))?; - payload["stream"] = serde_json::Value::Bool(true); - - let response = self - .with_retry(|| async { - let payload_clone = payload.clone(); - self.post_streaming(session_id, &payload_clone).await - }) - .await?; - - let stream = response.bytes_stream().map_err(io::Error::other); - let stream_reader = StreamReader::new(stream); - let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); - - let message_stream = responses_api_to_streaming_message(framed); - pin!(message_stream); - - let mut final_message: Option = None; - let mut final_usage: Option = None; - - while let Some(result) = message_stream.next().await { - let (message, usage) = result - .map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; - if let Some(msg) = message { - final_message = Some(msg); - } - if let Some(u) = usage { - final_usage = Some(u); - } - } - - let message = final_message.ok_or_else(|| { - ProviderError::ExecutionError("No message received from stream".to_string()) - })?; - let usage = final_usage.unwrap_or_else(|| { - ProviderUsage::new( - model_config.model_name.clone(), - crate::providers::base::Usage::default(), - ) - }); - - Ok((message, usage)) - } - - fn supports_streaming(&self) -> bool { - true - } - async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result { - let mut payload = create_codex_request(&self.model, system, messages, tools) + let mut payload = create_codex_request(model_config, system, messages, tools) .map_err(|e| ProviderError::ExecutionError(e.to_string()))?; payload["stream"] = serde_json::Value::Bool(true); diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index a00d21cd69a9..d7ca8179d056 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -17,7 +17,7 @@ use super::base::{ ProviderUsage, Usage, }; use super::errors::ProviderError; -use super::utils::{filter_extensions_from_system_prompt, RequestLog}; +use super::utils::filter_extensions_from_system_prompt; use crate::config::base::ClaudeCodeCommand; use crate::config::paths::Paths; use crate::config::search_path::SearchPaths; @@ -307,73 +307,6 @@ impl ClaudeCodeProvider { Ok(()) } - fn parse_claude_response( - &self, - json_lines: &[String], - ) -> Result<(Message, Usage), ProviderError> { - let mut all_text_content = Vec::new(); - let mut usage = Usage::default(); - - for line in json_lines { - if let Ok(parsed) = serde_json::from_str::(line) { - match parsed.get("type").and_then(|t| t.as_str()) { - Some("assistant") => { - if let Some(message) = parsed.get("message") { - if let Some(content) = message.get("content").and_then(|c| c.as_array()) - { - for item in content { - if item.get("type").and_then(|t| t.as_str()) == Some("text") { - if let Some(text) = - item.get("text").and_then(|t| t.as_str()) - { - all_text_content.push(text.to_string()); - } - } - } - } - - if let Some(usage_info) = message.get("usage") { - usage = extract_usage_tokens(usage_info); - } - } - } - Some("result") => { - if let Some(result_usage) = parsed.get("usage") { - let new = extract_usage_tokens(result_usage); - usage = Usage::new( - usage.input_tokens.or(new.input_tokens), - usage.output_tokens.or(new.output_tokens), - None, - ); - } - } - Some("error") => { - return Err(error_from_event("Claude CLI", &parsed)); - } - Some("system") => {} - _ => {} - } - } - } - - let combined_text = all_text_content.join("\n\n"); - if combined_text.is_empty() { - return Err(ProviderError::RequestFailed( - "No text content found in response".to_string(), - )); - } - - let message_content = vec![MessageContent::text(combined_text)]; - - let response_message = Message::new( - Role::Assistant, - chrono::Utc::now().timestamp(), - message_content, - ); - - Ok((response_message, usage)) - } - fn spawn_process(&self, filtered_system: &str) -> Result { let mut cmd = self.build_stream_json_command(); @@ -440,103 +373,6 @@ impl ClaudeCodeProvider { }) .await } - - async fn execute_command( - &self, - system: &str, - messages: &[Message], - _tools: &[Tool], - session_id: &str, - model: &str, - ) -> Result, ProviderError> { - let filtered_system = filter_extensions_from_system_prompt(system); - - tracing::debug!( - command = ?self.command, - system_prompt_len = system.len(), - filtered_system_prompt_len = filtered_system.len(), - "Executing Claude CLI command" - ); - - let process_mutex = self.get_or_init_process(&filtered_system).await?; - let mut process = process_mutex.lock().await; - - // Drain any pending response from a cancelled stream - process.drain_pending_response().await; - - // Switch model if it differs from what the CLI is currently using. - process.send_set_model(model).await?; - - let blocks = self.last_user_content_blocks(messages); - - // Write NDJSON line to stdin - let ndjson_line = build_stream_json_input(&blocks, session_id); - process - .stdin - .write_all(ndjson_line.as_bytes()) - .await - .map_err(|e| { - ProviderError::RequestFailed(format!("Failed to write to stdin: {}", e)) - })?; - process.stdin.write_all(b"\n").await.map_err(|e| { - ProviderError::RequestFailed(format!("Failed to write newline to stdin: {}", e)) - })?; - - // Read lines until we see a "result" or "error" event - let mut lines = Vec::new(); - let mut line = String::new(); - - loop { - line.clear(); - match process.reader.read_line(&mut line).await { - Ok(0) => { - return Err(ProviderError::RequestFailed( - "Claude CLI process terminated unexpectedly".to_string(), - )); - } - Ok(_) => { - let trimmed = line.trim(); - if trimmed.is_empty() { - continue; - } - - if let Ok(parsed) = serde_json::from_str::(trimmed) { - match parsed.get("type").and_then(|t| t.as_str()) { - Some("stream_event") => continue, - Some("result") | Some("error") => { - lines.push(trimmed.to_string()); - break; - } - // The system init with the resolved model arrives here, - // not in send_set_model (which only sees control_response). - Some("system") if process.log_model_update => { - if let Some(resolved) = parsed.get("model").and_then(|m| m.as_str()) - { - tracing::debug!( - from = %process.current_model, - to = %resolved, - "set_model resolved" - ); - } - process.log_model_update = false; - } - _ => {} - } - } - lines.push(trimmed.to_string()); - } - Err(e) => { - return Err(ProviderError::RequestFailed(format!( - "Failed to read output: {}", - e - ))); - } - } - } - - tracing::debug!("Command executed successfully, got {} lines", lines.len()); - Ok(lines) - } } /// Extract model aliases from the CLI's initialize control_response. @@ -755,60 +591,9 @@ impl Provider for ClaudeCodeProvider { Ok(parse_models_from_lines(&lines)) } - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - if super::cli_common::is_session_description_request(system) { - return super::cli_common::generate_simple_session_description( - &model_config.model_name, - messages, - ); - } - - // session_id is None before a session is created (e.g. model listing). - let sid = session_id.unwrap_or("default"); - let json_lines = self - .execute_command(system, messages, tools, sid, &model_config.model_name) - .await?; - - let (message, usage) = self.parse_claude_response(&json_lines)?; - - let payload = json!({ - "command": self.command, - "model": model_config.model_name, - "system": system, - "messages": messages.len() - }); - let mut log = RequestLog::start(model_config, &payload)?; - - let response = json!({ - "lines": json_lines.len(), - "usage": usage - }); - - log.write(&response, Some(&usage))?; - - Ok(( - message, - ProviderUsage::new(model_config.model_name.clone(), usage), - )) - } - - fn supports_streaming(&self) -> bool { - true - } - async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], @@ -816,7 +601,7 @@ impl Provider for ClaudeCodeProvider { ) -> Result { if super::cli_common::is_session_description_request(system) { let (message, usage) = super::cli_common::generate_simple_session_description( - &self.model.model_name, + &model_config.model_name, messages, )?; return Ok(stream_from_single_message(message, usage)); @@ -828,7 +613,7 @@ impl Provider for ClaudeCodeProvider { // Prepare the payload outside the lock โ€” these don't need the process. let blocks = self.last_user_content_blocks(messages); let ndjson_line = build_stream_json_input(&blocks, session_id); - let model_name = self.model.model_name.clone(); + let model_name = model_config.model_name.clone(); let message_id = uuid::Uuid::new_v4().to_string(); Ok(Box::pin(try_stream! { @@ -1118,87 +903,6 @@ mod tests { assert_eq!(parsed, expected); } - #[test_case( - &[ - r#"{"type":"assistant","message":{"content":[{"type":"text","text":"Hello! How can I help you today?"}],"usage":{"input_tokens":3,"output_tokens":3}}}"#, - r#"{"type":"result","usage":{"input_tokens":3,"output_tokens":16}}"#, - ], - "Hello! How can I help you today?", - Some(3), Some(3) - ; "assistant_with_usage" - )] - #[test_case( - &[ - r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}}"#, - ], - "First\n\nSecond", - None, None - ; "multiple_text_blocks" - )] - #[test_case( - &[ - r#"{"type":"system","model":"claude-opus-4-6"}"#, - r#"{"type":"assistant","message":{"content":[{"type":"text","text":"Hello!"}],"usage":{"input_tokens":3,"output_tokens":3}}}"#, - r#"{"type":"result","usage":{"input_tokens":3,"output_tokens":16}}"#, - ], - "Hello!", - Some(3), Some(3) - ; "system_init_filtered" - )] - #[test_case( - &[ - r#"{"type":"stream_event","event":{"type":"content_block_delta","delta":{"type":"text_delta","text":"He"}}}"#, - r#"{"type":"stream_event","event":{"type":"content_block_delta","delta":{"type":"text_delta","text":"llo"}}}"#, - r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Hello"}],"usage":{"input_tokens":50,"output_tokens":10}}}"#, - r#"{"type":"result","subtype":"success","result":"Hello","session_id":"abc"}"#, - ], - "Hello", - Some(50), Some(10) - ; "streaming_events_ignored_by_parse" - )] - fn test_parse_claude_response_ok( - lines: &[&str], - expected_text: &str, - expected_input: Option, - expected_output: Option, - ) { - let provider = make_provider(); - let lines: Vec = lines.iter().map(|s| s.to_string()).collect(); - let (message, usage) = provider.parse_claude_response(&lines).unwrap(); - assert_eq!(message.role, Role::Assistant); - if let MessageContent::Text(t) = &message.content[0] { - assert_eq!(t.text, expected_text); - } else { - panic!("expected text content"); - } - assert_eq!(usage.input_tokens, expected_input); - assert_eq!(usage.output_tokens, expected_output); - } - - #[test_case( - &[], - ProviderError::RequestFailed("No text content found in response".into()) - ; "empty_lines" - )] - #[test_case( - &[r#"{"type":"error","error":"context window exceeded"}"#], - ProviderError::ContextLengthExceeded("context window exceeded".into()) - ; "context_length" - )] - #[test_case( - &[r#"{"type":"error","error":"Model not supported"}"#], - ProviderError::RequestFailed("Claude CLI error: Model not supported".into()) - ; "generic_error" - )] - fn test_parse_claude_response_err(lines: &[&str], expected: ProviderError) { - let provider = make_provider(); - let lines: Vec = lines.iter().map(|s| s.to_string()).collect(); - assert_eq!( - provider.parse_claude_response(&lines).unwrap_err(), - expected - ); - } - #[test_case( &[ r#"{"type":"control_response","response":{"subtype":"success","request_id":"model_list","response":{"models":[{"value":"default","displayName":"Default (recommended)","description":"Opus 4.6 ยท Most capable for complex work"},{"value":"sonnet","displayName":"Sonnet","description":"Sonnet 4.5 ยท Best for everyday tasks"},{"value":"haiku","displayName":"Haiku","description":"Haiku 4.5 ยท Fastest for quick answers"}]}}}"#, diff --git a/crates/goose/src/providers/codex.rs b/crates/goose/src/providers/codex.rs index 41589362f13a..7da73c4c0a6d 100644 --- a/crates/goose/src/providers/codex.rs +++ b/crates/goose/src/providers/codex.rs @@ -10,7 +10,9 @@ use tempfile::NamedTempFile; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, +}; use super::errors::ProviderError; use super::utils::{filter_extensions_from_system_prompt, RequestLog}; use crate::config::base::{CodexCommand, CodexReasoningEffort, CodexSkipGitCheck}; @@ -671,19 +673,23 @@ impl Provider for CodexProvider { skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete_with_model( + async fn stream( &self, - _session_id: Option<&str>, // CLI has no external session-id flag to propagate. model_config: &ModelConfig, + _session_id: &str, // CLI has no external session-id flag to propagate. system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { if super::cli_common::is_session_description_request(system) { - return super::cli_common::generate_simple_session_description( + let (message, provider_usage) = super::cli_common::generate_simple_session_description( &model_config.model_name, messages, - ); + )?; + return Ok(super::base::stream_from_single_message( + message, + provider_usage, + )); } let lines = self.execute_command(system, messages, tools).await?; @@ -712,9 +718,10 @@ impl Provider for CodexProvider { ProviderError::RequestFailed(format!("Failed to write request log: {}", e)) })?; - Ok(( + let provider_usage = ProviderUsage::new(model_config.model_name.clone(), usage); + Ok(super::base::stream_from_single_message( message, - ProviderUsage::new(model_config.model_name.clone(), usage), + provider_usage, )) } diff --git a/crates/goose/src/providers/cursor_agent.rs b/crates/goose/src/providers/cursor_agent.rs index 21df99b69124..aeab6136cddc 100644 --- a/crates/goose/src/providers/cursor_agent.rs +++ b/crates/goose/src/providers/cursor_agent.rs @@ -7,7 +7,10 @@ use std::process::Stdio; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ + stream_from_single_message, ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, + ProviderUsage, Usage, +}; use super::errors::ProviderError; use super::utils::{filter_extensions_from_system_prompt, RequestLog}; use crate::config::base::CursorAgentCommand; @@ -171,7 +174,6 @@ impl CursorAgentProvider { message_content, ); let usage = Usage::default(); - Ok((response_message, usage)) } @@ -325,19 +327,20 @@ impl Provider for CursorAgentProvider { skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete_with_model( + async fn stream( &self, - _session_id: Option<&str>, // CLI has no external session-id flag to propagate. model_config: &ModelConfig, + _session_id: &str, // CLI has no external session-id flag to propagate. system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { if super::cli_common::is_session_description_request(system) { - return super::cli_common::generate_simple_session_description( + let (message, provider_usage) = super::cli_common::generate_simple_session_description( &model_config.model_name, messages, - ); + )?; + return Ok(stream_from_single_message(message, provider_usage)); } let lines = self.execute_command(system, messages, tools).await?; @@ -360,9 +363,7 @@ impl Provider for CursorAgentProvider { let mut log = RequestLog::start(&self.model, &payload)?; log.write(&response, Some(&usage))?; - Ok(( - message, - ProviderUsage::new(model_config.model_name.clone(), usage), - )) + let provider_usage = ProviderUsage::new(model_config.model_name.clone(), usage); + Ok(stream_from_single_message(message, provider_usage)) } } diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index c5509fced491..34f4239e3f81 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -6,22 +6,19 @@ use serde_json::Value; use std::time::Duration; use super::api_client::{ApiClient, AuthMethod, AuthProvider}; -use super::base::{ - ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, -}; +use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use super::embedding::EmbeddingCapable; use super::errors::ProviderError; -use super::formats::databricks::{create_request, response_to_message}; +use super::formats::databricks::create_request; use super::oauth; use super::openai_compatible::{ handle_response_openai_compat, map_http_error_to_provider_error, stream_openai_compat, }; use super::retry::ProviderRetry; -use super::utils::{get_model, ImageFormat, RequestLog}; +use super::utils::{ImageFormat, RequestLog}; use crate::config::ConfigError; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::formats::openai::get_usage; use crate::providers::retry::{ RetryConfig, DEFAULT_BACKOFF_MULTIPLIER, DEFAULT_INITIAL_RETRY_INTERVAL_MS, DEFAULT_MAX_RETRIES, DEFAULT_MAX_RETRY_INTERVAL_MS, @@ -277,70 +274,16 @@ impl Provider for DatabricksProvider { self.model.clone() } - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let mut payload = - create_request(model_config, system, messages, tools, &self.image_format)?; - payload - .as_object_mut() - .expect("payload should have model key") - .remove("model"); - - let mut log = RequestLog::start(&self.model, &payload)?; - - // Use fast retry config if this is the fast model - let is_fast_model = self - .model - .fast_model_config - .as_ref() - .map(|fast| fast.model_name == model_config.model_name) - .unwrap_or(false); - - let retry_config = if is_fast_model { - self.fast_retry_config.clone() - } else { - self.retry_config.clone() - }; - - let response = self - .with_retry_config( - || self.post(session_id, payload.clone(), Some(&model_config.model_name)), - retry_config, - ) - .await?; - - let message = response_to_message(&response)?; - let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { - tracing::debug!("Failed to get usage data"); - Usage::default() - }); - let response_model = get_model(&response); - log.write(&response, Some(&usage))?; - - Ok((message, ProviderUsage::new(response_model, usage))) - } - async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result { - let model_config = self.model.clone(); - let mut payload = - create_request(&model_config, system, messages, tools, &self.image_format)?; + create_request(model_config, system, messages, tools, &self.image_format)?; payload .as_object_mut() .expect("payload should have model key") @@ -352,7 +295,7 @@ impl Provider for DatabricksProvider { .insert("stream".to_string(), Value::Bool(true)); let path = self.get_endpoint_path(&model_config.model_name, false); - let mut log = RequestLog::start(&self.model, &payload)?; + let mut log = RequestLog::start(model_config, &payload)?; let response = self .with_retry(|| async { let resp = self @@ -377,10 +320,6 @@ impl Provider for DatabricksProvider { stream_openai_compat(response, log) } - fn supports_streaming(&self) -> bool { - true - } - fn supports_embeddings(&self) -> bool { true } diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index b7d2a626d106..a089e65bdd60 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -16,14 +16,12 @@ use url::Url; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::base::{ - ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, -}; +use crate::providers::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use crate::providers::errors::ProviderError; use crate::providers::formats::gcpvertexai::{ - create_request, get_usage, response_to_message, response_to_streaming_message, GcpLocation, - ModelProvider, RequestContext, DEFAULT_MODEL, KNOWN_MODELS, + create_request, response_to_streaming_message, GcpLocation, ModelProvider, RequestContext, + DEFAULT_MODEL, KNOWN_MODELS, }; use crate::providers::gcpauth::GcpAuth; use crate::providers::openai_compatible::map_http_error_to_provider_error; @@ -358,58 +356,6 @@ impl GcpVertexAIProvider { } } - async fn post_with_location( - &self, - session_id: Option<&str>, - payload: &Value, - context: &RequestContext, - location: &str, - ) -> Result { - let url = self - .build_request_url(context.provider(), location, false) - .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; - - let response = self - .send_request_with_retry(session_id, url, payload) - .await?; - - response - .json::() - .await - .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse response: {e}"))) - } - - async fn post( - &self, - session_id: Option<&str>, - payload: &Value, - context: &RequestContext, - ) -> Result { - let result = self - .post_with_location(session_id, payload, context, &self.location) - .await; - - if self.location == context.model.known_location().to_string() || result.is_ok() { - return result; - } - - match &result { - Err(ProviderError::RequestFailed(msg)) => { - let model_name = context.model.to_string(); - let configured_location = &self.location; - let known_location = context.model.known_location().to_string(); - - tracing::warn!( - "Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}" - ); - - self.post_with_location(session_id, payload, context, &known_location) - .await - } - _ => result, - } - } - async fn post_stream_with_location( &self, session_id: Option<&str>, @@ -613,53 +559,20 @@ impl Provider for GcpVertexAIProvider { /// * `system` - System prompt or context /// * `messages` - Array of previous messages in the conversation /// * `tools` - Array of available tools for the model - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - // Create request and context - let (request, context) = create_request(model_config, system, messages, tools)?; - - // Send request and process response - let response = self.post(session_id, &request, &context).await?; - let usage = get_usage(&response, &context)?; - - let mut log = RequestLog::start(model_config, &request)?; - log.write(&response, Some(&usage))?; - - // Convert response to message - let message = response_to_message(response, context)?; - let provider_usage = ProviderUsage::new(self.model.model_name.clone(), usage); - - Ok((message, provider_usage)) - } - /// Returns the current model configuration. fn get_model_config(&self) -> ModelConfig { self.model.clone() } - fn supports_streaming(&self) -> bool { - true - } - async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result { - let model_config = self.get_model_config(); - let (mut request, context) = create_request(&model_config, system, messages, tools)?; + let (mut request, context) = create_request(model_config, system, messages, tools)?; if matches!(context.provider(), ModelProvider::Anthropic) { if let Some(obj) = request.as_object_mut() { @@ -667,7 +580,7 @@ impl Provider for GcpVertexAIProvider { } } - let mut log = RequestLog::start(&model_config, &request)?; + let mut log = RequestLog::start(model_config, &request)?; let response = self .post_stream(Some(session_id), &request, &context) diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index 33ffd3e61673..746e1d6b73f8 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -21,7 +21,6 @@ use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::base::ConfigKey; use crate::subprocess::configure_subprocess; -use async_stream::try_stream; use futures::future::BoxFuture; use rmcp::model::Role; use rmcp::model::Tool; @@ -341,19 +340,20 @@ impl Provider for GeminiCliProvider { skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete_with_model( + async fn stream( &self, - _session_id: Option<&str>, model_config: &ModelConfig, + _session_id: &str, // CLI has no external session-id flag to propagate. system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { if super::cli_common::is_session_description_request(system) { - return super::cli_common::generate_simple_session_description( + let (message, provider_usage) = super::cli_common::generate_simple_session_description( &model_config.model_name, messages, - ); + )?; + return Ok(stream_from_single_message(message, provider_usage)); } let payload = json!({ @@ -381,133 +381,8 @@ impl Provider for GeminiCliProvider { ProviderError::RequestFailed(format!("Failed to write request log: {e}")) })?; - Ok(( - message, - ProviderUsage::new(model_config.model_name.clone(), usage), - )) - } - - fn supports_streaming(&self) -> bool { - true - } - - async fn stream( - &self, - _session_id: &str, - system: &str, - messages: &[Message], - _tools: &[Tool], - ) -> Result { - if super::cli_common::is_session_description_request(system) { - let (message, usage) = super::cli_common::generate_simple_session_description( - &self.model.model_name, - messages, - )?; - return Ok(stream_from_single_message(message, usage)); - } - - let (mut child, mut reader) = - self.spawn_command(system, messages, &self.model.model_name)?; - let session_id_lock = Arc::clone(&self.cli_session_id); - let model_name = self.model.model_name.clone(); - let message_id = uuid::Uuid::new_v4().to_string(); - - // Drain stderr concurrently to avoid pipe deadlock - let stderr = child.stderr.take(); - let stderr_drain = tokio::spawn(async move { - let mut buf = String::new(); - if let Some(mut stderr) = stderr { - let _ = AsyncReadExt::read_to_string(&mut stderr, &mut buf).await; - } - buf - }); - - Ok(Box::pin(try_stream! { - let mut line = String::new(); - let mut accumulated_usage = Usage::default(); - let stream_timestamp = chrono::Utc::now().timestamp(); - - loop { - line.clear(); - match reader.read_line(&mut line).await { - Ok(0) => break, - Ok(_) => { - let trimmed = line.trim(); - if trimmed.is_empty() { - continue; - } - - if let Ok(parsed) = serde_json::from_str::(trimmed) { - match parsed.get("type").and_then(|t| t.as_str()) { - Some("init") => { - if let Some(sid) = - parsed.get("session_id").and_then(|s| s.as_str()) - { - let _ = session_id_lock.set(sid.to_string()); - } - } - Some("message") => { - let is_assistant = parsed.get("role").and_then(|r| r.as_str()) - == Some("assistant"); - let content = parsed - .get("content") - .and_then(|c| c.as_str()) - .unwrap_or(""); - if is_assistant && !content.is_empty() { - let mut partial = Message::new( - Role::Assistant, - stream_timestamp, - vec![MessageContent::text(content)], - ); - partial.id = Some(message_id.clone()); - yield (Some(partial), None); - } - } - Some("result") => { - if let Some(stats) = parsed.get("stats") { - accumulated_usage = extract_usage_tokens(stats); - } - break; - } - Some("error") => { - let _ = child.wait().await; - Err(error_from_event("Gemini CLI", &parsed))?; - } - _ => {} - } - } else { - tracing::warn!(line = trimmed, "Non-JSON line in stream-json output"); - } - } - Err(e) => { - let _ = child.wait().await; - Err(ProviderError::RequestFailed(format!( - "Failed to read streaming output: {e}" - )))?; - } - } - } - - let stderr_text = stderr_drain.await.unwrap_or_default(); - let exit_status = child.wait().await.map_err(|e| { - ProviderError::RequestFailed(format!("Failed to wait for command: {e}")) - })?; - - if !exit_status.success() { - let stderr_snippet = stderr_text.trim(); - let detail = if stderr_snippet.is_empty() { - format!("exit code {:?}", exit_status.code()) - } else { - format!("exit code {:?}: {stderr_snippet}", exit_status.code()) - }; - Err(ProviderError::RequestFailed(format!( - "Gemini CLI command failed ({detail})" - )))?; - } - - let provider_usage = ProviderUsage::new(model_name, accumulated_usage); - yield (None, Some(provider_usage)); - })) + let provider_usage = ProviderUsage::new(model_config.model_name.clone(), usage); + Ok(stream_from_single_message(message, provider_usage)) } } diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index d5a51d51b4a2..e7f767e4ba08 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -417,85 +417,85 @@ impl Provider for GithubCopilotProvider { self.model.clone() } - fn supports_streaming(&self) -> bool { - GITHUB_COPILOT_STREAM_MODELS - .iter() - .any(|prefix| self.model.model_name.starts_with(prefix)) - } - - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request( - model_config, - system, - messages, - tools, - &ImageFormat::OpenAi, - false, - )?; - let mut log = RequestLog::start(model_config, &payload)?; - - // Make request with retry - let response = self - .with_retry(|| async { - let mut payload_clone = payload.clone(); - self.post(session_id, &mut payload_clone).await - }) - .await?; - let response = handle_response_openai_compat(response).await?; - - let response = promote_tool_choice(response); - - // Parse response - let message = response_to_message(&response)?; - let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { - tracing::debug!("Failed to get usage data"); - Usage::default() - }); - let response_model = get_model(&response); - log.write(&response, Some(&usage))?; - Ok((message, ProviderUsage::new(response_model, usage))) - } - async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result { - let payload = create_request( - &self.model, - system, - messages, - tools, - &ImageFormat::OpenAi, - true, - )?; - let mut log = RequestLog::start(&self.model, &payload)?; - - let response = self - .with_retry(|| async { - let mut payload_clone = payload.clone(); - let resp = self.post(Some(session_id), &mut payload_clone).await?; - handle_status_openai_compat(resp).await - }) - .await - .inspect_err(|e| { - let _ = log.error(e); - })?; - - stream_openai_compat(response, log) + // Check if this model supports streaming + let supports_streaming = GITHUB_COPILOT_STREAM_MODELS + .iter() + .any(|prefix| model_config.model_name.starts_with(prefix)); + + if supports_streaming { + // Use streaming API + let payload = create_request( + model_config, + system, + messages, + tools, + &ImageFormat::OpenAi, + true, + )?; + let mut log = RequestLog::start(model_config, &payload)?; + + let response = self + .with_retry(|| async { + let mut payload_clone = payload.clone(); + let resp = self.post(Some(session_id), &mut payload_clone).await?; + handle_status_openai_compat(resp).await + }) + .await + .inspect_err(|e| { + let _ = log.error(e); + })?; + + stream_openai_compat(response, log) + } else { + // Use non-streaming API and wrap result + let session_id_opt = if session_id.is_empty() { + None + } else { + Some(session_id) + }; + let payload = create_request( + model_config, + system, + messages, + tools, + &ImageFormat::OpenAi, + false, + )?; + let mut log = RequestLog::start(model_config, &payload)?; + + // Make request with retry + let response = self + .with_retry(|| async { + let mut payload_clone = payload.clone(); + self.post(session_id_opt, &mut payload_clone).await + }) + .await?; + let response = handle_response_openai_compat(response).await?; + + let response = promote_tool_choice(response); + + // Parse response + let message = response_to_message(&response)?; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); + let response_model = get_model(&response); + log.write(&response, Some(&usage))?; + + Ok(super::base::stream_from_single_message( + message, + ProviderUsage::new(response_model, usage), + )) + } } async fn fetch_supported_models(&self) -> Result, ProviderError> { diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 217aa1905ec2..769cc54200ff 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -3,14 +3,12 @@ use super::base::MessageStream; use super::errors::ProviderError; use super::openai_compatible::handle_status_openai_compat; use super::retry::ProviderRetry; -use super::utils::{handle_response_google_compat, unescape_json_values, RequestLog}; +use super::utils::RequestLog; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage}; -use crate::providers::formats::google::{ - create_request, get_usage, response_to_message, response_to_streaming_message, -}; +use crate::providers::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata}; +use crate::providers::formats::google::{create_request, response_to_streaming_message}; use anyhow::Result; use async_stream::try_stream; use async_trait::async_trait; @@ -92,20 +90,6 @@ impl GoogleProvider { }) } - async fn post( - &self, - session_id: Option<&str>, - model_name: &str, - payload: &Value, - ) -> Result { - let path = format!("v1beta/models/{}:generateContent", model_name); - let response = self - .api_client - .response_post(session_id, &path, payload) - .await?; - handle_response_google_compat(response).await - } - async fn post_stream( &self, session_id: Option<&str>, @@ -157,39 +141,6 @@ impl Provider for GoogleProvider { self.model.clone() } - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request(model_config, system, messages, tools)?; - let mut log = RequestLog::start(model_config, &payload)?; - - let response = self - .with_retry(|| async { - self.post(session_id, &model_config.model_name, &payload) - .await - }) - .await?; - - let message = response_to_message(unescape_json_values(&response))?; - let usage = get_usage(&response)?; - let response_model = match response.get("modelVersion") { - Some(model_version) => model_version.as_str().unwrap_or_default().to_string(), - None => model_config.model_name.clone(), - }; - log.write(&response, Some(&usage))?; - let provider_usage = ProviderUsage::new(response_model, usage); - Ok((message, provider_usage)) - } - async fn fetch_supported_models(&self) -> Result, ProviderError> { let response = self .api_client @@ -214,23 +165,20 @@ impl Provider for GoogleProvider { Ok(models) } - fn supports_streaming(&self) -> bool { - true - } - async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result { - let payload = create_request(&self.model, system, messages, tools)?; - let mut log = RequestLog::start(&self.model, &payload)?; + let payload = create_request(model_config, system, messages, tools)?; + let mut log = RequestLog::start(model_config, &payload)?; let response = self .with_retry(|| async { - self.post_stream(Some(session_id), &self.model.model_name, &payload) + self.post_stream(Some(session_id), &model_config.model_name, &payload) .await }) .await diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index ed1fddbd284c..7674620ed5ae 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -5,7 +5,8 @@ use std::sync::Arc; use tokio::sync::Mutex; use super::base::{ - LeadWorkerProviderTrait, Provider, ProviderDef, ProviderMetadata, ProviderUsage, + collect_stream, stream_from_single_message, LeadWorkerProviderTrait, MessageStream, Provider, + ProviderDef, ProviderMetadata, ProviderUsage, }; use super::errors::ProviderError; use crate::conversation::message::{Message, MessageContent}; @@ -356,14 +357,14 @@ impl Provider for LeadWorkerProvider { self.lead_provider.get_model_config() } - async fn complete_with_model( + async fn stream( &self, - session_id: Option<&str>, _model_config: &ModelConfig, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { // Get the active provider let provider = self.get_active_provider().await; @@ -410,9 +411,13 @@ impl Provider for LeadWorkerProvider { // Make the completion request let model_config = provider.get_model_config(); - let result = provider - .complete_with_model(session_id, &model_config, system, messages, tools) + let stream_result = provider + .stream(&model_config, session_id, system, messages, tools) .await; + let result = match stream_result { + Ok(stream) => collect_stream(stream).await, + Err(e) => Err(e), + }; // For technical failures, try with default model (lead provider) instead let final_result = match &result { @@ -421,10 +426,14 @@ impl Provider for LeadWorkerProvider { // Try with lead provider as the default/fallback for technical failures let model_config = self.lead_provider.get_model_config(); - let default_result = self + let default_stream_result = self .lead_provider - .complete_with_model(session_id, &model_config, system, messages, tools) + .stream(&model_config, session_id, system, messages, tools) .await; + let default_result = match default_stream_result { + Ok(stream) => collect_stream(stream).await, + Err(e) => Err(e), + }; match &default_result { Ok(_) => { @@ -445,7 +454,10 @@ impl Provider for LeadWorkerProvider { // Handle the result and update tracking (only for successful completions) self.handle_completion_result(&final_result).await; - final_result + match final_result { + Ok((message, usage)) => Ok(stream_from_single_message(message, usage)), + Err(e) => Err(e), + } } async fn fetch_supported_models(&self) -> Result, ProviderError> { @@ -514,28 +526,27 @@ mod tests { self.model_config.clone() } - async fn complete_with_model( + async fn stream( &self, - _session_id: Option<&str>, _model_config: &ModelConfig, + _session_id: &str, _system: &str, _messages: &[Message], _tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - Ok(( - Message::new( - Role::Assistant, - Utc::now().timestamp(), - vec![MessageContent::Text( - RawTextContent { - text: format!("Response from {}", self.name), - meta: None, - } - .no_annotation(), - )], - ), - ProviderUsage::new(self.name.clone(), Usage::default()), - )) + ) -> Result { + let message = Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text( + RawTextContent { + text: format!("Response from {}", self.name), + meta: None, + } + .no_annotation(), + )], + ); + let usage = ProviderUsage::new(self.name.clone(), Usage::default()); + Ok(stream_from_single_message(message, usage)) } } @@ -552,11 +563,12 @@ mod tests { }); let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(3)); + let model_config = provider.get_model_config(); // First three turns should use lead provider for i in 0..3 { let (_message, usage) = provider - .complete("test-session-id", "system", &[], &[]) + .complete(&model_config, "test-session-id", "system", &[], &[]) .await .unwrap(); assert_eq!(usage.model, "lead"); @@ -567,7 +579,7 @@ mod tests { // Subsequent turns should use worker provider for i in 3..6 { let (_message, usage) = provider - .complete("test-session-id", "system", &[], &[]) + .complete(&model_config, "test-session-id", "system", &[], &[]) .await .unwrap(); assert_eq!(usage.model, "worker"); @@ -582,7 +594,7 @@ mod tests { assert!(!provider.is_in_fallback_mode().await); let (_message, usage) = provider - .complete("test-session-id", "system", &[], &[]) + .complete(&model_config, "test-session-id", "system", &[], &[]) .await .unwrap(); assert_eq!(usage.model, "lead"); @@ -603,11 +615,12 @@ mod tests { }); let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2)); + let model_config = provider.get_model_config(); // First two turns use lead (should succeed) for _i in 0..2 { let result = provider - .complete("test-session-id", "system", &[], &[]) + .complete(&model_config, "test-session-id", "system", &[], &[]) .await; assert!(result.is_ok()); assert_eq!(result.unwrap().1.model, "lead"); @@ -615,8 +628,9 @@ mod tests { } // Next turn uses worker (will fail, but should retry with lead and succeed) + let model_config = provider.get_model_config(); let result = provider - .complete("test-session-id", "system", &[], &[]) + .complete(&model_config, "test-session-id", "system", &[], &[]) .await; assert!(result.is_ok()); // Should succeed because lead provider is used as fallback assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider @@ -624,8 +638,9 @@ mod tests { assert!(!provider.is_in_fallback_mode().await); // Not in fallback mode // Another turn - should still try worker first, then retry with lead + let model_config = provider.get_model_config(); let result = provider - .complete("test-session-id", "system", &[], &[]) + .complete(&model_config, "test-session-id", "system", &[], &[]) .await; assert!(result.is_ok()); // Should succeed because lead provider is used as fallback assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider @@ -663,16 +678,18 @@ mod tests { } // Should use lead provider in fallback mode + let model_config = provider.get_model_config(); let result = provider - .complete("test-session-id", "system", &[], &[]) + .complete(&model_config, "test-session-id", "system", &[], &[]) .await; assert!(result.is_ok()); assert_eq!(result.unwrap().1.model, "lead"); assert!(provider.is_in_fallback_mode().await); // One more fallback turn + let model_config = provider.get_model_config(); let result = provider - .complete("test-session-id", "system", &[], &[]) + .complete(&model_config, "test-session-id", "system", &[], &[]) .await; assert!(result.is_ok()); assert_eq!(result.unwrap().1.model, "lead"); @@ -696,33 +713,32 @@ mod tests { self.model_config.clone() } - async fn complete_with_model( + async fn stream( &self, - _session_id: Option<&str>, _model_config: &ModelConfig, + _session_id: &str, _system: &str, _messages: &[Message], _tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { if self.should_fail { Err(ProviderError::ExecutionError( "Simulated failure".to_string(), )) } else { - Ok(( - Message::new( - Role::Assistant, - Utc::now().timestamp(), - vec![MessageContent::Text( - RawTextContent { - text: format!("Response from {}", self.name), - meta: None, - } - .no_annotation(), - )], - ), - ProviderUsage::new(self.name.clone(), Usage::default()), - )) + let message = Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text( + RawTextContent { + text: format!("Response from {}", self.name), + meta: None, + } + .no_annotation(), + )], + ); + let usage = ProviderUsage::new(self.name.clone(), Usage::default()); + Ok(stream_from_single_message(message, usage)) } } } diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index 1df8b5c290db..9f6c8306d942 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -5,7 +5,9 @@ use serde_json::{json, Value}; use std::collections::HashMap; use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderUsage}; +use super::base::{ + ConfigKey, MessageStream, ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderUsage, +}; use super::embedding::EmbeddingCapable; use super::errors::ProviderError; use super::openai_compatible::handle_response_openai_compat; @@ -176,14 +178,19 @@ impl Provider for LiteLLMProvider { } #[tracing::instrument(skip_all, name = "provider_complete")] - async fn complete_with_model( + async fn stream( &self, - session_id: Option<&str>, model_config: &ModelConfig, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { + let session_id = if session_id.is_empty() { + None + } else { + Some(session_id) + }; let mut payload = super::formats::openai::create_request( model_config, system, @@ -209,7 +216,11 @@ impl Provider for LiteLLMProvider { let response_model = get_model(&response); let mut log = RequestLog::start(model_config, &payload)?; log.write(&response, Some(&usage))?; - Ok((message, ProviderUsage::new(response_model, usage))) + let provider_usage = ProviderUsage::new(response_model, usage); + Ok(super::base::stream_from_single_message( + message, + provider_usage, + )) } fn supports_embeddings(&self) -> bool { diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index ff61091b3ce4..2e18ea82384c 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,19 +1,15 @@ use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ - ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, -}; +use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use super::errors::ProviderError; -use super::openai_compatible::{handle_response_openai_compat, handle_status_openai_compat}; +use super::openai_compatible::handle_status_openai_compat; use super::retry::ProviderRetry; -use super::utils::{get_model, ImageFormat, RequestLog}; +use super::utils::{ImageFormat, RequestLog}; use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::config::GooseMode; use crate::conversation::message::Message; use crate::conversation::Conversation; use crate::model::ModelConfig; -use crate::providers::formats::ollama::{ - create_request, get_usage, response_to_message, response_to_streaming_message_ollama, -}; +use crate::providers::formats::ollama::{create_request, response_to_streaming_message_ollama}; use crate::utils::safe_truncate; use anyhow::{Error, Result}; use async_stream::try_stream; @@ -132,25 +128,22 @@ impl OllamaProvider { api_client = api_client.with_headers(header_map)?; } + let supports_streaming = config.supports_streaming.unwrap_or(true); + + if !supports_streaming { + return Err(anyhow::anyhow!( + "Ollama provider does not support non-streaming mode. All Ollama models support streaming. \ + Please remove 'supports_streaming: false' from your provider configuration." + )); + } + Ok(Self { api_client, model, - supports_streaming: config.supports_streaming.unwrap_or(true), + supports_streaming, name: config.name.clone(), }) } - - async fn post( - &self, - session_id: Option<&str>, - payload: &Value, - ) -> Result { - let response = self - .api_client - .response_post(session_id, "v1/chat/completions", payload) - .await?; - handle_response_openai_compat(response).await - } } impl ProviderDef for OllamaProvider { @@ -194,57 +187,6 @@ impl Provider for OllamaProvider { self.model.clone() } - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let config = crate::config::Config::global(); - let goose_mode = config.get_goose_mode().unwrap_or(GooseMode::Auto); - let filtered_tools = if goose_mode == GooseMode::Chat { - &[] - } else { - tools - }; - - let payload = create_request( - model_config, - system, - messages, - filtered_tools, - &ImageFormat::OpenAi, - false, - )?; - - let mut log = RequestLog::start(model_config, &payload)?; - let response = self - .with_retry(|| async { - let payload_clone = payload.clone(); - self.post(session_id, &payload_clone).await - }) - .await - .inspect_err(|e| { - let _ = log.error(e); - })?; - - let message = response_to_message(&response)?; - - let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { - tracing::debug!("Failed to get usage data"); - Usage::default() - }); - let response_model = get_model(&response); - log.write(&response, Some(&usage))?; - Ok((message, ProviderUsage::new(response_model, usage))) - } - async fn generate_session_name( &self, session_id: &str, @@ -252,8 +194,10 @@ impl Provider for OllamaProvider { ) -> Result { let context = self.get_initial_user_messages(messages); let message = Message::user().with_text(self.create_session_name_prompt(&context)); + let model_config = self.get_model_config(); let result = self .complete( + &model_config, session_id, "You are a title generator. Output only the requested title of 4 words or less, with no additional text, reasoning, or explanations.", &[message], @@ -267,12 +211,9 @@ impl Provider for OllamaProvider { Ok(safe_truncate(&description, 100)) } - fn supports_streaming(&self) -> bool { - self.supports_streaming - } - async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], @@ -287,14 +228,14 @@ impl Provider for OllamaProvider { }; let payload = create_request( - &self.model, + model_config, system, messages, filtered_tools, &ImageFormat::OpenAi, true, )?; - let mut log = RequestLog::start(&self.model, &payload)?; + let mut log = RequestLog::start(model_config, &payload)?; let response = self .with_retry(|| async { diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 40fc1ee6d3b3..a27dffff717d 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -1,7 +1,5 @@ use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ - ConfigKey, ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, -}; +use super::base::{ConfigKey, ModelInfo, Provider, ProviderDef, ProviderMetadata}; use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; @@ -13,7 +11,7 @@ use super::openai_compatible::{ handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, }; use super::retry::ProviderRetry; -use super::utils::{get_model, ImageFormat}; +use super::utils::ImageFormat; use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::conversation::message::Message; use anyhow::Result; @@ -22,7 +20,6 @@ use async_trait::async_trait; use futures::future::BoxFuture; use futures::{StreamExt, TryStreamExt}; use reqwest::StatusCode; -use serde_json::Value; use std::collections::HashMap; use std::io; use tokio::pin; @@ -269,34 +266,6 @@ impl OpenAiProvider { fallback.to_string() } } - - async fn post( - &self, - session_id: Option<&str>, - payload: &Value, - ) -> Result { - let response = self - .api_client - .response_post(session_id, &self.base_path, payload) - .await?; - handle_response_openai_compat(response).await - } - - async fn post_responses( - &self, - session_id: Option<&str>, - payload: &Value, - ) -> Result { - let response = self - .api_client - .response_post( - session_id, - &Self::map_base_path(&self.base_path, "responses", OPEN_AI_DEFAULT_RESPONSES_PATH), - payload, - ) - .await?; - handle_response_openai_compat(response).await - } } impl ProviderDef for OpenAiProvider { @@ -344,82 +313,6 @@ impl Provider for OpenAiProvider { self.model.clone() } - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - if Self::should_use_responses_api(&model_config.model_name, &self.base_path) { - let payload = create_responses_request(model_config, system, messages, tools)?; - let mut log = RequestLog::start(&self.model, &payload)?; - - let json_response = self - .with_retry(|| async { - let payload_clone = payload.clone(); - self.post_responses(session_id, &payload_clone).await - }) - .await - .inspect_err(|e| { - let _ = log.error(e); - })?; - - let responses_api_response: ResponsesApiResponse = - serde_json::from_value(json_response.clone()).map_err(|e| { - ProviderError::ExecutionError(format!( - "Failed to parse responses API response: {}", - e - )) - })?; - - let message = responses_api_to_message(&responses_api_response)?; - let usage = get_responses_usage(&responses_api_response); - let model = responses_api_response.model.clone(); - - log.write(&json_response, Some(&usage))?; - Ok((message, ProviderUsage::new(model, usage))) - } else { - let payload = create_request( - model_config, - system, - messages, - tools, - &ImageFormat::OpenAi, - false, - )?; - - let mut log = RequestLog::start(&self.model, &payload)?; - let json_response = self - .with_retry(|| async { - let payload_clone = payload.clone(); - self.post(session_id, &payload_clone).await - }) - .await - .inspect_err(|e| { - let _ = log.error(e); - })?; - - let message = response_to_message(&json_response)?; - let usage = json_response - .get("usage") - .map(get_usage) - .unwrap_or_else(|| { - tracing::debug!("Failed to get usage data"); - Usage::default() - }); - - let model = get_model(&json_response); - log.write(&json_response, Some(&usage))?; - Ok((message, ProviderUsage::new(model, usage))) - } - } - async fn fetch_supported_models(&self) -> Result, ProviderError> { let models_path = Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH); @@ -462,22 +355,19 @@ impl Provider for OpenAiProvider { .map_err(|e| ProviderError::ExecutionError(e.to_string())) } - fn supports_streaming(&self) -> bool { - self.supports_streaming - } - async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result { - if Self::should_use_responses_api(&self.model.model_name, &self.base_path) { - let mut payload = create_responses_request(&self.model, system, messages, tools)?; - payload["stream"] = serde_json::Value::Bool(true); + if Self::should_use_responses_api(&model_config.model_name, &self.base_path) { + let mut payload = create_responses_request(model_config, system, messages, tools)?; + payload["stream"] = serde_json::Value::Bool(self.supports_streaming); - let mut log = RequestLog::start(&self.model, &payload)?; + let mut log = RequestLog::start(model_config, &payload)?; let response = self .with_retry(|| async { @@ -501,30 +391,56 @@ impl Provider for OpenAiProvider { let _ = log.error(e); })?; - let stream = response.bytes_stream().map_err(io::Error::other); - - Ok(Box::pin(try_stream! { - let stream_reader = StreamReader::new(stream); - let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); + if self.supports_streaming { + let stream = response.bytes_stream().map_err(io::Error::other); + + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); + + let message_stream = responses_api_to_streaming_message(framed); + pin!(message_stream); + while let Some(message) = message_stream.next().await { + let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; + log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?; + yield (message, usage); + } + })) + } else { + let json: serde_json::Value = response.json().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse JSON: {}", e)) + })?; - let message_stream = responses_api_to_streaming_message(framed); - pin!(message_stream); - while let Some(message) = message_stream.next().await { - let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; - log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?; - yield (message, usage); - } - })) + let responses_api_response: ResponsesApiResponse = + serde_json::from_value(json.clone()).map_err(|e| { + ProviderError::ExecutionError(format!( + "Failed to parse responses API response: {}", + e + )) + })?; + + let message = responses_api_to_message(&responses_api_response)?; + let usage_data = get_responses_usage(&responses_api_response); + let usage = + super::base::ProviderUsage::new(model_config.model_name.clone(), usage_data); + + log.write( + &serde_json::to_value(&message).unwrap_or_default(), + Some(&usage_data), + )?; + + Ok(super::base::stream_from_single_message(message, usage)) + } } else { let payload = create_request( - &self.model, + model_config, system, messages, tools, &ImageFormat::OpenAi, - true, + self.supports_streaming, )?; - let mut log = RequestLog::start(&self.model, &payload)?; + let mut log = RequestLog::start(model_config, &payload)?; let response = self .with_retry(|| async { @@ -539,7 +455,28 @@ impl Provider for OpenAiProvider { let _ = log.error(e); })?; - stream_openai_compat(response, log) + if self.supports_streaming { + stream_openai_compat(response, log) + } else { + let json: serde_json::Value = response.json().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse JSON: {}", e)) + })?; + + let message = response_to_message(&json).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse message: {}", e)) + })?; + + let usage_data = get_usage(json.get("usage").unwrap_or(&serde_json::Value::Null)); + let usage = + super::base::ProviderUsage::new(model_config.model_name.clone(), usage_data); + + log.write( + &serde_json::to_value(&message).unwrap_or_default(), + Some(&usage_data), + )?; + + Ok(super::base::stream_from_single_message(message, usage)) + } } } } diff --git a/crates/goose/src/providers/openai_compatible.rs b/crates/goose/src/providers/openai_compatible.rs index 3e703c54dc8b..0eb1c811ef98 100644 --- a/crates/goose/src/providers/openai_compatible.rs +++ b/crates/goose/src/providers/openai_compatible.rs @@ -9,15 +9,13 @@ use tokio_util::codec::{FramedRead, LinesCodec}; use tokio_util::io::StreamReader; use super::api_client::ApiClient; -use super::base::{MessageStream, Provider, ProviderUsage, Usage}; +use super::base::{MessageStream, Provider}; use super::errors::ProviderError; use super::retry::ProviderRetry; -use super::utils::{get_model, ImageFormat, RequestLog}; +use super::utils::{ImageFormat, RequestLog}; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::formats::openai::{ - create_request, get_usage, response_to_message, response_to_streaming_message, -}; +use crate::providers::formats::openai::{create_request, response_to_streaming_message}; use rmcp::model::Tool; pub struct OpenAiCompatibleProvider { @@ -74,44 +72,6 @@ impl Provider for OpenAiCompatibleProvider { self.model.clone() } - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = self.build_request(model_config, system, messages, tools, false)?; - let mut log = RequestLog::start(model_config, &payload)?; - - let completions_path = format!("{}chat/completions", self.completions_prefix); - let response = self - .with_retry(|| async { - let resp = self - .api_client - .response_post(session_id, &completions_path, &payload) - .await?; - handle_response_openai_compat(resp).await - }) - .await?; - - let response_model = get_model(&response); - let message = response_to_message(&response) - .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; - let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { - tracing::debug!("Failed to get usage data"); - Usage::default() - }); - log.write(&response, Some(&usage))?; - - Ok((message, ProviderUsage::new(response_model, usage))) - } - async fn fetch_supported_models(&self) -> Result, ProviderError> { let response = self .api_client @@ -139,19 +99,16 @@ impl Provider for OpenAiCompatibleProvider { Ok(models) } - fn supports_streaming(&self) -> bool { - true - } - async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result { - let payload = self.build_request(&self.model, system, messages, tools, true)?; - let mut log = RequestLog::start(&self.model, &payload)?; + let payload = self.build_request(model_config, system, messages, tools, true)?; + let mut log = RequestLog::start(model_config, &payload)?; let completions_path = format!("{}chat/completions", self.completions_prefix); let response = self diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 26d08983d436..39fd3149ff92 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -4,18 +4,14 @@ use futures::future::BoxFuture; use serde_json::{json, Value}; use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ - ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, -}; +use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use super::errors::ProviderError; -use super::openai_compatible::{ - handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, -}; +use super::openai_compatible::{handle_status_openai_compat, stream_openai_compat}; use super::retry::ProviderRetry; -use super::utils::{get_model, ImageFormat, RequestLog}; +use super::utils::{ImageFormat, RequestLog}; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::formats::openai::{create_request, get_usage}; +use crate::providers::formats::openai::create_request; use crate::providers::formats::openrouter as openrouter_format; use rmcp::model::Tool; @@ -71,50 +67,6 @@ impl OpenRouterProvider { name: OPENROUTER_PROVIDER_NAME.to_string(), }) } - - async fn post( - &self, - session_id: Option<&str>, - payload: &Value, - ) -> Result { - let response = self - .api_client - .response_post(session_id, "api/v1/chat/completions", payload) - .await?; - - let response_body = handle_response_openai_compat(response) - .await - .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse response: {e}")))?; - - if let Some(error_obj) = response_body.get("error") { - let error_message = error_obj - .get("message") - .and_then(|m| m.as_str()) - .unwrap_or("Unknown OpenRouter error"); - - let error_code = error_obj.get("code").and_then(|c| c.as_u64()).unwrap_or(0); - - if error_code == 400 && error_message.contains("maximum context length") { - return Err(ProviderError::ContextLengthExceeded( - error_message.to_string(), - )); - } - - match error_code { - 401 | 403 => return Err(ProviderError::Authentication(error_message.to_string())), - 429 => { - return Err(ProviderError::RateLimitExceeded { - details: error_message.to_string(), - retry_delay: None, - }) - } - 500 | 503 => return Err(ProviderError::ServerError(error_message.to_string())), - _ => return Err(ProviderError::RequestFailed(error_message.to_string())), - } - } - - Ok(response_body) - } } /// Update the request when using anthropic model. @@ -194,43 +146,6 @@ fn is_gemini_model(model_name: &str) -> bool { model_name.starts_with("google/") } -async fn create_request_based_on_model( - provider: &OpenRouterProvider, - session_id: Option<&str>, - system: &str, - messages: &[Message], - tools: &[Tool], -) -> Result { - let mut payload = create_request( - &provider.model, - system, - messages, - tools, - &ImageFormat::OpenAi, - false, - )?; - - if let Some(session_id) = session_id.filter(|id| !id.is_empty()) { - if let Some(obj) = payload.as_object_mut() { - obj.insert("user".to_string(), Value::String(session_id.to_string())); - } - } - - if provider.supports_cache_control().await { - payload = update_request_for_anthropic(&payload); - } - - if is_gemini_model(&provider.model.model_name) { - openrouter_format::add_reasoning_details_to_request(&mut payload, messages); - } - - if let Some(obj) = payload.as_object_mut() { - obj.insert("transforms".to_string(), json!(["middle-out"])); - } - - Ok(payload) -} - impl ProviderDef for OpenRouterProvider { type Provider = Self; @@ -273,44 +188,7 @@ impl Provider for OpenRouterProvider { self.model.clone() } - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = - create_request_based_on_model(self, session_id, system, messages, tools).await?; - let mut log = RequestLog::start(model_config, &payload)?; - - let response = self - .with_retry(|| async { - let payload_clone = payload.clone(); - self.post(session_id, &payload_clone).await - }) - .await?; - - let response_model = get_model(&response); - let message = if is_gemini_model(&self.model.model_name) { - openrouter_format::response_to_message(&response)? - } else { - crate::providers::formats::openai::response_to_message(&response)? - }; - - let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { - tracing::debug!("Failed to get usage data"); - Usage::default() - }); - log.write(&response, Some(&usage))?; - Ok((message, ProviderUsage::new(response_model, usage))) - } - + /// Fetch supported models from OpenRouter API (only models with tool support) async fn fetch_supported_models(&self) -> Result, ProviderError> { let response = self .api_client @@ -364,19 +242,16 @@ impl Provider for OpenRouterProvider { .starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC) } - fn supports_streaming(&self) -> bool { - self.supports_streaming - } - async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result { let mut payload = create_request( - &self.model, + model_config, system, messages, tools, @@ -384,11 +259,18 @@ impl Provider for OpenRouterProvider { true, )?; + // Add user field for OpenRouter attribution/rate-limiting + if !session_id.is_empty() { + if let Some(obj) = payload.as_object_mut() { + obj.insert("user".to_string(), Value::String(session_id.to_string())); + } + } + if self.supports_cache_control().await { payload = update_request_for_anthropic(&payload); } - if is_gemini_model(&self.model.model_name) { + if is_gemini_model(&model_config.model_name) { openrouter_format::add_reasoning_details_to_request(&mut payload, messages); } @@ -396,7 +278,7 @@ impl Provider for OpenRouterProvider { obj.insert("transforms".to_string(), json!(["middle-out"])); } - let mut log = RequestLog::start(&self.model, &payload)?; + let mut log = RequestLog::start(model_config, &payload)?; let response = self .with_retry(|| async { diff --git a/crates/goose/src/providers/provider_test.rs b/crates/goose/src/providers/provider_test.rs index 3283b9b7f44c..33450bd978c0 100644 --- a/crates/goose/src/providers/provider_test.rs +++ b/crates/goose/src/providers/provider_test.rs @@ -26,8 +26,10 @@ pub async fn test_provider_configuration( vec![] }; + let provider_model_config = provider.get_model_config(); let _result = provider .complete( + &provider_model_config, "test-session-id", "You are an AI agent called goose. You use tools of connected extensions to solve problems.", &messages, diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index e75d89f05248..06b0fdc24ef9 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -9,7 +9,9 @@ use aws_sdk_sagemakerruntime::Client as SageMakerClient; use rmcp::model::Tool; use serde_json::{json, Value}; -use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, +}; use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::RequestLog; @@ -311,14 +313,19 @@ impl Provider for SageMakerTgiProvider { skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete_with_model( + async fn stream( &self, - session_id: Option<&str>, model_config: &ModelConfig, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { + let session_id = if session_id.is_empty() { + None + } else { + Some(session_id) + }; let model_name = &model_config.model_name; let request_payload = self.create_tgi_request(system, messages).map_err(|e| { @@ -351,6 +358,9 @@ impl Provider for SageMakerTgiProvider { )?; let provider_usage = ProviderUsage::new(model_name.to_string(), usage); - Ok((message, provider_usage)) + Ok(super::base::stream_from_single_message( + message, + provider_usage, + )) } } diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index 16ca9a4ffcd8..1153fa8539dd 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -4,7 +4,9 @@ use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage}; +use super::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, +}; use super::errors::ProviderError; use super::formats::snowflake::{create_request, get_usage, response_to_message}; use super::openai_compatible::map_http_error_to_provider_error; @@ -342,14 +344,19 @@ impl Provider for SnowflakeProvider { skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete_with_model( + async fn stream( &self, - session_id: Option<&str>, model_config: &ModelConfig, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { + let session_id = if session_id.is_empty() { + None + } else { + Some(session_id) + }; let payload = create_request(model_config, system, messages, tools)?; let mut log = RequestLog::start(&self.model, &payload)?; @@ -367,6 +374,10 @@ impl Provider for SnowflakeProvider { log.write(&response, Some(&usage))?; - Ok((message, ProviderUsage::new(response_model, usage))) + let provider_usage = ProviderUsage::new(response_model, usage); + Ok(super::base::stream_from_single_message( + message, + provider_usage, + )) } } diff --git a/crates/goose/src/providers/testprovider.rs b/crates/goose/src/providers/testprovider.rs index 8ffdac628235..af5d54f74aeb 100644 --- a/crates/goose/src/providers/testprovider.rs +++ b/crates/goose/src/providers/testprovider.rs @@ -7,7 +7,9 @@ use std::fs; use std::path::Path; use std::sync::{Arc, Mutex}; -use super::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage}; +#[cfg(test)] +use super::base::stream_from_single_message; +use super::base::{MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use crate::conversation::message::Message; use crate::model::ModelConfig; @@ -153,21 +155,22 @@ impl Provider for TestProvider { &self.name } - async fn complete_with_model( + async fn stream( &self, - session_id: Option<&str>, - _model_config: &ModelConfig, + model_config: &ModelConfig, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { let hash = Self::hash_input(messages); if let Some(inner) = &self.inner { - let model_config = inner.get_model_config(); - let (message, usage) = inner - .complete_with_model(session_id, &model_config, system, messages, tools) + // Call inner provider's stream and collect it + let stream = inner + .stream(model_config, session_id, system, messages, tools) .await?; + let (message, usage) = super::base::collect_stream(stream).await?; let record = TestRecord { input: TestInput { @@ -186,11 +189,13 @@ impl Provider for TestProvider { records.insert(hash, record); } - Ok((message, usage)) + Ok(super::base::stream_from_single_message(message, usage)) } else { let records = self.records.lock().unwrap(); if let Some(record) = records.get(&hash) { - Ok((record.output.message.clone(), record.output.usage.clone())) + let message = record.output.message.clone(); + let usage = record.output.usage.clone(); + Ok(super::base::stream_from_single_message(message, usage)) } else { Err(ProviderError::ExecutionError(format!( "No recorded response found for input hash: {}", @@ -226,28 +231,27 @@ mod tests { "mock-testprovider" } - async fn complete_with_model( + async fn stream( &self, - _session_id: Option<&str>, _model_config: &ModelConfig, + _session_id: &str, _system: &str, _messages: &[Message], _tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - Ok(( - Message::new( - Role::Assistant, - Utc::now().timestamp(), - vec![MessageContent::Text(TextContent { - raw: RawTextContent { - text: self.response.clone(), - meta: None, - }, - annotations: None, - })], - ), - ProviderUsage::new("mock-model".to_string(), Usage::default()), - )) + ) -> Result { + let message = Message::new( + Role::Assistant, + Utc::now().timestamp(), + vec![MessageContent::Text(TextContent { + raw: RawTextContent { + text: self.response.clone(), + meta: None, + }, + annotations: None, + })], + ); + let usage = ProviderUsage::new("mock-model".to_string(), Usage::default()); + Ok(stream_from_single_message(message, usage)) } fn get_model_config(&self) -> ModelConfig { @@ -270,9 +274,16 @@ mod tests { { let test_provider = TestProvider::new_recording(mock, &temp_file); + let model_config = test_provider.get_model_config(); let result = test_provider - .complete("test-session-id", "You are helpful", &[], &[]) + .complete( + &model_config, + "test-session-id", + "You are helpful", + &[], + &[], + ) .await; assert!(result.is_ok()); @@ -288,9 +299,16 @@ mod tests { { let replay_provider = TestProvider::new_replaying(&temp_file).unwrap(); + let model_config = replay_provider.get_model_config(); let result = replay_provider - .complete("test-session-id", "You are helpful", &[], &[]) + .complete( + &model_config, + "test-session-id", + "You are helpful", + &[], + &[], + ) .await; assert!(result.is_ok()); @@ -313,9 +331,16 @@ mod tests { ); let replay_provider = TestProvider::new_replaying(&temp_file).unwrap(); + let model_config = replay_provider.get_model_config(); let result = replay_provider - .complete("test-session-id", "Different system prompt", &[], &[]) + .complete( + &model_config, + "test-session-id", + "Different system prompt", + &[], + &[], + ) .await; assert!(result.is_err()); diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index b8ad1803c04f..1eff1cd3ef30 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -1,22 +1,17 @@ use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ - ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, -}; +use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use super::errors::ProviderError; -use super::openai_compatible::{ - handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, -}; +use super::openai_compatible::{handle_status_openai_compat, stream_openai_compat}; use super::retry::ProviderRetry; -use super::utils::{get_model, handle_response_google_compat, is_google_model, RequestLog}; +use super::utils::RequestLog; use crate::config::signup_tetrate::TETRATE_DEFAULT_MODEL; use crate::conversation::message::Message; use anyhow::Result; use async_trait::async_trait; use futures::future::BoxFuture; -use serde_json::Value; use crate::model::ModelConfig; -use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; +use crate::providers::formats::openai::create_request; use rmcp::model::Tool; const TETRATE_PROVIDER_NAME: &str = "tetrate"; @@ -66,68 +61,6 @@ impl TetrateProvider { name: TETRATE_PROVIDER_NAME.to_string(), }) } - - async fn post( - &self, - session_id: Option<&str>, - payload: &Value, - ) -> Result { - let response = self - .api_client - .response_post(session_id, "v1/chat/completions", payload) - .await?; - - // Handle Google-compatible model responses differently - if is_google_model(payload) { - return handle_response_google_compat(response).await; - } - - // For OpenAI-compatible models, parse the response body to JSON - let response_body = handle_response_openai_compat(response) - .await - .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse response: {e}")))?; - - let _debug = format!( - "Tetrate Agent Router Service request with payload: {} and response: {}", - serde_json::to_string_pretty(payload).unwrap_or_else(|_| "Invalid JSON".to_string()), - serde_json::to_string_pretty(&response_body) - .unwrap_or_else(|_| "Invalid JSON".to_string()) - ); - - // Tetrate Agent Router Service can return errors in 200 OK responses, so we have to check for errors explicitly - if let Some(error_obj) = response_body.get("error") { - // If there's an error object, extract the error message and code - let error_message = error_obj - .get("message") - .and_then(|m| m.as_str()) - .unwrap_or("Unknown Tetrate Agent Router Service error"); - - let error_code = error_obj.get("code").and_then(|c| c.as_u64()).unwrap_or(0); - - // Check for context length errors in the error message - if error_code == 400 && error_message.contains("maximum context length") { - return Err(ProviderError::ContextLengthExceeded( - error_message.to_string(), - )); - } - - // Return appropriate error based on the error code - match error_code { - 401 | 403 => return Err(ProviderError::Authentication(error_message.to_string())), - 429 => { - return Err(ProviderError::RateLimitExceeded { - details: error_message.to_string(), - retry_delay: None, - }) - } - 500 | 503 => return Err(ProviderError::ServerError(error_message.to_string())), - _ => return Err(ProviderError::RequestFailed(error_message.to_string())), - } - } - - // No error detected, return the response body - Ok(response_body) - } } impl ProviderDef for TetrateProvider { @@ -171,56 +104,16 @@ impl Provider for TetrateProvider { self.model.clone() } - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - session_id: Option<&str>, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request( - model_config, - system, - messages, - tools, - &super::utils::ImageFormat::OpenAi, - false, - )?; - let mut log = RequestLog::start(model_config, &payload)?; - - // Make request - let response = self - .with_retry(|| async { - let payload_clone = payload.clone(); - self.post(session_id, &payload_clone).await - }) - .await?; - - // Parse response - let message = response_to_message(&response)?; - let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { - tracing::debug!("Failed to get usage data"); - Usage::default() - }); - let model = get_model(&response); - log.write(&response, Some(&usage))?; - Ok((message, ProviderUsage::new(model, usage))) - } - async fn stream( &self, + model_config: &ModelConfig, session_id: &str, system: &str, messages: &[Message], tools: &[Tool], ) -> Result { let payload = create_request( - &self.model, + model_config, system, messages, tools, @@ -228,7 +121,7 @@ impl Provider for TetrateProvider { true, )?; - let mut log = RequestLog::start(&self.model, &payload)?; + let mut log = RequestLog::start(model_config, &payload)?; let response = self .with_retry(|| async { @@ -311,8 +204,4 @@ impl Provider for TetrateProvider { models.sort(); Ok(models) } - - fn supports_streaming(&self) -> bool { - self.supports_streaming - } } diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 7a26d47e0db5..8ab86cb72e45 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -5,7 +5,9 @@ use serde::Serialize; use serde_json::{json, Value}; use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ + ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, +}; use super::errors::ProviderError; use super::openai_compatible::map_http_error_to_provider_error; use super::retry::ProviderRetry; @@ -271,14 +273,19 @@ impl Provider for VeniceProvider { skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] - async fn complete_with_model( + async fn stream( &self, - session_id: Option<&str>, model_config: &ModelConfig, + session_id: &str, system: &str, messages: &[Message], tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { + let session_id = if session_id.is_empty() { + None + } else { + Some(session_id) + }; // Create properly formatted messages for Venice API let mut formatted_messages = Vec::new(); @@ -502,12 +509,13 @@ impl Provider for VeniceProvider { message = message.with_content(item); } - return Ok(( + let provider_usage = ProviderUsage::new( + strip_flags(&model_config.model_name).to_string(), + Usage::default(), + ); + return Ok(super::base::stream_from_single_message( message, - ProviderUsage::new( - strip_flags(&model_config.model_name).to_string(), - Usage::default(), - ), + provider_usage, )); } } @@ -534,9 +542,12 @@ impl Provider for VeniceProvider { usage_data["total_tokens"].as_i64().map(|v| v as i32), ); - Ok(( - Message::new(Role::Assistant, Utc::now().timestamp(), content), - ProviderUsage::new(strip_flags(&self.model.model_name).to_string(), usage), + let message = Message::new(Role::Assistant, Utc::now().timestamp(), content); + let provider_usage = + ProviderUsage::new(strip_flags(&self.model.model_name).to_string(), usage); + Ok(super::base::stream_from_single_message( + message, + provider_usage, )) } } diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 26f82f548bf0..370d15b82927 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -343,7 +343,8 @@ mod tests { use goose::conversation::message::{Message, MessageContent}; use goose::model::ModelConfig; use goose::providers::base::{ - Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, + stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata, + ProviderUsage, Usage, }; use goose::providers::errors::ProviderError; use goose::session::session_manager::SessionType; @@ -385,13 +386,14 @@ mod tests { #[async_trait] impl Provider for MockToolProvider { - async fn complete( + async fn stream( &self, + _model_config: &ModelConfig, _session_id: &str, _system_prompt: &str, _messages: &[Message], _tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { let tool_call = CallToolRequestParams { meta: None, task: None, @@ -405,21 +407,7 @@ mod tests { Usage::new(Some(10), Some(5), Some(15)), ); - Ok((message, usage)) - } - - async fn complete_with_model( - &self, - session_id: Option<&str>, - _model_config: &ModelConfig, - system_prompt: &str, - messages: &[Message], - tools: &[Tool], - ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { - // Test-only: coerce missing session_id to empty so complete() can be reused. - let session_id = session_id.unwrap_or(""); - self.complete(session_id, system_prompt, messages, tools) - .await + Ok(stream_from_single_message(message, usage)) } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/tests/compaction.rs b/crates/goose/tests/compaction.rs index 29755c81c0f1..e6e421e86329 100644 --- a/crates/goose/tests/compaction.rs +++ b/crates/goose/tests/compaction.rs @@ -5,7 +5,10 @@ use goose::agents::{Agent, AgentEvent, SessionConfig}; use goose::conversation::message::{Message, MessageContent}; use goose::conversation::Conversation; use goose::model::ModelConfig; -use goose::providers::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; +use goose::providers::base::{ + stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata, + ProviderUsage, Usage, +}; use goose::providers::errors::ProviderError; use goose::session::session_manager::SessionType; use goose::session::Session; @@ -94,14 +97,14 @@ impl MockCompactionProvider { #[async_trait] impl Provider for MockCompactionProvider { - async fn complete_with_model( + async fn stream( &self, - _session_id: Option<&str>, _model_config: &ModelConfig, + _session_id: &str, system_prompt: &str, messages: &[Message], _tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { + ) -> Result { // Check if this is a compaction call (message contains "summarize") let is_compaction = messages.iter().any(|msg| { msg.content.iter().any(|content| { @@ -163,7 +166,7 @@ impl Provider for MockCompactionProvider { ), ); - Ok((message, usage)) + Ok(stream_from_single_message(message, usage)) } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/tests/mcp_integration_test.rs b/crates/goose/tests/mcp_integration_test.rs index 0281ce528d6d..3a9aa7319bdd 100644 --- a/crates/goose/tests/mcp_integration_test.rs +++ b/crates/goose/tests/mcp_integration_test.rs @@ -19,7 +19,10 @@ use test_case::test_case; use async_trait::async_trait; use goose::conversation::message::Message; -use goose::providers::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; +use goose::providers::base::{ + stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata, + ProviderUsage, Usage, +}; use goose::providers::errors::ProviderError; use once_cell::sync::Lazy; use std::process::Command; @@ -69,18 +72,17 @@ impl Provider for MockProvider { "mock" } - async fn complete_with_model( + async fn stream( &self, - _session_id: Option<&str>, _model_config: &ModelConfig, + _session_id: &str, _system: &str, _messages: &[Message], _tools: &[Tool], - ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { - Ok(( - Message::assistant().with_text("\"So we beat on, boats against the current, borne back ceaselessly into the past.\" โ€” F. Scott Fitzgerald, The Great Gatsby (1925)"), - ProviderUsage::new("mock".to_string(), Usage::default()), - )) + ) -> Result { + let message = Message::assistant().with_text("\"So we beat on, boats against the current, borne back ceaselessly into the past.\" โ€” F. Scott Fitzgerald, The Great Gatsby (1925)"); + let usage = ProviderUsage::new("mock".to_string(), Usage::default()); + Ok(stream_from_single_message(message, usage)) } fn get_model_config(&self) -> ModelConfig { diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 0d3ff3a46eeb..dac199ee975b 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -127,9 +127,16 @@ impl ProviderTester { .build(); let message = Message::user().with_text(prompt); + let model_config = self.provider.get_model_config(); let (response1, _) = self .provider - .complete(session_id, &system, std::slice::from_ref(&message), &tools) + .complete( + &model_config, + session_id, + &system, + std::slice::from_ref(&message), + &tools, + ) .await?; // Agentic CLI providers (claude-code, codex) call tools internally and @@ -163,6 +170,7 @@ impl ProviderTester { let (response2, _) = self .provider .complete( + &model_config, session_id, &system, &[message, response1, tool_response], @@ -174,10 +182,17 @@ impl ProviderTester { async fn test_basic_response(&self, session_id: &str) -> Result<()> { let message = Message::user().with_text("Just say hello!"); + let model_config = self.provider.get_model_config(); let (response, _) = self .provider - .complete(session_id, "You are a helpful assistant.", &[message], &[]) + .complete( + &model_config, + session_id, + "You are a helpful assistant.", + &[message], + &[], + ) .await?; assert!( @@ -227,10 +242,17 @@ impl ProviderTester { }; let messages = vec![Message::user().with_text(&large_message_content)]; + let model_config = self.provider.get_model_config(); let result = self .provider - .complete(session_id, "You are a helpful assistant.", &messages, &[]) + .complete( + &model_config, + session_id, + "You are a helpful assistant.", + &messages, + &[], + ) .await; println!("=== {}::context_length_exceeded_error ===", self.name); @@ -288,9 +310,9 @@ impl ProviderTester { let message = Message::user().with_text("Just say hello!"); let (response, _) = self .provider - .complete_with_model( - Some(session_id), + .complete( &alt_config, + session_id, "You are a helpful assistant.", &[message], &[], diff --git a/crates/goose/tests/session_id_propagation_test.rs b/crates/goose/tests/session_id_propagation_test.rs index 5142e18fc5d7..1ace2e6ae79d 100644 --- a/crates/goose/tests/session_id_propagation_test.rs +++ b/crates/goose/tests/session_id_propagation_test.rs @@ -54,24 +54,33 @@ async fn setup_mock_server() -> (MockServer, HeaderCapture, Box) { .and(path("/v1/chat/completions")) .respond_with(move |req: &Request| { capture_clone.capture_session_header(req); - ResponseTemplate::new(200).set_body_json(json!({ - "choices": [{ - "finish_reason": "stop", - "index": 0, - "message": { - "content": "Hi there! How can I help you today?", - "role": "assistant" + // Return SSE streaming format + let sse_response = format!( + "data: {}\n\ndata: {}\n\ndata: [DONE]\n\n", + json!({ + "choices": [{ + "delta": { + "content": "Hi there! How can I help you today?", + "role": "assistant" + }, + "index": 0 + }], + "created": 1755133833, + "id": "chatcmpl-test", + "model": "gpt-5-nano" + }), + json!({ + "choices": [], + "usage": { + "completion_tokens": 10, + "prompt_tokens": 8, + "total_tokens": 18 } - }], - "created": 1755133833, - "id": "chatcmpl-test", - "model": "gpt-5-nano", - "usage": { - "completion_tokens": 10, - "prompt_tokens": 8, - "total_tokens": 18 - } - })) + }) + ); + ResponseTemplate::new(200) + .set_body_string(sse_response) + .insert_header("content-type", "text/event-stream") }) .mount(&mock_server) .await; @@ -82,8 +91,15 @@ async fn setup_mock_server() -> (MockServer, HeaderCapture, Box) { async fn make_request(provider: &dyn Provider, session_id: &str) { let message = Message::user().with_text("test message"); + let model_config = provider.get_model_config(); let _ = provider - .complete(session_id, "You are a helpful assistant.", &[message], &[]) + .complete( + &model_config, + session_id, + "You are a helpful assistant.", + &[message], + &[], + ) .await .unwrap(); } diff --git a/crates/goose/tests/tetrate_streaming.rs b/crates/goose/tests/tetrate_streaming.rs index 32c4c2028c27..04e7370c7313 100644 --- a/crates/goose/tests/tetrate_streaming.rs +++ b/crates/goose/tests/tetrate_streaming.rs @@ -27,9 +27,11 @@ mod tetrate_streaming_tests { let provider = create_test_provider().await?; let messages = vec![Message::user().with_text("Count from 1 to 5, one number at a time.")]; + let model_config = provider.get_model_config(); let mut stream = provider .stream( + &model_config, "test-session-id", "You are a helpful assistant that counts numbers.", &messages, @@ -99,9 +101,11 @@ mod tetrate_streaming_tests { ); let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; + let model_config = provider.get_model_config(); let mut stream = provider .stream( + &model_config, "test-session-id", "You are a helpful assistant with access to weather information.", &messages, @@ -147,9 +151,11 @@ mod tetrate_streaming_tests { // This might result in a very short or empty response let messages = vec![Message::user().with_text("")]; + let model_config = provider.get_model_config(); let mut stream = provider .stream( + &model_config, "test-session-id", "You are a helpful assistant.", &messages, @@ -182,9 +188,11 @@ mod tetrate_streaming_tests { let messages = vec![Message::user().with_text( "Write a detailed 3-paragraph essay about the importance of streaming in modern APIs.", )]; + let model_config = provider.get_model_config(); let mut stream = provider .stream( + &model_config, "test-session-id", "You are a helpful assistant that writes detailed essays.", &messages, @@ -243,9 +251,11 @@ mod tetrate_streaming_tests { let provider = TetrateProvider::from_env(model_config).await?; let messages = vec![Message::user().with_text("Hello")]; + let model_config = provider.get_model_config(); let result = provider .stream( + &model_config, "test-session-id", "You are a helpful assistant.", &messages, @@ -271,9 +281,11 @@ mod tetrate_streaming_tests { // Create multiple concurrent streams let messages1 = vec![Message::user().with_text("Say 'Stream 1'")]; let messages2 = vec![Message::user().with_text("Say 'Stream 2'")]; + let model_config = provider.get_model_config(); let stream1 = provider .stream( + &model_config, "test-session-id", "You are a helpful assistant.", &messages1, @@ -283,6 +295,7 @@ mod tetrate_streaming_tests { let stream2 = provider .stream( + &model_config, "test-session-id", "You are a helpful assistant.", &messages2,