@@ -2,6 +2,7 @@ package client
22
33import  (
44	"context" 
5+ 	"net/http" 
56	"testing" 
67	"time" 
78
@@ -11,6 +12,13 @@ import (
1112	"github.com/mark3labs/mcp-go/server" 
1213)
1314
15+ type  contextKey  string 
16+ 
17+ const  (
18+ 	testHeaderKey      contextKey  =  "X-Test-Header" 
19+ 	testHeaderFuncKey  contextKey  =  "X-Test-Header-Func" 
20+ )
21+ 
1422func  TestSSEMCPClient (t  * testing.T ) {
1523	// Create MCP server with capabilities 
1624	mcpServer  :=  server .NewMCPServer (
@@ -36,14 +44,34 @@ func TestSSEMCPClient(t *testing.T) {
3644			Content : []mcp.Content {
3745				mcp.TextContent {
3846					Type : "text" ,
39- 					Text : "Input parameter: "  +  request .Params .Arguments ["parameter-1" ].(string ),
47+ 					Text : "Input parameter: "  +  request .GetArguments ()["parameter-1" ].(string ),
48+ 				},
49+ 			},
50+ 		}, nil 
51+ 	})
52+ 	mcpServer .AddTool (mcp .NewTool (
53+ 		"test-tool-for-http-header" ,
54+ 		mcp .WithDescription ("Test tool for http header" ),
55+ 	), func (ctx  context.Context , request  mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
56+ 		//  , X-Test-Header-Func 
57+ 		return  & mcp.CallToolResult {
58+ 			Content : []mcp.Content {
59+ 				mcp.TextContent {
60+ 					Type : "text" ,
61+ 					Text : "context from header: "  +  ctx .Value (testHeaderKey ).(string ) +  ", "  +  ctx .Value (testHeaderFuncKey ).(string ),
4062				},
4163			},
4264		}, nil 
4365	})
4466
4567	// Initialize 
46- 	testServer  :=  server .NewTestServer (mcpServer )
68+ 	testServer  :=  server .NewTestServer (mcpServer ,
69+ 		server .WithSSEContextFunc (func (ctx  context.Context , r  * http.Request ) context.Context  {
70+ 			ctx  =  context .WithValue (ctx , testHeaderKey , r .Header .Get ("X-Test-Header" ))
71+ 			ctx  =  context .WithValue (ctx , testHeaderFuncKey , r .Header .Get ("X-Test-Header-Func" ))
72+ 			return  ctx 
73+ 		}),
74+ 	)
4775	defer  testServer .Close ()
4876
4977	t .Run ("Can create client" , func (t  * testing.T ) {
@@ -250,4 +278,56 @@ func TestSSEMCPClient(t *testing.T) {
250278			t .Errorf ("Expected 1 content item, got %d" , len (result .Content ))
251279		}
252280	})
281+ 
282+ 	t .Run ("CallTool with customized header" , func (t  * testing.T ) {
283+ 		client , err  :=  NewSSEMCPClient (testServer .URL + "/sse" ,
284+ 			WithHeaders (map [string ]string {
285+ 				"X-Test-Header" : "test-header-value" ,
286+ 			}),
287+ 			WithHeaderFunc (func (ctx  context.Context ) map [string ]string  {
288+ 				return  map [string ]string {
289+ 					"X-Test-Header-Func" : "test-header-func-value" ,
290+ 				}
291+ 			}),
292+ 		)
293+ 		if  err  !=  nil  {
294+ 			t .Fatalf ("Failed to create client: %v" , err )
295+ 		}
296+ 		defer  client .Close ()
297+ 
298+ 		ctx , cancel  :=  context .WithTimeout (context .Background (), 5 * time .Second )
299+ 		defer  cancel ()
300+ 
301+ 		if  err  :=  client .Start (ctx ); err  !=  nil  {
302+ 			t .Fatalf ("Failed to start client: %v" , err )
303+ 		}
304+ 
305+ 		// Initialize 
306+ 		initRequest  :=  mcp.InitializeRequest {}
307+ 		initRequest .Params .ProtocolVersion  =  mcp .LATEST_PROTOCOL_VERSION 
308+ 		initRequest .Params .ClientInfo  =  mcp.Implementation {
309+ 			Name :    "test-client" ,
310+ 			Version : "1.0.0" ,
311+ 		}
312+ 
313+ 		_ , err  =  client .Initialize (ctx , initRequest )
314+ 		if  err  !=  nil  {
315+ 			t .Fatalf ("Failed to initialize: %v" , err )
316+ 		}
317+ 
318+ 		request  :=  mcp.CallToolRequest {}
319+ 		request .Params .Name  =  "test-tool-for-http-header" 
320+ 
321+ 		result , err  :=  client .CallTool (ctx , request )
322+ 		if  err  !=  nil  {
323+ 			t .Fatalf ("CallTool failed: %v" , err )
324+ 		}
325+ 
326+ 		if  len (result .Content ) !=  1  {
327+ 			t .Errorf ("Expected 1 content item, got %d" , len (result .Content ))
328+ 		}
329+ 		if  result .Content [0 ].(mcp.TextContent ).Text  !=  "context from header: test-header-value, test-header-func-value"  {
330+ 			t .Errorf ("Got %q, want %q" , result .Content [0 ].(mcp.TextContent ).Text , "context from header: test-header-value, test-header-func-value" )
331+ 		}
332+ 	})
253333}
0 commit comments