Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions ui/desktop/src/components/ProviderGuard.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useEffect, useState } from 'react';
import { useEffect, useState, useMemo } from 'react';
import { useNavigate } from 'react-router-dom';
import { useConfig } from './ConfigContext';
import { SetupModal } from './SetupModal';
Expand All @@ -8,6 +8,8 @@ import WelcomeGooseLogo from './WelcomeGooseLogo';
import { toastService } from '../toasts';
import { OllamaSetup } from './OllamaSetup';
import ApiKeyTester from './ApiKeyTester';
import { SwitchModelModal } from './settings/models/subcomponents/SwitchModelModal';
import { createNavigationHandler } from '../utils/navigationUtils';

import { Goose, OpenRouter, Tetrate } from './icons';

Expand All @@ -24,6 +26,10 @@ export default function ProviderGuard({ didSelectProvider, children }: ProviderG
const [showFirstTimeSetup, setShowFirstTimeSetup] = useState(false);
const [showOllamaSetup, setShowOllamaSetup] = useState(false);
const [userInActiveSetup, setUserInActiveSetup] = useState(false);
const [showSwitchModelModal, setShowSwitchModelModal] = useState(false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these don't strike me as independent booleans, might be good at some point to change them into an annotated enum like thing

const [switchModelProvider, setSwitchModelProvider] = useState<string | null>(null);

const setView = useMemo(() => createNavigationHandler(navigate), [navigate]);

const [openRouterSetupState, setOpenRouterSetupState] = useState<{
show: boolean;
Expand All @@ -45,18 +51,8 @@ export default function ProviderGuard({ didSelectProvider, children }: ProviderG
try {
const result = await startTetrateSetup();
if (result.success) {
setTetrateSetupState({
show: true,
title: 'Setup Complete!',
message: result.message,
showRetry: false,
autoClose: 3000,
});
setTimeout(() => {
setShowFirstTimeSetup(false);
setHasProvider(true);
navigate('/', { replace: true });
}, 3000);
setSwitchModelProvider('tetrate');
setShowSwitchModelModal(true);
} else {
setTetrateSetupState({
show: true,
Expand All @@ -76,34 +72,33 @@ export default function ProviderGuard({ didSelectProvider, children }: ProviderG
}
};

const handleApiKeySuccess = async (provider: string, model: string, apiKey: string) => {
const handleApiKeySuccess = async (provider: string, _model: string, apiKey: string) => {
const keyName = `${provider.toUpperCase()}_API_KEY`;
await upsert(keyName, apiKey, true);
await upsert('GOOSE_PROVIDER', provider, false);
await upsert('GOOSE_MODEL', model, false);

setSwitchModelProvider(provider);
setShowSwitchModelModal(true);
};

const handleModelSelected = () => {
setShowSwitchModelModal(false);
setUserInActiveSetup(false);
setShowFirstTimeSetup(false);
setHasProvider(true);
navigate('/', { replace: true });
};

const handleSwitchModelClose = () => {
setShowSwitchModelModal(false);
};

const handleOpenRouterSetup = async () => {
try {
const result = await startOpenRouterSetup();
if (result.success) {
setOpenRouterSetupState({
show: true,
title: 'Setup Complete!',
message: result.message,
showRetry: false,
autoClose: 3000,
});
setTimeout(() => {
setShowFirstTimeSetup(false);
setHasProvider(true);
navigate('/', { replace: true });
}, 3000);
setSwitchModelProvider('openrouter');
setShowSwitchModelModal(true);
} else {
setOpenRouterSetupState({
show: true,
Expand Down Expand Up @@ -337,6 +332,17 @@ export default function ProviderGuard({ didSelectProvider, children }: ProviderG
autoClose={tetrateSetupState.autoClose}
/>
)}

{showSwitchModelModal && (
<SwitchModelModal
sessionId={null}
onClose={handleSwitchModelClose}
setView={setView}
onModelSelected={handleModelSelected}
initialProvider={switchModelProvider}
titleOverride="Choose Model"
/>
)}
</div>
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { useEffect, useState, useCallback } from 'react';
import { ArrowLeftRight, ExternalLink } from 'lucide-react';
import { Bot, ExternalLink } from 'lucide-react';

import {
Dialog,
Expand All @@ -20,18 +20,66 @@ import Model, { getProviderMetadata, fetchModelsForProviders } from '../modelInt
import { getPredefinedModelsFromEnv, shouldShowPredefinedModels } from '../predefinedModelsUtils';
import { ProviderType } from '../../../../api';

const PREFERRED_MODEL_PATTERNS = [
/claude-sonnet-4/i,
/claude-4/i,
/gpt-4o(?!-mini)/i,
/claude-3-5-sonnet/i,
/claude-3\.5-sonnet/i,
/gpt-4-turbo/i,
/gpt-4(?!-|o)/i,
/claude-3-opus/i,
/claude-3-sonnet/i,
/gemini-pro/i,
/llama-3/i,
/gpt-4o-mini/i,
/claude-3-haiku/i,
/gemini/i,
];

function findPreferredModel(
models: { value: string; label: string; provider: string }[]
): string | null {
if (models.length === 0) return null;

const validModels = models.filter(
(m) => m.value !== 'custom' && m.value !== '__loading__' && !m.value.startsWith('__')
);

if (validModels.length === 0) return null;

for (const pattern of PREFERRED_MODEL_PATTERNS) {
const match = validModels.find((m) => pattern.test(m.value));
if (match) {
return match.value;
}
}

return validModels[0].value;
}

type SwitchModelModalProps = {
sessionId: string | null;
onClose: () => void;
setView: (view: View) => void;
onModelSelected?: () => void;
initialProvider?: string | null;
titleOverride?: string;
};
export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelModalProps) => {
export const SwitchModelModal = ({
sessionId,
onClose,
setView,
onModelSelected,
initialProvider,
titleOverride,
}: SwitchModelModalProps) => {
const { getProviders, getProviderModels, read } = useConfig();
const { changeModel } = useModelAndProvider();
const [providerOptions, setProviderOptions] = useState<{ value: string; label: string }[]>([]);
type ModelOption = { value: string; label: string; provider: string; isDisabled?: boolean };
const [modelOptions, setModelOptions] = useState<{ options: ModelOption[] }[]>([]);
const [provider, setProvider] = useState<string | null>(null);
const [provider, setProvider] = useState<string | null>(initialProvider || null);
const [model, setModel] = useState<string>('');
const [isCustomModel, setIsCustomModel] = useState(false);
const [validationErrors, setValidationErrors] = useState({
Expand Down Expand Up @@ -95,6 +143,9 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod
}

await changeModel(sessionId, modelObj);
if (onModelSelected) {
onModelSelected();
}
onClose();
}
};
Expand Down Expand Up @@ -209,11 +260,25 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod
})();
}, [getProviders, getProviderModels, usePredefinedModels, read]);

// Filter model options based on selected provider
const filteredModelOptions = provider
? modelOptions.filter((group) => group.options[0]?.provider === provider)
: [];

useEffect(() => {
if (!provider || loadingModels || model || isCustomModel) return;

const providerModels = modelOptions
.filter((group) => group.options[0]?.provider === provider)
.flatMap((group) => group.options);

if (providerModels.length > 0) {
const preferredModel = findPreferredModel(providerModels);
if (preferredModel) {
setModel(preferredModel);
}
}
}, [provider, modelOptions, loadingModels, model, isCustomModel]);

// Handle model selection change
const handleModelChange = (newValue: unknown) => {
const selectedOption = newValue as { value: string; label: string; provider: string } | null;
Expand Down Expand Up @@ -277,30 +342,16 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod
<DialogContent className="sm:max-w-[500px]">
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<ArrowLeftRight size={24} className="text-textStandard" />
Switch models
<Bot size={24} className="text-textStandard" />
{titleOverride || 'Switch models'}
</DialogTitle>
<DialogDescription>
Configure your AI model providers by adding their API keys. Your keys are stored
securely and encrypted locally.
Select a provider and model to use for your conversations.
</DialogDescription>
</DialogHeader>

<div className="flex flex-col gap-4 py-4">
<div>
<a
href={QUICKSTART_GUIDE_URL}
target="_blank"
rel="noopener noreferrer"
className="flex items-center text-textStandard font-medium text-sm"
>
<ExternalLink size={16} className="mr-1" />
View quick start guide
</a>
</div>

{usePredefinedModels ? (
/* Predefined Models Section */
<div className="w-full flex flex-col gap-4">
<div className="flex justify-between items-center">
<label className="text-sm font-medium text-textStandard">Choose a model:</label>
Expand Down Expand Up @@ -448,13 +499,24 @@ export const SwitchModelModal = ({ sessionId, onClose, setView }: SwitchModelMod
)}
</div>

<DialogFooter className="pt-2">
<Button variant="outline" onClick={handleClose} type="button">
Cancel
</Button>
<Button onClick={handleSubmit} disabled={!isValid}>
Select model
</Button>
<DialogFooter className="pt-4 flex-col sm:flex-row gap-3">
<a
href={QUICKSTART_GUIDE_URL}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center text-text-muted hover:text-textStandard text-sm mr-auto"
>
<ExternalLink size={14} className="mr-1" />
Quick start guide
</a>
<div className="flex gap-2">
<Button variant="outline" onClick={handleClose} type="button">
Cancel
</Button>
<Button onClick={handleSubmit} disabled={!isValid}>
Select model
</Button>
</div>
</DialogFooter>
</DialogContent>
</Dialog>
Expand Down
Loading
Loading