diff --git a/crates/goose-cli/src/commands/acp.rs b/crates/goose-cli/src/commands/acp.rs index 37896208ee8b..4a26e0dea0a4 100644 --- a/crates/goose-cli/src/commands/acp.rs +++ b/crates/goose-cli/src/commands/acp.rs @@ -230,18 +230,22 @@ impl GooseAcpAgent { }; let provider = create(&provider_name, model_config).await?; - // Create a shared agent instance + let session = SessionManager::create_session( + std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), + "ACP Session".to_string(), + SessionType::Hidden, + ) + .await?; + let agent = Agent::new(); - agent.update_provider(provider.clone()).await?; + agent.update_provider(provider.clone(), &session.id).await?; - // Load and add extensions just like the normal CLI let extensions_to_run: Vec<_> = get_all_extensions() .into_iter() .filter(|ext| ext.enabled) .map(|ext| ext.config) .collect(); - // Add extensions to the agent in parallel let agent_ptr = Arc::new(agent); let mut set = JoinSet::new(); let mut waiting_on = HashSet::new(); @@ -257,7 +261,6 @@ impl GooseAcpAgent { }); } - // Wait for all extensions to load while let Some(result) = set.join_next().await { match result { Ok((name, Ok(_))) => { @@ -274,7 +277,6 @@ impl GooseAcpAgent { } } - // Unwrap the Arc to get the agent back let agent = Arc::try_unwrap(agent_ptr) .map_err(|_| anyhow::anyhow!("Failed to unwrap agent Arc"))?; diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index c28a9ce4f719..dd7ce318e587 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -21,6 +21,7 @@ use goose::conversation::message::Message; use goose::model::ModelConfig; use goose::providers::provider_test::test_provider_configuration; use goose::providers::{create, providers}; +use goose::session::{SessionManager, SessionType}; use serde_json::Value; use std::collections::HashMap; @@ -1368,7 +1369,6 @@ pub async fn configure_tool_permissions_dialog() -> anyhow::Result<()> { .collect(); extensions.push("platform".to_string()); - // Sort extensions alphabetically by name extensions.sort(); let selected_extension_name = cliclack::select("Choose an extension to configure tools") @@ -1380,8 +1380,6 @@ pub async fn configure_tool_permissions_dialog() -> anyhow::Result<()> { ) .interact()?; - // Fetch tools for the selected extension - // Load config and get provider/model let config = Config::global(); let provider_name: String = config @@ -1393,10 +1391,16 @@ pub async fn configure_tool_permissions_dialog() -> anyhow::Result<()> { .expect("No model configured. Please set model first"); let model_config = ModelConfig::new(&model)?; - // Create the agent + let session = SessionManager::create_session( + std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), + "Tool Permission Configuration".to_string(), + SessionType::Hidden, + ) + .await?; + let agent = Agent::new(); let new_provider = create(&provider_name, model_config).await?; - agent.update_provider(new_provider).await?; + agent.update_provider(new_provider, &session.id).await?; if let Some(config) = get_extension_by_name(&selected_extension_name) { agent .add_extension(config.clone()) diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index b47ddcdf8d85..e1e49f4b91f2 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -158,12 +158,17 @@ pub async fn handle_web( let model_config = goose::model::ModelConfig::new(&model)?; - // Create the agent + let init_session = SessionManager::create_session( + std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")), + "Web Agent Initialization".to_string(), + SessionType::Hidden, + ) + .await?; + let agent = Agent::new(); let provider = goose::providers::create(&provider_name, model_config).await?; - agent.update_provider(provider).await?; + agent.update_provider(provider, &init_session.id).await?; - // Load and enable extensions from config let enabled_configs = goose::config::get_enabled_extensions(); for config in enabled_configs { if let Err(e) = agent.add_extension(config.clone()).await { @@ -177,7 +182,6 @@ pub async fn handle_web( auth_token, }; - // Build router let app = Router::new() .route("/", get(serve_index)) .route("/session/{session_name}", get(serve_session)) diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index ccbbfba4fec4..5dae5c2b62db 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -217,16 +217,20 @@ where ) .await; - agent - .update_provider(provider_arc as Arc) - .await?; - let session = SessionManager::create_session( PathBuf::default(), "scenario-runner".to_string(), SessionType::Hidden, ) .await?; + + agent + .update_provider( + provider_arc as Arc, + &session.id, + ) + .await?; + let mut cli_session = CliSession::new( agent, session.id, diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 6845e3f2e222..8c9fc966139a 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -149,7 +149,15 @@ async fn offer_extension_debugging_help( // Create a minimal agent for debugging let debug_agent = Agent::new(); - debug_agent.update_provider(provider).await?; + + let session = SessionManager::create_session( + std::env::current_dir()?, + "CLI Session".to_string(), + SessionType::Hidden, + ) + .await?; + + debug_agent.update_provider(provider, &session.id).await?; // Add the developer extension if available to help with debugging let extensions = get_all_extensions(); @@ -166,12 +174,6 @@ async fn offer_extension_debugging_help( } } - let session = SessionManager::create_session( - std::env::current_dir()?, - "CLI Session".to_string(), - SessionType::Hidden, - ) - .await?; let mut debug_session = CliSession::new( debug_agent, session.id, @@ -246,11 +248,24 @@ pub struct SessionSettings { } pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { - // Load config and get provider/model let config = Config::global(); + let (saved_provider, saved_model_config) = if session_config.resume { + if let Some(ref session_id) = session_config.session_id { + match SessionManager::get_session(session_id, false).await { + Ok(session_data) => (session_data.provider_name, session_data.model_config), + Err(_) => (None, None), + } + } else { + (None, None) + } + } else { + (None, None) + }; + let provider_name = session_config .provider + .or(saved_provider) .or_else(|| { session_config .settings @@ -262,6 +277,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { let model_name = session_config .model + .or_else(|| saved_model_config.as_ref().map(|mc| mc.model_name.clone())) .or_else(|| { session_config .settings @@ -271,16 +287,26 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { .or_else(|| config.get_goose_model().ok()) .expect("No model configured. Run 'goose configure' first"); - let temperature = session_config.settings.as_ref().and_then(|s| s.temperature); - - let model_config = goose::model::ModelConfig::new(&model_name) - .unwrap_or_else(|e| { - output::render_error(&format!("Failed to create model configuration: {}", e)); - process::exit(1); - }) - .with_temperature(temperature); + let model_config = if session_config.resume + && saved_model_config + .as_ref() + .is_some_and(|mc| mc.model_name == model_name) + { + let mut config = saved_model_config.unwrap(); + if let Some(temp) = session_config.settings.as_ref().and_then(|s| s.temperature) { + config = config.with_temperature(Some(temp)); + } + config + } else { + let temperature = session_config.settings.as_ref().and_then(|s| s.temperature); + goose::model::ModelConfig::new(&model_name) + .unwrap_or_else(|e| { + output::render_error(&format!("Failed to create model configuration: {}", e)); + process::exit(1); + }) + .with_temperature(temperature) + }; - // Create the agent let agent: Agent = Agent::new(); agent @@ -304,10 +330,8 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { process::exit(1); } }; - // Keep a reference to the provider for display_session_info let provider_for_display = Arc::clone(&new_provider); - // Log model information at startup if let Some(lead_worker) = new_provider.as_lead_worker() { let (lead_model, worker_model) = lead_worker.get_model_info(); tracing::info!( @@ -319,14 +343,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { tracing::info!("🤖 Using model: {}", model_name); } - agent - .update_provider(new_provider) - .await - .unwrap_or_else(|e| { - output::render_error(&format!("Failed to initialize agent: {}", e)); - process::exit(1); - }); - let session_id: String = if session_config.no_session { let working_dir = std::env::current_dir().expect("Could not get working directory"); let session = SessionManager::create_session( @@ -362,6 +378,14 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { session_config.session_id.unwrap() }; + agent + .update_provider(new_provider, &session_id) + .await + .unwrap_or_else(|e| { + output::render_error(&format!("Failed to initialize agent: {}", e)); + process::exit(1); + }); + agent .extension_manager .set_context(PlatformExtensionContext { diff --git a/crates/goose-server/src/openapi.rs b/crates/goose-server/src/openapi.rs index a75841f9cffe..33aeea69a45b 100644 --- a/crates/goose-server/src/openapi.rs +++ b/crates/goose-server/src/openapi.rs @@ -4,6 +4,7 @@ use goose::agents::ExtensionConfig; use goose::config::permission::PermissionLevel; use goose::config::ExtensionEntry; use goose::conversation::Conversation; +use goose::model::ModelConfig; use goose::permission::permission_confirmation::PrincipalType; use goose::providers::base::{ConfigKey, ModelInfo, ProviderMetadata, ProviderType}; use goose::session::{Session, SessionInsights, SessionType}; @@ -447,6 +448,7 @@ derive_utoipa!(Icon as IconSchema); PermissionLevel, PrincipalType, ModelInfo, + ModelConfig, Session, SessionInsights, SessionType, diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index 3d1be23ae215..2a3d8a014d9c 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -16,7 +16,7 @@ use goose::agents::ExtensionConfig; use goose::config::{Config, GooseMode}; use goose::model::ModelConfig; use goose::prompt_template::render_global_file; -use goose::providers::{create, create_with_named_model}; +use goose::providers::create; use goose::recipe::Recipe; use goose::recipe_deeplink; use goose::session::session_manager::SessionType; @@ -192,7 +192,7 @@ async fn resume_agent( State(state): State>, Json(payload): Json, ) -> Result, ErrorResponse> { - let session = SessionManager::get_session(&payload.session_id, true) + let session = SessionManager::get_session(&payload.session_id, false) .await .map_err(|err| { error!("Failed to resume session {}: {}", payload.session_id, err); @@ -204,7 +204,7 @@ async fn resume_agent( if payload.load_model_and_extensions { let agent = state - .get_agent_for_route(payload.session_id) + .get_agent_for_route(payload.session_id.clone()) .await .map_err(|code| ErrorResponse { message: "Failed to get agent for route".into(), @@ -214,25 +214,39 @@ async fn resume_agent( let config = Config::global(); let provider_result = async { - let provider_name: String = config.get_goose_provider().map_err(|_| ErrorResponse { - message: "Could not configure agent: missing provider".into(), - status: StatusCode::INTERNAL_SERVER_ERROR, - })?; - - let model: String = config.get_goose_model().map_err(|_| ErrorResponse { - message: "Could not configure agent: missing model".into(), - status: StatusCode::INTERNAL_SERVER_ERROR, - })?; - - let provider = create_with_named_model(&provider_name, &model) - .await - .map_err(|_| ErrorResponse { - message: "Could not configure agent: missing model".into(), + let provider_name = session + .provider_name + .clone() + .or_else(|| config.get_goose_provider().ok()) + .ok_or_else(|| ErrorResponse { + message: "Could not configure agent: missing provider".into(), status: StatusCode::INTERNAL_SERVER_ERROR, })?; + let model_config = match session.model_config.clone() { + Some(saved_config) => saved_config, + None => { + let model_name = config.get_goose_model().map_err(|_| ErrorResponse { + message: "Could not configure agent: missing model".into(), + status: StatusCode::INTERNAL_SERVER_ERROR, + })?; + ModelConfig::new(&model_name).map_err(|e| ErrorResponse { + message: format!("Could not configure agent: invalid model {}", e), + status: StatusCode::INTERNAL_SERVER_ERROR, + })? + } + }; + + let provider = + create(&provider_name, model_config) + .await + .map_err(|e| ErrorResponse { + message: format!("Could not create provider: {}", e), + status: StatusCode::INTERNAL_SERVER_ERROR, + })?; + agent - .update_provider(provider) + .update_provider(provider, &payload.session_id) .await .map_err(|e| ErrorResponse { message: format!("Could not configure agent: {}", e), @@ -428,12 +442,15 @@ async fn update_agent_provider( ) })?; - agent.update_provider(new_provider).await.map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to update provider: {}", e), - ) - })?; + agent + .update_provider(new_provider, &payload.session_id) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to update provider: {}", e), + ) + })?; Ok(()) } diff --git a/crates/goose/examples/agent.rs b/crates/goose/examples/agent.rs index 014d63b302ad..4e4bb5795903 100644 --- a/crates/goose/examples/agent.rs +++ b/crates/goose/examples/agent.rs @@ -10,15 +10,21 @@ use goose::session::SessionManager; use std::path::PathBuf; #[tokio::main] -async fn main() { +async fn main() -> anyhow::Result<()> { let _ = dotenv(); - let provider = create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL) - .await - .expect("Couldn't create provider"); + let provider = create_with_named_model("databricks", DATABRICKS_DEFAULT_MODEL).await?; let agent = Agent::new(); - let _ = agent.update_provider(provider).await; + + let session = SessionManager::create_session( + PathBuf::default(), + "max-turn-test".to_string(), + SessionType::Hidden, + ) + .await?; + + let _ = agent.update_provider(provider, &session.id).await; let config = ExtensionConfig::stdio( "developer", @@ -27,21 +33,13 @@ async fn main() { DEFAULT_EXTENSION_TIMEOUT, ) .with_args(vec!["mcp", "developer"]); - agent.add_extension(config).await.unwrap(); + agent.add_extension(config).await?; println!("Extensions:"); for extension in agent.list_extensions().await { println!(" {}", extension); } - let session = SessionManager::create_session( - PathBuf::default(), - "max-turn-test".to_string(), - SessionType::Hidden, - ) - .await - .expect("session manager creation failed"); - let session_config = SessionConfig { id: session.id, schedule_id: None, @@ -52,13 +50,12 @@ async fn main() { let user_message = Message::user() .with_text("can you summarize the readme.md in this dir using just a haiku?"); - let mut stream = agent - .reply(user_message, session_config, None) - .await - .unwrap(); + let mut stream = agent.reply(user_message, session_config, None).await?; while let Some(Ok(AgentEvent::Message(message))) = stream.next().await { - println!("{}", serde_json::to_string_pretty(&message).unwrap()); + println!("{}", serde_json::to_string_pretty(&message)?); println!("\n"); } + + Ok(()) } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index ed28ea8f1463..6e797a7caf7a 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -3,7 +3,7 @@ use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use futures::stream::BoxStream; use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; use uuid::Uuid; @@ -1238,13 +1238,23 @@ impl Agent { prompt_manager.add_system_prompt_extra(instruction); } - pub async fn update_provider(&self, provider: Arc) -> Result<()> { + pub async fn update_provider( + &self, + provider: Arc, + session_id: &str, + ) -> Result<()> { let mut current_provider = self.provider.lock().await; *current_provider = Some(provider.clone()); - self.update_router_tool_selector(Some(provider), None) + self.update_router_tool_selector(Some(provider.clone()), None) .await?; - Ok(()) + + SessionManager::update_session(session_id) + .provider_name(provider.get_name()) + .model_config(provider.get_model_config()) + .apply() + .await + .context("Failed to persist provider config to session") } pub async fn update_router_tool_selector( diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index f9755c716771..40b1a99d6a8a 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -18,6 +18,8 @@ use crate::providers::toolshim::{ use crate::agents::recipe_tools::dynamic_task_tools::should_enabled_subagents; use crate::session::SessionManager; +#[cfg(test)] +use crate::session::SessionType; use rmcp::model::Tool; fn coerce_value(s: &str, schema: &Value) -> Value { @@ -439,9 +441,16 @@ mod tests { ) -> anyhow::Result<()> { let agent = crate::agents::Agent::new(); + let session = SessionManager::create_session( + std::path::PathBuf::default(), + "test-prepare-tools".to_string(), + SessionType::Hidden, + ) + .await?; + let model_config = ModelConfig::new("test-model").unwrap(); let provider = std::sync::Arc::new(MockProvider { model_config }); - agent.update_provider(provider).await?; + agent.update_provider(provider, &session.id).await?; // Disable the router to trigger sorting agent.disable_router_for_recipe().await; diff --git a/crates/goose/src/agents/subagent_handler.rs b/crates/goose/src/agents/subagent_handler.rs index 0ab0310a7837..2d6f7f8167a9 100644 --- a/crates/goose/src/agents/subagent_handler.rs +++ b/crates/goose/src/agents/subagent_handler.rs @@ -117,7 +117,7 @@ fn get_agent_messages( .map_err(|e| anyhow!("Failed to get sub agent session file path: {}", e))?; agent - .update_provider(task_config.provider) + .update_provider(task_config.provider, &session_id) .await .map_err(|e| anyhow!("Failed to set provider on sub agent: {}", e))?; diff --git a/crates/goose/src/execution/manager.rs b/crates/goose/src/execution/manager.rs index c222be2d42da..efd596d803c9 100644 --- a/crates/goose/src/execution/manager.rs +++ b/crates/goose/src/execution/manager.rs @@ -87,7 +87,9 @@ impl AgentManager { }) .await; if let Some(provider) = &*self.default_provider.read().await { - agent.update_provider(Arc::clone(provider)).await?; + agent + .update_provider(Arc::clone(provider), &session_id) + .await?; } let mut sessions = self.sessions.write().await; diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index 2b3ee255b7f4..9a9f8eb546dc 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -1,6 +1,7 @@ use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use thiserror::Error; +use utoipa::ToSchema; const DEFAULT_CONTEXT_LIMIT: usize = 128_000; @@ -67,7 +68,7 @@ static MODEL_SPECIFIC_LIMITS: Lazy> = Lazy::new(|| { ] }); -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] pub struct ModelConfig { pub model_name: String, pub context_limit: Option, diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 045922074a84..09883327103c 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -664,8 +664,6 @@ async fn execute_job( } } - agent.update_provider(agent_provider).await?; - let session = SessionManager::create_session( std::env::current_dir()?, format!("Scheduled job: {}", job.id), @@ -673,6 +671,8 @@ async fn execute_job( ) .await?; + agent.update_provider(agent_provider, &session.id).await?; + let mut jobs_guard = jobs.lock().await; if let Some((_, job_def)) = jobs_guard.get_mut(job_id.as_str()) { job_def.current_session_id = Some(session.id.clone()); diff --git a/crates/goose/src/session/session_manager.rs b/crates/goose/src/session/session_manager.rs index 8780b09536f6..faf9b4551e02 100644 --- a/crates/goose/src/session/session_manager.rs +++ b/crates/goose/src/session/session_manager.rs @@ -1,6 +1,7 @@ use crate::config::paths::Paths; use crate::conversation::message::Message; use crate::conversation::Conversation; +use crate::model::ModelConfig; use crate::providers::base::{Provider, MSG_COUNT_FOR_SESSION_NAME_GENERATION}; use crate::recipe::Recipe; use crate::session::extension_data::ExtensionData; @@ -18,7 +19,7 @@ use tokio::sync::OnceCell; use tracing::{info, warn}; use utoipa::ToSchema; -const CURRENT_SCHEMA_VERSION: i32 = 5; +const CURRENT_SCHEMA_VERSION: i32 = 6; pub const SESSIONS_FOLDER: &str = "sessions"; pub const DB_NAME: &str = "sessions.db"; @@ -89,6 +90,8 @@ pub struct Session { pub user_recipe_values: Option>, pub conversation: Option, pub message_count: usize, + pub provider_name: Option, + pub model_config: Option, } pub struct SessionUpdateBuilder { @@ -107,6 +110,8 @@ pub struct SessionUpdateBuilder { schedule_id: Option>, recipe: Option>, user_recipe_values: Option>>, + provider_name: Option>, + model_config: Option>, } #[derive(Serialize, ToSchema, Debug)] @@ -134,6 +139,8 @@ impl SessionUpdateBuilder { schedule_id: None, recipe: None, user_recipe_values: None, + provider_name: None, + model_config: None, } } @@ -218,6 +225,16 @@ impl SessionUpdateBuilder { self } + pub fn provider_name(mut self, provider_name: impl Into) -> Self { + self.provider_name = Some(Some(provider_name.into())); + self + } + + pub fn model_config(mut self, model_config: ModelConfig) -> Self { + self.model_config = Some(Some(model_config)); + self + } + pub async fn apply(self) -> Result<()> { SessionManager::apply_update(self).await } @@ -375,6 +392,8 @@ impl Default for Session { user_recipe_values: None, conversation: None, message_count: 0, + provider_name: None, + model_config: None, } } } @@ -397,6 +416,9 @@ impl sqlx::FromRow<'_, sqlx::sqlite::SqliteRow> for Session { let user_recipe_values = user_recipe_values_json.and_then(|json| serde_json::from_str(&json).ok()); + let model_config_json: Option = row.try_get("model_config_json").ok().flatten(); + let model_config = model_config_json.and_then(|json| serde_json::from_str(&json).ok()); + let name: String = { let name_val: String = row.try_get("name").unwrap_or_default(); if !name_val.is_empty() { @@ -434,6 +456,8 @@ impl sqlx::FromRow<'_, sqlx::sqlite::SqliteRow> for Session { user_recipe_values, conversation: None, message_count: row.try_get("message_count").unwrap_or(0) as usize, + provider_name: row.try_get("provider_name").ok().flatten(), + model_config, }) } } @@ -521,7 +545,9 @@ impl SessionStorage { accumulated_output_tokens INTEGER, schedule_id TEXT, recipe_json TEXT, - user_recipe_values_json TEXT + user_recipe_values_json TEXT, + provider_name TEXT, + model_config_json TEXT ) "#, ) @@ -618,14 +644,20 @@ impl SessionStorage { None => None, }; + let model_config_json = match &session.model_config { + Some(model_config) => Some(serde_json::to_string(model_config)?), + None => None, + }; + sqlx::query( r#" INSERT INTO sessions ( id, name, user_set_name, session_type, working_dir, created_at, updated_at, extension_data, total_tokens, input_tokens, output_tokens, accumulated_total_tokens, accumulated_input_tokens, accumulated_output_tokens, - schedule_id, recipe_json, user_recipe_values_json - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + schedule_id, recipe_json, user_recipe_values_json, + provider_name, model_config_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) "#, ) .bind(&session.id) @@ -645,6 +677,8 @@ impl SessionStorage { .bind(&session.schedule_id) .bind(recipe_json) .bind(user_recipe_values_json) + .bind(&session.provider_name) + .bind(model_config_json) .execute(&mut *tx) .await?; @@ -771,6 +805,23 @@ impl SessionStorage { .execute(&self.pool) .await?; } + 6 => { + sqlx::query( + r#" + ALTER TABLE sessions ADD COLUMN provider_name TEXT + "#, + ) + .execute(&self.pool) + .await?; + + sqlx::query( + r#" + ALTER TABLE sessions ADD COLUMN model_config_json TEXT + "#, + ) + .execute(&self.pool) + .await?; + } _ => { anyhow::bail!("Unknown migration version: {}", version); } @@ -824,7 +875,8 @@ impl SessionStorage { SELECT id, working_dir, name, description, user_set_name, session_type, created_at, updated_at, extension_data, total_tokens, input_tokens, output_tokens, accumulated_total_tokens, accumulated_input_tokens, accumulated_output_tokens, - schedule_id, recipe_json, user_recipe_values_json + schedule_id, recipe_json, user_recipe_values_json, + provider_name, model_config_json FROM sessions WHERE id = ? "#, @@ -850,6 +902,7 @@ impl SessionStorage { Ok(session) } + #[allow(clippy::too_many_lines)] async fn apply_update(&self, builder: SessionUpdateBuilder) -> Result<()> { let mut updates = Vec::new(); let mut query = String::from("UPDATE sessions SET "); @@ -884,6 +937,8 @@ impl SessionStorage { add_update!(builder.schedule_id, "schedule_id"); add_update!(builder.recipe, "recipe_json"); add_update!(builder.user_recipe_values, "user_recipe_values_json"); + add_update!(builder.provider_name, "provider_name"); + add_update!(builder.model_config, "model_config_json"); if updates.is_empty() { return Ok(()); @@ -940,6 +995,15 @@ impl SessionStorage { .transpose()?; q = q.bind(user_recipe_values_json); } + if let Some(provider_name) = builder.provider_name { + q = q.bind(provider_name); + } + if let Some(model_config) = builder.model_config { + let model_config_json = model_config + .map(|mc| serde_json::to_string(&mc)) + .transpose()?; + q = q.bind(model_config_json); + } let mut tx = self.pool.begin().await?; q = q.bind(&builder.session_id); @@ -1050,6 +1114,7 @@ impl SessionStorage { s.total_tokens, s.input_tokens, s.output_tokens, s.accumulated_total_tokens, s.accumulated_input_tokens, s.accumulated_output_tokens, s.schedule_id, s.recipe_json, s.user_recipe_values_json, + s.provider_name, s.model_config_json, COUNT(m.id) as message_count FROM sessions s INNER JOIN messages m ON s.id = m.session_id diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 5a4389a4dea6..51aaa55ebebf 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -356,7 +356,6 @@ mod tests { async fn test_max_turns_limit() -> Result<()> { let agent = Agent::new(); let provider = Arc::new(MockToolProvider::new()); - agent.update_provider(provider).await?; let user_message = Message::user().with_text("Hello"); let session = SessionManager::create_session( @@ -365,6 +364,9 @@ mod tests { SessionType::Hidden, ) .await?; + + agent.update_provider(provider, &session.id).await?; + let session_config = SessionConfig { id: session.id, schedule_id: None, diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 8bc8dd6034f6..83500adb150f 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -3509,6 +3509,44 @@ } } }, + "ModelConfig": { + "type": "object", + "required": [ + "model_name", + "toolshim" + ], + "properties": { + "context_limit": { + "type": "integer", + "nullable": true, + "minimum": 0 + }, + "fast_model": { + "type": "string", + "nullable": true + }, + "max_tokens": { + "type": "integer", + "format": "int32", + "nullable": true + }, + "model_name": { + "type": "string" + }, + "temperature": { + "type": "number", + "format": "float", + "nullable": true + }, + "toolshim": { + "type": "boolean" + }, + "toolshim_model": { + "type": "string", + "nullable": true + } + } + }, "ModelInfo": { "type": "object", "description": "Information about a model's capabilities", @@ -4270,6 +4308,14 @@ "type": "integer", "minimum": 0 }, + "model_config": { + "allOf": [ + { + "$ref": "#/components/schemas/ModelConfig" + } + ], + "nullable": true + }, "name": { "type": "string" }, @@ -4278,6 +4324,10 @@ "format": "int32", "nullable": true }, + "provider_name": { + "type": "string", + "nullable": true + }, "recipe": { "allOf": [ { diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 075081d48c20..8c24f9ffa098 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -410,6 +410,16 @@ export type MessageMetadata = { userVisible: boolean; }; +export type ModelConfig = { + context_limit?: number | null; + fast_model?: string | null; + max_tokens?: number | null; + model_name: string; + temperature?: number | null; + toolshim: boolean; + toolshim_model?: string | null; +}; + /** * Information about a model's capabilities */ @@ -687,8 +697,10 @@ export type Session = { id: string; input_tokens?: number | null; message_count: number; + model_config?: ModelConfig | null; name: string; output_tokens?: number | null; + provider_name?: string | null; recipe?: Recipe | null; schedule_id?: string | null; session_type?: SessionType;