Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 229 additions & 40 deletions crates/goose/src/execution/manager.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use crate::agents::extension::PlatformExtensionContext;
use crate::agents::Agent;
use crate::config::paths::Paths;
use crate::model::ModelConfig;
use crate::providers::create;
use crate::scheduler_factory::SchedulerFactory;
use crate::scheduler_trait::SchedulerTrait;
use anyhow::Result;
use lru::LruCache;
use std::num::NonZeroUsize;
use std::sync::Arc;
use tokio::sync::{OnceCell, RwLock};
use tracing::{debug, info, warn};
use tracing::{debug, info};

const DEFAULT_MAX_SESSION: usize = 100;

Expand All @@ -23,7 +21,7 @@ pub struct AgentManager {
}

impl AgentManager {
/// Reset the global singleton - ONLY for testing
#[cfg(test)]
pub fn reset_for_test() {
unsafe {
// Cast away the const to get mutable access
Expand All @@ -34,7 +32,6 @@ impl AgentManager {
}
}

// Private constructor - prevents direct instantiation in production
async fn new(max_sessions: Option<usize>) -> Result<Self> {
let schedule_file_path = Paths::data_dir().join("schedule.json");

Expand All @@ -49,8 +46,6 @@ impl AgentManager {
default_provider: Arc::new(RwLock::new(None)),
};

let _ = manager.configure_default_provider().await;

Ok(manager)
}

Expand All @@ -73,39 +68,6 @@ impl AgentManager {
*self.default_provider.write().await = Some(provider);
}

pub async fn configure_default_provider(&self) -> Result<()> {
let provider_name = std::env::var("GOOSE_DEFAULT_PROVIDER")
.or_else(|_| std::env::var("GOOSE_PROVIDER__TYPE"))
.ok();

let model_name = std::env::var("GOOSE_DEFAULT_MODEL")
.or_else(|_| std::env::var("GOOSE_PROVIDER__MODEL"))
.ok();

if provider_name.is_none() || model_name.is_none() {
return Ok(());
}

if let (Some(provider_name), Some(model_name)) = (provider_name, model_name) {
match ModelConfig::new(&model_name) {
Ok(model_config) => match create(&provider_name, model_config).await {
Ok(provider) => {
self.set_default_provider(provider).await;
info!(
"Configured default provider: {} with model: {}",
provider_name, model_name
);
}
Err(e) => {
warn!("Failed to create default provider {}: {}", provider_name, e)
}
},
Err(e) => warn!("Failed to create model config for {}: {}", model_name, e),
}
}
Ok(())
}

pub async fn get_or_create_agent(&self, session_id: String) -> Result<Arc<Agent>> {
{
let mut sessions = self.sessions.write().await;
Expand Down Expand Up @@ -154,3 +116,230 @@ impl AgentManager {
self.sessions.read().await.len()
}
}

#[cfg(test)]
mod tests {
use serial_test::serial;
use std::sync::Arc;

use crate::execution::{manager::AgentManager, SessionExecutionMode};

#[test]
fn test_execution_mode_constructors() {
assert_eq!(
SessionExecutionMode::chat(),
SessionExecutionMode::Interactive
);
assert_eq!(
SessionExecutionMode::scheduled(),
SessionExecutionMode::Background
);

let parent = "parent-123".to_string();
assert_eq!(
SessionExecutionMode::task(parent.clone()),
SessionExecutionMode::SubTask {
parent_session: parent
}
);
}

#[tokio::test]
#[serial]
async fn test_session_isolation() {
AgentManager::reset_for_test();
let manager = AgentManager::instance().await.unwrap();

let session1 = uuid::Uuid::new_v4().to_string();
let session2 = uuid::Uuid::new_v4().to_string();

let agent1 = manager.get_or_create_agent(session1.clone()).await.unwrap();

let agent2 = manager.get_or_create_agent(session2.clone()).await.unwrap();

// Different sessions should have different agents
assert!(!Arc::ptr_eq(&agent1, &agent2));

// Getting the same session should return the same agent
let agent1_again = manager.get_or_create_agent(session1).await.unwrap();

assert!(Arc::ptr_eq(&agent1, &agent1_again));

AgentManager::reset_for_test();
}

#[tokio::test]
#[serial]
async fn test_session_limit() {
AgentManager::reset_for_test();
let manager = AgentManager::instance().await.unwrap();

let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect();

for session in &sessions {
manager.get_or_create_agent(session.clone()).await.unwrap();
}

// Create a new session after cleanup
let new_session = "new-session".to_string();
let _new_agent = manager.get_or_create_agent(new_session).await.unwrap();

assert_eq!(manager.session_count().await, 100);
}

#[tokio::test]
#[serial]
async fn test_remove_session() {
AgentManager::reset_for_test();
let manager = AgentManager::instance().await.unwrap();
let session = String::from("remove-test");

manager.get_or_create_agent(session.clone()).await.unwrap();
assert!(manager.has_session(&session).await);

manager.remove_session(&session).await.unwrap();
assert!(!manager.has_session(&session).await);

assert!(manager.remove_session(&session).await.is_err());
}

#[tokio::test]
#[serial]
async fn test_concurrent_access() {
AgentManager::reset_for_test();
let manager = AgentManager::instance().await.unwrap();
let session = String::from("concurrent-test");

let mut handles = vec![];
for _ in 0..10 {
let mgr = Arc::clone(&manager);
let sess = session.clone();
handles.push(tokio::spawn(async move {
mgr.get_or_create_agent(sess).await.unwrap()
}));
}

let agents: Vec<_> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();

for agent in &agents[1..] {
assert!(Arc::ptr_eq(&agents[0], agent));
}

assert_eq!(manager.session_count().await, 1);
}

#[tokio::test]
#[serial]
async fn test_concurrent_session_creation_race_condition() {
// Test that concurrent attempts to create the same new session ID
// result in only one agent being created (tests double-check pattern)
AgentManager::reset_for_test();
let manager = AgentManager::instance().await.unwrap();
let session_id = String::from("race-condition-test");

// Spawn multiple tasks trying to create the same NEW session simultaneously
let mut handles = vec![];
for _ in 0..20 {
let sess = session_id.clone();
let mgr_clone = Arc::clone(&manager);
handles.push(tokio::spawn(async move {
mgr_clone.get_or_create_agent(sess).await.unwrap()
}));
}

// Collect all agents
let agents: Vec<_> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();

for agent in &agents[1..] {
assert!(
Arc::ptr_eq(&agents[0], agent),
"All concurrent requests should get the same agent"
);
}
assert_eq!(manager.session_count().await, 1);
}

#[tokio::test]
#[serial]
async fn test_set_default_provider() {
use crate::providers::testprovider::TestProvider;
use std::sync::Arc;

AgentManager::reset_for_test();
let manager = AgentManager::instance().await.unwrap();

// Create a test provider for replaying (doesn't need inner provider)
let temp_file = format!(
"{}/test_provider_{}.json",
std::env::temp_dir().display(),
std::process::id()
);

// Create an empty test provider (will fail on actual use but that's ok for this test)
let test_provider = TestProvider::new_replaying(&temp_file)
.unwrap_or_else(|_| TestProvider::new_replaying("/tmp/dummy.json").unwrap());

manager.set_default_provider(Arc::new(test_provider)).await;

let session = String::from("provider-test");
let _agent = manager.get_or_create_agent(session.clone()).await.unwrap();

assert!(manager.has_session(&session).await);
}

#[tokio::test]
#[serial]
async fn test_eviction_updates_last_used() {
AgentManager::reset_for_test();
// Test that accessing a session updates its last_used timestamp
// and affects eviction order
let manager = AgentManager::instance().await.unwrap();

let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect();

for session in &sessions {
manager.get_or_create_agent(session.clone()).await.unwrap();
// Small delay to ensure different timestamps
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}

// Access the first session again to update its last_used
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
manager
.get_or_create_agent(sessions[0].clone())
.await
.unwrap();

// Now create a 101st session - should evict session2 (least recently used)
let session101 = String::from("session-101");
manager
.get_or_create_agent(session101.clone())
.await
.unwrap();

assert!(manager.has_session(&sessions[0]).await);
assert!(!manager.has_session(&sessions[1]).await);
assert!(manager.has_session(&session101).await);
}

#[tokio::test]
#[serial]
async fn test_remove_nonexistent_session_error() {
// Test that removing a non-existent session returns an error
AgentManager::reset_for_test();
let manager = AgentManager::instance().await.unwrap();
let session = String::from("never-created");

let result = manager.remove_session(&session).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
}
Loading
Loading