diff --git a/crates/goose-cli/src/commands/mcp.rs b/crates/goose-cli/src/commands/mcp.rs index 620bcdfe91b4..14cf395a699d 100644 --- a/crates/goose-cli/src/commands/mcp.rs +++ b/crates/goose-cli/src/commands/mcp.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, Result}; use goose_mcp::{ - AutoVisualiserRouter, ComputerControllerRouter, DeveloperServer, MemoryRouter, TutorialServer, + AutoVisualiserRouter, ComputerControllerRouter, DeveloperServer, MemoryServer, TutorialServer, }; use mcp_server::router::RouterService; use mcp_server::{BoundedService, ByteTransport, Server}; @@ -65,10 +65,27 @@ pub async fn run_server(name: &str) -> Result<()> { return Ok(()); } + if name == "memory" { + let service = MemoryServer::new().serve(stdio()).await.inspect_err(|e| { + tracing::error!("serving error: {:?}", e); + })?; + + service.waiting().await?; + return Ok(()); + } + + // Handle old MCP-based servers + if name == "memory" { + let service = MemoryServer::new().serve(stdio()).await.inspect_err(|e| { + tracing::error!("serving error: {:?}", e); + })?; + service.waiting().await?; + return Ok(()); + } + // Handle old MCP-based servers let router: Option> = match name { "computercontroller" => Some(Box::new(RouterService(ComputerControllerRouter::new()))), - "memory" => Some(Box::new(RouterService(MemoryRouter::new()))), _ => None, }; diff --git a/crates/goose-mcp/examples/mcp.rs b/crates/goose-mcp/examples/mcp.rs index 052e78572672..6e02ac88724f 100644 --- a/crates/goose-mcp/examples/mcp.rs +++ b/crates/goose-mcp/examples/mcp.rs @@ -1,9 +1,6 @@ // An example script to run an MCP server use anyhow::Result; -use goose_mcp::MemoryRouter; -use mcp_server::router::RouterService; -use mcp_server::{ByteTransport, Server}; -use tokio::io::{stdin, stdout}; +use goose_mcp::MemoryServer; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{self, EnvFilter}; @@ -24,13 +21,16 @@ async fn main() -> Result<()> { tracing::info!("Starting MCP server"); - // Create an instance of our counter router - let router = RouterService(MemoryRouter::new()); + // Create an instance of our memory server + let memory_server = MemoryServer::new(); - // Create and run the server - let server = Server::new(router); - let transport = ByteTransport::new(stdin(), stdout()); + // Run the server using rmcp + let transport = rmcp::transport::stdio(); tracing::info!("Server initialized and ready to handle requests"); - Ok(server.run(transport).await?) + let running_service = rmcp::service::serve_directly(memory_server, transport, None); + + // Wait for the service to complete + running_service.waiting().await?; + Ok(()) } diff --git a/crates/goose-mcp/src/lib.rs b/crates/goose-mcp/src/lib.rs index 329efe23e21a..2a369213ede4 100644 --- a/crates/goose-mcp/src/lib.rs +++ b/crates/goose-mcp/src/lib.rs @@ -16,5 +16,5 @@ pub mod tutorial; pub use autovisualiser::AutoVisualiserRouter; pub use computercontroller::ComputerControllerRouter; pub use developer::rmcp_developer::DeveloperServer; -pub use memory::MemoryRouter; +pub use memory::MemoryServer; pub use tutorial::TutorialServer; diff --git a/crates/goose-mcp/src/memory/mod.rs b/crates/goose-mcp/src/memory/mod.rs index 2bd1d20a57b0..91109e797b54 100644 --- a/crates/goose-mcp/src/memory/mod.rs +++ b/crates/goose-mcp/src/memory/mod.rs @@ -1,128 +1,83 @@ -use async_trait::async_trait; use etcetera::{choose_app_strategy, AppStrategy}; use indoc::formatdoc; -use mcp_core::{ - handler::{PromptError, ResourceError}, - protocol::ServerCapabilities, - tool::ToolCall, +use rmcp::{ + handler::server::{router::tool::ToolRouter, wrapper::Parameters}, + model::{ + CallToolResult, Content, ErrorCode, ErrorData, Implementation, ServerCapabilities, + ServerInfo, + }, + schemars::JsonSchema, + tool, tool_handler, tool_router, ServerHandler, }; -use mcp_server::router::CapabilitiesBuilder; -use mcp_server::Router; -use rmcp::model::{ - Content, ErrorCode, ErrorData, JsonRpcMessage, Prompt, Resource, Tool, ToolAnnotations, -}; -use rmcp::object; -use serde_json::Value; +use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, fs, - future::Future, io::{self, Read, Write}, path::PathBuf, - pin::Pin, }; -use tokio::sync::mpsc; -// MemoryRouter implementation +/// Parameters for the remember_memory tool +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct RememberMemoryParams { + /// The category to store the memory in + pub category: String, + /// The data to remember + pub data: String, + /// Optional tags for the memory + #[serde(default)] + pub tags: Vec, + /// Whether to store globally or locally + pub is_global: bool, +} + +/// Parameters for the retrieve_memories tool +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct RetrieveMemoriesParams { + /// The category to retrieve memories from (use "*" for all) + pub category: String, + /// Whether to retrieve from global or local storage + pub is_global: bool, +} + +/// Parameters for the remove_memory_category tool +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct RemoveMemoryCategoryParams { + /// The category to remove (use "*" for all) + pub category: String, + /// Whether to remove from global or local storage + pub is_global: bool, +} + +/// Parameters for the remove_specific_memory tool +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct RemoveSpecificMemoryParams { + /// The category containing the memory + pub category: String, + /// The content of the memory to remove + pub memory_content: String, + /// Whether to remove from global or local storage + pub is_global: bool, +} + +/// Memory MCP Server using official RMCP SDK #[derive(Clone)] -pub struct MemoryRouter { - tools: Vec, +pub struct MemoryServer { + tool_router: ToolRouter, instructions: String, global_memory_dir: PathBuf, local_memory_dir: PathBuf, } -impl Default for MemoryRouter { +impl Default for MemoryServer { fn default() -> Self { Self::new() } } -impl MemoryRouter { +#[tool_router(router = tool_router)] +impl MemoryServer { pub fn new() -> Self { - let remember_memory = Tool::new( - "remember_memory", - "Stores a memory with optional tags in a specified category", - object!({ - "type": "object", - "properties": { - "category": {"type": "string"}, - "data": {"type": "string"}, - "tags": {"type": "array", "items": {"type": "string"}}, - "is_global": {"type": "boolean"} - }, - "required": ["category", "data", "is_global"] - }), - ) - .annotate(ToolAnnotations { - title: Some("Remember Memory".to_string()), - read_only_hint: Some(false), - destructive_hint: Some(false), - idempotent_hint: Some(true), - open_world_hint: Some(false), - }); - - let retrieve_memories = Tool::new( - "retrieve_memories", - "Retrieves all memories from a specified category", - object!({ - "type": "object", - "properties": { - "category": {"type": "string"}, - "is_global": {"type": "boolean"} - }, - "required": ["category", "is_global"] - }), - ) - .annotate(ToolAnnotations { - title: Some("Retrieve Memory".to_string()), - read_only_hint: Some(true), - destructive_hint: Some(false), - idempotent_hint: Some(false), - open_world_hint: Some(false), - }); - - let remove_memory_category = Tool::new( - "remove_memory_category", - "Removes all memories within a specified category", - object!({ - "type": "object", - "properties": { - "category": {"type": "string"}, - "is_global": {"type": "boolean"} - }, - "required": ["category", "is_global"] - }), - ) - .annotate(ToolAnnotations { - title: Some("Remove Memory Category".to_string()), - read_only_hint: Some(false), - destructive_hint: Some(true), - idempotent_hint: Some(false), - open_world_hint: Some(false), - }); - - let remove_specific_memory = Tool::new( - "remove_specific_memory", - "Removes a specific memory within a specified category", - object!({ - "type": "object", - "properties": { - "category": {"type": "string"}, - "memory_content": {"type": "string"}, - "is_global": {"type": "boolean"} - }, - "required": ["category", "memory_content", "is_global"] - }), - ) - .annotate(ToolAnnotations { - title: Some("Remove Specific Memory".to_string()), - read_only_hint: Some(false), - destructive_hint: Some(true), - idempotent_hint: Some(false), - open_world_hint: Some(false), - }); - let instructions = formatdoc! {r#" This extension allows storage and retrieval of categorized information with tagging support. It's designed to help manage important information across sessions in a systematic and organized manner. @@ -242,12 +197,7 @@ impl MemoryRouter { .unwrap_or_else(|_| PathBuf::from(".config/goose/memory")); let mut memory_router = Self { - tools: vec![ - remember_memory, - retrieve_memories, - remove_memory_category, - remove_specific_memory, - ], + tool_router: Self::tool_router(), instructions: instructions.clone(), global_memory_dir, local_memory_dir, @@ -405,7 +355,7 @@ impl MemoryRouter { Ok(memories) } - pub fn remove_specific_memory( + pub fn remove_specific_memory_internal( &self, category: &str, memory_content: &str, @@ -453,178 +403,134 @@ impl MemoryRouter { Ok(()) } - async fn execute_tool_call(&self, tool_call: ToolCall) -> Result { - match tool_call.name.as_str() { - "remember_memory" => { - let args = MemoryArgs::from_value(&tool_call.arguments)?; - let data = args.data.filter(|d| !d.is_empty()).ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Data must exist when remembering a memory", - ) - })?; - self.remember("context", args.category, data, &args.tags, args.is_global)?; - Ok(format!("Stored memory in category: {}", args.category)) - } - "retrieve_memories" => { - let args = MemoryArgs::from_value(&tool_call.arguments)?; - let memories = if args.category == "*" { - self.retrieve_all(args.is_global)? - } else { - self.retrieve(args.category, args.is_global)? - }; - Ok(format!("Retrieved memories: {:?}", memories)) - } - "remove_memory_category" => { - let args = MemoryArgs::from_value(&tool_call.arguments)?; - if args.category == "*" { - self.clear_all_global_or_local_memories(args.is_global)?; - Ok(format!( - "Cleared all memory {} categories", - if args.is_global { "global" } else { "local" } - )) - } else { - self.clear_memory(args.category, args.is_global)?; - Ok(format!("Cleared memories in category: {}", args.category)) - } - } - "remove_specific_memory" => { - let args = MemoryArgs::from_value(&tool_call.arguments)?; - let memory_content = tool_call.arguments["memory_content"].as_str().unwrap(); - self.remove_specific_memory(args.category, memory_content, args.is_global)?; - Ok(format!( - "Removed specific memory from category: {}", - args.category - )) - } - _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Unknown tool")), + /// Stores a memory with optional tags in a specified category + #[tool( + name = "remember_memory", + description = "Stores a memory with optional tags in a specified category" + )] + pub async fn remember_memory( + &self, + params: Parameters, + ) -> Result { + let params = params.0; + + if params.data.is_empty() { + return Err(ErrorData::new( + ErrorCode::INVALID_PARAMS, + "Data must not be empty when remembering a memory".to_string(), + None, + )); } - } -} -#[async_trait] -impl Router for MemoryRouter { - fn name(&self) -> String { - "memory".to_string() - } + let tags: Vec<&str> = params.tags.iter().map(|s| s.as_str()).collect(); + self.remember( + "context", + ¶ms.category, + ¶ms.data, + &tags, + params.is_global, + ) + .map_err(|e| ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), None))?; - fn instructions(&self) -> String { - self.instructions.clone() + Ok(CallToolResult::success(vec![Content::text(format!( + "Stored memory in category: {}", + params.category + ))])) } - fn capabilities(&self) -> ServerCapabilities { - CapabilitiesBuilder::new().with_tools(false).build() - } + /// Retrieves all memories from a specified category + #[tool( + name = "retrieve_memories", + description = "Retrieves all memories from a specified category" + )] + pub async fn retrieve_memories( + &self, + params: Parameters, + ) -> Result { + let params = params.0; - fn list_tools(&self) -> Vec { - self.tools.clone() + let memories = if params.category == "*" { + self.retrieve_all(params.is_global) + } else { + self.retrieve(¶ms.category, params.is_global) + } + .map_err(|e| ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), None))?; + + Ok(CallToolResult::success(vec![Content::text(format!( + "Retrieved memories: {:?}", + memories + ))])) } - fn call_tool( + /// Removes all memories within a specified category + #[tool( + name = "remove_memory_category", + description = "Removes all memories within a specified category" + )] + pub async fn remove_memory_category( &self, - tool_name: &str, - arguments: Value, - _notifier: mpsc::Sender, - ) -> Pin, ErrorData>> + Send + 'static>> { - let this = self.clone(); - let tool_name = tool_name.to_string(); - - Box::pin(async move { - let tool_call = ToolCall { - name: tool_name, - arguments, - }; - match this.execute_tool_call(tool_call).await { - Ok(result) => Ok(vec![Content::text(result)]), - Err(err) => Err(ErrorData::new( - ErrorCode::INTERNAL_ERROR, - err.to_string(), - None, - )), - } - }) - } + params: Parameters, + ) -> Result { + let params = params.0; + + let message = if params.category == "*" { + self.clear_all_global_or_local_memories(params.is_global) + .map_err(|e| ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), None))?; + format!( + "Cleared all memory {} categories", + if params.is_global { "global" } else { "local" } + ) + } else { + self.clear_memory(¶ms.category, params.is_global) + .map_err(|e| ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), None))?; + format!("Cleared memories in category: {}", params.category) + }; - fn list_resources(&self) -> Vec { - Vec::new() + Ok(CallToolResult::success(vec![Content::text(message)])) } - fn read_resource( + /// Removes a specific memory within a specified category + #[tool( + name = "remove_specific_memory", + description = "Removes a specific memory within a specified category" + )] + pub async fn remove_specific_memory( &self, - _uri: &str, - ) -> Pin> + Send + 'static>> { - Box::pin(async move { Ok("".to_string()) }) - } - fn list_prompts(&self) -> Vec { - vec![] - } + params: Parameters, + ) -> Result { + let params = params.0; + + self.remove_specific_memory_internal( + ¶ms.category, + ¶ms.memory_content, + params.is_global, + ) + .map_err(|e| ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), None))?; - fn get_prompt( - &self, - prompt_name: &str, - ) -> Pin> + Send + 'static>> { - let prompt_name = prompt_name.to_string(); - Box::pin(async move { - Err(PromptError::NotFound(format!( - "Prompt {} not found", - prompt_name - ))) - }) + Ok(CallToolResult::success(vec![Content::text(format!( + "Removed specific memory from category: {}", + params.category + ))])) } } -#[derive(Debug)] -struct MemoryArgs<'a> { - category: &'a str, - data: Option<&'a str>, - tags: Vec<&'a str>, - is_global: bool, -} - -impl<'a> MemoryArgs<'a> { - // Category is required, data is optional, tags are optional, is_global is optional - fn from_value(args: &'a Value) -> Result { - let category = args["category"].as_str().ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidInput, "Category must be a string") - })?; - - if category.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Category must be a string", - )); +#[tool_handler(router = self.tool_router)] +impl ServerHandler for MemoryServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + server_info: Implementation { + name: "goose-memory".to_string(), + version: env!("CARGO_PKG_VERSION").to_owned(), + }, + capabilities: ServerCapabilities::builder().enable_tools().build(), + instructions: Some(self.instructions.clone()), + ..Default::default() } - - let data = args.get("data").and_then(|d| d.as_str()); - - let tags = match &args["tags"] { - Value::Array(arr) => arr.iter().filter_map(|v| v.as_str()).collect(), - Value::String(s) => vec![s.as_str()], - _ => Vec::new(), - }; - - let is_global = match &args.get("is_global") { - // Default to false if no is_global flag is provided - Some(Value::Bool(b)) => *b, - Some(Value::String(s)) => s.to_lowercase() == "true", - None => false, - _ => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "is_global must be a boolean or string 'true'/'false'", - )) - } - }; - - Ok(Self { - category, - data, - tags, - is_global, - }) } } +// Remove the old MemoryArgs struct since we're using the new parameter structs + #[cfg(test)] mod tests { use super::*; @@ -635,8 +541,8 @@ mod tests { let temp_dir = tempdir().unwrap(); let memory_base = temp_dir.path().join("test_memory"); - let router = MemoryRouter { - tools: vec![], + let router = MemoryServer { + tool_router: ToolRouter::new(), instructions: String::new(), global_memory_dir: memory_base.join("global"), local_memory_dir: memory_base.join("local"), @@ -676,8 +582,8 @@ mod tests { let temp_dir = tempdir().unwrap(); let memory_base = temp_dir.path().join("nonexistent_memory"); - let router = MemoryRouter { - tools: vec![], + let router = MemoryServer { + tool_router: ToolRouter::new(), instructions: String::new(), global_memory_dir: memory_base.join("global"), local_memory_dir: memory_base.join("local"), @@ -692,8 +598,8 @@ mod tests { let temp_dir = tempdir().unwrap(); let memory_base = temp_dir.path().join("workflow_test"); - let router = MemoryRouter { - tools: vec![], + let router = MemoryServer { + tool_router: ToolRouter::new(), instructions: String::new(), global_memory_dir: memory_base.join("global"), local_memory_dir: memory_base.join("local"), @@ -729,8 +635,8 @@ mod tests { let temp_dir = tempdir().unwrap(); let memory_base = temp_dir.path().join("write_test"); - let router = MemoryRouter { - tools: vec![], + let router = MemoryServer { + tool_router: ToolRouter::new(), instructions: String::new(), global_memory_dir: memory_base.join("global"), local_memory_dir: memory_base.join("local"), @@ -751,8 +657,8 @@ mod tests { let temp_dir = tempdir().unwrap(); let memory_base = temp_dir.path().join("remove_test"); - let router = MemoryRouter { - tools: vec![], + let router = MemoryServer { + tool_router: ToolRouter::new(), instructions: String::new(), global_memory_dir: memory_base.join("global"), local_memory_dir: memory_base.join("local"), @@ -769,7 +675,7 @@ mod tests { assert_eq!(memories.len(), 1); router - .remove_specific_memory("category", "remove_this", false) + .remove_specific_memory_internal("category", "remove_this", false) .unwrap(); let memories_after = router.retrieve("category", false).unwrap(); diff --git a/crates/goose-server/src/commands/mcp.rs b/crates/goose-server/src/commands/mcp.rs index 76443b93c58c..eba1428c4d30 100644 --- a/crates/goose-server/src/commands/mcp.rs +++ b/crates/goose-server/src/commands/mcp.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, Result}; use goose_mcp::{ - AutoVisualiserRouter, ComputerControllerRouter, DeveloperServer, MemoryRouter, TutorialServer, + AutoVisualiserRouter, ComputerControllerRouter, DeveloperServer, MemoryServer, TutorialServer, }; use mcp_server::router::RouterService; use mcp_server::{BoundedService, ByteTransport, Server}; @@ -55,10 +55,17 @@ pub async fn run(name: &str) -> Result<()> { return Ok(()); } + if name == "memory" { + let service = MemoryServer::new().serve(stdio()).await.inspect_err(|e| { + tracing::error!("serving error: {:?}", e); + })?; + service.waiting().await?; + return Ok(()); + } + // Handle old MCP-based servers let router: Option> = match name { "computercontroller" => Some(Box::new(RouterService(ComputerControllerRouter::new()))), - "memory" => Some(Box::new(RouterService(MemoryRouter::new()))), _ => None, };