Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions api/proto/teleport/legacy/types/events/events.proto
Original file line number Diff line number Diff line change
Expand Up @@ -8605,6 +8605,11 @@ message MCPSessionStart {
(gogoproto.embed) = true,
(gogoproto.jsontag) = ""
];

// McpSessionId is the session ID tracked by remote MCP servers.
string mcp_session_id = 7;
// ClientInfo stores reported client agent information, e.g. "claude-ai/0.1.0".
string client_info = 8;
}

// MCPSessionEnd is emitted when an MCP session ends.
Expand Down
16 changes: 13 additions & 3 deletions api/types/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ func (a *AppV3) CheckAndSetDefaults() error {
case a.Spec.Cloud != "":
a.Spec.URI = fmt.Sprintf("cloud://%v", a.Spec.Cloud)
case a.Spec.MCP != nil && a.Spec.MCP.Command != "":
a.Spec.URI = SchemaMCPStdio
a.Spec.URI = SchemeMCPStdio + "://"
default:
return trace.BadParameter("app %q URI is empty", a.GetName())
}
Expand Down Expand Up @@ -519,6 +519,9 @@ func (a *AppV3) checkMCP() error {
switch GetMCPServerTransportType(a.Spec.URI) {
case MCPTransportStdio:
return trace.Wrap(a.checkMCPStdio())
case MCPTransportSSE:
_, err := url.Parse(a.Spec.URI)
return trace.Wrap(err)
default:
return trace.BadParameter("unsupported MCP server %q with URI %q", a.GetName(), a.Spec.URI)
}
Expand Down Expand Up @@ -670,9 +673,16 @@ func (p *PortRange) String() string {
// the URI. If no MCP transport type can be determined from the URI, an empty
// string is returned.
func GetMCPServerTransportType(uri string) string {
switch {
case strings.HasPrefix(uri, SchemaMCPStdio):
parsed, err := url.Parse(uri)
if err != nil {
return ""
}

switch parsed.Scheme {
case SchemeMCPStdio:
return MCPTransportStdio
case SchemeMCPSSEHTTP, SchemeMCPSSEHTTPS:
return MCPTransportSSE
default:
return ""
}
Expand Down
32 changes: 32 additions & 0 deletions api/types/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,28 @@ func TestNewAppV3(t *testing.T) {
},
wantErr: require.NoError,
},
{
name: "mcp with uri",
meta: Metadata{
Name: "mcp-everything",
},
spec: AppSpecV3{
URI: "mcp+sse+http://localhost:12345/sse",
},
want: &AppV3{
Kind: "app",
SubKind: "mcp",
Version: "v3",
Metadata: Metadata{
Name: "mcp-everything",
Namespace: "default",
},
Spec: AppSpecV3{
URI: "mcp+sse+http://localhost:12345/sse",
},
},
wantErr: require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -758,6 +780,16 @@ func TestGetMCPServerTransportType(t *testing.T) {
uri: "http://localhost",
want: "",
},
{
name: "SSE HTTP",
uri: "mcp+sse+http://127.0.0.1:12345",
want: MCPTransportSSE,
},
{
name: "SSE HTTPS",
uri: "mcp+sse+httpS://some-domain:443",
want: MCPTransportSSE,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
12 changes: 10 additions & 2 deletions api/types/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,10 +937,18 @@ const (
// CloudGCP identifies that a resource was discovered in GCP.
CloudGCP = "GCP"

// SchemaMCPStdio is a URI schema for MCP servers using stdio transport.
SchemaMCPStdio = "mcp+stdio://"
// SchemeMCPStdio is a URI scheme for MCP servers using stdio transport.
SchemeMCPStdio = "mcp+stdio"
// MCPTransportStdio indicates the MCP server uses stdio transport.
MCPTransportStdio = "stdio"
// SchemeMCPSSEHTTP is a URI scheme for MCP servers using HTTP with SSE
// transport.
SchemeMCPSSEHTTP = "mcp+sse+http"
// SchemeMCPSSEHTTPS is a URI scheme for MCP servers using HTTPS with SSE
// transport.
SchemeMCPSSEHTTPS = "mcp+sse+https"
// MCPTransportSSE indicates the MCP server uses SSE transport.
MCPTransportSSE = "SSE"

// DiscoveredResourceNode identifies a discovered SSH node.
DiscoveredResourceNode = "node"
Expand Down
2,578 changes: 1,335 additions & 1,243 deletions api/types/events/events.pb.go

Large diffs are not rendered by default.

14 changes: 13 additions & 1 deletion integration/appaccess/appaccess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (
"github.com/gravitational/teleport/lib/service/servicecfg"
"github.com/gravitational/teleport/lib/srv/app/common"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/mcptest"
"github.com/gravitational/teleport/lib/web/app"
)

Expand All @@ -58,7 +59,17 @@ import (
// It allows to make the entire cluster set up once, instead of per test,
// which speeds things up significantly.
func TestAppAccess(t *testing.T) {
pack := Setup(t)
// Enable MCP test servers.
sseServerURL := mcptest.MustStartSSETestServer(t)
extraApps := []servicecfg.App{{
Name: "test-sse",
URI: "mcp+sse+" + sseServerURL,
}}

// Reusing the pack as much as we can.
pack := SetupWithOptions(t, AppTestOptions{
ExtraRootApps: extraApps,
})

t.Run("Forward", bind(pack, testForward))
t.Run("Websockets", bind(pack, testWebsockets))
Expand All @@ -72,6 +83,7 @@ func TestAppAccess(t *testing.T) {
t.Run("NoHeaderOverrides", bind(pack, testNoHeaderOverrides))
t.Run("AuditEvents", bind(pack, testAuditEvents))

// MCP access tests.
t.Run("MCP", bind(pack, testMCP))

// This test should go last because it stops/starts app servers.
Expand Down
21 changes: 20 additions & 1 deletion integration/appaccess/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ func testMCP(pack *Pack, t *testing.T) {
t.Run("DialMCPServer stdio success", func(t *testing.T) {
testMCPDialStdio(t, pack)
})

t.Run("DialMCPServer stdio to sse success", func(t *testing.T) {
testMCPDialStdioToSSE(t, pack, "test-sse")
})
}

func testMCPDialStdioNoServerFound(t *testing.T, pack *Pack) {
Expand All @@ -52,7 +56,7 @@ func testMCPDialStdio(t *testing.T, pack *Pack) {
serverConn, err := pack.tc.DialMCPServer(context.Background(), libmcp.DemoServerName)
require.NoError(t, err)

ctx := context.Background()
ctx := t.Context()
stdioClient := mcptest.NewStdioClientFromConn(t, serverConn)

_, err = mcptest.InitializeClient(ctx, stdioClient)
Expand All @@ -62,3 +66,18 @@ func testMCPDialStdio(t *testing.T, pack *Pack) {
require.NoError(t, err)
require.Len(t, listTools.Tools, 3)
}

func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) {
require.NoError(t, pack.tc.SaveProfile(false))

serverConn, err := pack.tc.DialMCPServer(context.Background(), appName)
require.NoError(t, err)

ctx := t.Context()
stdioClient := mcptest.NewStdioClientFromConn(t, serverConn)

_, err = mcptest.InitializeClient(ctx, stdioClient)
require.NoError(t, err)

mcptest.MustCallServerTool(t, ctx, stdioClient)
}
10 changes: 4 additions & 6 deletions lib/client/mcp/reconnect.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ func ProxyStdioConnWithAutoReconnect(ctx context.Context, cfg ProxyStdioConnWith
}

clientRequestReader, err := mcputils.NewMessageReader(mcputils.MessageReaderConfig{
Transport: mcputils.NewStdioReader(cfg.ClientStdio),
ParentContext: ctx,
Logger: cfg.Logger.With("client", "stdin"),
OnParseError: mcputils.ReplyParseError(cfg.clientResponseWriter),
Transport: mcputils.NewStdioReader(cfg.ClientStdio),
Logger: cfg.Logger.With("client", "stdin"),
OnParseError: mcputils.ReplyParseError(cfg.clientResponseWriter),
OnNotification: func(ctx context.Context, notification *mcputils.JSONRPCNotification) error {
// By spec, we should not reply to notifications. Try our best to
// send a notification with the error message. In practice, only the
Expand Down Expand Up @@ -203,8 +202,7 @@ func (r *serverConnWithAutoReconnect) getServerRequestWriterLocked(ctx context.C

// This should never fail as long the correct config is passed in.
serverResponseReader, err := mcputils.NewMessageReader(mcputils.MessageReaderConfig{
Transport: serverStdioReader,
ParentContext: r.parentCtx,
Transport: serverStdioReader,
// OnClose is called when server connection is dead.
// Teleport Proxy automatically closes the connection when tsh session
// is expired.
Expand Down
2 changes: 1 addition & 1 deletion lib/service/servicecfg/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (a *App) CheckAndSetDefaults() error {
case a.Cloud != "":
a.URI = fmt.Sprintf("cloud://%v", a.Cloud)
case a.MCP != nil && a.MCP.Command != "":
a.URI = types.SchemaMCPStdio
a.URI = types.SchemeMCPStdio + "://"
default:
return trace.BadParameter("missing application %q URI", a.Name)
}
Expand Down
3 changes: 2 additions & 1 deletion lib/srv/app/connections_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ func NewConnectionsHandler(closeContext context.Context, cfg *ConnectionsHandler
HostID: c.cfg.HostID,
AccessPoint: c.cfg.AccessPoint,
EnableDemoServer: c.cfg.MCPDemoServer,
CipherSuites: c.cfg.CipherSuites,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -644,7 +645,7 @@ func (c *ConnectionsHandler) handleConnection(conn net.Conn) (func(), error) {
AuthCtx: authCtx,
App: app,
}
return nil, trace.Wrap(c.mcpServer.HandleSession(ctx, sessionCtx))
return nil, trace.Wrap(c.mcpServer.HandleSession(ctx, &sessionCtx))

default:
cleanup := func() {
Expand Down
1 change: 1 addition & 0 deletions lib/srv/mcp/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ func (a *sessionAuditor) emitStartEvent(ctx context.Context) {
UserMetadata: a.makeUserMetadata(),
ConnectionMetadata: a.makeConnectionMetadata(),
AppMetadata: a.makeAppMetadata(),
McpSessionId: a.sessionCtx.mcpSessionID.String(),
})
}

Expand Down
3 changes: 2 additions & 1 deletion lib/srv/mcp/demo.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package mcp
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"

Expand Down Expand Up @@ -52,7 +53,7 @@ func NewDemoServerApp() (types.Application, error) {
Labels: map[string]string{types.TeleportInternalResourceType: types.DemoResource},
Description: "A demo MCP server that shows current user and session information",
}, types.AppSpecV3{
URI: types.SchemaMCPStdio + DemoServerName,
URI: fmt.Sprintf("%s://%s", types.SchemeMCPStdio, DemoServerName),
})
return app, trace.Wrap(err)
}
Expand Down
27 changes: 21 additions & 6 deletions lib/srv/mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/gravitational/teleport"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
Expand All @@ -55,6 +56,9 @@ type ServerConfig struct {
AccessPoint AccessPoint
// EnableDemoServer enables the "Teleport Demo" MCP server.
EnableDemoServer bool
// CipherSuites is the list of TLS cipher suites that have been configured
// for this process.
CipherSuites []uint16

clock clockwork.Clock
}
Expand All @@ -73,6 +77,9 @@ func (c *ServerConfig) CheckAndSetDefaults() error {
if c.AccessPoint == nil {
return trace.BadParameter("missing AccessPoint")
}
if len(c.CipherSuites) == 0 {
return trace.BadParameter("missing CipherSuites")
}
if c.Log == nil {
c.Log = slog.With(teleport.ComponentKey, teleport.ComponentMCP)
}
Expand All @@ -99,14 +106,22 @@ func NewServer(cfg ServerConfig) (*Server, error) {
}

// HandleSession handles an authorized client connection.
func (s *Server) HandleSession(ctx context.Context, sessionCtx SessionCtx) error {
func (s *Server) HandleSession(ctx context.Context, sessionCtx *SessionCtx) error {
if err := sessionCtx.checkAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
if s.cfg.EnableDemoServer && isDemoServerApp(sessionCtx.App) {
return trace.Wrap(s.handleStdio(ctx, sessionCtx, makeDemoServerRunner))
}
return trace.Wrap(s.handleStdio(ctx, sessionCtx, makeExecServerRunner))
transportType := types.GetMCPServerTransportType(sessionCtx.App.GetURI())
switch transportType {
case types.MCPTransportStdio:
return trace.Wrap(s.handleStdio(ctx, sessionCtx, makeExecServerRunner))
case types.MCPTransportSSE:
return trace.Wrap(s.handleStdioToSSE(ctx, sessionCtx))
default:
return trace.BadParameter("unknown transport type: %v", transportType)
}
}

// HandleUnauthorizedConnection handles an unauthorized client connection.
Expand All @@ -118,7 +133,7 @@ func (s *Server) HandleUnauthorizedConnection(ctx context.Context, clientConn ne
return trace.Wrap(s.handleAuthErrStdio(ctx, clientConn, authErr))
}

func (s *Server) makeSessionAuditor(ctx context.Context, sessionCtx SessionCtx, logger *slog.Logger) (*sessionAuditor, error) {
func (s *Server) makeSessionAuditor(ctx context.Context, sessionCtx *SessionCtx, logger *slog.Logger) (*sessionAuditor, error) {
clusterName, err := s.cfg.AccessPoint.GetClusterName(ctx)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -143,11 +158,11 @@ func (s *Server) makeSessionAuditor(ctx context.Context, sessionCtx SessionCtx,
logger: logger,
hostID: s.cfg.HostID,
preparer: preparer,
sessionCtx: &sessionCtx,
sessionCtx: sessionCtx,
})
}

func (s *Server) makeSessionHandler(ctx context.Context, sessionCtx SessionCtx) (*sessionHandler, error) {
func (s *Server) makeSessionHandler(ctx context.Context, sessionCtx *SessionCtx) (*sessionHandler, error) {
// Some extra info for debugging purpose.
logger := s.cfg.Log.With(
"client_ip", sessionCtx.ClientConn.RemoteAddr(),
Expand All @@ -162,7 +177,7 @@ func (s *Server) makeSessionHandler(ctx context.Context, sessionCtx SessionCtx)
}

return newSessionHandler(sessionHandlerConfig{
SessionCtx: &sessionCtx,
SessionCtx: sessionCtx,
sessionAuditor: sessionAuditor,
accessPoint: s.cfg.AccessPoint,
logger: logger,
Expand Down
16 changes: 16 additions & 0 deletions lib/srv/mcp/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"encoding/json"
"log/slog"
"net"
"sync/atomic"
"time"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -55,6 +56,9 @@ type SessionCtx struct {
// Note that for stdio-based MCP server, a new session ID is generated per
// connection instead of using the web session ID from the app route.
sessionID session.ID

// mcpSessionID is the MCP session ID tracked by remote MCP server.
mcpSessionID atomicString
}

func (c *SessionCtx) checkAndSetDefaults() error {
Expand Down Expand Up @@ -271,3 +275,15 @@ func makeToolAccessDeniedResponse(msg *mcputils.JSONRPCRequest, authErr error) m
authErr,
)
}

type atomicString struct {
atomic.Pointer[string]
}

// String loads the atomic string value. If the point is nil, empty is returned.
func (s *atomicString) String() string {
if loaded := s.Load(); loaded != nil {
return *loaded
}
return ""
}
Loading
Loading