Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions crates/goose/src/providers/custom_providers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use super::base::ModelInfo;
use super::provider_registry::ProviderRegistry;
use crate::model::ModelConfig;
use crate::providers::ollama::OllamaProvider;
use crate::providers::openai::OpenAiProvider;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;

#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ProviderEngine {
OpenAI,
Ollama,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CustomProviderConfig {
pub name: String,
pub engine: ProviderEngine,
pub display_name: String,
pub description: Option<String>,
pub api_key_env: String,
pub base_url: String,
pub models: Vec<ModelInfo>,
// Optional fields for OpenAI-compatible providers
pub headers: Option<HashMap<String, String>>,
pub timeout_seconds: Option<u64>,
}

pub fn load_custom_providers(dir: &Path) -> Result<Vec<CustomProviderConfig>> {
let mut configs = Vec::new();

if !dir.exists() {
return Ok(configs);
}

for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();

if path.extension().and_then(|s| s.to_str()) == Some("json") {
let content = std::fs::read_to_string(&path)?;
let config: CustomProviderConfig = serde_json::from_str(&content)
.map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", path.display(), e))?;

configs.push(config);
}
}

Ok(configs)
}

pub fn register_custom_providers(registry: &mut ProviderRegistry, dir: &Path) -> Result<()> {
for config in load_custom_providers(dir)? {
match config.engine {
ProviderEngine::OpenAI => {
registry.register(move |model: ModelConfig| {
OpenAiProvider::from_custom_config(model, config.clone())
});
}
ProviderEngine::Ollama => {
registry.register(move |model: ModelConfig| {
OllamaProvider::from_custom_config(model, config.clone())
});
}
}
}
Ok(())
}
94 changes: 43 additions & 51 deletions crates/goose/src/providers/factory.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use once_cell::sync::Lazy;
use std::sync::Arc;

#[cfg(test)]
use super::errors::ProviderError;
use super::{
anthropic::AnthropicProvider,
azure::AzureProvider,
Expand All @@ -21,11 +24,12 @@ use super::{
venice::VeniceProvider,
xai::XaiProvider,
};
use crate::config::APP_STRATEGY;
use crate::model::ModelConfig;
use crate::providers::custom_providers::register_custom_providers;
use crate::providers::provider_registry::ProviderRegistry;
use anyhow::Result;

#[cfg(test)]
use super::errors::ProviderError;
use etcetera::{choose_app_strategy, AppStrategy};
#[cfg(test)]
use rmcp::model::Tool;

Expand All @@ -39,27 +43,38 @@ fn default_fallback_turns() -> usize {
2
}

static REGISTRY: Lazy<ProviderRegistry> = Lazy::new(|| {
let mut registry = ProviderRegistry::new();

registry.register(|model| OpenAiProvider::from_env(model));
registry.register(|model| AnthropicProvider::from_env(model));
registry.register(|model| AzureProvider::from_env(model));
registry.register(|model| BedrockProvider::from_env(model));
registry.register(|model| ClaudeCodeProvider::from_env(model));
registry.register(|model| DatabricksProvider::from_env(model));
registry.register(|model| GcpVertexAIProvider::from_env(model));
registry.register(|model| GeminiCliProvider::from_env(model));
registry.register(|model| GoogleProvider::from_env(model));
registry.register(|model| GroqProvider::from_env(model));
registry.register(|model| LiteLLMProvider::from_env(model));
registry.register(|model| OllamaProvider::from_env(model));
registry.register(|model| OpenRouterProvider::from_env(model));
registry.register(|model| SageMakerTgiProvider::from_env(model));
registry.register(|model| VeniceProvider::from_env(model));
registry.register(|model| SnowflakeProvider::from_env(model));
registry.register(|model| XaiProvider::from_env(model));

let config_dir = choose_app_strategy(APP_STRATEGY.clone())
.expect("goose requires a home dir")
.config_dir();

register_custom_providers(&mut registry, &config_dir.join("custom_providers"));

registry
});

pub fn providers() -> Vec<ProviderMetadata> {
vec![
AnthropicProvider::metadata(),
AzureProvider::metadata(),
BedrockProvider::metadata(),
ClaudeCodeProvider::metadata(),
DatabricksProvider::metadata(),
GcpVertexAIProvider::metadata(),
GeminiCliProvider::metadata(),
// GithubCopilotProvider::metadata(),
GoogleProvider::metadata(),
GroqProvider::metadata(),
LiteLLMProvider::metadata(),
OllamaProvider::metadata(),
OpenAiProvider::metadata(),
OpenRouterProvider::metadata(),
SageMakerTgiProvider::metadata(),
VeniceProvider::metadata(),
SnowflakeProvider::metadata(),
XaiProvider::metadata(),
]
REGISTRY.all_metadata()
}

pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
Expand All @@ -71,7 +86,9 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {

return create_lead_worker_from_env(name, &model, &lead_model_name);
}
create_provider(name, model)

// Default: create regular provider using registry
REGISTRY.create(name, model)
}

/// Create a lead/worker provider from environment variables
Expand Down Expand Up @@ -133,8 +150,8 @@ fn create_lead_worker_from_env(
};

// Create the providers
let lead_provider = create_provider(&lead_provider_name, lead_model_config)?;
let worker_provider = create_provider(default_provider_name, worker_model_config)?;
let lead_provider = REGISTRY.create(&lead_provider_name, lead_model_config)?;
let worker_provider = REGISTRY.create(default_provider_name, worker_model_config)?;

