diff --git a/Cargo.lock b/Cargo.lock index 61beaa6f1db1..c0632ed47090 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3423,6 +3423,7 @@ dependencies = [ "tracing", "tracing-appender", "tracing-subscriber", + "uuid", "webbrowser 1.0.4", "winapi", ] @@ -3514,6 +3515,7 @@ dependencies = [ "tracing-appender", "tracing-subscriber", "utoipa", + "uuid", ] [[package]] diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index 89b4b0e54650..0986ec3219e0 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -47,6 +47,7 @@ shlex = "1.3.0" async-trait = "0.1.86" base64 = "0.22.1" regex = "1.11.1" +uuid = { version = "1.11", features = ["v4"] } nix = { version = "0.30.1", features = ["process", "signal"] } tar = "0.4" # Web server dependencies diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index ff218f408e02..000f71ec64de 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -1,5 +1,6 @@ use cliclack::spinner; use console::style; +use etcetera::{choose_app_strategy, AppStrategy}; use goose::agents::extension::ToolInfo; use goose::agents::extension_manager::get_parameter_names; use goose::agents::platform_tools::{ @@ -7,6 +8,8 @@ use goose::agents::platform_tools::{ }; use goose::agents::Agent; use goose::agents::{extension::Envs, ExtensionConfig}; +use goose::config::base::APP_STRATEGY; +use goose::config::custom_providers::{CustomProviderConfig, ProviderEngine}; use goose::config::extensions::name_to_key; use goose::config::permission::PermissionLevel; use goose::config::{ @@ -14,6 +17,7 @@ use goose::config::{ PermissionManager, }; use goose::conversation::message::Message; +use goose::providers::base::ModelInfo; use goose::providers::{create, providers}; use rmcp::model::{Tool, ToolAnnotations}; use rmcp::object; @@ -221,6 +225,11 @@ pub async fn handle_configure() -> Result<(), Box> { "Configure Providers", "Change provider or update credentials", ) + .item( + "custom_providers", + "Custom Providers", + "Add OpenAI or Anthropic compatible APIs", + ) .item("add", "Add Extension", "Connect to a new extension") .item( "toggle", @@ -241,6 +250,7 @@ pub async fn handle_configure() -> Result<(), Box> { "remove" => remove_extension_dialog(), "settings" => configure_settings_dialog().await.and(Ok(())), "providers" => configure_provider_dialog().await.and(Ok(())), + "custom_providers" => configure_custom_provider_dialog(), _ => unreachable!(), } } @@ -1650,3 +1660,169 @@ pub async fn handle_openrouter_auth() -> Result<(), Box> { Ok(()) } + +pub fn configure_custom_provider_dialog() -> Result<(), Box> { + let action = cliclack::select("What would you like to do?") + .item( + "add", + "Add A Custom Provider", + "Add a new OpenAI/Anthropic compatible Provider", + ) + .item( + "remove", + "Remove Custom Provider", + "Remove an existing custom provider", + ) + .interact()?; + + match action { + "add" => { + let provider_type = cliclack::select("What type of API is this?") + .item( + "openai_compatible", + "OpenAI Compatible", + "Uses OpenAI API format", + ) + .item( + "anthropic_compatible", + "Anthropic Compatible", + "Uses Anthropic API format", + ) + .item( + "ollama_compatible", + "Ollama Compatible", + "Uses Ollama API format", + ) + .interact()?; + + let display_name: String = cliclack::input("What should we call this provider?") + .placeholder("Your Provider Name") + .validate(|input: &String| { + if input.is_empty() { + Err("Please enter a name") + } else { + Ok(()) + } + }) + .interact()?; + + let api_url: String = cliclack::input("Provider API URL:") + .placeholder("https://api.example.com/v1/messages") + .validate(|input: &String| { + if !input.starts_with("http://") && !input.starts_with("https://") { + Err("Inputed URL must start with either http:// or https://") + } else { + Ok(()) + } + }) + .interact()?; + + let api_key: String = cliclack::password("API key:").mask('▪').interact()?; + + let models_input: String = cliclack::input("Available models (seperate with commas):") + .placeholder("model-a, model-b, model-c") + .validate(|input: &String| { + if input.trim().is_empty() { + Err("Please enter at least one model name") + } else { + Ok(()) + } + }) + .interact()?; + + let models: Vec = models_input + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + // Generate name from display name + let id = format!("custom_{}", display_name.to_lowercase().replace(' ', "_")); + + // api-key -> keyring + let config = Config::global(); + let api_key_name = format!("{}_API_KEY", id.to_uppercase()); + config.set_secret(&api_key_name, Value::String(api_key))?; + + let display_name_clone = display_name.clone(); + + let model_infos: Vec = models + .iter() + .map(|name| ModelInfo::new(name.clone(), 128000)) + .collect(); + + // create final provider config + let provider_config = CustomProviderConfig { + name: id.clone(), + engine: match provider_type { + "openai_compatible" => ProviderEngine::OpenAI, + "anthropic_compatible" => ProviderEngine::Anthropic, + "ollama_compatible" => ProviderEngine::Ollama, + _ => unreachable!(), + }, + display_name: display_name_clone, + description: Some(format!("Custom {} provider", display_name)), + api_key_env: api_key_name, + base_url: api_url, + models: model_infos, + headers: None, + timeout_seconds: None, + }; + + let config_dir = choose_app_strategy(APP_STRATEGY.clone()) + .expect("goose requires a home dir") + .config_dir(); + let custom_providers_dir = config_dir.join("custom_providers"); + std::fs::create_dir_all(&custom_providers_dir)?; + + let json_content = serde_json::to_string_pretty(&provider_config)?; + let file_path = custom_providers_dir.join(format!("{}.json", id)); + std::fs::write(file_path, json_content)?; + + cliclack::outro(format!("Custom provider added: {}", display_name))?; + } + "remove" => { + // load custom providers from JSON files + let config_dir = choose_app_strategy(APP_STRATEGY.clone()) + .expect("goose requires a home dir") + .config_dir(); + let custom_providers_dir = config_dir.join("custom_providers"); + + let custom_providers = if custom_providers_dir.exists() { + goose::config::custom_providers::load_custom_providers(&custom_providers_dir)? + } else { + Vec::new() + }; + if custom_providers.is_empty() { + cliclack::outro("No custom providers added just yet.")?; + return Ok(()); + } + + let provider_items: Vec<_> = custom_providers + .iter() + .map(|p| (&p.name, &p.display_name, "Custom provider")) + .collect(); + + let selected_id = cliclack::select("Which custom provider would you like to remove?") + .items(&provider_items) + .interact()?; + + // TODO: remove api-key from keyring + + let config = Config::global(); + let api_key_name = format!("{}_API_KEY", selected_id.to_uppercase()); + let _ = config.delete_secret(&api_key_name); + + // remove json file + let file_path = custom_providers_dir.join(format!("{}.json", selected_id)); + if file_path.exists() { + std::fs::remove_file(file_path)?; + } + + cliclack::outro(format!("Removed custom provider: {}", selected_id))?; + } + _ => unreachable!(), + } + + Ok(()) +} diff --git a/crates/goose-server/Cargo.toml b/crates/goose-server/Cargo.toml index 555d64f9a308..7abfdd37f2e8 100644 --- a/crates/goose-server/Cargo.toml +++ b/crates/goose-server/Cargo.toml @@ -40,6 +40,7 @@ serde_yaml = "0.9.34" utoipa = { version = "4.1", features = ["axum_extras", "chrono"] } reqwest = { version = "0.12.9", features = ["json", "rustls-tls", "blocking", "multipart"], default-features = false } tokio-util = "0.7.15" +uuid = { version = "1.11", features = ["v4"] } [[bin]] name = "goosed" diff --git a/crates/goose-server/src/openapi.rs b/crates/goose-server/src/openapi.rs index d1c0305239c0..875c28d7dac5 100644 --- a/crates/goose-server/src/openapi.rs +++ b/crates/goose-server/src/openapi.rs @@ -367,6 +367,8 @@ impl<'__s> ToSchema<'__s> for AnnotatedSchema { super::routes::config_management::read_all_config, super::routes::config_management::providers, super::routes::config_management::upsert_permissions, + super::routes::config_management::create_custom_provider, + super::routes::config_management::remove_custom_provider, super::routes::agent::get_tools, super::routes::agent::add_sub_recipes, super::routes::reply::confirm_permission, @@ -397,6 +399,7 @@ impl<'__s> ToSchema<'__s> for AnnotatedSchema { super::routes::config_management::ExtensionQuery, super::routes::config_management::ToolPermission, super::routes::config_management::UpsertPermissionsQuery, + super::routes::config_management::CreateCustomProviderRequest, super::routes::reply::PermissionConfirmationRequest, super::routes::context::ContextManageRequest, super::routes::context::ContextManageResponse, diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index 8ce58b00b588..36779da59f77 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -241,7 +241,7 @@ async fn update_agent_provider( let agent = state .get_agent() .await - .map_err(|_| StatusCode::PRECONDITION_FAILED)?; + .map_err(|_e| StatusCode::PRECONDITION_FAILED)?; let config = Config::global(); let model = match payload @@ -259,7 +259,7 @@ async fn update_agent_provider( agent .update_provider(new_provider) .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + .map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(StatusCode::OK) } diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 9dc2d0eb912f..d450cc69b573 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -8,7 +8,6 @@ use axum::{ }; use etcetera::{choose_app_strategy, AppStrategy}; use goose::config::APP_STRATEGY; -use goose::config::{extensions::name_to_key, PermissionManager}; use goose::config::{Config, ConfigError}; use goose::config::{ExtensionConfigManager, ExtensionEntry}; use goose::model::ModelConfig; @@ -80,6 +79,15 @@ pub struct UpsertPermissionsQuery { pub tool_permissions: Vec, } +#[derive(Deserialize, ToSchema)] +pub struct CreateCustomProviderRequest { + pub provider_type: String, + pub display_name: String, + pub api_url: String, + pub api_key: String, + pub models: Vec, +} + #[utoipa::path( post, path = "/config/upsert", @@ -229,7 +237,7 @@ pub async fn add_extension( let extensions = ExtensionConfigManager::get_all().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - let key = name_to_key(&extension_query.name); + let key = goose::config::extensions::name_to_key(&extension_query.name); let is_update = extensions.iter().any(|e| e.config.key() == key); @@ -264,7 +272,7 @@ pub async fn remove_extension( ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; - let key = name_to_key(&name); + let key = goose::config::extensions::name_to_key(&name); match ExtensionConfigManager::remove(&key) { Ok(_) => Ok(Json(format!("Removed extension {}", name))), Err(_) => Err(StatusCode::NOT_FOUND), @@ -306,7 +314,68 @@ pub async fn providers( ) -> Result>, StatusCode> { verify_secret_key(&headers, &state)?; - let providers_metadata = get_providers(); + let mut providers_metadata = get_providers(); + + let config_dir = etcetera::choose_app_strategy(goose::config::base::APP_STRATEGY.clone()) + .expect("goose requires a home dir") + .config_dir(); + let custom_providers_dir = config_dir.join("custom_providers"); + + if custom_providers_dir.exists() { + if let Ok(entries) = std::fs::read_dir(&custom_providers_dir) { + for entry in entries.flatten() { + if let Some(extension) = entry.path().extension() { + if extension == "json" { + if let Ok(content) = std::fs::read_to_string(entry.path()) { + if let Ok(custom_provider) = serde_json::from_str::< + goose::config::custom_providers::CustomProviderConfig, + >(&content) + { + // CustomProviderConfig => ProviderMetadata + let default_model = custom_provider + .models + .first() + .map(|m| m.name.clone()) + .unwrap_or_default(); + + let metadata = goose::providers::base::ProviderMetadata { + name: custom_provider.name.clone(), + display_name: custom_provider.display_name.clone(), + description: custom_provider + .description + .clone() + .unwrap_or_else(|| { + format!( + "Custom {} provider", + custom_provider.display_name + ) + }), + default_model, + known_models: custom_provider.models.clone(), + model_doc_link: "Custom provider".to_string(), + config_keys: vec![ + goose::providers::base::ConfigKey::new( + &custom_provider.api_key_env, + true, + true, + None, + ), + goose::providers::base::ConfigKey::new( + "CUSTOM_PROVIDER_BASE_URL", + true, + false, + Some(&custom_provider.base_url), + ), + ], + }; + providers_metadata.push(metadata); + } + } + } + } + } + } + } let providers_response: Vec = providers_metadata .into_iter() @@ -493,7 +562,7 @@ pub async fn upsert_permissions( ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; - let mut permission_manager = PermissionManager::default(); + let mut permission_manager = goose::config::PermissionManager::default(); for tool_permission in &query.tool_permissions { permission_manager.update_user_permission( @@ -639,6 +708,132 @@ pub async fn get_current_model( }))) } +#[utoipa::path( + post, + path = "/config/custom-providers", + request_body = CreateCustomProviderRequest, + responses( + (status = 200, description = "Custom provider created successfully", body = String), + (status = 400, description = "Invalid request"), + (status = 500, description = "Internal server error") + ) +)] +pub async fn create_custom_provider( + State(state): State>, + headers: HeaderMap, + Json(request): Json, +) -> Result, StatusCode> { + verify_secret_key(&headers, &state)?; + + // use display name as name + let id = format!( + "custom_{}", + request.display_name.to_lowercase().replace(' ', "_") + ); + + // key naming convention + let api_key_name = format!( + "{}_API_KEY", + request + .display_name + .to_uppercase() + .replace(" ", "_") + .replace("-", "_") + ); + + // api-key -> keyring + let config = Config::global(); + config + .set_secret( + &api_key_name, + serde_json::Value::String(request.api_key.clone()), + ) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // models -> ModelInfo fmt + let model_infos: Vec = request + .models + .iter() + .map(|name| goose::providers::base::ModelInfo::new(name.clone(), 128000)) + .collect(); + + let provider_config = goose::config::custom_providers::CustomProviderConfig { + name: id.clone(), + engine: match request.provider_type.as_str() { + "openai_compatible" => goose::config::custom_providers::ProviderEngine::OpenAI, + "anthropic_compatible" => goose::config::custom_providers::ProviderEngine::Anthropic, + "ollama_compatible" => goose::config::custom_providers::ProviderEngine::Ollama, + _ => return Err(StatusCode::BAD_REQUEST), + }, + display_name: request.display_name, + description: None, + api_key_env: api_key_name, + base_url: request.api_url, + models: model_infos, + headers: None, + timeout_seconds: None, + }; + + // create custom provider + let config_dir = etcetera::choose_app_strategy(goose::config::base::APP_STRATEGY.clone()) + .expect("goose requires a home dir") + .config_dir(); + let custom_providers_dir = config_dir.join("custom_providers"); + std::fs::create_dir_all(&custom_providers_dir) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + let json_content = serde_json::to_string_pretty(&provider_config) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let file_path = custom_providers_dir.join(format!("{}.json", id)); + std::fs::write(file_path, json_content).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + if let Err(e) = goose::providers::refresh_custom_providers() { + tracing::warn!("Failed to refresh custom providers after creation: {}", e); + } + + Ok(Json(format!("Custom provider added - ID: {}", id))) +} + +#[utoipa::path( + delete, + path = "/config/custom-providers/{id}", + responses( + (status = 200, description = "Custom provider removed successfully", body = String), + (status = 404, description = "Provider not found"), + (status = 500, description = "Internal server error") + ) +)] +pub async fn remove_custom_provider( + State(state): State>, + headers: HeaderMap, + axum::extract::Path(id): axum::extract::Path, +) -> Result, StatusCode> { + verify_secret_key(&headers, &state)?; + + let config = Config::global(); + let api_key_name = format!("{}_API_KEY", id.to_uppercase()); + let _ = config.delete_secret(&api_key_name); + + // remove provider + let config_dir = etcetera::choose_app_strategy(goose::config::base::APP_STRATEGY.clone()) + .expect("goose requires a home dir") + .config_dir(); + let custom_providers_dir = config_dir.join("custom_providers"); + let file_path = custom_providers_dir.join(format!("{}.json", id)); + + if file_path.exists() { + std::fs::remove_file(file_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + if let Err(e) = goose::providers::refresh_custom_providers() { + tracing::warn!("Failed to refresh custom providers after deletion: {}", e); + } + } else { + return Err(StatusCode::NOT_FOUND); + } + + Ok(Json(format!("Removed custom provider: {}", id))) +} + pub fn routes(state: Arc) -> Router { Router::new() .route("/config", get(read_all_config)) @@ -656,6 +851,11 @@ pub fn routes(state: Arc) -> Router { .route("/config/validate", get(validate_config)) .route("/config/permissions", post(upsert_permissions)) .route("/config/current-model", get(get_current_model)) + .route("/config/custom-providers", post(create_custom_provider)) + .route( + "/config/custom-providers/{id}", + delete(remove_custom_provider), + ) .with_state(state) } diff --git a/crates/goose/src/config/custom_providers.rs b/crates/goose/src/config/custom_providers.rs new file mode 100644 index 000000000000..c6bbaeba6b1c --- /dev/null +++ b/crates/goose/src/config/custom_providers.rs @@ -0,0 +1,189 @@ +use crate::providers::base::ModelInfo; +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::Path; + +/// custom providers +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ProviderEngine { + OpenAI, + Ollama, + Anthropic, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CustomProviderConfig { + pub name: String, + pub engine: ProviderEngine, + pub display_name: String, + pub description: Option, + pub api_key_env: String, + pub base_url: String, + pub models: Vec, + pub headers: Option>, + pub timeout_seconds: Option, +} + +impl CustomProviderConfig { + pub fn id(&self) -> &str { + &self.name + } + + pub fn display_name(&self) -> &str { + &self.display_name + } + + pub fn models(&self) -> &[ModelInfo] { + &self.models + } +} + +/// load custom providers +pub fn load_custom_providers(dir: &Path) -> Result> { + let mut configs = Vec::new(); + + if !dir.exists() { + return Ok(configs); + } + + for entry in std::fs::read_dir(dir)? { + let entry = entry?; + let path = entry.path(); + + if path.extension().and_then(|s| s.to_str()) == Some("json") { + let content = std::fs::read_to_string(&path)?; + let config: CustomProviderConfig = serde_json::from_str(&content) + .map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", path.display(), e))?; + + configs.push(config); + } else { + } + } + + Ok(configs) +} + +/// register custom providers +pub fn register_custom_providers( + registry: &mut crate::providers::provider_registry::ProviderRegistry, + dir: &Path, +) -> Result<()> { + use crate::model::ModelConfig; + use crate::providers::{ + anthropic::AnthropicProvider, ollama::OllamaProvider, openai::OpenAiProvider, + }; + + let configs = load_custom_providers(dir)?; + + for config in configs { + let config_clone = config.clone(); + + match config.engine { + ProviderEngine::OpenAI => { + let description = config + .description + .clone() + .unwrap_or_else(|| format!("Custom {} provider", config.display_name)); + let default_model = config + .models + .first() + .map(|m| m.name.clone()) + .unwrap_or_default(); + let known_models: Vec = config + .models + .iter() + .map(|m| crate::providers::base::ModelInfo { + name: m.name.clone(), + context_limit: m.context_limit as usize, + input_token_cost: m.input_token_cost, + output_token_cost: m.output_token_cost, + currency: m.currency.clone(), + supports_cache_control: Some(m.supports_cache_control.unwrap_or(false)), + }) + .collect(); + + registry.register_with_name::( + config.name.clone(), + config.display_name.clone(), + description, + default_model, + known_models, + move |model: ModelConfig| { + OpenAiProvider::from_custom_config(model, config_clone.clone()) + }, + ); + } + ProviderEngine::Ollama => { + let description = config + .description + .clone() + .unwrap_or_else(|| format!("Custom {} provider", config.display_name)); + let default_model = config + .models + .first() + .map(|m| m.name.clone()) + .unwrap_or_default(); + let known_models: Vec = config + .models + .iter() + .map(|m| crate::providers::base::ModelInfo { + name: m.name.clone(), + context_limit: m.context_limit as usize, + input_token_cost: m.input_token_cost, + output_token_cost: m.output_token_cost, + currency: m.currency.clone(), + supports_cache_control: Some(m.supports_cache_control.unwrap_or(false)), + }) + .collect(); + + registry.register_with_name::( + config.name.clone(), + config.display_name.clone(), + description, + default_model, + known_models, + move |model: ModelConfig| { + OllamaProvider::from_custom_config(model, config_clone.clone()) + }, + ); + } + ProviderEngine::Anthropic => { + let description = config + .description + .clone() + .unwrap_or_else(|| format!("Custom {} provider", config.display_name)); + let default_model = config + .models + .first() + .map(|m| m.name.clone()) + .unwrap_or_default(); + let known_models: Vec = config + .models + .iter() + .map(|m| crate::providers::base::ModelInfo { + name: m.name.clone(), + context_limit: m.context_limit as usize, + input_token_cost: m.input_token_cost, + output_token_cost: m.output_token_cost, + currency: m.currency.clone(), + supports_cache_control: Some(m.supports_cache_control.unwrap_or(false)), + }) + .collect(); + + registry.register_with_name::( + config.name.clone(), + config.display_name.clone(), + description, + default_model, + known_models, + move |model: ModelConfig| { + AnthropicProvider::from_custom_config(model, config_clone.clone()) + }, + ); + } + } + } + Ok(()) +} diff --git a/crates/goose/src/config/mod.rs b/crates/goose/src/config/mod.rs index eaa40072ea5e..dda2a92d6682 100644 --- a/crates/goose/src/config/mod.rs +++ b/crates/goose/src/config/mod.rs @@ -1,4 +1,5 @@ pub mod base; +pub mod custom_providers; mod experiments; pub mod extensions; pub mod permission; @@ -6,6 +7,7 @@ pub mod signup_openrouter; pub use crate::agents::ExtensionConfig; pub use base::{Config, ConfigError, APP_STRATEGY}; +pub use custom_providers::CustomProviderConfig; pub use experiments::ExperimentManager; pub use extensions::{ExtensionConfigManager, ExtensionEntry}; pub use permission::PermissionManager; diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index aca6a4ef3896..0895091c8ee6 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -15,6 +15,7 @@ use super::formats::anthropic::{ create_request, get_usage, response_to_message, response_to_streaming_message, }; use super::utils::{emit_debug_trace, get_model, map_http_error_to_provider_error}; +use crate::config::custom_providers::CustomProviderConfig; use crate::conversation::message::Message; use crate::impl_provider_default; use crate::model::ModelConfig; @@ -65,6 +66,23 @@ impl AnthropicProvider { Ok(Self { api_client, model }) } + pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result { + let global_config = crate::config::Config::global(); + let api_key: String = global_config + .get_secret(&config.api_key_env) + .map_err(|_| anyhow::anyhow!("Missing API key: {}", config.api_key_env))?; + + let auth = AuthMethod::ApiKey { + header_name: "x-api-key".to_string(), + key: api_key, + }; + + let api_client = ApiClient::new(config.base_url, auth)? + .with_header("anthropic-version", ANTHROPIC_API_VERSION)?; + + Ok(Self { api_client, model }) + } + fn get_conditional_headers(&self) -> Vec<(&str, &str)> { let mut headers = Vec::new(); diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index 1ebf97344938..ac93ab4d2193 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use super::{ anthropic::AnthropicProvider, @@ -16,13 +16,18 @@ use super::{ ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, + provider_registry::ProviderRegistry, sagemaker_tgi::SageMakerTgiProvider, snowflake::SnowflakeProvider, venice::VeniceProvider, xai::XaiProvider, }; +use crate::config::base::APP_STRATEGY; +use crate::config::custom_providers::register_custom_providers; use crate::model::ModelConfig; use anyhow::Result; +use etcetera::{choose_app_strategy, AppStrategy}; +use once_cell::sync::Lazy; #[cfg(test)] use super::errors::ProviderError; @@ -39,27 +44,54 @@ fn default_fallback_turns() -> usize { 2 } +static REGISTRY: Lazy> = Lazy::new(|| { + let registry = ProviderRegistry::new().with_providers(|registry| { + registry.register::(|model| OpenAiProvider::from_env(model)); + registry.register::(|model| AnthropicProvider::from_env(model)); + registry.register::(|model| OllamaProvider::from_env(model)); + registry.register::(|model| AzureProvider::from_env(model)); + registry.register::(|model| BedrockProvider::from_env(model)); + registry.register::(|model| ClaudeCodeProvider::from_env(model)); + registry.register::(|model| DatabricksProvider::from_env(model)); + registry.register::(|model| GcpVertexAIProvider::from_env(model)); + registry.register::(|model| GeminiCliProvider::from_env(model)); + registry.register::(|model| GoogleProvider::from_env(model)); + registry.register::(|model| GroqProvider::from_env(model)); + registry.register::(|model| LiteLLMProvider::from_env(model)); + registry.register::(|model| OpenRouterProvider::from_env(model)); + registry.register::(|model| SageMakerTgiProvider::from_env(model)); + registry.register::(|model| SnowflakeProvider::from_env(model)); + registry.register::(|model| VeniceProvider::from_env(model)); + registry.register::(|model| XaiProvider::from_env(model)); + + load_custom_providers_into_registry(registry); + }); + RwLock::new(registry) +}); + +fn load_custom_providers_into_registry(registry: &mut ProviderRegistry) { + let config_dir = choose_app_strategy(APP_STRATEGY.clone()) + .expect("goose requires a home dir") + .config_dir(); + + if let Err(e) = register_custom_providers(registry, &config_dir.join("custom_providers")) { + tracing::warn!("Failed to load custom providers: {}", e); + } +} + pub fn providers() -> Vec { - vec![ - AnthropicProvider::metadata(), - AzureProvider::metadata(), - BedrockProvider::metadata(), - ClaudeCodeProvider::metadata(), - DatabricksProvider::metadata(), - GcpVertexAIProvider::metadata(), - GeminiCliProvider::metadata(), - // GithubCopilotProvider::metadata(), - GoogleProvider::metadata(), - GroqProvider::metadata(), - LiteLLMProvider::metadata(), - OllamaProvider::metadata(), - OpenAiProvider::metadata(), - OpenRouterProvider::metadata(), - SageMakerTgiProvider::metadata(), - VeniceProvider::metadata(), - SnowflakeProvider::metadata(), - XaiProvider::metadata(), - ] + REGISTRY.read().unwrap().all_metadata() +} + +pub fn refresh_custom_providers() -> Result<()> { + let mut registry = REGISTRY.write().unwrap(); + + registry.remove_custom_providers(); + + load_custom_providers_into_registry(&mut registry); + + tracing::info!("Custom providers refreshed"); + Ok(()) } pub fn create(name: &str, model: ModelConfig) -> Result> { @@ -71,7 +103,10 @@ pub fn create(name: &str, model: ModelConfig) -> Result> { return create_lead_worker_from_env(name, &model, &lead_model_name); } - create_provider(name, model) + + let result = REGISTRY.read().unwrap().create(name, model); + + result } /// Create a lead/worker provider from environment variables @@ -132,9 +167,15 @@ fn create_lead_worker_from_env( worker_config }; - // Create the providers - let lead_provider = create_provider(&lead_provider_name, lead_model_config)?; - let worker_provider = create_provider(default_provider_name, worker_model_config)?; + // create providers + let lead_provider = REGISTRY + .read() + .unwrap() + .create(&lead_provider_name, lead_model_config)?; + let worker_provider = REGISTRY + .read() + .unwrap() + .create(default_provider_name, worker_model_config)?; // Create the lead/worker provider with configured settings Ok(Arc::new(LeadWorkerProvider::new_with_settings( @@ -146,31 +187,6 @@ fn create_lead_worker_from_env( ))) } -fn create_provider(name: &str, model: ModelConfig) -> Result> { - // We use Arc instead of Box to be able to clone for multiple async tasks - match name { - "anthropic" => Ok(Arc::new(AnthropicProvider::from_env(model)?)), - "aws_bedrock" => Ok(Arc::new(BedrockProvider::from_env(model)?)), - "azure_openai" => Ok(Arc::new(AzureProvider::from_env(model)?)), - "claude-code" => Ok(Arc::new(ClaudeCodeProvider::from_env(model)?)), - "databricks" => Ok(Arc::new(DatabricksProvider::from_env(model)?)), - "gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)), - "gemini-cli" => Ok(Arc::new(GeminiCliProvider::from_env(model)?)), - // "github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)), - "google" => Ok(Arc::new(GoogleProvider::from_env(model)?)), - "groq" => Ok(Arc::new(GroqProvider::from_env(model)?)), - "litellm" => Ok(Arc::new(LiteLLMProvider::from_env(model)?)), - "ollama" => Ok(Arc::new(OllamaProvider::from_env(model)?)), - "openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)), - "openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)), - "sagemaker_tgi" => Ok(Arc::new(SageMakerTgiProvider::from_env(model)?)), - "snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)), - "venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)), - "xai" => Ok(Arc::new(XaiProvider::from_env(model)?)), - _ => Err(anyhow::anyhow!("Unknown provider: {}", name)), - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 3e04fba896ee..1f9857ca916c 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -23,6 +23,7 @@ pub mod ollama; pub mod openai; pub mod openrouter; pub mod pricing; +pub mod provider_registry; mod retry; pub mod sagemaker_tgi; pub mod snowflake; @@ -33,4 +34,4 @@ pub mod utils_universal_openai_stream; pub mod venice; pub mod xai; -pub use factory::{create, providers}; +pub use factory::{create, providers, refresh_custom_providers}; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 7f17420fcf5d..a92bca38edc4 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -3,6 +3,7 @@ use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{get_model, handle_response_openai_compat}; +use crate::config::custom_providers::CustomProviderConfig; use crate::conversation::message::Message; use crate::conversation::Conversation; use crate::impl_provider_default; @@ -74,6 +75,38 @@ impl OllamaProvider { Ok(Self { api_client, model }) } + pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result { + let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(OLLAMA_TIMEOUT)); + + // Parse and normalize the custom URL + let base = + if config.base_url.starts_with("http://") || config.base_url.starts_with("https://") { + config.base_url.clone() + } else { + format!("http://{}", config.base_url) + }; + + let mut base_url = Url::parse(&base) + .map_err(|e| anyhow::anyhow!("Invalid base URL '{}': {}", config.base_url, e))?; + + // Set default port if missing and not using standard ports + let explicit_default_port = + config.base_url.ends_with(":80") || config.base_url.ends_with(":443"); + let is_https = base_url.scheme() == "https"; + + if base_url.port().is_none() && !explicit_default_port && !is_https { + base_url + .set_port(Some(OLLAMA_DEFAULT_PORT)) + .map_err(|_| anyhow::anyhow!("Failed to set default port"))?; + } + + // No authentication for Ollama + let auth = AuthMethod::Custom(Box::new(NoAuth)); + let api_client = ApiClient::with_timeout(base_url.to_string(), auth, timeout)?; + + Ok(Self { api_client, model }) + } + async fn post(&self, payload: &Value) -> Result { let response = self .api_client diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index fa4211f128bc..3e0f6db714f8 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -20,6 +20,7 @@ use super::utils::{ emit_debug_trace, get_model, handle_response_openai_compat, handle_status_openai_compat, ImageFormat, }; +use crate::config::custom_providers::CustomProviderConfig; use crate::conversation::message::Message; use crate::impl_provider_default; use crate::model::ModelConfig; @@ -106,6 +107,49 @@ impl OpenAiProvider { }) } + pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result { + let global_config = crate::config::Config::global(); + let api_key: String = global_config + .get_secret(&config.api_key_env) + .map_err(|_e| anyhow::anyhow!("Missing API key: {}", config.api_key_env))?; + + let url = url::Url::parse(&config.base_url) + .map_err(|e| anyhow::anyhow!("Invalid base URL '{}': {}", config.base_url, e))?; + + let host = format!("{}://{}", url.scheme(), url.host_str().unwrap_or("")); + let base_path = url.path().trim_start_matches('/').to_string(); + let base_path = if base_path.is_empty() { + "v1/chat/completions".to_string() + } else { + base_path + }; + + let timeout_secs = config.timeout_seconds.unwrap_or(600); + let auth = AuthMethod::BearerToken(api_key); + let mut api_client = + ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?; + + // Add custom headers if present + if let Some(headers) = &config.headers { + let mut header_map = reqwest::header::HeaderMap::new(); + for (key, value) in headers { + let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())?; + let header_value = reqwest::header::HeaderValue::from_str(value)?; + header_map.insert(header_name, header_value); + } + api_client = api_client.with_headers(header_map)?; + } + + Ok(Self { + api_client, + base_path, + organization: None, + project: None, + model, + custom_headers: config.headers, + }) + } + async fn post(&self, payload: &Value) -> Result { let response = self .api_client diff --git a/crates/goose/src/providers/provider_registry.rs b/crates/goose/src/providers/provider_registry.rs new file mode 100644 index 000000000000..bae0737162a4 --- /dev/null +++ b/crates/goose/src/providers/provider_registry.rs @@ -0,0 +1,103 @@ +use super::base::{Provider, ProviderMetadata}; +use crate::model::ModelConfig; +use anyhow::Result; +use std::collections::HashMap; +use std::sync::Arc; + +type ProviderConstructor = Box Result> + Send + Sync>; + +struct ProviderEntry { + metadata: ProviderMetadata, + constructor: ProviderConstructor, +} + +pub struct ProviderRegistry { + entries: HashMap, +} + +impl ProviderRegistry { + pub fn new() -> Self { + Self { + entries: HashMap::new(), + } + } + + pub fn register(&mut self, constructor: F) + where + P: Provider + 'static, + F: Fn(ModelConfig) -> Result

