Skip to content

Commit b7f2605

Browse files
committed
merge main and fix conflict
2 parents de256af + f47e2bc commit b7f2605

File tree

22 files changed

+2633
-285
lines changed

22 files changed

+2633
-285
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
.aider*
22
.env
3-
.idea
3+
.idea
4+
.opencode
5+
.claude

README.md

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ MCP Go handles all the complex protocol details and server management, so you ca
9191
- [Tools](#tools)
9292
- [Prompts](#prompts)
9393
- [Examples](#examples)
94+
- [Extras](#extras)
95+
- [Session Management](#session-management)
96+
- [Request Hooks](#request-hooks)
97+
- [Tool Handler Middleware](#tool-handler-middleware)
9498
- [Contributing](#contributing)
9599
- [Prerequisites](#prerequisites)
96100
- [Installation](#installation-1)
@@ -516,6 +520,214 @@ For examples, see the `examples/` directory.
516520

517521
## Extras
518522

523+
### Session Management
524+
525+
MCP-Go provides a robust session management system that allows you to:
526+
- Maintain separate state for each connected client
527+
- Register and track client sessions
528+
- Send notifications to specific clients
529+
- Provide per-session tool customization
530+
531+
<details>
532+
<summary>Show Session Management Examples</summary>
533+
534+
#### Basic Session Handling
535+
536+
```go
537+
// Create a server with session capabilities
538+
s := server.NewMCPServer(
539+
"Session Demo",
540+
"1.0.0",
541+
server.WithToolCapabilities(true),
542+
)
543+
544+
// Implement your own ClientSession
545+
type MySession struct {
546+
id string
547+
notifChannel chan mcp.JSONRPCNotification
548+
isInitialized bool
549+
// Add custom fields for your application
550+
}
551+
552+
// Implement the ClientSession interface
553+
func (s *MySession) SessionID() string {
554+
return s.id
555+
}
556+
557+
func (s *MySession) NotificationChannel() chan<- mcp.JSONRPCNotification {
558+
return s.notifChannel
559+
}
560+
561+
func (s *MySession) Initialize() {
562+
s.isInitialized = true
563+
}
564+
565+
func (s *MySession) Initialized() bool {
566+
return s.isInitialized
567+
}
568+
569+
// Register a session
570+
session := &MySession{
571+
id: "user-123",
572+
notifChannel: make(chan mcp.JSONRPCNotification, 10),
573+
}
574+
if err := s.RegisterSession(context.Background(), session); err != nil {
575+
log.Printf("Failed to register session: %v", err)
576+
}
577+
578+
// Send notification to a specific client
579+
err := s.SendNotificationToSpecificClient(
580+
session.SessionID(),
581+
"notification/update",
582+
map[string]any{"message": "New data available!"},
583+
)
584+
if err != nil {
585+
log.Printf("Failed to send notification: %v", err)
586+
}
587+
588+
// Unregister session when done
589+
s.UnregisterSession(context.Background(), session.SessionID())
590+
```
591+
592+
#### Per-Session Tools
593+
594+
For more advanced use cases, you can implement the `SessionWithTools` interface to support per-session tool customization:
595+
596+
```go
597+
// Implement SessionWithTools interface for per-session tools
598+
type MyAdvancedSession struct {
599+
MySession // Embed the basic session
600+
sessionTools map[string]server.ServerTool
601+
}
602+
603+
// Implement additional methods for SessionWithTools
604+
func (s *MyAdvancedSession) GetSessionTools() map[string]server.ServerTool {
605+
return s.sessionTools
606+
}
607+
608+
func (s *MyAdvancedSession) SetSessionTools(tools map[string]server.ServerTool) {
609+
s.sessionTools = tools
610+
}
611+
612+
// Create and register a session with tools support
613+
advSession := &MyAdvancedSession{
614+
MySession: MySession{
615+
id: "user-456",
616+
notifChannel: make(chan mcp.JSONRPCNotification, 10),
617+
},
618+
sessionTools: make(map[string]server.ServerTool),
619+
}
620+
if err := s.RegisterSession(context.Background(), advSession); err != nil {
621+
log.Printf("Failed to register session: %v", err)
622+
}
623+
624+
// Add session-specific tools
625+
userSpecificTool := mcp.NewTool(
626+
"user_data",
627+
mcp.WithDescription("Access user-specific data"),
628+
)
629+
// You can use AddSessionTool (similar to AddTool)
630+
err := s.AddSessionTool(
631+
advSession.SessionID(),
632+
userSpecificTool,
633+
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
634+
// This handler is only available to this specific session
635+
return mcp.NewToolResultText("User-specific data for " + advSession.SessionID()), nil
636+
},
637+
)
638+
if err != nil {
639+
log.Printf("Failed to add session tool: %v", err)
640+
}
641+
642+
// Or use AddSessionTools directly with ServerTool
643+
/*
644+
err := s.AddSessionTools(
645+
advSession.SessionID(),
646+
server.ServerTool{
647+
Tool: userSpecificTool,
648+
Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
649+
// This handler is only available to this specific session
650+
return mcp.NewToolResultText("User-specific data for " + advSession.SessionID()), nil
651+
},
652+
},
653+
)
654+
if err != nil {
655+
log.Printf("Failed to add session tool: %v", err)
656+
}
657+
*/
658+
659+
// Delete session-specific tools when no longer needed
660+
err = s.DeleteSessionTools(advSession.SessionID(), "user_data")
661+
if err != nil {
662+
log.Printf("Failed to delete session tool: %v", err)
663+
}
664+
```
665+
666+
#### Tool Filtering
667+
668+
You can also apply filters to control which tools are available to certain sessions:
669+
670+
```go
671+
// Add a tool filter that only shows tools with certain prefixes
672+
s := server.NewMCPServer(
673+
"Tool Filtering Demo",
674+
"1.0.0",
675+
server.WithToolCapabilities(true),
676+
server.WithToolFilter(func(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
677+
// Get session from context
678+
session := server.ClientSessionFromContext(ctx)
679+
if session == nil {
680+
return tools // Return all tools if no session
681+
}
682+
683+
// Example: filter tools based on session ID prefix
684+
if strings.HasPrefix(session.SessionID(), "admin-") {
685+
// Admin users get all tools
686+
return tools
687+
} else {
688+
// Regular users only get tools with "public-" prefix
689+
var filteredTools []mcp.Tool
690+
for _, tool := range tools {
691+
if strings.HasPrefix(tool.Name, "public-") {
692+
filteredTools = append(filteredTools, tool)
693+
}
694+
}
695+
return filteredTools
696+
}
697+
}),
698+
)
699+
```
700+
701+
#### Working with Context
702+
703+
The session context is automatically passed to tool and resource handlers:
704+
705+
```go
706+
s.AddTool(mcp.NewTool("session_aware"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
707+
// Get the current session from context
708+
session := server.ClientSessionFromContext(ctx)
709+
if session == nil {
710+
return mcp.NewToolResultError("No active session"), nil
711+
}
712+
713+
return mcp.NewToolResultText("Hello, session " + session.SessionID()), nil
714+
})
715+
716+
// When using handlers in HTTP/SSE servers, you need to pass the context with the session
717+
httpHandler := func(w http.ResponseWriter, r *http.Request) {
718+
// Get session from somewhere (like a cookie or header)
719+
session := getSessionFromRequest(r)
720+
721+
// Add session to context
722+
ctx := s.WithContext(r.Context(), session)
723+
724+
// Use this context when handling requests
725+
// ...
726+
}
727+
```
728+
729+
</details>
730+
519731
### Request Hooks
520732

521733
Hook into the request lifecycle by creating a `Hooks` object with your

client/sse.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@ package client
22

33
import (
44
"fmt"
5-
"github.com/mark3labs/mcp-go/client/transport"
5+
"net/http"
66
"net/url"
7+
8+
"github.com/mark3labs/mcp-go/client/transport"
79
)
810

911
func WithHeaders(headers map[string]string) transport.ClientOption {
1012
return transport.WithHeaders(headers)
1113
}
1214

15+
func WithHTTPClient(httpClient *http.Client) transport.ClientOption {
16+
return transport.WithHTTPClient(httpClient)
17+
}
18+
1319
// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL.
1420
// Returns an error if the URL is invalid.
1521
func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) {

client/stdio_test.go

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
"log/slog"
88
"os"
99
"os/exec"
10-
"path/filepath"
10+
"runtime"
1111
"sync"
1212
"testing"
1313
"time"
@@ -19,21 +19,41 @@ func compileTestServer(outputPath string) error {
1919
cmd := exec.Command(
2020
"go",
2121
"build",
22+
"-buildmode=pie",
2223
"-o",
2324
outputPath,
2425
"../testdata/mockstdio_server.go",
2526
)
27+
tmpCache, _ := os.MkdirTemp("", "gocache")
28+
cmd.Env = append(os.Environ(), "GOCACHE="+tmpCache)
29+
2630
if output, err := cmd.CombinedOutput(); err != nil {
2731
return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output)
2832
}
33+
// Verify the binary was actually created
34+
if _, err := os.Stat(outputPath); os.IsNotExist(err) {
35+
return fmt.Errorf("mock server binary not found at %s after compilation", outputPath)
36+
}
2937
return nil
3038
}
3139

3240
func TestStdioMCPClient(t *testing.T) {
33-
// Compile mock server
34-
mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server")
35-
if err := compileTestServer(mockServerPath); err != nil {
36-
t.Fatalf("Failed to compile mock server: %v", err)
41+
// Create a temporary file for the mock server
42+
tempFile, err := os.CreateTemp("", "mockstdio_server")
43+
if err != nil {
44+
t.Fatalf("Failed to create temp file: %v", err)
45+
}
46+
tempFile.Close()
47+
mockServerPath := tempFile.Name()
48+
49+
// Add .exe suffix on Windows
50+
if runtime.GOOS == "windows" {
51+
os.Remove(mockServerPath) // Remove the empty file first
52+
mockServerPath += ".exe"
53+
}
54+
55+
if compileErr := compileTestServer(mockServerPath); compileErr != nil {
56+
t.Fatalf("Failed to compile mock server: %v", compileErr)
3757
}
3858
defer os.Remove(mockServerPath)
3959

client/transport/sse.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ func WithHeaders(headers map[string]string) ClientOption {
4545
}
4646
}
4747

48+
func WithHTTPClient(httpClient *http.Client) ClientOption {
49+
return func(sc *SSE) {
50+
sc.httpClient = httpClient
51+
}
52+
}
53+
4854
// NewSSE creates a new SSE-based MCP client with the given base URL.
4955
// Returns an error if the URL is invalid.
5056
func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
@@ -226,7 +232,7 @@ func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotifi
226232
c.onNotification = handler
227233
}
228234

229-
// sendRequest sends a JSON-RPC request to the server and waits for a response.
235+
// SendRequest sends a JSON-RPC request to the server and waits for a response.
230236
// Returns the raw JSON response message or an error if the request fails.
231237
func (c *SSE) SendRequest(
232238
ctx context.Context,
@@ -278,13 +284,19 @@ func (c *SSE) SendRequest(
278284
deleteResponseChan()
279285
return nil, fmt.Errorf("failed to send request: %w", err)
280286
}
281-
defer resp.Body.Close()
287+
288+
// Drain any outstanding io
289+
body, err := io.ReadAll(resp.Body)
290+
resp.Body.Close()
291+
292+
if err != nil {
293+
return nil, fmt.Errorf("failed to read response body: %w", err)
294+
}
282295

283296
// Check if we got an error response
284297
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
285298
deleteResponseChan()
286299

287-
body, _ := io.ReadAll(resp.Body)
288300
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
289301
}
290302

client/transport/sse_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,31 @@ func TestSSEErrors(t *testing.T) {
415415
}
416416
})
417417

418+
t.Run("WithHTTPClient", func(t *testing.T) {
419+
// Create a custom client with a very short timeout
420+
customClient := &http.Client{Timeout: 1 * time.Nanosecond}
421+
422+
url, closeF := startMockSSEEchoServer()
423+
defer closeF()
424+
// Initialize SSE transport with the custom HTTP client
425+
trans, err := NewSSE(url, WithHTTPClient(customClient))
426+
if err != nil {
427+
t.Fatalf("Failed to create SSE with custom client: %v", err)
428+
}
429+
430+
// Starting should immediately error due to timeout
431+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
432+
defer cancel()
433+
err = trans.Start(ctx)
434+
if err == nil {
435+
t.Error("Expected Start to fail with custom timeout, got nil")
436+
}
437+
if !errors.Is(err, context.DeadlineExceeded) {
438+
t.Errorf("Expected error 'context deadline exceeded', got '%s'", err.Error())
439+
}
440+
trans.Close()
441+
})
442+
418443
t.Run("RequestBeforeStart", func(t *testing.T) {
419444
url, closeF := startMockSSEEchoServer()
420445
defer closeF()

0 commit comments

Comments
 (0)