// Create the lead/worker provider with configured settings
Ok(Arc::new(LeadWorkerProvider::new_with_settings(
Expand All @@ -146,31 +163,6 @@ fn create_lead_worker_from_env(
)))
}

fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
// We use Arc instead of Box to be able to clone for multiple async tasks
match name {
"anthropic" => Ok(Arc::new(AnthropicProvider::from_env(model)?)),
"aws_bedrock" => Ok(Arc::new(BedrockProvider::from_env(model)?)),
"azure_openai" => Ok(Arc::new(AzureProvider::from_env(model)?)),
"claude-code" => Ok(Arc::new(ClaudeCodeProvider::from_env(model)?)),
"databricks" => Ok(Arc::new(DatabricksProvider::from_env(model)?)),
"gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)),
"gemini-cli" => Ok(Arc::new(GeminiCliProvider::from_env(model)?)),
// "github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)),
"google" => Ok(Arc::new(GoogleProvider::from_env(model)?)),
"groq" => Ok(Arc::new(GroqProvider::from_env(model)?)),
"litellm" => Ok(Arc::new(LiteLLMProvider::from_env(model)?)),
"ollama" => Ok(Arc::new(OllamaProvider::from_env(model)?)),
"openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)),
"openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)),
"sagemaker_tgi" => Ok(Arc::new(SageMakerTgiProvider::from_env(model)?)),
"snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)),
"venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)),
"xai" => Ok(Arc::new(XaiProvider::from_env(model)?)),
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 2 additions & 0 deletions crates/goose/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod azureauth;
pub mod base;
pub mod bedrock;
pub mod claude_code;
mod custom_providers;
pub mod databricks;
pub mod embedding;
pub mod errors;
Expand All @@ -22,6 +23,7 @@ pub mod ollama;
pub mod openai;
pub mod openrouter;
pub mod pricing;
mod provider_registry;
pub mod sagemaker_tgi;
pub mod snowflake;
pub mod testprovider;
Expand Down
12 changes: 12 additions & 0 deletions crates/goose/src/providers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::utils::{get_model, handle_response_openai_compat};
use crate::impl_provider_default;
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::custom_providers::CustomProviderConfig;
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
use crate::utils::safe_truncate;
use anyhow::Result;
Expand Down Expand Up @@ -52,6 +53,17 @@ impl OllamaProvider {
})
}

pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result<Self> {
let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(OLLAMA_TIMEOUT));
let client = Client::builder().timeout(timeout).build()?;

Ok(Self {
client,
host: config.base_url,
model,
})
}

/// Get the base URL for Ollama API calls
fn get_base_url(&self) -> Result<Url, ProviderError> {
// OLLAMA_HOST is sometimes just the 'host' or 'host:port' without a scheme
Expand Down
29 changes: 29 additions & 0 deletions crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::impl_provider_default;
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::MessageStream;
use crate::providers::custom_providers::CustomProviderConfig;
use crate::providers::formats::openai::response_to_streaming_message;
use crate::providers::utils::handle_status_openai_compat;
use rmcp::model::Tool;
Expand Down Expand Up @@ -87,6 +88,34 @@ impl OpenAiProvider {
})
}

pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result<Self> {
let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(600));
let client = Client::builder().timeout(timeout).build()?;

let api_key = std::env::var(&config.api_key_env)
.map_err(|_| anyhow::anyhow!("Missing API key: {}", config.api_key_env))?;

let url = url::Url::parse(&config.base_url)?;
let host = format!(
"{}://{}:{}",
url.scheme(),
url.host_str().unwrap_or(""),
url.port_or_known_default().unwrap_or(443)
);
let base_path = url.path().trim_start_matches('/').to_string();

Ok(Self {
client,
host,
base_path,
api_key,
organization: None,
project: None,
model,
custom_headers: config.headers,
})
}

/// Helper function to add OpenAI-specific headers to a request
fn add_headers(&self, mut request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
// Add organization header if present
Expand Down
55 changes: 55 additions & 0 deletions crates/goose/src/providers/provider_registry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;

use super::base::{Provider, ProviderMetadata};
use crate::model::ModelConfig;

type ProviderConstructor = Box<dyn Fn(ModelConfig) -> Result<Arc<dyn Provider>> + Send + Sync>;

struct ProviderEntry {
metadata: ProviderMetadata,
constructor: ProviderConstructor,
}

pub struct ProviderRegistry {
entries: HashMap<String, ProviderEntry>,
}

impl ProviderRegistry {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}

pub fn register<P, F>(&mut self, constructor: F)
where
P: Provider + 'static,
F: Fn(ModelConfig) -> Result<P> + Send + Sync + 'static,
{
let metadata = P::metadata();
let name = metadata.name.clone();

self.entries.insert(
name,
ProviderEntry {
metadata,
constructor: Box::new(move |model| Ok(Arc::new(constructor(model)?))),
},
);
}

pub fn create(&self, name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
let entry = self
.entries
.get(name)
.ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", name))?;

(entry.constructor)(model)
}

pub fn all_metadata(&self) -> Vec<ProviderMetadata> {
self.entries.values().map(|e| e.metadata.clone()).collect()
}
}
Loading