diff --git a/Cargo.lock b/Cargo.lock index 9541e0b7dbf9..afd0ed58cdef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8696,8 +8696,10 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ + "async-compression", "bitflags 2.9.0", "bytes", + "futures-core", "futures-util", "http 1.2.0", "http-body 1.0.1", diff --git a/crates/goose-server/Cargo.toml b/crates/goose-server/Cargo.toml index eaf6c30b9384..3b54d9889f9e 100644 --- a/crates/goose-server/Cargo.toml +++ b/crates/goose-server/Cargo.toml @@ -19,7 +19,7 @@ axum = { version = "0.8.1", features = ["ws", "macros"] } tokio = { version = "1.43", features = ["full"] } chrono = "0.4" tokio-cron-scheduler = "0.14.0" -tower-http = { version = "0.5", features = ["cors"] } +tower-http = { version = "0.5", features = ["cors", "compression-gzip", "compression-br"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" futures = "0.3" diff --git a/crates/goose-server/src/commands/agent.rs b/crates/goose-server/src/commands/agent.rs index 5fdfa89ae2ee..be13a579390e 100644 --- a/crates/goose-server/src/commands/agent.rs +++ b/crates/goose-server/src/commands/agent.rs @@ -7,6 +7,7 @@ use etcetera::{choose_app_strategy, AppStrategy}; use goose::agents::Agent; use goose::config::APP_STRATEGY; use goose::scheduler_factory::SchedulerFactory; +use tower_http::compression::CompressionLayer; use tower_http::cors::{Any, CorsLayer}; use tracing::info; @@ -50,7 +51,12 @@ pub async fn run() -> Result<()> { .allow_methods(Any) .allow_headers(Any); - let app = crate::routes::configure(app_state).layer(cors); + // Add compression middleware for gzip and brotli + let compression = CompressionLayer::new().gzip(true).br(true); + + let app = crate::routes::configure(app_state) + .layer(cors) + .layer(compression); let listener = tokio::net::TcpListener::bind(settings.socket_addr()).await?; info!("listening on {}", listener.local_addr()?); diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index e21963a062a7..15f4f3cfd5de 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -333,10 +333,18 @@ pub struct PricingResponse { pub source: String, } +#[derive(Deserialize, ToSchema)] +pub struct ModelRequest { + pub provider: String, + pub model: String, +} + #[derive(Deserialize, ToSchema)] pub struct PricingQuery { /// If true, only return pricing for configured providers. If false, return all. pub configured_only: Option, + /// Specific models to fetch pricing for. If provided, only these models will be returned. + pub models: Option>, } #[utoipa::path( @@ -355,6 +363,7 @@ pub async fn get_pricing( verify_secret_key(&headers, &state)?; let configured_only = query.configured_only.unwrap_or(true); + let has_specific_models = query.models.is_some(); // If refresh requested (configured_only = false), refresh the cache if !configured_only { @@ -365,7 +374,49 @@ pub async fn get_pricing( let mut pricing_data = Vec::new(); - if !configured_only { + // If specific models are requested, fetch only those + if let Some(requested_models) = query.models { + for model_req in requested_models { + // Try to get pricing from cache + if let Some(pricing) = get_model_pricing(&model_req.provider, &model_req.model).await { + pricing_data.push(PricingData { + provider: model_req.provider, + model: model_req.model, + input_token_cost: pricing.input_cost, + output_token_cost: pricing.output_cost, + currency: "$".to_string(), + context_length: pricing.context_length, + }); + } + // Check if the model has embedded pricing data from provider metadata + else if let Some(metadata) = get_providers() + .iter() + .find(|p| p.name == model_req.provider) + { + if let Some(model_info) = metadata + .known_models + .iter() + .find(|m| m.name == model_req.model) + { + if let (Some(input_cost), Some(output_cost)) = + (model_info.input_token_cost, model_info.output_token_cost) + { + pricing_data.push(PricingData { + provider: model_req.provider, + model: model_req.model, + input_token_cost: input_cost, + output_token_cost: output_cost, + currency: model_info + .currency + .clone() + .unwrap_or_else(|| "$".to_string()), + context_length: Some(model_info.context_limit as u32), + }); + } + } + } + } + } else if !configured_only { // Get ALL pricing data from the cache let all_pricing = get_all_pricing().await; @@ -425,7 +476,9 @@ pub async fn get_pricing( tracing::debug!( "Returning pricing for {} models{}", pricing_data.len(), - if configured_only { + if has_specific_models { + " (specific models requested)" + } else if configured_only { " (configured providers only)" } else { " (all cached models)" diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index c8a9929895ae..0d996b832f6b 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -115,3 +115,7 @@ path = "examples/async_token_counter_demo.rs" [[bench]] name = "tokenization_benchmark" harness = false + +[[bench]] +name = "connection_pooling" +harness = false diff --git a/crates/goose/benches/connection_pooling.rs b/crates/goose/benches/connection_pooling.rs new file mode 100644 index 000000000000..1fcde3f9a91d --- /dev/null +++ b/crates/goose/benches/connection_pooling.rs @@ -0,0 +1,102 @@ +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use goose::providers::provider_common::{create_provider_client, get_shared_client}; +use reqwest::Client; +use std::sync::Arc; +use tokio::runtime::Runtime; + +fn create_new_clients(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + c.bench_function("create_new_client", |b| { + b.iter(|| { + rt.block_on(async { + let _client = black_box(create_provider_client(Some(600)).unwrap()); + }) + }) + }); +} + +fn reuse_shared_client(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + c.bench_function("get_shared_client", |b| { + b.iter(|| { + rt.block_on(async { + let _client = black_box(get_shared_client()); + }) + }) + }); +} + +fn concurrent_requests_new_clients(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let mut group = c.benchmark_group("concurrent_requests_new"); + for num_requests in [10, 50, 100].iter() { + group.bench_with_input( + BenchmarkId::from_parameter(num_requests), + num_requests, + |b, &num_requests| { + b.iter(|| { + rt.block_on(async { + let tasks: Vec<_> = (0..num_requests) + .map(|_| { + tokio::spawn(async move { + let client = create_provider_client(Some(600)).unwrap(); + // Simulate a request (without actually making one) + black_box(&client); + }) + }) + .collect(); + + for task in tasks { + task.await.unwrap(); + } + }) + }) + }, + ); + } + group.finish(); +} + +fn concurrent_requests_shared_client(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let mut group = c.benchmark_group("concurrent_requests_shared"); + for num_requests in [10, 50, 100].iter() { + group.bench_with_input( + BenchmarkId::from_parameter(num_requests), + num_requests, + |b, &num_requests| { + b.iter(|| { + rt.block_on(async { + let tasks: Vec<_> = (0..num_requests) + .map(|_| { + tokio::spawn(async move { + let client = get_shared_client(); + // Simulate a request (without actually making one) + black_box(&client); + }) + }) + .collect(); + + for task in tasks { + task.await.unwrap(); + } + }) + }) + }, + ); + } + group.finish(); +} + +criterion_group!( + benches, + create_new_clients, + reuse_shared_client, + concurrent_requests_new_clients, + concurrent_requests_shared_client +); +criterion_main!(benches); diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 88a71b0f145c..03c13ac2426a 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -3,11 +3,15 @@ use async_trait::async_trait; use axum::http::HeaderMap; use reqwest::{Client, StatusCode}; use serde_json::Value; -use std::time::Duration; +use std::sync::Arc; use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use super::formats::anthropic::{create_request, get_usage, response_to_message}; +use super::provider_common::{ + build_endpoint_url, get_shared_client, retry_with_backoff, AuthType, HeaderBuilder, + ProviderConfigBuilder, RetryConfig, +}; use super::utils::{emit_debug_trace, get_model}; use crate::message::Message; use crate::model::ModelConfig; @@ -32,10 +36,12 @@ pub const ANTHROPIC_API_VERSION: &str = "2023-06-01"; #[derive(serde::Serialize)] pub struct AnthropicProvider { #[serde(skip)] - client: Client, + client: Arc, host: String, api_key: String, model: ModelConfig, + #[serde(skip)] + retry_config: RetryConfig, } impl Default for AnthropicProvider { @@ -48,76 +54,82 @@ impl Default for AnthropicProvider { impl AnthropicProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); - let api_key: String = config.get_secret("ANTHROPIC_API_KEY")?; - let host: String = config - .get_param("ANTHROPIC_HOST") - .unwrap_or_else(|_| "https://api.anthropic.com".to_string()); + let config_builder = ProviderConfigBuilder::new(config, "ANTHROPIC"); + + let api_key = config_builder.get_api_key()?; + let host = config_builder.get_host("https://api.anthropic.com"); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + // Use shared client for better connection pooling + let client = get_shared_client(); + + // Configure retry settings + let retry_config = RetryConfig { + max_retries: 3, + initial_delay_ms: 1000, + max_delay_ms: 32000, + backoff_multiplier: 2.0, + }; Ok(Self { client, host, api_key, model, + retry_config, }) } async fn post(&self, headers: HeaderMap, payload: Value) -> Result { - let base_url = url::Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("v1/messages").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; + let url = build_endpoint_url(&self.host, "v1/messages")?; - let response = self - .client - .post(url) - .headers(headers) - .json(&payload) - .send() - .await?; + retry_with_backoff(&self.retry_config, || async { + let response = self + .client + .post(url.clone()) + .headers(headers.clone()) + .json(&payload) + .send() + .await?; - let status = response.status(); - let payload: Option = response.json().await.ok(); + let status = response.status(); + let payload: Option = response.json().await.ok(); - // https://docs.anthropic.com/en/api/errors - match status { - StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, payload))) - } - StatusCode::BAD_REQUEST => { - let mut error_msg = "Unknown error".to_string(); - if let Some(payload) = &payload { - if let Some(error) = payload.get("error") { - tracing::debug!("Bad Request Error: {error:?}"); - error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string(); - if error_msg.to_lowercase().contains("too long") || error_msg.to_lowercase().contains("too many") { - return Err(ProviderError::ContextLengthExceeded(error_msg.to_string())); - } - }} - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg))) - } - StatusCode::TOO_MANY_REQUESTS => { - Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) - } - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - Err(ProviderError::ServerError(format!("{:?}", payload))) - } - _ => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) + // https://docs.anthropic.com/en/api/errors + match status { + StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ + Status: {}. Response: {:?}", status, payload))) + } + StatusCode::BAD_REQUEST => { + let mut error_msg = "Unknown error".to_string(); + if let Some(payload) = &payload { + if let Some(error) = payload.get("error") { + tracing::debug!("Bad Request Error: {error:?}"); + error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string(); + if error_msg.to_lowercase().contains("too long") || error_msg.to_lowercase().contains("too many") { + return Err(ProviderError::ContextLengthExceeded(error_msg.to_string())); + } + }} + tracing::debug!( + "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) + ); + Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg))) + } + StatusCode::TOO_MANY_REQUESTS => { + Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) + } + StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { + Err(ProviderError::ServerError(format!("{:?}", payload))) + } + _ => { + tracing::debug!( + "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) + ); + Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) + } } - } + }).await } } @@ -171,24 +183,35 @@ impl Provider for AnthropicProvider { ) -> Result<(Message, ProviderUsage), ProviderError> { let payload = create_request(&self.model, system, messages, tools)?; - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert("x-api-key", self.api_key.parse().unwrap()); - headers.insert("anthropic-version", ANTHROPIC_API_VERSION.parse().unwrap()); + // Build headers using the new HeaderBuilder + let mut header_builder = HeaderBuilder::new( + self.api_key.clone(), + AuthType::Custom("x-api-key".to_string()), + ); + header_builder = header_builder.add_custom_header( + "anthropic-version".to_string(), + ANTHROPIC_API_VERSION.to_string(), + ); let is_thinking_enabled = std::env::var("CLAUDE_THINKING_ENABLED").is_ok(); - if self.model.model_name.starts_with("claude-3-7-sonnet-") && is_thinking_enabled { - // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#extended-output-capabilities-beta - headers.insert("anthropic-beta", "output-128k-2025-02-19".parse().unwrap()); - } - if self.model.model_name.starts_with("claude-3-7-sonnet-") { - // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use - headers.insert( - "anthropic-beta", - "token-efficient-tools-2025-02-19".parse().unwrap(), - ); + if is_thinking_enabled { + // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#extended-output-capabilities-beta + header_builder = header_builder.add_custom_header( + "anthropic-beta".to_string(), + "output-128k-2025-02-19".to_string(), + ); + } else { + // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use + header_builder = header_builder.add_custom_header( + "anthropic-beta".to_string(), + "token-efficient-tools-2025-02-19".to_string(), + ); + } } + let headers = header_builder.build(); + // Make request let response = self.post(headers, payload.clone()).await?; diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index 51a31c06b957..025892df3c55 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -3,13 +3,16 @@ use async_trait::async_trait; use reqwest::Client; use serde::Serialize; use serde_json::Value; -use std::time::Duration; -use tokio::time::sleep; +use std::sync::Arc; use super::azureauth::AzureAuth; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; +use super::provider_common::{ + get_shared_client, retry_with_backoff_and_custom_delay, AuthType, HeaderBuilder, + ProviderConfigBuilder, RetryConfig, +}; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::message::Message; use crate::model::ModelConfig; @@ -29,12 +32,13 @@ const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0; #[derive(Debug)] pub struct AzureProvider { - client: Client, + client: Arc, auth: AzureAuth, endpoint: String, deployment_name: String, api_version: String, model: ModelConfig, + retry_config: RetryConfig, } impl Serialize for AzureProvider { @@ -43,10 +47,11 @@ impl Serialize for AzureProvider { S: serde::Serializer, { use serde::ser::SerializeStruct; - let mut state = serializer.serialize_struct("AzureProvider", 3)?; + let mut state = serializer.serialize_struct("AzureProvider", 4)?; state.serialize_field("endpoint", &self.endpoint)?; state.serialize_field("deployment_name", &self.deployment_name)?; state.serialize_field("api_version", &self.api_version)?; + state.serialize_field("model", &self.model)?; state.end() } } @@ -61,19 +66,32 @@ impl Default for AzureProvider { impl AzureProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); - let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?; - let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?; - let api_version: String = config - .get_param("AZURE_OPENAI_API_VERSION") - .unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string()); + let config_builder = ProviderConfigBuilder::new(config, "AZURE_OPENAI"); + + let endpoint = config_builder + .get_param("ENDPOINT", None) + .ok_or_else(|| anyhow::anyhow!("AZURE_OPENAI_ENDPOINT is required"))?; + let deployment_name = config_builder + .get_param("DEPLOYMENT_NAME", None) + .ok_or_else(|| anyhow::anyhow!("AZURE_OPENAI_DEPLOYMENT_NAME is required"))?; + let api_version = config_builder + .get_param("API_VERSION", Some(AZURE_DEFAULT_API_VERSION)) + .unwrap_or_else(|| AZURE_DEFAULT_API_VERSION.to_string()); // Try to get API key first, if not found use Azure credential chain let api_key = config.get_secret("AZURE_OPENAI_API_KEY").ok(); let auth = AzureAuth::new(api_key)?; - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + // Use shared client for better connection pooling + let client = get_shared_client(); + + // Configure retry settings with Azure's specific requirements + let retry_config = RetryConfig { + max_retries: DEFAULT_MAX_RETRIES as u32, + initial_delay_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS, + max_delay_ms: DEFAULT_MAX_RETRY_INTERVAL_MS, + backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER, + }; Ok(Self { client, @@ -82,6 +100,7 @@ impl AzureProvider { deployment_name, api_version, model, + retry_config, }) } @@ -106,110 +125,62 @@ impl AzureProvider { base_url.set_path(&new_path); base_url.set_query(Some(&format!("api-version={}", self.api_version))); - let mut attempts = 0; - let mut last_error = None; - let mut current_delay = DEFAULT_INITIAL_RETRY_INTERVAL_MS; - - loop { - // Check if we've exceeded max retries - if attempts > DEFAULT_MAX_RETRIES { - let error_msg = format!( - "Exceeded maximum retry attempts ({}) for rate limiting", - DEFAULT_MAX_RETRIES - ); - tracing::error!("{}", error_msg); - return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg))); - } - - // Get a fresh auth token for each attempt - let auth_token = self.auth.get_token().await.map_err(|e| { - tracing::error!("Authentication error: {:?}", e); - ProviderError::RequestFailed(format!("Failed to get authentication token: {}", e)) - })?; - - let mut request_builder = self.client.post(base_url.clone()); - let token_value = auth_token.token_value.clone(); - - // Set the correct header based on authentication type - match self.auth.credential_type() { - super::azureauth::AzureCredentials::ApiKey(_) => { - request_builder = request_builder.header("api-key", token_value.clone()); - } - super::azureauth::AzureCredentials::DefaultCredential => { - request_builder = request_builder - .header("Authorization", format!("Bearer {}", token_value.clone())); - } - } - - let response_result = request_builder.json(&payload).send().await; - - match response_result { - Ok(response) => match handle_response_openai_compat(response).await { - Ok(result) => { - return Ok(result); + // Use the enhanced retry logic with custom delay extraction for Azure + retry_with_backoff_and_custom_delay( + &self.retry_config, + || async { + // Get a fresh auth token for each attempt + let auth_token = self.auth.get_token().await.map_err(|e| { + tracing::error!("Authentication error: {:?}", e); + ProviderError::RequestFailed(format!( + "Failed to get authentication token: {}", + e + )) + })?; + + // Build headers using HeaderBuilder + let header_builder = match self.auth.credential_type() { + super::azureauth::AzureCredentials::ApiKey(_) => HeaderBuilder::new( + auth_token.token_value.clone(), + AuthType::Custom("api-key".to_string()), + ), + super::azureauth::AzureCredentials::DefaultCredential => { + HeaderBuilder::new(auth_token.token_value.clone(), AuthType::Bearer) } - Err(ProviderError::RateLimitExceeded(msg)) => { - attempts += 1; - last_error = Some(ProviderError::RateLimitExceeded(msg.clone())); - - let retry_after = - if let Some(secs) = msg.to_lowercase().find("try again in ") { - msg[secs..] - .split_whitespace() - .nth(3) - .and_then(|s| s.parse::().ok()) - .unwrap_or(0) - } else { - 0 - }; - - let delay = if retry_after > 0 { - Duration::from_secs(retry_after) + }; + + let headers = header_builder.build(); + + let response = self + .client + .post(base_url.clone()) + .headers(headers) + .json(&payload) + .send() + .await?; + + handle_response_openai_compat(response).await + }, + |error| { + // Extract retry-after delay from Azure error messages + match error { + ProviderError::RateLimitExceeded(msg) => { + // Look for "try again in X seconds" pattern + if let Some(pos) = msg.to_lowercase().find("try again in ") { + let rest = &msg[pos + 13..]; // Skip "try again in " + rest.split_whitespace() + .next() + .and_then(|s| s.parse::().ok()) + .map(|secs| secs * 1000) // Convert to milliseconds } else { - let delay = current_delay.min(DEFAULT_MAX_RETRY_INTERVAL_MS); - current_delay = - (current_delay as f64 * DEFAULT_BACKOFF_MULTIPLIER) as u64; - Duration::from_millis(delay) - }; - - sleep(delay).await; - continue; - } - Err(e) => { - tracing::error!( - "Error response from Azure OpenAI (attempt {}): {:?}", - attempts + 1, - e - ); - return Err(e); + None + } } - }, - Err(e) => { - tracing::error!( - "Request failed (attempt {}): {:?}\nIs timeout: {}\nIs connect: {}\nIs request: {}", - attempts + 1, - e, - e.is_timeout(), - e.is_connect(), - e.is_request(), - ); - - // For timeout errors, we should retry - if e.is_timeout() { - attempts += 1; - let delay = current_delay.min(DEFAULT_MAX_RETRY_INTERVAL_MS); - current_delay = (current_delay as f64 * DEFAULT_BACKOFF_MULTIPLIER) as u64; - sleep(Duration::from_millis(delay)).await; - continue; - } - - return Err(ProviderError::RequestFailed(format!( - "Request failed: {}", - e - ))); + _ => None, } - } - } + }, + ) + .await } } diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 31e6cf8b4363..69ba3f1813ab 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::time::Duration; use anyhow::Result; use async_trait::async_trait; @@ -8,10 +7,10 @@ use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; use mcp_core::Tool; use serde_json::Value; -use tokio::time::sleep; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; +use super::provider_common::{retry_with_backoff, RetryConfig}; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::utils::emit_debug_trace; @@ -35,6 +34,8 @@ pub struct BedrockProvider { #[serde(skip)] client: Client, model: ModelConfig, + #[serde(skip)] + retry_config: RetryConfig, } impl BedrockProvider { @@ -66,7 +67,19 @@ impl BedrockProvider { )?; let client = Client::new(&sdk_config); - Ok(Self { client, model }) + // Configure retry settings for Bedrock's specific requirements + let retry_config = RetryConfig { + max_retries: 10, + initial_delay_ms: 20_000, // 20 seconds + max_delay_ms: 120_000, // 2 minutes + backoff_multiplier: 1.5, // Slower backoff for Bedrock + }; + + Ok(Self { + client, + model, + retry_config, + }) } } @@ -123,21 +136,14 @@ impl Provider for BedrockProvider { request = request.tool_config(to_bedrock_tool_config(tools)?); } - // Retry configuration - const MAX_RETRIES: u32 = 10; - const INITIAL_BACKOFF_MS: u64 = 20_000; // 20 seconds - const MAX_BACKOFF_MS: u64 = 120_000; // 120 seconds (2 minutes) - - let mut attempts = 0; - let mut backoff_ms = INITIAL_BACKOFF_MS; - - loop { - attempts += 1; + // Use retry logic for resilience + retry_with_backoff(&self.retry_config, || async { + let result = request.clone().send().await; - match request.clone().send().await { + match result { Ok(response) => { // Successful response, process it and return - return match response.output { + match response.output { Some(bedrock::ConverseOutput::Message(message)) => { let usage = response .usage @@ -166,73 +172,40 @@ impl Provider for BedrockProvider { _ => Err(ProviderError::RequestFailed( "No output from Bedrock".to_string(), )), - }; - } - Err(err) => { - match err.into_service_error() { - ConverseError::ThrottlingException(throttle_err) => { - if attempts > MAX_RETRIES { - // We've exhausted our retries - tracing::error!( - "Failed after {MAX_RETRIES} retries: {:?}", - throttle_err - ); - return Err(ProviderError::RateLimitExceeded(format!( - "Failed to call Bedrock after {MAX_RETRIES} retries: {:?}", - throttle_err - ))); - } - - // Log retry attempt - tracing::warn!( - "Bedrock throttling error (attempt {}/{}), retrying in {} ms: {:?}", - attempts, - MAX_RETRIES, - backoff_ms, - throttle_err - ); - - // Wait before retry with exponential backoff - sleep(Duration::from_millis(backoff_ms)).await; - - // Calculate next backoff with exponential growth, capped at max - backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS); - - // Continue to the next retry attempt - continue; - } - ConverseError::AccessDeniedException(err) => { - return Err(ProviderError::Authentication(format!( - "Failed to call Bedrock: {:?}", - err - ))); - } - ConverseError::ValidationException(err) - if err - .message() - .unwrap_or_default() - .contains("Input is too long for requested model.") => - { - return Err(ProviderError::ContextLengthExceeded(format!( - "Failed to call Bedrock: {:?}", - err - ))); - } - ConverseError::ModelErrorException(err) => { - return Err(ProviderError::ExecutionError(format!( - "Failed to call Bedrock: {:?}", - err - ))); - } - err => { - return Err(ProviderError::ServerError(format!( - "Failed to call Bedrock: {:?}", - err - ))); - } } } + Err(err) => match err.into_service_error() { + ConverseError::ThrottlingException(throttle_err) => { + tracing::warn!("Bedrock throttling error: {:?}", throttle_err); + Err(ProviderError::RateLimitExceeded(format!( + "Failed to call Bedrock: {:?}", + throttle_err + ))) + } + ConverseError::AccessDeniedException(err) => Err( + ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err)), + ), + ConverseError::ValidationException(err) + if err + .message() + .unwrap_or_default() + .contains("Input is too long for requested model.") => + { + Err(ProviderError::ContextLengthExceeded(format!( + "Failed to call Bedrock: {:?}", + err + ))) + } + ConverseError::ModelErrorException(err) => Err(ProviderError::ExecutionError( + format!("Failed to call Bedrock: {:?}", err), + )), + err => Err(ProviderError::ServerError(format!( + "Failed to call Bedrock: {:?}", + err + ))), + }, } - } + }) + .await } } diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index fbfd22a0a634..7917eaac6f7e 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -3,21 +3,22 @@ use super::embedding::EmbeddingCapable; use super::errors::ProviderError; use super::formats::databricks::{create_request, get_usage, response_to_message}; use super::oauth; -use super::utils::{get_model, ImageFormat}; +use super::provider_common::{ + build_endpoint_url, get_shared_client, retry_with_backoff, ProviderConfigBuilder, RetryConfig, +}; +use super::utils::{emit_debug_trace, get_model, ImageFormat}; use crate::config::ConfigError; use crate::message::Message; use crate::model::ModelConfig; use mcp_core::tool::Tool; use serde_json::json; -use url::Url; use anyhow::Result; use async_trait::async_trait; use reqwest::{Client, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::time::Duration; -use tokio::time::sleep; +use std::sync::Arc; const DEFAULT_CLIENT_ID: &str = "databricks-cli"; const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; @@ -25,16 +26,10 @@ const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; // https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"]; -/// Default timeout for API requests in seconds -const DEFAULT_TIMEOUT_SECS: u64 = 600; -/// Default initial interval for retry (in milliseconds) -const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000; -/// Default maximum number of retries -const DEFAULT_MAX_RETRIES: usize = 6; -/// Default retry backoff multiplier -const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0; -/// Default maximum interval for retry (in milliseconds) -const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000; +// Databricks specific retry settings +const DATABRICKS_MAX_RETRIES: u32 = 6; +const DATABRICKS_INITIAL_RETRY_INTERVAL_MS: u64 = 5000; +const DATABRICKS_MAX_RETRY_INTERVAL_MS: u64 = 320_000; pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet"; // Databricks can passthrough to a wide range of models, we only provide the default @@ -48,53 +43,6 @@ pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[ pub const DATABRICKS_DOC_URL: &str = "https://docs.databricks.com/en/generative-ai/external-models/index.html"; -/// Retry configuration for handling rate limit errors -#[derive(Debug, Clone)] -struct RetryConfig { - /// Maximum number of retry attempts - max_retries: usize, - /// Initial interval between retries in milliseconds - initial_interval_ms: u64, - /// Multiplier for backoff (exponential) - backoff_multiplier: f64, - /// Maximum interval between retries in milliseconds - max_interval_ms: u64, -} - -impl Default for RetryConfig { - fn default() -> Self { - Self { - max_retries: DEFAULT_MAX_RETRIES, - initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS, - backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER, - max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS, - } - } -} - -impl RetryConfig { - /// Calculate the delay for a specific retry attempt (with jitter) - fn delay_for_attempt(&self, attempt: usize) -> Duration { - if attempt == 0 { - return Duration::from_millis(0); - } - - // Calculate exponential backoff - let exponent = (attempt - 1) as u32; - let base_delay_ms = (self.initial_interval_ms as f64 - * self.backoff_multiplier.powi(exponent as i32)) as u64; - - // Apply max limit - let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms); - - // Add jitter (+/-20% randomness) to avoid thundering herd problem - let jitter_factor = 0.8 + (rand::random::() * 0.4); // Between 0.8 and 1.2 - let jittered_delay_ms = (capped_delay_ms as f64 * jitter_factor) as u64; - - Duration::from_millis(jittered_delay_ms) - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub enum DatabricksAuth { Token(String), @@ -124,7 +72,7 @@ impl DatabricksAuth { #[derive(Debug, serde::Serialize)] pub struct DatabricksProvider { #[serde(skip)] - client: Client, + client: Arc, host: String, auth: DatabricksAuth, model: ModelConfig, @@ -143,6 +91,7 @@ impl Default for DatabricksProvider { impl DatabricksProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); + let config_builder = ProviderConfigBuilder::new(config, "DATABRICKS"); // For compatibility for now we check both config and secret for databricks host // but it is not actually a secret value @@ -160,12 +109,25 @@ impl DatabricksProvider { let host = host?; - let client = Client::builder() - .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS)) - .build()?; - - // Load optional retry configuration from environment - let retry_config = Self::load_retry_config(config); + // Use shared client for better connection pooling + let client = get_shared_client(); + + // Configure retry settings with Databricks' specific requirements + let retry_config = RetryConfig { + max_retries: config_builder + .get_param("MAX_RETRIES", None) + .and_then(|v| v.parse::().ok()) + .unwrap_or(DATABRICKS_MAX_RETRIES), + initial_delay_ms: config_builder + .get_param("INITIAL_RETRY_INTERVAL_MS", None) + .and_then(|v| v.parse::().ok()) + .unwrap_or(DATABRICKS_INITIAL_RETRY_INTERVAL_MS), + max_delay_ms: config_builder + .get_param("MAX_RETRY_INTERVAL_MS", None) + .and_then(|v| v.parse::().ok()) + .unwrap_or(DATABRICKS_MAX_RETRY_INTERVAL_MS), + backoff_multiplier: 2.0, + }; // If we find a databricks token we prefer that if let Ok(api_key) = config.get_secret("DATABRICKS_TOKEN") { @@ -190,40 +152,6 @@ impl DatabricksProvider { }) } - /// Loads retry configuration from environment variables or uses defaults. - fn load_retry_config(config: &crate::config::Config) -> RetryConfig { - let max_retries = config - .get_param("DATABRICKS_MAX_RETRIES") - .ok() - .and_then(|v: String| v.parse::().ok()) - .unwrap_or(DEFAULT_MAX_RETRIES); - - let initial_interval_ms = config - .get_param("DATABRICKS_INITIAL_RETRY_INTERVAL_MS") - .ok() - .and_then(|v: String| v.parse::().ok()) - .unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS); - - let backoff_multiplier = config - .get_param("DATABRICKS_BACKOFF_MULTIPLIER") - .ok() - .and_then(|v: String| v.parse::().ok()) - .unwrap_or(DEFAULT_BACKOFF_MULTIPLIER); - - let max_interval_ms = config - .get_param("DATABRICKS_MAX_RETRY_INTERVAL_MS") - .ok() - .and_then(|v: String| v.parse::().ok()) - .unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS); - - RetryConfig { - max_retries, - initial_interval_ms, - backoff_multiplier, - max_interval_ms, - } - } - /// Create a new DatabricksProvider with the specified host and token /// /// # Arguments @@ -235,9 +163,8 @@ impl DatabricksProvider { /// /// Returns a Result containing the new DatabricksProvider instance pub fn from_params(host: String, api_key: String, model: ModelConfig) -> Result { - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + // Use shared client for better connection pooling + let client = get_shared_client(); Ok(Self { client, @@ -266,9 +193,6 @@ impl DatabricksProvider { } async fn post(&self, payload: Value) -> Result { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - // Check if this is an embedding request by looking at the payload structure let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none(); let path = if is_embedding { @@ -279,26 +203,16 @@ impl DatabricksProvider { format!("serving-endpoints/{}/invocations", self.model.model_name) }; - let url = base_url.join(&path).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; + let url = build_endpoint_url(&self.host, &path)?; - // Initialize retry counter - let mut attempts = 0; - let mut last_error = None; + // Use retry logic for resilience + retry_with_backoff(&self.retry_config, || async { + // Get a fresh auth token for each attempt + let auth_header = self.ensure_auth_header().await.map_err(|e| { + tracing::error!("Authentication error: {:?}", e); + ProviderError::RequestFailed(format!("Failed to get authentication token: {}", e)) + })?; - loop { - // Check if we've exceeded max retries - if attempts > 0 && attempts > self.retry_config.max_retries { - let error_msg = format!( - "Exceeded maximum retry attempts ({}) for rate limiting (429)", - self.retry_config.max_retries - ); - tracing::error!("{}", error_msg); - return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg))); - } - - let auth_header = self.ensure_auth_header().await?; let response = self .client .post(url.clone()) @@ -308,25 +222,25 @@ impl DatabricksProvider { .await?; let status = response.status(); - let payload: Option = response.json().await.ok(); + let response_body: Option = response.json().await.ok(); match status { StatusCode::OK => { - return payload.ok_or_else(|| { + response_body.ok_or_else(|| { ProviderError::RequestFailed("Response body is not valid JSON".to_string()) - }); + }) } StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - return Err(ProviderError::Authentication(format!( + Err(ProviderError::Authentication(format!( "Authentication failed. Please ensure your API keys are valid and have the required permissions. \ Status: {}. Response: {:?}", - status, payload - ))); + status, response_body + ))) } StatusCode::BAD_REQUEST => { // Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific // We try to extract the error message from the payload and check for phrases that indicate context length exceeded - let payload_str = serde_json::to_string(&payload) + let payload_str = serde_json::to_string(&response_body) .unwrap_or_default() .to_lowercase(); let check_phrases = [ @@ -347,7 +261,7 @@ impl DatabricksProvider { } let mut error_msg = "Unknown error".to_string(); - if let Some(payload) = &payload { + if let Some(payload) = &response_body { // try to convert message to string, if that fails use external_model_message error_msg = payload .get("message") @@ -366,67 +280,35 @@ impl DatabricksProvider { "{}", format!( "Provider request failed with status: {}. Payload: {:?}", - status, payload + status, response_body ) ); - return Err(ProviderError::RequestFailed(format!( + Err(ProviderError::RequestFailed(format!( "Request failed with status: {}. Message: {}", status, error_msg - ))); + ))) } StatusCode::TOO_MANY_REQUESTS => { - attempts += 1; - let error_msg = format!( - "Rate limit exceeded (attempt {}/{}): {:?}", - attempts, self.retry_config.max_retries, payload - ); - tracing::warn!("{}. Retrying after backoff...", error_msg); - - // Store the error in case we need to return it after max retries - last_error = Some(ProviderError::RateLimitExceeded(error_msg)); - - // Calculate and apply the backoff delay - let delay = self.retry_config.delay_for_attempt(attempts); - tracing::info!("Backing off for {:?} before retry", delay); - sleep(delay).await; - - // Continue to the next retry attempt - continue; + Err(ProviderError::RateLimitExceeded(format!("{:?}", response_body))) } StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - attempts += 1; - let error_msg = format!( - "Server error (attempt {}/{}): {:?}", - attempts, self.retry_config.max_retries, payload - ); - tracing::warn!("{}. Retrying after backoff...", error_msg); - - // Store the error in case we need to return it after max retries - last_error = Some(ProviderError::ServerError(error_msg)); - - // Calculate and apply the backoff delay - let delay = self.retry_config.delay_for_attempt(attempts); - tracing::info!("Backing off for {:?} before retry", delay); - sleep(delay).await; - - // Continue to the next retry attempt - continue; + Err(ProviderError::ServerError(format!("{:?}", response_body))) } _ => { tracing::debug!( "{}", format!( "Provider request failed with status: {}. Payload: {:?}", - status, payload + status, response_body ) ); - return Err(ProviderError::RequestFailed(format!( + Err(ProviderError::RequestFailed(format!( "Request failed with status: {}", status - ))); + ))) } } - } + }).await } } @@ -481,7 +363,7 @@ impl Provider for DatabricksProvider { Err(e) => return Err(e), }; let model = get_model(&response); - super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); + emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) } @@ -497,11 +379,7 @@ impl Provider for DatabricksProvider { } async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("api/2.0/serving-endpoints").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; + let url = build_endpoint_url(&self.host, "api/2.0/serving-endpoints")?; let auth_header = match self.ensure_auth_header().await { Ok(header) => header, diff --git a/crates/goose/src/providers/errors.rs b/crates/goose/src/providers/errors.rs index c9f867c41f04..ac1aa88dd1e9 100644 --- a/crates/goose/src/providers/errors.rs +++ b/crates/goose/src/providers/errors.rs @@ -23,6 +23,18 @@ pub enum ProviderError { #[error("Usage data error: {0}")] UsageError(String), + + #[error("Timeout error: Request timed out after {0} seconds")] + Timeout(u64), + + #[error("Network error: {0}")] + NetworkError(String), + + #[error("Invalid response: {0}")] + InvalidResponse(String), + + #[error("Configuration error: {0}")] + ConfigurationError(String), } impl From for ProviderError { @@ -33,7 +45,18 @@ impl From for ProviderError { impl From for ProviderError { fn from(error: reqwest::Error) -> Self { - ProviderError::ExecutionError(error.to_string()) + if error.is_timeout() { + // Extract timeout duration if possible from error message + ProviderError::Timeout(600) // Default to our standard timeout + } else if error.is_connect() { + ProviderError::NetworkError(format!("Connection failed: {}", error)) + } else if error.is_decode() { + ProviderError::InvalidResponse(format!("Failed to decode response: {}", error)) + } else if error.is_builder() || error.is_request() { + ProviderError::ConfigurationError(format!("Request configuration error: {}", error)) + } else { + ProviderError::ExecutionError(error.to_string()) + } } } @@ -177,3 +200,43 @@ impl std::fmt::Display for OpenAIError { Ok(()) } } + +/// Trait for parsing provider-specific error responses +pub trait ProviderErrorParser { + /// Parse an error response into a ProviderError + fn parse_error_response(&self, status: StatusCode, response_text: &str) -> ProviderError; + + /// Check if an error indicates context length exceeded + fn is_context_length_error(&self, response_text: &str) -> bool { + response_text.to_lowercase().contains("context") + || response_text.to_lowercase().contains("too long") + || response_text.to_lowercase().contains("too many tokens") + || response_text.to_lowercase().contains("exceeds") + } +} + +/// Default implementation for providers without specific error parsing +pub struct DefaultErrorParser; + +impl ProviderErrorParser for DefaultErrorParser { + fn parse_error_response(&self, status: StatusCode, response_text: &str) -> ProviderError { + let error_msg = format!("API error ({}): {}", status, response_text); + + match status { + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + ProviderError::Authentication(error_msg) + } + StatusCode::TOO_MANY_REQUESTS => ProviderError::RateLimitExceeded(error_msg), + StatusCode::BAD_REQUEST => { + if self.is_context_length_error(response_text) { + ProviderError::ContextLengthExceeded(error_msg) + } else { + ProviderError::RequestFailed(error_msg) + } + } + s if s.is_server_error() => ProviderError::ServerError(error_msg), + StatusCode::REQUEST_TIMEOUT => ProviderError::Timeout(408), + _ => ProviderError::RequestFailed(error_msg), + } + } +} diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index 6385ec299abc..4150988449bc 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -1,10 +1,9 @@ -use std::time::Duration; +use std::sync::Arc; use anyhow::Result; use async_trait::async_trait; use reqwest::{Client, StatusCode}; use serde_json::Value; -use tokio::time::sleep; use url::Url; use crate::message::Message; @@ -19,21 +18,18 @@ use crate::providers::formats::gcpvertexai::{ use crate::providers::formats::gcpvertexai::GcpLocation::Iowa; use crate::providers::gcpauth::GcpAuth; +use crate::providers::provider_common::{ + get_shared_client, retry_with_backoff, ProviderConfigBuilder, RetryConfig, +}; use crate::providers::utils::emit_debug_trace; use mcp_core::tool::Tool; /// Base URL for GCP Vertex AI documentation const GCP_VERTEX_AI_DOC_URL: &str = "https://cloud.google.com/vertex-ai"; -/// Default timeout for API requests in seconds -const DEFAULT_TIMEOUT_SECS: u64 = 600; -/// Default initial interval for retry (in milliseconds) -const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000; -/// Default maximum number of retries -const DEFAULT_MAX_RETRIES: usize = 6; -/// Default retry backoff multiplier -const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0; -/// Default maximum interval for retry (in milliseconds) -const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000; +// GCP Vertex AI specific retry settings +const GCP_MAX_RETRIES: u32 = 6; +const GCP_INITIAL_RETRY_INTERVAL_MS: u64 = 5000; +const GCP_MAX_RETRY_INTERVAL_MS: u64 = 320_000; /// Represents errors specific to GCP Vertex AI operations. #[derive(Debug, thiserror::Error)] @@ -47,53 +43,6 @@ enum GcpVertexAIError { AuthError(String), } -/// Retry configuration for handling rate limit errors -#[derive(Debug, Clone)] -struct RetryConfig { - /// Maximum number of retry attempts - max_retries: usize, - /// Initial interval between retries in milliseconds - initial_interval_ms: u64, - /// Multiplier for backoff (exponential) - backoff_multiplier: f64, - /// Maximum interval between retries in milliseconds - max_interval_ms: u64, -} - -impl Default for RetryConfig { - fn default() -> Self { - Self { - max_retries: DEFAULT_MAX_RETRIES, - initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS, - backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER, - max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS, - } - } -} - -impl RetryConfig { - /// Calculate the delay for a specific retry attempt (with jitter) - fn delay_for_attempt(&self, attempt: usize) -> Duration { - if attempt == 0 { - return Duration::from_millis(0); - } - - // Calculate exponential backoff - let exponent = (attempt - 1) as u32; - let base_delay_ms = (self.initial_interval_ms as f64 - * self.backoff_multiplier.powi(exponent as i32)) as u64; - - // Apply max limit - let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms); - - // Add jitter (+/-20% randomness) to avoid thundering herd problem - let jitter_factor = 0.8 + (rand::random::() * 0.4); // Between 0.8 and 1.2 - let jittered_delay_ms = (capped_delay_ms as f64 * jitter_factor) as u64; - - Duration::from_millis(jittered_delay_ms) - } -} - /// Provider implementation for Google Cloud Platform's Vertex AI service. /// /// This provider enables interaction with various AI models hosted on GCP Vertex AI, @@ -103,7 +52,7 @@ impl RetryConfig { pub struct GcpVertexAIProvider { /// HTTP client for making API requests #[serde(skip)] - client: Client, + client: Arc, /// GCP authentication handler #[serde(skip)] auth: GcpAuth, @@ -146,18 +95,33 @@ impl GcpVertexAIProvider { /// * `model` - Configuration for the model to be used async fn new_async(model: ModelConfig) -> Result { let config = crate::config::Config::global(); + let config_builder = ProviderConfigBuilder::new(config, "GCP"); + let project_id = config.get_param("GCP_PROJECT_ID")?; let location = Self::determine_location(config)?; let host = format!("https://{}-aiplatform.googleapis.com", location); - let client = Client::builder() - .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS)) - .build()?; + // Use shared client for better connection pooling + let client = get_shared_client(); let auth = GcpAuth::new().await?; - // Load optional retry configuration from environment - let retry_config = Self::load_retry_config(config); + // Configure retry settings with GCP's specific requirements + let retry_config = RetryConfig { + max_retries: config_builder + .get_param("MAX_RETRIES", None) + .and_then(|v| v.parse::().ok()) + .unwrap_or(GCP_MAX_RETRIES), + initial_delay_ms: config_builder + .get_param("INITIAL_RETRY_INTERVAL_MS", None) + .and_then(|v| v.parse::().ok()) + .unwrap_or(GCP_INITIAL_RETRY_INTERVAL_MS), + max_delay_ms: config_builder + .get_param("MAX_RETRY_INTERVAL_MS", None) + .and_then(|v| v.parse::().ok()) + .unwrap_or(GCP_MAX_RETRY_INTERVAL_MS), + backoff_multiplier: 2.0, + }; Ok(Self { client, @@ -170,40 +134,6 @@ impl GcpVertexAIProvider { }) } - /// Loads retry configuration from environment variables or uses defaults. - fn load_retry_config(config: &crate::config::Config) -> RetryConfig { - let max_retries = config - .get_param("GCP_MAX_RETRIES") - .ok() - .and_then(|v: String| v.parse::().ok()) - .unwrap_or(DEFAULT_MAX_RETRIES); - - let initial_interval_ms = config - .get_param("GCP_INITIAL_RETRY_INTERVAL_MS") - .ok() - .and_then(|v: String| v.parse::().ok()) - .unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS); - - let backoff_multiplier = config - .get_param("GCP_BACKOFF_MULTIPLIER") - .ok() - .and_then(|v: String| v.parse::().ok()) - .unwrap_or(DEFAULT_BACKOFF_MULTIPLIER); - - let max_interval_ms = config - .get_param("GCP_MAX_RETRY_INTERVAL_MS") - .ok() - .and_then(|v: String| v.parse::().ok()) - .unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS); - - RetryConfig { - max_retries, - initial_interval_ms, - backoff_multiplier, - max_interval_ms, - } - } - /// Determines the appropriate GCP location for model deployment. /// /// Location is determined in the following order: @@ -285,21 +215,7 @@ impl GcpVertexAIProvider { .build_request_url(context.provider(), location) .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; - // Initialize retry counter - let mut attempts = 0; - let mut last_error = None; - - loop { - // Check if we've exceeded max retries - if attempts > 0 && attempts > self.retry_config.max_retries { - let error_msg = format!( - "Exceeded maximum retry attempts ({}) for rate limiting (429)", - self.retry_config.max_retries - ); - tracing::error!("{}", error_msg); - return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg))); - } - + retry_with_backoff(&self.retry_config, || async { // Get a fresh auth token for each attempt let auth_header = self .get_auth_header() @@ -318,61 +234,46 @@ impl GcpVertexAIProvider { let status = response.status(); - // If not a 429, process normally - if status != StatusCode::TOO_MANY_REQUESTS { - let response_json = response.json::().await.map_err(|e| { - ProviderError::RequestFailed(format!("Failed to parse response: {e}")) - })?; - - return match status { - StatusCode::OK => Ok(response_json), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - tracing::debug!( - "Authentication failed. Status: {status}, Payload: {payload:?}" - ); - Err(ProviderError::Authentication(format!( - "Authentication failed: {response_json:?}" - ))) - } - _ => { - tracing::debug!( - "Request failed. Status: {status}, Response: {response_json:?}" - ); - Err(ProviderError::RequestFailed(format!( - "Request failed with status {status}: {response_json:?}" - ))) - } + // Handle rate limits specially for GCP + if status == StatusCode::TOO_MANY_REQUESTS { + // Try to parse response for more detailed error info + let cite_gcp_vertex_429 = + "See https://cloud.google.com/vertex-ai/generative-ai/docs/error-code-429"; + let response_text = response.text().await.unwrap_or_default(); + let quota_error = if response_text.contains("Exceeded the Provisioned Throughput") { + format!("Exceeded the Provisioned Throughput: {cite_gcp_vertex_429}.") + } else { + format!("Pay-as-you-go resource exhausted: {cite_gcp_vertex_429}.") }; + + return Err(ProviderError::RateLimitExceeded(quota_error)); } - // Handle 429 Too Many Requests - attempts += 1; - - // Try to parse response for more detailed error info - let cite_gcp_vertex_429 = - "See https://cloud.google.com/vertex-ai/generative-ai/docs/error-code-429"; - let response_text = response.text().await.unwrap_or_default(); - let quota_error = if response_text.contains("Exceeded the Provisioned Throughput") { - format!("Exceeded the Provisioned Throughput: {cite_gcp_vertex_429}.") - } else { - format!("Pay-as-you-go resource exhausted: {cite_gcp_vertex_429}.") - }; - - tracing::warn!( - "Rate limit exceeded (attempt {}/{}): {}. Retrying after backoff...", - attempts, - self.retry_config.max_retries, - quota_error - ); - - // Store the error in case we need to return it after max retries - last_error = Some(ProviderError::RateLimitExceeded(quota_error)); - - // Calculate and apply the backoff delay - let delay = self.retry_config.delay_for_attempt(attempts); - tracing::info!("Backing off for {:?} before retry", delay); - sleep(delay).await; - } + let response_json = response.json::().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse response: {e}")) + })?; + + match status { + StatusCode::OK => Ok(response_json), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + tracing::debug!( + "Authentication failed. Status: {status}, Payload: {payload:?}" + ); + Err(ProviderError::Authentication(format!( + "Authentication failed: {response_json:?}" + ))) + } + _ => { + tracing::debug!( + "Request failed. Status: {status}, Response: {response_json:?}" + ); + Err(ProviderError::RequestFailed(format!( + "Request failed with status {status}: {response_json:?}" + ))) + } + } + }) + .await } /// Makes an authenticated POST request to the Vertex AI API with fallback for invalid locations. @@ -460,25 +361,19 @@ impl Provider for GcpVertexAIProvider { "GCP_MAX_RETRIES", false, false, - Some(&DEFAULT_MAX_RETRIES.to_string()), + Some(&GCP_MAX_RETRIES.to_string()), ), ConfigKey::new( "GCP_INITIAL_RETRY_INTERVAL_MS", false, false, - Some(&DEFAULT_INITIAL_RETRY_INTERVAL_MS.to_string()), - ), - ConfigKey::new( - "GCP_BACKOFF_MULTIPLIER", - false, - false, - Some(&DEFAULT_BACKOFF_MULTIPLIER.to_string()), + Some(&GCP_INITIAL_RETRY_INTERVAL_MS.to_string()), ), ConfigKey::new( "GCP_MAX_RETRY_INTERVAL_MS", false, false, - Some(&DEFAULT_MAX_RETRY_INTERVAL_MS.to_string()), + Some(&GCP_MAX_RETRY_INTERVAL_MS.to_string()), ), ], ) @@ -526,32 +421,6 @@ impl Provider for GcpVertexAIProvider { mod tests { use super::*; - #[test] - fn test_retry_config_delay_calculation() { - let config = RetryConfig { - max_retries: 5, - initial_interval_ms: 1000, - backoff_multiplier: 2.0, - max_interval_ms: 32000, - }; - - // First attempt has no delay - let delay0 = config.delay_for_attempt(0); - assert_eq!(delay0.as_millis(), 0); - - // First retry should be around initial_interval with jitter - let delay1 = config.delay_for_attempt(1); - assert!(delay1.as_millis() >= 800 && delay1.as_millis() <= 1200); - - // Second retry should be around initial_interval * multiplier^1 with jitter - let delay2 = config.delay_for_attempt(2); - assert!(delay2.as_millis() >= 1600 && delay2.as_millis() <= 2400); - - // Check that max interval is respected - let delay10 = config.delay_for_attempt(10); - assert!(delay10.as_millis() <= 38400); // max_interval_ms * 1.2 (max jitter) - } - #[test] fn test_model_provider_conversion() { assert_eq!(ModelProvider::Anthropic.as_str(), "anthropic"); @@ -596,7 +465,7 @@ mod tests { .collect(); assert!(model_names.contains(&"claude-3-5-sonnet-v2@20241022".to_string())); assert!(model_names.contains(&"gemini-1.5-pro-002".to_string())); - // Should contain the original 2 config keys plus 4 new retry-related ones - assert_eq!(metadata.config_keys.len(), 6); + // Should contain the original 2 config keys plus 3 new retry-related ones + assert_eq!(metadata.config_keys.len(), 5); } } diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 97bd3ad589e0..c64d18f0578e 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -9,11 +9,12 @@ use serde_json::Value; use std::cell::RefCell; use std::collections::HashMap; use std::path::PathBuf; -use std::time::Duration; +use std::sync::Arc; use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; +use super::provider_common::{get_shared_client, retry_with_backoff, RetryConfig}; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::config::{Config, ConfigError}; @@ -107,12 +108,14 @@ impl DiskCache { #[derive(Debug, serde::Serialize)] pub struct GithubCopilotProvider { #[serde(skip)] - client: Client, + client: Arc, #[serde(skip)] cache: DiskCache, #[serde(skip)] mu: tokio::sync::Mutex>>, model: ModelConfig, + #[serde(skip)] + retry_config: RetryConfig, } impl Default for GithubCopilotProvider { @@ -124,16 +127,16 @@ impl Default for GithubCopilotProvider { impl GithubCopilotProvider { pub fn from_env(model: ModelConfig) -> Result { - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + let client = get_shared_client(); let cache = DiskCache::new(); let mu = tokio::sync::Mutex::new(RefCell::new(None)); + let retry_config = RetryConfig::default(); Ok(Self { client, cache, mu, model, + retry_config, }) } @@ -154,14 +157,19 @@ impl GithubCopilotProvider { let (endpoint, token) = self.get_api_info().await?; let url = url::Url::parse(&format!("{}/chat/completions", endpoint)) .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let response = self - .client - .post(url) - .headers(self.get_github_headers()) - .header("Authorization", format!("Bearer {}", token)) - .json(&payload) - .send() - .await?; + + // Use retry logic for resilience + let response = retry_with_backoff(&self.retry_config, || async { + self.client + .post(url.clone()) + .headers(self.get_github_headers()) + .header("Authorization", format!("Bearer {}", token.clone())) + .json(&payload) + .send() + .await + .map_err(|e| ProviderError::RequestFailed(e.to_string())) + }) + .await?; if stream_only_model { let mut collector = OAIStreamCollector::new(); let mut stream = response.bytes_stream(); diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index dbe9b331d160..db2cd384b708 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -3,17 +3,18 @@ use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use crate::providers::formats::google::{create_request, get_usage, response_to_message}; +use crate::providers::provider_common::{ + build_endpoint_url, get_shared_client, retry_with_backoff, ProviderConfigBuilder, RetryConfig, +}; use crate::providers::utils::{ emit_debug_trace, handle_response_google_compat, unescape_json_values, }; use anyhow::Result; use async_trait::async_trait; -use axum::http::HeaderMap; use mcp_core::tool::Tool; use reqwest::Client; use serde_json::Value; -use std::time::Duration; -use url::Url; +use std::sync::Arc; pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com"; pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-2.5-flash"; @@ -50,9 +51,13 @@ pub const GOOGLE_DOC_URL: &str = "https://ai.google.dev/gemini-api/docs/models"; #[derive(Debug, serde::Serialize)] pub struct GoogleProvider { #[serde(skip)] - client: Client, + client: Arc, host: String, model: ModelConfig, + #[serde(skip)] + retry_config: RetryConfig, + #[serde(skip)] + api_key: String, } impl Default for GoogleProvider { @@ -65,82 +70,46 @@ impl Default for GoogleProvider { impl GoogleProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); - let api_key: String = config.get_secret("GOOGLE_API_KEY")?; - let host: String = config - .get_param("GOOGLE_HOST") - .unwrap_or_else(|_| GOOGLE_API_HOST.to_string()); + let config_builder = ProviderConfigBuilder::new(config, "GOOGLE"); - let mut headers = HeaderMap::new(); - headers.insert("CONTENT_TYPE", "application/json".parse()?); - headers.insert("x-goog-api-key", api_key.parse()?); + let api_key = config_builder.get_api_key()?; + let host = config_builder.get_host(GOOGLE_API_HOST); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .default_headers(headers) - .build()?; + // Use shared client for better connection pooling + let client = get_shared_client(); + + // Configure retry settings + let retry_config = RetryConfig::default(); Ok(Self { client, host, model, + retry_config, + api_key, }) } async fn post(&self, payload: Value) -> Result { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - - let url = base_url - .join(&format!( - "v1beta/models/{}:generateContent", - self.model.model_name - )) - .map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - let max_retries = 3; - let mut retries = 0; - let base_delay = Duration::from_secs(2); - - loop { + let path = format!( + "v1beta/models/{}:generateContent?key={}", + self.model.model_name, self.api_key + ); + let url = build_endpoint_url(&self.host, &path)?; + + // Use retry logic for resilience + retry_with_backoff(&self.retry_config, || async { let response = self .client - .post(url.clone()) // Clone the URL for each retry + .post(url.clone()) + .header("Content-Type", "application/json") .json(&payload) .send() - .await; - - match response { - Ok(res) => { - match handle_response_google_compat(res).await { - Ok(result) => return Ok(result), - Err(ProviderError::RateLimitExceeded(_)) => { - retries += 1; - if retries > max_retries { - return Err(ProviderError::RateLimitExceeded( - "Max retries exceeded for rate limit error".to_string(), - )); - } - - let delay = 2u64.pow(retries); - let total_delay = Duration::from_secs(delay) + base_delay; - - println!("Rate limit hit. Retrying in {:?}", total_delay); - tokio::time::sleep(total_delay).await; - continue; - } - Err(err) => return Err(err), // Other errors - } - } - Err(err) => { - return Err(ProviderError::RequestFailed(format!( - "Request failed: {}", - err - ))); - } - } - } + .await?; + + handle_response_google_compat(response).await + }) + .await } } @@ -195,14 +164,18 @@ impl Provider for GoogleProvider { /// Fetch supported models from Google Generative Language API; returns Err on failure, Ok(None) if not present async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { // List models via the v1beta/models endpoint - let url = format!("{}/v1beta/models", self.host); - let response = self.client.get(&url).send().await?; + let path = format!("v1beta/models?key={}", self.api_key); + let url = build_endpoint_url(&self.host, &path)?; + + let response = self.client.get(url).send().await?; let json: serde_json::Value = response.json().await?; + // If 'models' field missing, return None let arr = match json.get("models").and_then(|v| v.as_array()) { Some(arr) => arr, None => return Ok(None), }; + let mut models: Vec = arr .iter() .filter_map(|m| m.get("name").and_then(|v| v.as_str())) diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 3716df0e6dc3..c48c00c095ad 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -3,14 +3,17 @@ use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; -use crate::providers::utils::get_model; +use crate::providers::provider_common::{ + build_endpoint_url, get_shared_client, retry_with_backoff, AuthType, HeaderBuilder, + ProviderConfigBuilder, RetryConfig, +}; +use crate::providers::utils::{get_model, handle_response_openai_compat}; use anyhow::Result; use async_trait::async_trait; use mcp_core::Tool; use reqwest::{Client, StatusCode}; use serde_json::Value; -use std::time::Duration; -use url::Url; +use std::sync::Arc; pub const GROQ_API_HOST: &str = "https://api.groq.com"; pub const GROQ_DEFAULT_MODEL: &str = "llama-3.3-70b-versatile"; @@ -21,10 +24,12 @@ pub const GROQ_DOC_URL: &str = "https://console.groq.com/docs/models"; #[derive(serde::Serialize)] pub struct GroqProvider { #[serde(skip)] - client: Client, + client: Arc, host: String, api_key: String, model: ModelConfig, + #[serde(skip)] + retry_config: RetryConfig, } impl Default for GroqProvider { @@ -37,63 +42,46 @@ impl Default for GroqProvider { impl GroqProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); - let api_key: String = config.get_secret("GROQ_API_KEY")?; - let host: String = config - .get_param("GROQ_HOST") - .unwrap_or_else(|_| GROQ_API_HOST.to_string()); + let config_builder = ProviderConfigBuilder::new(config, "GROQ"); + + let api_key = config_builder.get_api_key()?; + let host = config_builder.get_host(GROQ_API_HOST); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + // Use shared client for better connection pooling + let client = get_shared_client(); + + // Configure retry settings + let retry_config = RetryConfig::default(); Ok(Self { client, host, api_key, model, + retry_config, }) } - async fn post(&self, payload: Value) -> anyhow::Result { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("openai/v1/chat/completions").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; + async fn post(&self, payload: Value) -> Result { + let url = build_endpoint_url(&self.host, "openai/v1/chat/completions")?; - let response = self - .client - .post(url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&payload) - .send() - .await?; + // Build headers using the new HeaderBuilder + let headers = HeaderBuilder::new(self.api_key.clone(), AuthType::Bearer).build(); - let status = response.status(); - let payload: Option = response.json().await.ok(); + // Use retry logic for resilience + retry_with_backoff(&self.retry_config, || async { + let response = self + .client + .post(url.clone()) + .headers(headers.clone()) + .json(&payload) + .send() + .await?; - match status { - StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, payload))) - } - StatusCode::PAYLOAD_TOO_LARGE => { - Err(ProviderError::ContextLengthExceeded(format!("{:?}", payload))) - } - StatusCode::TOO_MANY_REQUESTS => { - Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) - } - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - Err(ProviderError::ServerError(format!("{:?}", payload))) - } - _ => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) - } - } + // Use the common response handler + handle_response_openai_compat(response).await + }) + .await } } @@ -155,21 +143,15 @@ impl Provider for GroqProvider { /// Fetch supported models from Groq; returns Err on failure, Ok(None) if no models found async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { // Construct the Groq models endpoint - let base_url = url::Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {}", e)))?; - let url = base_url.join("openai/v1/models").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {}", e)) - })?; + let url = build_endpoint_url(&self.host, "openai/v1/models")?; - // Build the request with required headers - let request = self - .client - .get(url) - .bearer_auth(&self.api_key) - .header("Content-Type", "application/json"); + // Build headers using HeaderBuilder + let headers = HeaderBuilder::new(self.api_key.clone(), AuthType::Bearer) + .add_custom_header("Content-Type".to_string(), "application/json".to_string()) + .build(); // Send request - let response = request.send().await?; + let response = self.client.get(url).headers(headers).send().await?; let status = response.status(); let payload: serde_json::Value = response.json().await.map_err(|_| { ProviderError::RequestFailed("Response body is not valid JSON".to_string()) diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index a7748044a3bd..2b510467053d 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -21,6 +21,7 @@ pub mod ollama; pub mod openai; pub mod openrouter; pub mod pricing; +pub mod provider_common; pub mod sagemaker_tgi; pub mod snowflake; pub mod toolshim; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 4bbf1c392dae..72c17135e19c 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,5 +1,8 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use super::provider_common::{ + get_shared_client, retry_with_backoff, ProviderConfigBuilder, RetryConfig, +}; use super::utils::{get_model, handle_response_openai_compat}; use crate::message::Message; use crate::model::ModelConfig; @@ -9,7 +12,7 @@ use async_trait::async_trait; use mcp_core::tool::Tool; use reqwest::Client; use serde_json::Value; -use std::time::Duration; +use std::sync::Arc; use url::Url; pub const OLLAMA_HOST: &str = "localhost"; @@ -22,9 +25,11 @@ pub const OLLAMA_DOC_URL: &str = "https://ollama.com/library"; #[derive(serde::Serialize)] pub struct OllamaProvider { #[serde(skip)] - client: Client, + client: Arc, host: String, model: ModelConfig, + #[serde(skip)] + retry_config: RetryConfig, } impl Default for OllamaProvider { @@ -37,18 +42,21 @@ impl Default for OllamaProvider { impl OllamaProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); - let host: String = config - .get_param("OLLAMA_HOST") - .unwrap_or_else(|_| OLLAMA_HOST.to_string()); + let config_builder = ProviderConfigBuilder::new(config, "OLLAMA"); + + let host = config_builder.get_host(OLLAMA_HOST); + + // Use shared client for better connection pooling + let client = get_shared_client(); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + // Configure retry settings + let retry_config = RetryConfig::default(); Ok(Self { client, host, model, + retry_config, }) } @@ -88,9 +96,12 @@ impl OllamaProvider { ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) })?; - let response = self.client.post(url).json(&payload).send().await?; - - handle_response_openai_compat(response).await + // Use retry logic for resilience + retry_with_backoff(&self.retry_config, || async { + let response = self.client.post(url.clone()).json(&payload).send().await?; + handle_response_openai_compat(response).await + }) + .await } } diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 9884d147bffc..e2a211d6cc1d 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -3,12 +3,16 @@ use async_trait::async_trait; use reqwest::Client; use serde_json::Value; use std::collections::HashMap; -use std::time::Duration; +use std::sync::Arc; use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; +use super::provider_common::{ + build_endpoint_url, create_provider_client, get_shared_client, retry_with_backoff, AuthType, + HeaderBuilder, ProviderConfigBuilder, RetryConfig, +}; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; use crate::message::Message; use crate::model::ModelConfig; @@ -30,7 +34,7 @@ pub const OPEN_AI_DOC_URL: &str = "https://platform.openai.com/docs/models"; #[derive(Debug, serde::Serialize)] pub struct OpenAiProvider { #[serde(skip)] - client: Client, + client: Arc, host: String, base_path: String, api_key: String, @@ -38,6 +42,8 @@ pub struct OpenAiProvider { project: Option, model: ModelConfig, custom_headers: Option>, + #[serde(skip)] + retry_config: RetryConfig, } impl Default for OpenAiProvider { @@ -50,24 +56,36 @@ impl Default for OpenAiProvider { impl OpenAiProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); - let api_key: String = config.get_secret("OPENAI_API_KEY")?; - let host: String = config - .get_param("OPENAI_HOST") - .unwrap_or_else(|_| "https://api.openai.com".to_string()); - let base_path: String = config - .get_param("OPENAI_BASE_PATH") - .unwrap_or_else(|_| "v1/chat/completions".to_string()); - let organization: Option = config.get_param("OPENAI_ORGANIZATION").ok(); - let project: Option = config.get_param("OPENAI_PROJECT").ok(); + let config_builder = ProviderConfigBuilder::new(config, "OPENAI"); + + let api_key = config_builder.get_api_key()?; + let host = config_builder.get_host("https://api.openai.com"); + let base_path = config_builder + .get_param("BASE_PATH", Some("v1/chat/completions")) + .unwrap_or_else(|| "v1/chat/completions".to_string()); + let organization = config_builder.get_param("ORGANIZATION", None); + let project = config_builder.get_param("PROJECT", None); + let custom_headers: Option> = config .get_secret("OPENAI_CUSTOM_HEADERS") .or_else(|_| config.get_param("OPENAI_CUSTOM_HEADERS")) .ok() .map(parse_custom_headers); - let timeout_secs: u64 = config.get_param("OPENAI_TIMEOUT").unwrap_or(600); - let client = Client::builder() - .timeout(Duration::from_secs(timeout_secs)) - .build()?; + + // Check for custom timeout configuration + let timeout_secs = config_builder + .get_param("TIMEOUT", None) + .and_then(|s| s.parse::().ok()); + + // Use provider-specific client if timeout is configured, otherwise use shared client + let client = if timeout_secs.is_some() { + create_provider_client(timeout_secs)? + } else { + get_shared_client() + }; + + // Configure retry settings + let retry_config = RetryConfig::default(); Ok(Self { client, @@ -78,48 +96,50 @@ impl OpenAiProvider { project, model, custom_headers, + retry_config, }) } - /// Helper function to add OpenAI-specific headers to a request - fn add_headers(&self, mut request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + async fn post(&self, payload: Value) -> Result { + let url = build_endpoint_url(&self.host, &self.base_path)?; + + // Build headers using the new HeaderBuilder + let mut header_builder = HeaderBuilder::new(self.api_key.clone(), AuthType::Bearer); + // Add organization header if present if let Some(org) = &self.organization { - request = request.header("OpenAI-Organization", org); + header_builder = + header_builder.add_custom_header("OpenAI-Organization".to_string(), org.clone()); } // Add project header if present if let Some(project) = &self.project { - request = request.header("OpenAI-Project", project); + header_builder = + header_builder.add_custom_header("OpenAI-Project".to_string(), project.clone()); } // Add custom headers if present if let Some(custom_headers) = &self.custom_headers { for (key, value) in custom_headers { - request = request.header(key, value); + header_builder = header_builder.add_custom_header(key.clone(), value.clone()); } } - request - } - - async fn post(&self, payload: Value) -> Result { - let base_url = url::Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join(&self.base_path).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - let request = self - .client - .post(url) - .header("Authorization", format!("Bearer {}", self.api_key)); - - let request = self.add_headers(request); + let headers = header_builder.build(); - let response = request.json(&payload).send().await?; + // Use retry logic for resilience + retry_with_backoff(&self.retry_config, || async { + let response = self + .client + .post(url.clone()) + .headers(headers.clone()) + .json(&payload) + .send() + .await?; - handle_response_openai_compat(response).await + handle_response_openai_compat(response).await + }) + .await } } @@ -190,24 +210,29 @@ impl Provider for OpenAiProvider { /// Fetch supported models from OpenAI; returns Err on any failure, Ok(None) if no data async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { // List available models via OpenAI API - let base_url = - url::Url::parse(&self.host).map_err(|e| ProviderError::RequestFailed(e.to_string()))?; - let url = base_url - .join("v1/models") - .map_err(|e| ProviderError::RequestFailed(e.to_string()))?; - let mut request = self.client.get(url).bearer_auth(&self.api_key); + let url = build_endpoint_url(&self.host, "v1/models")?; + + // Build headers using the same pattern as post method + let mut header_builder = HeaderBuilder::new(self.api_key.clone(), AuthType::Bearer); + if let Some(org) = &self.organization { - request = request.header("OpenAI-Organization", org); + header_builder = + header_builder.add_custom_header("OpenAI-Organization".to_string(), org.clone()); } + if let Some(project) = &self.project { - request = request.header("OpenAI-Project", project); + header_builder = + header_builder.add_custom_header("OpenAI-Project".to_string(), project.clone()); } - if let Some(headers) = &self.custom_headers { - for (key, value) in headers { - request = request.header(key, value); + + if let Some(custom_headers) = &self.custom_headers { + for (key, value) in custom_headers { + header_builder = header_builder.add_custom_header(key.clone(), value.clone()); } } - let response = request.send().await?; + + let headers = header_builder.build(); + let response = self.client.get(url).headers(headers).send().await?; let json: serde_json::Value = response.json().await?; if let Some(err_obj) = json.get("error") { let msg = err_obj @@ -266,21 +291,34 @@ impl EmbeddingCapable for OpenAiProvider { }; // Construct embeddings endpoint URL - let base_url = - url::Url::parse(&self.host).map_err(|e| anyhow::anyhow!("Invalid base URL: {e}"))?; - let url = base_url - .join("v1/embeddings") - .map_err(|e| anyhow::anyhow!("Failed to construct embeddings URL: {e}"))?; + let url = build_endpoint_url(&self.host, "v1/embeddings") + .map_err(|e| anyhow::anyhow!("Failed to build embeddings URL: {e}"))?; - let req = self - .client - .post(url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&request); + // Build headers using the same pattern + let mut header_builder = HeaderBuilder::new(self.api_key.clone(), AuthType::Bearer); + + if let Some(org) = &self.organization { + header_builder = + header_builder.add_custom_header("OpenAI-Organization".to_string(), org.clone()); + } - let req = self.add_headers(req); + if let Some(project) = &self.project { + header_builder = + header_builder.add_custom_header("OpenAI-Project".to_string(), project.clone()); + } + + if let Some(custom_headers) = &self.custom_headers { + for (key, value) in custom_headers { + header_builder = header_builder.add_custom_header(key.clone(), value.clone()); + } + } - let response = req + let headers = header_builder.build(); + let response = self + .client + .post(url) + .headers(headers) + .json(&request) .send() .await .map_err(|e| anyhow::anyhow!("Failed to send embedding request: {e}"))?; diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 0352012fc462..2df4aab8d9cb 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -2,10 +2,15 @@ use anyhow::{Error, Result}; use async_trait::async_trait; use reqwest::Client; use serde_json::{json, Value}; -use std::time::Duration; +use std::sync::Arc; +use url::Url; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use super::provider_common::{ + build_endpoint_url, get_shared_client, retry_with_backoff, AuthType, HeaderBuilder, + ProviderConfigBuilder, RetryConfig, +}; use super::utils::{ emit_debug_trace, get_model, handle_response_google_compat, handle_response_openai_compat, is_google_model, @@ -14,7 +19,6 @@ use crate::message::Message; use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; use mcp_core::tool::Tool; -use url::Url; pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet"; pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic"; @@ -32,10 +36,12 @@ pub const OPENROUTER_DOC_URL: &str = "https://openrouter.ai/models"; #[derive(serde::Serialize)] pub struct OpenRouterProvider { #[serde(skip)] - client: Client, + client: Arc, host: String, api_key: String, model: ModelConfig, + #[serde(skip)] + retry_config: RetryConfig, } impl Default for OpenRouterProvider { @@ -48,80 +54,91 @@ impl Default for OpenRouterProvider { impl OpenRouterProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); - let api_key: String = config.get_secret("OPENROUTER_API_KEY")?; - let host: String = config - .get_param("OPENROUTER_HOST") - .unwrap_or_else(|_| "https://openrouter.ai".to_string()); + let config_builder = ProviderConfigBuilder::new(config, "OPENROUTER"); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + let api_key = config_builder.get_api_key()?; + let host = config_builder.get_host("https://openrouter.ai"); + + // Use shared client for better connection pooling + let client = get_shared_client(); + + // Configure retry settings + let retry_config = RetryConfig::default(); Ok(Self { client, host, api_key, model, + retry_config, }) } async fn post(&self, payload: Value) -> Result { - let base_url = Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("api/v1/chat/completions").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - let response = self - .client - .post(url) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("HTTP-Referer", "https://block.github.io/goose") - .header("X-Title", "Goose") - .json(&payload) - .send() - .await?; - - // Handle Google-compatible model responses differently - if is_google_model(&payload) { - return handle_response_google_compat(response).await; - } - - // For OpenAI-compatible models, parse the response body to JSON - let response_body = handle_response_openai_compat(response) - .await - .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse response: {e}")))?; - - // OpenRouter can return errors in 200 OK responses, so we have to check for errors explicitly - // https://openrouter.ai/docs/api-reference/errors - if let Some(error_obj) = response_body.get("error") { - // If there's an error object, extract the error message and code - let error_message = error_obj - .get("message") - .and_then(|m| m.as_str()) - .unwrap_or("Unknown OpenRouter error"); - - let error_code = error_obj.get("code").and_then(|c| c.as_u64()).unwrap_or(0); - - // Check for context length errors in the error message - if error_code == 400 && error_message.contains("maximum context length") { - return Err(ProviderError::ContextLengthExceeded( - error_message.to_string(), - )); + let url = build_endpoint_url(&self.host, "api/v1/chat/completions")?; + + // Build headers using HeaderBuilder + let headers = HeaderBuilder::new(self.api_key.clone(), AuthType::Bearer) + .add_custom_header( + "HTTP-Referer".to_string(), + "https://block.github.io/goose".to_string(), + ) + .add_custom_header("X-Title".to_string(), "Goose".to_string()) + .build(); + + // Use retry logic for resilience + retry_with_backoff(&self.retry_config, || async { + let response = self + .client + .post(url.clone()) + .headers(headers.clone()) + .json(&payload) + .send() + .await?; + + // Handle Google-compatible model responses differently + if is_google_model(&payload) { + return handle_response_google_compat(response).await; } - // Return appropriate error based on the OpenRouter error code - match error_code { - 401 | 403 => return Err(ProviderError::Authentication(error_message.to_string())), - 429 => return Err(ProviderError::RateLimitExceeded(error_message.to_string())), - 500 | 503 => return Err(ProviderError::ServerError(error_message.to_string())), - _ => return Err(ProviderError::RequestFailed(error_message.to_string())), + // For OpenAI-compatible models, parse the response body to JSON + let response_body = handle_response_openai_compat(response).await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse response: {e}")) + })?; + + // OpenRouter can return errors in 200 OK responses, so we have to check for errors explicitly + // https://openrouter.ai/docs/api-reference/errors + if let Some(error_obj) = response_body.get("error") { + // If there's an error object, extract the error message and code + let error_message = error_obj + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown OpenRouter error"); + + let error_code = error_obj.get("code").and_then(|c| c.as_u64()).unwrap_or(0); + + // Check for context length errors in the error message + if error_code == 400 && error_message.contains("maximum context length") { + return Err(ProviderError::ContextLengthExceeded( + error_message.to_string(), + )); + } + + // Return appropriate error based on the OpenRouter error code + match error_code { + 401 | 403 => { + return Err(ProviderError::Authentication(error_message.to_string())) + } + 429 => return Err(ProviderError::RateLimitExceeded(error_message.to_string())), + 500 | 503 => return Err(ProviderError::ServerError(error_message.to_string())), + _ => return Err(ProviderError::RequestFailed(error_message.to_string())), + } } - } - // No error detected, return the response body - Ok(response_body) + // No error detected, return the response body + Ok(response_body) + }) + .await } } diff --git a/crates/goose/src/providers/pricing.rs b/crates/goose/src/providers/pricing.rs index b817907a0e6f..7b6b0ec70190 100644 --- a/crates/goose/src/providers/pricing.rs +++ b/crates/goose/src/providers/pricing.rs @@ -44,12 +44,15 @@ pub struct PricingInfo { pub struct PricingCache { /// In-memory cache memory_cache: Arc>>, + /// Active model cache for frequently accessed models + active_model_cache: Arc>>, } impl PricingCache { pub fn new() -> Self { Self { memory_cache: Arc::new(RwLock::new(None)), + active_model_cache: Arc::new(RwLock::new(None)), } } @@ -106,6 +109,32 @@ impl PricingCache { Ok(()) } + /// Get pricing for a specific model with active model caching + pub async fn get_active_model_pricing( + &self, + provider: &str, + model: &str, + ) -> Option { + // Check active model cache first + { + let cache = self.active_model_cache.read().await; + if let Some((cached_provider, cached_model, info)) = &*cache { + if cached_provider == provider && cached_model == model { + return Some(info.clone()); + } + } + } + + // Fetch and cache + if let Some(info) = self.get_model_pricing(provider, model).await { + let mut cache = self.active_model_cache.write().await; + *cache = Some((provider.to_string(), model.to_string(), info.clone())); + Some(info) + } else { + None + } + } + /// Get pricing for a specific model pub async fn get_model_pricing(&self, provider: &str, model: &str) -> Option { // Try memory cache first @@ -303,6 +332,13 @@ pub async fn get_model_pricing(provider: &str, model: &str) -> Option Option { + PRICING_CACHE + .get_active_model_pricing(provider, model) + .await +} + /// Force refresh pricing data pub async fn refresh_pricing() -> Result<()> { PRICING_CACHE.refresh().await diff --git a/crates/goose/src/providers/provider_common.rs b/crates/goose/src/providers/provider_common.rs new file mode 100644 index 000000000000..58e089e512e7 --- /dev/null +++ b/crates/goose/src/providers/provider_common.rs @@ -0,0 +1,655 @@ +use anyhow::{anyhow, Result}; +use lazy_static::lazy_static; +use reqwest::{Client, StatusCode}; +use serde::Serialize; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; +use url::Url; + +use crate::providers::errors::ProviderError; +use uuid::Uuid; + +/// Trait for collecting metrics about provider requests +pub trait ProviderMetrics: Send + Sync { + /// Called when a request starts + fn on_request_start(&self, provider: &str, endpoint: &str); + + /// Called when a request completes successfully + fn on_request_success(&self, provider: &str, endpoint: &str, duration_ms: u64); + + /// Called when a request fails + fn on_request_failure( + &self, + provider: &str, + endpoint: &str, + error: &ProviderError, + duration_ms: u64, + ); + + /// Called when a retry is attempted + fn on_retry_attempt(&self, provider: &str, endpoint: &str, attempt: u32); +} + +/// Simple cache trait for providers that want response caching +pub trait ProviderCache: Send + Sync { + /// Get a cached response if available + fn get(&self, key: &str) -> Option; + + /// Store a response in the cache + fn set(&self, key: &str, value: serde_json::Value, ttl_secs: u64); + + /// Generate a cache key from request parameters + fn make_key(&self, provider: &str, endpoint: &str, payload: &serde_json::Value) -> String { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + provider.hash(&mut hasher); + endpoint.hash(&mut hasher); + payload.to_string().hash(&mut hasher); + format!("{}:{}:{}", provider, endpoint, hasher.finish()) + } +} + +/// Default timeout for HTTP requests +pub const DEFAULT_TIMEOUT_SECS: u64 = 600; + +/// Common retry configuration for providers +#[derive(Debug, Clone)] +pub struct RetryConfig { + pub max_retries: u32, + pub initial_delay_ms: u64, + pub max_delay_ms: u64, + pub backoff_multiplier: f64, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: 3, + initial_delay_ms: 1000, + max_delay_ms: 32000, + backoff_multiplier: 2.0, + } + } +} + +/// Authentication type for providers +#[derive(Debug, Clone)] +pub enum AuthType { + Bearer, + ApiKey, + Custom(String), +} + +/// Common headers builder for providers +pub struct HeaderBuilder { + auth_token: String, + auth_type: AuthType, + custom_headers: HashMap, +} + +impl HeaderBuilder { + pub fn new(auth_token: String, auth_type: AuthType) -> Self { + Self { + auth_token, + auth_type, + custom_headers: HashMap::new(), + } + } + + pub fn add_custom_header(mut self, key: String, value: String) -> Self { + self.custom_headers.insert(key, value); + self + } + + pub fn add_request_id(mut self) -> Self { + let request_id = Uuid::new_v4().to_string(); + self.custom_headers + .insert("X-Request-ID".to_string(), request_id.clone()); + self.custom_headers + .insert("X-Trace-ID".to_string(), request_id); + self + } + + pub fn build(self) -> reqwest::header::HeaderMap { + let mut headers = reqwest::header::HeaderMap::new(); + + // Add authorization header + match self.auth_type { + AuthType::Bearer => { + headers.insert( + reqwest::header::AUTHORIZATION, + format!("Bearer {}", self.auth_token).parse().unwrap(), + ); + } + AuthType::ApiKey => { + headers.insert("X-API-Key", self.auth_token.parse().unwrap()); + } + AuthType::Custom(header_name) => { + if let Ok(name) = reqwest::header::HeaderName::from_bytes(header_name.as_bytes()) { + headers.insert(name, self.auth_token.parse().unwrap()); + } + } + } + + // Add compression support headers + headers.insert( + reqwest::header::ACCEPT_ENCODING, + "gzip, deflate, br".parse().unwrap(), + ); + + // Add User-Agent header + headers.insert( + reqwest::header::USER_AGENT, + format!("Goose/{} (Rust)", env!("CARGO_PKG_VERSION")) + .parse() + .unwrap(), + ); + + // Add custom headers + for (key, value) in self.custom_headers { + if let (Ok(header_name), Ok(header_value)) = ( + reqwest::header::HeaderName::from_bytes(key.as_bytes()), + value.parse(), + ) { + headers.insert(header_name, header_value); + } + } + + headers + } +} + +/// Connection pool configuration +pub struct ConnectionPoolConfig { + /// Maximum idle connections per host + pub max_idle_per_host: usize, + /// Time before idle connections are closed + pub idle_timeout_secs: u64, + /// Maximum number of connections per host + pub max_connections_per_host: Option, + /// Enable HTTP/2 + pub http2_enabled: bool, +} + +impl Default for ConnectionPoolConfig { + fn default() -> Self { + Self { + max_idle_per_host: 10, + idle_timeout_secs: 90, + max_connections_per_host: Some(50), + http2_enabled: true, + } + } +} + +/// Create a default HTTP client with common settings +pub fn create_default_client(timeout_secs: Option) -> Result { + create_client_with_config(timeout_secs, ConnectionPoolConfig::default()) +} + +/// Create an HTTP client with custom configuration +pub fn create_client_with_config( + timeout_secs: Option, + pool_config: ConnectionPoolConfig, +) -> Result { + let mut builder = Client::builder() + .timeout(Duration::from_secs( + timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS), + )) + .pool_idle_timeout(Duration::from_secs(pool_config.idle_timeout_secs)) + .pool_max_idle_per_host(pool_config.max_idle_per_host) + .gzip(true) // Enable automatic gzip decompression + .brotli(true) // Enable automatic brotli decompression + .tcp_keepalive(Duration::from_secs(60)) // Keep connections alive + .tcp_nodelay(true) // Disable Nagle's algorithm for lower latency + .connect_timeout(Duration::from_secs(30)); // Timeout for establishing connection + + if pool_config.http2_enabled { + builder = builder + .http2_prior_knowledge() + .http2_keep_alive_interval(Duration::from_secs(10)) + .http2_keep_alive_timeout(Duration::from_secs(20)); + } + + builder + .build() + .map_err(|e| anyhow!("Failed to create HTTP client: {}", e)) +} + +// Global shared HTTP client for providers that want to share connections +lazy_static! { + static ref SHARED_CLIENT: Arc = + Arc::new(create_default_client(None).expect("Failed to create shared HTTP client")); +} + +/// Get the shared HTTP client instance +pub fn get_shared_client() -> Arc { + SHARED_CLIENT.clone() +} + +/// Create a provider-specific HTTP client with custom timeout +pub fn create_provider_client(timeout_secs: Option) -> Result> { + Ok(Arc::new(create_default_client(timeout_secs)?)) +} + +/// Build endpoint URL from base and path +pub fn build_endpoint_url(base: &str, path: &str) -> Result { + let base_url = Url::parse(base) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + base_url + .join(path) + .map_err(|e| ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))) +} + +/// Maximum request payload size (10MB) +pub const MAX_REQUEST_SIZE: usize = 10 * 1024 * 1024; + +/// Check if a request payload is within size limits +pub fn validate_request_size(payload: &serde_json::Value) -> Result<(), ProviderError> { + let size = serde_json::to_string(payload) + .map_err(|e| ProviderError::RequestFailed(format!("Failed to serialize payload: {}", e)))? + .len(); + + if size > MAX_REQUEST_SIZE { + Err(ProviderError::RequestFailed(format!( + "Request payload too large: {} bytes (max: {} bytes). Consider reducing the message history or content size.", + size, MAX_REQUEST_SIZE + ))) + } else { + Ok(()) + } +} + +/// Check if an error is retryable +pub trait IsRetryable { + fn is_retryable(&self) -> bool; +} + +impl IsRetryable for ProviderError { + fn is_retryable(&self) -> bool { + matches!( + self, + ProviderError::RateLimitExceeded(_) + | ProviderError::ServerError(_) + | ProviderError::RequestFailed(_) + ) + } +} + +/// Retry an async operation with exponential backoff +pub async fn retry_with_backoff( + config: &RetryConfig, + mut operation: F, +) -> Result +where + F: FnMut() -> Fut, + Fut: std::future::Future>, +{ + let mut attempts = 0; + let mut delay_ms = config.initial_delay_ms; + + loop { + match operation().await { + Ok(result) => return Ok(result), + Err(e) if e.is_retryable() && attempts < config.max_retries => { + attempts += 1; + tracing::warn!( + "Retryable error (attempt {}/{}): {}. Retrying in {}ms...", + attempts, + config.max_retries, + e, + delay_ms + ); + + sleep(Duration::from_millis(delay_ms)).await; + + // Update delay with exponential backoff + delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64; + delay_ms = delay_ms.min(config.max_delay_ms); + } + Err(e) => return Err(e), + } + } +} + +/// Retry an async operation with custom delay extraction from errors +pub async fn retry_with_backoff_and_custom_delay( + config: &RetryConfig, + mut operation: F, + mut extract_delay: D, +) -> Result +where + F: FnMut() -> Fut, + Fut: std::future::Future>, + D: FnMut(&ProviderError) -> Option, +{ + let mut attempts = 0; + let mut delay_ms = config.initial_delay_ms; + + loop { + match operation().await { + Ok(result) => return Ok(result), + Err(e) if e.is_retryable() && attempts < config.max_retries => { + attempts += 1; + + // Try to extract custom delay from error + let custom_delay_ms = extract_delay(&e); + let actual_delay_ms = custom_delay_ms.unwrap_or(delay_ms); + + tracing::warn!( + "Retryable error (attempt {}/{}): {}. Retrying in {}ms{}...", + attempts, + config.max_retries, + e, + actual_delay_ms, + if custom_delay_ms.is_some() { + " (custom delay)" + } else { + "" + } + ); + + sleep(Duration::from_millis(actual_delay_ms)).await; + + // Only update delay with exponential backoff if no custom delay was found + if custom_delay_ms.is_none() { + delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64; + delay_ms = delay_ms.min(config.max_delay_ms); + } + } + Err(e) => return Err(e), + } + } +} + +/// Common response handler for providers +pub async fn handle_provider_response( + response: reqwest::Response, + provider_name: &str, +) -> Result { + let status = response.status(); + let response_text = response.text().await + .map_err(|e| { + if e.is_timeout() { + ProviderError::RequestFailed(format!( + "{} request timed out. The provider may be slow or the request may be too large. Consider increasing the timeout or reducing the request size.", + provider_name + )) + } else if e.is_connect() { + ProviderError::RequestFailed(format!( + "Failed to connect to {} API. Please check your network connection and the provider's status.", + provider_name + )) + } else { + ProviderError::RequestFailed(format!("Failed to read response from {}: {}", provider_name, e)) + } + })?; + + if status.is_success() { + serde_json::from_str(&response_text) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid JSON response: {}", e))) + } else { + let error_msg = format!( + "{} API error ({}): {}", + provider_name, status, response_text + ); + + match status { + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + Err(ProviderError::Authentication(error_msg)) + } + StatusCode::TOO_MANY_REQUESTS => Err(ProviderError::RateLimitExceeded(error_msg)), + StatusCode::BAD_REQUEST => { + // Check if it's a context length error + if response_text.to_lowercase().contains("context") + || response_text.to_lowercase().contains("too long") + || response_text.to_lowercase().contains("exceeds") + { + Err(ProviderError::ContextLengthExceeded(error_msg)) + } else { + Err(ProviderError::RequestFailed(error_msg)) + } + } + s if s.is_server_error() => Err(ProviderError::ServerError(error_msg)), + _ => Err(ProviderError::RequestFailed(error_msg)), + } + } +} + +/// Configuration builder for providers +pub struct ProviderConfigBuilder<'a> { + config: &'a crate::config::Config, + prefix: String, +} + +impl<'a> ProviderConfigBuilder<'a> { + pub fn new(config: &'a crate::config::Config, prefix: &str) -> Self { + Self { + config, + prefix: prefix.to_uppercase(), + } + } + + pub fn get_api_key(&self) -> Result { + self.config + .get_secret(&format!("{}_API_KEY", self.prefix)) + .map_err(|e| anyhow!("Failed to get API key: {}", e)) + } + + pub fn get_host(&self, default: &str) -> String { + self.config + .get_param(&format!("{}_HOST", self.prefix)) + .unwrap_or_else(|_| default.to_string()) + } + + pub fn get_model(&self, default: &str) -> String { + self.config + .get_param(&format!("{}_MODEL", self.prefix)) + .unwrap_or_else(|_| default.to_string()) + } + + pub fn get_param(&self, param: &str, default: Option<&str>) -> Option { + self.config + .get_param(&format!("{}_{}", self.prefix, param)) + .ok() + .or_else(|| default.map(|s| s.to_string())) + } +} + +/// Base provider struct that others can compose with +pub struct BaseProvider { + pub client: Client, + pub host: String, + pub retry_config: RetryConfig, +} + +impl BaseProvider { + pub fn new(host: String, retry_config: Option) -> Result { + Ok(Self { + client: create_default_client(None)?, + host, + retry_config: retry_config.unwrap_or_default(), + }) + } + + /// Make a POST request with retry logic + pub async fn post_json( + &self, + endpoint: &str, + headers: reqwest::header::HeaderMap, + payload: &T, + ) -> Result { + let url = build_endpoint_url(&self.host, endpoint)?; + + retry_with_backoff(&self.retry_config, || async { + let response = self + .client + .post(url.clone()) + .headers(headers.clone()) + .json(payload) + .send() + .await + .map_err(|e| ProviderError::RequestFailed(format!("Request failed: {}", e)))?; + + handle_provider_response(response, "Provider").await + }) + .await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_header_builder() { + let headers = HeaderBuilder::new("test-token".to_string(), AuthType::Bearer) + .add_custom_header("X-Custom".to_string(), "value".to_string()) + .build(); + + assert_eq!(headers.get("authorization").unwrap(), "Bearer test-token"); + assert_eq!(headers.get("x-custom").unwrap(), "value"); + } + + #[test] + fn test_build_endpoint_url() { + let url = build_endpoint_url("https://api.example.com", "/v1/chat").unwrap(); + assert_eq!(url.as_str(), "https://api.example.com/v1/chat"); + + let url = build_endpoint_url("https://api.example.com/", "v1/chat").unwrap(); + assert_eq!(url.as_str(), "https://api.example.com/v1/chat"); + } + + #[test] + fn test_retry_config_default() { + let config = RetryConfig::default(); + assert_eq!(config.max_retries, 3); + assert_eq!(config.initial_delay_ms, 1000); + assert_eq!(config.backoff_multiplier, 2.0); + } + + #[tokio::test] + async fn test_retry_with_backoff_success() { + let config = RetryConfig { + max_retries: 3, + initial_delay_ms: 10, + max_delay_ms: 100, + backoff_multiplier: 2.0, + }; + + let result = retry_with_backoff(&config, || async { Ok::(42) }).await; + + assert_eq!(result.unwrap(), 42); + } + + #[tokio::test] + async fn test_retry_with_backoff_eventual_success() { + use std::sync::atomic::{AtomicU32, Ordering}; + + let config = RetryConfig { + max_retries: 3, + initial_delay_ms: 10, + max_delay_ms: 100, + backoff_multiplier: 2.0, + }; + + let attempts = AtomicU32::new(0); + let result = retry_with_backoff(&config, || async { + let current = attempts.fetch_add(1, Ordering::SeqCst) + 1; + if current < 3 { + Err(ProviderError::RateLimitExceeded("Rate limited".to_string())) + } else { + Ok::(42) + } + }) + .await; + + assert_eq!(result.unwrap(), 42); + assert_eq!(attempts.load(Ordering::SeqCst), 3); // Should succeed on third try + } + + #[tokio::test] + async fn test_retry_with_backoff_max_retries_exceeded() { + let config = RetryConfig { + max_retries: 2, + initial_delay_ms: 10, + max_delay_ms: 100, + backoff_multiplier: 2.0, + }; + + let result = retry_with_backoff(&config, || async { + Err::(ProviderError::RateLimitExceeded("Rate limited".to_string())) + }) + .await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_retry_with_backoff_non_retryable_error() { + let config = RetryConfig { + max_retries: 3, + initial_delay_ms: 10, + max_delay_ms: 100, + backoff_multiplier: 2.0, + }; + + let result = retry_with_backoff(&config, || async { + Err::(ProviderError::Authentication("Auth failed".to_string())) + }) + .await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_retry_with_custom_delay() { + use std::sync::atomic::{AtomicU32, Ordering}; + + let config = RetryConfig { + max_retries: 3, + initial_delay_ms: 100, + max_delay_ms: 1000, + backoff_multiplier: 2.0, + }; + + let attempts = AtomicU32::new(0); + let start = std::time::Instant::now(); + + let result = retry_with_backoff_and_custom_delay( + &config, + || async { + let current = attempts.fetch_add(1, Ordering::SeqCst) + 1; + if current < 3 { + Err(ProviderError::RateLimitExceeded("Rate limited".to_string())) + } else { + Ok::(42) + } + }, + |_| Some(50), // Always return 50ms custom delay + ) + .await; + + let elapsed = start.elapsed(); + + assert_eq!(result.unwrap(), 42); + assert_eq!(attempts.load(Ordering::SeqCst), 3); + // Should have used custom delay (50ms) twice + assert!(elapsed.as_millis() >= 100 && elapsed.as_millis() < 200); + } + + #[test] + fn test_is_retryable() { + assert!(ProviderError::RateLimitExceeded("test".to_string()).is_retryable()); + assert!(ProviderError::ServerError("test".to_string()).is_retryable()); + assert!(ProviderError::RequestFailed("test".to_string()).is_retryable()); + + assert!(!ProviderError::Authentication("test".to_string()).is_retryable()); + assert!(!ProviderError::ContextLengthExceeded("test".to_string()).is_retryable()); + assert!(!ProviderError::UsageError("test".to_string()).is_retryable()); + } +} diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index 19b95deb83f6..ce3cef56bbb1 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -8,10 +8,10 @@ use aws_sdk_bedrockruntime::config::ProvideCredentials; use aws_sdk_sagemakerruntime::Client as SageMakerClient; use mcp_core::Tool; use serde_json::{json, Value}; -use tokio::time::sleep; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use super::provider_common::{retry_with_backoff, RetryConfig}; use super::utils::emit_debug_trace; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; @@ -30,6 +30,8 @@ pub struct SageMakerTgiProvider { sagemaker_client: SageMakerClient, endpoint_name: String, model: ModelConfig, + #[serde(skip)] + retry_config: RetryConfig, } impl SageMakerTgiProvider { @@ -76,10 +78,14 @@ impl SageMakerTgiProvider { let sagemaker_client = SageMakerClient::new(&config_with_timeout); + // Configure retry settings + let retry_config = RetryConfig::default(); + Ok(Self { sagemaker_client, endpoint_name, model, + retry_config, }) } @@ -157,27 +163,45 @@ impl SageMakerTgiProvider { ProviderError::RequestFailed(format!("Failed to serialize request: {}", e)) })?; - let response = self - .sagemaker_client - .invoke_endpoint() - .endpoint_name(&self.endpoint_name) - .content_type("application/json") - .body(body.into_bytes().into()) - .send() - .await - .map_err(|e| ProviderError::RequestFailed(format!("SageMaker invoke failed: {}", e)))?; - - let response_body = response - .body - .as_ref() - .ok_or_else(|| ProviderError::RequestFailed("Empty response body".to_string()))?; - let response_text = std::str::from_utf8(response_body.as_ref()).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to decode response: {}", e)) - })?; + // Use retry logic for resilience + retry_with_backoff(&self.retry_config, || async { + let response = self + .sagemaker_client + .invoke_endpoint() + .endpoint_name(&self.endpoint_name) + .content_type("application/json") + .body(body.clone().into_bytes().into()) + .send() + .await + .map_err(|e| { + // Convert AWS SDK errors to appropriate ProviderError types + let error_msg = format!("SageMaker invoke failed: {}", e); + if error_msg.contains("ThrottlingException") { + ProviderError::RateLimitExceeded(error_msg) + } else if error_msg.contains("ValidationException") + && error_msg.contains("payload size") + { + ProviderError::ContextLengthExceeded(error_msg) + } else if error_msg.contains("ModelError") { + ProviderError::ExecutionError(error_msg) + } else { + ProviderError::RequestFailed(error_msg) + } + })?; + + let response_body = response + .body + .as_ref() + .ok_or_else(|| ProviderError::RequestFailed("Empty response body".to_string()))?; + let response_text = std::str::from_utf8(response_body.as_ref()).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to decode response: {}", e)) + })?; - serde_json::from_str(response_text).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to parse response JSON: {}", e)) + serde_json::from_str(response_text).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse response JSON: {}", e)) + }) }) + .await } fn parse_tgi_response(&self, response: Value) -> Result { @@ -303,63 +327,30 @@ impl Provider for SageMakerTgiProvider { ProviderError::RequestFailed(format!("Failed to create request: {}", e)) })?; - // Retry configuration - const MAX_RETRIES: u32 = 3; - const INITIAL_BACKOFF_MS: u64 = 1000; // 1 second - const MAX_BACKOFF_MS: u64 = 30000; // 30 seconds - - let mut attempts = 0; - let mut backoff_ms = INITIAL_BACKOFF_MS; - - loop { - attempts += 1; - - match self.invoke_endpoint(request_payload.clone()).await { - Ok(response) => { - let message = self.parse_tgi_response(response)?; - - // TGI doesn't provide usage statistics, so we estimate - let usage = Usage { - input_tokens: Some(0), // Would need to tokenize input to get accurate count - output_tokens: Some(0), // Would need to tokenize output to get accurate count - total_tokens: Some(0), - }; - - // Add debug trace - let debug_payload = serde_json::json!({ - "system": system, - "messages": messages, - "tools": tools - }); - emit_debug_trace( - &self.model, - &debug_payload, - &serde_json::to_value(&message).unwrap_or_default(), - &usage, - ); - - let provider_usage = ProviderUsage::new(model_name.to_string(), usage); - return Ok((message, provider_usage)); - } - Err(err) => { - if attempts > MAX_RETRIES { - return Err(err); - } + let response = self.invoke_endpoint(request_payload.clone()).await?; + let message = self.parse_tgi_response(response)?; - // Log retry attempt - tracing::warn!( - "SageMaker TGI request failed (attempt {}/{}), retrying in {} ms: {:?}", - attempts, - MAX_RETRIES, - backoff_ms, - err - ); - - // Wait before retry - sleep(Duration::from_millis(backoff_ms)).await; - backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS); - } - } - } + // TGI doesn't provide usage statistics, so we estimate + let usage = Usage { + input_tokens: Some(0), // Would need to tokenize input to get accurate count + output_tokens: Some(0), // Would need to tokenize output to get accurate count + total_tokens: Some(0), + }; + + // Add debug trace + let debug_payload = serde_json::json!({ + "system": system, + "messages": messages, + "tools": tools + }); + emit_debug_trace( + &self.model, + &debug_payload, + &serde_json::to_value(&message).unwrap_or_default(), + &usage, + ); + + let provider_usage = ProviderUsage::new(model_name.to_string(), usage); + Ok((message, provider_usage)) } } diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index 54309ec4e54a..98d5814697ca 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -3,17 +3,19 @@ use async_trait::async_trait; use reqwest::{Client, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use std::time::Duration; +use std::sync::Arc; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use super::formats::snowflake::{create_request, get_usage, response_to_message}; -use super::utils::{get_model, ImageFormat}; +use super::provider_common::{ + build_endpoint_url, get_shared_client, retry_with_backoff, ProviderConfigBuilder, RetryConfig, +}; +use super::utils::{emit_debug_trace, get_model, ImageFormat}; use crate::config::ConfigError; use crate::message::Message; use crate::model::ModelConfig; use mcp_core::tool::Tool; -use url::Url; pub const SNOWFLAKE_DEFAULT_MODEL: &str = "claude-3-7-sonnet"; pub const SNOWFLAKE_KNOWN_MODELS: &[&str] = &["claude-3-7-sonnet", "claude-3-5-sonnet"]; @@ -35,11 +37,13 @@ impl SnowflakeAuth { #[derive(Debug, serde::Serialize)] pub struct SnowflakeProvider { #[serde(skip)] - client: Client, + client: Arc, host: String, auth: SnowflakeAuth, model: ModelConfig, image_format: ImageFormat, + #[serde(skip)] + retry_config: RetryConfig, } impl Default for SnowflakeProvider { @@ -52,6 +56,9 @@ impl Default for SnowflakeProvider { impl SnowflakeProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); + let _config_builder = ProviderConfigBuilder::new(config, "SNOWFLAKE"); + + // Try to get host from params or secrets let mut host: Result = config.get_param("SNOWFLAKE_HOST"); if host.is_err() { host = config.get_secret("SNOWFLAKE_HOST") @@ -73,12 +80,11 @@ impl SnowflakeProvider { host = format!("{}.snowflakecomputing.com", host); } + // Try to get token from params or secrets let mut token: Result = config.get_param("SNOWFLAKE_TOKEN"); - if token.is_err() { token = config.get_secret("SNOWFLAKE_TOKEN") } - if token.is_err() { return Err(ConfigError::NotFound( "Did not find SNOWFLAKE_TOKEN in either config file or keyring".to_string(), @@ -86,9 +92,11 @@ impl SnowflakeProvider { .into()); } - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + // Use shared client for better connection pooling + let client = get_shared_client(); + + // Configure retry settings + let retry_config = RetryConfig::default(); // Use token-based authentication let api_key = token?; @@ -98,6 +106,7 @@ impl SnowflakeProvider { auth: SnowflakeAuth::token(api_key), model, image_format: ImageFormat::OpenAi, + retry_config, }) } @@ -109,32 +118,28 @@ impl SnowflakeProvider { } async fn post(&self, payload: Value) -> Result { - let base_url_str = - if !self.host.starts_with("https://") && !self.host.starts_with("http://") { - format!("https://{}", self.host) - } else { - self.host.clone() - }; - let base_url = Url::parse(&base_url_str) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let path = "api/v2/cortex/inference:complete"; - let url = base_url.join(path).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; - - let auth_header = self.ensure_auth_header().await?; - let response = self - .client - .post(url) - .header("Authorization", auth_header) - .header("User-Agent", "Goose") - .json(&payload) - .send() - .await?; - - let status = response.status(); - - let payload_text: String = response.text().await.ok().unwrap_or_default(); + let host = if !self.host.starts_with("https://") && !self.host.starts_with("http://") { + format!("https://{}", self.host) + } else { + self.host.clone() + }; + + let url = build_endpoint_url(&host, "api/v2/cortex/inference:complete")?; + + // Use retry logic for resilience + retry_with_backoff(&self.retry_config, || async { + let auth_header = self.ensure_auth_header().await?; + let response = self + .client + .post(url.clone()) + .header("Authorization", auth_header) + .header("User-Agent", "Goose") + .json(&payload) + .send() + .await?; + + let status = response.status(); + let payload_text: String = response.text().await.ok().unwrap_or_default(); if status == StatusCode::OK { if let Ok(payload) = serde_json::from_str::(&payload_text) { @@ -390,6 +395,7 @@ impl SnowflakeProvider { ))) } } + }).await } } @@ -432,7 +438,7 @@ impl Provider for SnowflakeProvider { let message = response_to_message(response.clone())?; let usage = get_usage(&response)?; let model = get_model(&response); - super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); + emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) } diff --git a/crates/goose/src/providers/venice.rs b/crates/goose/src/providers/venice.rs index 81ee9c0b85ad..59eddf93ed8a 100644 --- a/crates/goose/src/providers/venice.rs +++ b/crates/goose/src/providers/venice.rs @@ -4,10 +4,14 @@ use chrono::Utc; use reqwest::{Client, Response}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use std::time::Duration; +use std::sync::Arc; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; +use super::provider_common::{ + build_endpoint_url, get_shared_client, retry_with_backoff, AuthType, HeaderBuilder, + ProviderConfigBuilder, RetryConfig, +}; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; use mcp_core::{tool::Tool, Role, ToolCall, ToolResult}; @@ -71,12 +75,14 @@ const FALLBACK_MODELS: [&str; 3] = [ #[derive(Debug, Serialize, Deserialize)] pub struct VeniceProvider { #[serde(skip)] - client: Client, + client: Arc, host: String, base_path: String, models_path: String, api_key: String, model: ModelConfig, + #[serde(skip)] + retry_config: RetryConfig, } impl Default for VeniceProvider { @@ -89,23 +95,25 @@ impl Default for VeniceProvider { impl VeniceProvider { pub fn from_env(mut model: ModelConfig) -> Result { let config = crate::config::Config::global(); - let api_key: String = config.get_secret("VENICE_API_KEY")?; - let host: String = config - .get_param("VENICE_HOST") - .unwrap_or_else(|_| VENICE_DEFAULT_HOST.to_string()); - let base_path: String = config - .get_param("VENICE_BASE_PATH") - .unwrap_or_else(|_| VENICE_DEFAULT_BASE_PATH.to_string()); - let models_path: String = config - .get_param("VENICE_MODELS_PATH") - .unwrap_or_else(|_| VENICE_DEFAULT_MODELS_PATH.to_string()); + let config_builder = ProviderConfigBuilder::new(config, "VENICE"); + + let api_key = config_builder.get_api_key()?; + let host = config_builder.get_host(VENICE_DEFAULT_HOST); + let base_path = config_builder + .get_param("BASE_PATH", Some(VENICE_DEFAULT_BASE_PATH)) + .unwrap_or_else(|| VENICE_DEFAULT_BASE_PATH.to_string()); + let models_path = config_builder + .get_param("MODELS_PATH", Some(VENICE_DEFAULT_MODELS_PATH)) + .unwrap_or_else(|| VENICE_DEFAULT_MODELS_PATH.to_string()); // Ensure we only keep the bare model id internally model.model_name = strip_flags(&model.model_name).to_string(); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + // Use shared client for better connection pooling + let client = get_shared_client(); + + // Configure retry settings + let retry_config = RetryConfig::default(); let instance = Self { client, @@ -114,106 +122,112 @@ impl VeniceProvider { models_path, api_key, model, + retry_config, }; Ok(instance) } async fn post(&self, path: &str, body: &str) -> Result { - let base_url = url::Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url - .join(path) - .map_err(|e| ProviderError::RequestFailed(format!("Failed to construct URL: {e}")))?; + let url = build_endpoint_url(&self.host, path)?; + + // Build headers using HeaderBuilder + let headers = HeaderBuilder::new(self.api_key.clone(), AuthType::Bearer).build(); + // Choose GET for models endpoint, POST otherwise - let method = if path.contains("models") { - tracing::debug!("Using GET method for models endpoint"); - self.client.get(url.clone()) - } else { - tracing::debug!("Using POST method for completions endpoint"); - self.client.post(url.clone()) - }; + let is_models_endpoint = path.contains("models"); + + // Use retry logic for resilience + retry_with_backoff(&self.retry_config, || async { + // Log the request details + tracing::debug!("Venice request URL: {}", url); + tracing::debug!("Venice request body: {}", body); + + let response = if is_models_endpoint { + tracing::debug!("Using GET method for models endpoint"); + self.client.get(url.clone()) + .headers(headers.clone()) + .send() + .await? + } else { + tracing::debug!("Using POST method for completions endpoint"); + self.client.post(url.clone()) + .headers(headers.clone()) + .body(body.to_string()) + .send() + .await? + }; - // Log the request details - tracing::debug!("Venice request URL: {}", url); - tracing::debug!("Venice request body: {}", body); - - let response = method - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .body(body.to_string()) - .send() - .await?; - - let status = response.status(); - tracing::debug!("Venice response status: {}", status); - - if !status.is_success() { - // Read response body for more details on error - let error_body = response.text().await.unwrap_or_default(); - - // Log full error response for debugging - tracing::debug!("Full Venice error response: {}", error_body); - - // Try to parse the error response - if let Ok(json) = serde_json::from_str::(&error_body) { - // Print the full JSON error for better debugging - println!( - "Venice API error response: {}", - serde_json::to_string_pretty(&json).unwrap_or_else(|_| json.to_string()) - ); - - // Check for tool support errors - if let Some(details) = json.get("details") { - // Specifically look for tool support issues - if let Some(tools) = details.get("tools") { - if let Some(errors) = tools.get("_errors") { - if errors.to_string().contains("not supported by this model") { - let model_name = self.model.model_name.clone(); - return Err(ProviderError::RequestFailed( - format!("The selected model '{}' does not support tool calls. Please select a model that supports tools, such as 'llama-3.3-70b' or 'mistral-31-24b'.", model_name) - )); + let status = response.status(); + tracing::debug!("Venice response status: {}", status); + + if !status.is_success() { + // Read response body for more details on error + let error_body = response.text().await.unwrap_or_default(); + + // Log full error response for debugging + tracing::debug!("Full Venice error response: {}", error_body); + + // Try to parse the error response + if let Ok(json) = serde_json::from_str::(&error_body) { + // Print the full JSON error for better debugging + println!( + "Venice API error response: {}", + serde_json::to_string_pretty(&json).unwrap_or_else(|_| json.to_string()) + ); + + // Check for tool support errors + if let Some(details) = json.get("details") { + // Specifically look for tool support issues + if let Some(tools) = details.get("tools") { + if let Some(errors) = tools.get("_errors") { + if errors.to_string().contains("not supported by this model") { + let model_name = self.model.model_name.clone(); + return Err(ProviderError::RequestFailed( + format!("The selected model '{}' does not support tool calls. Please select a model that supports tools, such as 'llama-3.3-70b' or 'mistral-31-24b'.", model_name) + )); + } } } } - } - // Check for specific error message in context.issues - if let Some(context) = json.get("context") { - if let Some(issues) = context.get("issues") { - if let Some(issues_array) = issues.as_array() { - for issue in issues_array { - if let Some(message) = issue.get("message").and_then(|m| m.as_str()) - { - if message.contains("tools is not supported by this model") { - let model_name = self.model.model_name.clone(); - return Err(ProviderError::RequestFailed( - format!("The selected model '{}' does not support tool calls. Please select a model that supports tools, such as 'llama-3.3-70b' or 'mistral-31-24b'.", model_name) - )); + // Check for specific error message in context.issues + if let Some(context) = json.get("context") { + if let Some(issues) = context.get("issues") { + if let Some(issues_array) = issues.as_array() { + for issue in issues_array { + if let Some(message) = issue.get("message").and_then(|m| m.as_str()) + { + if message.contains("tools is not supported by this model") { + let model_name = self.model.model_name.clone(); + return Err(ProviderError::RequestFailed( + format!("The selected model '{}' does not support tool calls. Please select a model that supports tools, such as 'llama-3.3-70b' or 'mistral-31-24b'.", model_name) + )); + } } } } } } - } - // General error extraction - if let Some(error_msg) = json.get("error").and_then(|e| e.as_str()) { - return Err(ProviderError::RequestFailed(format!( - "Venice API error: {}", - error_msg - ))); + // General error extraction + if let Some(error_msg) = json.get("error").and_then(|e| e.as_str()) { + return Err(ProviderError::RequestFailed(format!( + "Venice API error: {}", + error_msg + ))); + } } - } - // Fallback for unparseable errors - return Err(ProviderError::RequestFailed(format!( - "Venice API request failed with status code {}", - status - ))); - } + // Fallback for unparseable errors + return Err(ProviderError::RequestFailed(format!( + "Venice API request failed with status code {}", + status + ))); + } - Ok(response) + Ok(response) + }).await } } @@ -251,24 +265,8 @@ impl Provider for VeniceProvider { } async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { - // Fetch supported models via Venice API - let base_url = url::Url::parse(&self.host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {}", e)))?; - let models_url = base_url.join(&self.models_path).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct models URL: {}", e)) - })?; - let response = self - .client - .get(models_url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .send() - .await?; - if !response.status().is_success() { - return Err(ProviderError::RequestFailed(format!( - "Venice API request failed with status {}", - response.status() - ))); - } + // Fetch supported models via Venice API using the post method + let response = self.post(&self.models_path, "").await?; let body = response.text().await?; let json: serde_json::Value = serde_json::from_str(&body) .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse JSON: {}", e)))?; diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs index 7e91a23f8b9e..eb46ed747a88 100644 --- a/crates/goose/src/providers/xai.rs +++ b/crates/goose/src/providers/xai.rs @@ -3,14 +3,17 @@ use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; -use crate::providers::utils::get_model; +use crate::providers::provider_common::{ + build_endpoint_url, get_shared_client, retry_with_backoff, AuthType, HeaderBuilder, + ProviderConfigBuilder, RetryConfig, +}; +use crate::providers::utils::{emit_debug_trace, get_model, handle_response_openai_compat}; use anyhow::Result; use async_trait::async_trait; use mcp_core::Tool; -use reqwest::{Client, StatusCode}; +use reqwest::Client; use serde_json::Value; -use std::time::Duration; -use url::Url; +use std::sync::Arc; pub const XAI_API_HOST: &str = "https://api.x.ai/v1"; pub const XAI_DEFAULT_MODEL: &str = "grok-3"; @@ -39,10 +42,12 @@ pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview"; #[derive(serde::Serialize)] pub struct XaiProvider { #[serde(skip)] - client: Client, + client: Arc, host: String, api_key: String, model: ModelConfig, + #[serde(skip)] + retry_config: RetryConfig, } impl Default for XaiProvider { @@ -55,72 +60,48 @@ impl Default for XaiProvider { impl XaiProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); - let api_key: String = config.get_secret("XAI_API_KEY")?; - let host: String = config - .get_param("XAI_HOST") - .unwrap_or_else(|_| XAI_API_HOST.to_string()); + let config_builder = ProviderConfigBuilder::new(config, "XAI"); + + let api_key = config_builder.get_api_key()?; + let host = config_builder.get_host(XAI_API_HOST); + + // Use shared client for better connection pooling + let client = get_shared_client(); - let client = Client::builder() - .timeout(Duration::from_secs(600)) - .build()?; + // Configure retry settings + let retry_config = RetryConfig::default(); Ok(Self { client, host, api_key, model, + retry_config, }) } async fn post(&self, payload: Value) -> anyhow::Result { - // Ensure the host ends with a slash for proper URL joining - let host = if self.host.ends_with('/') { - self.host.clone() - } else { - format!("{}/", self.host) - }; - let base_url = Url::parse(&host) - .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; - let url = base_url.join("chat/completions").map_err(|e| { - ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) - })?; + let url = build_endpoint_url(&self.host, "chat/completions")?; + + // Build headers using HeaderBuilder + let headers = HeaderBuilder::new(self.api_key.clone(), AuthType::Bearer).build(); tracing::debug!("xAI API URL: {}", url); tracing::debug!("xAI request model: {:?}", self.model.model_name); - let response = self - .client - .post(url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&payload) - .send() - .await?; - - let status = response.status(); - let payload: Option = response.json().await.ok(); - - match status { - StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, payload))) - } - StatusCode::PAYLOAD_TOO_LARGE => { - Err(ProviderError::ContextLengthExceeded(format!("{:?}", payload))) - } - StatusCode::TOO_MANY_REQUESTS => { - Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) - } - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - Err(ProviderError::ServerError(format!("{:?}", payload))) - } - _ => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) - } - } + // Use retry logic for resilience + retry_with_backoff(&self.retry_config, || async { + let response = self + .client + .post(url.clone()) + .headers(headers.clone()) + .json(&payload) + .send() + .await?; + + handle_response_openai_compat(response).await + }) + .await } } @@ -175,7 +156,37 @@ impl Provider for XaiProvider { Err(e) => return Err(e), }; let model = get_model(&response); - super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); + emit_debug_trace(&self.model, &payload, &response, &usage); Ok((message, ProviderUsage::new(model, usage))) } + + /// Fetch supported models from xAI API; returns Err on failure, Ok(None) if no models found + async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { + let url = build_endpoint_url(&self.host, "models")?; + + // Build headers using HeaderBuilder + let headers = HeaderBuilder::new(self.api_key.clone(), AuthType::Bearer).build(); + + let response = self.client.get(url).headers(headers).send().await?; + let json: serde_json::Value = response.json().await?; + + if let Some(err_obj) = json.get("error") { + let msg = err_obj + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error"); + return Err(ProviderError::Authentication(msg.to_string())); + } + + let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| { + ProviderError::UsageError("Missing data field in JSON response".into()) + })?; + + let mut models: Vec = data + .iter() + .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string)) + .collect(); + models.sort(); + Ok(Some(models)) + } } diff --git a/ui/desktop/src/utils/costDatabase.ts b/ui/desktop/src/utils/costDatabase.ts index 82c684bc2adc..c8832a1c22a5 100644 --- a/ui/desktop/src/utils/costDatabase.ts +++ b/ui/desktop/src/utils/costDatabase.ts @@ -7,8 +7,175 @@ export interface ModelCostInfo { currency: string; // Currency symbol } -// In-memory cache for current session only -const sessionPricingCache = new Map(); +// In-memory cache for current model pricing only +let currentModelPricing: { + provider: string; + model: string; + costInfo: ModelCostInfo | null; +} | null = null; + +// Request batching to prevent duplicate API calls +let pendingRequests = new Map>(); + +// LocalStorage keys +const PRICING_CACHE_KEY = 'goose_pricing_cache'; +const PRICING_CACHE_TIMESTAMP_KEY = 'goose_pricing_cache_timestamp'; +const RECENTLY_USED_MODELS_KEY = 'goose_recently_used_models'; +const CACHE_TTL_MS = 7 * 24 * 60 * 60 * 1000; // 7 days in milliseconds +const MAX_RECENTLY_USED_MODELS = 20; // Keep only the last 20 used models in cache + +interface PricingItem { + provider: string; + model: string; + input_token_cost: number; + output_token_cost: number; + currency: string; +} + +interface PricingCacheData { + pricing: PricingItem[]; + timestamp: number; +} + +interface RecentlyUsedModel { + provider: string; + model: string; + lastUsed: number; +} + +/** + * Get recently used models from localStorage + */ +function getRecentlyUsedModels(): RecentlyUsedModel[] { + try { + const stored = localStorage.getItem(RECENTLY_USED_MODELS_KEY); + return stored ? JSON.parse(stored) : []; + } catch (error) { + console.error('Error loading recently used models:', error); + return []; + } +} + +/** + * Add a model to the recently used list + */ +function addToRecentlyUsed(provider: string, model: string): void { + try { + let recentModels = getRecentlyUsedModels(); + + // Remove existing entry if present + recentModels = recentModels.filter((m) => !(m.provider === provider && m.model === model)); + + // Add to front + recentModels.unshift({ provider, model, lastUsed: Date.now() }); + + // Keep only the most recent models + recentModels = recentModels.slice(0, MAX_RECENTLY_USED_MODELS); + + localStorage.setItem(RECENTLY_USED_MODELS_KEY, JSON.stringify(recentModels)); + } catch (error) { + console.error('Error saving recently used models:', error); + } +} + +/** + * Load pricing data from localStorage cache - only for recently used models + */ +function loadPricingFromLocalStorage(): PricingCacheData | null { + try { + const cached = localStorage.getItem(PRICING_CACHE_KEY); + const timestamp = localStorage.getItem(PRICING_CACHE_TIMESTAMP_KEY); + + if (cached && timestamp) { + const cacheAge = Date.now() - parseInt(timestamp, 10); + if (cacheAge < CACHE_TTL_MS) { + const fullCache = JSON.parse(cached) as PricingCacheData; + const recentModels = getRecentlyUsedModels(); + + // Filter to only include recently used models + const filteredPricing = fullCache.pricing.filter((p) => + recentModels.some((r) => r.provider === p.provider && r.model === p.model) + ); + + console.log( + `Loading ${filteredPricing.length} recently used models from cache (out of ${fullCache.pricing.length} total)` + ); + + return { + pricing: filteredPricing, + timestamp: fullCache.timestamp, + }; + } else { + console.log('LocalStorage pricing cache expired'); + } + } + } catch (error) { + console.error('Error loading pricing from localStorage:', error); + } + return null; +} + +/** + * Save pricing data to localStorage - merge with existing data + */ +function savePricingToLocalStorage(data: PricingCacheData, mergeWithExisting = true): void { + try { + if (mergeWithExisting) { + // Load existing full cache + const existingCached = localStorage.getItem(PRICING_CACHE_KEY); + if (existingCached) { + const existingData = JSON.parse(existingCached) as PricingCacheData; + + // Create a map of existing pricing for quick lookup + const pricingMap = new Map(); + existingData.pricing.forEach((p) => { + pricingMap.set(`${p.provider}/${p.model}`, p); + }); + + // Update with new data + data.pricing.forEach((p) => { + pricingMap.set(`${p.provider}/${p.model}`, p); + }); + + // Convert back to array + data = { + pricing: Array.from(pricingMap.values()), + timestamp: data.timestamp, + }; + } + } + + localStorage.setItem(PRICING_CACHE_KEY, JSON.stringify(data)); + localStorage.setItem(PRICING_CACHE_TIMESTAMP_KEY, data.timestamp.toString()); + console.log(`Saved ${data.pricing.length} models to localStorage cache`); + } catch (error) { + console.error('Error saving pricing to localStorage:', error); + } +} + +/** + * Clean up pricing cache to prevent excessive storage usage + */ +function cleanupPricingCache(): void { + try { + // Remove old cache entries if they exist + const oldKeys = [ + 'modelCosts', + 'modelCostsTimestamp', + 'goose_model_costs', + 'goose_model_costs_timestamp', + ]; + + oldKeys.forEach((key) => { + if (localStorage.getItem(key)) { + localStorage.removeItem(key); + console.log(`Removed old cache key: ${key}`); + } + }); + } catch (error) { + console.error('Error cleaning up pricing cache:', error); + } +} /** * Fetch pricing data from backend for specific provider/model @@ -92,21 +259,6 @@ async function fetchPricingForModel( return null; } -/** - * Initialize the cost database - no-op since we fetch on demand now - */ -export async function initializeCostDatabase(): Promise { - // Clear session cache on init - sessionPricingCache.clear(); -} - -/** - * Update model costs from providers - no-op since we fetch on demand - */ -export async function updateAllModelCosts(): Promise { - // No-op - we fetch on demand now -} - /** * Parse OpenRouter model ID to extract provider and model * e.g., "anthropic/claude-sonnet-4" -> ["anthropic", "claude-sonnet-4"] @@ -120,29 +272,123 @@ function parseOpenRouterModel(modelId: string): [string, string] | null { } /** - * Get cost information for a specific model with session caching + * Initialize the cost database - only load commonly used models on startup */ -export function getCostForModel(provider: string, model: string): ModelCostInfo | null { - const cacheKey = `${provider}/${model}`; +export async function initializeCostDatabase(): Promise { + try { + // Clean up any existing large caches first + cleanupPricingCache(); + + // First check if we have valid cached data + const cachedData = loadPricingFromLocalStorage(); + if (cachedData && cachedData.pricing.length > 0) { + console.log('Using cached pricing data from localStorage'); + return; + } - // Check session cache first - if (sessionPricingCache.has(cacheKey)) { - return sessionPricingCache.get(cacheKey) || null; - } + // List of commonly used models to pre-fetch + const commonModels = [ + { provider: 'openai', model: 'gpt-4o' }, + { provider: 'openai', model: 'gpt-4o-mini' }, + { provider: 'openai', model: 'gpt-4-turbo' }, + { provider: 'openai', model: 'gpt-4' }, + { provider: 'openai', model: 'gpt-3.5-turbo' }, + { provider: 'anthropic', model: 'claude-3-5-sonnet' }, + { provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' }, + { provider: 'anthropic', model: 'claude-3-opus' }, + { provider: 'anthropic', model: 'claude-3-sonnet' }, + { provider: 'anthropic', model: 'claude-3-haiku' }, + { provider: 'google', model: 'gemini-1.5-pro' }, + { provider: 'google', model: 'gemini-1.5-flash' }, + { provider: 'deepseek', model: 'deepseek-chat' }, + { provider: 'deepseek', model: 'deepseek-reasoner' }, + { provider: 'meta-llama', model: 'llama-3.2-90b-text-preview' }, + { provider: 'meta-llama', model: 'llama-3.1-405b-instruct' }, + ]; + + // Get recently used models + const recentModels = getRecentlyUsedModels(); + + // Combine common and recent models (deduplicated) + const modelsToFetch = new Map(); + + // Add common models + commonModels.forEach((m) => { + modelsToFetch.set(`${m.provider}/${m.model}`, m); + }); - // For OpenRouter models, also check if we have cached data under the parsed provider/model - if (provider.toLowerCase() === 'openrouter') { - const parsed = parseOpenRouterModel(model); - if (parsed) { - const [parsedProvider, parsedModel] = parsed; - const parsedCacheKey = `${parsedProvider}/${parsedModel}`; - if (sessionPricingCache.has(parsedCacheKey)) { - const cachedData = sessionPricingCache.get(parsedCacheKey) || null; - // Also cache it under the original OpenRouter key for future lookups - sessionPricingCache.set(cacheKey, cachedData); - return cachedData; - } + // Add recent models + recentModels.forEach((m) => { + modelsToFetch.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model }); + }); + + console.log(`Initializing cost database with ${modelsToFetch.size} models...`); + + // Fetch only the pricing we need + const apiUrl = getApiUrl('/config/pricing'); + const secretKey = getSecretKey(); + + const headers: HeadersInit = { 'Content-Type': 'application/json' }; + if (secretKey) { + headers['X-Secret-Key'] = secretKey; } + + const response = await fetch(apiUrl, { + method: 'POST', + headers, + body: JSON.stringify({ + configured_only: false, + models: Array.from(modelsToFetch.values()), // Send specific models if API supports it + }), + }); + + if (!response.ok) { + console.error('Failed to fetch initial pricing data:', response.status); + return; + } + + const data = await response.json(); + console.log(`Fetched pricing for ${data.pricing?.length || 0} models`); + + if (data.pricing && data.pricing.length > 0) { + // Filter to only the models we requested (in case API returns all) + const filteredPricing = data.pricing.filter((p: PricingItem) => + modelsToFetch.has(`${p.provider}/${p.model}`) + ); + + // Save to localStorage + const cacheData: PricingCacheData = { + pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50), // Fallback to first 50 if filtering didn't work + timestamp: Date.now(), + }; + savePricingToLocalStorage(cacheData, false); // Don't merge on initial load + } + } catch (error) { + console.error('Error initializing cost database:', error); + } +} + +/** + * Update model costs from providers - no longer needed + */ +export async function updateAllModelCosts(): Promise { + // No-op - we fetch on demand now +} + +/** + * Get cost information for a specific model with caching + */ +export function getCostForModel(provider: string, model: string): ModelCostInfo | null { + // Track this model as recently used + addToRecentlyUsed(provider, model); + + // Check if it's the same model we already have cached in memory + if ( + currentModelPricing && + currentModelPricing.provider === provider && + currentModelPricing.model === model + ) { + return currentModelPricing.costInfo; } // For local/free providers, return zero cost immediately @@ -153,11 +399,48 @@ export function getCostForModel(provider: string, model: string): ModelCostInfo output_token_cost: 0, currency: '$', }; - sessionPricingCache.set(cacheKey, zeroCost); + currentModelPricing = { provider, model, costInfo: zeroCost }; return zeroCost; } - // Need to fetch - return null and let component handle async fetch + // Check localStorage cache (which now only contains recently used models) + const cachedData = loadPricingFromLocalStorage(); + if (cachedData) { + const pricing = cachedData.pricing.find((p) => { + const providerMatch = p.provider.toLowerCase() === provider.toLowerCase(); + + // More flexible model matching - handle versioned models + let modelMatch = p.model === model; + + // If exact match fails, try matching without version suffix + if (!modelMatch && model.includes('-20')) { + // Remove date suffix like -20241022 + const modelWithoutDate = model.replace(/-20\d{6}$/, ''); + modelMatch = p.model === modelWithoutDate; + + // Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet) + if (!modelMatch) { + const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.'); + modelMatch = p.model === modelWithDots; + } + } + + return providerMatch && modelMatch; + }); + + if (pricing) { + const costInfo = { + input_token_cost: pricing.input_token_cost, + output_token_cost: pricing.output_token_cost, + currency: pricing.currency || '$', + }; + currentModelPricing = { provider, model, costInfo }; + return costInfo; + } + } + + // Need to fetch new pricing - return null for now + // The component will handle the async fetch return null; } @@ -168,45 +451,82 @@ export async function fetchAndCachePricing( provider: string, model: string ): Promise<{ costInfo: ModelCostInfo | null; error?: string } | null> { - try { - const cacheKey = `${provider}/${model}`; - const costInfo = await fetchPricingForModel(provider, model); - - // Cache the result in session cache under the original key - sessionPricingCache.set(cacheKey, costInfo); - - // For OpenRouter models, also cache under the parsed provider/model key - // This helps with cross-referencing between frontend requests and backend responses - if (provider.toLowerCase() === 'openrouter') { - const parsed = parseOpenRouterModel(model); - if (parsed) { - const [parsedProvider, parsedModel] = parsed; - const parsedCacheKey = `${parsedProvider}/${parsedModel}`; - sessionPricingCache.set(parsedCacheKey, costInfo); - } + const key = `${provider}/${model}`; + + // Check if request is already pending + if (pendingRequests.has(key)) { + console.log(`Request already pending for ${key}, waiting...`); + try { + const result = await pendingRequests.get(key); + return result ? { costInfo: result } : { costInfo: null, error: 'model_not_found' }; + } catch (error) { + return null; } + } + + try { + // Create promise for batching + const promise = fetchPricingForModel(provider, model); + pendingRequests.set(key, promise); + + const costInfo = await promise; if (costInfo) { + // Cache the result in memory + currentModelPricing = { provider, model, costInfo }; + + // Update localStorage cache with this new data + const cachedData = loadPricingFromLocalStorage(); + if (cachedData) { + // Check if this model already exists in cache + const existingIndex = cachedData.pricing.findIndex( + (p) => p.provider.toLowerCase() === provider.toLowerCase() && p.model === model + ); + + const newPricing = { + provider, + model, + input_token_cost: costInfo.input_token_cost, + output_token_cost: costInfo.output_token_cost, + currency: costInfo.currency, + }; + + if (existingIndex >= 0) { + // Update existing + cachedData.pricing[existingIndex] = newPricing; + } else { + // Add new + cachedData.pricing.push(newPricing); + } + + // Save updated cache + savePricingToLocalStorage(cachedData); + } + return { costInfo }; } else { - // Model not found in pricing data + // Cache the null result in memory + currentModelPricing = { provider, model, costInfo: null }; + + // Check if the API call succeeded but model wasn't found + // We can determine this by checking if we got a response but no matching model return { costInfo: null, error: 'model_not_found' }; } } catch (error) { + console.error('Error in fetchAndCachePricing:', error); // This is a real API/network error return null; + } finally { + // Always remove from pending + pendingRequests.delete(key); } } /** - * Refresh pricing data from backend + * Refresh pricing data from backend - only refresh recently used models */ export async function refreshPricing(): Promise { try { - // Clear session cache to force re-fetch - sessionPricingCache.clear(); - - // The actual refresh happens on the backend when we call with configured_only: false const apiUrl = getApiUrl('/config/pricing'); const secretKey = getSecretKey(); @@ -215,14 +535,64 @@ export async function refreshPricing(): Promise { headers['X-Secret-Key'] = secretKey; } + // Get recently used models to refresh + const recentModels = getRecentlyUsedModels(); + + // Add some common models as well + const commonModels = [ + { provider: 'openai', model: 'gpt-4o' }, + { provider: 'openai', model: 'gpt-4o-mini' }, + { provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' }, + { provider: 'google', model: 'gemini-1.5-pro' }, + ]; + + // Combine and deduplicate + const modelsToRefresh = new Map(); + + commonModels.forEach((m) => { + modelsToRefresh.set(`${m.provider}/${m.model}`, m); + }); + + recentModels.forEach((m) => { + modelsToRefresh.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model }); + }); + + console.log(`Refreshing pricing for ${modelsToRefresh.size} models...`); + const response = await fetch(apiUrl, { method: 'POST', headers, - body: JSON.stringify({ configured_only: false }), + body: JSON.stringify({ + configured_only: false, + models: Array.from(modelsToRefresh.values()), // Send specific models if API supports it + }), }); - return response.ok; + if (response.ok) { + const data = await response.json(); + + if (data.pricing && data.pricing.length > 0) { + // Filter to only the models we requested (in case API returns all) + const filteredPricing = data.pricing.filter((p: PricingItem) => + modelsToRefresh.has(`${p.provider}/${p.model}`) + ); + + // Save fresh data to localStorage (merge with existing) + const cacheData: PricingCacheData = { + pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50), + timestamp: Date.now(), + }; + savePricingToLocalStorage(cacheData, true); // Merge with existing + } + + // Clear current memory cache to force re-fetch + currentModelPricing = null; + return true; + } + + return false; } catch (error) { + console.error('Error refreshing pricing data:', error); return false; } } @@ -233,7 +603,7 @@ declare global { getCostForModel?: typeof getCostForModel; fetchAndCachePricing?: typeof fetchAndCachePricing; refreshPricing?: typeof refreshPricing; - sessionPricingCache?: typeof sessionPricingCache; + currentModelPricing?: typeof currentModelPricing; } } @@ -241,5 +611,5 @@ if (process.env.NODE_ENV === 'development' || typeof window !== 'undefined') { window.getCostForModel = getCostForModel; window.fetchAndCachePricing = fetchAndCachePricing; window.refreshPricing = refreshPricing; - window.sessionPricingCache = sessionPricingCache; + window.currentModelPricing = currentModelPricing; }