diff --git a/api/fixtures/fixtures.go b/api/fixtures/fixtures.go
index 573e4d3c9f651..d888c4377e1ab 100644
--- a/api/fixtures/fixtures.go
+++ b/api/fixtures/fixtures.go
@@ -14,6 +14,10 @@
package fixtures
+import (
+ "time"
+)
+
const (
TLSCACertPEM = `-----BEGIN CERTIFICATE-----
MIIDKjCCAhKgAwIBAgIQJtJDJZZBkg/afM8d2ZJCTjANBgkqhkiG9w0BAQsFADBA
@@ -62,3 +66,10 @@ LJxgC1GdoEz2ilXW802H9QrdKf9GPqxwi2TVzfO6pzWkdZcmbItu+QCCFz+co+r8
+ki49FmlfbR5YVPN+8X40aLQB4xDkCHwRwTkrigzWQhIOv8NAhDA
-----END RSA PRIVATE KEY-----`
)
+
+var (
+ // TLSCACertNotBefore is the "Not before" date of TLSCACertPEM.
+ TLSCACertNotBefore = time.Date(2017, time.May, 9, 19, 40, 36, 0, time.UTC)
+ // TLSCACertNotAfter is the "Not after" date of TLSCACertPEM.
+ TLSCACertNotAfter = time.Date(2027, time.May, 7, 19, 40, 36, 0, time.UTC)
+)
diff --git a/integration/appaccess/mcp_test.go b/integration/appaccess/mcp_test.go
index cf513e33acf60..4bb56b4dcbe80 100644
--- a/integration/appaccess/mcp_test.go
+++ b/integration/appaccess/mcp_test.go
@@ -19,53 +19,53 @@
package appaccess
import (
- "context"
- "crypto/tls"
- "fmt"
"net/http"
"testing"
+ "github.com/gravitational/trace"
mcpclient "github.com/mark3labs/mcp-go/client"
mcpclienttransport "github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/require"
- "github.com/gravitational/teleport/api/client/proto"
- "github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/client"
libmcp "github.com/gravitational/teleport/lib/srv/mcp"
"github.com/gravitational/teleport/lib/utils/mcptest"
)
func testMCP(pack *Pack, t *testing.T) {
- t.Run("DialMCPServer stdio no server found", func(t *testing.T) {
+ // SaveProfile before using the TeleportClient.
+ require.NoError(t, pack.tc.SaveProfile(false))
+
+ t.Run("stdio no server found", func(t *testing.T) {
testMCPDialStdioNoServerFound(t, pack)
})
- t.Run("DialMCPServer stdio success", func(t *testing.T) {
+ t.Run("stdio success", func(t *testing.T) {
testMCPDialStdio(t, pack)
})
- t.Run("DialMCPServer stdio to sse success", func(t *testing.T) {
+ t.Run("stdio to sse success", func(t *testing.T) {
testMCPDialStdioToSSE(t, pack, "test-sse")
})
- t.Run("proxy streamable HTTP requests with TLS cert", func(t *testing.T) {
+ t.Run("proxy streamable HTTP success", func(t *testing.T) {
testMCPProxyStreamableHTTP(t, pack, "test-http")
})
}
func testMCPDialStdioNoServerFound(t *testing.T, pack *Pack) {
- require.NoError(t, pack.tc.SaveProfile(false))
-
- _, err := pack.tc.DialMCPServer(context.Background(), "not-found")
+ // Single connection dial for stdio.
+ dialer := client.NewMCPServerDialer(pack.tc, "not-found")
+ _, err := dialer.DialALPN(t.Context())
require.Error(t, err)
+ require.True(t, trace.IsNotFound(err))
}
func testMCPDialStdio(t *testing.T, pack *Pack) {
- require.NoError(t, pack.tc.SaveProfile(false))
-
- serverConn, err := pack.tc.DialMCPServer(context.Background(), libmcp.DemoServerName)
+ // Single connection dial for stdio.
+ dialer := client.NewMCPServerDialer(pack.tc, libmcp.DemoServerName)
+ serverConn, err := dialer.DialALPN(t.Context())
require.NoError(t, err)
ctx := t.Context()
@@ -80,9 +80,9 @@ func testMCPDialStdio(t *testing.T, pack *Pack) {
}
func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) {
- require.NoError(t, pack.tc.SaveProfile(false))
-
- serverConn, err := pack.tc.DialMCPServer(context.Background(), appName)
+ // Single connection dial for stdio.
+ dialer := client.NewMCPServerDialer(pack.tc, appName)
+ serverConn, err := dialer.DialALPN(t.Context())
require.NoError(t, err)
ctx := t.Context()
@@ -95,38 +95,14 @@ func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) {
}
func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) {
- require.NoError(t, pack.tc.SaveProfile(false))
-
- // Find the MCP server.
- filter := pack.tc.ResourceFilter(types.KindAppServer)
- filter.PredicateExpression = fmt.Sprintf(`name == "%s"`, appName)
- apps, err := pack.tc.ListApps(t.Context(), filter)
- require.NoError(t, err)
- require.Len(t, apps, 1)
-
- // Issue a TLS cert with app route.
- keyRing, err := pack.tc.IssueUserCertsWithMFA(t.Context(), client.ReissueParams{
- RouteToCluster: pack.rootCluster.Secrets.SiteName,
- RouteToApp: proto.RouteToApp{
- ClusterName: pack.rootCluster.Secrets.SiteName,
- Name: apps[0].GetName(),
- PublicAddr: apps[0].GetPublicAddr(),
- },
- })
- require.NoError(t, err)
- appCert, err := keyRing.AppTLSCert(appName)
- require.NoError(t, err)
-
- // Create an MCP client with app cert.
+ // Use special dialer for HTTP client.
ctx := t.Context()
+ dialer := client.NewMCPServerDialer(pack.tc, appName)
mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP(
"https://"+pack.rootCluster.Web,
mcpclienttransport.WithHTTPBasicClient(&http.Client{
Transport: &http.Transport{
- TLSClientConfig: &tls.Config{
- Certificates: []tls.Certificate{appCert},
- InsecureSkipVerify: true,
- },
+ DialTLSContext: dialer.DialContext,
},
}),
)
diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go
index 3a5fa017dc107..fb3af9cc04736 100644
--- a/lib/auth/auth_with_roles.go
+++ b/lib/auth/auth_with_roles.go
@@ -1890,18 +1890,9 @@ func (a *ServerWithRoles) ListResources(ctx context.Context, req proto.ListResou
// Perform the label/search/expr filtering here (instead of at the backend
// `ListResources`) to ensure that it will be applied only to resources
// the user has access to.
- filter := services.MatchResourceFilter{
- ResourceKind: req.ResourceType,
- Labels: req.Labels,
- SearchKeywords: req.SearchKeywords,
- }
-
- if req.PredicateExpression != "" {
- expression, err := services.NewResourceExpression(req.PredicateExpression)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- filter.PredicateExpression = expression
+ filter, err := services.MatchResourceFilterFromListResourceRequest(&req)
+ if err != nil {
+ return nil, trace.Wrap(err)
}
req.Labels = nil
diff --git a/lib/cache/cache.go b/lib/cache/cache.go
index d2a6b56136471..fbcd231f918dc 100644
--- a/lib/cache/cache.go
+++ b/lib/cache/cache.go
@@ -1672,18 +1672,9 @@ func (c *Cache) listResources(ctx context.Context, req authproto.ListResourcesRe
_, span := c.Tracer.Start(ctx, "cache/listResources")
defer span.End()
- filter := services.MatchResourceFilter{
- ResourceKind: req.ResourceType,
- Labels: req.Labels,
- SearchKeywords: req.SearchKeywords,
- }
-
- if req.PredicateExpression != "" {
- expression, err := services.NewResourceExpression(req.PredicateExpression)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- filter.PredicateExpression = expression
+ filter, err := services.MatchResourceFilterFromListResourceRequest(&req)
+ if err != nil {
+ return nil, trace.Wrap(err)
}
// Adjust page size, so it can't be empty.
diff --git a/lib/client/api.go b/lib/client/api.go
index 37abe98c5d5fe..c7a6f7f50f104 100644
--- a/lib/client/api.go
+++ b/lib/client/api.go
@@ -5537,78 +5537,6 @@ func (tc *TeleportClient) DialALPN(ctx context.Context, clientCert tls.Certifica
return tlsConn, nil
}
-// DialMCPServer makes a connection to the remote MCP server.
-func (tc *TeleportClient) DialMCPServer(ctx context.Context, appName string) (net.Conn, error) {
- ctx, span := tc.Tracer.Start(
- ctx,
- "teleportClient/DialMCPServer",
- oteltrace.WithSpanKind(oteltrace.SpanKindClient),
- oteltrace.WithAttributes(
- attribute.String("app", appName),
- ),
- )
- defer span.End()
-
- apps, err := tc.ListApps(ctx, &proto.ListResourcesRequest{
- ResourceType: types.KindAppServer,
- Namespace: apidefaults.Namespace,
- PredicateExpression: fmt.Sprintf("name == %q", strings.TrimSpace(appName)),
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
- switch len(apps) {
- case 0:
- return nil, trace.NotFound("no MCP servers found")
- case 1:
- default:
- log.WarnContext(ctx, "multiple apps found, using the first one")
- }
- if !apps[0].IsMCP() {
- return nil, trace.BadParameter("app %q is not a MCP server", appName)
- }
-
- // TODO(greedy52) support streamable HTTP for "tsh mcp connect" before
- // release.
- if transport := types.GetMCPServerTransportType(apps[0].GetURI()); transport == types.MCPTransportHTTP {
- return nil, trace.NotImplemented("MCP support for %s is not yet implemented", transport)
- }
-
- cert, err := tc.issueMCPCertWithMFA(ctx, apps[0])
- if err != nil {
- return nil, trace.Wrap(err)
- }
- return tc.DialALPN(ctx, cert, alpncommon.ProtocolMCP)
-}
-
-func (tc *TeleportClient) issueMCPCertWithMFA(ctx context.Context, mcpServer types.Application) (tls.Certificate, error) {
- profile, err := tc.ProfileStatus()
- if err != nil {
- return tls.Certificate{}, trace.Wrap(err)
- }
-
- appCertParams := ReissueParams{
- RouteToCluster: tc.SiteName,
- RouteToApp: proto.RouteToApp{
- Name: mcpServer.GetName(),
- PublicAddr: mcpServer.GetPublicAddr(),
- ClusterName: tc.SiteName,
- URI: mcpServer.GetURI(),
- },
- AccessRequests: profile.ActiveRequests,
- }
-
- // Do NOT write the keyring to avoid race condition when AI clients run
- // multiple tsh at the same time.
- keyRing, err := tc.IssueUserCertsWithMFA(ctx, appCertParams)
- if err != nil {
- return tls.Certificate{}, trace.Wrap(err)
- }
-
- cert, err := keyRing.AppTLSCert(mcpServer.GetName())
- return cert, trace.Wrap(err)
-}
-
// DialDatabase makes a remote connection to the database.
//
// TODO(gabrielcorado): support acccess requests connections.
@@ -5648,6 +5576,11 @@ func (tc *TeleportClient) DialDatabase(ctx context.Context, route proto.RouteToD
return tc.DialALPN(ctx, cert, alpnProtocol)
}
+// GetSiteName returns the cluster name this client instance is targeting.
+func (tc *TeleportClient) GetSiteName() string {
+ return tc.SiteName
+}
+
// CalculateSSHLogins returns the subset of the allowedLogins that exist in
// the principals of the identity. This is required because SSH authorization
// only allows using a login that exists in the certificates valid principals.
diff --git a/lib/client/mcp.go b/lib/client/mcp.go
new file mode 100644
index 0000000000000..f7ca3cfa0c0ad
--- /dev/null
+++ b/lib/client/mcp.go
@@ -0,0 +1,178 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package client
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "log/slog"
+ "net"
+ "strings"
+ "sync"
+
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/api/client/proto"
+ apidefaults "github.com/gravitational/teleport/api/defaults"
+ "github.com/gravitational/teleport/api/types"
+ alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+// MCPServerDialerClient defines a subset of TeleportClient functions that are
+// used by MCPServerDialer.
+type MCPServerDialerClient interface {
+ DialALPN(context.Context, tls.Certificate, alpncommon.Protocol) (net.Conn, error)
+ ListApps(context.Context, *proto.ListResourcesRequest) ([]types.Application, error)
+ IssueUserCertsWithMFA(context.Context, ReissueParams) (*KeyRing, error)
+ ProfileStatus() (*ProfileStatus, error)
+ GetSiteName() string
+}
+
+// MCPServerDialer is a wrapper of TeleportClient for handling MCP connections
+// to proxy.
+type MCPServerDialer struct {
+ client MCPServerDialerClient
+ appName string
+
+ mu sync.Mutex
+ app types.Application
+ cert tls.Certificate
+ clock clockwork.Clock
+ logger *slog.Logger
+}
+
+// NewMCPServerDialer creates a new MCPServerDialer.
+func NewMCPServerDialer(client MCPServerDialerClient, appName string) *MCPServerDialer {
+ return &MCPServerDialer{
+ client: client,
+ appName: appName,
+ clock: clockwork.NewRealClock(),
+ logger: slog.With(
+ teleport.ComponentKey,
+ teleport.Component(teleport.ComponentMCP, "dialer"),
+ ),
+ }
+}
+
+// GetApp returns the types.Application for the associated MCP server.
+func (d *MCPServerDialer) GetApp(ctx context.Context) (types.Application, error) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.getAppLocked(ctx)
+}
+
+// DialALPN dials Teleport Proxy to establish a TLS routing connection for the
+// MCP server.
+func (d *MCPServerDialer) DialALPN(ctx context.Context) (net.Conn, error) {
+ app, err := d.getAppLocked(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ cert, err := d.getCertLocked(ctx, app)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ switch types.GetMCPServerTransportType(app.GetURI()) {
+ case types.MCPTransportHTTP:
+ return d.client.DialALPN(ctx, cert, alpncommon.ProtocolHTTP)
+ default:
+ return d.client.DialALPN(ctx, cert, alpncommon.ProtocolMCP)
+ }
+}
+
+// DialContext is a simple wrapper of DialALPN. This function is defined to be
+// compatible with common context dialer interfaces.
+func (d *MCPServerDialer) DialContext(ctx context.Context, _, _ string) (net.Conn, error) {
+ return d.DialALPN(ctx)
+}
+
+func (d *MCPServerDialer) getAppLocked(ctx context.Context) (types.Application, error) {
+ if d.app != nil {
+ return d.app, nil
+ }
+
+ apps, err := d.client.ListApps(ctx, &proto.ListResourcesRequest{
+ ResourceType: types.KindAppServer,
+ Namespace: apidefaults.Namespace,
+ PredicateExpression: fmt.Sprintf("name == %q", strings.TrimSpace(d.appName)),
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ switch len(apps) {
+ case 0:
+ return nil, trace.NotFound("no MCP servers found")
+ case 1:
+ default:
+ d.logger.WarnContext(ctx, "multiple appServers found, using the first one")
+ }
+ if !apps[0].IsMCP() {
+ return nil, trace.BadParameter("app %q is not a MCP server", d.appName)
+ }
+
+ d.app = apps[0]
+ d.logger.InfoContext(ctx, "Successfully fetched app",
+ "name", d.app.GetName(),
+ "transport", types.GetMCPServerTransportType(d.app.GetURI()),
+ )
+ return d.app, nil
+}
+
+func (d *MCPServerDialer) getCertLocked(ctx context.Context, mcpServer types.Application) (tls.Certificate, error) {
+ if err := utils.VerifyTLSCertLeafExpiry(d.cert, d.clock); err == nil {
+ return d.cert, nil
+ }
+
+ d.logger.InfoContext(ctx, "Reissuing certificate", "name", mcpServer.GetName())
+ profile, err := d.client.ProfileStatus()
+ if err != nil {
+ return tls.Certificate{}, trace.Wrap(err)
+ }
+
+ appCertParams := ReissueParams{
+ RouteToCluster: d.client.GetSiteName(),
+ RouteToApp: proto.RouteToApp{
+ Name: mcpServer.GetName(),
+ PublicAddr: mcpServer.GetPublicAddr(),
+ ClusterName: d.client.GetSiteName(),
+ URI: mcpServer.GetURI(),
+ },
+ AccessRequests: profile.ActiveRequests,
+ }
+
+ // Do NOT write the keyring to avoid race condition when AI clients run
+ // multiple tsh at the same time.
+ keyRing, err := d.client.IssueUserCertsWithMFA(ctx, appCertParams)
+ if err != nil {
+ return tls.Certificate{}, trace.Wrap(err)
+ }
+
+ cert, err := keyRing.AppTLSCert(mcpServer.GetName())
+ if err != nil {
+ return tls.Certificate{}, trace.Wrap(err)
+ }
+
+ d.logger.InfoContext(ctx, "Successfully issued certificate", "name", mcpServer.GetName())
+ d.cert = cert
+ return d.cert, nil
+}
diff --git a/lib/client/mcp_test.go b/lib/client/mcp_test.go
new file mode 100644
index 0000000000000..0a6f82fa5e9c1
--- /dev/null
+++ b/lib/client/mcp_test.go
@@ -0,0 +1,240 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package client
+
+import (
+ "cmp"
+ "context"
+ "crypto/tls"
+ "net"
+ "slices"
+ "testing"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/cryptosuites"
+ "github.com/gravitational/teleport/lib/services"
+ alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
+ "github.com/gravitational/teleport/lib/tlsca"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+type mockALPNConn struct {
+ net.Conn
+ cert tls.Certificate
+ protocol alpncommon.Protocol
+}
+
+func newFakeALPNConn(protocol alpncommon.Protocol, cert tls.Certificate) *mockALPNConn {
+ return &mockALPNConn{
+ Conn: nil, // currently not using the net.conn
+ cert: cert,
+ protocol: protocol,
+ }
+}
+
+func (c *mockALPNConn) getCertSerialNumber() string {
+ leaf, err := utils.TLSCertLeaf(c.cert)
+ if err != nil {
+ return ""
+ }
+ return leaf.SerialNumber.String()
+}
+
+type mockMCPServerDialerClient struct {
+ appServers types.AppServers
+ tlsCA *tlsca.CertAuthority
+ clock *clockwork.FakeClock
+ identity tlsca.Identity
+}
+
+func (m *mockMCPServerDialerClient) DialALPN(_ context.Context, cert tls.Certificate, protocol alpncommon.Protocol) (net.Conn, error) {
+ if err := utils.VerifyTLSCertLeafExpiry(cert, m.clock); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return newFakeALPNConn(protocol, cert), nil
+}
+
+func (m *mockMCPServerDialerClient) ListApps(_ context.Context, req *proto.ListResourcesRequest) ([]types.Application, error) {
+ filter, err := services.MatchResourceFilterFromListResourceRequest(req)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ appServers, err := services.MatchResourcesByFilters(m.appServers, filter)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return slices.Collect(appServers.Applications()), nil
+}
+
+func (m *mockMCPServerDialerClient) IssueUserCertsWithMFA(_ context.Context, params ReissueParams) (*KeyRing, error) {
+ if params.RouteToApp.Name == "" {
+ return nil, trace.BadParameter("missing app name")
+ }
+ if params.RouteToCluster != m.GetSiteName() {
+ return nil, trace.BadParameter("wrong cluster")
+ }
+ subject, err := m.identity.Subject()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ tlsKey, err := cryptosuites.GeneratePrivateKeyWithAlgorithm(cryptosuites.Ed25519)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ tlsCert, err := m.tlsCA.GenerateCertificate(tlsca.CertificateRequest{
+ Clock: m.clock,
+ PublicKey: tlsKey.Public(),
+ Subject: subject,
+ NotAfter: m.clock.Now().Add(cmp.Or(params.TTL, time.Minute)),
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &KeyRing{
+ AppTLSCredentials: map[string]TLSCredential{
+ params.RouteToApp.Name: {
+ PrivateKey: tlsKey,
+ Cert: tlsCert,
+ },
+ },
+ }, nil
+}
+
+func (m *mockMCPServerDialerClient) ProfileStatus() (*ProfileStatus, error) {
+ return &ProfileStatus{}, nil
+}
+
+func (m *mockMCPServerDialerClient) GetSiteName() string {
+ return "teleport.example.com"
+}
+
+func TestMCPServerDialer(t *testing.T) {
+ tlsCA, _, err := newSelfSignedCA(CAPriv, "localhost")
+ require.NoError(t, err)
+
+ mockClient := &mockMCPServerDialerClient{
+ appServers: types.AppServers{
+ mustMakeAppServer(t, "http-app", "http://localhost:1234"),
+ mustMakeAppServer(t, "http-mcp", "mcp+http://localhost:1234"),
+ mustMakeAppServer(t, "sse-mcp", "mcp+sse+http://localhost:1234"),
+ },
+ clock: clockwork.NewFakeClock(),
+ tlsCA: tlsCA,
+ identity: tlsca.Identity{
+ Username: "test",
+ },
+ }
+
+ t.Run("GetApp", func(t *testing.T) {
+ tests := []struct {
+ name string
+ checkResult require.ErrorAssertionFunc
+ }{
+ {
+ name: "http-app",
+ checkResult: require.Error,
+ },
+ {
+ name: "http-mcp",
+ checkResult: require.NoError,
+ },
+ {
+ name: "sse-mcp",
+ checkResult: require.NoError,
+ },
+ {
+ name: "not-found",
+ checkResult: require.Error,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ dialer := NewMCPServerDialer(mockClient, test.name)
+ _, err := dialer.GetApp(t.Context())
+ test.checkResult(t, err)
+ })
+ }
+ })
+
+ t.Run("DialALPN", func(t *testing.T) {
+ tests := []struct {
+ name string
+ wantALPN alpncommon.Protocol
+ }{
+ {
+ name: "http-mcp",
+ wantALPN: alpncommon.ProtocolHTTP,
+ },
+ {
+ name: "sse-mcp",
+ wantALPN: alpncommon.ProtocolMCP,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ dialer := NewMCPServerDialer(mockClient, test.name)
+ dialer.clock = mockClient.clock
+
+ // Verify ALPN used.
+ firstConn, err := dialer.DialALPN(t.Context())
+ require.NoError(t, err)
+ firstALPNConn, ok := firstConn.(*mockALPNConn)
+ require.True(t, ok)
+ require.Equal(t, test.wantALPN, firstALPNConn.protocol)
+
+ // Advance time to trigger issue cert again.
+ mockClient.clock.Advance(time.Hour)
+ secondConn, err := dialer.DialALPN(t.Context())
+ require.NoError(t, err)
+ secondALPNConn, ok := secondConn.(*mockALPNConn)
+ require.True(t, ok)
+
+ // Double-check a new cert is issued.
+ firstSerial := firstALPNConn.getCertSerialNumber()
+ secondSerial := secondALPNConn.getCertSerialNumber()
+ require.NotEmpty(t, firstSerial)
+ require.NotEqual(t, firstSerial, secondSerial)
+ })
+ }
+ })
+}
+
+func mustMakeAppServer(t *testing.T, name, uri string) types.AppServer {
+ t.Helper()
+ app := mustMakeApp(t, name, uri)
+ appServer, err := types.NewAppServerV3FromApp(app, "test", "test")
+ require.NoError(t, err)
+ return appServer
+}
+
+func mustMakeApp(t *testing.T, name, uri string) *types.AppV3 {
+ t.Helper()
+ app, err := types.NewAppV3(
+ types.Metadata{Name: name},
+ types.AppSpecV3{URI: uri},
+ )
+ require.NoError(t, err)
+ return app
+}
diff --git a/lib/services/matchers.go b/lib/services/matchers.go
index d2f0329b8798f..24ec87329f3cb 100644
--- a/lib/services/matchers.go
+++ b/lib/services/matchers.go
@@ -26,6 +26,7 @@ import (
"github.com/gravitational/trace"
+ "github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
azureutils "github.com/gravitational/teleport/api/utils/azure"
@@ -145,6 +146,20 @@ func (r *resourceWithTargetHealth) GetTargetHealthStatus() types.TargetHealthSta
// only applies to resource Application.
type ResourceSeenKey struct{ name, kind, addr string }
+// MatchResourcesByFilters filters provided resources with profiled filter.
+func MatchResourcesByFilters[E types.ResourceWithLabels, S ~[]E](all S, filter MatchResourceFilter) (S, error) {
+ var filtered S
+ for _, r := range all {
+ match, err := MatchResourceByFilters(r, filter, nil)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ } else if match {
+ filtered = append(filtered, r)
+ }
+ }
+ return filtered, nil
+}
+
// MatchResourceByFilters returns true if all filter values given matched against the resource.
//
// If no filters were provided, we will treat that as a match.
@@ -315,3 +330,21 @@ func (m *MatchResourceFilter) IsSimple() bool {
m.PredicateExpression == nil &&
len(m.Kinds) == 0
}
+
+// MatchResourceFilterFromListResourceRequest converts a
+// proto.ListResourcesRequest to MatchResourceFilter.
+func MatchResourceFilterFromListResourceRequest(req *proto.ListResourcesRequest) (MatchResourceFilter, error) {
+ filter := MatchResourceFilter{
+ ResourceKind: req.ResourceType,
+ Labels: req.Labels,
+ SearchKeywords: req.SearchKeywords,
+ }
+ if req.PredicateExpression != "" {
+ expression, err := NewResourceExpression(req.PredicateExpression)
+ if err != nil {
+ return MatchResourceFilter{}, trace.Wrap(err)
+ }
+ filter.PredicateExpression = expression
+ }
+ return filter, nil
+}
diff --git a/lib/services/matchers_test.go b/lib/services/matchers_test.go
index 4328d1ad42ed5..d5d0867a1569c 100644
--- a/lib/services/matchers_test.go
+++ b/lib/services/matchers_test.go
@@ -749,6 +749,38 @@ func TestResourceMatchersToTypes(t *testing.T) {
}
}
+func TestMatchResourcesByFilters(t *testing.T) {
+ appServers := make(types.AppServers, 5)
+ oddOrEven := func(i int) string {
+ if i%2 == 1 {
+ return "odd"
+ }
+ return "even"
+ }
+ for i := range len(appServers) {
+ app, err := types.NewAppV3(types.Metadata{
+ Name: fmt.Sprintf("app-%d", i),
+ Labels: map[string]string{"group": oddOrEven(i)},
+ }, types.AppSpecV3{
+ URI: "http://localhost:8888",
+ })
+ require.NoError(t, err)
+ appServers[i] = newAppServerFromApp(t, app)
+ }
+
+ evenAppServers, err := MatchResourcesByFilters(appServers, MatchResourceFilter{
+ ResourceKind: types.KindAppServer,
+ Labels: map[string]string{"group": "even"},
+ })
+
+ require.NoError(t, err)
+ require.IsType(t, types.AppServers{}, evenAppServers)
+ require.Equal(t,
+ []string{"app-0", "app-2", "app-4"},
+ slices.Collect(types.ResourceNames(evenAppServers)),
+ )
+}
+
func newMCPServerApp(t *testing.T, name string) *types.AppV3 {
t.Helper()
app, err := types.NewAppV3(types.Metadata{
diff --git a/lib/utils/certs.go b/lib/utils/certs.go
index b18f1afb00cad..798ec7f2be9b0 100644
--- a/lib/utils/certs.go
+++ b/lib/utils/certs.go
@@ -142,6 +142,15 @@ func VerifyCertificateExpiry(c *x509.Certificate, clock clockwork.Clock) error {
return nil
}
+// VerifyTLSCertLeafExpiry checks a TLS certificate's expiration status.
+func VerifyTLSCertLeafExpiry(cert tls.Certificate, clock clockwork.Clock) error {
+ leaf, err := TLSCertLeaf(cert)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ return trace.Wrap(VerifyCertificateExpiry(leaf, clock))
+}
+
// VerifyCertificateChain reads in chain of certificates and makes sure the
// chain from leaf to root is valid. This ensures that clients (web browsers
// and CLI) won't have problem validating the chain.
diff --git a/lib/utils/certs_test.go b/lib/utils/certs_test.go
index 672e8336d9af3..84e445b9fd430 100644
--- a/lib/utils/certs_test.go
+++ b/lib/utils/certs_test.go
@@ -19,13 +19,17 @@
package utils
import (
+ "crypto/tls"
"runtime"
"testing"
+ "time"
"github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/constants"
+ "github.com/gravitational/teleport/api/fixtures"
)
func TestRejectsInvalidPEMData(t *testing.T) {
@@ -60,3 +64,49 @@ func TestNewCertPoolFromPath(t *testing.T) {
//nolint:staticcheck // Pool not returned by SystemCertPool
require.Len(t, pool.Subjects(), 1)
}
+
+func TestVerifyTLSCertLeafExpiry(t *testing.T) {
+ tlsCert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM))
+ require.NoError(t, err)
+ emptyCert := tls.Certificate{}
+
+ tests := []struct {
+ name string
+ input tls.Certificate
+ fakeTime time.Time
+ checkResult require.ErrorAssertionFunc
+ }{
+ {
+ name: "empty",
+ input: emptyCert,
+ fakeTime: time.Now(),
+ checkResult: require.Error,
+ },
+ {
+ name: "valid",
+ input: tlsCert,
+ fakeTime: fixtures.TLSCACertNotAfter.Add(-time.Minute),
+ checkResult: require.NoError,
+ },
+ {
+ name: "not valid yet",
+ input: tlsCert,
+ fakeTime: fixtures.TLSCACertNotBefore.Add(-time.Minute),
+ checkResult: require.Error,
+ },
+ {
+ name: "expired",
+ input: tlsCert,
+ fakeTime: fixtures.TLSCACertNotAfter.Add(time.Minute),
+ checkResult: require.Error,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ clock := clockwork.NewFakeClockAt(tt.fakeTime)
+ err := VerifyTLSCertLeafExpiry(tt.input, clock)
+ tt.checkResult(t, err)
+ })
+ }
+}
diff --git a/tool/tsh/common/db_exec_test.go b/tool/tsh/common/db_exec_test.go
index 8e4edaf57b82b..4796d5e676726 100644
--- a/tool/tsh/common/db_exec_test.go
+++ b/tool/tsh/common/db_exec_test.go
@@ -334,29 +334,11 @@ func (c *fakeDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, re
}
func matchResources[R types.ResourceWithLabels](req *proto.ListResourcesRequest, s []R) ([]R, error) {
- filter := services.MatchResourceFilter{
- ResourceKind: req.ResourceType,
- Labels: req.Labels,
- SearchKeywords: req.SearchKeywords,
- }
- if req.PredicateExpression != "" {
- expression, err := services.NewResourceExpression(req.PredicateExpression)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- filter.PredicateExpression = expression
- }
-
- var filtered []R
- for _, r := range s {
- match, err := services.MatchResourceByFilters(r, filter, nil)
- if err != nil {
- return nil, trace.Wrap(err)
- } else if match {
- filtered = append(filtered, r)
- }
+ filter, err := services.MatchResourceFilterFromListResourceRequest(req)
+ if err != nil {
+ return nil, trace.Wrap(err)
}
- return filtered, nil
+ return services.MatchResourcesByFilters(s, filter)
}
func mustMakeDatabaseServer(t *testing.T, db types.Database) types.DatabaseServer {
diff --git a/tool/tsh/common/mcp_app.go b/tool/tsh/common/mcp_app.go
index 0cca38e389df5..41377f527bc4f 100644
--- a/tool/tsh/common/mcp_app.go
+++ b/tool/tsh/common/mcp_app.go
@@ -471,13 +471,14 @@ func (c *mcpConnectCommand) run() error {
}
tc.NonInteractive = true
+ dialer := client.NewMCPServerDialer(tc, c.cf.AppName)
if c.autoReconnect {
return clientmcp.ProxyStdioConnWithAutoReconnect(
c.cf.Context,
clientmcp.ProxyStdioConnWithAutoReconnectConfig{
ClientStdio: utils.CombinedStdio{},
DialServer: func(ctx context.Context) (io.ReadWriteCloser, error) {
- conn, err := tc.DialMCPServer(ctx, c.cf.AppName)
+ conn, err := dialer.DialALPN(ctx)
return conn, trace.Wrap(err)
},
MakeReconnectUserMessage: makeMCPReconnectUserMessage,
@@ -485,7 +486,7 @@ func (c *mcpConnectCommand) run() error {
)
}
- serverConn, err := tc.DialMCPServer(c.cf.Context, c.cf.AppName)
+ serverConn, err := dialer.DialALPN(c.cf.Context)
if err != nil {
return trace.Wrap(err)
}