+ Send + Sync + 'static, + { + let metadata = P::metadata(); + let name = metadata.name.clone(); + + self.entries.insert( + name, + ProviderEntry { + metadata, + constructor: Box::new(move |model| Ok(Arc::new(constructor(model)?))), + }, + ); + } + + /// create provider with custom name + pub fn register_with_name( + &mut self, + custom_name: String, + display_name: String, + description: String, + default_model: String, + known_models: Vec, + constructor: F, + ) where + P: Provider + 'static, + F: Fn(ModelConfig) -> Result

+ Send + Sync + 'static, + { + let base_metadata = P::metadata(); + let custom_metadata = ProviderMetadata { + name: custom_name.clone(), + display_name, + description, + default_model, + known_models, + model_doc_link: base_metadata.model_doc_link, + config_keys: base_metadata.config_keys, + }; + + self.entries.insert( + custom_name, + ProviderEntry { + metadata: custom_metadata, + constructor: Box::new(move |model| Ok(Arc::new(constructor(model)?))), + }, + ); + } + + pub fn with_providers(mut self, setup: F) -> Self + where + F: FnOnce(&mut Self), + { + setup(&mut self); + self + } + + pub fn create(&self, name: &str, model: ModelConfig) -> Result> { + let _available_providers: Vec<_> = self.entries.keys().collect(); + + let entry = self + .entries + .get(name) + .ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", name))?; + + let result = (entry.constructor)(model); + + result + } + + pub fn all_metadata(&self) -> Vec { + self.entries.values().map(|e| e.metadata.clone()).collect() + } + + pub fn remove_custom_providers(&mut self) { + self.entries.retain(|name, _| !name.starts_with("custom_")); + } +} diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 52642f80c7bf..8660caf64242 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -133,6 +133,78 @@ } } }, + "/config/custom-providers": { + "post": { + "tags": [ + "super::routes::config_management" + ], + "operationId": "create_custom_provider", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateCustomProviderRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Custom provider created successfully", + "content": { + "text/plain": { + "schema": { + "type": "string" + } + } + } + }, + "400": { + "description": "Invalid request" + }, + "500": { + "description": "Internal server error" + } + } + } + }, + "/config/custom-providers/{id}": { + "delete": { + "tags": [ + "super::routes::config_management" + ], + "operationId": "remove_custom_provider", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Custom provider removed successfully", + "content": { + "text/plain": { + "schema": { + "type": "string" + } + } + } + }, + "404": { + "description": "Provider not found" + }, + "500": { + "description": "Internal server error" + } + } + } + }, "/config/extensions": { "get": { "tags": [ @@ -1297,6 +1369,36 @@ } } }, + "CreateCustomProviderRequest": { + "type": "object", + "required": [ + "provider_type", + "display_name", + "api_url", + "api_key", + "models" + ], + "properties": { + "api_key": { + "type": "string" + }, + "api_url": { + "type": "string" + }, + "display_name": { + "type": "string" + }, + "models": { + "type": "array", + "items": { + "type": "string" + } + }, + "provider_type": { + "type": "string" + } + } + }, "CreateRecipeRequest": { "type": "object", "required": [ diff --git a/ui/desktop/src/api/sdk.gen.ts b/ui/desktop/src/api/sdk.gen.ts index d3b39f08533a..e519c1fa4633 100644 --- a/ui/desktop/src/api/sdk.gen.ts +++ b/ui/desktop/src/api/sdk.gen.ts @@ -1,7 +1,7 @@ // This file is auto-generated by @hey-api/openapi-ts import type { Options as ClientOptions, TDataShape, Client } from '@hey-api/client-fetch'; -import type { AddSubRecipesData, AddSubRecipesResponse2, GetToolsData, GetToolsResponse, ReadAllConfigData, ReadAllConfigResponse, BackupConfigData, BackupConfigResponse, GetExtensionsData, GetExtensionsResponse, AddExtensionData, AddExtensionResponse, RemoveExtensionData, RemoveExtensionResponse, InitConfigData, InitConfigResponse, UpsertPermissionsData, UpsertPermissionsResponse, ProvidersData, ProvidersResponse2, ReadConfigData, RecoverConfigData, RecoverConfigResponse, RemoveConfigData, RemoveConfigResponse, UpsertConfigData, UpsertConfigResponse, ValidateConfigData, ValidateConfigResponse, ConfirmPermissionData, ManageContextData, ManageContextResponse, CreateRecipeData, CreateRecipeResponse2, DecodeRecipeData, DecodeRecipeResponse2, EncodeRecipeData, EncodeRecipeResponse2, CreateScheduleData, CreateScheduleResponse, DeleteScheduleData, DeleteScheduleResponse, ListSchedulesData, ListSchedulesResponse2, UpdateScheduleData, UpdateScheduleResponse, InspectRunningJobData, InspectRunningJobResponse, KillRunningJobData, PauseScheduleData, PauseScheduleResponse, RunNowHandlerData, RunNowHandlerResponse, SessionsHandlerData, SessionsHandlerResponse, UnpauseScheduleData, UnpauseScheduleResponse, ListSessionsData, ListSessionsResponse, GetSessionHistoryData, GetSessionHistoryResponse } from './types.gen'; +import type { AddSubRecipesData, AddSubRecipesResponse2, GetToolsData, GetToolsResponse, ReadAllConfigData, ReadAllConfigResponse, BackupConfigData, BackupConfigResponse, CreateCustomProviderData, CreateCustomProviderResponse, RemoveCustomProviderData, RemoveCustomProviderResponse, GetExtensionsData, GetExtensionsResponse, AddExtensionData, AddExtensionResponse, RemoveExtensionData, RemoveExtensionResponse, InitConfigData, InitConfigResponse, UpsertPermissionsData, UpsertPermissionsResponse, ProvidersData, ProvidersResponse2, ReadConfigData, RecoverConfigData, RecoverConfigResponse, RemoveConfigData, RemoveConfigResponse, UpsertConfigData, UpsertConfigResponse, ValidateConfigData, ValidateConfigResponse, ConfirmPermissionData, ManageContextData, ManageContextResponse, CreateRecipeData, CreateRecipeResponse2, DecodeRecipeData, DecodeRecipeResponse2, EncodeRecipeData, EncodeRecipeResponse2, CreateScheduleData, CreateScheduleResponse, DeleteScheduleData, DeleteScheduleResponse, ListSchedulesData, ListSchedulesResponse2, UpdateScheduleData, UpdateScheduleResponse, InspectRunningJobData, InspectRunningJobResponse, KillRunningJobData, PauseScheduleData, PauseScheduleResponse, RunNowHandlerData, RunNowHandlerResponse, SessionsHandlerData, SessionsHandlerResponse, UnpauseScheduleData, UnpauseScheduleResponse, ListSessionsData, ListSessionsResponse, GetSessionHistoryData, GetSessionHistoryResponse } from './types.gen'; import { client as _heyApiClient } from './client.gen'; export type Options = ClientOptions & { @@ -50,6 +50,24 @@ export const backupConfig = (options?: Opt }); }; +export const createCustomProvider = (options: Options) => { + return (options.client ?? _heyApiClient).post({ + url: '/config/custom-providers', + ...options, + headers: { + 'Content-Type': 'application/json', + ...options?.headers + } + }); +}; + +export const removeCustomProvider = (options: Options) => { + return (options.client ?? _heyApiClient).delete({ + url: '/config/custom-providers/{id}', + ...options + }); +}; + export const getExtensions = (options?: Options) => { return (options?.client ?? _heyApiClient).get({ url: '/config/extensions', diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index eeb5a691b218..30c69651f64d 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -96,6 +96,14 @@ export type ContextManageResponse = { tokenCounts: Array; }; +export type CreateCustomProviderRequest = { + api_key: string; + api_url: string; + display_name: string; + models: Array; + provider_type: string; +}; + export type CreateRecipeRequest = { activities?: Array | null; author?: AuthorRequest | null; @@ -884,6 +892,62 @@ export type BackupConfigResponses = { export type BackupConfigResponse = BackupConfigResponses[keyof BackupConfigResponses]; +export type CreateCustomProviderData = { + body: CreateCustomProviderRequest; + path?: never; + query?: never; + url: '/config/custom-providers'; +}; + +export type CreateCustomProviderErrors = { + /** + * Invalid request + */ + 400: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type CreateCustomProviderResponses = { + /** + * Custom provider created successfully + */ + 200: string; +}; + +export type CreateCustomProviderResponse = CreateCustomProviderResponses[keyof CreateCustomProviderResponses]; + +export type RemoveCustomProviderData = { + body?: never; + path: { + id: string; + }; + query?: never; + url: '/config/custom-providers/{id}'; +}; + +export type RemoveCustomProviderErrors = { + /** + * Provider not found + */ + 404: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type RemoveCustomProviderResponses = { + /** + * Custom provider removed successfully + */ + 200: string; +}; + +export type RemoveCustomProviderResponse = RemoveCustomProviderResponses[keyof RemoveCustomProviderResponses]; + export type GetExtensionsData = { body?: never; path?: never; diff --git a/ui/desktop/src/components/ConfigContext.tsx b/ui/desktop/src/components/ConfigContext.tsx index fd97bca068ce..f39175f819e5 100644 --- a/ui/desktop/src/components/ConfigContext.tsx +++ b/ui/desktop/src/components/ConfigContext.tsx @@ -170,9 +170,15 @@ export const ConfigProvider: React.FC = ({ children }) => { const getProviders = useCallback( async (forceRefresh = false): Promise => { if (forceRefresh || providersList.length === 0) { - const response = await providers(); - setProvidersList(response.data || []); - return response.data || []; + try { + const response = await providers(); + const providersData = response.data || []; + setProvidersList(providersData); + return providersData; + } catch (error) { + console.error('Failed to fetch providers:', error); + return []; + } } return providersList; }, @@ -189,9 +195,11 @@ export const ConfigProvider: React.FC = ({ children }) => { // Load providers try { const providersResponse = await providers(); - setProvidersList(providersResponse.data || []); + const providersData = providersResponse.data || []; + setProvidersList(providersData); } catch (error) { console.error('Failed to load providers:', error); + setProvidersList([]); } // Load extensions diff --git a/ui/desktop/src/components/MarkdownContent.test.tsx b/ui/desktop/src/components/MarkdownContent.test.tsx index d48713d1edad..163ecb2afd84 100644 --- a/ui/desktop/src/components/MarkdownContent.test.tsx +++ b/ui/desktop/src/components/MarkdownContent.test.tsx @@ -1,5 +1,6 @@ import { describe, it, expect, vi } from 'vitest'; -import { render, screen, waitFor } from '@testing-library/react'; +import { render } from '@testing-library/react'; +import { screen, waitFor } from '@testing-library/dom'; import MarkdownContent from './MarkdownContent'; // Mock the icons to avoid import issues diff --git a/ui/desktop/src/components/settings/models/subcomponents/AddModelModal.tsx b/ui/desktop/src/components/settings/models/subcomponents/AddModelModal.tsx index 0a2ff2622177..d2068bcfa955 100644 --- a/ui/desktop/src/components/settings/models/subcomponents/AddModelModal.tsx +++ b/ui/desktop/src/components/settings/models/subcomponents/AddModelModal.tsx @@ -159,11 +159,14 @@ export const AddModelModal = ({ onClose, setView }: AddModelModalProps) => { // Add the "Custom model" option to each provider group formattedModelOptions.forEach((group) => { - group.options.push({ - value: 'custom', - label: 'Use custom model', - provider: group.options[0]?.provider, - }); + const providerName = group.options[0]?.provider; + if (providerName && !providerName.startsWith('custom_')) { + group.options.push({ + value: 'custom', + label: 'Use custom model', + provider: providerName, + }); + } }); setModelOptions(formattedModelOptions); diff --git a/ui/desktop/src/components/settings/providers/ProviderGrid.tsx b/ui/desktop/src/components/settings/providers/ProviderGrid.tsx index c30166106a7e..851fbe794a3c 100644 --- a/ui/desktop/src/components/settings/providers/ProviderGrid.tsx +++ b/ui/desktop/src/components/settings/providers/ProviderGrid.tsx @@ -1,8 +1,12 @@ -import React, { memo, useMemo, useCallback } from 'react'; +import React, { memo, useMemo, useCallback, useState } from 'react'; import { ProviderCard } from './subcomponents/ProviderCard'; +import CardContainer from './subcomponents/CardContainer'; import { ProviderModalProvider, useProviderModal } from './modal/ProviderModalProvider'; import ProviderConfigurationModal from './modal/ProviderConfiguationModal'; -import { ProviderDetails } from '../../../api'; +import { ProviderDetails, CreateCustomProviderRequest } from '../../../api'; +import { Plus } from 'lucide-react'; +import { Dialog, DialogContent, DialogHeader, DialogTitle } from '../../ui/dialog'; +import CustomProviderForm from './modal/subcomponents/forms/CustomProviderForm'; const GridLayout = memo(function GridLayout({ children }: { children: React.ReactNode }) { return ( @@ -18,6 +22,27 @@ const GridLayout = memo(function GridLayout({ children }: { children: React.Reac ); }); +const CustomProviderCard = memo(function CustomProviderCard({ onClick }: { onClick: () => void }) { + return ( + + +

