diff --git a/crates/goose/src/agents/tool_vectordb.rs b/crates/goose/src/agents/tool_vectordb.rs index 293360e234c8..73e2f703aba5 100644 --- a/crates/goose/src/agents/tool_vectordb.rs +++ b/crates/goose/src/agents/tool_vectordb.rs @@ -12,6 +12,8 @@ use std::path::PathBuf; use std::sync::Arc; use tokio::sync::RwLock; +use crate::config::Config; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolRecord { pub tool_name: String, @@ -53,7 +55,25 @@ impl ToolVectorDB { Ok(tool_db) } - fn get_db_path() -> Result { + pub fn get_db_path() -> Result { + let config = Config::global(); + + // Check for custom database path override + if let Ok(custom_path) = config.get_param::("GOOSE_VECTOR_DB_PATH") { + let path = PathBuf::from(custom_path); + + // Validate the path is absolute + if !path.is_absolute() { + return Err(anyhow::anyhow!( + "GOOSE_VECTOR_DB_PATH must be an absolute path, got: {}", + path.display() + )); + } + + return Ok(path); + } + + // Fall back to default XDG-based path let data_dir = Xdg::new() .context("Failed to determine base strategy")? .data_dir(); @@ -363,6 +383,7 @@ mod tests { use super::*; #[tokio::test] + #[serial_test::serial] async fn test_tool_vectordb_creation() { let db = ToolVectorDB::new(Some("test_tools_vectordb_creation".to_string())) .await @@ -372,6 +393,7 @@ mod tests { } #[tokio::test] + #[serial_test::serial] async fn test_tool_vectordb_operations() -> Result<()> { // Create a new database instance with a unique table name let db = ToolVectorDB::new(Some("test_tool_vectordb_operations".to_string())).await?; @@ -440,6 +462,7 @@ mod tests { } #[tokio::test] + #[serial_test::serial] async fn test_empty_db() -> Result<()> { // Create a new database instance with a unique table name let db = ToolVectorDB::new(Some("test_empty_db".to_string())).await?; @@ -458,6 +481,7 @@ mod tests { } #[tokio::test] + #[serial_test::serial] async fn test_tool_deletion() -> Result<()> { // Create a new database instance with a unique table name let db = ToolVectorDB::new(Some("test_tool_deletion".to_string())).await?; @@ -490,4 +514,74 @@ mod tests { Ok(()) } + + #[test] + #[serial_test::serial] + fn test_custom_db_path_override() -> Result<()> { + use std::env; + use tempfile::TempDir; + + // Create a temporary directory for testing + let temp_dir = TempDir::new().unwrap(); + let custom_path = temp_dir.path().join("custom_vector_db"); + + // Set the environment variable + env::set_var("GOOSE_VECTOR_DB_PATH", custom_path.to_str().unwrap()); + + // Test that get_db_path returns the custom path + let db_path = ToolVectorDB::get_db_path()?; + assert_eq!(db_path, custom_path); + + // Clean up + env::remove_var("GOOSE_VECTOR_DB_PATH"); + + Ok(()) + } + + #[test] + #[serial_test::serial] + fn test_custom_db_path_validation() { + use std::env; + + // Test that relative paths are rejected + env::set_var("GOOSE_VECTOR_DB_PATH", "relative/path"); + + let result = ToolVectorDB::get_db_path(); + assert!( + result.is_err(), + "Expected error for relative path, got: {:?}", + result + ); + assert!(result + .unwrap_err() + .to_string() + .contains("must be an absolute path")); + + // Clean up + env::remove_var("GOOSE_VECTOR_DB_PATH"); + } + + #[test] + #[serial_test::serial] + fn test_fallback_to_default_path() -> Result<()> { + use std::env; + + // Ensure no custom path is set + env::remove_var("GOOSE_VECTOR_DB_PATH"); + + // Test that it falls back to default XDG path + let db_path = ToolVectorDB::get_db_path()?; + assert!( + db_path.to_string_lossy().contains("goose"), + "Path should contain 'goose', got: {}", + db_path.display() + ); + assert!( + db_path.to_string_lossy().contains("tool_db"), + "Path should contain 'tool_db', got: {}", + db_path.display() + ); + + Ok(()) + } }