From 11369fbdbf069b504319fe6dc40ac5b4a99a1448 Mon Sep 17 00:00:00 2001 From: btdeviant Date: Sun, 15 Jun 2025 19:28:39 -0700 Subject: [PATCH 1/3] feat: Adding streamable-http transport MCP support. Retaining sse for backwards compatibility --- crates/goose/src/agents/extension.rs | 11 +-- crates/goose/src/agents/extension_manager.rs | 3 +- crates/mcp-client/examples/streamable_http.rs | 14 ++-- crates/mcp-client/src/lib.rs | 4 +- .../src/transport/streamable_http.rs | 83 +++++++------------ 5 files changed, 40 insertions(+), 75 deletions(-) diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index 545c4c222743..98d07691c2cf 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -227,12 +227,7 @@ impl ExtensionConfig { } } - pub fn streamable_http, T: Into>( - name: S, - uri: S, - description: S, - timeout: T, - ) -> Self { + pub fn streamable_http, T: Into>(name: S, uri: S, description: S, timeout: T) -> Self { Self::StreamableHttp { name: name.into(), uri: uri.into(), @@ -314,9 +309,7 @@ impl std::fmt::Display for ExtensionConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ExtensionConfig::Sse { name, uri, .. } => write!(f, "SSE({}: {})", name, uri), - ExtensionConfig::StreamableHttp { name, uri, .. } => { - write!(f, "StreamableHttp({}: {})", name, uri) - } + ExtensionConfig::StreamableHttp { name, uri, .. } => write!(f, "StreamableHttp({}: {})", name, uri), ExtensionConfig::Stdio { name, cmd, args, .. } => { diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index a418fe112a9e..bb947e68f0d4 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -204,8 +204,7 @@ impl ExtensionManager { .. } => { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; - let transport = - StreamableHttpTransport::with_headers(uri, all_envs, headers.clone()); + let transport = StreamableHttpTransport::with_headers(uri, all_envs, headers.clone()); let handle = transport.start().await?; Box::new( McpClient::connect( diff --git a/crates/mcp-client/examples/streamable_http.rs b/crates/mcp-client/examples/streamable_http.rs index 0fd856ba661b..6aea20f74f95 100644 --- a/crates/mcp-client/examples/streamable_http.rs +++ b/crates/mcp-client/examples/streamable_http.rs @@ -19,14 +19,14 @@ async fn main() -> Result<()> { // Create example headers let mut headers = HashMap::new(); headers.insert("X-Custom-Header".to_string(), "example-value".to_string()); - headers.insert( - "User-Agent".to_string(), - "MCP-StreamableHttp-Client/1.0".to_string(), - ); + headers.insert("User-Agent".to_string(), "MCP-StreamableHttp-Client/1.0".to_string()); // Create the Streamable HTTP transport with headers - let transport = - StreamableHttpTransport::with_headers("http://localhost:8000/mcp", HashMap::new(), headers); + let transport = StreamableHttpTransport::with_headers( + "http://localhost:8000/mcp", + HashMap::new(), + headers + ); // Start transport let handle = transport.start().await?; @@ -90,4 +90,4 @@ async fn main() -> Result<()> { println!("Streamable HTTP transport example completed successfully!"); Ok(()) -} +} \ No newline at end of file diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs index f6ed51dc467b..a3de24d2465a 100644 --- a/crates/mcp-client/src/lib.rs +++ b/crates/mcp-client/src/lib.rs @@ -4,6 +4,4 @@ pub mod transport; pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait}; pub use service::McpService; -pub use transport::{ - SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle, -}; +pub use transport::{SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle}; diff --git a/crates/mcp-client/src/transport/streamable_http.rs b/crates/mcp-client/src/transport/streamable_http.rs index cc3f4fc5d172..9310653f69a0 100644 --- a/crates/mcp-client/src/transport/streamable_http.rs +++ b/crates/mcp-client/src/transport/streamable_http.rs @@ -8,7 +8,7 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::time::Duration; -use tracing::{debug, error, warn}; +use tracing::{debug, warn, error}; use url::Url; use super::{serialize_and_send, Transport, TransportHandle}; @@ -83,17 +83,13 @@ impl StreamableHttpActor { debug!("Sending message to MCP endpoint: {}", message_str); // Parse the message to determine if it's a request that expects a response - let parsed_message: JsonRpcMessage = - serde_json::from_str(&message_str).map_err(Error::Serialization)?; + let parsed_message: JsonRpcMessage = serde_json::from_str(&message_str) + .map_err(|e| Error::Serialization(e))?; - let expects_response = matches!( - parsed_message, - JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. }) - ); + let expects_response = matches!(parsed_message, JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. })); // Build the HTTP request - let mut request = self - .http_client + let mut request = self.http_client .post(&self.mcp_endpoint) .header("Content-Type", "application/json") .header("Accept", "application/json, text/event-stream") @@ -110,9 +106,7 @@ impl StreamableHttpActor { } // Send the request - let response = request - .send() - .await + let response = request.send().await .map_err(|e| Error::StreamableHttpError(format!("HTTP request failed: {}", e)))?; // Handle HTTP error status codes @@ -121,13 +115,9 @@ impl StreamableHttpActor { if status.as_u16() == 404 { // Session not found - clear our session ID *self.session_id.write().await = None; - return Err(Error::SessionError( - "Session expired or not found".to_string(), - )); + return Err(Error::SessionError("Session expired or not found".to_string())); } - let error_text = response - .text() - .await + let error_text = response.text().await .unwrap_or_else(|_| "Unknown error".to_string()); return Err(Error::HttpError { status: status.as_u16(), @@ -144,8 +134,7 @@ impl StreamableHttpActor { } // Handle the response based on content type - let content_type = response - .headers() + let content_type = response.headers() .get("content-type") .and_then(|h| h.to_str().ok()) .unwrap_or(""); @@ -157,14 +146,13 @@ impl StreamableHttpActor { } } else if content_type.starts_with("application/json") || expects_response { // Handle single JSON response - let response_text = response.text().await.map_err(|e| { - Error::StreamableHttpError(format!("Failed to read response: {}", e)) - })?; - + let response_text = response.text().await + .map_err(|e| Error::StreamableHttpError(format!("Failed to read response: {}", e)))?; + if !response_text.is_empty() { - let json_message: JsonRpcMessage = - serde_json::from_str(&response_text).map_err(Error::Serialization)?; - + let json_message: JsonRpcMessage = serde_json::from_str(&response_text) + .map_err(|e| Error::Serialization(e))?; + let _ = self.sender.send(json_message).await; } } @@ -174,23 +162,20 @@ impl StreamableHttpActor { } /// Handle streaming HTTP response that uses Server-Sent Events format - /// + /// /// This is called when the server responds to an HTTP POST with `text/event-stream` /// content-type, indicating it wants to stream multiple JSON-RPC messages back /// rather than sending a single response. This is part of the Streamable HTTP /// specification, not a separate SSE transport. - async fn handle_streaming_response( - &mut self, - response: reqwest::Response, - ) -> Result<(), Error> { + async fn handle_streaming_response(&mut self, response: reqwest::Response) -> Result<(), Error> { use futures::StreamExt; - use tokio::io::AsyncBufReadExt; use tokio_util::io::StreamReader; + use tokio::io::AsyncBufReadExt; // Convert the response body to a stream reader - let stream = response - .bytes_stream() - .map(|result| result.map_err(std::io::Error::other)); + let stream = response.bytes_stream().map(|result| { + result.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) + }); let reader = StreamReader::new(stream); let mut lines = tokio::io::BufReader::new(reader).lines(); @@ -260,8 +245,7 @@ impl StreamableHttpTransportHandle { /// Manually terminate the session by sending HTTP DELETE pub async fn terminate_session(&self) -> Result<(), Error> { if let Some(session_id) = self.session_id.read().await.as_ref() { - let mut request = self - .http_client + let mut request = self.http_client .delete(&self.mcp_endpoint) .header("Mcp-Session-Id", session_id); @@ -287,8 +271,7 @@ impl StreamableHttpTransportHandle { /// Create a GET request to establish a streaming connection for server-initiated messages pub async fn listen_for_server_messages(&self) -> Result<(), Error> { - let mut request = self - .http_client + let mut request = self.http_client .get(&self.mcp_endpoint) .header("Accept", "text/event-stream"); @@ -302,9 +285,8 @@ impl StreamableHttpTransportHandle { request = request.header(key, value); } - let response = request.send().await.map_err(|e| { - Error::StreamableHttpError(format!("Failed to start GET streaming connection: {}", e)) - })?; + let response = request.send().await + .map_err(|e| Error::StreamableHttpError(format!("Failed to start GET streaming connection: {}", e)))?; if !response.status().is_success() { if response.status().as_u16() == 405 { @@ -321,15 +303,12 @@ impl StreamableHttpTransportHandle { // Handle the streaming connection in a separate task let receiver = self.receiver.clone(); let url = response.url().clone(); - + tokio::spawn(async move { let client = match eventsource_client::ClientBuilder::for_url(url.as_str()) { Ok(builder) => builder.build(), Err(e) => { - error!( - "Failed to create streaming client for GET connection: {}", - e - ); + error!("Failed to create streaming client for GET connection: {}", e); return; } }; @@ -376,11 +355,7 @@ impl StreamableHttpTransport { } } - pub fn with_headers>( - mcp_endpoint: S, - env: HashMap, - headers: HashMap, - ) -> Self { + pub fn with_headers>(mcp_endpoint: S, env: HashMap, headers: HashMap) -> Self { Self { mcp_endpoint: mcp_endpoint.into(), env, @@ -444,4 +419,4 @@ impl Transport for StreamableHttpTransport { // No additional cleanup needed Ok(()) } -} +} \ No newline at end of file From 8177e52ecc3112600b6cf563a9159c1b0a8b4831 Mon Sep 17 00:00:00 2001 From: btdeviant Date: Tue, 24 Jun 2025 18:18:30 -0700 Subject: [PATCH 2/3] linting --- crates/goose/src/agents/extension.rs | 11 ++- crates/goose/src/agents/extension_manager.rs | 3 +- crates/mcp-client/examples/streamable_http.rs | 14 ++-- crates/mcp-client/src/lib.rs | 4 +- .../src/transport/streamable_http.rs | 83 ++++++++++++------- 5 files changed, 75 insertions(+), 40 deletions(-) diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index 98d07691c2cf..545c4c222743 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -227,7 +227,12 @@ impl ExtensionConfig { } } - pub fn streamable_http, T: Into>(name: S, uri: S, description: S, timeout: T) -> Self { + pub fn streamable_http, T: Into>( + name: S, + uri: S, + description: S, + timeout: T, + ) -> Self { Self::StreamableHttp { name: name.into(), uri: uri.into(), @@ -309,7 +314,9 @@ impl std::fmt::Display for ExtensionConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ExtensionConfig::Sse { name, uri, .. } => write!(f, "SSE({}: {})", name, uri), - ExtensionConfig::StreamableHttp { name, uri, .. } => write!(f, "StreamableHttp({}: {})", name, uri), + ExtensionConfig::StreamableHttp { name, uri, .. } => { + write!(f, "StreamableHttp({}: {})", name, uri) + } ExtensionConfig::Stdio { name, cmd, args, .. } => { diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index bb947e68f0d4..a418fe112a9e 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -204,7 +204,8 @@ impl ExtensionManager { .. } => { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; - let transport = StreamableHttpTransport::with_headers(uri, all_envs, headers.clone()); + let transport = + StreamableHttpTransport::with_headers(uri, all_envs, headers.clone()); let handle = transport.start().await?; Box::new( McpClient::connect( diff --git a/crates/mcp-client/examples/streamable_http.rs b/crates/mcp-client/examples/streamable_http.rs index 6aea20f74f95..0fd856ba661b 100644 --- a/crates/mcp-client/examples/streamable_http.rs +++ b/crates/mcp-client/examples/streamable_http.rs @@ -19,14 +19,14 @@ async fn main() -> Result<()> { // Create example headers let mut headers = HashMap::new(); headers.insert("X-Custom-Header".to_string(), "example-value".to_string()); - headers.insert("User-Agent".to_string(), "MCP-StreamableHttp-Client/1.0".to_string()); + headers.insert( + "User-Agent".to_string(), + "MCP-StreamableHttp-Client/1.0".to_string(), + ); // Create the Streamable HTTP transport with headers - let transport = StreamableHttpTransport::with_headers( - "http://localhost:8000/mcp", - HashMap::new(), - headers - ); + let transport = + StreamableHttpTransport::with_headers("http://localhost:8000/mcp", HashMap::new(), headers); // Start transport let handle = transport.start().await?; @@ -90,4 +90,4 @@ async fn main() -> Result<()> { println!("Streamable HTTP transport example completed successfully!"); Ok(()) -} \ No newline at end of file +} diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs index a3de24d2465a..f6ed51dc467b 100644 --- a/crates/mcp-client/src/lib.rs +++ b/crates/mcp-client/src/lib.rs @@ -4,4 +4,6 @@ pub mod transport; pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait}; pub use service::McpService; -pub use transport::{SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle}; +pub use transport::{ + SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle, +}; diff --git a/crates/mcp-client/src/transport/streamable_http.rs b/crates/mcp-client/src/transport/streamable_http.rs index 9310653f69a0..cc3f4fc5d172 100644 --- a/crates/mcp-client/src/transport/streamable_http.rs +++ b/crates/mcp-client/src/transport/streamable_http.rs @@ -8,7 +8,7 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::time::Duration; -use tracing::{debug, warn, error}; +use tracing::{debug, error, warn}; use url::Url; use super::{serialize_and_send, Transport, TransportHandle}; @@ -83,13 +83,17 @@ impl StreamableHttpActor { debug!("Sending message to MCP endpoint: {}", message_str); // Parse the message to determine if it's a request that expects a response - let parsed_message: JsonRpcMessage = serde_json::from_str(&message_str) - .map_err(|e| Error::Serialization(e))?; + let parsed_message: JsonRpcMessage = + serde_json::from_str(&message_str).map_err(Error::Serialization)?; - let expects_response = matches!(parsed_message, JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. })); + let expects_response = matches!( + parsed_message, + JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. }) + ); // Build the HTTP request - let mut request = self.http_client + let mut request = self + .http_client .post(&self.mcp_endpoint) .header("Content-Type", "application/json") .header("Accept", "application/json, text/event-stream") @@ -106,7 +110,9 @@ impl StreamableHttpActor { } // Send the request - let response = request.send().await + let response = request + .send() + .await .map_err(|e| Error::StreamableHttpError(format!("HTTP request failed: {}", e)))?; // Handle HTTP error status codes @@ -115,9 +121,13 @@ impl StreamableHttpActor { if status.as_u16() == 404 { // Session not found - clear our session ID *self.session_id.write().await = None; - return Err(Error::SessionError("Session expired or not found".to_string())); + return Err(Error::SessionError( + "Session expired or not found".to_string(), + )); } - let error_text = response.text().await + let error_text = response + .text() + .await .unwrap_or_else(|_| "Unknown error".to_string()); return Err(Error::HttpError { status: status.as_u16(), @@ -134,7 +144,8 @@ impl StreamableHttpActor { } // Handle the response based on content type - let content_type = response.headers() + let content_type = response + .headers() .get("content-type") .and_then(|h| h.to_str().ok()) .unwrap_or(""); @@ -146,13 +157,14 @@ impl StreamableHttpActor { } } else if content_type.starts_with("application/json") || expects_response { // Handle single JSON response - let response_text = response.text().await - .map_err(|e| Error::StreamableHttpError(format!("Failed to read response: {}", e)))?; - + let response_text = response.text().await.map_err(|e| { + Error::StreamableHttpError(format!("Failed to read response: {}", e)) + })?; + if !response_text.is_empty() { - let json_message: JsonRpcMessage = serde_json::from_str(&response_text) - .map_err(|e| Error::Serialization(e))?; - + let json_message: JsonRpcMessage = + serde_json::from_str(&response_text).map_err(Error::Serialization)?; + let _ = self.sender.send(json_message).await; } } @@ -162,20 +174,23 @@ impl StreamableHttpActor { } /// Handle streaming HTTP response that uses Server-Sent Events format - /// + /// /// This is called when the server responds to an HTTP POST with `text/event-stream` /// content-type, indicating it wants to stream multiple JSON-RPC messages back /// rather than sending a single response. This is part of the Streamable HTTP /// specification, not a separate SSE transport. - async fn handle_streaming_response(&mut self, response: reqwest::Response) -> Result<(), Error> { + async fn handle_streaming_response( + &mut self, + response: reqwest::Response, + ) -> Result<(), Error> { use futures::StreamExt; - use tokio_util::io::StreamReader; use tokio::io::AsyncBufReadExt; + use tokio_util::io::StreamReader; // Convert the response body to a stream reader - let stream = response.bytes_stream().map(|result| { - result.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) - }); + let stream = response + .bytes_stream() + .map(|result| result.map_err(std::io::Error::other)); let reader = StreamReader::new(stream); let mut lines = tokio::io::BufReader::new(reader).lines(); @@ -245,7 +260,8 @@ impl StreamableHttpTransportHandle { /// Manually terminate the session by sending HTTP DELETE pub async fn terminate_session(&self) -> Result<(), Error> { if let Some(session_id) = self.session_id.read().await.as_ref() { - let mut request = self.http_client + let mut request = self + .http_client .delete(&self.mcp_endpoint) .header("Mcp-Session-Id", session_id); @@ -271,7 +287,8 @@ impl StreamableHttpTransportHandle { /// Create a GET request to establish a streaming connection for server-initiated messages pub async fn listen_for_server_messages(&self) -> Result<(), Error> { - let mut request = self.http_client + let mut request = self + .http_client .get(&self.mcp_endpoint) .header("Accept", "text/event-stream"); @@ -285,8 +302,9 @@ impl StreamableHttpTransportHandle { request = request.header(key, value); } - let response = request.send().await - .map_err(|e| Error::StreamableHttpError(format!("Failed to start GET streaming connection: {}", e)))?; + let response = request.send().await.map_err(|e| { + Error::StreamableHttpError(format!("Failed to start GET streaming connection: {}", e)) + })?; if !response.status().is_success() { if response.status().as_u16() == 405 { @@ -303,12 +321,15 @@ impl StreamableHttpTransportHandle { // Handle the streaming connection in a separate task let receiver = self.receiver.clone(); let url = response.url().clone(); - + tokio::spawn(async move { let client = match eventsource_client::ClientBuilder::for_url(url.as_str()) { Ok(builder) => builder.build(), Err(e) => { - error!("Failed to create streaming client for GET connection: {}", e); + error!( + "Failed to create streaming client for GET connection: {}", + e + ); return; } }; @@ -355,7 +376,11 @@ impl StreamableHttpTransport { } } - pub fn with_headers>(mcp_endpoint: S, env: HashMap, headers: HashMap) -> Self { + pub fn with_headers>( + mcp_endpoint: S, + env: HashMap, + headers: HashMap, + ) -> Self { Self { mcp_endpoint: mcp_endpoint.into(), env, @@ -419,4 +444,4 @@ impl Transport for StreamableHttpTransport { // No additional cleanup needed Ok(()) } -} \ No newline at end of file +} From 6ee1c0722986aedabff0ab063df7c5077a6ec5d0 Mon Sep 17 00:00:00 2001 From: Jeremiah Williams Date: Mon, 30 Jun 2025 17:51:54 -0700 Subject: [PATCH 3/3] Adding Streamable HTTP with dynamic client registration --- Cargo.lock | 7 + crates/mcp-client/Cargo.toml | 8 + crates/mcp-client/examples/test_auth.rs | 52 +++ crates/mcp-client/src/lib.rs | 2 + crates/mcp-client/src/oauth.rs | 419 ++++++++++++++++++ .../src/transport/streamable_http.rs | 67 ++- 6 files changed, 553 insertions(+), 2 deletions(-) create mode 100644 crates/mcp-client/examples/test_auth.rs create mode 100644 crates/mcp-client/src/oauth.rs diff --git a/Cargo.lock b/Cargo.lock index cab8ca6cd998..5786ec3792df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5311,14 +5311,20 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "axum", + "base64 0.22.1", + "chrono", "eventsource-client", "futures", "mcp-core", + "nanoid", "nix 0.30.1", "rand 0.8.5", "reqwest 0.11.27", "serde", "serde_json", + "serde_urlencoded", + "sha2", "thiserror 1.0.69", "tokio", "tokio-util", @@ -5327,6 +5333,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", + "webbrowser 1.0.4", ] [[package]] diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index 7188cf33792c..a678e8f20643 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -25,5 +25,13 @@ tower = { version = "0.4", features = ["timeout", "util"] } tower-service = "0.3" rand = "0.8" nix = { version = "0.30.1", features = ["process", "signal"] } +# OAuth dependencies +axum = { version = "0.8", features = ["query"] } +base64 = "0.22" +sha2 = "0.10" +chrono = { version = "0.4", features = ["serde"] } +nanoid = "0.4" +webbrowser = "1.0" +serde_urlencoded = "0.7" [dev-dependencies] diff --git a/crates/mcp-client/examples/test_auth.rs b/crates/mcp-client/examples/test_auth.rs new file mode 100644 index 000000000000..d4fba7d6f528 --- /dev/null +++ b/crates/mcp-client/examples/test_auth.rs @@ -0,0 +1,52 @@ +use anyhow::Result; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; +use mcp_client::transport::{StreamableHttpTransport, Transport}; +use std::collections::HashMap; +use std::time::Duration; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("eventsource_client=info".parse().unwrap()), + ) + .init(); + + println!("Testing Streamable HTTP transport with auto-authentication..."); + + // Create the Streamable HTTP transport for any MCP service that supports OAuth + // This example uses a hypothetical MCP endpoint - replace with actual service + let mcp_endpoint = + std::env::var("MCP_ENDPOINT").unwrap_or_else(|_| "https://example.com/mcp".to_string()); + + println!("Using MCP endpoint: {}", mcp_endpoint); + + let transport = StreamableHttpTransport::new(&mcp_endpoint, HashMap::new()); + + // Start transport + let handle = transport.start().await?; + + // Create client + let mut client = McpClient::connect(handle, Duration::from_secs(30)).await?; + println!("Client created with Streamable HTTP transport\n"); + + // Initialize - this should trigger the OAuth flow if authentication is needed + let server_info = client + .initialize( + ClientInfo { + name: "streamable-http-auth-test".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + + println!("Connected to server: {server_info:?}\n"); + println!("Authentication test completed successfully!"); + + Ok(()) +} diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs index f6ed51dc467b..01f55864e0ba 100644 --- a/crates/mcp-client/src/lib.rs +++ b/crates/mcp-client/src/lib.rs @@ -1,8 +1,10 @@ pub mod client; +pub mod oauth; pub mod service; pub mod transport; pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait}; +pub use oauth::{authenticate_service, ServiceConfig}; pub use service::McpService; pub use transport::{ SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle, diff --git a/crates/mcp-client/src/oauth.rs b/crates/mcp-client/src/oauth.rs new file mode 100644 index 000000000000..fc6af957628f --- /dev/null +++ b/crates/mcp-client/src/oauth.rs @@ -0,0 +1,419 @@ +use anyhow::Result; +use axum::{extract::Query, response::Html, routing::get, Router}; +use base64::Engine; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sha2::Digest; +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; +use tokio::sync::{oneshot, Mutex as TokioMutex}; +use url::Url; + +#[derive(Debug, Clone)] +struct OidcEndpoints { + authorization_endpoint: String, + token_endpoint: String, + registration_endpoint: Option, +} + +#[derive(Serialize, Deserialize)] +struct TokenData { + access_token: String, + refresh_token: Option, +} + +#[derive(Serialize, Deserialize)] +struct ClientRegistrationRequest { + redirect_uris: Vec, + token_endpoint_auth_method: String, + grant_types: Vec, + response_types: Vec, + client_name: String, + client_uri: String, +} + +#[derive(Serialize, Deserialize)] +struct ClientRegistrationResponse { + client_id: String, + client_id_issued_at: Option, + #[serde(default)] + client_secret: Option, +} + +/// OAuth configuration for any service +#[derive(Debug, Clone)] +pub struct ServiceConfig { + pub oauth_host: String, + pub redirect_uri: String, + pub client_name: String, + pub client_uri: String, + pub discovery_path: Option, +} + +impl ServiceConfig { + /// Create a generic OAuth configuration from an MCP endpoint URL + /// Extracts the base URL for OAuth discovery + pub fn from_mcp_endpoint(mcp_url: &str) -> Result { + let parsed_url = Url::parse(mcp_url.trim())?; + let oauth_host = format!( + "{}://{}{}", + parsed_url.scheme(), + parsed_url.host_str().ok_or_else(|| { + anyhow::anyhow!("Invalid MCP URL: no host found in {}", mcp_url) + })?, + if let Some(port) = parsed_url.port() { + format!(":{}", port) + } else { + String::new() + } + ); + + Ok(Self { + oauth_host, + redirect_uri: "http://localhost:8020".to_string(), + client_name: "Goose MCP Client".to_string(), + client_uri: "https://github.com/block/goose".to_string(), + discovery_path: None, // Use standard discovery + }) + } + + /// Create configuration with custom discovery path for non-standard services + pub fn with_custom_discovery(mut self, discovery_path: String) -> Self { + self.discovery_path = Some(discovery_path); + self + } +} + +struct OAuthFlow { + endpoints: OidcEndpoints, + client_id: String, + redirect_url: String, + state: String, + verifier: String, +} + +impl OAuthFlow { + fn new(endpoints: OidcEndpoints, client_id: String, redirect_url: String) -> Self { + Self { + endpoints, + client_id, + redirect_url, + state: nanoid::nanoid!(16), + verifier: nanoid::nanoid!(64), + } + } + + /// Register a dynamic client and return the client_id + async fn register_client(endpoints: &OidcEndpoints, config: &ServiceConfig) -> Result { + let Some(registration_endpoint) = &endpoints.registration_endpoint else { + return Err(anyhow::anyhow!("No registration endpoint available")); + }; + + let registration_request = ClientRegistrationRequest { + redirect_uris: vec![config.redirect_uri.clone()], + token_endpoint_auth_method: "none".to_string(), + grant_types: vec![ + "authorization_code".to_string(), + "refresh_token".to_string(), + ], + response_types: vec!["code".to_string()], + client_name: config.client_name.clone(), + client_uri: config.client_uri.clone(), + }; + + tracing::info!("Registering dynamic client with OAuth server..."); + + let client = reqwest::Client::new(); + let resp = client + .post(registration_endpoint) + .header("Content-Type", "application/json") + .json(®istration_request) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let err_text = resp.text().await?; + return Err(anyhow::anyhow!( + "Failed to register client: {} - {}", + status, + err_text + )); + } + + let registration_response: ClientRegistrationResponse = resp.json().await?; + + tracing::info!( + "Client registered successfully with ID: {}", + registration_response.client_id + ); + Ok(registration_response.client_id) + } + + fn get_authorization_url(&self) -> String { + let challenge = { + let digest = sha2::Sha256::digest(self.verifier.as_bytes()); + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest) + }; + + let params = [ + ("response_type", "code"), + ("client_id", &self.client_id), + ("redirect_uri", &self.redirect_url), + ("state", &self.state), + ("code_challenge", &challenge), + ("code_challenge_method", "S256"), + ]; + + format!( + "{}?{}", + self.endpoints.authorization_endpoint, + serde_urlencoded::to_string(params).unwrap() + ) + } + + async fn exchange_code_for_token(&self, code: &str) -> Result { + let params = [ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", &self.redirect_url), + ("code_verifier", &self.verifier), + ("client_id", &self.client_id), + ]; + + let client = reqwest::Client::new(); + let resp = client + .post(&self.endpoints.token_endpoint) + .header("Content-Type", "application/x-www-form-urlencoded") + .form(¶ms) + .send() + .await?; + + if !resp.status().is_success() { + let err_text = resp.text().await?; + return Err(anyhow::anyhow!( + "Failed to exchange code for token: {}", + err_text + )); + } + + let token_response: Value = resp.json().await?; + + let access_token = token_response + .get("access_token") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))? + .to_string(); + + let refresh_token = token_response + .get("refresh_token") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + Ok(TokenData { + access_token, + refresh_token, + }) + } + + async fn execute(&self) -> Result { + // Create a channel that will send the auth code from the callback + let (tx, rx) = oneshot::channel(); + let state = self.state.clone(); + let tx = Arc::new(TokioMutex::new(Some(tx))); + + // Setup a server that will receive the redirect and capture the code + let app = Router::new().route( + "/", + get(move |Query(params): Query>| { + let tx = Arc::clone(&tx); + let state = state.clone(); + async move { + let code = params.get("code").cloned(); + let received_state = params.get("state").cloned(); + + if let (Some(code), Some(received_state)) = (code, received_state) { + if received_state == state { + if let Some(sender) = tx.lock().await.take() { + if sender.send(code).is_ok() { + return Html( + "

Authentication Successful!

You can close this window and return to the application.

", + ); + } + } + Html("

Error

Authentication already completed.

") + } else { + Html("

Error

State mismatch - possible security issue.

") + } + } else { + Html("

Error

Authentication failed - missing parameters.

") + } + } + }), + ); + + // Start the callback server + let redirect_url = Url::parse(&self.redirect_url)?; + let port = redirect_url.port().unwrap_or(8020); + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + + let listener = tokio::net::TcpListener::bind(addr).await?; + + let server_handle = tokio::spawn(async move { + let server = axum::serve(listener, app); + server.await.unwrap(); + }); + + // Open the browser for OAuth + let authorization_url = self.get_authorization_url(); + tracing::info!("Opening browser for OAuth authentication..."); + + if webbrowser::open(&authorization_url).is_err() { + tracing::warn!("Could not open browser automatically. Please open this URL manually:"); + tracing::warn!("{}", authorization_url); + } + + // Wait for the authorization code with a timeout + let code = tokio::time::timeout( + std::time::Duration::from_secs(120), // 2 minute timeout + rx, + ) + .await + .map_err(|_| anyhow::anyhow!("Authentication timed out after 2 minutes"))??; + + // Stop the callback server + server_handle.abort(); + + // Exchange the code for a token + self.exchange_code_for_token(&code).await + } +} + +async fn get_oauth_endpoints( + host: &str, + custom_discovery_path: Option<&str>, +) -> Result { + let base_url = Url::parse(host)?; + let client = reqwest::Client::new(); + + // Define discovery paths to try, with custom path first if provided + let mut discovery_paths = Vec::new(); + if let Some(custom_path) = custom_discovery_path { + discovery_paths.push(custom_path); + } + discovery_paths.extend([ + "/.well-known/oauth-authorization-server", + "/.well-known/openid_configuration", + "/oauth/.well-known/oauth-authorization-server", + "/.well-known/oauth_authorization_server", // Some services use underscore + ]); + + let discovery_paths_for_error = discovery_paths.clone(); // Clone for error message + let mut last_error = None; + + // Try each discovery path until one works + for path in discovery_paths { + match base_url.join(path) { + Ok(discovery_url) => { + tracing::debug!("Trying OAuth discovery at: {}", discovery_url); + + match client.get(discovery_url.clone()).send().await { + Ok(resp) if resp.status().is_success() => { + match resp.json::().await { + Ok(oidc_config) => { + // Try to parse the OAuth configuration + match parse_oauth_config(oidc_config) { + Ok(endpoints) => { + tracing::info!( + "Successfully discovered OAuth endpoints at: {}", + discovery_url + ); + return Ok(endpoints); + } + Err(e) => { + tracing::debug!( + "Invalid OAuth config at {}: {}", + discovery_url, + e + ); + last_error = Some(e); + } + } + } + Err(e) => { + tracing::debug!( + "Failed to parse JSON from {}: {}", + discovery_url, + e + ); + last_error = Some(e.into()); + } + } + } + Ok(resp) => { + tracing::debug!("HTTP {} from {}", resp.status(), discovery_url); + } + Err(e) => { + tracing::debug!("Request failed to {}: {}", discovery_url, e); + last_error = Some(e.into()); + } + } + } + Err(e) => { + tracing::debug!("Invalid discovery URL {}{}: {}", host, path, e); + } + } + } + + Err(last_error.unwrap_or_else(|| { + anyhow::anyhow!( + "No OAuth discovery endpoint found at {}. Tried paths: {:?}", + host, + discovery_paths_for_error + ) + })) +} + +fn parse_oauth_config(oidc_config: Value) -> Result { + let authorization_endpoint = oidc_config + .get("authorization_endpoint") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("authorization_endpoint not found in OAuth configuration"))? + .to_string(); + + let token_endpoint = oidc_config + .get("token_endpoint") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("token_endpoint not found in OAuth configuration"))? + .to_string(); + + let registration_endpoint = oidc_config + .get("registration_endpoint") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + Ok(OidcEndpoints { + authorization_endpoint, + token_endpoint, + registration_endpoint, + }) +} + +/// Perform OAuth flow for a service +pub async fn authenticate_service(config: ServiceConfig) -> Result { + tracing::info!("Starting OAuth authentication for service..."); + + // Get OAuth endpoints using flexible discovery + let endpoints = + get_oauth_endpoints(&config.oauth_host, config.discovery_path.as_deref()).await?; + + // Register dynamic client to get client_id + let client_id = OAuthFlow::register_client(&endpoints, &config).await?; + + // Create and execute OAuth flow with the dynamic client_id + let flow = OAuthFlow::new(endpoints, client_id, config.redirect_uri); + + let token_data = flow.execute().await?; + + tracing::info!("OAuth authentication successful!"); + Ok(token_data.access_token) +} diff --git a/crates/mcp-client/src/transport/streamable_http.rs b/crates/mcp-client/src/transport/streamable_http.rs index cc3f4fc5d172..0eb2b52a386f 100644 --- a/crates/mcp-client/src/transport/streamable_http.rs +++ b/crates/mcp-client/src/transport/streamable_http.rs @@ -1,3 +1,4 @@ +use crate::oauth::{authenticate_service, ServiceConfig}; use crate::transport::Error; use async_trait::async_trait; use eventsource_client::{Client, SSE}; @@ -8,7 +9,7 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::time::Duration; -use tracing::{debug, error, warn}; +use tracing::{debug, error, info, warn}; use url::Url; use super::{serialize_and_send, Transport, TransportHandle}; @@ -91,13 +92,45 @@ impl StreamableHttpActor { JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. }) ); + // Try to send the request + match self.send_request(&message_str, expects_response).await { + Ok(()) => Ok(()), + Err(Error::HttpError { status, .. }) if status == 401 || status == 403 => { + // Authentication challenge - try to authenticate and retry + info!( + "Received authentication challenge ({}), attempting OAuth flow...", + status + ); + + if let Some(token) = self.attempt_authentication().await? { + info!("Authentication successful, retrying request..."); + self.headers + .insert("Authorization".to_string(), format!("Bearer {}", token)); + self.send_request(&message_str, expects_response).await + } else { + Err(Error::StreamableHttpError( + "Authentication failed - service not supported or OAuth flow failed" + .to_string(), + )) + } + } + Err(e) => Err(e), + } + } + + /// Send an HTTP request to the MCP endpoint + async fn send_request( + &mut self, + message_str: &str, + expects_response: bool, + ) -> Result<(), Error> { // Build the HTTP request let mut request = self .http_client .post(&self.mcp_endpoint) .header("Content-Type", "application/json") .header("Accept", "application/json, text/event-stream") - .body(message_str); + .body(message_str.to_string()); // Add session ID header if we have one if let Some(session_id) = self.session_id.read().await.as_ref() { @@ -173,6 +206,36 @@ impl StreamableHttpActor { Ok(()) } + /// Attempt to authenticate with the service + async fn attempt_authentication(&self) -> Result, Error> { + info!("Attempting to authenticate with service..."); + + // Create a generic OAuth configuration from the MCP endpoint + match ServiceConfig::from_mcp_endpoint(&self.mcp_endpoint) { + Ok(config) => { + info!("Created OAuth config for endpoint: {}", self.mcp_endpoint); + + match authenticate_service(config).await { + Ok(token) => { + info!("OAuth authentication successful!"); + Ok(Some(token)) + } + Err(e) => { + warn!("OAuth authentication failed: {}", e); + Err(Error::StreamableHttpError(format!("OAuth failed: {}", e))) + } + } + } + Err(e) => { + warn!( + "Could not create OAuth config from MCP endpoint {}: {}", + self.mcp_endpoint, e + ); + Ok(None) + } + } + } + /// Handle streaming HTTP response that uses Server-Sent Events format /// /// This is called when the server responds to an HTTP POST with `text/event-stream`