Skip to content

Commit 650f3c9

Browse files
committed
update session to new session tools
1 parent b464eae commit 650f3c9

File tree

2 files changed

+29
-49
lines changed

2 files changed

+29
-49
lines changed

server/streamable_http.go

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ type StreamableHTTPServer struct {
204204
baseURL string
205205
basePath string
206206
endpoint string
207-
sessions sync.Map
207+
sessions sync.Map // Maps sessionID to ClientSession
208208
srv *http.Server
209209
contextFunc SSEContextFunc
210210
sessionIDGenerator func() string
@@ -254,8 +254,10 @@ func (s *StreamableHTTPServer) Start(addr string) error {
254254
func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
255255
if s.srv != nil {
256256
s.sessions.Range(func(key, value interface{}) bool {
257-
if session, ok := value.(*streamableHTTPSession); ok {
258-
close(session.notificationChannel)
257+
if session, ok := value.(ClientSession); ok {
258+
if httpSession, ok := session.(*streamableHTTPSession); ok {
259+
close(httpSession.notificationChannel)
260+
}
259261
}
260262
s.sessions.Delete(key)
261263
return true
@@ -297,8 +299,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
297299
// Check if this is a request with a valid session
298300
if sessionID != "" {
299301
if sessionValue, ok := s.sessions.Load(sessionID); ok {
300-
if sess, ok := sessionValue.(*streamableHTTPSession); ok {
301-
session = sess
302+
if sess, ok := sessionValue.(SessionWithTools); ok {
303+
session = sess.(*streamableHTTPSession)
302304
} else {
303305
http.Error(w, "Invalid session", http.StatusBadRequest)
304306
return
@@ -449,7 +451,7 @@ func (s *StreamableHTTPServer) handleRequest(w http.ResponseWriter, r *http.Requ
449451
}
450452

451453
// handleSSEResponse sends the response as an SSE stream
452-
func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.Request, ctx context.Context, initialResponse mcp.JSONRPCMessage, session *streamableHTTPSession) {
454+
func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.Request, ctx context.Context, initialResponse mcp.JSONRPCMessage, session SessionWithTools) {
453455
// Set SSE headers
454456
w.Header().Set("Content-Type", "text/event-stream")
455457
w.Header().Set("Cache-Control", "no-cache, no-transform")
@@ -483,9 +485,10 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
483485

484486
// Check for Last-Event-ID header for resumability
485487
lastEventID := r.Header.Get("Last-Event-Id")
486-
if lastEventID != "" && session.eventStore != nil {
488+
httpSession, ok := session.(*streamableHTTPSession)
489+
if lastEventID != "" && ok && httpSession.eventStore != nil {
487490
// Replay events that occurred after the last event ID
488-
err := session.eventStore.ReplayEventsAfter(lastEventID, func(eventID string, message mcp.JSONRPCMessage) error {
491+
err := httpSession.eventStore.ReplayEventsAfter(lastEventID, func(eventID string, message mcp.JSONRPCMessage) error {
489492
data, err := json.Marshal(message)
490493
if err != nil {
491494
return err
@@ -516,9 +519,9 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
516519

517520
// Store the event if we have an event store
518521
var eventID string
519-
if session.eventStore != nil {
522+
if httpSession != nil && httpSession.eventStore != nil {
520523
var storeErr error
521-
eventID, storeErr = session.eventStore.StoreEvent(streamID, initialResponse)
524+
eventID, storeErr = httpSession.eventStore.StoreEvent(streamID, initialResponse)
522525
if storeErr != nil {
523526
// Log the error but continue
524527
fmt.Printf("Error storing event: %v\n", storeErr)
@@ -541,7 +544,7 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
541544
go func() {
542545
for {
543546
select {
544-
case notification, ok := <-session.notificationChannel:
547+
case notification, ok := <-httpSession.notificationChannel:
545548
if !ok {
546549
return
547550
}
@@ -553,9 +556,9 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
553556

554557
// Store the event if we have an event store
555558
var eventID string
556-
if session.eventStore != nil {
559+
if httpSession != nil && httpSession.eventStore != nil {
557560
var storeErr error
558-
eventID, storeErr = session.eventStore.StoreEvent(streamID, notification)
561+
eventID, storeErr = httpSession.eventStore.StoreEvent(streamID, notification)
559562
if storeErr != nil {
560563
// Log the error but continue
561564
fmt.Printf("Error storing event: %v\n", storeErr)
@@ -623,8 +626,8 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
623626
// Check if this is a request with a valid session
624627
if sessionID != "" {
625628
if sessionValue, ok := s.sessions.Load(sessionID); ok {
626-
if sess, ok := sessionValue.(*streamableHTTPSession); ok {
627-
session = sess
629+
if sess, ok := sessionValue.(SessionWithTools); ok {
630+
session = sess.(*streamableHTTPSession)
628631
} else {
629632
http.Error(w, "Invalid session", http.StatusBadRequest)
630633
return
@@ -663,7 +666,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
663666

664667
// Check for Last-Event-ID header for resumability
665668
lastEventID := r.Header.Get("Last-Event-Id")
666-
if lastEventID != "" && session.eventStore != nil {
669+
if lastEventID != "" && session != nil && session.eventStore != nil {
667670
// Replay events that occurred after the last event ID
668671
err := session.eventStore.ReplayEventsAfter(lastEventID, func(eventID string, message mcp.JSONRPCMessage) error {
669672
data, err := json.Marshal(message)
@@ -690,10 +693,13 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
690693
notifDone := make(chan struct{})
691694
defer close(notifDone)
692695

696+
// Get the concrete session type for notification channel access
697+
httpSession := session
698+
693699
go func() {
694700
for {
695701
select {
696-
case notification, ok := <-session.notificationChannel:
702+
case notification, ok := <-httpSession.notificationChannel:
697703
if !ok {
698704
return
699705
}
@@ -705,9 +711,9 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
705711

706712
// Store the event if we have an event store
707713
var eventID string
708-
if session.eventStore != nil {
714+
if httpSession != nil && httpSession.eventStore != nil {
709715
var storeErr error
710-
eventID, storeErr = session.eventStore.StoreEvent(s.standaloneStreamID, notification)
716+
eventID, storeErr = httpSession.eventStore.StoreEvent(s.standaloneStreamID, notification)
711717
if storeErr != nil {
712718
// Log the error but continue
713719
fmt.Printf("Error storing event: %v\n", storeErr)
@@ -865,7 +871,7 @@ func (s *StreamableHTTPServer) validateSession(sessionID string) bool {
865871
return false
866872
}
867873

868-
session, ok := sessionValue.(*streamableHTTPSession)
874+
session, ok := sessionValue.(ClientSession)
869875
if !ok {
870876
return false
871877
}

server/streamable_http_test.go

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ import (
99
"strings"
1010
"testing"
1111
"time"
12-
13-
"github.com/mark3labs/mcp-go/mcp"
1412
)
1513

1614
func TestStreamableHTTPServer(t *testing.T) {
@@ -256,34 +254,10 @@ func TestStreamableHTTPServer(t *testing.T) {
256254
// Wait a bit for the stream to be established
257255
time.Sleep(100 * time.Millisecond)
258256

259-
// Create a notification
260-
notification := mcp.JSONRPCNotification{
261-
JSONRPC: "2.0",
262-
Notification: mcp.Notification{
263-
Method: "test/notification",
264-
Params: mcp.NotificationParams{
265-
AdditionalFields: map[string]interface{}{
266-
"message": "Hello, world!",
267-
},
268-
},
269-
},
270-
}
271-
272-
// Find the session
273-
sessionValue, ok := streamableServer.sessions.Load(sessionID)
274-
if !ok {
275-
t.Errorf("Session not found: %s", sessionID)
276-
return
277-
}
278-
279257
// Send the notification
280-
session, ok := sessionValue.(*streamableHTTPSession)
281-
if !ok {
282-
t.Errorf("Invalid session type")
283-
return
284-
}
285-
286-
session.notificationChannel <- notification
258+
mcpServer.SendNotificationToSpecificClient(sessionID, "test/notification", map[string]interface{}{
259+
"message": "Hello, world!",
260+
})
287261
}()
288262

289263
// Read the response body

0 commit comments

Comments
 (0)