Skip to content

Commit 40ce109

Browse files
authored
feat: add tls support for streamable-http (#568)
* tests tls * log * fail * clean * docs * nit * nit
1 parent ef80a50 commit 40ce109

File tree

3 files changed

+63
-16
lines changed

3 files changed

+63
-16
lines changed

server/streamable_http.go

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"mime"
99
"net/http"
1010
"net/http/httptest"
11+
"os"
1112
"strings"
1213
"sync"
1314
"sync/atomic"
@@ -93,6 +94,15 @@ func WithLogger(logger util.Logger) StreamableHTTPOption {
9394
}
9495
}
9596

97+
// WithTLSCert sets the TLS certificate and key files for HTTPS support.
98+
// Both certFile and keyFile must be provided to enable TLS.
99+
func WithTLSCert(certFile, keyFile string) StreamableHTTPOption {
100+
return func(s *StreamableHTTPServer) {
101+
s.tlsCertFile = certFile
102+
s.tlsKeyFile = keyFile
103+
}
104+
}
105+
96106
// StreamableHTTPServer implements a Streamable-http based MCP server.
97107
// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
98108
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
@@ -131,6 +141,9 @@ type StreamableHTTPServer struct {
131141
listenHeartbeatInterval time.Duration
132142
logger util.Logger
133143
sessionLogLevels *sessionLogLevelsStore
144+
145+
tlsCertFile string
146+
tlsKeyFile string
134147
}
135148

136149
// NewStreamableHTTPServer creates a new streamable-http server instance
@@ -188,6 +201,19 @@ func (s *StreamableHTTPServer) Start(addr string) error {
188201
srv := s.httpServer
189202
s.mu.Unlock()
190203

204+
if s.tlsCertFile != "" || s.tlsKeyFile != "" {
205+
if s.tlsCertFile == "" || s.tlsKeyFile == "" {
206+
return fmt.Errorf("both TLS cert and key must be provided")
207+
}
208+
if _, err := os.Stat(s.tlsCertFile); err != nil {
209+
return fmt.Errorf("failed to find TLS certificate file: %w", err)
210+
}
211+
if _, err := os.Stat(s.tlsKeyFile); err != nil {
212+
return fmt.Errorf("failed to find TLS key file: %w", err)
213+
}
214+
return srv.ListenAndServeTLS(s.tlsCertFile, s.tlsKeyFile)
215+
}
216+
191217
return srv.ListenAndServe()
192218
}
193219

@@ -237,9 +263,9 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
237263
}
238264

239265
// Check if this is a sampling response (has result/error but no method)
240-
isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil &&
266+
isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil &&
241267
(jsonMessage.Result != nil || jsonMessage.Error != nil)
242-
268+
243269
isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize
244270

245271
// Handle sampling responses separately
@@ -390,7 +416,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
390416
return
391417
}
392418
defer s.server.UnregisterSession(r.Context(), sessionID)
393-
419+
394420
// Register session for sampling response delivery
395421
s.activeSessions.Store(sessionID, session)
396422
defer s.activeSessions.Delete(sessionID)
@@ -743,18 +769,18 @@ type streamableHttpSession struct {
743769
logLevels *sessionLogLevelsStore
744770

745771
// Sampling support for bidirectional communication
746-
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
747-
samplingRequests sync.Map // requestID -> pending sampling request context
748-
requestIDCounter atomic.Int64 // for generating unique request IDs
772+
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
773+
samplingRequests sync.Map // requestID -> pending sampling request context
774+
requestIDCounter atomic.Int64 // for generating unique request IDs
749775
}
750776

751777
func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession {
752778
s := &streamableHttpSession{
753-
sessionID: sessionID,
754-
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
755-
tools: toolStore,
756-
logLevels: levels,
757-
samplingRequestChan: make(chan samplingRequestItem, 10),
779+
sessionID: sessionID,
780+
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
781+
tools: toolStore,
782+
logLevels: levels,
783+
samplingRequestChan: make(chan samplingRequestItem, 10),
758784
}
759785
return s
760786
}
@@ -810,21 +836,21 @@ var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
810836
func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
811837
// Generate unique request ID
812838
requestID := s.requestIDCounter.Add(1)
813-
839+
814840
// Create response channel for this specific request
815841
responseChan := make(chan samplingResponseItem, 1)
816-
842+
817843
// Create the sampling request item
818844
samplingRequest := samplingRequestItem{
819845
requestID: requestID,
820846
request: request,
821847
response: responseChan,
822848
}
823-
849+
824850
// Store the pending request
825851
s.samplingRequests.Store(requestID, responseChan)
826852
defer s.samplingRequests.Delete(requestID)
827-
853+
828854
// Send the sampling request via the channel (non-blocking)
829855
select {
830856
case s.samplingRequestChan <- samplingRequest:
@@ -834,7 +860,7 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp
834860
default:
835861
return nil, fmt.Errorf("sampling request queue is full - server overloaded")
836862
}
837-
863+
838864
// Wait for response or context cancellation
839865
select {
840866
case response := <-responseChan:

server/streamable_http_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,26 @@ func TestStreamableHTTP_HeaderPassthrough(t *testing.T) {
894894
}
895895
}
896896

897+
func TestStreamableHTTPServer_TLS(t *testing.T) {
898+
t.Run("TLS options are set correctly", func(t *testing.T) {
899+
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
900+
certFile := "/path/to/cert.pem"
901+
keyFile := "/path/to/key.pem"
902+
903+
server := NewStreamableHTTPServer(
904+
mcpServer,
905+
WithTLSCert(certFile, keyFile),
906+
)
907+
908+
if server.tlsCertFile != certFile {
909+
t.Errorf("Expected tlsCertFile to be %s, got %s", certFile, server.tlsCertFile)
910+
}
911+
if server.tlsKeyFile != keyFile {
912+
t.Errorf("Expected tlsKeyFile to be %s, got %s", keyFile, server.tlsKeyFile)
913+
}
914+
})
915+
}
916+
897917
func postJSON(url string, bodyObject any) (*http.Response, error) {
898918
jsonBody, _ := json.Marshal(bodyObject)
899919
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))

www/docs/pages/servers/basics.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ Configure transport-specific options:
182182
httpServer := server.NewStreamableHTTPServer(s,
183183
server.WithEndpointPath("/mcp"),
184184
server.WithStateless(true),
185+
server.WithTLSCert("/path/to/cert.pem", "/path/to/key.pem"),
185186
)
186187

187188
if err := httpServer.Start(":8080"); err != nil {

0 commit comments

Comments
 (0)