diff --git a/constants.go b/constants.go index cb729e15c9e35..01836d37b5186 100644 --- a/constants.go +++ b/constants.go @@ -298,6 +298,9 @@ const ( // ComponentForwardingGit represents the SSH proxy that forwards Git commands. ComponentForwardingGit = "git:forward" + // ComponentMCP represents the MCP server handler. + ComponentMCP = "mcp" + // VerboseLogsEnvVar forces all logs to be verbose (down to DEBUG level) VerboseLogsEnvVar = "TELEPORT_DEBUG" diff --git a/integration/appaccess/appaccess_test.go b/integration/appaccess/appaccess_test.go index 8bb73e091754b..7f1cf582e7df9 100644 --- a/integration/appaccess/appaccess_test.go +++ b/integration/appaccess/appaccess_test.go @@ -48,6 +48,7 @@ import ( "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/srv/app/common" + libmcp "github.com/gravitational/teleport/lib/srv/mcp" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/app" ) @@ -57,6 +58,8 @@ import ( // It allows to make the entire cluster set up once, instead of per test, // which speeds things up significantly. func TestAppAccess(t *testing.T) { + t.Setenv(libmcp.InMemoryServerEnvVar, "true") + pack := Setup(t) t.Run("Forward", bind(pack, testForward)) @@ -71,6 +74,8 @@ func TestAppAccess(t *testing.T) { t.Run("NoHeaderOverrides", bind(pack, testNoHeaderOverrides)) t.Run("AuditEvents", bind(pack, testAuditEvents)) + t.Run("MCP", bind(pack, testMCP)) + // This test should go last because it stops/starts app servers. t.Run("TestAppServersHA", bind(pack, testServersHA)) } diff --git a/integration/appaccess/mcp_test.go b/integration/appaccess/mcp_test.go new file mode 100644 index 0000000000000..f1e45479d8114 --- /dev/null +++ b/integration/appaccess/mcp_test.go @@ -0,0 +1,76 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package appaccess + +import ( + "bytes" + "context" + "io" + "testing" + + mcpclient "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" + + libmcp "github.com/gravitational/teleport/lib/srv/mcp" +) + +func testMCP(pack *Pack, t *testing.T) { + t.Run("DialMCPServer stdio no server found", func(t *testing.T) { + testMCPDialStdioNoServerFound(t, pack) + }) + + t.Run("DialMCPSererver stdio success", func(t *testing.T) { + testMCPDialStdio(t, pack) + }) +} + +func testMCPDialStdioNoServerFound(t *testing.T, pack *Pack) { + require.NoError(t, pack.tc.SaveProfile(false)) + + _, err := pack.tc.DialMCPServer(context.Background(), "not-found") + require.Error(t, err) +} + +func testMCPDialStdio(t *testing.T, pack *Pack) { + require.NoError(t, pack.tc.SaveProfile(false)) + + serverConn, err := pack.tc.DialMCPServer(context.Background(), libmcp.InMemoryServerName) + require.NoError(t, err) + + ctx := context.Background() + clientTransport := transport.NewIO(serverConn, serverConn, io.NopCloser(bytes.NewReader(nil))) + stdioClient := mcpclient.NewClient(clientTransport) + defer stdioClient.Close() + require.NoError(t, stdioClient.Start(ctx)) + + initReq := mcp.InitializeRequest{} + initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initReq.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + _, err = stdioClient.Initialize(ctx, initReq) + require.NoError(t, err) + + listTools, err := stdioClient.ListTools(ctx, mcp.ListToolsRequest{}) + require.NoError(t, err) + require.Len(t, listTools.Tools, 2) +} diff --git a/integration/appaccess/pack.go b/integration/appaccess/pack.go index b7eaf2f1f4e48..5a223695f5d13 100644 --- a/integration/appaccess/pack.go +++ b/integration/appaccess/pack.go @@ -29,6 +29,7 @@ import ( "net" "net/http" "net/url" + "os" "testing" "time" @@ -57,6 +58,7 @@ import ( "github.com/gravitational/teleport/lib/srv/alpnproxy" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/srv/app/common" + libmcp "github.com/gravitational/teleport/lib/srv/mcp" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web" "github.com/gravitational/teleport/lib/web/app" @@ -1062,6 +1064,12 @@ func (p *Pack) startLeafAppServers(t *testing.T, count int, opts AppTestOptions) } func waitForAppRegInRemoteSiteCache(t *testing.T, tunnel reversetunnelclient.Server, clusterName string, cfgApps []servicecfg.App, hostUUID string) { + if os.Getenv(libmcp.InMemoryServerEnvVar) == "true" { + cfgApps = append(cfgApps, servicecfg.App{ + Name: libmcp.InMemoryServerName, + }) + } + require.EventuallyWithT(t, func(t *assert.CollectT) { site, err := tunnel.GetSite(clusterName) assert.NoError(t, err) diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index 37f809f827041..414fbf9026051 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -960,6 +960,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.31.0 h1:4UxSV8aM770OPmTvaVe/b1rA2oZAjBMhGBfUgOGut+4= +github.com/mark3labs/mcp-go v0.31.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= @@ -1310,6 +1312,8 @@ github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8 github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU= github.com/xlab/treeprint v1.2.0 h1:HzHnuAF1plUN2zGlAFHbSQP2qJ0ZAD3XF5XD7OesXRQ= github.com/xlab/treeprint v1.2.0/go.mod h1:gj5Gd3gPdKtR1ikdDK6fnFLdmIS0X30kTTuNd/WEJu0= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/lib/client/api.go b/lib/client/api.go index 07e04f8b41ef6..1d2371f2d108b 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -72,6 +72,7 @@ import ( "github.com/gravitational/teleport/api/utils/grpc/interceptors" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/api/utils/keys/hardwarekey" + "github.com/gravitational/teleport/api/utils/pingconn" "github.com/gravitational/teleport/api/utils/prompt" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/touchid" @@ -5462,3 +5463,101 @@ func (tc *TeleportClient) HeadlessApprove(ctx context.Context, headlessAuthentic err = rootClient.UpdateHeadlessAuthenticationState(ctx, headlessAuthenticationID, types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, mfaResp) return trace.Wrap(err) } + +// DialALPN dials the Proxy with provided client certificate and ALPN protocol. +func (tc *TeleportClient) DialALPN(ctx context.Context, clientCert tls.Certificate, protocol alpncommon.Protocol) (net.Conn, error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/DialALPN", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + attribute.String("protocol", string(protocol)), + ), + ) + defer span.End() + + dialConfig := client.ALPNDialerConfig{ + ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired, + TLSConfig: &tls.Config{ + NextProtos: alpncommon.ProtocolToStringsWithPing(protocol), + InsecureSkipVerify: tc.InsecureSkipVerify, + Certificates: []tls.Certificate{clientCert}, + }, + GetClusterCAs: tc.RootClusterCACertPool, + } + + tlsConn, err := client.DialALPN(ctx, tc.WebProxyAddr, dialConfig) + if err != nil { + return nil, trace.Wrap(err) + } + if alpncommon.IsPingProtocol(alpncommon.Protocol(tlsConn.ConnectionState().NegotiatedProtocol)) { + return pingconn.NewTLS(tlsConn), nil + } + return tlsConn, nil +} + +// DialMCPServer makes a connection to the remote MCP server. +func (tc *TeleportClient) DialMCPServer(ctx context.Context, appName string) (net.Conn, error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/DialMCPServer", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + attribute.String("app", appName), + ), + ) + defer span.End() + + apps, err := tc.ListApps(ctx, &proto.ListResourcesRequest{ + ResourceType: types.KindAppServer, + Namespace: apidefaults.Namespace, + PredicateExpression: fmt.Sprintf("name == %q", strings.TrimSpace(appName)), + }) + if err != nil { + return nil, trace.Wrap(err) + } + switch len(apps) { + case 0: + return nil, trace.NotFound("no MCP servers found") + case 1: + default: + log.WarnContext(ctx, "multiple apps found, using the first one") + } + if !apps[0].IsMCP() { + return nil, trace.BadParameter("app %q is not a MCP server", appName) + } + + cert, err := tc.issueMCPCertWithMFA(ctx, apps[0]) + if err != nil { + return nil, trace.Wrap(err) + } + return tc.DialALPN(ctx, cert, alpncommon.ProtocolMCP) +} + +func (tc *TeleportClient) issueMCPCertWithMFA(ctx context.Context, mcpServer types.Application) (tls.Certificate, error) { + profile, err := tc.ProfileStatus() + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + appCertParams := ReissueParams{ + RouteToCluster: tc.SiteName, + RouteToApp: proto.RouteToApp{ + Name: mcpServer.GetName(), + PublicAddr: mcpServer.GetPublicAddr(), + ClusterName: tc.SiteName, + URI: mcpServer.GetURI(), + }, + AccessRequests: profile.ActiveRequests, + } + + // Do NOT write the keyring to avoid race condition when AI clients run + // multiple tsh at the same time. + keyRing, err := tc.IssueUserCertsWithMFA(ctx, appCertParams) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + cert, err := keyRing.AppTLSCert(mcpServer.GetName()) + return cert, trace.Wrap(err) +} diff --git a/lib/service/service.go b/lib/service/service.go index fb78ccdcfd3ef..68e08e9d24e3d 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -160,6 +160,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db" "github.com/gravitational/teleport/lib/srv/desktop" "github.com/gravitational/teleport/lib/srv/ingress" + "github.com/gravitational/teleport/lib/srv/mcp" "github.com/gravitational/teleport/lib/srv/regular" "github.com/gravitational/teleport/lib/srv/transport/transportv1" "github.com/gravitational/teleport/lib/sshutils" @@ -5045,11 +5046,16 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // Register ALPN handler that will be accepting connections for plain // TCP applications. + // Use the same handler for MCP protocols, for now. if alpnRouter != nil { alpnRouter.Add(alpnproxy.HandlerDecs{ MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolTCP), Handler: webServer.HandleConnection, }) + alpnRouter.Add(alpnproxy.HandlerDecs{ + MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolMCP), + Handler: webServer.HandleConnection, + }) } var peerAddrString string @@ -6174,6 +6180,14 @@ func (process *TeleportProcess) initApps() { applications = append(applications, a) } + if os.Getenv(mcp.InMemoryServerEnvVar) == "true" { + if mcpInMemoryServer, err := mcp.NewInMemoryServerApp(); err != nil { + logger.ErrorContext(process.ExitContext(), "Failed to create in-memory MCP server app") + } else { + applications = append(applications, mcpInMemoryServer) + } + } + lockWatcher, err := services.NewLockWatcher(process.ExitContext(), services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentApp, diff --git a/lib/service/service_test.go b/lib/service/service_test.go index b7bd9663856ea..872a14165c7ca 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -831,6 +831,7 @@ func TestSetupProxyTLSConfig(t *testing.T) { "h2", "acme-tls/1", "teleport-tcp-ping", + "teleport-mcp-ping", "teleport-postgres-ping", "teleport-mysql-ping", "teleport-mongodb-ping", @@ -851,6 +852,7 @@ func TestSetupProxyTLSConfig(t *testing.T) { "teleport-proxy-ssh-grpc", "teleport-proxy-grpc", "teleport-proxy-grpc-mtls", + "teleport-mcp", "teleport-postgres", "teleport-mysql", "teleport-mongodb", @@ -871,6 +873,7 @@ func TestSetupProxyTLSConfig(t *testing.T) { acmeEnabled: false, wantNextProtos: []string{ "teleport-tcp-ping", + "teleport-mcp-ping", "teleport-postgres-ping", "teleport-mysql-ping", "teleport-mongodb-ping", @@ -894,6 +897,7 @@ func TestSetupProxyTLSConfig(t *testing.T) { "teleport-proxy-ssh-grpc", "teleport-proxy-grpc", "teleport-proxy-grpc-mtls", + "teleport-mcp", "teleport-postgres", "teleport-mysql", "teleport-mongodb", diff --git a/lib/services/role.go b/lib/services/role.go index 151d81800ec42..b2aa02fcb4063 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -174,6 +174,9 @@ func RoleWithVersionForUser(u types.User, v string) types.Role { KubernetesLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, DatabaseServiceLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, DatabaseLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, + MCP: &types.MCPPermissions{ + Tools: []string{types.Wildcard}, + }, Rules: []types.Rule{ types.NewRule(types.KindRole, RW()), types.NewRule(types.KindAuthConnector, RW()), @@ -612,6 +615,11 @@ func ApplyTraits(r types.Role, traits map[string][]string) (types.Role, error) { outCond.Roles = apiutils.Deduplicate(outCond.Roles) outCond.Where = inCond.Where r.SetImpersonateConditions(condition, outCond) + + if mcp := r.GetMCPPermissions(condition); mcp != nil { + mcp.Tools = applyValueTraitsSlice(mcp.Tools, traits, "mcp.tools") + r.SetMCPPermissions(condition, mcp) + } } return r, nil diff --git a/lib/services/role_test.go b/lib/services/role_test.go index 1f312befc001e..957df9387173f 100644 --- a/lib/services/role_test.go +++ b/lib/services/role_test.go @@ -2965,6 +2965,8 @@ func TestApplyTraits(t *testing.T) { outKubeResources []types.KubernetesResource inGitHubPermissions []types.GitHubPermission outGitHubPermissions []types.GitHubPermission + inMCPPermissions *types.MCPPermissions + outMCPPermissions *types.MCPPermissions } tests := []struct { comment string @@ -3761,6 +3763,34 @@ func TestApplyTraits(t *testing.T) { }}, }, }, + { + comment: "MCP permissions in allow rule", + inTraits: map[string][]string{ + "mcp_tools": {"get_*", "search_files"}, + }, + allow: rule{ + inMCPPermissions: &types.MCPPermissions{ + Tools: []string{"{{internal.mcp_tools}}"}, + }, + outMCPPermissions: &types.MCPPermissions{ + Tools: []string{"get_*", "search_files"}, + }, + }, + }, + { + comment: "MCP permissions in deny rule", + inTraits: map[string][]string{ + "mcp_tools": {"get_*", "search_files"}, + }, + deny: rule{ + inMCPPermissions: &types.MCPPermissions{ + Tools: []string{"{{internal.mcp_tools}}"}, + }, + outMCPPermissions: &types.MCPPermissions{ + Tools: []string{"get_*", "search_files"}, + }, + }, + }, } for _, tt := range tests { t.Run(tt.comment, func(t *testing.T) { @@ -3793,6 +3823,7 @@ func TestApplyTraits(t *testing.T) { HostSudoers: tt.allow.inSudoers, KubernetesResources: tt.allow.inKubeResources, GitHubPermissions: tt.allow.inGitHubPermissions, + MCP: tt.allow.inMCPPermissions, }, Deny: types.RoleConditions{ Logins: tt.deny.inLogins, @@ -3815,6 +3846,7 @@ func TestApplyTraits(t *testing.T) { HostSudoers: tt.deny.outSudoers, KubernetesResources: tt.deny.inKubeResources, GitHubPermissions: tt.deny.inGitHubPermissions, + MCP: tt.deny.inMCPPermissions, }, }, } diff --git a/lib/srv/alpnproxy/common/protocols.go b/lib/srv/alpnproxy/common/protocols.go index d4cbbcfa4c190..054391d13f3ff 100644 --- a/lib/srv/alpnproxy/common/protocols.go +++ b/lib/srv/alpnproxy/common/protocols.go @@ -118,6 +118,9 @@ const ( // ProtocolPingSuffix is TLS ALPN suffix used to wrap connections with // Ping. ProtocolPingSuffix Protocol = "-ping" + + // ProtocolMCP is TLS ALPN protocol value used to indicate MCP connections. + ProtocolMCP Protocol = "teleport-mcp" ) // SupportedProtocols is the list of supported ALPN protocols. @@ -138,6 +141,7 @@ var SupportedProtocols = WithPingProtocols( ProtocolProxySSHGRPC, ProtocolProxyGRPCInsecure, ProtocolProxyGRPCSecure, + ProtocolMCP, }, DatabaseProtocols...), ) @@ -150,6 +154,18 @@ func ProtocolsToString(protocols []Protocol) []string { return out } +// ProtocolToStringsWithPing converts Protocol to a list of strings, adding the +// ping version if the protocol supports it. +func ProtocolToStringsWithPing(protocol Protocol) []string { + if HasPingSupport(protocol) { + return []string{ + string(ProtocolWithPing(protocol)), + string(protocol), + } + } + return []string{string(protocol)} +} + // ToALPNProtocol maps provided database protocol to ALPN protocol. func ToALPNProtocol(dbProtocol string) (Protocol, error) { switch dbProtocol { @@ -231,6 +247,7 @@ var DatabaseProtocols = []Protocol{ var ProtocolsWithPingSupport = append( DatabaseProtocols, ProtocolTCP, + ProtocolMCP, ) // WithPingProtocols adds Ping protocols to the list for each protocol that diff --git a/lib/srv/alpnproxy/common/protocols_test.go b/lib/srv/alpnproxy/common/protocols_test.go index f7d24c56d4a20..79a2a81183e1e 100644 --- a/lib/srv/alpnproxy/common/protocols_test.go +++ b/lib/srv/alpnproxy/common/protocols_test.go @@ -50,3 +50,8 @@ func TestIsDBTLSProtocol(t *testing.T) { require.False(t, IsDBTLSProtocol("teleport-tcp")) require.False(t, IsDBTLSProtocol("")) } + +func TestProtocolToStringsWithPing(t *testing.T) { + require.Equal(t, []string{"teleport-proxy-grpc-mtls"}, ProtocolToStringsWithPing(ProtocolProxyGRPCSecure)) + require.Equal(t, []string{"teleport-mcp-ping", "teleport-mcp"}, ProtocolToStringsWithPing(ProtocolMCP)) +} diff --git a/lib/srv/app/connections_handler.go b/lib/srv/app/connections_handler.go index 3cd60725c5797..73d986b59e536 100644 --- a/lib/srv/app/connections_handler.go +++ b/lib/srv/app/connections_handler.go @@ -54,6 +54,7 @@ import ( appazure "github.com/gravitational/teleport/lib/srv/app/azure" "github.com/gravitational/teleport/lib/srv/app/common" appgcp "github.com/gravitational/teleport/lib/srv/app/gcp" + "github.com/gravitational/teleport/lib/srv/mcp" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" awsutils "github.com/gravitational/teleport/lib/utils/aws" @@ -175,6 +176,7 @@ type ConnectionsHandler struct { httpServer *http.Server tlsConfig *tls.Config tcpServer *tcpServer + mcpServer *mcp.Server // cache holds sessionChunk objects for in-flight app sessions. cache *utils.FnCache @@ -273,6 +275,17 @@ func NewConnectionsHandler(closeContext context.Context, cfg *ConnectionsHandler } c.tcpServer = tcpServer + // Handle MCP servers. + c.mcpServer, err = mcp.NewServer(mcp.ServerConfig{ + Emitter: c.cfg.Emitter, + ParentContext: c.closeContext, + HostID: c.cfg.HostID, + AccessPoint: c.cfg.AccessPoint, + }) + if err != nil { + return nil, trace.Wrap(err) + } + // Make copy of server's TLS configuration and update it with the specific // functionality this server needs, like requiring client certificates. c.tlsConfig = CopyAndConfigureTLS(c.log, c.cfg.AccessPoint, c.cfg.TLSConfig) @@ -579,14 +592,19 @@ func (c *ConnectionsHandler) handleConnection(conn net.Conn) (func(), error) { // The behavior here is a little hard to track. To be clear here, if authorization fails // the following will occur: // 1. If the application is a TCP application, error out immediately as expected. - // 2. If the application is an HTTP application, store the error and let the HTTP handler + // 2. If the application is an MCP application, let the MCP server handler + // returns the error on first request. + // 3. If the application is an HTTP application, store the error and let the HTTP handler // serve the error directly so that it's properly converted to an HTTP status code. // This will ensure users will get a 403 when authorization fails. if err != nil { - if !app.IsTCP() { - c.setConnAuth(tlsConn, err) - } else { + switch { + case app.IsTCP(): return nil, trace.Wrap(err) + case app.IsMCP(): + return nil, trace.Wrap(c.mcpServer.HandleUnauthorizedConnection(ctx, conn, err)) + default: + c.setConnAuth(tlsConn, err) } } else { // Monitor the connection an update the context. @@ -602,17 +620,28 @@ func (c *ConnectionsHandler) handleConnection(conn net.Conn) (func(), error) { // Application access supports plain TCP connections which are handled // differently than HTTP requests from web apps. - if app.IsTCP() { + switch { + case app.IsTCP(): identity := authCtx.Identity.GetIdentity() defer cancel(nil) return nil, trace.Wrap(c.handleTCPApp(ctx, tlsConn, &identity, app)) - } - cleanup := func() { - cancel(nil) - c.deleteConnAuth(tlsConn) + case app.IsMCP(): + defer cancel(nil) + sessionCtx := mcp.SessionCtx{ + ClientConn: tlsConn, + AuthCtx: authCtx, + App: app, + } + return nil, trace.Wrap(c.mcpServer.HandleSession(ctx, sessionCtx)) + + default: + cleanup := func() { + cancel(nil) + c.deleteConnAuth(tlsConn) + } + return cleanup, trace.Wrap(c.handleHTTPApp(ctx, tlsConn)) } - return cleanup, trace.Wrap(c.handleHTTPApp(ctx, tlsConn)) } // handleHTTPApp handles connection for an HTTP application. diff --git a/lib/srv/mcp/memory.go b/lib/srv/mcp/memory.go new file mode 100644 index 0000000000000..49126399f2a79 --- /dev/null +++ b/lib/srv/mcp/memory.go @@ -0,0 +1,111 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcp + +import ( + "context" + "log/slog" + + "github.com/gravitational/trace" + "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" +) + +const ( + // InMemoryServerEnvVar enables an in-memory MCP server for testing + // purposes. The test app enables a stdio MCP server that has a + // "teleport-hello-test" tool and a "teleport-echo-test" tool. + InMemoryServerEnvVar = "TELEPORT_UNSTABLE_MCP_IN_MEMORY_SERVER" + + // InMemoryServerName is the name of the in-memory MCP server. + InMemoryServerName = "teleport-mcp-test-server" +) + +// NewInMemoryServerApp returns the app definition for the in-memory test server. +func NewInMemoryServerApp() (types.Application, error) { + app, err := types.NewAppV3(types.Metadata{ + Name: InMemoryServerName, + Labels: map[string]string{ + types.TeleportInternalLabelPrefix + "mcp-in-memory-server": "true", + }, + }, types.AppSpecV3{ + MCP: &types.MCP{ + Command: "in-memory-server", + RunAsHostUser: "in-memory-server", + }, + }) + return app, trace.Wrap(err) +} + +func isInMemoryServerApp(app types.Application) bool { + value, ok := app.GetLabel(types.TeleportInternalLabelPrefix + "mcp-in-memory-server") + return ok && value == "true" +} + +func (s *Server) handleInMemoryServerSession(ctx context.Context, sessionCtx SessionCtx) error { + s.cfg.Log.DebugContext(ctx, "Started in-memory server session") + defer s.cfg.Log.DebugContext(ctx, "Completed in-memory server session") + + server := mcpserver.NewMCPServer("hello-test-server", "1.0.0") + stdioServer := mcpserver.NewStdioServer(server) + stdioServer.SetErrorLogger(slog.NewLogLogger(s.cfg.Log.Handler(), slog.LevelDebug)) + + checkAccess := func(toolName string) bool { + return sessionCtx.AuthCtx.Checker.CheckAccess( + sessionCtx.App, + services.AccessState{ + MFAVerified: true, + }, + &services.MCPToolMatcher{ + Name: toolName, + }, + ) == nil + } + + helloTool := mcp.NewTool("teleport-hello-test", + mcp.WithDescription("this is simple hello test and it always return \"hello client\""), + ) + if checkAccess(helloTool.GetName()) { + server.AddTool(helloTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{mcp.NewTextContent("hello client")}, + }, nil + }) + } + + echoTool := mcp.NewTool("teleport-echo-test", + mcp.WithDescription("this is simple echo and it always return the input back"), + mcp.WithString("input", mcp.Required(), mcp.Description("input for echo")), + ) + if checkAccess(echoTool.GetName()) { + server.AddTool(echoTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + input, err := request.RequireString("input") + if err != nil { + return nil, trace.Wrap(err) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{mcp.NewTextContent(input)}, + }, nil + }) + } + return stdioServer.Listen(ctx, sessionCtx.ClientConn, sessionCtx.ClientConn) +} diff --git a/lib/srv/mcp/server.go b/lib/srv/mcp/server.go new file mode 100644 index 0000000000000..67b7470620266 --- /dev/null +++ b/lib/srv/mcp/server.go @@ -0,0 +1,115 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcp + +import ( + "context" + "log/slog" + "net" + "os" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/services" +) + +// AccessPoint defines functions that the MCP server requires from the caching +// client to the Auth Server. +type AccessPoint interface { + services.AuthPreferenceGetter + services.ClusterNameGetter +} + +// ServerConfig is the config for the MCP forward server. +type ServerConfig struct { + // Emitter is used for emitting audit events. + Emitter apievents.Emitter + // Log is the slog logger. + Log *slog.Logger + // ParentContext is parent's context for logging. + ParentContext context.Context + // HostID is the host ID of the teleport service. + HostID string + // AccessPoint is a caching client connected to the Auth Server. + AccessPoint AccessPoint + + clock clockwork.Clock + inMemoryServer bool +} + +// CheckAndSetDefaults checks values and sets defaults +func (c *ServerConfig) CheckAndSetDefaults() error { + if c.Emitter == nil { + return trace.BadParameter("missing Emitter") + } + if c.ParentContext == nil { + return trace.BadParameter("missing ParentContext") + } + if c.HostID == "" { + return trace.BadParameter("missing HostID") + } + if c.AccessPoint == nil { + return trace.BadParameter("missing AccessPoint") + } + if c.Log == nil { + c.Log = slog.With(teleport.ComponentKey, teleport.ComponentMCP) + } + if c.clock == nil { + c.clock = clockwork.NewRealClock() + } + c.inMemoryServer = os.Getenv(InMemoryServerEnvVar) == "true" + return nil +} + +// Server handles forwarding client connections to MCP servers. +// TODO(greedy52) add server metrics. +type Server struct { + cfg ServerConfig +} + +// NewServer creates a new Server. +func NewServer(cfg ServerConfig) (*Server, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return &Server{ + cfg: cfg, + }, nil +} + +// HandleSession handles an authorized client connection. +func (s *Server) HandleSession(ctx context.Context, sessionCtx SessionCtx) error { + if err := sessionCtx.checkAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + if s.cfg.inMemoryServer && isInMemoryServerApp(sessionCtx.App) { + return trace.Wrap(s.handleInMemoryServerSession(ctx, sessionCtx)) + } + // TODO(greedy52) handle stdio + return trace.NotImplemented("not implemented") +} + +// HandleUnauthorizedConnection handles an unauthorized client connection. +func (s *Server) HandleUnauthorizedConnection(ctx context.Context, clientConn net.Conn, authErr error) error { + // TODO(greedy52) handle stdio + return trace.NotImplemented("not implemented") +} diff --git a/lib/srv/mcp/session.go b/lib/srv/mcp/session.go new file mode 100644 index 0000000000000..1c1261d2d86ce --- /dev/null +++ b/lib/srv/mcp/session.go @@ -0,0 +1,65 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package mcp + +import ( + "net" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/tlsca" +) + +// SessionCtx contains basic information of an MCP session. +type SessionCtx struct { + // ClientConn is the incoming client connection. + ClientConn net.Conn + // AuthCtx is the authorization context. + AuthCtx *authz.Context + // App is the MCP server application being accessed. + App types.Application + // Identity is the user identity. + Identity tlsca.Identity + + // sessionID is the session ID. + sessionID session.ID +} + +func (c *SessionCtx) checkAndSetDefaults() error { + if c.ClientConn == nil { + return trace.BadParameter("missing ClientConn") + } + if c.AuthCtx == nil { + return trace.BadParameter("missing AuthCtx") + } + if c.App == nil { + return trace.BadParameter("missing App") + } + if c.Identity.Username == "" { + c.Identity = c.AuthCtx.Identity.GetIdentity() + } + if c.sessionID == "" { + // Do not use web session ID from the app route. + c.sessionID = session.NewID() + } + return nil +} diff --git a/tool/tsh/common/mcp.go b/tool/tsh/common/mcp.go index bbe574d7a802f..e9675ee9a2a1e 100644 --- a/tool/tsh/common/mcp.go +++ b/tool/tsh/common/mcp.go @@ -32,8 +32,9 @@ import ( type mcpCommands struct { dbStart *mcpDBStartCommand - config *mcpConfigCommand - list *mcpListCommand + config *mcpConfigCommand + list *mcpListCommand + connect *mcpConnectCommand } func newMCPCommands(app *kingpin.Application, cf *CLIConf) *mcpCommands { @@ -42,8 +43,9 @@ func newMCPCommands(app *kingpin.Application, cf *CLIConf) *mcpCommands { return &mcpCommands{ dbStart: newMCPDBCommand(db), - list: newMCPListCommand(mcp, cf), - config: newMCPConfigCommand(mcp, cf), + list: newMCPListCommand(mcp, cf), + config: newMCPConfigCommand(mcp, cf), + connect: newMCPConnectCommand(mcp, cf), } } diff --git a/tool/tsh/common/mcp_app.go b/tool/tsh/common/mcp_app.go index 5397d73a26b71..56b645b6c8e56 100644 --- a/tool/tsh/common/mcp_app.go +++ b/tool/tsh/common/mcp_app.go @@ -40,9 +40,20 @@ import ( "github.com/gravitational/teleport/lib/client/mcp/claude" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/tool/common" ) +func newMCPConnectCommand(parent *kingpin.CmdClause, cf *CLIConf) *mcpConnectCommand { + cmd := &mcpConnectCommand{ + CmdClause: parent.Command("connect", "Connect to an MCP server.").Hidden(), + cf: cf, + } + + cmd.Arg("name", "Name of the MCP server").Required().StringVar(&cf.AppName) + return cmd +} + func newMCPListCommand(parent *kingpin.CmdClause, cf *CLIConf) *mcpListCommand { cmd := &mcpListCommand{ CmdClause: parent.Command("ls", "List available MCP server applications."), @@ -399,3 +410,28 @@ restart your client after logging in a new tsh session. } const mcpServerAppConfigPrefix = "teleport-mcp-" + +// mcpConnectCommand implements `tsh mcp connect` command. +type mcpConnectCommand struct { + *kingpin.CmdClause + cf *CLIConf +} + +func (c *mcpConnectCommand) run() error { + _, err := initLogger(c.cf, utils.LoggingForMCP, getLoggingOptsForMCPServer(c.cf)) + if err != nil { + return trace.Wrap(err) + } + + tc, err := makeClient(c.cf) + if err != nil { + return trace.Wrap(err) + } + tc.NonInteractive = true + + serverConn, err := tc.DialMCPServer(c.cf.Context, c.cf.AppName) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(utils.ProxyConn(c.cf.Context, utils.CombinedStdio{}, serverConn)) +} diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 22c6071f651c0..bc6fd18c350a7 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -1755,6 +1755,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { err = pivCmd.agent.run(&cf) case mcpCmd.dbStart.FullCommand(): err = mcpCmd.dbStart.run(&cf) + case mcpCmd.connect.FullCommand(): + err = mcpCmd.connect.run() case mcpCmd.list.FullCommand(): err = mcpCmd.list.run() case mcpCmd.config.FullCommand():