Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 132 additions & 30 deletions crates/goose/src/session/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,36 +156,10 @@ impl SessionManager {
}

pub async fn create_session(working_dir: PathBuf, description: String) -> Result<Session> {
let today = chrono::Utc::now().format("%Y%m%d").to_string();
let storage = Self::instance().await?;

let mut tx = storage.pool.begin().await?;

let max_idx = sqlx::query_scalar::<_, Option<i32>>(
"SELECT MAX(CAST(SUBSTR(id, 10) AS INTEGER)) FROM sessions WHERE id LIKE ?",
)
.bind(format!("{}_%", today))
.fetch_one(&mut *tx)
.await?
.unwrap_or(0);

let session_id = format!("{}_{}", today, max_idx + 1);

sqlx::query(
r#"
INSERT INTO sessions (id, description, working_dir, extension_data)
VALUES (?, ?, ?, '{}')
"#,
)
.bind(&session_id)
.bind(&description)
.bind(working_dir.to_string_lossy().as_ref())
.execute(&mut *tx)
.await?;

tx.commit().await?;

Self::get_session(&session_id, false).await
Self::instance()
.await?
.create_session(working_dir, description)
.await
}

pub async fn get_session(id: &str, include_messages: bool) -> Result<Session> {
Expand Down Expand Up @@ -606,6 +580,32 @@ impl SessionStorage {
Ok(())
}

async fn create_session(&self, working_dir: PathBuf, description: String) -> Result<Session> {
let today = chrono::Utc::now().format("%Y%m%d").to_string();
Ok(sqlx::query_as(
r#"
INSERT INTO sessions (id, description, working_dir, extension_data)
VALUES (
? || '_' || CAST(COALESCE((
SELECT MAX(CAST(SUBSTR(id, 10) AS INTEGER))
FROM sessions
WHERE id LIKE ? || '_%'
), 0) + 1 AS TEXT),
?,
?,
'{}'
)
RETURNING *
"#,
)
.bind(&today)
.bind(&today)
.bind(&description)
.bind(working_dir.to_string_lossy().as_ref())
.fetch_one(&self.pool)
.await?)
}

async fn get_session(&self, id: &str, include_messages: bool) -> Result<Session> {
let mut session = sqlx::query_as::<_, Session>(
r#"
Expand Down Expand Up @@ -859,3 +859,105 @@ impl SessionStorage {
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::conversation::message::{Message, MessageContent};
use tempfile::TempDir;

const NUM_CONCURRENT_SESSIONS: i32 = 10;

#[tokio::test]
async fn test_concurrent_session_creation() {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test_sessions.db");

let storage = Arc::new(SessionStorage::create(&db_path).await.unwrap());

let mut handles = vec![];

for i in 0..NUM_CONCURRENT_SESSIONS {
let session_storage = Arc::clone(&storage);
let handle = tokio::spawn(async move {
let working_dir = PathBuf::from(format!("/tmp/test_{}", i));
let description = format!("Test session {}", i);

let session = session_storage
.create_session(working_dir.clone(), description)
.await
.unwrap();

session_storage
.add_message(
&session.id,
&Message {
id: None,
role: Role::User,
created: chrono::Utc::now().timestamp_millis(),
content: vec![MessageContent::text("hello world")],
metadata: Default::default(),
},
)
.await
.unwrap();

session_storage
.add_message(
&session.id,
&Message {
id: None,
role: Role::Assistant,
created: chrono::Utc::now().timestamp_millis(),
content: vec![MessageContent::text("sup world?")],
metadata: Default::default(),
},
)
.await
.unwrap();

session_storage
.apply_update(
SessionUpdateBuilder::new(session.id.clone())
.description(format!("Updated session {}", i))
.total_tokens(Some(100 * i)),
)
.await
.unwrap();

let updated = session_storage
.get_session(&session.id, true)
.await
.unwrap();
assert_eq!(updated.message_count, 2);
assert_eq!(updated.total_tokens, Some(100 * i));

session.id
});
handles.push(handle);
}

let mut results = vec![];
for handle in handles {
results.push(handle.await.unwrap());
}

assert_eq!(results.len(), NUM_CONCURRENT_SESSIONS as usize);

let unique_ids: std::collections::HashSet<_> = results.iter().collect();
assert_eq!(unique_ids.len(), NUM_CONCURRENT_SESSIONS as usize);

let sessions = storage.list_sessions().await.unwrap();
assert_eq!(sessions.len(), NUM_CONCURRENT_SESSIONS as usize);

for session in &sessions {
assert_eq!(session.message_count, 2);
assert!(session.description.starts_with("Updated session"));
}

let insights = storage.get_insights().await.unwrap();
assert_eq!(insights.total_sessions, NUM_CONCURRENT_SESSIONS as usize);
let expected_tokens = 100 * NUM_CONCURRENT_SESSIONS * (NUM_CONCURRENT_SESSIONS - 1) / 2;
assert_eq!(insights.total_tokens, expected_tokens as i64);
}
}
Loading