diff --git a/go.mod b/go.mod
index 628b87acee522..255f6c79d9db5 100644
--- a/go.mod
+++ b/go.mod
@@ -162,6 +162,7 @@ require (
github.com/keys-pub/go-libfido2 v1.5.3-0.20220306005615-8ab03fb1ec27 // replaced
github.com/lib/pq v1.10.9
github.com/mailgun/mailgun-go/v4 v4.23.0
+ github.com/mark3labs/mcp-go v0.30.1
github.com/mattn/go-shellwords v1.0.12
github.com/mattn/go-sqlite3 v1.14.28
github.com/mdlayher/netlink v1.7.2
@@ -544,6 +545,7 @@ require (
github.com/xhit/go-str2duration/v2 v2.1.0 // indirect
github.com/xlab/treeprint v1.2.0 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
+ github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
github.com/zeebo/errs v1.4.0 // indirect
diff --git a/go.sum b/go.sum
index 0754baec7c963..7db6bd831a248 100644
--- a/go.sum
+++ b/go.sum
@@ -1842,6 +1842,8 @@ github.com/mailgun/mailgun-go/v4 v4.23.0 h1:jPEMJzzin2s7lvehcfv/0UkyBu18GvcURPr2
github.com/mailgun/mailgun-go/v4 v4.23.0/go.mod h1:imTtizoFtpfZqPqGP8vltVBB6q9yWcv6llBhfFeElZU=
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
+github.com/mark3labs/mcp-go v0.30.1 h1:3R1BPvNT/rC1iPpLx+EMXFy+gvux/Mz/Nio3c6XEU9E=
+github.com/mark3labs/mcp-go v0.30.1/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
@@ -2288,6 +2290,8 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yashtewari/glob-intersection v0.2.0 h1:8iuHdN88yYuCzCdjt0gDe+6bAhUwBeEWqThExu54RFg=
github.com/yashtewari/glob-intersection v0.2.0/go.mod h1:LK7pIC3piUjovexikBbJ26Yml7g8xa5bsjfx2v1fwok=
+github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
+github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
diff --git a/lib/utils/mcputils/errors.go b/lib/utils/mcputils/errors.go
new file mode 100644
index 0000000000000..ea819e9c88b30
--- /dev/null
+++ b/lib/utils/mcputils/errors.go
@@ -0,0 +1,33 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcputils
+
+import (
+ "errors"
+ "io"
+
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+// IsOKCloseError checks if provided error is a common close error that
+// indicates the connection is ended.
+func IsOKCloseError(err error) bool {
+ return errors.Is(err, io.ErrClosedPipe) ||
+ utils.IsOKNetworkError(err)
+}
diff --git a/lib/utils/mcputils/id_tracker.go b/lib/utils/mcputils/id_tracker.go
new file mode 100644
index 0000000000000..1d528bec5acef
--- /dev/null
+++ b/lib/utils/mcputils/id_tracker.go
@@ -0,0 +1,80 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcputils
+
+import (
+ "sync"
+
+ "github.com/gravitational/trace"
+ "github.com/hashicorp/golang-lru/v2/simplelru"
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// IDTracker tracks message information like method based on ID. IDTracker
+// internally uses an LRU cache to keep track the last X messages to avoid
+// growing infinitely. IDTracker is safe for concurrent use.
+type IDTracker struct {
+ mu sync.Mutex
+ lruCache *simplelru.LRU[mcp.RequestId, mcp.MCPMethod]
+}
+
+// NewIDTracker creates a new IDTracker with provided maximum size.
+func NewIDTracker(size int) (*IDTracker, error) {
+ lruCache, err := simplelru.NewLRU[mcp.RequestId, mcp.MCPMethod](size, nil)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &IDTracker{
+ lruCache: lruCache,
+ }, nil
+}
+
+// PushRequest tracks a request. Returns true if the request has been added to
+// cache.
+func (t *IDTracker) PushRequest(msg *JSONRPCRequest) bool {
+ if msg == nil || msg.ID.IsNil() || msg.Method == "" {
+ return false
+ }
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.lruCache.Add(msg.ID, msg.Method)
+ return true
+}
+
+// PopByID retrieves the tracked information and remove it from the tracker.
+func (t *IDTracker) PopByID(id mcp.RequestId) (mcp.MCPMethod, bool) {
+ if id.IsNil() {
+ return "", false
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ retrieved, ok := t.lruCache.Get(id)
+ if !ok {
+ return "", false
+ }
+ t.lruCache.Remove(id)
+ return retrieved, true
+}
+
+// Len returns the size of the tracker cache.
+func (t *IDTracker) Len() int {
+ return t.lruCache.Len()
+}
diff --git a/lib/utils/mcputils/id_tracker_test.go b/lib/utils/mcputils/id_tracker_test.go
new file mode 100644
index 0000000000000..94743e31b0824
--- /dev/null
+++ b/lib/utils/mcputils/id_tracker_test.go
@@ -0,0 +1,109 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcputils
+
+import (
+ "fmt"
+ "slices"
+ "testing"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/stretchr/testify/require"
+)
+
+func TestIDTracker(t *testing.T) {
+ tracker, err := NewIDTracker(5)
+ require.NoError(t, err)
+ require.Empty(t, tracker.Len())
+
+ t.Run("request missing ID not tracked", func(t *testing.T) {
+ require.False(t, tracker.PushRequest(&JSONRPCRequest{
+ Method: "bad",
+ }))
+ require.Empty(t, tracker.Len())
+ })
+
+ t.Run("request tracked", func(t *testing.T) {
+ require.True(t, tracker.PushRequest(&JSONRPCRequest{
+ ID: mcp.NewRequestId(0),
+ Method: mcp.MethodToolsList,
+ }))
+ require.Equal(t, 1, tracker.Len())
+ })
+
+ t.Run("pop unknown id", func(t *testing.T) {
+ unknownIDs := []mcp.RequestId{
+ mcp.NewRequestId(5),
+ mcp.NewRequestId("0"),
+ mcp.NewRequestId(nil),
+ }
+ for id := range slices.Values(unknownIDs) {
+ t.Run(fmt.Sprintf("%T", id), func(t *testing.T) {
+ _, ok := tracker.PopByID(id)
+ require.False(t, ok)
+ require.Equal(t, 1, tracker.Len())
+ })
+ }
+ })
+
+ t.Run("pop tracked id", func(t *testing.T) {
+ method, ok := tracker.PopByID(mcp.NewRequestId(0))
+ require.True(t, ok)
+ require.Equal(t, mcp.MethodToolsList, method)
+ require.Empty(t, tracker.Len())
+ })
+
+ t.Run("track last 5", func(t *testing.T) {
+ for i := range 20 {
+ tracker.PushRequest(&JSONRPCRequest{
+ ID: mcp.NewRequestId(i + 1),
+ Method: mcp.MethodToolsCall,
+ })
+ require.LessOrEqual(t, tracker.Len(), 10)
+ }
+ for i := range 5 {
+ method, ok := tracker.PopByID(mcp.NewRequestId(20 - i))
+ require.True(t, ok)
+ require.Equal(t, mcp.MethodToolsCall, method)
+ }
+ require.Empty(t, tracker.Len())
+ })
+}
+
+func BenchmarkIDTracker(b *testing.B) {
+ idTracker, err := NewIDTracker(100)
+ require.NoError(b, err)
+
+ for i := 0; i < 100; i++ {
+ idTracker.PushRequest(&JSONRPCRequest{
+ ID: mcp.NewRequestId(i),
+ Method: mcp.MethodToolsList,
+ })
+ }
+
+ // cpu: Apple M3 Pro
+ // BenchmarkIDTracker-12 12267649 81.85 ns/op
+ for b.Loop() {
+ idTracker.PushRequest(&JSONRPCRequest{
+ ID: mcp.NewRequestId(2000),
+ Method: mcp.MethodToolsList,
+ })
+ idTracker.PopByID(mcp.NewRequestId(2000))
+ }
+}
diff --git a/lib/utils/mcputils/protocol.go b/lib/utils/mcputils/protocol.go
new file mode 100644
index 0000000000000..fadbcb75c1ce3
--- /dev/null
+++ b/lib/utils/mcputils/protocol.go
@@ -0,0 +1,155 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcputils
+
+import (
+ "encoding/json"
+
+ "github.com/gravitational/trace"
+ "github.com/mark3labs/mcp-go/mcp"
+
+ apievents "github.com/gravitational/teleport/api/types/events"
+)
+
+// Type definitions from both mcp-go/client/transport or mcp-go are not suitable
+// for our reverse proxy use, thus this file redefines them.
+//
+// TODO(greedy52) switch to official golang lib or official go SDK if they offer
+// the level of handling we need. Same goes for other helpers like StdioXXX.
+
+// JSONRPCParams defines params for request or notification.
+// TODO(greedy52) handle metadata
+type JSONRPCParams map[string]any
+
+// GetEventParams returns the apievents.Struct for auditing.
+func (p JSONRPCParams) GetEventParams() *apievents.Struct {
+ if p == nil {
+ return nil
+ }
+
+ eventParams, _ := apievents.EncodeMap(p)
+ return eventParams
+}
+
+// GetName returns the "name" param.
+func (p JSONRPCParams) GetName() (string, bool) {
+ if p == nil {
+ return "", false
+ }
+ name, ok := p["name"].(string)
+ return name, ok
+}
+
+// baseJSONRPCMessage is a base message that includes all fields for MCP
+// protocol.
+//
+// Note that json.RawMessage is used to keep the original content when
+// marshaling it again. json.RawMessage can also be easily unmarshalled to user
+// defined types when needed. Same applies to other types in this file.
+type baseJSONRPCMessage struct {
+ // JSONRPC specifies the version of JSONRPC.
+ JSONRPC string `json:"jsonrpc"`
+ // ID is the ID for request and response. ID is nil for notification.
+ ID mcp.RequestId `json:"id,omitempty"`
+ // Method is the request or notification method. Method is empty for response.
+ Method mcp.MCPMethod `json:"method,omitempty"`
+ // Params is the params for request and notification.
+ Params JSONRPCParams `json:"params,omitempty"`
+ // Result is the response result.
+ Result json.RawMessage `json:"result,omitempty"`
+ // Error is the response error.
+ Error json.RawMessage `json:"error,omitempty"`
+}
+
+func (m *baseJSONRPCMessage) isNotification() bool {
+ return m.ID.IsNil()
+}
+func (m *baseJSONRPCMessage) isRequest() bool {
+ return !m.ID.IsNil() && m.Method != ""
+}
+func (m *baseJSONRPCMessage) isResponse() bool {
+ return !m.ID.IsNil() && (m.Result != nil || m.Error != nil)
+}
+
+func (m *baseJSONRPCMessage) makeNotification() *JSONRPCNotification {
+ return &JSONRPCNotification{
+ JSONRPC: m.JSONRPC,
+ Method: m.Method,
+ Params: m.Params,
+ }
+}
+func (m *baseJSONRPCMessage) makeRequest() *JSONRPCRequest {
+ return &JSONRPCRequest{
+ JSONRPC: m.JSONRPC,
+ ID: m.ID,
+ Method: m.Method,
+ Params: m.Params,
+ }
+}
+func (m *baseJSONRPCMessage) makeResponse() *JSONRPCResponse {
+ return &JSONRPCResponse{
+ JSONRPC: m.JSONRPC,
+ ID: m.ID,
+ Result: m.Result,
+ Error: m.Error,
+ }
+}
+
+// JSONRPCNotification defines a MCP notification.
+//
+// https://modelcontextprotocol.io/specification/2025-03-26/basic#notifications
+type JSONRPCNotification struct {
+ JSONRPC string `json:"jsonrpc"`
+ Method mcp.MCPMethod `json:"method"`
+ Params JSONRPCParams `json:"params,omitempty"`
+}
+
+// JSONRPCRequest defines a MCP request.
+//
+// https://modelcontextprotocol.io/specification/2025-03-26/basic#requests
+type JSONRPCRequest struct {
+ JSONRPC string `json:"jsonrpc"`
+ Method mcp.MCPMethod `json:"method"`
+ ID mcp.RequestId `json:"id,omitempty"`
+ Params JSONRPCParams `json:"params,omitempty"`
+}
+
+// JSONRPCResponse defines an MCP response.
+//
+// By protocol spec, responses are further sub-categorized as either successful
+// results or errors. Either a result or an error MUST be set. A response MUST
+// NOT set both.
+//
+// https://modelcontextprotocol.io/specification/2025-03-26/basic#responses
+type JSONRPCResponse struct {
+ JSONRPC string `json:"jsonrpc"`
+ ID mcp.RequestId `json:"id"`
+ Result json.RawMessage `json:"result,omitempty"`
+ Error json.RawMessage `json:"error,omitempty"`
+}
+
+// GetListToolResult assumes the result is for mcp.MethodToolsList and returns
+// the corresponding go object.
+func (r *JSONRPCResponse) GetListToolResult() (*mcp.ListToolsResult, error) {
+ var listResult mcp.ListToolsResult
+ if err := json.Unmarshal([]byte(r.Result), &listResult); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &listResult, nil
+}
diff --git a/lib/utils/mcputils/protocol_test.go b/lib/utils/mcputils/protocol_test.go
new file mode 100644
index 0000000000000..cf7a61045ae58
--- /dev/null
+++ b/lib/utils/mcputils/protocol_test.go
@@ -0,0 +1,153 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcputils
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestJSONRPCNotification(t *testing.T) {
+ inputJSON := []byte(`{
+ "jsonrpc": "2.0",
+ "method": "notifications/message",
+ "params": {
+ "level": "error",
+ "logger": "database",
+ "data": {
+ "error": "Connection failed",
+ "details": {
+ "host": "localhost",
+ "port": 5432
+ }
+ }
+ }
+}`)
+
+ var base baseJSONRPCMessage
+ require.NoError(t, json.Unmarshal(inputJSON, &base))
+ assert.True(t, base.isNotification())
+ assert.False(t, base.isRequest())
+ assert.False(t, base.isResponse())
+
+ m := base.makeNotification()
+ require.NotNil(t, m)
+ assert.Equal(t, mcp.MCPMethod("notifications/message"), m.Method)
+ assert.Len(t, base.Params, 3)
+
+ outputJSON, err := json.MarshalIndent(m, "", " ")
+ require.NoError(t, err)
+ assert.JSONEq(t, string(inputJSON), string(outputJSON))
+}
+
+func TestJSONRPCRequest(t *testing.T) {
+ inputJSON := []byte(`{
+ "jsonrpc": "2.0",
+ "id": 2,
+ "method": "tools/call",
+ "params": {
+ "name": "get_weather",
+ "arguments": {
+ "location": "New York"
+ }
+ }
+}`)
+ var base baseJSONRPCMessage
+ require.NoError(t, json.Unmarshal(inputJSON, &base))
+ assert.False(t, base.isNotification())
+ assert.True(t, base.isRequest())
+ assert.False(t, base.isResponse())
+
+ m := base.makeRequest()
+ require.NotNil(t, m)
+ assert.Equal(t, mcp.MethodToolsCall, m.Method)
+ assert.Equal(t, "int64:2", m.ID.String())
+ name, ok := m.Params.GetName()
+ assert.True(t, ok)
+ assert.Equal(t, "get_weather", name)
+
+ outputJSON, err := json.MarshalIndent(m, "", " ")
+ require.NoError(t, err)
+ assert.JSONEq(t, string(inputJSON), string(outputJSON))
+}
+
+func TestJSONRPCResponse(t *testing.T) {
+ inputJSON := []byte(`{
+ "jsonrpc": "2.0",
+ "id": 2,
+ "result": {
+ "tools": [
+ {
+ "name": "get_weather",
+ "description": "Get current weather information for a location",
+ "inputSchema": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "City name or zip code"
+ }
+ },
+ "required": ["location"]
+ }
+ }
+ ],
+ "nextCursor": "next-page-cursor"
+ }
+}`)
+ var base baseJSONRPCMessage
+ require.NoError(t, json.Unmarshal(inputJSON, &base))
+ assert.False(t, base.isNotification())
+ assert.False(t, base.isRequest())
+ assert.True(t, base.isResponse())
+
+ m := base.makeResponse()
+ require.NotNil(t, m)
+ assert.Equal(t, "int64:2", m.ID.String())
+
+ outputJSON, err := json.MarshalIndent(m, "", " ")
+ require.NoError(t, err)
+ assert.JSONEq(t, string(inputJSON), string(outputJSON))
+
+ toolList, err := m.GetListToolResult()
+ require.NoError(t, err)
+ require.Equal(t, &mcp.ListToolsResult{
+ PaginatedResult: mcp.PaginatedResult{
+ NextCursor: "next-page-cursor",
+ },
+ Tools: []mcp.Tool{{
+ Name: "get_weather",
+ Description: "Get current weather information for a location",
+ InputSchema: mcp.ToolInputSchema{
+ Type: "object",
+ Properties: map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "City name or zip code",
+ },
+ },
+ Required: []string{"location"},
+ },
+ }},
+ }, toolList)
+}
diff --git a/lib/utils/mcputils/stdio.go b/lib/utils/mcputils/stdio.go
new file mode 100644
index 0000000000000..81972650eb2d8
--- /dev/null
+++ b/lib/utils/mcputils/stdio.go
@@ -0,0 +1,248 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcputils
+
+import (
+ "bufio"
+ "cmp"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+
+ "github.com/gravitational/trace"
+ "github.com/mark3labs/mcp-go/mcp"
+
+ "github.com/gravitational/teleport"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
+)
+
+// StderrTraceLogWriter implements io.Writer and logs the content at TRACE
+// level. Used for tracing stderr.
+type StderrTraceLogWriter struct {
+ ctx context.Context
+ log *slog.Logger
+}
+
+// NewStderrTraceLogWriter returns a new StderrTraceLogWriter.
+func NewStderrTraceLogWriter(ctx context.Context, log *slog.Logger) *StderrTraceLogWriter {
+ return &StderrTraceLogWriter{
+ ctx: ctx,
+ log: cmp.Or(log, slog.Default()),
+ }
+}
+
+// Write implements io.Writer and logs the given input p at trace level.
+// Note that the input p may contain arbitrary-length data, which can span
+// multiple lines or include partial lines.
+func (l *StderrTraceLogWriter) Write(p []byte) (int, error) {
+ l.log.Log(l.ctx, logutils.TraceLevel, "Trace stderr", "data", p)
+ return len(p), nil
+}
+
+// StdioMessageWriter writes a JSONRPC message in stdio transport.
+type StdioMessageWriter struct {
+ w io.Writer
+}
+
+// NewStdioMessageWriter returns a MessageWriter using stdio transport.
+func NewStdioMessageWriter(w io.Writer) *StdioMessageWriter {
+ return &StdioMessageWriter{
+ w: w,
+ }
+}
+
+// WriteMessage writes a JSONRPC message in stdio transport.
+func (w *StdioMessageWriter) WriteMessage(_ context.Context, resp mcp.JSONRPCMessage) error {
+ bytes, err := json.Marshal(resp)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ _, err = fmt.Fprintf(w.w, "%s\n", string(bytes))
+ return trace.Wrap(err)
+}
+
+// HandleParseErrorFunc handles parse errors.
+type HandleParseErrorFunc func(context.Context, *mcp.JSONRPCError) error
+
+// ReplyParseError returns a HandleParseErrorFunc that forwards the error to
+// provided writer.
+func ReplyParseError(w *StdioMessageWriter) HandleParseErrorFunc {
+ return func(ctx context.Context, parseError *mcp.JSONRPCError) error {
+ return trace.Wrap(w.WriteMessage(ctx, parseError))
+ }
+}
+
+// LogAndIgnoreParseError returns a HandleParseErrorFunc that logs the parse
+// error.
+func LogAndIgnoreParseError(log *slog.Logger) HandleParseErrorFunc {
+ return func(ctx context.Context, parseError *mcp.JSONRPCError) error {
+ log.DebugContext(ctx, "Ignore parse error", "error", parseError)
+ return nil
+ }
+}
+
+// StdioMessageReaderConfig is the config for StdioMessageReader.
+type StdioMessageReaderConfig struct {
+ // SourceReadCloser is the input to the read the message from.
+ // SourceReadCloser will be closed when reader finishes.
+ SourceReadCloser io.ReadCloser
+ // Logger is the slog.Logger.
+ Logger *slog.Logger
+ // ParentContext is the parent's context. Used for logging during tear down.
+ ParentContext context.Context
+
+ // OnClose is an optional callback when reader finishes.
+ OnClose func()
+ // OnParseError specifies the handler for handling parse error. Any error
+ // returned by the handler stops this message reader.
+ OnParseError HandleParseErrorFunc
+ // OnRequest specifies the handler for handling request. Any error by the
+ // handler stops this message reader.
+ OnRequest func(context.Context, *JSONRPCRequest) error
+ // OnResponse specifies the handler for handling response. Any error by the
+ // handler stops this message reader.
+ OnResponse func(context.Context, *JSONRPCResponse) error
+ // OnNotification specifies the handler for handling notification. Any error
+ // returned by the handler stops this message reader.
+ OnNotification func(context.Context, *JSONRPCNotification) error
+}
+
+// CheckAndSetDefaults checks values and sets defaults.
+func (c *StdioMessageReaderConfig) CheckAndSetDefaults() error {
+ if c.SourceReadCloser == nil {
+ return trace.BadParameter("missing parameter SourceReadCloser")
+ }
+ if c.OnParseError == nil {
+ return trace.BadParameter("missing parameter OnParseError")
+ }
+ if c.OnNotification == nil {
+ return trace.BadParameter("missing parameter OnNotification")
+ }
+ if c.OnRequest == nil && c.OnResponse == nil {
+ return trace.BadParameter("one of OnRequest or OnResponse must be set")
+ }
+ if c.ParentContext == nil {
+ return trace.BadParameter("missing parameter ParentContext")
+ }
+ if c.Logger == nil {
+ c.Logger = slog.With(teleport.ComponentKey, "mcp")
+ }
+ return nil
+}
+
+// StdioMessageReader reads requests from provided reader.
+type StdioMessageReader struct {
+ cfg StdioMessageReaderConfig
+}
+
+// NewStdioMessageReader creates a new StdioMessageReader. Must call "Start" to
+// start the processing.
+func NewStdioMessageReader(cfg StdioMessageReaderConfig) (*StdioMessageReader, error) {
+ if err := cfg.CheckAndSetDefaults(); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &StdioMessageReader{
+ cfg: cfg,
+ }, nil
+}
+
+// Run starts reading requests from provided reader. Run blocks until an
+// error happens from the provided reader or any of the handler.
+func (r *StdioMessageReader) Run(ctx context.Context) {
+ r.cfg.Logger.InfoContext(ctx, "Start processing stdio messages")
+
+ finished := make(chan struct{})
+ go func() {
+ r.startProcess(ctx)
+ close(finished)
+ }()
+
+ select {
+ case <-finished:
+ case <-ctx.Done():
+ }
+
+ r.cfg.Logger.InfoContext(r.cfg.ParentContext, "Finished processing stdio messages")
+ if err := r.cfg.SourceReadCloser.Close(); err != nil && !IsOKCloseError(err) {
+ r.cfg.Logger.ErrorContext(r.cfg.ParentContext, "Failed to close reader", "error", err)
+ }
+ if r.cfg.OnClose != nil {
+ r.cfg.OnClose()
+ }
+}
+
+func (r *StdioMessageReader) startProcess(ctx context.Context) {
+ lineReader := bufio.NewReader(r.cfg.SourceReadCloser)
+ for {
+ if ctx.Err() != nil {
+ return
+ }
+
+ if err := r.processNextLine(ctx, lineReader); err != nil {
+ if !IsOKCloseError(err) {
+ r.cfg.Logger.ErrorContext(ctx, "Failed to process line", "error", err)
+ }
+ return
+ }
+ }
+}
+
+func (r *StdioMessageReader) processNextLine(ctx context.Context, lineReader *bufio.Reader) error {
+ line, err := lineReader.ReadString('\n')
+ if err != nil {
+ return trace.Wrap(err, "reading line")
+ }
+
+ r.cfg.Logger.Log(ctx, logutils.TraceLevel, "Trace stdio", "line", line)
+
+ var base baseJSONRPCMessage
+ if parseError := json.Unmarshal([]byte(line), &base); parseError != nil {
+ rpcError := mcp.NewJSONRPCError(mcp.NewRequestId(nil), mcp.PARSE_ERROR, parseError.Error(), nil)
+ if err := r.cfg.OnParseError(ctx, &rpcError); err != nil {
+ return trace.Wrap(err, "handling JSON unmarshal error")
+ }
+ }
+
+ switch {
+ case base.isNotification():
+ return trace.Wrap(r.cfg.OnNotification(ctx, base.makeNotification()), "handling notification")
+ case base.isRequest():
+ if r.cfg.OnRequest != nil {
+ return trace.Wrap(r.cfg.OnRequest(ctx, base.makeRequest()), "handling request")
+ }
+ // Should not happen. Log something just in case.
+ r.cfg.Logger.DebugContext(ctx, "Skipping request", "id", base.ID)
+ return nil
+ case base.isResponse():
+ if r.cfg.OnResponse != nil {
+ return trace.Wrap(r.cfg.OnResponse(ctx, base.makeResponse()), "handling response")
+ }
+ // Should not happen. Log something just in case.
+ r.cfg.Logger.DebugContext(ctx, "Skipping response", "id", base.ID)
+ return nil
+ default:
+ rpcError := mcp.NewJSONRPCError(base.ID, mcp.PARSE_ERROR, "unknown message type", line)
+ return trace.Wrap(
+ r.cfg.OnParseError(ctx, &rpcError),
+ "handling unknown message type error",
+ )
+ }
+}
diff --git a/lib/utils/mcputils/stdio_test.go b/lib/utils/mcputils/stdio_test.go
new file mode 100644
index 0000000000000..ec63799353ed5
--- /dev/null
+++ b/lib/utils/mcputils/stdio_test.go
@@ -0,0 +1,190 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package mcputils
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "log"
+ "log/slog"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/gravitational/trace"
+ mcpclient "github.com/mark3labs/mcp-go/client"
+ mcpclienttransport "github.com/mark3labs/mcp-go/client/transport"
+ "github.com/mark3labs/mcp-go/mcp"
+ mcpserver "github.com/mark3labs/mcp-go/server"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestStdioHelpers tests StdioMessageReader and StdioMessageWriter by
+// implementing a passthrough reverse proxy.
+//
+// The flow looks something like this:
+// request: MCP client --> client message reader --> server message writer --> MCP server
+// response: MCP client <-- client message writer <-- server message reader <-- MCP server
+func TestStdioHelpers(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ t.Cleanup(cancel)
+
+ // Set up some counters for verification.
+ var readClientNotifications int32
+ var readClientRequests int32
+ var readServerNotifications int32
+ var readServerResponses int32
+
+ // Pipes for hooking things up.
+ clientStdin, writeToClient := io.Pipe()
+ readFromClient, clientStdout := io.Pipe()
+ serverStdio, writeToServer := io.Pipe()
+ readFromServer, serverStdout := io.Pipe()
+ t.Cleanup(func() {
+ assert.NoError(t, trace.NewAggregate(
+ clientStdin.Close(), writeToClient.Close(),
+ readFromClient.Close(), clientStdout.Close(),
+ serverStdio.Close(), writeToServer.Close(),
+ readFromServer.Close(), serverStdout.Close(),
+ ))
+ })
+
+ // Make "low-level" message readers and writers for MITM proxy.
+ clientMessageWriter := NewStdioMessageWriter(writeToClient)
+ serverMessageWriter := NewStdioMessageWriter(writeToServer)
+
+ clientMessageReader, err := NewStdioMessageReader(StdioMessageReaderConfig{
+ ParentContext: context.Background(),
+ SourceReadCloser: readFromClient,
+ OnNotification: func(ctx context.Context, notification *JSONRPCNotification) error {
+ atomic.AddInt32(&readClientNotifications, 1)
+ return trace.Wrap(serverMessageWriter.WriteMessage(ctx, notification))
+ },
+ OnRequest: func(ctx context.Context, request *JSONRPCRequest) error {
+ atomic.AddInt32(&readClientRequests, 1)
+ return trace.Wrap(serverMessageWriter.WriteMessage(ctx, request))
+ },
+ OnParseError: ReplyParseError(clientMessageWriter),
+ })
+ require.NoError(t, err)
+ clientMessageReaderClosed := make(chan struct{})
+ go func() {
+ clientMessageReader.Run(ctx)
+ close(clientMessageReaderClosed)
+ }()
+
+ serverMessageReader, err := NewStdioMessageReader(StdioMessageReaderConfig{
+ ParentContext: context.Background(),
+ SourceReadCloser: readFromServer,
+ OnNotification: func(ctx context.Context, notification *JSONRPCNotification) error {
+ atomic.AddInt32(&readServerNotifications, 1)
+ return trace.Wrap(clientMessageWriter.WriteMessage(ctx, notification))
+ },
+ OnResponse: func(ctx context.Context, response *JSONRPCResponse) error {
+ atomic.AddInt32(&readServerResponses, 1)
+ return trace.Wrap(clientMessageWriter.WriteMessage(ctx, response))
+ },
+ OnParseError: LogAndIgnoreParseError(slog.Default()),
+ })
+ require.NoError(t, err)
+ serverMessageReaderClosed := make(chan struct{})
+ serverMessageReaderCtx, serverMessageReaderCtxCancel := context.WithCancel(ctx)
+ go func() {
+ serverMessageReader.Run(serverMessageReaderCtx)
+ close(serverMessageReaderClosed)
+ }()
+
+ // Make "high-level" MCP client and server with stdio transport as the two
+ // ends.
+ stdioClientTransport := mcpclienttransport.NewIO(clientStdin, clientStdout, io.NopCloser(bytes.NewReader(nil)))
+ stdioClient := mcpclient.NewClient(stdioClientTransport)
+ defer stdioClient.Close()
+ require.NoError(t, stdioClient.Start(ctx))
+
+ stdioServer := mcpserver.NewStdioServer(makeTestMCPServer())
+ stdioServer.SetErrorLogger(log.New(io.Discard, "", log.LstdFlags))
+ go stdioServer.Listen(ctx, serverStdio, serverStdout)
+
+ // Test things out.
+ t.Run("client initialize", func(t *testing.T) {
+ initReq := mcp.InitializeRequest{}
+ initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
+ initReq.Params.ClientInfo = mcp.Implementation{
+ Name: "test-client",
+ Version: "1.0.0",
+ }
+ _, err = stdioClient.Initialize(ctx, initReq)
+ require.NoError(t, err)
+ })
+
+ t.Run("client call tool", func(t *testing.T) {
+ callToolRequest := mcp.CallToolRequest{}
+ callToolRequest.Params.Name = "hello-server"
+ callToolResult, err := stdioClient.CallTool(ctx, callToolRequest)
+ require.NoError(t, err)
+ require.NotNil(t, callToolResult)
+ require.Equal(t, []mcp.Content{
+ mcp.NewTextContent("hello client"),
+ }, callToolResult.Content)
+ })
+
+ t.Run("reader closed by closing stdin", func(t *testing.T) {
+ readFromClient.Close()
+ select {
+ case <-clientMessageReaderClosed:
+ case <-time.After(time.Second * 2):
+ require.Fail(t, "timeout waiting for reader closed by closing stdin")
+ }
+ })
+
+ t.Run("reader closed by canceling context", func(t *testing.T) {
+ serverMessageReaderCtxCancel()
+ select {
+ case <-serverMessageReaderClosed:
+ case <-time.After(time.Second * 2):
+ require.Fail(t, "timeout waiting for reader closed by canceling context")
+ }
+ })
+
+ t.Run("verify counters", func(t *testing.T) {
+ // client -> server: initialize request
+ // server -> client: initialize response
+ // client -> server: notifications/initialized
+ // client -> server: tools\call request
+ // server -> client: tools\call response
+ assert.Equal(t, int32(1), atomic.LoadInt32(&readClientNotifications))
+ assert.Equal(t, int32(2), atomic.LoadInt32(&readClientRequests))
+ assert.Equal(t, int32(0), atomic.LoadInt32(&readServerNotifications))
+ assert.Equal(t, int32(2), atomic.LoadInt32(&readServerResponses))
+ })
+}
+
+func makeTestMCPServer() *mcpserver.MCPServer {
+ server := mcpserver.NewMCPServer("test-server", "1.0.0")
+ server.AddTool(mcp.Tool{
+ Name: "hello-server",
+ }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{mcp.NewTextContent("hello client")},
+ }, nil
+ })
+ return server
+}