diff --git a/go.mod b/go.mod index 628b87acee522..255f6c79d9db5 100644 --- a/go.mod +++ b/go.mod @@ -162,6 +162,7 @@ require ( github.com/keys-pub/go-libfido2 v1.5.3-0.20220306005615-8ab03fb1ec27 // replaced github.com/lib/pq v1.10.9 github.com/mailgun/mailgun-go/v4 v4.23.0 + github.com/mark3labs/mcp-go v0.30.1 github.com/mattn/go-shellwords v1.0.12 github.com/mattn/go-sqlite3 v1.14.28 github.com/mdlayher/netlink v1.7.2 @@ -544,6 +545,7 @@ require ( github.com/xhit/go-str2duration/v2 v2.1.0 // indirect github.com/xlab/treeprint v1.2.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect github.com/zeebo/errs v1.4.0 // indirect diff --git a/go.sum b/go.sum index 0754baec7c963..7db6bd831a248 100644 --- a/go.sum +++ b/go.sum @@ -1842,6 +1842,8 @@ github.com/mailgun/mailgun-go/v4 v4.23.0 h1:jPEMJzzin2s7lvehcfv/0UkyBu18GvcURPr2 github.com/mailgun/mailgun-go/v4 v4.23.0/go.mod h1:imTtizoFtpfZqPqGP8vltVBB6q9yWcv6llBhfFeElZU= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.30.1 h1:3R1BPvNT/rC1iPpLx+EMXFy+gvux/Mz/Nio3c6XEU9E= +github.com/mark3labs/mcp-go v0.30.1/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= @@ -2288,6 +2290,8 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yashtewari/glob-intersection v0.2.0 h1:8iuHdN88yYuCzCdjt0gDe+6bAhUwBeEWqThExu54RFg= github.com/yashtewari/glob-intersection v0.2.0/go.mod h1:LK7pIC3piUjovexikBbJ26Yml7g8xa5bsjfx2v1fwok= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= diff --git a/lib/utils/mcputils/errors.go b/lib/utils/mcputils/errors.go new file mode 100644 index 0000000000000..ea819e9c88b30 --- /dev/null +++ b/lib/utils/mcputils/errors.go @@ -0,0 +1,33 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcputils + +import ( + "errors" + "io" + + "github.com/gravitational/teleport/lib/utils" +) + +// IsOKCloseError checks if provided error is a common close error that +// indicates the connection is ended. +func IsOKCloseError(err error) bool { + return errors.Is(err, io.ErrClosedPipe) || + utils.IsOKNetworkError(err) +} diff --git a/lib/utils/mcputils/id_tracker.go b/lib/utils/mcputils/id_tracker.go new file mode 100644 index 0000000000000..1d528bec5acef --- /dev/null +++ b/lib/utils/mcputils/id_tracker.go @@ -0,0 +1,80 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcputils + +import ( + "sync" + + "github.com/gravitational/trace" + "github.com/hashicorp/golang-lru/v2/simplelru" + "github.com/mark3labs/mcp-go/mcp" +) + +// IDTracker tracks message information like method based on ID. IDTracker +// internally uses an LRU cache to keep track the last X messages to avoid +// growing infinitely. IDTracker is safe for concurrent use. +type IDTracker struct { + mu sync.Mutex + lruCache *simplelru.LRU[mcp.RequestId, mcp.MCPMethod] +} + +// NewIDTracker creates a new IDTracker with provided maximum size. +func NewIDTracker(size int) (*IDTracker, error) { + lruCache, err := simplelru.NewLRU[mcp.RequestId, mcp.MCPMethod](size, nil) + if err != nil { + return nil, trace.Wrap(err) + } + return &IDTracker{ + lruCache: lruCache, + }, nil +} + +// PushRequest tracks a request. Returns true if the request has been added to +// cache. +func (t *IDTracker) PushRequest(msg *JSONRPCRequest) bool { + if msg == nil || msg.ID.IsNil() || msg.Method == "" { + return false + } + t.mu.Lock() + defer t.mu.Unlock() + t.lruCache.Add(msg.ID, msg.Method) + return true +} + +// PopByID retrieves the tracked information and remove it from the tracker. +func (t *IDTracker) PopByID(id mcp.RequestId) (mcp.MCPMethod, bool) { + if id.IsNil() { + return "", false + } + + t.mu.Lock() + defer t.mu.Unlock() + + retrieved, ok := t.lruCache.Get(id) + if !ok { + return "", false + } + t.lruCache.Remove(id) + return retrieved, true +} + +// Len returns the size of the tracker cache. +func (t *IDTracker) Len() int { + return t.lruCache.Len() +} diff --git a/lib/utils/mcputils/id_tracker_test.go b/lib/utils/mcputils/id_tracker_test.go new file mode 100644 index 0000000000000..94743e31b0824 --- /dev/null +++ b/lib/utils/mcputils/id_tracker_test.go @@ -0,0 +1,109 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcputils + +import ( + "fmt" + "slices" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" +) + +func TestIDTracker(t *testing.T) { + tracker, err := NewIDTracker(5) + require.NoError(t, err) + require.Empty(t, tracker.Len()) + + t.Run("request missing ID not tracked", func(t *testing.T) { + require.False(t, tracker.PushRequest(&JSONRPCRequest{ + Method: "bad", + })) + require.Empty(t, tracker.Len()) + }) + + t.Run("request tracked", func(t *testing.T) { + require.True(t, tracker.PushRequest(&JSONRPCRequest{ + ID: mcp.NewRequestId(0), + Method: mcp.MethodToolsList, + })) + require.Equal(t, 1, tracker.Len()) + }) + + t.Run("pop unknown id", func(t *testing.T) { + unknownIDs := []mcp.RequestId{ + mcp.NewRequestId(5), + mcp.NewRequestId("0"), + mcp.NewRequestId(nil), + } + for id := range slices.Values(unknownIDs) { + t.Run(fmt.Sprintf("%T", id), func(t *testing.T) { + _, ok := tracker.PopByID(id) + require.False(t, ok) + require.Equal(t, 1, tracker.Len()) + }) + } + }) + + t.Run("pop tracked id", func(t *testing.T) { + method, ok := tracker.PopByID(mcp.NewRequestId(0)) + require.True(t, ok) + require.Equal(t, mcp.MethodToolsList, method) + require.Empty(t, tracker.Len()) + }) + + t.Run("track last 5", func(t *testing.T) { + for i := range 20 { + tracker.PushRequest(&JSONRPCRequest{ + ID: mcp.NewRequestId(i + 1), + Method: mcp.MethodToolsCall, + }) + require.LessOrEqual(t, tracker.Len(), 10) + } + for i := range 5 { + method, ok := tracker.PopByID(mcp.NewRequestId(20 - i)) + require.True(t, ok) + require.Equal(t, mcp.MethodToolsCall, method) + } + require.Empty(t, tracker.Len()) + }) +} + +func BenchmarkIDTracker(b *testing.B) { + idTracker, err := NewIDTracker(100) + require.NoError(b, err) + + for i := 0; i < 100; i++ { + idTracker.PushRequest(&JSONRPCRequest{ + ID: mcp.NewRequestId(i), + Method: mcp.MethodToolsList, + }) + } + + // cpu: Apple M3 Pro + // BenchmarkIDTracker-12 12267649 81.85 ns/op + for b.Loop() { + idTracker.PushRequest(&JSONRPCRequest{ + ID: mcp.NewRequestId(2000), + Method: mcp.MethodToolsList, + }) + idTracker.PopByID(mcp.NewRequestId(2000)) + } +} diff --git a/lib/utils/mcputils/protocol.go b/lib/utils/mcputils/protocol.go new file mode 100644 index 0000000000000..fadbcb75c1ce3 --- /dev/null +++ b/lib/utils/mcputils/protocol.go @@ -0,0 +1,155 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcputils + +import ( + "encoding/json" + + "github.com/gravitational/trace" + "github.com/mark3labs/mcp-go/mcp" + + apievents "github.com/gravitational/teleport/api/types/events" +) + +// Type definitions from both mcp-go/client/transport or mcp-go are not suitable +// for our reverse proxy use, thus this file redefines them. +// +// TODO(greedy52) switch to official golang lib or official go SDK if they offer +// the level of handling we need. Same goes for other helpers like StdioXXX. + +// JSONRPCParams defines params for request or notification. +// TODO(greedy52) handle metadata +type JSONRPCParams map[string]any + +// GetEventParams returns the apievents.Struct for auditing. +func (p JSONRPCParams) GetEventParams() *apievents.Struct { + if p == nil { + return nil + } + + eventParams, _ := apievents.EncodeMap(p) + return eventParams +} + +// GetName returns the "name" param. +func (p JSONRPCParams) GetName() (string, bool) { + if p == nil { + return "", false + } + name, ok := p["name"].(string) + return name, ok +} + +// baseJSONRPCMessage is a base message that includes all fields for MCP +// protocol. +// +// Note that json.RawMessage is used to keep the original content when +// marshaling it again. json.RawMessage can also be easily unmarshalled to user +// defined types when needed. Same applies to other types in this file. +type baseJSONRPCMessage struct { + // JSONRPC specifies the version of JSONRPC. + JSONRPC string `json:"jsonrpc"` + // ID is the ID for request and response. ID is nil for notification. + ID mcp.RequestId `json:"id,omitempty"` + // Method is the request or notification method. Method is empty for response. + Method mcp.MCPMethod `json:"method,omitempty"` + // Params is the params for request and notification. + Params JSONRPCParams `json:"params,omitempty"` + // Result is the response result. + Result json.RawMessage `json:"result,omitempty"` + // Error is the response error. + Error json.RawMessage `json:"error,omitempty"` +} + +func (m *baseJSONRPCMessage) isNotification() bool { + return m.ID.IsNil() +} +func (m *baseJSONRPCMessage) isRequest() bool { + return !m.ID.IsNil() && m.Method != "" +} +func (m *baseJSONRPCMessage) isResponse() bool { + return !m.ID.IsNil() && (m.Result != nil || m.Error != nil) +} + +func (m *baseJSONRPCMessage) makeNotification() *JSONRPCNotification { + return &JSONRPCNotification{ + JSONRPC: m.JSONRPC, + Method: m.Method, + Params: m.Params, + } +} +func (m *baseJSONRPCMessage) makeRequest() *JSONRPCRequest { + return &JSONRPCRequest{ + JSONRPC: m.JSONRPC, + ID: m.ID, + Method: m.Method, + Params: m.Params, + } +} +func (m *baseJSONRPCMessage) makeResponse() *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: m.JSONRPC, + ID: m.ID, + Result: m.Result, + Error: m.Error, + } +} + +// JSONRPCNotification defines a MCP notification. +// +// https://modelcontextprotocol.io/specification/2025-03-26/basic#notifications +type JSONRPCNotification struct { + JSONRPC string `json:"jsonrpc"` + Method mcp.MCPMethod `json:"method"` + Params JSONRPCParams `json:"params,omitempty"` +} + +// JSONRPCRequest defines a MCP request. +// +// https://modelcontextprotocol.io/specification/2025-03-26/basic#requests +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + Method mcp.MCPMethod `json:"method"` + ID mcp.RequestId `json:"id,omitempty"` + Params JSONRPCParams `json:"params,omitempty"` +} + +// JSONRPCResponse defines an MCP response. +// +// By protocol spec, responses are further sub-categorized as either successful +// results or errors. Either a result or an error MUST be set. A response MUST +// NOT set both. +// +// https://modelcontextprotocol.io/specification/2025-03-26/basic#responses +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID mcp.RequestId `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` +} + +// GetListToolResult assumes the result is for mcp.MethodToolsList and returns +// the corresponding go object. +func (r *JSONRPCResponse) GetListToolResult() (*mcp.ListToolsResult, error) { + var listResult mcp.ListToolsResult + if err := json.Unmarshal([]byte(r.Result), &listResult); err != nil { + return nil, trace.Wrap(err) + } + return &listResult, nil +} diff --git a/lib/utils/mcputils/protocol_test.go b/lib/utils/mcputils/protocol_test.go new file mode 100644 index 0000000000000..cf7a61045ae58 --- /dev/null +++ b/lib/utils/mcputils/protocol_test.go @@ -0,0 +1,153 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcputils + +import ( + "encoding/json" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONRPCNotification(t *testing.T) { + inputJSON := []byte(`{ + "jsonrpc": "2.0", + "method": "notifications/message", + "params": { + "level": "error", + "logger": "database", + "data": { + "error": "Connection failed", + "details": { + "host": "localhost", + "port": 5432 + } + } + } +}`) + + var base baseJSONRPCMessage + require.NoError(t, json.Unmarshal(inputJSON, &base)) + assert.True(t, base.isNotification()) + assert.False(t, base.isRequest()) + assert.False(t, base.isResponse()) + + m := base.makeNotification() + require.NotNil(t, m) + assert.Equal(t, mcp.MCPMethod("notifications/message"), m.Method) + assert.Len(t, base.Params, 3) + + outputJSON, err := json.MarshalIndent(m, "", " ") + require.NoError(t, err) + assert.JSONEq(t, string(inputJSON), string(outputJSON)) +} + +func TestJSONRPCRequest(t *testing.T) { + inputJSON := []byte(`{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "get_weather", + "arguments": { + "location": "New York" + } + } +}`) + var base baseJSONRPCMessage + require.NoError(t, json.Unmarshal(inputJSON, &base)) + assert.False(t, base.isNotification()) + assert.True(t, base.isRequest()) + assert.False(t, base.isResponse()) + + m := base.makeRequest() + require.NotNil(t, m) + assert.Equal(t, mcp.MethodToolsCall, m.Method) + assert.Equal(t, "int64:2", m.ID.String()) + name, ok := m.Params.GetName() + assert.True(t, ok) + assert.Equal(t, "get_weather", name) + + outputJSON, err := json.MarshalIndent(m, "", " ") + require.NoError(t, err) + assert.JSONEq(t, string(inputJSON), string(outputJSON)) +} + +func TestJSONRPCResponse(t *testing.T) { + inputJSON := []byte(`{ + "jsonrpc": "2.0", + "id": 2, + "result": { + "tools": [ + { + "name": "get_weather", + "description": "Get current weather information for a location", + "inputSchema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name or zip code" + } + }, + "required": ["location"] + } + } + ], + "nextCursor": "next-page-cursor" + } +}`) + var base baseJSONRPCMessage + require.NoError(t, json.Unmarshal(inputJSON, &base)) + assert.False(t, base.isNotification()) + assert.False(t, base.isRequest()) + assert.True(t, base.isResponse()) + + m := base.makeResponse() + require.NotNil(t, m) + assert.Equal(t, "int64:2", m.ID.String()) + + outputJSON, err := json.MarshalIndent(m, "", " ") + require.NoError(t, err) + assert.JSONEq(t, string(inputJSON), string(outputJSON)) + + toolList, err := m.GetListToolResult() + require.NoError(t, err) + require.Equal(t, &mcp.ListToolsResult{ + PaginatedResult: mcp.PaginatedResult{ + NextCursor: "next-page-cursor", + }, + Tools: []mcp.Tool{{ + Name: "get_weather", + Description: "Get current weather information for a location", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "City name or zip code", + }, + }, + Required: []string{"location"}, + }, + }}, + }, toolList) +} diff --git a/lib/utils/mcputils/stdio.go b/lib/utils/mcputils/stdio.go new file mode 100644 index 0000000000000..81972650eb2d8 --- /dev/null +++ b/lib/utils/mcputils/stdio.go @@ -0,0 +1,248 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcputils + +import ( + "bufio" + "cmp" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + + "github.com/gravitational/trace" + "github.com/mark3labs/mcp-go/mcp" + + "github.com/gravitational/teleport" + logutils "github.com/gravitational/teleport/lib/utils/log" +) + +// StderrTraceLogWriter implements io.Writer and logs the content at TRACE +// level. Used for tracing stderr. +type StderrTraceLogWriter struct { + ctx context.Context + log *slog.Logger +} + +// NewStderrTraceLogWriter returns a new StderrTraceLogWriter. +func NewStderrTraceLogWriter(ctx context.Context, log *slog.Logger) *StderrTraceLogWriter { + return &StderrTraceLogWriter{ + ctx: ctx, + log: cmp.Or(log, slog.Default()), + } +} + +// Write implements io.Writer and logs the given input p at trace level. +// Note that the input p may contain arbitrary-length data, which can span +// multiple lines or include partial lines. +func (l *StderrTraceLogWriter) Write(p []byte) (int, error) { + l.log.Log(l.ctx, logutils.TraceLevel, "Trace stderr", "data", p) + return len(p), nil +} + +// StdioMessageWriter writes a JSONRPC message in stdio transport. +type StdioMessageWriter struct { + w io.Writer +} + +// NewStdioMessageWriter returns a MessageWriter using stdio transport. +func NewStdioMessageWriter(w io.Writer) *StdioMessageWriter { + return &StdioMessageWriter{ + w: w, + } +} + +// WriteMessage writes a JSONRPC message in stdio transport. +func (w *StdioMessageWriter) WriteMessage(_ context.Context, resp mcp.JSONRPCMessage) error { + bytes, err := json.Marshal(resp) + if err != nil { + return trace.Wrap(err) + } + _, err = fmt.Fprintf(w.w, "%s\n", string(bytes)) + return trace.Wrap(err) +} + +// HandleParseErrorFunc handles parse errors. +type HandleParseErrorFunc func(context.Context, *mcp.JSONRPCError) error + +// ReplyParseError returns a HandleParseErrorFunc that forwards the error to +// provided writer. +func ReplyParseError(w *StdioMessageWriter) HandleParseErrorFunc { + return func(ctx context.Context, parseError *mcp.JSONRPCError) error { + return trace.Wrap(w.WriteMessage(ctx, parseError)) + } +} + +// LogAndIgnoreParseError returns a HandleParseErrorFunc that logs the parse +// error. +func LogAndIgnoreParseError(log *slog.Logger) HandleParseErrorFunc { + return func(ctx context.Context, parseError *mcp.JSONRPCError) error { + log.DebugContext(ctx, "Ignore parse error", "error", parseError) + return nil + } +} + +// StdioMessageReaderConfig is the config for StdioMessageReader. +type StdioMessageReaderConfig struct { + // SourceReadCloser is the input to the read the message from. + // SourceReadCloser will be closed when reader finishes. + SourceReadCloser io.ReadCloser + // Logger is the slog.Logger. + Logger *slog.Logger + // ParentContext is the parent's context. Used for logging during tear down. + ParentContext context.Context + + // OnClose is an optional callback when reader finishes. + OnClose func() + // OnParseError specifies the handler for handling parse error. Any error + // returned by the handler stops this message reader. + OnParseError HandleParseErrorFunc + // OnRequest specifies the handler for handling request. Any error by the + // handler stops this message reader. + OnRequest func(context.Context, *JSONRPCRequest) error + // OnResponse specifies the handler for handling response. Any error by the + // handler stops this message reader. + OnResponse func(context.Context, *JSONRPCResponse) error + // OnNotification specifies the handler for handling notification. Any error + // returned by the handler stops this message reader. + OnNotification func(context.Context, *JSONRPCNotification) error +} + +// CheckAndSetDefaults checks values and sets defaults. +func (c *StdioMessageReaderConfig) CheckAndSetDefaults() error { + if c.SourceReadCloser == nil { + return trace.BadParameter("missing parameter SourceReadCloser") + } + if c.OnParseError == nil { + return trace.BadParameter("missing parameter OnParseError") + } + if c.OnNotification == nil { + return trace.BadParameter("missing parameter OnNotification") + } + if c.OnRequest == nil && c.OnResponse == nil { + return trace.BadParameter("one of OnRequest or OnResponse must be set") + } + if c.ParentContext == nil { + return trace.BadParameter("missing parameter ParentContext") + } + if c.Logger == nil { + c.Logger = slog.With(teleport.ComponentKey, "mcp") + } + return nil +} + +// StdioMessageReader reads requests from provided reader. +type StdioMessageReader struct { + cfg StdioMessageReaderConfig +} + +// NewStdioMessageReader creates a new StdioMessageReader. Must call "Start" to +// start the processing. +func NewStdioMessageReader(cfg StdioMessageReaderConfig) (*StdioMessageReader, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return &StdioMessageReader{ + cfg: cfg, + }, nil +} + +// Run starts reading requests from provided reader. Run blocks until an +// error happens from the provided reader or any of the handler. +func (r *StdioMessageReader) Run(ctx context.Context) { + r.cfg.Logger.InfoContext(ctx, "Start processing stdio messages") + + finished := make(chan struct{}) + go func() { + r.startProcess(ctx) + close(finished) + }() + + select { + case <-finished: + case <-ctx.Done(): + } + + r.cfg.Logger.InfoContext(r.cfg.ParentContext, "Finished processing stdio messages") + if err := r.cfg.SourceReadCloser.Close(); err != nil && !IsOKCloseError(err) { + r.cfg.Logger.ErrorContext(r.cfg.ParentContext, "Failed to close reader", "error", err) + } + if r.cfg.OnClose != nil { + r.cfg.OnClose() + } +} + +func (r *StdioMessageReader) startProcess(ctx context.Context) { + lineReader := bufio.NewReader(r.cfg.SourceReadCloser) + for { + if ctx.Err() != nil { + return + } + + if err := r.processNextLine(ctx, lineReader); err != nil { + if !IsOKCloseError(err) { + r.cfg.Logger.ErrorContext(ctx, "Failed to process line", "error", err) + } + return + } + } +} + +func (r *StdioMessageReader) processNextLine(ctx context.Context, lineReader *bufio.Reader) error { + line, err := lineReader.ReadString('\n') + if err != nil { + return trace.Wrap(err, "reading line") + } + + r.cfg.Logger.Log(ctx, logutils.TraceLevel, "Trace stdio", "line", line) + + var base baseJSONRPCMessage + if parseError := json.Unmarshal([]byte(line), &base); parseError != nil { + rpcError := mcp.NewJSONRPCError(mcp.NewRequestId(nil), mcp.PARSE_ERROR, parseError.Error(), nil) + if err := r.cfg.OnParseError(ctx, &rpcError); err != nil { + return trace.Wrap(err, "handling JSON unmarshal error") + } + } + + switch { + case base.isNotification(): + return trace.Wrap(r.cfg.OnNotification(ctx, base.makeNotification()), "handling notification") + case base.isRequest(): + if r.cfg.OnRequest != nil { + return trace.Wrap(r.cfg.OnRequest(ctx, base.makeRequest()), "handling request") + } + // Should not happen. Log something just in case. + r.cfg.Logger.DebugContext(ctx, "Skipping request", "id", base.ID) + return nil + case base.isResponse(): + if r.cfg.OnResponse != nil { + return trace.Wrap(r.cfg.OnResponse(ctx, base.makeResponse()), "handling response") + } + // Should not happen. Log something just in case. + r.cfg.Logger.DebugContext(ctx, "Skipping response", "id", base.ID) + return nil + default: + rpcError := mcp.NewJSONRPCError(base.ID, mcp.PARSE_ERROR, "unknown message type", line) + return trace.Wrap( + r.cfg.OnParseError(ctx, &rpcError), + "handling unknown message type error", + ) + } +} diff --git a/lib/utils/mcputils/stdio_test.go b/lib/utils/mcputils/stdio_test.go new file mode 100644 index 0000000000000..ec63799353ed5 --- /dev/null +++ b/lib/utils/mcputils/stdio_test.go @@ -0,0 +1,190 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcputils + +import ( + "bytes" + "context" + "io" + "log" + "log/slog" + "sync/atomic" + "testing" + "time" + + "github.com/gravitational/trace" + mcpclient "github.com/mark3labs/mcp-go/client" + mcpclienttransport "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestStdioHelpers tests StdioMessageReader and StdioMessageWriter by +// implementing a passthrough reverse proxy. +// +// The flow looks something like this: +// request: MCP client --> client message reader --> server message writer --> MCP server +// response: MCP client <-- client message writer <-- server message reader <-- MCP server +func TestStdioHelpers(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Set up some counters for verification. + var readClientNotifications int32 + var readClientRequests int32 + var readServerNotifications int32 + var readServerResponses int32 + + // Pipes for hooking things up. + clientStdin, writeToClient := io.Pipe() + readFromClient, clientStdout := io.Pipe() + serverStdio, writeToServer := io.Pipe() + readFromServer, serverStdout := io.Pipe() + t.Cleanup(func() { + assert.NoError(t, trace.NewAggregate( + clientStdin.Close(), writeToClient.Close(), + readFromClient.Close(), clientStdout.Close(), + serverStdio.Close(), writeToServer.Close(), + readFromServer.Close(), serverStdout.Close(), + )) + }) + + // Make "low-level" message readers and writers for MITM proxy. + clientMessageWriter := NewStdioMessageWriter(writeToClient) + serverMessageWriter := NewStdioMessageWriter(writeToServer) + + clientMessageReader, err := NewStdioMessageReader(StdioMessageReaderConfig{ + ParentContext: context.Background(), + SourceReadCloser: readFromClient, + OnNotification: func(ctx context.Context, notification *JSONRPCNotification) error { + atomic.AddInt32(&readClientNotifications, 1) + return trace.Wrap(serverMessageWriter.WriteMessage(ctx, notification)) + }, + OnRequest: func(ctx context.Context, request *JSONRPCRequest) error { + atomic.AddInt32(&readClientRequests, 1) + return trace.Wrap(serverMessageWriter.WriteMessage(ctx, request)) + }, + OnParseError: ReplyParseError(clientMessageWriter), + }) + require.NoError(t, err) + clientMessageReaderClosed := make(chan struct{}) + go func() { + clientMessageReader.Run(ctx) + close(clientMessageReaderClosed) + }() + + serverMessageReader, err := NewStdioMessageReader(StdioMessageReaderConfig{ + ParentContext: context.Background(), + SourceReadCloser: readFromServer, + OnNotification: func(ctx context.Context, notification *JSONRPCNotification) error { + atomic.AddInt32(&readServerNotifications, 1) + return trace.Wrap(clientMessageWriter.WriteMessage(ctx, notification)) + }, + OnResponse: func(ctx context.Context, response *JSONRPCResponse) error { + atomic.AddInt32(&readServerResponses, 1) + return trace.Wrap(clientMessageWriter.WriteMessage(ctx, response)) + }, + OnParseError: LogAndIgnoreParseError(slog.Default()), + }) + require.NoError(t, err) + serverMessageReaderClosed := make(chan struct{}) + serverMessageReaderCtx, serverMessageReaderCtxCancel := context.WithCancel(ctx) + go func() { + serverMessageReader.Run(serverMessageReaderCtx) + close(serverMessageReaderClosed) + }() + + // Make "high-level" MCP client and server with stdio transport as the two + // ends. + stdioClientTransport := mcpclienttransport.NewIO(clientStdin, clientStdout, io.NopCloser(bytes.NewReader(nil))) + stdioClient := mcpclient.NewClient(stdioClientTransport) + defer stdioClient.Close() + require.NoError(t, stdioClient.Start(ctx)) + + stdioServer := mcpserver.NewStdioServer(makeTestMCPServer()) + stdioServer.SetErrorLogger(log.New(io.Discard, "", log.LstdFlags)) + go stdioServer.Listen(ctx, serverStdio, serverStdout) + + // Test things out. + t.Run("client initialize", func(t *testing.T) { + initReq := mcp.InitializeRequest{} + initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initReq.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + _, err = stdioClient.Initialize(ctx, initReq) + require.NoError(t, err) + }) + + t.Run("client call tool", func(t *testing.T) { + callToolRequest := mcp.CallToolRequest{} + callToolRequest.Params.Name = "hello-server" + callToolResult, err := stdioClient.CallTool(ctx, callToolRequest) + require.NoError(t, err) + require.NotNil(t, callToolResult) + require.Equal(t, []mcp.Content{ + mcp.NewTextContent("hello client"), + }, callToolResult.Content) + }) + + t.Run("reader closed by closing stdin", func(t *testing.T) { + readFromClient.Close() + select { + case <-clientMessageReaderClosed: + case <-time.After(time.Second * 2): + require.Fail(t, "timeout waiting for reader closed by closing stdin") + } + }) + + t.Run("reader closed by canceling context", func(t *testing.T) { + serverMessageReaderCtxCancel() + select { + case <-serverMessageReaderClosed: + case <-time.After(time.Second * 2): + require.Fail(t, "timeout waiting for reader closed by canceling context") + } + }) + + t.Run("verify counters", func(t *testing.T) { + // client -> server: initialize request + // server -> client: initialize response + // client -> server: notifications/initialized + // client -> server: tools\call request + // server -> client: tools\call response + assert.Equal(t, int32(1), atomic.LoadInt32(&readClientNotifications)) + assert.Equal(t, int32(2), atomic.LoadInt32(&readClientRequests)) + assert.Equal(t, int32(0), atomic.LoadInt32(&readServerNotifications)) + assert.Equal(t, int32(2), atomic.LoadInt32(&readServerResponses)) + }) +} + +func makeTestMCPServer() *mcpserver.MCPServer { + server := mcpserver.NewMCPServer("test-server", "1.0.0") + server.AddTool(mcp.Tool{ + Name: "hello-server", + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{mcp.NewTextContent("hello client")}, + }, nil + }) + return server +}