diff --git a/crates/goose/src/execution/manager.rs b/crates/goose/src/execution/manager.rs index 3d86ff092b29..615f7943213f 100644 --- a/crates/goose/src/execution/manager.rs +++ b/crates/goose/src/execution/manager.rs @@ -1,8 +1,6 @@ use crate::agents::extension::PlatformExtensionContext; use crate::agents::Agent; use crate::config::paths::Paths; -use crate::model::ModelConfig; -use crate::providers::create; use crate::scheduler_factory::SchedulerFactory; use crate::scheduler_trait::SchedulerTrait; use anyhow::Result; @@ -10,7 +8,7 @@ use lru::LruCache; use std::num::NonZeroUsize; use std::sync::Arc; use tokio::sync::{OnceCell, RwLock}; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; const DEFAULT_MAX_SESSION: usize = 100; @@ -23,7 +21,7 @@ pub struct AgentManager { } impl AgentManager { - /// Reset the global singleton - ONLY for testing + #[cfg(test)] pub fn reset_for_test() { unsafe { // Cast away the const to get mutable access @@ -34,7 +32,6 @@ impl AgentManager { } } - // Private constructor - prevents direct instantiation in production async fn new(max_sessions: Option) -> Result { let schedule_file_path = Paths::data_dir().join("schedule.json"); @@ -49,8 +46,6 @@ impl AgentManager { default_provider: Arc::new(RwLock::new(None)), }; - let _ = manager.configure_default_provider().await; - Ok(manager) } @@ -73,39 +68,6 @@ impl AgentManager { *self.default_provider.write().await = Some(provider); } - pub async fn configure_default_provider(&self) -> Result<()> { - let provider_name = std::env::var("GOOSE_DEFAULT_PROVIDER") - .or_else(|_| std::env::var("GOOSE_PROVIDER__TYPE")) - .ok(); - - let model_name = std::env::var("GOOSE_DEFAULT_MODEL") - .or_else(|_| std::env::var("GOOSE_PROVIDER__MODEL")) - .ok(); - - if provider_name.is_none() || model_name.is_none() { - return Ok(()); - } - - if let (Some(provider_name), Some(model_name)) = (provider_name, model_name) { - match ModelConfig::new(&model_name) { - Ok(model_config) => match create(&provider_name, model_config).await { - Ok(provider) => { - self.set_default_provider(provider).await; - info!( - "Configured default provider: {} with model: {}", - provider_name, model_name - ); - } - Err(e) => { - warn!("Failed to create default provider {}: {}", provider_name, e) - } - }, - Err(e) => warn!("Failed to create model config for {}: {}", model_name, e), - } - } - Ok(()) - } - pub async fn get_or_create_agent(&self, session_id: String) -> Result> { { let mut sessions = self.sessions.write().await; @@ -154,3 +116,230 @@ impl AgentManager { self.sessions.read().await.len() } } + +#[cfg(test)] +mod tests { + use serial_test::serial; + use std::sync::Arc; + + use crate::execution::{manager::AgentManager, SessionExecutionMode}; + + #[test] + fn test_execution_mode_constructors() { + assert_eq!( + SessionExecutionMode::chat(), + SessionExecutionMode::Interactive + ); + assert_eq!( + SessionExecutionMode::scheduled(), + SessionExecutionMode::Background + ); + + let parent = "parent-123".to_string(); + assert_eq!( + SessionExecutionMode::task(parent.clone()), + SessionExecutionMode::SubTask { + parent_session: parent + } + ); + } + + #[tokio::test] + #[serial] + async fn test_session_isolation() { + AgentManager::reset_for_test(); + let manager = AgentManager::instance().await.unwrap(); + + let session1 = uuid::Uuid::new_v4().to_string(); + let session2 = uuid::Uuid::new_v4().to_string(); + + let agent1 = manager.get_or_create_agent(session1.clone()).await.unwrap(); + + let agent2 = manager.get_or_create_agent(session2.clone()).await.unwrap(); + + // Different sessions should have different agents + assert!(!Arc::ptr_eq(&agent1, &agent2)); + + // Getting the same session should return the same agent + let agent1_again = manager.get_or_create_agent(session1).await.unwrap(); + + assert!(Arc::ptr_eq(&agent1, &agent1_again)); + + AgentManager::reset_for_test(); + } + + #[tokio::test] + #[serial] + async fn test_session_limit() { + AgentManager::reset_for_test(); + let manager = AgentManager::instance().await.unwrap(); + + let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect(); + + for session in &sessions { + manager.get_or_create_agent(session.clone()).await.unwrap(); + } + + // Create a new session after cleanup + let new_session = "new-session".to_string(); + let _new_agent = manager.get_or_create_agent(new_session).await.unwrap(); + + assert_eq!(manager.session_count().await, 100); + } + + #[tokio::test] + #[serial] + async fn test_remove_session() { + AgentManager::reset_for_test(); + let manager = AgentManager::instance().await.unwrap(); + let session = String::from("remove-test"); + + manager.get_or_create_agent(session.clone()).await.unwrap(); + assert!(manager.has_session(&session).await); + + manager.remove_session(&session).await.unwrap(); + assert!(!manager.has_session(&session).await); + + assert!(manager.remove_session(&session).await.is_err()); + } + + #[tokio::test] + #[serial] + async fn test_concurrent_access() { + AgentManager::reset_for_test(); + let manager = AgentManager::instance().await.unwrap(); + let session = String::from("concurrent-test"); + + let mut handles = vec![]; + for _ in 0..10 { + let mgr = Arc::clone(&manager); + let sess = session.clone(); + handles.push(tokio::spawn(async move { + mgr.get_or_create_agent(sess).await.unwrap() + })); + } + + let agents: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + for agent in &agents[1..] { + assert!(Arc::ptr_eq(&agents[0], agent)); + } + + assert_eq!(manager.session_count().await, 1); + } + + #[tokio::test] + #[serial] + async fn test_concurrent_session_creation_race_condition() { + // Test that concurrent attempts to create the same new session ID + // result in only one agent being created (tests double-check pattern) + AgentManager::reset_for_test(); + let manager = AgentManager::instance().await.unwrap(); + let session_id = String::from("race-condition-test"); + + // Spawn multiple tasks trying to create the same NEW session simultaneously + let mut handles = vec![]; + for _ in 0..20 { + let sess = session_id.clone(); + let mgr_clone = Arc::clone(&manager); + handles.push(tokio::spawn(async move { + mgr_clone.get_or_create_agent(sess).await.unwrap() + })); + } + + // Collect all agents + let agents: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + for agent in &agents[1..] { + assert!( + Arc::ptr_eq(&agents[0], agent), + "All concurrent requests should get the same agent" + ); + } + assert_eq!(manager.session_count().await, 1); + } + + #[tokio::test] + #[serial] + async fn test_set_default_provider() { + use crate::providers::testprovider::TestProvider; + use std::sync::Arc; + + AgentManager::reset_for_test(); + let manager = AgentManager::instance().await.unwrap(); + + // Create a test provider for replaying (doesn't need inner provider) + let temp_file = format!( + "{}/test_provider_{}.json", + std::env::temp_dir().display(), + std::process::id() + ); + + // Create an empty test provider (will fail on actual use but that's ok for this test) + let test_provider = TestProvider::new_replaying(&temp_file) + .unwrap_or_else(|_| TestProvider::new_replaying("/tmp/dummy.json").unwrap()); + + manager.set_default_provider(Arc::new(test_provider)).await; + + let session = String::from("provider-test"); + let _agent = manager.get_or_create_agent(session.clone()).await.unwrap(); + + assert!(manager.has_session(&session).await); + } + + #[tokio::test] + #[serial] + async fn test_eviction_updates_last_used() { + AgentManager::reset_for_test(); + // Test that accessing a session updates its last_used timestamp + // and affects eviction order + let manager = AgentManager::instance().await.unwrap(); + + let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect(); + + for session in &sessions { + manager.get_or_create_agent(session.clone()).await.unwrap(); + // Small delay to ensure different timestamps + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + } + + // Access the first session again to update its last_used + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + manager + .get_or_create_agent(sessions[0].clone()) + .await + .unwrap(); + + // Now create a 101st session - should evict session2 (least recently used) + let session101 = String::from("session-101"); + manager + .get_or_create_agent(session101.clone()) + .await + .unwrap(); + + assert!(manager.has_session(&sessions[0]).await); + assert!(!manager.has_session(&sessions[1]).await); + assert!(manager.has_session(&session101).await); + } + + #[tokio::test] + #[serial] + async fn test_remove_nonexistent_session_error() { + // Test that removing a non-existent session returns an error + AgentManager::reset_for_test(); + let manager = AgentManager::instance().await.unwrap(); + let session = String::from("never-created"); + + let result = manager.remove_session(&session).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("not found")); + } +} diff --git a/crates/goose/tests/execution_tests.rs b/crates/goose/tests/execution_tests.rs deleted file mode 100644 index 07e8932563b7..000000000000 --- a/crates/goose/tests/execution_tests.rs +++ /dev/null @@ -1,256 +0,0 @@ -mod execution_tests { - use goose::execution::manager::AgentManager; - use goose::execution::SessionExecutionMode; - use serial_test::serial; - use std::sync::Arc; - - #[test] - fn test_execution_mode_constructors() { - assert_eq!( - SessionExecutionMode::chat(), - SessionExecutionMode::Interactive - ); - assert_eq!( - SessionExecutionMode::scheduled(), - SessionExecutionMode::Background - ); - - let parent = "parent-123".to_string(); - assert_eq!( - SessionExecutionMode::task(parent.clone()), - SessionExecutionMode::SubTask { - parent_session: parent - } - ); - } - - #[tokio::test] - #[serial] - async fn test_session_isolation() { - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); - - let session1 = uuid::Uuid::new_v4().to_string(); - let session2 = uuid::Uuid::new_v4().to_string(); - - let agent1 = manager.get_or_create_agent(session1.clone()).await.unwrap(); - - let agent2 = manager.get_or_create_agent(session2.clone()).await.unwrap(); - - // Different sessions should have different agents - assert!(!Arc::ptr_eq(&agent1, &agent2)); - - // Getting the same session should return the same agent - let agent1_again = manager.get_or_create_agent(session1).await.unwrap(); - - assert!(Arc::ptr_eq(&agent1, &agent1_again)); - - AgentManager::reset_for_test(); - } - - #[tokio::test] - #[serial] - async fn test_session_limit() { - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); - - let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect(); - - for session in &sessions { - manager.get_or_create_agent(session.clone()).await.unwrap(); - } - - // Create a new session after cleanup - let new_session = "new-session".to_string(); - let _new_agent = manager.get_or_create_agent(new_session).await.unwrap(); - - assert_eq!(manager.session_count().await, 100); - } - - #[tokio::test] - #[serial] - async fn test_remove_session() { - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); - let session = String::from("remove-test"); - - manager.get_or_create_agent(session.clone()).await.unwrap(); - assert!(manager.has_session(&session).await); - - manager.remove_session(&session).await.unwrap(); - assert!(!manager.has_session(&session).await); - - assert!(manager.remove_session(&session).await.is_err()); - } - - #[tokio::test] - #[serial] - async fn test_concurrent_access() { - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); - let session = String::from("concurrent-test"); - - let mut handles = vec![]; - for _ in 0..10 { - let mgr = Arc::clone(&manager); - let sess = session.clone(); - handles.push(tokio::spawn(async move { - mgr.get_or_create_agent(sess).await.unwrap() - })); - } - - let agents: Vec<_> = futures::future::join_all(handles) - .await - .into_iter() - .map(|r| r.unwrap()) - .collect(); - - for agent in &agents[1..] { - assert!(Arc::ptr_eq(&agents[0], agent)); - } - - assert_eq!(manager.session_count().await, 1); - } - - #[tokio::test] - #[serial] - async fn test_concurrent_session_creation_race_condition() { - // Test that concurrent attempts to create the same new session ID - // result in only one agent being created (tests double-check pattern) - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); - let session_id = String::from("race-condition-test"); - - // Spawn multiple tasks trying to create the same NEW session simultaneously - let mut handles = vec![]; - for _ in 0..20 { - let sess = session_id.clone(); - let mgr_clone = Arc::clone(&manager); - handles.push(tokio::spawn(async move { - mgr_clone.get_or_create_agent(sess).await.unwrap() - })); - } - - // Collect all agents - let agents: Vec<_> = futures::future::join_all(handles) - .await - .into_iter() - .map(|r| r.unwrap()) - .collect(); - - for agent in &agents[1..] { - assert!( - Arc::ptr_eq(&agents[0], agent), - "All concurrent requests should get the same agent" - ); - } - assert_eq!(manager.session_count().await, 1); - } - - #[tokio::test] - #[serial] - async fn test_configure_default_provider() { - use std::env; - - AgentManager::reset_for_test(); - - let original_provider = env::var("GOOSE_DEFAULT_PROVIDER").ok(); - let original_model = env::var("GOOSE_DEFAULT_MODEL").ok(); - - env::set_var("GOOSE_DEFAULT_PROVIDER", "openai"); - env::set_var("GOOSE_DEFAULT_MODEL", "gpt-4o-mini"); - - let manager = AgentManager::instance().await.unwrap(); - let result = manager.configure_default_provider().await; - - assert!(result.is_ok()); - - // Restore original env vars - if let Some(val) = original_provider { - env::set_var("GOOSE_DEFAULT_PROVIDER", val); - } else { - env::remove_var("GOOSE_DEFAULT_PROVIDER"); - } - if let Some(val) = original_model { - env::set_var("GOOSE_DEFAULT_MODEL", val); - } else { - env::remove_var("GOOSE_DEFAULT_MODEL"); - } - } - - #[tokio::test] - #[serial] - async fn test_set_default_provider() { - use goose::providers::testprovider::TestProvider; - use std::sync::Arc; - - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); - - // Create a test provider for replaying (doesn't need inner provider) - let temp_file = format!( - "{}/test_provider_{}.json", - std::env::temp_dir().display(), - std::process::id() - ); - - // Create an empty test provider (will fail on actual use but that's ok for this test) - let test_provider = TestProvider::new_replaying(&temp_file) - .unwrap_or_else(|_| TestProvider::new_replaying("/tmp/dummy.json").unwrap()); - - manager.set_default_provider(Arc::new(test_provider)).await; - - let session = String::from("provider-test"); - let _agent = manager.get_or_create_agent(session.clone()).await.unwrap(); - - assert!(manager.has_session(&session).await); - } - - #[tokio::test] - #[serial] - async fn test_eviction_updates_last_used() { - AgentManager::reset_for_test(); - // Test that accessing a session updates its last_used timestamp - // and affects eviction order - let manager = AgentManager::instance().await.unwrap(); - - let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect(); - - for session in &sessions { - manager.get_or_create_agent(session.clone()).await.unwrap(); - // Small delay to ensure different timestamps - tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; - } - - // Access the first session again to update its last_used - tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; - manager - .get_or_create_agent(sessions[0].clone()) - .await - .unwrap(); - - // Now create a 101st session - should evict session2 (least recently used) - let session101 = String::from("session-101"); - manager - .get_or_create_agent(session101.clone()) - .await - .unwrap(); - - assert!(manager.has_session(&sessions[0]).await); - assert!(!manager.has_session(&sessions[1]).await); - assert!(manager.has_session(&session101).await); - } - - #[tokio::test] - #[serial] - async fn test_remove_nonexistent_session_error() { - // Test that removing a non-existent session returns an error - AgentManager::reset_for_test(); - let manager = AgentManager::instance().await.unwrap(); - let session = String::from("never-created"); - - let result = manager.remove_session(&session).await; - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("not found")); - } -}