Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mcp/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ type JSONRPCMessage any
// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol.
const LATEST_PROTOCOL_VERSION = "2025-03-26"

// The default negotiated version of the Model Context Protocol when no version is specified
// Reference: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header
const DEFAULT_NEGOTIATED_VERSION = "2025-03-26"

// ValidProtocolVersions lists all known valid MCP protocol versions.
var ValidProtocolVersions = []string{
"2024-11-05",
Expand Down
77 changes: 61 additions & 16 deletions server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -205,7 +206,8 @@ func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
// --- internal methods ---

const (
headerKeySessionID = "Mcp-Session-Id"
headerKeySessionID = "Mcp-Session-Id"
headerKeyProtocolVersion = "MCP-Protocol-Version"
)

func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -240,19 +242,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
if isInitializeRequest {
// generate a new one for initialize request
sessionID = s.sessionIdManager.Generate()
} else {
// Get session ID from header.
// Stateful servers need the client to carry the session ID.
sessionID = r.Header.Get(headerKeySessionID)
isTerminated, err := s.sessionIdManager.Validate(sessionID)
if err != nil {
http.Error(w, "Invalid session ID", http.StatusBadRequest)
return
}
if isTerminated {
http.Error(w, "Session terminated", http.StatusNotFound)
return
}
} else if !s.validateRequestHeaders(w, r) {
return

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to store this new header in the session state during initialization if the server is stateful. From the mcp spec, I am interpreting that when the header is "invalid" is when the client uses a different protocol version header than what was initialized. Otherwise, what you have is good is my book

If using HTTP, the client MUST include the MCP-Protocol-Version: HTTP header on all subsequent requests to the MCP server, allowing the MCP server to respond based on the MCP protocol version.


The protocol version sent by the client SHOULD be the one negotiated during initialization.


If the server receives a request with an invalid or unsupported MCP-Protocol-Version, it MUST respond with 400 Bad Request.

}

session := newStreamableHttpSession(sessionID, s.sessionTools)
Expand Down Expand Up @@ -355,10 +346,11 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {
// get request is for listening to notifications
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
if !s.validateRequestHeaders(w, r) {
return
}

sessionID := r.Header.Get(headerKeySessionID)
// the specification didn't say we should validate the session id

if sessionID == "" {
// It's a stateless server,
// but the MCP server requires a unique ID for registering, so we use a random one
Expand Down Expand Up @@ -454,6 +446,10 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
}

func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) {
if !s.validateRequestHeaders(w, r) {
return
}

// delete request terminate the session
sessionID := r.Header.Get(headerKeySessionID)
notAllowed, err := s.sessionIdManager.Terminate(sessionID)
Expand Down Expand Up @@ -510,6 +506,55 @@ func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 {
return counter.Add(1)
}

func (s *StreamableHTTPServer) validateRequestHeaders(w http.ResponseWriter, r *http.Request) bool {
if !s.validateSession(w, r) {
return false
}
if !s.validateProtocolVersion(w, r) {
return false
}
return true
}

// validateSession validates the validates the session ID in the request.
func (s *StreamableHTTPServer) validateSession(w http.ResponseWriter, r *http.Request) bool {
// Get session ID from header.
// Stateful servers need the client to carry the session ID.
sessionID := r.Header.Get(headerKeySessionID)
isTerminated, err := s.sessionIdManager.Validate(sessionID)
if err != nil {
http.Error(w, "Invalid session ID", http.StatusBadRequest)
return false
}
if isTerminated {
http.Error(w, "Session terminated", http.StatusNotFound)
return false
}
return true
}

// validateProtocolVersion validates the protocol version header in the request.
func (s *StreamableHTTPServer) validateProtocolVersion(w http.ResponseWriter, r *http.Request) bool {
protocolVersion := r.Header.Get(headerKeyProtocolVersion)

// If no protocol version provided, assume default version
if protocolVersion == "" {
protocolVersion = mcp.DEFAULT_NEGOTIATED_VERSION
}

// Check if the protocol version is supported
if !slices.Contains(mcp.ValidProtocolVersions, protocolVersion) {
supportedVersion := strings.Join(mcp.ValidProtocolVersions, ",")
http.Error(
w,
fmt.Sprintf("Unsupported protocol version: %s. Supported versions: %s", protocolVersion, supportedVersion),
http.StatusBadRequest,
)
return false
}
return true
}

// --- session ---

type sessionToolsStore struct {
Expand Down
203 changes: 199 additions & 4 deletions server/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ type jsonRPCResponse struct {
Error *mcp.JSONRPCError `json:"error"`
}

const clientInitializedProtocolVersion = "2025-03-26"

var initRequest = map[string]any{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": map[string]any{
"protocolVersion": "2025-03-26",
"protocolVersion": clientInitializedProtocolVersion,
"clientInfo": map[string]any{
"name": "test-client",
"version": "1.0.0",
Expand Down Expand Up @@ -505,7 +507,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) {
func TestStreamableHTTP_GET(t *testing.T) {
mcpServer := NewMCPServer("test-mcp-server", "1.0")
addSSETool(mcpServer)
server := NewTestStreamableHTTPServer(mcpServer)
server := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true))

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
Expand Down Expand Up @@ -612,7 +614,7 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) {
})

mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks))
testServer := NewTestStreamableHTTPServer(mcpServer)
testServer := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true))
defer testServer.Close()

