Skip to content

Commit cde14f0

Browse files
committed
mcp: be strict about returning the Mcp-Session-Id header
Rather than returning the Mcp-Session-Id header for all responses, only return it from initialize, per the spec. Fixes #412
1 parent 2c40bdc commit cde14f0

File tree

2 files changed

+36
-32
lines changed

2 files changed

+36
-32
lines changed

mcp/streamable.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er
424424
// It is always text/event-stream, since it must carry arbitrarily many
425425
// messages.
426426
var err error
427-
t.connection.streams[""], err = t.connection.newStream(ctx, "", false)
427+
t.connection.streams[""], err = t.connection.newStream(ctx, "", false, false)
428428
if err != nil {
429429
return nil, err
430430
}
@@ -485,6 +485,10 @@ type stream struct {
485485
// an empty string is used for messages that don't correlate with an incoming request.
486486
id StreamID
487487

488+
// If isInitialize is set, the stream is in response to an initialize request,
489+
// and therefore should include the session ID header.
490+
isInitialize bool
491+
488492
// jsonResponse records whether this stream should respond with application/json
489493
// instead of text/event-stream.
490494
//
@@ -513,12 +517,13 @@ type stream struct {
513517
requests map[jsonrpc.ID]struct{}
514518
}
515519

516-
func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, jsonResponse bool) (*stream, error) {
520+
func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, isInitialize, jsonResponse bool) (*stream, error) {
517521
if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil {
518522
return nil, err
519523
}
520524
return &stream{
521525
id: id,
526+
isInitialize: isInitialize,
522527
jsonResponse: jsonResponse,
523528
requests: make(map[jsonrpc.ID]struct{}),
524529
}, nil
@@ -647,6 +652,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
647652
}
648653
requests := make(map[jsonrpc.ID]struct{})
649654
tokenInfo := auth.TokenInfoFromContext(req.Context())
655+
isInitialize := false
650656
for _, msg := range incoming {
651657
if jreq, ok := msg.(*jsonrpc.Request); ok {
652658
// Preemptively check that this is a valid request, so that we can fail
@@ -656,6 +662,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
656662
http.Error(w, err.Error(), http.StatusBadRequest)
657663
return
658664
}
665+
if jreq.Method == methodInitialize {
666+
isInitialize = true
667+
}
659668
jreq.Extra = &RequestExtra{
660669
TokenInfo: tokenInfo,
661670
Header: req.Header,
@@ -672,7 +681,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
672681
// notifications or server->client requests made in the course of handling.
673682
// Update accounting for this incoming payload.
674683
if len(requests) > 0 {
675-
stream, err = c.newStream(req.Context(), StreamID(randText()), c.jsonResponse)
684+
stream, err = c.newStream(req.Context(), StreamID(randText()), isInitialize, c.jsonResponse)
676685
if err != nil {
677686
http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError)
678687
return
@@ -708,7 +717,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
708717
func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) {
709718
w.Header().Set("Cache-Control", "no-cache, no-transform")
710719
w.Header().Set("Content-Type", "application/json")
711-
if c.sessionID != "" {
720+
if c.sessionID != "" && stream.isInitialize {
712721
w.Header().Set(sessionIDHeader, c.sessionID)
713722
}
714723

@@ -747,7 +756,7 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
747756
w.Header().Set("Cache-Control", "no-cache, no-transform")
748757
w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler]
749758
w.Header().Set("Connection", "keep-alive")
750-
if c.sessionID != "" {
759+
if c.sessionID != "" && stream.isInitialize {
751760
w.Header().Set(sessionIDHeader, c.sessionID)
752761
}
753762
if persistent {

mcp/streamable_test.go

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ func TestStreamableTransports(t *testing.T) {
133133
defer session.Close()
134134
sid := session.ID()
135135
if sid == "" {
136-
t.Error("empty session ID")
136+
t.Fatalf("empty session ID")
137137
}
138138
if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w {
139139
t.Fatalf("got protocol version %q, want %q", g, w)
@@ -475,6 +475,8 @@ func resp(id int64, result any, err error) *jsonrpc.Response {
475475
}
476476
}
477477

478+
var ()
479+
478480
func TestStreamableServerTransport(t *testing.T) {
479481
// This test checks detailed behavior of the streamable server transport, by
480482
// faking the behavior of a streamable client using a sequence of HTTP
@@ -502,7 +504,6 @@ func TestStreamableServerTransport(t *testing.T) {
502504
method: "POST",
503505
messages: []jsonrpc.Message{initializedMsg},
504506
wantStatusCode: http.StatusAccepted,
505-
wantSessionID: false, // TODO: should this be true?
506507
}
507508

508509
tests := []struct {
@@ -520,7 +521,6 @@ func TestStreamableServerTransport(t *testing.T) {
520521
messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})},
521522
wantStatusCode: http.StatusOK,
522523
wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)},
523-
wantSessionID: true,
524524
},
525525
},
526526
},
@@ -535,30 +535,26 @@ func TestStreamableServerTransport(t *testing.T) {
535535
headers: http.Header{"Accept": {"text/plain", "application/*"}},
536536
messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})},
537537
wantStatusCode: http.StatusBadRequest, // missing text/event-stream
538-
wantSessionID: false,
539538
},
540539
{
541540
method: "POST",
542541
headers: http.Header{"Accept": {"text/event-stream"}},
543542
messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})},
544543
wantStatusCode: http.StatusBadRequest, // missing application/json
545-
wantSessionID: false,
546544
},
547545
{
548546
method: "POST",
549547
headers: http.Header{"Accept": {"text/plain", "*/*"}},
550548
messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})},
551549
wantStatusCode: http.StatusOK,
552550
wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)},
553-
wantSessionID: true,
554551
},
555552
{
556553
method: "POST",
557554
headers: http.Header{"Accept": {"text/*, application/*"}},
558555
messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})},
559556
wantStatusCode: http.StatusOK,
560557
wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)},
561-
wantSessionID: true,
562558
},
563559
},
564560
},
@@ -598,7 +594,6 @@ func TestStreamableServerTransport(t *testing.T) {
598594
req(0, "notifications/progress", &ProgressNotificationParams{}),
599595
resp(2, &CallToolResult{}, nil),
600596
},
601-
wantSessionID: true,
602597
},
603598
},
604599
},
@@ -620,7 +615,6 @@ func TestStreamableServerTransport(t *testing.T) {
620615
resp(1, &ListRootsResult{}, nil),
621616
},
622617
wantStatusCode: http.StatusAccepted,
623-
wantSessionID: false,
624618
},
625619
{
626620
method: "POST",
@@ -632,7 +626,6 @@ func TestStreamableServerTransport(t *testing.T) {
632626
req(1, "roots/list", &ListRootsParams{}),
633627
resp(2, &CallToolResult{}, nil),
634628
},
635-
wantSessionID: true,
636629
},
637630
},
638631
},
@@ -663,7 +656,6 @@ func TestStreamableServerTransport(t *testing.T) {
663656
resp(1, &ListRootsResult{}, nil),
664657
},
665658
wantStatusCode: http.StatusAccepted,
666-
wantSessionID: false,
667659
},
668660
{
669661
method: "GET",
@@ -674,7 +666,6 @@ func TestStreamableServerTransport(t *testing.T) {
674666
req(0, "notifications/progress", &ProgressNotificationParams{}),
675667
req(1, "roots/list", &ListRootsParams{}),
676668
},
677-
wantSessionID: true,
678669
},
679670
{
680671
method: "POST",
@@ -685,7 +676,6 @@ func TestStreamableServerTransport(t *testing.T) {
685676
wantMessages: []jsonrpc.Message{
686677
resp(2, &CallToolResult{}, nil),
687678
},
688-
wantSessionID: true,
689679
},
690680
{
691681
method: "DELETE",
@@ -724,7 +714,6 @@ func TestStreamableServerTransport(t *testing.T) {
724714
wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{
725715
Message: `method "tools/call" is invalid during session initialization`,
726716
})},
727-
wantSessionID: true, // TODO: this is probably wrong; we don't have a valid session
728717
},
729718
},
730719
},
@@ -951,7 +940,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string,
951940
return "", 0, nil, fmt.Errorf("creating request: %w", err)
952941
}
953942
if sessionID != "" {
954-
req.Header.Set("Mcp-Session-Id", sessionID)
943+
req.Header.Set(sessionIDHeader, sessionID)
955944
}
956945
req.Header.Set("Content-Type", "application/json")
957946
req.Header.Set("Accept", "application/json, text/event-stream")
@@ -963,7 +952,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string,
963952
}
964953
defer resp.Body.Close()
965954