+
Add
+
Custom Provider
+
+ + } + grayedOut={false} + borderStyle="dashed" + /> + ); +}); + // Memoize the ProviderCards component const ProviderCards = memo(function ProviderCards({ providers, @@ -31,6 +56,7 @@ const ProviderCards = memo(function ProviderCards({ onProviderLaunch: (provider: ProviderDetails) => void; }) { const { openModal } = useProviderModal(); + const [showCustomProviderModal, setShowCustomProviderModal] = useState(false); // Memoize these functions so they don't get recreated on every render const configureProviderViaModal = useCallback( @@ -42,7 +68,7 @@ const ProviderCards = memo(function ProviderCards({ refreshProviders(); } }, - onDelete: () => { + onDelete: (_values: unknown) => { if (refreshProviders) { refreshProviders(); } @@ -56,7 +82,7 @@ const ProviderCards = memo(function ProviderCards({ const deleteProviderConfigViaModal = useCallback( (provider: ProviderDetails) => { openModal(provider, { - onDelete: () => { + onDelete: (_values: unknown) => { // Only refresh if the function is provided if (refreshProviders) { refreshProviders(); @@ -68,12 +94,30 @@ const ProviderCards = memo(function ProviderCards({ [openModal, refreshProviders] ); + const handleCreateCustomProvider = useCallback( + async (data: CreateCustomProviderRequest) => { + try { + const { createCustomProvider } = await import('../../../api'); + await createCustomProvider({ body: data }); + setShowCustomProviderModal(false); + if (refreshProviders) { + refreshProviders(); + } + } catch (error) { + console.error('Failed to create custom provider:', error); + } + }, + [refreshProviders] + ); + // We don't need an intermediate function here // Just pass the onProviderLaunch directly // Use useMemo to memoize the cards array const providerCards = useMemo(() => { - return providers.map((provider) => ( + // providers needs to be an array + const providersArray = Array.isArray(providers) ? providers : []; + const cards = providersArray.map((provider) => ( )); + + if (!isOnboarding) { + cards.push( + setShowCustomProviderModal(true)} /> + ); + } + + return cards; }, [ providers, isOnboarding, @@ -91,7 +143,23 @@ const ProviderCards = memo(function ProviderCards({ onProviderLaunch, ]); - return <>{providerCards}; + return ( + <> + {providerCards} + + + + + Add Custom Provider + + setShowCustomProviderModal(false)} + /> + + + + ); }); export default memo(function ProviderGrid({ diff --git a/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx b/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx index 67070265b4df..4cdaf016db9f 100644 --- a/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx +++ b/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx @@ -3,7 +3,7 @@ import { Dialog, DialogContent, DialogDescription, - DialogFooter, + // DialogFooter, DialogHeader, DialogTitle, } from '../../../ui/dialog'; @@ -18,7 +18,7 @@ import OllamaForm from './subcomponents/forms/OllamaForm'; import { useConfig } from '../../../ConfigContext'; import { useModelAndProvider } from '../../../ModelAndProviderContext'; import { AlertTriangle } from 'lucide-react'; -import { ConfigKey } from '../../../../api'; +import { ConfigKey, removeCustomProvider } from '../../../../api'; interface FormValues { [key: string]: string | number | boolean | null; @@ -162,13 +162,21 @@ export default function ProviderConfigurationModal() { } try { - // Remove the provider configuration - // get the keys - const params = currentProvider.metadata.config_keys; - - // go through the keys are remove them - for (const param of params) { - await remove(param.name, param.secret); + const isCustomProvider = currentProvider.name.startsWith('custom_'); + + if (isCustomProvider) { + await removeCustomProvider({ + path: { id: currentProvider.name }, + }); + } else { + // Remove the provider configuration + // get the keys + const params = currentProvider.metadata.config_keys; + + // go through the keys are remove them + for (const param of params) { + await remove(param.name, param.secret); + } } // Call onDelete callback if provided @@ -235,23 +243,21 @@ export default function ProviderConfigurationModal() { ) : null} - - { - setShowDeleteConfirmation(false); - setIsActiveProvider(false); - }} - canDelete={isConfigured && !isActiveProvider} // Disable delete button for active provider - providerName={currentProvider.metadata.display_name} - isActiveProvider={isActiveProvider} // Pass this to actions for button state - /> - + { + setShowDeleteConfirmation(false); + setIsActiveProvider(false); + }} + canDelete={isConfigured && !isActiveProvider} // Disable delete button for active provider + providerName={currentProvider.metadata.display_name} + isActiveProvider={isActiveProvider} // Pass this to actions for button state + /> ); diff --git a/ui/desktop/src/components/settings/providers/modal/subcomponents/forms/CustomProviderForm.tsx b/ui/desktop/src/components/settings/providers/modal/subcomponents/forms/CustomProviderForm.tsx new file mode 100644 index 000000000000..842b2cd5ead0 --- /dev/null +++ b/ui/desktop/src/components/settings/providers/modal/subcomponents/forms/CustomProviderForm.tsx @@ -0,0 +1,182 @@ +import React, { useState } from 'react'; +import { Input } from '../../../../../ui/input'; +import { Select } from '../../../../../ui/Select'; +import { Button } from '../../../../../ui/button'; +import { SecureStorageNotice } from '../SecureStorageNotice'; +import { Checkbox } from '../../../../../ui/checkbox'; + +interface CustomProviderFormProps { + onSubmit: (data: { + provider_type: string; + display_name: string; + api_url: string; + api_key: string; + models: string[]; + }) => void; + onCancel: () => void; +} + +export default function CustomProviderForm({ onSubmit, onCancel }: CustomProviderFormProps) { + const [providerType, setProviderType] = useState('openai_compatible'); + const [displayName, setDisplayName] = useState(''); + const [apiUrl, setApiUrl] = useState(''); + const [apiKey, setApiKey] = useState(''); + const [models, setModels] = useState(''); + const [isLocalModel, setIsLocalModel] = useState(false); + const [validationErrors, setValidationErrors] = useState>({}); + + const handleLocalModels = (checked: boolean) => { + setIsLocalModel(checked); + if (checked) { + setApiKey('notrequired'); + } else { + setApiKey(''); + } + }; + + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault(); + + const errors: Record = {}; + if (!displayName) errors.displayName = 'Display name is required'; + if (!apiUrl) errors.apiUrl = 'API URL is required'; + if (!isLocalModel && !apiKey) errors.apiKey = 'API key is required'; + if (!models) errors.models = 'At least one model is required'; + + if (Object.keys(errors).length > 0) { + setValidationErrors(errors); + return; + } + + // parse custom models (separate with commas) + const modelList = models + .split(',') + .map((m) => m.trim()) + .filter((m) => m); + + onSubmit({ + provider_type: providerType, + display_name: displayName, + api_url: apiUrl, + api_key: apiKey, + models: modelList, + }); + }; + + return ( +
+
+ + setDisplayName(e.target.value)} + placeholder="Your Provider Name" + className={validationErrors.displayName ? 'border-red-500' : ''} + /> + {validationErrors.displayName && ( +

{validationErrors.displayName}

+ )} +
+ +
+ + setApiUrl(e.target.value)} + placeholder="https://api.example.com/v1/messages" + className={validationErrors.apiUrl ? 'border-red-500' : ''} + /> + {validationErrors.apiUrl && ( +

{validationErrors.apiUrl}

+ )} +
+ +
+ + setApiKey(e.target.value)} + placeholder="Your API key" + className={validationErrors.apiKey ? 'border-red-500' : ''} + disabled={isLocalModel} + /> + {validationErrors.apiKey && ( +

{validationErrors.apiKey}

+ )} + +
+ + +
+
+ +
+ + setModels(e.target.value)} + placeholder="model-a, model-b, model-c" + className={validationErrors.models ? 'border-red-500' : ''} + /> + {validationErrors.models && ( +

{validationErrors.models}

+ )} +
+ + + +
+ + +
+ + ); +} diff --git a/ui/desktop/src/components/settings/providers/modal/subcomponents/forms/DefaultProviderSetupForm.tsx b/ui/desktop/src/components/settings/providers/modal/subcomponents/forms/DefaultProviderSetupForm.tsx index 294f3052de1b..6f8e2dd69476 100644 --- a/ui/desktop/src/components/settings/providers/modal/subcomponents/forms/DefaultProviderSetupForm.tsx +++ b/ui/desktop/src/components/settings/providers/modal/subcomponents/forms/DefaultProviderSetupForm.tsx @@ -3,9 +3,7 @@ import { Input } from '../../../../../ui/input'; import { useConfig } from '../../../../../ConfigContext'; // Adjust this import path as needed import { ProviderDetails, ConfigKey } from '../../../../../../api'; -interface ValidationErrors { - [key: string]: string; -} +type ValidationErrors = Record; interface DefaultProviderSetupFormProps { configValues: Record; @@ -36,33 +34,30 @@ export default function DefaultProviderSetupForm({ // Try to load actual values from config for each parameter that is not secret for (const parameter of parameters) { - if (parameter.required) { - try { - // Check if there's a stored value in the config system - const configKey = `${parameter.name}`; - const configResponse = await read(configKey, parameter.secret || false); - - if (configResponse) { - // Use the value from the config provider - newValues[parameter.name] = String(configResponse); - } else if ( - parameter.default !== undefined && - parameter.default !== null && - !configValues[parameter.name] - ) { - // Fall back to default value if no config value exists - newValues[parameter.name] = String(parameter.default); - } - } catch (error) { - console.error(`Failed to load config for ${parameter.name}:`, error); - // Fall back to default if read operation fails - if ( - parameter.default !== undefined && - parameter.default !== null && - !configValues[parameter.name] - ) { - newValues[parameter.name] = String(parameter.default); - } + try { + // Check if there's a stored value in the config system + const configKey = `${parameter.name}`; + const configResponse = await read(configKey, parameter.secret || false); + + if (configResponse) { + newValues[parameter.name] = parameter.secret ? 'true' : String(configResponse); + } else if ( + parameter.default !== undefined && + parameter.default !== null && + !configValues[parameter.name] + ) { + // Fall back to default value if no config value exists + newValues[parameter.name] = String(parameter.default); + } + } catch (error) { + console.error(`Failed to load config for ${parameter.name}:`, error); + // Fall back to default if read operation fails + if ( + parameter.default !== undefined && + parameter.default !== null && + !configValues[parameter.name] + ) { + newValues[parameter.name] = String(parameter.default); } } } @@ -85,6 +80,11 @@ export default function DefaultProviderSetupForm({ return parameters.filter((param) => param.required === true); }, [parameters]); + // TODO: show all params, not just required ones + // const allParameters = useMemo(() => { + // return parameters; + // }, [parameters]); + // Helper function to generate appropriate placeholder text const getPlaceholder = (parameter: ConfigKey): string => { // If default is defined and not null, show it @@ -92,8 +92,30 @@ export default function DefaultProviderSetupForm({ return `Default: ${parameter.default}`; } - // Otherwise, use the parameter name as a hint - return parameter.name.toUpperCase(); + const name = parameter.name.toLowerCase(); + if (name.includes('api_key')) return 'Your API key'; + if (name.includes('api_url') || name.includes('host')) return 'https://api.example.com'; + if (name.includes('models')) return 'model-a, model-b'; + + return parameter.name + .replace(/_/g, ' ') + .replace(/([A-Z])/g, ' $1') + .replace(/^./, (str) => str.toUpperCase()) + .trim(); + }; + + // helper for custom labels + const getFieldLabel = (parameter: ConfigKey): string => { + const name = parameter.name.toLowerCase(); + if (name.includes('api_key')) return 'API Key'; + if (name.includes('api_url') || name.includes('host')) return 'API Host'; + if (name.includes('models')) return 'Models'; + + return parameter.name + .replace(/_/g, ' ') + .replace(/([A-Z])/g, ' $1') + .replace(/^./, (str) => str.toUpperCase()) + .trim(); }; if (isLoading) { @@ -111,25 +133,30 @@ export default function DefaultProviderSetupForm({ requiredParameters.map((parameter) => (
+ onChange={(e: React.ChangeEvent) => { + console.log(`Setting ${parameter.name} to:`, e.target.value); setConfigValues((prev) => ({ ...prev, [parameter.name]: e.target.value, - })) - } + })); + }} placeholder={getPlaceholder(parameter)} className={`w-full h-14 px-4 font-regular rounded-lg shadow-none ${ validationErrors[parameter.name] ? 'border-2 border-red-500' : 'border border-borderSubtle hover:border-borderStandard' } bg-background-default text-lg placeholder:text-textSubtle font-regular text-textStandard`} - required={true} + required={parameter.required} /> + {validationErrors[parameter.name] && ( +

{validationErrors[parameter.name]}

+ )}
)) )} diff --git a/ui/desktop/src/components/settings/providers/subcomponents/CardContainer.tsx b/ui/desktop/src/components/settings/providers/subcomponents/CardContainer.tsx index 5a6e02e9e496..4d9225768b57 100644 --- a/ui/desktop/src/components/settings/providers/subcomponents/CardContainer.tsx +++ b/ui/desktop/src/components/settings/providers/subcomponents/CardContainer.tsx @@ -6,6 +6,7 @@ interface CardContainerProps { onClick: () => void; grayedOut: boolean; testId?: string; + borderStyle?: 'solid' | 'dashed'; } function GlowingRing() { @@ -33,6 +34,7 @@ export default function CardContainer({ onClick, grayedOut = false, testId, + borderStyle = 'solid', }: CardContainerProps) { return (
}
{/* Apply opacity only to the header when grayed out */} diff --git a/ui/desktop/src/components/ui/button.tsx b/ui/desktop/src/components/ui/button.tsx index 04cc74cb6298..ff2516a83d21 100644 --- a/ui/desktop/src/components/ui/button.tsx +++ b/ui/desktop/src/components/ui/button.tsx @@ -4,7 +4,7 @@ import { cva, type VariantProps } from 'class-variance-authority'; import { cn } from '../../utils'; const buttonVariants = cva( - "inline-flex items-center justify-center gap-2 whitespace-nowrap text-sm transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[1px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive", + "inline-flex items-center justify-center gap-2 whitespace-nowrap text-sm transition-all cursor-pointer disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[1px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive", { variants: { variant: {