diff --git a/client/client.go b/client/client.go index 220786b68..86fcbcf98 100644 --- a/client/client.go +++ b/client/client.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "slices" "sync" "sync/atomic" @@ -145,6 +146,7 @@ func (c *Client) sendRequest( ctx context.Context, method string, params any, + header http.Header, ) (*json.RawMessage, error) { if !c.initialized && method != "initialize" { return nil, fmt.Errorf("client not initialized") @@ -157,6 +159,7 @@ func (c *Client) sendRequest( ID: mcp.NewRequestId(id), Method: method, Params: params, + Header: header, } response, err := c.transport.SendRequest(ctx, request) @@ -198,7 +201,7 @@ func (c *Client) Initialize( Capabilities: capabilities, } - response, err := c.sendRequest(ctx, "initialize", params) + response, err := c.sendRequest(ctx, "initialize", params, request.Header) if err != nil { return nil, err } @@ -243,7 +246,7 @@ func (c *Client) Initialize( } func (c *Client) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil) + _, err := c.sendRequest(ctx, "ping", nil, nil) return err } @@ -324,7 +327,7 @@ func (c *Client) ReadResource( ctx context.Context, request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params) + response, err := c.sendRequest(ctx, "resources/read", request.Params, request.Header) if err != nil { return nil, err } @@ -336,7 +339,7 @@ func (c *Client) Subscribe( ctx context.Context, request mcp.SubscribeRequest, ) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) + _, err := c.sendRequest(ctx, "resources/subscribe", request.Params, request.Header) return err } @@ -344,7 +347,7 @@ func (c *Client) Unsubscribe( ctx context.Context, request mcp.UnsubscribeRequest, ) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) + _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params, request.Header) return err } @@ -388,7 +391,7 @@ func (c *Client) GetPrompt( ctx context.Context, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params) + response, err := c.sendRequest(ctx, "prompts/get", request.Params, request.Header) if err != nil { return nil, err } @@ -436,7 +439,7 @@ func (c *Client) CallTool( ctx context.Context, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params) + response, err := c.sendRequest(ctx, "tools/call", request.Params, request.Header) if err != nil { return nil, err } @@ -448,7 +451,7 @@ func (c *Client) SetLevel( ctx context.Context, request mcp.SetLevelRequest, ) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) + _, err := c.sendRequest(ctx, "logging/setLevel", request.Params, request.Header) return err } @@ -456,7 +459,7 @@ func (c *Client) Complete( ctx context.Context, request mcp.CompleteRequest, ) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params) + response, err := c.sendRequest(ctx, "completion/complete", request.Params, request.Header) if err != nil { return nil, err } @@ -583,7 +586,7 @@ func listByPage[T any]( request mcp.PaginatedRequest, method string, ) (*T, error) { - response, err := client.sendRequest(ctx, method, request.Params) + response, err := client.sendRequest(ctx, method, request.Params, nil) if err != nil { return nil, err } diff --git a/client/transport/interface.go b/client/transport/interface.go index b00210e52..e35a5f318 100644 --- a/client/transport/interface.go +++ b/client/transport/interface.go @@ -3,6 +3,7 @@ package transport import ( "context" "encoding/json" + "net/http" "github.com/mark3labs/mcp-go/mcp" ) @@ -59,6 +60,7 @@ type JSONRPCRequest struct { ID mcp.RequestId `json:"id"` Method string `json:"method"` Params any `json:"params,omitempty"` + Header http.Header `json:"-"` } // JSONRPCResponse represents a JSON-RPC 2.0 response message. diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 9d5218139..ed56fe055 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -258,7 +258,7 @@ func (c *StreamableHTTP) SendRequest( ctx, cancel := c.contextAwareOfClientClose(ctx) defer cancel() - resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream", request.Header) if err != nil { if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) { // If the request is initialize, should not return a SessionTerminated error @@ -339,6 +339,7 @@ func (c *StreamableHTTP) sendHTTP( method string, body io.Reader, acceptType string, + header http.Header, ) (resp *http.Response, err error) { // Create HTTP request req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body) @@ -346,6 +347,11 @@ func (c *StreamableHTTP) sendHTTP( return nil, fmt.Errorf("failed to create request: %w", err) } + // request headers + if header != nil { + req.Header = header + } + // Set headers req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", acceptType) @@ -368,7 +374,7 @@ func (c *StreamableHTTP) sendHTTP( authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) if err != nil { // If we get an authorization error, return a specific error that can be handled by the client - if err.Error() == "no valid token available, authorization required" { + if errors.Is(err, ErrOAuthAuthorizationRequired) { return nil, &OAuthAuthorizationRequiredError{ Handler: c.oauthHandler, } @@ -539,7 +545,7 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. ctx, cancel := c.contextAwareOfClientClose(ctx) defer cancel() - resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream", nil) if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -631,7 +637,7 @@ var ( ) func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error { - resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream") + resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream", nil) if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -746,7 +752,7 @@ func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSO ctx, cancel := c.contextAwareOfClientClose(ctx) defer cancel() - resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json, text/event-stream") + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json, text/event-stream", nil) if err != nil { c.logger.Errorf("failed to send response to server: %v", err) return diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go index 04abc2450..cdd5a93ee 100644 --- a/client/transport/streamable_http_test.go +++ b/client/transport/streamable_http_test.go @@ -70,6 +70,7 @@ func startMockStreamableHTTPServer() (string, func()) { "jsonrpc": "2.0", "id": request["id"], "result": request, + "headers": r.Header, }); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) return @@ -122,6 +123,24 @@ func startMockStreamableHTTPServer() (string, func()) { http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } + case "debug/echo_header": + // Check session ID + if r.Header.Get("Mcp-Session-Id") != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Echo back the request headers as the response result + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "result": r.Header, + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } } }) @@ -215,6 +234,46 @@ func TestStreamableHTTP(t *testing.T) { } }) + t.Run("SendRequestWithHeader", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + params := map[string]any{ + "string": "hello world", + "array": []any{1, 2, 3}, + } + + hdr := http.Header{"X-Test-Header": {"test-header-value"}} + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(1)), + Method: "debug/echo_header", + Params: params, + Header: hdr, + } + + // Send the request + response, err := trans.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + // Parse the result to verify echo + var result map[string]any + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + if headerValues, ok := result["X-Test-Header"].([]any); !ok || len(headerValues) == 0 || headerValues[0] != "test-header-value" { + t.Errorf("Expected X-Test-Header to be ['test-header-value'], got %v", result["X-Test-Header"]) + } + + // Verify system headers are still present + if contentType, ok := result["Content-Type"].([]any); !ok || len(contentType) == 0 { + t.Errorf("Expected Content-Type header to be preserved") + } + }) + t.Run("SendRequestWithTimeout", func(t *testing.T) { // Create a context that's already canceled ctx, cancel := context.WithCancel(context.Background())