Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"mime"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -93,6 +94,15 @@ func WithLogger(logger util.Logger) StreamableHTTPOption {
}
}

// WithTLSCert sets the TLS certificate and key files for HTTPS support.
// Both certFile and keyFile must be provided to enable TLS.
func WithTLSCert(certFile, keyFile string) StreamableHTTPOption {
return func(s *StreamableHTTPServer) {
s.tlsCertFile = certFile
s.tlsKeyFile = keyFile
}
}

// StreamableHTTPServer implements a Streamable-http based MCP server.
// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
Expand Down Expand Up @@ -131,6 +141,9 @@ type StreamableHTTPServer struct {
listenHeartbeatInterval time.Duration
logger util.Logger
sessionLogLevels *sessionLogLevelsStore

tlsCertFile string
tlsKeyFile string
}

// NewStreamableHTTPServer creates a new streamable-http server instance
Expand Down Expand Up @@ -188,6 +201,10 @@ func (s *StreamableHTTPServer) Start(addr string) error {
srv := s.httpServer
s.mu.Unlock()

if s.canUseTLS() {
return srv.ListenAndServeTLS(s.tlsCertFile, s.tlsKeyFile)
}

return srv.ListenAndServe()
}

Expand Down Expand Up @@ -237,9 +254,9 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
}

// Check if this is a sampling response (has result/error but no method)
isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil &&
isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil &&
(jsonMessage.Result != nil || jsonMessage.Error != nil)

isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize

// Handle sampling responses separately
Expand Down Expand Up @@ -390,7 +407,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
return
}
defer s.server.UnregisterSession(r.Context(), sessionID)

// Register session for sampling response delivery
s.activeSessions.Store(sessionID, session)
defer s.activeSessions.Delete(sessionID)
Expand Down Expand Up @@ -656,6 +673,28 @@ func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 {
return counter.Add(1)
}

// canUseTLS checks if TLS is properly configured and files are valid
func (s *StreamableHTTPServer) canUseTLS() bool {
// Not configured
if s.tlsCertFile == "" || s.tlsKeyFile == "" {
return false
}

// Check certificate file
if _, err := os.Stat(s.tlsCertFile); err != nil {
s.logger.Errorf("TLS certificate file error: %v", err)
return false
}

// Check key file
if _, err := os.Stat(s.tlsKeyFile); err != nil {
s.logger.Errorf("TLS key file error: %v", err)
return false
}

return true
}

