Skip to content

Commit 0596c88

Browse files
committed
feat(srv/stream): Require Protocol Version Header in HTTP
1 parent 0fdb197 commit 0596c88

File tree

3 files changed

+254
-20
lines changed

3 files changed

+254
-20
lines changed

mcp/types.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ type JSONRPCMessage any
9898
// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol.
9999
const LATEST_PROTOCOL_VERSION = "2025-03-26"
100100

101+
// The default negotiated version of the Model Context Protocol when no version is specified
102+
// Reference: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header
103+
const DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
104+
101105
// ValidProtocolVersions lists all known valid MCP protocol versions.
102106
var ValidProtocolVersions = []string{
103107
"2024-11-05",

server/streamable_http.go

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"net/http"
99
"net/http/httptest"
10+
"slices"
1011
"strings"
1112
"sync"
1213
"sync/atomic"
@@ -205,7 +206,8 @@ func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
205206
// --- internal methods ---
206207

207208
const (
208-
headerKeySessionID = "Mcp-Session-Id"
209+
headerKeySessionID = "Mcp-Session-Id"
210+
headerKeyProtocolVersion = "MCP-Protocol-Version"
209211
)
210212

211213
func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {
@@ -240,19 +242,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
240242
if isInitializeRequest {
241243
// generate a new one for initialize request
242244
sessionID = s.sessionIdManager.Generate()
243-
} else {
244-
// Get session ID from header.
245-
// Stateful servers need the client to carry the session ID.
246-
sessionID = r.Header.Get(headerKeySessionID)
247-
isTerminated, err := s.sessionIdManager.Validate(sessionID)
248-
if err != nil {
249-
http.Error(w, "Invalid session ID", http.StatusBadRequest)
250-
return
251-
}
252-
if isTerminated {
253-
http.Error(w, "Session terminated", http.StatusNotFound)
254-
return
255-
}
245+
} else if !s.validateRequestHeaders(w, r) {
246+
return
256247
}
257248

258249
session := newStreamableHttpSession(sessionID, s.sessionTools)
@@ -355,10 +346,11 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
355346
func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {
356347
// get request is for listening to notifications
357348
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
349+
if !s.validateRequestHeaders(w, r) {
350+
return
351+
}
358352

359353
sessionID := r.Header.Get(headerKeySessionID)
360-
// the specification didn't say we should validate the session id
361-
362354
if sessionID == "" {
363355
// It's a stateless server,
364356
// but the MCP server requires a unique ID for registering, so we use a random one
@@ -454,6 +446,10 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
454446
}
455447

456448
func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) {
449+
if !s.validateRequestHeaders(w, r) {
450+
return
451+
}
452+
457453
// delete request terminate the session
458454
sessionID := r.Header.Get(headerKeySessionID)
459455
notAllowed, err := s.sessionIdManager.Terminate(sessionID)
@@ -510,6 +506,55 @@ func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 {
510506
return counter.Add(1)
511507
}
512508

509+
func (s *StreamableHTTPServer) validateRequestHeaders(w http.ResponseWriter, r *http.Request) bool {
510+
if !s.validateSession(w, r) {
511+
return false
512+
}
513+
if !s.validateProtocolVersion(w, r) {
514+
return false
515+
}
516+
return true
517+
}
518+
519+
// validateSession validates the validates the session ID in the request.
520+
func (s *StreamableHTTPServer) validateSession(w http.ResponseWriter, r *http.Request) bool {
521+
// Get session ID from header.
522+
// Stateful servers need the client to carry the session ID.
523+
sessionID := r.Header.Get(headerKeySessionID)
524+
isTerminated, err := s.sessionIdManager.Validate(sessionID)
525+
if err != nil {
526+
http.Error(w, "Invalid session ID", http.StatusBadRequest)
527+
return false
528+
}
529+
if isTerminated {
530+
http.Error(w, "Session terminated", http.StatusNotFound)
531+
return false
532+
}
533+
return true
534+
}
535+
536+
// validateProtocolVersion validates the protocol version header in the request.
537+
func (s *StreamableHTTPServer) validateProtocolVersion(w http.ResponseWriter, r *http.Request) bool {
538+
protocolVersion := r.Header.Get(headerKeyProtocolVersion)
539+
540+
// If no protocol version provided, assume default version
541+
if protocolVersion == "" {
542+
protocolVersion = mcp.DEFAULT_NEGOTIATED_VERSION
543+
}
544+
545+
// Check if the protocol version is supported
546+
if !slices.Contains(mcp.ValidProtocolVersions, protocolVersion) {
547+
supportedVersion := strings.Join(mcp.ValidProtocolVersions, ",")
548+
http.Error(
549+
w,
550+
fmt.Sprintf("Unsupported protocol version: %s. Supported versions: %s", protocolVersion, supportedVersion),
551+
http.StatusBadRequest,
552+
)
553+
return false
554+
}
555+
return true
556+
}
557+
513558
// --- session ---
514559

515560
type sessionToolsStore struct {

server/streamable_http_test.go

Lines changed: 189 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@ type jsonRPCResponse struct {
2323
Error *mcp.JSONRPCError `json:"error"`
2424
}
2525

26+
const clientInitializedProtocolVersion = "2025-03-26"
27+
2628
var initRequest = map[string]any{
2729
"jsonrpc": "2.0",
2830
"id": 1,
2931
"method": "initialize",
3032
"params": map[string]any{
31-
"protocolVersion": "2025-03-26",
33+
"protocolVersion": clientInitializedProtocolVersion,
3234
"clientInfo": map[string]any{
3335
"name": "test-client",
3436
"version": "1.0.0",
@@ -505,7 +507,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) {
505507
func TestStreamableHTTP_GET(t *testing.T) {
506508
mcpServer := NewMCPServer("test-mcp-server", "1.0")
507509
addSSETool(mcpServer)
508-
server := NewTestStreamableHTTPServer(mcpServer)
510+
server := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true))
509511

510512
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
511513
defer cancel()
@@ -612,7 +614,7 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) {
612614
})
613615

614616
mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks))
615-
testServer := NewTestStreamableHTTPServer(mcpServer)
617+
testServer := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true))
616618
defer testServer.Close()
617619

