diff --git a/ui/desktop/src/components/settings_v2/models/index.ts b/ui/desktop/src/components/settings_v2/models/index.ts index 879cf92d09cc..b7bb1cced965 100644 --- a/ui/desktop/src/components/settings_v2/models/index.ts +++ b/ui/desktop/src/components/settings_v2/models/index.ts @@ -1,28 +1,26 @@ import { initializeSystem } from '../../../utils/providerUtils'; import { toastError, toastSuccess } from '../../../toasts'; import { ProviderDetails } from '@/src/api'; -import { getProviderMetadata } from './modelInterface'; +import Model, { getProviderMetadata } from './modelInterface'; import { ProviderMetadata } from '../../../api'; import type { ExtensionConfig, FixedExtensionEntry } from '../../ConfigContext'; // titles -const CHANGE_MODEL_TOAST_TITLE = 'Model selected'; -const START_AGENT_TITLE = 'Initialize agent'; export const UNKNOWN_PROVIDER_TITLE = 'Provider name lookup'; // errors -const SWITCH_MODEL_AGENT_ERROR_MSG = 'Failed to start agent with selected model'; -const CONFIG_UPDATE_ERROR_MSG = 'Failed to update configuration settings'; -const CONFIG_READ_MODEL_ERROR_MSG = 'Failed to read GOOSE_MODEL or GOOSE_PROVIDER from config'; +const CHANGE_MODEL_ERROR_TITLE = 'Change failed'; +const SWITCH_MODEL_AGENT_ERROR_MSG = + 'Failed to start agent with selected model -- please try again'; +const CONFIG_UPDATE_ERROR_MSG = 'Failed to update configuration settings -- please try again'; export const UNKNOWN_PROVIDER_MSG = 'Unknown provider in config -- please inspect your config.yaml'; // success +const CHANGE_MODEL_TOAST_TITLE = 'Model changed'; const SWITCH_MODEL_SUCCESS_MSG = 'Successfully switched models'; -const INITIALIZE_SYSTEM_WITH_MODEL_SUCCESS_MSG = 'Successfully started Goose'; interface changeModelProps { - model: string; - provider: string; + model: Model; writeToConfig: (key: string, value: unknown, is_secret: boolean) => Promise; getExtensions?: (b: boolean) => Promise; addExtension?: (name: string, config: ExtensionConfig, enabled: boolean) => Promise; @@ -31,20 +29,21 @@ interface changeModelProps { // TODO: error handling export async function changeModel({ model, - provider, writeToConfig, getExtensions, addExtension, }: changeModelProps) { + const modelName = model.name; + const providerName = model.provider; try { - await initializeSystem(provider, model, { + await initializeSystem(providerName, modelName, { getExtensions, addExtension, }); } catch (error) { - console.error(`Failed to change model at agent step -- ${model} ${provider}`); + console.error(`Failed to change model at agent step -- ${modelName} ${providerName}`); toastError({ - title: CHANGE_MODEL_TOAST_TITLE, + title: CHANGE_MODEL_ERROR_TITLE, msg: SWITCH_MODEL_AGENT_ERROR_MSG, traceback: error, }); @@ -53,12 +52,12 @@ export async function changeModel({ } try { - await writeToConfig('GOOSE_PROVIDER', provider, false); - await writeToConfig('GOOSE_MODEL', model, false); + await writeToConfig('GOOSE_PROVIDER', providerName, false); + await writeToConfig('GOOSE_MODEL', modelName, false); } catch (error) { - console.error(`Failed to change model at config step -- ${model} ${provider}`); + console.error(`Failed to change model at config step -- ${modelName} ${providerName}}`); toastError({ - title: CHANGE_MODEL_TOAST_TITLE, + title: CHANGE_MODEL_ERROR_TITLE, msg: CONFIG_UPDATE_ERROR_MSG, traceback: error, }); @@ -68,7 +67,7 @@ export async function changeModel({ // show toast toastSuccess({ title: CHANGE_MODEL_TOAST_TITLE, - msg: `${SWITCH_MODEL_SUCCESS_MSG} -- using ${model} from ${provider}`, + msg: `${SWITCH_MODEL_SUCCESS_MSG} -- using ${model.alias ?? modelName} from ${model.subtext ?? providerName}`, }); } } diff --git a/ui/desktop/src/components/settings_v2/models/model_list/BaseModelsList.tsx b/ui/desktop/src/components/settings_v2/models/model_list/BaseModelsList.tsx index 075df28027cf..ff07aad8d746 100644 --- a/ui/desktop/src/components/settings_v2/models/model_list/BaseModelsList.tsx +++ b/ui/desktop/src/components/settings_v2/models/model_list/BaseModelsList.tsx @@ -1,5 +1,5 @@ import React, { useEffect, useState } from 'react'; -import Model from '../modelInterface'; +import Model, { getProviderMetadata } from '../modelInterface'; import { useRecentModels } from './recentModels'; import { changeModel, getCurrentModelAndProvider } from '../index'; import { useConfig } from '../../../ConfigContext'; @@ -14,6 +14,7 @@ interface ModelRadioListProps { providedModelList?: Model[]; } +// renders a model list and handles changing models when user clicks on them export function BaseModelsList({ renderItem, className = '', @@ -28,9 +29,8 @@ export function BaseModelsList({ } else { modelList = providedModelList; } - const { read, upsert } = useConfig(); - const [selectedModel, setSelectedModel] = useState(null); - const [selectedProvider, setSelectedProvider] = useState(null); + const { read, upsert, getProviders } = useConfig(); + const [selectedModel, setSelectedModel] = useState(null); const [isInitialized, setIsInitialized] = useState(false); // Load current model/provider once on component mount @@ -41,8 +41,18 @@ export function BaseModelsList({ try { const result = await getCurrentModelAndProvider({ readFromConfig: read }); if (isMounted) { - setSelectedModel(result.model); - setSelectedProvider(result.provider); + // try to look up the model in the modelList + let currentModel: Model; + const match = modelList.find( + (model) => model.name == result.model && model.provider == result.provider + ); + // no matches so just create a model object (maybe user updated config.yaml from CLI usage, manual editing etc) + if (!match) { + currentModel = { name: result.model, provider: result.provider }; + } else { + currentModel = match; + } + setSelectedModel(currentModel); setIsInitialized(true); } } catch (error) { @@ -61,19 +71,21 @@ export function BaseModelsList({ }, [read]); const handleModelSelection = async (modelName: string, providerName: string) => { - await changeModel({ model: modelName, provider: providerName, writeToConfig: upsert }); + await changeModel({ model: selectedModel, writeToConfig: upsert }); }; // Updated to work with CustomRadio const handleRadioChange = async (model: Model) => { - if (selectedModel === model.name) { + if (selectedModel.name === model.name && selectedModel.provider === model.provider) { console.log(`Model "${model.name}" is already active.`); return; } - // Update local state immediately for UI feedback - setSelectedModel(model.name); - setSelectedProvider(model.provider); + const providerMetaData = await getProviderMetadata(model.provider, getProviders); + const providerDisplayName = providerMetaData.display_name; + + // Update local state immediately for UI feedback and add in display name + setSelectedModel({ ...model, alias: providerDisplayName }); try { await handleModelSelection(model.name, model.provider); @@ -92,7 +104,7 @@ export function BaseModelsList({ {modelList.map((model) => renderItem({ model, - isSelected: selectedModel === model.name, + isSelected: selectedModel === model, onSelect: () => handleRadioChange(model), }) )} diff --git a/ui/desktop/src/components/settings_v2/models/subcomponents/AddModelModal.tsx b/ui/desktop/src/components/settings_v2/models/subcomponents/AddModelModal.tsx index a552e5dcb236..b813240092fd 100644 --- a/ui/desktop/src/components/settings_v2/models/subcomponents/AddModelModal.tsx +++ b/ui/desktop/src/components/settings_v2/models/subcomponents/AddModelModal.tsx @@ -9,6 +9,7 @@ import { Select } from '../../../ui/Select'; import { useConfig } from '../../../ConfigContext'; import { changeModel } from '../index'; import type { View } from '../../../../App'; +import Model, { getProviderMetadata } from '../modelInterface'; const ModalButtons = ({ onSubmit, onCancel, isValid, validationErrors }) => (
@@ -77,9 +78,11 @@ export const AddModelModal = ({ onClose, setView }: AddModelModalProps) => { const isFormValid = validateForm(); if (isFormValid) { + const providerMetaData = await getProviderMetadata(provider, getProviders); + const providerDisplayName = providerMetaData.display_name; + await changeModel({ - model: model, - provider: provider, + model: { name: model, provider: provider, subtext: providerDisplayName } as Model, // pass in a Model object writeToConfig: upsert, getExtensions, addExtension,