diff --git a/ui/desktop/src/components/ProviderGuard.tsx b/ui/desktop/src/components/ProviderGuard.tsx index 76e789d6e167..227be70e49bd 100644 --- a/ui/desktop/src/components/ProviderGuard.tsx +++ b/ui/desktop/src/components/ProviderGuard.tsx @@ -1,4 +1,4 @@ -import { useEffect, useState } from 'react'; +import { useEffect, useState, useMemo } from 'react'; import { useNavigate } from 'react-router-dom'; import { useConfig } from './ConfigContext'; import { SetupModal } from './SetupModal'; @@ -8,6 +8,8 @@ import WelcomeGooseLogo from './WelcomeGooseLogo'; import { toastService } from '../toasts'; import { OllamaSetup } from './OllamaSetup'; import ApiKeyTester from './ApiKeyTester'; +import { SwitchModelModal } from './settings/models/subcomponents/SwitchModelModal'; +import { createNavigationHandler } from '../utils/navigationUtils'; import { Goose, OpenRouter, Tetrate } from './icons'; @@ -24,6 +26,10 @@ export default function ProviderGuard({ didSelectProvider, children }: ProviderG const [showFirstTimeSetup, setShowFirstTimeSetup] = useState(false); const [showOllamaSetup, setShowOllamaSetup] = useState(false); const [userInActiveSetup, setUserInActiveSetup] = useState(false); + const [showSwitchModelModal, setShowSwitchModelModal] = useState(false); + const [switchModelProvider, setSwitchModelProvider] = useState(null); + + const setView = useMemo(() => createNavigationHandler(navigate), [navigate]); const [openRouterSetupState, setOpenRouterSetupState] = useState<{ show: boolean; @@ -45,18 +51,8 @@ export default function ProviderGuard({ didSelectProvider, children }: ProviderG try { const result = await startTetrateSetup(); if (result.success) { - setTetrateSetupState({ - show: true, - title: 'Setup Complete!', - message: result.message, - showRetry: false, - autoClose: 3000, - }); - setTimeout(() => { - setShowFirstTimeSetup(false); - setHasProvider(true); - navigate('/', { replace: true }); - }, 3000); + setSwitchModelProvider('tetrate'); + setShowSwitchModelModal(true); } else { setTetrateSetupState({ show: true, @@ -76,34 +72,33 @@ export default function ProviderGuard({ didSelectProvider, children }: ProviderG } }; - const handleApiKeySuccess = async (provider: string, model: string, apiKey: string) => { + const handleApiKeySuccess = async (provider: string, _model: string, apiKey: string) => { const keyName = `${provider.toUpperCase()}_API_KEY`; await upsert(keyName, apiKey, true); await upsert('GOOSE_PROVIDER', provider, false); - await upsert('GOOSE_MODEL', model, false); + setSwitchModelProvider(provider); + setShowSwitchModelModal(true); + }; + + const handleModelSelected = () => { + setShowSwitchModelModal(false); setUserInActiveSetup(false); setShowFirstTimeSetup(false); setHasProvider(true); navigate('/', { replace: true }); }; + const handleSwitchModelClose = () => { + setShowSwitchModelModal(false); + }; + const handleOpenRouterSetup = async () => { try { const result = await startOpenRouterSetup(); if (result.success) { - setOpenRouterSetupState({ - show: true, - title: 'Setup Complete!', - message: result.message, - showRetry: false, - autoClose: 3000, - }); - setTimeout(() => { - setShowFirstTimeSetup(false); - setHasProvider(true); - navigate('/', { replace: true }); - }, 3000); + setSwitchModelProvider('openrouter'); + setShowSwitchModelModal(true); } else { setOpenRouterSetupState({ show: true, @@ -337,6 +332,17 @@ export default function ProviderGuard({ didSelectProvider, children }: ProviderG autoClose={tetrateSetupState.autoClose} /> )} + + {showSwitchModelModal && ( + + )} ); } diff --git a/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx b/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx index abcc270033b0..73bb0c8bcb1a 100644 --- a/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx +++ b/ui/desktop/src/components/settings/models/subcomponents/SwitchModelModal.tsx @@ -1,5 +1,5 @@ import { useEffect, useState, useCallback } from 'react'; -import { ArrowLeftRight, ExternalLink } from 'lucide-react'; +import { Bot, ExternalLink } from 'lucide-react'; import { Dialog, @@ -20,18 +20,66 @@ import Model, { getProviderMetadata, fetchModelsForProviders } from '../modelInt import { getPredefinedModelsFromEnv, shouldShowPredefinedModels } from '../predefinedModelsUtils'; import { ProviderType } from '../../../../api'; +const PREFERRED_MODEL_PATTERNS = [ + /claude-sonnet-4/i, + /claude-4/i, + /gpt-4o(?!-mini)/i, + /claude-3-5-sonnet/i, + /claude-3\.5-sonnet/i, + /gpt-4-turbo/i, + /gpt-4(?!-|o)/i, + /claude-3-opus/i, + /claude-3-sonnet/i, + /gemini-pro/i, + /llama-3/i, + /gpt-4o-mini/i, + /claude-3-haiku/i, + /gemini/i, +]; + +function findPreferredModel( + models: { value: string; label: string; provider: string }[] +): string | null { + if (models.length === 0) return null; + + const validModels = models.filter( + (m) => m.value !== 'custom' && m.value !== '__loading__' && !m.value.startsWith('__') + ); + + if (validModels.length === 0) return null; + + for (const pattern of PREFERRED_MODEL_PATTERNS) { + const match = validModels.find((m) => pattern.test(m.value)); + if (match) { + return match.value; + } + } + + return validModels[0].value; +} + type SwitchModelModalProps = { sessionId: string | null; onClose: () => void; setView: (view: View) => void; + onModelSelected?: () => void; + initialProvider?: string | null; + titleOverride?: string; }; -export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelModalProps) => { +export const SwitchModelModal = ({ + sessionId, + onClose, + setView, + onModelSelected, + initialProvider, + titleOverride, +}: SwitchModelModalProps) => { const { getProviders, getProviderModels, read } = useConfig(); const { changeModel } = useModelAndProvider(); const [providerOptions, setProviderOptions] = useState<{ value: string; label: string }[]>([]); type ModelOption = { value: string; label: string; provider: string; isDisabled?: boolean }; const [modelOptions, setModelOptions] = useState<{ options: ModelOption[] }[]>([]); - const [provider, setProvider] = useState(null); + const [provider, setProvider] = useState(initialProvider || null); const [model, setModel] = useState(''); const [isCustomModel, setIsCustomModel] = useState(false); const [validationErrors, setValidationErrors] = useState({ @@ -95,6 +143,9 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod } await changeModel(sessionId, modelObj); + if (onModelSelected) { + onModelSelected(); + } onClose(); } }; @@ -209,11 +260,25 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod })(); }, [getProviders, getProviderModels, usePredefinedModels, read]); - // Filter model options based on selected provider const filteredModelOptions = provider ? modelOptions.filter((group) => group.options[0]?.provider === provider) : []; + useEffect(() => { + if (!provider || loadingModels || model || isCustomModel) return; + + const providerModels = modelOptions + .filter((group) => group.options[0]?.provider === provider) + .flatMap((group) => group.options); + + if (providerModels.length > 0) { + const preferredModel = findPreferredModel(providerModels); + if (preferredModel) { + setModel(preferredModel); + } + } + }, [provider, modelOptions, loadingModels, model, isCustomModel]); + // Handle model selection change const handleModelChange = (newValue: unknown) => { const selectedOption = newValue as { value: string; label: string; provider: string } | null; @@ -277,30 +342,16 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod - - Switch models + + {titleOverride || 'Switch models'} - Configure your AI model providers by adding their API keys. Your keys are stored - securely and encrypted locally. + Select a provider and model to use for your conversations.
- - {usePredefinedModels ? ( - /* Predefined Models Section */
@@ -448,13 +499,24 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod )}
- - - + + + + Quick start guide + +
+ + +
diff --git a/ui/desktop/src/components/settings/providers/ProviderGrid.tsx b/ui/desktop/src/components/settings/providers/ProviderGrid.tsx index f970e9335678..55683122ec50 100644 --- a/ui/desktop/src/components/settings/providers/ProviderGrid.tsx +++ b/ui/desktop/src/components/settings/providers/ProviderGrid.tsx @@ -10,6 +10,8 @@ import { import { Plus } from 'lucide-react'; import { Dialog, DialogContent, DialogHeader, DialogTitle } from '../../ui/dialog'; import CustomProviderForm from './modal/subcomponents/forms/CustomProviderForm'; +import { SwitchModelModal } from '../models/subcomponents/SwitchModelModal'; +import type { View } from '../../../utils/navigationUtils'; const GridLayout = memo(function GridLayout({ children }: { children: React.ReactNode }) { return ( @@ -50,21 +52,30 @@ function ProviderCards({ providers, isOnboarding, refreshProviders, - onProviderLaunch, + setView, + onModelSelected, }: { providers: ProviderDetails[]; isOnboarding: boolean; refreshProviders?: () => void; - onProviderLaunch: (provider: ProviderDetails) => void; + setView?: (view: View) => void; + onModelSelected?: () => void; }) { const [configuringProvider, setConfiguringProvider] = useState(null); const [showCustomProviderModal, setShowCustomProviderModal] = useState(false); + const [showSwitchModelModal, setShowSwitchModelModal] = useState(false); + const [switchModelProvider, setSwitchModelProvider] = useState(null); const [editingProvider, setEditingProvider] = useState<{ id: string; config: DeclarativeProviderConfig; isEditable: boolean; } | null>(null); + const handleProviderLaunchWithModelSelection = useCallback((provider: ProviderDetails) => { + setSwitchModelProvider(provider.name); + setShowSwitchModelModal(true); + }, []); + const openModal = useCallback( (provider: ProviderDetails) => setConfiguringProvider(provider), [] @@ -101,11 +112,14 @@ function ProviderCards({ body: data, throwOnError: true, }); + const providerId = editingProvider.id; setShowCustomProviderModal(false); setEditingProvider(null); if (refreshProviders) { refreshProviders(); } + setSwitchModelProvider(providerId); + setShowSwitchModelModal(true); }, [editingProvider, refreshProviders] ); @@ -122,6 +136,32 @@ function ProviderCards({ } }, [refreshProviders]); + const onProviderConfigured = useCallback( + (provider: ProviderDetails) => { + setConfiguringProvider(null); + if (refreshProviders) { + refreshProviders(); + } + setSwitchModelProvider(provider.name); + setShowSwitchModelModal(true); + }, + [refreshProviders] + ); + + const onCloseSwitchModelModal = useCallback(() => { + setShowSwitchModelModal(false); + }, []); + + const handleSetView = useCallback( + (view: View) => { + setShowSwitchModelModal(false); + if (setView) { + setView(view); + } + }, + [setView] + ); + const handleCreateCustomProvider = useCallback( async (data: UpdateCustomProviderRequest) => { const { createCustomProvider } = await import('../../../api'); @@ -130,6 +170,7 @@ function ProviderCards({ if (refreshProviders) { refreshProviders(); } + setShowSwitchModelModal(true); }, [refreshProviders] ); @@ -144,7 +185,7 @@ function ProviderCards({ key={provider.name} provider={provider} onConfigure={() => configureProviderViaModal(provider)} - onLaunch={() => onProviderLaunch(provider)} + onLaunch={() => handleProviderLaunchWithModelSelection(provider)} isOnboarding={isOnboarding} /> )); @@ -154,7 +195,7 @@ function ProviderCards({ ); return cards; - }, [providers, isOnboarding, configureProviderViaModal, onProviderLaunch]); + }, [providers, isOnboarding, configureProviderViaModal, handleProviderLaunchWithModelSelection]); const initialData = editingProvider && { engine: editingProvider.config.engine.toLowerCase() + '_compatible', @@ -187,6 +228,17 @@ function ProviderCards({ + )} + {showSwitchModelModal && ( + )} @@ -197,12 +249,14 @@ export default function ProviderGrid({ providers, isOnboarding, refreshProviders, - onProviderLaunch, + setView, + onModelSelected, }: { providers: ProviderDetails[]; isOnboarding: boolean; refreshProviders?: () => void; - onProviderLaunch?: (provider: ProviderDetails) => void; + setView?: (view: View) => void; + onModelSelected?: () => void; }) { return ( @@ -210,7 +264,8 @@ export default function ProviderGrid({ providers={providers} isOnboarding={isOnboarding} refreshProviders={refreshProviders} - onProviderLaunch={onProviderLaunch || (() => {})} + setView={setView} + onModelSelected={onModelSelected} /> ); diff --git a/ui/desktop/src/components/settings/providers/ProviderSettingsPage.tsx b/ui/desktop/src/components/settings/providers/ProviderSettingsPage.tsx index 2110b6f60ab5..1c745a3fd3ba 100644 --- a/ui/desktop/src/components/settings/providers/ProviderSettingsPage.tsx +++ b/ui/desktop/src/components/settings/providers/ProviderSettingsPage.tsx @@ -1,10 +1,11 @@ -import { useEffect, useState, useCallback, useRef } from 'react'; +import { useEffect, useState, useCallback, useRef, useMemo } from 'react'; +import { useNavigate } from 'react-router-dom'; import { ScrollArea } from '../../ui/scroll-area'; import BackButton from '../../ui/BackButton'; import ProviderGrid from './ProviderGrid'; import { useConfig } from '../../ConfigContext'; -import { ProviderDetails, setConfigProvider } from '../../../api'; -import { toastService } from '../../../toasts'; +import { ProviderDetails } from '../../../api'; +import { createNavigationHandler } from '../../../utils/navigationUtils'; interface ProviderSettingsProps { onClose: () => void; @@ -18,10 +19,13 @@ export default function ProviderSettings({ onProviderLaunched, }: ProviderSettingsProps) { const { getProviders } = useConfig(); + const navigate = useNavigate(); const [loading, setLoading] = useState(true); const [providers, setProviders] = useState([]); const initialLoadDone = useRef(false); + const setView = useMemo(() => createNavigationHandler(navigate), [navigate]); + // Create a function to load providers that can be called multiple times const loadProviders = useCallback(async () => { setLoading(true); @@ -54,47 +58,6 @@ export default function ProviderSettings({ } }, [getProviders]); - // Handler for when a provider is launched if this component is used as part of onboarding page - const handleProviderLaunch = useCallback( - async (provider: ProviderDetails) => { - const provider_name = provider.name; - const model = provider.metadata.default_model; - - try { - await setConfigProvider({ - body: { - provider: provider_name, - model, - }, - throwOnError: true, - }); - - toastService.configure({ silent: false }); - toastService.success({ - title: 'Success!', - msg: `Started goose with ${model} by ${provider.metadata.display_name}. You can change the model via the dropdown.`, - }); - - if (onProviderLaunched) { - onProviderLaunched(); - } else { - onClose(); - } - } catch (error) { - console.error(`Failed to initialize with provider ${provider_name}:`, error); - - // Show error toast - toastService.configure({ silent: false }); - toastService.error({ - title: 'Initialization Failed', - msg: `Failed to initialize with ${provider.metadata.display_name}: ${error instanceof Error ? error.message : String(error)}`, - traceback: error instanceof Error ? error.stack || '' : '', - }); - } - }, - [onClose, onProviderLaunched] - ); - return (
@@ -127,8 +90,9 @@ export default function ProviderSettings({ )}
diff --git a/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx b/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx index 1858589eb855..a0b76483cad3 100644 --- a/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx +++ b/ui/desktop/src/components/settings/providers/modal/ProviderConfiguationModal.tsx @@ -23,11 +23,13 @@ import { Button } from '../../../../components/ui/button'; interface ProviderConfigurationModalProps { provider: ProviderDetails; onClose: () => void; + onConfigured?: (provider: ProviderDetails) => void; } export default function ProviderConfigurationModal({ provider, onClose, + onConfigured, }: ProviderConfigurationModalProps) { const [validationErrors, setValidationErrors] = useState>({}); const { upsert, remove } = useConfig(); @@ -83,7 +85,11 @@ export default function ProviderConfigurationModal({ try { await providerConfigSubmitHandler(upsert, provider, toSubmit); - onClose(); + if (onConfigured) { + onConfigured(provider); + } else { + onClose(); + } } catch (error) { setError(`${error}`); }