diff --git a/server/http_transport_options.go b/server/http_transport_options.go new file mode 100644 index 000000000..91dd875dc --- /dev/null +++ b/server/http_transport_options.go @@ -0,0 +1,189 @@ +package server + +import ( + "context" + "net/http" + "net/url" + "strings" + "time" +) + +// HTTPContextFunc is a function that takes an existing context and the current +// request and returns a potentially modified context based on the request +// content. This can be used to inject context values from headers, for example. +type HTTPContextFunc func(ctx context.Context, r *http.Request) context.Context + +// httpTransportConfigurable is an internal interface for shared HTTP transport configuration. +type httpTransportConfigurable interface { + setBasePath(string) + setDynamicBasePath(DynamicBasePathFunc) + setKeepAliveInterval(time.Duration) + setKeepAlive(bool) + setContextFunc(HTTPContextFunc) + setHTTPServer(*http.Server) + setBaseURL(string) +} + +// HTTPTransportOption is a function that configures an httpTransportConfigurable. +type HTTPTransportOption func(httpTransportConfigurable) + +// Option interfaces and wrappers for server configuration +// Base option interface +type HTTPServerOption interface { + isHTTPServerOption() +} + +// SSE-specific option interface +type SSEOption interface { + HTTPServerOption + applyToSSE(*SSEServer) +} + +// StreamableHTTP-specific option interface +type StreamableHTTPOption interface { + HTTPServerOption + applyToStreamableHTTP(*StreamableHTTPServer) +} + +// Common options that work with both server types +type CommonHTTPServerOption interface { + SSEOption + StreamableHTTPOption +} + +// Wrapper for SSE-specific functional options +type sseOption func(*SSEServer) + +func (o sseOption) isHTTPServerOption() {} +func (o sseOption) applyToSSE(s *SSEServer) { o(s) } + +// Wrapper for StreamableHTTP-specific functional options +type streamableHTTPOption func(*StreamableHTTPServer) + +func (o streamableHTTPOption) isHTTPServerOption() {} +func (o streamableHTTPOption) applyToStreamableHTTP(s *StreamableHTTPServer) { o(s) } + +// Refactor commonOption to use a single apply func(httpTransportConfigurable) +type commonOption struct { + apply func(httpTransportConfigurable) +} + +func (o commonOption) isHTTPServerOption() {} +func (o commonOption) applyToSSE(s *SSEServer) { o.apply(s) } +func (o commonOption) applyToStreamableHTTP(s *StreamableHTTPServer) { o.apply(s) } + +// TODO: This is a stub implementation of StreamableHTTPServer just to show how +// to use it with the new options interfaces. +type StreamableHTTPServer struct{} + +// Add stub methods to satisfy httpTransportConfigurable + +func (s *StreamableHTTPServer) setBasePath(string) {} +func (s *StreamableHTTPServer) setDynamicBasePath(DynamicBasePathFunc) {} +func (s *StreamableHTTPServer) setKeepAliveInterval(time.Duration) {} +func (s *StreamableHTTPServer) setKeepAlive(bool) {} +func (s *StreamableHTTPServer) setContextFunc(HTTPContextFunc) {} +func (s *StreamableHTTPServer) setHTTPServer(srv *http.Server) {} +func (s *StreamableHTTPServer) setBaseURL(baseURL string) {} + +// Ensure the option types implement the correct interfaces +var ( + _ httpTransportConfigurable = (*StreamableHTTPServer)(nil) + _ SSEOption = sseOption(nil) + _ StreamableHTTPOption = streamableHTTPOption(nil) + _ CommonHTTPServerOption = commonOption{} +) + +// WithStaticBasePath adds a new option for setting a static base path. +// This is useful for mounting the server at a known, fixed path. +func WithStaticBasePath(basePath string) CommonHTTPServerOption { + return commonOption{ + apply: func(c httpTransportConfigurable) { + c.setBasePath(basePath) + }, + } +} + +// DynamicBasePathFunc allows the user to provide a function to generate the +// base path for a given request and sessionID. This is useful for cases where +// the base path is not known at the time of SSE server creation, such as when +// using a reverse proxy or when the base path is dynamically generated. The +// function should return the base path (e.g., "/mcp/tenant123"). +type DynamicBasePathFunc func(r *http.Request, sessionID string) string + +// WithDynamicBasePath accepts a function for generating the base path. +// This is useful for cases where the base path is not known at the time of server creation, +// such as when using a reverse proxy or when the server is mounted at a dynamic path. +func WithDynamicBasePath(fn DynamicBasePathFunc) CommonHTTPServerOption { + return commonOption{ + apply: func(c httpTransportConfigurable) { + c.setDynamicBasePath(fn) + }, + } +} + +// WithKeepAliveInterval sets the keep-alive interval for the transport. +// When enabled, the server will periodically send ping events to keep the connection alive. +func WithKeepAliveInterval(interval time.Duration) CommonHTTPServerOption { + return commonOption{ + apply: func(c httpTransportConfigurable) { + c.setKeepAliveInterval(interval) + }, + } +} + +// WithKeepAlive enables or disables keep-alive for the transport. +// When enabled, the server will send periodic keep-alive events to clients. +func WithKeepAlive(keepAlive bool) CommonHTTPServerOption { + return commonOption{ + apply: func(c httpTransportConfigurable) { + c.setKeepAlive(keepAlive) + }, + } +} + +// WithHTTPContextFunc sets a function that will be called to customize the context +// for the server using the incoming request. This is useful for injecting +// context values from headers or other request properties. +func WithHTTPContextFunc(fn HTTPContextFunc) CommonHTTPServerOption { + return commonOption{ + apply: func(c httpTransportConfigurable) { + c.setContextFunc(fn) + }, + } +} + +// WithBaseURL sets the base URL for the HTTP transport server. +// This is useful for configuring the externally visible base URL for clients. +func WithBaseURL(baseURL string) CommonHTTPServerOption { + return commonOption{ + apply: func(c httpTransportConfigurable) { + if baseURL != "" { + u, err := url.Parse(baseURL) + if err != nil { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + if u.Host == "" || strings.HasPrefix(u.Host, ":") { + return + } + if len(u.Query()) > 0 { + return + } + } + c.setBaseURL(strings.TrimSuffix(baseURL, "/")) + }, + } +} + +// WithHTTPServer sets the HTTP server instance for the transport. +// This is useful for advanced scenarios where you want to provide your own http.Server. +func WithHTTPServer(srv *http.Server) CommonHTTPServerOption { + return commonOption{ + apply: func(c httpTransportConfigurable) { + c.setHTTPServer(srv) + }, + } +} diff --git a/server/sse.go b/server/sse.go index 018657e6f..81e48d0d9 100644 --- a/server/sse.go +++ b/server/sse.go @@ -36,13 +36,6 @@ type sseSession struct { // content. This can be used to inject context values from headers, for example. type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context -// DynamicBasePathFunc allows the user to provide a function to generate the -// base path for a given request and sessionID. This is useful for cases where -// the base path is not known at the time of SSE server creation, such as when -// using a reverse proxy or when the base path is dynamically generated. The -// function should return the base path (e.g., "/mcp/tenant123"). -type DynamicBasePathFunc func(r *http.Request, sessionID string) string - func (s *sseSession) SessionID() string { return s.sessionID } @@ -100,7 +93,7 @@ type SSEServer struct { sseEndpoint string sessions sync.Map srv *http.Server - contextFunc SSEContextFunc + contextFunc HTTPContextFunc dynamicBasePathFunc DynamicBasePathFunc keepAlive bool @@ -109,37 +102,41 @@ type SSEServer struct { mu sync.RWMutex } -// SSEOption defines a function type for configuring SSEServer -type SSEOption func(*SSEServer) +// Ensure SSEServer implements httpTransportConfigurable +var _ httpTransportConfigurable = (*SSEServer)(nil) -// WithBaseURL sets the base URL for the SSE server -func WithBaseURL(baseURL string) SSEOption { - return func(s *SSEServer) { - if baseURL != "" { - u, err := url.Parse(baseURL) - if err != nil { - return - } - if u.Scheme != "http" && u.Scheme != "https" { - return - } - // Check if the host is empty or only contains a port - if u.Host == "" || strings.HasPrefix(u.Host, ":") { - return - } - if len(u.Query()) > 0 { - return - } +func (s *SSEServer) setBasePath(basePath string) { + s.basePath = normalizeURLPath(basePath) +} + +func (s *SSEServer) setDynamicBasePath(fn DynamicBasePathFunc) { + if fn != nil { + s.dynamicBasePathFunc = func(r *http.Request, sid string) string { + bp := fn(r, sid) + return normalizeURLPath(bp) } - s.baseURL = strings.TrimSuffix(baseURL, "/") } } -// WithStaticBasePath adds a new option for setting a static base path -func WithStaticBasePath(basePath string) SSEOption { - return func(s *SSEServer) { - s.basePath = normalizeURLPath(basePath) - } +func (s *SSEServer) setKeepAliveInterval(interval time.Duration) { + s.keepAlive = true + s.keepAliveInterval = interval +} + +func (s *SSEServer) setKeepAlive(keepAlive bool) { + s.keepAlive = keepAlive +} + +func (s *SSEServer) setContextFunc(fn HTTPContextFunc) { + s.contextFunc = fn +} + +func (s *SSEServer) setHTTPServer(srv *http.Server) { + s.srv = srv +} + +func (s *SSEServer) setBaseURL(baseURL string) { + s.baseURL = baseURL } // WithBasePath adds a new option for setting a static base path. @@ -151,26 +148,11 @@ func WithBasePath(basePath string) SSEOption { return WithStaticBasePath(basePath) } -// WithDynamicBasePath accepts a function for generating the base path. This is -// useful for cases where the base path is not known at the time of SSE server -// creation, such as when using a reverse proxy or when the server is mounted -// at a dynamic path. -func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption { - return func(s *SSEServer) { - if fn != nil { - s.dynamicBasePathFunc = func(r *http.Request, sid string) string { - bp := fn(r, sid) - return normalizeURLPath(bp) - } - } - } -} - // WithMessageEndpoint sets the message endpoint path func WithMessageEndpoint(endpoint string) SSEOption { - return func(s *SSEServer) { + return sseOption(func(s *SSEServer) { s.messageEndpoint = endpoint - } + }) } // WithAppendQueryToMessageEndpoint configures the SSE server to append the original request's @@ -179,53 +161,37 @@ func WithMessageEndpoint(endpoint string) SSEOption { // SSE connection request and carry them over to subsequent message requests, maintaining // context or authentication details across the communication channel. func WithAppendQueryToMessageEndpoint() SSEOption { - return func(s *SSEServer) { + return sseOption(func(s *SSEServer) { s.appendQueryToMessageEndpoint = true - } + }) } // WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL) // or just the path portion for the message endpoint. Set to false when clients will concatenate // the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message". func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption { - return func(s *SSEServer) { + return sseOption(func(s *SSEServer) { s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint - } + }) } // WithSSEEndpoint sets the SSE endpoint path func WithSSEEndpoint(endpoint string) SSEOption { - return func(s *SSEServer) { + return sseOption(func(s *SSEServer) { s.sseEndpoint = endpoint - } -} - -// WithHTTPServer sets the HTTP server instance -func WithHTTPServer(srv *http.Server) SSEOption { - return func(s *SSEServer) { - s.srv = srv - } -} - -func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption { - return func(s *SSEServer) { - s.keepAlive = true - s.keepAliveInterval = keepAliveInterval - } -} - -func WithKeepAlive(keepAlive bool) SSEOption { - return func(s *SSEServer) { - s.keepAlive = keepAlive - } + }) } // WithSSEContextFunc sets a function that will be called to customise the context // to the server using the incoming request. +// +// Deprecated: Use WithContextFunc instead. This will be removed in a future version. +// +//go:deprecated func WithSSEContextFunc(fn SSEContextFunc) SSEOption { - return func(s *SSEServer) { - s.contextFunc = fn - } + return sseOption(func(s *SSEServer) { + WithHTTPContextFunc(HTTPContextFunc(fn)).applyToSSE(s) + }) } // NewSSEServer creates a new SSE server instance with the given MCP server and options. @@ -241,16 +207,15 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { // Apply all options for _, opt := range opts { - opt(s) + opt.applyToSSE(s) } return s } -// NewTestServer creates a test server for testing purposes +// NewTestServer creates a test server for testing purposes. func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { sseServer := NewSSEServer(server, opts...) - testServer := httptest.NewServer(sseServer) sseServer.baseURL = testServer.URL return testServer