diff --git a/integration/appaccess/mcp_test.go b/integration/appaccess/mcp_test.go index cf513e33acf60..45bca027d42a8 100644 --- a/integration/appaccess/mcp_test.go +++ b/integration/appaccess/mcp_test.go @@ -85,13 +85,9 @@ func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) { serverConn, err := pack.tc.DialMCPServer(context.Background(), appName) require.NoError(t, err) - ctx := t.Context() stdioClient := mcptest.NewStdioClientFromConn(t, serverConn) - - _, err = mcptest.InitializeClient(ctx, stdioClient) - require.NoError(t, err) - - mcptest.MustCallServerTool(t, ctx, stdioClient) + mcptest.MustInitializeClient(t, stdioClient) + mcptest.MustCallServerTool(t, stdioClient) } func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) { @@ -136,7 +132,6 @@ func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) { defer client.Close() // Initialize client and call a tool. - _, err = mcptest.InitializeClient(ctx, client) - require.NoError(t, err) - mcptest.MustCallServerTool(t, ctx, client) + mcptest.MustInitializeClient(t, client) + mcptest.MustCallServerTool(t, client) } diff --git a/lib/srv/mcp/http_test.go b/lib/srv/mcp/http_test.go index 89455e4325246..bd16a1ba938d3 100644 --- a/lib/srv/mcp/http_test.go +++ b/lib/srv/mcp/http_test.go @@ -125,9 +125,8 @@ func Test_handleStreamableHTTP(t *testing.T) { getEventCode := func(e apievents.AuditEvent) string { return e.GetCode() } - _, err = mcptest.InitializeClient(ctx, client) - require.NoError(t, err) - mcptest.MustCallServerTool(t, ctx, client) + mcptest.MustInitializeClient(t, client) + mcptest.MustCallServerTool(t, client) require.EventuallyWithT(t, func(t *assert.CollectT) { require.ElementsMatch(t, []string{ libevents.MCPSessionStartCode, diff --git a/lib/srv/mcp/sse_test.go b/lib/srv/mcp/sse_test.go index 8ec15e972f685..da25590562c67 100644 --- a/lib/srv/mcp/sse_test.go +++ b/lib/srv/mcp/sse_test.go @@ -75,12 +75,11 @@ func Test_handleStdioToSSE(t *testing.T) { }, time.Second*5, time.Millisecond*100, "expect session start") require.NotEmpty(t, startEvent.McpSessionId) - resp, err := mcptest.InitializeClient(ctx, stdioClient) - require.NoError(t, err) + resp := mcptest.MustInitializeClient(t, stdioClient) require.Equal(t, "test-server", resp.ServerInfo.Name) // Make a tools call. - mcptest.MustCallServerTool(t, ctx, stdioClient) + mcptest.MustCallServerTool(t, stdioClient) // Now close the client. stdioClient.Close() diff --git a/lib/utils/mcptest/test.go b/lib/utils/mcptest/test.go index 8c02d6ab2ebca..573b18b6f413d 100644 --- a/lib/utils/mcptest/test.go +++ b/lib/utils/mcptest/test.go @@ -94,12 +94,19 @@ func InitializeClient(ctx context.Context, client *mcpclient.Client) (*mcp.Initi return resp, trace.Wrap(err) } +func MustInitializeClient(t *testing.T, client *mcpclient.Client) *mcp.InitializeResult { + t.Helper() + result, err := InitializeClient(t.Context(), client) + require.NoError(t, err) + return result +} + // MustCallServerTool calls the "hello-server" tool and verifies the result. -func MustCallServerTool(t *testing.T, ctx context.Context, client *mcpclient.Client) { +func MustCallServerTool(t *testing.T, client *mcpclient.Client) { t.Helper() callToolRequest := mcp.CallToolRequest{} callToolRequest.Params.Name = "hello-server" - callToolResult, err := client.CallTool(ctx, callToolRequest) + callToolResult, err := client.CallTool(t.Context(), callToolRequest) require.NoError(t, err) require.NotNil(t, callToolResult) require.Equal(t, []mcp.Content{ diff --git a/lib/utils/mcputils/http.go b/lib/utils/mcputils/http.go index d3aa0703bbb16..d3e1f414ebd50 100644 --- a/lib/utils/mcputils/http.go +++ b/lib/utils/mcputils/http.go @@ -23,13 +23,16 @@ import ( "context" "encoding/json" "io" + "log/slog" "mime" "net/http" "github.com/gravitational/trace" + mcpclienttransport "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/utils" ) @@ -141,3 +144,106 @@ func (r *httpSSEResponseReplacer) Read(p []byte) (int, error) { r.buf = e.marshal() return r.Read(p) } + +// HTTPReaderWriter implements MessageWriter and TransportReader for +// streamable HTTP transport. +type HTTPReaderWriter struct { + targetClient *mcpclienttransport.StreamableHTTP + messagesToRead chan string +} + +// NewHTTPReaderWriter creates a new HTTPReaderWriter that implements +// MessageWriter and TransportReader that connects to provided serverURL in +// streamable HTTP transport. +func NewHTTPReaderWriter( + ctx context.Context, + serverURL string, + opts ...mcpclienttransport.StreamableHTTPCOption, +) (*HTTPReaderWriter, error) { + // Use a real client transport from mcp-go to avoid writing custom logic. + targetClient, err := mcpclienttransport.NewStreamableHTTP(serverURL, opts...) + if err != nil { + return nil, trace.Wrap(err) + } + + h := &HTTPReaderWriter{ + targetClient: targetClient, + // Normally only one message at a time. Use a small buffer just in case. + messagesToRead: make(chan string, 10), + } + + // Notification will only be received if mcpclienttransport.WithContinuousListening + // is set and the listen (GET) request is successful. + h.targetClient.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + if err := h.sendMessageToRead(notification); err != nil { + // Error should never happen. Log a warning just in case. + slog.WarnContext(ctx, "failed to marshal msg", "error", err) + } + }) + if err := h.targetClient.Start(ctx); err != nil { + return nil, trace.Wrap(err) + } + return h, nil +} + +func (h *HTTPReaderWriter) sendMessageToRead(msg any) error { + data, err := json.Marshal(msg) + if err != nil { + return trace.Wrap(err) + } + h.messagesToRead <- string(data) + return nil +} + +// WriteMessage sends out a HTTP request to target. WriteMessage implements +// MessageWriter. +func (h *HTTPReaderWriter) WriteMessage(ctx context.Context, msg mcp.JSONRPCMessage) error { + switch v := msg.(type) { + case *JSONRPCRequest: + resp, err := h.targetClient.SendRequest(ctx, mcpclienttransport.JSONRPCRequest{ + JSONRPC: v.JSONRPC, + ID: v.ID, + Method: string(v.Method), + Params: v.Params, + }) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(h.sendMessageToRead(resp)) + + case *JSONRPCNotification: + return trace.Wrap(h.targetClient.SendNotification(ctx, mcp.JSONRPCNotification{ + JSONRPC: v.JSONRPC, + Notification: mcp.Notification{ + Method: string(v.Method), + Params: mcp.NotificationParams{ + AdditionalFields: v.Params, + }, + }, + })) + + default: + return trace.BadParameter("unrecognized message type: %T", msg) + } +} + +// Type implements TransportReader. +func (h *HTTPReaderWriter) Type() string { + return types.MCPTransportHTTP +} + +// ReadMessage returns responses and notifications received from the target. +// ReadMessage implements TransportReader. +func (h *HTTPReaderWriter) ReadMessage(ctx context.Context) (string, error) { + select { + case <-ctx.Done(): + return "", io.EOF + case msg := <-h.messagesToRead: + return msg, nil + } +} + +// Close implements TransportReader. +func (h *HTTPReaderWriter) Close() error { + return trace.Wrap(h.targetClient.Close()) +} diff --git a/lib/utils/mcputils/http_test.go b/lib/utils/mcputils/http_test.go index a5fa3ddd5feb2..ee06617f0f815 100644 --- a/lib/utils/mcputils/http_test.go +++ b/lib/utils/mcputils/http_test.go @@ -20,18 +20,22 @@ package mcputils import ( "context" + "io" + "log/slog" "maps" "net/http" "sync" "sync/atomic" "testing" "testing/synctest" + "time" "github.com/gravitational/trace" mcpclient "github.com/mark3labs/mcp-go/client" mcpclienttransport "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" listenerutils "github.com/gravitational/teleport/lib/utils/listener" @@ -73,9 +77,8 @@ func TestReplaceHTTPResponse(t *testing.T) { require.NoError(t, client.Start(ctx)) // Initialize client and call a tool. - _, err = mcptest.InitializeClient(ctx, client) - require.NoError(t, err) - mcptest.MustCallServerTool(t, ctx, client) + mcptest.MustInitializeClient(t, client) + mcptest.MustCallServerTool(t, client) require.Equal(t, uint32(2), httpClientTransport.countMCPResponse.Load()) // Send notifications from server. Notifications will be sent through SSE. @@ -144,3 +147,67 @@ func (t *testReplaceHTTPResponseTransport) ProcessResponse(_ context.Context, re func (t *testReplaceHTTPResponseTransport) ProcessNotification(_ context.Context, notification *JSONRPCNotification) mcp.JSONRPCMessage { return notification } + +func TestHTTPReaderWriter(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // Set up an MCP server. + mcpServer := mcptest.NewServer() + httpServer := mcpserver.NewTestStreamableHTTPServer(mcpServer) + t.Cleanup(httpServer.Close) + + // Create a proxy that converts from stdio to HTTP. + clientStdin, writeToClient := io.Pipe() + readFromClient, clientStdout := io.Pipe() + t.Cleanup(func() { + assert.NoError(t, trace.NewAggregate( + clientStdin.Close(), writeToClient.Close(), + readFromClient.Close(), clientStdout.Close(), + )) + }) + + serverReaderWriter, err := NewHTTPReaderWriter(ctx, httpServer.URL, mcpclienttransport.WithContinuousListening()) + require.NoError(t, err) + defer serverReaderWriter.Close() // Send DELETE before server is shutdown + + clientTransportReader := NewStdioReader(readFromClient) + clientWriter := NewStdioMessageWriter(writeToClient) + proxyReaderWriter(t, clientTransportReader, clientWriter, serverReaderWriter, serverReaderWriter) + + // Make a "high-level" stdio MCP client and test the proxy. + notificationsChan := make(chan mcp.JSONRPCNotification, 1) + stdioClient := mcptest.NewStdioClient(t, clientStdin, clientStdout) + stdioClient.OnNotification(func(notification mcp.JSONRPCNotification) { + notificationsChan <- notification + }) + mcptest.MustInitializeClient(t, stdioClient) + mcptest.MustCallServerTool(t, stdioClient) + + // Test listening notifications from server. + mcpServer.SendNotificationToAllClients("notifications/test", nil) + select { + case notification := <-notificationsChan: + require.NotNil(t, notification) + require.Equal(t, "notifications/test", notification.Notification.Method) + case <-time.After(time.Second): + require.Fail(t, "timeout waiting for notification") + } +} + +func proxyReaderWriter( + t *testing.T, + clientTransportReader TransportReader, + clientWriter MessageWriter, + serverTransportReader TransportReader, + serverWriter MessageWriter, +) { + t.Helper() + + clientMessageReader, err := NewForwardMessageReader(slog.Default(), clientTransportReader, serverWriter) + require.NoError(t, err) + serverMessageReader, err := NewForwardMessageReader(slog.Default(), serverTransportReader, clientWriter) + require.NoError(t, err) + go clientMessageReader.Run(t.Context()) + go serverMessageReader.Run(t.Context()) +} diff --git a/lib/utils/mcputils/reader.go b/lib/utils/mcputils/reader.go index 6099c901c9f0e..57074ddf53604 100644 --- a/lib/utils/mcputils/reader.go +++ b/lib/utils/mcputils/reader.go @@ -226,3 +226,22 @@ func ReadOneResponse(ctx context.Context, reader TransportReader) (*JSONRPCRespo return unmarshalResponse(rawMessage) } + +// NewForwardMessageReader creates a MessageReader that simply forwards every +// message read from the provided reader to the provided writer. +func NewForwardMessageReader(logger *slog.Logger, reader TransportReader, writer MessageWriter) (*MessageReader, error) { + return NewMessageReader(MessageReaderConfig{ + Logger: logger, + Transport: reader, + OnNotification: func(ctx context.Context, notification *JSONRPCNotification) error { + return trace.Wrap(writer.WriteMessage(ctx, notification)) + }, + OnRequest: func(ctx context.Context, request *JSONRPCRequest) error { + return trace.Wrap(writer.WriteMessage(ctx, request)) + }, + OnResponse: func(ctx context.Context, response *JSONRPCResponse) error { + return trace.Wrap(writer.WriteMessage(ctx, response)) + }, + OnParseError: LogAndIgnoreParseError(logger), + }) +} diff --git a/lib/utils/mcputils/stdio_test.go b/lib/utils/mcputils/stdio_test.go index 113655e605adc..7f53051cd8cab 100644 --- a/lib/utils/mcputils/stdio_test.go +++ b/lib/utils/mcputils/stdio_test.go @@ -28,6 +28,7 @@ import ( "time" "github.com/gravitational/trace" + "github.com/mark3labs/mcp-go/mcp" mcpserver "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -35,6 +36,32 @@ import ( "github.com/gravitational/teleport/lib/utils/mcptest" ) +type countingMessageWriter struct { + m MessageWriter + + notifications atomic.Int32 + requests atomic.Int32 + responses atomic.Int32 +} + +func newCountingMessageWriter(m MessageWriter) *countingMessageWriter { + return &countingMessageWriter{ + m: m, + } +} + +func (c *countingMessageWriter) WriteMessage(ctx context.Context, msg mcp.JSONRPCMessage) error { + switch msg.(type) { + case *JSONRPCRequest: + c.requests.Add(1) + case *JSONRPCResponse: + c.responses.Add(1) + case *JSONRPCNotification: + c.notifications.Add(1) + } + return trace.Wrap(c.m.WriteMessage(ctx, msg)) +} + // TestStdioHelpers tests MessageReader and StdioMessageWriter by // implementing a passthrough reverse proxy. // @@ -45,12 +72,6 @@ func TestStdioHelpers(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - // Set up some counters for verification. - var readClientNotifications int32 - var readClientRequests int32 - var readServerNotifications int32 - var readServerResponses int32 - // Pipes for hooking things up. clientStdin, writeToClient := io.Pipe() readFromClient, clientStdout := io.Pipe() @@ -66,21 +87,10 @@ func TestStdioHelpers(t *testing.T) { }) // Make "low-level" message readers and writers for MITM proxy. - clientMessageWriter := NewStdioMessageWriter(writeToClient) - serverMessageWriter := NewStdioMessageWriter(writeToServer) - - clientMessageReader, err := NewMessageReader(MessageReaderConfig{ - Transport: NewStdioReader(readFromClient), - OnNotification: func(ctx context.Context, notification *JSONRPCNotification) error { - atomic.AddInt32(&readClientNotifications, 1) - return trace.Wrap(serverMessageWriter.WriteMessage(ctx, notification)) - }, - OnRequest: func(ctx context.Context, request *JSONRPCRequest) error { - atomic.AddInt32(&readClientRequests, 1) - return trace.Wrap(serverMessageWriter.WriteMessage(ctx, request)) - }, - OnParseError: ReplyParseError(clientMessageWriter), - }) + clientMessageWriter := newCountingMessageWriter(NewStdioMessageWriter(writeToClient)) + serverMessageWriter := newCountingMessageWriter(NewStdioMessageWriter(writeToServer)) + + clientMessageReader, err := NewForwardMessageReader(slog.Default(), NewStdioReader(readFromClient), serverMessageWriter) require.NoError(t, err) clientMessageReaderClosed := make(chan struct{}) go func() { @@ -88,18 +98,7 @@ func TestStdioHelpers(t *testing.T) { close(clientMessageReaderClosed) }() - serverMessageReader, err := NewMessageReader(MessageReaderConfig{ - Transport: NewStdioReader(readFromServer), - OnNotification: func(ctx context.Context, notification *JSONRPCNotification) error { - atomic.AddInt32(&readServerNotifications, 1) - return trace.Wrap(clientMessageWriter.WriteMessage(ctx, notification)) - }, - OnResponse: func(ctx context.Context, response *JSONRPCResponse) error { - atomic.AddInt32(&readServerResponses, 1) - return trace.Wrap(clientMessageWriter.WriteMessage(ctx, response)) - }, - OnParseError: LogAndIgnoreParseError(slog.Default()), - }) + serverMessageReader, err := NewForwardMessageReader(slog.Default(), NewStdioReader(readFromServer), clientMessageWriter) require.NoError(t, err) serverMessageReaderClosed := make(chan struct{}) serverMessageReaderCtx, serverMessageReaderCtxCancel := context.WithCancel(ctx) @@ -118,12 +117,11 @@ func TestStdioHelpers(t *testing.T) { // Test things out. t.Run("client initialize", func(t *testing.T) { - _, err := mcptest.InitializeClient(ctx, stdioClient) - require.NoError(t, err) + mcptest.MustInitializeClient(t, stdioClient) }) t.Run("client call tool", func(t *testing.T) { - mcptest.MustCallServerTool(t, ctx, stdioClient) + mcptest.MustCallServerTool(t, stdioClient) }) t.Run("reader closed by closing stdin", func(t *testing.T) { @@ -150,9 +148,9 @@ func TestStdioHelpers(t *testing.T) { // client -> server: notifications/initialized // client -> server: tools\call request // server -> client: tools\call response - assert.Equal(t, int32(1), atomic.LoadInt32(&readClientNotifications)) - assert.Equal(t, int32(2), atomic.LoadInt32(&readClientRequests)) - assert.Equal(t, int32(0), atomic.LoadInt32(&readServerNotifications)) - assert.Equal(t, int32(2), atomic.LoadInt32(&readServerResponses)) + assert.Equal(t, int32(1), serverMessageWriter.notifications.Load()) + assert.Equal(t, int32(2), serverMessageWriter.requests.Load()) + assert.Equal(t, int32(0), clientMessageWriter.notifications.Load()) + assert.Equal(t, int32(2), clientMessageWriter.responses.Load()) }) }