// send initialize request to trigger the session registration
Expand All @@ -625,19 +627,36 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) {
// Watch the notification to ensure the session is registered
// (Normal http request (post) will not trigger the session registration)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
errChan := make(chan string, 1)
defer cancel()
go func() {
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, testServer.URL, nil)
req.Header.Set("Content-Type", "text/event-stream")
getResp, err := http.DefaultClient.Do(req)
if err != nil {
fmt.Printf("Failed to get: %v\n", err)
errMsg := fmt.Sprintf("Failed to get: %v\n", err)
errChan <- errMsg
return
}
defer getResp.Body.Close()
if getResp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(getResp.Body)
errMsg := fmt.Sprintf("Expected status 200, got %d, Response:%s", getResp.StatusCode, string(bodyBytes))
errChan <- errMsg
return
}
close(errChan)
}()

// Verify we got a session
select {
case <-ctx.Done():
t.Fatal("Timeout waiting for GET request to complete")
case errMsg := <-errChan:
if errMsg != "" {
t.Fatal(errMsg)
}
}
sessionRegistered.Wait()
mu.Lock()
if registeredSession == nil {
Expand Down Expand Up @@ -775,6 +794,182 @@ func TestStreamableHTTPServer_WithOptions(t *testing.T) {
})
}

func TestStreamableHTTPServer_ProtocolVersionHeader(t *testing.T) {
// If using HTTP, the client MUST include the MCP-Protocol-Version: <protocol-version> HTTP header on all subsequent requests to the MCP server
// The protocol version sent by the client SHOULD be the one negotiated during initialization.
// If the server receives a request with an invalid or unsupported MCP-Protocol-Version, it MUST respond with 400 Bad Request.
// Reference: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header
mcpServer := NewMCPServer("test-mcp-server", "1.0")
server := NewTestStreamableHTTPServer(mcpServer, WithStateLess(true))

// Send initialize request
resp, err := postJSON(server.URL, initRequest)
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
bodyBytes, _ := io.ReadAll(resp.Body)
var responseMessage jsonRPCResponse
if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if responseMessage.Result["protocolVersion"] != clientInitializedProtocolVersion {
t.Errorf("Expected protocol version %s, got %s", clientInitializedProtocolVersion, responseMessage.Result["protocolVersion"])
}
// no session id from header
sessionID := resp.Header.Get(headerKeySessionID)
if sessionID != "" {
t.Fatalf("Expected no session id in header, got %s", sessionID)
}

t.Run("POST Request", func(t *testing.T) {
// send notification
notification := mcp.JSONRPCNotification{
JSONRPC: "2.0",
Notification: mcp.Notification{
Method: "testNotification",
Params: mcp.NotificationParams{
AdditionalFields: map[string]interface{}{"param1": "value1"},
},
},
}
rawNotification, _ := json.Marshal(notification)

// Request with protocol version to be the one negotiated during initialization.
req1, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification))
req1.Header.Set("Content-Type", "application/json")
req1.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion)
resp1, err := server.Client().Do(req1)
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}
defer resp1.Body.Close()
if resp1.StatusCode != http.StatusAccepted {
t.Errorf("Expected status 202, got %d", resp1.StatusCode)
}
bodyBytes, _ := io.ReadAll(resp1.Body)
if len(bodyBytes) > 0 {
t.Errorf("Expected empty body, got %s", string(bodyBytes))
}

