Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions crates/goose-cli/src/session/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,27 @@ pub struct SessionSettings {
}

pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession {
// Load config and get provider/model
let config = Config::global();

let (saved_provider, saved_model) = if session_config.resume {
if let Some(ref session_id) = session_config.session_id {
match SessionManager::get_session(session_id, false).await {
Ok(session_data) => (
session_data.provider_name,
session_data.model_config.map(|mc| mc.model_name),
),
Err(_) => (None, None),
}
} else {
(None, None)
}
} else {
(None, None)
};

let provider_name = session_config
.provider
.or(saved_provider)
.or_else(|| {
session_config
.settings
Expand All @@ -262,6 +278,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession {

let model_name = session_config
.model
.or(saved_model)
.or_else(|| {
session_config
.settings
Expand All @@ -280,7 +297,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession {
})
.with_temperature(temperature);

// Create the agent
let agent: Agent = Agent::new();

if let Some(sub_recipes) = session_config.sub_recipes {
Expand All @@ -304,10 +320,8 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession {
process::exit(1);
}
};
// Keep a reference to the provider for display_session_info
let provider_for_display = Arc::clone(&new_provider);

// Log model information at startup
if let Some(lead_worker) = new_provider.as_lead_worker() {
let (lead_model, worker_model) = lead_worker.get_model_info();
tracing::info!(
Expand Down Expand Up @@ -362,6 +376,14 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession {
session_config.session_id.unwrap()
};

agent
.persist_provider_config(&session_id)
.await
.unwrap_or_else(|e| {
output::render_error(&format!("Failed to save provider config: {}", e));
process::exit(1);
});

agent
.extension_manager
.set_context(PlatformExtensionContext {
Expand Down
2 changes: 2 additions & 0 deletions crates/goose-server/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use goose::agents::ExtensionConfig;
use goose::config::permission::PermissionLevel;
use goose::config::ExtensionEntry;
use goose::conversation::Conversation;
use goose::model::ModelConfig;
use goose::permission::permission_confirmation::PrincipalType;
use goose::providers::base::{ConfigKey, ModelInfo, ProviderMetadata, ProviderType};
use goose::session::{Session, SessionInsights, SessionType};
Expand Down Expand Up @@ -442,6 +443,7 @@ derive_utoipa!(Icon as IconSchema);
PermissionLevel,
PrincipalType,
ModelInfo,
ModelConfig,
Session,
SessionInsights,
SessionType,
Expand Down
18 changes: 14 additions & 4 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use futures::stream::BoxStream;
use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt};
use uuid::Uuid;
Expand Down Expand Up @@ -885,6 +885,7 @@ impl Agent {
if let Some((new_provider, role, model)) = autopilot.check_for_switch(&conversation, self.provider().await?).await? {
debug!("AutoPilot switching to {} role with model {}", role, model);
self.update_provider(new_provider).await?;
self.persist_provider_config(&session.id).await?;

yield AgentEvent::ModelChange {
model: model.clone(),
Expand Down Expand Up @@ -1198,10 +1199,19 @@ impl Agent {
pub async fn update_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
let mut current_provider = self.provider.lock().await;
*current_provider = Some(provider.clone());
drop(current_provider);

self.update_router_tool_selector(Some(provider), None)
.await?;
Ok(())
self.update_router_tool_selector(Some(provider), None).await
}

pub async fn persist_provider_config(&self, session_id: &str) -> Result<()> {
let provider = self.provider().await?;
SessionManager::update_session(session_id)
.provider_name(provider.get_name())
.model_config(provider.get_model_config())
.apply()
.await
.context("Failed to persist provider config to session")
}

pub async fn update_router_tool_selector(
Expand Down
5 changes: 5 additions & 0 deletions crates/goose/src/agents/subagent_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ fn get_agent_messages(
.await
.map_err(|e| anyhow!("Failed to set provider on sub agent: {}", e))?;

agent
.persist_provider_config(&session.id)
.await
.map_err(|e| anyhow!("Failed to persist provider config for sub agent: {}", e))?;

for extension in task_config.extensions {
if let Err(e) = agent.add_extension(extension.clone()).await {
debug!(
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/execution/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ impl AgentManager {
.await;
if let Some(provider) = &*self.default_provider.read().await {
agent.update_provider(Arc::clone(provider)).await?;
agent.persist_provider_config(&session_id).await?;
}

let mut sessions = self.sessions.write().await;
Expand Down
3 changes: 2 additions & 1 deletion crates/goose/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use utoipa::ToSchema;

const DEFAULT_CONTEXT_LIMIT: usize = 128_000;

Expand Down Expand Up @@ -67,7 +68,7 @@ static MODEL_SPECIFIC_LIMITS: Lazy<Vec<(&'static str, usize)>> = Lazy::new(|| {
]
});

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ModelConfig {
pub model_name: String,
pub context_limit: Option<usize>,
Expand Down
22 changes: 14 additions & 8 deletions crates/goose/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1152,12 +1152,6 @@ async fn run_scheduled_job_internal(
}
}

if let Err(e) = agent.update_provider(agent_provider).await {
return Err(JobExecutionError {
job_id: job.id.clone(),
error: format!("Failed to set provider on agent: {}", e),
});
}
tracing::info!("Agent configured with provider for job '{}'", job.id);

let current_dir = match std::env::current_dir() {
Expand Down Expand Up @@ -1186,15 +1180,27 @@ async fn run_scheduled_job_internal(
}
};

// Update the job with the session ID if we have access to the jobs arc
if let Err(e) = agent.update_provider(agent_provider).await {
return Err(JobExecutionError {
job_id: job.id.clone(),
error: format!("Failed to set provider on agent: {}", e),
});
}

if let Err(e) = agent.persist_provider_config(&session.id).await {
return Err(JobExecutionError {
job_id: job.id.clone(),
error: format!("Failed to save provider config to session: {}", e),
});
}

if let (Some(jobs_arc), Some(job_id_str)) = (jobs_arc.as_ref(), job_id.as_ref()) {
let mut jobs_guard = jobs_arc.lock().await;
if let Some((_, job_def)) = jobs_guard.get_mut(job_id_str) {
job_def.current_session_id = Some(session.id.clone());
}
}

// Use prompt if available, otherwise fall back to instructions
let prompt_text = recipe
.prompt
.as_ref()
Expand Down
Loading