Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions integration/appaccess/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@
package appaccess

import (
"net"
"net/http"
"testing"
"time"

"github.com/gravitational/trace"
mcpclient "github.com/mark3labs/mcp-go/client"
mcpclienttransport "github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/client"
clientmcp "github.com/gravitational/teleport/lib/client/mcp"
libmcp "github.com/gravitational/teleport/lib/srv/mcp"
"github.com/gravitational/teleport/lib/utils/mcptest"
)
Expand All @@ -52,6 +56,10 @@ func testMCP(pack *Pack, t *testing.T) {
t.Run("proxy streamable HTTP success", func(t *testing.T) {
testMCPProxyStreamableHTTP(t, pack, "test-http")
})

t.Run("stdio to streamable HTTP success", func(t *testing.T) {
testMCPStdioToStreamableHTTP(t, pack, "test-http")
})
}

func testMCPDialStdioNoServerFound(t *testing.T, pack *Pack) {
Expand Down Expand Up @@ -111,3 +119,39 @@ func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) {
mcptest.MustInitializeClient(t, client)
mcptest.MustCallServerTool(t, client)
}

func testMCPStdioToStreamableHTTP(t *testing.T, pack *Pack, appName string) {
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
assert.NoError(t, trace.NewAggregate(clientConn.Close(), serverConn.Close()))
})

// Use clientmcp.ProxyStdioConn to handle transport conversion.
// Use special dialer (on pack.tc) to dial Proxy.
dialer := client.NewMCPServerDialer(pack.tc, appName)
proxyErrChan := make(chan error, 1)
go func() {
err := clientmcp.ProxyStdioConn(
t.Context(),
clientmcp.ProxyStdioConnConfig{
ClientStdio: serverConn,
GetApp: dialer.GetApp,
DialServer: dialer.DialALPN,
},
)
proxyErrChan <- err
}()

stdioClient := mcptest.NewStdioClientFromConn(t, clientConn)
mcptest.MustInitializeClient(t, stdioClient)
mcptest.MustCallServerTool(t, stdioClient)

// Shut done client and wait for proxy func to finish.
require.NoError(t, stdioClient.Close())
select {
case proxyErr := <-proxyErrChan:
require.NoError(t, proxyErr)
case <-time.After(time.Second * 5):
require.Fail(t, "timed out waiting for proxy to complete")
}
}
142 changes: 114 additions & 28 deletions lib/client/mcp/reconnect.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,35 @@ import (
"fmt"
"io"
"log/slog"
"net"
"net/http"
"sync"

"github.com/gravitational/trace"
mcpclienttransport "github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/utils/mcputils"
)

