Skip to content

Commit 5a8cd65

Browse files
feat: add support for custom HTTP headers in client requests
This update introduces the ability to include custom HTTP headers in requests sent from the client. This enhancement facilitates more flexible and secure communication with servers by allowing clients to pass additional information in the header of each request, such as authentication tokens or custom metadata. This feature is crucial for integrating with APIs that require specific headers for access control, content negotiation, or tracking purposes. Signed-off-by: Matthis Holleville <[email protected]>
1 parent 9f16336 commit 5a8cd65

File tree

4 files changed

+84
-20
lines changed

4 files changed

+84
-20
lines changed

client/client.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"net/http"
89
"slices"
910
"sync"
1011
"sync/atomic"
@@ -130,6 +131,7 @@ func (c *Client) sendRequest(
130131
ctx context.Context,
131132
method string,
132133
params any,
134+
header http.Header,
133135
) (*json.RawMessage, error) {
134136
if !c.initialized && method != "initialize" {
135137
return nil, fmt.Errorf("client not initialized")
@@ -142,6 +144,7 @@ func (c *Client) sendRequest(
142144
ID: mcp.NewRequestId(id),
143145
Method: method,
144146
Params: params,
147+
Header: header,
145148
}
146149

147150
response, err := c.transport.SendRequest(ctx, request)
@@ -179,7 +182,7 @@ func (c *Client) Initialize(
179182
Capabilities: capabilities,
180183
}
181184

182-
response, err := c.sendRequest(ctx, "initialize", params)
185+
response, err := c.sendRequest(ctx, "initialize", params, request.Header)
183186
if err != nil {
184187
return nil, err
185188
}
@@ -224,7 +227,7 @@ func (c *Client) Initialize(
224227
}
225228

226229
func (c *Client) Ping(ctx context.Context) error {
227-
_, err := c.sendRequest(ctx, "ping", nil)
230+
_, err := c.sendRequest(ctx, "ping", nil, nil)
228231
return err
229232
}
230233

@@ -305,7 +308,7 @@ func (c *Client) ReadResource(
305308
ctx context.Context,
306309
request mcp.ReadResourceRequest,
307310
) (*mcp.ReadResourceResult, error) {
308-
response, err := c.sendRequest(ctx, "resources/read", request.Params)
311+
response, err := c.sendRequest(ctx, "resources/read", request.Params, request.Header)
309312
if err != nil {
310313
return nil, err
311314
}
@@ -317,15 +320,15 @@ func (c *Client) Subscribe(
317320
ctx context.Context,
318321
request mcp.SubscribeRequest,
319322
) error {
320-
_, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
323+
_, err := c.sendRequest(ctx, "resources/subscribe", request.Params, request.Header)
321324
return err
322325
}
323326

324327
func (c *Client) Unsubscribe(
325328
ctx context.Context,
326329
request mcp.UnsubscribeRequest,
327330
) error {
328-
_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
331+
_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params, request.Header)
329332
return err
330333
}
331334

@@ -369,7 +372,7 @@ func (c *Client) GetPrompt(
369372
ctx context.Context,
370373
request mcp.GetPromptRequest,
371374
) (*mcp.GetPromptResult, error) {
372-
response, err := c.sendRequest(ctx, "prompts/get", request.Params)
375+
response, err := c.sendRequest(ctx, "prompts/get", request.Params, request.Header)
373376
if err != nil {
374377
return nil, err
375378
}
@@ -417,7 +420,7 @@ func (c *Client) CallTool(
417420
ctx context.Context,
418421
request mcp.CallToolRequest,
419422
) (*mcp.CallToolResult, error) {
420-
response, err := c.sendRequest(ctx, "tools/call", request.Params)
423+
response, err := c.sendRequest(ctx, "tools/call", request.Params, request.Header)
421424
if err != nil {
422425
return nil, err
423426
}
@@ -429,15 +432,15 @@ func (c *Client) SetLevel(
429432
ctx context.Context,
430433
request mcp.SetLevelRequest,
431434
) error {
432-
_, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
435+
_, err := c.sendRequest(ctx, "logging/setLevel", request.Params, request.Header)
433436
return err
434437
}
435438

436439
func (c *Client) Complete(
437440
ctx context.Context,
438441
request mcp.CompleteRequest,
439442
) (*mcp.CompleteResult, error) {
440-
response, err := c.sendRequest(ctx, "completion/complete", request.Params)
443+
response, err := c.sendRequest(ctx, "completion/complete", request.Params, request.Header)
441444
if err != nil {
442445
return nil, err
443446
}
@@ -514,7 +517,7 @@ func listByPage[T any](
514517
request mcp.PaginatedRequest,
515518
method string,
516519
) (*T, error) {
517-
response, err := client.sendRequest(ctx, method, request.Params)
520+
response, err := client.sendRequest(ctx, method, request.Params, nil)
518521
if err != nil {
519522
return nil, err
520523
}

client/transport/interface.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package transport
33
import (
44
"context"
55
"encoding/json"
6+
"net/http"
67

78
"github.com/mark3labs/mcp-go/mcp"
89
)
@@ -59,6 +60,7 @@ type JSONRPCRequest struct {
5960
ID mcp.RequestId `json:"id"`
6061
Method string `json:"method"`
6162
Params any `json:"params,omitempty"`
63+
Header http.Header `json:"-"`
6264
}
6365

6466
type JSONRPCResponse struct {

client/transport/streamable_http.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ func (c *StreamableHTTP) SendRequest(
258258
ctx, cancel := c.contextAwareOfClientClose(ctx)
259259
defer cancel()
260260

261-
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
261+
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream", request.Header)
262262
if err != nil {
263263
if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
264264
// If the request is initialize, should not return a SessionTerminated error
@@ -339,13 +339,19 @@ func (c *StreamableHTTP) sendHTTP(
339339
method string,
340340
body io.Reader,
341341
acceptType string,
342+
header http.Header,
342343
) (resp *http.Response, err error) {
343344
// Create HTTP request
344345
req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
345346
if err != nil {
346347
return nil, fmt.Errorf("failed to create request: %w", err)
347348
}
348349

350+
// request headers
351+
if header != nil {
352+
req.Header = header
353+
}
354+
349355
// Set headers
350356
req.Header.Set("Content-Type", "application/json")
351357
req.Header.Set("Accept", acceptType)
@@ -546,7 +552,7 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
546552
ctx, cancel := c.contextAwareOfClientClose(ctx)
547553
defer cancel()
548554

549-
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
555+
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream", nil)
550556
if err != nil {
551557
return fmt.Errorf("failed to send request: %w", err)
552558
}
@@ -605,7 +611,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
605611
connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
606612
err := c.createGETConnectionToServer(connectCtx)
607613
cancel()
608-
614+
609615
if errors.Is(err, ErrGetMethodNotAllowed) {
610616
// server does not support listening
611617
c.logger.Errorf("server does not support listening")
@@ -621,7 +627,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
621627
if err != nil {
622628
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
623629
}
624-
630+
625631
// Use context-aware sleep
626632
select {
627633
case <-time.After(retryInterval):
@@ -639,7 +645,7 @@ var (
639645
)
640646

641647
func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error {
642-
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
648+
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream", nil)
643649
if err != nil {
644650
return fmt.Errorf("failed to send request: %w", err)
645651
}
@@ -704,15 +710,15 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON
704710
// Create a new context with timeout for request handling, respecting parent context
705711
requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
706712
defer cancel()
707-
713+
708714
response, err := handler(requestCtx, request)
709715
if err != nil {
710716
c.logger.Errorf("error handling request %s: %v", request.Method, err)
711-
717+
712718
// Determine appropriate JSON-RPC error code based on error type
713719
var errorCode int
714720
var errorMessage string
715-
721+
716722
// Check for specific sampling-related errors
717723
if errors.Is(err, context.Canceled) {
718724
errorCode = -32800 // Request cancelled
@@ -731,7 +737,7 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON
731737
errorMessage = err.Error()
732738
}
733739
}
734-
740+
735741
// Send error response
736742
errorResponse := &JSONRPCResponse{
737743
JSONRPC: "2.0",
@@ -771,7 +777,7 @@ func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSO
771777
ctx, cancel := c.contextAwareOfClientClose(ctx)
772778
defer cancel()
773779

774-
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json")
780+
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json", nil)
775781
if err != nil {
776782
c.logger.Errorf("failed to send response to server: %v", err)
777783
return

client/transport/streamable_http_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ func startMockStreamableHTTPServer() (string, func()) {
7070
"jsonrpc": "2.0",
7171
"id": request["id"],
7272
"result": request,
73+
"headers": r.Header,
7374
}); err != nil {
7475
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
7576
return
@@ -122,6 +123,24 @@ func startMockStreamableHTTPServer() (string, func()) {
122123
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
123124
return
124125
}
126+
case "debug/echo_header":
127+
// Check session ID
128+
if r.Header.Get("Mcp-Session-Id") != sessionID {
129+
http.Error(w, "Invalid session ID", http.StatusNotFound)
130+
return
131+
}
132+
133+
// Echo back the request headersas the response result
134+
w.Header().Set("Content-Type", "application/json")
135+
w.WriteHeader(http.StatusOK)
136+
if err := json.NewEncoder(w).Encode(map[string]any{
137+
"jsonrpc": "2.0",
138+
"id": request["id"],
139+
"result": r.Header,
140+
}); err != nil {
141+
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
142+
return
143+
}
125144
}
126145
})
127146

@@ -215,6 +234,40 @@ func TestStreamableHTTP(t *testing.T) {
215234
}
216235
})
217236

237+
t.Run("SendRequestWithHeader", func(t *testing.T) {
238+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
239+
defer cancel()
240+
241+
params := map[string]any{
242+
"string": "hello world",
243+
"array": []any{1, 2, 3},
244+
}
245+
246+
request := JSONRPCRequest{
247+
JSONRPC: "2.0",
248+
ID: mcp.NewRequestId(int64(1)),
249+
Method: "debug/echo_header",
250+
Params: params,
251+
Header: http.Header{"X-Test-Header": {"test-header-value"}},
252+
}
253+
254+
// Send the request
255+
response, err := trans.SendRequest(ctx, request)
256+
if err != nil {
257+
t.Fatalf("SendRequest failed: %v", err)
258+
}
259+
260+
// Parse the result to verify echo
261+
var result map[string]any
262+
if err := json.Unmarshal(response.Result, &result); err != nil {
263+
t.Fatalf("Failed to unmarshal result: %v", err)
264+
}
265+
266+
if hdr, ok := result["X-Test-Header"].([]any); !ok || len(hdr) == 0 || hdr[0] != "test-header-value" {
267+
t.Errorf("Expected X-Test-Header to be ['test-header-value'], got %v", result["X-Test-Header"])
268+
}
269+
})
270+
218271
t.Run("SendRequestWithTimeout", func(t *testing.T) {
219272
// Create a context that's already canceled
220273
ctx, cancel := context.WithCancel(context.Background())

0 commit comments

Comments
 (0)