Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions crates/rmcp/src/handler/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ impl<H: ClientHandler> Service<RoleClient> for H {
ServerNotification::PromptListChangedNotification(_notification_no_param) => {
self.on_prompt_list_changed(context).await
}
ServerNotification::CustomNotification(notification) => {
self.on_custom_notification(notification, context).await
}
};
Ok(())
}
Expand Down Expand Up @@ -166,6 +169,14 @@ pub trait ClientHandler: Sized + Send + Sync + 'static {
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}
fn on_custom_notification(
&self,
notification: CustomNotification,
context: NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + Send + '_ {
let _ = (notification, context);
std::future::ready(())
}

fn get_info(&self) -> ClientInfo {
ClientInfo::default()
Expand Down
4 changes: 2 additions & 2 deletions crates/rmcp/src/handler/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl<H: ServerHandler> Service<RoleServer> for H {
ClientNotification::RootsListChangedNotification(_notification) => {
self.on_roots_list_changed(context).await
}
ClientNotification::CustomClientNotification(notification) => {
ClientNotification::CustomNotification(notification) => {
self.on_custom_notification(notification, context).await
}
};
Expand Down Expand Up @@ -230,7 +230,7 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
}
fn on_custom_notification(
&self,
notification: CustomClientNotification,
notification: CustomNotification,
context: NotificationContext<RoleServer>,
) -> impl Future<Output = ()> + Send + '_ {
let _ = (notification, context);
Expand Down
45 changes: 39 additions & 6 deletions crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,13 +627,13 @@ const_string!(CancelledNotificationMethod = "notifications/cancelled");
pub type CancelledNotification =
Notification<CancelledNotificationMethod, CancelledNotificationParam>;

/// A catch-all notification the client can use to send custom messages to a server.
/// A catch-all notification either side can use to send custom messages to its peer.
///
/// This preserves the raw `method` name and `params` payload so handlers can
/// deserialize them into domain-specific types.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct CustomClientNotification {
pub struct CustomNotification {
pub method: String,
pub params: Option<Value>,
/// extensions will carry anything possible in the context, including [`Meta`]
Expand All @@ -643,7 +643,7 @@ pub struct CustomClientNotification {
pub extensions: Extensions,
}

impl CustomClientNotification {
impl CustomNotification {
pub fn new(method: impl Into<String>, params: Option<Value>) -> Self {
Self {
method: method.into(),
Expand Down Expand Up @@ -1786,7 +1786,7 @@ ts_union!(
| ProgressNotification
| InitializedNotification
| RootsListChangedNotification
| CustomClientNotification;
| CustomNotification;
);

ts_union!(
Expand Down Expand Up @@ -1817,7 +1817,8 @@ ts_union!(
| ResourceUpdatedNotification
| ResourceListChangedNotification
| ToolListChangedNotification
| PromptListChangedNotification;
| PromptListChangedNotification
| CustomNotification;
);

ts_union!(
Expand Down Expand Up @@ -1907,7 +1908,7 @@ mod tests {
serde_json::from_value(raw.clone()).expect("invalid notification");
match &message {
ClientJsonRpcMessage::Notification(JsonRpcNotification {
notification: ClientNotification::CustomClientNotification(notification),
notification: ClientNotification::CustomNotification(notification),
..
}) => {
assert_eq!(notification.method, "notifications/custom");
Expand All @@ -1927,6 +1928,38 @@ mod tests {
assert_eq!(json, raw);
}

#[test]
fn test_custom_server_notification_roundtrip() {
let raw = json!( {
"jsonrpc": JsonRpcVersion2_0,
"method": "notifications/custom-server",
"params": {"hello": "world"},
});

let message: ServerJsonRpcMessage =
serde_json::from_value(raw.clone()).expect("invalid notification");
match &message {
ServerJsonRpcMessage::Notification(JsonRpcNotification {
notification: ServerNotification::CustomNotification(notification),
..
}) => {
assert_eq!(notification.method, "notifications/custom-server");
assert_eq!(
notification
.params
.as_ref()
.and_then(|p| p.get("hello"))
.expect("hello present"),
"world"
);
}
_ => panic!("Expected custom server notification"),
}

let json = serde_json::to_value(message).expect("valid json");
assert_eq!(json, raw);
}

#[test]
fn test_request_conversion() {
let raw = json!( {
Expand Down
11 changes: 6 additions & 5 deletions crates/rmcp/src/model/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;

use super::{
ClientNotification, ClientRequest, CustomClientNotification, Extensions, JsonObject,
JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification, ServerRequest,
ClientNotification, ClientRequest, CustomNotification, Extensions, JsonObject, JsonRpcMessage,
NumberOrString, ProgressToken, ServerNotification, ServerRequest,
};

pub trait GetMeta {
Expand All @@ -18,7 +18,7 @@ pub trait GetExtensions {
fn extensions_mut(&mut self) -> &mut Extensions;
}

impl GetExtensions for CustomClientNotification {
impl GetExtensions for CustomNotification {
fn extensions(&self) -> &Extensions {
&self.extensions
}
Expand All @@ -27,7 +27,7 @@ impl GetExtensions for CustomClientNotification {
}
}

impl GetMeta for CustomClientNotification {
impl GetMeta for CustomNotification {
fn get_meta_mut(&mut self) -> &mut Meta {
self.extensions_mut().get_or_insert_default()
}
Expand Down Expand Up @@ -104,7 +104,7 @@ variant_extension! {
ProgressNotification
InitializedNotification
RootsListChangedNotification
CustomClientNotification
CustomNotification
}
}

Expand All @@ -117,6 +117,7 @@ variant_extension! {
ResourceListChangedNotification
ToolListChangedNotification
PromptListChangedNotification
CustomNotification
}
}
#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
Expand Down
8 changes: 4 additions & 4 deletions crates/rmcp/src/model/serde_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::borrow::Cow;
use serde::{Deserialize, Serialize};

use super::{
CustomClientNotification, Extensions, Meta, Notification, NotificationNoParam, Request,
CustomNotification, Extensions, Meta, Notification, NotificationNoParam, Request,
RequestNoParam, RequestOptionalParam,
};
#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -249,7 +249,7 @@ where
}
}

impl Serialize for CustomClientNotification {
impl Serialize for CustomNotification {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
Expand Down Expand Up @@ -277,7 +277,7 @@ impl Serialize for CustomClientNotification {
}
}

impl<'de> Deserialize<'de> for CustomClientNotification {
impl<'de> Deserialize<'de> for CustomNotification {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
Expand All @@ -294,7 +294,7 @@ impl<'de> Deserialize<'de> for CustomClientNotification {
if let Some(meta) = _meta {
extensions.insert(meta);
}
Ok(CustomClientNotification {
Ok(CustomNotification {
extensions,
method: body.method,
params,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@
"content"
]
},
"CustomClientNotification": {
"description": "A catch-all notification the client can use to send custom messages to a server.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.",
"CustomNotification": {
"description": "A catch-all notification either side can use to send custom messages to its peer.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.",
"type": "object",
"properties": {
"method": {
Expand Down Expand Up @@ -651,7 +651,7 @@
"$ref": "#/definitions/NotificationNoParam2"
},
{
"$ref": "#/definitions/CustomClientNotification"
"$ref": "#/definitions/CustomNotification"
}
],
"required": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@
"content"
]
},
"CustomClientNotification": {
"description": "A catch-all notification the client can use to send custom messages to a server.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.",
"CustomNotification": {
"description": "A catch-all notification either side can use to send custom messages to its peer.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.",
"type": "object",
"properties": {
"method": {
Expand Down Expand Up @@ -651,7 +651,7 @@
"$ref": "#/definitions/NotificationNoParam2"
},
{
"$ref": "#/definitions/CustomClientNotification"
"$ref": "#/definitions/CustomNotification"
}
],
"required": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,19 @@
"content"
]
},
"CustomNotification": {
"description": "A catch-all notification either side can use to send custom messages to its peer.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.",
"type": "object",
"properties": {
"method": {
"type": "string"
},
"params": true
},
"required": [
"method"
]
},
"CancelledNotificationMethod": {
"type": "string",
"format": "const",
Expand Down Expand Up @@ -977,6 +990,9 @@
},
{
"$ref": "#/definitions/NotificationNoParam3"
},
{
"$ref": "#/definitions/CustomNotification"
}
],
"required": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,19 @@
"content"
]
},
"CustomNotification": {
"description": "A catch-all notification either side can use to send custom messages to its peer.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.",
"type": "object",
"properties": {
"method": {
"type": "string"
},
"params": true
},
"required": [
"method"
]
},
"CancelledNotificationMethod": {
"type": "string",
"format": "const",
Expand Down Expand Up @@ -977,6 +990,9 @@
},
{
"$ref": "#/definitions/NotificationNoParam3"
},
{
"$ref": "#/definitions/CustomNotification"
}
],
"required": [
Expand Down
Loading