diff --git a/Cargo.lock b/Cargo.lock index b1de675077c6..b09db839f466 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2682,6 +2682,7 @@ dependencies = [ "jsonwebtoken", "keyring", "lazy_static", + "lru", "mcp-client", "mcp-core", "minijinja", diff --git a/crates/goose-server/src/commands/agent.rs b/crates/goose-server/src/commands/agent.rs index 3746437523d2..84e42b556527 100644 --- a/crates/goose-server/src/commands/agent.rs +++ b/crates/goose-server/src/commands/agent.rs @@ -1,13 +1,7 @@ -use std::sync::Arc; - use crate::configuration; use crate::state; use anyhow::Result; use axum::middleware; -use etcetera::{choose_app_strategy, AppStrategy}; -use goose::agents::Agent; -use goose::config::APP_STRATEGY; -use goose::scheduler_factory::SchedulerFactory; use goose_server::auth::check_token; use tower_http::cors::{Any, CorsLayer}; use tracing::info; @@ -32,49 +26,7 @@ pub async fn run() -> Result<()> { let secret_key = std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string()); - let new_agent = Agent::new(); - - // Only initialize provider and extensions when running in standalone goosed mode - // This prevents breaking the Electron app which manages its own provider setup - if std::env::var("GOOSE_STANDALONE_MODE").unwrap_or_else(|_| "false".to_string()) == "true" { - tracing::info!("Running in standalone mode - initializing provider and extensions"); - - // Initialize provider like the CLI does - let config = goose::config::Config::global(); - - let provider_name: String = config - .get_param("GOOSE_PROVIDER") - .expect("No provider configured. Run 'goose configure' first"); - - let model_name: String = config - .get_param("GOOSE_MODEL") - .expect("No model configured. Run 'goose configure' first"); - - let model_config = goose::model::ModelConfig::new(&model_name) - .expect("Failed to create model configuration"); - - let provider = goose::providers::create(&provider_name, model_config) - .expect("Failed to create provider"); - - new_agent - .update_provider(provider) - .await - .expect("Failed to update agent provider"); - } - - let agent_ref = Arc::new(new_agent); - - let app_state = state::AppState::new(agent_ref.clone()); - - let schedule_file_path = choose_app_strategy(APP_STRATEGY.clone())? - .data_dir() - .join("schedules.json"); - - let scheduler_instance = SchedulerFactory::create(schedule_file_path).await?; - app_state.set_scheduler(scheduler_instance.clone()).await; - - // NEW: Provide scheduler access to the agent - agent_ref.set_scheduler(scheduler_instance).await; + let app_state = state::AppState::new().await?; let cors = CorsLayer::new() .allow_origin(Any) diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index 2a0ffbdf32d3..1c546d5d66f1 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -1,5 +1,4 @@ use crate::state::AppState; -use axum::response::IntoResponse; use axum::{ extract::{Query, State}, http::StatusCode, @@ -28,7 +27,6 @@ use tracing::error; #[derive(Deserialize, utoipa::ToSchema)] pub struct ExtendPromptRequest { extension: String, - #[allow(dead_code)] session_id: String, } @@ -40,7 +38,6 @@ pub struct ExtendPromptResponse { #[derive(Deserialize, utoipa::ToSchema)] pub struct AddSubRecipesRequest { sub_recipes: Vec, - #[allow(dead_code)] session_id: String, } @@ -53,27 +50,23 @@ pub struct AddSubRecipesResponse { pub struct UpdateProviderRequest { provider: String, model: Option, - #[allow(dead_code)] session_id: String, } #[derive(Deserialize, utoipa::ToSchema)] pub struct SessionConfigRequest { response: Option, - #[allow(dead_code)] session_id: String, } #[derive(Deserialize, utoipa::ToSchema)] pub struct GetToolsQuery { extension_name: Option, - #[allow(dead_code)] session_id: String, } #[derive(Deserialize, utoipa::ToSchema)] pub struct UpdateRouterToolSelectorRequest { - #[allow(dead_code)] session_id: String, } @@ -116,8 +109,6 @@ async fn start_agent( State(state): State>, Json(payload): Json, ) -> Result, StatusCode> { - state.reset().await; - let session_id = session::generate_session_id(); let counter = state.session_counter.fetch_add(1, Ordering::SeqCst) + 1; @@ -203,7 +194,7 @@ async fn add_sub_recipes( State(state): State>, Json(payload): Json, ) -> Result, StatusCode> { - let agent = state.get_agent().await; + let agent = state.get_agent_for_route(payload.session_id).await?; agent.add_sub_recipes(payload.sub_recipes.clone()).await; Ok(Json(AddSubRecipesResponse { success: true })) } @@ -222,7 +213,7 @@ async fn extend_prompt( State(state): State>, Json(payload): Json, ) -> Result, StatusCode> { - let agent = state.get_agent().await; + let agent = state.get_agent_for_route(payload.session_id).await?; agent.extend_system_prompt(payload.extension.clone()).await; Ok(Json(ExtendPromptResponse { success: true })) } @@ -247,7 +238,7 @@ async fn get_tools( ) -> Result>, StatusCode> { let config = Config::global(); let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string()); - let agent = state.get_agent().await; + let agent = state.get_agent_for_route(query.session_id).await?; let permission_manager = PermissionManager::default(); let mut tools: Vec = agent @@ -298,35 +289,37 @@ async fn get_tools( async fn update_agent_provider( State(state): State>, Json(payload): Json, -) -> Result { - let agent = state.get_agent().await; +) -> Result { + let agent = state + .get_agent_for_route(payload.session_id.clone()) + .await?; + let config = Config::global(); let model = match payload .model .or_else(|| config.get_param("GOOSE_MODEL").ok()) { Some(m) => m, - None => return Err((StatusCode::BAD_REQUEST, "No model specified".to_string())), + None => { + tracing::error!("No model specified"); + return Err(StatusCode::BAD_REQUEST); + } }; let model_config = ModelConfig::new(&model).map_err(|e| { - ( - StatusCode::BAD_REQUEST, - format!("Invalid model config: {}", e), - ) + tracing::error!("Invalid model config: {}", e); + StatusCode::BAD_REQUEST })?; let new_provider = create(&payload.provider, model_config).map_err(|e| { - ( - StatusCode::BAD_REQUEST, - format!("Failed to create provider: {}", e), - ) + tracing::error!("Failed to create provider: {}", e); + StatusCode::BAD_REQUEST })?; - agent - .update_provider(new_provider) - .await - .map_err(|_e| (StatusCode::INTERNAL_SERVER_ERROR, String::new()))?; + agent.update_provider(new_provider).await.map_err(|e| { + tracing::error!("Failed to update provider: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; Ok(StatusCode::OK) } @@ -344,17 +337,15 @@ async fn update_agent_provider( )] async fn update_router_tool_selector( State(state): State>, - Json(_payload): Json, -) -> Result, Json> { - let agent = state.get_agent().await; + Json(payload): Json, +) -> Result, StatusCode> { + let agent = state.get_agent_for_route(payload.session_id).await?; agent .update_router_tool_selector(None, Some(true)) .await .map_err(|e| { tracing::error!("Failed to update tool selection strategy: {}", e); - Json(ErrorResponse { - error: format!("Failed to update tool selection strategy: {}", e), - }) + StatusCode::INTERNAL_SERVER_ERROR })?; Ok(Json( @@ -376,8 +367,8 @@ async fn update_router_tool_selector( async fn update_session_config( State(state): State>, Json(payload): Json, -) -> Result, Json> { - let agent = state.get_agent().await; +) -> Result, StatusCode> { + let agent = state.get_agent_for_route(payload.session_id).await?; if let Some(response) = payload.response { agent.add_final_output_tool(response).await; diff --git a/crates/goose-server/src/routes/audio.rs b/crates/goose-server/src/routes/audio.rs index 473c37b5800e..8fa36a1945aa 100644 --- a/crates/goose-server/src/routes/audio.rs +++ b/crates/goose-server/src/routes/audio.rs @@ -391,13 +391,13 @@ pub fn routes(state: Arc) -> Router { mod tests { use super::*; use axum::{body::Body, http::Request}; + use serde_json::json; use tower::ServiceExt; - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn test_transcribe_endpoint_requires_auth() { - let state = AppState::new(Arc::new(goose::agents::Agent::new())); + let state = AppState::new().await.unwrap(); let app = routes(state); - // Test without auth header let request = Request::builder() .uri("/audio/transcribe") @@ -413,40 +413,18 @@ mod tests { .unwrap(); let response = app.oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + assert!( + response.status() == StatusCode::PRECONDITION_FAILED + || response.status() == StatusCode::UNAUTHORIZED + ); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn test_transcribe_endpoint_validates_size() { - let state = AppState::new(Arc::new(goose::agents::Agent::new())); - let app = routes(state); - - // Create a large base64 string (simulating > 25MB audio) - let large_audio = BASE64.encode(vec![0u8; MAX_AUDIO_SIZE_BYTES + 1]); - - let request = Request::builder() - .uri("/audio/transcribe") - .method("POST") - .header("content-type", "application/json") - .header("x-secret-key", "test-secret") - .body(Body::from( - serde_json::to_string(&serde_json::json!({ - "audio": large_audio, - "mime_type": "audio/webm" - })) - .unwrap(), - )) - .unwrap(); - - let response = app.oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); - } - - #[tokio::test] - async fn test_transcribe_endpoint_validates_mime_type() { - let state = AppState::new(Arc::new(goose::agents::Agent::new())); + let state = AppState::new().await.unwrap(); let app = routes(state); + let large_data = "a".repeat(30 * 1024 * 1024); // 30MB let request = Request::builder() .uri("/audio/transcribe") .method("POST") @@ -468,9 +446,9 @@ mod tests { ); } - #[tokio::test] - async fn test_transcribe_endpoint_handles_invalid_base64() { - let state = AppState::new(Arc::new(goose::agents::Agent::new())); + #[tokio::test(flavor = "multi_thread")] + async fn test_transcribe_endpoint_validates_mime_type() { + let state = AppState::new().await.unwrap(); let app = routes(state); let request = Request::builder() diff --git a/crates/goose-server/src/routes/context.rs b/crates/goose-server/src/routes/context.rs index 205d2837c6b4..6f8764b074f9 100644 --- a/crates/goose-server/src/routes/context.rs +++ b/crates/goose-server/src/routes/context.rs @@ -13,6 +13,8 @@ pub struct ContextManageRequest { pub messages: Vec, /// Operation to perform: "truncation" or "summarize" pub manage_action: String, + /// Optional session ID for session-specific agent + pub session_id: String, } /// Response from context management operations @@ -44,7 +46,7 @@ async fn manage_context( State(state): State>, Json(request): Json, ) -> Result, StatusCode> { - let agent = state.get_agent().await; + let agent = state.get_agent_for_route(request.session_id).await?; let mut processed_messages = Conversation::new_unvalidated(vec![]); let mut token_counts: Vec = vec![]; diff --git a/crates/goose-server/src/routes/extension.rs b/crates/goose-server/src/routes/extension.rs index 8102373a6cdd..ecdaad4eab06 100644 --- a/crates/goose-server/src/routes/extension.rs +++ b/crates/goose-server/src/routes/extension.rs @@ -96,33 +96,31 @@ struct ExtensionResponse { message: Option, } +/// Request structure for adding an extension, combining session_id with the extension config +#[derive(Deserialize)] +struct AddExtensionRequest { + session_id: String, + #[serde(flatten)] + config: ExtensionConfigRequest, +} + /// Handler for adding a new extension configuration. async fn add_extension( State(state): State>, - raw: axum::extract::Json, + Json(request): Json, ) -> Result, StatusCode> { - // Log the raw request for debugging + // Log the request for debugging tracing::info!( - "Received extension request: {}", - serde_json::to_string_pretty(&raw.0).unwrap() + "Received extension request for session: {}", + request.session_id ); - // Try to parse into our enum - let request: ExtensionConfigRequest = match serde_json::from_value(raw.0.clone()) { - Ok(req) => req, - Err(e) => { - tracing::error!("Failed to parse extension request: {}", e); - tracing::error!( - "Raw request was: {}", - serde_json::to_string_pretty(&raw.0).unwrap() - ); - return Err(StatusCode::UNPROCESSABLE_ENTITY); - } - }; + let session_id = request.session_id.clone(); + let extension_request = request.config; // If this is a Stdio extension that uses npx, check for Node.js installation #[cfg(target_os = "windows")] - if let ExtensionConfigRequest::Stdio { cmd, .. } = &request { + if let ExtensionConfigRequest::Stdio { cmd, .. } = &extension_request { if cmd.ends_with("npx.cmd") || cmd.ends_with("npx") { // Check if Node.js is installed in standard locations let node_exists = std::path::Path::new(r"C:\Program Files\nodejs\node.exe").exists() @@ -175,7 +173,7 @@ async fn add_extension( } // Construct ExtensionConfig with Envs populated from keyring based on provided env_keys. - let extension_config: ExtensionConfig = match request { + let extension_config: ExtensionConfig = match extension_request { ExtensionConfigRequest::Sse { name, uri, @@ -267,7 +265,7 @@ async fn add_extension( }, }; - let agent = state.get_agent().await; + let agent = state.get_agent_for_route(session_id).await?; let response = agent.add_extension(extension_config).await; // Respond with the result. @@ -289,13 +287,20 @@ async fn add_extension( } } +#[derive(Deserialize)] +struct RemoveExtensionRequest { + name: String, + session_id: String, +} + /// Handler for removing an extension by name async fn remove_extension( State(state): State>, - Json(name): Json, + Json(request): Json, ) -> Result, StatusCode> { - let agent = state.get_agent().await; - match agent.remove_extension(&name).await { + let agent = state.get_agent_for_route(request.session_id).await?; + + match agent.remove_extension(&request.name).await { Ok(_) => Ok(Json(ExtensionResponse { error: false, message: None, diff --git a/crates/goose-server/src/routes/recipe.rs b/crates/goose-server/src/routes/recipe.rs index 5da49941113d..96268a9bdf4e 100644 --- a/crates/goose-server/src/routes/recipe.rs +++ b/crates/goose-server/src/routes/recipe.rs @@ -25,6 +25,7 @@ pub struct CreateRecipeRequest { activities: Option>, #[serde(default)] author: Option, + session_id: String, } #[derive(Debug, Deserialize, ToSchema)] @@ -108,13 +109,13 @@ pub struct ListRecipeResponse { async fn create_recipe( State(state): State>, Json(request): Json, -) -> Result, (StatusCode, Json)> { +) -> Result, StatusCode> { tracing::info!( "Recipe creation request received with {} messages", request.messages.len() ); - let agent = state.get_agent().await; + let agent = state.get_agent_for_route(request.session_id).await?; // Create base recipe from agent state and messages let recipe_result = agent @@ -143,12 +144,7 @@ async fn create_recipe( } Err(e) => { tracing::error!("Error details: {:?}", e); - let error_message = format!("Recipe creation failed: {}", e); - let error_response = CreateRecipeResponse { - recipe: None, - error: Some(error_message), - }; - Err((StatusCode::BAD_REQUEST, Json(error_response))) + Err(StatusCode::BAD_REQUEST) } } } diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 68db8328c2ab..e68aef92837d 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -10,6 +10,7 @@ use bytes::Bytes; use futures::{stream::StreamExt, Stream}; use goose::conversation::message::{Message, MessageContent}; use goose::conversation::Conversation; +use goose::execution::SessionExecutionMode; use goose::{ agents::{AgentEvent, SessionConfig}, permission::permission_confirmation::PrincipalType, @@ -86,7 +87,7 @@ fn track_tool_telemetry(content: &MessageContent, all_messages: &[Message]) { #[derive(Debug, Deserialize, Serialize)] struct ChatRequest { messages: Vec, - session_id: Option, + session_id: String, recipe_name: Option, recipe_version: Option, } @@ -178,9 +179,9 @@ async fn reply_handler( "Session started" ); - if let (Some(recipe_name), Some(session_id)) = - (request.recipe_name.clone(), request.session_id.clone()) - { + let session_id = request.session_id.clone(); + + if let Some(recipe_name) = request.recipe_name.clone() { if state.mark_recipe_run_if_absent(&session_id).await { let recipe_version = request .recipe_version @@ -204,16 +205,28 @@ async fn reply_handler( let messages = Conversation::new_unvalidated(request.messages); - let session_id = request.session_id.ok_or_else(|| { - tracing::error!("session_id is required but was not provided"); - StatusCode::BAD_REQUEST - })?; - let task_cancel = cancel_token.clone(); let task_tx = tx.clone(); drop(tokio::spawn(async move { - let agent = state.get_agent().await; + let agent = match state + .get_agent(session_id.clone(), SessionExecutionMode::Interactive) + .await + { + Ok(agent) => agent, + Err(e) => { + tracing::error!("Failed to get session agent: {}", e); + let _ = stream_event( + MessageEvent::Error { + error: format!("Failed to get session agent: {}", e), + }, + &task_tx, + &task_cancel, + ) + .await; + return; + } + }; // Load session metadata to get the working directory and other config let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) { @@ -453,7 +466,6 @@ pub struct PermissionConfirmationRequest { #[serde(default = "default_principal_type")] principal_type: PrincipalType, action: String, - #[allow(dead_code)] session_id: String, } @@ -475,7 +487,7 @@ pub async fn confirm_permission( State(state): State>, Json(request): Json, ) -> Result, StatusCode> { - let agent = state.get_agent().await; + let agent = state.get_agent_for_route(request.session_id).await?; let permission = match request.action.as_str() { "always_allow" => Permission::AlwaysAllow, "allow_once" => Permission::AllowOnce, @@ -499,7 +511,6 @@ pub async fn confirm_permission( struct ToolResultRequest { id: String, result: ToolResult>, - #[allow(dead_code)] session_id: String, } @@ -524,7 +535,7 @@ async fn submit_tool_result( } }; - let agent = state.get_agent().await; + let agent = state.get_agent_for_route(payload.session_id).await?; agent.handle_tool_result(payload.id, payload.result).await; Ok(Json(json!({"status": "ok"}))) } @@ -548,7 +559,6 @@ mod tests { use super::*; use goose::conversation::message::Message; use goose::{ - agents::Agent, model::ModelConfig, providers::{ base::{Provider, ProviderUsage, Usage}, @@ -589,18 +599,17 @@ mod tests { use super::*; use axum::{body::Body, http::Request}; use goose::conversation::message::Message; - use std::sync::Arc; + use serde_json::json; use tower::ServiceExt; - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn test_reply_endpoint() { let mock_model_config = ModelConfig::new("test-model").unwrap(); - let mock_provider = Arc::new(MockProvider { + let mock_provider = MockProvider { model_config: mock_model_config, - }); - let agent = Agent::new(); - let _ = agent.update_provider(mock_provider).await; - let state = AppState::new(Arc::new(agent)); + }; + + let state = AppState::new().await.unwrap(); let app = routes(state); @@ -612,7 +621,7 @@ mod tests { .body(Body::from( serde_json::to_string(&ChatRequest { messages: vec![Message::user().with_text("test message")], - session_id: Some("test-session".to_string()), + session_id: "test-session".to_string(), recipe_name: None, recipe_version: None, }) diff --git a/crates/goose-server/src/routes/session.rs b/crates/goose-server/src/routes/session.rs index c8d987c7431c..9109d855d0dc 100644 --- a/crates/goose-server/src/routes/session.rs +++ b/crates/goose-server/src/routes/session.rs @@ -60,6 +60,7 @@ pub struct SessionInsights { } #[derive(Serialize, ToSchema, Debug)] +#[allow(dead_code)] #[serde(rename_all = "camelCase")] pub struct ActivityHeatmapCell { pub week: usize, diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index ba36dafdf2d4..8d5f5d24cb06 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -1,18 +1,15 @@ -use goose::agents::Agent; +use axum::http::StatusCode; +use goose::execution::manager::AgentManager; +use goose::execution::SessionExecutionMode; use goose::scheduler_trait::SchedulerTrait; use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tokio::sync::Mutex; -use tokio::sync::RwLock; - -type AgentRef = Arc; - #[derive(Clone)] pub struct AppState { - agent: Arc>, - pub scheduler: Arc>>>, + pub(crate) agent_manager: Arc, pub recipe_file_hash_map: Arc>>, pub session_counter: Arc, /// Tracks sessions that have already emitted recipe telemetry to prevent double counting. @@ -20,31 +17,18 @@ pub struct AppState { } impl AppState { - pub fn new(agent: AgentRef) -> Arc { - Arc::new(Self { - agent: Arc::new(RwLock::new(agent)), - scheduler: Arc::new(RwLock::new(None)), + pub async fn new() -> anyhow::Result> { + let agent_manager = Arc::new(AgentManager::new(None).await?); + Ok(Arc::new(Self { + agent_manager, recipe_file_hash_map: Arc::new(Mutex::new(HashMap::new())), session_counter: Arc::new(AtomicUsize::new(0)), recipe_session_tracker: Arc::new(Mutex::new(HashSet::new())), - }) - } - - pub async fn get_agent(&self) -> AgentRef { - self.agent.read().await.clone() - } - - pub async fn set_scheduler(&self, sched: Arc) { - let mut guard = self.scheduler.write().await; - *guard = Some(sched); + })) } pub async fn scheduler(&self) -> Result, anyhow::Error> { - self.scheduler - .read() - .await - .clone() - .ok_or_else(|| anyhow::anyhow!("Scheduler not initialized")) + self.agent_manager.scheduler().await } pub async fn set_recipe_file_hash_map(&self, hash_map: HashMap) { @@ -52,41 +36,6 @@ impl AppState { *map = hash_map; } - pub async fn reset(&self) { - let mut agent = self.agent.write().await; - let new_agent = Agent::new(); - - // Only initialize provider when running in standalone goosed mode - // This prevents breaking the Electron app which manages its own provider setup - if std::env::var("GOOSE_STANDALONE_MODE").unwrap_or_else(|_| "false".to_string()) == "true" - { - tracing::info!("Running in standalone mode - initializing provider"); - - let config = goose::config::Config::global(); - - let provider_name: String = config - .get_param("GOOSE_PROVIDER") - .expect("No provider configured. Run 'goose configure' first"); - - let model_name: String = config - .get_param("GOOSE_MODEL") - .expect("No model configured. Run 'goose configure' first"); - - let model_config = goose::model::ModelConfig::new(&model_name) - .expect("Failed to create model configuration"); - - let provider = goose::providers::create(&provider_name, model_config) - .expect("Failed to create provider"); - - new_agent - .update_provider(provider) - .await - .expect("Failed to update agent provider"); - } - - *agent = Arc::new(new_agent); - } - pub async fn mark_recipe_run_if_absent(&self, session_id: &str) -> bool { let mut sessions = self.recipe_session_tracker.lock().await; if sessions.contains(session_id) { @@ -96,4 +45,27 @@ impl AppState { true } } + + pub async fn get_agent( + &self, + session_id: String, + mode: SessionExecutionMode, + ) -> anyhow::Result> { + self.agent_manager + .get_or_create_agent(session_id, mode) + .await + } + + /// Get agent for route handlers - always uses Interactive mode and converts any error to 500 + pub async fn get_agent_for_route( + &self, + session_id: String, + ) -> Result, StatusCode> { + self.get_agent(session_id, SessionExecutionMode::Interactive) + .await + .map_err(|e| { + tracing::error!("Failed to get agent: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + }) + } } diff --git a/crates/goose-server/tests/pricing_api_test.rs b/crates/goose-server/tests/pricing_api_test.rs deleted file mode 100644 index 456710c2fa53..000000000000 --- a/crates/goose-server/tests/pricing_api_test.rs +++ /dev/null @@ -1,41 +0,0 @@ -use axum::http::StatusCode; -use axum::Router; -use axum::{body::Body, http::Request}; -use etcetera::AppStrategy; -use serde_json::json; -use std::sync::Arc; -use tower::ServiceExt; - -async fn create_test_app() -> Router { - let agent = Arc::new(goose::agents::Agent::default()); - let state = goose_server::AppState::new(agent); - - // Add scheduler setup like in the existing tests - let sched_storage_path = etcetera::choose_app_strategy(goose::config::APP_STRATEGY.clone()) - .unwrap() - .data_dir() - .join("schedules.json"); - let sched = goose::scheduler_factory::SchedulerFactory::create_legacy(sched_storage_path) - .await - .unwrap(); - state.set_scheduler(sched).await; - - goose_server::routes::config_management::routes(state) -} - -#[tokio::test] -async fn test_pricing_endpoint_basic() { - // Basic test to ensure pricing endpoint responds correctly - let app = create_test_app().await; - - let request = Request::builder() - .uri("/config/pricing") - .method("POST") - .header("content-type", "application/json") - .header("x-secret-key", "test") - .body(Body::from(json!({"configured_only": true}).to_string())) - .unwrap(); - - let response = app.oneshot(request).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); -} diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index e96e3e30e8b9..e58781393c72 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -15,6 +15,7 @@ tokio = { version = "1.43", features = ["full"] } reqwest = { version = "0.12.9", features = ["json", "rustls-tls-native-roots"], default-features = false } [dependencies] +lru = "0.12" mcp-client = { path = "../mcp-client" } mcp-core = { path = "../mcp-core" } rmcp = { workspace = true, features = [ diff --git a/crates/goose/src/agents/context.rs b/crates/goose/src/agents/context.rs index 8b594bbc0ab5..2a00b76da716 100644 --- a/crates/goose/src/agents/context.rs +++ b/crates/goose/src/agents/context.rs @@ -33,7 +33,7 @@ impl Agent { // Only add an assistant message if we have room for it and it won't cause another overflow let assistant_message = Message::assistant().with_text("I had run into a context length exceeded error so I truncated some of the oldest messages in our conversation."); let assistant_tokens = - token_counter.count_chat_tokens("", &[assistant_message.clone()], &[]); + token_counter.count_chat_tokens("", std::slice::from_ref(&assistant_message), &[]); let current_total: usize = new_token_counts.iter().sum(); if current_total + assistant_tokens <= target_context_limit { diff --git a/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs index daa00a96eda4..690b506d2c31 100644 --- a/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs @@ -45,7 +45,7 @@ pub async fn execute_single_task( .await; let execution_time = start_time.elapsed().as_millis(); - let stats = calculate_stats(&[result.clone()], execution_time); + let stats = calculate_stats(std::slice::from_ref(&result), execution_time); ExecutionResponse { status: EXECUTION_STATUS_COMPLETED.to_string(), diff --git a/crates/goose/src/conversation/mod.rs b/crates/goose/src/conversation/mod.rs index f1dbfa1e75b8..4b7588acacf8 100644 --- a/crates/goose/src/conversation/mod.rs +++ b/crates/goose/src/conversation/mod.rs @@ -86,7 +86,7 @@ impl Conversation { } } - pub fn iter(&self) -> std::slice::Iter { + pub fn iter(&self) -> std::slice::Iter<'_, Message> { self.0.iter() } diff --git a/crates/goose/src/execution/manager.rs b/crates/goose/src/execution/manager.rs new file mode 100644 index 000000000000..3ef38237e0d5 --- /dev/null +++ b/crates/goose/src/execution/manager.rs @@ -0,0 +1,150 @@ +//! Agent lifecycle management with session isolation + +use super::SessionExecutionMode; +use crate::agents::Agent; +use crate::config::APP_STRATEGY; +use crate::model::ModelConfig; +use crate::providers::create; +use crate::scheduler_factory::SchedulerFactory; +use crate::scheduler_trait::SchedulerTrait; +use anyhow::Result; +use etcetera::{choose_app_strategy, AppStrategy}; +use lru::LruCache; +use std::num::NonZeroUsize; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{debug, info, warn}; + +pub struct AgentManager { + sessions: Arc>>>, + scheduler: Arc, + default_provider: Arc>>>, +} + +impl AgentManager { + pub async fn new(max_sessions: Option) -> Result { + // Construct scheduler with the standard goose-server path + let schedule_file_path = choose_app_strategy(APP_STRATEGY.clone())? + .data_dir() + .join("schedule.json"); + + let scheduler = SchedulerFactory::create(schedule_file_path).await?; + + let capacity = NonZeroUsize::new(max_sessions.unwrap_or(100)) + .unwrap_or_else(|| NonZeroUsize::new(100).unwrap()); + + let manager = Self { + sessions: Arc::new(RwLock::new(LruCache::new(capacity))), + scheduler, + default_provider: Arc::new(RwLock::new(None)), + }; + + let _ = manager.configure_default_provider().await; + + Ok(manager) + } + + pub async fn scheduler(&self) -> Result> { + Ok(Arc::clone(&self.scheduler)) + } + + pub async fn set_default_provider(&self, provider: Arc) { + debug!("Setting default provider on AgentManager"); + *self.default_provider.write().await = Some(provider); + } + + pub async fn configure_default_provider(&self) -> Result<()> { + let provider_name = std::env::var("GOOSE_DEFAULT_PROVIDER") + .or_else(|_| std::env::var("GOOSE_PROVIDER__TYPE")) + .ok(); + + let model_name = std::env::var("GOOSE_DEFAULT_MODEL") + .or_else(|_| std::env::var("GOOSE_PROVIDER__MODEL")) + .ok(); + + if provider_name.is_none() || model_name.is_none() { + return Ok(()); + } + + if let (Some(provider_name), Some(model_name)) = (provider_name, model_name) { + match ModelConfig::new(&model_name) { + Ok(model_config) => match create(&provider_name, model_config) { + Ok(provider) => { + self.set_default_provider(provider).await; + info!( + "Configured default provider: {} with model: {}", + provider_name, model_name + ); + } + Err(e) => { + warn!("Failed to create default provider {}: {}", provider_name, e) + } + }, + Err(e) => warn!("Failed to create model config for {}: {}", model_name, e), + } + } + Ok(()) + } + + pub async fn get_or_create_agent( + &self, + session_id: String, + mode: SessionExecutionMode, + ) -> Result> { + let agent = { + let mut sessions = self.sessions.write().await; + if let Some(agent) = sessions.get(&session_id) { + debug!("Found existing agent for session {}", session_id); + return Ok(Arc::clone(agent)); + } + + info!( + "Creating new agent for session {} with mode {}", + session_id, mode + ); + let agent = Arc::new(Agent::new()); + sessions.put(session_id.clone(), Arc::clone(&agent)); + agent + }; + + match &mode { + SessionExecutionMode::Interactive | SessionExecutionMode::Background => { + debug!("Setting scheduler on agent for session {}", session_id); + agent.set_scheduler(Arc::clone(&self.scheduler)).await; + } + SessionExecutionMode::SubTask { .. } => { + debug!( + "SubTask mode for session {}, skipping scheduler setup", + session_id + ); + } + } + + if let Some(provider) = &*self.default_provider.read().await { + debug!( + "Setting default provider on agent for session {}", + session_id + ); + let _ = agent.update_provider(Arc::clone(provider)).await; + } + + Ok(agent) + } + + pub async fn remove_session(&self, session_id: &str) -> Result<()> { + let mut sessions = self.sessions.write().await; + sessions + .pop(session_id) + .ok_or_else(|| anyhow::anyhow!("Session {} not found", session_id))?; + info!("Removed session {}", session_id); + Ok(()) + } + + pub async fn has_session(&self, session_id: &str) -> bool { + self.sessions.read().await.contains(session_id) + } + + pub async fn session_count(&self) -> usize { + self.sessions.read().await.len() + } +} diff --git a/crates/goose/src/execution/mod.rs b/crates/goose/src/execution/mod.rs new file mode 100644 index 000000000000..beff687ab8d2 --- /dev/null +++ b/crates/goose/src/execution/mod.rs @@ -0,0 +1,45 @@ +//! Unified execution management for Goose agents +//! +//! This module provides centralized agent lifecycle management with session isolation, +//! enabling multiple concurrent sessions with independent agents, extensions, and providers. + +pub mod manager; + +use serde::{Deserialize, Serialize}; +use std::fmt; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum SessionExecutionMode { + Interactive, + Background, + SubTask { parent_session: String }, +} + +impl SessionExecutionMode { + /// Create an interactive chat mode + pub fn chat() -> Self { + Self::Interactive + } + + /// Create a background/scheduled mode + pub fn scheduled() -> Self { + Self::Background + } + + /// Create a sub-task mode with parent reference + pub fn task(parent: String) -> Self { + Self::SubTask { + parent_session: parent, + } + } +} + +impl fmt::Display for SessionExecutionMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Interactive => write!(f, "interactive"), + Self::Background => write!(f, "background"), + Self::SubTask { parent_session } => write!(f, "subtask(parent: {})", parent_session), + } + } +} diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index 7e7234a6075e..99ffe352a8d2 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -2,6 +2,7 @@ pub mod agents; pub mod config; pub mod context_mgmt; pub mod conversation; +pub mod execution; pub mod logging; pub mod model; pub mod oauth; diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index 5cd6dc9e14f8..f804aa7f0a89 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -144,7 +144,11 @@ pub async fn detect_read_only_tools( .unwrap_or_else(|_| "You are a good analyst and can detect operations whether they have read-only operations.".to_string()); let res = provider - .complete(&system_prompt, check_messages.messages(), &[tool.clone()]) + .complete( + &system_prompt, + check_messages.messages(), + std::slice::from_ref(&tool), + ) .await; // Process the response and return an empty vector if the response is invalid diff --git a/crates/goose/src/recipe/template_recipe.rs b/crates/goose/src/recipe/template_recipe.rs index 7396bb99f5af..3190195b2b54 100644 --- a/crates/goose/src/recipe/template_recipe.rs +++ b/crates/goose/src/recipe/template_recipe.rs @@ -112,7 +112,7 @@ fn add_template_in_env( content: &str, recipe_dir: String, undefined_behavior: UndefinedBehavior, -) -> Result { +) -> Result> { let mut env = minijinja::Environment::new(); env.set_undefined_behavior(undefined_behavior); env.set_loader(move |name| { @@ -136,7 +136,7 @@ fn get_env_with_template_variables( content: &str, recipe_dir: String, undefined_behavior: UndefinedBehavior, -) -> Result<(Environment, HashSet)> { +) -> Result<(Environment<'_>, HashSet)> { let env = add_template_in_env(content, recipe_dir, undefined_behavior)?; let template = env.get_template(CURRENT_TEMPLATE_NAME).unwrap(); let state = template.eval_to_state(())?; diff --git a/crates/goose/tests/execution_tests.rs b/crates/goose/tests/execution_tests.rs new file mode 100644 index 000000000000..2ba8fbce8152 --- /dev/null +++ b/crates/goose/tests/execution_tests.rs @@ -0,0 +1,323 @@ +mod execution_tests { + use goose::execution::manager::AgentManager; + use goose::execution::SessionExecutionMode; + use serial_test::serial; + use std::sync::Arc; + + #[test] + fn test_execution_mode_constructors() { + assert_eq!( + SessionExecutionMode::chat(), + SessionExecutionMode::Interactive + ); + assert_eq!( + SessionExecutionMode::scheduled(), + SessionExecutionMode::Background + ); + + let parent = "parent-123".to_string(); + assert_eq!( + SessionExecutionMode::task(parent.clone()), + SessionExecutionMode::SubTask { + parent_session: parent + } + ); + } + + #[tokio::test] + async fn test_session_isolation() { + let manager = AgentManager::new(None).await.unwrap(); + + let session1 = uuid::Uuid::new_v4().to_string(); + let session2 = uuid::Uuid::new_v4().to_string(); + + let agent1 = manager + .get_or_create_agent(session1.clone(), SessionExecutionMode::Interactive) + .await + .unwrap(); + + let agent2 = manager + .get_or_create_agent(session2.clone(), SessionExecutionMode::Interactive) + .await + .unwrap(); + + // Different sessions should have different agents + assert!(!Arc::ptr_eq(&agent1, &agent2)); + + // Getting the same session should return the same agent + let agent1_again = manager + .get_or_create_agent(session1, SessionExecutionMode::chat()) + .await + .unwrap(); + + assert!(Arc::ptr_eq(&agent1, &agent1_again)); + } + + #[tokio::test] + async fn test_session_limit() { + let manager = AgentManager::new(Some(3)).await.unwrap(); + + let sessions: Vec<_> = (0..3) + .map(|i| String::from(format!("session-{}", i))) + .collect(); + + for session in &sessions { + manager + .get_or_create_agent(session.clone(), SessionExecutionMode::chat()) + .await + .unwrap(); + } + + // Create a new session after cleanup + let new_session = "new-session".to_string(); + let _new_agent = manager + .get_or_create_agent(new_session, SessionExecutionMode::chat()) + .await + .unwrap(); + + assert_eq!(manager.session_count().await, 3); + assert!(!manager.has_session(&sessions[0]).await); + } + + #[tokio::test] + async fn test_remove_session() { + let manager = AgentManager::new(None).await.unwrap(); + let session = String::from("remove-test"); + + manager + .get_or_create_agent(session.clone(), SessionExecutionMode::chat()) + .await + .unwrap(); + assert!(manager.has_session(&session).await); + + manager.remove_session(&session).await.unwrap(); + assert!(!manager.has_session(&session).await); + + assert!(manager.remove_session(&session).await.is_err()); + } + + #[tokio::test] + async fn test_concurrent_access() { + let manager = Arc::new(AgentManager::new(None).await.unwrap()); + let session = String::from("concurrent-test"); + + let mut handles = vec![]; + for _ in 0..10 { + let mgr = Arc::clone(&manager); + let sess = session.clone(); + handles.push(tokio::spawn(async move { + mgr.get_or_create_agent(sess, SessionExecutionMode::chat()) + .await + .unwrap() + })); + } + + let agents: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + for agent in &agents[1..] { + assert!(Arc::ptr_eq(&agents[0], agent)); + } + + assert_eq!(manager.session_count().await, 1); + } + + #[tokio::test] + async fn test_different_modes_same_session() { + let manager = AgentManager::new(None).await.unwrap(); + let session_id = String::from("mode-test"); + + // Create initial agent + let agent1 = manager + .get_or_create_agent(session_id.clone(), SessionExecutionMode::chat()) + .await + .unwrap(); + + // Get same session with different mode - should return same agent + // (mode is stored but agent is reused) + let agent2 = manager + .get_or_create_agent(session_id.clone(), SessionExecutionMode::Background) + .await + .unwrap(); + + assert!(Arc::ptr_eq(&agent1, &agent2)); + } + + #[tokio::test] + async fn test_concurrent_session_creation_race_condition() { + // Test that concurrent attempts to create the same new session ID + // result in only one agent being created (tests double-check pattern) + let manager = Arc::new(AgentManager::new(None).await.unwrap()); + let session_id = String::from("race-condition-test"); + + // Spawn multiple tasks trying to create the same NEW session simultaneously + let mut handles = vec![]; + for _ in 0..20 { + let sess = session_id.clone(); + let mgr_clone = Arc::clone(&manager); + handles.push(tokio::spawn(async move { + mgr_clone + .get_or_create_agent(sess, SessionExecutionMode::Interactive) + .await + .unwrap() + })); + } + + // Collect all agents + let agents: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // All should be the same agent (double-check pattern should prevent duplicates) + for agent in &agents[1..] { + assert!( + Arc::ptr_eq(&agents[0], agent), + "All concurrent requests should get the same agent" + ); + } + + // Only one session should exist + assert_eq!(manager.session_count().await, 1); + } + + #[tokio::test] + async fn test_edge_case_max_sessions_one() { + let manager = AgentManager::new(Some(1)).await.unwrap(); + + let session1 = String::from("only-session"); + manager + .get_or_create_agent(session1.clone(), SessionExecutionMode::Interactive) + .await + .unwrap(); + + assert_eq!(manager.session_count().await, 1); + + // Creating second session should evict the first + let session2 = String::from("new-session"); + manager + .get_or_create_agent(session2.clone(), SessionExecutionMode::Interactive) + .await + .unwrap(); + + assert!(!manager.has_session(&session1).await); + assert!(manager.has_session(&session2).await); + assert_eq!(manager.session_count().await, 1); + } + + #[tokio::test] + #[serial] + async fn test_configure_default_provider() { + use std::env; + + let original_provider = env::var("GOOSE_DEFAULT_PROVIDER").ok(); + let original_model = env::var("GOOSE_DEFAULT_MODEL").ok(); + + env::set_var("GOOSE_DEFAULT_PROVIDER", "openai"); + env::set_var("GOOSE_DEFAULT_MODEL", "gpt-4o-mini"); + + let manager = AgentManager::new(None).await.unwrap(); + let result = manager.configure_default_provider().await; + + assert!(result.is_ok()); + + // Restore original env vars + if let Some(val) = original_provider { + env::set_var("GOOSE_DEFAULT_PROVIDER", val); + } else { + env::remove_var("GOOSE_DEFAULT_PROVIDER"); + } + if let Some(val) = original_model { + env::set_var("GOOSE_DEFAULT_MODEL", val); + } else { + env::remove_var("GOOSE_DEFAULT_MODEL"); + } + } + + #[tokio::test] + async fn test_set_default_provider() { + use goose::providers::testprovider::TestProvider; + use std::sync::Arc; + + let manager = AgentManager::new(None).await.unwrap(); + + // Create a test provider for replaying (doesn't need inner provider) + let temp_file = format!( + "{}/test_provider_{}.json", + std::env::temp_dir().display(), + std::process::id() + ); + + // Create an empty test provider (will fail on actual use but that's ok for this test) + let test_provider = TestProvider::new_replaying(&temp_file) + .unwrap_or_else(|_| TestProvider::new_replaying("/tmp/dummy.json").unwrap()); + + manager.set_default_provider(Arc::new(test_provider)).await; + + let session = String::from("provider-test"); + let _agent = manager + .get_or_create_agent(session.clone(), SessionExecutionMode::Interactive) + .await + .unwrap(); + + assert!(manager.has_session(&session).await); + } + + #[tokio::test] + async fn test_eviction_updates_last_used() { + // Test that accessing a session updates its last_used timestamp + // and affects eviction order + let manager = AgentManager::new(Some(2)).await.unwrap(); + + let session1 = String::from("session-1"); + let session2 = String::from("session-2"); + + manager + .get_or_create_agent(session1.clone(), SessionExecutionMode::Interactive) + .await + .unwrap(); + + // Small delay to ensure different timestamps + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + manager + .get_or_create_agent(session2.clone(), SessionExecutionMode::Interactive) + .await + .unwrap(); + + // Access session1 again to update its last_used + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + manager + .get_or_create_agent(session1.clone(), SessionExecutionMode::Interactive) + .await + .unwrap(); + + // Now create a third session - should evict session2 (least recently used) + let session3 = String::from("session-3"); + manager + .get_or_create_agent(session3.clone(), SessionExecutionMode::Interactive) + .await + .unwrap(); + + // session1 should still exist (recently accessed) + // session2 should be evicted (least recently used) + assert!(manager.has_session(&session1).await); + assert!(!manager.has_session(&session2).await); + assert!(manager.has_session(&session3).await); + } + + #[tokio::test] + async fn test_remove_nonexistent_session_error() { + // Test that removing a non-existent session returns an error + let manager = AgentManager::new(None).await.unwrap(); + let session = String::from("never-created"); + + let result = manager.remove_session(&session).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("not found")); + } +} diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index c0cf1ca5d35d..2fc1bdee0b28 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -137,8 +137,8 @@ impl ProviderTester { .provider .complete( "You are a helpful weather assistant.", - &[message.clone()], - &[weather_tool.clone()], + std::slice::from_ref(&message), + std::slice::from_ref(&weather_tool), ) .await?; diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index fd2b4b835e3f..69c0ba28323e 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -1769,7 +1769,8 @@ "description": "Request payload for context management operations", "required": [ "messages", - "manageAction" + "manageAction", + "sessionId" ], "properties": { "manageAction": { @@ -1782,6 +1783,10 @@ "$ref": "#/components/schemas/Message" }, "description": "Collection of messages to be managed" + }, + "sessionId": { + "type": "string", + "description": "Optional session ID for session-specific agent" } } }, @@ -1849,7 +1854,8 @@ "required": [ "messages", "title", - "description" + "description", + "session_id" ], "properties": { "activities": { @@ -1876,6 +1882,9 @@ "$ref": "#/components/schemas/Message" } }, + "session_id": { + "type": "string" + }, "title": { "type": "string" } diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index bf6a30e182c6..f05649ef8f7f 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -83,6 +83,10 @@ export type ContextManageRequest = { * Collection of messages to be managed */ messages: Array; + /** + * Optional session ID for session-specific agent + */ + sessionId: string; }; /** @@ -113,6 +117,7 @@ export type CreateRecipeRequest = { author?: AuthorRequest | null; description: string; messages: Array; + session_id: string; title: string; }; diff --git a/ui/desktop/src/components/extensions/ExtensionsView.tsx b/ui/desktop/src/components/extensions/ExtensionsView.tsx index 1e2c52fd4863..8907435dc1d2 100644 --- a/ui/desktop/src/components/extensions/ExtensionsView.tsx +++ b/ui/desktop/src/components/extensions/ExtensionsView.tsx @@ -1,4 +1,5 @@ import { View, ViewOptions } from '../../utils/navigationUtils'; +import { useChatContext } from '../../contexts/ChatContext'; import ExtensionsSection from '../settings/extensions/ExtensionsSection'; import { ExtensionConfig } from '../../api'; import { MainPanelLayout } from '../Layout/MainPanelLayout'; @@ -30,6 +31,12 @@ export default function ExtensionsView({ const [isAddModalOpen, setIsAddModalOpen] = useState(false); const [refreshKey, setRefreshKey] = useState(0); const { addExtension } = useConfig(); + const chatContext = useChatContext(); + const sessionId = chatContext?.chat.sessionId || ''; + + if (!sessionId) { + console.error('ExtensionsView: No session ID available'); + } // Trigger refresh when deep link config changes (i.e., when a deep link is processed) useEffect(() => { @@ -46,9 +53,20 @@ export default function ExtensionsView({ // Close the modal immediately handleModalClose(); + if (!sessionId) { + console.warn('Cannot activate extension without session'); + setRefreshKey((prevKey) => prevKey + 1); + return; + } + const extensionConfig = createExtensionConfig(formData); + try { - await activateExtension({ addToConfig: addExtension, extensionConfig: extensionConfig }); + await activateExtension({ + addToConfig: addExtension, + extensionConfig: extensionConfig, + sessionId: sessionId, + }); // Trigger a refresh of the extensions list setRefreshKey((prevKey) => prevKey + 1); } catch (error) { @@ -97,6 +115,7 @@ export default function ExtensionsView({
{ }; mockFetch.mockResolvedValue(mockResponse); - const response = await extensionApiCall('/extensions/add', mockExtensionConfig); + const response = await extensionApiCall( + '/extensions/add', + mockExtensionConfig, + {}, + 'test-session' + ); expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/extensions/add', { method: 'POST', @@ -69,7 +74,7 @@ describe('Agent API', () => { 'Content-Type': 'application/json', 'X-Secret-Key': 'secret-key', }, - body: JSON.stringify(mockExtensionConfig), + body: JSON.stringify({ ...mockExtensionConfig, session_id: 'test-session' }), }); expect(mockToastService.loading).toHaveBeenCalledWith({ @@ -92,7 +97,12 @@ describe('Agent API', () => { }; mockFetch.mockResolvedValue(mockResponse); - const response = await extensionApiCall('/extensions/remove', 'test-extension'); + const response = await extensionApiCall( + '/extensions/remove', + 'test-extension', + {}, + 'test-session' + ); expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/extensions/remove', { method: 'POST', @@ -100,7 +110,7 @@ describe('Agent API', () => { 'Content-Type': 'application/json', 'X-Secret-Key': 'secret-key', }, - body: JSON.stringify('test-extension'), + body: JSON.stringify({ name: 'test-extension', session_id: 'test-session' }), }); expect(mockToastService.loading).not.toHaveBeenCalled(); // No loading toast for removal @@ -120,9 +130,9 @@ describe('Agent API', () => { }; mockFetch.mockResolvedValue(mockResponse); - await expect(extensionApiCall('/extensions/add', mockExtensionConfig)).rejects.toThrow( - 'Server returned 500: Internal Server Error' - ); + await expect( + extensionApiCall('/extensions/add', mockExtensionConfig, {}, 'test-session') + ).rejects.toThrow('Server returned 500: Internal Server Error'); expect(mockToastService.error).toHaveBeenCalledWith({ title: 'test-extension', @@ -139,9 +149,9 @@ describe('Agent API', () => { }; mockFetch.mockResolvedValue(mockResponse); - await expect(extensionApiCall('/extensions/add', mockExtensionConfig)).rejects.toThrow( - 'Agent is not initialized. Please initialize the agent first.' - ); + await expect( + extensionApiCall('/extensions/add', mockExtensionConfig, {}, 'test-session') + ).rejects.toThrow('Agent is not initialized. Please initialize the agent first.'); expect(mockToastService.error).toHaveBeenCalledWith({ title: 'test-extension', @@ -157,9 +167,9 @@ describe('Agent API', () => { }; mockFetch.mockResolvedValue(mockResponse); - await expect(extensionApiCall('/extensions/remove', 'test-extension')).rejects.toThrow( - 'Error deactivating extension: Extension not found' - ); + await expect( + extensionApiCall('/extensions/remove', 'test-extension', {}, 'test-session') + ).rejects.toThrow('Error deactivating extension: Extension not found'); expect(mockToastService.error).toHaveBeenCalledWith({ title: 'test-extension', @@ -175,7 +185,12 @@ describe('Agent API', () => { }; mockFetch.mockResolvedValue(mockResponse); - const response = await extensionApiCall('/extensions/add', mockExtensionConfig); + const response = await extensionApiCall( + '/extensions/add', + mockExtensionConfig, + {}, + 'test-session' + ); expect(mockToastService.success).toHaveBeenCalledWith({ title: 'test-extension', @@ -189,9 +204,9 @@ describe('Agent API', () => { const networkError = new Error('Network error'); mockFetch.mockRejectedValue(networkError); - await expect(extensionApiCall('/extensions/add', mockExtensionConfig)).rejects.toThrow( - 'Network error' - ); + await expect( + extensionApiCall('/extensions/add', mockExtensionConfig, {}, 'test-session') + ).rejects.toThrow('Network error'); expect(mockToastService.error).toHaveBeenCalledWith({ title: 'test-extension', @@ -207,7 +222,12 @@ describe('Agent API', () => { }; mockFetch.mockResolvedValue(mockResponse); - await extensionApiCall('/extensions/add', mockExtensionConfig, { silent: true }); + await extensionApiCall( + '/extensions/add', + mockExtensionConfig, + { silent: true }, + 'test-session' + ); expect(mockToastService.configure).toHaveBeenCalledWith({ silent: true }); }); @@ -232,7 +252,7 @@ describe('Agent API', () => { const { replaceWithShims } = await import('./utils'); vi.mocked(replaceWithShims).mockResolvedValue('/path/to/python'); - await addToAgent(mockExtensionConfig); + await addToAgent(mockExtensionConfig, {}, 'test-session'); expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/extensions/add', { method: 'POST', @@ -244,6 +264,7 @@ describe('Agent API', () => { ...mockExtensionConfig, name: 'testextension', cmd: '/path/to/python', + session_id: 'test-session', }), }); }); @@ -256,7 +277,7 @@ describe('Agent API', () => { }; mockFetch.mockResolvedValue(mockResponse); - await expect(addToAgent(mockExtensionConfig)).rejects.toThrow( + await expect(addToAgent(mockExtensionConfig, {}, 'test-session')).rejects.toThrow( 'Agent is not initialized. Please initialize the agent first.' ); }); @@ -274,7 +295,7 @@ describe('Agent API', () => { }; mockFetch.mockResolvedValue(mockResponse); - await addToAgent(sseConfig); + await addToAgent(sseConfig, {}, 'test-session'); expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/extensions/add', { method: 'POST', @@ -285,6 +306,7 @@ describe('Agent API', () => { body: JSON.stringify({ ...sseConfig, name: 'sseextension', + session_id: 'test-session', }), }); }); @@ -298,7 +320,7 @@ describe('Agent API', () => { }; mockFetch.mockResolvedValue(mockResponse); - await removeFromAgent('Test Extension'); + await removeFromAgent('Test Extension', {}, 'test-session'); expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/extensions/remove', { method: 'POST', @@ -306,7 +328,7 @@ describe('Agent API', () => { 'Content-Type': 'application/json', 'X-Secret-Key': 'secret-key', }, - body: JSON.stringify('testextension'), + body: JSON.stringify({ name: 'testextension', session_id: 'test-session' }), }); }); @@ -318,7 +340,7 @@ describe('Agent API', () => { }; mockFetch.mockResolvedValue(mockResponse); - await expect(removeFromAgent('Test Extension')).rejects.toThrow(); + await expect(removeFromAgent('Test Extension', {}, 'test-session')).rejects.toThrow(); expect(mockToastService.error).toHaveBeenCalled(); }); diff --git a/ui/desktop/src/components/settings/extensions/agent-api.ts b/ui/desktop/src/components/settings/extensions/agent-api.ts index a807e140a53b..bc229a597e25 100644 --- a/ui/desktop/src/components/settings/extensions/agent-api.ts +++ b/ui/desktop/src/components/settings/extensions/agent-api.ts @@ -14,7 +14,8 @@ interface ApiResponse { export async function extensionApiCall( endpoint: string, payload: ExtensionConfig | string, - options: ToastServiceOptions & { isDelete?: boolean } = {} + options: ToastServiceOptions & { isDelete?: boolean } = {}, + sessionId: string ): Promise { // Configure toast notifications toastService.configure(options); @@ -43,6 +44,16 @@ export async function extensionApiCall( } try { + // Build the request body + let requestBody: ExtensionConfig | { name: string; session_id: string }; + if (typeof payload === 'object') { + // For adding extensions (ExtensionConfig) + requestBody = { ...payload, session_id: sessionId }; + } else { + // For removing extensions (just the name string) + requestBody = { name: payload, session_id: sessionId }; + } + // Step 2: Make the API call const response = await fetch(getApiUrl(endpoint), { method: 'POST', @@ -50,7 +61,7 @@ export async function extensionApiCall( 'Content-Type': 'application/json', 'X-Secret-Key': await window.electron.getSecretKey(), }, - body: JSON.stringify(payload), + body: JSON.stringify(requestBody), }); // Step 3: Handle non-successful responses @@ -142,7 +153,8 @@ async function parseResponseData(response: Response): Promise { */ export async function addToAgent( extension: ExtensionConfig, - options: ToastServiceOptions = {} + options: ToastServiceOptions = {}, + sessionId: string ): Promise { try { if (extension.type === 'stdio') { @@ -151,7 +163,7 @@ export async function addToAgent( extension.name = sanitizeName(extension.name); - return await extensionApiCall('/extensions/add', extension, options); + return await extensionApiCall('/extensions/add', extension, options, sessionId); } catch (error) { // Check if this is a 428 error and make the message more descriptive if (error instanceof Error && error.message && error.message.includes('428')) { @@ -170,10 +182,11 @@ export async function addToAgent( */ export async function removeFromAgent( name: string, - options: ToastServiceOptions & { isDelete?: boolean } = {} + options: ToastServiceOptions & { isDelete?: boolean } = {}, + sessionId: string ): Promise { try { - return await extensionApiCall('/extensions/remove', sanitizeName(name), options); + return await extensionApiCall('/extensions/remove', sanitizeName(name), options, sessionId); } catch (error) { const action = options.isDelete ? 'remove' : 'deactivate'; console.error(`Failed to ${action} extension ${name} from agent:`, error); diff --git a/ui/desktop/src/components/settings/extensions/deeplink.ts b/ui/desktop/src/components/settings/extensions/deeplink.ts index 79ff7d96ea4a..ffe1f6f74dbf 100644 --- a/ui/desktop/src/components/settings/extensions/deeplink.ts +++ b/ui/desktop/src/components/settings/extensions/deeplink.ts @@ -1,6 +1,5 @@ import type { ExtensionConfig } from '../../../api'; import { toastService } from '../../../toasts'; -import { activateExtension } from './extension-manager'; import { DEFAULT_EXTENSION_TIMEOUT } from './utils'; /** @@ -180,7 +179,11 @@ export async function addExtensionFromDeepLink( try { console.log('No env vars required, activating extension directly'); - await activateExtension({ extensionConfig: config, addToConfig: addExtensionFn }); + // Note: deeplink activation doesn't have access to sessionId + // The extension will be added to config but not activated in the current session + // It will be activated when the next session starts + console.warn('Extension will be added to config but requires a session to activate'); + await addExtensionFn(config.name, config, true); } catch (error) { console.error('Failed to activate extension from deeplink:', error); throw error; diff --git a/ui/desktop/src/components/settings/extensions/extension-manager.test.ts b/ui/desktop/src/components/settings/extensions/extension-manager.test.ts index 7310eb038e5c..75a15dbcdca9 100644 --- a/ui/desktop/src/components/settings/extensions/extension-manager.test.ts +++ b/ui/desktop/src/components/settings/extensions/extension-manager.test.ts @@ -43,10 +43,15 @@ describe('Extension Manager', () => { await activateExtension({ addToConfig: mockAddToConfig, + sessionId: 'test-session', extensionConfig: mockExtensionConfig, }); - expect(mockAddToAgent).toHaveBeenCalledWith(mockExtensionConfig, { silent: false }); + expect(mockAddToAgent).toHaveBeenCalledWith( + mockExtensionConfig, + { silent: false }, + 'test-session' + ); expect(mockAddToConfig).toHaveBeenCalledWith('test-extension', mockExtensionConfig, true); }); @@ -57,11 +62,16 @@ describe('Extension Manager', () => { await expect( activateExtension({ addToConfig: mockAddToConfig, + sessionId: 'test-session', extensionConfig: mockExtensionConfig, }) ).rejects.toThrow('Agent failed'); - expect(mockAddToAgent).toHaveBeenCalledWith(mockExtensionConfig, { silent: false }); + expect(mockAddToAgent).toHaveBeenCalledWith( + mockExtensionConfig, + { silent: false }, + 'test-session' + ); expect(mockAddToConfig).toHaveBeenCalledWith('test-extension', mockExtensionConfig, false); }); @@ -73,13 +83,18 @@ describe('Extension Manager', () => { await expect( activateExtension({ addToConfig: mockAddToConfig, + sessionId: 'test-session', extensionConfig: mockExtensionConfig, }) ).rejects.toThrow('Config failed'); - expect(mockAddToAgent).toHaveBeenCalledWith(mockExtensionConfig, { silent: false }); + expect(mockAddToAgent).toHaveBeenCalledWith( + mockExtensionConfig, + { silent: false }, + 'test-session' + ); expect(mockAddToConfig).toHaveBeenCalledWith('test-extension', mockExtensionConfig, true); - expect(mockRemoveFromAgent).toHaveBeenCalledWith('test-extension'); + expect(mockRemoveFromAgent).toHaveBeenCalledWith('test-extension', {}, 'test-session'); }); }); @@ -89,10 +104,15 @@ describe('Extension Manager', () => { await addToAgentOnStartup({ addToConfig: mockAddToConfig, + sessionId: 'test-session', extensionConfig: mockExtensionConfig, }); - expect(mockAddToAgent).toHaveBeenCalledWith(mockExtensionConfig, { silent: true }); + expect(mockAddToAgent).toHaveBeenCalledWith( + mockExtensionConfig, + { silent: true }, + 'test-session' + ); expect(mockAddToConfig).not.toHaveBeenCalled(); }); @@ -101,11 +121,16 @@ describe('Extension Manager', () => { await addToAgentOnStartup({ addToConfig: mockAddToConfig, + sessionId: 'test-session', extensionConfig: mockExtensionConfig, toastOptions: { silent: false }, }); - expect(mockAddToAgent).toHaveBeenCalledWith(mockExtensionConfig, { silent: false }); + expect(mockAddToAgent).toHaveBeenCalledWith( + mockExtensionConfig, + { silent: false }, + 'test-session' + ); expect(mockAddToConfig).not.toHaveBeenCalled(); }); @@ -118,6 +143,7 @@ describe('Extension Manager', () => { await addToAgentOnStartup({ addToConfig: mockAddToConfig, + sessionId: 'test-session', extensionConfig: mockExtensionConfig, }); @@ -132,6 +158,7 @@ describe('Extension Manager', () => { await addToAgentOnStartup({ addToConfig: mockAddToConfig, + sessionId: 'test-session', extensionConfig: mockExtensionConfig, }); @@ -153,6 +180,7 @@ describe('Extension Manager', () => { await updateExtension({ enabled: true, addToConfig: mockAddToConfig, + sessionId: 'test-session', removeFromConfig: mockRemoveFromConfig, extensionConfig: mockExtensionConfig, originalName: 'test-extension', @@ -160,7 +188,8 @@ describe('Extension Manager', () => { expect(mockAddToAgent).toHaveBeenCalledWith( { ...mockExtensionConfig, name: 'test-extension' }, - { silent: true } + { silent: true }, + 'test-session' ); expect(mockAddToConfig).toHaveBeenCalledWith( 'test-extension', @@ -183,16 +212,22 @@ describe('Extension Manager', () => { await updateExtension({ enabled: true, addToConfig: mockAddToConfig, + sessionId: 'test-session', removeFromConfig: mockRemoveFromConfig, extensionConfig: { ...mockExtensionConfig, name: 'new-extension' }, originalName: 'old-extension', }); - expect(mockRemoveFromAgent).toHaveBeenCalledWith('old-extension', { silent: true }); + expect(mockRemoveFromAgent).toHaveBeenCalledWith( + 'old-extension', + { silent: true }, + 'test-session' + ); expect(mockRemoveFromConfig).toHaveBeenCalledWith('old-extension'); expect(mockAddToAgent).toHaveBeenCalledWith( { ...mockExtensionConfig, name: 'new-extension' }, - { silent: true } + { silent: true }, + 'test-session' ); expect(mockAddToConfig).toHaveBeenCalledWith( 'new-extension', @@ -208,6 +243,7 @@ describe('Extension Manager', () => { await updateExtension({ enabled: false, addToConfig: mockAddToConfig, + sessionId: 'test-session', removeFromConfig: mockRemoveFromConfig, extensionConfig: mockExtensionConfig, originalName: 'test-extension', @@ -235,9 +271,10 @@ describe('Extension Manager', () => { toggle: 'toggleOn', extensionConfig: mockExtensionConfig, addToConfig: mockAddToConfig, + sessionId: 'test-session', }); - expect(mockAddToAgent).toHaveBeenCalledWith(mockExtensionConfig, {}); + expect(mockAddToAgent).toHaveBeenCalledWith(mockExtensionConfig, {}, 'test-session'); expect(mockAddToConfig).toHaveBeenCalledWith('test-extension', mockExtensionConfig, true); }); @@ -249,9 +286,10 @@ describe('Extension Manager', () => { toggle: 'toggleOff', extensionConfig: mockExtensionConfig, addToConfig: mockAddToConfig, + sessionId: 'test-session', }); - expect(mockRemoveFromAgent).toHaveBeenCalledWith('test-extension', {}); + expect(mockRemoveFromAgent).toHaveBeenCalledWith('test-extension', {}, 'test-session'); expect(mockAddToConfig).toHaveBeenCalledWith('test-extension', mockExtensionConfig, false); }); @@ -265,10 +303,11 @@ describe('Extension Manager', () => { toggle: 'toggleOn', extensionConfig: mockExtensionConfig, addToConfig: mockAddToConfig, + sessionId: 'test-session', }) ).rejects.toThrow('Agent failed'); - expect(mockAddToAgent).toHaveBeenCalledWith(mockExtensionConfig, {}); + expect(mockAddToAgent).toHaveBeenCalledWith(mockExtensionConfig, {}, 'test-session'); // addToConfig is called during the rollback (toggleOff) expect(mockAddToConfig).toHaveBeenCalledWith('test-extension', mockExtensionConfig, false); }); @@ -283,12 +322,13 @@ describe('Extension Manager', () => { toggle: 'toggleOn', extensionConfig: mockExtensionConfig, addToConfig: mockAddToConfig, + sessionId: 'test-session', }) ).rejects.toThrow('Config failed'); - expect(mockAddToAgent).toHaveBeenCalledWith(mockExtensionConfig, {}); + expect(mockAddToAgent).toHaveBeenCalledWith(mockExtensionConfig, {}, 'test-session'); expect(mockAddToConfig).toHaveBeenCalledWith('test-extension', mockExtensionConfig, true); - expect(mockRemoveFromAgent).toHaveBeenCalledWith('test-extension', {}); + expect(mockRemoveFromAgent).toHaveBeenCalledWith('test-extension', {}, 'test-session'); }); it('should update config even if agent removal fails when toggling off', async () => { @@ -301,10 +341,11 @@ describe('Extension Manager', () => { toggle: 'toggleOff', extensionConfig: mockExtensionConfig, addToConfig: mockAddToConfig, + sessionId: 'test-session', }) ).rejects.toThrow('Agent removal failed'); - expect(mockRemoveFromAgent).toHaveBeenCalledWith('test-extension', {}); + expect(mockRemoveFromAgent).toHaveBeenCalledWith('test-extension', {}, 'test-session'); expect(mockAddToConfig).toHaveBeenCalledWith('test-extension', mockExtensionConfig, false); }); }); @@ -317,9 +358,14 @@ describe('Extension Manager', () => { await deleteExtension({ name: 'test-extension', removeFromConfig: mockRemoveFromConfig, + sessionId: 'test-session', }); - expect(mockRemoveFromAgent).toHaveBeenCalledWith('test-extension', { isDelete: true }); + expect(mockRemoveFromAgent).toHaveBeenCalledWith( + 'test-extension', + { isDelete: true }, + 'test-session' + ); expect(mockRemoveFromConfig).toHaveBeenCalledWith('test-extension'); }); @@ -332,10 +378,15 @@ describe('Extension Manager', () => { deleteExtension({ name: 'test-extension', removeFromConfig: mockRemoveFromConfig, + sessionId: 'test-session', }) ).rejects.toThrow('Agent removal failed'); - expect(mockRemoveFromAgent).toHaveBeenCalledWith('test-extension', { isDelete: true }); + expect(mockRemoveFromAgent).toHaveBeenCalledWith( + 'test-extension', + { isDelete: true }, + 'test-session' + ); expect(mockRemoveFromConfig).toHaveBeenCalledWith('test-extension'); }); @@ -349,10 +400,15 @@ describe('Extension Manager', () => { deleteExtension({ name: 'test-extension', removeFromConfig: mockRemoveFromConfig, + sessionId: 'test-session', }) ).rejects.toThrow('Config removal failed'); - expect(mockRemoveFromAgent).toHaveBeenCalledWith('test-extension', { isDelete: true }); + expect(mockRemoveFromAgent).toHaveBeenCalledWith( + 'test-extension', + { isDelete: true }, + 'test-session' + ); expect(mockRemoveFromConfig).toHaveBeenCalledWith('test-extension'); }); }); diff --git a/ui/desktop/src/components/settings/extensions/extension-manager.ts b/ui/desktop/src/components/settings/extensions/extension-manager.ts index 3a1cdb538dcf..a27693eadf5e 100644 --- a/ui/desktop/src/components/settings/extensions/extension-manager.ts +++ b/ui/desktop/src/components/settings/extensions/extension-manager.ts @@ -5,6 +5,7 @@ import { addToAgent, removeFromAgent, sanitizeName } from './agent-api'; interface ActivateExtensionProps { addToConfig: (name: string, extensionConfig: ExtensionConfig, enabled: boolean) => Promise; extensionConfig: ExtensionConfig; + sessionId: string; } type ExtensionError = { @@ -55,10 +56,11 @@ async function retryWithBackoff(fn: () => Promise, options: RetryOptions = export async function activateExtension({ addToConfig, extensionConfig, + sessionId, }: ActivateExtensionProps): Promise { try { // AddToAgent - await addToAgent(extensionConfig, { silent: false }); + await addToAgent(extensionConfig, { silent: false }, sessionId); } catch (error) { console.error('Failed to add extension to agent:', error); // add to config with enabled = false @@ -74,7 +76,7 @@ export async function activateExtension({ console.error('Failed to add extension to config:', error); // remove from Agent try { - await removeFromAgent(extensionConfig.name); + await removeFromAgent(extensionConfig.name, {}, sessionId); } catch (removeError) { console.error('Failed to remove extension from agent after config failure:', removeError); } @@ -87,6 +89,7 @@ interface AddToAgentOnStartupProps { addToConfig: (name: string, extensionConfig: ExtensionConfig, enabled: boolean) => Promise; extensionConfig: ExtensionConfig; toastOptions?: ToastServiceOptions; + sessionId: string; } /** @@ -96,9 +99,10 @@ export async function addToAgentOnStartup({ addToConfig, extensionConfig, toastOptions = { silent: true }, + sessionId, }: AddToAgentOnStartupProps): Promise { try { - await retryWithBackoff(() => addToAgent(extensionConfig, toastOptions), { + await retryWithBackoff(() => addToAgent(extensionConfig, toastOptions, sessionId), { retries: 3, delayMs: 1000, shouldRetry: (error: ExtensionError) => @@ -121,6 +125,7 @@ export async function addToAgentOnStartup({ extensionConfig, addToConfig, toastOptions: { silent: true }, + sessionId, }); } catch (toggleErr) { console.error('Failed to toggle off after error:', toggleErr); @@ -134,6 +139,7 @@ interface UpdateExtensionProps { removeFromConfig: (name: string) => Promise; extensionConfig: ExtensionConfig; originalName?: string; + sessionId: string; } /** @@ -145,6 +151,7 @@ export async function updateExtension({ removeFromConfig, extensionConfig, originalName, + sessionId, }: UpdateExtensionProps) { // Sanitize the new name to match the behavior when adding extensions const sanitizedNewName = sanitizeName(extensionConfig.name); @@ -158,7 +165,7 @@ export async function updateExtension({ // First remove the old extension from agent (using original name) try { - await removeFromAgent(originalName!, { silent: true }); // Suppress removal toast since we'll show update toast + await removeFromAgent(originalName!, { silent: true }, sessionId); // Suppress removal toast since we'll show update toast } catch (error) { console.error('Failed to remove old extension from agent during rename:', error); // Continue with the process even if agent removal fails @@ -182,7 +189,7 @@ export async function updateExtension({ if (enabled) { try { // AddToAgent with silent option to avoid duplicate toasts - await addToAgent(sanitizedExtensionConfig, { silent: true }); + await addToAgent(sanitizedExtensionConfig, { silent: true }, sessionId); } catch (error) { console.error('[updateExtension]: Failed to add renamed extension to agent:', error); throw error; @@ -212,7 +219,7 @@ export async function updateExtension({ if (enabled) { try { // AddToAgent with silent option to avoid duplicate toasts - await addToAgent(sanitizedExtensionConfig, { silent: true }); + await addToAgent(sanitizedExtensionConfig, { silent: true }, sessionId); } catch (error) { console.error('[updateExtension]: Failed to add extension to agent during update:', error); // Failed to add to agent -- show that error to user and do not update the config file @@ -254,6 +261,7 @@ interface ToggleExtensionProps { extensionConfig: ExtensionConfig; addToConfig: (name: string, extensionConfig: ExtensionConfig, enabled: boolean) => Promise; toastOptions?: ToastServiceOptions; + sessionId: string; } /** @@ -264,14 +272,19 @@ export async function toggleExtension({ extensionConfig, addToConfig, toastOptions = {}, + sessionId, }: ToggleExtensionProps) { // disabled to enabled if (toggle == 'toggleOn') { try { // add to agent with toast options - await addToAgent(extensionConfig, { - ...toastOptions, - }); + await addToAgent( + extensionConfig, + { + ...toastOptions, + }, + sessionId + ); } catch (error) { console.error('Error adding extension to agent. Will try to toggle back off.'); try { @@ -280,6 +293,7 @@ export async function toggleExtension({ extensionConfig, addToConfig, toastOptions: { silent: true }, // otherwise we will see a toast for removing something that was never added + sessionId, }); } catch (toggleError) { console.error('Failed to toggle extension off after agent error:', toggleError); @@ -294,7 +308,7 @@ export async function toggleExtension({ console.error('Failed to update config after enabling extension:', error); // remove from agent try { - await removeFromAgent(extensionConfig.name, toastOptions); + await removeFromAgent(extensionConfig.name, toastOptions, sessionId); } catch (removeError) { console.error('Failed to remove extension from agent after config failure:', removeError); } @@ -304,7 +318,7 @@ export async function toggleExtension({ // enabled to disabled let agentRemoveError = null; try { - await removeFromAgent(extensionConfig.name, toastOptions); + await removeFromAgent(extensionConfig.name, toastOptions, sessionId); } catch (error) { // note there was an error, but attempt to remove from config anyway console.error('Error removing extension from agent', extensionConfig.name, error); @@ -329,16 +343,17 @@ export async function toggleExtension({ interface DeleteExtensionProps { name: string; removeFromConfig: (name: string) => Promise; + sessionId: string; } /** * Deletes an extension completely from both agent and config */ -export async function deleteExtension({ name, removeFromConfig }: DeleteExtensionProps) { +export async function deleteExtension({ name, removeFromConfig, sessionId }: DeleteExtensionProps) { // remove from agent let agentRemoveError = null; try { - await removeFromAgent(name, { isDelete: true }); + await removeFromAgent(name, { isDelete: true }, sessionId); } catch (error) { console.error('Failed to remove extension from agent during deletion:', error); agentRemoveError = error; diff --git a/ui/desktop/src/hooks/useRecipeManager.ts b/ui/desktop/src/hooks/useRecipeManager.ts index 8b91da5314a1..ce182d383f65 100644 --- a/ui/desktop/src/hooks/useRecipeManager.ts +++ b/ui/desktop/src/hooks/useRecipeManager.ts @@ -195,6 +195,7 @@ export const useRecipeManager = (chat: ChatType, recipeConfig?: Recipe | null) = messages: messagesRef.current, title: '', description: '', + session_id: chat.sessionId, }; const response = await createRecipe(createRecipeRequest); @@ -243,7 +244,7 @@ export const useRecipeManager = (chat: ChatType, recipeConfig?: Recipe | null) = return () => { window.removeEventListener('make-agent-from-chat', handleMakeAgent); }; - }, []); + }, [chat.sessionId]); return { recipeConfig: finalRecipeConfig, diff --git a/ui/desktop/src/recipe/index.ts b/ui/desktop/src/recipe/index.ts index 4e24119ef488..b3004bb543c3 100644 --- a/ui/desktop/src/recipe/index.ts +++ b/ui/desktop/src/recipe/index.ts @@ -39,6 +39,7 @@ export interface CreateRecipeRequest { contact?: string; metadata?: string; }; + session_id: string; } export type CreateRecipeResponse = ApiCreateRecipeResponse; @@ -69,6 +70,7 @@ export async function createRecipe(request: CreateRecipeRequest): Promise