diff --git a/lib/client/db/mcp/errors.go b/lib/client/db/mcp/errors.go index 19fdc90dbb2ac..04a00c0043850 100644 --- a/lib/client/db/mcp/errors.go +++ b/lib/client/db/mcp/errors.go @@ -33,7 +33,7 @@ import ( func FormatErrorMessage(err error) error { switch { case errors.Is(err, apiclient.ErrClientCredentialsHaveExpired) || utils.IsCertExpiredError(err): - return trace.BadParameter(ReloginRequiredErrorMessage) + return trace.BadParameter(mcp.ReloginRequiredErrorMessage) case strings.Contains(err.Error(), "connection reset by peer") || errors.Is(err, io.ErrClosedPipe): return trace.BadParameter(LocalProxyConnectionErrorMessage) } @@ -42,12 +42,6 @@ func FormatErrorMessage(err error) error { } const ( - // ReloginRequiredErrorMessage is the message returned to the MCP client - // when the tsh session expired. - ReloginRequiredErrorMessage = `It looks like your Teleport session expired, -you must relogin (using "tsh login" on a terminal) before continue using this -tool. After that, there is no need to update or relaunch the MCP client - just -try using it again.` // LocalProxyConnectionErrorMessage is the message returned to the MCP client when // the database client cannot connect to the local proxy. LocalProxyConnectionErrorMessage = `Teleport MCP server is having issue while diff --git a/lib/client/db/postgres/mcp/mcp_test.go b/lib/client/db/postgres/mcp/mcp_test.go index 7e018d6ef5917..698baeb499200 100644 --- a/lib/client/db/postgres/mcp/mcp_test.go +++ b/lib/client/db/postgres/mcp/mcp_test.go @@ -140,7 +140,7 @@ func TestFormatErrors(t *testing.T) { }, }, expectErrorMessage: func(tt require.TestingT, i1 any, i2 ...any) { - require.Equal(t, dbmcp.ReloginRequiredErrorMessage, i1) + require.Equal(t, clientmcp.ReloginRequiredErrorMessage, i1) }, }, } { diff --git a/lib/client/mcp/errors.go b/lib/client/mcp/errors.go new file mode 100644 index 0000000000000..969a4297e23be --- /dev/null +++ b/lib/client/mcp/errors.go @@ -0,0 +1,84 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcp + +import ( + "errors" + "fmt" + "net" + "syscall" + + "github.com/gravitational/trace" + "github.com/mark3labs/mcp-go/mcp" +) + +const ( + // ReloginRequiredErrorMessage is the message returned to the MCP client + // when the tsh session expired. + ReloginRequiredErrorMessage = `It looks like your Teleport session expired, +you must relogin (using "tsh login" on a terminal) before continue using this +tool. After that, there is no need to update or relaunch the MCP client - just +try using it again.` +) + +// IsLikelyTemporaryNetworkError returns true if the error is likely a temporary +// network error. +func IsLikelyTemporaryNetworkError(err error) bool { + if trace.IsConnectionProblem(err) || + isTemporarySyscallNetError(err) { + return true + } + + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + return dnsErr.Temporary() + } + + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + + return false +} + +func isTemporarySyscallNetError(err error) bool { + return errors.Is(err, syscall.EHOSTUNREACH) || + errors.Is(err, syscall.ENETUNREACH) || + errors.Is(err, syscall.ETIMEDOUT) || + errors.Is(err, syscall.ECONNREFUSED) +} + +// IsServerInfoChangedError returns true if the error indicates the remote MCP +// server's info has changed from previous connections. Auto-reconnection +// reports this scenario as an error case to be on the safe side in case things +// like tools have changed. +func IsServerInfoChangedError(err error) bool { + var serverInfoChangedError *serverInfoChangedError + return errors.As(err, &serverInfoChangedError) +} + +type serverInfoChangedError struct { + expectedInfo mcp.Implementation + currentInfo mcp.Implementation +} + +func (e *serverInfoChangedError) Error() string { + return fmt.Sprintf("server info has changed, expected %v, got %v", e.expectedInfo, e.currentInfo) +} diff --git a/lib/client/mcp/errors_test.go b/lib/client/mcp/errors_test.go new file mode 100644 index 0000000000000..86e11d1cf8a29 --- /dev/null +++ b/lib/client/mcp/errors_test.go @@ -0,0 +1,40 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcp + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" +) + +func TestIsServerInfoChangedError(t *testing.T) { + err := &serverInfoChangedError{ + expectedInfo: mcp.Implementation{ + Name: "i-am-mcp", + Version: "1.0.0", + }, + currentInfo: mcp.Implementation{ + Name: "i-am-mcp", + Version: "1.1.0", + }, + } + require.True(t, IsServerInfoChangedError(err)) +} diff --git a/lib/client/mcp/reconnect.go b/lib/client/mcp/reconnect.go new file mode 100644 index 0000000000000..11a7d6a6e00cd --- /dev/null +++ b/lib/client/mcp/reconnect.go @@ -0,0 +1,327 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcp + +import ( + "context" + "fmt" + "io" + "log/slog" + "sync" + + "github.com/gravitational/trace" + "github.com/mark3labs/mcp-go/mcp" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/utils/mcputils" +) + +// ProxyStdioConnWithAutoReconnectConfig is the config for ProxyStdioConnWithAutoReconnect. +type ProxyStdioConnWithAutoReconnectConfig struct { + // ClientStdio is the client stdin and stdout. + ClientStdio io.ReadWriteCloser + // MakeReconnectUserMessage generates a user-friendly message based on the + // error. + MakeReconnectUserMessage func(error) string + // DialServer makes a new connection to the remote MCP server. + DialServer func(context.Context) (io.ReadWriteCloser, error) + // Logger is the slog logger. + Logger *slog.Logger + + // clientResponseWriter replies to ClientStdio. + clientResponseWriter mcputils.MessageWriter + // onServerConnClosed is a callback when remote server connection is dead. + onServerConnClosed func() +} + +// CheckAndSetDefaults validates the config and sets default values. +func (cfg *ProxyStdioConnWithAutoReconnectConfig) CheckAndSetDefaults() error { + if cfg.ClientStdio == nil { + return trace.BadParameter("missing ClientStdio") + } + if cfg.DialServer == nil { + return trace.BadParameter("missing DialServer") + } + if cfg.MakeReconnectUserMessage == nil { + return trace.BadParameter("missing MakeReconnectUserMessage") + } + if cfg.Logger == nil { + cfg.Logger = slog.With( + teleport.ComponentKey, + teleport.Component(teleport.ComponentMCP, "autoreconnect"), + ) + } + if cfg.clientResponseWriter == nil { + cfg.clientResponseWriter = mcputils.NewSyncStdioMessageWriter(cfg.ClientStdio) + } + return nil +} + +// ProxyStdioConnWithAutoReconnect serves a stdio client with a consistent +// connection and reconnects to the remote server upon issues. +func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWithAutoReconnectConfig) error { + if err := cfg.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + serverConn, err := newServerConnWithAutoReconnect(ctx, cfg) + if err != nil { + return trace.Wrap(err) + } + + clientRequestReader, err := mcputils.NewMessageReader(mcputils.MessageReaderConfig{ + Transport: mcputils.NewStdioReader(cfg.ClientStdio), + ParentContext: ctx, + Logger: cfg.Logger.With("client", "stdin"), + OnParseError: mcputils.ReplyParseError(cfg.clientResponseWriter), + OnNotification: func(ctx context.Context, notification *mcputils.JSONRPCNotification) error { + // By spec, we should not reply to notifications. Try our best to + // send a notification with the error message. In practice, only the + // initialize notification is sent from client after receiving the + // initialize response so it's unlikely to hit here. + if writeError := serverConn.WriteMessage(ctx, notification); writeError != nil { + cfg.Logger.WarnContext(ctx, "failed to write notification to server. Notification is dropped.", "error", writeError) + userMessage := cfg.MakeReconnectUserMessage(writeError) + errNotification := mcp.Notification{ + Method: "notifications/tsherr", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]any{ + "error": fmt.Sprintf("Notification %q was dropped. %s", notification.Method, userMessage), + }, + }, + } + return trace.Wrap(cfg.clientResponseWriter.WriteMessage(ctx, errNotification)) + } + return nil + }, + OnRequest: func(ctx context.Context, request *mcputils.JSONRPCRequest) error { + if writeError := serverConn.WriteMessage(ctx, request); writeError != nil { + cfg.Logger.WarnContext(ctx, "failed to write request to server", "error", writeError) + userMessage := cfg.MakeReconnectUserMessage(writeError) + errResp := mcp.NewJSONRPCError(request.ID, mcp.INTERNAL_ERROR, userMessage, writeError) + return trace.Wrap(cfg.clientResponseWriter.WriteMessage(ctx, errResp)) + } + return nil + }, + }) + if err != nil { + return trace.Wrap(err) + } + clientRequestReader.Run(ctx) + return nil + +} + +type serverConnWithAutoReconnect struct { + ProxyStdioConnWithAutoReconnectConfig + parentCtx context.Context + + mu sync.Mutex + serverRequestWriter mcputils.MessageWriter + replayOnNextConn bool + initRequest *mcputils.JSONRPCRequest + initResponse *mcp.InitializeResult + initNotification *mcputils.JSONRPCNotification + closeServerConn func() +} + +func newServerConnWithAutoReconnect(parentCtx context.Context, cfg ProxyStdioConnWithAutoReconnectConfig) (*serverConnWithAutoReconnect, error) { + return &serverConnWithAutoReconnect{ + ProxyStdioConnWithAutoReconnectConfig: cfg, + parentCtx: parentCtx, + }, nil +} + +func (r *serverConnWithAutoReconnect) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + if r.closeServerConn != nil { + r.closeServerConn() + } + return nil +} + +func (r *serverConnWithAutoReconnect) WriteMessage(ctx context.Context, msg mcp.JSONRPCMessage) error { + r.mu.Lock() + defer r.mu.Unlock() + + writer, err := r.getServerRequestWriterLocked(ctx) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(writer.WriteMessage(ctx, msg)) +} + +func (r *serverConnWithAutoReconnect) getServerRequestWriterLocked(ctx context.Context) (mcputils.MessageWriter, error) { + if r.serverRequestWriter != nil { + return r.serverRequestWriter, nil + } + + r.Logger.InfoContext(ctx, "Connecting to server") + serverConn, err := r.DialServer(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + serverStdioReader := mcputils.NewStdioReader(serverConn) + serverWriter := mcputils.NewStdioMessageWriter(serverConn) + if r.replayOnNextConn { + // Replay initialize sequence. Any error here is likely permanent. + if err := r.replayInitializeLocked(ctx, serverStdioReader, serverWriter); err != nil { + serverConn.Close() + return nil, trace.Wrap(err) + } + r.serverRequestWriter = serverWriter + } else { + r.serverRequestWriter = mcputils.NewMultiMessageWriter( + mcputils.MessageWriterFunc(func(ctx context.Context, msg mcp.JSONRPCMessage) error { + r.cacheMessageLocked(ctx, msg) + return nil + }), + serverWriter, + ) + r.replayOnNextConn = true + } + + // This should never fail as long the correct config is passed in. + serverResponseReader, err := mcputils.NewMessageReader(mcputils.MessageReaderConfig{ + Transport: serverStdioReader, + ParentContext: r.parentCtx, + // OnClose is called when server connection is dead. + // Teleport Proxy automatically closes the connection when tsh session + // is expired. + OnClose: func() { + r.Logger.InfoContext(ctx, "Lost server connection, resetting...") + r.mu.Lock() + r.serverRequestWriter = nil + if r.onServerConnClosed != nil { + r.onServerConnClosed() + } + r.mu.Unlock() + }, + Logger: r.Logger.With("server", "stdout"), + OnParseError: mcputils.LogAndIgnoreParseError(r.Logger), + OnNotification: func(ctx context.Context, notification *mcputils.JSONRPCNotification) error { + return trace.Wrap(r.clientResponseWriter.WriteMessage(ctx, notification)) + }, + OnResponse: func(ctx context.Context, response *mcputils.JSONRPCResponse) error { + r.cacheMessageLocked(ctx, response) + return trace.Wrap(r.clientResponseWriter.WriteMessage(ctx, response)) + }, + }) + if err != nil { + serverConn.Close() + return nil, trace.Wrap(err) + } + + readerCtx, readerCancel := context.WithCancel(r.parentCtx) + r.closeServerConn = readerCancel + go serverResponseReader.Run(readerCtx) + + r.Logger.InfoContext(ctx, "Started a new MCP server connection") + return r.serverRequestWriter, nil +} + +func (r *serverConnWithAutoReconnect) initializedLocked() bool { + return r.initRequest != nil && r.initResponse != nil && r.initNotification != nil +} + +func (r *serverConnWithAutoReconnect) replayInitializeLocked(ctx context.Context, serverReader mcputils.TransportReader, serverWriter mcputils.MessageWriter) error { + if !r.initializedLocked() { + return trace.Errorf("client has not initialized yet") + } + + r.Logger.DebugContext(ctx, "Replaying initialize request") + if err := serverWriter.WriteMessage(ctx, r.initRequest); err != nil { + return trace.Wrap(err) + } + + r.Logger.DebugContext(ctx, "Reading and comparing initialize response") + msg, err := mcputils.ReadOneResponse(ctx, serverReader) + if err != nil { + return trace.Wrap(err) + } + + if err := r.checkReplyResponseLocked(msg); err != nil { + return trace.Wrap(err) + } + + r.Logger.DebugContext(ctx, "Replaying initialized notification") + if err := serverWriter.WriteMessage(ctx, r.initNotification); err != nil { + return trace.Wrap(err) + } + return nil +} + +func (r *serverConnWithAutoReconnect) checkReplyResponseLocked(msg mcp.JSONRPCMessage) error { + resp, ok := msg.(*mcputils.JSONRPCResponse) + if !ok { + return trace.Errorf("expected initialize response, got %T", resp) + } + if resp.Error != nil { + return trace.Errorf("expected initialize result but got error") + } + if resp.ID.String() != r.initRequest.ID.String() { + return trace.CompareFailed("expected initialize response with ID %s, got %s", r.initRequest.ID, resp.ID.String()) + } + + newResult, err := resp.GetInitializeResult() + if err != nil { + return trace.Wrap(err) + } + if newResult.ServerInfo != r.initResponse.ServerInfo { + return trace.Wrap(&serverInfoChangedError{ + expectedInfo: r.initResponse.ServerInfo, + currentInfo: newResult.ServerInfo, + }) + } + return nil +} + +// cacheMessageLocked caches client init request and notification. +func (r *serverConnWithAutoReconnect) cacheMessageLocked(ctx context.Context, msg mcp.JSONRPCMessage) { + if r.initializedLocked() { + return + } + + switch m := msg.(type) { + case *mcputils.JSONRPCRequest: + if r.initRequest == nil && m.Method == mcp.MethodInitialize { + r.initRequest = m + r.Logger.DebugContext(ctx, "Cached initialize", "request", m) + } + case *mcputils.JSONRPCNotification: + if r.initNotification == nil && m.Method == mcputils.MethodNotificationInitialized { + r.initNotification = m + r.Logger.DebugContext(ctx, "Cached notification", "notification", m) + } + case *mcputils.JSONRPCResponse: + if r.initResponse == nil && r.initRequest != nil && r.initRequest.ID.String() == m.ID.String() { + initResponse, err := m.GetInitializeResult() + if err != nil { + r.Logger.DebugContext(ctx, "Error parsing init response", "error", err) + } else { + r.initResponse = initResponse + r.Logger.DebugContext(ctx, "Cached response", "response", m) + } + } + } +} diff --git a/lib/client/mcp/reconnect_test.go b/lib/client/mcp/reconnect_test.go new file mode 100644 index 0000000000000..6e98a8fa235ad --- /dev/null +++ b/lib/client/mcp/reconnect_test.go @@ -0,0 +1,126 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcp + +import ( + "context" + "io" + "sync/atomic" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/utils/mcptest" + "github.com/gravitational/teleport/lib/utils/uds" +) + +func TestProxyStdioConnWithAutoReconnect(t *testing.T) { + ctx := t.Context() + + var serverStdioSource atomic.Value + prepServerWithVersion := func(version string) { + testServerV1 := mcptest.NewServerWithVersion(version) + testServerSource, testServerDest := mustMakeSocketPair(t) + serverStdioSource.Store(testServerSource) + go func() { + mcpserver.NewStdioServer(testServerV1).Listen(t.Context(), testServerDest, testServerDest) + }() + } + prepServerWithVersion("1.0.0") + + clientStdioSource, clientStdioDest := mustMakeSocketPair(t) + stdioClient := mcptest.NewStdioClientFromConn(t, clientStdioSource) + proxyError := make(chan error, 1) + serverConnClosed := make(chan struct{}, 1) + + // Start proxy. + go func() { + proxyError <- ProxyStdioConnWithAutoReconnect(ctx, ProxyStdioConnWithAutoReconnectConfig{ + ClientStdio: clientStdioDest, + MakeReconnectUserMessage: func(err error) string { + return err.Error() + }, + DialServer: func(ctx context.Context) (io.ReadWriteCloser, error) { + return serverStdioSource.Load().(io.ReadWriteCloser), nil + }, + onServerConnClosed: func() { + serverConnClosed <- struct{}{} + }, + }) + }() + + // Initialize. + _, err := mcptest.InitializeClient(ctx, stdioClient) + require.NoError(t, err) + + // Call tool success. + callToolRequest := mcp.CallToolRequest{} + callToolRequest.Params.Name = "hello-server" + _, err = stdioClient.CallTool(ctx, callToolRequest) + require.NoError(t, err) + + // Let's kill the server, CallTool should fail. + serverStdioSource.Load().(io.ReadWriteCloser).Close() + select { + case <-serverConnClosed: + case <-time.After(time.Second): + t.Fatal("timed out waiting for server connection to close") + } + _, err = stdioClient.CallTool(ctx, callToolRequest) + require.ErrorContains(t, err, "use of closed network connection") + + // Let it try again with a successful reconnect. + prepServerWithVersion("1.0.0") + _, err = stdioClient.CallTool(ctx, callToolRequest) + require.NoError(t, err) + + // Let's kill the server again, and prepare a different version. + serverStdioSource.Load().(io.ReadWriteCloser).Close() + select { + case <-serverConnClosed: + case <-time.After(time.Second): + t.Fatal("timed out waiting for server connection to close") + } + prepServerWithVersion("2.0.0") + _, err = stdioClient.CallTool(ctx, callToolRequest) + require.ErrorContains(t, err, "server info has changed") + + // Cleanup. + clientStdioSource.Close() + select { + case proxyErr := <-proxyError: + require.NoError(t, proxyErr) + case <-time.After(time.Second): + t.Fatal("timed out waiting for proxy connection") + } +} + +func mustMakeSocketPair(t *testing.T) (io.ReadWriteCloser, io.ReadWriteCloser) { + t.Helper() + source, dest, err := uds.NewSocketpair(uds.SocketTypeStream) + require.NoError(t, err) + t.Cleanup(func() { + source.Close() + dest.Close() + }) + return source, dest +} diff --git a/lib/utils/mcputils/protocol.go b/lib/utils/mcputils/protocol.go index 5a8d8537167fd..3c458c2c7bf45 100644 --- a/lib/utils/mcputils/protocol.go +++ b/lib/utils/mcputils/protocol.go @@ -148,8 +148,25 @@ type JSONRPCResponse struct { // the corresponding go object. func (r *JSONRPCResponse) GetListToolResult() (*mcp.ListToolsResult, error) { var listResult mcp.ListToolsResult - if err := json.Unmarshal([]byte(r.Result), &listResult); err != nil { + if err := json.Unmarshal(r.Result, &listResult); err != nil { return nil, trace.Wrap(err) } return &listResult, nil } + +// GetInitializeResult assumes the result is for mcp.MethodInitialize and +// returns the corresponding go object. +func (r *JSONRPCResponse) GetInitializeResult() (*mcp.InitializeResult, error) { + var result mcp.InitializeResult + if err := json.Unmarshal(r.Result, &result); err != nil { + return nil, trace.Wrap(err) + } + return &result, nil +} + +const ( + // MethodNotificationInitialized defines the method used for "initialized" + // notification. This notification is sent by the client after it receives + // the initialize response. + MethodNotificationInitialized = "notifications/initialized" +) diff --git a/lib/utils/mcputils/protocol_test.go b/lib/utils/mcputils/protocol_test.go index acf0f4c721641..4be664ac35e59 100644 --- a/lib/utils/mcputils/protocol_test.go +++ b/lib/utils/mcputils/protocol_test.go @@ -27,8 +27,8 @@ import ( "github.com/stretchr/testify/require" ) -func TestJSONRPCNotification(t *testing.T) { - inputJSON := []byte(`{ +var ( + sampleNotificationJSON = []byte(`{ "jsonrpc": "2.0", "method": "notifications/message", "params": { @@ -43,9 +43,46 @@ func TestJSONRPCNotification(t *testing.T) { } } }`) + sampleRequestJSON = []byte(`{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "get_weather", + "arguments": { + "location": "New York" + } + } +}`) + + sampleResponseJSON = []byte(`{ + "jsonrpc": "2.0", + "id": 2, + "result": { + "tools": [ + { + "name": "get_weather", + "description": "Get current weather information for a location", + "inputSchema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name or zip code" + } + }, + "required": ["location"] + } + } + ], + "nextCursor": "next-page-cursor" + } +}`) +) +func TestJSONRPCNotification(t *testing.T) { var base baseJSONRPCMessage - require.NoError(t, json.Unmarshal(inputJSON, &base)) + require.NoError(t, json.Unmarshal(sampleNotificationJSON, &base)) assert.True(t, base.isNotification()) assert.False(t, base.isRequest()) assert.False(t, base.isResponse()) @@ -57,23 +94,12 @@ func TestJSONRPCNotification(t *testing.T) { outputJSON, err := json.MarshalIndent(m, "", " ") require.NoError(t, err) - assert.JSONEq(t, string(inputJSON), string(outputJSON)) + assert.JSONEq(t, string(sampleNotificationJSON), string(outputJSON)) } func TestJSONRPCRequest(t *testing.T) { - inputJSON := []byte(`{ - "jsonrpc": "2.0", - "id": 2, - "method": "tools/call", - "params": { - "name": "get_weather", - "arguments": { - "location": "New York" - } - } -}`) var base baseJSONRPCMessage - require.NoError(t, json.Unmarshal(inputJSON, &base)) + require.NoError(t, json.Unmarshal(sampleRequestJSON, &base)) assert.False(t, base.isNotification()) assert.True(t, base.isRequest()) assert.False(t, base.isResponse()) @@ -88,35 +114,12 @@ func TestJSONRPCRequest(t *testing.T) { outputJSON, err := json.MarshalIndent(m, "", " ") require.NoError(t, err) - assert.JSONEq(t, string(inputJSON), string(outputJSON)) + assert.JSONEq(t, string(sampleRequestJSON), string(outputJSON)) } func TestJSONRPCResponse(t *testing.T) { - inputJSON := []byte(`{ - "jsonrpc": "2.0", - "id": 2, - "result": { - "tools": [ - { - "name": "get_weather", - "description": "Get current weather information for a location", - "inputSchema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City name or zip code" - } - }, - "required": ["location"] - } - } - ], - "nextCursor": "next-page-cursor" - } -}`) var base baseJSONRPCMessage - require.NoError(t, json.Unmarshal(inputJSON, &base)) + require.NoError(t, json.Unmarshal(sampleResponseJSON, &base)) assert.False(t, base.isNotification()) assert.False(t, base.isRequest()) assert.True(t, base.isResponse()) @@ -127,7 +130,7 @@ func TestJSONRPCResponse(t *testing.T) { outputJSON, err := json.MarshalIndent(m, "", " ") require.NoError(t, err) - assert.JSONEq(t, string(inputJSON), string(outputJSON)) + assert.JSONEq(t, string(sampleResponseJSON), string(outputJSON)) toolList, err := m.GetListToolResult() require.NoError(t, err) diff --git a/lib/utils/mcputils/reader.go b/lib/utils/mcputils/reader.go index f51596739d85a..069c582bb08ba 100644 --- a/lib/utils/mcputils/reader.go +++ b/lib/utils/mcputils/reader.go @@ -217,3 +217,22 @@ func (r *MessageReader) processNextLine(ctx context.Context) error { ) } } + +// ReadOneResponse reads one message from the reader and marshals it to a +// response. +func ReadOneResponse(ctx context.Context, reader TransportReader) (*JSONRPCResponse, error) { + rawMessage, err := reader.ReadMessage(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + var base baseJSONRPCMessage + if parseError := json.Unmarshal([]byte(rawMessage), &base); parseError != nil { + return nil, trace.Wrap(parseError) + } + + if !base.isResponse() { + return nil, trace.BadParameter("message is not a response") + } + return base.makeResponse(), nil +} diff --git a/lib/utils/mcputils/reader_test.go b/lib/utils/mcputils/reader_test.go new file mode 100644 index 0000000000000..eac495d62dafa --- /dev/null +++ b/lib/utils/mcputils/reader_test.go @@ -0,0 +1,84 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcputils + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type mockTransportReader struct { + message string +} + +func (m mockTransportReader) ReadMessage(context.Context) (string, error) { + return m.message, nil +} +func (m mockTransportReader) Type() string { + return "mock" +} +func (m mockTransportReader) Close() error { + return nil +} + +func TestReadOneResponse(t *testing.T) { + tests := []struct { + name string + rawMessage string + checkError require.ErrorAssertionFunc + checkResponse func(*testing.T, *JSONRPCResponse) + }{ + { + name: "bad json", + rawMessage: "not JSON RPC message", + checkError: require.Error, + }, + { + name: "notification", + rawMessage: string(sampleNotificationJSON), + checkError: require.Error, + }, + { + name: "request", + rawMessage: string(sampleRequestJSON), + checkError: require.Error, + }, + { + name: "response", + rawMessage: string(sampleResponseJSON), + checkError: require.NoError, + checkResponse: func(t *testing.T, response *JSONRPCResponse) { + require.NotNil(t, response) + _, err := response.GetListToolResult() + require.NoError(t, err) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + resp, err := ReadOneResponse(t.Context(), mockTransportReader{test.rawMessage}) + test.checkError(t, err) + if test.checkResponse != nil { + test.checkResponse(t, resp) + } + }) + } +} diff --git a/lib/utils/mcputils/writer.go b/lib/utils/mcputils/writer.go index b62a1d35ff56d..29fd2053dac8b 100644 --- a/lib/utils/mcputils/writer.go +++ b/lib/utils/mcputils/writer.go @@ -22,6 +22,7 @@ import ( "context" "sync" + "github.com/gravitational/trace" "github.com/mark3labs/mcp-go/mcp" ) @@ -50,3 +51,37 @@ func (s *SyncMessageWriter) WriteMessage(ctx context.Context, msg mcp.JSONRPCMes defer s.mu.Unlock() return s.w.WriteMessage(ctx, msg) } + +// MessageWriterFunc defines a message writer function that implements +// MessageWriter. +type MessageWriterFunc func(context.Context, mcp.JSONRPCMessage) error + +// WriteMessage writes an JSON RPC message. +func (f MessageWriterFunc) WriteMessage(ctx context.Context, msg mcp.JSONRPCMessage) error { + return f(ctx, msg) +} + +// MultiMessageWriter creates a writer that duplicates its writes to all the +// provided writers. +// +// Each write is written to each listed writer, one at a time. If a listed +// writer returns an error, that overall writes operation stops and returns the +// error; it does not continue down the list. +type MultiMessageWriter struct { + writers []MessageWriter +} + +// NewMultiMessageWriter creates a new MultiMessageWriter. +func NewMultiMessageWriter(writers ...MessageWriter) *MultiMessageWriter { + return &MultiMessageWriter{writers: writers} +} + +// WriteMessage writes the message to each listed writer, one at a time. +func (w *MultiMessageWriter) WriteMessage(ctx context.Context, msg mcp.JSONRPCMessage) error { + for _, writer := range w.writers { + if err := writer.WriteMessage(ctx, msg); err != nil { + return trace.Wrap(err) + } + } + return nil +} diff --git a/tool/tsh/common/mcp_app.go b/tool/tsh/common/mcp_app.go index 5a67648e4c36b..cf8222a6e9f11 100644 --- a/tool/tsh/common/mcp_app.go +++ b/tool/tsh/common/mcp_app.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/api/utils/iterutils" "github.com/gravitational/teleport/lib/asciitable" "github.com/gravitational/teleport/lib/client" + clientmcp "github.com/gravitational/teleport/lib/client/mcp" "github.com/gravitational/teleport/lib/client/mcp/claude" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" @@ -51,6 +52,7 @@ func newMCPConnectCommand(parent *kingpin.CmdClause, cf *CLIConf) *mcpConnectCom } cmd.Arg("name", "Name of the MCP server.").Required().StringVar(&cf.AppName) + cmd.Flag("auto-reconnect", mcpAutoReconnectHelp).Default("true").BoolVar(&cmd.autoReconnect) return cmd } @@ -77,6 +79,7 @@ func newMCPConfigCommand(parent *kingpin.CmdClause, cf *CLIConf) *mcpConfigComma cmd.Flag("all", "Select all MCP servers. Mutually exclusive with --labels or --query.").Short('R').BoolVar(&cf.ListAll) cmd.Flag("labels", labelHelp).StringVar(&cf.Labels) cmd.Flag("query", queryHelp).StringVar(&cf.PredicateExpression) + cmd.Flag("auto-reconnect", mcpAutoReconnectHelp).IsSetByUser(&cmd.autoReconnectSetByUser).BoolVar(&cmd.autoReconnect) cmd.Arg("name", "Name of the MCP server.").StringVar(&cf.AppName) cmd.clientConfig.addToCmd(cmd.CmdClause) cmd.Alias(mcpConfigHelp) @@ -260,8 +263,10 @@ func printMCPServersInVerboseText(w io.Writer, mcpServers iter.Seq[mcpServerWith type mcpConfigCommand struct { *kingpin.CmdClause - clientConfig mcpClientConfigFlags - cf *CLIConf + clientConfig mcpClientConfigFlags + cf *CLIConf + autoReconnect bool + autoReconnectSetByUser bool mcpServerApps []types.Application @@ -347,7 +352,9 @@ func (c *mcpConfigCommand) fetch() error { func (c *mcpConfigCommand) addMCPServersToConfig(config claudeConfig) error { for _, app := range c.mcpServerApps { localName := mcpServerAppConfigPrefix + app.GetName() - err := config.PutMCPServer(localName, makeLocalMCPServer(c.cf, []string{"mcp", "connect", app.GetName()})) + args := []string{"mcp", "connect", app.GetName()} + args = c.maybeAddAutoReconnect(args) + err := config.PutMCPServer(localName, makeLocalMCPServer(c.cf, args)) if err != nil { return trace.Wrap(err) } @@ -355,6 +362,16 @@ func (c *mcpConfigCommand) addMCPServersToConfig(config claudeConfig) error { return nil } +func (c *mcpConfigCommand) maybeAddAutoReconnect(args []string) []string { + if !c.autoReconnectSetByUser { + return args + } + if c.autoReconnect { + return append(args, "--auto-reconnect") + } + return append(args, "--no-auto-reconnect") +} + func (c *mcpConfigCommand) printJSONWithHint() error { if err := c.fetchAndPrintResult(); err != nil { return trace.Wrap(err) @@ -372,7 +389,12 @@ func (c *mcpConfigCommand) printJSONWithHint() error { if err := config.Write(w, claude.FormatJSONOption(c.clientConfig.jsonFormat)); err != nil { return trace.Wrap(err) } - if _, err := fmt.Fprintln(w, ""); err != nil { + if !c.autoReconnectSetByUser { + if err := c.printAutoReconnectHint(w); err != nil { + return trace.Wrap(err) + } + } + if _, err := fmt.Fprintln(w); err != nil { return trace.Wrap(err) } return trace.Wrap(c.clientConfig.printHint(w)) @@ -395,26 +417,39 @@ func (c *mcpConfigCommand) updateClientConfig() error { return trace.Wrap(err) } - // TODO(greedy52) update hint once auto-reconnection is handled. _, err = fmt.Fprintf(c.cf.Stdout(), `Updated client configuration at: %s Teleport MCP servers will be prefixed with "teleport-mcp-" in this -configuration. - -You may need to restart your client to reload these new configurations. If you -encounter a "disconnected" error when tsh session expires, you may also need to -restart your client after logging in a new tsh session. +configuration. You may need to restart your client to reload these new +configurations. `, config.Path()) return trace.Wrap(err) } -const mcpServerAppConfigPrefix = "teleport-mcp-" +func (c *mcpConfigCommand) printAutoReconnectHint(w io.Writer) error { + _, err := fmt.Fprintln(w, ` +By default, tsh automatically starts a new remote MCP session if the previous +one is interrupted by network issues or tsh session expiration. +Auto-reconnection is recommended when MCP sessions are stateless across +requests. To disable it, use the --no-auto-reconnect flag. If disabled, you may +need to manually restart your client when encountering "disconnected" errors.`) + return trace.Wrap(err) +} + +const ( + mcpServerAppConfigPrefix = "teleport-mcp-" + mcpAutoReconnectHelp = "Automatically starts a new remote MCP session " + + "when the previous remote session is interrupted " + + "by network issues or tsh session expirations. " + + "Recommended for stateless MCP sessions. Defaults to true." +) // mcpConnectCommand implements `tsh mcp connect` command. type mcpConnectCommand struct { *kingpin.CmdClause - cf *CLIConf + cf *CLIConf + autoReconnect bool } func (c *mcpConnectCommand) run() error { @@ -429,9 +464,46 @@ func (c *mcpConnectCommand) run() error { } tc.NonInteractive = true + if c.autoReconnect { + return clientmcp.ProxyStdioConnWithAutoReconnect( + c.cf.Context, + clientmcp.ProxyStdioConnWithAutoReconnectConfig{ + ClientStdio: utils.CombinedStdio{}, + DialServer: func(ctx context.Context) (io.ReadWriteCloser, error) { + conn, err := tc.DialMCPServer(ctx, c.cf.AppName) + return conn, trace.Wrap(err) + }, + MakeReconnectUserMessage: makeMCPReconnectUserMessage, + }, + ) + } + serverConn, err := tc.DialMCPServer(c.cf.Context, c.cf.AppName) if err != nil { return trace.Wrap(err) } return trace.Wrap(utils.ProxyConn(c.cf.Context, utils.CombinedStdio{}, serverConn)) } + +func makeMCPReconnectUserMessage(err error) string { + var userMessage string + switch { + case clientmcp.IsLikelyTemporaryNetworkError(err): + userMessage = "A network error occurred while trying to connect to Teleport." + + " This issue is likely temporary — the server may be unavailable, or your internet connection may be unstable." + + " Please check your network and try again in a few moments." + + " If your network appears to be working, try restarting your MCP client to see if the problem is resolved." + case client.IsErrorResolvableWithRelogin(err): + userMessage = clientmcp.ReloginRequiredErrorMessage + case clientmcp.IsServerInfoChangedError(err): + userMessage = "The remote MCP server information has changed after the reconnection. " + + " Please restart your MCP client to use the new version." + default: + userMessage = "An error was encountered while sending the request to Teleport." + + " This does not appear to be a transient error." + + " Please ensure your tsh session is valid and restart your MCP client to see if the problem is resolved." + } + + userMessage += " If the issue persists, check the MCP logs for more details or contact your Teleport admin." + return userMessage +} diff --git a/tool/tsh/common/mcp_app_test.go b/tool/tsh/common/mcp_app_test.go index 263d65370c4b5..5eb1e8c9bf99a 100644 --- a/tool/tsh/common/mcp_app_test.go +++ b/tool/tsh/common/mcp_app_test.go @@ -452,8 +452,16 @@ func mustMakeNewAppServer(t *testing.T, app *types.AppV3, host string) types.App return appServer } -func mustMakeMCPAppWithNameAndLabels(t *testing.T, name string, labels map[string]string) *types.AppV3 { +func mustMakeMCPAppWithNameAndLabels(t *testing.T, name string, labels map[string]string, opts ...func(*types.MCP)) *types.AppV3 { t.Helper() + mcpSpec := &types.MCP{ + Command: "test", + Args: []string{"arg"}, + RunAsHostUser: "test", + } + for _, opt := range opts { + opt(mcpSpec) + } return mustMakeNewAppV3(t, types.Metadata{ Name: name, @@ -461,11 +469,7 @@ func mustMakeMCPAppWithNameAndLabels(t *testing.T, name string, labels map[strin Labels: labels, }, types.AppSpecV3{ - MCP: &types.MCP{ - Command: "test", - Args: []string{"arg"}, - RunAsHostUser: "test", - }, + MCP: mcpSpec, }, ) }