diff --git a/lib/srv/mcp/http.go b/lib/srv/mcp/http.go index 9f59d6864c89e..1ff6c030b70d8 100644 --- a/lib/srv/mcp/http.go +++ b/lib/srv/mcp/http.go @@ -143,9 +143,13 @@ func (t *streamableHTTPTransport) RoundTrip(r *http.Request) (*http.Response, er default: t.emitInvalidHTTPRequest(t.parentCtx, r) + + statusText := http.StatusText(http.StatusMethodNotAllowed) return &http.Response{ Request: r, + Status: statusText, StatusCode: http.StatusMethodNotAllowed, + Body: io.NopCloser(bytes.NewReader(nil)), // Body must not be nil. }, nil } } diff --git a/lib/srv/mcp/http_test.go b/lib/srv/mcp/http_test.go index bd16a1ba938d3..b05b9795d3827 100644 --- a/lib/srv/mcp/http_test.go +++ b/lib/srv/mcp/http_test.go @@ -168,6 +168,18 @@ func Test_handleStreamableHTTP(t *testing.T) { require.False(t, lastEvent.Success) require.Equal(t, "HTTP 404 Not Found", lastEvent.Error) }) + + t.Run("unsupported method", func(t *testing.T) { + emitter.Reset() + httpClient := listener.MakeHTTPClient() + request, err := http.NewRequestWithContext(t.Context(), http.MethodOptions, "http://localhost/", nil) + require.NoError(t, err) + response, err := httpClient.Do(request) + require.NoError(t, err) + defer response.Body.Close() + require.Equal(t, http.StatusMethodNotAllowed, response.StatusCode) + require.Equal(t, libevents.MCPSessionInvalidHTTPRequest, emitter.LastEvent().GetType()) + }) } func Test_handleAuthErrHTTP(t *testing.T) { diff --git a/lib/utils/mcputils/http.go b/lib/utils/mcputils/http.go index d3e1f414ebd50..8995faa2d6514 100644 --- a/lib/utils/mcputils/http.go +++ b/lib/utils/mcputils/http.go @@ -26,6 +26,7 @@ import ( "log/slog" "mime" "net/http" + "strconv" "github.com/gravitational/trace" mcpclienttransport "github.com/mark3labs/mcp-go/client/transport" @@ -76,7 +77,11 @@ func ReplaceHTTPResponse(ctx context.Context, resp *http.Response, processor Ser return trace.Wrap(err) } resp.Body = io.NopCloser(bytes.NewReader(respToClientAsBody)) + + // Make sure content length in both the response field and the header + // are updated. resp.ContentLength = int64(len(respToClientAsBody)) + resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(respToClientAsBody)), 10)) return nil case "text/event-stream": @@ -88,6 +93,11 @@ func ReplaceHTTPResponse(ctx context.Context, resp *http.Response, processor Ser SSEResponseReader: NewSSEResponseReader(resp.Body), processor: processor, } + + // Content-Length should be -1 from server for streams. Force to -1 again just to + // be sure. + resp.ContentLength = -1 + resp.Header.Del("Content-Length") return nil default: return trace.BadParameter("unsupported response type %s", mediaType) diff --git a/lib/utils/mcputils/http_test.go b/lib/utils/mcputils/http_test.go index ee06617f0f815..4af9ee756ed64 100644 --- a/lib/utils/mcputils/http_test.go +++ b/lib/utils/mcputils/http_test.go @@ -19,7 +19,9 @@ package mcputils import ( + "bytes" "context" + "fmt" "io" "log/slog" "maps" @@ -49,7 +51,7 @@ func TestReplaceHTTPResponse(t *testing.T) { ctx := t.Context() // Set up a server. Use InMemoryListener for synctest. - mcpServer := mcptest.NewServer() + mcpServer := mcptest.NewServerWithVersion("11.22.33") listener := listenerutils.NewInMemoryListener() httpServer := http.Server{ Handler: mcpserver.NewStreamableHTTPServer(mcpServer), @@ -77,7 +79,8 @@ func TestReplaceHTTPResponse(t *testing.T) { require.NoError(t, client.Start(ctx)) // Initialize client and call a tool. - mcptest.MustInitializeClient(t, client) + result := mcptest.MustInitializeClient(t, client) + require.Equal(t, "111.222.333", result.ServerInfo.Version) mcptest.MustCallServerTool(t, client) require.Equal(t, uint32(2), httpClientTransport.countMCPResponse.Load()) @@ -136,10 +139,27 @@ func (t *testReplaceHTTPResponseTransport) RoundTrip(r *http.Request) (*http.Res if err := ReplaceHTTPResponse(r.Context(), resp, t); err != nil { return nil, trace.Wrap(err) } + + if resp != nil { + switch { + case resp.ContentLength >= 0: + if resp.Header.Get("Content-Length") != fmt.Sprintf("%d", resp.ContentLength) { + return nil, trace.CompareFailed("Content-Length does not match Content-Length header") + } + + default: + if resp.Header.Get("Content-Length") != "" { + return nil, trace.CompareFailed("Content-Length does not match Content-Length header") + } + } + } + return resp, nil } func (t *testReplaceHTTPResponseTransport) ProcessResponse(_ context.Context, response *JSONRPCResponse) mcp.JSONRPCMessage { + // Replace server version. + response.Result = bytes.ReplaceAll(response.Result, []byte("11.22.33"), []byte("111.222.333")) t.countMCPResponse.Add(1) return response }