618620
// send initialize request to trigger the session registration
@@ -625,19 +627,36 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) {
625627
// Watch the notification to ensure the session is registered
626628
// (Normal http request (post) will not trigger the session registration)
627629
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
630+
errChan := make(chan string, 1)
628631
defer cancel()
629632
go func() {
630633
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, testServer.URL, nil)
631634
req.Header.Set("Content-Type", "text/event-stream")
632635
getResp, err := http.DefaultClient.Do(req)
633636
if err != nil {
634-
fmt.Printf("Failed to get: %v\n", err)
637+
errMsg := fmt.Sprintf("Failed to get: %v\n", err)
638+
errChan <- errMsg
635639
return
636640
}
637641
defer getResp.Body.Close()
642+
if getResp.StatusCode != http.StatusOK {
643+
bodyBytes, _ := io.ReadAll(getResp.Body)
644+
errMsg := fmt.Sprintf("Expected status 200, got %d, Response:%s", getResp.StatusCode, string(bodyBytes))
645+
errChan <- errMsg
646+
return
647+
}
648+
close(errChan)
638649
}()
639650

640651
// Verify we got a session
652+
select {
653+
case <-ctx.Done():
654+
t.Fatal("Timeout waiting for GET request to complete")
655+
case errMsg := <-errChan:
656+
if errMsg != "" {
657+
t.Fatal(errMsg)
658+
}
659+
}
641660
sessionRegistered.Wait()
642661
mu.Lock()
643662
if registeredSession == nil {
@@ -775,6 +794,172 @@ func TestStreamableHTTPServer_WithOptions(t *testing.T) {
775794
})
776795
}
777796

