Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions ui/desktop/src/components/settings_v2/models/index.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
import { initializeSystem } from '../../../utils/providerUtils';
import { toastError, toastSuccess } from '../../../toasts';
import { ProviderDetails } from '@/src/api';
import { getProviderMetadata } from './modelInterface';
import Model, { getProviderMetadata } from './modelInterface';
import { ProviderMetadata } from '../../../api';
import type { ExtensionConfig, FixedExtensionEntry } from '../../ConfigContext';

// titles
const CHANGE_MODEL_TOAST_TITLE = 'Model selected';
const START_AGENT_TITLE = 'Initialize agent';
export const UNKNOWN_PROVIDER_TITLE = 'Provider name lookup';

// errors
const SWITCH_MODEL_AGENT_ERROR_MSG = 'Failed to start agent with selected model';
const CONFIG_UPDATE_ERROR_MSG = 'Failed to update configuration settings';
const CONFIG_READ_MODEL_ERROR_MSG = 'Failed to read GOOSE_MODEL or GOOSE_PROVIDER from config';
const CHANGE_MODEL_ERROR_TITLE = 'Change failed';
const SWITCH_MODEL_AGENT_ERROR_MSG =
'Failed to start agent with selected model -- please try again';
const CONFIG_UPDATE_ERROR_MSG = 'Failed to update configuration settings -- please try again';
export const UNKNOWN_PROVIDER_MSG = 'Unknown provider in config -- please inspect your config.yaml';

// success
const CHANGE_MODEL_TOAST_TITLE = 'Model changed';
const SWITCH_MODEL_SUCCESS_MSG = 'Successfully switched models';
const INITIALIZE_SYSTEM_WITH_MODEL_SUCCESS_MSG = 'Successfully started Goose';

