diff --git a/lib/client/db/mcp/errors.go b/lib/client/db/mcp/errors.go
index 19fdc90dbb2ac..04a00c0043850 100644
--- a/lib/client/db/mcp/errors.go
+++ b/lib/client/db/mcp/errors.go
@@ -33,7 +33,7 @@ import (
func FormatErrorMessage(err error) error {
switch {
case errors.Is(err, apiclient.ErrClientCredentialsHaveExpired) || utils.IsCertExpiredError(err):
- return trace.BadParameter(ReloginRequiredErrorMessage)
+ return trace.BadParameter(mcp.ReloginRequiredErrorMessage)
case strings.Contains(err.Error(), "connection reset by peer") || errors.Is(err, io.ErrClosedPipe):
return trace.BadParameter(LocalProxyConnectionErrorMessage)
}
@@ -42,12 +42,6 @@ func FormatErrorMessage(err error) error {
}
const (
- // ReloginRequiredErrorMessage is the message returned to the MCP client
- // when the tsh session expired.
- ReloginRequiredErrorMessage = `It looks like your Teleport session expired,
-you must relogin (using "tsh login" on a terminal) before continue using this
-tool. After that, there is no need to update or relaunch the MCP client - just
-try using it again.`
// LocalProxyConnectionErrorMessage is the message returned to the MCP client when
// the database client cannot connect to the local proxy.
LocalProxyConnectionErrorMessage = `Teleport MCP server is having issue while
diff --git a/lib/client/db/postgres/mcp/mcp_test.go b/lib/client/db/postgres/mcp/mcp_test.go
index 7e018d6ef5917..698baeb499200 100644
--- a/lib/client/db/postgres/mcp/mcp_test.go
+++ b/lib/client/db/postgres/mcp/mcp_test.go
@@ -140,7 +140,7 @@ func TestFormatErrors(t *testing.T) {
},
},
expectErrorMessage: func(tt require.TestingT, i1 any, i2 ...any) {
- require.Equal(t, dbmcp.ReloginRequiredErrorMessage, i1)
+ require.Equal(t, clientmcp.ReloginRequiredErrorMessage, i1)
},
},
} {
diff --git a/lib/client/mcp/errors.go b/lib/client/mcp/errors.go
new file mode 100644
index 0000000000000..969a4297e23be
--- /dev/null
+++ b/lib/client/mcp/errors.go
@@ -0,0 +1,84 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcp
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "syscall"
+
+ "github.com/gravitational/trace"
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+const (
+ // ReloginRequiredErrorMessage is the message returned to the MCP client
+ // when the tsh session expired.
+ ReloginRequiredErrorMessage = `It looks like your Teleport session expired,
+you must relogin (using "tsh login" on a terminal) before continue using this
+tool. After that, there is no need to update or relaunch the MCP client - just
+try using it again.`
+)
+
+// IsLikelyTemporaryNetworkError returns true if the error is likely a temporary
+// network error.
+func IsLikelyTemporaryNetworkError(err error) bool {
+ if trace.IsConnectionProblem(err) ||
+ isTemporarySyscallNetError(err) {
+ return true
+ }
+
+ var dnsErr *net.DNSError
+ if errors.As(err, &dnsErr) {
+ return dnsErr.Temporary()
+ }
+
+ var netErr net.Error
+ if errors.As(err, &netErr) {
+ return netErr.Timeout()
+ }
+
+ return false
+}
+
+func isTemporarySyscallNetError(err error) bool {
+ return errors.Is(err, syscall.EHOSTUNREACH) ||
+ errors.Is(err, syscall.ENETUNREACH) ||
+ errors.Is(err, syscall.ETIMEDOUT) ||
+ errors.Is(err, syscall.ECONNREFUSED)
+}
+
+// IsServerInfoChangedError returns true if the error indicates the remote MCP
+// server's info has changed from previous connections. Auto-reconnection
+// reports this scenario as an error case to be on the safe side in case things
+// like tools have changed.
+func IsServerInfoChangedError(err error) bool {
+ var serverInfoChangedError *serverInfoChangedError
+ return errors.As(err, &serverInfoChangedError)
+}
+
+type serverInfoChangedError struct {
+ expectedInfo mcp.Implementation
+ currentInfo mcp.Implementation
+}
+
+func (e *serverInfoChangedError) Error() string {
+ return fmt.Sprintf("server info has changed, expected %v, got %v", e.expectedInfo, e.currentInfo)
+}
diff --git a/lib/client/mcp/errors_test.go b/lib/client/mcp/errors_test.go
new file mode 100644
index 0000000000000..86e11d1cf8a29
--- /dev/null
+++ b/lib/client/mcp/errors_test.go
@@ -0,0 +1,40 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcp
+
+import (
+ "testing"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/stretchr/testify/require"
+)
+
+func TestIsServerInfoChangedError(t *testing.T) {
+ err := &serverInfoChangedError{
+ expectedInfo: mcp.Implementation{
+ Name: "i-am-mcp",
+ Version: "1.0.0",
+ },
+ currentInfo: mcp.Implementation{
+ Name: "i-am-mcp",
+ Version: "1.1.0",
+ },
+ }
+ require.True(t, IsServerInfoChangedError(err))
+}
diff --git a/lib/client/mcp/reconnect.go b/lib/client/mcp/reconnect.go
new file mode 100644
index 0000000000000..11a7d6a6e00cd
--- /dev/null
+++ b/lib/client/mcp/reconnect.go
@@ -0,0 +1,327 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcp
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log/slog"
+ "sync"
+
+ "github.com/gravitational/trace"
+ "github.com/mark3labs/mcp-go/mcp"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/lib/utils/mcputils"
+)
+
+// ProxyStdioConnWithAutoReconnectConfig is the config for ProxyStdioConnWithAutoReconnect.
+type ProxyStdioConnWithAutoReconnectConfig struct {
+ // ClientStdio is the client stdin and stdout.
+ ClientStdio io.ReadWriteCloser
+ // MakeReconnectUserMessage generates a user-friendly message based on the
+ // error.
+ MakeReconnectUserMessage func(error) string
+ // DialServer makes a new connection to the remote MCP server.
+ DialServer func(context.Context) (io.ReadWriteCloser, error)
+ // Logger is the slog logger.
+ Logger *slog.Logger
+
+ // clientResponseWriter replies to ClientStdio.
+ clientResponseWriter mcputils.MessageWriter
+ // onServerConnClosed is a callback when remote server connection is dead.
+ onServerConnClosed func()
+}
+
+// CheckAndSetDefaults validates the config and sets default values.
+func (cfg *ProxyStdioConnWithAutoReconnectConfig) CheckAndSetDefaults() error {
+ if cfg.ClientStdio == nil {
+ return trace.BadParameter("missing ClientStdio")
+ }
+ if cfg.DialServer == nil {
+ return trace.BadParameter("missing DialServer")
+ }
+ if cfg.MakeReconnectUserMessage == nil {
+ return trace.BadParameter("missing MakeReconnectUserMessage")
+ }
+ if cfg.Logger == nil {
+ cfg.Logger = slog.With(
+ teleport.ComponentKey,
+ teleport.Component(teleport.ComponentMCP, "autoreconnect"),
+ )
+ }
+ if cfg.clientResponseWriter == nil {
+ cfg.clientResponseWriter = mcputils.NewSyncStdioMessageWriter(cfg.ClientStdio)
+ }
+ return nil
+}
+
+// ProxyStdioConnWithAutoReconnect serves a stdio client with a consistent
+// connection and reconnects to the remote server upon issues.
+func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWithAutoReconnectConfig) error {
+ if err := cfg.CheckAndSetDefaults(); err != nil {
+ return trace.Wrap(err)
+ }
+
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ serverConn, err := newServerConnWithAutoReconnect(ctx, cfg)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ clientRequestReader, err := mcputils.NewMessageReader(mcputils.MessageReaderConfig{
+ Transport: mcputils.NewStdioReader(cfg.ClientStdio),
+ ParentContext: ctx,
+ Logger: cfg.Logger.With("client", "stdin"),
+ OnParseError: mcputils.ReplyParseError(cfg.clientResponseWriter),
+ OnNotification: func(ctx context.Context, notification *mcputils.JSONRPCNotification) error {
+ // By spec, we should not reply to notifications. Try our best to
+ // send a notification with the error message. In practice, only the
+ // initialize notification is sent from client after receiving the
+ // initialize response so it's unlikely to hit here.
+ if writeError := serverConn.WriteMessage(ctx, notification); writeError != nil {
+ cfg.Logger.WarnContext(ctx, "failed to write notification to server. Notification is dropped.", "error", writeError)
+ userMessage := cfg.MakeReconnectUserMessage(writeError)
+ errNotification := mcp.Notification{
+ Method: "notifications/tsherr",
+ Params: mcp.NotificationParams{
+ AdditionalFields: map[string]any{
+ "error": fmt.Sprintf("Notification %q was dropped. %s", notification.Method, userMessage),
+ },
+ },
+ }
+ return trace.Wrap(cfg.clientResponseWriter.WriteMessage(ctx, errNotification))
+ }
+ return nil
+ },
+ OnRequest: func(ctx context.Context, request *mcputils.JSONRPCRequest) error {
+ if writeError := serverConn.WriteMessage(ctx, request); writeError != nil {
+ cfg.Logger.WarnContext(ctx, "failed to write request to server", "error", writeError)
+ userMessage := cfg.MakeReconnectUserMessage(writeError)
+ errResp := mcp.NewJSONRPCError(request.ID, mcp.INTERNAL_ERROR, userMessage, writeError)
+ return trace.Wrap(cfg.clientResponseWriter.WriteMessage(ctx, errResp))
+ }
+ return nil
+ },
+ })
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ clientRequestReader.Run(ctx)
+ return nil
+
+}
+
+type serverConnWithAutoReconnect struct {
+ ProxyStdioConnWithAutoReconnectConfig
+ parentCtx context.Context
+
+ mu sync.Mutex
+ serverRequestWriter mcputils.MessageWriter
+ replayOnNextConn bool
+ initRequest *mcputils.JSONRPCRequest
+ initResponse *mcp.InitializeResult
+ initNotification *mcputils.JSONRPCNotification
+ closeServerConn func()
+}
+
+func newServerConnWithAutoReconnect(parentCtx context.Context, cfg ProxyStdioConnWithAutoReconnectConfig) (*serverConnWithAutoReconnect, error) {
+ return &serverConnWithAutoReconnect{
+ ProxyStdioConnWithAutoReconnectConfig: cfg,
+ parentCtx: parentCtx,
+ }, nil
+}
+
+func (r *serverConnWithAutoReconnect) Close() error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.closeServerConn != nil {
+ r.closeServerConn()
+ }
+ return nil
+}
+
+func (r *serverConnWithAutoReconnect) WriteMessage(ctx context.Context, msg mcp.JSONRPCMessage) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ writer, err := r.getServerRequestWriterLocked(ctx)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ return trace.Wrap(writer.WriteMessage(ctx, msg))
+}
+
+func (r *serverConnWithAutoReconnect) getServerRequestWriterLocked(ctx context.Context) (mcputils.MessageWriter, error) {
+ if r.serverRequestWriter != nil {
+ return r.serverRequestWriter, nil
+ }
+
+ r.Logger.InfoContext(ctx, "Connecting to server")
+ serverConn, err := r.DialServer(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ serverStdioReader := mcputils.NewStdioReader(serverConn)
+ serverWriter := mcputils.NewStdioMessageWriter(serverConn)
+ if r.replayOnNextConn {
+ // Replay initialize sequence. Any error here is likely permanent.
+ if err := r.replayInitializeLocked(ctx, serverStdioReader, serverWriter); err != nil {
+ serverConn.Close()
+ return nil, trace.Wrap(err)
+ }
+ r.serverRequestWriter = serverWriter
+ } else {
+ r.serverRequestWriter = mcputils.NewMultiMessageWriter(
+ mcputils.MessageWriterFunc(func(ctx context.Context, msg mcp.JSONRPCMessage) error {
+ r.cacheMessageLocked(ctx, msg)
+ return nil
+ }),
+ serverWriter,
+ )
+ r.replayOnNextConn = true
+ }
+
+ // This should never fail as long the correct config is passed in.
+ serverResponseReader, err := mcputils.NewMessageReader(mcputils.MessageReaderConfig{
+ Transport: serverStdioReader,
+ ParentContext: r.parentCtx,
+ // OnClose is called when server connection is dead.
+ // Teleport Proxy automatically closes the connection when tsh session
+ // is expired.
+ OnClose: func() {
+ r.Logger.InfoContext(ctx, "Lost server connection, resetting...")
+ r.mu.Lock()
+ r.serverRequestWriter = nil
+ if r.onServerConnClosed != nil {
+ r.onServerConnClosed()
+ }
+ r.mu.Unlock()
+ },
+ Logger: r.Logger.With("server", "stdout"),
+ OnParseError: mcputils.LogAndIgnoreParseError(r.Logger),
+ OnNotification: func(ctx context.Context, notification *mcputils.JSONRPCNotification) error {
+ return trace.Wrap(r.clientResponseWriter.WriteMessage(ctx, notification))
+ },
+ OnResponse: func(ctx context.Context, response *mcputils.JSONRPCResponse) error {
+ r.cacheMessageLocked(ctx, response)
+ return trace.Wrap(r.clientResponseWriter.WriteMessage(ctx, response))
+ },
+ })
+ if err != nil {
+ serverConn.Close()
+ return nil, trace.Wrap(err)
+ }
+
+ readerCtx, readerCancel := context.WithCancel(r.parentCtx)
+ r.closeServerConn = readerCancel
+ go serverResponseReader.Run(readerCtx)
+
+ r.Logger.InfoContext(ctx, "Started a new MCP server connection")
+ return r.serverRequestWriter, nil
+}
+
+func (r *serverConnWithAutoReconnect) initializedLocked() bool {
+ return r.initRequest != nil && r.initResponse != nil && r.initNotification != nil
+}
+
+func (r *serverConnWithAutoReconnect) replayInitializeLocked(ctx context.Context, serverReader mcputils.TransportReader, serverWriter mcputils.MessageWriter) error {
+ if !r.initializedLocked() {
+ return trace.Errorf("client has not initialized yet")
+ }
+
+ r.Logger.DebugContext(ctx, "Replaying initialize request")
+ if err := serverWriter.WriteMessage(ctx, r.initRequest); err != nil {
+ return trace.Wrap(err)
+ }
+
+ r.Logger.DebugContext(ctx, "Reading and comparing initialize response")
+ msg, err := mcputils.ReadOneResponse(ctx, serverReader)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ if err := r.checkReplyResponseLocked(msg); err != nil {
+ return trace.Wrap(err)
+ }
+
+ r.Logger.DebugContext(ctx, "Replaying initialized notification")
+ if err := serverWriter.WriteMessage(ctx, r.initNotification); err != nil {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+func (r *serverConnWithAutoReconnect) checkReplyResponseLocked(msg mcp.JSONRPCMessage) error {
+ resp, ok := msg.(*mcputils.JSONRPCResponse)
+ if !ok {
+ return trace.Errorf("expected initialize response, got %T", resp)
+ }
+ if resp.Error != nil {
+ return trace.Errorf("expected initialize result but got error")
+ }
+ if resp.ID.String() != r.initRequest.ID.String() {
+ return trace.CompareFailed("expected initialize response with ID %s, got %s", r.initRequest.ID, resp.ID.String())
+ }
+
+ newResult, err := resp.GetInitializeResult()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if newResult.ServerInfo != r.initResponse.ServerInfo {
+ return trace.Wrap(&serverInfoChangedError{
+ expectedInfo: r.initResponse.ServerInfo,
+ currentInfo: newResult.ServerInfo,
+ })
+ }
+ return nil
+}
+
+// cacheMessageLocked caches client init request and notification.
+func (r *serverConnWithAutoReconnect) cacheMessageLocked(ctx context.Context, msg mcp.JSONRPCMessage) {
+ if r.initializedLocked() {
+ return
+ }
+
+ switch m := msg.(type) {
+ case *mcputils.JSONRPCRequest:
+ if r.initRequest == nil && m.Method == mcp.MethodInitialize {
+ r.initRequest = m
+ r.Logger.DebugContext(ctx, "Cached initialize", "request", m)
+ }
+ case *mcputils.JSONRPCNotification:
+ if r.initNotification == nil && m.Method == mcputils.MethodNotificationInitialized {
+ r.initNotification = m
+ r.Logger.DebugContext(ctx, "Cached notification", "notification", m)
+ }
+ case *mcputils.JSONRPCResponse:
+ if r.initResponse == nil && r.initRequest != nil && r.initRequest.ID.String() == m.ID.String() {
+ initResponse, err := m.GetInitializeResult()
+ if err != nil {
+ r.Logger.DebugContext(ctx, "Error parsing init response", "error", err)
+ } else {
+ r.initResponse = initResponse
+ r.Logger.DebugContext(ctx, "Cached response", "response", m)
+ }
+ }
+ }
+}
diff --git a/lib/client/mcp/reconnect_test.go b/lib/client/mcp/reconnect_test.go
new file mode 100644
index 0000000000000..6e98a8fa235ad
--- /dev/null
+++ b/lib/client/mcp/reconnect_test.go
@@ -0,0 +1,126 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcp
+
+import (
+ "context"
+ "io"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ mcpserver "github.com/mark3labs/mcp-go/server"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/lib/utils/mcptest"
+ "github.com/gravitational/teleport/lib/utils/uds"
+)
+
+func TestProxyStdioConnWithAutoReconnect(t *testing.T) {
+ ctx := t.Context()
+
+ var serverStdioSource atomic.Value
+ prepServerWithVersion := func(version string) {
+ testServerV1 := mcptest.NewServerWithVersion(version)
+ testServerSource, testServerDest := mustMakeSocketPair(t)
+ serverStdioSource.Store(testServerSource)
+ go func() {
+ mcpserver.NewStdioServer(testServerV1).Listen(t.Context(), testServerDest, testServerDest)
+ }()
+ }
+ prepServerWithVersion("1.0.0")
+
+ clientStdioSource, clientStdioDest := mustMakeSocketPair(t)
+ stdioClient := mcptest.NewStdioClientFromConn(t, clientStdioSource)
+ proxyError := make(chan error, 1)
+ serverConnClosed := make(chan struct{}, 1)
+
+ // Start proxy.
+ go func() {
+ proxyError <- ProxyStdioConnWithAutoReconnect(ctx, ProxyStdioConnWithAutoReconnectConfig{
+ ClientStdio: clientStdioDest,
+ MakeReconnectUserMessage: func(err error) string {
+ return err.Error()
+ },
+ DialServer: func(ctx context.Context) (io.ReadWriteCloser, error) {
+ return serverStdioSource.Load().(io.ReadWriteCloser), nil
+ },
+ onServerConnClosed: func() {
+ serverConnClosed <- struct{}{}
+ },
+ })
+ }()
+
+ // Initialize.
+ _, err := mcptest.InitializeClient(ctx, stdioClient)
+ require.NoError(t, err)
+
+ // Call tool success.
+ callToolRequest := mcp.CallToolRequest{}
+ callToolRequest.Params.Name = "hello-server"
+ _, err = stdioClient.CallTool(ctx, callToolRequest)
+ require.NoError(t, err)
+
+ // Let's kill the server, CallTool should fail.
+ serverStdioSource.Load().(io.ReadWriteCloser).Close()
+ select {
+ case <-serverConnClosed:
+ case <-time.After(time.Second):
+ t.Fatal("timed out waiting for server connection to close")
+ }
+ _, err = stdioClient.CallTool(ctx, callToolRequest)
+ require.ErrorContains(t, err, "use of closed network connection")
+
+ // Let it try again with a successful reconnect.
+ prepServerWithVersion("1.0.0")
+ _, err = stdioClient.CallTool(ctx, callToolRequest)
+ require.NoError(t, err)
+
+ // Let's kill the server again, and prepare a different version.
+ serverStdioSource.Load().(io.ReadWriteCloser).Close()
+ select {
+ case <-serverConnClosed:
+ case <-time.After(time.Second):
+ t.Fatal("timed out waiting for server connection to close")
+ }
+ prepServerWithVersion("2.0.0")
+ _, err = stdioClient.CallTool(ctx, callToolRequest)
+ require.ErrorContains(t, err, "server info has changed")
+
+ // Cleanup.
+ clientStdioSource.Close()
+ select {
+ case proxyErr := <-proxyError:
+ require.NoError(t, proxyErr)
+ case <-time.After(time.Second):
+ t.Fatal("timed out waiting for proxy connection")
+ }
+}
+
+func mustMakeSocketPair(t *testing.T) (io.ReadWriteCloser, io.ReadWriteCloser) {
+ t.Helper()
+ source, dest, err := uds.NewSocketpair(uds.SocketTypeStream)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ source.Close()
+ dest.Close()
+ })
+ return source, dest
+}
diff --git a/lib/utils/mcputils/protocol.go b/lib/utils/mcputils/protocol.go
index 5a8d8537167fd..3c458c2c7bf45 100644
--- a/lib/utils/mcputils/protocol.go
+++ b/lib/utils/mcputils/protocol.go
@@ -148,8 +148,25 @@ type JSONRPCResponse struct {
// the corresponding go object.
func (r *JSONRPCResponse) GetListToolResult() (*mcp.ListToolsResult, error) {
var listResult mcp.ListToolsResult
- if err := json.Unmarshal([]byte(r.Result), &listResult); err != nil {
+ if err := json.Unmarshal(r.Result, &listResult); err != nil {
return nil, trace.Wrap(err)
}
return &listResult, nil
}
+
+// GetInitializeResult assumes the result is for mcp.MethodInitialize and
+// returns the corresponding go object.
+func (r *JSONRPCResponse) GetInitializeResult() (*mcp.InitializeResult, error) {
+ var result mcp.InitializeResult
+ if err := json.Unmarshal(r.Result, &result); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &result, nil
+}
+
+const (
+ // MethodNotificationInitialized defines the method used for "initialized"
+ // notification. This notification is sent by the client after it receives
+ // the initialize response.
+ MethodNotificationInitialized = "notifications/initialized"
+)
diff --git a/lib/utils/mcputils/protocol_test.go b/lib/utils/mcputils/protocol_test.go
index acf0f4c721641..4be664ac35e59 100644
--- a/lib/utils/mcputils/protocol_test.go
+++ b/lib/utils/mcputils/protocol_test.go
@@ -27,8 +27,8 @@ import (
"github.com/stretchr/testify/require"
)
-func TestJSONRPCNotification(t *testing.T) {
- inputJSON := []byte(`{
+var (
+ sampleNotificationJSON = []byte(`{
"jsonrpc": "2.0",
"method": "notifications/message",
"params": {
@@ -43,9 +43,46 @@ func TestJSONRPCNotification(t *testing.T) {
}
}
}`)
+ sampleRequestJSON = []byte(`{
+ "jsonrpc": "2.0",
+ "id": 2,
+ "method": "tools/call",
+ "params": {
+ "name": "get_weather",
+ "arguments": {
+ "location": "New York"
+ }
+ }
+}`)
+
+ sampleResponseJSON = []byte(`{
+ "jsonrpc": "2.0",
+ "id": 2,
+ "result": {
+ "tools": [
+ {
+ "name": "get_weather",
+ "description": "Get current weather information for a location",
+ "inputSchema": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "City name or zip code"
+ }
+ },
+ "required": ["location"]
+ }
+ }
+ ],
+ "nextCursor": "next-page-cursor"
+ }
+}`)
+)
+func TestJSONRPCNotification(t *testing.T) {
var base baseJSONRPCMessage
- require.NoError(t, json.Unmarshal(inputJSON, &base))
+ require.NoError(t, json.Unmarshal(sampleNotificationJSON, &base))
assert.True(t, base.isNotification())
assert.False(t, base.isRequest())
assert.False(t, base.isResponse())
@@ -57,23 +94,12 @@ func TestJSONRPCNotification(t *testing.T) {
outputJSON, err := json.MarshalIndent(m, "", " ")
require.NoError(t, err)
- assert.JSONEq(t, string(inputJSON), string(outputJSON))
+ assert.JSONEq(t, string(sampleNotificationJSON), string(outputJSON))
}
func TestJSONRPCRequest(t *testing.T) {
- inputJSON := []byte(`{
- "jsonrpc": "2.0",
- "id": 2,
- "method": "tools/call",
- "params": {
- "name": "get_weather",
- "arguments": {
- "location": "New York"
- }
- }
-}`)
var base baseJSONRPCMessage
- require.NoError(t, json.Unmarshal(inputJSON, &base))
+ require.NoError(t, json.Unmarshal(sampleRequestJSON, &base))
assert.False(t, base.isNotification())
assert.True(t, base.isRequest())
assert.False(t, base.isResponse())
@@ -88,35 +114,12 @@ func TestJSONRPCRequest(t *testing.T) {
outputJSON, err := json.MarshalIndent(m, "", " ")
require.NoError(t, err)
- assert.JSONEq(t, string(inputJSON), string(outputJSON))
+ assert.JSONEq(t, string(sampleRequestJSON), string(outputJSON))
}
func TestJSONRPCResponse(t *testing.T) {
- inputJSON := []byte(`{
- "jsonrpc": "2.0",
- "id": 2,
- "result": {
- "tools": [
- {
- "name": "get_weather",
- "description": "Get current weather information for a location",
- "inputSchema": {
- "type": "object",
- "properties": {
- "location": {
- "type": "string",
- "description": "City name or zip code"
- }
- },
- "required": ["location"]
- }
- }
- ],
- "nextCursor": "next-page-cursor"
- }
-}`)
var base baseJSONRPCMessage
- require.NoError(t, json.Unmarshal(inputJSON, &base))
+ require.NoError(t, json.Unmarshal(sampleResponseJSON, &base))
assert.False(t, base.isNotification())
assert.False(t, base.isRequest())
assert.True(t, base.isResponse())
@@ -127,7 +130,7 @@ func TestJSONRPCResponse(t *testing.T) {
outputJSON, err := json.MarshalIndent(m, "", " ")
require.NoError(t, err)
- assert.JSONEq(t, string(inputJSON), string(outputJSON))
+ assert.JSONEq(t, string(sampleResponseJSON), string(outputJSON))
toolList, err := m.GetListToolResult()
require.NoError(t, err)
diff --git a/lib/utils/mcputils/reader.go b/lib/utils/mcputils/reader.go
index f51596739d85a..069c582bb08ba 100644
--- a/lib/utils/mcputils/reader.go
+++ b/lib/utils/mcputils/reader.go
@@ -217,3 +217,22 @@ func (r *MessageReader) processNextLine(ctx context.Context) error {
)
}
}
+
+// ReadOneResponse reads one message from the reader and marshals it to a
+// response.
+func ReadOneResponse(ctx context.Context, reader TransportReader) (*JSONRPCResponse, error) {
+ rawMessage, err := reader.ReadMessage(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ var base baseJSONRPCMessage
+ if parseError := json.Unmarshal([]byte(rawMessage), &base); parseError != nil {
+ return nil, trace.Wrap(parseError)
+ }
+
+ if !base.isResponse() {
+ return nil, trace.BadParameter("message is not a response")
+ }
+ return base.makeResponse(), nil
+}
diff --git a/lib/utils/mcputils/reader_test.go b/lib/utils/mcputils/reader_test.go
new file mode 100644
index 0000000000000..eac495d62dafa
--- /dev/null
+++ b/lib/utils/mcputils/reader_test.go
@@ -0,0 +1,84 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcputils
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type mockTransportReader struct {
+ message string
+}
+
+func (m mockTransportReader) ReadMessage(context.Context) (string, error) {
+ return m.message, nil
+}
+func (m mockTransportReader) Type() string {
+ return "mock"
+}
+func (m mockTransportReader) Close() error {
+ return nil
+}
+
+func TestReadOneResponse(t *testing.T) {
+ tests := []struct {
+ name string
+ rawMessage string
+ checkError require.ErrorAssertionFunc
+ checkResponse func(*testing.T, *JSONRPCResponse)
+ }{
+ {
+ name: "bad json",
+ rawMessage: "not JSON RPC message",
+ checkError: require.Error,
+ },
+ {
+ name: "notification",
+ rawMessage: string(sampleNotificationJSON),
+ checkError: require.Error,
+ },
+ {
+ name: "request",
+ rawMessage: string(sampleRequestJSON),
+ checkError: require.Error,
+ },
+ {
+ name: "response",
+ rawMessage: string(sampleResponseJSON),
+ checkError: require.NoError,
+ checkResponse: func(t *testing.T, response *JSONRPCResponse) {
+ require.NotNil(t, response)
+ _, err := response.GetListToolResult()
+ require.NoError(t, err)
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ resp, err := ReadOneResponse(t.Context(), mockTransportReader{test.rawMessage})
+ test.checkError(t, err)
+ if test.checkResponse != nil {
+ test.checkResponse(t, resp)
+ }
+ })
+ }
+}
diff --git a/lib/utils/mcputils/writer.go b/lib/utils/mcputils/writer.go
index b62a1d35ff56d..29fd2053dac8b 100644
--- a/lib/utils/mcputils/writer.go
+++ b/lib/utils/mcputils/writer.go
@@ -22,6 +22,7 @@ import (
"context"
"sync"
+ "github.com/gravitational/trace"
"github.com/mark3labs/mcp-go/mcp"
)
@@ -50,3 +51,37 @@ func (s *SyncMessageWriter) WriteMessage(ctx context.Context, msg mcp.JSONRPCMes
defer s.mu.Unlock()
return s.w.WriteMessage(ctx, msg)
}
+
+// MessageWriterFunc defines a message writer function that implements
+// MessageWriter.
+type MessageWriterFunc func(context.Context, mcp.JSONRPCMessage) error
+
+// WriteMessage writes an JSON RPC message.
+func (f MessageWriterFunc) WriteMessage(ctx context.Context, msg mcp.JSONRPCMessage) error {
+ return f(ctx, msg)
+}
+
+// MultiMessageWriter creates a writer that duplicates its writes to all the
+// provided writers.
+//
+// Each write is written to each listed writer, one at a time. If a listed
+// writer returns an error, that overall writes operation stops and returns the
+// error; it does not continue down the list.
+type MultiMessageWriter struct {
+ writers []MessageWriter
+}
+
+// NewMultiMessageWriter creates a new MultiMessageWriter.
+func NewMultiMessageWriter(writers ...MessageWriter) *MultiMessageWriter {
+ return &MultiMessageWriter{writers: writers}
+}
+
+// WriteMessage writes the message to each listed writer, one at a time.
+func (w *MultiMessageWriter) WriteMessage(ctx context.Context, msg mcp.JSONRPCMessage) error {
+ for _, writer := range w.writers {
+ if err := writer.WriteMessage(ctx, msg); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+ return nil
+}
diff --git a/tool/tsh/common/mcp_app.go b/tool/tsh/common/mcp_app.go
index 5a67648e4c36b..cf8222a6e9f11 100644
--- a/tool/tsh/common/mcp_app.go
+++ b/tool/tsh/common/mcp_app.go
@@ -37,6 +37,7 @@ import (
"github.com/gravitational/teleport/api/utils/iterutils"
"github.com/gravitational/teleport/lib/asciitable"
"github.com/gravitational/teleport/lib/client"
+ clientmcp "github.com/gravitational/teleport/lib/client/mcp"
"github.com/gravitational/teleport/lib/client/mcp/claude"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
@@ -51,6 +52,7 @@ func newMCPConnectCommand(parent *kingpin.CmdClause, cf *CLIConf) *mcpConnectCom
}
cmd.Arg("name", "Name of the MCP server.").Required().StringVar(&cf.AppName)
+ cmd.Flag("auto-reconnect", mcpAutoReconnectHelp).Default("true").BoolVar(&cmd.autoReconnect)
return cmd
}
@@ -77,6 +79,7 @@ func newMCPConfigCommand(parent *kingpin.CmdClause, cf *CLIConf) *mcpConfigComma
cmd.Flag("all", "Select all MCP servers. Mutually exclusive with --labels or --query.").Short('R').BoolVar(&cf.ListAll)
cmd.Flag("labels", labelHelp).StringVar(&cf.Labels)
cmd.Flag("query", queryHelp).StringVar(&cf.PredicateExpression)
+ cmd.Flag("auto-reconnect", mcpAutoReconnectHelp).IsSetByUser(&cmd.autoReconnectSetByUser).BoolVar(&cmd.autoReconnect)
cmd.Arg("name", "Name of the MCP server.").StringVar(&cf.AppName)
cmd.clientConfig.addToCmd(cmd.CmdClause)
cmd.Alias(mcpConfigHelp)
@@ -260,8 +263,10 @@ func printMCPServersInVerboseText(w io.Writer, mcpServers iter.Seq[mcpServerWith
type mcpConfigCommand struct {
*kingpin.CmdClause
- clientConfig mcpClientConfigFlags
- cf *CLIConf
+ clientConfig mcpClientConfigFlags
+ cf *CLIConf
+ autoReconnect bool
+ autoReconnectSetByUser bool
mcpServerApps []types.Application
@@ -347,7 +352,9 @@ func (c *mcpConfigCommand) fetch() error {
func (c *mcpConfigCommand) addMCPServersToConfig(config claudeConfig) error {
for _, app := range c.mcpServerApps {
localName := mcpServerAppConfigPrefix + app.GetName()
- err := config.PutMCPServer(localName, makeLocalMCPServer(c.cf, []string{"mcp", "connect", app.GetName()}))
+ args := []string{"mcp", "connect", app.GetName()}
+ args = c.maybeAddAutoReconnect(args)
+ err := config.PutMCPServer(localName, makeLocalMCPServer(c.cf, args))
if err != nil {
return trace.Wrap(err)
}
@@ -355,6 +362,16 @@ func (c *mcpConfigCommand) addMCPServersToConfig(config claudeConfig) error {
return nil
}
+func (c *mcpConfigCommand) maybeAddAutoReconnect(args []string) []string {
+ if !c.autoReconnectSetByUser {
+ return args
+ }
+ if c.autoReconnect {
+ return append(args, "--auto-reconnect")
+ }
+ return append(args, "--no-auto-reconnect")
+}
+
func (c *mcpConfigCommand) printJSONWithHint() error {
if err := c.fetchAndPrintResult(); err != nil {
return trace.Wrap(err)
@@ -372,7 +389,12 @@ func (c *mcpConfigCommand) printJSONWithHint() error {
if err := config.Write(w, claude.FormatJSONOption(c.clientConfig.jsonFormat)); err != nil {
return trace.Wrap(err)
}
- if _, err := fmt.Fprintln(w, ""); err != nil {
+ if !c.autoReconnectSetByUser {
+ if err := c.printAutoReconnectHint(w); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+ if _, err := fmt.Fprintln(w); err != nil {
return trace.Wrap(err)
}
return trace.Wrap(c.clientConfig.printHint(w))
@@ -395,26 +417,39 @@ func (c *mcpConfigCommand) updateClientConfig() error {
return trace.Wrap(err)
}
- // TODO(greedy52) update hint once auto-reconnection is handled.
_, err = fmt.Fprintf(c.cf.Stdout(), `Updated client configuration at:
%s
Teleport MCP servers will be prefixed with "teleport-mcp-" in this
-configuration.
-
-You may need to restart your client to reload these new configurations. If you
-encounter a "disconnected" error when tsh session expires, you may also need to
-restart your client after logging in a new tsh session.
+configuration. You may need to restart your client to reload these new
+configurations.
`, config.Path())
return trace.Wrap(err)
}
-const mcpServerAppConfigPrefix = "teleport-mcp-"
+func (c *mcpConfigCommand) printAutoReconnectHint(w io.Writer) error {
+ _, err := fmt.Fprintln(w, `
+By default, tsh automatically starts a new remote MCP session if the previous
+one is interrupted by network issues or tsh session expiration.
+Auto-reconnection is recommended when MCP sessions are stateless across
+requests. To disable it, use the --no-auto-reconnect flag. If disabled, you may
+need to manually restart your client when encountering "disconnected" errors.`)
+ return trace.Wrap(err)
+}
+
+const (
+ mcpServerAppConfigPrefix = "teleport-mcp-"
+ mcpAutoReconnectHelp = "Automatically starts a new remote MCP session " +
+ "when the previous remote session is interrupted " +
+ "by network issues or tsh session expirations. " +
+ "Recommended for stateless MCP sessions. Defaults to true."
+)
// mcpConnectCommand implements `tsh mcp connect` command.
type mcpConnectCommand struct {
*kingpin.CmdClause
- cf *CLIConf
+ cf *CLIConf
+ autoReconnect bool
}
func (c *mcpConnectCommand) run() error {
@@ -429,9 +464,46 @@ func (c *mcpConnectCommand) run() error {
}
tc.NonInteractive = true
+ if c.autoReconnect {
+ return clientmcp.ProxyStdioConnWithAutoReconnect(
+ c.cf.Context,
+ clientmcp.ProxyStdioConnWithAutoReconnectConfig{
+ ClientStdio: utils.CombinedStdio{},
+ DialServer: func(ctx context.Context) (io.ReadWriteCloser, error) {
+ conn, err := tc.DialMCPServer(ctx, c.cf.AppName)
+ return conn, trace.Wrap(err)
+ },
+ MakeReconnectUserMessage: makeMCPReconnectUserMessage,
+ },
+ )
+ }
+
serverConn, err := tc.DialMCPServer(c.cf.Context, c.cf.AppName)
if err != nil {
return trace.Wrap(err)
}
return trace.Wrap(utils.ProxyConn(c.cf.Context, utils.CombinedStdio{}, serverConn))
}
+
+func makeMCPReconnectUserMessage(err error) string {
+ var userMessage string
+ switch {
+ case clientmcp.IsLikelyTemporaryNetworkError(err):
+ userMessage = "A network error occurred while trying to connect to Teleport." +
+ " This issue is likely temporary — the server may be unavailable, or your internet connection may be unstable." +
+ " Please check your network and try again in a few moments." +
+ " If your network appears to be working, try restarting your MCP client to see if the problem is resolved."
+ case client.IsErrorResolvableWithRelogin(err):
+ userMessage = clientmcp.ReloginRequiredErrorMessage
+ case clientmcp.IsServerInfoChangedError(err):
+ userMessage = "The remote MCP server information has changed after the reconnection. " +
+ " Please restart your MCP client to use the new version."
+ default:
+ userMessage = "An error was encountered while sending the request to Teleport." +
+ " This does not appear to be a transient error." +
+ " Please ensure your tsh session is valid and restart your MCP client to see if the problem is resolved."
+ }
+
+ userMessage += " If the issue persists, check the MCP logs for more details or contact your Teleport admin."
+ return userMessage
+}
diff --git a/tool/tsh/common/mcp_app_test.go b/tool/tsh/common/mcp_app_test.go
index 263d65370c4b5..5eb1e8c9bf99a 100644
--- a/tool/tsh/common/mcp_app_test.go
+++ b/tool/tsh/common/mcp_app_test.go
@@ -452,8 +452,16 @@ func mustMakeNewAppServer(t *testing.T, app *types.AppV3, host string) types.App
return appServer
}
-func mustMakeMCPAppWithNameAndLabels(t *testing.T, name string, labels map[string]string) *types.AppV3 {
+func mustMakeMCPAppWithNameAndLabels(t *testing.T, name string, labels map[string]string, opts ...func(*types.MCP)) *types.AppV3 {
t.Helper()
+ mcpSpec := &types.MCP{
+ Command: "test",
+ Args: []string{"arg"},
+ RunAsHostUser: "test",
+ }
+ for _, opt := range opts {
+ opt(mcpSpec)
+ }
return mustMakeNewAppV3(t,
types.Metadata{
Name: name,
@@ -461,11 +469,7 @@ func mustMakeMCPAppWithNameAndLabels(t *testing.T, name string, labels map[strin
Labels: labels,
},
types.AppSpecV3{
- MCP: &types.MCP{
- Command: "test",
- Args: []string{"arg"},
- RunAsHostUser: "test",
- },
+ MCP: mcpSpec,
},
)
}