@@ -204,7 +204,7 @@ type StreamableHTTPServer struct {
204204 baseURL string
205205 basePath string
206206 endpoint string
207- sessions sync.Map
207+ sessions sync.Map // Maps sessionID to ClientSession
208208 srv * http.Server
209209 contextFunc SSEContextFunc
210210 sessionIDGenerator func () string
@@ -254,8 +254,10 @@ func (s *StreamableHTTPServer) Start(addr string) error {
254254func (s * StreamableHTTPServer ) Shutdown (ctx context.Context ) error {
255255 if s .srv != nil {
256256 s .sessions .Range (func (key , value interface {}) bool {
257- if session , ok := value .(* streamableHTTPSession ); ok {
258- close (session .notificationChannel )
257+ if session , ok := value .(ClientSession ); ok {
258+ if httpSession , ok := session .(* streamableHTTPSession ); ok {
259+ close (httpSession .notificationChannel )
260+ }
259261 }
260262 s .sessions .Delete (key )
261263 return true
@@ -297,8 +299,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
297299 // Check if this is a request with a valid session
298300 if sessionID != "" {
299301 if sessionValue , ok := s .sessions .Load (sessionID ); ok {
300- if sess , ok := sessionValue .(* streamableHTTPSession ); ok {
301- session = sess
302+ if sess , ok := sessionValue .(SessionWithTools ); ok {
303+ session = sess .( * streamableHTTPSession )
302304 } else {
303305 http .Error (w , "Invalid session" , http .StatusBadRequest )
304306 return
@@ -449,7 +451,7 @@ func (s *StreamableHTTPServer) handleRequest(w http.ResponseWriter, r *http.Requ
449451}
450452
451453// handleSSEResponse sends the response as an SSE stream
452- func (s * StreamableHTTPServer ) handleSSEResponse (w http.ResponseWriter , r * http.Request , ctx context.Context , initialResponse mcp.JSONRPCMessage , session * streamableHTTPSession ) {
454+ func (s * StreamableHTTPServer ) handleSSEResponse (w http.ResponseWriter , r * http.Request , ctx context.Context , initialResponse mcp.JSONRPCMessage , session SessionWithTools ) {
453455 // Set SSE headers
454456 w .Header ().Set ("Content-Type" , "text/event-stream" )
455457 w .Header ().Set ("Cache-Control" , "no-cache, no-transform" )
@@ -483,9 +485,10 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
483485
484486 // Check for Last-Event-ID header for resumability
485487 lastEventID := r .Header .Get ("Last-Event-Id" )
486- if lastEventID != "" && session .eventStore != nil {
488+ httpSession , ok := session .(* streamableHTTPSession )
489+ if lastEventID != "" && ok && httpSession .eventStore != nil {
487490 // Replay events that occurred after the last event ID
488- err := session .eventStore .ReplayEventsAfter (lastEventID , func (eventID string , message mcp.JSONRPCMessage ) error {
491+ err := httpSession .eventStore .ReplayEventsAfter (lastEventID , func (eventID string , message mcp.JSONRPCMessage ) error {
489492 data , err := json .Marshal (message )
490493 if err != nil {
491494 return err
@@ -516,9 +519,9 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
516519
517520 // Store the event if we have an event store
518521 var eventID string
519- if session .eventStore != nil {
522+ if httpSession != nil && httpSession .eventStore != nil {
520523 var storeErr error
521- eventID , storeErr = session .eventStore .StoreEvent (streamID , initialResponse )
524+ eventID , storeErr = httpSession .eventStore .StoreEvent (streamID , initialResponse )
522525 if storeErr != nil {
523526 // Log the error but continue
524527 fmt .Printf ("Error storing event: %v\n " , storeErr )
@@ -541,7 +544,7 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
541544 go func () {
542545 for {
543546 select {
544- case notification , ok := <- session .notificationChannel :
547+ case notification , ok := <- httpSession .notificationChannel :
545548 if ! ok {
546549 return
547550 }
@@ -553,9 +556,9 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
553556
554557 // Store the event if we have an event store
555558 var eventID string
556- if session .eventStore != nil {
559+ if httpSession != nil && httpSession .eventStore != nil {
557560 var storeErr error
558- eventID , storeErr = session .eventStore .StoreEvent (streamID , notification )
561+ eventID , storeErr = httpSession .eventStore .StoreEvent (streamID , notification )
559562 if storeErr != nil {
560563 // Log the error but continue
561564 fmt .Printf ("Error storing event: %v\n " , storeErr )
@@ -623,8 +626,8 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
623626 // Check if this is a request with a valid session
624627 if sessionID != "" {
625628 if sessionValue , ok := s .sessions .Load (sessionID ); ok {
626- if sess , ok := sessionValue .(* streamableHTTPSession ); ok {
627- session = sess
629+ if sess , ok := sessionValue .(SessionWithTools ); ok {
630+ session = sess .( * streamableHTTPSession )
628631 } else {
629632 http .Error (w , "Invalid session" , http .StatusBadRequest )
630633 return
@@ -663,7 +666,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
663666
664667 // Check for Last-Event-ID header for resumability
665668 lastEventID := r .Header .Get ("Last-Event-Id" )
666- if lastEventID != "" && session .eventStore != nil {
669+ if lastEventID != "" && session != nil && session .eventStore != nil {
667670 // Replay events that occurred after the last event ID
668671 err := session .eventStore .ReplayEventsAfter (lastEventID , func (eventID string , message mcp.JSONRPCMessage ) error {
669672 data , err := json .Marshal (message )
@@ -690,10 +693,13 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
690693 notifDone := make (chan struct {})
691694 defer close (notifDone )
692695
696+ // Get the concrete session type for notification channel access
697+ httpSession := session
698+
693699 go func () {
694700 for {
695701 select {
696- case notification , ok := <- session .notificationChannel :
702+ case notification , ok := <- httpSession .notificationChannel :
697703 if ! ok {
698704 return
699705 }
@@ -705,9 +711,9 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
705711
706712 // Store the event if we have an event store
707713 var eventID string
708- if session .eventStore != nil {
714+ if httpSession != nil && httpSession .eventStore != nil {
709715 var storeErr error
710- eventID , storeErr = session .eventStore .StoreEvent (s .standaloneStreamID , notification )
716+ eventID , storeErr = httpSession .eventStore .StoreEvent (s .standaloneStreamID , notification )
711717 if storeErr != nil {
712718 // Log the error but continue
713719 fmt .Printf ("Error storing event: %v\n " , storeErr )
@@ -865,7 +871,7 @@ func (s *StreamableHTTPServer) validateSession(sessionID string) bool {
865871 return false
866872 }
867873
868- session , ok := sessionValue .(* streamableHTTPSession )
874+ session , ok := sessionValue .(ClientSession )
869875 if ! ok {
870876 return false
871877 }
0 commit comments