Skip to content

Commit 47011cd

Browse files
committed
handle initializations per-session rather than globally
1 parent 88fce3a commit 47011cd

File tree

4 files changed

+71
-15
lines changed

4 files changed

+71
-15
lines changed

server/server.go

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"fmt"
88
"sort"
99
"sync"
10-
"sync/atomic"
1110

1211
"github.com/mark3labs/mcp-go/mcp"
1312
)
@@ -47,6 +46,10 @@ type ServerTool struct {
4746

4847
// ClientSession represents an active session that can be used by MCPServer to interact with client.
4948
type ClientSession interface {
49+
// Initialize marks session as fully initialized and ready for notifications
50+
Initialize()
51+
// Initialized returns if session is ready to accept notifications
52+
Initialized() bool
5053
// NotificationChannel provides a channel suitable for sending notifications to client.
5154
NotificationChannel() chan<- mcp.JSONRPCNotification
5255
// SessionID is a unique identifier used to track user session.
@@ -82,7 +85,6 @@ type MCPServer struct {
8285
notificationHandlers map[string]NotificationHandlerFunc
8386
capabilities serverCapabilities
8487
sessions sync.Map
85-
initialized atomic.Bool // Use atomic for the initialized flag
8688
}
8789

8890
// serverKey is the context key for storing the server instance
@@ -138,7 +140,7 @@ func (s *MCPServer) sendNotificationToAllClients(
138140
}
139141

140142
s.sessions.Range(func(k, v any) bool {
141-
if session, ok := v.(ClientSession); ok {
143+
if session, ok := v.(ClientSession); ok && session.Initialized() {
142144
select {
143145
case session.NotificationChannel() <- notification:
144146
default:
@@ -156,7 +158,7 @@ func (s *MCPServer) SendNotificationToClient(
156158
params map[string]any,
157159
) error {
158160
session := ClientSessionFromContext(ctx)
159-
if session == nil {
161+
if session == nil || !session.Initialized() {
160162
return fmt.Errorf("notification channel not initialized")
161163
}
162164

@@ -526,13 +528,10 @@ func (s *MCPServer) AddTools(tools ...ServerTool) {
526528
for _, entry := range tools {
527529
s.tools[entry.Tool.Name] = entry
528530
}
529-
initialized := s.initialized.Load()
530531
s.mu.Unlock()
531532

532-
// Send notification if server is already initialized
533-
if initialized {
534-
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
535-
}
533+
// Send notification to all initialized sessions
534+
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
536535
}
537536

538537
// SetTools replaces all existing tools with the provided list
@@ -549,13 +548,10 @@ func (s *MCPServer) DeleteTools(names ...string) {
549548
for _, name := range names {
550549
delete(s.tools, name)
551550
}
552-
initialized := s.initialized.Load()
553551
s.mu.Unlock()
554552

555-
// Send notification if server is already initialized
556-
if initialized {
557-
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
558-
}
553+
// Send notification to all initialized sessions
554+
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
559555
}
560556

561557
// AddNotificationHandler registers a new handler for incoming notifications
@@ -618,7 +614,9 @@ func (s *MCPServer) handleInitialize(
618614
Instructions: s.instructions,
619615
}
620616

621-
s.initialized.Store(true)
617+
if session := ClientSessionFromContext(ctx); session != nil {
618+
session.Initialize()
619+
}
622620
return createResponse(id, result)
623621
}
624622

server/server_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ func TestMCPServer_Tools(t *testing.T) {
180180
err := server.RegisterSession(&fakeSession{
181181
sessionID: "test",
182182
notificationChannel: notificationChannel,
183+
initialized: true,
183184
})
184185
require.NoError(t, err)
185186
server.SetTools(ServerTool{
@@ -210,6 +211,16 @@ func TestMCPServer_Tools(t *testing.T) {
210211
err := server.RegisterSession(&fakeSession{
211212
sessionID: fmt.Sprintf("test%d", i),
212213
notificationChannel: notificationChannel,
214+
initialized: true,
215+
})
216+
require.NoError(t, err)
217+
}
218+
// also let's register inactive sessions
219+
for i := range 5 {
220+
err := server.RegisterSession(&fakeSession{
221+
sessionID: fmt.Sprintf("test%d", i+5),
222+
notificationChannel: notificationChannel,
223+
initialized: false,
213224
})
214225
require.NoError(t, err)
215226
}
@@ -242,6 +253,7 @@ func TestMCPServer_Tools(t *testing.T) {
242253
err := server.RegisterSession(&fakeSession{
243254
sessionID: "test",
244255
notificationChannel: notificationChannel,
256+
initialized: true,
245257
})
246258
require.NoError(t, err)
247259
server.AddTool(mcp.NewTool("test-tool-1"),
@@ -269,6 +281,7 @@ func TestMCPServer_Tools(t *testing.T) {
269281
err := server.RegisterSession(&fakeSession{
270282
sessionID: "test",
271283
notificationChannel: notificationChannel,
284+
initialized: true,
272285
})
273286
require.NoError(t, err)
274287
server.SetTools(
@@ -488,12 +501,28 @@ func TestMCPServer_SendNotificationToClient(t *testing.T) {
488501
require.Error(t, srv.SendNotificationToClient(ctx, "method", nil))
489502
},
490503
},
504+
{
505+
name: "uninit session",
506+
contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context {
507+
return srv.WithContext(ctx, fakeSession{
508+
sessionID: "test",
509+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
510+
initialized: false,
511+
})
512+
},
513+
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
514+
require.Error(t, srv.SendNotificationToClient(ctx, "method", nil))
515+
_, ok := ClientSessionFromContext(ctx).(fakeSession)
516+
require.True(t, ok, "session not found or of incorrect type")
517+
},
518+
},
491519
{
492520
name: "active session",
493521
contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context {
494522
return srv.WithContext(ctx, fakeSession{
495523
sessionID: "test",
496524
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
525+
initialized: true,
497526
})
498527
},
499528
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
@@ -518,6 +547,7 @@ func TestMCPServer_SendNotificationToClient(t *testing.T) {
518547
return srv.WithContext(ctx, fakeSession{
519548
sessionID: "test",
520549
notificationChannel: make(chan mcp.JSONRPCNotification, 1),
550+
initialized: true,
521551
})
522552
},
523553
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
@@ -1051,6 +1081,7 @@ func createTestServer() *MCPServer {
10511081
type fakeSession struct {
10521082
sessionID string
10531083
notificationChannel chan mcp.JSONRPCNotification
1084+
initialized bool
10541085
}
10551086

10561087
func (f fakeSession) SessionID() string {
@@ -1061,4 +1092,11 @@ func (f fakeSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
10611092
return f.notificationChannel
10621093
}
10631094

1095+
func (f fakeSession) Initialize() {
1096+
}
1097+
1098+
func (f fakeSession) Initialized() bool {
1099+
return f.initialized
1100+
}
1101+
10641102
var _ ClientSession = fakeSession{}

server/sse.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http/httptest"
99
"strings"
1010
"sync"
11+
"sync/atomic"
1112

1213
"github.com/google/uuid"
1314
"github.com/mark3labs/mcp-go/mcp"
@@ -21,6 +22,7 @@ type sseSession struct {
2122
eventQueue chan string // Channel for queuing events
2223
sessionID string
2324
notificationChannel chan mcp.JSONRPCNotification
25+
initialized atomic.Bool
2426
}
2527

2628
// SSEContextFunc is a function that takes an existing context and the current
@@ -36,6 +38,14 @@ func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
3638
return s.notificationChannel
3739
}
3840

41+
func (s *sseSession) Initialize() {
42+
s.initialized.Store(true)
43+
}
44+
45+
func (s *sseSession) Initialized() bool {
46+
return s.initialized.Load()
47+
}
48+
3949
var _ ClientSession = (*sseSession)(nil)
4050

4151
// SSEServer implements a Server-Sent Events (SSE) based MCP server.

server/stdio.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"log"
1010
"os"
1111
"os/signal"
12+
"sync/atomic"
1213
"syscall"
1314

1415
"github.com/mark3labs/mcp-go/mcp"
@@ -51,6 +52,7 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
5152
// stdioSession is a static client session, since stdio has only one client.
5253
type stdioSession struct {
5354
notifications chan mcp.JSONRPCNotification
55+
initialized atomic.Bool
5456
}
5557

5658
func (s *stdioSession) SessionID() string {
@@ -61,6 +63,14 @@ func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
6163
return s.notifications
6264
}
6365

66+
func (s *stdioSession) Initialize() {
67+
s.initialized.Store(true)
68+
}
69+
70+
func (s *stdioSession) Initialized() bool {
71+
return s.initialized.Load()
72+
}
73+
6474
var _ ClientSession = (*stdioSession)(nil)
6575

6676
var stdioSessionInstance = stdioSession{

0 commit comments

Comments
 (0)