diff --git a/crates/goose-server/src/openapi.rs b/crates/goose-server/src/openapi.rs index db13b3d9f218..554e423de13b 100644 --- a/crates/goose-server/src/openapi.rs +++ b/crates/goose-server/src/openapi.rs @@ -321,6 +321,7 @@ derive_utoipa!(Icon as IconSchema); paths( super::routes::health::status, super::routes::config_management::backup_config, + super::routes::config_management::detect_provider, super::routes::config_management::recover_config, super::routes::config_management::validate_config, super::routes::config_management::init_config, @@ -379,6 +380,8 @@ derive_utoipa!(Icon as IconSchema); components(schemas( super::routes::config_management::UpsertConfigQuery, super::routes::config_management::ConfigKeyQuery, + super::routes::config_management::DetectProviderResponse, + super::routes::config_management::DetectProviderRequest, super::routes::config_management::ConfigResponse, super::routes::config_management::ProvidersResponse, super::routes::config_management::ProviderDetails, diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index d479147f648c..43bfdd4c822c 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -9,6 +9,7 @@ use goose::config::paths::Paths; use goose::config::ExtensionEntry; use goose::config::{Config, ConfigError}; use goose::model::ModelConfig; +use goose::providers::auto_detect::detect_provider_from_api_key; use goose::providers::base::ProviderMetadata; use goose::providers::pricing::{ get_all_pricing, get_model_pricing, parse_model_id, refresh_pricing, @@ -85,6 +86,16 @@ pub struct CreateCustomProviderRequest { pub supports_streaming: Option, } +#[derive(Deserialize, ToSchema)] +pub struct DetectProviderRequest { + pub api_key: String, +} + +#[derive(Serialize, ToSchema)] +pub struct DetectProviderResponse { + pub provider_name: String, + pub models: Vec, +} #[utoipa::path( post, path = "/config/upsert", @@ -662,6 +673,28 @@ pub async fn get_current_model() -> Result, StatusCode> { }))) } +#[utoipa::path( + post, + path = "/config/detect-provider", + request_body = DetectProviderRequest, + responses( + (status = 200, description = "Provider detected successfully", body = DetectProviderResponse), + (status = 404, description = "No matching provider found"), + (status = 500, description = "Internal server error") + ) +)] +pub async fn detect_provider( + Json(detect_request): Json, +) -> Result, StatusCode> { + match detect_provider_from_api_key(&detect_request.api_key).await { + Some((provider_name, models)) => Ok(Json(DetectProviderResponse { + provider_name, + models, + })), + None => Err(StatusCode::NOT_FOUND), + } +} + #[utoipa::path( post, path = "/config/custom-providers", @@ -725,6 +758,7 @@ pub fn routes(state: Arc) -> Router { .route("/config/extensions/{name}", delete(remove_extension)) .route("/config/providers", get(providers)) .route("/config/providers/{name}/models", get(get_provider_models)) + .route("/config/detect-provider", post(detect_provider)) .route("/config/pricing", post(get_pricing)) .route("/config/init", post(init_config)) .route("/config/backup", post(backup_config)) diff --git a/crates/goose/src/providers/auto_detect.rs b/crates/goose/src/providers/auto_detect.rs new file mode 100644 index 000000000000..3d97c7102b4d --- /dev/null +++ b/crates/goose/src/providers/auto_detect.rs @@ -0,0 +1,52 @@ +use crate::model::ModelConfig; + +pub async fn detect_provider_from_api_key(api_key: &str) -> Option<(String, Vec)> { + let provider_tests = vec![ + ("anthropic", "ANTHROPIC_API_KEY"), + ("openai", "OPENAI_API_KEY"), + ("google", "GOOGLE_API_KEY"), + ("groq", "GROQ_API_KEY"), + ("xai", "XAI_API_KEY"), + ("ollama", "OLLAMA_API_KEY"), + //("openrouter", "OPENROUTER_API_KEY"), Open Router seems to return the models also without a key + ]; + + let tasks: Vec<_> = provider_tests + .into_iter() + .map(|(provider_name, env_key)| { + let api_key = api_key.to_string(); + tokio::spawn(async move { + let original_value = std::env::var(env_key).ok(); + std::env::set_var(env_key, &api_key); + + let result = match crate::providers::create( + provider_name, + ModelConfig::new_or_fail("default"), + ) + .await + { + Ok(provider) => match provider.fetch_supported_models().await { + Ok(Some(models)) => Some((provider_name.to_string(), models)), + _ => None, + }, + Err(_) => None, + }; + + match original_value { + Some(val) => std::env::set_var(env_key, val), + None => std::env::remove_var(env_key), + } + + result + }) + }) + .collect(); + + for task in tasks { + if let Ok(Some(result)) = task.await { + return Some(result); + } + } + + None +} diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 52e8ba0185b9..9b26c25a1d1f 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -1,5 +1,6 @@ pub mod anthropic; mod api_client; +pub mod auto_detect; pub mod azure; pub mod azureauth; pub mod base; diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 15aeca146fc7..1795cc7e8f66 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -274,6 +274,25 @@ impl Provider for OllamaProvider { } })) } + + async fn fetch_supported_models(&self) -> Result>, ProviderError> { + let response = self.api_client.response_get("v1/models").await?; + + let json: Value = response.json().await?; + + let arr = match json.get("data").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => return Ok(None), + }; + + let mut models: Vec = arr + .iter() + .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string)) + .collect(); + + models.sort(); + Ok(Some(models)) + } } impl OllamaProvider { diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index a53946407b46..9564cfb0c2fc 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -444,6 +444,42 @@ } } }, + "/config/detect-provider": { + "post": { + "tags": [ + "super::routes::config_management" + ], + "operationId": "detect_provider", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DetectProviderRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Provider detected successfully", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DetectProviderResponse" + } + } + } + }, + "404": { + "description": "No matching provider found" + }, + "500": { + "description": "Internal server error" + } + } + } + }, "/config/extensions": { "get": { "tags": [ @@ -2329,6 +2365,35 @@ } } }, + "DetectProviderRequest": { + "type": "object", + "required": [ + "api_key" + ], + "properties": { + "api_key": { + "type": "string" + } + } + }, + "DetectProviderResponse": { + "type": "object", + "required": [ + "provider_name", + "models" + ], + "properties": { + "models": { + "type": "array", + "items": { + "type": "string" + } + }, + "provider_name": { + "type": "string" + } + } + }, "EmbeddedResource": { "type": "object", "required": [ diff --git a/ui/desktop/src/api/sdk.gen.ts b/ui/desktop/src/api/sdk.gen.ts index 13fe5d7899da..ba4c33de6ff0 100644 --- a/ui/desktop/src/api/sdk.gen.ts +++ b/ui/desktop/src/api/sdk.gen.ts @@ -1,7 +1,7 @@ // This file is auto-generated by @hey-api/openapi-ts import type { Options as ClientOptions, TDataShape, Client } from './client'; -import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, ResumeAgentData, ResumeAgentResponses, ResumeAgentErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, StartAgentData, StartAgentResponses, StartAgentErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, GetProviderModelsData, GetProviderModelsResponses, GetProviderModelsErrors, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, StartOpenrouterSetupData, StartOpenrouterSetupResponses, StartTetrateSetupData, StartTetrateSetupResponses, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, DeleteRecipeData, DeleteRecipeResponses, DeleteRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ListRecipesData, ListRecipesResponses, ListRecipesErrors, ParseRecipeData, ParseRecipeResponses, ParseRecipeErrors, SaveRecipeData, SaveRecipeResponses, SaveRecipeErrors, ScanRecipeData, ScanRecipeResponses, ReplyData, ReplyResponses, ReplyErrors, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, ImportSessionData, ImportSessionResponses, ImportSessionErrors, GetSessionInsightsData, GetSessionInsightsResponses, GetSessionInsightsErrors, DeleteSessionData, DeleteSessionResponses, DeleteSessionErrors, GetSessionData, GetSessionResponses, GetSessionErrors, UpdateSessionDescriptionData, UpdateSessionDescriptionResponses, UpdateSessionDescriptionErrors, ExportSessionData, ExportSessionResponses, ExportSessionErrors, UpdateSessionUserRecipeValuesData, UpdateSessionUserRecipeValuesResponses, UpdateSessionUserRecipeValuesErrors, StatusData, StatusResponses } from './types.gen'; +import type { AddSubRecipesData, AddSubRecipesResponses, AddSubRecipesErrors, ExtendPromptData, ExtendPromptResponses, ExtendPromptErrors, ResumeAgentData, ResumeAgentResponses, ResumeAgentErrors, UpdateSessionConfigData, UpdateSessionConfigResponses, UpdateSessionConfigErrors, StartAgentData, StartAgentResponses, StartAgentErrors, GetToolsData, GetToolsResponses, GetToolsErrors, UpdateAgentProviderData, UpdateAgentProviderResponses, UpdateAgentProviderErrors, UpdateRouterToolSelectorData, UpdateRouterToolSelectorResponses, UpdateRouterToolSelectorErrors, ReadAllConfigData, ReadAllConfigResponses, BackupConfigData, BackupConfigResponses, BackupConfigErrors, CreateCustomProviderData, CreateCustomProviderResponses, CreateCustomProviderErrors, RemoveCustomProviderData, RemoveCustomProviderResponses, RemoveCustomProviderErrors, DetectProviderData, DetectProviderResponses, DetectProviderErrors, GetExtensionsData, GetExtensionsResponses, GetExtensionsErrors, AddExtensionData, AddExtensionResponses, AddExtensionErrors, RemoveExtensionData, RemoveExtensionResponses, RemoveExtensionErrors, InitConfigData, InitConfigResponses, InitConfigErrors, UpsertPermissionsData, UpsertPermissionsResponses, UpsertPermissionsErrors, ProvidersData, ProvidersResponses, GetProviderModelsData, GetProviderModelsResponses, GetProviderModelsErrors, ReadConfigData, ReadConfigResponses, ReadConfigErrors, RecoverConfigData, RecoverConfigResponses, RecoverConfigErrors, RemoveConfigData, RemoveConfigResponses, RemoveConfigErrors, UpsertConfigData, UpsertConfigResponses, UpsertConfigErrors, ValidateConfigData, ValidateConfigResponses, ValidateConfigErrors, ConfirmPermissionData, ConfirmPermissionResponses, ConfirmPermissionErrors, ManageContextData, ManageContextResponses, ManageContextErrors, StartOpenrouterSetupData, StartOpenrouterSetupResponses, StartTetrateSetupData, StartTetrateSetupResponses, CreateRecipeData, CreateRecipeResponses, CreateRecipeErrors, DecodeRecipeData, DecodeRecipeResponses, DecodeRecipeErrors, DeleteRecipeData, DeleteRecipeResponses, DeleteRecipeErrors, EncodeRecipeData, EncodeRecipeResponses, EncodeRecipeErrors, ListRecipesData, ListRecipesResponses, ListRecipesErrors, ParseRecipeData, ParseRecipeResponses, ParseRecipeErrors, SaveRecipeData, SaveRecipeResponses, SaveRecipeErrors, ScanRecipeData, ScanRecipeResponses, ReplyData, ReplyResponses, ReplyErrors, CreateScheduleData, CreateScheduleResponses, CreateScheduleErrors, DeleteScheduleData, DeleteScheduleResponses, DeleteScheduleErrors, ListSchedulesData, ListSchedulesResponses, ListSchedulesErrors, UpdateScheduleData, UpdateScheduleResponses, UpdateScheduleErrors, InspectRunningJobData, InspectRunningJobResponses, InspectRunningJobErrors, KillRunningJobData, KillRunningJobResponses, PauseScheduleData, PauseScheduleResponses, PauseScheduleErrors, RunNowHandlerData, RunNowHandlerResponses, RunNowHandlerErrors, SessionsHandlerData, SessionsHandlerResponses, SessionsHandlerErrors, UnpauseScheduleData, UnpauseScheduleResponses, UnpauseScheduleErrors, ListSessionsData, ListSessionsResponses, ListSessionsErrors, ImportSessionData, ImportSessionResponses, ImportSessionErrors, GetSessionInsightsData, GetSessionInsightsResponses, GetSessionInsightsErrors, DeleteSessionData, DeleteSessionResponses, DeleteSessionErrors, GetSessionData, GetSessionResponses, GetSessionErrors, UpdateSessionDescriptionData, UpdateSessionDescriptionResponses, UpdateSessionDescriptionErrors, ExportSessionData, ExportSessionResponses, ExportSessionErrors, UpdateSessionUserRecipeValuesData, UpdateSessionUserRecipeValuesResponses, UpdateSessionUserRecipeValuesErrors, StatusData, StatusResponses } from './types.gen'; import { client as _heyApiClient } from './client.gen'; export type Options = ClientOptions & { @@ -134,6 +134,17 @@ export const removeCustomProvider = (optio }); }; +export const detectProvider = (options: Options) => { + return (options.client ?? _heyApiClient).post({ + url: '/config/detect-provider', + ...options, + headers: { + 'Content-Type': 'application/json', + ...options.headers + } + }); +}; + export const getExtensions = (options?: Options) => { return (options?.client ?? _heyApiClient).get({ url: '/config/extensions', diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index e4b3ccbeeda3..34802c47cb3e 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -148,6 +148,15 @@ export type DeleteRecipeRequest = { id: string; }; +export type DetectProviderRequest = { + api_key: string; +}; + +export type DetectProviderResponse = { + models: Array; + provider_name: string; +}; + export type EmbeddedResource = { _meta?: { [key: string]: unknown; @@ -1238,6 +1247,33 @@ export type RemoveCustomProviderResponses = { export type RemoveCustomProviderResponse = RemoveCustomProviderResponses[keyof RemoveCustomProviderResponses]; +export type DetectProviderData = { + body: DetectProviderRequest; + path?: never; + query?: never; + url: '/config/detect-provider'; +}; + +export type DetectProviderErrors = { + /** + * No matching provider found + */ + 404: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type DetectProviderResponses = { + /** + * Provider detected successfully + */ + 200: DetectProviderResponse; +}; + +export type DetectProviderResponse2 = DetectProviderResponses[keyof DetectProviderResponses]; + export type GetExtensionsData = { body?: never; path?: never; diff --git a/ui/desktop/src/components/settings/providers/ProviderSettingsPage.tsx b/ui/desktop/src/components/settings/providers/ProviderSettingsPage.tsx index 4b437ae95a4a..5dc426aecf97 100644 --- a/ui/desktop/src/components/settings/providers/ProviderSettingsPage.tsx +++ b/ui/desktop/src/components/settings/providers/ProviderSettingsPage.tsx @@ -3,7 +3,7 @@ import { ScrollArea } from '../../ui/scroll-area'; import BackButton from '../../ui/BackButton'; import ProviderGrid from './ProviderGrid'; import { useConfig } from '../../ConfigContext'; -import { ProviderDetails } from '../../../api'; +import { ProviderDetails, detectProvider } from '../../../api'; import { toastService } from '../../../toasts'; interface ProviderSettingsProps { @@ -16,12 +16,12 @@ export default function ProviderSettings({ onClose, isOnboarding }: ProviderSett const [loading, setLoading] = useState(true); const [providers, setProviders] = useState([]); const initialLoadDone = useRef(false); + const [testApiKey, setTestApiKey] = useState(''); + const [detectedProvider, setDetectedProvider] = useState(''); - // Create a function to load providers that can be called multiple times const loadProviders = useCallback(async () => { setLoading(true); try { - // Only force refresh when explicitly requested, not on initial load const result = await getProviders(!initialLoadDone.current); if (result) { setProviders(result); @@ -34,13 +34,10 @@ export default function ProviderSettings({ onClose, isOnboarding }: ProviderSett } }, [getProviders]); - // Load providers only once when component mounts useEffect(() => { loadProviders(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); // Intentionally not including loadProviders in deps to prevent reloading + }, [loadProviders]); - // This function will be passed to ProviderGrid for manual refreshes after config changes const refreshProviders = useCallback(() => { if (initialLoadDone.current) { getProviders(true).then((result) => { @@ -49,19 +46,15 @@ export default function ProviderSettings({ onClose, isOnboarding }: ProviderSett } }, [getProviders]); - // Handler for when a provider is launched if this component is used as part of onboarding page const handleProviderLaunch = useCallback( async (provider: ProviderDetails) => { const provider_name = provider.name; const model = provider.metadata.default_model; try { - // update the config - // set GOOSE_PROVIDER in the config file upsert('GOOSE_PROVIDER', provider_name, false).then((_) => console.log('Setting GOOSE_PROVIDER to', provider_name) ); - // set GOOSE_MODEL in the config file upsert('GOOSE_MODEL', model, false).then((_) => console.log('Setting GOOSE_MODEL to', model) ); @@ -76,7 +69,6 @@ export default function ProviderSettings({ onClose, isOnboarding }: ProviderSett } catch (error) { console.error(`Failed to initialize with provider ${provider_name}:`, error); - // Show error toast toastService.configure({ silent: false }); toastService.error({ title: 'Initialization Failed', @@ -88,11 +80,22 @@ export default function ProviderSettings({ onClose, isOnboarding }: ProviderSett [onClose, upsert] ); + const handleDetectProvider = async () => { + try { + const response = await detectProvider({ body: { api_key: testApiKey }, throwOnError: true }); + if (response.data) { + setDetectedProvider(response.data.provider_name); + } + } catch (error) { + console.error('Detection failed:', error); + setDetectedProvider('Error'); + } + }; + return (
- {/* Consistent header pattern with back button */}
@@ -110,8 +113,26 @@ export default function ProviderSettings({ onClose, isOnboarding }: ProviderSett
+
+ setTestApiKey(e.target.value)} + placeholder="Enter API key" + className="mr-3 p-2 w-80 border border-border-default bg-background-default text-text-default rounded" + /> + + {detectedProvider && ( +
Detected: {detectedProvider}
+ )} +
+
- {/* Content Area */}
{loading ? (