From 0596c88830e4f0ef62169e5d86ad866a9b064c6c Mon Sep 17 00:00:00 2001 From: cryo Date: Fri, 20 Jun 2025 04:25:29 +0000 Subject: [PATCH 1/3] feat(srv/stream): Require Protocol Version Header in HTTP --- mcp/types.go | 4 + server/streamable_http.go | 77 ++++++++++--- server/streamable_http_test.go | 193 ++++++++++++++++++++++++++++++++- 3 files changed, 254 insertions(+), 20 deletions(-) diff --git a/mcp/types.go b/mcp/types.go index d4f6132c8..f039bbaa3 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -98,6 +98,10 @@ type JSONRPCMessage any // LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol. const LATEST_PROTOCOL_VERSION = "2025-03-26" +// The default negotiated version of the Model Context Protocol when no version is specified +// Reference: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header +const DEFAULT_NEGOTIATED_VERSION = "2025-03-26" + // ValidProtocolVersions lists all known valid MCP protocol versions. var ValidProtocolVersions = []string{ "2024-11-05", diff --git a/server/streamable_http.go b/server/streamable_http.go index 1312c9753..f7bed1daa 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httptest" + "slices" "strings" "sync" "sync/atomic" @@ -205,7 +206,8 @@ func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error { // --- internal methods --- const ( - headerKeySessionID = "Mcp-Session-Id" + headerKeySessionID = "Mcp-Session-Id" + headerKeyProtocolVersion = "MCP-Protocol-Version" ) func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) { @@ -240,19 +242,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request if isInitializeRequest { // generate a new one for initialize request sessionID = s.sessionIdManager.Generate() - } else { - // Get session ID from header. - // Stateful servers need the client to carry the session ID. - sessionID = r.Header.Get(headerKeySessionID) - isTerminated, err := s.sessionIdManager.Validate(sessionID) - if err != nil { - http.Error(w, "Invalid session ID", http.StatusBadRequest) - return - } - if isTerminated { - http.Error(w, "Session terminated", http.StatusNotFound) - return - } + } else if !s.validateRequestHeaders(w, r) { + return } session := newStreamableHttpSession(sessionID, s.sessionTools) @@ -355,10 +346,11 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) { // get request is for listening to notifications // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server + if !s.validateRequestHeaders(w, r) { + return + } sessionID := r.Header.Get(headerKeySessionID) - // the specification didn't say we should validate the session id - if sessionID == "" { // It's a stateless server, // but the MCP server requires a unique ID for registering, so we use a random one @@ -454,6 +446,10 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) } func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) { + if !s.validateRequestHeaders(w, r) { + return + } + // delete request terminate the session sessionID := r.Header.Get(headerKeySessionID) notAllowed, err := s.sessionIdManager.Terminate(sessionID) @@ -510,6 +506,55 @@ func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 { return counter.Add(1) } +func (s *StreamableHTTPServer) validateRequestHeaders(w http.ResponseWriter, r *http.Request) bool { + if !s.validateSession(w, r) { + return false + } + if !s.validateProtocolVersion(w, r) { + return false + } + return true +} + +// validateSession validates the validates the session ID in the request. +func (s *StreamableHTTPServer) validateSession(w http.ResponseWriter, r *http.Request) bool { + // Get session ID from header. + // Stateful servers need the client to carry the session ID. + sessionID := r.Header.Get(headerKeySessionID) + isTerminated, err := s.sessionIdManager.Validate(sessionID) + if err != nil { + http.Error(w, "Invalid session ID", http.StatusBadRequest) + return false + } + if isTerminated { + http.Error(w, "Session terminated", http.StatusNotFound) + return false + } + return true +} + +// validateProtocolVersion validates the protocol version header in the request. +func (s *StreamableHTTPServer) validateProtocolVersion(w http.ResponseWriter, r *http.Request) bool { + protocolVersion := r.Header.Get(headerKeyProtocolVersion) + + // If no protocol version provided, assume default version + if protocolVersion == "" { + protocolVersion = mcp.DEFAULT_NEGOTIATED_VERSION + } + + // Check if the protocol version is supported + if !slices.Contains(mcp.ValidProtocolVersions, protocolVersion) { + supportedVersion := strings.Join(mcp.ValidProtocolVersions, ",") + http.Error( + w, + fmt.Sprintf("Unsupported protocol version: %s. Supported versions: %s", protocolVersion, supportedVersion), + http.StatusBadRequest, + ) + return false + } + return true +} + // --- session --- type sessionToolsStore struct { diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 0e1c7a65b..21bb19ee2 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -23,12 +23,14 @@ type jsonRPCResponse struct { Error *mcp.JSONRPCError `json:"error"` } +const clientInitializedProtocolVersion = "2025-03-26" + var initRequest = map[string]any{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": map[string]any{ - "protocolVersion": "2025-03-26", + "protocolVersion": clientInitializedProtocolVersion, "clientInfo": map[string]any{ "name": "test-client", "version": "1.0.0", @@ -505,7 +507,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { func TestStreamableHTTP_GET(t *testing.T) { mcpServer := NewMCPServer("test-mcp-server", "1.0") addSSETool(mcpServer) - server := NewTestStreamableHTTPServer(mcpServer) + server := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true)) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -612,7 +614,7 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) { }) mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) - testServer := NewTestStreamableHTTPServer(mcpServer) + testServer := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true)) defer testServer.Close() // send initialize request to trigger the session registration @@ -625,19 +627,36 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) { // Watch the notification to ensure the session is registered // (Normal http request (post) will not trigger the session registration) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + errChan := make(chan string, 1) defer cancel() go func() { req, _ := http.NewRequestWithContext(ctx, http.MethodGet, testServer.URL, nil) req.Header.Set("Content-Type", "text/event-stream") getResp, err := http.DefaultClient.Do(req) if err != nil { - fmt.Printf("Failed to get: %v\n", err) + errMsg := fmt.Sprintf("Failed to get: %v\n", err) + errChan <- errMsg return } defer getResp.Body.Close() + if getResp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(getResp.Body) + errMsg := fmt.Sprintf("Expected status 200, got %d, Response:%s", getResp.StatusCode, string(bodyBytes)) + errChan <- errMsg + return + } + close(errChan) }() // Verify we got a session + select { + case <-ctx.Done(): + t.Fatal("Timeout waiting for GET request to complete") + case errMsg := <-errChan: + if errMsg != "" { + t.Fatal(errMsg) + } + } sessionRegistered.Wait() mu.Lock() if registeredSession == nil { @@ -775,6 +794,172 @@ func TestStreamableHTTPServer_WithOptions(t *testing.T) { }) } +func TestStreamableHTTPServer_ProtocolVersionHeader(t *testing.T) { + // If using HTTP, the client MUST include the MCP-Protocol-Version: HTTP header on all subsequent requests to the MCP server + // The protocol version sent by the client SHOULD be the one negotiated during initialization. + // If the server receives a request with an invalid or unsupported MCP-Protocol-Version, it MUST respond with 400 Bad Request. + // Reference: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header + mcpServer := NewMCPServer("test-mcp-server", "1.0") + server := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true)) + + // Send initialize request + resp, err := postJSON(server.URL, initRequest) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + bodyBytes, _ := io.ReadAll(resp.Body) + var responseMessage jsonRPCResponse + if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + if responseMessage.Result["protocolVersion"] != clientInitializedProtocolVersion { + t.Errorf("Expected protocol version %s, got %s", clientInitializedProtocolVersion, responseMessage.Result["protocolVersion"]) + } + // no session id from header + sessionID := resp.Header.Get(headerKeySessionID) + if sessionID != "" { + t.Fatalf("Expected no session id in header, got %s", sessionID) + } + + t.Run("POST Request", func(t *testing.T) { + // send notification + notification := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "testNotification", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]interface{}{"param1": "value1"}, + }, + }, + } + rawNotification, _ := json.Marshal(notification) + + // Request with protocol version to be the one negotiated during initialization. + req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp.StatusCode) + } + bodyBytes, _ := io.ReadAll(resp.Body) + if len(bodyBytes) > 0 { + t.Errorf("Expected empty body, got %s", string(bodyBytes)) + } + resp.Body.Close() + + // Request with protocol version not to be the one negotiated during initialization but supported by server. + req.Header.Set(headerKeyProtocolVersion, "2024-11-05") + resp, err = server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp.StatusCode) + } + bodyBytes, _ = io.ReadAll(resp.Body) + if len(bodyBytes) > 0 { + t.Errorf("Expected empty body, got %s", string(bodyBytes)) + } + resp.Body.Close() + + // Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server. + req.Header.Set(headerKeyProtocolVersion, "2024-06-18") + resp, err = server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp.StatusCode) + } + resp.Body.Close() + }) + + t.Run("GET Request", func(t *testing.T) { + // Request with protocol version to be the one negotiated during initialization. + req, _ := http.NewRequest(http.MethodGet, server.URL, nil) + req.Header.Set("Content-Type", "text/event-stream") + req.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v\n", err) + } + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes)) + } + resp.Body.Close() + + // Request with protocol version not to be the one negotiated during initialization but supported by server. + req.Header.Set(headerKeyProtocolVersion, "2024-11-05") + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v\n", err) + } + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes)) + } + resp.Body.Close() + + // Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server. + req.Header.Set(headerKeyProtocolVersion, "2024-06-18") + resp, err = server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp.StatusCode) + } + resp.Body.Close() + }) + + t.Run("DELETE Request", func(t *testing.T) { + // Request with protocol version to be the one negotiated during initialization. + req, _ := http.NewRequest(http.MethodDelete, server.URL, nil) + req.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v\n", err) + } + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes)) + } + resp.Body.Close() + + // Request with protocol version not to be the one negotiated during initialization but supported by server. + req.Header.Set(headerKeyProtocolVersion, "2024-11-05") + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v\n", err) + } + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes)) + } + resp.Body.Close() + + // Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server. + req.Header.Set(headerKeyProtocolVersion, "2024-06-18") + resp, err = server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp.StatusCode) + } + resp.Body.Close() + }) +} + func postJSON(url string, bodyObject any) (*http.Response, error) { jsonBody, _ := json.Marshal(bodyObject) req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) From 337925b0302840366f407bb6020c108935bf1920 Mon Sep 17 00:00:00 2001 From: cryo Date: Fri, 20 Jun 2025 08:00:25 +0000 Subject: [PATCH 2/3] chore(tst/streamSrv): update ProtocolVersionHeader test --- server/streamable_http_test.go | 120 ++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 55 deletions(-) diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 21bb19ee2..a65c25d52 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -839,124 +839,134 @@ func TestStreamableHTTPServer_ProtocolVersionHeader(t *testing.T) { rawNotification, _ := json.Marshal(notification) // Request with protocol version to be the one negotiated during initialization. - req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) - resp, err := server.Client().Do(req) + req1, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification)) + req1.Header.Set("Content-Type", "application/json") + req1.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) + resp1, err := server.Client().Do(req1) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if resp.StatusCode != http.StatusAccepted { - t.Errorf("Expected status 202, got %d", resp.StatusCode) + defer resp1.Body.Close() + if resp1.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp1.StatusCode) } - bodyBytes, _ := io.ReadAll(resp.Body) + bodyBytes, _ := io.ReadAll(resp1.Body) if len(bodyBytes) > 0 { t.Errorf("Expected empty body, got %s", string(bodyBytes)) } - resp.Body.Close() // Request with protocol version not to be the one negotiated during initialization but supported by server. - req.Header.Set(headerKeyProtocolVersion, "2024-11-05") - resp, err = server.Client().Do(req) + req2, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification)) + req2.Header.Set("Content-Type", "application/json") + req2.Header.Set(headerKeyProtocolVersion, "2024-11-05") + resp2, err := server.Client().Do(req2) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if resp.StatusCode != http.StatusAccepted { - t.Errorf("Expected status 202, got %d", resp.StatusCode) + defer resp2.Body.Close() + if resp2.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp2.StatusCode) } - bodyBytes, _ = io.ReadAll(resp.Body) + bodyBytes, _ = io.ReadAll(resp2.Body) if len(bodyBytes) > 0 { t.Errorf("Expected empty body, got %s", string(bodyBytes)) } - resp.Body.Close() // Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server. - req.Header.Set(headerKeyProtocolVersion, "2024-06-18") - resp, err = server.Client().Do(req) + req3, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification)) + req3.Header.Set("Content-Type", "application/json") + req3.Header.Set(headerKeyProtocolVersion, "2024-06-18") + resp3, err := server.Client().Do(req3) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("Expected status 400, got %d", resp.StatusCode) + defer resp3.Body.Close() + if resp3.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp3.StatusCode) } - resp.Body.Close() }) t.Run("GET Request", func(t *testing.T) { // Request with protocol version to be the one negotiated during initialization. - req, _ := http.NewRequest(http.MethodGet, server.URL, nil) - req.Header.Set("Content-Type", "text/event-stream") - req.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) - resp, err := http.DefaultClient.Do(req) + req1, _ := http.NewRequest(http.MethodGet, server.URL, nil) + req1.Header.Set("Content-Type", "text/event-stream") + req1.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) + resp1, err := http.DefaultClient.Do(req1) if err != nil { t.Fatalf("Failed to send message: %v\n", err) } - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes)) + defer resp1.Body.Close() + if resp1.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp1.Body) + t.Fatalf("Expected status 200, got %d, Response:%s", resp1.StatusCode, string(bodyBytes)) } - resp.Body.Close() // Request with protocol version not to be the one negotiated during initialization but supported by server. - req.Header.Set(headerKeyProtocolVersion, "2024-11-05") - resp, err = http.DefaultClient.Do(req) + req2, _ := http.NewRequest(http.MethodGet, server.URL, nil) + req2.Header.Set("Content-Type", "text/event-stream") + req2.Header.Set(headerKeyProtocolVersion, "2024-11-05") + resp2, err := http.DefaultClient.Do(req2) if err != nil { t.Fatalf("Failed to send message: %v\n", err) } - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) + defer resp2.Body.Close() + if resp2.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp2.Body) t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes)) } - resp.Body.Close() // Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server. - req.Header.Set(headerKeyProtocolVersion, "2024-06-18") - resp, err = server.Client().Do(req) + req3, _ := http.NewRequest(http.MethodGet, server.URL, nil) + req3.Header.Set("Content-Type", "text/event-stream") + req3.Header.Set(headerKeyProtocolVersion, "2024-06-18") + resp3, err := http.DefaultClient.Do(req3) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("Expected status 400, got %d", resp.StatusCode) + defer resp3.Body.Close() + if resp3.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp3.StatusCode) } - resp.Body.Close() }) t.Run("DELETE Request", func(t *testing.T) { // Request with protocol version to be the one negotiated during initialization. - req, _ := http.NewRequest(http.MethodDelete, server.URL, nil) - req.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) - resp, err := http.DefaultClient.Do(req) + req1, _ := http.NewRequest(http.MethodDelete, server.URL, nil) + req1.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion) + resp1, err := http.DefaultClient.Do(req1) if err != nil { t.Fatalf("Failed to send message: %v\n", err) } - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes)) + defer resp1.Body.Close() + if resp1.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp1.Body) + t.Fatalf("Expected status 200, got %d, Response:%s", resp1.StatusCode, string(bodyBytes)) } - resp.Body.Close() // Request with protocol version not to be the one negotiated during initialization but supported by server. - req.Header.Set(headerKeyProtocolVersion, "2024-11-05") - resp, err = http.DefaultClient.Do(req) + req2, _ := http.NewRequest(http.MethodDelete, server.URL, nil) + req2.Header.Set(headerKeyProtocolVersion, "2024-11-05") + resp2, err := http.DefaultClient.Do(req2) if err != nil { t.Fatalf("Failed to send message: %v\n", err) } - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes)) + defer resp2.Body.Close() + if resp2.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp2.Body) + t.Fatalf("Expected status 200, got %d, Response:%s", resp2.StatusCode, string(bodyBytes)) } - resp.Body.Close() // Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server. - req.Header.Set(headerKeyProtocolVersion, "2024-06-18") - resp, err = server.Client().Do(req) + req3, _ := http.NewRequest(http.MethodDelete, server.URL, nil) + req3.Header.Set(headerKeyProtocolVersion, "2024-06-18") + resp3, err := http.DefaultClient.Do(req3) if err != nil { t.Fatalf("Failed to send message: %v", err) } - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("Expected status 400, got %d", resp.StatusCode) + defer resp3.Body.Close() + if resp3.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp3.StatusCode) } - resp.Body.Close() }) } From 6d1ca6d58c1fae9d9555e9976a1c4c30092eb704 Mon Sep 17 00:00:00 2001 From: cryo Date: Sun, 22 Jun 2025 17:33:26 +0800 Subject: [PATCH 3/3] Fix: server/streamable_http_test.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- server/streamable_http_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index a65c25d52..c7dd9c1d3 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -912,7 +912,7 @@ func TestStreamableHTTPServer_ProtocolVersionHeader(t *testing.T) { defer resp2.Body.Close() if resp2.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp2.Body) - t.Fatalf("Expected status 200, got %d, Response:%s", resp.StatusCode, string(bodyBytes)) + t.Fatalf("Expected status 200, got %d, Response:%s", resp2.StatusCode, string(bodyBytes)) } // Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server.