diff --git a/crates/goose-server/src/openapi.rs b/crates/goose-server/src/openapi.rs index 540f7762fa84..0e9265a6521c 100644 --- a/crates/goose-server/src/openapi.rs +++ b/crates/goose-server/src/openapi.rs @@ -366,6 +366,7 @@ impl<'__s> ToSchema<'__s> for AnnotatedSchema { super::routes::config_management::get_extensions, super::routes::config_management::read_all_config, super::routes::config_management::providers, + super::routes::config_management::get_provider_models, super::routes::config_management::upsert_permissions, super::routes::config_management::create_custom_provider, super::routes::config_management::remove_custom_provider, diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 79a177943d57..0224cc79565d 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -2,7 +2,7 @@ use super::utils::verify_secret_key; use crate::routes::utils::check_provider_configured; use crate::state::AppState; use axum::{ - extract::State, + extract::{Path, State}, routing::{delete, get, post}, Json, Router, }; @@ -386,6 +386,66 @@ pub async fn providers( Ok(Json(providers_response)) } +#[utoipa::path( + get, + path = "/config/providers/{name}/models", + params( + ("name" = String, Path, description = "Provider name (e.g., openai)") + ), + responses( + (status = 200, description = "Models fetched successfully", body = [String]), + (status = 400, description = "Unknown provider, provider not configured, or authentication error"), + (status = 429, description = "Rate limit exceeded"), + (status = 500, description = "Internal server error") + ) +)] +pub async fn get_provider_models( + State(state): State>, + headers: HeaderMap, + Path(name): Path, +) -> Result>, StatusCode> { + verify_secret_key(&headers, &state)?; + + let all = get_providers(); + let Some(metadata) = all.into_iter().find(|m| m.name == name) else { + return Err(StatusCode::BAD_REQUEST); + }; + if !check_provider_configured(&metadata) { + return Err(StatusCode::BAD_REQUEST); + } + + let model_config = + ModelConfig::new(&metadata.default_model).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let provider = goose::providers::create(&name, model_config) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + match provider.fetch_supported_models().await { + Ok(Some(models)) => Ok(Json(models)), + Ok(None) => Ok(Json(Vec::new())), + Err(provider_error) => { + use goose::providers::errors::ProviderError; + let status_code = match provider_error { + // Permanent misconfigurations - client should fix configuration + ProviderError::Authentication(_) => StatusCode::BAD_REQUEST, + ProviderError::UsageError(_) => StatusCode::BAD_REQUEST, + + // Transient errors - client should retry later + ProviderError::RateLimitExceeded(_) => StatusCode::TOO_MANY_REQUESTS, + + // All other errors - internal server error + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + + tracing::warn!( + "Provider {} failed to fetch models: {}", + name, + provider_error + ); + Err(status_code) + } + } +} + #[derive(Serialize, ToSchema)] pub struct PricingData { pub provider: String, @@ -771,6 +831,7 @@ pub fn routes(state: Arc) -> Router { .route("/config/extensions", post(add_extension)) .route("/config/extensions/{name}", delete(remove_extension)) .route("/config/providers", get(providers)) + .route("/config/providers/{name}/models", get(get_provider_models)) .route("/config/pricing", post(get_pricing)) .route("/config/init", post(init_config)) .route("/config/backup", post(backup_config)) @@ -790,8 +851,7 @@ pub fn routes(state: Arc) -> Router { mod tests { use super::*; - #[tokio::test] - async fn test_read_model_limits() { + async fn create_test_state() -> Arc { let test_state = AppState::new( Arc::new(goose::agents::Agent::default()), "test".to_string(), @@ -805,6 +865,12 @@ mod tests { .await .unwrap(); test_state.set_scheduler(sched).await; + test_state + } + + #[tokio::test] + async fn test_read_model_limits() { + let test_state = create_test_state().await; let mut headers = HeaderMap::new(); headers.insert("X-Secret-Key", "test".parse().unwrap()); @@ -829,4 +895,47 @@ mod tests { assert!(gpt4_limit.is_some()); assert_eq!(gpt4_limit.unwrap().context_limit, 128_000); } + + #[tokio::test] + async fn test_get_provider_models_unknown_provider() { + let test_state = create_test_state().await; + let mut headers = HeaderMap::new(); + headers.insert("X-Secret-Key", "test".parse().unwrap()); + + let result = get_provider_models( + State(test_state), + headers, + Path("unknown_provider".to_string()), + ) + .await; + + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn test_get_provider_models_openai_configured() { + std::env::set_var("OPENAI_API_KEY", "test-key"); + + let test_state = create_test_state().await; + let mut headers = HeaderMap::new(); + headers.insert("X-Secret-Key", "test".parse().unwrap()); + + let result = + get_provider_models(State(test_state), headers, Path("openai".to_string())).await; + + // The response should be BAD_REQUEST since the API key is invalid (authentication error) + assert!( + result.is_err(), + "Expected error response from OpenAI provider with invalid key" + ); + let status_code = result.unwrap_err(); + + assert!(status_code == StatusCode::BAD_REQUEST, + "Expected BAD_REQUEST (authentication error) or INTERNAL_SERVER_ERROR (other errors), got: {}", + status_code + ); + + std::env::remove_var("OPENAI_API_KEY"); + } } diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 0ac9f0123db9..61542bf46a31 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -522,6 +522,49 @@ } } }, + "/config/providers/{name}/models": { + "get": { + "tags": [ + "super::routes::config_management" + ], + "operationId": "get_provider_models", + "parameters": [ + { + "name": "name", + "in": "path", + "description": "Provider name (e.g., openai)", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Models fetched successfully", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "type": "string" + } + } + } + } + }, + "400": { + "description": "Unknown provider, provider not configured, or authentication error" + }, + "429": { + "description": "Rate limit exceeded" + }, + "500": { + "description": "Internal server error" + } + } + } + }, "/config/read": { "post": { "tags": [ diff --git a/ui/desktop/src/api/sdk.gen.ts b/ui/desktop/src/api/sdk.gen.ts index e5407d98db3d..a36b1ce97c9f 100644 --- a/ui/desktop/src/api/sdk.gen.ts +++ b/ui/desktop/src/api/sdk.gen.ts @@ -1,7 +1,7 @@ // This file is auto-generated by @hey-api/openapi-ts import type { Options as ClientOptions, TDataShape, Client } from './client'; -import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ScanRecipeData, ScanRecipeResponses, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, GetSessionHistoryData, GetSessionHistoryResponses, GetSessionHistoryErrors } from './types.gen'; +import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, GetProviderModelsData, GetProviderModelsResponses, GetProviderModelsErrors, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ScanRecipeData, ScanRecipeResponses, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, GetSessionHistoryData, GetSessionHistoryResponses, GetSessionHistoryErrors } from './types.gen'; import { client as _heyApiClient } from './client.gen'; export type Options = ClientOptions & { @@ -158,6 +158,13 @@ export const providers = (options?: Option }); }; +export const getProviderModels = (options: Options) => { + return (options.client ?? _heyApiClient).get({ + url: '/config/providers/{name}/models', + ...options + }); +}; + export const readConfig = (options: Options) => { return (options.client ?? _heyApiClient).post({ url: '/config/read', diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 34e30c2ba596..76a9e3fcd13b 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -1261,6 +1261,42 @@ export type ProvidersResponses = { export type ProvidersResponse2 = ProvidersResponses[keyof ProvidersResponses]; +export type GetProviderModelsData = { + body?: never; + path: { + /** + * Provider name (e.g., openai) + */ + name: string; + }; + query?: never; + url: '/config/providers/{name}/models'; +}; + +export type GetProviderModelsErrors = { + /** + * Unknown provider, provider not configured, or authentication error + */ + 400: unknown; + /** + * Rate limit exceeded + */ + 429: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type GetProviderModelsResponses = { + /** + * Models fetched successfully + */ + 200: Array; +}; + +export type GetProviderModelsResponse = GetProviderModelsResponses[keyof GetProviderModelsResponses]; + export type ReadConfigData = { body: ConfigKeyQuery; path?: never; diff --git a/ui/desktop/src/components/ConfigContext.tsx b/ui/desktop/src/components/ConfigContext.tsx index f39175f819e5..8bda007082de 100644 --- a/ui/desktop/src/components/ConfigContext.tsx +++ b/ui/desktop/src/components/ConfigContext.tsx @@ -8,6 +8,7 @@ import { addExtension as apiAddExtension, removeExtension as apiRemoveExtension, providers, + getProviderModels as apiGetProviderModels, } from '../api'; import type { ConfigResponse, @@ -39,6 +40,7 @@ interface ConfigContextType { removeExtension: (name: string) => Promise; getProviders: (b: boolean) => Promise; getExtensions: (b: boolean) => Promise; + getProviderModels: (providerName: string) => Promise; disableAllExtensions: () => Promise; enableBotExtensions: (extensions: ExtensionConfig[]) => Promise; } @@ -185,6 +187,21 @@ export const ConfigProvider: React.FC = ({ children }) => { [providersList] ); + const getProviderModels = useCallback(async (providerName: string): Promise => { + try { + const response = await apiGetProviderModels({ + path: { name: providerName }, + headers: { + 'X-Secret-Key': await window.electron.getSecretKey(), + }, + }); + return response.data || []; + } catch (error) { + console.error(`Failed to fetch models for provider ${providerName}:`, error); + return []; + } + }, []); + useEffect(() => { // Load all configuration data and providers on mount (async () => { @@ -242,6 +259,7 @@ export const ConfigProvider: React.FC = ({ children }) => { toggleExtension, getProviders, getExtensions, + getProviderModels, disableAllExtensions, enableBotExtensions, }; @@ -257,6 +275,7 @@ export const ConfigProvider: React.FC = ({ children }) => { toggleExtension, getProviders, getExtensions, + getProviderModels, reloadConfig, ]); diff --git a/ui/desktop/src/components/settings/models/subcomponents/AddModelModal.tsx b/ui/desktop/src/components/settings/models/subcomponents/AddModelModal.tsx index ae50156b571e..88ef4e695249 100644 --- a/ui/desktop/src/components/settings/models/subcomponents/AddModelModal.tsx +++ b/ui/desktop/src/components/settings/models/subcomponents/AddModelModal.tsx @@ -24,12 +24,11 @@ type AddModelModalProps = { setView: (view: View) => void; }; export const AddModelModal = ({ onClose, setView }: AddModelModalProps) => { - const { getProviders, read } = useConfig(); + const { getProviders, getProviderModels, read } = useConfig(); const { changeModel } = useModelAndProvider(); const [providerOptions, setProviderOptions] = useState<{ value: string; label: string }[]>([]); - const [modelOptions, setModelOptions] = useState< - { options: { value: string; label: string; provider: string }[] }[] - >([]); + type ModelOption = { value: string; label: string; provider: string; isDisabled?: boolean }; + const [modelOptions, setModelOptions] = useState<{ options: ModelOption[] }[]>([]); const [provider, setProvider] = useState(null); const [model, setModel] = useState(''); const [isCustomModel, setIsCustomModel] = useState(false); @@ -42,6 +41,7 @@ export const AddModelModal = ({ onClose, setView }: AddModelModalProps) => { const [usePredefinedModels] = useState(shouldShowPredefinedModels()); const [selectedPredefinedModel, setSelectedPredefinedModel] = useState(null); const [predefinedModels, setPredefinedModels] = useState([]); + const [loadingModels, setLoadingModels] = useState(false); // Validate form data const validateForm = useCallback(() => { @@ -141,24 +141,60 @@ export const AddModelModal = ({ onClose, setView }: AddModelModalProps) => { }, ]); - // Format model options by provider - const formattedModelOptions: { - options: { value: string; label: string; provider: string }[]; - }[] = []; - activeProviders.forEach(({ metadata, name }) => { - if (metadata.known_models && metadata.known_models.length > 0) { - formattedModelOptions.push({ - options: metadata.known_models.map(({ name: modelName }) => ({ - value: modelName, - label: modelName, - provider: name, - })), + setLoadingModels(true); + + // Fetching models for all providers + const modelPromises = activeProviders.map(async (p) => { + const providerName = p.name; + try { + let models = await getProviderModels(providerName); + // Fallback to known_models if server returned none + if ((!models || models.length === 0) && p.metadata.known_models?.length) { + models = p.metadata.known_models.map((m) => m.name); + } + return { provider: p, models, error: null }; + } catch (e: unknown) { + return { + provider: p, + models: null, + error: `Failed to fetch models for ${providerName}${e instanceof Error ? `: ${e.message}` : ''}`, + }; + } + }); + const results = await Promise.all(modelPromises); + + // Process results and build grouped options + const groupedOptions: { options: { value: string; label: string; provider: string }[] }[] = + []; + const errors: string[] = []; + + results.forEach(({ provider: p, models, error }) => { + if (error) { + errors.push(error); + // Fallback to metadata known_models on error + if (p.metadata.known_models && p.metadata.known_models.length > 0) { + groupedOptions.push({ + options: p.metadata.known_models.map(({ name }) => ({ + value: name, + label: name, + provider: p.name, + })), + }); + } + } else if (models && models.length > 0) { + groupedOptions.push({ + options: models.map((m) => ({ value: m, label: m, provider: p.name })), }); } }); + // Log errors if any providers failed (don't show to user) + if (errors.length > 0) { + console.error('Provider model fetch errors:', errors); + } + // Add the "Custom model" option to each provider group - formattedModelOptions.forEach((group) => { + groupedOptions.forEach((group) => { const providerName = group.options[0]?.provider; if (providerName && !providerName.startsWith('custom_')) { group.options.push({ @@ -169,13 +205,15 @@ export const AddModelModal = ({ onClose, setView }: AddModelModalProps) => { } }); - setModelOptions(formattedModelOptions); - setOriginalModelOptions(formattedModelOptions); - } catch (error) { - console.error('Failed to load providers:', error); + setModelOptions(groupedOptions); + setOriginalModelOptions(groupedOptions); + } catch (error: unknown) { + console.error('Failed to query providers:', error); + } finally { + setLoadingModels(false); } })(); - }, [getProviders, usePredefinedModels, read]); + }, [getProviders, getProviderModels, usePredefinedModels, read]); // Filter model options based on selected provider const filteredModelOptions = provider @@ -347,7 +385,7 @@ export const AddModelModal = ({ onClose, setView }: AddModelModalProps) => { setIsCustomModel(false); } }} - placeholder="Provider" + placeholder="Provider, type to search" isClearable /> {attemptedSubmit && validationErrors.provider && ( @@ -360,12 +398,30 @@ export const AddModelModal = ({ onClose, setView }: AddModelModalProps) => { {!isCustomModel ? (