diff --git a/crates/goose-mcp/src/memory/mod.rs b/crates/goose-mcp/src/memory/mod.rs index 8c814478f2b1..eb5adaeeaff4 100644 --- a/crates/goose-mcp/src/memory/mod.rs +++ b/crates/goose-mcp/src/memory/mod.rs @@ -241,9 +241,6 @@ impl MemoryRouter { .map(|strategy| strategy.in_config_dir("memory")) .unwrap_or_else(|_| PathBuf::from(".config/goose/memory")); - fs::create_dir_all(&global_memory_dir).unwrap(); - fs::create_dir_all(&local_memory_dir).unwrap(); - let mut memory_router = Self { tools: vec![ remember_memory, @@ -353,6 +350,10 @@ impl MemoryRouter { ) -> io::Result<()> { let memory_file_path = self.get_memory_file(category, is_global); + if let Some(parent) = memory_file_path.parent() { + fs::create_dir_all(parent)?; + } + let mut file = fs::OpenOptions::new() .append(true) .create(true) @@ -446,7 +447,9 @@ impl MemoryRouter { } else { &self.local_memory_dir }; - fs::remove_dir_all(base_dir)?; + if base_dir.exists() { + fs::remove_dir_all(base_dir)?; + } Ok(()) } @@ -617,3 +620,163 @@ impl<'a> MemoryArgs<'a> { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + + #[test] + fn test_lazy_directory_creation() { + let temp_dir = tempdir().unwrap(); + let memory_base = temp_dir.path().join("test_memory"); + + let router = MemoryRouter { + tools: vec![], + instructions: String::new(), + global_memory_dir: memory_base.join("global"), + local_memory_dir: memory_base.join("local"), + }; + + assert!(!router.global_memory_dir.exists()); + assert!(!router.local_memory_dir.exists()); + + router + .remember( + "test_context", + "test_category", + "test_data", + &["tag1"], + false, + ) + .unwrap(); + + assert!(router.local_memory_dir.exists()); + assert!(!router.global_memory_dir.exists()); + + router + .remember( + "test_context", + "global_category", + "global_data", + &["global_tag"], + true, + ) + .unwrap(); + + assert!(router.global_memory_dir.exists()); + } + + #[test] + fn test_clear_nonexistent_directories() { + let temp_dir = tempdir().unwrap(); + let memory_base = temp_dir.path().join("nonexistent_memory"); + + let router = MemoryRouter { + tools: vec![], + instructions: String::new(), + global_memory_dir: memory_base.join("global"), + local_memory_dir: memory_base.join("local"), + }; + + assert!(router.clear_all_global_or_local_memories(false).is_ok()); + assert!(router.clear_all_global_or_local_memories(true).is_ok()); + } + + #[test] + fn test_remember_retrieve_clear_workflow() { + let temp_dir = tempdir().unwrap(); + let memory_base = temp_dir.path().join("workflow_test"); + + let router = MemoryRouter { + tools: vec![], + instructions: String::new(), + global_memory_dir: memory_base.join("global"), + local_memory_dir: memory_base.join("local"), + }; + + router + .remember( + "context", + "test_category", + "test_data_content", + &["test_tag"], + false, + ) + .unwrap(); + + let memories = router.retrieve("test_category", false).unwrap(); + assert!(!memories.is_empty()); + + let has_content = memories.values().any(|v| { + v.iter() + .any(|content| content.contains("test_data_content")) + }); + assert!(has_content); + + router.clear_memory("test_category", false).unwrap(); + + let memories_after_clear = router.retrieve("test_category", false).unwrap(); + assert!(memories_after_clear.is_empty()); + } + + #[test] + fn test_directory_creation_on_write() { + let temp_dir = tempdir().unwrap(); + let memory_base = temp_dir.path().join("write_test"); + + let router = MemoryRouter { + tools: vec![], + instructions: String::new(), + global_memory_dir: memory_base.join("global"), + local_memory_dir: memory_base.join("local"), + }; + + assert!(!router.local_memory_dir.exists()); + + router + .remember("context", "category", "data", &[], false) + .unwrap(); + + assert!(router.local_memory_dir.exists()); + assert!(router.local_memory_dir.join("category.txt").exists()); + } + + #[test] + fn test_remove_specific_memory() { + let temp_dir = tempdir().unwrap(); + let memory_base = temp_dir.path().join("remove_test"); + + let router = MemoryRouter { + tools: vec![], + instructions: String::new(), + global_memory_dir: memory_base.join("global"), + local_memory_dir: memory_base.join("local"), + }; + + router + .remember("context", "category", "keep_this", &[], false) + .unwrap(); + router + .remember("context", "category", "remove_this", &[], false) + .unwrap(); + + let memories = router.retrieve("category", false).unwrap(); + assert_eq!(memories.len(), 1); + + router + .remove_specific_memory("category", "remove_this", false) + .unwrap(); + + let memories_after = router.retrieve("category", false).unwrap(); + let has_removed = memories_after + .values() + .any(|v| v.iter().any(|content| content.contains("remove_this"))); + assert!(!has_removed); + + let has_kept = memories_after + .values() + .any(|v| v.iter().any(|content| content.contains("keep_this"))); + assert!(has_kept); + } +}