From 4ea231ae0ec3024d8ec6501b5d15487e35706de0 Mon Sep 17 00:00:00 2001 From: joulei Date: Fri, 15 Aug 2025 00:59:10 -0300 Subject: [PATCH 1/3] fix: add reqwest dependency to transport-streamable-http-client feature - Fix compilation error when using transport-streamable-http-client feature due to missing dependency - Move From implementation to reqwest module --- crates/rmcp/Cargo.toml | 2 +- .../src/transport/common/reqwest/streamable_http_client.rs | 6 ++++++ crates/rmcp/src/transport/streamable_http_client.rs | 5 ----- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 8e9b48e9b..c36a815e6 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -107,7 +107,7 @@ transport-worker = ["dep:tokio-stream"] # Streamable HTTP client -transport-streamable-http-client = ["client-side-sse", "transport-worker"] +transport-streamable-http-client = ["client-side-sse", "transport-worker", "reqwest"] transport-async-rw = ["tokio/io-util", "tokio-util/codec"] diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index fd3aa1d54..5af907ef8 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -14,6 +14,12 @@ use crate::{ }, }; +impl From for StreamableHttpError { + fn from(e: reqwest::Error) -> Self { + StreamableHttpError::Client(e) + } +} + impl StreamableHttpClient for reqwest::Client { type Error = reqwest::Error; diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 20159a01f..42446a9c5 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -48,11 +48,6 @@ pub enum StreamableHttpError { Auth(#[from] crate::transport::auth::AuthError), } -impl From for StreamableHttpError { - fn from(e: reqwest::Error) -> Self { - StreamableHttpError::Client(e) - } -} pub enum StreamableHttpPostResponse { Accepted, From ce2ac8ba27c7340d22b9b1b775273dcccf823715 Mon Sep 17 00:00:00 2001 From: joulei Date: Mon, 18 Aug 2025 12:19:55 -0300 Subject: [PATCH 2/3] feat(rmcp): enhance transport features by decoupling reqwest - Added reqwest features for reqwest-based implementations. - Updated documentation - Modified error handling in SSE transport to use `String` for content type. - Updated examples to include new features --- crates/rmcp/Cargo.toml | 4 +- crates/rmcp/README.md | 6 +- crates/rmcp/src/transport/common/reqwest.rs | 8 +- .../transport/common/reqwest/sse_client.rs | 37 +++- .../common/reqwest/streamable_http_client.rs | 24 +++ crates/rmcp/src/transport/sse_client.rs | 91 +++++++++- .../src/transport/streamable_http_client.rs | 162 +++++++++++++++++- examples/clients/Cargo.toml | 4 +- examples/rig-integration/Cargo.toml | 5 +- examples/simple-chat-client/Cargo.toml | 5 +- 10 files changed, 321 insertions(+), 25 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index c36a815e6..7b2ceaee3 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -102,12 +102,14 @@ server-side-http = [ client-side-sse = ["dep:sse-stream", "dep:http"] transport-sse-client = ["client-side-sse", "transport-worker"] +transport-sse-client-reqwest = ["transport-sse-client", "reqwest"] transport-worker = ["dep:tokio-stream"] # Streamable HTTP client -transport-streamable-http-client = ["client-side-sse", "transport-worker", "reqwest"] +transport-streamable-http-client = ["client-side-sse", "transport-worker"] +transport-streamable-http-client-reqwest = ["transport-streamable-http-client", "reqwest"] transport-async-rw = ["tokio/io-util", "tokio-util/codec"] diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index 578a228a5..1ce43196c 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -199,8 +199,10 @@ RMCP uses feature flags to control which components are included: - `transport-async-rw`: Async read/write support - `transport-io`: I/O stream support - `transport-child-process`: Child process support - - `transport-sse-client` / `transport-sse-server`: SSE support - - `transport-streamable-http-client` / `transport-streamable-http-server`: HTTP streaming + - `transport-sse-client` / `transport-sse-server`: SSE support (client agnostic) + - `transport-sse-client-reqwest`: a default `reqwest` implementation of the SSE client + - `transport-streamable-http-client` / `transport-streamable-http-server`: HTTP streaming (client agnostic, see [`StreamableHttpClientTransport`] for details) + - `transport-streamable-http-client-reqwest`: a default `reqwest` implementation of the streamable http client - `auth`: OAuth2 authentication support - `schemars`: JSON Schema generation (for tool definitions) diff --git a/crates/rmcp/src/transport/common/reqwest.rs b/crates/rmcp/src/transport/common/reqwest.rs index 5395d571e..4f9dc0dc5 100644 --- a/crates/rmcp/src/transport/common/reqwest.rs +++ b/crates/rmcp/src/transport/common/reqwest.rs @@ -1,7 +1,7 @@ -#[cfg(feature = "transport-streamable-http-client")] -#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))] +#[cfg(feature = "transport-streamable-http-client-reqwest")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client-reqwest")))] mod streamable_http_client; -#[cfg(feature = "transport-sse-client")] -#[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-client")))] +#[cfg(feature = "transport-sse-client-reqwest")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-sse-client-reqwest")))] mod sse_client; diff --git a/crates/rmcp/src/transport/common/reqwest/sse_client.rs b/crates/rmcp/src/transport/common/reqwest/sse_client.rs index 37fe78417..a5362d79c 100644 --- a/crates/rmcp/src/transport/common/reqwest/sse_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/sse_client.rs @@ -11,6 +11,12 @@ use crate::transport::{ sse_client::{SseClient, SseClientConfig, SseTransportError}, }; +impl From for SseTransportError { + fn from(e: reqwest::Error) -> Self { + SseTransportError::Client(e) + } +} + impl SseClient for reqwest::Client { type Error = reqwest::Error; @@ -55,7 +61,9 @@ impl SseClient for reqwest::Client { match response.headers().get(reqwest::header::CONTENT_TYPE) { Some(ct) => { if !ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) { - return Err(SseTransportError::UnexpectedContentType(Some(ct.clone()))); + return Err(SseTransportError::UnexpectedContentType(Some( + String::from_utf8_lossy(ct.as_bytes()).to_string(), + ))); } } None => { @@ -68,6 +76,33 @@ impl SseClient for reqwest::Client { } impl SseClientTransport { + /// Creates a new transport using reqwest with the specified SSE endpoint. + /// + /// This is a convenience method that creates a transport using the default + /// reqwest client. This method is only available when the + /// `transport-sse-client-reqwest` feature is enabled. + /// + /// # Arguments + /// + /// * `uri` - The SSE endpoint to connect to + /// + /// # Example + /// + /// ```rust + /// use rmcp::transport::SseClientTransport; + /// + /// // Enable the reqwest feature in Cargo.toml: + /// // rmcp = { version = "0.5", features = ["transport-sse-client-reqwest"] } + /// + /// # async fn example() -> Result<(), Box> { + /// let transport = SseClientTransport::start("http://localhost:8000/sse").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// # Feature requirement + /// + /// This method requires the `transport-sse-client-reqwest` feature. pub async fn start( uri: impl Into>, ) -> Result> { diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index 5af907ef8..4f69a6a6c 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -131,6 +131,30 @@ impl StreamableHttpClient for reqwest::Client { } impl StreamableHttpClientTransport { + /// Creates a new transport using reqwest with the specified URI. + /// + /// This is a convenience method that creates a transport using the default + /// reqwest client. This method is only available when the + /// `transport-streamable-http-client-reqwest` feature is enabled. + /// + /// # Arguments + /// + /// * `uri` - The server URI to connect to + /// + /// # Example + /// + /// ```rust,no_run + /// use rmcp::transport::StreamableHttpClientTransport; + /// + /// // Enable the reqwest feature in Cargo.toml: + /// // rmcp = { version = "0.5", features = ["transport-streamable-http-client-reqwest"] } + /// + /// let transport = StreamableHttpClientTransport::from_uri("http://localhost:8000/mcp"); + /// ``` + /// + /// # Feature requirement + /// + /// This method requires the `transport-streamable-http-client-reqwest` feature. pub fn from_uri(uri: impl Into>) -> Self { StreamableHttpClientTransport::with_client( reqwest::Client::default(), diff --git a/crates/rmcp/src/transport/sse_client.rs b/crates/rmcp/src/transport/sse_client.rs index 7b6f280c9..910df5f58 100644 --- a/crates/rmcp/src/transport/sse_client.rs +++ b/crates/rmcp/src/transport/sse_client.rs @@ -3,7 +3,7 @@ use std::{pin::Pin, sync::Arc}; use futures::{StreamExt, future::BoxFuture}; use http::Uri; -use reqwest::header::HeaderValue; + use sse_stream::Error as SseError; use thiserror::Error; @@ -28,7 +28,7 @@ pub enum SseTransportError { #[error("unexpected end of stream")] UnexpectedEndOfStream, #[error("Unexpected content type: {0:?}")] - UnexpectedContentType(Option), + UnexpectedContentType(Option), #[cfg(feature = "auth")] #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] #[error("Auth error: {0}")] @@ -39,12 +39,6 @@ pub enum SseTransportError { InvalidUriParts(#[from] http::uri::InvalidUriParts), } -impl From for SseTransportError { - fn from(e: reqwest::Error) -> Self { - SseTransportError::Client(e) - } -} - pub trait SseClient: Clone + Send + Sync + 'static { type Error: std::error::Error + Send + Sync + 'static; fn post_message( @@ -77,6 +71,87 @@ impl SseStreamReconnect for SseClientReconnect { } } type ServerMessageStream = Pin>>>; + +/// A client-agnostic SSE transport for RMCP that supports Server-Sent Events. +/// +/// This transport allows you to choose your preferred HTTP client implementation +/// by implementing the [`SseClient`] trait. The transport handles SSE streaming +/// and automatic reconnection. +/// +/// # Usage +/// +/// ## Using reqwest +/// +/// ```rust +/// use rmcp::transport::SseClientTransport; +/// +/// // Enable the reqwest feature in Cargo.toml: +/// // rmcp = { version = "0.5", features = ["transport-sse-client-reqwest"] } +/// +/// # async fn example() -> Result<(), Box> { +/// let transport = SseClientTransport::start("http://localhost:8000/sse").await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// ## Using a custom HTTP client +/// +/// ```rust +/// use rmcp::transport::sse_client::{SseClient, SseClientTransport, SseClientConfig}; +/// use std::sync::Arc; +/// use futures::stream::BoxStream; +/// use rmcp::model::ClientJsonRpcMessage; +/// use sse_stream::{Sse, Error as SseError}; +/// use http::Uri; +/// +/// #[derive(Clone)] +/// struct MyHttpClient; +/// +/// #[derive(Debug, thiserror::Error)] +/// struct MyError; +/// +/// impl std::fmt::Display for MyError { +/// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +/// write!(f, "MyError") +/// } +/// } +/// +/// impl SseClient for MyHttpClient { +/// type Error = MyError; +/// +/// async fn post_message( +/// &self, +/// _uri: Uri, +/// _message: ClientJsonRpcMessage, +/// _auth_token: Option, +/// ) -> Result<(), rmcp::transport::sse_client::SseTransportError> { +/// todo!() +/// } +/// +/// async fn get_stream( +/// &self, +/// _uri: Uri, +/// _last_event_id: Option, +/// _auth_token: Option, +/// ) -> Result>, rmcp::transport::sse_client::SseTransportError> { +/// todo!() +/// } +/// } +/// +/// # async fn example() -> Result<(), Box> { +/// let config = SseClientConfig { +/// sse_endpoint: "http://localhost:8000/sse".into(), +/// ..Default::default() +/// }; +/// let transport = SseClientTransport::start_with_client(MyHttpClient, config).await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// # Feature Flags +/// +/// - `transport-sse-client`: Base feature providing the generic transport infrastructure +/// - `transport-sse-client-reqwest`: Includes reqwest HTTP client support with convenience methods pub struct SseClientTransport { client: C, config: SseClientConfig, diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 42446a9c5..1e29dfae9 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -48,7 +48,6 @@ pub enum StreamableHttpError { Auth(#[from] crate::transport::auth::AuthError), } - pub enum StreamableHttpPostResponse { Accepted, Json(ServerJsonRpcMessage, Option), @@ -478,9 +477,170 @@ impl Worker for StreamableHttpClientWorker { } } +/// A client-agnostic HTTP transport for RMCP that supports streaming responses. +/// +/// This transport allows you to choose your preferred HTTP client implementation +/// by implementing the [`StreamableHttpClient`] trait. The transport handles +/// session management, SSE streaming, and automatic reconnection. +/// +/// # Usage +/// +/// ## Using reqwest +/// +/// ```rust,no_run +/// use rmcp::transport::StreamableHttpClientTransport; +/// +/// // Enable the reqwest feature in Cargo.toml: +/// // rmcp = { version = "0.5", features = ["transport-streamable-http-client-reqwest"] } +/// +/// let transport = StreamableHttpClientTransport::from_uri("http://localhost:8000/mcp"); +/// ``` +/// +/// ## Using a custom HTTP client +/// +/// ```rust,no_run +/// use rmcp::transport::streamable_http_client::{ +/// StreamableHttpClient, +/// StreamableHttpClientTransport, +/// StreamableHttpClientTransportConfig +/// }; +/// use std::sync::Arc; +/// use futures::stream::BoxStream; +/// use rmcp::model::ClientJsonRpcMessage; +/// use sse_stream::{Sse, Error as SseError}; +/// +/// #[derive(Clone)] +/// struct MyHttpClient; +/// +/// #[derive(Debug, thiserror::Error)] +/// struct MyError; +/// +/// impl std::fmt::Display for MyError { +/// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +/// write!(f, "MyError") +/// } +/// } +/// +/// impl StreamableHttpClient for MyHttpClient { +/// type Error = MyError; +/// +/// async fn post_message( +/// &self, +/// _uri: Arc, +/// _message: ClientJsonRpcMessage, +/// _session_id: Option>, +/// _auth_header: Option, +/// ) -> Result> { +/// todo!() +/// } +/// +/// async fn delete_session( +/// &self, +/// _uri: Arc, +/// _session_id: Arc, +/// _auth_header: Option, +/// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError> { +/// todo!() +/// } +/// +/// async fn get_stream( +/// &self, +/// _uri: Arc, +/// _session_id: Arc, +/// _last_event_id: Option, +/// _auth_header: Option, +/// ) -> Result>, rmcp::transport::streamable_http_client::StreamableHttpError> { +/// todo!() +/// } +/// } +/// +/// let transport = StreamableHttpClientTransport::with_client( +/// MyHttpClient, +/// StreamableHttpClientTransportConfig::with_uri("http://localhost:8000/mcp") +/// ); +/// ``` +/// +/// # Feature Flags +/// +/// - `transport-streamable-http-client`: Base feature providing the generic transport infrastructure +/// - `transport-streamable-http-client-reqwest`: Includes reqwest HTTP client support with convenience methods pub type StreamableHttpClientTransport = WorkerTransport>; impl StreamableHttpClientTransport { + /// Creates a new transport with a custom HTTP client implementation. + /// + /// This method allows you to use any HTTP client that implements the [`StreamableHttpClient`] trait. + /// Use this when you want to use a custom HTTP client or when the reqwest feature is not enabled. + /// + /// # Arguments + /// + /// * `client` - Your HTTP client implementation + /// * `config` - Transport configuration including the server URI + /// + /// # Example + /// + /// ```rust,no_run + /// use rmcp::transport::streamable_http_client::{ + /// StreamableHttpClient, + /// StreamableHttpClientTransport, + /// StreamableHttpClientTransportConfig + /// }; + /// use std::sync::Arc; + /// use futures::stream::BoxStream; + /// use rmcp::model::ClientJsonRpcMessage; + /// use sse_stream::{Sse, Error as SseError}; + /// + /// // Define your custom client + /// #[derive(Clone)] + /// struct MyHttpClient; + /// + /// #[derive(Debug, thiserror::Error)] + /// struct MyError; + /// + /// impl std::fmt::Display for MyError { + /// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + /// write!(f, "MyError") + /// } + /// } + /// + /// impl StreamableHttpClient for MyHttpClient { + /// type Error = MyError; + /// + /// async fn post_message( + /// &self, + /// _uri: Arc, + /// _message: ClientJsonRpcMessage, + /// _session_id: Option>, + /// _auth_header: Option, + /// ) -> Result> { + /// todo!() + /// } + /// + /// async fn delete_session( + /// &self, + /// _uri: Arc, + /// _session_id: Arc, + /// _auth_header: Option, + /// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError> { + /// todo!() + /// } + /// + /// async fn get_stream( + /// &self, + /// _uri: Arc, + /// _session_id: Arc, + /// _last_event_id: Option, + /// _auth_header: Option, + /// ) -> Result>, rmcp::transport::streamable_http_client::StreamableHttpError> { + /// todo!() + /// } + /// } + /// + /// let transport = StreamableHttpClientTransport::with_client( + /// MyHttpClient, + /// StreamableHttpClientTransportConfig::with_uri("http://localhost:8000/mcp") + /// ); + /// ``` pub fn with_client(client: C, config: StreamableHttpClientTransportConfig) -> Self { let worker = StreamableHttpClientWorker::new(client, config); WorkerTransport::spawn(worker) diff --git a/examples/clients/Cargo.toml b/examples/clients/Cargo.toml index 5dcb2dc6c..5e97c3be0 100644 --- a/examples/clients/Cargo.toml +++ b/examples/clients/Cargo.toml @@ -9,9 +9,9 @@ publish = false [dependencies] rmcp = { workspace = true, features = [ "client", - "transport-sse-client", + "transport-sse-client-reqwest", "reqwest", - "transport-streamable-http-client", + "transport-streamable-http-client-reqwest", "transport-child-process", "tower", "auth" diff --git a/examples/rig-integration/Cargo.toml b/examples/rig-integration/Cargo.toml index 4f643d8a1..a472086d9 100644 --- a/examples/rig-integration/Cargo.toml +++ b/examples/rig-integration/Cargo.toml @@ -17,10 +17,9 @@ rig-core = "0.15.1" tokio = { version = "1", features = ["full"] } rmcp = { workspace = true, features = [ "client", - "reqwest", "transport-child-process", - "transport-sse-client", - "transport-streamable-http-client" + "transport-sse-client-reqwest", + "transport-streamable-http-client-reqwest" ] } anyhow = "1.0" serde_json = "1" diff --git a/examples/simple-chat-client/Cargo.toml b/examples/simple-chat-client/Cargo.toml index c99353ee9..612b7a865 100644 --- a/examples/simple-chat-client/Cargo.toml +++ b/examples/simple-chat-client/Cargo.toml @@ -17,8 +17,7 @@ toml = "0.9" rmcp = { workspace = true, features = [ "client", "transport-child-process", - "transport-sse-client", - "transport-streamable-http-client", - "reqwest" + "transport-sse-client-reqwest", + "transport-streamable-http-client-reqwest" ], no-default-features = true } clap = { version = "4.0", features = ["derive"] } From eeb5ec2843030019ea3db1a546d623d6d3a7372c Mon Sep 17 00:00:00 2001 From: joulei Date: Tue, 19 Aug 2025 11:16:45 -0300 Subject: [PATCH 3/3] feat(rmcp): enhance transport features by decoupling reqwest - Added reqwest features for reqwest-based implementations - Updated documentation - Modified error handling in SSE transport to use `String` - Updated examples to include new features --- README.md | 1 + crates/rmcp/src/handler/server/router/tool.rs | 14 -- crates/rmcp/src/handler/server/tool.rs | 37 ---- crates/rmcp/src/model.rs | 38 +++- crates/rmcp/src/transport/child_process.rs | 179 +++++++++++++++--- .../common/reqwest/streamable_http_client.rs | 2 +- crates/rmcp/src/transport/sse_client.rs | 1 - .../src/transport/streamable_http_client.rs | 25 ++- .../streamable_http_server/session/local.rs | 58 ++++-- .../transport/streamable_http_server/tower.rs | 13 ++ crates/rmcp/src/transport/worker.rs | 26 +-- crates/rmcp/tests/test_structured_output.rs | 17 +- 12 files changed, 280 insertions(+), 131 deletions(-) diff --git a/README.md b/README.md index 64c9d6a10..692b7a6f6 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,7 @@ See [oauth_support](docs/OAUTH_SUPPORT.md) for details. - [Schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/2024-11-05/schema.ts) ## Related Projects +- [rustfs-mcp](https://github.com/rustfs/rustfs/tree/main/crates/mcp) - High-performance MCP server providing S3-compatible object storage operations for AI/LLM integration - [containerd-mcp-server](https://github.com/jokemanfire/mcp-containerd) - A containerd-based MCP server implementation ## Development diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index 2a100de96..4a5ba34f2 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -6,7 +6,6 @@ use schemars::JsonSchema; use crate::{ handler::server::tool::{ CallToolHandler, DynCallToolHandler, ToolCallContext, schema_for_type, - validate_against_schema, }, model::{CallToolResult, Tool, ToolAnnotations}, }; @@ -246,19 +245,6 @@ where let result = (item.call)(context).await?; - // Validate structured content against output schema if present - if let Some(ref output_schema) = item.attr.output_schema { - // When output_schema is defined, structured_content is required - if result.structured_content.is_none() { - return Err(crate::ErrorData::invalid_params( - "Tool with output_schema must return structured_content", - None, - )); - } - // Validate the structured content against the schema - validate_against_schema(result.structured_content.as_ref().unwrap(), output_schema)?; - } - Ok(result) } diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index 8d5c82133..bdb336983 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -67,43 +67,6 @@ pub fn schema_for_type() -> JsonObject { } } -/// Validate that a JSON value conforms to basic type constraints from a schema. -/// -/// Note: This is a basic validation that only checks type compatibility. -/// For full JSON Schema validation, a dedicated validation library would be needed. -pub fn validate_against_schema( - value: &serde_json::Value, - schema: &JsonObject, -) -> Result<(), crate::ErrorData> { - // Basic type validation - if let Some(schema_type) = schema.get("type").and_then(|t| t.as_str()) { - let value_type = get_json_value_type(value); - - if schema_type != value_type { - return Err(crate::ErrorData::invalid_params( - format!( - "Value type does not match schema. Expected '{}', got '{}'", - schema_type, value_type - ), - None, - )); - } - } - - Ok(()) -} - -fn get_json_value_type(value: &serde_json::Value) -> &'static str { - match value { - serde_json::Value::Null => "null", - serde_json::Value::Bool(_) => "boolean", - serde_json::Value::Number(_) => "number", - serde_json::Value::String(_) => "string", - serde_json::Value::Array(_) => "array", - serde_json::Value::Object(_) => "object", - } -} - /// Call [`schema_for_type`] with a cache pub fn cached_schema_for_type() -> Arc { thread_local! { diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index f8a449337..c51eaa1e0 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -15,7 +15,7 @@ pub use extension::*; pub use meta::*; pub use prompt::*; pub use resource::*; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::Value; pub use tool::*; @@ -1260,12 +1260,32 @@ impl CallToolResult { } } - /// Validate that content or structured content is provided - pub fn validate(&self) -> Result<(), &'static str> { - match (&self.content, &self.structured_content) { - (None, None) => Err("either content or structured_content must be provided"), - _ => Ok(()), + /// Convert the `structured_content` part of response into a certain type. + /// + /// # About json schema validation + /// Since rust is a strong type language, we don't need to do json schema validation here. + /// + /// But if you do have to validate the response data, you can use [`jsonschema`](https://crates.io/crates/jsonschema) crate. + pub fn into_typed(self) -> Result + where + T: DeserializeOwned, + { + let raw_text = match (self.structured_content, &self.content) { + (Some(value), _) => return serde_json::from_value(value), + (None, Some(contents)) => { + if let Some(text) = contents.first().and_then(|c| c.as_text()) { + let text = &text.text; + Some(text) + } else { + None + } + } + (None, None) => None, + }; + if let Some(text) = raw_text { + return serde_json::from_str(text); } + serde_json::from_value(serde_json::Value::Null) } } @@ -1294,7 +1314,11 @@ impl<'de> Deserialize<'de> for CallToolResult { }; // Validate mutual exclusivity - result.validate().map_err(serde::de::Error::custom)?; + if result.content.is_none() && result.structured_content.is_none() { + return Err(serde::de::Error::custom( + "CallToolResult must have either content or structured_content", + )); + } Ok(result) } diff --git a/crates/rmcp/src/transport/child_process.rs b/crates/rmcp/src/transport/child_process.rs index 2e7c034ff..e384ad555 100644 --- a/crates/rmcp/src/transport/child_process.rs +++ b/crates/rmcp/src/transport/child_process.rs @@ -1,14 +1,16 @@ use std::process::Stdio; +use futures::future::Future; use process_wrap::tokio::{TokioChildWrapper, TokioCommandWrap}; use tokio::{ io::AsyncRead, process::{ChildStderr, ChildStdin, ChildStdout}, }; -use super::{IntoTransport, Transport}; -use crate::service::ServiceRole; +use super::{RxJsonRpcMessage, Transport, TxJsonRpcMessage, async_rw::AsyncRwTransport}; +use crate::RoleClient; +const MAX_WAIT_ON_DROP_SECS: u64 = 3; /// The parts of a child process. type ChildProcessParts = ( Box, @@ -36,18 +38,23 @@ fn child_process(mut child: Box) -> std::io::Result, } pub struct ChildWithCleanup { - inner: Box, + inner: Option>, } impl Drop for ChildWithCleanup { fn drop(&mut self) { - if let Err(e) = self.inner.start_kill() { - tracing::warn!("Failed to kill child process: {e}"); + // We should not use start_kill(), instead we should use kill() to avoid zombies + if let Some(mut inner) = self.inner.take() { + // We don't care about the result, just try to kill it + tokio::spawn(async move { + if let Err(e) = Box::into_pin(inner.kill()).await { + tracing::warn!("Error killing child process: {}", e); + } + }); } } } @@ -64,7 +71,7 @@ pin_project_lite::pin_project! { impl TokioChildProcessOut { /// Get the process ID of the child process. pub fn id(&self) -> Option { - self.child.inner.id() + self.child.inner.as_ref()?.id() } } @@ -92,23 +99,51 @@ impl TokioChildProcess { /// Get the process ID of the child process. pub fn id(&self) -> Option { - self.child.inner.id() + self.child.inner.as_ref()?.id() + } + + /// Gracefully shutdown the child process + /// + /// This will first wait for the child process to exit normally with a timeout. + /// If the child process doesn't exit within the timeout, it will be killed. + pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> { + if let Some(mut child) = self.child.inner.take() { + let wait_fut = Box::into_pin(child.wait()); + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => { + if let Err(e) = Box::into_pin(child.kill()).await { + tracing::warn!("Error killing child: {e}"); + return Err(e); + } + }, + res = wait_fut => { + match res { + Ok(status) => { + tracing::info!("Child exited gracefully {}", status); + } + Err(e) => { + tracing::warn!("Error waiting for child: {e}"); + return Err(e); + } + } + } + } + } + Ok(()) + } + + /// Take ownership of the inner child process + pub fn into_inner(mut self) -> Option> { + self.child.inner.take() } /// Split this helper into a reader (stdout) and writer (stdin). + #[deprecated( + since = "0.5.0", + note = "use the Transport trait implementation instead" + )] pub fn split(self) -> (TokioChildProcessOut, ChildStdin) { - let TokioChildProcess { - child, - child_stdin, - child_stdout, - } = self; - ( - TokioChildProcessOut { - child, - child_stdout, - }, - child_stdin, - ) + unimplemented!("This method is deprecated, use the Transport trait implementation instead"); } } @@ -156,20 +191,31 @@ impl TokioChildProcessBuilder { let (child, stdout, stdin, stderr_opt) = child_process(self.cmd.spawn()?)?; + let transport = AsyncRwTransport::new(stdout, stdin); let proc = TokioChildProcess { - child: ChildWithCleanup { inner: child }, - child_stdin: stdin, - child_stdout: stdout, + child: ChildWithCleanup { inner: Some(child) }, + transport, }; Ok((proc, stderr_opt)) } } -impl IntoTransport for TokioChildProcess { - fn into_transport(self) -> impl Transport + 'static { - IntoTransport::::into_transport( - self.split(), - ) +impl Transport for TokioChildProcess { + type Error = std::io::Error; + + fn send( + &mut self, + item: TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + self.transport.send(item) + } + + fn receive(&mut self) -> impl Future>> + Send { + self.transport.receive() + } + + fn close(&mut self) -> impl Future> + Send { + self.graceful_shutdown() } } @@ -183,3 +229,78 @@ impl ConfigureCommandExt for tokio::process::Command { self } } + +#[cfg(unix)] +#[cfg(test)] +mod tests { + use tokio::process::Command; + + use super::*; + + #[tokio::test] + async fn test_tokio_child_process_drop() { + let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { + cmd.arg("30"); + })); + assert!(r.is_ok()); + let child_process = r.unwrap(); + let id = child_process.id(); + assert!(id.is_some()); + let id = id.unwrap(); + // Drop the child process + drop(child_process); + // Wait a moment to allow the cleanup task to run + tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; + // Check if the process is still running + let status = Command::new("ps") + .arg("-p") + .arg(id.to_string()) + .status() + .await; + match status { + Ok(status) => { + assert!( + !status.success(), + "Process with PID {} is still running", + id + ); + } + Err(e) => { + panic!("Failed to check process status: {}", e); + } + } + } + + #[tokio::test] + async fn test_tokio_child_process_graceful_shutdown() { + let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { + cmd.arg("30"); + })); + assert!(r.is_ok()); + let mut child_process = r.unwrap(); + let id = child_process.id(); + assert!(id.is_some()); + let id = id.unwrap(); + child_process.graceful_shutdown().await.unwrap(); + // Wait a moment to allow the cleanup task to run + tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; + // Check if the process is still running + let status = Command::new("ps") + .arg("-p") + .arg(id.to_string()) + .status() + .await; + match status { + Ok(status) => { + assert!( + !status.success(), + "Process with PID {} is still running", + id + ); + } + Err(e) => { + panic!("Failed to check process status: {}", e); + } + } + } +} diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index 4f69a6a6c..a18df6e07 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -42,7 +42,7 @@ impl StreamableHttpClient for reqwest::Client { } let response = request_builder.send().await?; if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { - return Err(StreamableHttpError::SeverDoesNotSupportSse); + return Err(StreamableHttpError::ServerDoesNotSupportSse); } let response = response.error_for_status()?; match response.headers().get(reqwest::header::CONTENT_TYPE) { diff --git a/crates/rmcp/src/transport/sse_client.rs b/crates/rmcp/src/transport/sse_client.rs index 910df5f58..7a0705c2f 100644 --- a/crates/rmcp/src/transport/sse_client.rs +++ b/crates/rmcp/src/transport/sse_client.rs @@ -3,7 +3,6 @@ use std::{pin::Pin, sync::Arc}; use futures::{StreamExt, future::BoxFuture}; use http::Uri; - use sse_stream::Error as SseError; use thiserror::Error; diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 1e29dfae9..29f1dc998 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -33,21 +33,28 @@ pub enum StreamableHttpError { #[error("Unexpected content type: {0:?}")] UnexpectedContentType(Option), #[error("Server does not support SSE")] - SeverDoesNotSupportSse, + ServerDoesNotSupportSse, #[error("Server does not support delete session")] - SeverDoesNotSupportDeleteSession, + ServerDoesNotSupportDeleteSession, #[error("Tokio join error: {0}")] TokioJoinError(#[from] tokio::task::JoinError), #[error("Deserialize error: {0}")] Deserialize(#[from] serde_json::Error), #[error("Transport channel closed")] TransportChannelClosed, + #[error("Missing session id in HTTP response")] + MissingSessionIdInResponse, #[cfg(feature = "auth")] #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] #[error("Auth error: {0}")] Auth(#[from] crate::transport::auth::AuthError), } +#[derive(Debug, Clone, Error)] +pub enum StreamableHttpProtocolError { + #[error("Missing session id in response")] + MissingSessionIdInResponse, +} pub enum StreamableHttpPostResponse { Accepted, Json(ServerJsonRpcMessage, Option), @@ -255,7 +262,7 @@ impl Worker for StreamableHttpClientWorker { async fn run( self, mut context: super::worker::WorkerContext, - ) -> Result<(), WorkerQuitReason> { + ) -> Result<(), WorkerQuitReason> { let channel_buffer_capacity = self.config.channel_buffer_capacity; let (sse_worker_tx, mut sse_worker_rx) = tokio::sync::mpsc::channel::(channel_buffer_capacity); @@ -272,7 +279,7 @@ impl Worker for StreamableHttpClientWorker { .post_message(config.uri.clone(), initialize_request, None, None) .await .map_err(WorkerQuitReason::fatal_context("send initialize request"))? - .expect_initialized::() + .expect_initialized::() .await .map_err(WorkerQuitReason::fatal_context( "process initialize response", @@ -282,7 +289,7 @@ impl Worker for StreamableHttpClientWorker { } else { if !self.config.allow_stateless { return Err(WorkerQuitReason::fatal( - "missing session id in initialize response", + StreamableHttpError::::MissingSessionIdInResponse, "process initialize response", )); } @@ -302,7 +309,7 @@ impl Worker for StreamableHttpClientWorker { Ok(_) => { tracing::info!(session_id = session_id.as_ref(), "delete session success") } - Err(StreamableHttpError::SeverDoesNotSupportDeleteSession) => { + Err(StreamableHttpError::ServerDoesNotSupportDeleteSession) => { tracing::info!( session_id = session_id.as_ref(), "server doesn't support delete session" @@ -332,7 +339,7 @@ impl Worker for StreamableHttpClientWorker { .map_err(WorkerQuitReason::fatal_context( "send initialized notification", ))? - .expect_accepted::() + .expect_accepted::() .map_err(WorkerQuitReason::fatal_context( "process initialized notification response", ))?; @@ -367,14 +374,14 @@ impl Worker for StreamableHttpClientWorker { )); tracing::debug!("got common stream"); } - Err(StreamableHttpError::SeverDoesNotSupportSse) => { + Err(StreamableHttpError::ServerDoesNotSupportSse) => { tracing::debug!("server doesn't support sse, skip common stream"); } Err(e) => { // fail to get common stream tracing::error!("fail to get common stream: {e}"); return Err(WorkerQuitReason::fatal( - "fail to get general purpose event stream", + e, "get general purpose event stream", )); } diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index c1c4f8935..5458c4046 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -296,12 +296,8 @@ pub enum SessionError { SessionServiceTerminated, #[error("Invalid event id")] InvalidEventId, - #[error("Transport closed")] - TransportClosed, #[error("IO error: {0}")] Io(#[from] std::io::Error), - #[error("Tokio join error {0}")] - TokioJoinError(#[from] tokio::task::JoinError), } impl From for std::io::Error { @@ -317,7 +313,7 @@ enum OutboundChannel { RequestWise { id: HttpRequestId, close: bool }, Common, } - +#[derive(Debug)] pub struct StreamableHttpMessageReceiver { pub http_request_id: Option, pub inner: Receiver, @@ -534,8 +530,8 @@ impl LocalSessionWorker { } } } - -enum SessionEvent { +#[derive(Debug)] +pub enum SessionEvent { ClientMessage { message: ClientJsonRpcMessage, http_request_id: Option, @@ -695,14 +691,31 @@ impl LocalSessionHandle { pub type SessionTransport = WorkerTransport; +#[derive(Debug, Error)] +pub enum LocalSessionWorkerError { + #[error("transport terminated")] + TransportTerminated, + #[error("unexpected message: {0:?}")] + UnexpectedEvent(SessionEvent), + #[error("fail to send initialize request {0}")] + FailToSendInitializeRequest(SessionError), + #[error("fail to handle message: {0}")] + FailToHandleMessage(SessionError), + #[error("keep alive timeout after {}ms", _0.as_millis())] + KeepAliveTimeout(Duration), + #[error("Transport closed")] + TransportClosed, + #[error("Tokio join error {0}")] + TokioJoinError(#[from] tokio::task::JoinError), +} impl Worker for LocalSessionWorker { - type Error = SessionError; + type Error = LocalSessionWorkerError; type Role = RoleServer; fn err_closed() -> Self::Error { - SessionError::TransportClosed + LocalSessionWorkerError::TransportClosed } fn err_join(e: tokio::task::JoinError) -> Self::Error { - SessionError::TokioJoinError(e) + LocalSessionWorkerError::TokioJoinError(e) } fn config(&self) -> crate::transport::worker::WorkerConfig { crate::transport::worker::WorkerConfig { @@ -711,18 +724,24 @@ impl Worker for LocalSessionWorker { } } #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))] - async fn run(mut self, mut context: WorkerContext) -> Result<(), WorkerQuitReason> { + async fn run( + mut self, + mut context: WorkerContext, + ) -> Result<(), WorkerQuitReason> { enum InnerEvent { FromHttpService(SessionEvent), FromHandler(WorkerSendRequest), } // waiting for initialize request let evt = self.event_rx.recv().await.ok_or_else(|| { - WorkerQuitReason::fatal("transport terminated", "get initialize request") + WorkerQuitReason::fatal( + LocalSessionWorkerError::TransportTerminated, + "get initialize request", + ) })?; let SessionEvent::InitializeRequest { request, responder } = evt else { return Err(WorkerQuitReason::fatal( - "unexpected message", + LocalSessionWorkerError::UnexpectedEvent(evt), "get initialize request", )); }; @@ -732,7 +751,9 @@ impl Worker for LocalSessionWorker { .send(Ok(send_initialize_response.message)) .map_err(|_| { WorkerQuitReason::fatal( - "failed to send initialize response to http service", + LocalSessionWorkerError::FailToSendInitializeRequest( + SessionError::SessionServiceTerminated, + ), "send initialize response", ) })?; @@ -749,7 +770,7 @@ impl Worker for LocalSessionWorker { if let Some(event) = event { InnerEvent::FromHttpService(event) } else { - return Err(WorkerQuitReason::fatal("session dropped", "waiting next session event")) + return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event")) } }, from_handler = context.recv_from_handler() => { @@ -759,7 +780,7 @@ impl Worker for LocalSessionWorker { return Err(WorkerQuitReason::Cancelled) } _ = keep_alive_timeout => { - return Err(WorkerQuitReason::fatal("keep live timeout", "poll next session event")) + return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event")) } }; match event { @@ -779,7 +800,10 @@ impl Worker for LocalSessionWorker { // no need to unregister resource } }; - let handle_result = self.handle_server_message(message).await; + let handle_result = self + .handle_server_message(message) + .await + .map_err(LocalSessionWorkerError::FailToHandleMessage); let _ = responder.send(handle_result).inspect_err(|error| { tracing::warn!(?error, "failed to send message to http service handler"); }); diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 90473b326..475dbff3c 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -44,6 +44,19 @@ impl Default for StreamableHttpServerConfig { } } +/// # Streamable Http Server +/// +/// ## Extract information from raw http request +/// +/// The http service will consume the request body, however the rest part will be remain and injected into [`crate::model::Extensions`], +/// which you can get from [`crate::service::RequestContext`]. +/// ```rust +/// use rmcp::handler::server::tool::Extension; +/// use http::request::Parts; +/// async fn my_tool(Extension(parts): Extension) { +/// tracing::info!("http parts:{parts:?}") +/// } +/// ``` pub struct StreamableHttpService { pub config: StreamableHttpServerConfig, session_manager: Arc, diff --git a/crates/rmcp/src/transport/worker.rs b/crates/rmcp/src/transport/worker.rs index 5ae9098ea..eaabc506e 100644 --- a/crates/rmcp/src/transport/worker.rs +++ b/crates/rmcp/src/transport/worker.rs @@ -7,12 +7,12 @@ use super::{IntoTransport, Transport}; use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}; #[derive(Debug, thiserror::Error)] -pub enum WorkerQuitReason { +pub enum WorkerQuitReason { #[error("Join error {0}")] Join(#[from] tokio::task::JoinError), #[error("Transport fatal {error}, when {context}")] Fatal { - error: Cow<'static, str>, + error: E, context: Cow<'static, str>, }, #[error("Transport canncelled")] @@ -23,18 +23,16 @@ pub enum WorkerQuitReason { HandlerTerminated, } -impl WorkerQuitReason { - pub fn fatal(msg: impl Into>, context: impl Into>) -> Self { +impl WorkerQuitReason { + pub fn fatal(error: E, context: impl Into>) -> Self { Self::Fatal { - error: msg.into(), + error, context: context.into(), } } - pub fn fatal_context( - context: impl Into>, - ) -> impl FnOnce(E) -> Self { + pub fn fatal_context(context: impl Into>) -> impl FnOnce(E) -> Self { |e| Self::Fatal { - error: Cow::Owned(format!("{e}")), + error: e, context: context.into(), } } @@ -48,7 +46,7 @@ pub trait Worker: Sized + Send + 'static { fn run( self, context: WorkerContext, - ) -> impl Future> + Send; + ) -> impl Future>> + Send; fn config(&self) -> WorkerConfig { WorkerConfig::default() } @@ -62,7 +60,7 @@ pub struct WorkerSendRequest { pub struct WorkerTransport { rx: tokio::sync::mpsc::Receiver>, send_service: tokio::sync::mpsc::Sender>, - join_handle: Option>>, + join_handle: Option>>>, _drop_guard: tokio_util::sync::DropGuard, ct: CancellationToken, } @@ -159,14 +157,16 @@ impl WorkerContext { pub async fn send_to_handler( &mut self, item: RxJsonRpcMessage, - ) -> Result<(), WorkerQuitReason> { + ) -> Result<(), WorkerQuitReason> { self.to_handler_tx .send(item) .await .map_err(|_| WorkerQuitReason::HandlerTerminated) } - pub async fn recv_from_handler(&mut self) -> Result, WorkerQuitReason> { + pub async fn recv_from_handler( + &mut self, + ) -> Result, WorkerQuitReason> { self.from_handler_rx .recv() .await diff --git a/crates/rmcp/tests/test_structured_output.rs b/crates/rmcp/tests/test_structured_output.rs index f8e8beb0a..171213e0b 100644 --- a/crates/rmcp/tests/test_structured_output.rs +++ b/crates/rmcp/tests/test_structured_output.rs @@ -170,13 +170,24 @@ async fn test_structured_error_in_call_result() { #[tokio::test] async fn test_mutual_exclusivity_validation() { + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] + pub struct Response { + message: String, + } + let response = Response { + message: "Hello".into(), + }; // Test that content and structured_content can both be passed separately - let content_result = CallToolResult::success(vec![Content::text("Hello")]); + let content_result = CallToolResult::success(vec![Content::json(response.clone()).unwrap()]); let structured_result = CallToolResult::structured(json!({"message": "Hello"})); // Verify the validation - assert!(content_result.validate().is_ok()); - assert!(structured_result.validate().is_ok()); + content_result + .into_typed::() + .expect("Failed to extract content"); + structured_result + .into_typed::() + .expect("Failed to extract content"); // Try to create a result with both fields let json_with_both = json!({