diff --git a/api/fixtures/fixtures.go b/api/fixtures/fixtures.go index 573e4d3c9f651..d888c4377e1ab 100644 --- a/api/fixtures/fixtures.go +++ b/api/fixtures/fixtures.go @@ -14,6 +14,10 @@ package fixtures +import ( + "time" +) + const ( TLSCACertPEM = `-----BEGIN CERTIFICATE----- MIIDKjCCAhKgAwIBAgIQJtJDJZZBkg/afM8d2ZJCTjANBgkqhkiG9w0BAQsFADBA @@ -62,3 +66,10 @@ LJxgC1GdoEz2ilXW802H9QrdKf9GPqxwi2TVzfO6pzWkdZcmbItu+QCCFz+co+r8 +ki49FmlfbR5YVPN+8X40aLQB4xDkCHwRwTkrigzWQhIOv8NAhDA -----END RSA PRIVATE KEY-----` ) + +var ( + // TLSCACertNotBefore is the "Not before" date of TLSCACertPEM. + TLSCACertNotBefore = time.Date(2017, time.May, 9, 19, 40, 36, 0, time.UTC) + // TLSCACertNotAfter is the "Not after" date of TLSCACertPEM. + TLSCACertNotAfter = time.Date(2027, time.May, 7, 19, 40, 36, 0, time.UTC) +) diff --git a/integration/appaccess/mcp_test.go b/integration/appaccess/mcp_test.go index cf513e33acf60..4bb56b4dcbe80 100644 --- a/integration/appaccess/mcp_test.go +++ b/integration/appaccess/mcp_test.go @@ -19,53 +19,53 @@ package appaccess import ( - "context" - "crypto/tls" - "fmt" "net/http" "testing" + "github.com/gravitational/trace" mcpclient "github.com/mark3labs/mcp-go/client" mcpclienttransport "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" - "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/client" libmcp "github.com/gravitational/teleport/lib/srv/mcp" "github.com/gravitational/teleport/lib/utils/mcptest" ) func testMCP(pack *Pack, t *testing.T) { - t.Run("DialMCPServer stdio no server found", func(t *testing.T) { + // SaveProfile before using the TeleportClient. + require.NoError(t, pack.tc.SaveProfile(false)) + + t.Run("stdio no server found", func(t *testing.T) { testMCPDialStdioNoServerFound(t, pack) }) - t.Run("DialMCPServer stdio success", func(t *testing.T) { + t.Run("stdio success", func(t *testing.T) { testMCPDialStdio(t, pack) }) - t.Run("DialMCPServer stdio to sse success", func(t *testing.T) { + t.Run("stdio to sse success", func(t *testing.T) { testMCPDialStdioToSSE(t, pack, "test-sse") }) - t.Run("proxy streamable HTTP requests with TLS cert", func(t *testing.T) { + t.Run("proxy streamable HTTP success", func(t *testing.T) { testMCPProxyStreamableHTTP(t, pack, "test-http") }) } func testMCPDialStdioNoServerFound(t *testing.T, pack *Pack) { - require.NoError(t, pack.tc.SaveProfile(false)) - - _, err := pack.tc.DialMCPServer(context.Background(), "not-found") + // Single connection dial for stdio. + dialer := client.NewMCPServerDialer(pack.tc, "not-found") + _, err := dialer.DialALPN(t.Context()) require.Error(t, err) + require.True(t, trace.IsNotFound(err)) } func testMCPDialStdio(t *testing.T, pack *Pack) { - require.NoError(t, pack.tc.SaveProfile(false)) - - serverConn, err := pack.tc.DialMCPServer(context.Background(), libmcp.DemoServerName) + // Single connection dial for stdio. + dialer := client.NewMCPServerDialer(pack.tc, libmcp.DemoServerName) + serverConn, err := dialer.DialALPN(t.Context()) require.NoError(t, err) ctx := t.Context() @@ -80,9 +80,9 @@ func testMCPDialStdio(t *testing.T, pack *Pack) { } func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) { - require.NoError(t, pack.tc.SaveProfile(false)) - - serverConn, err := pack.tc.DialMCPServer(context.Background(), appName) + // Single connection dial for stdio. + dialer := client.NewMCPServerDialer(pack.tc, appName) + serverConn, err := dialer.DialALPN(t.Context()) require.NoError(t, err) ctx := t.Context() @@ -95,38 +95,14 @@ func testMCPDialStdioToSSE(t *testing.T, pack *Pack, appName string) { } func testMCPProxyStreamableHTTP(t *testing.T, pack *Pack, appName string) { - require.NoError(t, pack.tc.SaveProfile(false)) - - // Find the MCP server. - filter := pack.tc.ResourceFilter(types.KindAppServer) - filter.PredicateExpression = fmt.Sprintf(`name == "%s"`, appName) - apps, err := pack.tc.ListApps(t.Context(), filter) - require.NoError(t, err) - require.Len(t, apps, 1) - - // Issue a TLS cert with app route. - keyRing, err := pack.tc.IssueUserCertsWithMFA(t.Context(), client.ReissueParams{ - RouteToCluster: pack.rootCluster.Secrets.SiteName, - RouteToApp: proto.RouteToApp{ - ClusterName: pack.rootCluster.Secrets.SiteName, - Name: apps[0].GetName(), - PublicAddr: apps[0].GetPublicAddr(), - }, - }) - require.NoError(t, err) - appCert, err := keyRing.AppTLSCert(appName) - require.NoError(t, err) - - // Create an MCP client with app cert. + // Use special dialer for HTTP client. ctx := t.Context() + dialer := client.NewMCPServerDialer(pack.tc, appName) mcpClientTransport, err := mcpclienttransport.NewStreamableHTTP( "https://"+pack.rootCluster.Web, mcpclienttransport.WithHTTPBasicClient(&http.Client{ Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - Certificates: []tls.Certificate{appCert}, - InsecureSkipVerify: true, - }, + DialTLSContext: dialer.DialContext, }, }), ) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 3a5fa017dc107..fb3af9cc04736 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -1890,18 +1890,9 @@ func (a *ServerWithRoles) ListResources(ctx context.Context, req proto.ListResou // Perform the label/search/expr filtering here (instead of at the backend // `ListResources`) to ensure that it will be applied only to resources // the user has access to. - filter := services.MatchResourceFilter{ - ResourceKind: req.ResourceType, - Labels: req.Labels, - SearchKeywords: req.SearchKeywords, - } - - if req.PredicateExpression != "" { - expression, err := services.NewResourceExpression(req.PredicateExpression) - if err != nil { - return nil, trace.Wrap(err) - } - filter.PredicateExpression = expression + filter, err := services.MatchResourceFilterFromListResourceRequest(&req) + if err != nil { + return nil, trace.Wrap(err) } req.Labels = nil diff --git a/lib/cache/cache.go b/lib/cache/cache.go index d2a6b56136471..fbcd231f918dc 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -1672,18 +1672,9 @@ func (c *Cache) listResources(ctx context.Context, req authproto.ListResourcesRe _, span := c.Tracer.Start(ctx, "cache/listResources") defer span.End() - filter := services.MatchResourceFilter{ - ResourceKind: req.ResourceType, - Labels: req.Labels, - SearchKeywords: req.SearchKeywords, - } - - if req.PredicateExpression != "" { - expression, err := services.NewResourceExpression(req.PredicateExpression) - if err != nil { - return nil, trace.Wrap(err) - } - filter.PredicateExpression = expression + filter, err := services.MatchResourceFilterFromListResourceRequest(&req) + if err != nil { + return nil, trace.Wrap(err) } // Adjust page size, so it can't be empty. diff --git a/lib/client/api.go b/lib/client/api.go index 37abe98c5d5fe..c7a6f7f50f104 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -5537,78 +5537,6 @@ func (tc *TeleportClient) DialALPN(ctx context.Context, clientCert tls.Certifica return tlsConn, nil } -// DialMCPServer makes a connection to the remote MCP server. -func (tc *TeleportClient) DialMCPServer(ctx context.Context, appName string) (net.Conn, error) { - ctx, span := tc.Tracer.Start( - ctx, - "teleportClient/DialMCPServer", - oteltrace.WithSpanKind(oteltrace.SpanKindClient), - oteltrace.WithAttributes( - attribute.String("app", appName), - ), - ) - defer span.End() - - apps, err := tc.ListApps(ctx, &proto.ListResourcesRequest{ - ResourceType: types.KindAppServer, - Namespace: apidefaults.Namespace, - PredicateExpression: fmt.Sprintf("name == %q", strings.TrimSpace(appName)), - }) - if err != nil { - return nil, trace.Wrap(err) - } - switch len(apps) { - case 0: - return nil, trace.NotFound("no MCP servers found") - case 1: - default: - log.WarnContext(ctx, "multiple apps found, using the first one") - } - if !apps[0].IsMCP() { - return nil, trace.BadParameter("app %q is not a MCP server", appName) - } - - // TODO(greedy52) support streamable HTTP for "tsh mcp connect" before - // release. - if transport := types.GetMCPServerTransportType(apps[0].GetURI()); transport == types.MCPTransportHTTP { - return nil, trace.NotImplemented("MCP support for %s is not yet implemented", transport) - } - - cert, err := tc.issueMCPCertWithMFA(ctx, apps[0]) - if err != nil { - return nil, trace.Wrap(err) - } - return tc.DialALPN(ctx, cert, alpncommon.ProtocolMCP) -} - -func (tc *TeleportClient) issueMCPCertWithMFA(ctx context.Context, mcpServer types.Application) (tls.Certificate, error) { - profile, err := tc.ProfileStatus() - if err != nil { - return tls.Certificate{}, trace.Wrap(err) - } - - appCertParams := ReissueParams{ - RouteToCluster: tc.SiteName, - RouteToApp: proto.RouteToApp{ - Name: mcpServer.GetName(), - PublicAddr: mcpServer.GetPublicAddr(), - ClusterName: tc.SiteName, - URI: mcpServer.GetURI(), - }, - AccessRequests: profile.ActiveRequests, - } - - // Do NOT write the keyring to avoid race condition when AI clients run - // multiple tsh at the same time. - keyRing, err := tc.IssueUserCertsWithMFA(ctx, appCertParams) - if err != nil { - return tls.Certificate{}, trace.Wrap(err) - } - - cert, err := keyRing.AppTLSCert(mcpServer.GetName()) - return cert, trace.Wrap(err) -} - // DialDatabase makes a remote connection to the database. // // TODO(gabrielcorado): support acccess requests connections. @@ -5648,6 +5576,11 @@ func (tc *TeleportClient) DialDatabase(ctx context.Context, route proto.RouteToD return tc.DialALPN(ctx, cert, alpnProtocol) } +// GetSiteName returns the cluster name this client instance is targeting. +func (tc *TeleportClient) GetSiteName() string { + return tc.SiteName +} + // CalculateSSHLogins returns the subset of the allowedLogins that exist in // the principals of the identity. This is required because SSH authorization // only allows using a login that exists in the certificates valid principals. diff --git a/lib/client/mcp.go b/lib/client/mcp.go new file mode 100644 index 0000000000000..f7ca3cfa0c0ad --- /dev/null +++ b/lib/client/mcp.go @@ -0,0 +1,178 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package client + +import ( + "context" + "crypto/tls" + "fmt" + "log/slog" + "net" + "strings" + "sync" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client/proto" + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" + alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" + "github.com/gravitational/teleport/lib/utils" +) + +// MCPServerDialerClient defines a subset of TeleportClient functions that are +// used by MCPServerDialer. +type MCPServerDialerClient interface { + DialALPN(context.Context, tls.Certificate, alpncommon.Protocol) (net.Conn, error) + ListApps(context.Context, *proto.ListResourcesRequest) ([]types.Application, error) + IssueUserCertsWithMFA(context.Context, ReissueParams) (*KeyRing, error) + ProfileStatus() (*ProfileStatus, error) + GetSiteName() string +} + +// MCPServerDialer is a wrapper of TeleportClient for handling MCP connections +// to proxy. +type MCPServerDialer struct { + client MCPServerDialerClient + appName string + + mu sync.Mutex + app types.Application + cert tls.Certificate + clock clockwork.Clock + logger *slog.Logger +} + +// NewMCPServerDialer creates a new MCPServerDialer. +func NewMCPServerDialer(client MCPServerDialerClient, appName string) *MCPServerDialer { + return &MCPServerDialer{ + client: client, + appName: appName, + clock: clockwork.NewRealClock(), + logger: slog.With( + teleport.ComponentKey, + teleport.Component(teleport.ComponentMCP, "dialer"), + ), + } +} + +// GetApp returns the types.Application for the associated MCP server. +func (d *MCPServerDialer) GetApp(ctx context.Context) (types.Application, error) { + d.mu.Lock() + defer d.mu.Unlock() + return d.getAppLocked(ctx) +} + +// DialALPN dials Teleport Proxy to establish a TLS routing connection for the +// MCP server. +func (d *MCPServerDialer) DialALPN(ctx context.Context) (net.Conn, error) { + app, err := d.getAppLocked(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + cert, err := d.getCertLocked(ctx, app) + if err != nil { + return nil, trace.Wrap(err) + } + switch types.GetMCPServerTransportType(app.GetURI()) { + case types.MCPTransportHTTP: + return d.client.DialALPN(ctx, cert, alpncommon.ProtocolHTTP) + default: + return d.client.DialALPN(ctx, cert, alpncommon.ProtocolMCP) + } +} + +// DialContext is a simple wrapper of DialALPN. This function is defined to be +// compatible with common context dialer interfaces. +func (d *MCPServerDialer) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { + return d.DialALPN(ctx) +} + +func (d *MCPServerDialer) getAppLocked(ctx context.Context) (types.Application, error) { + if d.app != nil { + return d.app, nil + } + + apps, err := d.client.ListApps(ctx, &proto.ListResourcesRequest{ + ResourceType: types.KindAppServer, + Namespace: apidefaults.Namespace, + PredicateExpression: fmt.Sprintf("name == %q", strings.TrimSpace(d.appName)), + }) + if err != nil { + return nil, trace.Wrap(err) + } + switch len(apps) { + case 0: + return nil, trace.NotFound("no MCP servers found") + case 1: + default: + d.logger.WarnContext(ctx, "multiple appServers found, using the first one") + } + if !apps[0].IsMCP() { + return nil, trace.BadParameter("app %q is not a MCP server", d.appName) + } + + d.app = apps[0] + d.logger.InfoContext(ctx, "Successfully fetched app", + "name", d.app.GetName(), + "transport", types.GetMCPServerTransportType(d.app.GetURI()), + ) + return d.app, nil +} + +func (d *MCPServerDialer) getCertLocked(ctx context.Context, mcpServer types.Application) (tls.Certificate, error) { + if err := utils.VerifyTLSCertLeafExpiry(d.cert, d.clock); err == nil { + return d.cert, nil + } + + d.logger.InfoContext(ctx, "Reissuing certificate", "name", mcpServer.GetName()) + profile, err := d.client.ProfileStatus() + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + appCertParams := ReissueParams{ + RouteToCluster: d.client.GetSiteName(), + RouteToApp: proto.RouteToApp{ + Name: mcpServer.GetName(), + PublicAddr: mcpServer.GetPublicAddr(), + ClusterName: d.client.GetSiteName(), + URI: mcpServer.GetURI(), + }, + AccessRequests: profile.ActiveRequests, + } + + // Do NOT write the keyring to avoid race condition when AI clients run + // multiple tsh at the same time. + keyRing, err := d.client.IssueUserCertsWithMFA(ctx, appCertParams) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + cert, err := keyRing.AppTLSCert(mcpServer.GetName()) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + d.logger.InfoContext(ctx, "Successfully issued certificate", "name", mcpServer.GetName()) + d.cert = cert + return d.cert, nil +} diff --git a/lib/client/mcp_test.go b/lib/client/mcp_test.go new file mode 100644 index 0000000000000..0a6f82fa5e9c1 --- /dev/null +++ b/lib/client/mcp_test.go @@ -0,0 +1,240 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package client + +import ( + "cmp" + "context" + "crypto/tls" + "net" + "slices" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/teleport/lib/services" + alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" +) + +type mockALPNConn struct { + net.Conn + cert tls.Certificate + protocol alpncommon.Protocol +} + +func newFakeALPNConn(protocol alpncommon.Protocol, cert tls.Certificate) *mockALPNConn { + return &mockALPNConn{ + Conn: nil, // currently not using the net.conn + cert: cert, + protocol: protocol, + } +} + +func (c *mockALPNConn) getCertSerialNumber() string { + leaf, err := utils.TLSCertLeaf(c.cert) + if err != nil { + return "" + } + return leaf.SerialNumber.String() +} + +type mockMCPServerDialerClient struct { + appServers types.AppServers + tlsCA *tlsca.CertAuthority + clock *clockwork.FakeClock + identity tlsca.Identity +} + +func (m *mockMCPServerDialerClient) DialALPN(_ context.Context, cert tls.Certificate, protocol alpncommon.Protocol) (net.Conn, error) { + if err := utils.VerifyTLSCertLeafExpiry(cert, m.clock); err != nil { + return nil, trace.Wrap(err) + } + return newFakeALPNConn(protocol, cert), nil +} + +func (m *mockMCPServerDialerClient) ListApps(_ context.Context, req *proto.ListResourcesRequest) ([]types.Application, error) { + filter, err := services.MatchResourceFilterFromListResourceRequest(req) + if err != nil { + return nil, trace.Wrap(err) + } + appServers, err := services.MatchResourcesByFilters(m.appServers, filter) + if err != nil { + return nil, trace.Wrap(err) + } + return slices.Collect(appServers.Applications()), nil +} + +func (m *mockMCPServerDialerClient) IssueUserCertsWithMFA(_ context.Context, params ReissueParams) (*KeyRing, error) { + if params.RouteToApp.Name == "" { + return nil, trace.BadParameter("missing app name") + } + if params.RouteToCluster != m.GetSiteName() { + return nil, trace.BadParameter("wrong cluster") + } + subject, err := m.identity.Subject() + if err != nil { + return nil, trace.Wrap(err) + } + tlsKey, err := cryptosuites.GeneratePrivateKeyWithAlgorithm(cryptosuites.Ed25519) + if err != nil { + return nil, trace.Wrap(err) + } + tlsCert, err := m.tlsCA.GenerateCertificate(tlsca.CertificateRequest{ + Clock: m.clock, + PublicKey: tlsKey.Public(), + Subject: subject, + NotAfter: m.clock.Now().Add(cmp.Or(params.TTL, time.Minute)), + }) + if err != nil { + return nil, trace.Wrap(err) + } + return &KeyRing{ + AppTLSCredentials: map[string]TLSCredential{ + params.RouteToApp.Name: { + PrivateKey: tlsKey, + Cert: tlsCert, + }, + }, + }, nil +} + +func (m *mockMCPServerDialerClient) ProfileStatus() (*ProfileStatus, error) { + return &ProfileStatus{}, nil +} + +func (m *mockMCPServerDialerClient) GetSiteName() string { + return "teleport.example.com" +} + +func TestMCPServerDialer(t *testing.T) { + tlsCA, _, err := newSelfSignedCA(CAPriv, "localhost") + require.NoError(t, err) + + mockClient := &mockMCPServerDialerClient{ + appServers: types.AppServers{ + mustMakeAppServer(t, "http-app", "http://localhost:1234"), + mustMakeAppServer(t, "http-mcp", "mcp+http://localhost:1234"), + mustMakeAppServer(t, "sse-mcp", "mcp+sse+http://localhost:1234"), + }, + clock: clockwork.NewFakeClock(), + tlsCA: tlsCA, + identity: tlsca.Identity{ + Username: "test", + }, + } + + t.Run("GetApp", func(t *testing.T) { + tests := []struct { + name string + checkResult require.ErrorAssertionFunc + }{ + { + name: "http-app", + checkResult: require.Error, + }, + { + name: "http-mcp", + checkResult: require.NoError, + }, + { + name: "sse-mcp", + checkResult: require.NoError, + }, + { + name: "not-found", + checkResult: require.Error, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dialer := NewMCPServerDialer(mockClient, test.name) + _, err := dialer.GetApp(t.Context()) + test.checkResult(t, err) + }) + } + }) + + t.Run("DialALPN", func(t *testing.T) { + tests := []struct { + name string + wantALPN alpncommon.Protocol + }{ + { + name: "http-mcp", + wantALPN: alpncommon.ProtocolHTTP, + }, + { + name: "sse-mcp", + wantALPN: alpncommon.ProtocolMCP, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dialer := NewMCPServerDialer(mockClient, test.name) + dialer.clock = mockClient.clock + + // Verify ALPN used. + firstConn, err := dialer.DialALPN(t.Context()) + require.NoError(t, err) + firstALPNConn, ok := firstConn.(*mockALPNConn) + require.True(t, ok) + require.Equal(t, test.wantALPN, firstALPNConn.protocol) + + // Advance time to trigger issue cert again. + mockClient.clock.Advance(time.Hour) + secondConn, err := dialer.DialALPN(t.Context()) + require.NoError(t, err) + secondALPNConn, ok := secondConn.(*mockALPNConn) + require.True(t, ok) + + // Double-check a new cert is issued. + firstSerial := firstALPNConn.getCertSerialNumber() + secondSerial := secondALPNConn.getCertSerialNumber() + require.NotEmpty(t, firstSerial) + require.NotEqual(t, firstSerial, secondSerial) + }) + } + }) +} + +func mustMakeAppServer(t *testing.T, name, uri string) types.AppServer { + t.Helper() + app := mustMakeApp(t, name, uri) + appServer, err := types.NewAppServerV3FromApp(app, "test", "test") + require.NoError(t, err) + return appServer +} + +func mustMakeApp(t *testing.T, name, uri string) *types.AppV3 { + t.Helper() + app, err := types.NewAppV3( + types.Metadata{Name: name}, + types.AppSpecV3{URI: uri}, + ) + require.NoError(t, err) + return app +} diff --git a/lib/services/matchers.go b/lib/services/matchers.go index d2f0329b8798f..24ec87329f3cb 100644 --- a/lib/services/matchers.go +++ b/lib/services/matchers.go @@ -26,6 +26,7 @@ import ( "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" azureutils "github.com/gravitational/teleport/api/utils/azure" @@ -145,6 +146,20 @@ func (r *resourceWithTargetHealth) GetTargetHealthStatus() types.TargetHealthSta // only applies to resource Application. type ResourceSeenKey struct{ name, kind, addr string } +// MatchResourcesByFilters filters provided resources with profiled filter. +func MatchResourcesByFilters[E types.ResourceWithLabels, S ~[]E](all S, filter MatchResourceFilter) (S, error) { + var filtered S + for _, r := range all { + match, err := MatchResourceByFilters(r, filter, nil) + if err != nil { + return nil, trace.Wrap(err) + } else if match { + filtered = append(filtered, r) + } + } + return filtered, nil +} + // MatchResourceByFilters returns true if all filter values given matched against the resource. // // If no filters were provided, we will treat that as a match. @@ -315,3 +330,21 @@ func (m *MatchResourceFilter) IsSimple() bool { m.PredicateExpression == nil && len(m.Kinds) == 0 } + +// MatchResourceFilterFromListResourceRequest converts a +// proto.ListResourcesRequest to MatchResourceFilter. +func MatchResourceFilterFromListResourceRequest(req *proto.ListResourcesRequest) (MatchResourceFilter, error) { + filter := MatchResourceFilter{ + ResourceKind: req.ResourceType, + Labels: req.Labels, + SearchKeywords: req.SearchKeywords, + } + if req.PredicateExpression != "" { + expression, err := NewResourceExpression(req.PredicateExpression) + if err != nil { + return MatchResourceFilter{}, trace.Wrap(err) + } + filter.PredicateExpression = expression + } + return filter, nil +} diff --git a/lib/services/matchers_test.go b/lib/services/matchers_test.go index 4328d1ad42ed5..d5d0867a1569c 100644 --- a/lib/services/matchers_test.go +++ b/lib/services/matchers_test.go @@ -749,6 +749,38 @@ func TestResourceMatchersToTypes(t *testing.T) { } } +func TestMatchResourcesByFilters(t *testing.T) { + appServers := make(types.AppServers, 5) + oddOrEven := func(i int) string { + if i%2 == 1 { + return "odd" + } + return "even" + } + for i := range len(appServers) { + app, err := types.NewAppV3(types.Metadata{ + Name: fmt.Sprintf("app-%d", i), + Labels: map[string]string{"group": oddOrEven(i)}, + }, types.AppSpecV3{ + URI: "http://localhost:8888", + }) + require.NoError(t, err) + appServers[i] = newAppServerFromApp(t, app) + } + + evenAppServers, err := MatchResourcesByFilters(appServers, MatchResourceFilter{ + ResourceKind: types.KindAppServer, + Labels: map[string]string{"group": "even"}, + }) + + require.NoError(t, err) + require.IsType(t, types.AppServers{}, evenAppServers) + require.Equal(t, + []string{"app-0", "app-2", "app-4"}, + slices.Collect(types.ResourceNames(evenAppServers)), + ) +} + func newMCPServerApp(t *testing.T, name string) *types.AppV3 { t.Helper() app, err := types.NewAppV3(types.Metadata{ diff --git a/lib/utils/certs.go b/lib/utils/certs.go index b18f1afb00cad..798ec7f2be9b0 100644 --- a/lib/utils/certs.go +++ b/lib/utils/certs.go @@ -142,6 +142,15 @@ func VerifyCertificateExpiry(c *x509.Certificate, clock clockwork.Clock) error { return nil } +// VerifyTLSCertLeafExpiry checks a TLS certificate's expiration status. +func VerifyTLSCertLeafExpiry(cert tls.Certificate, clock clockwork.Clock) error { + leaf, err := TLSCertLeaf(cert) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(VerifyCertificateExpiry(leaf, clock)) +} + // VerifyCertificateChain reads in chain of certificates and makes sure the // chain from leaf to root is valid. This ensures that clients (web browsers // and CLI) won't have problem validating the chain. diff --git a/lib/utils/certs_test.go b/lib/utils/certs_test.go index 672e8336d9af3..84e445b9fd430 100644 --- a/lib/utils/certs_test.go +++ b/lib/utils/certs_test.go @@ -19,13 +19,17 @@ package utils import ( + "crypto/tls" "runtime" "testing" + "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/fixtures" ) func TestRejectsInvalidPEMData(t *testing.T) { @@ -60,3 +64,49 @@ func TestNewCertPoolFromPath(t *testing.T) { //nolint:staticcheck // Pool not returned by SystemCertPool require.Len(t, pool.Subjects(), 1) } + +func TestVerifyTLSCertLeafExpiry(t *testing.T) { + tlsCert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM)) + require.NoError(t, err) + emptyCert := tls.Certificate{} + + tests := []struct { + name string + input tls.Certificate + fakeTime time.Time + checkResult require.ErrorAssertionFunc + }{ + { + name: "empty", + input: emptyCert, + fakeTime: time.Now(), + checkResult: require.Error, + }, + { + name: "valid", + input: tlsCert, + fakeTime: fixtures.TLSCACertNotAfter.Add(-time.Minute), + checkResult: require.NoError, + }, + { + name: "not valid yet", + input: tlsCert, + fakeTime: fixtures.TLSCACertNotBefore.Add(-time.Minute), + checkResult: require.Error, + }, + { + name: "expired", + input: tlsCert, + fakeTime: fixtures.TLSCACertNotAfter.Add(time.Minute), + checkResult: require.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clock := clockwork.NewFakeClockAt(tt.fakeTime) + err := VerifyTLSCertLeafExpiry(tt.input, clock) + tt.checkResult(t, err) + }) + } +} diff --git a/tool/tsh/common/db_exec_test.go b/tool/tsh/common/db_exec_test.go index 8e4edaf57b82b..4796d5e676726 100644 --- a/tool/tsh/common/db_exec_test.go +++ b/tool/tsh/common/db_exec_test.go @@ -334,29 +334,11 @@ func (c *fakeDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, re } func matchResources[R types.ResourceWithLabels](req *proto.ListResourcesRequest, s []R) ([]R, error) { - filter := services.MatchResourceFilter{ - ResourceKind: req.ResourceType, - Labels: req.Labels, - SearchKeywords: req.SearchKeywords, - } - if req.PredicateExpression != "" { - expression, err := services.NewResourceExpression(req.PredicateExpression) - if err != nil { - return nil, trace.Wrap(err) - } - filter.PredicateExpression = expression - } - - var filtered []R - for _, r := range s { - match, err := services.MatchResourceByFilters(r, filter, nil) - if err != nil { - return nil, trace.Wrap(err) - } else if match { - filtered = append(filtered, r) - } + filter, err := services.MatchResourceFilterFromListResourceRequest(req) + if err != nil { + return nil, trace.Wrap(err) } - return filtered, nil + return services.MatchResourcesByFilters(s, filter) } func mustMakeDatabaseServer(t *testing.T, db types.Database) types.DatabaseServer { diff --git a/tool/tsh/common/mcp_app.go b/tool/tsh/common/mcp_app.go index 0cca38e389df5..41377f527bc4f 100644 --- a/tool/tsh/common/mcp_app.go +++ b/tool/tsh/common/mcp_app.go @@ -471,13 +471,14 @@ func (c *mcpConnectCommand) run() error { } tc.NonInteractive = true + dialer := client.NewMCPServerDialer(tc, c.cf.AppName) if c.autoReconnect { return clientmcp.ProxyStdioConnWithAutoReconnect( c.cf.Context, clientmcp.ProxyStdioConnWithAutoReconnectConfig{ ClientStdio: utils.CombinedStdio{}, DialServer: func(ctx context.Context) (io.ReadWriteCloser, error) { - conn, err := tc.DialMCPServer(ctx, c.cf.AppName) + conn, err := dialer.DialALPN(ctx) return conn, trace.Wrap(err) }, MakeReconnectUserMessage: makeMCPReconnectUserMessage, @@ -485,7 +486,7 @@ func (c *mcpConnectCommand) run() error { ) } - serverConn, err := tc.DialMCPServer(c.cf.Context, c.cf.AppName) + serverConn, err := dialer.DialALPN(c.cf.Context) if err != nil { return trace.Wrap(err) }