797+
func TestStreamableHTTPServer_ProtocolVersionHeader(t *testing.T) {
798+
// If using HTTP, the client MUST include the MCP-Protocol-Version: <protocol-version> HTTP header on all subsequent requests to the MCP server
799+
// The protocol version sent by the client SHOULD be the one negotiated during initialization.
800+
// If the server receives a request with an invalid or unsupported MCP-Protocol-Version, it MUST respond with 400 Bad Request.
801+
// Reference: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header
802+
mcpServer := NewMCPServer("test-mcp-server", "1.0")
803+
server := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true))
804+
805+
// Send initialize request
806+
resp, err := postJSON(server.URL, initRequest)
807+
if err != nil {
808+
t.Fatalf("Failed to send message: %v", err)
809+
}
810+
defer resp.Body.Close()
811+
if resp.StatusCode != http.StatusOK {
812+
t.Errorf("Expected status 200, got %d", resp.StatusCode)
813+
}
814+
bodyBytes, _ := io.ReadAll(resp.Body)
815+
var responseMessage jsonRPCResponse
816+
if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil {
817+
t.Fatalf("Failed to unmarshal response: %v", err)
818+
}
819+
if responseMessage.Result["protocolVersion"] != clientInitializedProtocolVersion {
820+
t.Errorf("Expected protocol version %s, got %s", clientInitializedProtocolVersion, responseMessage.Result["protocolVersion"])
821+
}
822+
// no session id from header
823+
sessionID := resp.Header.Get(headerKeySessionID)
824+
if sessionID != "" {
825+
t.Fatalf("Expected no session id in header, got %s", sessionID)
826+
}
827+
828+
t.Run("POST Request", func(t *testing.T) {
829+
// send notification
830+
notification := mcp.JSONRPCNotification{
831+
JSONRPC: "2.0",
832+
Notification: mcp.Notification{
833+
Method: "testNotification",
834+
Params: mcp.NotificationParams{
835+
AdditionalFields: map[string]interface{}{"param1": "value1"},
836+
},
837+
},
838+
}
839+
rawNotification, _ := json.Marshal(notification)
840+
841+
// Request with protocol version to be the one negotiated during initialization.
842+
req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification))
843+
req.Header.Set("Content-Type", "application/json")
844+
req.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion)
845+
resp, err := server.Client().Do(req)
846+
if err != nil {
847+
t.Fatalf("Failed to send message: %v", err)
848+
}
849+
if resp.StatusCode != http.StatusAccepted {
850+
t.Errorf("Expected status 202, got %d", resp.StatusCode)
851+
}
852+
bodyBytes, _ := io.ReadAll(resp.Body)
853+
if len(bodyBytes) > 0 {
854+
t.Errorf("Expected empty body, got %s", string(bodyBytes))
855+
}
856+
resp.Body.Close()
857+
858+
// Request with protocol version not to be the one negotiated during initialization but supported by server.
859+
req.Header.Set(headerKeyProtocolVersion, "2024-11-05")
860+
resp, err = server.Client().Do(req)
861+
if err != nil {
862+
t.Fatalf("Failed to send message: %v", err)
863+
}
864+
if resp.StatusCode != http.StatusAccepted {
865+
t.Errorf("Expected status 202, got %d", resp.StatusCode)
866+
}
867+
bodyBytes, _ = io.ReadAll(resp.Body)
868+
if len(bodyBytes) > 0 {
869+
t.Errorf("Expected empty body, got %s", string(bodyBytes))
870+
}
871+
resp.Body.Close()
872+
873+
// Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server.
874+
req.Header.Set(headerKeyProtocolVersion, "2024-06-18")
875+
resp, err = server.Client().Do(req)
876+
if err != nil {
877+
t.Fatalf("Failed to send message: %v", err)
878+
}
879+
if resp.StatusCode != http.StatusBadRequest {
880+
t.Errorf("Expected status 400, got %d", resp.StatusCode)
881+
}
882+
resp.Body.Close()
883+
})
884+
885+
t.Run("GET Request", func(t *testing.T) {
886+
// Request with protocol version to be the one negotiated during initialization.
887+
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
888+
req.Header.Set("Content-Type", "text/event-stream")
889+
req.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion)
890+
resp, err := http.DefaultClient.Do(req)
891+
if err != nil {
892+
t.Fatalf("Failed to send message: %v\n", err)
893+
}
894+
if resp.StatusCode != http.StatusOK {
895+
bodyBytes, _ := io.ReadAll(resp.Body)
896+
t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes))
897+
}
898+
resp.Body.Close()
899+
900+
// Request with protocol version not to be the one negotiated during initialization but supported by server.
901+
req.Header.Set(headerKeyProtocolVersion, "2024-11-05")
902+
resp, err = http.DefaultClient.Do(req)
903+
if err != nil {
904+
t.Fatalf("Failed to send message: %v\n", err)
905+
}
906+
if resp.StatusCode != http.StatusOK {
907+
bodyBytes, _ := io.ReadAll(resp.Body)
908+
t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes))
909+
}
910+
resp.Body.Close()
911+
912+
// Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server.
913+
req.Header.Set(headerKeyProtocolVersion, "2024-06-18")
914+
resp, err = server.Client().Do(req)
915+
if err != nil {
916+
t.Fatalf("Failed to send message: %v", err)
917+
}
918+
if resp.StatusCode != http.StatusBadRequest {
919+
t.Errorf("Expected status 400, got %d", resp.StatusCode)
920+
}
921+
resp.Body.Close()
922+
})
923+
924+
t.Run("DELETE Request", func(t *testing.T) {
925+
// Request with protocol version to be the one negotiated during initialization.
926+
req, _ := http.NewRequest(http.MethodDelete, server.URL, nil)
927+
req.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion)
928+
resp, err := http.DefaultClient.Do(req)
929+
if err != nil {
930+
t.Fatalf("Failed to send message: %v\n", err)
931+
}
932+
if resp.StatusCode != http.StatusOK {
933+
bodyBytes, _ := io.ReadAll(resp.Body)
934+
t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes))
935+
}
936+
resp.Body.Close()
937+
938+
// Request with protocol version not to be the one negotiated during initialization but supported by server.
939+
req.Header.Set(headerKeyProtocolVersion, "2024-11-05")
940+
resp, err = http.DefaultClient.Do(req)
941+
if err != nil {
942+
t.Fatalf("Failed to send message: %v\n", err)
943+
}
944+
if resp.StatusCode != http.StatusOK {
945+
bodyBytes, _ := io.ReadAll(resp.Body)
946+
t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes))
947+
}
948+
resp.Body.Close()
949+
950+
// Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server.
951+
req.Header.Set(headerKeyProtocolVersion, "2024-06-18")
952+
resp, err = server.Client().Do(req)
953+
if err != nil {
954+
t.Fatalf("Failed to send message: %v", err)
955+
}
956+
if resp.StatusCode != http.StatusBadRequest {
957+
t.Errorf("Expected status 400, got %d", resp.StatusCode)
958+
}
959+
resp.Body.Close()
960+
})
961+
}
962+
778963
func postJSON(url string, bodyObject any) (*http.Response, error) {
779964
jsonBody, _ := json.Marshal(bodyObject)
780965
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))

0 commit comments

Comments
 (0)