// ProxyStdioConnWithAutoReconnectConfig is the config for ProxyStdioConnWithAutoReconnect.
type ProxyStdioConnWithAutoReconnectConfig struct {
// ProxyStdioConnConfig is the config for ProxyStdioConn.
type ProxyStdioConnConfig struct {
// ClientStdio is the client stdin and stdout.
ClientStdio io.ReadWriteCloser
// MakeReconnectUserMessage generates a user-friendly message based on the
// error.
MakeReconnectUserMessage func(error) string
// DialServer makes a new connection to the remote MCP server.
DialServer func(context.Context) (io.ReadWriteCloser, error)
DialServer func(context.Context) (net.Conn, error)
// GetApp returns the MCP application.
GetApp func(context.Context) (types.Application, error)
// AutoReconnect attempts to re-establish new MCP sessions with the remote
// server when encounter connection issues.
AutoReconnect bool

// Logger is the slog logger.
Logger *slog.Logger

Expand All @@ -51,15 +62,20 @@ type ProxyStdioConnWithAutoReconnectConfig struct {
}

// CheckAndSetDefaults validates the config and sets default values.
func (cfg *ProxyStdioConnWithAutoReconnectConfig) CheckAndSetDefaults() error {
func (cfg *ProxyStdioConnConfig) CheckAndSetDefaults() error {
if cfg.ClientStdio == nil {
return trace.BadParameter("missing ClientStdio")
}
if cfg.GetApp == nil {
return trace.BadParameter("missing GetApp")
}
if cfg.DialServer == nil {
return trace.BadParameter("missing DialServer")
}
if cfg.MakeReconnectUserMessage == nil {
return trace.BadParameter("missing MakeReconnectUserMessage")
cfg.MakeReconnectUserMessage = func(err error) string {
return err.Error()
}
}
if cfg.Logger == nil {
cfg.Logger = slog.With(
Expand All @@ -73,9 +89,10 @@ func (cfg *ProxyStdioConnWithAutoReconnectConfig) CheckAndSetDefaults() error {
return nil
}

// ProxyStdioConnWithAutoReconnect serves a stdio client with a consistent
// connection and reconnects to the remote server upon issues.
func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWithAutoReconnectConfig) error {
// ProxyStdioConn serves a stdio client and handles transport conversion to
// the remote MCP servers. When AutoConnect is set, it also reconnects to the
// remote server with new MCP sessions upon connection issues.
func ProxyStdioConn(ctx context.Context, cfg ProxyStdioConnConfig) error {
if err := cfg.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
Expand All @@ -86,6 +103,7 @@ func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWith
if err != nil {
return trace.Wrap(err)
}
defer serverConn.Close()

clientRequestReader, err := mcputils.NewMessageReader(mcputils.MessageReaderConfig{
Transport: mcputils.NewStdioReader(cfg.ClientStdio),
Expand All @@ -97,6 +115,9 @@ func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWith
// initialize notification is sent from client after receiving the
// initialize response so it's unlikely to hit here.
if writeError := serverConn.WriteMessage(ctx, notification); writeError != nil {
if serverConn.shouldExitOnWriteError() {
return trace.Wrap(writeError)
}
cfg.Logger.WarnContext(ctx, "failed to write notification to server. Notification is dropped.", "error", writeError)
userMessage := cfg.MakeReconnectUserMessage(writeError)
errNotification := mcp.Notification{
Expand All @@ -113,6 +134,9 @@ func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWith
},
OnRequest: func(ctx context.Context, request *mcputils.JSONRPCRequest) error {
if writeError := serverConn.WriteMessage(ctx, request); writeError != nil {
if serverConn.shouldExitOnWriteError() {
return trace.Wrap(writeError)
}
cfg.Logger.WarnContext(ctx, "failed to write request to server", "error", writeError)
userMessage := cfg.MakeReconnectUserMessage(writeError)
errResp := mcp.NewJSONRPCError(request.ID, mcp.INTERNAL_ERROR, userMessage, writeError)
Expand All @@ -130,22 +154,22 @@ func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWith
}

type serverConnWithAutoReconnect struct {
ProxyStdioConnWithAutoReconnectConfig
ProxyStdioConnConfig
parentCtx context.Context

mu sync.Mutex
serverRequestWriter mcputils.MessageWriter
replayOnNextConn bool
firstConnectionDone bool
initRequest *mcputils.JSONRPCRequest
initResponse *mcp.InitializeResult
initNotification *mcputils.JSONRPCNotification
closeServerConn func()
}

func newServerConnWithAutoReconnect(parentCtx context.Context, cfg ProxyStdioConnWithAutoReconnectConfig) (*serverConnWithAutoReconnect, error) {
func newServerConnWithAutoReconnect(parentCtx context.Context, cfg ProxyStdioConnConfig) (*serverConnWithAutoReconnect, error) {
return &serverConnWithAutoReconnect{
ProxyStdioConnWithAutoReconnectConfig: cfg,
parentCtx: parentCtx,
ProxyStdioConnConfig: cfg,
parentCtx: parentCtx,
}, nil
}

Expand All @@ -169,23 +193,80 @@ func (r *serverConnWithAutoReconnect) WriteMessage(ctx context.Context, msg mcp.
return trace.Wrap(writer.WriteMessage(ctx, msg))
}

func (r *serverConnWithAutoReconnect) makeServerTransport(ctx context.Context) (mcputils.TransportReader, mcputils.MessageWriter, error) {
r.Logger.InfoContext(ctx, "Making new transport to server")
app, err := r.GetApp(ctx)
if err != nil {
return nil, nil, trace.Wrap(err)
}

switch types.GetMCPServerTransportType(app.GetURI()) {
case types.MCPTransportHTTP:
transport, err := defaults.Transport()
if err != nil {
return nil, nil, trace.Wrap(err)
}
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return r.DialServer(ctx)
}
httpReaderWriter, err := mcputils.NewHTTPReaderWriter(
r.parentCtx,
"http://localhost", // does not matter with the custom transport.
mcpclienttransport.WithHTTPBasicClient(&http.Client{
Transport: transport,
}),
mcpclienttransport.WithContinuousListening(),
)
if err != nil {
return nil, nil, trace.Wrap(err)
}
return httpReaderWriter, httpReaderWriter, nil

default:
serverConn, err := r.DialServer(ctx)
if err != nil {
return nil, nil, trace.Wrap(err)
}
return mcputils.NewStdioReader(serverConn),
mcputils.NewStdioMessageWriter(serverConn),
nil
}
}

func (r *serverConnWithAutoReconnect) canRetryLocked() bool {
// When auto-reconnect is on, always retry without exiting.
// When auto-reconnect is off, see if we have made the first connection yet.
// If not, we could retry until the first connection is established.
return r.AutoReconnect || !r.firstConnectionDone
}

func (r *serverConnWithAutoReconnect) shouldExitOnWriteError() bool {
r.mu.Lock()
defer r.mu.Unlock()

// just exit if we cannot retry.
return !r.canRetryLocked()
}

func (r *serverConnWithAutoReconnect) getServerRequestWriterLocked(ctx context.Context) (mcputils.MessageWriter, error) {
if r.serverRequestWriter != nil {
return r.serverRequestWriter, nil
}

r.Logger.InfoContext(ctx, "Connecting to server")
serverConn, err := r.DialServer(ctx)
if !r.canRetryLocked() {
// We shouldn't hit here as the proxy should have been ended.
// Double-check just in case.
return nil, trace.Errorf("mcp session finished")
}

serverTransportReader, serverWriter, err := r.makeServerTransport(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

serverStdioReader := mcputils.NewStdioReader(serverConn)
serverWriter := mcputils.NewStdioMessageWriter(serverConn)
if r.replayOnNextConn {
if r.firstConnectionDone {
// Replay initialize sequence. Any error here is likely permanent.
if err := r.replayInitializeLocked(ctx, serverStdioReader, serverWriter); err != nil {
serverConn.Close()
if err := r.replayInitializeLocked(ctx, serverTransportReader, serverWriter); err != nil {
serverTransportReader.Close()
return nil, trace.Wrap(err)
}
r.serverRequestWriter = serverWriter
Expand All @@ -197,18 +278,23 @@ func (r *serverConnWithAutoReconnect) getServerRequestWriterLocked(ctx context.C
}),
serverWriter,
)
r.replayOnNextConn = true
r.firstConnectionDone = true
}

// This should never fail as long the correct config is passed in.
serverResponseReader, err := mcputils.NewMessageReader(mcputils.MessageReaderConfig{
Transport: serverStdioReader,
// OnClose is called when server connection is dead.
// Teleport Proxy automatically closes the connection when tsh session
// is expired.
Transport: serverTransportReader,
// OnClose is called when server connection is dead or if any handler
// fails. Teleport Proxy automatically closes the connection when tsh
// session is expired.
OnClose: func() {
r.Logger.InfoContext(ctx, "Lost server connection, resetting...")
r.mu.Lock()
if r.canRetryLocked() {
r.Logger.InfoContext(ctx, "Lost server session, resetting...")
} else {
r.Logger.InfoContext(ctx, "Lost server session, closing...")
r.ClientStdio.Close()
}
r.serverRequestWriter = nil
if r.onServerConnClosed != nil {
r.onServerConnClosed()
Expand All @@ -226,7 +312,7 @@ func (r *serverConnWithAutoReconnect) getServerRequestWriterLocked(ctx context.C
},
})
if err != nil {
serverConn.Close()
serverTransportReader.Close()
return nil, trace.Wrap(err)
}

Expand Down
Loading
Loading