From 5a8cd659b7529abaf3fe57adf675104db86d021f Mon Sep 17 00:00:00 2001 From: Matthis Holleville Date: Wed, 13 Aug 2025 16:34:59 +0200 Subject: [PATCH 1/5] feat: add support for custom HTTP headers in client requests This update introduces the ability to include custom HTTP headers in requests sent from the client. This enhancement facilitates more flexible and secure communication with servers by allowing clients to pass additional information in the header of each request, such as authentication tokens or custom metadata. This feature is crucial for integrating with APIs that require specific headers for access control, content negotiation, or tracking purposes. Signed-off-by: Matthis Holleville --- client/client.go | 23 +++++----- client/transport/interface.go | 2 + client/transport/streamable_http.go | 26 +++++++----- client/transport/streamable_http_test.go | 53 ++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 20 deletions(-) diff --git a/client/client.go b/client/client.go index cda7665e..3ba2bc3c 100644 --- a/client/client.go +++ b/client/client.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "slices" "sync" "sync/atomic" @@ -130,6 +131,7 @@ func (c *Client) sendRequest( ctx context.Context, method string, params any, + header http.Header, ) (*json.RawMessage, error) { if !c.initialized && method != "initialize" { return nil, fmt.Errorf("client not initialized") @@ -142,6 +144,7 @@ func (c *Client) sendRequest( ID: mcp.NewRequestId(id), Method: method, Params: params, + Header: header, } response, err := c.transport.SendRequest(ctx, request) @@ -179,7 +182,7 @@ func (c *Client) Initialize( Capabilities: capabilities, } - response, err := c.sendRequest(ctx, "initialize", params) + response, err := c.sendRequest(ctx, "initialize", params, request.Header) if err != nil { return nil, err } @@ -224,7 +227,7 @@ func (c *Client) Initialize( } func (c *Client) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil) + _, err := c.sendRequest(ctx, "ping", nil, nil) return err } @@ -305,7 +308,7 @@ func (c *Client) ReadResource( ctx context.Context, request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params) + response, err := c.sendRequest(ctx, "resources/read", request.Params, request.Header) if err != nil { return nil, err } @@ -317,7 +320,7 @@ func (c *Client) Subscribe( ctx context.Context, request mcp.SubscribeRequest, ) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) + _, err := c.sendRequest(ctx, "resources/subscribe", request.Params, request.Header) return err } @@ -325,7 +328,7 @@ func (c *Client) Unsubscribe( ctx context.Context, request mcp.UnsubscribeRequest, ) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) + _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params, request.Header) return err } @@ -369,7 +372,7 @@ func (c *Client) GetPrompt( ctx context.Context, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params) + response, err := c.sendRequest(ctx, "prompts/get", request.Params, request.Header) if err != nil { return nil, err } @@ -417,7 +420,7 @@ func (c *Client) CallTool( ctx context.Context, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params) + response, err := c.sendRequest(ctx, "tools/call", request.Params, request.Header) if err != nil { return nil, err } @@ -429,7 +432,7 @@ func (c *Client) SetLevel( ctx context.Context, request mcp.SetLevelRequest, ) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) + _, err := c.sendRequest(ctx, "logging/setLevel", request.Params, request.Header) return err } @@ -437,7 +440,7 @@ func (c *Client) Complete( ctx context.Context, request mcp.CompleteRequest, ) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params) + response, err := c.sendRequest(ctx, "completion/complete", request.Params, request.Header) if err != nil { return nil, err } @@ -514,7 +517,7 @@ func listByPage[T any]( request mcp.PaginatedRequest, method string, ) (*T, error) { - response, err := client.sendRequest(ctx, method, request.Params) + response, err := client.sendRequest(ctx, method, request.Params, nil) if err != nil { return nil, err } diff --git a/client/transport/interface.go b/client/transport/interface.go index e6feeb74..e5c8dec5 100644 --- a/client/transport/interface.go +++ b/client/transport/interface.go @@ -3,6 +3,7 @@ package transport import ( "context" "encoding/json" + "net/http" "github.com/mark3labs/mcp-go/mcp" ) @@ -59,6 +60,7 @@ type JSONRPCRequest struct { ID mcp.RequestId `json:"id"` Method string `json:"method"` Params any `json:"params,omitempty"` + Header http.Header `json:"-"` } type JSONRPCResponse struct { diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 268aeb34..6d7c6815 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -258,7 +258,7 @@ func (c *StreamableHTTP) SendRequest( ctx, cancel := c.contextAwareOfClientClose(ctx) defer cancel() - resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream", request.Header) if err != nil { if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) { // If the request is initialize, should not return a SessionTerminated error @@ -339,6 +339,7 @@ func (c *StreamableHTTP) sendHTTP( method string, body io.Reader, acceptType string, + header http.Header, ) (resp *http.Response, err error) { // Create HTTP request req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body) @@ -346,6 +347,11 @@ func (c *StreamableHTTP) sendHTTP( return nil, fmt.Errorf("failed to create request: %w", err) } + // request headers + if header != nil { + req.Header = header + } + // Set headers req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", acceptType) @@ -546,7 +552,7 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. ctx, cancel := c.contextAwareOfClientClose(ctx) defer cancel() - resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream", nil) if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -605,7 +611,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) { connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second) err := c.createGETConnectionToServer(connectCtx) cancel() - + if errors.Is(err, ErrGetMethodNotAllowed) { // server does not support listening c.logger.Errorf("server does not support listening") @@ -621,7 +627,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) { if err != nil { c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) } - + // Use context-aware sleep select { case <-time.After(retryInterval): @@ -639,7 +645,7 @@ var ( ) func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error { - resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream") + resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream", nil) if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -704,15 +710,15 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON // Create a new context with timeout for request handling, respecting parent context requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - + response, err := handler(requestCtx, request) if err != nil { c.logger.Errorf("error handling request %s: %v", request.Method, err) - + // Determine appropriate JSON-RPC error code based on error type var errorCode int var errorMessage string - + // Check for specific sampling-related errors if errors.Is(err, context.Canceled) { errorCode = -32800 // Request cancelled @@ -731,7 +737,7 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON errorMessage = err.Error() } } - + // Send error response errorResponse := &JSONRPCResponse{ JSONRPC: "2.0", @@ -771,7 +777,7 @@ func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSO ctx, cancel := c.contextAwareOfClientClose(ctx) defer cancel() - resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json") + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json", nil) if err != nil { c.logger.Errorf("failed to send response to server: %v", err) return diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go index 5208cb9c..c384d8b3 100644 --- a/client/transport/streamable_http_test.go +++ b/client/transport/streamable_http_test.go @@ -70,6 +70,7 @@ func startMockStreamableHTTPServer() (string, func()) { "jsonrpc": "2.0", "id": request["id"], "result": request, + "headers": r.Header, }); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) return @@ -122,6 +123,24 @@ func startMockStreamableHTTPServer() (string, func()) { http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } + case "debug/echo_header": + // Check session ID + if r.Header.Get("Mcp-Session-Id") != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Echo back the request headersas the response result + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "result": r.Header, + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } }) @@ -215,6 +234,40 @@ func TestStreamableHTTP(t *testing.T) { } }) + t.Run("SendRequestWithHeader", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + params := map[string]any{ + "string": "hello world", + "array": []any{1, 2, 3}, + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(1)), + Method: "debug/echo_header", + Params: params, + Header: http.Header{"X-Test-Header": {"test-header-value"}}, + } + + // Send the request + response, err := trans.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + // Parse the result to verify echo + var result map[string]any + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + if hdr, ok := result["X-Test-Header"].([]any); !ok || len(hdr) == 0 || hdr[0] != "test-header-value" { + t.Errorf("Expected X-Test-Header to be ['test-header-value'], got %v", result["X-Test-Header"]) + } + }) + t.Run("SendRequestWithTimeout", func(t *testing.T) { // Create a context that's already canceled ctx, cancel := context.WithCancel(context.Background()) From 8eaa5047407ffbe5eba7369a306419d3b0f177bc Mon Sep 17 00:00:00 2001 From: Matthis Holleville Date: Thu, 2 Oct 2025 15:35:07 +0200 Subject: [PATCH 2/5] feat(client/transport): enhance HTTP request flexibility Enhanced the flexibility of HTTP requests in the streamable HTTP client by allowing additional headers to be specified. This change aims to support more diverse server requirements and improve the adaptability of our client transport layer. Signed-off-by: Matthis Holleville --- client/transport/streamable_http.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 91427c1b..3a581b82 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -752,7 +752,7 @@ func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSO ctx, cancel := c.contextAwareOfClientClose(ctx) defer cancel() - resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json, text/event-stream") + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json, text/event-stream", nil) if err != nil { c.logger.Errorf("failed to send response to server: %v", err) return From 660ac1a57b99004a8f773c81b195d5d2d561d676 Mon Sep 17 00:00:00 2001 From: Matthis Holleville Date: Tue, 21 Oct 2025 18:28:08 +0200 Subject: [PATCH 3/5] fix: Improve OAuth error handling and test readability in HTTP transport - Enhanced OAuth error detection by using `errors.Is` for more reliable error handling. - Corrected a typo in a comment and improved code readability in tests by using a variable for headers. Signed-off-by: Matthis Holleville --- client/transport/streamable_http.go | 2 +- client/transport/streamable_http_test.go | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 3a581b82..ed56fe05 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -374,7 +374,7 @@ func (c *StreamableHTTP) sendHTTP( authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) if err != nil { // If we get an authorization error, return a specific error that can be handled by the client - if err.Error() == "no valid token available, authorization required" { + if errors.Is(err, ErrOAuthAuthorizationRequired) { return nil, &OAuthAuthorizationRequiredError{ Handler: c.oauthHandler, } diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go index 8067fa03..002ea672 100644 --- a/client/transport/streamable_http_test.go +++ b/client/transport/streamable_http_test.go @@ -130,7 +130,7 @@ func startMockStreamableHTTPServer() (string, func()) { return } - // Echo back the request headersas the response result + // Echo back the request headers as the response result w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(map[string]any{ @@ -243,12 +243,13 @@ func TestStreamableHTTP(t *testing.T) { "array": []any{1, 2, 3}, } + hdr := http.Header{"X-Test-Header": {"test-header-value"}} request := JSONRPCRequest{ JSONRPC: "2.0", ID: mcp.NewRequestId(int64(1)), Method: "debug/echo_header", Params: params, - Header: http.Header{"X-Test-Header": {"test-header-value"}}, + Header: hdr, } // Send the request From 99ccc6ef5edcfb5343bd34c35c71c951e3e79b1e Mon Sep 17 00:00:00 2001 From: Matthis Holleville Date: Tue, 21 Oct 2025 18:39:56 +0200 Subject: [PATCH 4/5] fix: improve variable naming for clarity in streamable_http_test Signed-off-by: Matthis Holleville --- client/transport/streamable_http_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go index 002ea672..5499f464 100644 --- a/client/transport/streamable_http_test.go +++ b/client/transport/streamable_http_test.go @@ -264,7 +264,7 @@ func TestStreamableHTTP(t *testing.T) { t.Fatalf("Failed to unmarshal result: %v", err) } - if hdr, ok := result["X-Test-Header"].([]any); !ok || len(hdr) == 0 || hdr[0] != "test-header-value" { + if headerValues, ok := result["X-Test-Header"].([]any); !ok || len(headerValues) == 0 || headerValues[0] != "test-header-value" { t.Errorf("Expected X-Test-Header to be ['test-header-value'], got %v", result["X-Test-Header"]) } }) From 01c33458d3885b75ba03bbf2512c14a700b136fd Mon Sep 17 00:00:00 2001 From: Matthis Holleville Date: Tue, 21 Oct 2025 18:53:51 +0200 Subject: [PATCH 5/5] feat: Ensure system headers are preserved in streamable HTTP tests To maintain consistency and ensure the integrity of HTTP headers during tests, system headers like Content-Type are now verified to be preserved. This change enhances the reliability of our testing framework by ensuring essential headers are not inadvertently removed or altered during the testing process. Signed-off-by: Matthis Holleville --- client/transport/streamable_http_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go index 5499f464..cdd5a93e 100644 --- a/client/transport/streamable_http_test.go +++ b/client/transport/streamable_http_test.go @@ -267,6 +267,11 @@ func TestStreamableHTTP(t *testing.T) { if headerValues, ok := result["X-Test-Header"].([]any); !ok || len(headerValues) == 0 || headerValues[0] != "test-header-value" { t.Errorf("Expected X-Test-Header to be ['test-header-value'], got %v", result["X-Test-Header"]) } + + // Verify system headers are still present + if contentType, ok := result["Content-Type"].([]any); !ok || len(contentType) == 0 { + t.Errorf("Expected Content-Type header to be preserved") + } }) t.Run("SendRequestWithTimeout", func(t *testing.T) {