Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 128 additions & 110 deletions crates/goose/src/agents/tool_vectordb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,137 +382,164 @@ pub fn generate_table_id() -> String {
mod tests {
use super::*;

#[tokio::test]
#[serial_test::serial]
async fn test_tool_vectordb_creation() {
let db = ToolVectorDB::new(Some("test_tools_vectordb_creation".to_string()))
.await
.unwrap();
db.clear_tools().await.unwrap();
assert_eq!(db.table_name, "test_tools_vectordb_creation");
impl ToolVectorDB {
async fn new_test_db(
base_name: &str,
) -> Result<(Self, impl std::future::Future<Output = ()>)> {
let unique_name = format!("{}_{}", base_name, uuid::Uuid::new_v4().simple());
let db = Self::new(Some(unique_name)).await?;

let table_name = db.table_name.clone();
let connection = db.connection.clone();

let cleanup = async move {
let _ = async move {
let _ = connection.read().await.drop_table(&table_name).await;
};
};

Ok((db, cleanup))
}
}

#[tokio::test]
#[serial_test::serial]
async fn test_tool_vectordb_operations() -> Result<()> {
// Create a new database instance with a unique table name
let db = ToolVectorDB::new(Some("test_tool_vectordb_operations".to_string())).await?;
async fn test_tool_vectordb_creation() -> Result<()> {
let (db, cleanup) = ToolVectorDB::new_test_db("test_tools_vectordb_creation").await?;

// Clear any existing tools
db.clear_tools().await?;

// Create test tool records
let test_tools = vec![
ToolRecord {
tool_name: "test_tool_1".to_string(),
description: "A test tool for reading files".to_string(),
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#
.to_string(),
vector: vec![0.1; 1536], // Mock embedding vector
extension_name: "test_extension".to_string(),
},
ToolRecord {
tool_name: "test_tool_2".to_string(),
description: "A test tool for writing files".to_string(),
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#
.to_string(),
vector: vec![0.2; 1536], // Different mock embedding vector
extension_name: "test_extension".to_string(),
},
];
let result = async {
db.clear_tools().await?;
assert!(db.table_name.contains("test_tools_vectordb_creation"));
Ok(())
}
.await;

// Index the test tools
db.index_tools(test_tools).await?;
cleanup.await;
result
}

// Search for tools using a query vector similar to test_tool_1
let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector.clone(), 2, None).await?;
#[tokio::test]
#[serial_test::serial]
async fn test_tool_vectordb_operations() -> Result<()> {
let (db, cleanup) = ToolVectorDB::new_test_db("test_tool_vectordb_operations").await?;

let result = async {
db.clear_tools().await?;

let test_tools = vec![
ToolRecord {
tool_name: "test_tool_1".to_string(),
description: "A test tool for reading files".to_string(),
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#
.to_string(),
vector: vec![0.1; 1536],
extension_name: "test_extension".to_string(),
},
ToolRecord {
tool_name: "test_tool_2".to_string(),
description: "A test tool for writing files".to_string(),
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#
.to_string(),
vector: vec![0.2; 1536],
extension_name: "test_extension".to_string(),
},
];

db.index_tools(test_tools).await?;

let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector.clone(), 2, None).await?;

assert_eq!(results.len(), 2, "Should find both tools");
assert_eq!(
results[0].tool_name, "test_tool_1",
"First result should be test_tool_1"
);
assert_eq!(
results[1].tool_name, "test_tool_2",
"Second result should be test_tool_2"
);

// Verify results
assert_eq!(results.len(), 2, "Should find both tools");
assert_eq!(
results[0].tool_name, "test_tool_1",
"First result should be test_tool_1"
);
assert_eq!(
results[1].tool_name, "test_tool_2",
"Second result should be test_tool_2"
);
let results = db
.search_tools(query_vector.clone(), 2, Some("test_extension"))
.await?;
assert_eq!(
results.len(),
2,
"Should find both tools with test_extension"
);

// Test filtering by extension name
let results = db
.search_tools(query_vector.clone(), 2, Some("test_extension"))
.await?;
assert_eq!(
results.len(),
2,
"Should find both tools with test_extension"
);
let results = db
.search_tools(query_vector.clone(), 2, Some("nonexistent_extension"))
.await?;
assert_eq!(
results.len(),
0,
"Should find no tools with nonexistent_extension"
);

let results = db
.search_tools(query_vector.clone(), 2, Some("nonexistent_extension"))
.await?;
assert_eq!(
results.len(),
0,
"Should find no tools with nonexistent_extension"
);
Ok(())
}
.await;

Ok(())
cleanup.await;
result
}

#[tokio::test]
#[serial_test::serial]
async fn test_empty_db() -> Result<()> {
// Create a new database instance with a unique table name
let db = ToolVectorDB::new(Some("test_empty_db".to_string())).await?;
let (db, cleanup) = ToolVectorDB::new_test_db("test_empty_db").await?;

// Clear any existing tools
db.clear_tools().await?;
let result = async {
db.clear_tools().await?;

// Search in empty database
let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector, 2, None).await?;
let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector, 2, None).await?;

// Verify no results returned
assert_eq!(results.len(), 0, "Empty database should return no results");
assert_eq!(results.len(), 0, "Empty database should return no results");
Ok(())
}
.await;

Ok(())
cleanup.await;
result
}