// --- session ---
type sessionLogLevelsStore struct {
mu sync.RWMutex
Expand Down Expand Up @@ -743,18 +782,18 @@ type streamableHttpSession struct {
logLevels *sessionLogLevelsStore

// Sampling support for bidirectional communication
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
samplingRequests sync.Map // requestID -> pending sampling request context
requestIDCounter atomic.Int64 // for generating unique request IDs
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
samplingRequests sync.Map // requestID -> pending sampling request context
requestIDCounter atomic.Int64 // for generating unique request IDs
}

func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession {
s := &streamableHttpSession{
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
tools: toolStore,
logLevels: levels,
samplingRequestChan: make(chan samplingRequestItem, 10),
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
tools: toolStore,
logLevels: levels,
samplingRequestChan: make(chan samplingRequestItem, 10),
}
return s
}
Expand Down Expand Up @@ -810,21 +849,21 @@ var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
// Generate unique request ID
requestID := s.requestIDCounter.Add(1)

// Create response channel for this specific request
responseChan := make(chan samplingResponseItem, 1)

// Create the sampling request item
samplingRequest := samplingRequestItem{
requestID: requestID,
request: request,
response: responseChan,
}

// Store the pending request
s.samplingRequests.Store(requestID, responseChan)
defer s.samplingRequests.Delete(requestID)

// Send the sampling request via the channel (non-blocking)
select {
case s.samplingRequestChan <- samplingRequest:
Expand All @@ -834,7 +873,7 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp
default:
return nil, fmt.Errorf("sampling request queue is full - server overloaded")
}

// Wait for response or context cancellation
select {
case response := <-responseChan:
Expand Down
131 changes: 131 additions & 0 deletions server/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -894,6 +895,136 @@ func TestStreamableHTTP_HeaderPassthrough(t *testing.T) {
}
}

func TestStreamableHTTPServer_TLS(t *testing.T) {
t.Run("TLS options are set correctly", func(t *testing.T) {
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
certFile := "/path/to/cert.pem"
keyFile := "/path/to/key.pem"

server := NewStreamableHTTPServer(
mcpServer,
WithTLSCert(certFile, keyFile),
)

if server.tlsCertFile != certFile {
t.Errorf("Expected tlsCertFile to be %s, got %s", certFile, server.tlsCertFile)
}
if server.tlsKeyFile != keyFile {
t.Errorf("Expected tlsKeyFile to be %s, got %s", keyFile, server.tlsKeyFile)
}
})

t.Run("canUseTLS returns false when TLS not configured", func(t *testing.T) {
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
server := NewStreamableHTTPServer(mcpServer)

if server.canUseTLS() {
t.Error("Expected canUseTLS to return false when TLS is not configured")
}
})

t.Run("canUseTLS returns false when only cert file is set", func(t *testing.T) {
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
server := NewStreamableHTTPServer(
mcpServer,
WithTLSCert("/path/to/cert.pem", ""),
)

if server.canUseTLS() {
t.Error("Expected canUseTLS to return false when only cert file is set")
}
})

t.Run("canUseTLS returns false when only key file is set", func(t *testing.T) {
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
server := NewStreamableHTTPServer(
mcpServer,
WithTLSCert("", "/path/to/key.pem"),
)

if server.canUseTLS() {
t.Error("Expected canUseTLS to return false when only key file is set")
}
})

t.Run("canUseTLS returns false when cert file doesn't exist", func(t *testing.T) {
mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
server := NewStreamableHTTPServer(
mcpServer,
WithTLSCert("/nonexistent/cert.pem", "/nonexistent/key.pem"),
)

if server.canUseTLS() {
t.Error("Expected canUseTLS to return false when cert file doesn't exist")
}
})

t.Run("canUseTLS returns true with valid temp cert and key files", func(t *testing.T) {
// Create temporary cert and key files for testing
certFile, err := os.CreateTemp("", "test-cert-*.pem")
if err != nil {
t.Fatalf("Failed to create temp cert file: %v", err)
}
defer os.Remove(certFile.Name())
certFile.Close()

keyFile, err := os.CreateTemp("", "test-key-*.pem")
if err != nil {
t.Fatalf("Failed to create temp key file: %v", err)
}
defer os.Remove(keyFile.Name())
keyFile.Close()

mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
server := NewStreamableHTTPServer(
mcpServer,
WithTLSCert(certFile.Name(), keyFile.Name()),
)

if !server.canUseTLS() {
t.Error("Expected canUseTLS to return true with valid cert and key files")
}
})

t.Run("canUseTLS returns false when key file exists but cert doesn't", func(t *testing.T) {
keyFile, err := os.CreateTemp("", "test-key-*.pem")
if err != nil {
t.Fatalf("Failed to create temp key file: %v", err)
}
defer os.Remove(keyFile.Name())
keyFile.Close()

mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
server := NewStreamableHTTPServer(
mcpServer,
WithTLSCert("/nonexistent/cert.pem", keyFile.Name()),
)

if server.canUseTLS() {
t.Error("Expected canUseTLS to return false when cert file doesn't exist")
}
})

t.Run("canUseTLS returns false when cert file exists but key doesn't", func(t *testing.T) {
certFile, err := os.CreateTemp("", "test-cert-*.pem")
if err != nil {
t.Fatalf("Failed to create temp cert file: %v", err)
}
defer os.Remove(certFile.Name())
certFile.Close()

mcpServer := NewMCPServer("test-mcp-server", "1.0.0")
server := NewStreamableHTTPServer(
mcpServer,
WithTLSCert(certFile.Name(), "/nonexistent/key.pem"),
)

if server.canUseTLS() {
t.Error("Expected canUseTLS to return false when key file doesn't exist")
}
})
}

func postJSON(url string, bodyObject any) (*http.Response, error) {
jsonBody, _ := json.Marshal(bodyObject)
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))
Expand Down