diff --git a/crates/goose/src/agents/tool_vectordb.rs b/crates/goose/src/agents/tool_vectordb.rs index 73e2f703aba5..0cc86b876eea 100644 --- a/crates/goose/src/agents/tool_vectordb.rs +++ b/crates/goose/src/agents/tool_vectordb.rs @@ -382,137 +382,164 @@ pub fn generate_table_id() -> String { mod tests { use super::*; - #[tokio::test] - #[serial_test::serial] - async fn test_tool_vectordb_creation() { - let db = ToolVectorDB::new(Some("test_tools_vectordb_creation".to_string())) - .await - .unwrap(); - db.clear_tools().await.unwrap(); - assert_eq!(db.table_name, "test_tools_vectordb_creation"); + impl ToolVectorDB { + async fn new_test_db( + base_name: &str, + ) -> Result<(Self, impl std::future::Future)> { + let unique_name = format!("{}_{}", base_name, uuid::Uuid::new_v4().simple()); + let db = Self::new(Some(unique_name)).await?; + + let table_name = db.table_name.clone(); + let connection = db.connection.clone(); + + let cleanup = async move { + let _ = async move { + let _ = connection.read().await.drop_table(&table_name).await; + }; + }; + + Ok((db, cleanup)) + } } #[tokio::test] #[serial_test::serial] - async fn test_tool_vectordb_operations() -> Result<()> { - // Create a new database instance with a unique table name - let db = ToolVectorDB::new(Some("test_tool_vectordb_operations".to_string())).await?; + async fn test_tool_vectordb_creation() -> Result<()> { + let (db, cleanup) = ToolVectorDB::new_test_db("test_tools_vectordb_creation").await?; - // Clear any existing tools - db.clear_tools().await?; - - // Create test tool records - let test_tools = vec![ - ToolRecord { - tool_name: "test_tool_1".to_string(), - description: "A test tool for reading files".to_string(), - schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"# - .to_string(), - vector: vec![0.1; 1536], // Mock embedding vector - extension_name: "test_extension".to_string(), - }, - ToolRecord { - tool_name: "test_tool_2".to_string(), - description: "A test tool for writing files".to_string(), - schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"# - .to_string(), - vector: vec![0.2; 1536], // Different mock embedding vector - extension_name: "test_extension".to_string(), - }, - ]; + let result = async { + db.clear_tools().await?; + assert!(db.table_name.contains("test_tools_vectordb_creation")); + Ok(()) + } + .await; - // Index the test tools - db.index_tools(test_tools).await?; + cleanup.await; + result + } - // Search for tools using a query vector similar to test_tool_1 - let query_vector = vec![0.1; 1536]; - let results = db.search_tools(query_vector.clone(), 2, None).await?; + #[tokio::test] + #[serial_test::serial] + async fn test_tool_vectordb_operations() -> Result<()> { + let (db, cleanup) = ToolVectorDB::new_test_db("test_tool_vectordb_operations").await?; + + let result = async { + db.clear_tools().await?; + + let test_tools = vec![ + ToolRecord { + tool_name: "test_tool_1".to_string(), + description: "A test tool for reading files".to_string(), + schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"# + .to_string(), + vector: vec![0.1; 1536], + extension_name: "test_extension".to_string(), + }, + ToolRecord { + tool_name: "test_tool_2".to_string(), + description: "A test tool for writing files".to_string(), + schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"# + .to_string(), + vector: vec![0.2; 1536], + extension_name: "test_extension".to_string(), + }, + ]; + + db.index_tools(test_tools).await?; + + let query_vector = vec![0.1; 1536]; + let results = db.search_tools(query_vector.clone(), 2, None).await?; + + assert_eq!(results.len(), 2, "Should find both tools"); + assert_eq!( + results[0].tool_name, "test_tool_1", + "First result should be test_tool_1" + ); + assert_eq!( + results[1].tool_name, "test_tool_2", + "Second result should be test_tool_2" + ); - // Verify results - assert_eq!(results.len(), 2, "Should find both tools"); - assert_eq!( - results[0].tool_name, "test_tool_1", - "First result should be test_tool_1" - ); - assert_eq!( - results[1].tool_name, "test_tool_2", - "Second result should be test_tool_2" - ); + let results = db + .search_tools(query_vector.clone(), 2, Some("test_extension")) + .await?; + assert_eq!( + results.len(), + 2, + "Should find both tools with test_extension" + ); - // Test filtering by extension name - let results = db - .search_tools(query_vector.clone(), 2, Some("test_extension")) - .await?; - assert_eq!( - results.len(), - 2, - "Should find both tools with test_extension" - ); + let results = db + .search_tools(query_vector.clone(), 2, Some("nonexistent_extension")) + .await?; + assert_eq!( + results.len(), + 0, + "Should find no tools with nonexistent_extension" + ); - let results = db - .search_tools(query_vector.clone(), 2, Some("nonexistent_extension")) - .await?; - assert_eq!( - results.len(), - 0, - "Should find no tools with nonexistent_extension" - ); + Ok(()) + } + .await; - Ok(()) + cleanup.await; + result } #[tokio::test] #[serial_test::serial] async fn test_empty_db() -> Result<()> { - // Create a new database instance with a unique table name - let db = ToolVectorDB::new(Some("test_empty_db".to_string())).await?; + let (db, cleanup) = ToolVectorDB::new_test_db("test_empty_db").await?; - // Clear any existing tools - db.clear_tools().await?; + let result = async { + db.clear_tools().await?; - // Search in empty database - let query_vector = vec![0.1; 1536]; - let results = db.search_tools(query_vector, 2, None).await?; + let query_vector = vec![0.1; 1536]; + let results = db.search_tools(query_vector, 2, None).await?; - // Verify no results returned - assert_eq!(results.len(), 0, "Empty database should return no results"); + assert_eq!(results.len(), 0, "Empty database should return no results"); + Ok(()) + } + .await; - Ok(()) + cleanup.await; + result } #[tokio::test] #[serial_test::serial] async fn test_tool_deletion() -> Result<()> { - // Create a new database instance with a unique table name - let db = ToolVectorDB::new(Some("test_tool_deletion".to_string())).await?; - - // Clear any existing tools - db.clear_tools().await?; - - // Create and index a test tool - let test_tool = ToolRecord { - tool_name: "test_tool_to_delete".to_string(), - description: "A test tool that will be deleted".to_string(), - schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#.to_string(), - vector: vec![0.1; 1536], - extension_name: "test_extension".to_string(), - }; + let (db, cleanup) = ToolVectorDB::new_test_db("test_tool_deletion").await?; - db.index_tools(vec![test_tool]).await?; + let result = async { + db.clear_tools().await?; - // Verify tool exists - let query_vector = vec![0.1; 1536]; - let results = db.search_tools(query_vector.clone(), 1, None).await?; - assert_eq!(results.len(), 1, "Tool should exist before deletion"); + let test_tool = ToolRecord { + tool_name: "test_tool_to_delete".to_string(), + description: "A test tool that will be deleted".to_string(), + schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"# + .to_string(), + vector: vec![0.1; 1536], + extension_name: "test_extension".to_string(), + }; - // Delete the tool - db.remove_tool("test_tool_to_delete").await?; + db.index_tools(vec![test_tool]).await?; - // Verify tool is gone - let results = db.search_tools(query_vector.clone(), 1, None).await?; - assert_eq!(results.len(), 0, "Tool should be deleted"); + let query_vector = vec![0.1; 1536]; + let results = db.search_tools(query_vector.clone(), 1, None).await?; + assert_eq!(results.len(), 1, "Tool should exist before deletion"); - Ok(()) + db.remove_tool("test_tool_to_delete").await?; + + let results = db.search_tools(query_vector.clone(), 1, None).await?; + assert_eq!(results.len(), 0, "Tool should be deleted"); + + Ok(()) + } + .await; + + cleanup.await; + result } #[test] @@ -521,20 +548,15 @@ mod tests { use std::env; use tempfile::TempDir; - // Create a temporary directory for testing let temp_dir = TempDir::new().unwrap(); let custom_path = temp_dir.path().join("custom_vector_db"); - // Set the environment variable env::set_var("GOOSE_VECTOR_DB_PATH", custom_path.to_str().unwrap()); - // Test that get_db_path returns the custom path let db_path = ToolVectorDB::get_db_path()?; assert_eq!(db_path, custom_path); - // Clean up env::remove_var("GOOSE_VECTOR_DB_PATH"); - Ok(()) } @@ -543,7 +565,6 @@ mod tests { fn test_custom_db_path_validation() { use std::env; - // Test that relative paths are rejected env::set_var("GOOSE_VECTOR_DB_PATH", "relative/path"); let result = ToolVectorDB::get_db_path(); @@ -557,7 +578,6 @@ mod tests { .to_string() .contains("must be an absolute path")); - // Clean up env::remove_var("GOOSE_VECTOR_DB_PATH"); } @@ -566,10 +586,8 @@ mod tests { fn test_fallback_to_default_path() -> Result<()> { use std::env; - // Ensure no custom path is set env::remove_var("GOOSE_VECTOR_DB_PATH"); - // Test that it falls back to default XDG path let db_path = ToolVectorDB::get_db_path()?; assert!( db_path.to_string_lossy().contains("goose"), diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index 52baff428b51..09aac54a593b 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -432,7 +432,6 @@ impl RecipeBuilder { #[cfg(test)] mod tests { use super::*; - use std::fs; #[test] fn test_from_content_with_json() {