@@ -10,6 +10,7 @@ import (
1010 "strings"
1111 "sync"
1212 "sync/atomic"
13+ "time"
1314
1415 "github.com/google/uuid"
1516 "github.com/mark3labs/mcp-go/mcp"
@@ -60,6 +61,9 @@ type SSEServer struct {
6061 sessions sync.Map
6162 srv * http.Server
6263 contextFunc SSEContextFunc
64+
65+ keepAlive bool
66+ keepAliveInterval time.Duration
6367}
6468
6569// SSEOption defines a function type for configuring SSEServer
@@ -120,6 +124,19 @@ func WithHTTPServer(srv *http.Server) SSEOption {
120124 }
121125}
122126
127+ func WithKeepAliveInterval (keepAliveInterval time.Duration ) SSEOption {
128+ return func (s * SSEServer ) {
129+ s .keepAlive = true
130+ s .keepAliveInterval = keepAliveInterval
131+ }
132+ }
133+
134+ func WithKeepAlive (keepAlive bool ) SSEOption {
135+ return func (s * SSEServer ) {
136+ s .keepAlive = keepAlive
137+ }
138+ }
139+
123140// WithContextFunc sets a function that will be called to customise the context
124141// to the server using the incoming request.
125142func WithSSEContextFunc (fn SSEContextFunc ) SSEOption {
@@ -131,9 +148,11 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
131148// NewSSEServer creates a new SSE server instance with the given MCP server and options.
132149func NewSSEServer (server * MCPServer , opts ... SSEOption ) * SSEServer {
133150 s := & SSEServer {
134- server : server ,
135- sseEndpoint : "/sse" ,
136- messageEndpoint : "/message" ,
151+ server : server ,
152+ sseEndpoint : "/sse" ,
153+ messageEndpoint : "/message" ,
154+ keepAlive : false ,
155+ keepAliveInterval : 10 * time .Second ,
137156 }
138157
139158 // Apply all options
@@ -244,6 +263,24 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
244263 }
245264 }()
246265
266+ // Start keep alive : ping
267+ if s .keepAlive {
268+ go func () {
269+ ticker := time .NewTicker (s .keepAliveInterval )
270+ defer ticker .Stop ()
271+ for {
272+ select {
273+ case <- ticker .C :
274+ session .eventQueue <- fmt .Sprintf ("event: ping\n data: %s\n \n " , time .Now ().Format (time .RFC3339 ))
275+ case <- session .done :
276+ return
277+ case <- r .Context ().Done ():
278+ return
279+ }
280+ }
281+ }()
282+ }
283+
247284 messageEndpoint := fmt .Sprintf ("%s?sessionId=%s" , s .CompleteMessageEndpoint (), sessionID )
248285
249286 // Send the initial endpoint event
0 commit comments