diff --git a/mcp/types.go b/mcp/types.go index d4f6132c8..f039bbaa3 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -98,6 +98,10 @@ type JSONRPCMessage any // LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol. const LATEST_PROTOCOL_VERSION = "2025-03-26" +// The default negotiated version of the Model Context Protocol when no version is specified +// Reference: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header +const DEFAULT_NEGOTIATED_VERSION = "2025-03-26" + // ValidProtocolVersions lists all known valid MCP protocol versions. var ValidProtocolVersions = []string{ "2024-11-05", diff --git a/server/streamable_http.go b/server/streamable_http.go index 1312c9753..f7bed1daa 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httptest" + "slices" "strings" "sync" "sync/atomic" @@ -205,7 +206,8 @@ func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error { // --- internal methods --- const ( - headerKeySessionID = "Mcp-Session-Id" + headerKeySessionID = "Mcp-Session-Id" + headerKeyProtocolVersion = "MCP-Protocol-Version" ) func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) { @@ -240,19 +242,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request if isInitializeRequest { // generate a new one for initialize request sessionID = s.sessionIdManager.Generate() - } else { - // Get session ID from header. - // Stateful servers need the client to carry the session ID. - sessionID = r.Header.Get(headerKeySessionID) - isTerminated, err := s.sessionIdManager.Validate(sessionID) - if err != nil { - http.Error(w, "Invalid session ID", http.StatusBadRequest) - return - } - if isTerminated { - http.Error(w, "Session terminated", http.StatusNotFound) - return - } + } else if !s.validateRequestHeaders(w, r) { + return } session := newStreamableHttpSession(sessionID, s.sessionTools) @@ -355,10 +346,11 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) { // get request is for listening to notifications // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server + if !s.validateRequestHeaders(w, r) { + return + } sessionID := r.Header.Get(headerKeySessionID) - // the specification didn't say we should validate the session id - if sessionID == "" { // It's a stateless server, // but the MCP server requires a unique ID for registering, so we use a random one @@ -454,6 +446,10 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) } func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) { + if !s.validateRequestHeaders(w, r) { + return + } + // delete request terminate the session sessionID := r.Header.Get(headerKeySessionID) notAllowed, err := s.sessionIdManager.Terminate(sessionID) @@ -510,6 +506,55 @@ func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 { return counter.Add(1) } +func (s *StreamableHTTPServer) validateRequestHeaders(w http.ResponseWriter, r *http.Request) bool { + if !s.validateSession(w, r) { + return false + } + if !s.validateProtocolVersion(w, r) { + return false + } + return true +} + +// validateSession validates the validates the session ID in the request. +func (s *StreamableHTTPServer) validateSession(w http.ResponseWriter, r *http.Request) bool { + // Get session ID from header. + // Stateful servers need the client to carry the session ID. + sessionID := r.Header.Get(headerKeySessionID) + isTerminated, err := s.sessionIdManager.Validate(sessionID) + if err != nil { + http.Error(w, "Invalid session ID", http.StatusBadRequest) + return false + } + if isTerminated { + http.Error(w, "Session terminated", http.StatusNotFound) + return false + } + return true +} + +// validateProtocolVersion validates the protocol version header in the request. +func (s *StreamableHTTPServer) validateProtocolVersion(w http.ResponseWriter, r *http.Request) bool { + protocolVersion := r.Header.Get(headerKeyProtocolVersion) + + // If no protocol version provided, assume default version + if protocolVersion == "" { + protocolVersion = mcp.DEFAULT_NEGOTIATED_VERSION + } + + // Check if the protocol version is supported + if !slices.Contains(mcp.ValidProtocolVersions, protocolVersion) { + supportedVersion := strings.Join(mcp.ValidProtocolVersions, ",") + http.Error( + w, + fmt.Sprintf("Unsupported protocol version: %s. Supported versions: %s", protocolVersion, supportedVersion), + http.StatusBadRequest, + ) + return false + } + return true +} + // --- session --- type sessionToolsStore struct { diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 0e1c7a65b..c7dd9c1d3 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -23,12 +23,14 @@ type jsonRPCResponse struct { Error *mcp.JSONRPCError `json:"error"` } +const clientInitializedProtocolVersion = "2025-03-26" + var initRequest = map[string]any{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": map[string]any{ - "protocolVersion": "2025-03-26", + "protocolVersion": clientInitializedProtocolVersion, "clientInfo": map[string]any{ "name": "test-client", "version": "1.0.0", @@ -505,7 +507,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { func TestStreamableHTTP_GET(t *testing.T) { mcpServer := NewMCPServer("test-mcp-server", "1.0") addSSETool(mcpServer) - server := NewTestStreamableHTTPServer(mcpServer) + server := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true)) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -612,7 +614,7 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) { }) mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) - testServer := NewTestStreamableHTTPServer(mcpServer) + testServer := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true)) defer testServer.Close() // send initialize request to trigger the session registration @@ -625,19 +627,36 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) { // Watch the notification to ensure the session is registered // (Normal http request (post) will not trigger the session registration) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + errChan := make(chan string, 1) defer cancel() go func() { req, _ := http.NewRequestWithContext(ctx, http.MethodGet, testServer.URL, nil) req.Header.Set("Content-Type", "text/event-stream") getResp, err := http.DefaultClient.Do(req) if err != nil { - fmt.Printf("Failed to get: %v\n", err) + errMsg := fmt.Sprintf("Failed to get: %v\n", err) + errChan <- errMsg return } defer getResp.Body.Close() + if getResp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(getResp.Body) + errMsg := fmt.Sprintf("Expected status 200, got %d, Response:%s", getResp.StatusCode, string(bodyBytes)) + errChan <- errMsg + return + } + close(errChan) }() // Verify we got a session + select { + case <-ctx.Done(): + t.Fatal("Timeout waiting for GET request to complete") + case errMsg := <-errChan: + if errMsg != "" { + t.Fatal(errMsg) + } + } sessionRegistered.Wait() mu.Lock() if registeredSession == nil { @@ -775,6 +794,182 @@ func TestStreamableHTTPServer_WithOptions(t *testing.T) { }) } +func TestStreamableHTTPServer_ProtocolVersionHeader(t *testing.T) { + // If using HTTP, the client MUST include the MCP-Protocol-Version: HTTP header on all subsequent requests to the MCP server + // The protocol version sent by the client SHOULD be the one negotiated during initialization. + // If the server receives a request with an invalid or unsupported MCP-Protocol-Version, it MUST respond with 400 Bad Request. + // Reference: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header + mcpServer := NewMCPServer("test-mcp-server", "1.0") + server := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true)) + + // Send initialize request + resp, err := postJSON(server.URL, initRequest) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + bodyBytes, _ := io.ReadAll(resp.Body) + var responseMessage jsonRPCResponse + if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + if responseMessage.Result["protocolVersion"] != clientInitializedProtocolVersion { + t.Errorf("Expected protocol version %s, got %s", clientInitializedProtocolVersion, responseMessage.Result["protocolVersion"]) + } + // no session id from header + sessionID := resp.Header.Get(headerKeySessionID) + if sessionID != "" { + t.Fatalf("Expected no session id in header, got %s", sessionID) + } + + t.Run("POST Request", func(t *testing.T) { + // send notification + notification := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "testNotification", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]interface{}{"param1": "value1"}, + }, + }, + } + rawNotification, _ := json.Marshal(notification) + + // Request with protocol version to be the one negotiated during initialization. + req1, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification)) + req1.Header.Set("Content-Type", "application/json") + req1.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) + resp1, err := server.Client().Do(req1) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp1.Body.Close() + if resp1.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp1.StatusCode) + } + bodyBytes, _ := io.ReadAll(resp1.Body) + if len(bodyBytes) > 0 { + t.Errorf("Expected empty body, got %s", string(bodyBytes)) + } + + // Request with protocol version not to be the one negotiated during initialization but supported by server. + req2, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification)) + req2.Header.Set("Content-Type", "application/json") + req2.Header.Set(headerKeyProtocolVersion, "2024-11-05") + resp2, err := server.Client().Do(req2) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp2.Body.Close() + if resp2.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp2.StatusCode) + } + bodyBytes, _ = io.ReadAll(resp2.Body) + if len(bodyBytes) > 0 { + t.Errorf("Expected empty body, got %s", string(bodyBytes)) + } + + // Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server. + req3, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification)) + req3.Header.Set("Content-Type", "application/json") + req3.Header.Set(headerKeyProtocolVersion, "2024-06-18") + resp3, err := server.Client().Do(req3) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp3.Body.Close() + if resp3.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp3.StatusCode) + } + }) + + t.Run("GET Request", func(t *testing.T) { + // Request with protocol version to be the one negotiated during initialization. + req1, _ := http.NewRequest(http.MethodGet, server.URL, nil) + req1.Header.Set("Content-Type", "text/event-stream") + req1.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) + resp1, err := http.DefaultClient.Do(req1) + if err != nil { + t.Fatalf("Failed to send message: %v\n", err) + } + defer resp1.Body.Close() + if resp1.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp1.Body) + t.Fatalf("Expected status 200, got %d, Response:%s", resp1.StatusCode, string(bodyBytes)) + } + + // Request with protocol version not to be the one negotiated during initialization but supported by server. + req2, _ := http.NewRequest(http.MethodGet, server.URL, nil) + req2.Header.Set("Content-Type", "text/event-stream") + req2.Header.Set(headerKeyProtocolVersion, "2024-11-05") + resp2, err := http.DefaultClient.Do(req2) + if err != nil { + t.Fatalf("Failed to send message: %v\n", err) + } + defer resp2.Body.Close() + if resp2.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp2.Body) + t.Fatalf("Expected status 200, got %d, Response:%s", resp2.StatusCode, string(bodyBytes)) + } + + // Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server. + req3, _ := http.NewRequest(http.MethodGet, server.URL, nil) + req3.Header.Set("Content-Type", "text/event-stream") + req3.Header.Set(headerKeyProtocolVersion, "2024-06-18") + resp3, err := http.DefaultClient.Do(req3) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp3.Body.Close() + if resp3.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp3.StatusCode) + } + }) + + t.Run("DELETE Request", func(t *testing.T) { + // Request with protocol version to be the one negotiated during initialization. + req1, _ := http.NewRequest(http.MethodDelete, server.URL, nil) + req1.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) + resp1, err := http.DefaultClient.Do(req1) + if err != nil { + t.Fatalf("Failed to send message: %v\n", err) + } + defer resp1.Body.Close() + if resp1.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp1.Body) + t.Fatalf("Expected status 200, got %d, Response:%s", resp1.StatusCode, string(bodyBytes)) + } + + // Request with protocol version not to be the one negotiated during initialization but supported by server. + req2, _ := http.NewRequest(http.MethodDelete, server.URL, nil) + req2.Header.Set(headerKeyProtocolVersion, "2024-11-05") + resp2, err := http.DefaultClient.Do(req2) + if err != nil { + t.Fatalf("Failed to send message: %v\n", err) + } + defer resp2.Body.Close() + if resp2.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp2.Body) + t.Fatalf("Expected status 200, got %d, Response:%s", resp2.StatusCode, string(bodyBytes)) + } + + // Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server. + req3, _ := http.NewRequest(http.MethodDelete, server.URL, nil) + req3.Header.Set(headerKeyProtocolVersion, "2024-06-18") + resp3, err := http.DefaultClient.Do(req3) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp3.Body.Close() + if resp3.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp3.StatusCode) + } + }) +} + func postJSON(url string, bodyObject any) (*http.Response, error) { jsonBody, _ := json.Marshal(bodyObject) req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))