Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/srv/mcp/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,13 @@ func (t *streamableHTTPTransport) RoundTrip(r *http.Request) (*http.Response, er

default:
t.emitInvalidHTTPRequest(t.parentCtx, r)

statusText := http.StatusText(http.StatusMethodNotAllowed)
return &http.Response{
Request: r,
Status: statusText,
StatusCode: http.StatusMethodNotAllowed,
Body: io.NopCloser(bytes.NewReader(nil)), // Body must not be nil.
}, nil
}
}
Expand Down
12 changes: 12 additions & 0 deletions lib/srv/mcp/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ func Test_handleStreamableHTTP(t *testing.T) {
require.False(t, lastEvent.Success)
require.Equal(t, "HTTP 404 Not Found", lastEvent.Error)
})

t.Run("unsupported method", func(t *testing.T) {
emitter.Reset()
httpClient := listener.MakeHTTPClient()
request, err := http.NewRequestWithContext(t.Context(), http.MethodOptions, "http://localhost/", nil)
require.NoError(t, err)
response, err := httpClient.Do(request)
require.NoError(t, err)
defer response.Body.Close()
require.Equal(t, http.StatusMethodNotAllowed, response.StatusCode)
require.Equal(t, libevents.MCPSessionInvalidHTTPRequest, emitter.LastEvent().GetType())
})
}

func Test_handleAuthErrHTTP(t *testing.T) {
Expand Down
10 changes: 10 additions & 0 deletions lib/utils/mcputils/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"log/slog"
"mime"
"net/http"
"strconv"

"github.com/gravitational/trace"
mcpclienttransport "github.com/mark3labs/mcp-go/client/transport"
Expand Down Expand Up @@ -76,7 +77,11 @@ func ReplaceHTTPResponse(ctx context.Context, resp *http.Response, processor Ser
return trace.Wrap(err)
}
resp.Body = io.NopCloser(bytes.NewReader(respToClientAsBody))

// Make sure content length in both the response field and the header
// are updated.
resp.ContentLength = int64(len(respToClientAsBody))
resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(respToClientAsBody)), 10))
return nil

case "text/event-stream":
Expand All @@ -88,6 +93,11 @@ func ReplaceHTTPResponse(ctx context.Context, resp *http.Response, processor Ser
SSEResponseReader: NewSSEResponseReader(resp.Body),
processor: processor,
}

// Content-Length should be -1 from server for streams. Force to -1 again just to
// be sure.
resp.ContentLength = -1
resp.Header.Del("Content-Length")
return nil
default:
return trace.BadParameter("unsupported response type %s", mediaType)
Expand Down
24 changes: 22 additions & 2 deletions lib/utils/mcputils/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
package mcputils

import (
"bytes"
"context"
"fmt"
"io"
"log/slog"
"maps"
Expand Down Expand Up @@ -49,7 +51,7 @@ func TestReplaceHTTPResponse(t *testing.T) {
ctx := t.Context()

// Set up a server. Use InMemoryListener for synctest.
mcpServer := mcptest.NewServer()
mcpServer := mcptest.NewServerWithVersion("11.22.33")
listener := listenerutils.NewInMemoryListener()
httpServer := http.Server{
Handler: mcpserver.NewStreamableHTTPServer(mcpServer),
Expand Down Expand Up @@ -77,7 +79,8 @@ func TestReplaceHTTPResponse(t *testing.T) {
require.NoError(t, client.Start(ctx))

// Initialize client and call a tool.
mcptest.MustInitializeClient(t, client)
result := mcptest.MustInitializeClient(t, client)
require.Equal(t, "111.222.333", result.ServerInfo.Version)
mcptest.MustCallServerTool(t, client)
require.Equal(t, uint32(2), httpClientTransport.countMCPResponse.Load())

Expand Down Expand Up @@ -136,10 +139,27 @@ func (t *testReplaceHTTPResponseTransport) RoundTrip(r *http.Request) (*http.Res
if err := ReplaceHTTPResponse(r.Context(), resp, t); err != nil {
return nil, trace.Wrap(err)
}

if resp != nil {
switch {
case resp.ContentLength >= 0:
if resp.Header.Get("Content-Length") != fmt.Sprintf("%d", resp.ContentLength) {
return nil, trace.CompareFailed("Content-Length does not match Content-Length header")
}

default:
if resp.Header.Get("Content-Length") != "" {
return nil, trace.CompareFailed("Content-Length does not match Content-Length header")
}
}
}

return resp, nil
}

func (t *testReplaceHTTPResponseTransport) ProcessResponse(_ context.Context, response *JSONRPCResponse) mcp.JSONRPCMessage {
// Replace server version.
response.Result = bytes.ReplaceAll(response.Result, []byte("11.22.33"), []byte("111.222.333"))
t.countMCPResponse.Add(1)
return response
}
Expand Down
Loading