diff --git a/server/streamable_http.go b/server/streamable_http.go index 24ec1c95a..c97d9b747 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -8,6 +8,7 @@ import ( "mime" "net/http" "net/http/httptest" + "os" "strings" "sync" "sync/atomic" @@ -93,6 +94,15 @@ func WithLogger(logger util.Logger) StreamableHTTPOption { } } +// WithTLSCert sets the TLS certificate and key files for HTTPS support. +// Both certFile and keyFile must be provided to enable TLS. +func WithTLSCert(certFile, keyFile string) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.tlsCertFile = certFile + s.tlsKeyFile = keyFile + } +} + // StreamableHTTPServer implements a Streamable-http based MCP server. // It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams. // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http @@ -131,6 +141,9 @@ type StreamableHTTPServer struct { listenHeartbeatInterval time.Duration logger util.Logger sessionLogLevels *sessionLogLevelsStore + + tlsCertFile string + tlsKeyFile string } // NewStreamableHTTPServer creates a new streamable-http server instance @@ -188,6 +201,19 @@ func (s *StreamableHTTPServer) Start(addr string) error { srv := s.httpServer s.mu.Unlock() + if s.tlsCertFile != "" || s.tlsKeyFile != "" { + if s.tlsCertFile == "" || s.tlsKeyFile == "" { + return fmt.Errorf("both TLS cert and key must be provided") + } + if _, err := os.Stat(s.tlsCertFile); err != nil { + return fmt.Errorf("failed to find TLS certificate file: %w", err) + } + if _, err := os.Stat(s.tlsKeyFile); err != nil { + return fmt.Errorf("failed to find TLS key file: %w", err) + } + return srv.ListenAndServeTLS(s.tlsCertFile, s.tlsKeyFile) + } + return srv.ListenAndServe() } @@ -237,9 +263,9 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } // Check if this is a sampling response (has result/error but no method) - isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && + isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && (jsonMessage.Result != nil || jsonMessage.Error != nil) - + isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize // Handle sampling responses separately @@ -390,7 +416,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) return } defer s.server.UnregisterSession(r.Context(), sessionID) - + // Register session for sampling response delivery s.activeSessions.Store(sessionID, session) defer s.activeSessions.Delete(sessionID) @@ -743,18 +769,18 @@ type streamableHttpSession struct { logLevels *sessionLogLevelsStore // Sampling support for bidirectional communication - samplingRequestChan chan samplingRequestItem // server -> client sampling requests - samplingRequests sync.Map // requestID -> pending sampling request context - requestIDCounter atomic.Int64 // for generating unique request IDs + samplingRequestChan chan samplingRequestItem // server -> client sampling requests + samplingRequests sync.Map // requestID -> pending sampling request context + requestIDCounter atomic.Int64 // for generating unique request IDs } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession { s := &streamableHttpSession{ - sessionID: sessionID, - notificationChannel: make(chan mcp.JSONRPCNotification, 100), - tools: toolStore, - logLevels: levels, - samplingRequestChan: make(chan samplingRequestItem, 10), + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + logLevels: levels, + samplingRequestChan: make(chan samplingRequestItem, 10), } return s } @@ -810,21 +836,21 @@ var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { // Generate unique request ID requestID := s.requestIDCounter.Add(1) - + // Create response channel for this specific request responseChan := make(chan samplingResponseItem, 1) - + // Create the sampling request item samplingRequest := samplingRequestItem{ requestID: requestID, request: request, response: responseChan, } - + // Store the pending request s.samplingRequests.Store(requestID, responseChan) defer s.samplingRequests.Delete(requestID) - + // Send the sampling request via the channel (non-blocking) select { case s.samplingRequestChan <- samplingRequest: @@ -834,7 +860,7 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp default: return nil, fmt.Errorf("sampling request queue is full - server overloaded") } - + // Wait for response or context cancellation select { case response := <-responseChan: diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 105fd18ce..b464e1bdd 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -894,6 +894,26 @@ func TestStreamableHTTP_HeaderPassthrough(t *testing.T) { } } +func TestStreamableHTTPServer_TLS(t *testing.T) { + t.Run("TLS options are set correctly", func(t *testing.T) { + mcpServer := NewMCPServer("test-mcp-server", "1.0.0") + certFile := "/path/to/cert.pem" + keyFile := "/path/to/key.pem" + + server := NewStreamableHTTPServer( + mcpServer, + WithTLSCert(certFile, keyFile), + ) + + if server.tlsCertFile != certFile { + t.Errorf("Expected tlsCertFile to be %s, got %s", certFile, server.tlsCertFile) + } + if server.tlsKeyFile != keyFile { + t.Errorf("Expected tlsKeyFile to be %s, got %s", keyFile, server.tlsKeyFile) + } + }) +} + func postJSON(url string, bodyObject any) (*http.Response, error) { jsonBody, _ := json.Marshal(bodyObject) req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) diff --git a/www/docs/pages/servers/basics.mdx b/www/docs/pages/servers/basics.mdx index 7b33f33ce..b83bbdfab 100644 --- a/www/docs/pages/servers/basics.mdx +++ b/www/docs/pages/servers/basics.mdx @@ -182,6 +182,7 @@ Configure transport-specific options: httpServer := server.NewStreamableHTTPServer(s, server.WithEndpointPath("/mcp"), server.WithStateless(true), + server.WithTLSCert("/path/to/cert.pem", "/path/to/key.pem"), ) if err := httpServer.Start(":8080"); err != nil {