From 73aa3e77f6a1173eddbe3728b379e3e42682959e Mon Sep 17 00:00:00 2001 From: Douwe Osinga Date: Mon, 4 Aug 2025 15:42:57 +0200 Subject: [PATCH 1/3] Custom providers infra --- .../goose/src/providers/custom_providers.rs | 71 ++++++++++++++ crates/goose/src/providers/factory.rs | 94 +++++++++---------- crates/goose/src/providers/mod.rs | 2 + crates/goose/src/providers/ollama.rs | 15 +++ crates/goose/src/providers/openai.rs | 32 +++++++ .../goose/src/providers/provider_registry.rs | 54 +++++++++++ 6 files changed, 217 insertions(+), 51 deletions(-) create mode 100644 crates/goose/src/providers/custom_providers.rs create mode 100644 crates/goose/src/providers/provider_registry.rs diff --git a/crates/goose/src/providers/custom_providers.rs b/crates/goose/src/providers/custom_providers.rs new file mode 100644 index 000000000000..f671771c0d17 --- /dev/null +++ b/crates/goose/src/providers/custom_providers.rs @@ -0,0 +1,71 @@ +use super::base::ModelInfo; +use super::provider_registry::ProviderRegistry; +use crate::model::ModelConfig; +use crate::providers::ollama::OllamaProvider; +use crate::providers::openai::OpenAiProvider; +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::Path; + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum ProviderEngine { + OpenAI, + Ollama, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct CustomProviderConfig { + pub name: String, + pub engine: ProviderEngine, + pub display_name: String, + pub description: Option, + pub api_key_env: String, + pub base_url: String, + pub models: Vec, + // Optional fields for OpenAI-compatible providers + pub headers: Option>, + pub timeout_seconds: Option, +} + +pub fn load_custom_providers(dir: &Path) -> Result> { + let mut configs = Vec::new(); + + if !dir.exists() { + return Ok(configs); + } + + for entry in std::fs::read_dir(dir)? { + let entry = entry?; + let path = entry.path(); + + if path.extension().and_then(|s| s.to_str()) == Some("json") { + let content = std::fs::read_to_string(&path)?; + let config: CustomProviderConfig = serde_json::from_str(&content) + .map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", path.display(), e))?; + + configs.push(config); + } + } + + Ok(configs) +} + +pub fn register_custom_providers(registry: &mut ProviderRegistry, dir: &Path) -> Result<()> { + for config in load_custom_providers(dir)? { + match config.engine { + ProviderEngine::OpenAI => { + registry.register(move |model: ModelConfig| { + OpenAiProvider::from_custom_config(model, config) + }); + } + ProviderEngine::Ollama => { + registry.register(move |model: ModelConfig| { + OllamaProvider::from_custom_config(model, config) + }); + } + } + } + Ok(()) +} diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index b0cb696b71bf..1e122de2dce4 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -1,5 +1,8 @@ +use once_cell::sync::Lazy; use std::sync::Arc; +#[cfg(test)] +use super::errors::ProviderError; use super::{ anthropic::AnthropicProvider, azure::AzureProvider, @@ -21,11 +24,12 @@ use super::{ venice::VeniceProvider, xai::XaiProvider, }; +use crate::config::APP_STRATEGY; use crate::model::ModelConfig; +use crate::providers::custom_providers::register_custom_providers; +use crate::providers::provider_registry::ProviderRegistry; use anyhow::Result; - -#[cfg(test)] -use super::errors::ProviderError; +use etcetera::{choose_app_strategy, AppStrategy}; #[cfg(test)] use rmcp::model::Tool; @@ -39,27 +43,38 @@ fn default_fallback_turns() -> usize { 2 } +static REGISTRY: Lazy = Lazy::new(|| { + let mut registry = ProviderRegistry::new(); + + registry.register(|model| OpenAiProvider::from_env(model)); + registry.register(|model| AnthropicProvider::from_env(model)); + registry.register(|model| AzureProvider::from_env(model)); + registry.register(|model| BedrockProvider::from_env(model)); + registry.register(|model| ClaudeCodeProvider::from_env(model)); + registry.register(|model| DatabricksProvider::from_env(model)); + registry.register(|model| GcpVertexAIProvider::from_env(model)); + registry.register(|model| GeminiCliProvider::from_env(model)); + registry.register(|model| GoogleProvider::from_env(model)); + registry.register(|model| GroqProvider::from_env(model)); + registry.register(|model| LiteLLMProvider::from_env(model)); + registry.register(|model| OllamaProvider::from_env(model)); + registry.register(|model| OpenRouterProvider::from_env(model)); + registry.register(|model| SageMakerTgiProvider::from_env(model)); + registry.register(|model| VeniceProvider::from_env(model)); + registry.register(|model| SnowflakeProvider::from_env(model)); + registry.register(|model| XaiProvider::from_env(model)); + + let config_dir = choose_app_strategy(APP_STRATEGY.clone()) + .expect("goose requires a home dir") + .config_dir(); + + register_custom_providers(&mut registry, &config_dir.join("custom_providers")); + + registry +}); + pub fn providers() -> Vec { - vec![ - AnthropicProvider::metadata(), - AzureProvider::metadata(), - BedrockProvider::metadata(), - ClaudeCodeProvider::metadata(), - DatabricksProvider::metadata(), - GcpVertexAIProvider::metadata(), - GeminiCliProvider::metadata(), - // GithubCopilotProvider::metadata(), - GoogleProvider::metadata(), - GroqProvider::metadata(), - LiteLLMProvider::metadata(), - OllamaProvider::metadata(), - OpenAiProvider::metadata(), - OpenRouterProvider::metadata(), - SageMakerTgiProvider::metadata(), - VeniceProvider::metadata(), - SnowflakeProvider::metadata(), - XaiProvider::metadata(), - ] + REGISTRY.all_metadata() } pub fn create(name: &str, model: ModelConfig) -> Result> { @@ -71,7 +86,9 @@ pub fn create(name: &str, model: ModelConfig) -> Result> { return create_lead_worker_from_env(name, &model, &lead_model_name); } - create_provider(name, model) + + // Default: create regular provider using registry + REGISTRY.create(name, model) } /// Create a lead/worker provider from environment variables @@ -133,8 +150,8 @@ fn create_lead_worker_from_env( }; // Create the providers - let lead_provider = create_provider(&lead_provider_name, lead_model_config)?; - let worker_provider = create_provider(default_provider_name, worker_model_config)?; + let lead_provider = REGISTRY.create(&lead_provider_name, lead_model_config)?; + let worker_provider = REGISTRY.create(default_provider_name, worker_model_config)?; // Create the lead/worker provider with configured settings Ok(Arc::new(LeadWorkerProvider::new_with_settings( @@ -146,31 +163,6 @@ fn create_lead_worker_from_env( ))) } -fn create_provider(name: &str, model: ModelConfig) -> Result> { - // We use Arc instead of Box to be able to clone for multiple async tasks - match name { - "anthropic" => Ok(Arc::new(AnthropicProvider::from_env(model)?)), - "aws_bedrock" => Ok(Arc::new(BedrockProvider::from_env(model)?)), - "azure_openai" => Ok(Arc::new(AzureProvider::from_env(model)?)), - "claude-code" => Ok(Arc::new(ClaudeCodeProvider::from_env(model)?)), - "databricks" => Ok(Arc::new(DatabricksProvider::from_env(model)?)), - "gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)), - "gemini-cli" => Ok(Arc::new(GeminiCliProvider::from_env(model)?)), - // "github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)), - "google" => Ok(Arc::new(GoogleProvider::from_env(model)?)), - "groq" => Ok(Arc::new(GroqProvider::from_env(model)?)), - "litellm" => Ok(Arc::new(LiteLLMProvider::from_env(model)?)), - "ollama" => Ok(Arc::new(OllamaProvider::from_env(model)?)), - "openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)), - "openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)), - "sagemaker_tgi" => Ok(Arc::new(SageMakerTgiProvider::from_env(model)?)), - "snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)), - "venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)), - "xai" => Ok(Arc::new(XaiProvider::from_env(model)?)), - _ => Err(anyhow::anyhow!("Unknown provider: {}", name)), - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 38c810d4d171..24b8e4412891 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -4,6 +4,7 @@ pub mod azureauth; pub mod base; pub mod bedrock; pub mod claude_code; +mod custom_providers; pub mod databricks; pub mod embedding; pub mod errors; @@ -22,6 +23,7 @@ pub mod ollama; pub mod openai; pub mod openrouter; pub mod pricing; +mod provider_registry; pub mod sagemaker_tgi; pub mod snowflake; pub mod testprovider; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 1fa9300a2457..fc37e641cd49 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -4,6 +4,7 @@ use super::utils::{get_model, handle_response_openai_compat}; use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; +use crate::providers::custom_providers::CustomProviderConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; use crate::utils::safe_truncate; use anyhow::Result; @@ -52,6 +53,20 @@ impl OllamaProvider { }) } + pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result { + use reqwest::Client; + use std::time::Duration; + + let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(OLLAMA_TIMEOUT)); + let client = Client::builder().timeout(timeout).build()?; + + Ok(Self { + client, + host: config.base_url, + model, + }) + } + /// Get the base URL for Ollama API calls fn get_base_url(&self) -> Result { // OLLAMA_HOST is sometimes just the 'host' or 'host:port' without a scheme diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index e57e9ae46286..2081aef58b33 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -21,6 +21,7 @@ use crate::impl_provider_default; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::MessageStream; +use crate::providers::custom_providers::CustomProviderConfig; use crate::providers::formats::openai::response_to_streaming_message; use crate::providers::utils::handle_status_openai_compat; use rmcp::model::Tool; @@ -87,6 +88,37 @@ impl OpenAiProvider { }) } + pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result { + use reqwest::Client; + use std::time::Duration; + + let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(600)); + let client = Client::builder().timeout(timeout).build()?; + + let api_key = std::env::var(&config.api_key_env) + .map_err(|_| anyhow::anyhow!("Missing API key: {}", config.api_key_env))?; + + let url = url::Url::parse(&config.base_url)?; + let host = format!( + "{}://{}:{}", + url.scheme(), + url.host_str().unwrap_or(""), + url.port_or_known_default().unwrap_or(443) + ); + let base_path = url.path().trim_start_matches('/').to_string(); + + Ok(Self { + client, + host, + base_path, + api_key, + organization: None, + project: None, + model, + custom_headers: config.headers, + }) + } + /// Helper function to add OpenAI-specific headers to a request fn add_headers(&self, mut request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { // Add organization header if present diff --git a/crates/goose/src/providers/provider_registry.rs b/crates/goose/src/providers/provider_registry.rs new file mode 100644 index 000000000000..d329745a943a --- /dev/null +++ b/crates/goose/src/providers/provider_registry.rs @@ -0,0 +1,54 @@ +use anyhow::Result; +use std::collections::HashMap; +use std::sync::Arc; + +use super::base::{Provider, ProviderMetadata}; +use crate::model::ModelConfig; + +type ProviderConstructor = fn(ModelConfig) -> Result>; + +struct ProviderEntry { + metadata: ProviderMetadata, + constructor: ProviderConstructor, +} + +pub struct ProviderRegistry { + entries: HashMap, +} + +impl ProviderRegistry { + pub fn new() -> Self { + Self { + entries: HashMap::new(), + } + } + + pub fn register

(&mut self, constructor: fn(ModelConfig) -> Result

) + where + P: Provider + 'static, + { + let metadata = P::metadata(); + let name = metadata.name.clone(); + + self.entries.insert( + name, + ProviderEntry { + metadata, + constructor: move |model| Ok(Arc::new(constructor(model)?)), + }, + ); + } + + pub fn create(&self, name: &str, model: ModelConfig) -> Result> { + let entry = self + .entries + .get(name) + .ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", name))?; + + (entry.constructor)(model) + } + + pub fn all_metadata(&self) -> Vec { + self.entries.values().map(|e| e.metadata.clone()).collect() + } +} From 8cb97c188a0ffc7103a1dfa9ac1e980977c428eb Mon Sep 17 00:00:00 2001 From: Douwe Osinga Date: Mon, 4 Aug 2025 19:44:07 +0200 Subject: [PATCH 2/3] Small change --- crates/goose/src/providers/ollama.rs | 3 --- crates/goose/src/providers/openai.rs | 3 --- crates/goose/src/providers/provider_registry.rs | 4 ++-- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index fc37e641cd49..239e365c32a6 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -54,9 +54,6 @@ impl OllamaProvider { } pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result { - use reqwest::Client; - use std::time::Duration; - let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(OLLAMA_TIMEOUT)); let client = Client::builder().timeout(timeout).build()?; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 2081aef58b33..f5d315c4c7f4 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -89,9 +89,6 @@ impl OpenAiProvider { } pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result { - use reqwest::Client; - use std::time::Duration; - let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(600)); let client = Client::builder().timeout(timeout).build()?; diff --git a/crates/goose/src/providers/provider_registry.rs b/crates/goose/src/providers/provider_registry.rs index d329745a943a..e83703378639 100644 --- a/crates/goose/src/providers/provider_registry.rs +++ b/crates/goose/src/providers/provider_registry.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use super::base::{Provider, ProviderMetadata}; use crate::model::ModelConfig; -type ProviderConstructor = fn(ModelConfig) -> Result>; +type ProviderConstructor = Box Result> + Send + Sync>; struct ProviderEntry { metadata: ProviderMetadata, @@ -34,7 +34,7 @@ impl ProviderRegistry { name, ProviderEntry { metadata, - constructor: move |model| Ok(Arc::new(constructor(model)?)), + constructor: Box::new(move |model| Ok(Arc::new(constructor(model)?))), }, ); } From bbf3ec394790993ac8b01ddfcc5b247252b02d73 Mon Sep 17 00:00:00 2001 From: Douwe Osinga Date: Mon, 4 Aug 2025 23:12:59 +0200 Subject: [PATCH 3/3] fix it for real --- crates/goose/src/providers/custom_providers.rs | 6 +++--- crates/goose/src/providers/provider_registry.rs | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/crates/goose/src/providers/custom_providers.rs b/crates/goose/src/providers/custom_providers.rs index f671771c0d17..5a9e29a4b9b8 100644 --- a/crates/goose/src/providers/custom_providers.rs +++ b/crates/goose/src/providers/custom_providers.rs @@ -15,7 +15,7 @@ pub enum ProviderEngine { Ollama, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct CustomProviderConfig { pub name: String, pub engine: ProviderEngine, @@ -57,12 +57,12 @@ pub fn register_custom_providers(registry: &mut ProviderRegistry, dir: &Path) -> match config.engine { ProviderEngine::OpenAI => { registry.register(move |model: ModelConfig| { - OpenAiProvider::from_custom_config(model, config) + OpenAiProvider::from_custom_config(model, config.clone()) }); } ProviderEngine::Ollama => { registry.register(move |model: ModelConfig| { - OllamaProvider::from_custom_config(model, config) + OllamaProvider::from_custom_config(model, config.clone()) }); } } diff --git a/crates/goose/src/providers/provider_registry.rs b/crates/goose/src/providers/provider_registry.rs index e83703378639..d9fcc5d11d31 100644 --- a/crates/goose/src/providers/provider_registry.rs +++ b/crates/goose/src/providers/provider_registry.rs @@ -23,9 +23,10 @@ impl ProviderRegistry { } } - pub fn register

(&mut self, constructor: fn(ModelConfig) -> Result

) + pub fn register(&mut self, constructor: F) where P: Provider + 'static, + F: Fn(ModelConfig) -> Result

+ Send + Sync + 'static, { let metadata = P::metadata(); let name = metadata.name.clone();