diff --git a/Cargo.lock b/Cargo.lock index 86a2b4973fcd..fe9b34877273 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5394,6 +5394,7 @@ dependencies = [ "nix 0.30.1", "rand 0.8.5", "reqwest 0.11.27", + "rmcp", "serde", "serde_json", "serde_urlencoded", diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 52e01d3d8dcb..edd54ed0e260 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -36,9 +36,7 @@ use goose::providers::pricing::initialize_pricing_cache; use goose::session; use input::InputResult; use mcp_core::handler::ToolError; -use mcp_core::protocol::JsonRpcMessage; -use mcp_core::protocol::JsonRpcNotification; -use rmcp::model::PromptMessage; +use rmcp::model::{JsonRpcMessage, JsonRpcNotification, Notification, PromptMessage}; use rand::{distributions::Alphanumeric, Rng}; use rustyline::EditMode; @@ -1024,10 +1022,11 @@ impl Session { } } Some(Ok(AgentEvent::McpNotification((_id, message)))) => { - if let JsonRpcMessage::Notification(JsonRpcNotification{ - method, - params: Some(Value::Object(o)), - .. + if let JsonRpcMessage::Notification( JsonRpcNotification { + notification: Notification { + method, + params: o,.. + },.. }) = message { match method.as_str() { "notifications/message" => { diff --git a/crates/goose-mcp/src/computercontroller/mod.rs b/crates/goose-mcp/src/computercontroller/mod.rs index c54ee567609b..6090f6d587fc 100644 --- a/crates/goose-mcp/src/computercontroller/mod.rs +++ b/crates/goose-mcp/src/computercontroller/mod.rs @@ -13,12 +13,12 @@ use std::os::unix::fs::PermissionsExt; use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, - protocol::{JsonRpcMessage, ServerCapabilities}, + protocol::ServerCapabilities, tool::{Tool, ToolAnnotations}, }; use mcp_server::router::CapabilitiesBuilder; use mcp_server::Router; -use rmcp::model::{AnnotateAble, Content, Prompt, RawResource, Resource}; +use rmcp::model::{AnnotateAble, Content, JsonRpcMessage, Prompt, RawResource, Resource}; mod docx_tool; mod pdf_tool; diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index 307a456065c6..3344905d9ad4 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -22,17 +22,20 @@ use tokio::{ use url::Url; use include_dir::{include_dir, Dir}; -use mcp_core::tool::ToolAnnotations; use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, - protocol::{JsonRpcMessage, JsonRpcNotification, ServerCapabilities}, - tool::Tool, + protocol::ServerCapabilities, + tool::{Tool, ToolAnnotations}, }; + use mcp_server::router::CapabilitiesBuilder; use mcp_server::Router; -use rmcp::model::{Content, Prompt, PromptArgument, PromptTemplate, Resource}; -use rmcp::model::Role; +use rmcp::model::{ + Content, JsonRpcMessage, JsonRpcNotification, JsonRpcVersion2_0, Notification, Prompt, + PromptArgument, PromptTemplate, Resource, Role, +}; +use rmcp::object; use self::editor_models::{create_editor_model, EditorModel}; use self::shell::{expand_path, get_shell_config, is_absolute_path, normalize_line_endings}; @@ -671,15 +674,18 @@ impl DeveloperRouter { let line = String::from_utf8_lossy(&stdout_buf); notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "data": { - "type": "shell", - "stream": "stdout", - "output": line.to_string(), - } - })), + jsonrpc: JsonRpcVersion2_0, + notification: Notification { + method: "notifications/message".to_string(), + params: object!({ + "data": { + "type": "shell", + "stream": "stdout", + "output": line.to_string(), + } + }), + extensions: Default::default(), + } })).ok(); combined_output.push_str(&line); @@ -694,15 +700,18 @@ impl DeveloperRouter { let line = String::from_utf8_lossy(&stderr_buf); notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "data": { - "type": "shell", - "stream": "stderr", - "output": line.to_string(), - } - })), + jsonrpc: JsonRpcVersion2_0, + notification: Notification { + method: "notifications/message".to_string(), + params: object!({ + "data": { + "type": "shell", + "stream": "stderr", + "output": line.to_string(), + } + }), + extensions: Default::default(), + } })).ok(); combined_output.push_str(&line); diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index e36273428f1c..4ad8aeb27c83 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -7,17 +7,7 @@ use base64::Engine; use chrono::NaiveDate; use indoc::indoc; use lazy_static::lazy_static; -use mcp_core::protocol::JsonRpcMessage; use mcp_core::tool::ToolAnnotations; -use oauth_pkce::PkceOAuth2Client; -use regex::Regex; -use rmcp::model::{AnnotateAble, Content, Prompt, RawResource, Resource}; -use serde_json::{json, Value}; -use std::io::Cursor; -use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; -use storage::CredentialsManager; -use tokio::sync::mpsc; - use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, protocol::ServerCapabilities, @@ -25,6 +15,14 @@ use mcp_core::{ }; use mcp_server::router::CapabilitiesBuilder; use mcp_server::Router; +use oauth_pkce::PkceOAuth2Client; +use regex::Regex; +use rmcp::model::{AnnotateAble, Content, JsonRpcMessage, Prompt, RawResource, Resource}; +use serde_json::{json, Value}; +use std::io::Cursor; +use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; +use storage::CredentialsManager; +use tokio::sync::mpsc; use google_docs1::{self, Docs}; use google_drive3::common::ReadSeek; diff --git a/crates/goose-mcp/src/memory/mod.rs b/crates/goose-mcp/src/memory/mod.rs index 16b3122d80db..102e21a5b423 100644 --- a/crates/goose-mcp/src/memory/mod.rs +++ b/crates/goose-mcp/src/memory/mod.rs @@ -1,6 +1,15 @@ use async_trait::async_trait; use etcetera::{choose_app_strategy, AppStrategy}; use indoc::formatdoc; +use mcp_core::{ + handler::{PromptError, ResourceError, ToolError}, + protocol::ServerCapabilities, + tool::{Tool, ToolAnnotations, ToolCall}, +}; +use mcp_server::router::CapabilitiesBuilder; +use mcp_server::Router; +use rmcp::model::JsonRpcMessage; +use rmcp::model::{Content, Prompt, Resource}; use serde_json::{json, Value}; use std::{ collections::HashMap, @@ -12,15 +21,6 @@ use std::{ }; use tokio::sync::mpsc; -use mcp_core::{ - handler::{PromptError, ResourceError, ToolError}, - protocol::{JsonRpcMessage, ServerCapabilities}, - tool::{Tool, ToolAnnotations, ToolCall}, -}; -use mcp_server::router::CapabilitiesBuilder; -use mcp_server::Router; -use rmcp::model::{Content, Prompt, Resource}; - // MemoryRouter implementation #[derive(Clone)] pub struct MemoryRouter { diff --git a/crates/goose-mcp/src/tutorial/mod.rs b/crates/goose-mcp/src/tutorial/mod.rs index 0ce94b12a045..588242bd3096 100644 --- a/crates/goose-mcp/src/tutorial/mod.rs +++ b/crates/goose-mcp/src/tutorial/mod.rs @@ -1,18 +1,17 @@ use anyhow::Result; use include_dir::{include_dir, Dir}; use indoc::formatdoc; -use rmcp::model::{Content, Prompt, Resource, Role}; -use serde_json::{json, Value}; -use std::{future::Future, pin::Pin}; -use tokio::sync::mpsc; - use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, - protocol::{JsonRpcMessage, ServerCapabilities}, + protocol::ServerCapabilities, tool::{Tool, ToolAnnotations}, }; use mcp_server::router::CapabilitiesBuilder; use mcp_server::Router; +use rmcp::model::{Content, JsonRpcMessage, Prompt, Resource, Role}; +use serde_json::{json, Value}; +use std::{future::Future, pin::Pin}; +use tokio::sync::mpsc; static TUTORIALS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/tutorial/tutorials"); diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 36ceed48492d..24a1f7eb104d 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -18,8 +18,8 @@ use goose::{ permission::{Permission, PermissionConfirmation}, session, }; -use mcp_core::{protocol::JsonRpcMessage, ToolResult}; -use rmcp::model::Content; +use mcp_core::ToolResult; +use rmcp::model::{Content, JsonRpcMessage}; use serde::{Deserialize, Serialize}; use serde_json::json; use serde_json::Value; diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index ed755a7a2f7c..40fe5563eee0 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -6,7 +6,6 @@ use std::sync::Arc; use anyhow::{anyhow, Result}; use futures::stream::BoxStream; use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; -use mcp_core::protocol::JsonRpcMessage; use crate::agents::extension::{ExtensionConfig, ExtensionError, ExtensionResult, ToolInfo}; use crate::agents::extension_manager::{get_parameter_names, ExtensionManager}; @@ -45,7 +44,7 @@ use crate::scheduler_trait::SchedulerTrait; use crate::tool_monitor::{ToolCall, ToolMonitor}; use mcp_core::{protocol::GetPromptResult, tool::Tool, ToolError, ToolResult}; use regex::Regex; -use rmcp::model::{Content, Prompt}; +use rmcp::model::{Content, JsonRpcMessage, Prompt}; use serde_json::Value; use tokio::sync::{mpsc, Mutex, RwLock}; use tokio_util::sync::CancellationToken; @@ -775,7 +774,7 @@ impl Agent { let mcp_notifications = self.get_mcp_notifications().await; for notification in mcp_notifications { if let JsonRpcMessage::Notification(notif) = ¬ification { - if let Some(data) = notif.params.as_ref().and_then(|p| p.get("data")) { + if let Some(data) = notif.notification.params.get("data") { if let (Some(subagent_id), Some(_message)) = ( data.get("subagent_id").and_then(|v| v.as_str()), data.get("message").and_then(|v| v.as_str()), diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 6adac4009bb4..89753a5795ef 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -826,9 +826,10 @@ mod tests { use mcp_client::client::Error; use mcp_client::client::McpClientTrait; use mcp_core::protocol::{ - CallToolResult, GetPromptResult, InitializeResult, JsonRpcMessage, ListPromptsResult, - ListResourcesResult, ListToolsResult, ReadResourceResult, + CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult, + ListToolsResult, ReadResourceResult, }; + use rmcp::model::JsonRpcMessage; use serde_json::json; use tokio::sync::mpsc; diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 030c787732b6..0992d47fd60a 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -6,10 +6,11 @@ use crate::{ }; use anyhow::anyhow; use chrono::{DateTime, Utc}; -use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification}; use mcp_core::{handler::ToolError, tool::Tool}; +use rmcp::model::{JsonRpcMessage, JsonRpcNotification, JsonRpcVersion2_0, Notification}; +use rmcp::object; use serde::{Deserialize, Serialize}; -use serde_json::{self, json}; +// use serde_json::{self}; use std::{collections::HashMap, sync::Arc}; use tokio::sync::{Mutex, RwLock}; use tracing::{debug, error, instrument}; @@ -112,18 +113,21 @@ impl SubAgent { /// Send an MCP notification about the subagent's activity pub async fn send_mcp_notification(&self, notification_type: &str, message: &str) { let notification = JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "level": "info", - "logger": format!("subagent_{}", self.id), - "data": { - "subagent_id": self.id, - "type": notification_type, - "message": message, - "timestamp": Utc::now().to_rfc3339() - } - })), + jsonrpc: JsonRpcVersion2_0, + notification: Notification { + method: "notifications/message".to_string(), + params: object!({ + "level": "info", + "logger": format!("subagent_{}", self.id), + "data": { + "subagent_id": self.id, + "type": notification_type, + "message": message, + "timestamp": Utc::now().to_rfc3339() + } + }), + extensions: Default::default(), + }, }); if let Err(e) = self.config.mcp_tx.send(notification).await { diff --git a/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs index 9a71ad4ad7f9..14726a75470c 100644 --- a/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs @@ -7,10 +7,11 @@ use crate::agents::subagent_execution_tool::task_execution_tracker::{ use crate::agents::subagent_execution_tool::tasks::process_task; use crate::agents::subagent_execution_tool::workers::spawn_worker; use crate::agents::subagent_task_config::TaskConfig; -use mcp_core::protocol::JsonRpcMessage; +use rmcp::model::JsonRpcMessage; use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tokio::sync::mpsc; +use tokio::sync::mpsc::Sender; use tokio::time::Instant; const EXECUTION_STATUS_COMPLETED: &str = "completed"; @@ -46,7 +47,7 @@ pub async fn execute_single_task( pub async fn execute_tasks_in_parallel( tasks: Vec, - notifier: mpsc::Sender, + notifier: Sender, task_config: TaskConfig, ) -> ExecutionResponse { let task_execution_tracker = Arc::new(TaskExecutionTracker::new( diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs index faa2bcd5a578..d6f431ede629 100644 --- a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -6,14 +6,14 @@ use crate::agents::subagent_execution_tool::{ tasks_manager::TasksManager, }; use crate::agents::subagent_task_config::TaskConfig; -use mcp_core::protocol::JsonRpcMessage; +use rmcp::model::JsonRpcMessage; use serde_json::{json, Value}; -use tokio::sync::mpsc; +use tokio::sync::mpsc::Sender; pub async fn execute_tasks( input: Value, execution_mode: ExecutionMode, - notifier: mpsc::Sender, + notifier: Sender, task_config: TaskConfig, tasks_manager: &TasksManager, ) -> Result { diff --git a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs index f3860253eb78..fc400dad326b 100644 --- a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs +++ b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs @@ -8,7 +8,7 @@ use crate::agents::{ subagent_execution_tool::task_types::ExecutionMode, subagent_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, }; -use mcp_core::protocol::JsonRpcMessage; +use rmcp::model::JsonRpcMessage; use tokio::sync::mpsc; use tokio_stream; diff --git a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs index c720459e01ae..dab102392799 100644 --- a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs @@ -1,5 +1,5 @@ -use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification}; -use serde_json::json; +use rmcp::model::{JsonRpcMessage, JsonRpcNotification, JsonRpcVersion2_0, Notification}; +use rmcp::object; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; @@ -12,6 +12,7 @@ use crate::agents::subagent_execution_tool::notification_events::{ use crate::agents::subagent_execution_tool::task_types::{Task, TaskInfo, TaskResult, TaskStatus}; use crate::agents::subagent_execution_tool::utils::{count_by_status, get_task_name}; use serde_json::Value; +use tokio::sync::mpsc::Sender; #[derive(Debug, Clone, PartialEq)] pub enum DisplayMode { @@ -67,7 +68,7 @@ impl TaskExecutionTracker { pub fn new( tasks: Vec, display_mode: DisplayMode, - notifier: mpsc::Sender, + notifier: Sender, ) -> Self { let task_map = tasks .into_iter() @@ -155,11 +156,14 @@ impl TaskExecutionTracker { if let Err(e) = self.notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "data": event.to_notification_data() - })), + jsonrpc: JsonRpcVersion2_0, + notification: Notification { + method: "notifications/message".to_string(), + params: object!({ + "data": event.to_notification_data() + }), + extensions: Default::default(), + }, })) { tracing::warn!("Failed to send live output notification: {}", e); @@ -228,11 +232,14 @@ impl TaskExecutionTracker { if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "data": event.to_notification_data() - })), + jsonrpc: JsonRpcVersion2_0, + notification: Notification { + method: "notifications/message".to_string(), + params: object!({ + "data": event.to_notification_data() + }), + extensions: Default::default(), + }, })) { tracing::warn!("Failed to send tasks update notification: {}", e); @@ -289,11 +296,14 @@ impl TaskExecutionTracker { if let Err(e) = self .notifier .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/message".to_string(), - params: Some(json!({ - "data": event.to_notification_data() - })), + jsonrpc: JsonRpcVersion2_0, + notification: Notification { + method: "notifications/message".to_string(), + params: object!({ + "data": event.to_notification_data() + }), + extensions: Default::default(), + }, })) { tracing::warn!("Failed to send tasks complete notification: {}", e); diff --git a/crates/goose/src/agents/subagent_task_config.rs b/crates/goose/src/agents/subagent_task_config.rs index 261fb82b6f5f..282bc7a72ccf 100644 --- a/crates/goose/src/agents/subagent_task_config.rs +++ b/crates/goose/src/agents/subagent_task_config.rs @@ -1,6 +1,6 @@ use crate::agents::extension_manager::ExtensionManager; use crate::providers::base::Provider; -use mcp_core::protocol::JsonRpcMessage; +use rmcp::model::JsonRpcMessage; use std::fmt; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index ea997dfd5ebc..9af001fe7666 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use async_stream::try_stream; use futures::stream::{self, BoxStream}; use futures::{Stream, StreamExt}; -use mcp_core::protocol::JsonRpcMessage; +use rmcp::model::JsonRpcMessage; use tokio::sync::Mutex; use crate::config::permission::PermissionLevel; diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index a678e8f20643..92425e8d5216 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -11,6 +11,7 @@ mcp-core = { path = "../mcp-core" } tokio = { version = "1", features = ["full"] } tokio-util = { version = "0.7", features = ["io"] } reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "rustls-tls-native-roots"] } +rmcp = { workspace = true } eventsource-client = "0.12.0" futures = "0.3" serde = { version = "1.0", features = ["derive"] } diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 5cade18d66ee..e80b18a75fe8 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -1,8 +1,12 @@ use mcp_core::protocol::{ - CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError, - JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult, + CallToolResult, GetPromptResult, Implementation, InitializeResult, ListPromptsResult, ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND, }; + +use rmcp::model::{ + JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, + JsonRpcVersion2_0, Notification, NumberOrString, Request, RequestId, +}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::sync::{ @@ -112,7 +116,7 @@ where T: TransportHandle + Send + Sync + 'static, { service: Mutex>>, - next_id: AtomicU64, + next_id_counter: AtomicU64, // Added for atomic ID generation server_capabilities: Option, server_info: Option, notification_subscribers: Arc>>>, @@ -135,8 +139,14 @@ where Ok(message) => { tracing::info!("Received message: {:?}", message); match message { - JsonRpcMessage::Response(JsonRpcResponse { id: Some(id), .. }) - | JsonRpcMessage::Error(JsonRpcError { id: Some(id), .. }) => { + JsonRpcMessage::Response(JsonRpcResponse { + id: NumberOrString::Number(id), + .. + }) + | JsonRpcMessage::Error(JsonRpcError { + id: NumberOrString::Number(id), + .. + }) => { service_ptr.respond(&id.to_string(), Ok(message)).await; } _ => { @@ -158,7 +168,7 @@ where Ok(Self { service: Mutex::new(middleware.layer(service)), - next_id: AtomicU64::new(1), + next_id_counter: AtomicU64::new(1), server_capabilities: None, server_info: None, notification_subscribers, @@ -172,7 +182,8 @@ where { let mut service = self.service.lock().await; service.ready().await.map_err(|_| Error::NotReady)?; - let id = self.next_id.fetch_add(1, Ordering::SeqCst); + let id_num = self.next_id_counter.fetch_add(1, Ordering::SeqCst); + let id = RequestId::Number(id_num as u32); let mut params = params.clone(); params["_meta"] = json!({ @@ -180,10 +191,13 @@ where }); let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: Some(id), - method: method.to_string(), - params: Some(params), + jsonrpc: JsonRpcVersion2_0, + id, + request: Request { + method: method.to_string(), + params: params.as_object().unwrap().clone(), + extensions: Default::default(), + }, }); let response_msg = service @@ -201,35 +215,26 @@ where })?; match response_msg { - JsonRpcMessage::Response(JsonRpcResponse { - id, result, error, .. - }) => { - // Verify id matches - if id != Some(self.next_id.load(Ordering::SeqCst) - 1) { + JsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => { + // Verify id matches - convert current id to match expected format + let expected_id = RequestId::Number((id_num) as u32); + if id != expected_id { return Err(Error::UnexpectedResponse( "id mismatch for JsonRpcResponse".to_string(), )); } - if let Some(err) = error { - Err(Error::RpcError { - code: err.code, - message: err.message, - }) - } else if let Some(r) = result { - Ok(serde_json::from_value(r)?) - } else { - Err(Error::UnexpectedResponse("missing result".to_string())) - } + Ok(serde_json::from_value(serde_json::to_value(result)?)?) } JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => { - if id != Some(self.next_id.load(Ordering::SeqCst) - 1) { + let expected_id = RequestId::Number((id_num) as u32); + if id != expected_id { return Err(Error::UnexpectedResponse( "id mismatch for JsonRpcError".to_string(), )); } Err(Error::RpcError { - code: error.code, - message: error.message, + code: error.code.0, // Extract the i32 from ErrorCode + message: error.message.to_string(), // Convert Cow to String }) } _ => { @@ -247,9 +252,12 @@ where service.ready().await.map_err(|_| Error::NotReady)?; let notification = JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: method.to_string(), - params: Some(params.clone()), + jsonrpc: JsonRpcVersion2_0, + notification: Notification { + method: method.to_string(), + params: params.as_object().unwrap().clone(), + extensions: Default::default(), + }, }); service diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 12432c644a09..0bdc680c9f8d 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -1,5 +1,5 @@ use futures::future::BoxFuture; -use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; +use rmcp::model::{JsonRpcMessage, JsonRpcRequest}; use std::collections::HashMap; use std::sync::Arc; use std::task::{Context, Poll}; @@ -50,8 +50,8 @@ where let pending_requests = self.pending_requests.clone(); Box::pin(async move { - match request { - JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) => { + match &request { + JsonRpcMessage::Request(JsonRpcRequest { id, .. }) => { // Create a channel to receive the response let (sender, receiver) = oneshot::channel(); pending_requests.insert(id.to_string(), sender).await; @@ -59,15 +59,17 @@ where transport.send(request).await?; receiver.await.map_err(|_| Error::ChannelClosed)? } - JsonRpcMessage::Request(_) => { - // Handle notifications without waiting for a response - transport.send(request).await?; - Ok(JsonRpcMessage::Nil) - } JsonRpcMessage::Notification(_) => { // Handle notifications without waiting for a response transport.send(request).await?; - Ok(JsonRpcMessage::Nil) + // Return a dummy response for notifications + let dummy_response: JsonRpcMessage = + JsonRpcMessage::Response(rmcp::model::JsonRpcResponse { + jsonrpc: rmcp::model::JsonRpcVersion2_0, + id: rmcp::model::RequestId::Number(0), + result: serde_json::Map::new(), + }); + Ok(dummy_response) } _ => Err(Error::UnsupportedMessage), } diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index 76895d5126c2..1a4c8e1b933f 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use mcp_core::protocol::JsonRpcMessage; +use rmcp::model::JsonRpcMessage; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index 7a38aca9100c..56a00fcf8a64 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -2,8 +2,8 @@ use crate::transport::Error; use async_trait::async_trait; use eventsource_client::{Client, SSE}; use futures::TryStreamExt; -use mcp_core::protocol::JsonRpcMessage; use reqwest::Client as HttpClient; +use rmcp::model::JsonRpcMessage; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, Mutex, RwLock}; diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index afe10e8dc577..e721f0e5ef55 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command}; use async_trait::async_trait; -use mcp_core::protocol::JsonRpcMessage; +use rmcp::model::JsonRpcMessage; use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::sync::{mpsc, Mutex}; diff --git a/crates/mcp-client/src/transport/streamable_http.rs b/crates/mcp-client/src/transport/streamable_http.rs index 7b39218b25a1..8f3380da08f4 100644 --- a/crates/mcp-client/src/transport/streamable_http.rs +++ b/crates/mcp-client/src/transport/streamable_http.rs @@ -3,8 +3,8 @@ use crate::transport::Error; use async_trait::async_trait; use eventsource_client::{Client, SSE}; use futures::TryStreamExt; -use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; use reqwest::Client as HttpClient; +use rmcp::model::{JsonRpcMessage, JsonRpcRequest, NumberOrString::Number}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, Mutex, RwLock}; @@ -89,7 +89,7 @@ impl StreamableHttpActor { let expects_response = matches!( parsed_message, - JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. }) + JsonRpcMessage::Request(JsonRpcRequest { id: Number(_), .. }) ); // Try to send the request diff --git a/crates/mcp-server/src/lib.rs b/crates/mcp-server/src/lib.rs index 413d01d9d88b..159f08caf449 100644 --- a/crates/mcp-server/src/lib.rs +++ b/crates/mcp-server/src/lib.rs @@ -4,8 +4,10 @@ use std::{ }; use futures::{Future, Stream}; -use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcResponse}; use pin_project::pin_project; +use rmcp::model::{ + ErrorData, JsonRpcError, JsonRpcMessage, JsonRpcResponse, JsonRpcVersion2_0, RequestId, +}; use router::McpRequest; use tokio::{ io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}, @@ -151,14 +153,11 @@ where Ok(msg) => { match msg { JsonRpcMessage::Request(request) => { - // Serialize request for logging - let id = request.id; let request_json = serde_json::to_string(&request) .unwrap_or_else(|_| "Failed to serialize request".to_string()); tracing::info!( - request_id = ?id, - method = ?request.method, + method = ?request.request.method, json = %request_json, "Received request" ); @@ -184,16 +183,11 @@ where Err(e) => { let error_msg = e.into().to_string(); tracing::error!(error = %error_msg, "Request processing failed"); - JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id, - result: None, - error: Some(mcp_core::protocol::ErrorData { - code: mcp_core::protocol::INTERNAL_ERROR, - message: error_msg, - data: None, - }), - } + + // Return an error response instead of a regular response + return Err(ServerError::Transport(TransportError::Protocol( + error_msg, + ))); } }; @@ -226,39 +220,38 @@ where } JsonRpcMessage::Response(_) | JsonRpcMessage::Notification(_) - | JsonRpcMessage::Nil + | JsonRpcMessage::BatchRequest(_) + | JsonRpcMessage::BatchResponse(_) | JsonRpcMessage::Error(_) => { - // Ignore responses, notifications and nil messages for now + // Ignore responses, notifications, batch messages and error messages for now continue; } } } Err(e) => { // Convert transport error to JSON-RPC error response - let error = match e { - TransportError::Json(_) | TransportError::InvalidMessage(_) => { - mcp_core::protocol::ErrorData { - code: mcp_core::protocol::PARSE_ERROR, - message: e.to_string(), - data: None, - } - } - TransportError::Protocol(_) => mcp_core::protocol::ErrorData { - code: mcp_core::protocol::INVALID_REQUEST, - message: e.to_string(), + let error_data = match e { + TransportError::Json(_) | TransportError::InvalidMessage(_) => ErrorData { + code: rmcp::model::ErrorCode::PARSE_ERROR, + message: e.to_string().into(), + data: None, + }, + TransportError::Protocol(_) => ErrorData { + code: rmcp::model::ErrorCode::INVALID_REQUEST, + message: e.to_string().into(), data: None, }, - _ => mcp_core::protocol::ErrorData { - code: mcp_core::protocol::INTERNAL_ERROR, - message: e.to_string(), + _ => ErrorData { + code: rmcp::model::ErrorCode::INTERNAL_ERROR, + message: e.to_string().into(), data: None, }, }; let error_response = JsonRpcMessage::Error(JsonRpcError { - jsonrpc: "2.0".to_string(), - id: None, - error, + jsonrpc: JsonRpcVersion2_0, + id: RequestId::Number(0), // Use a default ID for transport errors + error: error_data, }); if let Err(e) = transport.write_message(error_response).await { diff --git a/crates/mcp-server/src/main.rs b/crates/mcp-server/src/main.rs index 32688cb035bf..a9757c7e8dd1 100644 --- a/crates/mcp-server/src/main.rs +++ b/crates/mcp-server/src/main.rs @@ -1,11 +1,10 @@ use anyhow::Result; use mcp_core::handler::{PromptError, ResourceError}; -use mcp_core::protocol::JsonRpcMessage; use mcp_core::tool::ToolAnnotations; use mcp_core::{handler::ToolError, protocol::ServerCapabilities, tool::Tool}; use mcp_server::router::{CapabilitiesBuilder, RouterService}; use mcp_server::{ByteTransport, Router, Server}; -use rmcp::model::{Content, Prompt, PromptArgument, RawResource, Resource}; +use rmcp::model::{Content, JsonRpcMessage, Prompt, PromptArgument, RawResource, Resource}; use serde_json::Value; use std::{future::Future, pin::Pin, sync::Arc}; use tokio::sync::mpsc; diff --git a/crates/mcp-server/src/router.rs b/crates/mcp-server/src/router.rs index c2298e216fdf..4d6186f3eac7 100644 --- a/crates/mcp-server/src/router.rs +++ b/crates/mcp-server/src/router.rs @@ -5,17 +5,18 @@ use std::{ }; type PromptFuture = Pin> + Send + 'static>>; - use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, protocol::{ - CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcMessage, - JsonRpcRequest, JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult, - PromptsCapability, ReadResourceResult, ResourcesCapability, ServerCapabilities, - ToolsCapability, + CallToolResult, GetPromptResult, Implementation, InitializeResult, ListPromptsResult, + ListResourcesResult, ListToolsResult, PromptsCapability, ReadResourceResult, + ResourcesCapability, ServerCapabilities, ToolsCapability, }, }; -use rmcp::model::{Content, Prompt, PromptMessage, PromptMessageRole, Resource, ResourceContents}; +use rmcp::model::{ + Content, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, JsonRpcVersion2_0, Prompt, + PromptMessage, PromptMessageRole, RequestId, Resource, ResourceContents, +}; use serde_json::Value; use tokio::sync::mpsc; use tower_service::Service; @@ -101,15 +102,32 @@ pub trait Router: Send + Sync + 'static { fn get_prompt(&self, prompt_name: &str) -> PromptFuture; // Helper method to create base response - fn create_response(&self, id: Option) -> JsonRpcResponse { + fn create_response(&self, id: RequestId) -> JsonRpcResponse { JsonRpcResponse { - jsonrpc: "2.0".to_string(), + jsonrpc: JsonRpcVersion2_0, id, - result: None, - error: None, + result: serde_json::Map::new(), } } + // Helper method to set result on response + fn set_result( + &self, + response: &mut JsonRpcResponse, + result: T, + ) -> Result<(), RouterError> { + let value = serde_json::to_value(result) + .map_err(|e| RouterError::Internal(format!("JSON serialization error: {}", e)))?; + + if let Some(obj) = value.as_object() { + response.result = obj.clone(); + } else { + return Err(RouterError::Internal("Result must be a JSON object".into())); + } + + Ok(()) + } + fn handle_initialize( &self, req: JsonRpcRequest, @@ -126,11 +144,7 @@ pub trait Router: Send + Sync + 'static { }; let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - + self.set_result(&mut response, result)?; Ok(response) } } @@ -147,11 +161,7 @@ pub trait Router: Send + Sync + 'static { next_cursor: None, }; let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - + self.set_result(&mut response, result)?; Ok(response) } } @@ -162,9 +172,7 @@ pub trait Router: Send + Sync + 'static { notifier: mpsc::Sender, ) -> impl Future> + Send { async move { - let params = req - .params - .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?; + let params = &req.request.params; let name = params .get("name") @@ -185,11 +193,7 @@ pub trait Router: Send + Sync + 'static { }; let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - + self.set_result(&mut response, result)?; Ok(response) } } @@ -206,11 +210,7 @@ pub trait Router: Send + Sync + 'static { next_cursor: None, }; let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - + self.set_result(&mut response, result)?; Ok(response) } } @@ -220,9 +220,7 @@ pub trait Router: Send + Sync + 'static { req: JsonRpcRequest, ) -> impl Future> + Send { async move { - let params = req - .params - .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?; + let params = &req.request.params; let uri = params .get("uri") @@ -240,11 +238,7 @@ pub trait Router: Send + Sync + 'static { }; let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - + self.set_result(&mut response, result)?; Ok(response) } } @@ -259,11 +253,7 @@ pub trait Router: Send + Sync + 'static { let result = ListPromptsResult { prompts }; let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - + self.set_result(&mut response, result)?; Ok(response) } } @@ -274,9 +264,7 @@ pub trait Router: Send + Sync + 'static { ) -> impl Future> + Send { async move { // Validate and extract parameters - let params = req - .params - .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?; + let params = &req.request.params; // Extract "name" field let prompt_name = params @@ -381,13 +369,11 @@ pub trait Router: Send + Sync + 'static { // Build the final response let mut response = self.create_response(req.id); - response.result = Some( - serde_json::to_value(GetPromptResult { - description: Some(description_filled), - messages, - }) - .map_err(|e| RouterError::Internal(format!("JSON serialization error: {}", e)))?, - ); + let result = GetPromptResult { + description: Some(description_filled), + messages, + }; + self.set_result(&mut response, result)?; Ok(response) } } @@ -416,7 +402,7 @@ where let this = self.0.clone(); Box::pin(async move { - let result = match req.request.method.as_str() { + let result = match req.request.request.method.as_str() { "initialize" => this.handle_initialize(req.request).await, "tools/list" => this.handle_tools_list(req.request).await, "tools/call" => this.handle_tools_call(req.request, req.notifier).await, @@ -425,9 +411,9 @@ where "prompts/list" => this.handle_prompts_list(req.request).await, "prompts/get" => this.handle_prompts_get(req.request).await, _ => { - let mut response = this.create_response(req.request.id); - response.error = Some(RouterError::MethodNotFound(req.request.method).into()); - Ok(response) + return Err( + RouterError::MethodNotFound(req.request.request.method.clone()).into(), + ); } };