Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion crates/goose/src/agents/tool_vectordb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;

use crate::config::Config;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolRecord {
pub tool_name: String,
Expand Down Expand Up @@ -53,7 +55,25 @@ impl ToolVectorDB {
Ok(tool_db)
}

fn get_db_path() -> Result<PathBuf> {
pub fn get_db_path() -> Result<PathBuf> {
let config = Config::global();

// Check for custom database path override
if let Ok(custom_path) = config.get_param::<String>("GOOSE_VECTOR_DB_PATH") {
let path = PathBuf::from(custom_path);

// Validate the path is absolute
if !path.is_absolute() {
return Err(anyhow::anyhow!(
"GOOSE_VECTOR_DB_PATH must be an absolute path, got: {}",
path.display()
));
}

return Ok(path);
}

// Fall back to default XDG-based path
let data_dir = Xdg::new()
.context("Failed to determine base strategy")?
.data_dir();
Expand Down Expand Up @@ -363,6 +383,7 @@ mod tests {
use super::*;

#[tokio::test]
#[serial_test::serial]
async fn test_tool_vectordb_creation() {
let db = ToolVectorDB::new(Some("test_tools_vectordb_creation".to_string()))
.await
Expand All @@ -372,6 +393,7 @@ mod tests {
}

#[tokio::test]
#[serial_test::serial]
async fn test_tool_vectordb_operations() -> Result<()> {
// Create a new database instance with a unique table name
let db = ToolVectorDB::new(Some("test_tool_vectordb_operations".to_string())).await?;
Expand Down Expand Up @@ -440,6 +462,7 @@ mod tests {
}

#[tokio::test]
#[serial_test::serial]
async fn test_empty_db() -> Result<()> {
// Create a new database instance with a unique table name
let db = ToolVectorDB::new(Some("test_empty_db".to_string())).await?;
Expand All @@ -458,6 +481,7 @@ mod tests {
}

#[tokio::test]
#[serial_test::serial]
async fn test_tool_deletion() -> Result<()> {
// Create a new database instance with a unique table name
let db = ToolVectorDB::new(Some("test_tool_deletion".to_string())).await?;
Expand Down Expand Up @@ -490,4 +514,74 @@ mod tests {

Ok(())
}

#[test]
#[serial_test::serial]
fn test_custom_db_path_override() -> Result<()> {
use std::env;
use tempfile::TempDir;

// Create a temporary directory for testing
let temp_dir = TempDir::new().unwrap();
let custom_path = temp_dir.path().join("custom_vector_db");

// Set the environment variable
env::set_var("GOOSE_VECTOR_DB_PATH", custom_path.to_str().unwrap());

// Test that get_db_path returns the custom path
let db_path = ToolVectorDB::get_db_path()?;
assert_eq!(db_path, custom_path);

// Clean up
env::remove_var("GOOSE_VECTOR_DB_PATH");

Ok(())
}

#[test]
#[serial_test::serial]
fn test_custom_db_path_validation() {
use std::env;

// Test that relative paths are rejected
env::set_var("GOOSE_VECTOR_DB_PATH", "relative/path");

let result = ToolVectorDB::get_db_path();
assert!(
result.is_err(),
"Expected error for relative path, got: {:?}",
result
);
assert!(result
.unwrap_err()
.to_string()
.contains("must be an absolute path"));

// Clean up
env::remove_var("GOOSE_VECTOR_DB_PATH");
}

#[test]
#[serial_test::serial]
fn test_fallback_to_default_path() -> Result<()> {
use std::env;

// Ensure no custom path is set
env::remove_var("GOOSE_VECTOR_DB_PATH");

// Test that it falls back to default XDG path
let db_path = ToolVectorDB::get_db_path()?;
assert!(
db_path.to_string_lossy().contains("goose"),
"Path should contain 'goose', got: {}",
db_path.display()
);
assert!(
db_path.to_string_lossy().contains("tool_db"),
"Path should contain 'tool_db', got: {}",
db_path.display()
);

Ok(())
}
}
Loading