From 5c11c743a92656059cc463e3f75be7a23da4c5a0 Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Tue, 16 Dec 2025 08:32:51 -0800 Subject: [PATCH] feat: add support for custom server notifications https://github.com/modelcontextprotocol/rust-sdk/pull/556 introduced support for custom client notifications, so this PR makes the complementary change, adding support for custom server notifications. MCP clients, particularly ones that offer "experimental" capabilities, may wish to handle custom server notifications that are not part of the standard MCP specification. This change introduces a new `CustomServerNotification` type that allows a client to process such custom notifications. - introduces `CustomServerNotification` to carry arbitrary methods/params while still preserving meta/extensions; wires it into the `ServerNotification` union and `serde` so `params` can be decoded with `params_as` - allows client handlers to receive custom notifications via a new `on_custom_notification` hook - adds integration coverage that sends a custom server notification end-to-end and asserts the client sees the method and payload Test: ```shell cargo test -p rmcp --features client test_custom_server_notification_reaches_client ``` --- crates/rmcp/src/handler/client.rs | 11 +++ crates/rmcp/src/handler/server.rs | 4 +- crates/rmcp/src/model.rs | 45 +++++++-- crates/rmcp/src/model/meta.rs | 11 ++- crates/rmcp/src/model/serde_impl.rs | 8 +- .../client_json_rpc_message_schema.json | 6 +- ...lient_json_rpc_message_schema_current.json | 6 +- .../server_json_rpc_message_schema.json | 16 ++++ ...erver_json_rpc_message_schema_current.json | 16 ++++ crates/rmcp/tests/test_notification.rs | 92 ++++++++++++++++--- 10 files changed, 180 insertions(+), 35 deletions(-) diff --git a/crates/rmcp/src/handler/client.rs b/crates/rmcp/src/handler/client.rs index 147f2fc29..d023c5d27 100644 --- a/crates/rmcp/src/handler/client.rs +++ b/crates/rmcp/src/handler/client.rs @@ -56,6 +56,9 @@ impl Service for H { ServerNotification::PromptListChangedNotification(_notification_no_param) => { self.on_prompt_list_changed(context).await } + ServerNotification::CustomNotification(notification) => { + self.on_custom_notification(notification, context).await + } }; Ok(()) } @@ -166,6 +169,14 @@ pub trait ClientHandler: Sized + Send + Sync + 'static { ) -> impl Future + Send + '_ { std::future::ready(()) } + fn on_custom_notification( + &self, + notification: CustomNotification, + context: NotificationContext, + ) -> impl Future + Send + '_ { + let _ = (notification, context); + std::future::ready(()) + } fn get_info(&self) -> ClientInfo { ClientInfo::default() diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index fd062dbde..2b55cacb5 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -90,7 +90,7 @@ impl Service for H { ClientNotification::RootsListChangedNotification(_notification) => { self.on_roots_list_changed(context).await } - ClientNotification::CustomClientNotification(notification) => { + ClientNotification::CustomNotification(notification) => { self.on_custom_notification(notification, context).await } }; @@ -230,7 +230,7 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { } fn on_custom_notification( &self, - notification: CustomClientNotification, + notification: CustomNotification, context: NotificationContext, ) -> impl Future + Send + '_ { let _ = (notification, context); diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index 33b507da8..92eca08b5 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -627,13 +627,13 @@ const_string!(CancelledNotificationMethod = "notifications/cancelled"); pub type CancelledNotification = Notification; -/// A catch-all notification the client can use to send custom messages to a server. +/// A catch-all notification either side can use to send custom messages to its peer. /// /// This preserves the raw `method` name and `params` payload so handlers can /// deserialize them into domain-specific types. #[derive(Debug, Clone)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -pub struct CustomClientNotification { +pub struct CustomNotification { pub method: String, pub params: Option, /// extensions will carry anything possible in the context, including [`Meta`] @@ -643,7 +643,7 @@ pub struct CustomClientNotification { pub extensions: Extensions, } -impl CustomClientNotification { +impl CustomNotification { pub fn new(method: impl Into, params: Option) -> Self { Self { method: method.into(), @@ -1786,7 +1786,7 @@ ts_union!( | ProgressNotification | InitializedNotification | RootsListChangedNotification - | CustomClientNotification; + | CustomNotification; ); ts_union!( @@ -1817,7 +1817,8 @@ ts_union!( | ResourceUpdatedNotification | ResourceListChangedNotification | ToolListChangedNotification - | PromptListChangedNotification; + | PromptListChangedNotification + | CustomNotification; ); ts_union!( @@ -1907,7 +1908,7 @@ mod tests { serde_json::from_value(raw.clone()).expect("invalid notification"); match &message { ClientJsonRpcMessage::Notification(JsonRpcNotification { - notification: ClientNotification::CustomClientNotification(notification), + notification: ClientNotification::CustomNotification(notification), .. }) => { assert_eq!(notification.method, "notifications/custom"); @@ -1927,6 +1928,38 @@ mod tests { assert_eq!(json, raw); } + #[test] + fn test_custom_server_notification_roundtrip() { + let raw = json!( { + "jsonrpc": JsonRpcVersion2_0, + "method": "notifications/custom-server", + "params": {"hello": "world"}, + }); + + let message: ServerJsonRpcMessage = + serde_json::from_value(raw.clone()).expect("invalid notification"); + match &message { + ServerJsonRpcMessage::Notification(JsonRpcNotification { + notification: ServerNotification::CustomNotification(notification), + .. + }) => { + assert_eq!(notification.method, "notifications/custom-server"); + assert_eq!( + notification + .params + .as_ref() + .and_then(|p| p.get("hello")) + .expect("hello present"), + "world" + ); + } + _ => panic!("Expected custom server notification"), + } + + let json = serde_json::to_value(message).expect("valid json"); + assert_eq!(json, raw); + } + #[test] fn test_request_conversion() { let raw = json!( { diff --git a/crates/rmcp/src/model/meta.rs b/crates/rmcp/src/model/meta.rs index a03fc056b..1054d6728 100644 --- a/crates/rmcp/src/model/meta.rs +++ b/crates/rmcp/src/model/meta.rs @@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use super::{ - ClientNotification, ClientRequest, CustomClientNotification, Extensions, JsonObject, - JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification, ServerRequest, + ClientNotification, ClientRequest, CustomNotification, Extensions, JsonObject, JsonRpcMessage, + NumberOrString, ProgressToken, ServerNotification, ServerRequest, }; pub trait GetMeta { @@ -18,7 +18,7 @@ pub trait GetExtensions { fn extensions_mut(&mut self) -> &mut Extensions; } -impl GetExtensions for CustomClientNotification { +impl GetExtensions for CustomNotification { fn extensions(&self) -> &Extensions { &self.extensions } @@ -27,7 +27,7 @@ impl GetExtensions for CustomClientNotification { } } -impl GetMeta for CustomClientNotification { +impl GetMeta for CustomNotification { fn get_meta_mut(&mut self) -> &mut Meta { self.extensions_mut().get_or_insert_default() } @@ -104,7 +104,7 @@ variant_extension! { ProgressNotification InitializedNotification RootsListChangedNotification - CustomClientNotification + CustomNotification } } @@ -117,6 +117,7 @@ variant_extension! { ResourceListChangedNotification ToolListChangedNotification PromptListChangedNotification + CustomNotification } } #[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)] diff --git a/crates/rmcp/src/model/serde_impl.rs b/crates/rmcp/src/model/serde_impl.rs index 65e14361e..b43335f30 100644 --- a/crates/rmcp/src/model/serde_impl.rs +++ b/crates/rmcp/src/model/serde_impl.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use serde::{Deserialize, Serialize}; use super::{ - CustomClientNotification, Extensions, Meta, Notification, NotificationNoParam, Request, + CustomNotification, Extensions, Meta, Notification, NotificationNoParam, Request, RequestNoParam, RequestOptionalParam, }; #[derive(Serialize, Deserialize)] @@ -249,7 +249,7 @@ where } } -impl Serialize for CustomClientNotification { +impl Serialize for CustomNotification { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, @@ -277,7 +277,7 @@ impl Serialize for CustomClientNotification { } } -impl<'de> Deserialize<'de> for CustomClientNotification { +impl<'de> Deserialize<'de> for CustomNotification { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, @@ -294,7 +294,7 @@ impl<'de> Deserialize<'de> for CustomClientNotification { if let Some(meta) = _meta { extensions.insert(meta); } - Ok(CustomClientNotification { + Ok(CustomNotification { extensions, method: body.method, params, diff --git a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json index 13df28e27..c1c3a73eb 100644 --- a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json +++ b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json @@ -396,8 +396,8 @@ "content" ] }, - "CustomClientNotification": { - "description": "A catch-all notification the client can use to send custom messages to a server.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.", + "CustomNotification": { + "description": "A catch-all notification either side can use to send custom messages to its peer.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.", "type": "object", "properties": { "method": { @@ -651,7 +651,7 @@ "$ref": "#/definitions/NotificationNoParam2" }, { - "$ref": "#/definitions/CustomClientNotification" + "$ref": "#/definitions/CustomNotification" } ], "required": [ diff --git a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json index 13df28e27..c1c3a73eb 100644 --- a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json +++ b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json @@ -396,8 +396,8 @@ "content" ] }, - "CustomClientNotification": { - "description": "A catch-all notification the client can use to send custom messages to a server.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.", + "CustomNotification": { + "description": "A catch-all notification either side can use to send custom messages to its peer.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.", "type": "object", "properties": { "method": { @@ -651,7 +651,7 @@ "$ref": "#/definitions/NotificationNoParam2" }, { - "$ref": "#/definitions/CustomClientNotification" + "$ref": "#/definitions/CustomNotification" } ], "required": [ diff --git a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json index abee96512..dcc3086fa 100644 --- a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json +++ b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json @@ -392,6 +392,19 @@ "content" ] }, + "CustomNotification": { + "description": "A catch-all notification either side can use to send custom messages to its peer.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.", + "type": "object", + "properties": { + "method": { + "type": "string" + }, + "params": true + }, + "required": [ + "method" + ] + }, "CancelledNotificationMethod": { "type": "string", "format": "const", @@ -977,6 +990,9 @@ }, { "$ref": "#/definitions/NotificationNoParam3" + }, + { + "$ref": "#/definitions/CustomNotification" } ], "required": [ diff --git a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json index abee96512..dcc3086fa 100644 --- a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json +++ b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json @@ -392,6 +392,19 @@ "content" ] }, + "CustomNotification": { + "description": "A catch-all notification either side can use to send custom messages to its peer.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.", + "type": "object", + "properties": { + "method": { + "type": "string" + }, + "params": true + }, + "required": [ + "method" + ] + }, "CancelledNotificationMethod": { "type": "string", "format": "const", @@ -977,6 +990,9 @@ }, { "$ref": "#/definitions/NotificationNoParam3" + }, + { + "$ref": "#/definitions/CustomNotification" } ], "required": [ diff --git a/crates/rmcp/tests/test_notification.rs b/crates/rmcp/tests/test_notification.rs index 408b7c850..cce04364d 100644 --- a/crates/rmcp/tests/test_notification.rs +++ b/crates/rmcp/tests/test_notification.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use rmcp::{ ClientHandler, ServerHandler, ServiceExt, model::{ - ClientNotification, CustomClientNotification, ResourceUpdatedNotificationParam, - ServerCapabilities, ServerInfo, SubscribeRequestParam, + ClientNotification, CustomNotification, ResourceUpdatedNotificationParam, + ServerCapabilities, ServerInfo, ServerNotification, SubscribeRequestParam, }, }; use serde_json::json; @@ -106,12 +106,11 @@ struct CustomServer { impl ServerHandler for CustomServer { async fn on_custom_notification( &self, - notification: CustomClientNotification, + notification: CustomNotification, _context: rmcp::service::NotificationContext, ) { - let CustomClientNotification { method, params, .. } = notification; - let mut payload = self.payload.lock().await; - *payload = Some((method, params)); + let CustomNotification { method, params, .. } = notification; + *self.payload.lock().await = Some((method, params)); self.receive_signal.notify_one(); } } @@ -148,20 +147,89 @@ async fn test_custom_client_notification_reaches_server() -> anyhow::Result<()> let client = ().serve(client_transport).await?; client - .send_notification(ClientNotification::CustomClientNotification( - CustomClientNotification::new( - "notifications/custom-test", - Some(json!({ "foo": "bar" })), - ), + .send_notification(ClientNotification::CustomNotification( + CustomNotification::new("notifications/custom-test", Some(json!({ "foo": "bar" }))), )) .await?; tokio::time::timeout(std::time::Duration::from_secs(5), receive_signal.notified()).await?; - let (method, params) = payload.lock().await.clone().expect("payload set"); + let (method, params) = payload.lock().await.take().expect("payload set"); assert_eq!("notifications/custom-test", method); assert_eq!(Some(json!({ "foo": "bar" })), params); client.cancel().await?; Ok(()) } + +struct CustomServerNotifier; + +impl ServerHandler for CustomServerNotifier { + async fn on_initialized(&self, context: rmcp::service::NotificationContext) { + let peer = context.peer.clone(); + tokio::spawn(async move { + peer.send_notification(ServerNotification::CustomNotification( + CustomNotification::new( + "notifications/custom-test", + Some(json!({ "hello": "world" })), + ), + )) + .await + .expect("send custom notification"); + }); + } +} + +struct CustomClient { + receive_signal: Arc, + payload: Arc>>, +} + +impl ClientHandler for CustomClient { + async fn on_custom_notification( + &self, + notification: CustomNotification, + _context: rmcp::service::NotificationContext, + ) { + let CustomNotification { method, params, .. } = notification; + *self.payload.lock().await = Some((method, params)); + self.receive_signal.notify_one(); + } +} + +#[tokio::test] +async fn test_custom_server_notification_reaches_client() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + + let (server_transport, client_transport) = tokio::io::duplex(4096); + tokio::spawn(async move { + let server = CustomServerNotifier {}.serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + let receive_signal = Arc::new(Notify::new()); + let payload = Arc::new(Mutex::new(None)); + + let client = CustomClient { + receive_signal: receive_signal.clone(), + payload: payload.clone(), + } + .serve(client_transport) + .await?; + + tokio::time::timeout(std::time::Duration::from_secs(5), receive_signal.notified()).await?; + + let (method, params) = payload.lock().await.take().expect("payload set"); + assert_eq!("notifications/custom-test", method); + assert_eq!(Some(json!({ "hello": "world" })), params); + + client.cancel().await?; + Ok(()) +}