Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions api/fixtures/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

package fixtures

import (
"time"
)

const (
TLSCACertPEM = `-----BEGIN CERTIFICATE-----
MIIDKjCCAhKgAwIBAgIQJtJDJZZBkg/afM8d2ZJCTjANBgkqhkiG9w0BAQsFADBA
Expand Down Expand Up @@ -62,3 +66,10 @@ LJxgC1GdoEz2ilXW802H9QrdKf9GPqxwi2TVzfO6pzWkdZcmbItu+QCCFz+co+r8
+ki49FmlfbR5YVPN+8X40aLQB4xDkCHwRwTkrigzWQhIOv8NAhDA
-----END RSA PRIVATE KEY-----`
)

var (
// TLSCACertNotBefore is the "Not before" date of TLSCACertPEM.
TLSCACertNotBefore = time.Date(2017, time.May, 9, 19, 40, 36, 0, time.UTC)
// TLSCACertNotAfter is the "Not after" date of TLSCACertPEM.
TLSCACertNotAfter = time.Date(2027, time.May, 7, 19, 40, 36, 0, time.UTC)
)
66 changes: 21 additions & 45 deletions integration/appaccess/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,53 @@
package appaccess

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"testing"

"github.com/gravitational/trace"
mcpclient "github.com/mark3labs/mcp-go/client"
mcpclienttransport "github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/client"
libmcp "github.com/gravitational/teleport/lib/srv/mcp"
"github.com/gravitational/teleport/lib/utils/mcptest"
)

