diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index f1824d935708..cf196b871762 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -8,7 +8,7 @@ use goose::agents::platform_tools::{ }; use goose::agents::Agent; use goose::agents::{extension::Envs, ExtensionConfig}; -use goose::config::custom_providers::CustomProviderConfig; +use goose::config::declarative_providers::{create_custom_provider, remove_custom_provider}; use goose::config::extensions::{ get_all_extension_names, get_all_extensions, get_enabled_extensions, get_extension_by_name, name_to_key, remove_extension, set_extension, set_extension_enabled, @@ -425,7 +425,7 @@ pub async fn configure_provider_dialog() -> Result> { // Create selection items from provider metadata let provider_items: Vec<(&String, &str, &str)> = available_providers .iter() - .map(|p| (&p.name, p.display_name.as_str(), p.description.as_str())) + .map(|(p, _)| (&p.name, p.display_name.as_str(), p.description.as_str())) .collect(); // Get current default provider if it exists @@ -439,9 +439,9 @@ pub async fn configure_provider_dialog() -> Result> { .interact()?; // Get the selected provider's metadata - let provider_meta = available_providers + let (provider_meta, _) = available_providers .iter() - .find(|p| &p.name == provider_name) + .find(|(p, _)| &p.name == provider_name) .expect("Selected provider must exist in metadata"); // Configure required provider keys @@ -1915,7 +1915,7 @@ fn add_provider() -> Result<(), Box> { .initial_value(true) .interact()?; - CustomProviderConfig::create_and_save( + create_custom_provider( provider_type, display_name.clone(), api_url, @@ -1929,9 +1929,9 @@ fn add_provider() -> Result<(), Box> { } fn remove_provider() -> Result<(), Box> { - let custom_providers_dir = goose::config::custom_providers::custom_providers_dir(); + let custom_providers_dir = goose::config::declarative_providers::custom_providers_dir(); let custom_providers = if custom_providers_dir.exists() { - goose::config::custom_providers::load_custom_providers(&custom_providers_dir)? + goose::config::declarative_providers::load_custom_providers(&custom_providers_dir)? } else { Vec::new() }; @@ -1950,7 +1950,7 @@ fn remove_provider() -> Result<(), Box> { .items(&provider_items) .interact()?; - CustomProviderConfig::remove(selected_id)?; + remove_custom_provider(selected_id)?; cliclack::outro(format!("Removed custom provider: {}", selected_id))?; Ok(()) } diff --git a/crates/goose-server/src/auth.rs b/crates/goose-server/src/auth.rs index ba13f59b5eee..a18c83972fdc 100644 --- a/crates/goose-server/src/auth.rs +++ b/crates/goose-server/src/auth.rs @@ -10,6 +10,9 @@ pub async fn check_token( request: Request, next: Next, ) -> Result { + if request.uri().path() == "/status" { + return Ok(next.run(request).await); + } let secret_key = request .headers() .get("X-Secret-Key") diff --git a/crates/goose-server/src/openapi.rs b/crates/goose-server/src/openapi.rs index db13b3d9f218..42df6681f107 100644 --- a/crates/goose-server/src/openapi.rs +++ b/crates/goose-server/src/openapi.rs @@ -5,7 +5,7 @@ use goose::config::permission::PermissionLevel; use goose::config::ExtensionEntry; use goose::conversation::Conversation; use goose::permission::permission_confirmation::PrincipalType; -use goose::providers::base::{ConfigKey, ModelInfo, ProviderMetadata}; +use goose::providers::base::{ConfigKey, ModelInfo, ProviderMetadata, ProviderType}; use goose::session::{Session, SessionInsights}; use rmcp::model::{ Annotations, Content, EmbeddedResource, Icon, ImageContent, JsonObject, RawAudioContent, @@ -14,6 +14,9 @@ use rmcp::model::{ }; use utoipa::{OpenApi, ToSchema}; +use goose::config::declarative_providers::{ + DeclarativeProviderConfig, LoadedProvider, ProviderEngine, +}; use goose::conversation::message::{ ContextLengthExceeded, FrontendToolRequest, Message, MessageContent, MessageMetadata, RedactedThinkingContent, SummarizationRequested, ThinkingContent, ToolConfirmationRequest, @@ -335,6 +338,8 @@ derive_utoipa!(Icon as IconSchema); super::routes::config_management::get_provider_models, super::routes::config_management::upsert_permissions, super::routes::config_management::create_custom_provider, + super::routes::config_management::get_custom_provider, + super::routes::config_management::update_custom_provider, super::routes::config_management::remove_custom_provider, super::routes::agent::start_agent, super::routes::agent::resume_agent, @@ -386,7 +391,7 @@ derive_utoipa!(Icon as IconSchema); super::routes::config_management::ExtensionQuery, super::routes::config_management::ToolPermission, super::routes::config_management::UpsertPermissionsQuery, - super::routes::config_management::CreateCustomProviderRequest, + super::routes::config_management::UpdateCustomProviderRequest, super::routes::reply::PermissionConfirmationRequest, super::routes::reply::ChatRequest, super::routes::context::ContextManageRequest, @@ -420,6 +425,10 @@ derive_utoipa!(Icon as IconSchema); JsonObjectSchema, RoleSchema, ProviderMetadata, + ProviderType, + LoadedProvider, + ProviderEngine, + DeclarativeProviderConfig, ExtensionEntry, ExtensionConfig, ConfigKey, diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index d479147f648c..c1b1de5d2777 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -1,15 +1,17 @@ use crate::routes::utils::check_provider_configured; use crate::state::AppState; +use axum::routing::put; use axum::{ extract::Path, routing::{delete, get, post}, Json, Router, }; +use goose::config::declarative_providers::LoadedProvider; use goose::config::paths::Paths; use goose::config::ExtensionEntry; use goose::config::{Config, ConfigError}; use goose::model::ModelConfig; -use goose::providers::base::ProviderMetadata; +use goose::providers::base::{ProviderMetadata, ProviderType}; use goose::providers::pricing::{ get_all_pricing, get_model_pricing, parse_model_id, refresh_pricing, }; @@ -57,6 +59,7 @@ pub struct ProviderDetails { pub name: String, pub metadata: ProviderMetadata, pub is_configured: bool, + pub provider_type: ProviderType, } #[derive(Serialize, ToSchema)] @@ -76,8 +79,8 @@ pub struct UpsertPermissionsQuery { } #[derive(Deserialize, ToSchema)] -pub struct CreateCustomProviderRequest { - pub provider_type: String, +pub struct UpdateCustomProviderRequest { + pub engine: String, pub display_name: String, pub api_url: String, pub api_key: String, @@ -225,9 +228,7 @@ pub async fn add_extension( (status = 500, description = "Internal server error") ) )] -pub async fn remove_extension( - axum::extract::Path(name): axum::extract::Path, -) -> Result, StatusCode> { +pub async fn remove_extension(Path(name): Path) -> Result, StatusCode> { let key = goose::config::extensions::name_to_key(&name); goose::config::remove_extension(&key); Ok(Json(format!("Removed extension {}", name))) @@ -258,72 +259,17 @@ pub async fn read_all_config() -> Result, StatusCode> { ) )] pub async fn providers() -> Result>, StatusCode> { - let mut providers_metadata = get_providers().await; - - let custom_providers_dir = goose::config::custom_providers::custom_providers_dir(); - - if custom_providers_dir.exists() { - if let Ok(entries) = std::fs::read_dir(&custom_providers_dir) { - for entry in entries.flatten() { - if let Some(extension) = entry.path().extension() { - if extension == "json" { - if let Ok(content) = std::fs::read_to_string(entry.path()) { - if let Ok(custom_provider) = serde_json::from_str::< - goose::config::custom_providers::CustomProviderConfig, - >(&content) - { - // CustomProviderConfig => ProviderMetadata - let default_model = custom_provider - .models - .first() - .map(|m| m.name.clone()) - .unwrap_or_default(); - - let metadata = goose::providers::base::ProviderMetadata { - name: custom_provider.name.clone(), - display_name: custom_provider.display_name.clone(), - description: custom_provider - .description - .clone() - .unwrap_or_else(|| { - format!("{} (custom)", custom_provider.display_name) - }), - default_model, - known_models: custom_provider.models.clone(), - model_doc_link: "Custom provider".to_string(), - config_keys: vec![ - goose::providers::base::ConfigKey::new( - &custom_provider.api_key_env, - true, - true, - None, - ), - goose::providers::base::ConfigKey::new( - "CUSTOM_PROVIDER_BASE_URL", - true, - false, - Some(&custom_provider.base_url), - ), - ], - }; - providers_metadata.push(metadata); - } - } - } - } - } - } - } - - let providers_response: Vec = providers_metadata + let providers = get_providers().await; + let providers_response: Vec = providers .into_iter() - .map(|metadata| { - let is_configured = check_provider_configured(&metadata); + .map(|(metadata, provider_type)| { + let is_configured = check_provider_configured(&metadata, provider_type); ProviderDetails { name: metadata.name.clone(), metadata, is_configured, + provider_type, } }) .collect(); @@ -347,11 +293,28 @@ pub async fn providers() -> Result>, StatusCode> { pub async fn get_provider_models( Path(name): Path, ) -> Result>, StatusCode> { - let all = get_providers().await; - let Some(metadata) = all.into_iter().find(|m| m.name == name) else { + let loaded_provider = goose::config::declarative_providers::load_provider(name.as_str()).ok(); + // TODO(Douwe): support a get models url for custom providers + if let Some(loaded_provider) = loaded_provider { + return Ok(Json( + loaded_provider + .config + .models + .into_iter() + .map(|m| m.name) + .collect::>(), + )); + } + + let all = get_providers() + .await + .into_iter() + //.map(|(m, p)| m) + .collect::>(); + let Some((metadata, provider_type)) = all.into_iter().find(|(m, _)| m.name == name) else { return Err(StatusCode::BAD_REQUEST); }; - if !check_provider_configured(&metadata) { + if !check_provider_configured(&metadata, provider_type) { return Err(StatusCode::BAD_REQUEST); } @@ -449,12 +412,9 @@ pub async fn get_pricing( } } } else { - // Get only configured providers' pricing - let providers_metadata = get_providers().await; - - for metadata in providers_metadata { + for (metadata, provider_type) in get_providers().await { // Skip unconfigured providers if filtering - if !check_provider_configured(&metadata) { + if !check_provider_configured(&metadata, provider_type) { continue; } @@ -647,25 +607,10 @@ pub async fn validate_config() -> Result, StatusCode> { } } -#[utoipa::path( - get, - path = "/config/current-model", - responses( - (status = 200, description = "Current model retrieved successfully", body = String), - ) -)] -pub async fn get_current_model() -> Result, StatusCode> { - let current_model = goose::providers::base::get_current_model(); - - Ok(Json(serde_json::json!({ - "model": current_model - }))) -} - #[utoipa::path( post, path = "/config/custom-providers", - request_body = CreateCustomProviderRequest, + request_body = UpdateCustomProviderRequest, responses( (status = 200, description = "Custom provider created successfully", body = String), (status = 400, description = "Invalid request"), @@ -673,10 +618,10 @@ pub async fn get_current_model() -> Result, StatusCode> { ) )] pub async fn create_custom_provider( - Json(request): Json, + Json(request): Json, ) -> Result, StatusCode> { - let config = goose::config::custom_providers::CustomProviderConfig::create_and_save( - &request.provider_type, + let config = goose::config::declarative_providers::create_custom_provider( + &request.engine, request.display_name, request.api_url, request.api_key, @@ -692,6 +637,24 @@ pub async fn create_custom_provider( Ok(Json(format!("Custom provider added - ID: {}", config.id()))) } +#[utoipa::path( + get, + path = "/config/custom-providers/{id}", + responses( + (status = 200, description = "Custom provider retrieved successfully", body = LoadedProvider), + (status = 404, description = "Provider not found"), + (status = 500, description = "Internal server error") + ) +)] +pub async fn get_custom_provider( + Path(id): Path, +) -> Result, StatusCode> { + let loaded_provider = goose::config::declarative_providers::load_provider(id.as_str()) + .map_err(|_| StatusCode::NOT_FOUND)?; + + Ok(Json(loaded_provider)) +} + #[utoipa::path( delete, path = "/config/custom-providers/{id}", @@ -701,10 +664,8 @@ pub async fn create_custom_provider( (status = 500, description = "Internal server error") ) )] -pub async fn remove_custom_provider( - axum::extract::Path(id): axum::extract::Path, -) -> Result, StatusCode> { - goose::config::custom_providers::CustomProviderConfig::remove(&id) +pub async fn remove_custom_provider(Path(id): Path) -> Result, StatusCode> { + goose::config::declarative_providers::remove_custom_provider(&id) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; if let Err(e) = goose::providers::refresh_custom_providers().await { @@ -714,6 +675,38 @@ pub async fn remove_custom_provider( Ok(Json(format!("Removed custom provider: {}", id))) } +#[utoipa::path( + put, + path = "/config/custom-providers/{id}", + request_body = UpdateCustomProviderRequest, + responses( + (status = 200, description = "Custom provider updated successfully", body = String), + (status = 404, description = "Provider not found"), + (status = 500, description = "Internal server error") + ) +)] +pub async fn update_custom_provider( + Path(id): Path, + Json(request): Json, +) -> Result, StatusCode> { + goose::config::declarative_providers::update_custom_provider( + &id, + &request.engine, + request.display_name, + request.api_url, + request.api_key, + request.models, + request.supports_streaming, + ) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + if let Err(e) = goose::providers::refresh_custom_providers().await { + tracing::warn!("Failed to refresh custom providers after update: {}", e); + } + + Ok(Json(format!("Updated custom provider: {}", id))) +} + pub fn routes(state: Arc) -> Router { Router::new() .route("/config", get(read_all_config)) @@ -731,12 +724,13 @@ pub fn routes(state: Arc) -> Router { .route("/config/recover", post(recover_config)) .route("/config/validate", get(validate_config)) .route("/config/permissions", post(upsert_permissions)) - .route("/config/current-model", get(get_current_model)) .route("/config/custom-providers", post(create_custom_provider)) .route( "/config/custom-providers/{id}", delete(remove_custom_provider), ) + .route("/config/custom-providers/{id}", put(update_custom_provider)) + .route("/config/custom-providers/{id}", get(get_custom_provider)) .with_state(state) } @@ -768,39 +762,4 @@ mod tests { assert!(gpt4_limit.is_some()); assert_eq!(gpt4_limit.unwrap().context_limit, 128_000); } - - #[tokio::test] - async fn test_get_provider_models_unknown_provider() { - let mut headers = HeaderMap::new(); - headers.insert("X-Secret-Key", "test".parse().unwrap()); - - let result = get_provider_models(Path("unknown_provider".to_string())).await; - - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST); - } - - #[tokio::test] - async fn test_get_provider_models_openai_configured() { - std::env::set_var("OPENAI_API_KEY", "test-key"); - - let mut headers = HeaderMap::new(); - headers.insert("X-Secret-Key", "test".parse().unwrap()); - - let result = get_provider_models(Path("openai".to_string())).await; - - // The response should be BAD_REQUEST since the API key is invalid (authentication error) - assert!( - result.is_err(), - "Expected error response from OpenAI provider with invalid key" - ); - let status_code = result.unwrap_err(); - - assert!(status_code == StatusCode::BAD_REQUEST, - "Expected BAD_REQUEST (authentication error) or INTERNAL_SERVER_ERROR (other errors), got: {}", - status_code - ); - - std::env::remove_var("OPENAI_API_KEY"); - } } diff --git a/crates/goose-server/src/routes/utils.rs b/crates/goose-server/src/routes/utils.rs index c1c32fd00668..81bd41d6b4a3 100644 --- a/crates/goose-server/src/routes/utils.rs +++ b/crates/goose-server/src/routes/utils.rs @@ -1,5 +1,6 @@ +use goose::config::declarative_providers::load_provider; use goose::config::Config; -use goose::providers::base::{ConfigKey, ProviderMetadata}; +use goose::providers::base::{ConfigKey, ProviderMetadata, ProviderType}; use serde::{Deserialize, Serialize}; use std::env; use std::error::Error; @@ -27,7 +28,7 @@ pub fn inspect_key(key_name: &str, is_secret: bool) -> Result bool { +pub fn check_provider_configured(metadata: &ProviderMetadata, provider_type: ProviderType) -> bool { let config = Config::global(); + // TODO(Douwe): if the provider doesn't need an API key, it should be considered configured always + if provider_type == ProviderType::Custom || provider_type == ProviderType::Declarative { + if let Ok(loaded_provider) = load_provider(metadata.name.as_str()) { + return config + .get_secret::(&loaded_provider.config.api_key_env) + .map(|s| !s.is_empty()) + .unwrap_or(false); + } + } // Special case: Zero-config providers (no config keys) if metadata.config_keys.is_empty() { // Check if the provider has been explicitly configured via the UI diff --git a/crates/goose/src/config/custom_providers.rs b/crates/goose/src/config/custom_providers.rs deleted file mode 100644 index e486f39f4d24..000000000000 --- a/crates/goose/src/config/custom_providers.rs +++ /dev/null @@ -1,212 +0,0 @@ -use crate::config::paths::Paths; -use crate::config::Config; -use crate::model::ModelConfig; -use crate::providers::anthropic::AnthropicProvider; -use crate::providers::base::ModelInfo; -use crate::providers::ollama::OllamaProvider; -use crate::providers::openai::OpenAiProvider; -use anyhow::Result; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::path::Path; - -pub fn custom_providers_dir() -> std::path::PathBuf { - Paths::config_dir().join("custom_providers") -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum ProviderEngine { - OpenAI, - Ollama, - Anthropic, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CustomProviderConfig { - pub name: String, - pub engine: ProviderEngine, - pub display_name: String, - pub description: Option, - pub api_key_env: String, - pub base_url: String, - pub models: Vec, - pub headers: Option>, - pub timeout_seconds: Option, - pub supports_streaming: Option, -} - -impl CustomProviderConfig { - pub fn id(&self) -> &str { - &self.name - } - - pub fn display_name(&self) -> &str { - &self.display_name - } - - pub fn models(&self) -> &[ModelInfo] { - &self.models - } - - pub fn generate_id(display_name: &str) -> String { - format!("custom_{}", display_name.to_lowercase().replace(' ', "_")) - } - - pub fn generate_api_key_name(id: &str) -> String { - format!("{}_API_KEY", id.to_uppercase()) - } - - pub fn create_and_save( - provider_type: &str, - display_name: String, - api_url: String, - api_key: String, - models: Vec, - supports_streaming: Option, - ) -> Result { - let id = Self::generate_id(&display_name); - let api_key_name = Self::generate_api_key_name(&id); - - let config = Config::global(); - config.set_secret(&api_key_name, serde_json::Value::String(api_key))?; - - let model_infos: Vec = models - .into_iter() - .map(|name| ModelInfo::new(name, 128000)) - .collect(); - - let provider_config = CustomProviderConfig { - name: id.clone(), - engine: match provider_type { - "openai_compatible" => ProviderEngine::OpenAI, - "anthropic_compatible" => ProviderEngine::Anthropic, - "ollama_compatible" => ProviderEngine::Ollama, - _ => return Err(anyhow::anyhow!("Invalid provider type: {}", provider_type)), - }, - display_name: display_name.clone(), - description: Some(format!("Custom {} provider", display_name)), - api_key_env: api_key_name, - base_url: api_url, - models: model_infos, - headers: None, - timeout_seconds: None, - supports_streaming, - }; - - // save to JSON file - let custom_providers_dir = custom_providers_dir(); - std::fs::create_dir_all(&custom_providers_dir)?; - - let json_content = serde_json::to_string_pretty(&provider_config)?; - let file_path = custom_providers_dir.join(format!("{}.json", id)); - std::fs::write(file_path, json_content)?; - - Ok(provider_config) - } - - pub fn remove(id: &str) -> Result<()> { - let config = Config::global(); - let api_key_name = Self::generate_api_key_name(id); - let _ = config.delete_secret(&api_key_name); - - let custom_providers_dir = custom_providers_dir(); - let file_path = custom_providers_dir.join(format!("{}.json", id)); - - if file_path.exists() { - std::fs::remove_file(file_path)?; - } - - Ok(()) - } -} - -pub fn load_custom_providers(dir: &Path) -> Result> { - if !dir.exists() { - return Ok(Vec::new()); - } - - std::fs::read_dir(dir)? - .filter_map(|entry| { - let path = entry.ok()?.path(); - (path.extension()? == "json").then_some(path) - }) - .map(|path| { - let content = std::fs::read_to_string(&path)?; - serde_json::from_str(&content) - .map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", path.display(), e)) - }) - .collect() -} - -pub fn register_custom_providers( - registry: &mut crate::providers::provider_registry::ProviderRegistry, - dir: &Path, -) -> Result<()> { - let configs = load_custom_providers(dir)?; - - for config in configs { - let config_clone = config.clone(); - let description = config - .description - .clone() - .unwrap_or_else(|| format!("Custom {} provider", config.display_name)); - let default_model = config - .models - .first() - .map(|m| m.name.clone()) - .unwrap_or_default(); - let known_models: Vec = config - .models - .iter() - .map(|m| ModelInfo { - name: m.name.clone(), - context_limit: m.context_limit, - input_token_cost: m.input_token_cost, - output_token_cost: m.output_token_cost, - currency: m.currency.clone(), - supports_cache_control: Some(m.supports_cache_control.unwrap_or(false)), - }) - .collect(); - - match config.engine { - ProviderEngine::OpenAI => { - registry.register_with_name::( - config.name.clone(), - config.display_name.clone(), - description, - default_model, - known_models, - move |model: ModelConfig| { - OpenAiProvider::from_custom_config(model, config_clone.clone()) - }, - ); - } - ProviderEngine::Ollama => { - registry.register_with_name::( - config.name.clone(), - config.display_name.clone(), - description, - default_model, - known_models, - move |model: ModelConfig| { - OllamaProvider::from_custom_config(model, config_clone.clone()) - }, - ); - } - ProviderEngine::Anthropic => { - registry.register_with_name::( - config.name.clone(), - config.display_name.clone(), - description, - default_model, - known_models, - move |model: ModelConfig| { - AnthropicProvider::from_custom_config(model, config_clone.clone()) - }, - ); - } - } - } - Ok(()) -} diff --git a/crates/goose/src/config/declarative_providers.rs b/crates/goose/src/config/declarative_providers.rs new file mode 100644 index 000000000000..10df3ff1021b --- /dev/null +++ b/crates/goose/src/config/declarative_providers.rs @@ -0,0 +1,317 @@ +use crate::config::paths::Paths; +use crate::config::Config; +use crate::providers::anthropic::AnthropicProvider; +use crate::providers::base::{ModelInfo, ProviderType}; +use crate::providers::ollama::OllamaProvider; +use crate::providers::openai::OpenAiProvider; +use anyhow::Result; +use include_dir::{include_dir, Dir}; +use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::Path; +use std::sync::Mutex; +use utoipa::ToSchema; + +static FIXED_PROVIDERS: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/providers/declarative"); + +pub fn custom_providers_dir() -> std::path::PathBuf { + Paths::config_dir().join("custom_providers") +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "lowercase")] +pub enum ProviderEngine { + OpenAI, + Ollama, + Anthropic, +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct DeclarativeProviderConfig { + pub name: String, + pub engine: ProviderEngine, + pub display_name: String, + pub description: Option, + pub api_key_env: String, + pub base_url: String, + pub models: Vec, + pub headers: Option>, + pub timeout_seconds: Option, + pub supports_streaming: Option, +} + +impl DeclarativeProviderConfig { + pub fn id(&self) -> &str { + &self.name + } + + pub fn display_name(&self) -> &str { + &self.display_name + } + + pub fn models(&self) -> &[ModelInfo] { + &self.models + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct LoadedProvider { + pub config: DeclarativeProviderConfig, + pub is_editable: bool, +} + +static ID_GENERATION_LOCK: Lazy> = Lazy::new(|| Mutex::new(())); + +pub fn generate_id(display_name: &str) -> String { + let _guard = ID_GENERATION_LOCK.lock().unwrap(); + + let normalized = display_name.to_lowercase().replace(' ', "_"); + let base_id = format!("custom_{}", normalized); + + let custom_dir = custom_providers_dir(); + let mut candidate_id = base_id.clone(); + let mut counter = 1; + + while custom_dir.join(format!("{}.json", candidate_id)).exists() { + candidate_id = format!("{}_{}", base_id, counter); + counter += 1; + } + + candidate_id +} + +pub fn generate_api_key_name(id: &str) -> String { + format!("{}_API_KEY", id.to_uppercase()) +} + +pub fn create_custom_provider( + engine: &str, + display_name: String, + api_url: String, + api_key: String, + models: Vec, + supports_streaming: Option, +) -> Result { + let id = generate_id(&display_name); + let api_key_name = generate_api_key_name(&id); + + let config = Config::global(); + config.set_secret(&api_key_name, serde_json::Value::String(api_key))?; + + let model_infos: Vec = models + .into_iter() + .map(|name| ModelInfo::new(name, 128000)) + .collect(); + + let provider_config = DeclarativeProviderConfig { + name: id.clone(), + engine: match engine { + "openai_compatible" => ProviderEngine::OpenAI, + "anthropic_compatible" => ProviderEngine::Anthropic, + "ollama_compatible" => ProviderEngine::Ollama, + _ => return Err(anyhow::anyhow!("Invalid provider type: {}", engine)), + }, + display_name: display_name.clone(), + description: Some(format!("Custom {} provider", display_name)), + api_key_env: api_key_name, + base_url: api_url, + models: model_infos, + headers: None, + timeout_seconds: None, + supports_streaming, + }; + + let custom_providers_dir = custom_providers_dir(); + std::fs::create_dir_all(&custom_providers_dir)?; + + let json_content = serde_json::to_string_pretty(&provider_config)?; + let file_path = custom_providers_dir.join(format!("{}.json", id)); + std::fs::write(file_path, json_content)?; + + Ok(provider_config) +} + +pub fn update_custom_provider( + id: &str, + provider_type: &str, + display_name: String, + api_url: String, + api_key: String, + models: Vec, + supports_streaming: Option, +) -> Result<()> { + let loaded_provider = load_provider(id)?; + let existing_config = loaded_provider.config; + let editable = loaded_provider.is_editable; + + let config = Config::global(); + if !api_key.is_empty() { + config.set_secret( + &existing_config.api_key_env, + serde_json::Value::String(api_key), + )?; + } + + if editable { + let model_infos: Vec = models + .into_iter() + .map(|name| ModelInfo::new(name, 128000)) + .collect(); + + let updated_config = DeclarativeProviderConfig { + name: id.to_string(), + engine: match provider_type { + "openai_compatible" => ProviderEngine::OpenAI, + "anthropic_compatible" => ProviderEngine::Anthropic, + "ollama_compatible" => ProviderEngine::Ollama, + _ => return Err(anyhow::anyhow!("Invalid provider type: {}", provider_type)), + }, + display_name, + description: existing_config.description, + api_key_env: existing_config.api_key_env, + base_url: api_url, + models: model_infos, + headers: existing_config.headers, + timeout_seconds: existing_config.timeout_seconds, + supports_streaming, + }; + + let file_path = custom_providers_dir().join(format!("{}.json", id)); + let json_content = serde_json::to_string_pretty(&updated_config)?; + std::fs::write(file_path, json_content)?; + } + Ok(()) +} + +pub fn remove_custom_provider(id: &str) -> Result<()> { + let config = Config::global(); + let api_key_name = generate_api_key_name(id); + let _ = config.delete_secret(&api_key_name); + + let custom_providers_dir = custom_providers_dir(); + let file_path = custom_providers_dir.join(format!("{}.json", id)); + + if file_path.exists() { + std::fs::remove_file(file_path)?; + } + + Ok(()) +} + +pub fn load_provider(id: &str) -> Result { + let custom_file_path = custom_providers_dir().join(format!("{}.json", id)); + + if custom_file_path.exists() { + let content = std::fs::read_to_string(&custom_file_path)?; + let config: DeclarativeProviderConfig = serde_json::from_str(&content)?; + return Ok(LoadedProvider { + config, + is_editable: true, + }); + } + + for file in FIXED_PROVIDERS.files() { + if file.path().extension().and_then(|s| s.to_str()) != Some("json") { + continue; + } + + let content = file + .contents_utf8() + .ok_or_else(|| anyhow::anyhow!("Failed to read file as UTF-8: {:?}", file.path()))?; + + let config: DeclarativeProviderConfig = serde_json::from_str(content)?; + if config.name == id { + return Ok(LoadedProvider { + config, + is_editable: false, + }); + } + } + + Err(anyhow::anyhow!("Provider not found: {}", id)) +} +pub fn load_custom_providers(dir: &Path) -> Result> { + if !dir.exists() { + return Ok(Vec::new()); + } + + std::fs::read_dir(dir)? + .filter_map(|entry| { + let path = entry.ok()?.path(); + (path.extension()? == "json").then_some(path) + }) + .map(|path| { + let content = std::fs::read_to_string(&path)?; + serde_json::from_str(&content) + .map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", path.display(), e)) + }) + .collect() +} + +fn load_fixed_providers() -> Result> { + let mut res = Vec::new(); + for file in FIXED_PROVIDERS.files() { + if file.path().extension().and_then(|s| s.to_str()) != Some("json") { + continue; + } + + let content = file + .contents_utf8() + .ok_or_else(|| anyhow::anyhow!("Failed to read file as UTF-8: {:?}", file.path()))?; + + let config: DeclarativeProviderConfig = serde_json::from_str(content)?; + res.push(config) + } + + Ok(res) +} + +pub fn register_declarative_providers( + registry: &mut crate::providers::provider_registry::ProviderRegistry, +) -> Result<()> { + let dir = custom_providers_dir(); + let custom_providers = load_custom_providers(&dir)?; + let fixed_providers = load_fixed_providers()?; + for config in fixed_providers { + register_declarative_provider(registry, config, ProviderType::Declarative); + } + + for config in custom_providers { + register_declarative_provider(registry, config, ProviderType::Custom); + } + + Ok(()) +} + +pub fn register_declarative_provider( + registry: &mut crate::providers::provider_registry::ProviderRegistry, + config: DeclarativeProviderConfig, + provider_type: ProviderType, +) { + let config_clone = config.clone(); + + match config.engine { + ProviderEngine::OpenAI => { + registry.register_with_name::( + &config, + provider_type, + move |model| OpenAiProvider::from_custom_config(model, config_clone.clone()), + ); + } + ProviderEngine::Ollama => { + registry.register_with_name::( + &config, + provider_type, + move |model| OllamaProvider::from_custom_config(model, config_clone.clone()), + ); + } + ProviderEngine::Anthropic => { + registry.register_with_name::( + &config, + provider_type, + move |model| AnthropicProvider::from_custom_config(model, config_clone.clone()), + ); + } + } +} diff --git a/crates/goose/src/config/mod.rs b/crates/goose/src/config/mod.rs index bdb4d7678000..3ab6f3497ffa 100644 --- a/crates/goose/src/config/mod.rs +++ b/crates/goose/src/config/mod.rs @@ -1,5 +1,5 @@ pub mod base; -pub mod custom_providers; +pub mod declarative_providers; mod experiments; pub mod extensions; pub mod paths; @@ -9,7 +9,7 @@ pub mod signup_tetrate; pub use crate::agents::ExtensionConfig; pub use base::{Config, ConfigError}; -pub use custom_providers::CustomProviderConfig; +pub use declarative_providers::DeclarativeProviderConfig; pub use experiments::ExperimentManager; pub use extensions::{ get_all_extension_names, get_all_extensions, get_enabled_extensions, get_extension_by_name, diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 86b6953486d6..bddc4197040f 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -15,7 +15,7 @@ use super::formats::anthropic::{ create_request, get_usage, response_to_message, response_to_streaming_message, }; use super::utils::{emit_debug_trace, get_model, map_http_error_to_provider_error}; -use crate::config::custom_providers::CustomProviderConfig; +use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::conversation::message::Message; use crate::model::ModelConfig; use crate::providers::retry::ProviderRetry; @@ -69,7 +69,10 @@ impl AnthropicProvider { }) } - pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result { + pub fn from_custom_config( + model: ModelConfig, + config: DeclarativeProviderConfig, + ) -> Result { let global_config = crate::config::Config::global(); let api_key: String = global_config .get_secret(&config.api_key_env) diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 447923097d35..f8dfadee3757 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -81,6 +81,14 @@ impl ModelInfo { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)] +pub enum ProviderType { + Preferred, + Builtin, + Declarative, + Custom, +} + /// Metadata about a provider's configuration requirements and capabilities #[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] pub struct ProviderMetadata { @@ -93,7 +101,6 @@ pub struct ProviderMetadata { /// The default/recommended model for this provider pub default_model: String, /// A list of currently known models with their capabilities - /// TODO: eventually query the apis directly pub known_models: Vec, /// Link to the docs where models can be found pub model_doc_link: String, @@ -132,7 +139,6 @@ impl ProviderMetadata { } } - /// Create a new ProviderMetadata with ModelInfo objects that include cost data pub fn with_models( name: &str, display_name: &str, diff --git a/crates/goose/src/providers/declarative/deepseek.json b/crates/goose/src/providers/declarative/deepseek.json new file mode 100644 index 000000000000..04347d1e6f1f --- /dev/null +++ b/crates/goose/src/providers/declarative/deepseek.json @@ -0,0 +1,29 @@ +{ + "name": "custom_deepseek", + "engine": "openai", + "display_name": "DeepSeek", + "description": "Custom DeepSeek provider", + "api_key_env": "DEEPSEEK_API_KEY", + "base_url": "https://api.deepseek.com", + "models": [ + { + "name": "deepseek-chat", + "context_limit": 128000, + "input_token_cost": null, + "output_token_cost": null, + "currency": null, + "supports_cache_control": null + }, + { + "name": "deepseek-reasoner", + "context_limit": 128000, + "input_token_cost": null, + "output_token_cost": null, + "currency": null, + "supports_cache_control": null + } + ], + "headers": null, + "timeout_seconds": null, + "supports_streaming": true +} \ No newline at end of file diff --git a/crates/goose/src/providers/declarative/groq.json b/crates/goose/src/providers/declarative/groq.json new file mode 100644 index 000000000000..8578ed8af4a7 --- /dev/null +++ b/crates/goose/src/providers/declarative/groq.json @@ -0,0 +1,31 @@ +{ + "name": "groq", + "engine": "openai", + "display_name": "Groq (d)", + "description": "Fast inference with Groq hardware", + "api_key_env": "GROQ_API_KEY", + "base_url": "https://api.groq.com/openai/v1/chat/completions", + "models": [ + { + "name": "openai/gpt-oss-120b", + "context_limit": 131072 + }, + { + "name": "llama-3.1-8b-instant", + "context_limit": 131072 + }, + { + "name": "llama-3.3-70b-versatile", + "context_limit": 131072 + }, + { + "name": "meta-llama/llama-guard-4-12b", + "context_limit": 131072 + }, + { + "name": "openai/gpt-oss-20b", + "context_limit": 131072 + } + ], + "supports_streaming": true +} \ No newline at end of file diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index be6b6eb59733..bdc8f569ec60 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -12,7 +12,6 @@ use super::{ gemini_cli::GeminiCliProvider, githubcopilot::GithubCopilotProvider, google::GoogleProvider, - groq::GroqProvider, lead_worker::LeadWorkerProvider, litellm::LiteLLMProvider, ollama::OllamaProvider, @@ -25,8 +24,9 @@ use super::{ venice::VeniceProvider, xai::XaiProvider, }; -use crate::config::custom_providers::{custom_providers_dir, register_custom_providers}; +use crate::config::declarative_providers::register_declarative_providers; use crate::model::ModelConfig; +use crate::providers::base::ProviderType; use anyhow::Result; use tokio::sync::OnceCell; @@ -38,28 +38,43 @@ static REGISTRY: OnceCell> = OnceCell::const_new(); async fn init_registry() -> RwLock { let mut registry = ProviderRegistry::new().with_providers(|registry| { - registry.register::(|m| Box::pin(AnthropicProvider::from_env(m))); - registry.register::(|m| Box::pin(AzureProvider::from_env(m))); - registry.register::(|m| Box::pin(BedrockProvider::from_env(m))); - registry.register::(|m| Box::pin(ClaudeCodeProvider::from_env(m))); - registry.register::(|m| Box::pin(CursorAgentProvider::from_env(m))); - registry.register::(|m| Box::pin(DatabricksProvider::from_env(m))); - registry.register::(|m| Box::pin(GcpVertexAIProvider::from_env(m))); - registry.register::(|m| Box::pin(GeminiCliProvider::from_env(m))); registry - .register::(|m| Box::pin(GithubCopilotProvider::from_env(m))); - registry.register::(|m| Box::pin(GoogleProvider::from_env(m))); - registry.register::(|m| Box::pin(GroqProvider::from_env(m))); - registry.register::(|m| Box::pin(LiteLLMProvider::from_env(m))); - registry.register::(|m| Box::pin(OllamaProvider::from_env(m))); - registry.register::(|m| Box::pin(OpenAiProvider::from_env(m))); - registry.register::(|m| Box::pin(OpenRouterProvider::from_env(m))); + .register::(|m| Box::pin(AnthropicProvider::from_env(m)), true); + registry.register::(|m| Box::pin(AzureProvider::from_env(m)), false); + registry.register::(|m| Box::pin(BedrockProvider::from_env(m)), false); registry - .register::(|m| Box::pin(SageMakerTgiProvider::from_env(m))); - registry.register::(|m| Box::pin(SnowflakeProvider::from_env(m))); - registry.register::(|m| Box::pin(TetrateProvider::from_env(m))); - registry.register::(|m| Box::pin(VeniceProvider::from_env(m))); - registry.register::(|m| Box::pin(XaiProvider::from_env(m))); + .register::(|m| Box::pin(ClaudeCodeProvider::from_env(m)), true); + registry.register::( + |m| Box::pin(CursorAgentProvider::from_env(m)), + false, + ); + registry + .register::(|m| Box::pin(DatabricksProvider::from_env(m)), true); + registry.register::( + |m| Box::pin(GcpVertexAIProvider::from_env(m)), + false, + ); + registry + .register::(|m| Box::pin(GeminiCliProvider::from_env(m)), false); + registry.register::( + |m| Box::pin(GithubCopilotProvider::from_env(m)), + false, + ); + registry.register::(|m| Box::pin(GoogleProvider::from_env(m)), true); + registry.register::(|m| Box::pin(LiteLLMProvider::from_env(m)), false); + registry.register::(|m| Box::pin(OllamaProvider::from_env(m)), true); + registry.register::(|m| Box::pin(OpenAiProvider::from_env(m)), true); + registry + .register::(|m| Box::pin(OpenRouterProvider::from_env(m)), true); + registry.register::( + |m| Box::pin(SageMakerTgiProvider::from_env(m)), + false, + ); + registry + .register::(|m| Box::pin(SnowflakeProvider::from_env(m)), false); + registry.register::(|m| Box::pin(TetrateProvider::from_env(m)), true); + registry.register::(|m| Box::pin(VeniceProvider::from_env(m)), false); + registry.register::(|m| Box::pin(XaiProvider::from_env(m)), false); }); if let Err(e) = load_custom_providers_into_registry(&mut registry) { tracing::warn!("Failed to load custom providers: {}", e); @@ -68,16 +83,19 @@ async fn init_registry() -> RwLock { } fn load_custom_providers_into_registry(registry: &mut ProviderRegistry) -> Result<()> { - let config_dir = custom_providers_dir(); - register_custom_providers(registry, &config_dir) + register_declarative_providers(registry) } async fn get_registry() -> &'static RwLock { REGISTRY.get_or_init(init_registry).await } -pub async fn providers() -> Vec { - get_registry().await.read().unwrap().all_metadata() +pub async fn providers() -> Vec<(ProviderMetadata, ProviderType)> { + get_registry() + .await + .read() + .unwrap() + .all_metadata_with_types() } pub async fn refresh_custom_providers() -> Result<()> { diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs deleted file mode 100644 index dec37a278887..000000000000 --- a/crates/goose/src/providers/groq.rs +++ /dev/null @@ -1,131 +0,0 @@ -use super::api_client::{ApiClient, AuthMethod}; -use super::errors::ProviderError; -use super::retry::ProviderRetry; -use super::utils::{get_model, handle_response_openai_compat}; -use crate::conversation::message::Message; -use crate::model::ModelConfig; -use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; -use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; -use anyhow::Result; -use async_trait::async_trait; -use rmcp::model::Tool; -use serde_json::Value; - -pub const GROQ_API_HOST: &str = "https://api.groq.com"; -pub const GROQ_DEFAULT_MODEL: &str = "moonshotai/kimi-k2-instruct"; -pub const GROQ_KNOWN_MODELS: &[&str] = &[ - "gemma2-9b-it", - "llama-3.3-70b-versatile", - "moonshotai/kimi-k2-instruct", - "qwen/qwen3-32b", -]; - -pub const GROQ_DOC_URL: &str = "https://console.groq.com/docs/models"; - -#[derive(serde::Serialize)] -pub struct GroqProvider { - #[serde(skip)] - api_client: ApiClient, - model: ModelConfig, -} - -impl GroqProvider { - pub async fn from_env(model: ModelConfig) -> Result { - let config = crate::config::Config::global(); - let api_key: String = config.get_secret("GROQ_API_KEY")?; - let host: String = config - .get_param("GROQ_HOST") - .unwrap_or_else(|_| GROQ_API_HOST.to_string()); - - let auth = AuthMethod::BearerToken(api_key); - let api_client = ApiClient::new(host, auth)?; - - Ok(Self { api_client, model }) - } - - async fn post(&self, payload: Value) -> Result { - let response = self - .api_client - .response_post("openai/v1/chat/completions", &payload) - .await?; - handle_response_openai_compat(response).await - } -} - -#[async_trait] -impl Provider for GroqProvider { - fn metadata() -> ProviderMetadata { - ProviderMetadata::new( - "groq", - "Groq", - "Fast inference with Groq hardware", - GROQ_DEFAULT_MODEL, - GROQ_KNOWN_MODELS.to_vec(), - GROQ_DOC_URL, - vec![ - ConfigKey::new("GROQ_API_KEY", true, true, None), - ConfigKey::new("GROQ_HOST", false, false, Some(GROQ_API_HOST)), - ], - ) - } - - fn get_model_config(&self) -> ModelConfig { - self.model.clone() - } - - #[tracing::instrument( - skip(self, model_config, system, messages, tools), - fields(model_config, input, output, input_tokens, output_tokens, total_tokens) - )] - async fn complete_with_model( - &self, - model_config: &ModelConfig, - system: &str, - messages: &[Message], - tools: &[Tool], - ) -> Result<(Message, ProviderUsage), ProviderError> { - let payload = create_request( - model_config, - system, - messages, - tools, - &super::utils::ImageFormat::OpenAi, - )?; - - let response = self.with_retry(|| self.post(payload.clone())).await?; - - let message = response_to_message(&response)?; - let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { - tracing::debug!("Failed to get usage data"); - Usage::default() - }); - let response_model = get_model(&response); - super::utils::emit_debug_trace(model_config, &payload, &response, &usage); - Ok((message, ProviderUsage::new(response_model, usage))) - } - - /// Fetch supported models from Groq; returns Err on failure, Ok(None) if no models found - async fn fetch_supported_models(&self) -> Result>, ProviderError> { - let response = self - .api_client - .request("openai/v1/models") - .header("Content-Type", "application/json")? - .response_get() - .await?; - let response = handle_response_openai_compat(response).await?; - - let data = response - .get("data") - .and_then(|v| v.as_array()) - .ok_or_else(|| { - ProviderError::UsageError("Missing or invalid `data` field in response".into()) - })?; - - let mut model_names: Vec = data - .iter() - .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(String::from)) - .collect(); - model_names.sort(); - Ok(Some(model_names)) - } -} diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 52e8ba0185b9..d50502a42b92 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -16,7 +16,6 @@ pub mod gcpvertexai; pub mod gemini_cli; pub mod githubcopilot; pub mod google; -pub mod groq; pub mod lead_worker; pub mod litellm; pub mod oauth; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 15aeca146fc7..f0044f6328b8 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -3,7 +3,7 @@ use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, Provider use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{get_model, handle_response_openai_compat, handle_status_openai_compat}; -use crate::config::custom_providers::CustomProviderConfig; +use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::conversation::message::Message; use crate::conversation::Conversation; @@ -89,7 +89,10 @@ impl OllamaProvider { }) } - pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result { + pub fn from_custom_config( + model: ModelConfig, + config: DeclarativeProviderConfig, + ) -> Result { let timeout = Duration::from_secs(config.timeout_seconds.unwrap_or(OLLAMA_TIMEOUT)); // Parse and normalize the custom URL diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 9c5794c1de24..31e4041b5b2f 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -20,7 +20,7 @@ use super::utils::{ emit_debug_trace, get_model, handle_response_openai_compat, handle_status_openai_compat, ImageFormat, }; -use crate::config::custom_providers::CustomProviderConfig; +use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::conversation::message::Message; use crate::model::ModelConfig; @@ -110,7 +110,10 @@ impl OpenAiProvider { }) } - pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result { + pub fn from_custom_config( + model: ModelConfig, + config: DeclarativeProviderConfig, + ) -> Result { let global_config = crate::config::Config::global(); let api_key: String = global_config .get_secret(&config.api_key_env) diff --git a/crates/goose/src/providers/provider_registry.rs b/crates/goose/src/providers/provider_registry.rs index 5a0ebe4332e4..9e6f568ce45c 100644 --- a/crates/goose/src/providers/provider_registry.rs +++ b/crates/goose/src/providers/provider_registry.rs @@ -1,4 +1,5 @@ -use super::base::{Provider, ProviderMetadata}; +use super::base::{ModelInfo, Provider, ProviderMetadata, ProviderType}; +use crate::config::DeclarativeProviderConfig; use crate::model::ModelConfig; use anyhow::Result; use futures::future::BoxFuture; @@ -11,6 +12,7 @@ type ProviderConstructor = pub struct ProviderEntry { metadata: ProviderMetadata, pub(crate) constructor: ProviderConstructor, + provider_type: ProviderType, } #[derive(Default)] @@ -25,7 +27,7 @@ impl ProviderRegistry { } } - pub fn register(&mut self, constructor: F) + pub fn register(&mut self, constructor: F, preferred: bool) where P: Provider + 'static, F: Fn(ModelConfig) -> BoxFuture<'static, Result

> + Send + Sync + 'static, @@ -44,26 +46,50 @@ impl ProviderRegistry { Ok(Arc::new(provider) as Arc) }) }), + provider_type: if preferred { + ProviderType::Preferred + } else { + ProviderType::Builtin + }, }, ); } pub fn register_with_name( &mut self, - custom_name: String, - display_name: String, - description: String, - default_model: String, - known_models: Vec, + config: &DeclarativeProviderConfig, + provider_type: ProviderType, constructor: F, ) where P: Provider + 'static, F: Fn(ModelConfig) -> Result

+ Send + Sync + 'static, { let base_metadata = P::metadata(); + let description = config + .description + .clone() + .unwrap_or_else(|| format!("Custom {} provider", config.display_name)); + let default_model = config + .models + .first() + .map(|m| m.name.clone()) + .unwrap_or_default(); + let known_models: Vec = config + .models + .iter() + .map(|m| ModelInfo { + name: m.name.clone(), + context_limit: m.context_limit, + input_token_cost: m.input_token_cost, + output_token_cost: m.output_token_cost, + currency: m.currency.clone(), + supports_cache_control: Some(m.supports_cache_control.unwrap_or(false)), + }) + .collect(); + let custom_metadata = ProviderMetadata { - name: custom_name.clone(), - display_name, + name: config.name.clone(), + display_name: config.display_name.clone(), description, default_model, known_models, @@ -72,7 +98,7 @@ impl ProviderRegistry { }; self.entries.insert( - custom_name, + config.name.clone(), ProviderEntry { metadata: custom_metadata, constructor: Arc::new(move |model| { @@ -82,6 +108,7 @@ impl ProviderRegistry { Ok(Arc::new(provider) as Arc) }) }), + provider_type, }, ); } @@ -103,8 +130,11 @@ impl ProviderRegistry { (entry.constructor)(model).await } - pub fn all_metadata(&self) -> Vec { - self.entries.values().map(|e| e.metadata.clone()).collect() + pub fn all_metadata_with_types(&self) -> Vec<(ProviderMetadata, ProviderType)> { + self.entries + .values() + .map(|e| (e.metadata.clone(), e.provider_type)) + .collect() } pub fn remove_custom_providers(&mut self) { diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index b03dcc2354f3..9c6ee2aa2a9d 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -80,6 +80,14 @@ pub fn map_http_error_to_provider_error( ); ProviderError::Authentication(message) } + StatusCode::PAYLOAD_TOO_LARGE => { + let payload_str = if let Some(payload) = &payload { + payload.to_string() + } else { + "Payload is too large.".to_string() + }; + ProviderError::ContextLengthExceeded(payload_str) + } StatusCode::BAD_REQUEST => { let mut error_msg = "Unknown error".to_string(); if let Some(payload) = &payload { @@ -929,12 +937,12 @@ mod tests { "The model 'gpt-5' does not exist (code: model_not_found, type: invalid_request_error) (status 404)".to_string(), )), ), - // Non-JSON body error (tests parse failure path) + // Non-JSON body error (tests 413 PAYLOAD_TOO_LARGE -> ContextLengthExceeded) ( 413, Some(Value::String("Payload Too Large".to_string())), - Err(ProviderError::RequestFailed( - "Request failed with status: 413 Payload Too Large".to_string(), + Err(ProviderError::ContextLengthExceeded( + "Payload is too large.".to_string(), )), ), ]; diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 157c588817e7..510fb30b9362 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -10,8 +10,8 @@ use goose::providers::base::Provider; use goose::providers::{ anthropic::AnthropicProvider, azure::AzureProvider, bedrock::BedrockProvider, databricks::DatabricksProvider, gcpvertexai::GcpVertexAIProvider, google::GoogleProvider, - groq::GroqProvider, ollama::OllamaProvider, openai::OpenAiProvider, - openrouter::OpenRouterProvider, xai::XaiProvider, + ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, + xai::XaiProvider, }; #[derive(Debug, PartialEq)] @@ -24,7 +24,6 @@ enum ProviderType { Databricks, GcpVertexAI, Google, - Groq, Ollama, OpenRouter, Xai, @@ -43,7 +42,6 @@ impl ProviderType { ProviderType::Bedrock => &["AWS_PROFILE"], ProviderType::Databricks => &["DATABRICKS_HOST"], ProviderType::Google => &["GOOGLE_API_KEY"], - ProviderType::Groq => &["GROQ_API_KEY"], ProviderType::Ollama => &[], ProviderType::OpenRouter => &["OPENROUTER_API_KEY"], ProviderType::GcpVertexAI => &["GCP_PROJECT_ID", "GCP_LOCATION"], @@ -80,7 +78,6 @@ impl ProviderType { Arc::new(GcpVertexAIProvider::from_env(model_config).await?) } ProviderType::Google => Arc::new(GoogleProvider::from_env(model_config).await?), - ProviderType::Groq => Arc::new(GroqProvider::from_env(model_config).await?), ProviderType::Ollama => Arc::new(OllamaProvider::from_env(model_config).await?), ProviderType::OpenRouter => Arc::new(OpenRouterProvider::from_env(model_config).await?), ProviderType::Xai => Arc::new(XaiProvider::from_env(model_config).await?), @@ -305,16 +302,6 @@ mod tests { .await } - #[tokio::test] - async fn test_agent_with_groq() -> Result<()> { - run_test_with_config(TestConfig { - provider_type: ProviderType::Groq, - model: "gemma2-9b-it", - context_window: 9_000, - }) - .await - } - #[tokio::test] async fn test_agent_with_openrouter() -> Result<()> { run_test_with_config(TestConfig { diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 85449c845194..7863943da5d0 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -9,7 +9,6 @@ use goose::providers::create_with_named_model; use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; use goose::providers::errors::ProviderError; use goose::providers::google::GOOGLE_DEFAULT_MODEL; -use goose::providers::groq::GROQ_DEFAULT_MODEL; use goose::providers::litellm::LITELLM_DEFAULT_MODEL; use goose::providers::ollama::OLLAMA_DEFAULT_MODEL; use goose::providers::openai::OPEN_AI_DEFAULT_MODEL; @@ -501,11 +500,6 @@ async fn test_ollama_provider() -> Result<()> { test_provider("Ollama", OLLAMA_DEFAULT_MODEL, &["OLLAMA_HOST"], None).await } -#[tokio::test] -async fn test_groq_provider() -> Result<()> { - test_provider("Groq", GROQ_DEFAULT_MODEL, &["GROQ_API_KEY"], None).await -} - #[tokio::test] async fn test_anthropic_provider() -> Result<()> { test_provider( diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index a53946407b46..cb12e582cfd4 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -382,7 +382,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateCustomProviderRequest" + "$ref": "#/components/schemas/UpdateCustomProviderRequest" } } }, @@ -409,6 +409,84 @@ } }, "/config/custom-providers/{id}": { + "get": { + "tags": [ + "super::routes::config_management" + ], + "operationId": "get_custom_provider", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Custom provider retrieved successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LoadedProvider" + } + } + } + }, + "404": { + "description": "Provider not found" + }, + "500": { + "description": "Internal server error" + } + } + }, + "put": { + "tags": [ + "super::routes::config_management" + ], + "operationId": "update_custom_provider", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UpdateCustomProviderRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Custom provider updated successfully", + "content": { + "text/plain": { + "schema": { + "type": "string" + } + } + } + }, + "404": { + "description": "Provider not found" + }, + "500": { + "description": "Internal server error" + } + } + }, "delete": { "tags": [ "super::routes::config_management" @@ -2203,40 +2281,6 @@ "$ref": "#/components/schemas/Message" } }, - "CreateCustomProviderRequest": { - "type": "object", - "required": [ - "provider_type", - "display_name", - "api_url", - "api_key", - "models" - ], - "properties": { - "api_key": { - "type": "string" - }, - "api_url": { - "type": "string" - }, - "display_name": { - "type": "string" - }, - "models": { - "type": "array", - "items": { - "type": "string" - } - }, - "provider_type": { - "type": "string" - }, - "supports_streaming": { - "type": "boolean", - "nullable": true - } - } - }, "CreateRecipeRequest": { "type": "object", "required": [ @@ -2296,6 +2340,61 @@ } } }, + "DeclarativeProviderConfig": { + "type": "object", + "required": [ + "name", + "engine", + "display_name", + "api_key_env", + "base_url", + "models" + ], + "properties": { + "api_key_env": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "description": { + "type": "string", + "nullable": true + }, + "display_name": { + "type": "string" + }, + "engine": { + "$ref": "#/components/schemas/ProviderEngine" + }, + "headers": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "nullable": true + }, + "models": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ModelInfo" + } + }, + "name": { + "type": "string" + }, + "supports_streaming": { + "type": "boolean", + "nullable": true + }, + "timeout_seconds": { + "type": "integer", + "format": "int64", + "nullable": true, + "minimum": 0 + } + } + }, "DecodeRecipeRequest": { "type": "object", "required": [ @@ -2979,6 +3078,21 @@ } } }, + "LoadedProvider": { + "type": "object", + "required": [ + "config", + "is_editable" + ], + "properties": { + "config": { + "$ref": "#/components/schemas/DeclarativeProviderConfig" + }, + "is_editable": { + "type": "boolean" + } + } + }, "Message": { "type": "object", "description": "A message to or from an LLM", @@ -3347,7 +3461,8 @@ "required": [ "name", "metadata", - "is_configured" + "is_configured", + "provider_type" ], "properties": { "is_configured": { @@ -3358,9 +3473,20 @@ }, "name": { "type": "string" + }, + "provider_type": { + "$ref": "#/components/schemas/ProviderType" } } }, + "ProviderEngine": { + "type": "string", + "enum": [ + "openai", + "ollama", + "anthropic" + ] + }, "ProviderMetadata": { "type": "object", "description": "Metadata about a provider's configuration requirements and capabilities", @@ -3398,7 +3524,7 @@ "items": { "$ref": "#/components/schemas/ModelInfo" }, - "description": "A list of currently known models with their capabilities\nTODO: eventually query the apis directly" + "description": "A list of currently known models with their capabilities" }, "model_doc_link": { "type": "string", @@ -3410,6 +3536,15 @@ } } }, + "ProviderType": { + "type": "string", + "enum": [ + "Preferred", + "Builtin", + "Declarative", + "Custom" + ] + }, "ProvidersResponse": { "type": "object", "required": [ @@ -4449,6 +4584,40 @@ } } }, + "UpdateCustomProviderRequest": { + "type": "object", + "required": [ + "engine", + "display_name", + "api_url", + "api_key", + "models" + ], + "properties": { + "api_key": { + "type": "string" + }, + "api_url": { + "type": "string" + }, + "display_name": { + "type": "string" + }, + "engine": { + "type": "string" + }, + "models": { + "type": "array", + "items": { + "type": "string" + } + }, + "supports_streaming": { + "type": "boolean", + "nullable": true + } + } + }, "UpdateProviderRequest": { "type": "object", "required": [ diff --git a/ui/desktop/src/api/sdk.gen.ts b/ui/desktop/src/api/sdk.gen.ts index 13fe5d7899da..c87f1dbc69b7 100644 --- a/ui/desktop/src/api/sdk.gen.ts +++ b/ui/desktop/src/api/sdk.gen.ts @@ -1,7 +1,7 @@ // This file is auto-generated by @hey-api/openapi-ts import type { Options as ClientOptions, TDataShape, Client } from './client'; -import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, ResumeAgentData, ResumeAgentResponses, ResumeAgentErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, StartAgentData, StartAgentResponses, StartAgentErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, GetProviderModelsData, GetProviderModelsResponses, GetProviderModelsErrors, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, StartOpenrouterSetupData, StartOpenrouterSetupResponses, StartTetrateSetupData, StartTetrateSetupResponses, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, DeleteRecipeData, DeleteRecipeResponses, DeleteRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ListRecipesData, ListRecipesResponses, ListRecipesErrors, ParseRecipeData, ParseRecipeResponses, ParseRecipeErrors, SaveRecipeData, SaveRecipeResponses, SaveRecipeErrors, ScanRecipeData, ScanRecipeResponses, ReplyData, ReplyResponses, ReplyErrors, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, ImportSessionData, ImportSessionResponses, ImportSessionErrors, GetSessionInsightsData, GetSessionInsightsResponses, GetSessionInsightsErrors, DeleteSessionData, DeleteSessionResponses, DeleteSessionErrors, GetSessionData, GetSessionResponses, GetSessionErrors, UpdateSessionDescriptionData, UpdateSessionDescriptionResponses, UpdateSessionDescriptionErrors, ExportSessionData, ExportSessionResponses, ExportSessionErrors, UpdateSessionUserRecipeValuesData, UpdateSessionUserRecipeValuesResponses, UpdateSessionUserRecipeValuesErrors, StatusData, StatusResponses } from './types.gen'; +import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, ResumeAgentData, ResumeAgentResponses, ResumeAgentErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, StartAgentData, StartAgentResponses, StartAgentErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, GetCustomProviderData, GetCustomProviderResponses, GetCustomProviderErrors, UpdateCustomProviderData, UpdateCustomProviderResponses, UpdateCustomProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, GetProviderModelsData, GetProviderModelsResponses, GetProviderModelsErrors, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, StartOpenrouterSetupData, StartOpenrouterSetupResponses, StartTetrateSetupData, StartTetrateSetupResponses, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, DeleteRecipeData, DeleteRecipeResponses, DeleteRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ListRecipesData, ListRecipesResponses, ListRecipesErrors, ParseRecipeData, ParseRecipeResponses, ParseRecipeErrors, SaveRecipeData, SaveRecipeResponses, SaveRecipeErrors, ScanRecipeData, ScanRecipeResponses, ReplyData, ReplyResponses, ReplyErrors, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, ImportSessionData, ImportSessionResponses, ImportSessionErrors, GetSessionInsightsData, GetSessionInsightsResponses, GetSessionInsightsErrors, DeleteSessionData, DeleteSessionResponses, DeleteSessionErrors, GetSessionData, GetSessionResponses, GetSessionErrors, UpdateSessionDescriptionData, UpdateSessionDescriptionResponses, UpdateSessionDescriptionErrors, ExportSessionData, ExportSessionResponses, ExportSessionErrors, UpdateSessionUserRecipeValuesData, UpdateSessionUserRecipeValuesResponses, UpdateSessionUserRecipeValuesErrors, StatusData, StatusResponses } from './types.gen'; import { client as _heyApiClient } from './client.gen'; export type Options = ClientOptions & { @@ -134,6 +134,24 @@ export const removeCustomProvider = (optio }); }; +export const getCustomProvider = (options: Options) => { + return (options.client ?? _heyApiClient).get({ + url: '/config/custom-providers/{id}', + ...options + }); +}; + +export const updateCustomProvider = (options: Options) => { + return (options.client ?? _heyApiClient).put({ + url: '/config/custom-providers/{id}', + ...options, + headers: { + 'Content-Type': 'application/json', + ...options.headers + } + }); +}; + export const getExtensions = (options?: Options) => { return (options?.client ?? _heyApiClient).get({ url: '/config/extensions', diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index e4b3ccbeeda3..a20a4f34df29 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -110,15 +110,6 @@ export type ContextManageResponse = { export type Conversation = Array; -export type CreateCustomProviderRequest = { - api_key: string; - api_url: string; - display_name: string; - models: Array; - provider_type: string; - supports_streaming?: boolean | null; -}; - export type CreateRecipeRequest = { author?: AuthorRequest | null; session_id: string; @@ -136,6 +127,21 @@ export type CreateScheduleRequest = { recipe_source: string; }; +export type DeclarativeProviderConfig = { + api_key_env: string; + base_url: string; + description?: string | null; + display_name: string; + engine: ProviderEngine; + headers?: { + [key: string]: string; + } | null; + models: Array; + name: string; + supports_streaming?: boolean | null; + timeout_seconds?: number | null; +}; + export type DecodeRecipeRequest = { deeplink: string; }; @@ -367,6 +373,11 @@ export type ListSchedulesResponse = { jobs: Array; }; +export type LoadedProvider = { + config: DeclarativeProviderConfig; + is_editable: boolean; +}; + /** * A message to or from an LLM */ @@ -473,8 +484,11 @@ export type ProviderDetails = { is_configured: boolean; metadata: ProviderMetadata; name: string; + provider_type: ProviderType; }; +export type ProviderEngine = 'openai' | 'ollama' | 'anthropic'; + /** * Metadata about a provider's configuration requirements and capabilities */ @@ -497,7 +511,6 @@ export type ProviderMetadata = { display_name: string; /** * A list of currently known models with their capabilities - * TODO: eventually query the apis directly */ known_models: Array; /** @@ -510,6 +523,8 @@ export type ProviderMetadata = { name: string; }; +export type ProviderType = 'Preferred' | 'Builtin' | 'Declarative' | 'Custom'; + export type ProvidersResponse = { providers: Array; }; @@ -852,6 +867,15 @@ export type ToolResponse = { }; }; +export type UpdateCustomProviderRequest = { + api_key: string; + api_url: string; + display_name: string; + engine: string; + models: Array; + supports_streaming?: boolean | null; +}; + export type UpdateProviderRequest = { model?: string | null; provider: string; @@ -1183,7 +1207,7 @@ export type BackupConfigResponses = { export type BackupConfigResponse = BackupConfigResponses[keyof BackupConfigResponses]; export type CreateCustomProviderData = { - body: CreateCustomProviderRequest; + body: UpdateCustomProviderRequest; path?: never; query?: never; url: '/config/custom-providers'; @@ -1238,6 +1262,64 @@ export type RemoveCustomProviderResponses = { export type RemoveCustomProviderResponse = RemoveCustomProviderResponses[keyof RemoveCustomProviderResponses]; +export type GetCustomProviderData = { + body?: never; + path: { + id: string; + }; + query?: never; + url: '/config/custom-providers/{id}'; +}; + +export type GetCustomProviderErrors = { + /** + * Provider not found + */ + 404: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type GetCustomProviderResponses = { + /** + * Custom provider retrieved successfully + */ + 200: LoadedProvider; +}; + +export type GetCustomProviderResponse = GetCustomProviderResponses[keyof GetCustomProviderResponses]; + +export type UpdateCustomProviderData = { + body: UpdateCustomProviderRequest; + path: { + id: string; + }; + query?: never; + url: '/config/custom-providers/{id}'; +}; + +export type UpdateCustomProviderErrors = { + /** + * Provider not found + */ + 404: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type UpdateCustomProviderResponses = { + /** + * Custom provider updated successfully + */ + 200: string; +}; + +export type UpdateCustomProviderResponse = UpdateCustomProviderResponses[keyof UpdateCustomProviderResponses]; + export type GetExtensionsData = { body?: never; path?: never; diff --git a/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx b/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx index dca71110880b..051ee3af375d 100644 --- a/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx +++ b/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx @@ -18,6 +18,7 @@ import { useModelAndProvider } from '../../../ModelAndProviderContext'; import type { View } from '../../../../utils/navigationUtils'; import Model, { getProviderMetadata } from '../modelInterface'; import { getPredefinedModelsFromEnv, shouldShowPredefinedModels } from '../predefinedModelsUtils'; +import { ProviderType } from '../../../../api'; type SwitchModelModalProps = { sessionId: string | null; @@ -165,8 +166,9 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod const results = await Promise.all(modelPromises); // Process results and build grouped options - const groupedOptions: { options: { value: string; label: string; provider: string }[] }[] = - []; + const groupedOptions: { + options: { value: string; label: string; provider: string; providerType: ProviderType }[]; + }[] = []; const errors: string[] = []; results.forEach(({ provider: p, models, error }) => { @@ -178,13 +180,19 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod options: p.metadata.known_models.map(({ name }) => ({ value: name, label: name, + providerType: p.provider_type, provider: p.name, })), }); } } else if (models && models.length > 0) { groupedOptions.push({ - options: models.map((m) => ({ value: m, label: m, provider: p.name })), + options: models.map((m) => ({ + value: m, + label: m, + provider: p.name, + providerType: p.provider_type, + })), }); } }); @@ -196,12 +204,14 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod // Add the "Custom model" option to each provider group groupedOptions.forEach((group) => { - const providerName = group.options[0]?.provider; - if (providerName && !providerName.startsWith('custom_')) { + const option = group.options[0]; + const providerName = option?.provider; + if (providerName && option?.providerType !== 'Custom') { group.options.push({ value: 'custom', label: 'Use custom model', provider: providerName, + providerType: option?.providerType, }); } }); diff --git a/ui/desktop/src/components/settings/providers/ProviderGrid.tsx b/ui/desktop/src/components/settings/providers/ProviderGrid.tsx index 661342782ae9..cdce9e8cfe7d 100644 --- a/ui/desktop/src/components/settings/providers/ProviderGrid.tsx +++ b/ui/desktop/src/components/settings/providers/ProviderGrid.tsx @@ -3,7 +3,11 @@ import { ProviderCard } from './subcomponents/ProviderCard'; import CardContainer from './subcomponents/CardContainer'; import { ProviderModalProvider, useProviderModal } from './modal/ProviderModalProvider'; import ProviderConfigurationModal from './modal/ProviderConfiguationModal'; -import { ProviderDetails, CreateCustomProviderRequest } from '../../../api'; +import { + DeclarativeProviderConfig, + ProviderDetails, + UpdateCustomProviderRequest, +} from '../../../api'; import { Plus } from 'lucide-react'; import { Dialog, DialogContent, DialogHeader, DialogTitle } from '../../ui/dialog'; import CustomProviderForm from './modal/subcomponents/forms/CustomProviderForm'; @@ -43,7 +47,6 @@ const CustomProviderCard = memo(function CustomProviderCard({ onClick }: { onCli ); }); -// Memoize the ProviderCards component const ProviderCards = memo(function ProviderCards({ providers, isOnboarding, @@ -57,28 +60,69 @@ const ProviderCards = memo(function ProviderCards({ }) { const { openModal } = useProviderModal(); const [showCustomProviderModal, setShowCustomProviderModal] = useState(false); + const [editingProvider, setEditingProvider] = useState<{ + id: string; + config: DeclarativeProviderConfig; + isEditable: boolean; + } | null>(null); - // Memoize these functions so they don't get recreated on every render const configureProviderViaModal = useCallback( - (provider: ProviderDetails) => { - openModal(provider, { - onSubmit: () => { - // Only refresh if the function is provided - if (refreshProviders) { - refreshProviders(); - } - }, - onDelete: (_values: unknown) => { - if (refreshProviders) { - refreshProviders(); - } - }, - formProps: {}, - }); + async (provider: ProviderDetails) => { + if (provider.provider_type === 'Custom' || provider.provider_type === 'Declarative') { + const { getCustomProvider } = await import('../../../api'); + const result = await getCustomProvider({ path: { id: provider.name }, throwOnError: true }); + + if (result.data) { + setEditingProvider({ + id: provider.name, + config: result.data.config, + isEditable: result.data.is_editable, + }); + setShowCustomProviderModal(true); + } + } else { + openModal(provider, { + onSubmit: () => { + if (refreshProviders) { + refreshProviders(); + } + }, + onDelete: (_values: unknown) => { + if (refreshProviders) { + refreshProviders(); + } + }, + formProps: {}, + }); + } }, [openModal, refreshProviders] ); + const handleUpdateCustomProvider = useCallback( + async (data: UpdateCustomProviderRequest) => { + if (!editingProvider) return; + + const { updateCustomProvider } = await import('../../../api'); + await updateCustomProvider({ + path: { id: editingProvider.id }, + body: data, + throwOnError: true, + }); + setShowCustomProviderModal(false); + setEditingProvider(null); + if (refreshProviders) { + refreshProviders(); + } + }, + [editingProvider, refreshProviders] + ); + + const handleCloseModal = useCallback(() => { + setShowCustomProviderModal(false); + setEditingProvider(null); + }, []); + const deleteProviderConfigViaModal = useCallback( (provider: ProviderDetails) => { openModal(provider, { @@ -95,22 +139,17 @@ const ProviderCards = memo(function ProviderCards({ ); const handleCreateCustomProvider = useCallback( - async (data: CreateCustomProviderRequest) => { - try { - const { createCustomProvider } = await import('../../../api'); - await createCustomProvider({ body: data }); - setShowCustomProviderModal(false); - if (refreshProviders) { - refreshProviders(); - } - } catch (error) { - console.error('Failed to create custom provider:', error); + async (data: UpdateCustomProviderRequest) => { + const { createCustomProvider } = await import('../../../api'); + await createCustomProvider({ body: data, throwOnError: true }); + setShowCustomProviderModal(false); + if (refreshProviders) { + refreshProviders(); } }, [refreshProviders] ); - // Use useMemo to memoize the cards array const providerCards = useMemo(() => { // providers needs to be an array const providersArray = Array.isArray(providers) ? providers : []; @@ -138,21 +177,33 @@ const ProviderCards = memo(function ProviderCards({ onProviderLaunch, ]); + const initialData = editingProvider && { + engine: editingProvider.config.engine.toLowerCase() + '_compatible', + display_name: editingProvider.config.display_name, + api_url: editingProvider.config.base_url, + api_key: '', + models: editingProvider.config.models.map((m) => m.name), + supports_streaming: editingProvider.config.supports_streaming ?? true, + }; + + const editable = editingProvider ? editingProvider.isEditable : true; + const title = (editingProvider ? (editable ? 'Edit' : 'Configure') : 'Add') + ' Provider'; return ( <> {providerCards} - -

+ - Add Custom Provider + {title} setShowCustomProviderModal(false)} + initialData={initialData} + isEditable={editable} + onSubmit={editingProvider ? handleUpdateCustomProvider : handleCreateCustomProvider} + onCancel={handleCloseModal} /> - + {' '} ); }); diff --git a/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx b/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx index 58b0df67dc86..b8971a69b6de 100644 --- a/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx +++ b/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx @@ -162,7 +162,7 @@ export default function ProviderConfigurationModal() { } try { - const isCustomProvider = currentProvider.name.startsWith('custom_'); + const isCustomProvider = currentProvider.provider_type === 'Custom'; if (isCustomProvider) { await removeCustomProvider({ diff --git a/ui/desktop/src/components/settings/providers/modal/subcomponents/forms/CustomProviderForm.tsx b/ui/desktop/src/components/settings/providers/modal/subcomponents/forms/CustomProviderForm.tsx index 9ce694fa958b..be57a3ec9918 100644 --- a/ui/desktop/src/components/settings/providers/modal/subcomponents/forms/CustomProviderForm.tsx +++ b/ui/desktop/src/components/settings/providers/modal/subcomponents/forms/CustomProviderForm.tsx @@ -1,24 +1,25 @@ -import React, { useState } from 'react'; +import React, { useState, useEffect } from 'react'; import { Input } from '../../../../../ui/input'; import { Select } from '../../../../../ui/Select'; import { Button } from '../../../../../ui/button'; import { SecureStorageNotice } from '../SecureStorageNotice'; import { Checkbox } from '@radix-ui/themes'; +import { UpdateCustomProviderRequest } from '../../../../../../api'; interface CustomProviderFormProps { - onSubmit: (data: { - provider_type: string; - display_name: string; - api_url: string; - api_key: string; - models: string[]; - supports_streaming: boolean; - }) => void; + onSubmit: (data: UpdateCustomProviderRequest) => void; onCancel: () => void; + initialData: UpdateCustomProviderRequest | null; + isEditable?: boolean; } -export default function CustomProviderForm({ onSubmit, onCancel }: CustomProviderFormProps) { - const [providerType, setProviderType] = useState('openai_compatible'); +export default function CustomProviderForm({ + onSubmit, + onCancel, + initialData, + isEditable, +}: CustomProviderFormProps) { + const [engine, setEngine] = useState('openai_compatible'); const [displayName, setDisplayName] = useState(''); const [apiUrl, setApiUrl] = useState(''); const [apiKey, setApiKey] = useState(''); @@ -27,6 +28,22 @@ export default function CustomProviderForm({ onSubmit, onCancel }: CustomProvide const [supportsStreaming, setSupportsStreaming] = useState(true); const [validationErrors, setValidationErrors] = useState>({}); + useEffect(() => { + if (initialData) { + const engineMap: Record = { + openai: 'openai_compatible', + anthropic: 'anthropic_compatible', + ollama: 'ollama_compatible', + }; + + setEngine(engineMap[initialData.engine.toLowerCase()] || 'openai_compatible'); + setDisplayName(initialData.display_name); + setApiUrl(initialData.api_url); + setModels(initialData.models.join(', ')); + setSupportsStreaming(initialData.supports_streaming ?? true); + } + }, [initialData]); + const handleLocalModels = (checked: boolean) => { setIsLocalModel(checked); if (checked) { @@ -42,7 +59,7 @@ export default function CustomProviderForm({ onSubmit, onCancel }: CustomProvide const errors: Record = {}; if (!displayName) errors.displayName = 'Display name is required'; if (!apiUrl) errors.apiUrl = 'API URL is required'; - if (!isLocalModel && !apiKey) errors.apiKey = 'API key is required'; + if (!isLocalModel && !apiKey && !initialData) errors.apiKey = 'API key is required'; if (!models) errors.models = 'At least one model is required'; if (Object.keys(errors).length > 0) { @@ -56,7 +73,7 @@ export default function CustomProviderForm({ onSubmit, onCancel }: CustomProvide .filter((m) => m); onSubmit({ - provider_type: providerType, + engine, display_name: displayName, api_url: apiUrl, api_key: apiKey, @@ -67,92 +84,94 @@ export default function CustomProviderForm({ onSubmit, onCancel }: CustomProvide return (
-
- - setDisplayName(e.target.value)} - placeholder="Your Provider Name" - aria-invalid={!!validationErrors.displayName} - aria-describedby={validationErrors.displayName ? 'display-name-error' : undefined} - className={validationErrors.displayName ? 'border-red-500' : ''} - /> - {validationErrors.displayName && ( -

- {validationErrors.displayName} -

- )} -
- -
- - setApiUrl(e.target.value)} - placeholder="https://api.example.com/v1/messages" - aria-invalid={!!validationErrors.apiUrl} - aria-describedby={validationErrors.apiUrl ? 'api-url-error' : undefined} - className={validationErrors.apiUrl ? 'border-red-500' : ''} - /> - {validationErrors.apiUrl && ( -

- {validationErrors.apiUrl} -

- )} -
+ {isEditable && ( + <> +
+ + setDisplayName(e.target.value)} + placeholder="Your Provider Name" + aria-invalid={!!validationErrors.displayName} + aria-describedby={validationErrors.displayName ? 'display-name-error' : undefined} + className={validationErrors.displayName ? 'border-red-500' : ''} + /> + {validationErrors.displayName && ( +

+ {validationErrors.displayName} +

+ )} +
+
+ + setApiUrl(e.target.value)} + placeholder="https://api.example.com/v1/messages" + aria-invalid={!!validationErrors.apiUrl} + aria-describedby={validationErrors.apiUrl ? 'api-url-error' : undefined} + className={validationErrors.apiUrl ? 'border-red-500' : ''} + /> + {validationErrors.apiUrl && ( +

+ {validationErrors.apiUrl} +

+ )} +
+ + )}
setApiKey(e.target.value)} - placeholder="Your API key" + placeholder={initialData ? 'Leave blank to keep existing key' : 'Your API key'} aria-invalid={!!validationErrors.apiKey} aria-describedby={validationErrors.apiKey ? 'api-key-error' : undefined} className={validationErrors.apiKey ? 'border-red-500' : ''} @@ -179,62 +198,64 @@ export default function CustomProviderForm({ onSubmit, onCancel }: CustomProvide

)} -
- - -
-
- -
- - setModels(e.target.value)} - placeholder="model-a, model-b, model-c" - aria-invalid={!!validationErrors.models} - aria-describedby={validationErrors.models ? 'available-models-error' : undefined} - className={validationErrors.models ? 'border-red-500' : ''} - /> - {validationErrors.models && ( -

- {validationErrors.models} -

+ {!initialData && ( +
+ + +
)}
- -
- setSupportsStreaming(checked as boolean)} - /> - -
- + {isEditable && ( + <> +
+ + setModels(e.target.value)} + placeholder="model-a, model-b, model-c" + aria-invalid={!!validationErrors.models} + aria-describedby={validationErrors.models ? 'available-models-error' : undefined} + className={validationErrors.models ? 'border-red-500' : ''} + /> + {validationErrors.models && ( +

+ {validationErrors.models} +

+ )} +
+
+ setSupportsStreaming(checked as boolean)} + /> + +
+ + )} -
- +
);