Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,18 +254,27 @@ async fn start_agent(
}

if let Some(recipe) = original_recipe {
manager
.update(&session.id)
.recipe(Some(recipe))
.apply()
.await
.map_err(|err| {
error!("Failed to update session with recipe: {}", err);
ErrorResponse {
message: format!("Failed to update session with recipe: {}", err),
status: StatusCode::INTERNAL_SERVER_ERROR,
let mut update = manager.update(&session.id).recipe(Some(recipe.clone()));

if let Some(ref settings) = recipe.settings {
if let Some(ref provider) = settings.goose_provider {
update = update.provider_name(provider);

if let Some(ref model) = settings.goose_model {
if let Ok(model_config) = ModelConfig::new(model) {
update = update.model_config(model_config);
}
}
})?;
}
}

update.apply().await.map_err(|err| {
error!("Failed to update session with recipe: {}", err);
ErrorResponse {
message: format!("Failed to update session with recipe: {}", err),
status: StatusCode::INTERNAL_SERVER_ERROR,
}
})?;
}

// Refetch session to get all updates
Expand Down
3 changes: 2 additions & 1 deletion crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1577,7 +1577,8 @@ impl Agent {
None => {
let model_name = config
.get_goose_model()
.map_err(|_| anyhow!("Could not configure agent: missing model"))?;
.ok()
.ok_or_else(|| anyhow!("Could not configure agent: missing model"))?;
crate::model::ModelConfig::new(&model_name)
.map_err(|e| anyhow!("Could not configure agent: invalid model {}", e))?
}
Expand Down
8 changes: 8 additions & 0 deletions ui/desktop/src/components/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import { useToolCount } from './alerts/useToolCount';
import { getThinkingMessage, getTextAndImageContent } from '../types/message';
import ParameterInputModal from './ParameterInputModal';
import { substituteParameters } from '../utils/providerUtils';
import { useModelAndProvider } from './ModelAndProviderContext';
import CreateRecipeFromSessionModal from './recipes/CreateRecipeFromSessionModal';
import { toastSuccess } from '../toasts';
import { Recipe } from '../recipe';
Expand Down Expand Up @@ -176,6 +177,13 @@ export default function BaseChat({
});

const recipe = session?.recipe;
const { setProviderAndModel } = useModelAndProvider();

useEffect(() => {
if (session?.provider_name && session?.model_config?.model_name) {
setProviderAndModel(session.provider_name, session.model_config.model_name);
}
}, [session?.provider_name, session?.model_config?.model_name, setProviderAndModel]);

useEffect(() => {
if (!recipe) return;
Expand Down
8 changes: 8 additions & 0 deletions ui/desktop/src/components/ModelAndProviderContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ interface ModelAndProviderContextType {
getCurrentModelDisplayName: () => Promise<string>;
getCurrentProviderDisplayName: () => Promise<string>; // Gets provider display name from subtext
refreshCurrentModelAndProvider: () => Promise<void>;
setProviderAndModel: (provider: string, model: string) => void;
}

interface ModelAndProviderProviderProps {
Expand Down Expand Up @@ -173,6 +174,11 @@ export const ModelAndProviderProvider: React.FC<ModelAndProviderProviderProps> =
}
}, [getCurrentModelAndProvider]);

const setProviderAndModel = useCallback((provider: string, model: string) => {
setCurrentProvider(provider);
setCurrentModel(model);
}, []);

// Load initial model and provider on mount
useEffect(() => {
refreshCurrentModelAndProvider();
Expand All @@ -189,6 +195,7 @@ export const ModelAndProviderProvider: React.FC<ModelAndProviderProviderProps> =
getCurrentModelDisplayName,
getCurrentProviderDisplayName,
refreshCurrentModelAndProvider,
setProviderAndModel,
}),
[
currentModel,
Expand All @@ -200,6 +207,7 @@ export const ModelAndProviderProvider: React.FC<ModelAndProviderProviderProps> =
getCurrentModelDisplayName,
getCurrentProviderDisplayName,
refreshCurrentModelAndProvider,
setProviderAndModel,
]
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
import { useCurrentModelInfo } from '../../../BaseChat';
import { useConfig } from '../../../ConfigContext';
import { getProviderMetadata } from '../modelInterface';
import { getModelDisplayName } from '../predefinedModelsUtils';
import { Alert } from '../../../alerts';
import BottomMenuAlertPopover from '../../../bottom_menu/BottomMenuAlertPopover';

Expand All @@ -32,9 +33,6 @@ export default function ModelsBottomBar({
const {
currentModel,
currentProvider,
getCurrentModelAndProviderForDisplay,
getCurrentModelDisplayName,
getCurrentProviderDisplayName,
} = useModelAndProvider();
const currentModelInfo = useCurrentModelInfo();
const { read, getProviders } = useConfig();
Expand Down Expand Up @@ -62,7 +60,6 @@ export default function ModelsBottomBar({
// Refresh lead/worker status when modal closes
const handleLeadWorkerModalClose = () => {
setIsLeadWorkerModalOpen(false);
// Refresh the lead/worker status after modal closes
const checkLeadWorker = async () => {
try {
const leadModel = await read('GOOSE_LEAD_MODEL', false);
Expand All @@ -78,8 +75,6 @@ export default function ModelsBottomBar({
checkLeadWorker();
};

// Since currentModelInfo.mode is not working, let's determine mode differently
// We'll need to get the lead model and compare it with the current model
const [leadModelName, setLeadModelName] = useState<string>('');
const [currentActiveModel, setCurrentActiveModel] = useState<string>('');

Expand Down Expand Up @@ -111,20 +106,16 @@ export default function ModelsBottomBar({
? currentModelInfo.model
: currentModel || providerDefaultModel || displayModelName;

// Update display provider when current provider changes
useEffect(() => {
if (currentProvider) {
(async () => {
const providerDisplayName = await getCurrentProviderDisplayName();
if (providerDisplayName) {
setDisplayProvider(providerDisplayName);
} else {
const modelProvider = await getCurrentModelAndProviderForDisplay();
setDisplayProvider(modelProvider.provider);
}
})();
}
}, [currentProvider, getCurrentProviderDisplayName, getCurrentModelAndProviderForDisplay]);
if (!currentProvider) return;
getProviderMetadata(currentProvider, getProviders)
.then((metadata) => {
setDisplayProvider(metadata.display_name || currentProvider);
})
.catch(() => {
setDisplayProvider(currentProvider);
});
}, [currentProvider, currentModel, getProviders]);

// Fetch provider default model when provider changes and no current model
useEffect(() => {
Expand All @@ -139,18 +130,14 @@ export default function ModelsBottomBar({
}
})();
} else if (currentModel) {
// Clear provider default when we have a current model
setProviderDefaultModel(null);
}
}, [currentProvider, currentModel, getProviders]);

// Update display model name when current model changes
useEffect(() => {
(async () => {
const displayName = await getCurrentModelDisplayName();
setDisplayModelName(displayName);
})();
}, [currentModel, getCurrentModelDisplayName]);
if (!currentModel) return;
setDisplayModelName(getModelDisplayName(currentModel));
}, [currentModel]);

return (
<div className="relative flex items-center" ref={dropdownRef}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,8 @@ export const SwitchModelModal = ({
const [providerOptions, setProviderOptions] = useState<{ value: string; label: string }[]>([]);
type ModelOption = { value: string; label: string; provider: string; isDisabled?: boolean };
const [modelOptions, setModelOptions] = useState<{ options: ModelOption[] }[]>([]);
const [provider, setProvider] = useState<string | null>(
initialProvider || currentProvider || null
);
// Only use currentModel if we're not switching to a different provider
// Otherwise, let the auto-select logic pick an appropriate model for the new provider
const [model, setModel] = useState<string>(
initialProvider && initialProvider !== currentProvider ? '' : currentModel || ''
);
const [provider, setProvider] = useState<string | null>(initialProvider || currentProvider || null);
const [model, setModel] = useState<string>(currentModel || '');
const [isCustomModel, setIsCustomModel] = useState(false);
const [validationErrors, setValidationErrors] = useState({
provider: '',
Expand Down Expand Up @@ -172,12 +166,10 @@ export const SwitchModelModal = ({
}

await changeModel(sessionId, modelObj);
onModelSelected?.(modelObj.name);

trackModelChanged(modelObj.provider || '', modelObj.name);

if (onModelSelected) {
onModelSelected(modelObj.name);
}
onClose();
}
};
Expand Down Expand Up @@ -212,8 +204,7 @@ export const SwitchModelModal = ({
// Load providers for manual model selection
(async () => {
try {
// Force refresh if initialProvider is set (OAuth flow needs fresh data)
const providersResponse = await getProviders(!!initialProvider);
const providersResponse = await getProviders(false);
const activeProviders = providersResponse.filter((provider) => provider.is_configured);
// Create provider options and add "Use other provider" option
setProviderOptions([
Expand Down Expand Up @@ -282,7 +273,7 @@ export const SwitchModelModal = ({
setLoadingModels(false);
}
})();
}, [getProviders, usePredefinedModels, read, initialProvider]);
}, [getProviders, usePredefinedModels, read]);

const filteredModelOptions = provider
? modelOptions.filter((group) => group.options[0]?.provider === provider)
Expand Down