Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions integration/appaccess/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,9 @@ func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) {
serverConn, err := pack.tc.DialMCPServer(context.Background(), appName)
require.NoError(t, err)

ctx := t.Context()
stdioClient := mcptest.NewStdioClientFromConn(t, serverConn)

_, err = mcptest.InitializeClient(ctx, stdioClient)
require.NoError(t, err)

mcptest.MustCallServerTool(t, ctx, stdioClient)
mcptest.MustInitializeClient(t, stdioClient)
mcptest.MustCallServerTool(t, stdioClient)
}

func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) {
Expand Down Expand Up @@ -136,7 +132,6 @@ func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) {
defer client.Close()

// Initialize client and call a tool.
_, err = mcptest.InitializeClient(ctx, client)
require.NoError(t, err)
mcptest.MustCallServerTool(t, ctx, client)
mcptest.MustInitializeClient(t, client)
mcptest.MustCallServerTool(t, client)
}
5 changes: 2 additions & 3 deletions lib/srv/mcp/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,8 @@ func Test_handleStreamableHTTP(t *testing.T) {
getEventCode := func(e apievents.AuditEvent) string {
return e.GetCode()
}
_, err = mcptest.InitializeClient(ctx, client)
require.NoError(t, err)
mcptest.MustCallServerTool(t, ctx, client)
mcptest.MustInitializeClient(t, client)
mcptest.MustCallServerTool(t, client)
require.EventuallyWithT(t, func(t *assert.CollectT) {
require.ElementsMatch(t, []string{
libevents.MCPSessionStartCode,
Expand Down
5 changes: 2 additions & 3 deletions lib/srv/mcp/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,11 @@ func Test_handleStdioToSSE(t *testing.T) {
}, time.Second*5, time.Millisecond*100, "expect session start")
require.NotEmpty(t, startEvent.McpSessionId)

resp, err := mcptest.InitializeClient(ctx, stdioClient)
require.NoError(t, err)
resp := mcptest.MustInitializeClient(t, stdioClient)
require.Equal(t, "test-server", resp.ServerInfo.Name)

// Make a tools call.
mcptest.MustCallServerTool(t, ctx, stdioClient)
mcptest.MustCallServerTool(t, stdioClient)

// Now close the client.
stdioClient.Close()
Expand Down
11 changes: 9 additions & 2 deletions lib/utils/mcptest/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,19 @@ func InitializeClient(ctx context.Context, client *mcpclient.Client) (*mcp.Initi
return resp, trace.Wrap(err)
}

func MustInitializeClient(t *testing.T, client *mcpclient.Client) *mcp.InitializeResult {
t.Helper()
result, err := InitializeClient(t.Context(), client)
require.NoError(t, err)
return result
}

// MustCallServerTool calls the "hello-server" tool and verifies the result.
func MustCallServerTool(t *testing.T, ctx context.Context, client *mcpclient.Client) {
func MustCallServerTool(t *testing.T, client *mcpclient.Client) {
t.Helper()
callToolRequest := mcp.CallToolRequest{}
callToolRequest.Params.Name = "hello-server"
callToolResult, err := client.CallTool(ctx, callToolRequest)
callToolResult, err := client.CallTool(t.Context(), callToolRequest)
require.NoError(t, err)
require.NotNil(t, callToolResult)
require.Equal(t, []mcp.Content{
Expand Down
106 changes: 106 additions & 0 deletions lib/utils/mcputils/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ import (
"context"
"encoding/json"
"io"
"log/slog"
"mime"
"net/http"

"github.com/gravitational/trace"
mcpclienttransport "github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -141,3 +144,106 @@ func (r *httpSSEResponseReplacer) Read(p []byte) (int, error) {
r.buf = e.marshal()
return r.Read(p)
}

// HTTPReaderWriter implements MessageWriter and TransportReader for
// streamable HTTP transport.
type HTTPReaderWriter struct {
targetClient *mcpclienttransport.StreamableHTTP
messagesToRead chan string
}

// NewHTTPReaderWriter creates a new HTTPReaderWriter that implements
// MessageWriter and TransportReader that connects to provided serverURL in
// streamable HTTP transport.
func NewHTTPReaderWriter(
ctx context.Context,
serverURL string,
opts ...mcpclienttransport.StreamableHTTPCOption,
) (*HTTPReaderWriter, error) {
// Use a real client transport from mcp-go to avoid writing custom logic.
targetClient, err := mcpclienttransport.NewStreamableHTTP(serverURL, opts...)
if err != nil {
return nil, trace.Wrap(err)
}

h := &HTTPReaderWriter{
targetClient: targetClient,
// Normally only one message at a time. Use a small buffer just in case.
messagesToRead: make(chan string, 10),
}

// Notification will only be received if mcpclienttransport.WithContinuousListening
// is set and the listen (GET) request is successful.
h.targetClient.SetNotificationHandler(func(notification mcp.JSONRPCNotification) {
if err := h.sendMessageToRead(notification); err != nil {
// Error should never happen. Log a warning just in case.
slog.WarnContext(ctx, "failed to marshal msg", "error", err)
}
})
if err := h.targetClient.Start(ctx); err != nil {
return nil, trace.Wrap(err)
}
return h, nil
}

func (h *HTTPReaderWriter) sendMessageToRead(msg any) error {
data, err := json.Marshal(msg)
if err != nil {
return trace.Wrap(err)
}
h.messagesToRead <- string(data)
return nil
}

// WriteMessage sends out a HTTP request to target. WriteMessage implements
// MessageWriter.
func (h *HTTPReaderWriter) WriteMessage(ctx context.Context, msg mcp.JSONRPCMessage) error {
switch v := msg.(type) {
case *JSONRPCRequest:
resp, err := h.targetClient.SendRequest(ctx, mcpclienttransport.JSONRPCRequest{
JSONRPC: v.JSONRPC,
ID: v.ID,
Method: string(v.Method),
Params: v.Params,
})
if err != nil {
return trace.Wrap(err)
}
return trace.Wrap(h.sendMessageToRead(resp))

case *JSONRPCNotification:
return trace.Wrap(h.targetClient.SendNotification(ctx, mcp.JSONRPCNotification{
JSONRPC: v.JSONRPC,
Notification: mcp.Notification{
Method: string(v.Method),
Params: mcp.NotificationParams{
AdditionalFields: v.Params,
},
},
}))

default:
return trace.BadParameter("unrecognized message type: %T", msg)
}
}

// Type implements TransportReader.
func (h *HTTPReaderWriter) Type() string {
return types.MCPTransportHTTP
}

// ReadMessage returns responses and notifications received from the target.
// ReadMessage implements TransportReader.
func (h *HTTPReaderWriter) ReadMessage(ctx context.Context) (string, error) {
select {
case <-ctx.Done():
return "", io.EOF
case msg := <-h.messagesToRead:
return msg, nil
}
}

// Close implements TransportReader.
func (h *HTTPReaderWriter) Close() error {
return trace.Wrap(h.targetClient.Close())
}
73 changes: 70 additions & 3 deletions lib/utils/mcputils/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@ package mcputils

import (
"context"
"io"
"log/slog"
"maps"
"net/http"
"sync"
"sync/atomic"
"testing"
"testing/synctest"
"time"

"github.com/gravitational/trace"
mcpclient "github.com/mark3labs/mcp-go/client"
mcpclienttransport "github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

listenerutils "github.com/gravitational/teleport/lib/utils/listener"
Expand Down Expand Up @@ -73,9 +77,8 @@ func TestReplaceHTTPResponse(t *testing.T) {
require.NoError(t, client.Start(ctx))

// Initialize client and call a tool.
_, err = mcptest.InitializeClient(ctx, client)
require.NoError(t, err)
mcptest.MustCallServerTool(t, ctx, client)
mcptest.MustInitializeClient(t, client)
mcptest.MustCallServerTool(t, client)
require.Equal(t, uint32(2), httpClientTransport.countMCPResponse.Load())

// Send notifications from server. Notifications will be sent through SSE.
Expand Down Expand Up @@ -144,3 +147,67 @@ func (t *testReplaceHTTPResponseTransport) ProcessResponse(_ context.Context, re
func (t *testReplaceHTTPResponseTransport) ProcessNotification(_ context.Context, notification *JSONRPCNotification) mcp.JSONRPCMessage {
return notification
}

func TestHTTPReaderWriter(t *testing.T) {
t.Parallel()
ctx := t.Context()

// Set up an MCP server.
mcpServer := mcptest.NewServer()
httpServer := mcpserver.NewTestStreamableHTTPServer(mcpServer)
t.Cleanup(httpServer.Close)

// Create a proxy that converts from stdio to HTTP.
clientStdin, writeToClient := io.Pipe()
readFromClient, clientStdout := io.Pipe()
t.Cleanup(func() {
assert.NoError(t, trace.NewAggregate(
clientStdin.Close(), writeToClient.Close(),
readFromClient.Close(), clientStdout.Close(),
))
})

serverReaderWriter, err := NewHTTPReaderWriter(ctx, httpServer.URL, mcpclienttransport.WithContinuousListening())
require.NoError(t, err)
defer serverReaderWriter.Close() // Send DELETE before server is shutdown

clientTransportReader := NewStdioReader(readFromClient)
clientWriter := NewStdioMessageWriter(writeToClient)
proxyReaderWriter(t, clientTransportReader, clientWriter, serverReaderWriter, serverReaderWriter)

// Make a "high-level" stdio MCP client and test the proxy.
notificationsChan := make(chan mcp.JSONRPCNotification, 1)
stdioClient := mcptest.NewStdioClient(t, clientStdin, clientStdout)
stdioClient.OnNotification(func(notification mcp.JSONRPCNotification) {
notificationsChan <- notification
})
mcptest.MustInitializeClient(t, stdioClient)
mcptest.MustCallServerTool(t, stdioClient)

// Test listening notifications from server.
mcpServer.SendNotificationToAllClients("notifications/test", nil)
select {
case notification := <-notificationsChan:
require.NotNil(t, notification)
require.Equal(t, "notifications/test", notification.Notification.Method)
case <-time.After(time.Second):
require.Fail(t, "timeout waiting for notification")
}
}

func proxyReaderWriter(
t *testing.T,
clientTransportReader TransportReader,
clientWriter MessageWriter,
serverTransportReader TransportReader,
serverWriter MessageWriter,
) {
t.Helper()

clientMessageReader, err := NewForwardMessageReader(slog.Default(), clientTransportReader, serverWriter)
require.NoError(t, err)
serverMessageReader, err := NewForwardMessageReader(slog.Default(), serverTransportReader, clientWriter)
require.NoError(t, err)
go clientMessageReader.Run(t.Context())
go serverMessageReader.Run(t.Context())
}
19 changes: 19 additions & 0 deletions lib/utils/mcputils/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,22 @@ func ReadOneResponse(ctx context.Context, reader TransportReader) (*JSONRPCRespo

return unmarshalResponse(rawMessage)
}

// NewForwardMessageReader creates a MessageReader that simply forwards every
// message read from the provided reader to the provided writer.
func NewForwardMessageReader(logger *slog.Logger, reader TransportReader, writer MessageWriter) (*MessageReader, error) {
return NewMessageReader(MessageReaderConfig{
Logger: logger,
Transport: reader,
OnNotification: func(ctx context.Context, notification *JSONRPCNotification) error {
return trace.Wrap(writer.WriteMessage(ctx, notification))
},
OnRequest: func(ctx context.Context, request *JSONRPCRequest) error {
return trace.Wrap(writer.WriteMessage(ctx, request))
},
OnResponse: func(ctx context.Context, response *JSONRPCResponse) error {
return trace.Wrap(writer.WriteMessage(ctx, response))
},
OnParseError: LogAndIgnoreParseError(logger),
})
}
Loading
Loading