Skip to content

Commit 824376e

Browse files
authored
handle initializations per-session rather than globally (#60)
1 parent 2bd7076 commit 824376e

File tree

4 files changed

+71
-15
lines changed

4 files changed

+71
-15
lines changed

server/server.go

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"fmt"
99
"sort"
1010
"sync"
11-
"sync/atomic"
1211

1312
"github.com/mark3labs/mcp-go/mcp"
1413
)
@@ -48,6 +47,10 @@ type ServerTool struct {
4847

4948
// ClientSession represents an active session that can be used by MCPServer to interact with client.
5049
type ClientSession interface {
50+
// Initialize marks session as fully initialized and ready for notifications
51+
Initialize()
52+
// Initialized returns if session is ready to accept notifications
53+
Initialized() bool
5154
// NotificationChannel provides a channel suitable for sending notifications to client.
5255
NotificationChannel() chan<- mcp.JSONRPCNotification
5356
// SessionID is a unique identifier used to track user session.
@@ -145,7 +148,6 @@ type MCPServer struct {
145148
notificationHandlers map[string]NotificationHandlerFunc
146149
capabilities serverCapabilities
147150
sessions sync.Map
148-
initialized atomic.Bool // Use atomic for the initialized flag
149151
hooks *Hooks
150152
}
151153

@@ -202,7 +204,7 @@ func (s *MCPServer) sendNotificationToAllClients(
202204
}
203205

204206
s.sessions.Range(func(k, v any) bool {
205-
if session, ok := v.(ClientSession); ok {
207+
if session, ok := v.(ClientSession); ok && session.Initialized() {
206208
select {
207209
case session.NotificationChannel() <- notification:
208210
default:
@@ -220,7 +222,7 @@ func (s *MCPServer) SendNotificationToClient(
220222
params map[string]any,
221223
) error {
222224
session := ClientSessionFromContext(ctx)
223-
if session == nil {
225+
if session == nil || !session.Initialized() {
224226
return fmt.Errorf("notification channel not initialized")
225227
}
226228

@@ -406,13 +408,10 @@ func (s *MCPServer) AddTools(tools ...ServerTool) {
406408
for _, entry := range tools {
407409
s.tools[entry.Tool.Name] = entry
408410
}
409-
initialized := s.initialized.Load()
410411
s.mu.Unlock()
411412

412-
// Send notification if server is already initialized
413-
if initialized {
414-
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
415-
}
413+
// Send notification to all initialized sessions
414+
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
416415
}
417416

418417
// SetTools replaces all existing tools with the provided list
@@ -429,13 +428,10 @@ func (s *MCPServer) DeleteTools(names ...string) {
429428
for _, name := range names {
430429
delete(s.tools, name)
431430
}
432-
initialized := s.initialized.Load()
433431
s.mu.Unlock()
434432

435-
// Send notification if server is already initialized
436-
if initialized {
437-
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
438-
}
433+
// Send notification to all initialized sessions
434+
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
439435
}
440436

441437
// AddNotificationHandler registers a new handler for incoming notifications
@@ -498,7 +494,9 @@ func (s *MCPServer) handleInitialize(
498494
Instructions: s.instructions,
499495
}
500496

501-
s.initialized.Store(true)
497+
if session := ClientSessionFromContext(ctx); session != nil {
498+
session.Initialize()
499+
}
502500
return &result, nil
503501
}
504502

server/server_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ func TestMCPServer_Tools(t *testing.T) {
181181
err := server.RegisterSession(&fakeSession{
182182
sessionID: "test",
183183
notificationChannel: notificationChannel,
184+
initialized: true,
184185
})
185186
require.NoError(t, err)
186187
server.SetTools(ServerTool{
@@ -211,6 +212,16 @@ func TestMCPServer_Tools(t *testing.T) {
211212
err := server.RegisterSession(&fakeSession{
212213
sessionID: fmt.Sprintf("test%d", i),
213214
notificationChannel: notificationChannel,
215+
initialized: true,
216+
})
217+
require.NoError(t, err)
218+
}
219+
// also let's register inactive sessions
220+
for i := range 5 {
221+
err := server.RegisterSession(&fakeSession{
222+
sessionID: fmt.Sprintf("test%d", i+5),
223+
notificationChannel: notificationChannel,
224+
initialized: false,
214225
})
215226
require.NoError(t, err)
216227
}
@@ -243,6 +254,7 @@ func TestMCPServer_Tools(t *testing.T) {
243254
err := server.RegisterSession(&fakeSession{
244255
sessionID: "test",
245256
notificationChannel: notificationChannel,
257+
initialized: true,
246258
})
247259
require.NoError(t, err)
248260
server.AddTool(mcp.NewTool("test-tool-1"),
@@ -270,6 +282,7 @@ func TestMCPServer_Tools(t *testing.T) {
270282
err := server.RegisterSession(&fakeSession{
271283
sessionID: "test",
272284
notificationChannel: notificationChannel,
285+
initialized: true,
273286
})
274287
require.NoError(t, err)
275288
server.SetTools(
@@ -489,12 +502,28 @@ func TestMCPServer_SendNotificationToClient(t *testing.T) {
489502
require.Error(t, srv.SendNotificationToClient(ctx, "method", nil))
490503
},
491504
},
505+
{
506+
name: "uninit session",
507+
contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context {
508+
return srv.WithContext(ctx, fakeSession{
509+
sessionID: "test",
510+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
511+
initialized: false,
512+
})
513+
},
514+
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
515+
require.Error(t, srv.SendNotificationToClient(ctx, "method", nil))
516+
_, ok := ClientSessionFromContext(ctx).(fakeSession)
517+
require.True(t, ok, "session not found or of incorrect type")
518+
},
519+
},
492520
{
493521
name: "active session",
494522
contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context {
495523
return srv.WithContext(ctx, fakeSession{
496524
sessionID: "test",
497525
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
526+
initialized: true,
498527
})
499528
},
500529
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
@@ -519,6 +548,7 @@ func TestMCPServer_SendNotificationToClient(t *testing.T) {
519548
return srv.WithContext(ctx, fakeSession{
520549
sessionID: "test",
521550
notificationChannel: make(chan mcp.JSONRPCNotification, 1),
551+
initialized: true,
522552
})
523553
},
524554
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
@@ -1136,6 +1166,7 @@ func createTestServer() *MCPServer {
11361166
type fakeSession struct {
11371167
sessionID string
11381168
notificationChannel chan mcp.JSONRPCNotification
1169+
initialized bool
11391170
}
11401171

11411172
func (f fakeSession) SessionID() string {
@@ -1146,6 +1177,13 @@ func (f fakeSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
11461177
return f.notificationChannel
11471178
}
11481179

1180+
func (f fakeSession) Initialize() {
1181+
}
1182+
1183+
func (f fakeSession) Initialized() bool {
1184+
return f.initialized
1185+
}
1186+
11491187
var _ ClientSession = fakeSession{}
11501188

11511189
func TestMCPServer_WithHooks(t *testing.T) {

server/sse.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/url"
1010
"strings"
1111
"sync"
12+
"sync/atomic"
1213

1314
"github.com/google/uuid"
1415
"github.com/mark3labs/mcp-go/mcp"
@@ -22,6 +23,7 @@ type sseSession struct {
2223
eventQueue chan string // Channel for queuing events
2324
sessionID string
2425
notificationChannel chan mcp.JSONRPCNotification
26+
initialized atomic.Bool
2527
}
2628

2729
// SSEContextFunc is a function that takes an existing context and the current
@@ -37,6 +39,14 @@ func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
3739
return s.notificationChannel
3840
}
3941

42+
func (s *sseSession) Initialize() {
43+
s.initialized.Store(true)
44+
}
45+
46+
func (s *sseSession) Initialized() bool {
47+
return s.initialized.Load()
48+
}
49+
4050
var _ ClientSession = (*sseSession)(nil)
4151

4252
// SSEServer implements a Server-Sent Events (SSE) based MCP server.

server/stdio.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"log"
1010
"os"
1111
"os/signal"
12+
"sync/atomic"
1213
"syscall"
1314

1415
"github.com/mark3labs/mcp-go/mcp"
@@ -51,6 +52,7 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
5152
// stdioSession is a static client session, since stdio has only one client.
5253
type stdioSession struct {
5354
notifications chan mcp.JSONRPCNotification
55+
initialized atomic.Bool
5456
}
5557

5658
func (s *stdioSession) SessionID() string {
@@ -61,6 +63,14 @@ func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
6163
return s.notifications
6264
}
6365

66+
func (s *stdioSession) Initialize() {
67+
s.initialized.Store(true)
68+
}
69+
70+
func (s *stdioSession) Initialized() bool {
71+
return s.initialized.Load()
72+
}
73+
6474
var _ ClientSession = (*stdioSession)(nil)
6575

6676
var stdioSessionInstance = stdioSession{

0 commit comments

Comments
 (0)