interface changeModelProps {
model: string;
provider: string;
model: Model;
writeToConfig: (key: string, value: unknown, is_secret: boolean) => Promise<void>;
getExtensions?: (b: boolean) => Promise<FixedExtensionEntry[]>;
addExtension?: (name: string, config: ExtensionConfig, enabled: boolean) => Promise<void>;
Expand All @@ -31,20 +29,21 @@ interface changeModelProps {
// TODO: error handling
export async function changeModel({
model,
provider,
writeToConfig,
getExtensions,
addExtension,
}: changeModelProps) {
const modelName = model.name;
const providerName = model.provider;
try {
await initializeSystem(provider, model, {
await initializeSystem(providerName, modelName, {
getExtensions,
addExtension,
});
} catch (error) {
console.error(`Failed to change model at agent step -- ${model} ${provider}`);
console.error(`Failed to change model at agent step -- ${modelName} ${providerName}`);
toastError({
title: CHANGE_MODEL_TOAST_TITLE,
title: CHANGE_MODEL_ERROR_TITLE,
msg: SWITCH_MODEL_AGENT_ERROR_MSG,
traceback: error,
});
Expand All @@ -53,12 +52,12 @@ export async function changeModel({
}

try {
await writeToConfig('GOOSE_PROVIDER', provider, false);
await writeToConfig('GOOSE_MODEL', model, false);
await writeToConfig('GOOSE_PROVIDER', providerName, false);
await writeToConfig('GOOSE_MODEL', modelName, false);
} catch (error) {
console.error(`Failed to change model at config step -- ${model} ${provider}`);
console.error(`Failed to change model at config step -- ${modelName} ${providerName}}`);
toastError({
title: CHANGE_MODEL_TOAST_TITLE,
title: CHANGE_MODEL_ERROR_TITLE,
msg: CONFIG_UPDATE_ERROR_MSG,
traceback: error,
});
Expand All @@ -68,7 +67,7 @@ export async function changeModel({
// show toast
toastSuccess({
title: CHANGE_MODEL_TOAST_TITLE,
msg: `${SWITCH_MODEL_SUCCESS_MSG} -- using ${model} from ${provider}`,
msg: `${SWITCH_MODEL_SUCCESS_MSG} -- using ${model.alias ?? modelName} from ${model.subtext ?? providerName}`,
});
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react';
import Model from '../modelInterface';
import Model, { getProviderMetadata } from '../modelInterface';
import { useRecentModels } from './recentModels';
import { changeModel, getCurrentModelAndProvider } from '../index';
import { useConfig } from '../../../ConfigContext';
Expand All @@ -14,6 +14,7 @@ interface ModelRadioListProps {
providedModelList?: Model[];
}

// renders a model list and handles changing models when user clicks on them
export function BaseModelsList({
renderItem,
className = '',
Expand All @@ -28,9 +29,8 @@ export function BaseModelsList({
} else {
modelList = providedModelList;
}
const { read, upsert } = useConfig();
const [selectedModel, setSelectedModel] = useState<string | null>(null);
const [selectedProvider, setSelectedProvider] = useState<string | null>(null);
const { read, upsert, getProviders } = useConfig();
const [selectedModel, setSelectedModel] = useState<Model | null>(null);
const [isInitialized, setIsInitialized] = useState(false);

// Load current model/provider once on component mount
Expand All @@ -41,8 +41,18 @@ export function BaseModelsList({
try {
const result = await getCurrentModelAndProvider({ readFromConfig: read });
if (isMounted) {
setSelectedModel(result.model);
setSelectedProvider(result.provider);
// try to look up the model in the modelList
let currentModel: Model;
const match = modelList.find(
(model) => model.name == result.model && model.provider == result.provider
);
// no matches so just create a model object (maybe user updated config.yaml from CLI usage, manual editing etc)
if (!match) {
currentModel = { name: result.model, provider: result.provider };
} else {
currentModel = match;
}
setSelectedModel(currentModel);
setIsInitialized(true);
}
} catch (error) {
Expand All @@ -61,19 +71,21 @@ export function BaseModelsList({
}, [read]);

const handleModelSelection = async (modelName: string, providerName: string) => {
await changeModel({ model: modelName, provider: providerName, writeToConfig: upsert });
await changeModel({ model: selectedModel, writeToConfig: upsert });
};

// Updated to work with CustomRadio
const handleRadioChange = async (model: Model) => {
if (selectedModel === model.name) {
if (selectedModel.name === model.name && selectedModel.provider === model.provider) {
console.log(`Model "${model.name}" is already active.`);
return;
}

// Update local state immediately for UI feedback
setSelectedModel(model.name);
setSelectedProvider(model.provider);
const providerMetaData = await getProviderMetadata(model.provider, getProviders);
const providerDisplayName = providerMetaData.display_name;

// Update local state immediately for UI feedback and add in display name
setSelectedModel({ ...model, alias: providerDisplayName });

try {
await handleModelSelection(model.name, model.provider);
Expand All @@ -92,7 +104,7 @@ export function BaseModelsList({
{modelList.map((model) =>
renderItem({
model,
isSelected: selectedModel === model.name,
isSelected: selectedModel === model,
onSelect: () => handleRadioChange(model),
})
)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { Select } from '../../../ui/Select';
import { useConfig } from '../../../ConfigContext';
import { changeModel } from '../index';
import type { View } from '../../../../App';
import Model, { getProviderMetadata } from '../modelInterface';

const ModalButtons = ({ onSubmit, onCancel, isValid, validationErrors }) => (
<div>
Expand Down Expand Up @@ -77,9 +78,11 @@ export const AddModelModal = ({ onClose, setView }: AddModelModalProps) => {
const isFormValid = validateForm();

if (isFormValid) {
const providerMetaData = await getProviderMetadata(provider, getProviders);
const providerDisplayName = providerMetaData.display_name;

await changeModel({
model: model,
provider: provider,
model: { name: model, provider: provider, subtext: providerDisplayName } as Model, // pass in a Model object
writeToConfig: upsert,
getExtensions,
addExtension,
Expand Down
Loading