#[tokio::test]
#[serial_test::serial]
async fn test_tool_deletion() -> Result<()> {
// Create a new database instance with a unique table name
let db = ToolVectorDB::new(Some("test_tool_deletion".to_string())).await?;

// Clear any existing tools
db.clear_tools().await?;

// Create and index a test tool
let test_tool = ToolRecord {
tool_name: "test_tool_to_delete".to_string(),
description: "A test tool that will be deleted".to_string(),
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#.to_string(),
vector: vec![0.1; 1536],
extension_name: "test_extension".to_string(),
};
let (db, cleanup) = ToolVectorDB::new_test_db("test_tool_deletion").await?;

db.index_tools(vec![test_tool]).await?;
let result = async {
db.clear_tools().await?;

// Verify tool exists
let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector.clone(), 1, None).await?;
assert_eq!(results.len(), 1, "Tool should exist before deletion");
let test_tool = ToolRecord {
tool_name: "test_tool_to_delete".to_string(),
description: "A test tool that will be deleted".to_string(),
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#
.to_string(),
vector: vec![0.1; 1536],
extension_name: "test_extension".to_string(),
};

// Delete the tool
db.remove_tool("test_tool_to_delete").await?;
db.index_tools(vec![test_tool]).await?;

// Verify tool is gone
let results = db.search_tools(query_vector.clone(), 1, None).await?;
assert_eq!(results.len(), 0, "Tool should be deleted");
let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector.clone(), 1, None).await?;
assert_eq!(results.len(), 1, "Tool should exist before deletion");

Ok(())
db.remove_tool("test_tool_to_delete").await?;

let results = db.search_tools(query_vector.clone(), 1, None).await?;
assert_eq!(results.len(), 0, "Tool should be deleted");

Ok(())
}
.await;

cleanup.await;
result
}

#[test]
Expand All @@ -521,20 +548,15 @@ mod tests {
use std::env;
use tempfile::TempDir;

// Create a temporary directory for testing
let temp_dir = TempDir::new().unwrap();
let custom_path = temp_dir.path().join("custom_vector_db");

// Set the environment variable
env::set_var("GOOSE_VECTOR_DB_PATH", custom_path.to_str().unwrap());

// Test that get_db_path returns the custom path
let db_path = ToolVectorDB::get_db_path()?;
assert_eq!(db_path, custom_path);

// Clean up
env::remove_var("GOOSE_VECTOR_DB_PATH");

Ok(())
}

Expand All @@ -543,7 +565,6 @@ mod tests {
fn test_custom_db_path_validation() {
use std::env;

// Test that relative paths are rejected
env::set_var("GOOSE_VECTOR_DB_PATH", "relative/path");

let result = ToolVectorDB::get_db_path();
Expand All @@ -557,7 +578,6 @@ mod tests {
.to_string()
.contains("must be an absolute path"));

// Clean up
env::remove_var("GOOSE_VECTOR_DB_PATH");
}

Expand All @@ -566,10 +586,8 @@ mod tests {
fn test_fallback_to_default_path() -> Result<()> {
use std::env;

// Ensure no custom path is set
env::remove_var("GOOSE_VECTOR_DB_PATH");

// Test that it falls back to default XDG path
let db_path = ToolVectorDB::get_db_path()?;
assert!(
db_path.to_string_lossy().contains("goose"),
Expand Down
1 change: 0 additions & 1 deletion crates/goose/src/recipe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,6 @@ impl RecipeBuilder {
#[cfg(test)]
mod tests {
use super::*;
use std::fs;

#[test]
fn test_from_content_with_json() {
Expand Down
Loading