diff --git a/examples/server/sse/main.go b/examples/server/sse/main.go index 27f9caed..0507dd60 100644 --- a/examples/server/sse/main.go +++ b/examples/server/sse/main.go @@ -65,6 +65,6 @@ func main() { default: return nil } - }) + }, nil) log.Fatal(http.ListenAndServe(addr, handler)) } diff --git a/mcp/sse.go b/mcp/sse.go index f39a0397..7f644918 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -43,12 +43,18 @@ import ( // [2024-11-05 version]: https://modelcontextprotocol.io/specification/2024-11-05/basic/transports type SSEHandler struct { getServer func(request *http.Request) *Server + opts SSEOptions onConnection func(*ServerSession) // for testing; must not block mu sync.Mutex sessions map[string]*SSEServerTransport } +// SSEOptions specifies options for an [SSEHandler]. +// for now, it is empty, but may be extended in future. +// https://github.com/modelcontextprotocol/go-sdk/issues/507 +type SSEOptions struct{} + // NewSSEHandler returns a new [SSEHandler] that creates and manages MCP // sessions created via incoming HTTP requests. // @@ -62,13 +68,17 @@ type SSEHandler struct { // The getServer function may return a distinct [Server] for each new // request, or reuse an existing server. If it returns nil, the handler // will return a 400 Bad Request. -// -// TODO(rfindley): add options. -func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { - return &SSEHandler{ +func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptions) *SSEHandler { + s := &SSEHandler{ getServer: getServer, sessions: make(map[string]*SSEServerTransport), } + + if opts != nil { + s.opts = *opts + } + + return s } // A SSEServerTransport is a logical SSE session created through a hanging GET diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index d06ea62b..6132d31e 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -31,7 +31,7 @@ func ExampleSSEHandler() { server := mcp.NewServer(&mcp.Implementation{Name: "adder", Version: "v0.0.1"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add two numbers"}, Add) - handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }) + handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }, nil) httpServer := httptest.NewServer(handler) defer httpServer.Close() diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 32a20bf3..25435ff3 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -24,7 +24,7 @@ func TestSSEServer(t *testing.T) { server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet"}, sayHi) - sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) + sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil) serverSessions := make(chan *ServerSession, 1) sseHandler.onConnection = func(ss *ServerSession) {