Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1177,9 +1177,7 @@ impl Agent {
}
}

for msg in &messages_to_add {
SessionManager::add_message(&session_config.id, msg).await?;
}
SessionManager::add_messages(&session_config.id, messages_to_add.messages()).await?;
conversation.extend(messages_to_add);
if exit_chat {
break;
Expand Down
101 changes: 101 additions & 0 deletions crates/goose/src/session/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ impl SessionManager {
Self::instance().await?.add_message(id, message).await
}

pub async fn add_messages(id: &str, messages: &[Message]) -> Result<()> {
Self::instance().await?.add_messages(id, messages).await
}

pub async fn replace_conversation(id: &str, conversation: &Conversation) -> Result<()> {
Self::instance()
.await?
Expand Down Expand Up @@ -992,6 +996,41 @@ impl SessionStorage {
Ok(())
}

async fn add_messages(&self, session_id: &str, messages: &[Message]) -> Result<()> {
if messages.is_empty() {
return Ok(());
}

let mut tx = self.pool.begin().await?;

for message in messages {
let metadata_json = serde_json::to_string(&message.metadata)?;

sqlx::query(
r#"
INSERT INTO messages (session_id, role, content_json, created_timestamp, metadata_json)
VALUES (?, ?, ?, ?, ?)
"#,
)
.bind(session_id)
.bind(role_to_string(&message.role))
.bind(serde_json::to_string(&message.content)?)
.bind(message.created)
.bind(metadata_json)
.execute(&mut *tx)
.await?;
}

// Only update sessions table ONCE at the end
sqlx::query("UPDATE sessions SET updated_at = datetime('now') WHERE id = ?")
.bind(session_id)
.execute(&mut *tx)
.await?;

tx.commit().await?;
Ok(())
}

async fn replace_conversation(
&self,
session_id: &str,
Expand Down Expand Up @@ -1356,4 +1395,66 @@ mod tests {
assert!(imported.user_set_name);
assert_eq!(imported.working_dir, PathBuf::from("/tmp/test"));
}

#[tokio::test]
async fn test_add_messages_batch() {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test_batch.db");
let storage = Arc::new(SessionStorage::create(&db_path).await.unwrap());

let session = storage
.create_session(
PathBuf::from("/tmp/test"),
"Batch test".to_string(),
SessionType::User,
)
.await
.unwrap();

// Create multiple messages to add in batch
let messages = vec![
Message {
id: None,
role: Role::User,
created: chrono::Utc::now().timestamp_millis(),
content: vec![MessageContent::text("message 1")],
metadata: Default::default(),
},
Message {
id: None,
role: Role::Assistant,
created: chrono::Utc::now().timestamp_millis(),
content: vec![MessageContent::text("response 1")],
metadata: Default::default(),
},
Message {
id: None,
role: Role::User,
created: chrono::Utc::now().timestamp_millis(),
content: vec![MessageContent::text("message 2")],
metadata: Default::default(),
},
Message {
id: None,
role: Role::Assistant,
created: chrono::Utc::now().timestamp_millis(),
content: vec![MessageContent::text("response 2")],
metadata: Default::default(),
},
];

// Add messages in batch
storage.add_messages(&session.id, &messages).await.unwrap();

// Verify all messages were added
let retrieved_session = storage.get_session(&session.id, true).await.unwrap();
assert_eq!(retrieved_session.message_count, 4);

let conversation = retrieved_session.conversation.unwrap();
assert_eq!(conversation.messages().len(), 4);
assert_eq!(conversation.messages()[0].role, Role::User);
assert_eq!(conversation.messages()[1].role, Role::Assistant);
assert_eq!(conversation.messages()[2].role, Role::User);
assert_eq!(conversation.messages()[3].role, Role::Assistant);
}
}