diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index e0a05c023c7e..1c68512c3f8b 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -369,7 +369,7 @@ pub async fn configure_provider_dialog() -> Result> { let models_res = { let temp_model_config = goose::model::ModelConfig::new(&provider_meta.default_model)?; let temp_provider = create(provider_name, temp_model_config)?; - temp_provider.fetch_supported_models_async().await + temp_provider.fetch_supported_models().await }; spin.stop(style("Model fetch complete").green()); diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index 9aa97f687c45..da41e8120f1e 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -294,17 +294,18 @@ pub fn render_prompts(prompts: &HashMap>) { pub fn render_prompt_info(info: &PromptInfo) { println!(); - if let Some(ext) = &info.extension { println!(" {}: {}", style("Extension").green(), ext); } - println!(" Prompt: {}", style(&info.name).cyan().bold()); - if let Some(desc) = &info.description { println!("\n {}", desc); } + render_arguments(info); + println!(); +} +fn render_arguments(info: &PromptInfo) { if let Some(args) = &info.arguments { println!("\n Arguments:"); for arg in args { @@ -323,7 +324,6 @@ pub fn render_prompt_info(info: &PromptInfo) { ); } } - println!(); } pub fn render_extension_success(name: &str) { @@ -491,6 +491,23 @@ fn get_tool_params_max_length() -> usize { .unwrap_or(40) } +fn print_value(value: &Value, debug: bool) { + let formatted = match value { + Value::String(s) => { + if !debug && s.len() > get_tool_params_max_length() { + style(format!("[REDACTED: {} chars]", s.len())).yellow() + } else { + style(s.to_string()).green() + } + } + Value::Number(n) => style(n.to_string()).yellow(), + Value::Bool(b) => style(b.to_string()).yellow(), + Value::Null => style("null".to_string()).dim(), + _ => unreachable!(), + }; + println!("{}", formatted); +} + fn print_params(value: &Value, depth: usize, debug: bool) { let indent = INDENT.repeat(depth); @@ -509,21 +526,9 @@ fn print_params(value: &Value, depth: usize, debug: bool) { print_params(item, depth + 2, debug); } } - Value::String(s) => { - if !debug && s.len() > get_tool_params_max_length() { - println!("{}{}: {}", indent, style(key).dim(), style("...").dim()); - } else { - println!("{}{}: {}", indent, style(key).dim(), style(s).green()); - } - } - Value::Number(n) => { - println!("{}{}: {}", indent, style(key).dim(), style(n).blue()); - } - Value::Bool(b) => { - println!("{}{}: {}", indent, style(key).dim(), style(b).blue()); - } - Value::Null => { - println!("{}{}: {}", indent, style(key).dim(), style("null").dim()); + _ => { + print!("{}{}: ", indent, style(key).dim()); + print_value(val, debug); } } } @@ -534,26 +539,7 @@ fn print_params(value: &Value, depth: usize, debug: bool) { print_params(item, depth + 1, debug); } } - Value::String(s) => { - if !debug && s.len() > get_tool_params_max_length() { - println!( - "{}{}", - indent, - style(format!("[REDACTED: {} chars]", s.len())).yellow() - ); - } else { - println!("{}{}", indent, style(s).green()); - } - } - Value::Number(n) => { - println!("{}{}", indent, style(n).yellow()); - } - Value::Bool(b) => { - println!("{}{}", indent, style(b).yellow()); - } - Value::Null => { - println!("{}{}", indent, style("null").dim()); - } + _ => print_value(value, debug), } } diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index bb416993b3cc..8ce58b00b588 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -226,6 +226,8 @@ async fn get_tools( path = "/agent/update_provider", responses( (status = 200, description = "Update provider completed", body = String), + (status = 400, description = "Bad request - missing or invalid parameters"), + (status = 401, description = "Unauthorized - invalid secret key"), (status = 500, description = "Internal server error") ) )] @@ -234,15 +236,7 @@ async fn update_agent_provider( headers: HeaderMap, Json(payload): Json, ) -> Result { - // Verify secret key - let secret_key = headers - .get("X-Secret-Key") - .and_then(|value| value.to_str().ok()) - .ok_or(StatusCode::UNAUTHORIZED)?; - - if secret_key != state.secret_key { - return Err(StatusCode::UNAUTHORIZED); - } + verify_secret_key(&headers, &state)?; let agent = state .get_agent() @@ -250,13 +244,18 @@ async fn update_agent_provider( .map_err(|_| StatusCode::PRECONDITION_FAILED)?; let config = Config::global(); - let model = payload.model.unwrap_or_else(|| { - config - .get_param("GOOSE_MODEL") - .expect("Did not find a model on payload or in env to update provider with") - }); - let model_config = ModelConfig::new(&model).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - let new_provider = create(&payload.provider, model_config).unwrap(); + let model = match payload + .model + .or_else(|| config.get_param("GOOSE_MODEL").ok()) + { + Some(m) => m, + None => return Err(StatusCode::BAD_REQUEST), + }; + + let model_config = ModelConfig::new(&model).map_err(|_| StatusCode::BAD_REQUEST)?; + + let new_provider = + create(&payload.provider, model_config).map_err(|_| StatusCode::BAD_REQUEST)?; agent .update_provider(new_provider) .await diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index f66c5a35e94d..2a5ad01925cb 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -1,29 +1,28 @@ use anyhow::Result; use async_stream::try_stream; use async_trait::async_trait; -use axum::http::HeaderMap; use futures::TryStreamExt; -use reqwest::{Client, StatusCode}; +use reqwest::StatusCode; use serde_json::Value; use std::io; -use std::time::Duration; use tokio::pin; - use tokio_util::io::StreamReader; +use super::api_client::{ApiClient, ApiResponse, AuthMethod}; use super::base::{ConfigKey, MessageStream, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use super::formats::anthropic::{ create_request, get_usage, response_to_message, response_to_streaming_message, }; -use super::utils::{emit_debug_trace, get_model}; +use super::utils::{emit_debug_trace, get_model, map_http_error_to_provider_error}; use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; +use crate::providers::retry::ProviderRetry; use rmcp::model::Tool; -pub const ANTHROPIC_DEFAULT_MODEL: &str = "claude-3-5-sonnet-latest"; -pub const ANTHROPIC_KNOWN_MODELS: &[&str] = &[ +const ANTHROPIC_DEFAULT_MODEL: &str = "claude-sonnet-4-0"; +const ANTHROPIC_KNOWN_MODELS: &[&str] = &[ "claude-sonnet-4-0", "claude-sonnet-4-20250514", "claude-opus-4-0", @@ -35,15 +34,13 @@ pub const ANTHROPIC_KNOWN_MODELS: &[&str] = &[ "claude-3-opus-latest", ]; -pub const ANTHROPIC_DOC_URL: &str = "https://docs.anthropic.com/en/docs/about-claude/models"; -pub const ANTHROPIC_API_VERSION: &str = "2023-06-01"; +const ANTHROPIC_DOC_URL: &str = "https://docs.anthropic.com/en/docs/about-claude/models"; +const ANTHROPIC_API_VERSION: &str = "2023-06-01"; #[derive(serde::Serialize)] pub struct AnthropicProvider { #[serde(skip)] - client: Client, - host: String, - api_key: String, + api_client: ApiClient, model: ModelConfig, } @@ -57,69 +54,67 @@ impl AnthropicProvider { .get_param("ANTHROPIC_HOST") .unwrap_or_else(|_| "https://api.anthropic.com".to_string()); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + let auth = AuthMethod::ApiKey { + header_name: "x-api-key".to_string(), + key: api_key, + }; + + let api_client = + ApiClient::new(host, auth)?.with_header("anthropic-version", ANTHROPIC_API_VERSION)?; - Ok(Self { - client, - host, - api_key, - model, - }) + Ok(Self { api_client, model }) } - async fn post(&self, headers: HeaderMap, payload: &Value) -> Result { - let base_url = url::Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("v1/messages").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; + fn get_conditional_headers(&self) -> Vec<(&str, &str)> { + let mut headers = Vec::new(); - let response = self - .client - .post(url) - .headers(headers) - .json(payload) - .send() - .await?; + let is_thinking_enabled = std::env::var("CLAUDE_THINKING_ENABLED").is_ok(); + if self.model.model_name.starts_with("claude-3-7-sonnet-") { + if is_thinking_enabled { + headers.push(("anthropic-beta", "output-128k-2025-02-19")); + } + headers.push(("anthropic-beta", "token-efficient-tools-2025-02-19")); + } - let status = response.status(); - let payload: Option = response.json().await.ok(); + headers + } - // https://docs.anthropic.com/en/api/errors - match status { - StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, payload))) - } - StatusCode::BAD_REQUEST => { - let mut error_msg = "Unknown error".to_string(); - if let Some(payload) = &payload { - if let Some(error) = payload.get("error") { - tracing::debug!("Bad Request Error: {error:?}"); - error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string(); - if error_msg.to_lowercase().contains("too long") || error_msg.to_lowercase().contains("too many") { - return Err(ProviderError::ContextLengthExceeded(error_msg.to_string())); - } - }} - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg))) - } - StatusCode::TOO_MANY_REQUESTS => { - Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) - } - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - Err(ProviderError::ServerError(format!("{:?}", payload))) - } + async fn post(&self, payload: &Value) -> Result { + let mut request = self.api_client.request("v1/messages"); + + for (key, value) in self.get_conditional_headers() { + request = request.header(key, value)?; + } + + Ok(request.api_post(payload).await?) + } + + fn anthropic_api_call_result(response: ApiResponse) -> Result { + match response.status { + StatusCode::OK => response.payload.ok_or_else(|| { + ProviderError::RequestFailed("Response body is not valid JSON".to_string()) + }), _ => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) + if response.status == StatusCode::BAD_REQUEST { + if let Some(error_msg) = response + .payload + .as_ref() + .and_then(|p| p.get("error")) + .and_then(|e| e.get("message")) + .and_then(|m| m.as_str()) + { + let msg = error_msg.to_string(); + if msg.to_lowercase().contains("too long") + || msg.to_lowercase().contains("too many") + { + return Err(ProviderError::ContextLengthExceeded(msg)); + } + } + } + Err(map_http_error_to_provider_error( + response.status, + response.payload, + )) } } } @@ -128,24 +123,17 @@ impl AnthropicProvider { #[async_trait] impl Provider for AnthropicProvider { fn metadata() -> ProviderMetadata { + let models: Vec = ANTHROPIC_KNOWN_MODELS + .iter() + .map(|&model_name| ModelInfo::new(model_name, 200_000)) + .collect(); + ProviderMetadata::with_models( "anthropic", "Anthropic", "Claude and other models from Anthropic", ANTHROPIC_DEFAULT_MODEL, - vec![ - ModelInfo::new("claude-sonnet-4-0", 200000), - ModelInfo::new("claude-sonnet-4-20250514", 200000), - ModelInfo::new("claude-opus-4-0", 200000), - ModelInfo::new("claude-opus-4-20250514", 200000), - ModelInfo::new("claude-3-7-sonnet-latest", 200000), - ModelInfo::new("claude-3-7-sonnet-20250219", 200000), - ModelInfo::new("claude-3-5-sonnet-20241022", 200000), - ModelInfo::new("claude-3-5-haiku-20241022", 200000), - ModelInfo::new("claude-3-opus-20240229", 200000), - ModelInfo::new("claude-3-sonnet-20240229", 200000), - ModelInfo::new("claude-3-haiku-20240307", 200000), - ], + models, ANTHROPIC_DOC_URL, vec![ ConfigKey::new("ANTHROPIC_API_KEY", true, true, None), @@ -175,35 +163,19 @@ impl Provider for AnthropicProvider { ) -> Result<(Message, ProviderUsage), ProviderError> { let payload = create_request(&self.model, system, messages, tools)?; - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert("x-api-key", self.api_key.parse().unwrap()); - headers.insert("anthropic-version", ANTHROPIC_API_VERSION.parse().unwrap()); - - let is_thinking_enabled = std::env::var("CLAUDE_THINKING_ENABLED").is_ok(); - if self.model.model_name.starts_with("claude-3-7-sonnet-") && is_thinking_enabled { - // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#extended-output-capabilities-beta - headers.insert("anthropic-beta", "output-128k-2025-02-19".parse().unwrap()); - } - - if self.model.model_name.starts_with("claude-3-7-sonnet-") { - // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use - headers.insert( - "anthropic-beta", - "token-efficient-tools-2025-02-19".parse().unwrap(), - ); - } + let response = self + .with_retry(|| async { self.post(&payload).await }) + .await?; - // Make request - let response = self.post(headers, &payload).await?; + let json_response = Self::anthropic_api_call_result(response)?; - // Parse response - let message = response_to_message(&response)?; - let usage = get_usage(&response)?; - tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", + let message = response_to_message(&json_response)?; + let usage = get_usage(&json_response)?; + tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", usage.input_tokens, usage.output_tokens, usage.total_tokens); - let model = get_model(&response); - emit_debug_trace(&self.model, &payload, &response, &usage); + let model = get_model(&json_response); + emit_debug_trace(&self.model, &payload, &json_response, &usage); let provider_usage = ProviderUsage::new(model, usage); tracing::debug!( "🔍 Anthropic non-streaming returning ProviderUsage: {:?}", @@ -212,22 +184,22 @@ impl Provider for AnthropicProvider { Ok((message, provider_usage)) } - /// Fetch supported models from Anthropic; returns Err on failure, Ok(None) if not present - async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { - let url = format!("{}/v1/models", self.host); - let response = self - .client - .get(&url) - .header("anthropic-version", ANTHROPIC_API_VERSION) - .header("x-api-key", self.api_key.clone()) - .send() - .await?; - let json: serde_json::Value = response.json().await?; - // if 'models' key missing, return None + async fn fetch_supported_models(&self) -> Result>, ProviderError> { + let response = self.api_client.api_get("v1/models").await?; + + if response.status != StatusCode::OK { + return Err(map_http_error_to_provider_error( + response.status, + response.payload, + )); + } + + let json = response.payload.unwrap_or_default(); let arr = match json.get("models").and_then(|v| v.as_array()) { Some(arr) => arr, None => return Ok(None), }; + let mut models: Vec = arr .iter() .filter_map(|m| { @@ -251,59 +223,28 @@ impl Provider for AnthropicProvider { tools: &[Tool], ) -> Result { let mut payload = create_request(&self.model, system, messages, tools)?; - - // Add stream parameter payload .as_object_mut() .unwrap() .insert("stream".to_string(), Value::Bool(true)); - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert("x-api-key", self.api_key.parse().unwrap()); - headers.insert("anthropic-version", ANTHROPIC_API_VERSION.parse().unwrap()); - - let is_thinking_enabled = std::env::var("CLAUDE_THINKING_ENABLED").is_ok(); - if self.model.model_name.starts_with("claude-3-7-sonnet-") && is_thinking_enabled { - // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#extended-output-capabilities-beta - headers.insert("anthropic-beta", "output-128k-2025-02-19".parse().unwrap()); - } + let mut request = self.api_client.request("v1/messages"); - if self.model.model_name.starts_with("claude-3-7-sonnet-") { - // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use - headers.insert( - "anthropic-beta", - "token-efficient-tools-2025-02-19".parse().unwrap(), - ); + for (key, value) in self.get_conditional_headers() { + request = request.header(key, value)?; } - let base_url = url::Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("v1/messages").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - let response = self - .client - .post(url) - .headers(headers) - .json(&payload) - .send() - .await?; - + let response = request.response_post(&payload).await?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); - return Err(ProviderError::RequestFailed(format!( - "Streaming request failed with status: {}. Error: {}", - status, error_text - ))); + let error_json = serde_json::from_str::(&error_text).ok(); + return Err(map_http_error_to_provider_error(status, error_json)); } - // Map reqwest error to io::Error let stream = response.bytes_stream().map_err(io::Error::other); let model_config = self.model.clone(); - // Wrap in a line decoder and yield lines inside the stream Ok(Box::pin(try_stream! { let stream_reader = StreamReader::new(stream); let framed = tokio_util::codec::FramedRead::new(stream_reader, tokio_util::codec::LinesCodec::new()).map_err(anyhow::Error::from); @@ -312,7 +253,7 @@ impl Provider for AnthropicProvider { pin!(message_stream); while let Some(message) = futures::StreamExt::next(&mut message_stream).await { let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; - super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); yield (message, usage); } })) diff --git a/crates/goose/src/providers/api_client.rs b/crates/goose/src/providers/api_client.rs new file mode 100644 index 000000000000..434451e8770a --- /dev/null +++ b/crates/goose/src/providers/api_client.rs @@ -0,0 +1,225 @@ +use anyhow::Result; +use async_trait::async_trait; +use reqwest::{ + header::{HeaderMap, HeaderName, HeaderValue}, + Client, Response, StatusCode, +}; +use serde_json::Value; +use std::fmt; +use std::time::Duration; + +pub struct ApiClient { + client: Client, + host: String, + auth: AuthMethod, + default_headers: HeaderMap, + timeout: Duration, +} + +pub enum AuthMethod { + BearerToken(String), + ApiKey { + header_name: String, + key: String, + }, + #[allow(dead_code)] + OAuth(OAuthConfig), + Custom(Box), +} + +pub struct OAuthConfig { + pub host: String, + pub client_id: String, + pub redirect_url: String, + pub scopes: Vec, +} + +#[async_trait] +pub trait AuthProvider: Send + Sync { + async fn get_auth_header(&self) -> Result<(String, String)>; +} + +pub struct ApiResponse { + pub status: StatusCode, + pub payload: Option, +} + +impl fmt::Debug for AuthMethod { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthMethod::BearerToken(_) => f.debug_tuple("BearerToken").field(&"[hidden]").finish(), + AuthMethod::ApiKey { header_name, .. } => f + .debug_struct("ApiKey") + .field("header_name", header_name) + .field("key", &"[hidden]") + .finish(), + AuthMethod::OAuth(_) => f.debug_tuple("OAuth").field(&"[config]").finish(), + AuthMethod::Custom(_) => f.debug_tuple("Custom").field(&"[provider]").finish(), + } + } +} + +impl ApiResponse { + pub async fn from_response(response: Response) -> Result { + let status = response.status(); + let payload = response.json().await.ok(); + Ok(Self { status, payload }) + } +} + +pub struct ApiRequestBuilder<'a> { + client: &'a ApiClient, + path: &'a str, + headers: HeaderMap, +} + +impl ApiClient { + pub fn new(host: String, auth: AuthMethod) -> Result { + Self::with_timeout(host, auth, Duration::from_secs(600)) + } + + pub fn with_timeout(host: String, auth: AuthMethod, timeout: Duration) -> Result { + Ok(Self { + client: Client::builder().timeout(timeout).build()?, + host, + auth, + default_headers: HeaderMap::new(), + timeout, + }) + } + + pub fn with_headers(mut self, headers: HeaderMap) -> Result { + self.default_headers = headers; + self.client = Client::builder() + .timeout(self.timeout) + .default_headers(self.default_headers.clone()) + .build()?; + Ok(self) + } + + pub fn with_header(mut self, key: &str, value: &str) -> Result { + let header_name = HeaderName::from_bytes(key.as_bytes())?; + let header_value = HeaderValue::from_str(value)?; + self.default_headers.insert(header_name, header_value); + self.client = Client::builder() + .timeout(self.timeout) + .default_headers(self.default_headers.clone()) + .build()?; + Ok(self) + } + + pub fn request<'a>(&'a self, path: &'a str) -> ApiRequestBuilder<'a> { + ApiRequestBuilder { + client: self, + path, + headers: HeaderMap::new(), + } + } + + pub async fn api_post(&self, path: &str, payload: &Value) -> Result { + self.request(path).api_post(payload).await + } + + pub async fn response_post(&self, path: &str, payload: &Value) -> Result { + self.request(path).response_post(payload).await + } + + pub async fn api_get(&self, path: &str) -> Result { + self.request(path).api_get().await + } + + pub async fn response_get(&self, path: &str) -> Result { + self.request(path).response_get().await + } + + fn build_url(&self, path: &str) -> Result { + use url::Url; + let base_url = + Url::parse(&self.host).map_err(|e| anyhow::anyhow!("Invalid base URL: {}", e))?; + base_url + .join(path) + .map_err(|e| anyhow::anyhow!("Failed to construct URL: {}", e)) + } + + async fn get_oauth_token(&self, config: &OAuthConfig) -> Result { + super::oauth::get_oauth_token_async( + &config.host, + &config.client_id, + &config.redirect_url, + &config.scopes, + ) + .await + } +} + +impl<'a> ApiRequestBuilder<'a> { + pub fn header(mut self, key: &str, value: &str) -> Result { + let header_name = HeaderName::from_bytes(key.as_bytes())?; + let header_value = HeaderValue::from_str(value)?; + self.headers.insert(header_name, header_value); + Ok(self) + } + + #[allow(dead_code)] + pub fn headers(mut self, headers: HeaderMap) -> Self { + self.headers.extend(headers); + self + } + + pub async fn api_post(self, payload: &Value) -> Result { + let response = self.response_post(payload).await?; + ApiResponse::from_response(response).await + } + + pub async fn response_post(self, payload: &Value) -> Result { + let request = self.send_request(|url, client| client.post(url)).await?; + Ok(request.json(payload).send().await?) + } + + pub async fn api_get(self) -> Result { + let response = self.response_get().await?; + ApiResponse::from_response(response).await + } + + pub async fn response_get(self) -> Result { + let request = self.send_request(|url, client| client.get(url)).await?; + Ok(request.send().await?) + } + + async fn send_request(&self, request_builder: F) -> Result + where + F: FnOnce(url::Url, &Client) -> reqwest::RequestBuilder, + { + let url = self.client.build_url(self.path)?; + let mut request = request_builder(url, &self.client.client); + request = request.headers(self.headers.clone()); + + request = match &self.client.auth { + AuthMethod::BearerToken(token) => { + request.header("Authorization", format!("Bearer {}", token)) + } + AuthMethod::ApiKey { header_name, key } => request.header(header_name.as_str(), key), + AuthMethod::OAuth(config) => { + let token = self.client.get_oauth_token(config).await?; + request.header("Authorization", format!("Bearer {}", token)) + } + AuthMethod::Custom(provider) => { + let (header_name, header_value) = provider.get_auth_header().await?; + request.header(header_name, header_value) + } + }; + + Ok(request) + } +} + +impl fmt::Debug for ApiClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ApiClient") + .field("host", &self.host) + .field("auth", &"[auth method]") + .field("timeout", &self.timeout) + .field("default_headers", &self.default_headers) + .finish_non_exhaustive() + } +} diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index b4122ffb0380..0a0a2236e9d3 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -1,15 +1,14 @@ use anyhow::Result; use async_trait::async_trait; -use reqwest::Client; use serde::Serialize; use serde_json::Value; -use std::time::Duration; -use tokio::time::sleep; -use super::azureauth::AzureAuth; +use super::api_client::{ApiClient, AuthMethod, AuthProvider}; +use super::azureauth::{AuthError, AzureAuth}; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; +use super::retry::ProviderRetry; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::impl_provider_default; use crate::message::Message; @@ -22,17 +21,9 @@ pub const AZURE_DOC_URL: &str = pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21"; pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"]; -// Default retry configuration -const DEFAULT_MAX_RETRIES: usize = 5; -const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 1000; // Start with 1 second -const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 32000; // Max 32 seconds -const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0; - #[derive(Debug)] pub struct AzureProvider { - client: Client, - auth: AzureAuth, - endpoint: String, + api_client: ApiClient, deployment_name: String, api_version: String, model: ModelConfig, @@ -44,14 +35,39 @@ impl Serialize for AzureProvider { S: serde::Serializer, { use serde::ser::SerializeStruct; - let mut state = serializer.serialize_struct("AzureProvider", 3)?; - state.serialize_field("endpoint", &self.endpoint)?; + let mut state = serializer.serialize_struct("AzureProvider", 2)?; state.serialize_field("deployment_name", &self.deployment_name)?; state.serialize_field("api_version", &self.api_version)?; state.end() } } +// Custom auth provider that wraps AzureAuth +struct AzureAuthProvider { + auth: AzureAuth, +} + +#[async_trait] +impl AuthProvider for AzureAuthProvider { + async fn get_auth_header(&self) -> Result<(String, String)> { + let auth_token = self + .auth + .get_token() + .await + .map_err(|e| anyhow::anyhow!("Failed to get authentication token: {}", e))?; + + match self.auth.credential_type() { + super::azureauth::AzureCredentials::ApiKey(_) => { + Ok(("api-key".to_string(), auth_token.token_value)) + } + super::azureauth::AzureCredentials::DefaultCredential => Ok(( + "Authorization".to_string(), + format!("Bearer {}", auth_token.token_value), + )), + } + } +} + impl_provider_default!(AzureProvider); impl AzureProvider { @@ -67,16 +83,16 @@ impl AzureProvider { .get_secret("AZURE_OPENAI_API_KEY") .ok() .filter(|key: &String| !key.is_empty()); - let auth = AzureAuth::new(api_key)?; + let auth = AzureAuth::new(api_key).map_err(|e| match e { + AuthError::Credentials(msg) => anyhow::anyhow!("Credentials error: {}", msg), + AuthError::TokenExchange(msg) => anyhow::anyhow!("Token exchange error: {}", msg), + })?; - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + let auth_provider = AzureAuthProvider { auth }; + let api_client = ApiClient::new(endpoint, AuthMethod::Custom(Box::new(auth_provider)))?; Ok(Self { - client, - endpoint, - auth, + api_client, deployment_name, api_version, model, @@ -84,130 +100,14 @@ impl AzureProvider { } async fn post(&self, payload: &Value) -> Result { - let mut base_url = url::Url::parse(&self.endpoint) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - - // Get the existing path without trailing slashes - let existing_path = base_url.path().trim_end_matches('/'); - let new_path = if existing_path.is_empty() { - format!( - "/openai/deployments/{}/chat/completions", - self.deployment_name - ) - } else { - format!( - "{}/openai/deployments/{}/chat/completions", - existing_path, self.deployment_name - ) - }; - - base_url.set_path(&new_path); - base_url.set_query(Some(&format!("api-version={}", self.api_version))); - - let mut attempts = 0; - let mut last_error = None; - let mut current_delay = DEFAULT_INITIAL_RETRY_INTERVAL_MS; - - loop { - // Check if we've exceeded max retries - if attempts > DEFAULT_MAX_RETRIES { - let error_msg = format!( - "Exceeded maximum retry attempts ({}) for rate limiting", - DEFAULT_MAX_RETRIES - ); - tracing::error!("{}", error_msg); - return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg))); - } - - // Get a fresh auth token for each attempt - let auth_token = self.auth.get_token().await.map_err(|e| { - tracing::error!("Authentication error: {:?}", e); - ProviderError::RequestFailed(format!("Failed to get authentication token: {}", e)) - })?; - - let mut request_builder = self.client.post(base_url.clone()); - let token_value = auth_token.token_value.clone(); - - // Set the correct header based on authentication type - match self.auth.credential_type() { - super::azureauth::AzureCredentials::ApiKey(_) => { - request_builder = request_builder.header("api-key", token_value.clone()); - } - super::azureauth::AzureCredentials::DefaultCredential => { - request_builder = request_builder - .header("Authorization", format!("Bearer {}", token_value.clone())); - } - } - - let response_result = request_builder.json(payload).send().await; - - match response_result { - Ok(response) => match handle_response_openai_compat(response).await { - Ok(result) => { - return Ok(result); - } - Err(ProviderError::RateLimitExceeded(msg)) => { - attempts += 1; - last_error = Some(ProviderError::RateLimitExceeded(msg.clone())); - - let retry_after = - if let Some(secs) = msg.to_lowercase().find("try again in ") { - msg[secs..] - .split_whitespace() - .nth(3) - .and_then(|s| s.parse::().ok()) - .unwrap_or(0) - } else { - 0 - }; - - let delay = if retry_after > 0 { - Duration::from_secs(retry_after) - } else { - let delay = current_delay.min(DEFAULT_MAX_RETRY_INTERVAL_MS); - current_delay = - (current_delay as f64 * DEFAULT_BACKOFF_MULTIPLIER) as u64; - Duration::from_millis(delay) - }; - - sleep(delay).await; - continue; - } - Err(e) => { - tracing::error!( - "Error response from Azure OpenAI (attempt {}): {:?}", - attempts + 1, - e - ); - return Err(e); - } - }, - Err(e) => { - tracing::error!( - "Request failed (attempt {}): {:?}\nIs timeout: {}\nIs connect: {}\nIs request: {}", - attempts + 1, - e, - e.is_timeout(), - e.is_connect(), - e.is_request(), - ); - - // For timeout errors, we should retry - if e.is_timeout() { - attempts += 1; - let delay = current_delay.min(DEFAULT_MAX_RETRY_INTERVAL_MS); - current_delay = (current_delay as f64 * DEFAULT_BACKOFF_MULTIPLIER) as u64; - sleep(Duration::from_millis(delay)).await; - continue; - } - - return Err(ProviderError::RequestFailed(format!( - "Request failed: {}", - e - ))); - } - } - } + // Build the path for Azure OpenAI + let path = format!( + "openai/deployments/{}/chat/completions?api-version={}", + self.deployment_name, self.api_version + ); + + let response = self.api_client.response_post(&path, payload).await?; + handle_response_openai_compat(response).await } } @@ -245,7 +145,12 @@ impl Provider for AzureProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; - let response = self.post(&payload).await?; + let response = self + .with_retry(|| async { + let payload_clone = payload.clone(); + self.post(&payload_clone).await + }) + .await?; let message = response_to_message(&response)?; let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 6366deeea571..4498e6b4a617 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -3,6 +3,7 @@ use futures::Stream; use serde::{Deserialize, Serialize}; use super::errors::ProviderError; +use super::retry::RetryConfig; use crate::message::Message; use crate::model::ModelConfig; use crate::utils::safe_truncate; @@ -286,8 +287,12 @@ pub trait Provider: Send + Sync { /// Get the model config from the provider fn get_model_config(&self) -> ModelConfig; - /// Optional hook to fetch supported models asynchronously. - async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { + fn retry_config(&self) -> RetryConfig { + RetryConfig::default() + } + + /// Optional hook to fetch supported models. + async fn fetch_supported_models(&self) -> Result>, ProviderError> { Ok(None) } diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index ac823dff6808..4fc23b60e4c0 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; -use std::time::Duration; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; +use super::retry::ProviderRetry; use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; @@ -14,7 +14,6 @@ use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; use rmcp::model::Tool; use serde_json::Value; -use tokio::time::sleep; // Import the migrated helper functions from providers/formats/bedrock.rs use super::formats::bedrock::{ @@ -68,6 +67,68 @@ impl BedrockProvider { Ok(Self { client, model }) } + + async fn converse( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(bedrock::Message, Option), ProviderError> { + let model_name = &self.model.model_name; + + let mut request = self + .client + .converse() + .system(bedrock::SystemContentBlock::Text(system.to_string())) + .model_id(model_name.to_string()) + .set_messages(Some( + messages + .iter() + .map(to_bedrock_message) + .collect::>()?, + )); + + if !tools.is_empty() { + request = request.tool_config(to_bedrock_tool_config(tools)?); + } + + let response = request + .send() + .await + .map_err(|err| match err.into_service_error() { + ConverseError::ThrottlingException(throttle_err) => { + ProviderError::RateLimitExceeded(format!( + "Bedrock throttling error: {:?}", + throttle_err + )) + } + ConverseError::AccessDeniedException(err) => { + ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err)) + } + ConverseError::ValidationException(err) + if err + .message() + .unwrap_or_default() + .contains("Input is too long for requested model.") => + { + ProviderError::ContextLengthExceeded(format!( + "Failed to call Bedrock: {:?}", + err + )) + } + ConverseError::ModelErrorException(err) => { + ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err)) + } + err => ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err)), + })?; + + match response.output { + Some(bedrock::ConverseOutput::Message(message)) => Ok((message, response.usage)), + _ => Err(ProviderError::RequestFailed( + "No output from Bedrock".to_string(), + )), + } + } } impl_provider_default!(BedrockProvider); @@ -102,132 +163,31 @@ impl Provider for BedrockProvider { ) -> Result<(Message, ProviderUsage), ProviderError> { let model_name = &self.model.model_name; - let mut request = self - .client - .converse() - .system(bedrock::SystemContentBlock::Text(system.to_string())) - .model_id(model_name.to_string()) - .set_messages(Some( - messages - .iter() - .map(to_bedrock_message) - .collect::>()?, - )); - - if !tools.is_empty() { - request = request.tool_config(to_bedrock_tool_config(tools)?); - } - - // Retry configuration - const MAX_RETRIES: u32 = 10; - const INITIAL_BACKOFF_MS: u64 = 20_000; // 20 seconds - const MAX_BACKOFF_MS: u64 = 120_000; // 120 seconds (2 minutes) - - let mut attempts = 0; - let mut backoff_ms = INITIAL_BACKOFF_MS; - - loop { - attempts += 1; - - match request.clone().send().await { - Ok(response) => { - // Successful response, process it and return - return match response.output { - Some(bedrock::ConverseOutput::Message(message)) => { - let usage = response - .usage - .as_ref() - .map(from_bedrock_usage) - .unwrap_or_default(); - - let message = from_bedrock_message(&message)?; - - // Add debug trace with input context - let debug_payload = serde_json::json!({ - "system": system, - "messages": messages, - "tools": tools - }); - emit_debug_trace( - &self.model, - &debug_payload, - &serde_json::to_value(&message).unwrap_or_default(), - &usage, - ); - - let provider_usage = ProviderUsage::new(model_name.to_string(), usage); - Ok((message, provider_usage)) - } - _ => Err(ProviderError::RequestFailed( - "No output from Bedrock".to_string(), - )), - }; - } - Err(err) => { - match err.into_service_error() { - ConverseError::ThrottlingException(throttle_err) => { - if attempts > MAX_RETRIES { - // We've exhausted our retries - tracing::error!( - "Failed after {MAX_RETRIES} retries: {:?}", - throttle_err - ); - return Err(ProviderError::RateLimitExceeded(format!( - "Failed to call Bedrock after {MAX_RETRIES} retries: {:?}", - throttle_err - ))); - } - - // Log retry attempt - tracing::warn!( - "Bedrock throttling error (attempt {}/{}), retrying in {} ms: {:?}", - attempts, - MAX_RETRIES, - backoff_ms, - throttle_err - ); - - // Wait before retry with exponential backoff - sleep(Duration::from_millis(backoff_ms)).await; - - // Calculate next backoff with exponential growth, capped at max - backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS); - - // Continue to the next retry attempt - continue; - } - ConverseError::AccessDeniedException(err) => { - return Err(ProviderError::Authentication(format!( - "Failed to call Bedrock: {:?}", - err - ))); - } - ConverseError::ValidationException(err) - if err - .message() - .unwrap_or_default() - .contains("Input is too long for requested model.") => - { - return Err(ProviderError::ContextLengthExceeded(format!( - "Failed to call Bedrock: {:?}", - err - ))); - } - ConverseError::ModelErrorException(err) => { - return Err(ProviderError::ExecutionError(format!( - "Failed to call Bedrock: {:?}", - err - ))); - } - err => { - return Err(ProviderError::ServerError(format!( - "Failed to call Bedrock: {:?}", - err - ))); - } - } - } - } - } + let (bedrock_message, bedrock_usage) = self + .with_retry(|| self.converse(system, messages, tools)) + .await?; + + let usage = bedrock_usage + .as_ref() + .map(from_bedrock_usage) + .unwrap_or_default(); + + let message = from_bedrock_message(&bedrock_message)?; + + // Add debug trace with input context + let debug_payload = serde_json::json!({ + "system": system, + "messages": messages, + "tools": tools + }); + emit_debug_trace( + &self.model, + &debug_payload, + &serde_json::to_value(&message).unwrap_or_default(), + &usage, + ); + + let provider_usage = ProviderUsage::new(model_name.to_string(), usage); + Ok((message, provider_usage)) } } diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index cb75a5585f29..40261695bd79 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -2,7 +2,6 @@ use anyhow::Result; use async_stream::try_stream; use async_trait::async_trait; use futures::TryStreamExt; -use reqwest::{Client, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::io; @@ -10,43 +9,34 @@ use std::time::Duration; use tokio::pin; use tokio_util::io::StreamReader; +use super::api_client::{ApiClient, AuthMethod, AuthProvider}; use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::embedding::EmbeddingCapable; use super::errors::ProviderError; use super::formats::databricks::{create_request, response_to_message}; use super::oauth; -use super::utils::{get_model, ImageFormat}; +use super::retry::ProviderRetry; +use super::utils::{get_model, handle_response_openai_compat, ImageFormat}; use crate::config::ConfigError; use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::formats::openai::{get_usage, response_to_streaming_message}; +use crate::providers::retry::{ + RetryConfig, DEFAULT_BACKOFF_MULTIPLIER, DEFAULT_INITIAL_RETRY_INTERVAL_MS, + DEFAULT_MAX_RETRIES, DEFAULT_MAX_RETRY_INTERVAL_MS, +}; use rmcp::model::Tool; use serde_json::json; -use tokio::time::sleep; use tokio_stream::StreamExt; use tokio_util::codec::{FramedRead, LinesCodec}; -use url::Url; const DEFAULT_CLIENT_ID: &str = "databricks-cli"; const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; -// "offline_access" scope is used to request an OAuth 2.0 Refresh Token -// https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"]; - -/// Default timeout for API requests in seconds const DEFAULT_TIMEOUT_SECS: u64 = 600; -/// Default initial interval for retry (in milliseconds) -const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000; -/// Default maximum number of retries -const DEFAULT_MAX_RETRIES: usize = 6; -/// Default retry backoff multiplier -const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0; -/// Default maximum interval for retry (in milliseconds) -const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000; pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet"; -// Databricks can passthrough to a wide range of models, we only provide the default pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[ "databricks-meta-llama-3-3-70b-instruct", "databricks-meta-llama-3-1-405b-instruct", @@ -57,53 +47,6 @@ pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[ pub const DATABRICKS_DOC_URL: &str = "https://docs.databricks.com/en/generative-ai/external-models/index.html"; -/// Retry configuration for handling rate limit errors -#[derive(Debug, Clone)] -struct RetryConfig { - /// Maximum number of retry attempts - max_retries: usize, - /// Initial interval between retries in milliseconds - initial_interval_ms: u64, - /// Multiplier for backoff (exponential) - backoff_multiplier: f64, - /// Maximum interval between retries in milliseconds - max_interval_ms: u64, -} - -impl Default for RetryConfig { - fn default() -> Self { - Self { - max_retries: DEFAULT_MAX_RETRIES, - initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS, - backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER, - max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS, - } - } -} - -impl RetryConfig { - /// Calculate the delay for a specific retry attempt (with jitter) - fn delay_for_attempt(&self, attempt: usize) -> Duration { - if attempt == 0 { - return Duration::from_millis(0); - } - - // Calculate exponential backoff - let exponent = (attempt - 1) as u32; - let base_delay_ms = (self.initial_interval_ms as f64 - * self.backoff_multiplier.powi(exponent as i32)) as u64; - - // Apply max limit - let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms); - - // Add jitter (+/-20% randomness) to avoid thundering herd problem - let jitter_factor = 0.8 + (rand::random::() * 0.4); // Between 0.8 and 1.2 - let jittered_delay_ms = (capped_delay_ms as f64 * jitter_factor) as u64; - - Duration::from_millis(jittered_delay_ms) - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub enum DatabricksAuth { Token(String), @@ -116,7 +59,6 @@ pub enum DatabricksAuth { } impl DatabricksAuth { - /// Create a new OAuth configuration with default values pub fn oauth(host: String) -> Self { Self::OAuth { host, @@ -125,16 +67,36 @@ impl DatabricksAuth { scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(), } } + pub fn token(token: String) -> Self { Self::Token(token) } } +struct DatabricksAuthProvider { + auth: DatabricksAuth, +} + +#[async_trait] +impl AuthProvider for DatabricksAuthProvider { + async fn get_auth_header(&self) -> Result<(String, String)> { + let token = match &self.auth { + DatabricksAuth::Token(token) => token.clone(), + DatabricksAuth::OAuth { + host, + client_id, + redirect_url, + scopes, + } => oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?, + }; + Ok(("Authorization".to_string(), format!("Bearer {}", token))) + } +} + #[derive(Debug, serde::Serialize)] pub struct DatabricksProvider { #[serde(skip)] - client: Client, - host: String, + api_client: ApiClient, auth: DatabricksAuth, model: ModelConfig, image_format: ImageFormat, @@ -148,8 +110,6 @@ impl DatabricksProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); - // For compatibility for now we check both config and secret for databricks host - // but it is not actually a secret value let mut host: Result = config.get_param("DATABRICKS_HOST"); if host.is_err() { host = config.get_secret("DATABRICKS_HOST") @@ -163,38 +123,29 @@ impl DatabricksProvider { } let host = host?; + let retry_config = Self::load_retry_config(config); - let client = Client::builder() - .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS)) - .build()?; + let auth = if let Ok(api_key) = config.get_secret("DATABRICKS_TOKEN") { + DatabricksAuth::token(api_key) + } else { + DatabricksAuth::oauth(host.clone()) + }; - // Load optional retry configuration from environment - let retry_config = Self::load_retry_config(config); + let auth_method = + AuthMethod::Custom(Box::new(DatabricksAuthProvider { auth: auth.clone() })); - // If we find a databricks token we prefer that - if let Ok(api_key) = config.get_secret("DATABRICKS_TOKEN") { - return Ok(Self { - client, - host, - auth: DatabricksAuth::token(api_key), - model, - image_format: ImageFormat::OpenAi, - retry_config, - }); - } + let api_client = + ApiClient::with_timeout(host, auth_method, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?; - // Otherwise use Oauth flow Ok(Self { - client, - auth: DatabricksAuth::oauth(host.clone()), - host, + api_client, + auth, model, image_format: ImageFormat::OpenAi, retry_config, }) } - /// Loads retry configuration from environment variables or uses defaults. fn load_retry_config(config: &crate::config::Config) -> RetryConfig { let max_retries = config .get_param("DATABRICKS_MAX_RETRIES") @@ -228,184 +179,36 @@ impl DatabricksProvider { } } - /// Create a new DatabricksProvider with the specified host and token - /// - /// # Arguments - /// - /// * `host` - The Databricks host URL - /// * `token` - The Databricks API token - /// - /// # Returns - /// - /// Returns a Result containing the new DatabricksProvider instance pub fn from_params(host: String, api_key: String, model: ModelConfig) -> Result { - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + let auth = DatabricksAuth::token(api_key); + let auth_method = + AuthMethod::Custom(Box::new(DatabricksAuthProvider { auth: auth.clone() })); + + let api_client = ApiClient::with_timeout(host, auth_method, Duration::from_secs(600))?; Ok(Self { - client, - host, - auth: DatabricksAuth::token(api_key), + api_client, + auth, model, image_format: ImageFormat::OpenAi, retry_config: RetryConfig::default(), }) } - async fn ensure_auth_header(&self) -> Result { - match &self.auth { - DatabricksAuth::Token(token) => Ok(format!("Bearer {}", token)), - DatabricksAuth::OAuth { - host, - client_id, - redirect_url, - scopes, - } => { - let token = - oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?; - Ok(format!("Bearer {}", token)) - } - } - } - - async fn post(&self, payload: &Value) -> Result { - // Check if this is an embedding request by looking at the payload structure - let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none(); - let path = if is_embedding { - // For embeddings, use the embeddings endpoint - format!("serving-endpoints/{}/invocations", "text-embedding-3-small") + fn get_endpoint_path(&self, is_embedding: bool) -> String { + if is_embedding { + "serving-endpoints/text-embedding-3-small/invocations".to_string() } else { - // For chat completions, use the model name in the path format!("serving-endpoints/{}/invocations", self.model.model_name) - }; - - match self.post_with_retry(path.as_str(), payload).await { - Ok(res) => res.json().await.map_err(|_| { - ProviderError::RequestFailed("Response body is not valid JSON".to_string()) - }), - Err(e) => Err(e), } } - async fn post_with_retry( - &self, - path: &str, - payload: &Value, - ) -> Result { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join(path).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - let mut attempts = 0; - loop { - let auth_header = self.ensure_auth_header().await?; - let response = self - .client - .post(url.clone()) - .header("Authorization", auth_header) - .json(payload) - .send() - .await?; - - let status = response.status(); + async fn post(&self, payload: Value) -> Result { + let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none(); + let path = self.get_endpoint_path(is_embedding); - break match status { - StatusCode::OK => Ok(response), - StatusCode::TOO_MANY_REQUESTS - | StatusCode::INTERNAL_SERVER_ERROR - | StatusCode::SERVICE_UNAVAILABLE => { - if attempts < self.retry_config.max_retries { - attempts += 1; - tracing::warn!( - "{}: retrying ({}/{})", - status, - attempts, - self.retry_config.max_retries - ); - - let delay = self.retry_config.delay_for_attempt(attempts); - tracing::info!("Backing off for {:?} before retry", delay); - sleep(delay).await; - - continue; - } - - Err(match status { - StatusCode::TOO_MANY_REQUESTS => { - ProviderError::RateLimitExceeded("Rate limit exceeded".to_string()) - } - _ => ProviderError::ServerError("Server error".to_string()), - }) - } - StatusCode::BAD_REQUEST => { - // Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific - // We try to extract the error message from the payload and check for phrases that indicate context length exceeded - let bytes = response.bytes().await?; - let payload_str = String::from_utf8_lossy(&bytes).to_lowercase(); - let check_phrases = [ - "too long", - "context length", - "context_length_exceeded", - "reduce the length", - "token count", - "exceeds", - "exceed context limit", - "input length", - "max_tokens", - "decrease input length", - "context limit", - ]; - if check_phrases.iter().any(|c| payload_str.contains(c)) { - return Err(ProviderError::ContextLengthExceeded(payload_str)); - } - - let mut error_msg = "Unknown error".to_string(); - if let Ok(response_json) = serde_json::from_slice::(&bytes) { - // try to convert message to string, if that fails use external_model_message - error_msg = response_json - .get("message") - .and_then(|m| m.as_str()) - .or_else(|| { - response_json - .get("external_model_message") - .and_then(|ext| ext.get("message")) - .and_then(|m| m.as_str()) - }) - .unwrap_or("Unknown error") - .to_string(); - } - - tracing::debug!( - "{}", - format!( - "Provider request failed with status: {}. Payload: {:?}", - status, payload_str - ) - ); - return Err(ProviderError::RequestFailed(format!( - "Request failed with status: {}. Message: {}", - status, error_msg - ))); - } - _ => { - tracing::debug!( - "{}", - format!( - "Provider request failed with status: {}. Payload: {:?}", - status, - response.text().await.ok().unwrap_or_default() - ) - ); - return Err(ProviderError::RequestFailed(format!( - "Request failed with status: {}", - status - ))); - } - }; - } + let response = self.api_client.response_post(&path, &payload).await?; + handle_response_openai_compat(response).await } } @@ -426,6 +229,10 @@ impl Provider for DatabricksProvider { ) } + fn retry_config(&self) -> RetryConfig { + self.retry_config.clone() + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } @@ -441,15 +248,13 @@ impl Provider for DatabricksProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { let mut payload = create_request(&self.model, system, messages, tools, &self.image_format)?; - // Remove the model key which is part of the url with databricks payload .as_object_mut() .expect("payload should have model key") .remove("model"); - let response = self.post(&payload).await?; + let response = self.with_retry(|| self.post(payload.clone())).await?; - // Parse response let message = response_to_message(&response)?; let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { tracing::debug!("Failed to get usage data"); @@ -468,7 +273,6 @@ impl Provider for DatabricksProvider { tools: &[Tool], ) -> Result { let mut payload = create_request(&self.model, system, messages, tools, &self.image_format)?; - // Remove the model key which is part of the url with databricks payload .as_object_mut() .expect("payload should have model key") @@ -479,18 +283,24 @@ impl Provider for DatabricksProvider { .unwrap() .insert("stream".to_string(), Value::Bool(true)); + let path = self.get_endpoint_path(false); let response = self - .post_with_retry( - format!("serving-endpoints/{}/invocations", self.model.model_name).as_str(), - &payload, - ) + .with_retry(|| async { + let resp = self.api_client.response_post(&path, &payload).await?; + if !resp.status().is_success() { + return Err(ProviderError::RequestFailed(format!( + "HTTP {}: {}", + resp.status(), + resp.text().await.unwrap_or_default() + ))); + } + Ok(resp) + }) .await?; - // Map reqwest error to io::Error let stream = response.bytes_stream().map_err(io::Error::other); - let model_config = self.model.clone(); - // Wrap in a line decoder and yield lines inside the stream + Ok(Box::pin(try_stream! { let stream_reader = StreamReader::new(stream); let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); @@ -519,32 +329,16 @@ impl Provider for DatabricksProvider { .map_err(|e| ProviderError::ExecutionError(e.to_string())) } - async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("api/2.0/serving-endpoints").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - let auth_header = match self.ensure_auth_header().await { - Ok(header) => header, - Err(e) => { - tracing::warn!("Failed to authorize with Databricks: {}", e); - return Ok(None); // Return None to fall back to manual input - } - }; - + async fn fetch_supported_models(&self) -> Result>, ProviderError> { let response = match self - .client - .get(url) - .header("Authorization", auth_header) - .send() + .api_client + .response_get("api/2.0/serving-endpoints") .await { Ok(resp) => resp, Err(e) => { tracing::warn!("Failed to fetch Databricks models: {}", e); - return Ok(None); // Return None to fall back to manual input + return Ok(None); } }; @@ -559,7 +353,7 @@ impl Provider for DatabricksProvider { } else { tracing::warn!("Failed to fetch Databricks models: {}", status); } - return Ok(None); // Return None to fall back to manual input + return Ok(None); } let json: Value = match response.json().await { @@ -610,12 +404,11 @@ impl EmbeddingCapable for DatabricksProvider { return Ok(vec![]); } - // Create request in Databricks format for embeddings let request = json!({ "input": texts, }); - let response = self.post(&request).await?; + let response = self.with_retry(|| self.post(request.clone())).await?; let embeddings = response["data"] .as_array() diff --git a/crates/goose/src/providers/errors.rs b/crates/goose/src/providers/errors.rs index 3ff7d1880ca0..0060f6a6f764 100644 --- a/crates/goose/src/providers/errors.rs +++ b/crates/goose/src/providers/errors.rs @@ -30,13 +30,16 @@ pub enum ProviderError { impl From for ProviderError { fn from(error: anyhow::Error) -> Self { + if let Some(reqwest_err) = error.downcast_ref::() { + return ProviderError::RequestFailed(reqwest_err.to_string()); + } ProviderError::ExecutionError(error.to_string()) } } impl From for ProviderError { fn from(error: reqwest::Error) -> Self { - ProviderError::ExecutionError(error.to_string()) + ProviderError::RequestFailed(error.to_string()) } } diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index e5d58f303275..d038127358eb 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -21,6 +21,7 @@ use crate::providers::formats::gcpvertexai::{ use crate::impl_provider_default; use crate::providers::formats::gcpvertexai::GcpLocation::Iowa; use crate::providers::gcpauth::GcpAuth; +use crate::providers::retry::RetryConfig; use crate::providers::utils::emit_debug_trace; use rmcp::model::Tool; @@ -52,69 +53,6 @@ enum GcpVertexAIError { AuthError(String), } -/// Retry configuration for handling rate limit errors -#[derive(Debug, Clone)] -struct RetryConfig { - /// Maximum number of retry attempts for 429 errors - max_rate_limit_retries: usize, - /// Maximum number of retry attempts for 529 errors - max_overloaded_retries: usize, - /// Initial interval between retries in milliseconds - initial_interval_ms: u64, - /// Multiplier for backoff (exponential) - backoff_multiplier: f64, - /// Maximum interval between retries in milliseconds - max_interval_ms: u64, -} - -impl Default for RetryConfig { - fn default() -> Self { - Self { - max_rate_limit_retries: DEFAULT_MAX_RETRIES, - max_overloaded_retries: DEFAULT_MAX_RETRIES, - initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS, - backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER, - max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS, - } - } -} - -impl RetryConfig { - /// Calculate the delay for a specific retry attempt (with jitter) - fn delay_for_attempt(&self, attempt: usize) -> Duration { - if attempt == 0 { - return Duration::from_millis(0); - } - - // Calculate exponential backoff - let exponent = (attempt - 1) as u32; - let base_delay_ms = (self.initial_interval_ms as f64 - * self.backoff_multiplier.powi(exponent as i32)) as u64; - - // Apply max limit - let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms); - - // Add jitter (+/-20% randomness) to avoid thundering herd problem - let jitter_factor = 0.8 + (rand::random::() * 0.4); // Between 0.8 and 1.2 - let jittered_delay_ms = (capped_delay_ms as f64 * jitter_factor) as u64; - - Duration::from_millis(jittered_delay_ms) - } - - /// Get max retries for a specific error type - #[allow(dead_code)] // Used in tests - fn max_retries_for_status(&self, status: StatusCode) -> usize { - if status == StatusCode::TOO_MANY_REQUESTS { - self.max_rate_limit_retries - } else if status == *STATUS_API_OVERLOADED { - self.max_overloaded_retries - } else { - // Default to rate limit retries for any other status code - self.max_rate_limit_retries - } - } -} - /// Provider implementation for Google Cloud Platform's Vertex AI service. /// /// This provider enables interaction with various AI models hosted on GCP Vertex AI, @@ -194,31 +132,10 @@ impl GcpVertexAIProvider { /// Loads retry configuration from environment variables or uses defaults. fn load_retry_config(config: &crate::config::Config) -> RetryConfig { // Load max retries for 429 rate limit errors - let max_rate_limit_retries = config - .get_param("GCP_MAX_RATE_LIMIT_RETRIES") + let max_retries = config + .get_param("GCP_MAX_RETRIES") .ok() .and_then(|v: String| v.parse::().ok()) - .or_else(|| { - // Fall back to generic GCP_MAX_RETRIES if specific one isn't set - config - .get_param("GCP_MAX_RETRIES") - .ok() - .and_then(|v: String| v.parse::().ok()) - }) - .unwrap_or(DEFAULT_MAX_RETRIES); - - // Load max retries for 529 API overloaded errors - let max_overloaded_retries = config - .get_param("GCP_MAX_OVERLOADED_RETRIES") - .ok() - .and_then(|v: String| v.parse::().ok()) - .or_else(|| { - // Fall back to generic GCP_MAX_RETRIES if specific one isn't set - config - .get_param("GCP_MAX_RETRIES") - .ok() - .and_then(|v: String| v.parse::().ok()) - }) .unwrap_or(DEFAULT_MAX_RETRIES); let initial_interval_ms = config @@ -239,13 +156,12 @@ impl GcpVertexAIProvider { .and_then(|v: String| v.parse::().ok()) .unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS); - RetryConfig { - max_rate_limit_retries, - max_overloaded_retries, + RetryConfig::new( + max_retries, initial_interval_ms, backoff_multiplier, max_interval_ms, - } + ) } /// Determines the appropriate GCP location for model deployment. @@ -335,6 +251,18 @@ impl GcpVertexAIProvider { let mut last_error = None; loop { + // Check if we've exceeded max retries + if rate_limit_attempts > self.retry_config.max_retries + && overloaded_attempts > self.retry_config.max_retries + { + let error_msg = format!( + "Exceeded maximum retry attempts ({}) for rate limiting errors", + self.retry_config.max_retries + ); + tracing::error!("{}", error_msg); + return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg))); + } + // Get a fresh auth token for each attempt let auth_header = self .get_auth_header() @@ -358,10 +286,10 @@ impl GcpVertexAIProvider { status if status == StatusCode::TOO_MANY_REQUESTS => { rate_limit_attempts += 1; - if rate_limit_attempts > self.retry_config.max_rate_limit_retries { + if rate_limit_attempts > self.retry_config.max_retries { let error_msg = format!( "Exceeded maximum retry attempts ({}) for rate limiting (429) errors", - self.retry_config.max_rate_limit_retries + self.retry_config.max_retries ); tracing::error!("{}", error_msg); return Err( @@ -386,7 +314,7 @@ impl GcpVertexAIProvider { tracing::warn!( "Rate limit exceeded error (429) (attempt {}/{}): {}. Retrying after backoff...", rate_limit_attempts, - self.retry_config.max_rate_limit_retries, + self.retry_config.max_retries, error_message ); @@ -401,10 +329,10 @@ impl GcpVertexAIProvider { status if status == *STATUS_API_OVERLOADED => { overloaded_attempts += 1; - if overloaded_attempts > self.retry_config.max_overloaded_retries { + if overloaded_attempts > self.retry_config.max_retries { let error_msg = format!( "Exceeded maximum retry attempts ({}) for API overloaded (529) errors", - self.retry_config.max_overloaded_retries + self.retry_config.max_retries ); tracing::error!("{}", error_msg); return Err( @@ -421,7 +349,7 @@ impl GcpVertexAIProvider { tracing::warn!( "API overloaded error (529) (attempt {}/{}): {}. Retrying after backoff...", overloaded_attempts, - self.retry_config.max_overloaded_retries, + self.retry_config.max_retries, error_message ); @@ -549,18 +477,6 @@ impl Provider for GcpVertexAIProvider { vec![ ConfigKey::new("GCP_PROJECT_ID", true, false, None), ConfigKey::new("GCP_LOCATION", true, false, Some(Iowa.to_string().as_str())), - ConfigKey::new( - "GCP_MAX_RATE_LIMIT_RETRIES", - false, - false, - Some(&DEFAULT_MAX_RETRIES.to_string()), - ), - ConfigKey::new( - "GCP_MAX_OVERLOADED_RETRIES", - false, - false, - Some(&DEFAULT_MAX_RETRIES.to_string()), - ), ConfigKey::new( "GCP_MAX_RETRIES", false, @@ -634,13 +550,7 @@ mod tests { #[test] fn test_retry_config_delay_calculation() { - let config = RetryConfig { - max_rate_limit_retries: 5, - max_overloaded_retries: 5, - initial_interval_ms: 1000, - backoff_multiplier: 2.0, - max_interval_ms: 32000, - }; + let config = RetryConfig::new(5, 1000, 2.0, 32000); // First attempt has no delay let delay0 = config.delay_for_attempt(0); @@ -659,27 +569,6 @@ mod tests { assert!(delay10.as_millis() <= 38400); // max_interval_ms * 1.2 (max jitter) } - #[test] - fn test_max_retries_for_status() { - let config = RetryConfig { - max_rate_limit_retries: 5, - max_overloaded_retries: 10, - initial_interval_ms: 1000, - backoff_multiplier: 2.0, - max_interval_ms: 32000, - }; - - // Check that we get the right max retries for each error type - assert_eq!( - config.max_retries_for_status(StatusCode::TOO_MANY_REQUESTS), - 5 - ); - assert_eq!(config.max_retries_for_status(*STATUS_API_OVERLOADED), 10); - - // For any other status code, we should get the rate limit retries - assert_eq!(config.max_retries_for_status(StatusCode::BAD_REQUEST), 5); - } - #[test] fn test_status_overloaded_code() { // Test that we correctly handle the 529 status code @@ -742,7 +631,7 @@ mod tests { assert!(model_names.contains(&"claude-3-5-sonnet-v2@20241022".to_string())); assert!(model_names.contains(&"gemini-1.5-pro-002".to_string())); assert!(model_names.contains(&"gemini-2.5-pro".to_string())); - // Should contain the original 2 config keys plus 6 new retry-related ones - assert_eq!(metadata.config_keys.len(), 8); + // Should contain the original 2 config keys plus 4 new retry-related ones + assert_eq!(metadata.config_keys.len(), 6); } } diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 899ce0cb4033..202aebf9d331 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -14,6 +14,7 @@ use std::time::Duration; use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; +use super::retry::ProviderRetry; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::config::{Config, ConfigError}; @@ -404,11 +405,15 @@ impl Provider for GithubCopilotProvider { messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - let mut payload = - create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; + let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; - // Make request - let response = self.post(&mut payload).await?; + // Make request with retry + let response = self + .with_retry(|| async { + let mut payload_clone = payload.clone(); + self.post(&mut payload_clone).await + }) + .await?; // Parse response let message = response_to_message(&response)?; @@ -422,7 +427,7 @@ impl Provider for GithubCopilotProvider { } /// Fetch supported models from GitHub Copliot; returns Err on failure, Ok(None) if not present - async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { let (endpoint, token) = self.get_api_info().await?; let url = format!("{}/models", endpoint); diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index a499b08c51a4..fa401880bb44 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -1,20 +1,16 @@ +use super::api_client::{ApiClient, AuthMethod}; use super::errors::ProviderError; +use super::retry::ProviderRetry; +use super::utils::{emit_debug_trace, handle_response_google_compat, unescape_json_values}; use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use crate::providers::formats::google::{create_request, get_usage, response_to_message}; -use crate::providers::utils::{ - emit_debug_trace, handle_response_google_compat, unescape_json_values, -}; use anyhow::Result; use async_trait::async_trait; -use axum::http::HeaderMap; -use reqwest::Client; use rmcp::model::Tool; use serde_json::Value; -use std::time::Duration; -use url::Url; pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com"; pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-2.5-flash"; @@ -51,8 +47,7 @@ pub const GOOGLE_DOC_URL: &str = "https://ai.google.dev/gemini-api/docs/models"; #[derive(Debug, serde::Serialize)] pub struct GoogleProvider { #[serde(skip)] - client: Client, - host: String, + api_client: ApiClient, model: ModelConfig, } @@ -66,77 +61,21 @@ impl GoogleProvider { .get_param("GOOGLE_HOST") .unwrap_or_else(|_| GOOGLE_API_HOST.to_string()); - let mut headers = HeaderMap::new(); - headers.insert("CONTENT_TYPE", "application/json".parse()?); - headers.insert("x-goog-api-key", api_key.parse()?); + let auth = AuthMethod::ApiKey { + header_name: "x-goog-api-key".to_string(), + key: api_key, + }; - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .default_headers(headers) - .build()?; + let api_client = + ApiClient::new(host, auth)?.with_header("Content-Type", "application/json")?; - Ok(Self { - client, - host, - model, - }) + Ok(Self { api_client, model }) } async fn post(&self, payload: &Value) -> Result { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - - let url = base_url - .join(&format!( - "v1beta/models/{}:generateContent", - self.model.model_name - )) - .map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - let max_retries = 3; - let mut retries = 0; - let base_delay = Duration::from_secs(2); - - loop { - let response = self - .client - .post(url.clone()) // Clone the URL for each retry - .json(&payload) - .send() - .await; - - match response { - Ok(res) => { - match handle_response_google_compat(res).await { - Ok(result) => return Ok(result), - Err(ProviderError::RateLimitExceeded(_)) => { - retries += 1; - if retries > max_retries { - return Err(ProviderError::RateLimitExceeded( - "Max retries exceeded for rate limit error".to_string(), - )); - } - - let delay = 2u64.pow(retries); - let total_delay = Duration::from_secs(delay) + base_delay; - - println!("Rate limit hit. Retrying in {:?}", total_delay); - tokio::time::sleep(total_delay).await; - continue; - } - Err(err) => return Err(err), // Other errors - } - } - Err(err) => { - return Err(ProviderError::RequestFailed(format!( - "Request failed: {}", - err - ))); - } - } - } + let path = format!("v1beta/models/{}:generateContent", self.model.model_name); + let response = self.api_client.response_post(&path, payload).await?; + handle_response_google_compat(response).await } } @@ -174,7 +113,12 @@ impl Provider for GoogleProvider { let payload = create_request(&self.model, system, messages, tools)?; // Make request - let response = self.post(&payload).await?; + let response = self + .with_retry(|| async { + let payload_clone = payload.clone(); + self.post(&payload_clone).await + }) + .await?; // Parse response let message = response_to_message(unescape_json_values(&response))?; @@ -189,12 +133,9 @@ impl Provider for GoogleProvider { } /// Fetch supported models from Google Generative Language API; returns Err on failure, Ok(None) if not present - async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { - // List models via the v1beta/models endpoint - let url = format!("{}/v1beta/models", self.host); - let response = self.client.get(&url).send().await?; + async fn fetch_supported_models(&self) -> Result>, ProviderError> { + let response = self.api_client.response_get("v1beta/models").await?; let json: serde_json::Value = response.json().await?; - // If 'models' field missing, return None let arr = match json.get("models").and_then(|v| v.as_array()) { Some(arr) => arr, None => return Ok(None), diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 3a9d48a787f3..d840830dc699 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -1,17 +1,16 @@ +use super::api_client::{ApiClient, AuthMethod}; use super::errors::ProviderError; +use super::retry::ProviderRetry; +use super::utils::{get_model, handle_response_openai_compat}; use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; -use crate::providers::utils::get_model; use anyhow::Result; use async_trait::async_trait; -use reqwest::{Client, StatusCode}; use rmcp::model::Tool; use serde_json::Value; -use std::time::Duration; -use url::Url; pub const GROQ_API_HOST: &str = "https://api.groq.com"; pub const GROQ_DEFAULT_MODEL: &str = "moonshotai/kimi-k2-instruct"; @@ -27,9 +26,7 @@ pub const GROQ_DOC_URL: &str = "https://console.groq.com/docs/models"; #[derive(serde::Serialize)] pub struct GroqProvider { #[serde(skip)] - client: Client, - host: String, - api_key: String, + api_client: ApiClient, model: ModelConfig, } @@ -43,58 +40,18 @@ impl GroqProvider { .get_param("GROQ_HOST") .unwrap_or_else(|_| GROQ_API_HOST.to_string()); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + let auth = AuthMethod::BearerToken(api_key); + let api_client = ApiClient::new(host, auth)?; - Ok(Self { - client, - host, - api_key, - model, - }) + Ok(Self { api_client, model }) } - async fn post(&self, payload: &Value) -> anyhow::Result { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("openai/v1/chat/completions").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - + async fn post(&self, payload: Value) -> Result { let response = self - .client - .post(url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(payload) - .send() + .api_client + .response_post("openai/v1/chat/completions", &payload) .await?; - - let status = response.status(); - let response_payload: Option = response.json().await.ok(); - let formatted_payload = format!("{:?}", response_payload); - - match status { - StatusCode::OK => response_payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, response_payload))) - } - StatusCode::PAYLOAD_TOO_LARGE => { - Err(ProviderError::ContextLengthExceeded(formatted_payload)) - } - StatusCode::TOO_MANY_REQUESTS => { - Err(ProviderError::RateLimitExceeded(formatted_payload)) - } - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - Err(ProviderError::ServerError(formatted_payload)) - } - _ => { - let error_msg = format!("Provider request failed with status: {}. Payload: {:?}", status, response_payload); - tracing::debug!(error_msg); - Err(ProviderError::RequestFailed(error_msg)) - } - } + handle_response_openai_compat(response).await } } @@ -128,7 +85,7 @@ impl Provider for GroqProvider { system: &str, messages: &[Message], tools: &[Tool], - ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { + ) -> Result<(Message, ProviderUsage), ProviderError> { let payload = create_request( &self.model, system, @@ -137,7 +94,7 @@ impl Provider for GroqProvider { &super::utils::ImageFormat::OpenAi, )?; - let response = self.post(&payload).await?; + let response = self.with_retry(|| self.post(payload.clone())).await?; let message = response_to_message(&response)?; let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { @@ -150,57 +107,27 @@ impl Provider for GroqProvider { } /// Fetch supported models from Groq; returns Err on failure, Ok(None) if no models found - async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { - // Construct the Groq models endpoint - let base_url = url::Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {}", e)))?; - let url = base_url.join("openai/v1/models").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {}", e)) - })?; - - // Build the request with required headers - let request = self - .client - .get(url) - .bearer_auth(&self.api_key) - .header("Content-Type", "application/json"); - - // Send request - let response = request.send().await?; - let status = response.status(); - let payload: serde_json::Value = response.json().await.map_err(|_| { - ProviderError::RequestFailed("Response body is not valid JSON".to_string()) - })?; - - // Check for error response from API - if let Some(err_obj) = payload.get("error") { - let msg = err_obj - .get("message") - .and_then(|v| v.as_str()) - .unwrap_or("unknown error"); - return Err(ProviderError::Authentication(msg.to_string())); - } - - // Extract model names - if status == StatusCode::OK { - let data = payload - .get("data") - .and_then(|v| v.as_array()) - .ok_or_else(|| { - ProviderError::UsageError("Missing or invalid `data` field in response".into()) - })?; - - let mut model_names: Vec = data - .iter() - .filter_map(|m| m.get("id").and_then(Value::as_str).map(String::from)) - .collect(); - model_names.sort(); - Ok(Some(model_names)) - } else { - Err(ProviderError::RequestFailed(format!( - "Groq API returned error status: {}. Payload: {:?}", - status, payload - ))) - } + async fn fetch_supported_models(&self) -> Result>, ProviderError> { + let response = self + .api_client + .request("openai/v1/models") + .header("Content-Type", "application/json")? + .response_get() + .await?; + let response = handle_response_openai_compat(response).await?; + + let data = response + .get("data") + .and_then(|v| v.as_array()) + .ok_or_else(|| { + ProviderError::UsageError("Missing or invalid `data` field in response".into()) + })?; + + let mut model_names: Vec = data + .iter() + .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(String::from)) + .collect(); + model_names.sort(); + Ok(Some(model_names)) } } diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs index 4ddea247048d..6f1a5cd0e6c9 100644 --- a/crates/goose/src/providers/lead_worker.rs +++ b/crates/goose/src/providers/lead_worker.rs @@ -409,10 +409,10 @@ impl Provider for LeadWorkerProvider { final_result } - async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { // Combine models from both providers - let lead_models = self.lead_provider.fetch_supported_models_async().await?; - let worker_models = self.worker_provider.fetch_supported_models_async().await?; + let lead_models = self.lead_provider.fetch_supported_models().await?; + let worker_models = self.worker_provider.fetch_supported_models().await?; match (lead_models, worker_models) { (Some(lead), Some(worker)) => { diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index 6991d823dfbf..54a84ced522d 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -1,14 +1,13 @@ use anyhow::Result; use async_trait::async_trait; -use reqwest::Client; use serde_json::{json, Value}; use std::collections::HashMap; -use std::time::Duration; -use url::Url; +use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; use super::embedding::EmbeddingCapable; use super::errors::ProviderError; +use super::retry::ProviderRetry; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::impl_provider_default; use crate::message::Message; @@ -21,12 +20,9 @@ pub const LITELLM_DOC_URL: &str = "https://docs.litellm.ai/docs/"; #[derive(Debug, serde::Serialize)] pub struct LiteLLMProvider { #[serde(skip)] - client: Client, - host: String, + api_client: ApiClient, base_path: String, - api_key: String, model: ModelConfig, - custom_headers: Option>, } impl_provider_default!(LiteLLMProvider); @@ -49,44 +45,35 @@ impl LiteLLMProvider { .ok() .map(parse_custom_headers); let timeout_secs: u64 = config.get_param("LITELLM_TIMEOUT").unwrap_or(600); - let client = Client::builder() - .timeout(Duration::from_secs(timeout_secs)) - .build()?; + + let auth = if api_key.is_empty() { + AuthMethod::Custom(Box::new(NoAuth)) + } else { + AuthMethod::BearerToken(api_key) + }; + + let mut api_client = + ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?; + + if let Some(headers) = custom_headers { + let mut header_map = reqwest::header::HeaderMap::new(); + for (key, value) in headers { + let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())?; + let header_value = reqwest::header::HeaderValue::from_str(&value)?; + header_map.insert(header_name, header_value); + } + api_client = api_client.with_headers(header_map)?; + } Ok(Self { - client, - host, + api_client, base_path, - api_key, model, - custom_headers, }) } - fn add_headers(&self, mut request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - if let Some(custom_headers) = &self.custom_headers { - for (key, value) in custom_headers { - request = request.header(key, value); - } - } - - request - } - async fn fetch_models(&self) -> Result, ProviderError> { - let models_url = format!("{}/model/info", self.host); - - let mut req = self - .client - .get(&models_url) - .header("Authorization", format!("Bearer {}", self.api_key)); - - req = self.add_headers(req); - - let response = req - .send() - .await - .map_err(|e| ProviderError::RequestFailed(format!("Failed to fetch models: {}", e)))?; + let response = self.api_client.response_get("model/info").await?; if !response.status().is_success() { return Err(ProviderError::RequestFailed(format!( @@ -125,22 +112,22 @@ impl LiteLLMProvider { } async fn post(&self, payload: &Value) -> Result { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join(&self.base_path).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - let request = self - .client - .post(url) - .header("Authorization", format!("Bearer {}", self.api_key)); - - let request = self.add_headers(request); + let response = self + .api_client + .response_post(&self.base_path, payload) + .await?; + handle_response_openai_compat(response).await + } +} - let response = request.json(payload).send().await?; +// No authentication provider for LiteLLM when API key is not provided +struct NoAuth; - handle_response_openai_compat(response).await +#[async_trait] +impl super::api_client::AuthProvider for NoAuth { + async fn get_auth_header(&self) -> Result<(String, String)> { + // Return a dummy header that won't be used + Ok(("X-No-Auth".to_string(), "true".to_string())) } } @@ -192,7 +179,12 @@ impl Provider for LiteLLMProvider { payload = update_request_for_cache_control(&payload); } - let response = self.post(&payload).await?; + let response = self + .with_retry(|| async { + let payload_clone = payload.clone(); + self.post(&payload_clone).await + }) + .await?; let message = super::formats::openai::response_to_message(&response)?; let usage = super::formats::openai::get_usage(&response); @@ -217,7 +209,7 @@ impl Provider for LiteLLMProvider { self.model.model_name.to_lowercase().contains("claude") } - async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { + async fn fetch_supported_models(&self) -> Result>, ProviderError> { match self.fetch_models().await { Ok(models) => { let model_names: Vec = models.into_iter().map(|m| m.name).collect(); @@ -234,8 +226,6 @@ impl Provider for LiteLLMProvider { #[async_trait] impl EmbeddingCapable for LiteLLMProvider { async fn create_embeddings(&self, texts: Vec) -> Result>, anyhow::Error> { - let endpoint = format!("{}/v1/embeddings", self.host); - let embedding_model = std::env::var("GOOSE_EMBEDDING_MODEL") .unwrap_or_else(|_| "text-embedding-3-small".to_string()); @@ -245,16 +235,10 @@ impl EmbeddingCapable for LiteLLMProvider { "encoding_format": "float" }); - let mut req = self - .client - .post(&endpoint) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&payload); - - req = self.add_headers(req); - - let response = req.send().await?; + let response = self + .api_client + .response_post("v1/embeddings", &payload) + .await?; let response_text = response.text().await?; let response_json: Value = serde_json::from_str(&response_text)?; diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 38c810d4d171..3e04fba896ee 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -1,4 +1,5 @@ pub mod anthropic; +mod api_client; pub mod azure; pub mod azureauth; pub mod base; @@ -22,6 +23,7 @@ pub mod ollama; pub mod openai; pub mod openrouter; pub mod pricing; +mod retry; pub mod sagemaker_tgi; pub mod snowflake; pub mod testprovider; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 1fa9300a2457..0b16e21dc889 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,5 +1,7 @@ +use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use super::retry::ProviderRetry; use super::utils::{get_model, handle_response_openai_compat}; use crate::impl_provider_default; use crate::message::Message; @@ -9,7 +11,6 @@ use crate::utils::safe_truncate; use anyhow::Result; use async_trait::async_trait; use regex::Regex; -use reqwest::Client; use rmcp::model::Tool; use serde_json::Value; use std::time::Duration; @@ -26,8 +27,7 @@ pub const OLLAMA_DOC_URL: &str = "https://ollama.com/library"; #[derive(serde::Serialize)] pub struct OllamaProvider { #[serde(skip)] - client: Client, - host: String, + api_client: ApiClient, model: ModelConfig, } @@ -43,54 +43,53 @@ impl OllamaProvider { let timeout: Duration = Duration::from_secs(config.get_param("OLLAMA_TIMEOUT").unwrap_or(OLLAMA_TIMEOUT)); - let client = Client::builder().timeout(timeout).build()?; - - Ok(Self { - client, - host, - model, - }) - } - - /// Get the base URL for Ollama API calls - fn get_base_url(&self) -> Result { // OLLAMA_HOST is sometimes just the 'host' or 'host:port' without a scheme - let base = if self.host.starts_with("http://") || self.host.starts_with("https://") { - &self.host + let base = if host.starts_with("http://") || host.starts_with("https://") { + host.clone() } else { - &format!("http://{}", self.host) + format!("http://{}", host) }; - let mut base_url = Url::parse(base) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let mut base_url = + Url::parse(&base).map_err(|e| anyhow::anyhow!("Invalid base URL: {e}"))?; // Set the default port if missing // Don't add default port if: // 1. URL explicitly ends with standard ports (:80 or :443) // 2. URL uses HTTPS (which implicitly uses port 443) - let explicit_default_port = self.host.ends_with(":80") || self.host.ends_with(":443"); + let explicit_default_port = host.ends_with(":80") || host.ends_with(":443"); let is_https = base_url.scheme() == "https"; if base_url.port().is_none() && !explicit_default_port && !is_https { - base_url.set_port(Some(OLLAMA_DEFAULT_PORT)).map_err(|_| { - ProviderError::RequestFailed("Failed to set default port".to_string()) - })?; + base_url + .set_port(Some(OLLAMA_DEFAULT_PORT)) + .map_err(|_| anyhow::anyhow!("Failed to set default port"))?; } - Ok(base_url) + // No authentication for Ollama + let auth = AuthMethod::Custom(Box::new(NoAuth)); + let api_client = ApiClient::with_timeout(base_url.to_string(), auth, timeout)?; + + Ok(Self { api_client, model }) } async fn post(&self, payload: &Value) -> Result { - // TODO: remove this later when the UI handles provider config refresh - let base_url = self.get_base_url()?; - - let url = base_url.join("v1/chat/completions").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; + let response = self + .api_client + .response_post("v1/chat/completions", payload) + .await?; + handle_response_openai_compat(response).await + } +} - let response = self.client.post(url).json(payload).send().await?; +// No authentication provider for Ollama +struct NoAuth; - handle_response_openai_compat(response).await +#[async_trait] +impl super::api_client::AuthProvider for NoAuth { + async fn get_auth_header(&self) -> Result<(String, String)> { + // Return a dummy header that won't be used + Ok(("X-No-Auth".to_string(), "true".to_string())) } } @@ -141,8 +140,13 @@ impl Provider for OllamaProvider { filtered_tools, &super::utils::ImageFormat::OpenAi, )?; - let response = self.post(&payload).await?; - let message = response_to_message(&response)?; + let response = self + .with_retry(|| async { + let payload_clone = payload.clone(); + self.post(&payload_clone).await + }) + .await?; + let message = response_to_message(&response.clone())?; let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { tracing::debug!("Failed to get usage data"); diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index e57e9ae46286..a874f421c4aa 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -2,38 +2,42 @@ use anyhow::Result; use async_stream::try_stream; use async_trait::async_trait; use futures::TryStreamExt; -use reqwest::{Client, Response}; +use reqwest::StatusCode; use serde_json::{json, Value}; use std::collections::HashMap; use std::io; -use std::time::Duration; use tokio::pin; use tokio_stream::StreamExt; use tokio_util::codec::{FramedRead, LinesCodec}; use tokio_util::io::StreamReader; +use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; -use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; +use super::utils::{ + emit_debug_trace, get_model, handle_response_openai_compat, handle_status_openai_compat, + ImageFormat, +}; use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::MessageStream; use crate::providers::formats::openai::response_to_streaming_message; -use crate::providers::utils::handle_status_openai_compat; use rmcp::model::Tool; pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o"; -pub const OPEN_AI_KNOWN_MODELS: &[&str] = &[ - "gpt-4o", - "gpt-4o-mini", - "gpt-4-turbo", - "gpt-3.5-turbo", - "o1", - "o3", - "o4-mini", +pub const OPEN_AI_KNOWN_MODELS: &[(&str, usize)] = &[ + ("gpt-4o", 128_000), + ("gpt-4o-mini", 128_000), + ("gpt-4.1", 128_000), + ("gpt-4.1-mini", 128_000), + ("o1", 200_000), + ("o3", 200_000), + ("gpt-3.5-turbo", 16_385), + ("gpt-4-turbo", 128_000), + ("o4-mini", 128_000), ]; pub const OPEN_AI_DOC_URL: &str = "https://platform.openai.com/docs/models"; @@ -41,10 +45,8 @@ pub const OPEN_AI_DOC_URL: &str = "https://platform.openai.com/docs/models"; #[derive(Debug, serde::Serialize)] pub struct OpenAiProvider { #[serde(skip)] - client: Client, - host: String, + api_client: ApiClient, base_path: String, - api_key: String, organization: Option, project: Option, model: ModelConfig, @@ -71,79 +73,61 @@ impl OpenAiProvider { .ok() .map(parse_custom_headers); let timeout_secs: u64 = config.get_param("OPENAI_TIMEOUT").unwrap_or(600); - let client = Client::builder() - .timeout(Duration::from_secs(timeout_secs)) - .build()?; - Ok(Self { - client, - host, - base_path, - api_key, - organization, - project, - model, - custom_headers, - }) - } + let auth = AuthMethod::BearerToken(api_key); + let mut api_client = + ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?; - /// Helper function to add OpenAI-specific headers to a request - fn add_headers(&self, mut request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - // Add organization header if present - if let Some(org) = &self.organization { - request = request.header("OpenAI-Organization", org); + if let Some(org) = &organization { + api_client = api_client.with_header("OpenAI-Organization", org)?; } - // Add project header if present - if let Some(project) = &self.project { - request = request.header("OpenAI-Project", project); + if let Some(project) = &project { + api_client = api_client.with_header("OpenAI-Project", project)?; } - // Add custom headers if present - if let Some(custom_headers) = &self.custom_headers { - for (key, value) in custom_headers { - request = request.header(key, value); + if let Some(headers) = &custom_headers { + let mut header_map = reqwest::header::HeaderMap::new(); + for (key, value) in headers { + let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())?; + let header_value = reqwest::header::HeaderValue::from_str(value)?; + header_map.insert(header_name, header_value); } + api_client = api_client.with_headers(header_map)?; } - request + Ok(Self { + api_client, + base_path, + organization, + project, + model, + custom_headers, + }) } - async fn post(&self, payload: &Value) -> Result { - let base_url = url::Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join(&self.base_path).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - let request = self - .client - .post(url) - .header("Authorization", format!("Bearer {}", self.api_key)); - - let request = self.add_headers(request); - - Ok(request.json(&payload).send().await?) + async fn post(&self, payload: &Value) -> Result { + let response = self + .api_client + .response_post(&self.base_path, payload) + .await?; + handle_response_openai_compat(response).await } } #[async_trait] impl Provider for OpenAiProvider { fn metadata() -> ProviderMetadata { + let models = OPEN_AI_KNOWN_MODELS + .iter() + .map(|(name, limit)| ModelInfo::new(*name, *limit)) + .collect(); ProviderMetadata::with_models( "openai", "OpenAI", "GPT-4 and other OpenAI models, including OpenAI compatible ones", OPEN_AI_DEFAULT_MODEL, - vec![ - ModelInfo::new("gpt-4o", 128000), - ModelInfo::new("gpt-4o-mini", 128000), - ModelInfo::new("gpt-4-turbo", 128000), - ModelInfo::new("gpt-3.5-turbo", 16385), - ModelInfo::new("o1", 200000), - ModelInfo::new("o3", 200000), - ModelInfo::new("o4-mini", 128000), - ], + models, OPEN_AI_DOC_URL, vec![ ConfigKey::new("OPENAI_API_KEY", true, true, None), @@ -173,42 +157,25 @@ impl Provider for OpenAiProvider { ) -> Result<(Message, ProviderUsage), ProviderError> { let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; - // Make request - let response = handle_response_openai_compat(self.post(&payload).await?).await?; - - // Parse response - let message = response_to_message(&response)?; - let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { - tracing::debug!("Failed to get usage data"); - Usage::default() - }); - let model = get_model(&response); - emit_debug_trace(&self.model, &payload, &response, &usage); + let json_response = self.post(&payload).await?; + + let message = response_to_message(&json_response)?; + let usage = json_response + .get("usage") + .map(get_usage) + .unwrap_or_else(|| { + tracing::debug!("Failed to get usage data"); + Usage::default() + }); + let model = get_model(&json_response); + emit_debug_trace(&self.model, &payload, &json_response, &usage); Ok((message, ProviderUsage::new(model, usage))) } - /// Fetch supported models from OpenAI; returns Err on any failure, Ok(None) if no data - async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { - // List available models via OpenAI API - let base_url = - url::Url::parse(&self.host).map_err(|e| ProviderError::RequestFailed(e.to_string()))?; - let url = base_url - .join(&self.base_path.replace("v1/chat/completions", "v1/models")) - .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; - let mut request = self.client.get(url).bearer_auth(&self.api_key); - if let Some(org) = &self.organization { - request = request.header("OpenAI-Organization", org); - } - if let Some(project) = &self.project { - request = request.header("OpenAI-Project", project); - } - if let Some(headers) = &self.custom_headers { - for (key, value) in headers { - request = request.header(key, value); - } - } - let response = request.send().await?; - let json: serde_json::Value = response.json().await?; + async fn fetch_supported_models(&self) -> Result>, ProviderError> { + let models_path = self.base_path.replace("v1/chat/completions", "v1/models"); + let response = self.api_client.response_get(&models_path).await?; + let json = handle_response_openai_compat(response).await?; if let Some(err_obj) = json.get("error") { let msg = err_obj .get("message") @@ -216,6 +183,7 @@ impl Provider for OpenAiProvider { .unwrap_or("unknown error"); return Err(ProviderError::Authentication(msg.to_string())); } + let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| { ProviderError::UsageError("Missing data field in JSON response".into()) })?; @@ -254,12 +222,16 @@ impl Provider for OpenAiProvider { "include_usage": true, }); - let response = handle_status_openai_compat(self.post(&payload).await?).await?; + let response = self + .api_client + .response_post(&self.base_path, &payload) + .await?; + let response = handle_status_openai_compat(response).await?; let stream = response.bytes_stream().map_err(io::Error::other); let model_config = self.model.clone(); - // Wrap in a line decoder and yield lines inside the stream + Ok(Box::pin(try_stream! { let stream_reader = StreamReader::new(stream); let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from); @@ -268,7 +240,7 @@ impl Provider for OpenAiProvider { pin!(message_stream); while let Some(message) = message_stream.next().await { let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; - super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); yield (message, usage); } })) @@ -293,7 +265,6 @@ impl EmbeddingCapable for OpenAiProvider { return Ok(vec![]); } - // Get embedding model from env var or use default let embedding_model = std::env::var("GOOSE_EMBEDDING_MODEL") .unwrap_or_else(|_| "text-embedding-3-small".to_string()); @@ -302,35 +273,25 @@ impl EmbeddingCapable for OpenAiProvider { model: embedding_model, }; - // Construct embeddings endpoint URL - let base_url = - url::Url::parse(&self.host).map_err(|e| anyhow::anyhow!("Invalid base URL: {e}"))?; - let url = base_url - .join("v1/embeddings") - .map_err(|e| anyhow::anyhow!("Failed to construct embeddings URL: {e}"))?; - - let req = self - .client - .post(url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&request); - - let req = self.add_headers(req); - - let response = req - .send() - .await - .map_err(|e| anyhow::anyhow!("Failed to send embedding request: {e}"))?; - - if !response.status().is_success() { - let error_text = response.text().await.unwrap_or_default(); + let response = self + .api_client + .api_post("v1/embeddings", &serde_json::to_value(request)?) + .await?; + + if response.status != StatusCode::OK { + let error_text = response + .payload + .as_ref() + .and_then(|p| p.as_str()) + .unwrap_or("Unknown error"); return Err(anyhow::anyhow!("Embedding API error: {}", error_text)); } - let embedding_response: EmbeddingResponse = response - .json() - .await - .map_err(|e| anyhow::anyhow!("Failed to parse embedding response: {e}"))?; + let embedding_response: EmbeddingResponse = serde_json::from_value( + response + .payload + .ok_or_else(|| anyhow::anyhow!("Empty response body"))?, + )?; Ok(embedding_response .data diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 68633d63da33..cc23fbfa786d 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -1,11 +1,11 @@ use anyhow::{Error, Result}; use async_trait::async_trait; -use reqwest::Client; use serde_json::{json, Value}; -use std::time::Duration; +use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use super::retry::ProviderRetry; use super::utils::{ emit_debug_trace, get_model, handle_response_google_compat, handle_response_openai_compat, is_google_model, @@ -15,7 +15,6 @@ use crate::message::Message; use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; use rmcp::model::Tool; -use url::Url; pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet"; pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic"; @@ -35,9 +34,7 @@ pub const OPENROUTER_DOC_URL: &str = "https://openrouter.ai/models"; #[derive(serde::Serialize)] pub struct OpenRouterProvider { #[serde(skip)] - client: Client, - host: String, - api_key: String, + api_client: ApiClient, model: ModelConfig, } @@ -51,34 +48,18 @@ impl OpenRouterProvider { .get_param("OPENROUTER_HOST") .unwrap_or_else(|_| "https://openrouter.ai".to_string()); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + let auth = AuthMethod::BearerToken(api_key); + let api_client = ApiClient::new(host, auth)? + .with_header("HTTP-Referer", "https://block.github.io/goose")? + .with_header("X-Title", "Goose")?; - Ok(Self { - client, - host, - api_key, - model, - }) + Ok(Self { api_client, model }) } async fn post(&self, payload: &Value) -> Result { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("api/v1/chat/completions").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - let response = self - .client - .post(url) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("HTTP-Referer", "https://block.github.io/goose") - .header("X-Title", "Goose") - .json(payload) - .send() + .api_client + .response_post("api/v1/chat/completions", payload) .await?; // Handle Google-compatible model responses differently @@ -264,7 +245,12 @@ impl Provider for OpenRouterProvider { let payload = create_request_based_on_model(self, system, messages, tools)?; // Make request - let response = self.post(&payload).await?; + let response = self + .with_retry(|| async { + let payload_clone = payload.clone(); + self.post(&payload_clone).await + }) + .await?; // Parse response let message = response_to_message(&response)?; @@ -278,24 +264,10 @@ impl Provider for OpenRouterProvider { } /// Fetch supported models from OpenRouter API (only models with tool support) - async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("api/v1/models").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct models URL: {e}")) - })?; - + async fn fetch_supported_models(&self) -> Result>, ProviderError> { // Handle request failures gracefully // If the request fails, fall back to manual entry - let response = match self - .client - .get(url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("HTTP-Referer", "https://block.github.io/goose") - .header("X-Title", "Goose") - .send() - .await - { + let response = match self.api_client.response_get("api/v1/models").await { Ok(response) => response, Err(e) => { tracing::warn!("Failed to fetch models from OpenRouter API: {}, falling back to manual model entry", e); diff --git a/crates/goose/src/providers/retry.rs b/crates/goose/src/providers/retry.rs new file mode 100644 index 000000000000..b95f77bd2cd2 --- /dev/null +++ b/crates/goose/src/providers/retry.rs @@ -0,0 +1,117 @@ +use super::errors::ProviderError; +use crate::providers::base::Provider; +use async_trait::async_trait; +use std::future::Future; +use std::time::Duration; +use tokio::time::sleep; + +pub const DEFAULT_MAX_RETRIES: usize = 3; +pub const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 1000; +pub const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0; +pub const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 30_000; + +#[derive(Debug, Clone)] +pub struct RetryConfig { + /// Maximum number of retry attempts + pub(crate) max_retries: usize, + /// Initial interval between retries in milliseconds + pub(crate) initial_interval_ms: u64, + /// Multiplier for backoff (exponential) + pub(crate) backoff_multiplier: f64, + /// Maximum interval between retries in milliseconds + pub(crate) max_interval_ms: u64, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: DEFAULT_MAX_RETRIES, + initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS, + backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER, + max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS, + } + } +} + +impl RetryConfig { + pub fn new( + max_retries: usize, + initial_interval_ms: u64, + backoff_multiplier: f64, + max_interval_ms: u64, + ) -> Self { + Self { + max_retries, + initial_interval_ms, + backoff_multiplier, + max_interval_ms, + } + } + + pub fn delay_for_attempt(&self, attempt: usize) -> Duration { + if attempt == 0 { + return Duration::from_millis(0); + } + + let exponent = (attempt - 1) as u32; + let base_delay_ms = (self.initial_interval_ms as f64 + * self.backoff_multiplier.powi(exponent as i32)) as u64; + + let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms); + + let jitter_factor_to_avoid_thundering_herd = 0.8 + (rand::random::() * 0.4); + let jitter_delay_ms = + (capped_delay_ms as f64 * jitter_factor_to_avoid_thundering_herd) as u64; + + Duration::from_millis(jitter_delay_ms) + } +} + +/// Trait for retry functionality to keep Provider dyn-compatible +#[async_trait] +pub trait ProviderRetry { + fn retry_config(&self) -> RetryConfig { + RetryConfig::default() + } + + async fn with_retry(&self, operation: F) -> Result + where + F: Fn() -> Fut + Send, + Fut: Future> + Send, + T: Send, + { + let mut attempts = 0; + let config = self.retry_config(); + + loop { + return match operation().await { + Ok(result) => Ok(result), + Err(error) => { + let should_retry = matches!( + error, + ProviderError::RateLimitExceeded(_) | ProviderError::ServerError(_) + ); + + if should_retry && attempts < config.max_retries { + attempts += 1; + tracing::warn!( + "Request failed, retrying ({}/{}): {:?}", + attempts, + config.max_retries, + error + ); + + let delay = config.delay_for_attempt(attempts); + tracing::info!("Backing off for {:?} before retry", delay); + sleep(delay).await; + continue; + } + + Err(error) + } + }; + } + } +} + +impl ProviderRetry for P {} diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index d2656aad6f21..9b68ab656f67 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -8,10 +8,10 @@ use aws_sdk_bedrockruntime::config::ProvideCredentials; use aws_sdk_sagemakerruntime::Client as SageMakerClient; use rmcp::model::Tool; use serde_json::{json, Value}; -use tokio::time::sleep; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use super::retry::ProviderRetry; use super::utils::emit_debug_trace; use crate::impl_provider_default; use crate::message::{Message, MessageContent}; @@ -295,63 +295,33 @@ impl Provider for SageMakerTgiProvider { ProviderError::RequestFailed(format!("Failed to create request: {}", e)) })?; - // Retry configuration - const MAX_RETRIES: u32 = 3; - const INITIAL_BACKOFF_MS: u64 = 1000; // 1 second - const MAX_BACKOFF_MS: u64 = 30000; // 30 seconds - - let mut attempts = 0; - let mut backoff_ms = INITIAL_BACKOFF_MS; - - loop { - attempts += 1; - - match self.invoke_endpoint(request_payload.clone()).await { - Ok(response) => { - let message = self.parse_tgi_response(response)?; - - // TGI doesn't provide usage statistics, so we estimate - let usage = Usage { - input_tokens: Some(0), // Would need to tokenize input to get accurate count - output_tokens: Some(0), // Would need to tokenize output to get accurate count - total_tokens: Some(0), - }; - - // Add debug trace - let debug_payload = serde_json::json!({ - "system": system, - "messages": messages, - "tools": tools - }); - emit_debug_trace( - &self.model, - &debug_payload, - &serde_json::to_value(&message).unwrap_or_default(), - &usage, - ); - - let provider_usage = ProviderUsage::new(model_name.to_string(), usage); - return Ok((message, provider_usage)); - } - Err(err) => { - if attempts > MAX_RETRIES { - return Err(err); - } + let response = self + .with_retry(|| self.invoke_endpoint(request_payload.clone())) + .await?; - // Log retry attempt - tracing::warn!( - "SageMaker TGI request failed (attempt {}/{}), retrying in {} ms: {:?}", - attempts, - MAX_RETRIES, - backoff_ms, - err - ); - - // Wait before retry - sleep(Duration::from_millis(backoff_ms)).await; - backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS); - } - } - } + let message = self.parse_tgi_response(response)?; + + // TGI doesn't provide usage statistics, so we estimate + let usage = Usage { + input_tokens: Some(0), // Would need to tokenize input to get accurate count + output_tokens: Some(0), // Would need to tokenize output to get accurate count + total_tokens: Some(0), + }; + + // Add debug trace + let debug_payload = serde_json::json!({ + "system": system, + "messages": messages, + "tools": tools + }); + emit_debug_trace( + &self.model, + &debug_payload, + &serde_json::to_value(&message).unwrap_or_default(), + &usage, + ); + + let provider_usage = ProviderUsage::new(model_name.to_string(), usage); + Ok((message, provider_usage)) } } diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index 3e52310ee54b..f0b643b7748c 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -1,20 +1,19 @@ use anyhow::Result; use async_trait::async_trait; -use reqwest::{Client, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use std::time::Duration; +use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use super::formats::snowflake::{create_request, get_usage, response_to_message}; -use super::utils::{get_model, ImageFormat}; +use super::retry::ProviderRetry; +use super::utils::{get_model, map_http_error_to_provider_error, ImageFormat}; use crate::config::ConfigError; use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use rmcp::model::Tool; -use url::Url; pub const SNOWFLAKE_DEFAULT_MODEL: &str = "claude-3-7-sonnet"; pub const SNOWFLAKE_KNOWN_MODELS: &[&str] = &["claude-3-7-sonnet", "claude-3-5-sonnet"]; @@ -36,9 +35,7 @@ impl SnowflakeAuth { #[derive(Debug, serde::Serialize)] pub struct SnowflakeProvider { #[serde(skip)] - client: Client, - host: String, - auth: SnowflakeAuth, + api_client: ApiClient, model: ModelConfig, image_format: ImageFormat, } @@ -82,57 +79,33 @@ impl SnowflakeProvider { .into()); } - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + // Ensure host has https:// prefix + let base_url = if !host.starts_with("https://") && !host.starts_with("http://") { + format!("https://{}", host) + } else { + host + }; + + let auth = AuthMethod::BearerToken(token?); + let api_client = ApiClient::new(base_url, auth)?.with_header("User-Agent", "Goose")?; - // Use token-based authentication - let api_key = token?; Ok(Self { - client, - host, - auth: SnowflakeAuth::token(api_key), + api_client, model, image_format: ImageFormat::OpenAi, }) } - async fn ensure_auth_header(&self) -> Result { - match &self.auth { - // https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/authentication#using-a-programmatic-access-token-pat - SnowflakeAuth::Token(token) => Ok(format!("Bearer {}", token)), - } - } - async fn post(&self, payload: &Value) -> Result { - let base_url_str = - if !self.host.starts_with("https://") && !self.host.starts_with("http://") { - format!("https://{}", self.host) - } else { - self.host.clone() - }; - let base_url = Url::parse(&base_url_str) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let path = "api/v2/cortex/inference:complete"; - let url = base_url.join(path).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - let auth_header = self.ensure_auth_header().await?; let response = self - .client - .post(url) - .header("Authorization", auth_header) - .header("User-Agent", "Goose") - .json(&payload) - .send() + .api_client + .response_post("api/v2/cortex/inference:complete", payload) .await?; let status = response.status(); - let payload_text: String = response.text().await.ok().unwrap_or_default(); - if status == StatusCode::OK { + if status.is_success() { if let Ok(payload) = serde_json::from_str::(&payload_text) { if payload.get("code").is_some() { let code = payload @@ -295,96 +268,11 @@ impl SnowflakeProvider { "content_list": content_list }); - match status { - StatusCode::OK => Ok(answer_payload), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - // Extract a clean error message from the response if available - let error_msg = payload_text - .lines() - .find(|line| line.contains("\"message\"")) - .and_then(|line| { - let json_str = line.strip_prefix("data: ").unwrap_or(line); - serde_json::from_str::(json_str).ok() - }) - .and_then(|json| { - json.get("message") - .and_then(|m| m.as_str()) - .map(|s| s.to_string()) - }) - .unwrap_or_else(|| "Invalid credentials".to_string()); - - Err(ProviderError::Authentication(format!( - "Authentication failed. Please check your SNOWFLAKE_TOKEN and SNOWFLAKE_HOST configuration. Error: {}", - error_msg - ))) - } - StatusCode::BAD_REQUEST => { - // Snowflake provides a generic 'error' but also includes 'external_model_message' which is provider specific - // We try to extract the error message from the payload and check for phrases that indicate context length exceeded - let payload_str = payload_text.to_lowercase(); - let check_phrases = [ - "too long", - "context length", - "context_length_exceeded", - "reduce the length", - "token count", - "exceeds", - "exceed context limit", - "input length", - "max_tokens", - "decrease input length", - "context limit", - ]; - if check_phrases.iter().any(|c| payload_str.contains(c)) { - return Err(ProviderError::ContextLengthExceeded("Request exceeds maximum context length. Please reduce the number of messages or content size.".to_string())); - } - - // Try to parse a clean error message from the response - let error_msg = if let Ok(json) = serde_json::from_str::(&payload_text) { - json.get("message") - .and_then(|m| m.as_str()) - .map(|s| s.to_string()) - .or_else(|| { - json.get("external_model_message") - .and_then(|ext| ext.get("message")) - .and_then(|m| m.as_str()) - .map(|s| s.to_string()) - }) - .unwrap_or_else(|| "Bad request".to_string()) - } else { - "Bad request".to_string() - }; - - tracing::debug!( - "Provider request failed with status: {}. Response: {}", - status, - payload_text - ); - Err(ProviderError::RequestFailed(format!( - "Request failed: {}", - error_msg - ))) - } - StatusCode::TOO_MANY_REQUESTS => Err(ProviderError::RateLimitExceeded( - "Rate limit exceeded. Please try again later.".to_string(), - )), - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - Err(ProviderError::ServerError( - "Snowflake service is temporarily unavailable. Please try again later." - .to_string(), - )) - } - _ => { - tracing::debug!( - "Provider request failed with status: {}. Response: {}", - status, - payload_text - ); - Err(ProviderError::RequestFailed(format!( - "Request failed with status: {}", - status - ))) - } + if status.is_success() { + Ok(answer_payload) + } else { + let error_json = serde_json::from_str::(&payload_text).ok(); + Err(map_http_error_to_provider_error(status, error_json)) } } } @@ -422,7 +310,12 @@ impl Provider for SnowflakeProvider { ) -> Result<(Message, ProviderUsage), ProviderError> { let payload = create_request(&self.model, system, messages, tools)?; - let response = self.post(&payload).await?; + let response = self + .with_retry(|| async { + let payload_clone = payload.clone(); + self.post(&payload_clone).await + }) + .await?; // Parse response let message = response_to_message(&response)?; diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index abfa9f5a4555..7f4a1297b84c 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -44,6 +44,74 @@ pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value } } +fn check_context_length_exceeded(text: &str) -> bool { + let check_phrases = [ + "too long", + "context length", + "context_length_exceeded", + "reduce the length", + "token count", + "exceeds", + "exceed context limit", + "input length", + "max_tokens", + "decrease input length", + "context limit", + ]; + let text_lower = text.to_lowercase(); + check_phrases + .iter() + .any(|phrase| text_lower.contains(phrase)) +} + +#[allow(clippy::cognitive_complexity)] +pub fn map_http_error_to_provider_error( + status: StatusCode, + payload: Option, +) -> ProviderError { + match status { + StatusCode::OK => unreachable!("Should not call this function with OK status"), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + ProviderError::Authentication(format!( + "Authentication failed. Please ensure your API keys are valid and have the required permissions. \ + Status: {}. Response: {:?}", status, payload + )) + } + StatusCode::BAD_REQUEST => { + let mut error_msg = "Unknown error".to_string(); + if let Some(payload) = &payload { + let payload_str = payload.to_string(); + if check_context_length_exceeded(&payload_str) { + return ProviderError::ContextLengthExceeded(payload_str); + } + + if let Some(error) = payload.get("error") { + error_msg = error.get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error") + .to_string(); + } + } + tracing::debug!( + "Provider request failed with status: {}. Payload: {:?}", status, payload + ); + ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg)) + } + StatusCode::TOO_MANY_REQUESTS => { + ProviderError::RateLimitExceeded(format!("{:?}", payload)) + } + StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { + ProviderError::ServerError(format!("{:?}", payload)) + } + _ => { + tracing::debug!( + "Provider request failed with status: {}. Payload: {:?}", status, payload + ); + ProviderError::RequestFailed(format!("Request failed with status: {}", status)) + } + } +} + /// Handle response from OpenAI compatible endpoints /// Error codes: https://platform.openai.com/docs/guides/error-codes /// Context window exceeded: https://community.openai.com/t/help-needed-tackling-context-length-limits-in-openai-models/617543 @@ -54,36 +122,31 @@ pub async fn handle_status_openai_compat(response: Response) -> Result Ok(response), _ => { let body = response.json::().await; - match (body, status) { - (Err(e), _) => Err(ProviderError::RequestFailed(e.to_string())), - (Ok(body), StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN) => { - Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, body))) - } - (Ok(body), StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND) => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, body) - ); - if let Ok(err_resp) = from_value::(body) { - let err = err_resp.error; - if err.is_context_length_exceeded() { - return Err(ProviderError::ContextLengthExceeded(err.message.unwrap_or("Unknown error".to_string()))); + match body { + Err(e) => Err(ProviderError::RequestFailed(e.to_string())), + Ok(body) => { + let error = if matches!(status, StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND) + { + if let Ok(err_resp) = from_value::(body.clone()) { + let err = err_resp.error; + if err.is_context_length_exceeded() { + ProviderError::ContextLengthExceeded( + err.message.unwrap_or("Unknown error".to_string()), + ) + } else { + ProviderError::RequestFailed(format!( + "{} (status {})", + err, + status.as_u16() + )) + } + } else { + map_http_error_to_provider_error(status, Some(body)) } - return Err(ProviderError::RequestFailed(format!("{} (status {})", err, status.as_u16()))); - } - Err(ProviderError::RequestFailed(format!("Unknown error (status {})", status))) - } - (Ok(body), StatusCode::TOO_MANY_REQUESTS) => { - Err(ProviderError::RateLimitExceeded(format!("{:?}", body))) - } - (Ok(body), StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE) => { - Err(ProviderError::ServerError(format!("{:?}", body))) - } - (Ok(body), _) => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, body) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) + } else { + map_http_error_to_provider_error(status, Some(body)) + }; + Err(error) } } } diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 67acd4483f09..18b08aa4f26c 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -1,13 +1,14 @@ use anyhow::Result; use async_trait::async_trait; use chrono::Utc; -use reqwest::{Client, Response}; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use serde_json::{json, Value}; -use std::time::Duration; +use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use super::retry::ProviderRetry; +use super::utils::map_http_error_to_provider_error; use crate::impl_provider_default; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; @@ -70,14 +71,12 @@ const FALLBACK_MODELS: [&str; 3] = [ "mistral-31-24b", // Another model with function calling ]; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize)] pub struct VeniceProvider { #[serde(skip)] - client: Client, - host: String, + api_client: ApiClient, base_path: String, models_path: String, - api_key: String, model: ModelConfig, } @@ -100,47 +99,21 @@ impl VeniceProvider { // Ensure we only keep the bare model id internally model.model_name = strip_flags(&model.model_name).to_string(); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + let auth = AuthMethod::BearerToken(api_key); + let api_client = ApiClient::new(host, auth)?; let instance = Self { - client, - host, + api_client, base_path, models_path, - api_key, model, }; Ok(instance) } - async fn post(&self, path: &str, body: &str) -> Result { - let base_url = url::Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url - .join(path) - .map_err(|e| ProviderError::RequestFailed(format!("Failed to construct URL: {e}")))?; - // Choose GET for models endpoint, POST otherwise - let method = if path.contains("models") { - tracing::debug!("Using GET method for models endpoint"); - self.client.get(url.clone()) - } else { - tracing::debug!("Using POST method for completions endpoint"); - self.client.post(url.clone()) - }; - - // Log the request details - tracing::debug!("Venice request URL: {}", url); - tracing::debug!("Venice request body: {}", body); - - let response = method - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .body(body.to_string()) - .send() - .await?; + async fn post(&self, path: &str, payload: &Value) -> Result { + let response = self.api_client.response_post(path, payload).await?; let status = response.status(); tracing::debug!("Venice response status: {}", status); @@ -193,24 +166,20 @@ impl VeniceProvider { } } } - - // General error extraction - if let Some(error_msg) = json.get("error").and_then(|e| e.as_str()) { - return Err(ProviderError::RequestFailed(format!( - "Venice API error: {}", - error_msg - ))); - } } - // Fallback for unparseable errors - return Err(ProviderError::RequestFailed(format!( - "Venice API request failed with status code {}", - status - ))); + // Use the common error mapping function + let error_json = serde_json::from_str::(&error_body).ok(); + return Err(map_http_error_to_provider_error(status, error_json)); } - Ok(response) + let response_text = response.text().await?; + serde_json::from_str(&response_text).map_err(|e| { + ProviderError::RequestFailed(format!( + "Failed to parse JSON: {}\nResponse: {}", + e, response_text + )) + }) } } @@ -247,28 +216,9 @@ impl Provider for VeniceProvider { self.model.clone() } - async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { - // Fetch supported models via Venice API - let base_url = url::Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {}", e)))?; - let models_url = base_url.join(&self.models_path).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct models URL: {}", e)) - })?; - let response = self - .client - .get(models_url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .send() - .await?; - if !response.status().is_success() { - return Err(ProviderError::RequestFailed(format!( - "Venice API request failed with status {}", - response.status() - ))); - } - let body = response.text().await?; - let json: serde_json::Value = serde_json::from_str(&body) - .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse JSON: {}", e)))?; + async fn fetch_supported_models(&self) -> Result>, ProviderError> { + let response = self.api_client.response_get(&self.models_path).await?; + let json: serde_json::Value = response.json().await?; // Print legend once so users know what flags mean println!( @@ -471,17 +421,13 @@ impl Provider for VeniceProvider { tracing::debug!("Sending request to Venice API"); tracing::debug!("Venice request payload: {}", payload.to_string()); - // Send request - let response = self.post(&self.base_path, &payload.to_string()).await?; + // Send request with retry + let response = self + .with_retry(|| self.post(&self.base_path, &payload)) + .await?; - // Parse the response - let response_text = response.text().await?; - let response_json: Value = serde_json::from_str(&response_text).map_err(|e| { - ProviderError::RequestFailed(format!( - "Failed to parse JSON: {}\nResponse: {}", - e, response_text - )) - })?; + // Parse the response - response is already a Value from our post method + let response_json = response; // Handle tool calls from the response if present let tool_calls = response_json["choices"] diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index 3b8596a63284..e1462f71664a 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -1,28 +1,27 @@ +use super::api_client::{ApiClient, AuthMethod}; use super::errors::ProviderError; +use super::retry::ProviderRetry; +use super::utils::{get_model, handle_response_openai_compat}; use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; -use crate::providers::utils::get_model; use anyhow::Result; use async_trait::async_trait; -use reqwest::{Client, StatusCode}; use rmcp::model::Tool; use serde_json::Value; -use std::time::Duration; -use url::Url; pub const XAI_API_HOST: &str = "https://api.x.ai/v1"; pub const XAI_DEFAULT_MODEL: &str = "grok-3"; pub const XAI_KNOWN_MODELS: &[&str] = &[ + "grok-4-0709", "grok-3", "grok-3-fast", "grok-3-mini", "grok-3-mini-fast", "grok-2-vision-1212", "grok-2-image-1212", - "grok-2-1212", "grok-3-latest", "grok-3-fast-latest", "grok-3-mini-latest", @@ -40,9 +39,7 @@ pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview"; #[derive(serde::Serialize)] pub struct XaiProvider { #[serde(skip)] - client: Client, - host: String, - api_key: String, + api_client: ApiClient, model: ModelConfig, } @@ -56,67 +53,21 @@ impl XaiProvider { .get_param("XAI_HOST") .unwrap_or_else(|_| XAI_API_HOST.to_string()); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + let auth = AuthMethod::BearerToken(api_key); + let api_client = ApiClient::new(host, auth)?; - Ok(Self { - client, - host, - api_key, - model, - }) + Ok(Self { api_client, model }) } - async fn post(&self, payload: &Value) -> anyhow::Result { - // Ensure the host ends with a slash for proper URL joining - let host = if self.host.ends_with('/') { - self.host.clone() - } else { - format!("{}/", self.host) - }; - let base_url = Url::parse(&host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("chat/completions").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - tracing::debug!("xAI API URL: {}", url); + async fn post(&self, payload: Value) -> Result { tracing::debug!("xAI request model: {:?}", self.model.model_name); let response = self - .client - .post(url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&payload) - .send() + .api_client + .response_post("chat/completions", &payload) .await?; - let status = response.status(); - let payload: Option = response.json().await.ok(); - - match status { - StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, payload))) - } - StatusCode::PAYLOAD_TOO_LARGE => { - Err(ProviderError::ContextLengthExceeded(format!("{:?}", payload))) - } - StatusCode::TOO_MANY_REQUESTS => { - Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) - } - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - Err(ProviderError::ServerError(format!("{:?}", payload))) - } - _ => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) - } - } + handle_response_openai_compat(response).await } } @@ -150,7 +101,7 @@ impl Provider for XaiProvider { system: &str, messages: &[Message], tools: &[Tool], - ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { + ) -> Result<(Message, ProviderUsage), ProviderError> { let payload = create_request( &self.model, system, @@ -159,7 +110,7 @@ impl Provider for XaiProvider { &super::utils::ImageFormat::OpenAi, )?; - let response = self.post(&payload).await?; + let response = self.with_retry(|| self.post(payload.clone())).await?; let message = response_to_message(&response)?; let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {