Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::conversation::message::{Message, MessageContent};
use goose::conversation::Conversation;
use goose::execution::SessionExecutionMode;
use goose::mcp_utils::ToolResult;
use goose::permission::{Permission, PermissionConfirmation};
use goose::session::SessionManager;
Expand Down Expand Up @@ -207,10 +206,7 @@ async fn reply_handler(
let task_tx = tx.clone();

drop(tokio::spawn(async move {
let agent = match state
.get_agent(session_id.clone(), SessionExecutionMode::Interactive)
.await
{
let agent = match state.get_agent(session_id.clone()).await {
Ok(agent) => agent,
Err(e) => {
tracing::error!("Failed to get session agent: {}", e);
Expand Down
21 changes: 6 additions & 15 deletions crates/goose-server/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use axum::http::StatusCode;
use goose::execution::manager::AgentManager;
use goose::execution::SessionExecutionMode;
use goose::scheduler_trait::SchedulerTrait;
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
Expand Down Expand Up @@ -46,26 +45,18 @@ impl AppState {
}
}

pub async fn get_agent(
&self,
session_id: String,
mode: SessionExecutionMode,
) -> anyhow::Result<Arc<goose::agents::Agent>> {
self.agent_manager
.get_or_create_agent(session_id, mode)
.await
pub async fn get_agent(&self, session_id: String) -> anyhow::Result<Arc<goose::agents::Agent>> {
self.agent_manager.get_or_create_agent(session_id).await
}

/// Get agent for route handlers - always uses Interactive mode and converts any error to 500
pub async fn get_agent_for_route(
&self,
session_id: String,
) -> Result<Arc<goose::agents::Agent>, StatusCode> {
self.get_agent(session_id, SessionExecutionMode::Interactive)
.await
.map_err(|e| {
tracing::error!("Failed to get agent: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})
self.get_agent(session_id).await.map_err(|e| {
tracing::error!("Failed to get agent: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})
}
}
53 changes: 14 additions & 39 deletions crates/goose/src/execution/manager.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
//! Agent lifecycle management with session isolation

use super::SessionExecutionMode;
use crate::agents::Agent;
use crate::config::APP_STRATEGY;
use crate::model::ModelConfig;
Expand Down Expand Up @@ -112,49 +109,27 @@ impl AgentManager {
Ok(())
}

pub async fn get_or_create_agent(
&self,
session_id: String,
mode: SessionExecutionMode,
) -> Result<Arc<Agent>> {
let agent = {
pub async fn get_or_create_agent(&self, session_id: String) -> Result<Arc<Agent>> {
{
let mut sessions = self.sessions.write().await;
if let Some(agent) = sessions.get(&session_id) {
debug!("Found existing agent for session {}", session_id);
return Ok(Arc::clone(agent));
}

info!(
"Creating new agent for session {} with mode {}",
session_id, mode
);
let agent = Arc::new(Agent::new());
sessions.put(session_id.clone(), Arc::clone(&agent));
agent
};

match &mode {
SessionExecutionMode::Interactive | SessionExecutionMode::Background => {
debug!("Setting scheduler on agent for session {}", session_id);
agent.set_scheduler(Arc::clone(&self.scheduler)).await;
}
SessionExecutionMode::SubTask { .. } => {
debug!(
"SubTask mode for session {}, skipping scheduler setup",
session_id
);
if let Some(existing) = sessions.get(&session_id) {
return Ok(Arc::clone(existing));
}
}

let agent = Arc::new(Agent::new());
agent.set_scheduler(Arc::clone(&self.scheduler)).await;
if let Some(provider) = &*self.default_provider.read().await {
debug!(
"Setting default provider on agent for session {}",
session_id
);
let _ = agent.update_provider(Arc::clone(provider)).await;
agent.update_provider(Arc::clone(provider)).await?;
}

Ok(agent)
let mut sessions = self.sessions.write().await;
if let Some(existing) = sessions.get(&session_id) {
Ok(Arc::clone(existing))
} else {
sessions.put(session_id, agent.clone());
Ok(agent)
}
}

pub async fn remove_session(&self, session_id: &str) -> Result<()> {
Expand Down
96 changes: 12 additions & 84 deletions crates/goose/tests/execution_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,15 @@ mod execution_tests {
let session1 = uuid::Uuid::new_v4().to_string();
let session2 = uuid::Uuid::new_v4().to_string();

let agent1 = manager
.get_or_create_agent(session1.clone(), SessionExecutionMode::Interactive)
.await
.unwrap();
let agent1 = manager.get_or_create_agent(session1.clone()).await.unwrap();

let agent2 = manager
.get_or_create_agent(session2.clone(), SessionExecutionMode::Interactive)
.await
.unwrap();
let agent2 = manager.get_or_create_agent(session2.clone()).await.unwrap();

// Different sessions should have different agents
assert!(!Arc::ptr_eq(&agent1, &agent2));

// Getting the same session should return the same agent
let agent1_again = manager
.get_or_create_agent(session1, SessionExecutionMode::chat())
.await
.unwrap();
let agent1_again = manager.get_or_create_agent(session1).await.unwrap();

assert!(Arc::ptr_eq(&agent1, &agent1_again));

Expand All @@ -66,18 +57,12 @@ mod execution_tests {
let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect();

for session in &sessions {
manager
.get_or_create_agent(session.clone(), SessionExecutionMode::chat())
.await
.unwrap();
manager.get_or_create_agent(session.clone()).await.unwrap();
}

// Create a new session after cleanup
let new_session = "new-session".to_string();
let _new_agent = manager
.get_or_create_agent(new_session, SessionExecutionMode::chat())
.await
.unwrap();
let _new_agent = manager.get_or_create_agent(new_session).await.unwrap();

assert_eq!(manager.session_count().await, 100);
}
Expand All @@ -89,18 +74,13 @@ mod execution_tests {
let manager = AgentManager::instance().await.unwrap();
let session = String::from("remove-test");

manager
.get_or_create_agent(session.clone(), SessionExecutionMode::chat())
.await
.unwrap();
manager.get_or_create_agent(session.clone()).await.unwrap();
assert!(manager.has_session(&session).await);

manager.remove_session(&session).await.unwrap();
assert!(!manager.has_session(&session).await);

assert!(manager.remove_session(&session).await.is_err());

AgentManager::reset_for_test();
}

#[tokio::test]
Expand All @@ -115,9 +95,7 @@ mod execution_tests {
let mgr = Arc::clone(&manager);
let sess = session.clone();
handles.push(tokio::spawn(async move {
mgr.get_or_create_agent(sess, SessionExecutionMode::chat())
.await
.unwrap()
mgr.get_or_create_agent(sess).await.unwrap()
}));
}

Expand All @@ -132,33 +110,6 @@ mod execution_tests {
}

assert_eq!(manager.session_count().await, 1);

AgentManager::reset_for_test();
}

#[tokio::test]
#[serial]
async fn test_different_modes_same_session() {
AgentManager::reset_for_test();
let manager = AgentManager::instance().await.unwrap();
let session_id = String::from("mode-test");

// Create initial agent
let agent1 = manager
.get_or_create_agent(session_id.clone(), SessionExecutionMode::chat())
.await
.unwrap();

// Get same session with different mode - should return same agent
// (mode is stored but agent is reused)
let agent2 = manager
.get_or_create_agent(session_id.clone(), SessionExecutionMode::Background)
.await
.unwrap();

assert!(Arc::ptr_eq(&agent1, &agent2));

AgentManager::reset_for_test();
}

#[tokio::test]
Expand All @@ -176,10 +127,7 @@ mod execution_tests {
let sess = session_id.clone();
let mgr_clone = Arc::clone(&manager);
handles.push(tokio::spawn(async move {
mgr_clone
.get_or_create_agent(sess, SessionExecutionMode::Interactive)
.await
.unwrap()
mgr_clone.get_or_create_agent(sess).await.unwrap()
}));
}

Expand All @@ -190,18 +138,13 @@ mod execution_tests {
.map(|r| r.unwrap())
.collect();

// All should be the same agent (double-check pattern should prevent duplicates)
for agent in &agents[1..] {
assert!(
Arc::ptr_eq(&agents[0], agent),
"All concurrent requests should get the same agent"
);
}

// Only one session should exist
assert_eq!(manager.session_count().await, 1);

AgentManager::reset_for_test();
}

#[tokio::test]
Expand Down Expand Up @@ -233,8 +176,6 @@ mod execution_tests {
} else {
env::remove_var("GOOSE_DEFAULT_MODEL");
}

AgentManager::reset_for_test();
}

#[tokio::test]
Expand All @@ -260,14 +201,9 @@ mod execution_tests {
manager.set_default_provider(Arc::new(test_provider)).await;

let session = String::from("provider-test");
let _agent = manager
.get_or_create_agent(session.clone(), SessionExecutionMode::Interactive)
.await
.unwrap();
let _agent = manager.get_or_create_agent(session.clone()).await.unwrap();

assert!(manager.has_session(&session).await);

AgentManager::reset_for_test();
}

#[tokio::test]
Expand All @@ -281,34 +217,28 @@ mod execution_tests {
let sessions: Vec<_> = (0..100).map(|i| format!("session-{}", i)).collect();

for session in &sessions {
manager
.get_or_create_agent(session.clone(), SessionExecutionMode::chat())
.await
.unwrap();
manager.get_or_create_agent(session.clone()).await.unwrap();
// Small delay to ensure different timestamps
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}

// Access the first session again to update its last_used
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
manager
.get_or_create_agent(sessions[0].clone(), SessionExecutionMode::Interactive)
.get_or_create_agent(sessions[0].clone())
.await
.unwrap();

// Now create a 101st session - should evict session2 (least recently used)
let session101 = String::from("session-101");
manager
.get_or_create_agent(session101.clone(), SessionExecutionMode::Interactive)
.get_or_create_agent(session101.clone())
.await
.unwrap();

// session1 should still exist (recently accessed)
// session2 should be evicted (least recently used)
assert!(manager.has_session(&sessions[0]).await);
assert!(!manager.has_session(&sessions[1]).await);
assert!(manager.has_session(&session101).await);
AgentManager::reset_for_test();
}

#[tokio::test]
Expand All @@ -322,7 +252,5 @@ mod execution_tests {
let result = manager.remove_session(&session).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));

AgentManager::reset_for_test();
}
}
Loading