func testMCP(pack *Pack, t *testing.T) {
t.Run("DialMCPServer stdio no server found", func(t *testing.T) {
// SaveProfile before using the TeleportClient.
require.NoError(t, pack.tc.SaveProfile(false))

t.Run("stdio no server found", func(t *testing.T) {
testMCPDialStdioNoServerFound(t, pack)
})

t.Run("DialMCPServer stdio success", func(t *testing.T) {
t.Run("stdio success", func(t *testing.T) {
testMCPDialStdio(t, pack)
})

t.Run("DialMCPServer stdio to sse success", func(t *testing.T) {
t.Run("stdio to sse success", func(t *testing.T) {
testMCPDialStdioToSSE(t, pack, "test-sse")
})

t.Run("proxy streamable HTTP requests with TLS cert", func(t *testing.T) {
t.Run("proxy streamable HTTP success", func(t *testing.T) {
testMCPProxyStreamableHTTP(t, pack, "test-http")
})
}

func testMCPDialStdioNoServerFound(t *testing.T, pack *Pack) {
require.NoError(t, pack.tc.SaveProfile(false))

_, err := pack.tc.DialMCPServer(context.Background(), "not-found")
// Single connection dial for stdio.
dialer := client.NewMCPServerDialer(pack.tc, "not-found")
_, err := dialer.DialALPN(t.Context())
require.Error(t, err)
require.True(t, trace.IsNotFound(err))
}

func testMCPDialStdio(t *testing.T, pack *Pack) {
require.NoError(t, pack.tc.SaveProfile(false))

serverConn, err := pack.tc.DialMCPServer(context.Background(), libmcp.DemoServerName)
// Single connection dial for stdio.
dialer := client.NewMCPServerDialer(pack.tc, libmcp.DemoServerName)
serverConn, err := dialer.DialALPN(t.Context())
require.NoError(t, err)

ctx := t.Context()
Expand All @@ -80,9 +80,9 @@ func testMCPDialStdio(t *testing.T, pack *Pack) {
}

func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) {
require.NoError(t, pack.tc.SaveProfile(false))

serverConn, err := pack.tc.DialMCPServer(context.Background(), appName)
// Single connection dial for stdio.
dialer := client.NewMCPServerDialer(pack.tc, appName)
serverConn, err := dialer.DialALPN(t.Context())
require.NoError(t, err)

ctx := t.Context()
Expand All @@ -95,38 +95,14 @@ func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) {
}

func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) {
require.NoError(t, pack.tc.SaveProfile(false))

// Find the MCP server.
filter := pack.tc.ResourceFilter(types.KindAppServer)
filter.PredicateExpression = fmt.Sprintf(`name == "%s"`, appName)
apps, err := pack.tc.ListApps(t.Context(), filter)
require.NoError(t, err)
require.Len(t, apps, 1)

// Issue a TLS cert with app route.
keyRing, err := pack.tc.IssueUserCertsWithMFA(t.Context(), client.ReissueParams{
RouteToCluster: pack.rootCluster.Secrets.SiteName,
RouteToApp: proto.RouteToApp{
ClusterName: pack.rootCluster.Secrets.SiteName,
Name: apps[0].GetName(),
PublicAddr: apps[0].GetPublicAddr(),
},
})
require.NoError(t, err)
appCert, err := keyRing.AppTLSCert(appName)
require.NoError(t, err)

// Create an MCP client with app cert.
// Use special dialer for HTTP client.
ctx := t.Context()
dialer := client.NewMCPServerDialer(pack.tc, appName)
mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP(
"https://"+pack.rootCluster.Web,
mcpclienttransport.WithHTTPBasicClient(&http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{appCert},
InsecureSkipVerify: true,
},
DialTLSContext: dialer.DialContext,
},
}),
)
Expand Down
15 changes: 3 additions & 12 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -1890,18 +1890,9 @@ func (a *ServerWithRoles) ListResources(ctx context.Context, req proto.ListResou
// Perform the label/search/expr filtering here (instead of at the backend
// `ListResources`) to ensure that it will be applied only to resources
// the user has access to.
filter := services.MatchResourceFilter{
ResourceKind: req.ResourceType,
Labels: req.Labels,
SearchKeywords: req.SearchKeywords,
}

if req.PredicateExpression != "" {
expression, err := services.NewResourceExpression(req.PredicateExpression)
if err != nil {
return nil, trace.Wrap(err)
}
filter.PredicateExpression = expression
filter, err := services.MatchResourceFilterFromListResourceRequest(&req)
if err != nil {
return nil, trace.Wrap(err)
}

req.Labels = nil
Expand Down
15 changes: 3 additions & 12 deletions lib/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -1672,18 +1672,9 @@ func (c *Cache) listResources(ctx context.Context, req authproto.ListResourcesRe
_, span := c.Tracer.Start(ctx, "cache/listResources")
defer span.End()

filter := services.MatchResourceFilter{
ResourceKind: req.ResourceType,
Labels: req.Labels,
SearchKeywords: req.SearchKeywords,
}

if req.PredicateExpression != "" {
expression, err := services.NewResourceExpression(req.PredicateExpression)
if err != nil {
return nil, trace.Wrap(err)
}
filter.PredicateExpression = expression
filter, err := services.MatchResourceFilterFromListResourceRequest(&req)
if err != nil {
return nil, trace.Wrap(err)
}

// Adjust page size, so it can't be empty.
Expand Down
77 changes: 5 additions & 72 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5537,78 +5537,6 @@ func (tc *TeleportClient) DialALPN(ctx context.Context, clientCert tls.Certifica
return tlsConn, nil
}

// DialMCPServer makes a connection to the remote MCP server.
func (tc *TeleportClient) DialMCPServer(ctx context.Context, appName string) (net.Conn, error) {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/DialMCPServer",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
attribute.String("app", appName),
),
)
defer span.End()

apps, err := tc.ListApps(ctx, &proto.ListResourcesRequest{
ResourceType: types.KindAppServer,
Namespace: apidefaults.Namespace,
PredicateExpression: fmt.Sprintf("name == %q", strings.TrimSpace(appName)),
})
if err != nil {
return nil, trace.Wrap(err)
}
switch len(apps) {
case 0:
return nil, trace.NotFound("no MCP servers found")
case 1:
default:
log.WarnContext(ctx, "multiple apps found, using the first one")
}
if !apps[0].IsMCP() {
return nil, trace.BadParameter("app %q is not a MCP server", appName)
}

// TODO(greedy52) support streamable HTTP for "tsh mcp connect" before
// release.
if transport := types.GetMCPServerTransportType(apps[0].GetURI()); transport == types.MCPTransportHTTP {
return nil, trace.NotImplemented("MCP support for %s is not yet implemented", transport)
}

cert, err := tc.issueMCPCertWithMFA(ctx, apps[0])
if err != nil {
return nil, trace.Wrap(err)
}
return tc.DialALPN(ctx, cert, alpncommon.ProtocolMCP)
}

func (tc *TeleportClient) issueMCPCertWithMFA(ctx context.Context, mcpServer types.Application) (tls.Certificate, error) {
profile, err := tc.ProfileStatus()
if err != nil {
return tls.Certificate{}, trace.Wrap(err)
}

appCertParams := ReissueParams{
RouteToCluster: tc.SiteName,
RouteToApp: proto.RouteToApp{
Name: mcpServer.GetName(),
PublicAddr: mcpServer.GetPublicAddr(),
ClusterName: tc.SiteName,
URI: mcpServer.GetURI(),
},
AccessRequests: profile.ActiveRequests,
}

// Do NOT write the keyring to avoid race condition when AI clients run
// multiple tsh at the same time.
keyRing, err := tc.IssueUserCertsWithMFA(ctx, appCertParams)
if err != nil {
return tls.Certificate{}, trace.Wrap(err)
}

cert, err := keyRing.AppTLSCert(mcpServer.GetName())
return cert, trace.Wrap(err)
}

// DialDatabase makes a remote connection to the database.
//
// TODO(gabrielcorado): support acccess requests connections.
Expand Down Expand Up @@ -5648,6 +5576,11 @@ func (tc *TeleportClient) DialDatabase(ctx context.Context, route proto.RouteToD
return tc.DialALPN(ctx, cert, alpnProtocol)
}

// GetSiteName returns the cluster name this client instance is targeting.
func (tc *TeleportClient) GetSiteName() string {
return tc.SiteName
}

// CalculateSSHLogins returns the subset of the allowedLogins that exist in
// the principals of the identity. This is required because SSH authorization
// only allows using a login that exists in the certificates valid principals.
Expand Down
Loading
Loading