Skip to content

Commit 25a4798

Browse files
committed
feat: add ping for sse server
1 parent d3dc35c commit 25a4798

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

server/sse.go

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strings"
1111
"sync"
1212
"sync/atomic"
13+
"time"
1314

1415
"github.com/google/uuid"
1516
"github.com/mark3labs/mcp-go/mcp"
@@ -60,6 +61,9 @@ type SSEServer struct {
6061
sessions sync.Map
6162
srv *http.Server
6263
contextFunc SSEContextFunc
64+
65+
keepAlive bool
66+
keepAliveInterval time.Duration
6367
}
6468

6569
// SSEOption defines a function type for configuring SSEServer
@@ -120,6 +124,19 @@ func WithHTTPServer(srv *http.Server) SSEOption {
120124
}
121125
}
122126

127+
func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption {
128+
return func(s *SSEServer) {
129+
s.keepAlive = true
130+
s.keepAliveInterval = keepAliveInterval
131+
}
132+
}
133+
134+
func WithKeepAlive(keepAlive bool) SSEOption {
135+
return func(s *SSEServer) {
136+
s.keepAlive = keepAlive
137+
}
138+
}
139+
123140
// WithContextFunc sets a function that will be called to customise the context
124141
// to the server using the incoming request.
125142
func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
@@ -131,9 +148,11 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
131148
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
132149
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
133150
s := &SSEServer{
134-
server: server,
135-
sseEndpoint: "/sse",
136-
messageEndpoint: "/message",
151+
server: server,
152+
sseEndpoint: "/sse",
153+
messageEndpoint: "/message",
154+
keepAlive: false,
155+
keepAliveInterval: 10 * time.Second,
137156
}
138157

139158
// Apply all options
@@ -244,6 +263,24 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
244263
}
245264
}()
246265

266+
// Start keep alive : ping
267+
if s.keepAlive {
268+
go func() {
269+
ticker := time.NewTicker(s.keepAliveInterval)
270+
defer ticker.Stop()
271+
for {
272+
select {
273+
case <-ticker.C:
274+
session.eventQueue <- fmt.Sprintf("event: ping\ndata: %s\n\n", time.Now().Format(time.RFC3339))
275+
case <-session.done:
276+
return
277+
case <-r.Context().Done():
278+
return
279+
}
280+
}
281+
}()
282+
}
283+
247284
messageEndpoint := fmt.Sprintf("%s?sessionId=%s", s.CompleteMessageEndpoint(), sessionID)
248285

249286
// Send the initial endpoint event

0 commit comments

Comments
 (0)