diff --git a/Cargo.lock b/Cargo.lock index ad6b9772759c..34534d5e29a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1826,6 +1826,15 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "colored" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "combine" version = "4.6.7" @@ -2813,7 +2822,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4449,7 +4458,7 @@ checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" dependencies = [ "hermit-abi 0.4.0", "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -5188,7 +5197,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -5384,6 +5393,7 @@ dependencies = [ "eventsource-client", "futures", "mcp-core", + "mockito", "nanoid", "nix 0.30.1", "rand 0.8.5", @@ -5574,6 +5584,30 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "mockito" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7760e0e418d9b7e5777c0374009ca4c93861b9066f18cb334a20ce50ab63aa48" +dependencies = [ + "assert-json-diff", + "bytes", + "colored", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "log", + "rand 0.9.1", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "moka" version = "0.11.3" @@ -6682,7 +6716,7 @@ dependencies = [ "once_cell", "socket2 0.5.8", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -7231,7 +7265,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.4.15", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -7244,7 +7278,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.9.4", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -7769,6 +7803,12 @@ dependencies = [ "quote", ] +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "simple_asn1" version = "0.6.3" @@ -8291,7 +8331,7 @@ dependencies = [ "getrandom 0.3.1", "once_cell", "rustix 0.38.44", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -9481,7 +9521,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index 92425e8d5216..cc05182a2285 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -36,3 +36,4 @@ webbrowser = "1.0" serde_urlencoded = "0.7" [dev-dependencies] +mockito = "1.5" diff --git a/crates/mcp-client/src/transport/streamable_http.rs b/crates/mcp-client/src/transport/streamable_http.rs index 760d861ffba7..77c534fc0071 100644 --- a/crates/mcp-client/src/transport/streamable_http.rs +++ b/crates/mcp-client/src/transport/streamable_http.rs @@ -84,8 +84,8 @@ impl StreamableHttpActor { debug!("Sending message to MCP endpoint: {}", message_str); // Parse the message to determine if it's a request that expects a response - let parsed_message = serde_json::from_str::(&message_str) - .map_err(Error::Serialization)?; + let parsed_message = + serde_json::from_str::(&message_str).map_err(Error::Serialization)?; let expects_response = matches!( parsed_message, @@ -511,3 +511,491 @@ impl Transport for StreamableHttpTransport { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use mockito::Server; + use serde_json::json; + use std::collections::HashMap; + use std::sync::Arc; + use tokio::sync::mpsc; + use tokio::sync::RwLock; + + #[test] + fn test_message_parsing_request() { + // Test that we can parse a JSON-RPC request message using the mcp-core types + let request_json = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "capabilities": {} + } + }); + + let message_str = serde_json::to_string(&request_json).unwrap(); + let parsed_message = serde_json::from_str::(&message_str); + assert!( + parsed_message.is_ok(), + "Should be able to parse JSON-RPC request message" + ); + } + + #[test] + fn test_message_parsing_response() { + // Test that we can parse a JSON-RPC response message + let response_json = json!({ + "jsonrpc": "2.0", + "id": 1, + "result": { + "capabilities": {} + } + }); + + let message_str = serde_json::to_string(&response_json).unwrap(); + let parsed_message = serde_json::from_str::(&message_str); + assert!( + parsed_message.is_ok(), + "Should be able to parse JSON-RPC response message" + ); + } + + #[test] + fn test_message_parsing_notification() { + // Test that we can parse a JSON-RPC notification message + let notification_json = json!({ + "jsonrpc": "2.0", + "method": "initialized", + "params": {} + }); + + let message_str = serde_json::to_string(¬ification_json).unwrap(); + let parsed_message = serde_json::from_str::(&message_str); + assert!( + parsed_message.is_ok(), + "Should be able to parse JSON-RPC notification message" + ); + } + + #[test] + fn test_message_parsing_error() { + // Test that we can parse a JSON-RPC error message + let error_json = json!({ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32600, + "message": "Invalid Request" + } + }); + + let message_str = serde_json::to_string(&error_json).unwrap(); + let parsed_message = serde_json::from_str::(&message_str); + assert!( + parsed_message.is_ok(), + "Should be able to parse JSON-RPC error message" + ); + } + + #[test] + fn test_message_parsing_invalid_json() { + let invalid_json = "{ invalid json }"; + let parsed_message = serde_json::from_str::(invalid_json); + assert!(parsed_message.is_err(), "Invalid JSON should fail to parse"); + } + + #[test] + fn test_transport_message_recv_parsing() { + // Test that we can parse messages as TransportMessageRecv (the type used for incoming messages) + let response_json = json!({ + "jsonrpc": "2.0", + "id": 1, + "result": { + "capabilities": {} + } + }); + + let message_str = serde_json::to_string(&response_json).unwrap(); + + // For incoming messages + let parsed_message = serde_json::from_str::(&message_str); + assert!( + parsed_message.is_ok(), + "Should be able to parse response as TransportMessageRecv" + ); + } + + #[test] + fn test_untagged_enum_serialization_issue() { + let request_json = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": {} + }); + + let message_str = serde_json::to_string(&request_json).unwrap(); + + let parsed_as_jsonrpc = serde_json::from_str::(&message_str); + assert!( + parsed_as_jsonrpc.is_ok(), + "Should be able to parse request as JsonRpcMessage" + ); + } + + #[test] + fn test_expects_response_logic_with_number_id() { + // Check if a message expects a response + let request_json = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": {} + }); + + let message_str = serde_json::to_string(&request_json).unwrap(); + let parsed_message = serde_json::from_str::(&message_str).unwrap(); + + // This should match the logic in handle_outgoing_message after the fix + // The original code used: JsonRpcMessage::Request(JsonRpcRequest { id: Number(_), .. }) + let expects_response = match parsed_message { + JsonRpcMessage::Request(_) => true, + _ => false, + }; + + assert!(expects_response, "Request with ID should expect a response"); + + // Test notification (should not expect response) + let notification_json = json!({ + "jsonrpc": "2.0", + "method": "initialized", + "params": {} + }); + + let message_str = serde_json::to_string(¬ification_json).unwrap(); + let parsed_message = serde_json::from_str::(&message_str).unwrap(); + + let expects_response = match parsed_message { + JsonRpcMessage::Request(_) => true, + _ => false, + }; + + assert!( + !expects_response, + "Notification should not expect a response" + ); + } + + #[tokio::test] + async fn test_handle_outgoing_message_successful_request() { + // Set up a mock HTTP server + let mut server = Server::new_async().await; + let mock = server + .mock("POST", "/") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"jsonrpc":"2.0","id":1,"result":{"capabilities":{}}}"#) + .create_async() + .await; + + // Create channels for the actor + let (_tx, rx) = mpsc::channel(32); + let (otx, mut orx) = mpsc::channel(32); + + // Create the actor + let session_id = Arc::new(RwLock::new(None)); + let mut actor = StreamableHttpActor::new( + rx, + otx, + server.url(), + session_id, + HashMap::new(), + HashMap::new(), + ); + + // Create a JSON-RPC request message + let request_json = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "capabilities": {} + } + }); + let message_str = serde_json::to_string(&request_json).unwrap(); + + // Test handle_outgoing_message + let result = actor.handle_outgoing_message(message_str).await; + assert!(result.is_ok(), "handle_outgoing_message should succeed"); + + // Verify the mock was called + mock.assert_async().await; + + // Check that a response was received + let response = + tokio::time::timeout(std::time::Duration::from_millis(100), orx.recv()).await; + assert!(response.is_ok(), "Should receive a response"); + assert!(response.unwrap().is_some(), "Response should not be None"); + } + + #[tokio::test] + async fn test_handle_outgoing_message_notification() { + // Set up a mock HTTP server for notifications (202 Accepted, no body) + let mut server = Server::new_async().await; + let mock = server + .mock("POST", "/") + .with_status(202) + .create_async() + .await; + + // Create channels for the actor + let (_tx, rx) = mpsc::channel(32); + let (otx, mut orx) = mpsc::channel(32); + + // Create the actor + let session_id = Arc::new(RwLock::new(None)); + let mut actor = StreamableHttpActor::new( + rx, + otx, + server.url(), + session_id, + HashMap::new(), + HashMap::new(), + ); + + // Create a JSON-RPC notification message (no id) + let notification_json = json!({ + "jsonrpc": "2.0", + "method": "initialized", + "params": {} + }); + let message_str = serde_json::to_string(¬ification_json).unwrap(); + + // Test handle_outgoing_message + let result = actor.handle_outgoing_message(message_str).await; + assert!( + result.is_ok(), + "handle_outgoing_message should succeed for notification" + ); + + // Verify the mock was called + mock.assert_async().await; + + // For notifications, we shouldn't receive a response + let response = + tokio::time::timeout(std::time::Duration::from_millis(100), orx.recv()).await; + assert!( + response.is_err(), + "Should not receive a response for notification" + ); + } + + #[tokio::test] + async fn test_handle_outgoing_message_http_error() { + // Set up a mock HTTP server that returns an error + let mut server = Server::new_async().await; + let mock = server + .mock("POST", "/") + .with_status(500) + .with_body("Internal Server Error") + .create_async() + .await; + + // Create channels for the actor + let (_tx, rx) = mpsc::channel(32); + let (otx, _orx) = mpsc::channel(32); + + // Create the actor + let session_id = Arc::new(RwLock::new(None)); + let mut actor = StreamableHttpActor::new( + rx, + otx, + server.url(), + session_id, + HashMap::new(), + HashMap::new(), + ); + + // Create a JSON-RPC request message + let request_json = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "test", + "params": {} + }); + let message_str = serde_json::to_string(&request_json).unwrap(); + + // Test handle_outgoing_message + let result = actor.handle_outgoing_message(message_str).await; + assert!( + result.is_err(), + "handle_outgoing_message should fail with HTTP error" + ); + + // Verify it's an HTTP error + match result.unwrap_err() { + Error::HttpError { status, .. } => { + assert_eq!(status, 500, "Should return HTTP 500 error"); + } + _ => panic!("Expected HttpError"), + } + + // Verify the mock was called + mock.assert_async().await; + } + + #[tokio::test] + async fn test_handle_outgoing_message_session_id_handling() { + // Set up a mock HTTP server that returns a session ID + let mut server = Server::new_async().await; + let mock = server + .mock("POST", "/") + .with_status(200) + .with_header("content-type", "application/json") + .with_header("Mcp-Session-Id", "test-session-123") + .with_body(r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) + .create_async() + .await; + + // Create channels for the actor + let (_tx, rx) = mpsc::channel(32); + let (otx, _orx) = mpsc::channel(32); + + // Create the actor + let session_id = Arc::new(RwLock::new(None)); + let session_id_clone = Arc::clone(&session_id); + let mut actor = StreamableHttpActor::new( + rx, + otx, + server.url(), + session_id, + HashMap::new(), + HashMap::new(), + ); + + // Create a JSON-RPC request message + let request_json = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": {} + }); + let message_str = serde_json::to_string(&request_json).unwrap(); + + // Test handle_outgoing_message + let result = actor.handle_outgoing_message(message_str).await; + assert!(result.is_ok(), "handle_outgoing_message should succeed"); + + // Verify the session ID was stored + let stored_session_id = session_id_clone.read().await; + assert_eq!( + stored_session_id.as_ref(), + Some(&"test-session-123".to_string()), + "Session ID should be stored" + ); + + // Verify the mock was called + mock.assert_async().await; + } + + #[tokio::test] + async fn test_handle_outgoing_message_invalid_json() { + // Create channels for the actor + let (_tx, rx) = mpsc::channel(32); + let (otx, _orx) = mpsc::channel(32); + + // Create the actor + let session_id = Arc::new(RwLock::new(None)); + let mut actor = StreamableHttpActor::new( + rx, + otx, + "http://localhost:8080".to_string(), + session_id, + HashMap::new(), + HashMap::new(), + ); + + // Test with invalid JSON + let invalid_json = "{ invalid json }"; + + // Test handle_outgoing_message + let result = actor + .handle_outgoing_message(invalid_json.to_string()) + .await; + assert!( + result.is_err(), + "handle_outgoing_message should fail with invalid JSON" + ); + + // Verify it's a serialization error + match result.unwrap_err() { + Error::Serialization(_) => { + // Expected error type + } + _ => panic!("Expected Serialization error"), + } + } + + #[tokio::test] + async fn test_handle_outgoing_message_session_not_found() { + // Set up a mock HTTP server that returns 404 (session not found) + let mut server = Server::new_async().await; + let mock = server + .mock("POST", "/") + .with_status(404) + .with_body("Session not found") + .create_async() + .await; + + // Create channels for the actor + let (_tx, rx) = mpsc::channel(32); + let (otx, _orx) = mpsc::channel(32); + + // Create the actor with an existing session ID + let session_id = Arc::new(RwLock::new(Some("old-session".to_string()))); + let session_id_clone = Arc::clone(&session_id); + let mut actor = StreamableHttpActor::new( + rx, + otx, + server.url(), + session_id, + HashMap::new(), + HashMap::new(), + ); + + // Create a JSON-RPC request message + let request_json = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "test", + "params": {} + }); + let message_str = serde_json::to_string(&request_json).unwrap(); + + // Test handle_outgoing_message + let result = actor.handle_outgoing_message(message_str).await; + assert!( + result.is_err(), + "handle_outgoing_message should fail with 404" + ); + + // Verify it's a session error and the session ID was cleared + match result.unwrap_err() { + Error::SessionError(_) => { + // Expected error type + } + _ => panic!("Expected SessionError"), + } + + // Verify the session ID was cleared + let stored_session_id = session_id_clone.read().await; + assert!( + stored_session_id.is_none(), + "Session ID should be cleared on 404" + ); + + // Verify the mock was called + mock.assert_async().await; + } +}