diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index c9d36338164f..b355b6f244d6 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -285,6 +285,42 @@ impl Provider for OllamaProvider { } })) } + + async fn fetch_supported_models(&self) -> Result>, ProviderError> { + let response = self + .api_client + .response_get("api/tags") + .await + .map_err(|e| ProviderError::RequestFailed(format!("Failed to fetch models: {}", e)))?; + + if !response.status().is_success() { + return Err(ProviderError::RequestFailed(format!( + "Failed to fetch models: HTTP {}", + response.status() + ))); + } + + let json_response = response.json::().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse response: {}", e)) + })?; + + let models = json_response + .get("models") + .and_then(|m| m.as_array()) + .ok_or_else(|| { + ProviderError::RequestFailed("No models array in response".to_string()) + })?; + + let mut model_names: Vec = models + .iter() + .filter_map(|model| model.get("name").and_then(|n| n.as_str()).map(String::from)) + .collect(); + + // Sort alphabetically + model_names.sort(); + + Ok(Some(model_names)) + } } impl OllamaProvider { diff --git a/ui/desktop/src/components/settings/models/modelInterface.ts b/ui/desktop/src/components/settings/models/modelInterface.ts index 4a397647dde8..b3cec2469e19 100644 --- a/ui/desktop/src/components/settings/models/modelInterface.ts +++ b/ui/desktop/src/components/settings/models/modelInterface.ts @@ -39,3 +39,38 @@ export async function getProviderMetadata( } return matches.metadata; } + +export interface ProviderModelsResult { + provider: ProviderDetails; + models: string[] | null; + error: string | null; +} + +/** + * Fetches models for all active providers in parallel. + * Falls back to known_models if fetching fails or returns no models. + */ +export async function fetchModelsForProviders( + activeProviders: ProviderDetails[], + getProviderModelsFunc: (providerName: string) => Promise +): Promise { + const modelPromises = activeProviders.map(async (p) => { + const providerName = p.name; + try { + let models = await getProviderModelsFunc(providerName); + if ((!models || models.length === 0) && p.metadata.known_models?.length) { + models = p.metadata.known_models.map((m) => m.name); + } + return { provider: p, models, error: null }; + } catch (e: unknown) { + const errorMessage = `Failed to fetch models for ${providerName}${e instanceof Error ? `: ${e.message}` : ''}`; + return { + provider: p, + models: null, + error: errorMessage, + }; + } + }); + + return await Promise.all(modelPromises); +} diff --git a/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.tsx b/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.tsx index 267f504d74a0..e58932ea5381 100644 --- a/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.tsx +++ b/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.tsx @@ -5,6 +5,7 @@ import { Button } from '../../../ui/button'; import { Select } from '../../../ui/Select'; import { Input } from '../../../ui/input'; import { getPredefinedModelsFromEnv, shouldShowPredefinedModels } from '../predefinedModelsUtils'; +import { fetchModelsForProviders } from '../modelInterface'; import { Dialog, DialogContent, DialogHeader, DialogTitle } from '../../../ui/dialog'; interface LeadWorkerSettingsProps { @@ -13,7 +14,7 @@ interface LeadWorkerSettingsProps { } export function LeadWorkerSettings({ isOpen, onClose }: LeadWorkerSettingsProps) { - const { read, upsert, getProviders, remove } = useConfig(); + const { read, upsert, getProviders, getProviderModels, remove } = useConfig(); const { currentModel } = useModelAndProvider(); const [leadModel, setLeadModel] = useState(''); const [workerModel, setWorkerModel] = useState(''); @@ -103,13 +104,18 @@ export function LeadWorkerSettings({ isOpen, onClose }: LeadWorkerSettingsProps) const providers = await getProviders(false); const activeProviders = providers.filter((p) => p.is_configured); - activeProviders.forEach(({ metadata, name }) => { - if (metadata.known_models) { - metadata.known_models.forEach((model) => { + const results = await fetchModelsForProviders(activeProviders, getProviderModels); + results.forEach(({ provider: p, models, error }) => { + if (error) { + console.error(error); + } + + if (models && models.length > 0) { + models.forEach((modelName) => { options.push({ - value: model.name, - label: `${model.name} (${metadata.display_name})`, - provider: name, + value: modelName, + label: `${modelName} (${p.metadata.display_name})`, + provider: p.name, }); }); } @@ -128,7 +134,7 @@ export function LeadWorkerSettings({ isOpen, onClose }: LeadWorkerSettingsProps) }; loadConfig(); - }, [read, getProviders, currentModel, isOpen]); + }, [read, getProviders, getProviderModels, currentModel, isOpen]); // If current models are not in the list (e.g., previously set to custom), switch to custom mode useEffect(() => { diff --git a/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx b/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx index 051ee3af375d..abcc270033b0 100644 --- a/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx +++ b/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx @@ -16,7 +16,7 @@ import { Select } from '../../../ui/Select'; import { useConfig } from '../../../ConfigContext'; import { useModelAndProvider } from '../../../ModelAndProviderContext'; import type { View } from '../../../../utils/navigationUtils'; -import Model, { getProviderMetadata } from '../modelInterface'; +import Model, { getProviderMetadata, fetchModelsForProviders } from '../modelInterface'; import { getPredefinedModelsFromEnv, shouldShowPredefinedModels } from '../predefinedModelsUtils'; import { ProviderType } from '../../../../api'; @@ -146,24 +146,7 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod setLoadingModels(true); // Fetching models for all providers - const modelPromises = activeProviders.map(async (p) => { - const providerName = p.name; - try { - let models = await getProviderModels(providerName); - // Fallback to known_models if server returned none - if ((!models || models.length === 0) && p.metadata.known_models?.length) { - models = p.metadata.known_models.map((m) => m.name); - } - return { provider: p, models, error: null }; - } catch (e: unknown) { - return { - provider: p, - models: null, - error: `Failed to fetch models for ${providerName}${e instanceof Error ? `: ${e.message}` : ''}`, - }; - } - }); - const results = await Promise.all(modelPromises); + const results = await fetchModelsForProviders(activeProviders, getProviderModels); // Process results and build grouped options const groupedOptions: {