diff --git a/client/http.go b/client/http.go new file mode 100644 index 000000000..cb3be35d6 --- /dev/null +++ b/client/http.go @@ -0,0 +1,17 @@ +package client + +import ( + "fmt" + + "github.com/mark3labs/mcp-go/client/transport" +) + +// NewStreamableHttpClient is a convenience method that creates a new streamable-http-based MCP client +// with the given base URL. Returns an error if the URL is invalid. +func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTPCOption) (*Client, error) { + trans, err := transport.NewStreamableHTTP(baseURL, options...) + if err != nil { + return nil, fmt.Errorf("failed to create SSE transport: %w", err) + } + return NewClient(trans), nil +} diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go new file mode 100644 index 000000000..4bc60a25f --- /dev/null +++ b/client/transport/streamable_http.go @@ -0,0 +1,387 @@ +package transport + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +type StreamableHTTPCOption func(*StreamableHTTP) + +func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.headers = headers + } +} + +// WithHTTPTimeout sets the timeout for a HTTP request and stream. +func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.httpClient.Timeout = timeout + } +} + +// StreamableHTTP implements Streamable HTTP transport. +// +// It transmits JSON-RPC messages over individual HTTP requests. One message per request. +// The HTTP response body can either be a single JSON-RPC response, +// or an upgraded SSE stream that concludes with a JSON-RPC response for the same request. +// +// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports +// +// The current implementation does not support the following features: +// - batching +// - continuously listening for server notifications when no request is in flight +// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) +// - resuming stream +// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) +// - server -> client request +type StreamableHTTP struct { + baseURL *url.URL + httpClient *http.Client + headers map[string]string + + sessionID atomic.Value // string + + notificationHandler func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + + closed chan struct{} +} + +// NewStreamableHTTP creates a new Streamable HTTP transport with the given base URL. +// Returns an error if the URL is invalid. +func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) { + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + smc := &StreamableHTTP{ + baseURL: parsedURL, + httpClient: &http.Client{}, + headers: make(map[string]string), + closed: make(chan struct{}), + } + smc.sessionID.Store("") // set initial value to simplify later usage + + for _, opt := range options { + opt(smc) + } + + return smc, nil +} + +// Start initiates the HTTP connection to the server. +func (c *StreamableHTTP) Start(ctx context.Context) error { + // For Streamable HTTP, we don't need to establish a persistent connection + return nil +} + +// Close closes the all the HTTP connections to the server. +func (c *StreamableHTTP) Close() error { + select { + case <-c.closed: + return nil + default: + } + // Cancel all in-flight requests + close(c.closed) + + sessionId := c.sessionID.Load().(string) + if sessionId != "" { + c.sessionID.Store("") + + // notify server session closed + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.baseURL.String(), nil) + if err != nil { + fmt.Printf("failed to create close request\n: %v", err) + return + } + req.Header.Set(headerKeySessionID, sessionId) + res, err := c.httpClient.Do(req) + if err != nil { + fmt.Printf("failed to send close request\n: %v", err) + return + } + res.Body.Close() + }() + } + + return nil +} + +const ( + initializeMethod = "initialize" + headerKeySessionID = "Mcp-Session-Id" +) + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *StreamableHTTP) SendRequest( + ctx context.Context, + request JSONRPCRequest, +) (*JSONRPCResponse, error) { + + // Create a combined context that could be canceled when the client is closed + newCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + select { + case <-c.closed: + cancel() + case <-newCtx.Done(): + // The original context was canceled, no need to do anything + } + }() + ctx = newCtx + + // Marshal request + requestBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + sessionID := c.sessionID.Load() + if sessionID != "" { + req.Header.Set(headerKeySessionID, sessionID.(string)) + } + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // Send request + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Check if we got an error response + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + // handle session closed + if resp.StatusCode == http.StatusNotFound { + c.sessionID.CompareAndSwap(sessionID, "") + return nil, fmt.Errorf("session terminated (404). need to re-initialize") + } + + // handle error response + var errResponse JSONRPCResponse + body, _ := io.ReadAll(resp.Body) + if err := json.Unmarshal(body, &errResponse); err == nil { + return &errResponse, nil + } + return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body) + } + + if request.Method == initializeMethod { + // saved the received session ID in the response + // empty session ID is allowed + if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { + c.sessionID.Store(sessionID) + } + } + + // Handle different response types + switch resp.Header.Get("Content-Type") { + case "application/json": + // Single response + var response JSONRPCResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + // should not be a notification + if response.ID == nil { + return nil, fmt.Errorf("response should contain RPC id: %v", response) + } + + return &response, nil + + case "text/event-stream": + // Server is using SSE for streaming responses + return c.handleSSEResponse(ctx, resp.Body) + + default: + return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type")) + } +} + +// handleSSEResponse processes an SSE stream for a specific request. +// It returns the final result for the request once received, or an error. +func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) { + + // Create a channel for this specific request + responseChan := make(chan *JSONRPCResponse, 1) + defer close(responseChan) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Start a goroutine to process the SSE stream + go c.readSSE(ctx, reader, func(event, data string) { + + // (unsupported: batching) + + var message JSONRPCResponse + if err := json.Unmarshal([]byte(data), &message); err != nil { + fmt.Printf("failed to unmarshal message: %v\n", err) + return + } + + // Handle notification + if message.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(data), ¬ification); err != nil { + fmt.Printf("failed to unmarshal notification: %v\n", err) + return + } + c.notifyMu.RLock() + if c.notificationHandler != nil { + c.notificationHandler(notification) + } + c.notifyMu.RUnlock() + return + } + + responseChan <- &message + }) + + // Wait for the response or context cancellation + select { + case response := <-responseChan: + if response == nil { + return nil, fmt.Errorf("unexpected nil response") + } + return response, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair. +// It will end when the reader is closed (or the context is done). +func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) { + defer reader.Close() + + br := bufio.NewReader(reader) + var event, data string + + for { + select { + case <-ctx.Done(): + return + default: + line, err := br.ReadString('\n') + if err != nil { + if err == io.EOF { + // Process any pending event before exit + if event != "" && data != "" { + handler(event, data) + } + return + } + select { + case <-ctx.Done(): + return + default: + fmt.Printf("SSE stream error: %v\n", err) + return + } + } + + // Remove only newline markers + line = strings.TrimRight(line, "\r\n") + if line == "" { + // Empty line means end of event + if event != "" && data != "" { + handler(event, data) + event = "" + data = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + } +} + +func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + + // Marshal request + requestBody, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + if sessionID := c.sessionID.Load(); sessionID != "" { + req.Header.Set(headerKeySessionID, sessionID.(string)) + } + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // Send request + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf( + "notification failed with status %d: %s", + resp.StatusCode, + body, + ) + } + + return nil +} + +func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.notificationHandler = handler +} + +func (c *StreamableHTTP) GetSessionId() string { + return c.sessionID.Load().(string) +} diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go new file mode 100644 index 000000000..b7b76b96f --- /dev/null +++ b/client/transport/streamable_http_test.go @@ -0,0 +1,425 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// startMockStreamableHTTPServer starts a test HTTP server that implements +// a minimal Streamable HTTP server for testing purposes. +// It returns the server URL and a function to close the server. +func startMockStreamableHTTPServer() (string, func()) { + var sessionID string + var mu sync.Mutex + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle only POST requests + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse incoming JSON-RPC request + var request map[string]any + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&request); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + method := request["method"] + switch method { + case "initialize": + // Generate a new session ID + mu.Lock() + sessionID = fmt.Sprintf("test-session-%d", time.Now().UnixNano()) + mu.Unlock() + w.Header().Set("Mcp-Session-Id", sessionID) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": "initialized", + }) + + case "debug/echo": + // Check session ID + if r.Header.Get("Mcp-Session-Id") != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Echo back the request as the response result + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": request, + }) + + case "debug/echo_notification": + // Check session ID + if r.Header.Get("Mcp-Session-Id") != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Send response and notification + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + notification := map[string]any{ + "jsonrpc": "2.0", + "method": "debug/test", + "params": request, + } + notificationData, _ := json.Marshal(notification) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData) + response := map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "result": request, + } + responseData, _ := json.Marshal(response) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", responseData) + + case "debug/echo_error_string": + // Check session ID + if r.Header.Get("Mcp-Session-Id") != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Return an error response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(request) + json.NewEncoder(w).Encode(map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "error": map[string]interface{}{ + "code": -1, + "message": string(data), + }, + }) + } + }) + + // Start test server + testServer := httptest.NewServer(handler) + return testServer.URL, testServer.Close +} + +func TestStreamableHTTP(t *testing.T) { + // Start mock server + url, closeF := startMockStreamableHTTPServer() + defer closeF() + + // Create transport + trans, err := NewStreamableHTTP(url) + if err != nil { + t.Fatal(err) + } + defer trans.Close() + + // Initialize the transport first + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + initRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + _, err = trans.SendRequest(ctx, initRequest) + if err != nil { + t.Fatal(err) + } + + // Now run the tests + t.Run("SendRequest", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + params := map[string]interface{}{ + "string": "hello world", + "array": []interface{}{1, 2, 3}, + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo", + Params: params, + } + + // Send the request + response, err := trans.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + // Verify response data matches what was sent + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + if result.ID != 1 { + t.Errorf("Expected ID 1, got %d", result.ID) + } + if result.Method != "debug/echo" { + t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) + } + + if str, ok := result.Params["string"].(string); !ok || str != "hello world" { + t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) + } + + if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { + t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) + } + }) + + t.Run("SendRequestWithTimeout", func(t *testing.T) { + // Create a context that's already canceled + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "debug/echo", + } + + // The request should fail because the context is canceled + _, err := trans.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected context canceled error, got nil") + } else if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled error, got: %v", err) + } + }) + + t.Run("SendNotification & NotificationHandler", func(t *testing.T) { + var wg sync.WaitGroup + notificationChan := make(chan mcp.JSONRPCNotification, 1) + + // Set notification handler + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + notificationChan <- notification + }) + + // Send a request that triggers a notification + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo_notification", + } + + _, err := trans.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + wg.Add(1) + go func() { + defer wg.Done() + select { + case notification := <-notificationChan: + // We received a notification + got := notification.Params.AdditionalFields + if got == nil { + t.Errorf("Notification handler did not send the expected notification: got nil") + } + if int64(got["id"].(float64)) != request.ID || + got["jsonrpc"] != request.JSONRPC || + got["method"] != request.Method { + + responseJson, _ := json.Marshal(got) + requestJson, _ := json.Marshal(request) + t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) + } + + case <-time.After(1 * time.Second): + t.Errorf("Expected notification, got none") + } + }() + + wg.Wait() + }) + + t.Run("MultipleRequests", func(t *testing.T) { + var wg sync.WaitGroup + const numRequests = 5 + + // Send multiple requests concurrently + mu := sync.Mutex{} + responses := make([]*JSONRPCResponse, numRequests) + errors := make([]error, numRequests) + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Each request has a unique ID and payload + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: int64(100 + idx), + Method: "debug/echo", + Params: map[string]interface{}{ + "requestIndex": idx, + "timestamp": time.Now().UnixNano(), + }, + } + + resp, err := trans.SendRequest(ctx, request) + mu.Lock() + responses[idx] = resp + errors[idx] = err + mu.Unlock() + }(i) + } + + wg.Wait() + + // Check results + for i := 0; i < numRequests; i++ { + if errors[i] != nil { + t.Errorf("Request %d failed: %v", i, errors[i]) + continue + } + + if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { + t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + continue + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(responses[i].Result, &result); err != nil { + t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) + continue + } + + // Verify data matches what was sent + if result.ID != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + } + + if result.Method != "debug/echo" { + t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) + } + + // Verify the requestIndex parameter + if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { + t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) + } + } + }) + + t.Run("ResponseError", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 100, + Method: "debug/echo_error_string", + } + + reps, err := trans.SendRequest(ctx, request) + if err != nil { + t.Errorf("SendRequest failed: %v", err) + } + + if reps.Error == nil { + t.Errorf("Expected error, got nil") + } + + var responseError JSONRPCRequest + if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { + t.Errorf("Failed to unmarshal result: %v", err) + return + } + + if responseError.Method != "debug/echo_error_string" { + t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) + } + if responseError.ID != 100 { + t.Errorf("Expected ID 100, got %d", responseError.ID) + } + if responseError.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) + } + }) +} + +func TestStreamableHTTPErrors(t *testing.T) { + t.Run("InvalidURL", func(t *testing.T) { + // Create a new StreamableHTTP transport with an invalid URL + _, err := NewStreamableHTTP("://invalid-url") + if err == nil { + t.Errorf("Expected error when creating with invalid URL, got nil") + } + }) + + t.Run("NonExistentURL", func(t *testing.T) { + // Create a new StreamableHTTP transport with a non-existent URL + trans, err := NewStreamableHTTP("http://localhost:1") + if err != nil { + t.Fatalf("Failed to create StreamableHTTP transport: %v", err) + } + + // Send request should fail + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + + _, err = trans.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected error when sending request to non-existent URL, got nil") + } + }) + +}