88 "mime"
99 "net/http"
1010 "net/http/httptest"
11+ "os"
1112 "strings"
1213 "sync"
1314 "sync/atomic"
@@ -93,6 +94,15 @@ func WithLogger(logger util.Logger) StreamableHTTPOption {
9394 }
9495}
9596
97+ // WithTLSCert sets the TLS certificate and key files for HTTPS support.
98+ // Both certFile and keyFile must be provided to enable TLS.
99+ func WithTLSCert (certFile , keyFile string ) StreamableHTTPOption {
100+ return func (s * StreamableHTTPServer ) {
101+ s .tlsCertFile = certFile
102+ s .tlsKeyFile = keyFile
103+ }
104+ }
105+
96106// StreamableHTTPServer implements a Streamable-http based MCP server.
97107// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
98108// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
@@ -131,6 +141,9 @@ type StreamableHTTPServer struct {
131141 listenHeartbeatInterval time.Duration
132142 logger util.Logger
133143 sessionLogLevels * sessionLogLevelsStore
144+
145+ tlsCertFile string
146+ tlsKeyFile string
134147}
135148
136149// NewStreamableHTTPServer creates a new streamable-http server instance
@@ -188,6 +201,19 @@ func (s *StreamableHTTPServer) Start(addr string) error {
188201 srv := s .httpServer
189202 s .mu .Unlock ()
190203
204+ if s .tlsCertFile != "" || s .tlsKeyFile != "" {
205+ if s .tlsCertFile == "" || s .tlsKeyFile == "" {
206+ return fmt .Errorf ("both TLS cert and key must be provided" )
207+ }
208+ if _ , err := os .Stat (s .tlsCertFile ); err != nil {
209+ return fmt .Errorf ("failed to find TLS certificate file: %w" , err )
210+ }
211+ if _ , err := os .Stat (s .tlsKeyFile ); err != nil {
212+ return fmt .Errorf ("failed to find TLS key file: %w" , err )
213+ }
214+ return srv .ListenAndServeTLS (s .tlsCertFile , s .tlsKeyFile )
215+ }
216+
191217 return srv .ListenAndServe ()
192218}
193219
@@ -237,9 +263,9 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
237263 }
238264
239265 // Check if this is a sampling response (has result/error but no method)
240- isSamplingResponse := jsonMessage .Method == "" && jsonMessage .ID != nil &&
266+ isSamplingResponse := jsonMessage .Method == "" && jsonMessage .ID != nil &&
241267 (jsonMessage .Result != nil || jsonMessage .Error != nil )
242-
268+
243269 isInitializeRequest := jsonMessage .Method == mcp .MethodInitialize
244270
245271 // Handle sampling responses separately
@@ -390,7 +416,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
390416 return
391417 }
392418 defer s .server .UnregisterSession (r .Context (), sessionID )
393-
419+
394420 // Register session for sampling response delivery
395421 s .activeSessions .Store (sessionID , session )
396422 defer s .activeSessions .Delete (sessionID )
@@ -743,18 +769,18 @@ type streamableHttpSession struct {
743769 logLevels * sessionLogLevelsStore
744770
745771 // Sampling support for bidirectional communication
746- samplingRequestChan chan samplingRequestItem // server -> client sampling requests
747- samplingRequests sync.Map // requestID -> pending sampling request context
748- requestIDCounter atomic.Int64 // for generating unique request IDs
772+ samplingRequestChan chan samplingRequestItem // server -> client sampling requests
773+ samplingRequests sync.Map // requestID -> pending sampling request context
774+ requestIDCounter atomic.Int64 // for generating unique request IDs
749775}
750776
751777func newStreamableHttpSession (sessionID string , toolStore * sessionToolsStore , levels * sessionLogLevelsStore ) * streamableHttpSession {
752778 s := & streamableHttpSession {
753- sessionID : sessionID ,
754- notificationChannel : make (chan mcp.JSONRPCNotification , 100 ),
755- tools : toolStore ,
756- logLevels : levels ,
757- samplingRequestChan : make (chan samplingRequestItem , 10 ),
779+ sessionID : sessionID ,
780+ notificationChannel : make (chan mcp.JSONRPCNotification , 100 ),
781+ tools : toolStore ,
782+ logLevels : levels ,
783+ samplingRequestChan : make (chan samplingRequestItem , 10 ),
758784 }
759785 return s
760786}
@@ -810,21 +836,21 @@ var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
810836func (s * streamableHttpSession ) RequestSampling (ctx context.Context , request mcp.CreateMessageRequest ) (* mcp.CreateMessageResult , error ) {
811837 // Generate unique request ID
812838 requestID := s .requestIDCounter .Add (1 )
813-
839+
814840 // Create response channel for this specific request
815841 responseChan := make (chan samplingResponseItem , 1 )
816-
842+
817843 // Create the sampling request item
818844 samplingRequest := samplingRequestItem {
819845 requestID : requestID ,
820846 request : request ,
821847 response : responseChan ,
822848 }
823-
849+
824850 // Store the pending request
825851 s .samplingRequests .Store (requestID , responseChan )
826852 defer s .samplingRequests .Delete (requestID )
827-
853+
828854 // Send the sampling request via the channel (non-blocking)
829855 select {
830856 case s .samplingRequestChan <- samplingRequest :
@@ -834,7 +860,7 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp
834860 default :
835861 return nil , fmt .Errorf ("sampling request queue is full - server overloaded" )
836862 }
837-
863+
838864 // Wait for response or context cancellation
839865 select {
840866 case response := <- responseChan :
0 commit comments