966-
newSessionID := resp.Header.Get("Mcp-Session-Id")
955+
newSessionID := resp.Header.Get(sessionIDHeader)
967956

968957
contentType := resp.Header.Get("Content-Type")
969958
var respBody []byte
@@ -1079,6 +1068,15 @@ func TestEventID(t *testing.T) {
10791068
}
10801069

10811070
func TestStreamableStateless(t *testing.T) {
1071+
initReq := req(1, methodInitialize, &InitializeParams{})
1072+
initResp := resp(1, &InitializeResult{
1073+
Capabilities: &ServerCapabilities{
1074+
Logging: &LoggingCapabilities{},
1075+
Tools: &ToolCapabilities{ListChanged: true},
1076+
},
1077+
ProtocolVersion: latestProtocolVersion,
1078+
ServerInfo: &Implementation{Name: "test", Version: "v1.0.0"},
1079+
}, nil)
10821080
// This version of sayHi expects
10831081
// that request from our client).
10841082
sayHi := func(ctx context.Context, req *CallToolRequest, args hiParams) (*CallToolResult, any, error) {
@@ -1092,17 +1090,22 @@ func TestStreamableStateless(t *testing.T) {
10921090
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
10931091

10941092
requests := []streamableRequest{
1093+
{
1094+
method: "POST",
1095+
messages: []jsonrpc.Message{initReq},
1096+
wantStatusCode: http.StatusOK,
1097+
wantMessages: []jsonrpc.Message{initResp},
1098+
wantSessionID: false, // sessionless
1099+
},
10951100
{
10961101
method: "POST",
10971102
wantStatusCode: http.StatusOK,
10981103
messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})},
10991104
wantBodyContaining: "greet",
1100-
wantSessionID: false,
11011105
},
11021106
{
11031107
method: "GET",
11041108
wantStatusCode: http.StatusMethodNotAllowed,
1105-
wantSessionID: false,
11061109
},
11071110
{
11081111
method: "POST",
@@ -1116,7 +1119,6 @@ func TestStreamableStateless(t *testing.T) {
11161119
StructuredContent: json.RawMessage("null"),
11171120
}, nil),
11181121
},
1119-
wantSessionID: false,
11201122
},
11211123
{
11221124
method: "POST",
@@ -1130,7 +1132,6 @@ func TestStreamableStateless(t *testing.T) {
11301132
StructuredContent: json.RawMessage("null"),
11311133
}, nil),
11321134
},
1133-
wantSessionID: false,
11341135
},
11351136
}
11361137

@@ -1166,13 +1167,7 @@ func TestStreamableStateless(t *testing.T) {
11661167
//
11671168
// This can be used by tools to look up application state preserved across
11681169
// subsequent requests.
1169-
for i, req := range requests {
1170-
// Now, we want a session for all (valid) requests.
1171-
if req.wantStatusCode != http.StatusMethodNotAllowed {
1172-
req.wantSessionID = true
1173-
}
1174-
requests[i] = req
1175-
}
1170+
requests[0].wantSessionID = true // now expect a session ID for initialize
11761171
statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
11771172
Stateless: true,
11781173
})

0 commit comments

Comments
 (0)