// Request with protocol version not to be the one negotiated during initialization but supported by server.
req2, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification))
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set(headerKeyProtocolVersion, "2024-11-05")
resp2, err := server.Client().Do(req2)
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}
defer resp2.Body.Close()
if resp2.StatusCode != http.StatusAccepted {
t.Errorf("Expected status 202, got %d", resp2.StatusCode)
}
bodyBytes, _ = io.ReadAll(resp2.Body)
if len(bodyBytes) > 0 {
t.Errorf("Expected empty body, got %s", string(bodyBytes))
}

// Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server.
req3, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification))
req3.Header.Set("Content-Type", "application/json")
req3.Header.Set(headerKeyProtocolVersion, "2024-06-18")
resp3, err := server.Client().Do(req3)
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}
defer resp3.Body.Close()
if resp3.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", resp3.StatusCode)
}
})

t.Run("GET Request", func(t *testing.T) {
// Request with protocol version to be the one negotiated during initialization.
req1, _ := http.NewRequest(http.MethodGet, server.URL, nil)
req1.Header.Set("Content-Type", "text/event-stream")
req1.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion)
resp1, err := http.DefaultClient.Do(req1)
if err != nil {
t.Fatalf("Failed to send message: %v\n", err)
}
defer resp1.Body.Close()
if resp1.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp1.Body)
t.Fatalf("Expected status 200, got %d, Response:%s", resp1.StatusCode, string(bodyBytes))
}

// Request with protocol version not to be the one negotiated during initialization but supported by server.
req2, _ := http.NewRequest(http.MethodGet, server.URL, nil)
req2.Header.Set("Content-Type", "text/event-stream")
req2.Header.Set(headerKeyProtocolVersion, "2024-11-05")
resp2, err := http.DefaultClient.Do(req2)
if err != nil {
t.Fatalf("Failed to send message: %v\n", err)
}
defer resp2.Body.Close()
if resp2.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp2.Body)
t.Fatalf("Expected status 200, got %d, Response:%s", resp2.StatusCode, string(bodyBytes))
}

// Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server.
req3, _ := http.NewRequest(http.MethodGet, server.URL, nil)
req3.Header.Set("Content-Type", "text/event-stream")
req3.Header.Set(headerKeyProtocolVersion, "2024-06-18")
resp3, err := http.DefaultClient.Do(req3)
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}
defer resp3.Body.Close()
if resp3.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", resp3.StatusCode)
}
})

t.Run("DELETE Request", func(t *testing.T) {
// Request with protocol version to be the one negotiated during initialization.
req1, _ := http.NewRequest(http.MethodDelete, server.URL, nil)
req1.Header.Set(headerKeyProtocolVersion, clientInitializedProtocolVersion)
resp1, err := http.DefaultClient.Do(req1)
if err != nil {
t.Fatalf("Failed to send message: %v\n", err)
}
defer resp1.Body.Close()
if resp1.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp1.Body)
t.Fatalf("Expected status 200, got %d, Response:%s", resp1.StatusCode, string(bodyBytes))
}

// Request with protocol version not to be the one negotiated during initialization but supported by server.
req2, _ := http.NewRequest(http.MethodDelete, server.URL, nil)
req2.Header.Set(headerKeyProtocolVersion, "2024-11-05")
resp2, err := http.DefaultClient.Do(req2)
if err != nil {
t.Fatalf("Failed to send message: %v\n", err)
}
defer resp2.Body.Close()
if resp2.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp2.Body)
t.Fatalf("Expected status 200, got %d, Response:%s", resp2.StatusCode, string(bodyBytes))
}

// Request with protocol version not to be the one negotiated during initialization but also invalid/unsupported by server.
req3, _ := http.NewRequest(http.MethodDelete, server.URL, nil)
req3.Header.Set(headerKeyProtocolVersion, "2024-06-18")
resp3, err := http.DefaultClient.Do(req3)
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}
defer resp3.Body.Close()
if resp3.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", resp3.StatusCode)
}
})
}

func postJSON(url string, bodyObject any) (*http.Response, error) {
jsonBody, _ := json.Marshal(bodyObject)
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))
Expand Down