Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 208 additions & 26 deletions crates/goose/src/session/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@
let options = SqliteConnectOptions::new()
.filename(db_path)
.create_if_missing(create_if_missing)
.busy_timeout(std::time::Duration::from_secs(5));
.busy_timeout(std::time::Duration::from_secs(5))
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal);

sqlx::SqlitePool::connect_with(options).await.map_err(|e| {
anyhow::anyhow!(
Expand Down Expand Up @@ -780,32 +781,38 @@
name: String,
session_type: SessionType,
) -> Result<Session> {
let mut tx = self.pool.begin().await?;

let today = chrono::Utc::now().format("%Y%m%d").to_string();
Ok(sqlx::query_as(
r#"
INSERT INTO sessions (id, name, user_set_name, session_type, working_dir, extension_data)
VALUES (
? || '_' || CAST(COALESCE((
SELECT MAX(CAST(SUBSTR(id, 10) AS INTEGER))
FROM sessions
WHERE id LIKE ? || '_%'
), 0) + 1 AS TEXT),
?,
FALSE,
?,
?,
'{}'
)
RETURNING *
"#,
)
.bind(&today)
.bind(&today)
.bind(&name)
.bind(session_type.to_string())
.bind(working_dir.to_string_lossy().as_ref())
.fetch_one(&self.pool)
.await?)
let session = sqlx::query_as(
r#"
INSERT INTO sessions (id, name, user_set_name, session_type, working_dir, extension_data)
VALUES (
? || '_' || CAST(COALESCE((
SELECT MAX(CAST(SUBSTR(id, 10) AS INTEGER))
FROM sessions
WHERE id LIKE ? || '_%'
), 0) + 1 AS TEXT),
?,
FALSE,
?,
?,
'{}'
)
RETURNING *
"#,
)
.bind(&today)
.bind(&today)
.bind(&name)
.bind(session_type.to_string())
.bind(working_dir.to_string_lossy().as_ref())
// .fetch_one(&self.pool) <-- this contributes to the race condition
.fetch_one(&mut *tx)
.await?;

tx.commit().await?;
Ok(session)
}

async fn get_session(&self, id: &str, include_messages: bool) -> Result<Session> {
Expand Down Expand Up @@ -1364,4 +1371,179 @@
assert!(imported.user_set_name);
assert_eq!(imported.working_dir, PathBuf::from("/tmp/test"));
}

/// Test for WAL mode race condition matching build_session() pattern
///
/// This test closely simulates the actual build_session() flow:
/// 1. Determine if we need to create a new session (session_id is None)
/// 2. Call create_session() to create it
/// 3. Get the returned session_id
/// 4. Immediately call get_session() with that id (like CliSession::new does)
///
/// This matches the code in builder.rs and mod.rs where sessions are created
/// and immediately read.
#[tokio::test]
async fn test_wal_race_condition_create_then_get() {
use std::time::Duration;

const NUM_TASKS: usize = 100;
let mut handles = vec![];

for i in 0..NUM_TASKS {
let handle = tokio::spawn(async move {
// Wait for all tasks to be ready

// Simulate build_session() logic:
// Step 1: session_id is None, so we need to create a new session
let session_id: Option<String> = None;

Check warning on line 1399 in crates/goose/src/session/session_manager.rs

View workflow job for this annotation

GitHub Actions / Check Rust Code Format

Diff in /home/runner/work/goose/goose/crates/goose/src/session/session_manager.rs
// Step 2: Create session (like builder.rs)
let session_id = if session_id.is_none() {
let session =
SessionManager::create_session(
PathBuf::from(format!("/tmp/test_{}", i)),
format!("Race test session {}", i),
SessionType::User,
)
.await
.expect("Failed to create session");
Some(session.id)
} else {
session_id
};

// Step 3: Now simulate CliSession::new() which immediately reads the session
// (like mod.rs:138-149)
let session_id = session_id.unwrap();

// This is the critical read that happens in CliSession::new
// It tries to load the conversation from the just-created session
let fetched = SessionManager::get_session(&session_id, true) // include_messages=true like real code
.await;

match fetched {
Ok(fetched_session) => {
assert_eq!(
fetched_session.id, session_id,
"Session ID mismatch for session {}",
i
);
println!(
"✅ SUCCESS: Session {} found immediately after creation",
session_id
);
Ok(session_id)
}
Err(e) => {
// This is the race condition we're testing for
eprintln!("⚠️ RACE DETECTED: Session {} not found immediately after creation: {}",
session_id, e);
Err(format!("Session {} not found: {}", session_id, e))
}
}
});

handles.push(handle);
}

// Collect results
let mut errors = vec![];
for handle in handles {
match handle.await.unwrap() {
Ok(_) => {}
Err(e) => errors.push(e),
}
}

// Give WAL time to checkpoint
tokio::time::sleep(Duration::from_millis(100)).await;

// Report any race conditions detected
if !errors.is_empty() {
panic!(
"WAL race condition detected in {} out of {} tasks:\n{}",
errors.len(),
NUM_TASKS,
errors.join("\n")
);
}
}

/// Test the exact pattern used in CliSession::new with block_in_place
///
/// This test simulates the blocking pattern used in the actual code to see
/// if it exacerbates the WAL race condition.
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_wal_race_with_blocking_pattern() {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test_blocking_race.db");
let storage = Arc::new(SessionStorage::create(&db_path).await.unwrap());

const NUM_ITERATIONS: usize = 100;
let mut handles = vec![];

for i in 0..NUM_ITERATIONS {
let storage = Arc::clone(&storage);

let handle = tokio::spawn(async move {
// Create a session
let description = format!("Blocking test {}", i);
let created = storage
.create_session(
PathBuf::from(format!("/tmp/test_{}", i)),
description,
SessionType::User,
)
.await
.unwrap();

// Simulate CliSession::new's blocking pattern
let session_id = created.id.clone();
let fetched = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(async { storage.get_session(&session_id, false).await })
});

match fetched {
Ok(_) => {
println!(
"✅ SUCCESS (blocking): Session {} found immediately after creation",
session_id
);
Ok(created.id)
}
Err(e) => {
eprintln!(
"⚠️ RACE DETECTED with block_in_place: Session {} not found: {}",
session_id, e
);
Err(format!(
"Session {} not found with blocking: {}",
session_id, e
))
}
}
});

handles.push(handle);
}

// Collect results
let mut errors = vec![];
for handle in handles {
match handle.await.unwrap() {
Ok(_) => {}
Err(e) => errors.push(e),
}
}

if !errors.is_empty() {
panic!(
"WAL race condition detected with blocking pattern in {} out of {} iterations:\n{}",
errors.len(),
NUM_ITERATIONS,
errors.join("\n")
);
}
}
}
Loading