Skip to content

Commit d0e3734

Browse files
authored
Merge branch 'main' into feat/sendlog
2 parents a7b38b8 + 17af676 commit d0e3734

29 files changed

+2877
-384
lines changed

README.md

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func main() {
5858
}
5959

6060
func helloHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
61-
name, ok := request.Params.Arguments["name"].(string)
61+
name, ok := request.GetArguments()["name"].(string)
6262
if !ok {
6363
return nil, errors.New("name must be a string")
6464
}
@@ -97,6 +97,7 @@ MCP Go handles all the complex protocol details and server management, so you ca
9797
- [Session Management](#session-management)
9898
- [Request Hooks](#request-hooks)
9999
- [Tool Handler Middleware](#tool-handler-middleware)
100+
- [Regenerating Server Code](#regenerating-server-code)
100101
- [Contributing](/CONTRIBUTING.md)
101102

102103
## Installation
@@ -149,9 +150,21 @@ func main() {
149150

150151
// Add the calculator handler
151152
s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
152-
op := request.Params.Arguments["operation"].(string)
153-
x := request.Params.Arguments["x"].(float64)
154-
y := request.Params.Arguments["y"].(float64)
153+
// Using helper functions for type-safe argument access
154+
op, err := request.RequireString("operation")
155+
if err != nil {
156+
return mcp.NewToolResultError(err.Error()), nil
157+
}
158+
159+
x, err := request.RequireFloat("x")
160+
if err != nil {
161+
return mcp.NewToolResultError(err.Error()), nil
162+
}
163+
164+
y, err := request.RequireFloat("y")
165+
if err != nil {
166+
return mcp.NewToolResultError(err.Error()), nil
167+
}
155168

156169
var result float64
157170
switch op {
@@ -312,9 +325,10 @@ calculatorTool := mcp.NewTool("calculate",
312325
)
313326

314327
s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
315-
op := request.Params.Arguments["operation"].(string)
316-
x := request.Params.Arguments["x"].(float64)
317-
y := request.Params.Arguments["y"].(float64)
328+
args := request.GetArguments()
329+
op := args["operation"].(string)
330+
x := args["x"].(float64)
331+
y := args["y"].(float64)
318332

319333
var result float64
320334
switch op {
@@ -355,10 +369,11 @@ httpTool := mcp.NewTool("http_request",
355369
)
356370

357371
s.AddTool(httpTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
358-
method := request.Params.Arguments["method"].(string)
359-
url := request.Params.Arguments["url"].(string)
372+
args := request.GetArguments()
373+
method := args["method"].(string)
374+
url := args["url"].(string)
360375
body := ""
361-
if b, ok := request.Params.Arguments["body"].(string); ok {
376+
if b, ok := args["body"].(string); ok {
362377
body = b
363378
}
364379

@@ -517,6 +532,10 @@ For examples, see the `examples/` directory.
517532

518533
## Extras
519534

535+
### Transports
536+
537+
MCP-Go supports stdio, SSE and streamable-HTTP transport layers.
538+
520539
### Session Management
521540

522541
MCP-Go provides a robust session management system that allows you to:
@@ -742,3 +761,14 @@ Add middleware to tool call handlers using the `server.WithToolHandlerMiddleware
742761

743762
A recovery middleware option is available to recover from panics in a tool call and can be added to the server with the `server.WithRecovery` option.
744763

764+
### Regenerating Server Code
765+
766+
Server hooks and request handlers are generated. Regenerate them by running:
767+
768+
```bash
769+
go generate ./...
770+
```
771+
772+
You need `go` installed and the `goimports` tool available. The generator runs
773+
`goimports` automatically to format and fix imports.
774+

client/inprocess_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func TestInProcessMCPClient(t *testing.T) {
3232
Content: []mcp.Content{
3333
mcp.TextContent{
3434
Type: "text",
35-
Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string),
35+
Text: "Input parameter: " + request.GetArguments()["parameter-1"].(string),
3636
},
3737
mcp.AudioContent{
3838
Type: "audio",

client/sse.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ func WithHeaders(headers map[string]string) transport.ClientOption {
1212
return transport.WithHeaders(headers)
1313
}
1414

15+
func WithHeaderFunc(headerFunc transport.HTTPHeaderFunc) transport.ClientOption {
16+
return transport.WithHeaderFunc(headerFunc)
17+
}
18+
1519
func WithHTTPClient(httpClient *http.Client) transport.ClientOption {
1620
return transport.WithHTTPClient(httpClient)
1721
}

client/sse_test.go

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package client
22

33
import (
44
"context"
5+
"net/http"
56
"testing"
67
"time"
78

@@ -11,6 +12,13 @@ import (
1112
"github.com/mark3labs/mcp-go/server"
1213
)
1314

15+
type contextKey string
16+
17+
const (
18+
testHeaderKey contextKey = "X-Test-Header"
19+
testHeaderFuncKey contextKey = "X-Test-Header-Func"
20+
)
21+
1422
func TestSSEMCPClient(t *testing.T) {
1523
// Create MCP server with capabilities
1624
mcpServer := server.NewMCPServer(
@@ -36,14 +44,34 @@ func TestSSEMCPClient(t *testing.T) {
3644
Content: []mcp.Content{
3745
mcp.TextContent{
3846
Type: "text",
39-
Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string),
47+
Text: "Input parameter: " + request.GetArguments()["parameter-1"].(string),
48+
},
49+
},
50+
}, nil
51+
})
52+
mcpServer.AddTool(mcp.NewTool(
53+
"test-tool-for-http-header",
54+
mcp.WithDescription("Test tool for http header"),
55+
), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
56+
// , X-Test-Header-Func
57+
return &mcp.CallToolResult{
58+
Content: []mcp.Content{
59+
mcp.TextContent{
60+
Type: "text",
61+
Text: "context from header: " + ctx.Value(testHeaderKey).(string) + ", " + ctx.Value(testHeaderFuncKey).(string),
4062
},
4163
},
4264
}, nil
4365
})
4466

4567
// Initialize
46-
testServer := server.NewTestServer(mcpServer)
68+
testServer := server.NewTestServer(mcpServer,
69+
server.WithSSEContextFunc(func(ctx context.Context, r *http.Request) context.Context {
70+
ctx = context.WithValue(ctx, testHeaderKey, r.Header.Get("X-Test-Header"))
71+
ctx = context.WithValue(ctx, testHeaderFuncKey, r.Header.Get("X-Test-Header-Func"))
72+
return ctx
73+
}),
74+
)
4775
defer testServer.Close()
4876

4977
t.Run("Can create client", func(t *testing.T) {
@@ -250,4 +278,56 @@ func TestSSEMCPClient(t *testing.T) {
250278
t.Errorf("Expected 1 content item, got %d", len(result.Content))
251279
}
252280
})
281+
282+
t.Run("CallTool with customized header", func(t *testing.T) {
283+
client, err := NewSSEMCPClient(testServer.URL+"/sse",
284+
WithHeaders(map[string]string{
285+
"X-Test-Header": "test-header-value",
286+
}),
287+
WithHeaderFunc(func(ctx context.Context) map[string]string {
288+
return map[string]string{
289+
"X-Test-Header-Func": "test-header-func-value",
290+
}
291+
}),
292+
)
293+
if err != nil {
294+
t.Fatalf("Failed to create client: %v", err)
295+
}
296+
defer client.Close()
297+
298+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
299+
defer cancel()
300+
301+
if err := client.Start(ctx); err != nil {
302+
t.Fatalf("Failed to start client: %v", err)
303+
}
304+
305+
// Initialize
306+
initRequest := mcp.InitializeRequest{}
307+
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
308+
initRequest.Params.ClientInfo = mcp.Implementation{
309+
Name: "test-client",
310+
Version: "1.0.0",
311+
}
312+
313+
_, err = client.Initialize(ctx, initRequest)
314+
if err != nil {
315+
t.Fatalf("Failed to initialize: %v", err)
316+
}
317+
318+
request := mcp.CallToolRequest{}
319+
request.Params.Name = "test-tool-for-http-header"
320+
321+
result, err := client.CallTool(ctx, request)
322+
if err != nil {
323+
t.Fatalf("CallTool failed: %v", err)
324+
}
325+
326+
if len(result.Content) != 1 {
327+
t.Errorf("Expected 1 content item, got %d", len(result.Content))
328+
}
329+
if result.Content[0].(mcp.TextContent).Text != "context from header: test-header-value, test-header-func-value" {
330+
t.Errorf("Got %q, want %q", result.Content[0].(mcp.TextContent).Text, "context from header: test-header-value, test-header-func-value")
331+
}
332+
})
253333
}

client/transport/interface.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ import (
77
"github.com/mark3labs/mcp-go/mcp"
88
)
99

10+
// HTTPHeaderFunc is a function that extracts header entries from the given context
11+
// and returns them as key-value pairs. This is typically used to add context values
12+
// as HTTP headers in outgoing requests.
13+
type HTTPHeaderFunc func(context.Context) map[string]string
14+
1015
// Interface for the transport layer.
1116
type Interface interface {
1217
// Start the connection. Start should only be called once.

client/transport/sse.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ type SSE struct {
3131
notifyMu sync.RWMutex
3232
endpointChan chan struct{}
3333
headers map[string]string
34+
headerFunc HTTPHeaderFunc
3435

3536
started atomic.Bool
3637
closed atomic.Bool
@@ -45,6 +46,12 @@ func WithHeaders(headers map[string]string) ClientOption {
4546
}
4647
}
4748

49+
func WithHeaderFunc(headerFunc HTTPHeaderFunc) ClientOption {
50+
return func(sc *SSE) {
51+
sc.headerFunc = headerFunc
52+
}
53+
}
54+
4855
func WithHTTPClient(httpClient *http.Client) ClientOption {
4956
return func(sc *SSE) {
5057
sc.httpClient = httpClient
@@ -99,6 +106,11 @@ func (c *SSE) Start(ctx context.Context) error {
99106
for k, v := range c.headers {
100107
req.Header.Set(k, v)
101108
}
109+
if c.headerFunc != nil {
110+
for k, v := range c.headerFunc(ctx) {
111+
req.Header.Set(k, v)
112+
}
113+
}
102114

103115
resp, err := c.httpClient.Do(req)
104116
if err != nil {
@@ -269,6 +281,11 @@ func (c *SSE) SendRequest(
269281
for k, v := range c.headers {
270282
req.Header.Set(k, v)
271283
}
284+
if c.headerFunc != nil {
285+
for k, v := range c.headerFunc(ctx) {
286+
req.Header.Set(k, v)
287+
}
288+
}
272289

273290
// Create string key for map lookup
274291
idKey := request.ID.String()
@@ -310,8 +327,11 @@ func (c *SSE) SendRequest(
310327
case <-ctx.Done():
311328
deleteResponseChan()
312329
return nil, ctx.Err()
313-
case response := <-responseChan:
314-
return response, nil
330+
case response, ok := <-responseChan:
331+
if ok {
332+
return response, nil
333+
}
334+
return nil, fmt.Errorf("connection has been closed")
315335
}
316336
}
317337

@@ -365,6 +385,11 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
365385
for k, v := range c.headers {
366386
req.Header.Set(k, v)
367387
}
388+
if c.headerFunc != nil {
389+
for k, v := range c.headerFunc(ctx) {
390+
req.Header.Set(k, v)
391+
}
392+
}
368393

369394
resp, err := c.httpClient.Do(req)
370395
if err != nil {

client/transport/streamable_http.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
2626
}
2727
}
2828

29+
func WithHTTPHeaderFunc(headerFunc HTTPHeaderFunc) StreamableHTTPCOption {
30+
return func(sc *StreamableHTTP) {
31+
sc.headerFunc = headerFunc
32+
}
33+
}
34+
2935
// WithHTTPTimeout sets the timeout for a HTTP request and stream.
3036
func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
3137
return func(sc *StreamableHTTP) {
@@ -52,6 +58,7 @@ type StreamableHTTP struct {
5258
baseURL *url.URL
5359
httpClient *http.Client
5460
headers map[string]string
61+
headerFunc HTTPHeaderFunc
5562

5663
sessionID atomic.Value // string
5764

@@ -127,7 +134,6 @@ func (c *StreamableHTTP) Close() error {
127134
}
128135

129136
const (
130-
initializeMethod = "initialize"
131137
headerKeySessionID = "Mcp-Session-Id"
132138
)
133139

@@ -173,6 +179,11 @@ func (c *StreamableHTTP) SendRequest(
173179
for k, v := range c.headers {
174180
req.Header.Set(k, v)
175181
}
182+
if c.headerFunc != nil {
183+
for k, v := range c.headerFunc(ctx) {
184+
req.Header.Set(k, v)
185+
}
186+
}
176187

177188
// Send request
178189
resp, err := c.httpClient.Do(req)
@@ -198,7 +209,7 @@ func (c *StreamableHTTP) SendRequest(
198209
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
199210
}
200211

201-
if request.Method == initializeMethod {
212+
if request.Method == string(mcp.MethodInitialize) {
202213
// saved the received session ID in the response
203214
// empty session ID is allowed
204215
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
@@ -363,6 +374,11 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
363374
for k, v := range c.headers {
364375
req.Header.Set(k, v)
365376
}
377+
if c.headerFunc != nil {
378+
for k, v := range c.headerFunc(ctx) {
379+
req.Header.Set(k, v)
380+
}
381+
}
366382

367383
// Send request
368384
resp, err := c.httpClient.Do(req)

0 commit comments

Comments
 (0)