-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Fixes and improvements to get custom providers working in the UI #4557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ use crate::routes::utils::check_provider_configured; | |
| use crate::state::AppState; | ||
| use axum::{ | ||
| extract::{Path, State}, | ||
| routing::{delete, get, post}, | ||
| routing::{delete, get, post, put}, | ||
| Json, Router, | ||
| }; | ||
| use etcetera::{choose_app_strategy, AppStrategy}; | ||
|
|
@@ -87,6 +87,19 @@ pub struct CreateCustomProviderRequest { | |
| pub supports_streaming: Option<bool>, | ||
| } | ||
|
|
||
| #[derive(Deserialize, Serialize, ToSchema)] | ||
| pub struct UpdateCustomProviderRequest { | ||
| pub display_name: Option<String>, | ||
| pub api_url: Option<String>, | ||
| pub api_key: Option<String>, | ||
| pub models: Option<Vec<String>>, | ||
| pub supports_streaming: Option<bool>, | ||
| // Editable provider JSON fields | ||
| pub description: Option<String>, | ||
| pub headers: Option<std::collections::HashMap<String, String>>, | ||
| pub timeout_seconds: Option<u64>, | ||
| } | ||
|
|
||
| #[utoipa::path( | ||
| post, | ||
| path = "/config/upsert", | ||
|
|
@@ -104,6 +117,18 @@ pub async fn upsert_config( | |
| verify_secret_key(&headers, &state)?; | ||
|
|
||
| let config = Config::global(); | ||
|
|
||
| // Defensive guard: if this is a secret check and the client sent a boolean 'true' to | ||
| // indicate presence, do not write that boolean into the secret storage. Treat as a no-op. | ||
| if query.is_secret { | ||
| if let Value::Bool(b) = &query.value { | ||
| if *b { | ||
| tracing::info!(key = %query.key, "Skipping upsert for secret key with boolean true (presence-only)"); | ||
| return Ok(Json(Value::String(format!("Skipped secret upsert for {} (presence-only)", query.key)))); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| let result = config.set(&query.key, query.value, query.is_secret); | ||
|
|
||
| match result { | ||
|
|
@@ -316,64 +341,7 @@ pub async fn providers( | |
| ) -> Result<Json<Vec<ProviderDetails>>, StatusCode> { | ||
| verify_secret_key(&headers, &state)?; | ||
|
|
||
| let mut providers_metadata = get_providers(); | ||
|
|
||
| let custom_providers_dir = goose::config::custom_providers::custom_providers_dir(); | ||
|
|
||
| if custom_providers_dir.exists() { | ||
| if let Ok(entries) = std::fs::read_dir(&custom_providers_dir) { | ||
| for entry in entries.flatten() { | ||
| if let Some(extension) = entry.path().extension() { | ||
| if extension == "json" { | ||
| if let Ok(content) = std::fs::read_to_string(entry.path()) { | ||
| if let Ok(custom_provider) = serde_json::from_str::< | ||
| goose::config::custom_providers::CustomProviderConfig, | ||
| >(&content) | ||
| { | ||
| // CustomProviderConfig => ProviderMetadata | ||
| let default_model = custom_provider | ||
| .models | ||
| .first() | ||
| .map(|m| m.name.clone()) | ||
| .unwrap_or_default(); | ||
|
|
||
| let metadata = goose::providers::base::ProviderMetadata { | ||
| name: custom_provider.name.clone(), | ||
| display_name: custom_provider.display_name.clone(), | ||
| description: custom_provider | ||
| .description | ||
| .clone() | ||
| .unwrap_or_else(|| { | ||
| format!("{} (custom)", custom_provider.display_name) | ||
| }), | ||
| default_model, | ||
| known_models: custom_provider.models.clone(), | ||
| model_doc_link: "Custom provider".to_string(), | ||
| config_keys: vec![ | ||
| goose::providers::base::ConfigKey::new( | ||
| &custom_provider.api_key_env, | ||
| true, | ||
| true, | ||
| None, | ||
| ), | ||
| goose::providers::base::ConfigKey::new( | ||
| "CUSTOM_PROVIDER_BASE_URL", | ||
| true, | ||
| false, | ||
| Some(&custom_provider.base_url), | ||
| ), | ||
| ], | ||
| }; | ||
| providers_metadata.push(metadata); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| let providers_response: Vec<ProviderDetails> = providers_metadata | ||
| let providers_response: Vec<ProviderDetails> = get_providers() | ||
| .into_iter() | ||
| .map(|metadata| { | ||
| let is_configured = check_provider_configured(&metadata); | ||
|
|
@@ -766,6 +734,8 @@ pub async fn get_current_model( | |
|
|
||
| #[utoipa::path( | ||
| post, | ||
|
|
||
|
|
||
| path = "/config/custom-providers", | ||
| request_body = CreateCustomProviderRequest, | ||
| responses( | ||
|
|
@@ -798,6 +768,124 @@ pub async fn create_custom_provider( | |
| Ok(Json(format!("Custom provider added - ID: {}", config.id()))) | ||
| } | ||
|
|
||
| #[utoipa::path( | ||
| put, | ||
| path = "/config/custom-providers/{id}", | ||
| request_body = UpdateCustomProviderRequest, | ||
| responses( | ||
| (status = 200, description = "Custom provider updated successfully", body = String), | ||
| (status = 400, description = "Invalid request"), | ||
| (status = 404, description = "Provider not found"), | ||
| (status = 500, description = "Internal server error") | ||
| ) | ||
| )] | ||
| pub async fn update_custom_provider( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you remove the over eager LLM comments here? |
||
| State(state): State<Arc<AppState>>, | ||
| headers: HeaderMap, | ||
| axum::extract::Path(id): axum::extract::Path<String>, | ||
| Json(request): Json<UpdateCustomProviderRequest>, | ||
| ) -> Result<Json<String>, StatusCode> { | ||
| verify_secret_key(&headers, &state)?; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we got rid of this, but presumably syncing to main will tell you |
||
|
|
||
| // Log incoming update for debugging | ||
| let payload_value = serde_json::to_value(&request).unwrap_or(serde_json::Value::Null); | ||
| tracing::info!(id = %id, payload = ?payload_value, "update_custom_provider called"); | ||
|
|
||
| let custom_providers_dir = goose::config::custom_providers::custom_providers_dir(); | ||
| let file_path = custom_providers_dir.join(format!("{}.json", id)); | ||
|
|
||
| if !file_path.exists() { | ||
| return Err(StatusCode::NOT_FOUND); | ||
| } | ||
|
|
||
| // Read existing config | ||
| let content = std::fs::read_to_string(&file_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
| let mut config: goose::config::custom_providers::CustomProviderConfig = | ||
| serde_json::from_str(&content).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
|
|
||
| // Update fields if provided | ||
| if let Some(display_name) = request.display_name { | ||
| config.display_name = display_name; | ||
| } | ||
| if let Some(api_url) = request.api_url { | ||
| config.base_url = api_url; | ||
| } | ||
| if let Some(models) = request.models { | ||
| config.models = models | ||
| .into_iter() | ||
| .map(|m| goose::providers::base::ModelInfo::new(m, 128000)) | ||
| .collect(); | ||
| } | ||
| if let Some(s) = request.supports_streaming { | ||
| config.supports_streaming = Some(s); | ||
| } | ||
|
|
||
| // Persist API key into secrets if provided | ||
| // Update optional JSON fields if provided | ||
| if let Some(desc) = request.description { | ||
| config.description = Some(desc); | ||
| } | ||
| if let Some(hdrs) = request.headers { | ||
| config.headers = Some(hdrs); | ||
| } | ||
| if let Some(t) = request.timeout_seconds { | ||
| config.timeout_seconds = Some(t); | ||
| } | ||
|
|
||
| // Persist API key into secrets if provided | ||
| if let Some(api_key) = request.api_key { | ||
| let cfg = goose::config::Config::global(); | ||
| if let Err(e) = cfg.set_secret(&config.api_key_env, serde_json::Value::String(api_key)) { | ||
| tracing::error!("Failed to set secret for {}: {}", config.api_key_env, e); | ||
| return Err(StatusCode::INTERNAL_SERVER_ERROR); | ||
| } | ||
| } | ||
|
|
||
| // Save updated JSON atomically | ||
| let tmp = file_path.with_extension("json.tmp"); | ||
| let json_content = serde_json::to_string_pretty(&config).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
| std::fs::write(&tmp, &json_content).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
| std::fs::rename(&tmp, &file_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
|
|
||
| // Refresh in-memory providers | ||
| if let Err(e) = goose::providers::refresh_custom_providers() { | ||
| tracing::warn!("Failed to refresh custom providers after update: {}", e); | ||
| } | ||
|
|
||
| Ok(Json(format!("Updated custom provider: {}", id))) | ||
| } | ||
|
|
||
|
|
||
| #[utoipa::path( | ||
| get, | ||
| path = "/config/custom-providers/{id}", | ||
| responses( | ||
| (status = 200, description = "Custom provider retrieved successfully", body = goose::config::custom_providers::CustomProviderConfig), | ||
| (status = 404, description = "Provider not found"), | ||
| (status = 500, description = "Internal server error") | ||
| ) | ||
| )] | ||
| pub async fn get_custom_provider( | ||
| State(state): State<Arc<AppState>>, | ||
| headers: HeaderMap, | ||
| axum::extract::Path(id): axum::extract::Path<String>, | ||
| ) -> Result<Json<goose::config::custom_providers::CustomProviderConfig>, StatusCode> { | ||
| verify_secret_key(&headers, &state)?; | ||
|
|
||
| let custom_providers_dir = goose::config::custom_providers::custom_providers_dir(); | ||
| let file_path = custom_providers_dir.join(format!("{}.json", id)); | ||
|
|
||
| if !file_path.exists() { | ||
| return Err(StatusCode::NOT_FOUND); | ||
| } | ||
|
|
||
| let content = std::fs::read_to_string(&file_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
| let config: goose::config::custom_providers::CustomProviderConfig = | ||
| serde_json::from_str(&content).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
|
|
||
| Ok(Json(config)) | ||
| } | ||
|
|
||
| #[utoipa::path( | ||
| delete, | ||
| path = "/config/custom-providers/{id}", | ||
|
|
@@ -845,7 +933,7 @@ pub fn routes(state: Arc<AppState>) -> Router { | |
| .route("/config/custom-providers", post(create_custom_provider)) | ||
| .route( | ||
| "/config/custom-providers/{id}", | ||
| delete(remove_custom_provider), | ||
| get(get_custom_provider).delete(remove_custom_provider).put(update_custom_provider), | ||
| ) | ||
| .with_state(state) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -265,14 +265,17 @@ async fn reply_handler( | |
| Ok(stream) => stream, | ||
| Err(e) => { | ||
| tracing::error!("Failed to start reply stream: {:?}", e); | ||
| stream_event( | ||
| MessageEvent::Error { | ||
| error: e.to_string(), | ||
| }, | ||
| let err_text = e.to_string(); | ||
| // send Error event (for telemetry / UI error handling) | ||
| let _ = stream_event( | ||
| MessageEvent::Error { error: err_text.clone() }, | ||
| &task_tx, | ||
| &cancel_token, | ||
| ) | ||
| .await; | ||
| // also send a visible assistant message so the UI shows it inline | ||
| let assistant_msg = Message::assistant().with_text(format!("Provider error: {}", err_text)); | ||
| let _ = stream_event(MessageEvent::Message { message: assistant_msg }, &task_tx, &cancel_token).await; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we want to show the error (I thought we already did?) I think it would be better to just render the Error event instead of stream it as a Message
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when does this happen? |
||
| return; | ||
| } | ||
| }; | ||
|
|
@@ -312,8 +315,32 @@ async fn reply_handler( | |
| track_tool_telemetry(content, all_messages.messages()); | ||
| } | ||
|
|
||
| // Push and send the message event | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this comment |
||
| all_messages.push(message.clone()); | ||
| stream_event(MessageEvent::Message { message }, &tx, &cancel_token).await; | ||
| stream_event(MessageEvent::Message { message: message.clone() }, &tx, &cancel_token).await; | ||
|
|
||
| // If this message appears to be a provider streaming error produced by | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this doesn't feel right. |
||
| // the OpenAI-format streaming handler, surface an Error event as | ||
| // well so the frontend's error handling path runs and the UI | ||
| // visibly shows the failure. We detect this by looking for the | ||
| // distinctive prefix we emit when streaming errors are encountered. | ||
| if let Some(first_text) = message | ||
| .content | ||
| .iter() | ||
| .find_map(|c| match c { | ||
| goose::conversation::message::MessageContent::Text(t) => Some(t.text.clone()), | ||
| _ => None, | ||
| }) | ||
| .filter(|s| s.starts_with("LLM streaming error encountered")) | ||
| { | ||
| // Send a short error string (avoid flooding the SSE with huge payloads) | ||
| let short = if first_text.len() > 1024 { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't use .len or [..] for text since that breaks CJK. we have a function to this, safe_truncate |
||
| format!("{}...", &first_text[..1024]) | ||
| } else { | ||
| first_text | ||
| }; | ||
| let _ = stream_event(MessageEvent::Error { error: short }, &tx, &cancel_token).await; | ||
| } | ||
| } | ||
| Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => { | ||
| // Replace the message history with the compacted messages | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was slightly confusing to me -- where do we have a secret with a boolean and value and why should it be skipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is because our API returns true when getting a secret value or false if there is no value. however, I don't think we should block setting secret values to true - it is not much a of a value to keep secret. we should just not do this write. is there a particular case where this is happening?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the windows ui the true value seemed to cause keyring issues. I have tried a detailed split , but that failed. Retry with a subset, to get the custom provider bugs fixed first.