diff --git a/Cargo.lock b/Cargo.lock index 3ae1dbbd31a0..cab8ca6cd998 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5321,6 +5321,7 @@ dependencies = [ "serde_json", "thiserror 1.0.69", "tokio", + "tokio-util", "tower 0.4.13", "tower-service", "tracing", diff --git a/crates/goose-server/src/routes/extension.rs b/crates/goose-server/src/routes/extension.rs index 3b8f8dbcb387..64a941c167c2 100644 --- a/crates/goose-server/src/routes/extension.rs +++ b/crates/goose-server/src/routes/extension.rs @@ -56,6 +56,24 @@ enum ExtensionConfigRequest { display_name: Option, timeout: Option, }, + /// Streamable HTTP extension using MCP Streamable HTTP specification. + #[serde(rename = "streamable_http")] + StreamableHttp { + /// The name to identify this extension + name: String, + /// The URI endpoint for the streamable HTTP extension. + uri: String, + #[serde(default)] + /// Map of environment variable key to values. + envs: Envs, + /// List of environment variable keys. The server will fetch their values from the keyring. + #[serde(default)] + env_keys: Vec, + /// Custom headers to include in requests. + #[serde(default)] + headers: std::collections::HashMap, + timeout: Option, + }, /// Frontend extension that provides tools to be executed by the frontend. #[serde(rename = "frontend")] Frontend { @@ -176,6 +194,23 @@ async fn add_extension( timeout, bundled: None, }, + ExtensionConfigRequest::StreamableHttp { + name, + uri, + envs, + env_keys, + headers, + timeout, + } => ExtensionConfig::StreamableHttp { + name, + uri, + envs, + env_keys, + headers, + description: None, + timeout, + bundled: None, + }, ExtensionConfigRequest::Stdio { name, cmd, diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index 30ac368828fc..545c4c222743 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -15,7 +15,7 @@ use crate::config::permission::PermissionLevel; #[derive(Error, Debug)] pub enum ExtensionError { #[error("Failed to start the MCP server from configuration `{0}` `{1}`")] - Initialization(ExtensionConfig, ClientError), + Initialization(Box, ClientError), #[error("Failed a client call to an MCP server: {0}")] Client(#[from] ClientError), #[error("User Message exceeded context-limit. History could not be truncated to accommodate.")] @@ -54,7 +54,7 @@ impl Envs { "LD_AUDIT", // Loads a monitoring library that can intercept execution "LD_DEBUG", // Enables verbose linker logging (information disclosure risk) "LD_BIND_NOW", // Forces immediate symbol resolution, affecting ASLR - "LD_ASSUME_KERNEL", // Tricks linker into thinking itโ€™s running on an older kernel + "LD_ASSUME_KERNEL", // Tricks linker into thinking it's running on an older kernel // ๐ŸŽ macOS dynamic linker variables "DYLD_LIBRARY_PATH", // Same as LD_LIBRARY_PATH but for macOS "DYLD_INSERT_LIBRARIES", // macOS equivalent of LD_PRELOAD @@ -168,6 +168,26 @@ pub enum ExtensionConfig { #[serde(default)] bundled: Option, }, + /// Streamable HTTP client with a URI endpoint using MCP Streamable HTTP specification + #[serde(rename = "streamable_http")] + StreamableHttp { + /// The name used to identify this extension + name: String, + uri: String, + #[serde(default)] + envs: Envs, + #[serde(default)] + env_keys: Vec, + #[serde(default)] + headers: HashMap, + description: Option, + // NOTE: set timeout to be optional for compatibility. + // However, new configurations should include this field. + timeout: Option, + /// Whether this extension is bundled with Goose + #[serde(default)] + bundled: Option, + }, /// Frontend-provided tools that will be called through the frontend #[serde(rename = "frontend")] Frontend { @@ -207,6 +227,24 @@ impl ExtensionConfig { } } + pub fn streamable_http, T: Into>( + name: S, + uri: S, + description: S, + timeout: T, + ) -> Self { + Self::StreamableHttp { + name: name.into(), + uri: uri.into(), + envs: Envs::default(), + env_keys: Vec::new(), + headers: HashMap::new(), + description: Some(description.into()), + timeout: Some(timeout.into()), + bundled: None, + } + } + pub fn stdio, T: Into>( name: S, cmd: S, @@ -263,6 +301,7 @@ impl ExtensionConfig { pub fn name(&self) -> String { match self { Self::Sse { name, .. } => name, + Self::StreamableHttp { name, .. } => name, Self::Stdio { name, .. } => name, Self::Builtin { name, .. } => name, Self::Frontend { name, .. } => name, @@ -275,6 +314,9 @@ impl std::fmt::Display for ExtensionConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ExtensionConfig::Sse { name, uri, .. } => write!(f, "SSE({}: {})", name, uri), + ExtensionConfig::StreamableHttp { name, uri, .. } => { + write!(f, "StreamableHttp({}: {})", name, uri) + } ExtensionConfig::Stdio { name, cmd, args, .. } => { diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index aa8d117297e0..a418fe112a9e 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -18,7 +18,7 @@ use crate::agents::extension::Envs; use crate::config::{Config, ExtensionConfigManager}; use crate::prompt_template; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; -use mcp_client::transport::{SseTransport, StdioTransport, Transport}; +use mcp_client::transport::{SseTransport, StdioTransport, StreamableHttpTransport, Transport}; use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError}; use serde_json::Value; @@ -195,6 +195,28 @@ impl ExtensionManager { .await?, ) } + ExtensionConfig::StreamableHttp { + uri, + envs, + env_keys, + headers, + timeout, + .. + } => { + let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; + let transport = + StreamableHttpTransport::with_headers(uri, all_envs, headers.clone()); + let handle = transport.start().await?; + Box::new( + McpClient::connect( + handle, + Duration::from_secs( + timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), + ), + ) + .await?, + ) + } ExtensionConfig::Stdio { cmd, args, @@ -256,7 +278,7 @@ impl ExtensionManager { let init_result = client .initialize(info, capabilities) .await - .map_err(|e| ExtensionError::Initialization(config.clone(), e))?; + .map_err(|e| ExtensionError::Initialization(Box::new(config.clone()), e))?; if let Some(instructions) = init_result.instructions { self.instructions @@ -752,10 +774,13 @@ impl ExtensionManager { ExtensionConfig::Sse { description, name, .. } + | ExtensionConfig::StreamableHttp { + description, name, .. + } | ExtensionConfig::Stdio { description, name, .. } => { - // For SSE/Stdio, use description if available + // For SSE/StreamableHttp/Stdio, use description if available description .as_ref() .map(|s| s.to_string()) diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index 95b6728f4f20..7188cf33792c 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -9,6 +9,7 @@ workspace = true [dependencies] mcp-core = { path = "../mcp-core" } tokio = { version = "1", features = ["full"] } +tokio-util = { version = "0.7", features = ["io"] } reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "rustls-tls-native-roots"] } eventsource-client = "0.12.0" futures = "0.3" diff --git a/crates/mcp-client/examples/integration_test.rs b/crates/mcp-client/examples/integration_test.rs index 9d8909ead592..d5de80abfc4f 100644 --- a/crates/mcp-client/examples/integration_test.rs +++ b/crates/mcp-client/examples/integration_test.rs @@ -1,7 +1,7 @@ use anyhow::Result; use futures::lock::Mutex; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; -use mcp_client::transport::{SseTransport, Transport}; +use mcp_client::transport::{SseTransport, StreamableHttpTransport, Transport}; use mcp_client::StdioTransport; use std::collections::HashMap; use std::sync::Arc; @@ -20,6 +20,7 @@ async fn main() -> Result<()> { .init(); test_transport(sse_transport().await?).await?; + test_transport(streamable_http_transport().await?).await?; test_transport(stdio_transport().await?).await?; // Test broken transport @@ -52,6 +53,22 @@ async fn sse_transport() -> Result { )) } +async fn streamable_http_transport() -> Result { + let port = "60054"; + + tokio::process::Command::new("npx") + .env("PORT", port) + .arg("@modelcontextprotocol/server-everything") + .arg("streamable-http") + .spawn()?; + tokio::time::sleep(Duration::from_secs(1)).await; + + Ok(StreamableHttpTransport::new( + format!("http://localhost:{}/mcp", port), + HashMap::new(), + )) +} + async fn stdio_transport() -> Result { Ok(StdioTransport::new( "npx", diff --git a/crates/mcp-client/examples/streamable_http.rs b/crates/mcp-client/examples/streamable_http.rs new file mode 100644 index 000000000000..0fd856ba661b --- /dev/null +++ b/crates/mcp-client/examples/streamable_http.rs @@ -0,0 +1,93 @@ +use anyhow::Result; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; +use mcp_client::transport::{StreamableHttpTransport, Transport}; +use std::collections::HashMap; +use std::time::Duration; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("eventsource_client=info".parse().unwrap()), + ) + .init(); + + // Create example headers + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "example-value".to_string()); + headers.insert( + "User-Agent".to_string(), + "MCP-StreamableHttp-Client/1.0".to_string(), + ); + + // Create the Streamable HTTP transport with headers + let transport = + StreamableHttpTransport::with_headers("http://localhost:8000/mcp", HashMap::new(), headers); + + // Start transport + let handle = transport.start().await?; + + // Create client + let mut client = McpClient::connect(handle, Duration::from_secs(10)).await?; + println!("Client created with Streamable HTTP transport\n"); + + // Initialize + let server_info = client + .initialize( + ClientInfo { + name: "streamable-http-client".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Connected to server: {server_info:?}\n"); + + // Give the server a moment to fully initialize + tokio::time::sleep(Duration::from_millis(500)).await; + + // List tools + let tools = client.list_tools(None).await?; + println!("Available tools: {tools:?}\n"); + + // Call tool if available + if !tools.tools.is_empty() { + let tool_result = client + .call_tool( + &tools.tools[0].name, + serde_json::json!({ "message": "Hello from Streamable HTTP transport!" }), + ) + .await?; + println!("Tool result: {tool_result:?}\n"); + } + + // List resources + let resources = client.list_resources(None).await?; + println!("Resources: {resources:?}\n"); + + // Read resource if available + if !resources.resources.is_empty() { + let resource = client.read_resource(&resources.resources[0].uri).await?; + println!("Resource content: {resource:?}\n"); + } + + // List prompts + let prompts = client.list_prompts(None).await?; + println!("Available prompts: {prompts:?}\n"); + + // Get prompt if available + if !prompts.prompts.is_empty() { + let prompt_result = client + .get_prompt(&prompts.prompts[0].name, serde_json::json!({})) + .await?; + println!("Prompt result: {prompt_result:?}\n"); + } + + println!("Streamable HTTP transport example completed successfully!"); + + Ok(()) +} diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs index 985d89d16046..f6ed51dc467b 100644 --- a/crates/mcp-client/src/lib.rs +++ b/crates/mcp-client/src/lib.rs @@ -4,4 +4,6 @@ pub mod transport; pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait}; pub use service::McpService; -pub use transport::{SseTransport, StdioTransport, Transport, TransportHandle}; +pub use transport::{ + SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle, +}; diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index 28e6d929c633..76895d5126c2 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -30,6 +30,12 @@ pub enum Error { #[error("HTTP error: {status} - {message}")] HttpError { status: u16, message: String }, + + #[error("Streamable HTTP error: {0}")] + StreamableHttpError(String), + + #[error("Session error: {0}")] + SessionError(String), } /// A message that can be sent through the transport @@ -78,3 +84,6 @@ pub use stdio::StdioTransport; pub mod sse; pub use sse::SseTransport; + +pub mod streamable_http; +pub use streamable_http::StreamableHttpTransport; diff --git a/crates/mcp-client/src/transport/streamable_http.rs b/crates/mcp-client/src/transport/streamable_http.rs new file mode 100644 index 000000000000..cc3f4fc5d172 --- /dev/null +++ b/crates/mcp-client/src/transport/streamable_http.rs @@ -0,0 +1,447 @@ +use crate::transport::Error; +use async_trait::async_trait; +use eventsource_client::{Client, SSE}; +use futures::TryStreamExt; +use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; +use reqwest::Client as HttpClient; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex, RwLock}; +use tokio::time::Duration; +use tracing::{debug, error, warn}; +use url::Url; + +use super::{serialize_and_send, Transport, TransportHandle}; + +// Default timeout for HTTP requests +const HTTP_TIMEOUT_SECS: u64 = 30; + +/// The Streamable HTTP transport actor that handles: +/// - HTTP POST requests to send messages to the server +/// - Optional streaming responses for receiving multiple responses and server-initiated messages +/// - Session management with session IDs +pub struct StreamableHttpActor { + /// Receives messages (requests/notifications) from the handle + receiver: mpsc::Receiver, + /// Sends messages (responses) back to the handle + sender: mpsc::Sender, + /// MCP endpoint URL + mcp_endpoint: String, + /// HTTP client for sending requests + http_client: HttpClient, + /// Optional session ID for stateful connections + session_id: Arc>>, + /// Environment variables to set + env: HashMap, + /// Custom headers to include in requests + headers: HashMap, +} + +impl StreamableHttpActor { + pub fn new( + receiver: mpsc::Receiver, + sender: mpsc::Sender, + mcp_endpoint: String, + session_id: Arc>>, + env: HashMap, + headers: HashMap, + ) -> Self { + Self { + receiver, + sender, + mcp_endpoint, + http_client: HttpClient::builder() + .timeout(Duration::from_secs(HTTP_TIMEOUT_SECS)) + .build() + .unwrap(), + session_id, + env, + headers, + } + } + + /// Main entry point for the actor + pub async fn run(mut self) { + // Set environment variables + for (key, value) in &self.env { + std::env::set_var(key, value); + } + + // Handle outgoing messages + while let Some(message_str) = self.receiver.recv().await { + if let Err(e) = self.handle_outgoing_message(message_str).await { + error!("Error handling outgoing message: {}", e); + break; + } + } + + debug!("StreamableHttpActor shut down"); + } + + /// Handle an outgoing message by sending it via HTTP POST + async fn handle_outgoing_message(&mut self, message_str: String) -> Result<(), Error> { + debug!("Sending message to MCP endpoint: {}", message_str); + + // Parse the message to determine if it's a request that expects a response + let parsed_message: JsonRpcMessage = + serde_json::from_str(&message_str).map_err(Error::Serialization)?; + + let expects_response = matches!( + parsed_message, + JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. }) + ); + + // Build the HTTP request + let mut request = self + .http_client + .post(&self.mcp_endpoint) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .body(message_str); + + // Add session ID header if we have one + if let Some(session_id) = self.session_id.read().await.as_ref() { + request = request.header("Mcp-Session-Id", session_id); + } + + // Add custom headers + for (key, value) in &self.headers { + request = request.header(key, value); + } + + // Send the request + let response = request + .send() + .await + .map_err(|e| Error::StreamableHttpError(format!("HTTP request failed: {}", e)))?; + + // Handle HTTP error status codes + if !response.status().is_success() { + let status = response.status(); + if status.as_u16() == 404 { + // Session not found - clear our session ID + *self.session_id.write().await = None; + return Err(Error::SessionError( + "Session expired or not found".to_string(), + )); + } + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(Error::HttpError { + status: status.as_u16(), + message: error_text, + }); + } + + // Check for session ID in response headers + if let Some(session_id_header) = response.headers().get("Mcp-Session-Id") { + if let Ok(session_id) = session_id_header.to_str() { + debug!("Received session ID: {}", session_id); + *self.session_id.write().await = Some(session_id.to_string()); + } + } + + // Handle the response based on content type + let content_type = response + .headers() + .get("content-type") + .and_then(|h| h.to_str().ok()) + .unwrap_or(""); + + if content_type.starts_with("text/event-stream") { + // Handle streaming HTTP response (server chose to stream multiple messages back) + if expects_response { + self.handle_streaming_response(response).await?; + } + } else if content_type.starts_with("application/json") || expects_response { + // Handle single JSON response + let response_text = response.text().await.map_err(|e| { + Error::StreamableHttpError(format!("Failed to read response: {}", e)) + })?; + + if !response_text.is_empty() { + let json_message: JsonRpcMessage = + serde_json::from_str(&response_text).map_err(Error::Serialization)?; + + let _ = self.sender.send(json_message).await; + } + } + // For notifications and responses, we get 202 Accepted with no body + + Ok(()) + } + + /// Handle streaming HTTP response that uses Server-Sent Events format + /// + /// This is called when the server responds to an HTTP POST with `text/event-stream` + /// content-type, indicating it wants to stream multiple JSON-RPC messages back + /// rather than sending a single response. This is part of the Streamable HTTP + /// specification, not a separate SSE transport. + async fn handle_streaming_response( + &mut self, + response: reqwest::Response, + ) -> Result<(), Error> { + use futures::StreamExt; + use tokio::io::AsyncBufReadExt; + use tokio_util::io::StreamReader; + + // Convert the response body to a stream reader + let stream = response + .bytes_stream() + .map(|result| result.map_err(std::io::Error::other)); + let reader = StreamReader::new(stream); + let mut lines = tokio::io::BufReader::new(reader).lines(); + + let mut event_type = String::new(); + let mut event_data = String::new(); + let mut event_id = String::new(); + + while let Ok(Some(line)) = lines.next_line().await { + if line.is_empty() { + // Empty line indicates end of event + if !event_data.is_empty() { + // Parse the streamed data as JSON-RPC message + match serde_json::from_str::(&event_data) { + Ok(message) => { + debug!("Received streaming HTTP response message: {:?}", message); + let _ = self.sender.send(message).await; + } + Err(err) => { + warn!("Failed to parse streaming HTTP response message: {}", err); + } + } + } + // Reset for next event + event_type.clear(); + event_data.clear(); + event_id.clear(); + } else if let Some(field_data) = line.strip_prefix("data: ") { + if !event_data.is_empty() { + event_data.push('\n'); + } + event_data.push_str(field_data); + } else if let Some(field_data) = line.strip_prefix("event: ") { + event_type = field_data.to_string(); + } else if let Some(field_data) = line.strip_prefix("id: ") { + event_id = field_data.to_string(); + } + // Ignore other fields (retry, etc.) - we only care about data + } + + Ok(()) + } +} + +#[derive(Clone)] +pub struct StreamableHttpTransportHandle { + sender: mpsc::Sender, + receiver: Arc>>, + session_id: Arc>>, + mcp_endpoint: String, + http_client: HttpClient, + headers: HashMap, +} + +#[async_trait::async_trait] +impl TransportHandle for StreamableHttpTransportHandle { + async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> { + serialize_and_send(&self.sender, message).await + } + + async fn receive(&self) -> Result { + let mut receiver = self.receiver.lock().await; + receiver.recv().await.ok_or(Error::ChannelClosed) + } +} + +impl StreamableHttpTransportHandle { + /// Manually terminate the session by sending HTTP DELETE + pub async fn terminate_session(&self) -> Result<(), Error> { + if let Some(session_id) = self.session_id.read().await.as_ref() { + let mut request = self + .http_client + .delete(&self.mcp_endpoint) + .header("Mcp-Session-Id", session_id); + + // Add custom headers + for (key, value) in &self.headers { + request = request.header(key, value); + } + + match request.send().await { + Ok(response) => { + if response.status().as_u16() == 405 { + // Method not allowed - server doesn't support session termination + debug!("Server doesn't support session termination"); + } + } + Err(e) => { + warn!("Failed to terminate session: {}", e); + } + } + } + Ok(()) + } + + /// Create a GET request to establish a streaming connection for server-initiated messages + pub async fn listen_for_server_messages(&self) -> Result<(), Error> { + let mut request = self + .http_client + .get(&self.mcp_endpoint) + .header("Accept", "text/event-stream"); + + // Add session ID header if we have one + if let Some(session_id) = self.session_id.read().await.as_ref() { + request = request.header("Mcp-Session-Id", session_id); + } + + // Add custom headers + for (key, value) in &self.headers { + request = request.header(key, value); + } + + let response = request.send().await.map_err(|e| { + Error::StreamableHttpError(format!("Failed to start GET streaming connection: {}", e)) + })?; + + if !response.status().is_success() { + if response.status().as_u16() == 405 { + // Method not allowed - server doesn't support GET streaming connections + debug!("Server doesn't support GET streaming connections"); + return Ok(()); + } + return Err(Error::HttpError { + status: response.status().as_u16(), + message: "Failed to establish GET streaming connection".to_string(), + }); + } + + // Handle the streaming connection in a separate task + let receiver = self.receiver.clone(); + let url = response.url().clone(); + + tokio::spawn(async move { + let client = match eventsource_client::ClientBuilder::for_url(url.as_str()) { + Ok(builder) => builder.build(), + Err(e) => { + error!( + "Failed to create streaming client for GET connection: {}", + e + ); + return; + } + }; + + let mut stream = client.stream(); + while let Ok(Some(event)) = stream.try_next().await { + match event { + SSE::Event(e) if e.event_type == "message" || e.event_type.is_empty() => { + match serde_json::from_str::(&e.data) { + Ok(message) => { + debug!("Received GET streaming message: {:?}", message); + let receiver_guard = receiver.lock().await; + // We can't send through the receiver since it's for outbound messages + // This would need a different channel for server-initiated messages + drop(receiver_guard); + } + Err(err) => { + warn!("Failed to parse GET streaming message: {}", err); + } + } + } + _ => {} + } + } + }); + + Ok(()) + } +} + +#[derive(Clone)] +pub struct StreamableHttpTransport { + mcp_endpoint: String, + env: HashMap, + headers: HashMap, +} + +impl StreamableHttpTransport { + pub fn new>(mcp_endpoint: S, env: HashMap) -> Self { + Self { + mcp_endpoint: mcp_endpoint.into(), + env, + headers: HashMap::new(), + } + } + + pub fn with_headers>( + mcp_endpoint: S, + env: HashMap, + headers: HashMap, + ) -> Self { + Self { + mcp_endpoint: mcp_endpoint.into(), + env, + headers, + } + } + + /// Validate that the URL is a valid MCP endpoint + pub fn validate_endpoint(endpoint: &str) -> Result<(), Error> { + Url::parse(endpoint) + .map_err(|e| Error::StreamableHttpError(format!("Invalid MCP endpoint URL: {}", e)))?; + Ok(()) + } +} + +#[async_trait] +impl Transport for StreamableHttpTransport { + type Handle = StreamableHttpTransportHandle; + + async fn start(&self) -> Result { + // Validate the endpoint URL + Self::validate_endpoint(&self.mcp_endpoint)?; + + // Create channels for communication + let (tx, rx) = mpsc::channel(32); + let (otx, orx) = mpsc::channel(32); + + let session_id: Arc>> = Arc::new(RwLock::new(None)); + let session_id_clone = Arc::clone(&session_id); + + // Create and spawn the actor + let actor = StreamableHttpActor::new( + rx, + otx, + self.mcp_endpoint.clone(), + session_id, + self.env.clone(), + self.headers.clone(), + ); + + tokio::spawn(actor.run()); + + // Create the handle + let handle = StreamableHttpTransportHandle { + sender: tx, + receiver: Arc::new(Mutex::new(orx)), + session_id: session_id_clone, + mcp_endpoint: self.mcp_endpoint.clone(), + http_client: HttpClient::builder() + .timeout(Duration::from_secs(HTTP_TIMEOUT_SECS)) + .build() + .unwrap(), + headers: self.headers.clone(), + }; + + Ok(handle) + } + + async fn close(&self) -> Result<(), Error> { + // The transport is closed when the actor task completes + // No additional cleanup needed + Ok(()) + } +} diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 88132ea0e2f9..9359fa4e2bd9 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -1315,6 +1315,54 @@ } } }, + { + "type": "object", + "description": "Streamable HTTP client with a URI endpoint using MCP Streamable HTTP specification", + "required": [ + "name", + "uri", + "type" + ], + "properties": { + "bundled": { + "type": "boolean", + "description": "Whether this extension is bundled with Goose", + "nullable": true + }, + "description": { + "type": "string", + "nullable": true + }, + "env_keys": { + "type": "array", + "items": { + "type": "string" + } + }, + "envs": { + "$ref": "#/components/schemas/Envs" + }, + "name": { + "type": "string", + "description": "The name used to identify this extension" + }, + "timeout": { + "type": "integer", + "format": "int64", + "nullable": true, + "minimum": 0 + }, + "type": { + "type": "string", + "enum": [ + "streamable_http" + ] + }, + "uri": { + "type": "string" + } + } + }, { "type": "object", "description": "Frontend-provided tools that will be called through the frontend", diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 2cc73e43aed7..2f821de044b6 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -124,6 +124,21 @@ export type ExtensionConfig = { name: string; timeout?: number | null; type: 'builtin'; +} | { + /** + * Whether this extension is bundled with Goose + */ + bundled?: boolean | null; + description?: string | null; + env_keys?: Array; + envs?: Envs; + /** + * The name used to identify this extension + */ + name: string; + timeout?: number | null; + type: 'streamable_http'; + uri: string; } | { /** * Whether this extension is bundled with Goose diff --git a/ui/desktop/src/components/schedule/CreateScheduleModal.tsx b/ui/desktop/src/components/schedule/CreateScheduleModal.tsx index d94735ee01f6..6b6b28c6f862 100644 --- a/ui/desktop/src/components/schedule/CreateScheduleModal.tsx +++ b/ui/desktop/src/components/schedule/CreateScheduleModal.tsx @@ -36,7 +36,7 @@ interface CreateScheduleModalProps { // Interface for clean extension in YAML interface CleanExtension { name: string; - type: 'stdio' | 'sse' | 'builtin' | 'frontend'; + type: 'stdio' | 'sse' | 'builtin' | 'frontend' | 'streamable_http'; cmd?: string; args?: string[]; uri?: string; @@ -160,6 +160,8 @@ function recipeToYaml(recipe: Recipe, executionMode: ExecutionMode): string { if (ext.type === 'sse' && extAny.uri) { cleanExt.uri = extAny.uri as string; + } else if (ext.type === 'streamable_http' && extAny.uri) { + cleanExt.uri = extAny.uri as string; } else if (ext.type === 'stdio') { if (extAny.cmd) { cleanExt.cmd = extAny.cmd as string; @@ -195,7 +197,8 @@ function recipeToYaml(recipe: Recipe, executionMode: ExecutionMode): string { cleanExt.type = 'stdio'; cleanExt.cmd = extAny.command as string; } else if (extAny.uri) { - cleanExt.type = 'sse'; + // Default to streamable_http for URI-based extensions for forward compatibility + cleanExt.type = 'streamable_http'; cleanExt.uri = extAny.uri as string; } else if (extAny.tools) { cleanExt.type = 'frontend'; diff --git a/ui/desktop/src/components/settings/extensions/deeplink.ts b/ui/desktop/src/components/settings/extensions/deeplink.ts index 5c647d0783ef..0436cc932bbd 100644 --- a/ui/desktop/src/components/settings/extensions/deeplink.ts +++ b/ui/desktop/src/components/settings/extensions/deeplink.ts @@ -72,6 +72,26 @@ function getSseConfig(remoteUrl: string, name: string, description: string, time return config; } +/** + * Build an extension config for Streamable HTTP from the deeplink URL + */ +function getStreamableHttpConfig( + remoteUrl: string, + name: string, + description: string, + timeout: number +) { + const config: ExtensionConfig = { + name, + type: 'streamable_http', + uri: remoteUrl, + description, + timeout: timeout, + }; + + return config; +} + /** * Handles adding an extension from a deeplink URL */ @@ -120,9 +140,12 @@ export async function addExtensionFromDeepLink( const cmd = parsedUrl.searchParams.get('cmd'); const remoteUrl = parsedUrl.searchParams.get('url'); + const transportType = parsedUrl.searchParams.get('transport') || 'sse'; // Default to SSE for backward compatibility const config = remoteUrl - ? getSseConfig(remoteUrl, name, description || '', timeout) + ? transportType === 'streamable_http' + ? getStreamableHttpConfig(remoteUrl, name, description || '', timeout) + : getSseConfig(remoteUrl, name, description || '', timeout) : getStdioConfig(cmd!, parsedUrl, name, description || '', timeout); // Check if extension requires env vars and go to settings if so diff --git a/ui/desktop/src/components/settings/extensions/modal/ExtensionConfigFields.tsx b/ui/desktop/src/components/settings/extensions/modal/ExtensionConfigFields.tsx index 9acb085f76a4..52542b5708e2 100644 --- a/ui/desktop/src/components/settings/extensions/modal/ExtensionConfigFields.tsx +++ b/ui/desktop/src/components/settings/extensions/modal/ExtensionConfigFields.tsx @@ -1,7 +1,7 @@ import { Input } from '../../../ui/input'; interface ExtensionConfigFieldsProps { - type: 'stdio' | 'sse' | 'builtin'; + type: 'stdio' | 'sse' | 'streamable_http' | 'builtin'; full_cmd: string; endpoint: string; onChange: (key: string, value: string) => void; diff --git a/ui/desktop/src/components/settings/extensions/modal/ExtensionInfoFields.tsx b/ui/desktop/src/components/settings/extensions/modal/ExtensionInfoFields.tsx index e779505a41d4..691404f99388 100644 --- a/ui/desktop/src/components/settings/extensions/modal/ExtensionInfoFields.tsx +++ b/ui/desktop/src/components/settings/extensions/modal/ExtensionInfoFields.tsx @@ -3,7 +3,7 @@ import { Select } from '../../../ui/Select'; interface ExtensionInfoFieldsProps { name: string; - type: 'stdio' | 'sse' | 'builtin'; + type: 'stdio' | 'sse' | 'streamable_http' | 'builtin'; description: string; onChange: (key: string, value: string) => void; submitAttempted: boolean; @@ -43,7 +43,17 @@ export default function ExtensionInfoFields({
onChange(index, 'key', e.target.value)} + placeholder="Header name" + className={cn( + 'w-full text-textStandard border-borderSubtle hover:border-borderStandard', + isFieldInvalid(index, 'key') && 'border-red-500 focus:border-red-500' + )} + /> +
+
+ onChange(index, 'value', e.target.value)} + placeholder="Value" + className={cn( + 'w-full text-textStandard border-borderSubtle hover:border-borderStandard', + isFieldInvalid(index, 'value') && 'border-red-500 focus:border-red-500' + )} + /> +
+ + + ))} + + {/* Empty row with Add button */} + { + setNewKey(e.target.value); + clearValidation(); + }} + placeholder="Header name" + className={cn( + 'w-full text-textStandard border-borderSubtle hover:border-borderStandard', + invalidFields.key && 'border-red-500 focus:border-red-500' + )} + /> + { + setNewValue(e.target.value); + clearValidation(); + }} + placeholder="Value" + className={cn( + 'w-full text-textStandard border-borderSubtle hover:border-borderStandard', + invalidFields.value && 'border-red-500 focus:border-red-500' + )} + /> + + + {validationError &&
{validationError}
} + + ); +} diff --git a/ui/desktop/src/components/settings/extensions/subcomponents/ExtensionList.tsx b/ui/desktop/src/components/settings/extensions/subcomponents/ExtensionList.tsx index 93f1e2e458b4..ebc7f148be06 100644 --- a/ui/desktop/src/components/settings/extensions/subcomponents/ExtensionList.tsx +++ b/ui/desktop/src/components/settings/extensions/subcomponents/ExtensionList.tsx @@ -93,6 +93,14 @@ export function getSubtitle(config: ExtensionConfig): SubtitleParts { return { description, command }; } + if (config.type === 'streamable_http') { + const description = config.description + ? `Streamable HTTP extension: ${config.description}` + : 'Streamable HTTP extension'; + const command = config.uri || null; + return { description, command }; + } + return { description: 'Unknown type of extension', command: null, diff --git a/ui/desktop/src/components/settings/extensions/utils.ts b/ui/desktop/src/components/settings/extensions/utils.ts index b7205cd9ec4e..5c7f6f173655 100644 --- a/ui/desktop/src/components/settings/extensions/utils.ts +++ b/ui/desktop/src/components/settings/extensions/utils.ts @@ -21,7 +21,7 @@ import { ExtensionConfig } from '../../../api/types.gen'; export interface ExtensionFormData { name: string; description: string; - type: 'stdio' | 'sse' | 'builtin'; + type: 'stdio' | 'sse' | 'streamable_http' | 'builtin'; cmd?: string; endpoint?: string; enabled: boolean; @@ -31,6 +31,11 @@ export interface ExtensionFormData { value: string; isEdited?: boolean; }[]; + headers: { + key: string; + value: string; + isEdited?: boolean; + }[]; } export function getDefaultFormData(): ExtensionFormData { @@ -43,12 +48,14 @@ export function getDefaultFormData(): ExtensionFormData { enabled: true, timeout: 300, envVars: [], + headers: [], }; } export function extensionToFormData(extension: FixedExtensionEntry): ExtensionFormData { // Type guard: Check if 'envs' property exists for this variant - const hasEnvs = extension.type === 'sse' || extension.type === 'stdio'; + const hasEnvs = + extension.type === 'sse' || extension.type === 'streamable_http' || extension.type === 'stdio'; // Handle both envs (legacy) and env_keys (new secrets) let envVars = []; @@ -75,16 +82,32 @@ export function extensionToFormData(extension: FixedExtensionEntry): ExtensionFo ); } + // Handle headers for streamable_http + let headers = []; + if (extension.type === 'streamable_http' && 'headers' in extension && extension.headers) { + headers.push( + ...Object.entries(extension.headers).map(([key, value]) => ({ + key, + value: value as string, + isEdited: false, // Mark as not edited initially + })) + ); + } + return { name: extension.name || '', description: - extension.type === 'stdio' || extension.type === 'sse' ? extension.description || '' : '', + extension.type === 'stdio' || extension.type === 'sse' || extension.type === 'streamable_http' + ? extension.description || '' + : '', type: extension.type === 'frontend' ? 'stdio' : extension.type, cmd: extension.type === 'stdio' ? combineCmdAndArgs(extension.cmd, extension.args) : undefined, - endpoint: extension.type === 'sse' ? extension.uri : undefined, + endpoint: + extension.type === 'sse' || extension.type === 'streamable_http' ? extension.uri : undefined, enabled: extension.enabled, timeout: 'timeout' in extension ? (extension.timeout ?? undefined) : undefined, envVars, + headers, }; } @@ -114,6 +137,27 @@ export function createExtensionConfig(formData: ExtensionFormData): ExtensionCon uri: formData.endpoint || '', ...(env_keys.length > 0 ? { env_keys } : {}), }; + } else if (formData.type === 'streamable_http') { + // Extract headers + const headers = formData.headers + .filter(({ key, value }) => key.length > 0 && value.length > 0) + .reduce( + (acc, header) => { + acc[header.key] = header.value; + return acc; + }, + {} as Record + ); + + return { + type: 'streamable_http', + name: formData.name, + description: formData.description, + timeout: formData.timeout, + uri: formData.endpoint || '', + ...(env_keys.length > 0 ? { env_keys } : {}), + ...(Object.keys(headers).length > 0 ? { headers } : {}), + }; } else { // For other types return { diff --git a/ui/desktop/src/extensions.tsx b/ui/desktop/src/extensions.tsx index 774e61c36344..fbe5515a474a 100644 --- a/ui/desktop/src/extensions.tsx +++ b/ui/desktop/src/extensions.tsx @@ -17,6 +17,14 @@ export type ExtensionConfig = env_keys?: string[]; timeout?: number; } + | { + type: 'streamable_http'; + name: string; + uri: string; + env_keys?: string[]; + headers?: Record; + timeout?: number; + } | { type: 'stdio'; name: string; @@ -73,6 +81,10 @@ export async function addExtension( name: sanitizeName(extension.name), uri: extension.uri, }), + ...(extension.type === 'streamable_http' && { + name: sanitizeName(extension.name), + uri: extension.uri, + }), ...(extension.type === 'builtin' && { name: sanitizeName(extension.name), }),