From 69f565af2c6e8ab333ded2058cd0210bc6977e78 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Thu, 9 Oct 2025 16:07:22 -0400 Subject: [PATCH 1/2] implement "tsh mcp connect" for streamable HTTP --- integration/appaccess/mcp_test.go | 44 ++++++++ lib/client/mcp/reconnect.go | 142 +++++++++++++++++++----- lib/client/mcp/reconnect_test.go | 176 +++++++++++++++++++++++++----- lib/utils/mcptest/test.go | 12 +- tool/tsh/common/mcp_app.go | 30 ++--- 5 files changed, 328 insertions(+), 76 deletions(-) diff --git a/integration/appaccess/mcp_test.go b/integration/appaccess/mcp_test.go index e3b492bb8ceff..263c1c1d8d978 100644 --- a/integration/appaccess/mcp_test.go +++ b/integration/appaccess/mcp_test.go @@ -19,16 +19,20 @@ package appaccess import ( + "net" "net/http" "testing" + "time" "github.com/gravitational/trace" mcpclient "github.com/mark3labs/mcp-go/client" mcpclienttransport "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/lib/client" + clientmcp "github.com/gravitational/teleport/lib/client/mcp" libmcp "github.com/gravitational/teleport/lib/srv/mcp" "github.com/gravitational/teleport/lib/utils/mcptest" ) @@ -52,6 +56,10 @@ func testMCP(pack *Pack, t *testing.T) { t.Run("proxy streamable HTTP success", func(t *testing.T) { testMCPProxyStreamableHTTP(t, pack, "test-http") }) + + t.Run("stdio to streamable HTTP success", func(t *testing.T) { + testMCPStdioToStreamableHTTP(t, pack, "test-http") + }) } func testMCPDialStdioNoServerFound(t *testing.T, pack *Pack) { @@ -111,3 +119,39 @@ func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) { mcptest.MustInitializeClient(t, client) mcptest.MustCallServerTool(t, client) } + +func testMCPStdioToStreamableHTTP(t *testing.T, pack *Pack, appName string) { + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { + assert.NoError(t, trace.NewAggregate(clientConn.Close(), serverConn.Close())) + }) + + // Use clientmcp.ProxyStdioConn to handle transport conversion. + // Use special dialer (on pack.tc) to dial Proxy. + dialer := client.NewMCPServerDialer(pack.tc, appName) + proxyErrChan := make(chan error, 1) + go func() { + err := clientmcp.ProxyStdioConn( + t.Context(), + clientmcp.ProxyStdioConnConfig{ + ClientStdio: serverConn, + GetApp: dialer.GetApp, + DialServer: dialer.DialALPN, + }, + ) + proxyErrChan <- err + }() + + stdioClient := mcptest.NewStdioClientFromConn(t, clientConn) + mcptest.MustInitializeClient(t, stdioClient) + mcptest.MustCallServerTool(t, stdioClient) + + // Shut done client and wait for proxy func to finish. + require.NoError(t, stdioClient.Close()) + select { + case proxyErr := <-proxyErrChan: + require.NoError(t, proxyErr) + case <-time.After(time.Second * 5): + require.Fail(t, "timed out waiting for proxy to complete") + } +} diff --git a/lib/client/mcp/reconnect.go b/lib/client/mcp/reconnect.go index 27624d0f92a58..5b06c1f0b5ff8 100644 --- a/lib/client/mcp/reconnect.go +++ b/lib/client/mcp/reconnect.go @@ -23,24 +23,35 @@ import ( "fmt" "io" "log/slog" + "net" + "net/http" "sync" "github.com/gravitational/trace" + mcpclienttransport "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils/mcputils" ) -// ProxyStdioConnWithAutoReconnectConfig is the config for ProxyStdioConnWithAutoReconnect. -type ProxyStdioConnWithAutoReconnectConfig struct { +// ProxyStdioConnConfig is the config for ProxyStdioConn. +type ProxyStdioConnConfig struct { // ClientStdio is the client stdin and stdout. ClientStdio io.ReadWriteCloser // MakeReconnectUserMessage generates a user-friendly message based on the // error. MakeReconnectUserMessage func(error) string // DialServer makes a new connection to the remote MCP server. - DialServer func(context.Context) (io.ReadWriteCloser, error) + DialServer func(context.Context) (net.Conn, error) + // GetApp returns the MCP application. + GetApp func(context.Context) (types.Application, error) + // AutoReconnect attempts to re-establish new MCP sessions with the remote + // server when encounter connection issues. + AutoReconnect bool + // Logger is the slog logger. Logger *slog.Logger @@ -51,15 +62,20 @@ type ProxyStdioConnWithAutoReconnectConfig struct { } // CheckAndSetDefaults validates the config and sets default values. -func (cfg *ProxyStdioConnWithAutoReconnectConfig) CheckAndSetDefaults() error { +func (cfg *ProxyStdioConnConfig) CheckAndSetDefaults() error { if cfg.ClientStdio == nil { return trace.BadParameter("missing ClientStdio") } + if cfg.GetApp == nil { + return trace.BadParameter("missing GetApp") + } if cfg.DialServer == nil { return trace.BadParameter("missing DialServer") } if cfg.MakeReconnectUserMessage == nil { - return trace.BadParameter("missing MakeReconnectUserMessage") + cfg.MakeReconnectUserMessage = func(err error) string { + return err.Error() + } } if cfg.Logger == nil { cfg.Logger = slog.With( @@ -73,9 +89,10 @@ func (cfg *ProxyStdioConnWithAutoReconnectConfig) CheckAndSetDefaults() error { return nil } -// ProxyStdioConnWithAutoReconnect serves a stdio client with a consistent -// connection and reconnects to the remote server upon issues. -func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWithAutoReconnectConfig) error { +// ProxyStdioConn serves a stdio client and handles transport conversion to +// the remote MCP servers. When AutoConnect is set, it also reconnects to the +// remote server with new MCP sessions upon connection issues. +func ProxyStdioConn(ctx context.Context, cfg ProxyStdioConnConfig) error { if err := cfg.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) } @@ -86,6 +103,7 @@ func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWith if err != nil { return trace.Wrap(err) } + defer serverConn.Close() clientRequestReader, err := mcputils.NewMessageReader(mcputils.MessageReaderConfig{ Transport: mcputils.NewStdioReader(cfg.ClientStdio), @@ -97,6 +115,9 @@ func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWith // initialize notification is sent from client after receiving the // initialize response so it's unlikely to hit here. if writeError := serverConn.WriteMessage(ctx, notification); writeError != nil { + if serverConn.shouldExitOnWriteError() { + return trace.Wrap(writeError) + } cfg.Logger.WarnContext(ctx, "failed to write notification to server. Notification is dropped.", "error", writeError) userMessage := cfg.MakeReconnectUserMessage(writeError) errNotification := mcp.Notification{ @@ -113,6 +134,9 @@ func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWith }, OnRequest: func(ctx context.Context, request *mcputils.JSONRPCRequest) error { if writeError := serverConn.WriteMessage(ctx, request); writeError != nil { + if serverConn.shouldExitOnWriteError() { + return trace.Wrap(writeError) + } cfg.Logger.WarnContext(ctx, "failed to write request to server", "error", writeError) userMessage := cfg.MakeReconnectUserMessage(writeError) errResp := mcp.NewJSONRPCError(request.ID, mcp.INTERNAL_ERROR, userMessage, writeError) @@ -130,22 +154,22 @@ func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWith } type serverConnWithAutoReconnect struct { - ProxyStdioConnWithAutoReconnectConfig + ProxyStdioConnConfig parentCtx context.Context mu sync.Mutex serverRequestWriter mcputils.MessageWriter - replayOnNextConn bool + firstConnectionDone bool initRequest *mcputils.JSONRPCRequest initResponse *mcp.InitializeResult initNotification *mcputils.JSONRPCNotification closeServerConn func() } -func newServerConnWithAutoReconnect(parentCtx context.Context, cfg ProxyStdioConnWithAutoReconnectConfig) (*serverConnWithAutoReconnect, error) { +func newServerConnWithAutoReconnect(parentCtx context.Context, cfg ProxyStdioConnConfig) (*serverConnWithAutoReconnect, error) { return &serverConnWithAutoReconnect{ - ProxyStdioConnWithAutoReconnectConfig: cfg, - parentCtx: parentCtx, + ProxyStdioConnConfig: cfg, + parentCtx: parentCtx, }, nil } @@ -169,23 +193,80 @@ func (r *serverConnWithAutoReconnect) WriteMessage(ctx context.Context, msg mcp. return trace.Wrap(writer.WriteMessage(ctx, msg)) } +func (r *serverConnWithAutoReconnect) makeServerTransport(ctx context.Context) (mcputils.TransportReader, mcputils.MessageWriter, error) { + r.Logger.InfoContext(ctx, "Making new transport to server") + app, err := r.GetApp(ctx) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + switch types.GetMCPServerTransportType(app.GetURI()) { + case types.MCPTransportHTTP: + transport, err := defaults.Transport() + if err != nil { + return nil, nil, trace.Wrap(err) + } + transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return r.DialServer(ctx) + } + httpReaderWriter, err := mcputils.NewHTTPReaderWriter( + r.parentCtx, + "http://localhost", // does not matter with the custom transport. + mcpclienttransport.WithHTTPBasicClient(&http.Client{ + Transport: transport, + }), + mcpclienttransport.WithContinuousListening(), + ) + if err != nil { + return nil, nil, trace.Wrap(err) + } + return httpReaderWriter, httpReaderWriter, nil + + default: + serverConn, err := r.DialServer(ctx) + if err != nil { + return nil, nil, trace.Wrap(err) + } + return mcputils.NewStdioReader(serverConn), + mcputils.NewStdioMessageWriter(serverConn), + nil + } +} + +func (r *serverConnWithAutoReconnect) canRetryLocked() bool { + // When auto-reconnect is on, always retry without exiting. + // When auto-reconnect is off, see if we have made the first connection yet. + // If not, we could retry until the first connection is established. + return r.AutoReconnect || !r.firstConnectionDone +} + +func (r *serverConnWithAutoReconnect) shouldExitOnWriteError() bool { + r.mu.Lock() + defer r.mu.Unlock() + + // just exit if we cannot retry. + return !r.canRetryLocked() +} + func (r *serverConnWithAutoReconnect) getServerRequestWriterLocked(ctx context.Context) (mcputils.MessageWriter, error) { if r.serverRequestWriter != nil { return r.serverRequestWriter, nil } - r.Logger.InfoContext(ctx, "Connecting to server") - serverConn, err := r.DialServer(ctx) + if !r.canRetryLocked() { + // We shouldn't hit here as the proxy should have been ended. + // Double-check just in case. + return nil, trace.Errorf("mcp session finished") + } + + serverTransportReader, serverWriter, err := r.makeServerTransport(ctx) if err != nil { return nil, trace.Wrap(err) } - - serverStdioReader := mcputils.NewStdioReader(serverConn) - serverWriter := mcputils.NewStdioMessageWriter(serverConn) - if r.replayOnNextConn { + if r.firstConnectionDone { // Replay initialize sequence. Any error here is likely permanent. - if err := r.replayInitializeLocked(ctx, serverStdioReader, serverWriter); err != nil { - serverConn.Close() + if err := r.replayInitializeLocked(ctx, serverTransportReader, serverWriter); err != nil { + serverTransportReader.Close() return nil, trace.Wrap(err) } r.serverRequestWriter = serverWriter @@ -197,18 +278,23 @@ func (r *serverConnWithAutoReconnect) getServerRequestWriterLocked(ctx context.C }), serverWriter, ) - r.replayOnNextConn = true + r.firstConnectionDone = true } // This should never fail as long the correct config is passed in. serverResponseReader, err := mcputils.NewMessageReader(mcputils.MessageReaderConfig{ - Transport: serverStdioReader, - // OnClose is called when server connection is dead. - // Teleport Proxy automatically closes the connection when tsh session - // is expired. + Transport: serverTransportReader, + // OnClose is called when server connection is dead or if any handler + // fails. Teleport Proxy automatically closes the connection when tsh + // session is expired. OnClose: func() { - r.Logger.InfoContext(ctx, "Lost server connection, resetting...") r.mu.Lock() + if r.canRetryLocked() { + r.Logger.InfoContext(ctx, "Lost server session, resetting...") + } else { + r.Logger.InfoContext(ctx, "Lost server session, closing...") + r.ClientStdio.Close() + } r.serverRequestWriter = nil if r.onServerConnClosed != nil { r.onServerConnClosed() @@ -226,7 +312,7 @@ func (r *serverConnWithAutoReconnect) getServerRequestWriterLocked(ctx context.C }, }) if err != nil { - serverConn.Close() + serverTransportReader.Close() return nil, trace.Wrap(err) } diff --git a/lib/client/mcp/reconnect_test.go b/lib/client/mcp/reconnect_test.go index 6e98a8fa235ad..91760c08597b0 100644 --- a/lib/client/mcp/reconnect_test.go +++ b/lib/client/mcp/reconnect_test.go @@ -21,25 +21,28 @@ package mcp import ( "context" "io" + "net" + "net/http" "sync/atomic" "testing" "time" - "github.com/mark3labs/mcp-go/mcp" mcpserver "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/types" + listenerutils "github.com/gravitational/teleport/lib/utils/listener" "github.com/gravitational/teleport/lib/utils/mcptest" - "github.com/gravitational/teleport/lib/utils/uds" ) -func TestProxyStdioConnWithAutoReconnect(t *testing.T) { +func TestProxyStdioConn_autoReconnect(t *testing.T) { ctx := t.Context() + app := newAppFromURI(t, "some-mcp", "mcp+stdio://") var serverStdioSource atomic.Value prepServerWithVersion := func(version string) { testServerV1 := mcptest.NewServerWithVersion(version) - testServerSource, testServerDest := mustMakeSocketPair(t) + testServerSource, testServerDest := mustMakeConnPair(t) serverStdioSource.Store(testServerSource) go func() { mcpserver.NewStdioServer(testServerV1).Listen(t.Context(), testServerDest, testServerDest) @@ -47,21 +50,22 @@ func TestProxyStdioConnWithAutoReconnect(t *testing.T) { } prepServerWithVersion("1.0.0") - clientStdioSource, clientStdioDest := mustMakeSocketPair(t) + clientStdioSource, clientStdioDest := mustMakeConnPair(t) stdioClient := mcptest.NewStdioClientFromConn(t, clientStdioSource) proxyError := make(chan error, 1) serverConnClosed := make(chan struct{}, 1) // Start proxy. go func() { - proxyError <- ProxyStdioConnWithAutoReconnect(ctx, ProxyStdioConnWithAutoReconnectConfig{ + proxyError <- ProxyStdioConn(ctx, ProxyStdioConnConfig{ ClientStdio: clientStdioDest, - MakeReconnectUserMessage: func(err error) string { - return err.Error() + GetApp: func(ctx context.Context) (types.Application, error) { + return app, nil }, - DialServer: func(ctx context.Context) (io.ReadWriteCloser, error) { - return serverStdioSource.Load().(io.ReadWriteCloser), nil + DialServer: func(ctx context.Context) (net.Conn, error) { + return serverStdioSource.Load().(net.Conn), nil }, + AutoReconnect: true, onServerConnClosed: func() { serverConnClosed <- struct{}{} }, @@ -69,14 +73,10 @@ func TestProxyStdioConnWithAutoReconnect(t *testing.T) { }() // Initialize. - _, err := mcptest.InitializeClient(ctx, stdioClient) - require.NoError(t, err) + mcptest.MustInitializeClient(t, stdioClient) // Call tool success. - callToolRequest := mcp.CallToolRequest{} - callToolRequest.Params.Name = "hello-server" - _, err = stdioClient.CallTool(ctx, callToolRequest) - require.NoError(t, err) + mcptest.MustCallServerTool(t, stdioClient) // Let's kill the server, CallTool should fail. serverStdioSource.Load().(io.ReadWriteCloser).Close() @@ -85,23 +85,22 @@ func TestProxyStdioConnWithAutoReconnect(t *testing.T) { case <-time.After(time.Second): t.Fatal("timed out waiting for server connection to close") } - _, err = stdioClient.CallTool(ctx, callToolRequest) - require.ErrorContains(t, err, "use of closed network connection") + _, err := mcptest.CallServerTool(ctx, stdioClient) + require.ErrorContains(t, err, "on closed pipe") // Let it try again with a successful reconnect. prepServerWithVersion("1.0.0") - _, err = stdioClient.CallTool(ctx, callToolRequest) - require.NoError(t, err) + mcptest.MustCallServerTool(t, stdioClient) // Let's kill the server again, and prepare a different version. serverStdioSource.Load().(io.ReadWriteCloser).Close() select { case <-serverConnClosed: case <-time.After(time.Second): - t.Fatal("timed out waiting for server connection to close") + require.Fail(t, "timed out waiting for server connection to close") } prepServerWithVersion("2.0.0") - _, err = stdioClient.CallTool(ctx, callToolRequest) + _, err = mcptest.CallServerTool(ctx, stdioClient) require.ErrorContains(t, err, "server info has changed") // Cleanup. @@ -110,17 +109,142 @@ func TestProxyStdioConnWithAutoReconnect(t *testing.T) { case proxyErr := <-proxyError: require.NoError(t, proxyErr) case <-time.After(time.Second): - t.Fatal("timed out waiting for proxy connection") + require.Fail(t, "timed out waiting for proxy to complete") } } -func mustMakeSocketPair(t *testing.T) (io.ReadWriteCloser, io.ReadWriteCloser) { +func TestProxyStdioConn_http(t *testing.T) { + ctx := t.Context() + app := newAppFromURI(t, "some-mcp", "mcp+http://127.0.0.1:8888/mcp") + + // Remote MCP server. + mcpServer := mcptest.NewServer() + listener := listenerutils.NewInMemoryListener() + t.Cleanup(func() { listener.Close() }) + go http.Serve(listener, mcpserver.NewStreamableHTTPServer(mcpServer)) + + // Start proxy. + clientStdioSource, clientStdioDest := mustMakeConnPair(t) + proxyError := make(chan error, 1) + go func() { + proxyError <- ProxyStdioConn(ctx, ProxyStdioConnConfig{ + ClientStdio: clientStdioDest, + GetApp: func(ctx context.Context) (types.Application, error) { + return app, nil + }, + DialServer: func(ctx context.Context) (net.Conn, error) { + return listener.DialContext(ctx, "tcp", "") + }, + AutoReconnect: true, + }) + }() + + // Local stdio client. + stdioClient := mcptest.NewStdioClientFromConn(t, clientStdioSource) + mcptest.MustInitializeClient(t, stdioClient) + mcptest.MustCallServerTool(t, stdioClient) + + // Shut down. + stdioClient.Close() + select { + case proxyErr := <-proxyError: + require.NoError(t, proxyErr) + case <-time.After(time.Second): + require.Fail(t, "timed out waiting for proxy to complete") + } +} + +func TestProxyStdioConn_autoReconnectDisabled(t *testing.T) { + ctx := t.Context() + app := newAppFromURI(t, "some-mcp", "mcp+stdio://") + + var mcpServerConnCount atomic.Uint32 + var mcpServerConn atomic.Value + listener := listenerutils.NewInMemoryListener() + t.Cleanup(func() { listener.Close() }) + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + mcpServerConnCount.Add(1) + mcpServerConn.Store(conn) + go mcpserver.NewStdioServer(mcptest.NewServer()).Listen(t.Context(), conn, conn) + } + }() + + // Start proxy. + clientStdioSource, clientStdioDest := mustMakeConnPair(t) + serverConnClosed := make(chan struct{}, 1) + proxyError := make(chan error, 1) + go func() { + proxyError <- ProxyStdioConn(ctx, ProxyStdioConnConfig{ + ClientStdio: clientStdioDest, + GetApp: func(ctx context.Context) (types.Application, error) { + return app, nil + }, + DialServer: func(ctx context.Context) (net.Conn, error) { + return listener.DialContext(ctx, "tcp", "") + }, + AutoReconnect: false, + onServerConnClosed: func() { + serverConnClosed <- struct{}{} + }, + }) + }() + + // Local stdio client. + stdioClient := mcptest.NewStdioClientFromConn(t, clientStdioSource) + mcptest.MustInitializeClient(t, stdioClient) + mcptest.MustCallServerTool(t, stdioClient) + + // Let's kill the server conn. + connCloser, ok := mcpServerConn.Load().(io.Closer) + require.True(t, ok) + require.NoError(t, connCloser.Close()) + select { + case <-serverConnClosed: + case <-time.After(time.Second): + require.Fail(t, "timed out waiting for server connection to close") + } + + // Check proxy has ended. + select { + case proxyErr := <-proxyError: + require.NoError(t, proxyErr) + case <-time.After(time.Second): + require.Fail(t, "timed out waiting for proxy to complete") + } + + // New request should fail and no retry is performed. + _, err := mcptest.CallServerTool(t.Context(), stdioClient) + require.ErrorContains(t, err, "on closed pipe") + require.Equal(t, uint32(1), mcpServerConnCount.Load()) +} + +func mustMakeConnPair(t *testing.T) (net.Conn, net.Conn) { t.Helper() - source, dest, err := uds.NewSocketpair(uds.SocketTypeStream) - require.NoError(t, err) + source, dest := net.Pipe() t.Cleanup(func() { source.Close() dest.Close() }) return source, dest } + +func newAppFromURI(t *testing.T, name, uri string) types.Application { + t.Helper() + spec := types.AppSpecV3{ + URI: uri, + } + if types.GetMCPServerTransportType(uri) == types.MCPTransportStdio { + spec.MCP = &types.MCP{ + Command: "test", + RunAsHostUser: "test", + } + } + app, err := types.NewAppV3(types.Metadata{Name: name}, spec) + require.NoError(t, err) + return app +} diff --git a/lib/utils/mcptest/test.go b/lib/utils/mcptest/test.go index 573b18b6f413d..7f8c36beb9316 100644 --- a/lib/utils/mcptest/test.go +++ b/lib/utils/mcptest/test.go @@ -104,12 +104,18 @@ func MustInitializeClient(t *testing.T, client *mcpclient.Client) *mcp.Initializ // MustCallServerTool calls the "hello-server" tool and verifies the result. func MustCallServerTool(t *testing.T, client *mcpclient.Client) { t.Helper() - callToolRequest := mcp.CallToolRequest{} - callToolRequest.Params.Name = "hello-server" - callToolResult, err := client.CallTool(t.Context(), callToolRequest) + callToolResult, err := CallServerTool(t.Context(), client) require.NoError(t, err) require.NotNil(t, callToolResult) require.Equal(t, []mcp.Content{ mcp.NewTextContent("hello client"), }, callToolResult.Content) } + +// CallServerTool calls the "hello-server" tool +func CallServerTool(ctx context.Context, client *mcpclient.Client) (*mcp.CallToolResult, error) { + callToolRequest := mcp.CallToolRequest{} + callToolRequest.Params.Name = "hello-server" + callToolResult, err := client.CallTool(ctx, callToolRequest) + return callToolResult, trace.Wrap(err) +} diff --git a/tool/tsh/common/mcp_app.go b/tool/tsh/common/mcp_app.go index 41377f527bc4f..812c31f588ef2 100644 --- a/tool/tsh/common/mcp_app.go +++ b/tool/tsh/common/mcp_app.go @@ -53,6 +53,7 @@ func newMCPConnectCommand(parent *kingpin.CmdClause, cf *CLIConf) *mcpConnectCom cmd.Arg("name", "Name of the MCP server.").Required().StringVar(&cf.AppName) cmd.Flag("auto-reconnect", mcpAutoReconnectHelp).Default("true").BoolVar(&cmd.autoReconnect) + // TODO(greedy52) support providing extra headers for streamable HTTP. return cmd } @@ -472,25 +473,16 @@ func (c *mcpConnectCommand) run() error { tc.NonInteractive = true dialer := client.NewMCPServerDialer(tc, c.cf.AppName) - if c.autoReconnect { - return clientmcp.ProxyStdioConnWithAutoReconnect( - c.cf.Context, - clientmcp.ProxyStdioConnWithAutoReconnectConfig{ - ClientStdio: utils.CombinedStdio{}, - DialServer: func(ctx context.Context) (io.ReadWriteCloser, error) { - conn, err := dialer.DialALPN(ctx) - return conn, trace.Wrap(err) - }, - MakeReconnectUserMessage: makeMCPReconnectUserMessage, - }, - ) - } - - serverConn, err := dialer.DialALPN(c.cf.Context) - if err != nil { - return trace.Wrap(err) - } - return trace.Wrap(utils.ProxyConn(c.cf.Context, utils.CombinedStdio{}, serverConn)) + return clientmcp.ProxyStdioConn( + c.cf.Context, + clientmcp.ProxyStdioConnConfig{ + ClientStdio: utils.CombinedStdio{}, + GetApp: dialer.GetApp, + DialServer: dialer.DialALPN, + MakeReconnectUserMessage: makeMCPReconnectUserMessage, + AutoReconnect: c.autoReconnect, + }, + ) } func makeMCPReconnectUserMessage(err error) string { From 7a127ffa12c48d3e137a05c99fcea87a60851ed1 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Wed, 22 Oct 2025 14:15:38 -0400 Subject: [PATCH 2/2] wait for 5s just to be conservative --- lib/client/mcp/reconnect_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/client/mcp/reconnect_test.go b/lib/client/mcp/reconnect_test.go index 91760c08597b0..cc3b02fc86e9b 100644 --- a/lib/client/mcp/reconnect_test.go +++ b/lib/client/mcp/reconnect_test.go @@ -82,7 +82,7 @@ func TestProxyStdioConn_autoReconnect(t *testing.T) { serverStdioSource.Load().(io.ReadWriteCloser).Close() select { case <-serverConnClosed: - case <-time.After(time.Second): + case <-time.After(time.Second * 5): t.Fatal("timed out waiting for server connection to close") } _, err := mcptest.CallServerTool(ctx, stdioClient) @@ -96,7 +96,7 @@ func TestProxyStdioConn_autoReconnect(t *testing.T) { serverStdioSource.Load().(io.ReadWriteCloser).Close() select { case <-serverConnClosed: - case <-time.After(time.Second): + case <-time.After(time.Second * 5): require.Fail(t, "timed out waiting for server connection to close") } prepServerWithVersion("2.0.0") @@ -108,7 +108,7 @@ func TestProxyStdioConn_autoReconnect(t *testing.T) { select { case proxyErr := <-proxyError: require.NoError(t, proxyErr) - case <-time.After(time.Second): + case <-time.After(time.Second * 5): require.Fail(t, "timed out waiting for proxy to complete") } } @@ -149,7 +149,7 @@ func TestProxyStdioConn_http(t *testing.T) { select { case proxyErr := <-proxyError: require.NoError(t, proxyErr) - case <-time.After(time.Second): + case <-time.After(time.Second * 5): require.Fail(t, "timed out waiting for proxy to complete") } } @@ -205,7 +205,7 @@ func TestProxyStdioConn_autoReconnectDisabled(t *testing.T) { require.NoError(t, connCloser.Close()) select { case <-serverConnClosed: - case <-time.After(time.Second): + case <-time.After(time.Second * 5): require.Fail(t, "timed out waiting for server connection to close") } @@ -213,7 +213,7 @@ func TestProxyStdioConn_autoReconnectDisabled(t *testing.T) { select { case proxyErr := <-proxyError: require.NoError(t, proxyErr) - case <-time.After(time.Second): + case <-time.After(time.Second * 5): require.Fail(t, "timed out waiting for proxy to complete") }