From 96d8fcd925f4603209f389d9ad0c60648f85d21e Mon Sep 17 00:00:00 2001 From: Tyler Richardson Date: Thu, 19 Feb 2026 12:37:21 -0500 Subject: [PATCH 1/9] Use DatabaseServerWatcher for filtered db server lookups in proxy Replace the call to CachingAccessPoint.GetDatabaseServers in ProxyServer.Authorize with a DatabaseServerWatcher lookup using CurrentResourcesWithFilter. Previously, every inbound database connection allocated and iterated a full copy of all database servers in the cache, then discarded all but the matching entries. Under high concurrency with large numbers of registered databases this caused significant GC pressure and OOM. The watcher lookup only allocates the servers matching the requested database name. The watcher is initialized once at startup in service.go and plumbed through the Cluster interface, following the same pattern as AppServerWatcher. --- lib/reversetunnel/leaf_cluster.go | 8 ++ lib/reversetunnel/local_cluster.go | 5 + lib/reversetunnel/peer.go | 12 ++ lib/reversetunnel/srv.go | 18 +++ lib/reversetunnelclient/api.go | 2 + lib/reversetunnelclient/fake.go | 23 +++- lib/service/service.go | 12 ++ lib/services/presence.go | 3 +- lib/services/readonly/readonly.go | 38 ++++++ lib/services/watcher.go | 49 +++++++ lib/services/watcher_test.go | 102 +++++++++++++++ lib/srv/db/common/connect/connect.go | 34 ++--- .../db/common/connect/connect_bench_test.go | 121 ++++++++++++++++++ lib/srv/db/common/connect/connect_test.go | 65 +++++++--- lib/srv/db/proxyserver.go | 10 +- lib/srv/regular/sshserver_test.go | 20 +++ lib/web/apiserver_test.go | 20 +++ 17 files changed, 491 insertions(+), 51 deletions(-) create mode 100644 lib/srv/db/common/connect/connect_bench_test.go diff --git a/lib/reversetunnel/leaf_cluster.go b/lib/reversetunnel/leaf_cluster.go index 604b95f0f9553..15e11ccddf471 100644 --- a/lib/reversetunnel/leaf_cluster.go +++ b/lib/reversetunnel/leaf_cluster.go @@ -92,6 +92,9 @@ type leafCluster struct { // appServerWatcher is a app server watcher. appServerWatcher *services.GenericWatcher[types.AppServer, readonly.AppServer] + // databaseServerWatcher is a database server watcher. + databaseServerWatcher *services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer] + // remoteCA is the last remote certificate authority recorded by the client. // It is used to detect CA rotation status changes. If the rotation // state has been changed, the tunnel will reconnect to re-create the client @@ -182,6 +185,11 @@ func (s *leafCluster) GitServerWatcher() (*services.GenericWatcher[types.Server, return nil, trace.NotImplemented("GitServerWatcher not implemented for leafCluster") } +// DatabaseServerWatcher returns the Database server watcher for the leaf cluster. +func (s *leafCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + return s.databaseServerWatcher, nil +} + func (s *leafCluster) GetClient() (authclient.ClientI, error) { return s.leafClient, nil } diff --git a/lib/reversetunnel/local_cluster.go b/lib/reversetunnel/local_cluster.go index 6f9cf897d4bf3..6e289c82c85b8 100644 --- a/lib/reversetunnel/local_cluster.go +++ b/lib/reversetunnel/local_cluster.go @@ -193,6 +193,11 @@ func (s *localCluster) GitServerWatcher() (*services.GenericWatcher[types.Server return s.srv.GitServerWatcher, nil } +// DatabaseServerWatcher returns a Database server watcher for this cluster. +func (s *localCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + return s.srv.DatabaseServerWatcher, nil +} + // GetClient returns a client to the full Auth Server API. func (s *localCluster) GetClient() (authclient.ClientI, error) { return s.client, nil diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index 3c351e1af612b..d3abcaec338f9 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -114,6 +114,14 @@ func (p *expectedLeafClusters) GitServerWatcher() (*services.GenericWatcher[type return cluster.GitServerWatcher() } +func (p *expectedLeafClusters) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + cluster, err := p.pickCluster() + if err != nil { + return nil, trace.Wrap(err) + } + return cluster.DatabaseServerWatcher() +} + func (p *expectedLeafClusters) GetClient() (authclient.ClientI, error) { cluster, err := p.pickCluster() if err != nil { @@ -227,6 +235,10 @@ func (s *expectedLeafCluster) GitServerWatcher() (*services.GenericWatcher[types return nil, s.discoveryError("unable to fetch git server watcher for leaf cluster") } +func (s *expectedLeafCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + return nil, s.discoveryError("unable to fetch database server watcher for leaf cluster") +} + func (s *expectedLeafCluster) GetClient() (authclient.ClientI, error) { return nil, s.discoveryError("unable to fetch auth client for leaf cluster") } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index bc05edeff5829..54080baca3cc2 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -226,6 +226,9 @@ type Config struct { // AppServerWatcher is a app server watcher. AppServerWatcher *services.GenericWatcher[types.AppServer, readonly.AppServer] + // DatabaseServerWatcher is a database server watcher. + DatabaseServerWatcher *services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer] + // CircuitBreakerConfig configures the auth client circuit breaker CircuitBreakerConfig breaker.Config @@ -305,6 +308,9 @@ func (cfg *Config) CheckAndSetDefaults() error { if cfg.AppServerWatcher == nil { return trace.BadParameter("missing parameter AppServerWatcher") } + if cfg.DatabaseServerWatcher == nil { + return trace.BadParameter("missing parameter DatabaseServerWatcher") + } if cfg.EICEDialer == nil { return trace.BadParameter("missing parameter EICEDialer") @@ -1298,6 +1304,18 @@ func newLeafCluster(srv *server, domainName string, sconn ssh.Conn) (*leafCluste } leaf.appServerWatcher = appServerWatcher + databaseServerWatcher, err := services.NewDatabaseServerWatcher(closeContext, services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: srv.Component, + Logger: srv.Logger, + Client: accessPoint, + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + leaf.databaseServerWatcher = databaseServerWatcher + // instantiate a cache of host certificates for the forwarding server. the // certificate cache is created in each cluster (instead of creating it in // reversetunnel.server and passing it along) so that the host certificate diff --git a/lib/reversetunnelclient/api.go b/lib/reversetunnelclient/api.go index 4c7ad5e9c0f30..5803770ed7f06 100644 --- a/lib/reversetunnelclient/api.go +++ b/lib/reversetunnelclient/api.go @@ -133,6 +133,8 @@ type Cluster interface { GitServerWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) // AppServerWatcher returns the watcher that maintains the app server set for the cluster AppServerWatcher() (*services.GenericWatcher[types.AppServer, readonly.AppServer], error) + // DatabaseServerWatcher returns the watcher that maintains the database server set for the cluster + DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) // GetTunnelsCount returns the amount of active inbound tunnels // from the remote cluster GetTunnelsCount() int diff --git a/lib/reversetunnelclient/fake.go b/lib/reversetunnelclient/fake.go index c4f24ed04e302..efc3d954bc91f 100644 --- a/lib/reversetunnelclient/fake.go +++ b/lib/reversetunnelclient/fake.go @@ -73,6 +73,8 @@ type FakeCluster struct { closed bool // appServerWatcher ia a app server watcher to speed up app look up. appServerWatcher *services.GenericWatcher[types.AppServer, readonly.AppServer] + // databaseServerWatcher is a database server watcher to speed up database server look up. + databaseServerWatcher *services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer] } // NewFakeCluster is a FakeCluster constructor. @@ -84,11 +86,19 @@ func NewFakeCluster(clusterName string, accessPoint authclient.RemoteProxyAccess }, }) + databaseServerWatcher, _ := services.NewDatabaseServerWatcher(context.TODO(), services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "FakeCluster", + Client: accessPoint, + }, + }) + return &FakeCluster{ - Name: clusterName, - connCh: make(chan net.Conn), - AccessPoint: accessPoint, - appServerWatcher: appServerWatcher, + Name: clusterName, + connCh: make(chan net.Conn), + AccessPoint: accessPoint, + appServerWatcher: appServerWatcher, + databaseServerWatcher: databaseServerWatcher, } } @@ -97,6 +107,11 @@ func (s *FakeCluster) AppServerWatcher() (*services.GenericWatcher[types.AppServ return s.appServerWatcher, nil } +// DatabaseServerWatcher returns the watcher that maintains the database server set for the cluster +func (s *FakeCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + return s.databaseServerWatcher, nil +} + // CachingAccessPoint returns caching auth server client. func (s *FakeCluster) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, error) { return s.AccessPoint, nil diff --git a/lib/service/service.go b/lib/service/service.go index 95eb933aaa0cf..ba39ac400f74a 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4974,6 +4974,17 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } + databaseServerWatcher, err := services.NewDatabaseServerWatcher(process.ExitContext(), services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Logger: process.logger.With(teleport.ComponentKey, teleport.ComponentProxy), + Client: accessPoint, + }, + }) + if err != nil { + return trace.Wrap(err) + } + serverTLSConfig, err := process.ServerTLSConfig(conn) if err != nil { return trace.Wrap(err) @@ -5246,6 +5257,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { NodeWatcher: nodeWatcher, AppServerWatcher: appServerWatcher, GitServerWatcher: gitServerWatcher, + DatabaseServerWatcher: databaseServerWatcher, CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: process.Config.CircuitBreakerConfig, LocalAuthAddresses: utils.NetAddrsToStrings(process.Config.AuthServerAddresses()), diff --git a/lib/services/presence.go b/lib/services/presence.go index a5fbcde41079c..2dee0f598c8ca 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -48,7 +48,8 @@ type NodesGetter interface { // DatabaseServersGetter is a service that gets database servers. type DatabaseServersGetter interface { - GetDatabaseServers(context.Context, string, ...MarshalOption) ([]types.DatabaseServer, error) + // GetDatabaseServers returns all registered database proxy servers. + GetDatabaseServers(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.DatabaseServer, error) } // AppServersGetter is a service that gets application servers. diff --git a/lib/services/readonly/readonly.go b/lib/services/readonly/readonly.go index 3100073cb03db..0cb5ce357e6f1 100644 --- a/lib/services/readonly/readonly.go +++ b/lib/services/readonly/readonly.go @@ -292,6 +292,44 @@ type AppServer interface { var _ AppServer = types.AppServer(nil) +// DatabaseServer is a read only variant of [types.DatabaseServer] +type DatabaseServer interface { + // ResourceWithLabels provides common resource methods. + ResourceWithLabels + // GetNamespace returns server namespace. + GetNamespace() string + // GetTeleportVersion returns the teleport version the server is running on. + GetTeleportVersion() string + // GetHostname returns the server hostname. + GetHostname() string + // GetHostID returns ID of the host the server is running on. + GetHostID() string + // GetRotation gets the state of certificate authority rotation. + GetRotation() types.Rotation + // String returns string representation of the server. + String() string + // Copy returns a copy of this database server object. + Copy() types.DatabaseServer + // GetDatabase returns the database this database server proxies. + GetDatabase() types.Database + // ProxiedService provides common methods for a proxied service. + ProxiedService + // GetRelayGroup returns the name of the Relay group that the database + // server is connected to. + GetRelayGroup() string + // GetRelayIDs returns the list of Relay host IDs that the database server + // is connected to. + GetRelayIDs() []string + // GetTargetHealth returns the database server's target health. + GetTargetHealth() types.TargetHealth + // GetTargetHealthStatus returns target health status + GetTargetHealthStatus() types.TargetHealthStatus + // GetScope returns the scope this server belongs to. + GetScope() string +} + +var _ DatabaseServer = types.DatabaseServer(nil) + // KubeServer is a read only variant of [types.KubeServer]. type KubeServer interface { // ResourceWithLabels provides common resource methods. diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 982f74fa317d3..9783ad32201c9 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -546,6 +546,55 @@ func NewAppServersWatcher(ctx context.Context, cfg AppServersWatcherConfig) (*Ge return w, trace.Wrap(err) } +type DatabaseServerWatcherConfig struct { + DatabaseServersGetter + ResourceWatcherConfig +} + +// CheckAndSetDefaults checks parameters and sets default values. +func (cfg *DatabaseServerWatcherConfig) CheckAndSetDefaults() error { + if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + + if cfg.MaxStaleness == 0 { + const databaseServerMaxStaleness = time.Minute + cfg.MaxStaleness = databaseServerMaxStaleness + } + + if cfg.DatabaseServersGetter == nil { + getter, ok := cfg.Client.(DatabaseServersGetter) + if !ok { + return trace.BadParameter("missing parameter DatabaseServersGetter and Client not usable as DatabaseServersGetter") + } + cfg.DatabaseServersGetter = getter + } + + return nil +} + +func NewDatabaseServerWatcher(ctx context.Context, cfg DatabaseServerWatcherConfig) (*GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.DatabaseServer, readonly.DatabaseServer]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindDatabaseServer, + ResourceKey: func(r types.DatabaseServer) string { return r.GetHostID() + r.GetName() }, + ResourceGetter: func(ctx context.Context) ([]types.DatabaseServer, error) { + return cfg.DatabaseServersGetter.GetDatabaseServers(ctx, apidefaults.Namespace) + }, + DisableUpdateBroadcast: true, + CloneFunc: types.DatabaseServer.Copy, + ReadOnlyFunc: func(resource types.DatabaseServer) readonly.DatabaseServer { + return resource + }, + }) + + return w, trace.Wrap(err) +} + // KubeServerWatcherConfig is an KubeServerWatcher configuration. type KubeServerWatcherConfig struct { // KubernetesServerGetter is responsible for fetching kube_server resources. diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 3ef18656481d9..165915824c22a 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -757,6 +757,108 @@ func newApp(t *testing.T, name string) *types.AppV3 { return app } +// TestDatabaseServerWatcher tests that database server resource watcher properly +// receives and dispatches updates. +func TestDatabaseServerWatcher(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + ctx := t.Context() + + bk, err := memory.New(memory.Config{Context: ctx}) + require.NoError(t, err) + + type client struct { + services.DatabaseServersGetter + types.Events + } + + presenceService := local.NewPresenceService(bk) + w, err := services.NewDatabaseServerWatcher(ctx, services.DatabaseServerWatcherConfig{ + DatabaseServersGetter: presenceService, + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + MaxRetryPeriod: 200 * time.Millisecond, + Client: &client{ + DatabaseServersGetter: presenceService, + Events: local.NewEventsService(bk), + }, + }, + }) + require.NoError(t, err) + t.Cleanup(w.Close) + + // Wait for initial load. + require.NoError(t, w.WaitInitialization()) + + // Initially there are no database servers. + servers, err := w.CurrentResources(ctx) + require.NoError(t, err) + require.Empty(t, servers) + + // Add a database server and wait for the watcher to process the event. + server1 := newDatabaseServer(t, "db1", "host1") + _, err = presenceService.UpsertDatabaseServer(ctx, server1) + require.NoError(t, err) + synctest.Wait() + + servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { + return ds.GetDatabase().GetName() == "db1" + }) + require.NoError(t, err) + require.Len(t, servers, 1) + require.Equal(t, server1.GetName(), servers[0].GetName()) + + // Add a second database server and wait for the watcher to process the event. + server2 := newDatabaseServer(t, "db2", "host2") + _, err = presenceService.UpsertDatabaseServer(ctx, server2) + require.NoError(t, err) + synctest.Wait() + + servers, err = w.CurrentResources(ctx) + require.NoError(t, err) + require.Len(t, servers, 2) + + servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { + return ds.GetDatabase().GetName() == "db2" + }) + require.NoError(t, err) + require.Len(t, servers, 1) + require.Equal(t, server2.GetName(), servers[0].GetName()) + + // Delete the first database server and wait for the watcher to process the event. + err = presenceService.DeleteDatabaseServer(ctx, apidefaults.Namespace, server1.GetHostID(), server1.GetName()) + require.NoError(t, err) + synctest.Wait() + + // Verify the remaining server is server2. + servers, err = w.CurrentResources(ctx) + require.NoError(t, err) + require.Len(t, servers, 1) + require.Equal(t, server2.GetName(), servers[0].GetName()) + }) +} + +func newDatabaseServer(t *testing.T, dbName, hostID string) types.DatabaseServer { + t.Helper() + server, err := types.NewDatabaseServerV3(types.Metadata{ + Name: dbName, + }, types.DatabaseServerSpecV3{ + Database: &types.DatabaseV3{ + Metadata: types.Metadata{ + Name: dbName, + }, + Spec: types.DatabaseSpecV3{ + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + }, + }, + HostID: hostID, + Hostname: dbName, + }) + require.NoError(t, err) + return server +} + func TestCertAuthorityWatcher(t *testing.T) { t.Parallel() diff --git a/lib/srv/db/common/connect/connect.go b/lib/srv/db/common/connect/connect.go index 1be2ee49ab776..5f6c2734c2cb3 100644 --- a/lib/srv/db/common/connect/connect.go +++ b/lib/srv/db/common/connect/connect.go @@ -31,22 +31,20 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" - apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/reversetunnelclient" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" logutils "github.com/gravitational/teleport/lib/utils/log" ) -// DatabaseServersGetter is an interface for retrieving information about -// database proxy servers within a specific namespace. -type DatabaseServersGetter interface { - // GetDatabaseServers returns all registered database proxy servers. - GetDatabaseServers(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.DatabaseServer, error) +// DatabaseServerWatcher defines an interface for watching database servers in a cluster. +type DatabaseServerWatcher interface { + // CurrentResourcesWithFilter returns the current list of database servers in the cluster that match the provided filter function. + CurrentResourcesWithFilter(ctx context.Context, filter func(readonly.DatabaseServer) bool) ([]types.DatabaseServer, error) } // GetDatabaseServersParams contains the parameters required to retrieve @@ -55,30 +53,24 @@ type GetDatabaseServersParams struct { Logger *slog.Logger // ClusterName is the cluster name to which the database belongs. ClusterName string - // DatabaseServersGetter used to fetch the list of database servers. - DatabaseServersGetter DatabaseServersGetter + // Watcher is used to retrieve database servers registered in the cluster. + Watcher DatabaseServerWatcher // Identity contains the identity information. Identity tlsca.Identity } // GetDatabaseServers returns a list of database servers in a cluster that match -// the routing information from the provided identity. +// the routing information from the provided identity. It uses the cluster's +// DatabaseServerWatcher for fast in-memory lookup. func GetDatabaseServers(ctx context.Context, params GetDatabaseServersParams) ([]types.DatabaseServer, error) { - servers, err := params.DatabaseServersGetter.GetDatabaseServers(ctx, apidefaults.Namespace) + result, err := params.Watcher.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { + return ds.GetDatabase().GetName() == params.Identity.RouteToDatabase.ServiceName + }) if err != nil { return nil, trace.Wrap(err) } - // Find out which database servers proxy the database a user is - // connecting to using routing information from identity. - var result []types.DatabaseServer - for _, server := range servers { - if server.GetDatabase().GetName() == params.Identity.RouteToDatabase.ServiceName { - result = append(result, server) - } - } - - params.Logger.DebugContext(ctx, "Available database servers", + params.Logger.DebugContext(ctx, "Retrieved database servers from watcher", "cluster", params.ClusterName, "servers", logutils.StringerSliceAttr(result), ) diff --git a/lib/srv/db/common/connect/connect_bench_test.go b/lib/srv/db/common/connect/connect_bench_test.go new file mode 100644 index 0000000000000..60a761194bf20 --- /dev/null +++ b/lib/srv/db/common/connect/connect_bench_test.go @@ -0,0 +1,121 @@ +// Teleport +// Copyright (C) 2026 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connect + +import ( + "fmt" + "log/slog" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth/authtest" + "github.com/gravitational/teleport/lib/reversetunnelclient" + "github.com/gravitational/teleport/lib/tlsca" +) + +func createBenchmarkDatabaseServers(b *testing.B, total int, targetName string) []types.DatabaseServer { + b.Helper() + + servers := make([]types.DatabaseServer, 0, total) + for i := range total { + dbName := fmt.Sprintf("db-%d", i) + if i == total-1 { + dbName = targetName + } + + server, err := types.NewDatabaseServerV3(types.Metadata{ + Name: fmt.Sprintf("db-server-%d", i), + }, types.DatabaseServerSpecV3{ + Hostname: "localhost", + HostID: fmt.Sprintf("host-%d", i), + Database: &types.DatabaseV3{ + Metadata: types.Metadata{ + Name: dbName, + }, + Spec: types.DatabaseSpecV3{ + Protocol: types.DatabaseProtocolPostgreSQL, + URI: "localhost", + }, + }, + }) + require.NoError(b, err) + servers = append(servers, server) + } + + return servers +} + +// BenchmarkConnectGetDatabaseServers measures the memory usage of GetDatabaseServers +// with varying numbers of database servers in the cluster. +func BenchmarkConnectGetDatabaseServers(b *testing.B) { + const ( + matchCount = 1 + clusterName = "cluster-1" + targetName = "db-target" + ) + + totals := []int{ + 1000, + 5000, + 10000, + } + + for _, total := range totals { + b.Run(fmt.Sprintf("total=%d", total), func(sb *testing.B) { + sb.ReportAllocs() + + servers := createBenchmarkDatabaseServers(sb, total, targetName) + + authServer, err := authtest.NewAuthServer(authtest.AuthServerConfig{ + ClusterName: clusterName, + Dir: sb.TempDir(), + }) + require.NoError(sb, err) + sb.Cleanup(func() { require.NoError(sb, authServer.Close()) }) + + for _, server := range servers { + _, err := authServer.AuthServer.UpsertDatabaseServer(b.Context(), server) + require.NoError(sb, err) + } + + cluster := reversetunnelclient.NewFakeCluster(clusterName, authServer.AuthServer) + sb.Cleanup(func() { require.NoError(sb, cluster.Close()) }) + + watcher, err := cluster.DatabaseServerWatcher() + require.NoError(sb, err) + require.NoError(sb, watcher.WaitInitialization()) + + params := GetDatabaseServersParams{ + Logger: slog.Default(), + ClusterName: clusterName, + Watcher: watcher, + Identity: tlsca.Identity{ + RouteToDatabase: tlsca.RouteToDatabase{ServiceName: targetName}, + RouteToCluster: clusterName, + }, + } + + for sb.Loop() { + result, err := GetDatabaseServers(b.Context(), params) + require.NoError(sb, err) + require.Len(sb, result, matchCount) + } + }) + } +} diff --git a/lib/srv/db/common/connect/connect_test.go b/lib/srv/db/common/connect/connect_test.go index b5fdbe916db49..1751b288b52b0 100644 --- a/lib/srv/db/common/connect/connect_test.go +++ b/lib/srv/db/common/connect/connect_test.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "net" "net/netip" + "os" "strings" "testing" "time" @@ -36,44 +37,57 @@ import ( "github.com/gravitational/teleport/lib/auth/authtest" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/reversetunnelclient" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils/log/logtest" ) +func TestMain(m *testing.M) { + logtest.InitLogger(testing.Verbose) + os.Exit(m.Run()) +} + func TestGetDatabaseServers(t *testing.T) { + const clusterName = "root" + for name, tc := range map[string]struct { identity tlsca.Identity - getter *databaseServersMock + watcher *mockDatabaseServerWatcher expectErrorFunc require.ErrorAssertionFunc expectedServersLen int }{ "match": { identity: identityWithDatabase("matched-db", "root", "alice", nil), - getter: newDatabaseServersWithServers("no-match", "matched-db", "another-db"), + watcher: newMockWatcherWithServers("matched-db", "other-db"), expectErrorFunc: require.NoError, expectedServersLen: 1, }, + "multiple agents for same database": { + identity: identityWithDatabase("matched-db", "root", "alice", nil), + watcher: newMockWatcherWithServers("matched-db", "matched-db", "other-db"), + expectErrorFunc: require.NoError, + expectedServersLen: 2, + }, "no match": { identity: identityWithDatabase("no-match", "root", "alice", nil), - getter: newDatabaseServersWithServers("first", "second", "third"), + watcher: newMockWatcherWithServers("matched-db", "other-db"), expectErrorFunc: func(tt require.TestingT, err error, i ...any) { require.Error(t, err) require.True(t, trace.IsNotFound(err), "expected trace.NotFound error but got %T", err) }, }, - "get server error": { - identity: identityWithDatabase("no-match", "root", "alice", nil), - getter: newDatabaseServersWithErr(trace.Errorf("failure")), + "watcher error": { + identity: identityWithDatabase("matched-db", "root", "alice", nil), + watcher: newMockWatcherWithErr(trace.Errorf("failure")), expectErrorFunc: require.Error, }, } { t.Run(name, func(t *testing.T) { servers, err := GetDatabaseServers(context.Background(), GetDatabaseServersParams{ - Logger: logtest.NewLogger(), - ClusterName: "root", - DatabaseServersGetter: tc.getter, - Identity: tc.identity, + Logger: logtest.NewLogger(), + ClusterName: clusterName, + Watcher: tc.watcher, + Identity: tc.identity, }) tc.expectErrorFunc(t, err) require.Len(t, servers, tc.expectedServersLen) @@ -226,11 +240,26 @@ func identityWithDatabase(name, clusterName, user string, roles []string) tlsca. } } -type databaseServersMock struct { +// mockDatabaseServerWatcher implements DatabaseServerWatcher for tests. +type mockDatabaseServerWatcher struct { servers []types.DatabaseServer err error } +func (m *mockDatabaseServerWatcher) CurrentResourcesWithFilter(_ context.Context, filter func(readonly.DatabaseServer) bool) ([]types.DatabaseServer, error) { + if m.err != nil { + return nil, m.err + } + + var out []types.DatabaseServer + for _, s := range m.servers { + if filter(s) { + out = append(out, s) + } + } + return out, nil +} + func databaseServerWithName(name, hostId string) types.DatabaseServer { return &types.DatabaseServerV3{ Spec: types.DatabaseServerSpecV3{ @@ -245,21 +274,17 @@ func databaseServerWithName(name, hostId string) types.DatabaseServer { } } -func newDatabaseServersWithServers(dbNames ...string) *databaseServersMock { +func newMockWatcherWithServers(dbNames ...string) *mockDatabaseServerWatcher { var servers []types.DatabaseServer for _, name := range dbNames { servers = append(servers, databaseServerWithName(name, uuid.New().String())) } - return &databaseServersMock{servers: servers} -} - -func newDatabaseServersWithErr(err error) *databaseServersMock { - return &databaseServersMock{err: err} + return &mockDatabaseServerWatcher{servers: servers} } -func (d *databaseServersMock) GetDatabaseServers(_ context.Context, _ string, _ ...services.MarshalOption) ([]types.DatabaseServer, error) { - return d.servers, d.err +func newMockWatcherWithErr(err error) *mockDatabaseServerWatcher { + return &mockDatabaseServerWatcher{err: err} } func newDialerMock(t *testing.T, authServer *auth.Server, dbName string, availableServers []string, unavailableServers []string) *dialerMock { diff --git a/lib/srv/db/proxyserver.go b/lib/srv/db/proxyserver.go index 1f26c82a254ee..bc0ffba2361f9 100644 --- a/lib/srv/db/proxyserver.go +++ b/lib/srv/db/proxyserver.go @@ -522,15 +522,15 @@ func (s *ProxyServer) Authorize(ctx context.Context, tlsConn utils.TLSConn, para if err != nil { return nil, trace.Wrap(err) } - accessPoint, err := cluster.CachingAccessPoint() + watcher, err := cluster.DatabaseServerWatcher() if err != nil { return nil, trace.Wrap(err) } servers, err := connect.GetDatabaseServers(ctx, connect.GetDatabaseServersParams{ - Logger: s.log, - ClusterName: cluster.GetName(), - DatabaseServersGetter: accessPoint, - Identity: identity, + Logger: s.log, + Watcher: watcher, + ClusterName: cluster.GetName(), + Identity: identity, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index b663efed6307a..7eb3af4d099ae 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -1803,6 +1803,7 @@ func TestProxyRoundRobin(t *testing.T) { NodeWatcher: nodeWatcher, GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), AppServerWatcher: newAppServerWatcher(ctx, t, proxyClient), + DatabaseServerWatcher: newDatabaseServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) { @@ -1947,6 +1948,7 @@ func TestProxyDirectAccess(t *testing.T) { NodeWatcher: nodeWatcher, GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), AppServerWatcher: newAppServerWatcher(ctx, t, proxyClient), + DatabaseServerWatcher: newDatabaseServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) { @@ -2626,6 +2628,7 @@ func TestParseSubsystemRequest(t *testing.T) { NodeWatcher: nodeWatcher, GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), AppServerWatcher: newAppServerWatcher(ctx, t, proxyClient), + DatabaseServerWatcher: newDatabaseServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) { return nil, errors.New("eice disabled in tests") @@ -2887,6 +2890,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { NodeWatcher: nodeWatcher, GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), AppServerWatcher: newAppServerWatcher(ctx, t, proxyClient), + DatabaseServerWatcher: newDatabaseServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) { return nil, errors.New("eice disabled in tests") @@ -3264,6 +3268,21 @@ func newAppServerWatcher(ctx context.Context, t *testing.T, client *authclient.C return appServerWatcher } +func newDatabaseServerWatcher(ctx context.Context, t *testing.T, client *authclient.Client) *services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer] { + t.Helper() + + databaseServerWatcher, err := services.NewDatabaseServerWatcher(ctx, services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: client, + }, + }) + + require.NoError(t, err) + t.Cleanup(databaseServerWatcher.Close) + return databaseServerWatcher +} + // newSigner creates a new SSH signer that can be used by the Server. func newSigner(t testing.TB, ctx context.Context, testServer *authtest.Server) ssh.Signer { t.Helper() @@ -3336,6 +3355,7 @@ func TestHostUserCreationProxy(t *testing.T) { NodeWatcher: nodeWatcher, GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), AppServerWatcher: newAppServerWatcher(ctx, t, proxyClient), + DatabaseServerWatcher: newDatabaseServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 92d0647a51673..84835eec61300 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -466,6 +466,15 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { require.NoError(t, err) t.Cleanup(appServerWatcher.Close) + databaseServerWatcher, err := services.NewDatabaseServerWatcher(ctx, services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.proxyClient, + }, + }) + require.NoError(t, err) + t.Cleanup(databaseServerWatcher.Close) + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -483,6 +492,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { NodeWatcher: proxyNodeWatcher, GitServerWatcher: proxyGitServerWatcher, AppServerWatcher: appServerWatcher, + DatabaseServerWatcher: databaseServerWatcher, CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), LocalAuthAddresses: []string{s.server.TLS.Listener.Addr().String()}, @@ -8607,6 +8617,15 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(appServerWatcher.Close) + databaseServerWatcher, err := services.NewDatabaseServerWatcher(ctx, services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: client, + }, + }) + require.NoError(t, err) + t.Cleanup(databaseServerWatcher.Close) + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -8625,6 +8644,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula GitServerWatcher: proxyGitServerWatcher, CertAuthorityWatcher: proxyCAWatcher, AppServerWatcher: appServerWatcher, + DatabaseServerWatcher: databaseServerWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), LocalAuthAddresses: []string{authServer.Listener.Addr().String()}, EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) { From a8a5ba7f3b2e68472da4372aaa0bac6826842c1f Mon Sep 17 00:00:00 2001 From: Tyler Richardson Date: Thu, 19 Feb 2026 15:24:53 -0500 Subject: [PATCH 2/9] Populate entire config to better express intent and not rely on defaults --- lib/service/service.go | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/service/service.go b/lib/service/service.go index ba39ac400f74a..60e1a11be80d8 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4980,6 +4980,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Logger: process.logger.With(teleport.ComponentKey, teleport.ComponentProxy), Client: accessPoint, }, + DatabaseServersGetter: accessPoint, }) if err != nil { return trace.Wrap(err) From 51a3014cfe03d0a512a504b106642f0fa716098b Mon Sep 17 00:00:00 2001 From: Tyler Richardson Date: Thu, 19 Feb 2026 15:27:10 -0500 Subject: [PATCH 3/9] Redirect log output to /dev/null --- lib/srv/db/common/connect/connect_bench_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/srv/db/common/connect/connect_bench_test.go b/lib/srv/db/common/connect/connect_bench_test.go index 60a761194bf20..f0bfde4c48b10 100644 --- a/lib/srv/db/common/connect/connect_bench_test.go +++ b/lib/srv/db/common/connect/connect_bench_test.go @@ -102,7 +102,7 @@ func BenchmarkConnectGetDatabaseServers(b *testing.B) { require.NoError(sb, watcher.WaitInitialization()) params := GetDatabaseServersParams{ - Logger: slog.Default(), + Logger: slog.New(slog.DiscardHandler), ClusterName: clusterName, Watcher: watcher, Identity: tlsca.Identity{ From e1e96148a81e3fd1892c0d416a9c7ceeb231bcb5 Mon Sep 17 00:00:00 2001 From: Tyler Richardson Date: Thu, 19 Feb 2026 15:54:15 -0500 Subject: [PATCH 4/9] Simplify Benchmark test harness Remove auth server and fake cluster. Instantiate an in-memory backend, presence service, event service, and watcher directly in the benchmark. --- .../db/common/connect/connect_bench_test.go | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/lib/srv/db/common/connect/connect_bench_test.go b/lib/srv/db/common/connect/connect_bench_test.go index f0bfde4c48b10..cd06c2f88ce87 100644 --- a/lib/srv/db/common/connect/connect_bench_test.go +++ b/lib/srv/db/common/connect/connect_bench_test.go @@ -20,12 +20,14 @@ import ( "fmt" "log/slog" "testing" + "time" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/auth/authtest" - "github.com/gravitational/teleport/lib/reversetunnelclient" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/tlsca" ) @@ -82,23 +84,26 @@ func BenchmarkConnectGetDatabaseServers(b *testing.B) { servers := createBenchmarkDatabaseServers(sb, total, targetName) - authServer, err := authtest.NewAuthServer(authtest.AuthServerConfig{ - ClusterName: clusterName, - Dir: sb.TempDir(), - }) + backend, err := memory.New(memory.Config{Context: sb.Context()}) require.NoError(sb, err) - sb.Cleanup(func() { require.NoError(sb, authServer.Close()) }) + presenceService := local.NewPresenceService(backend) for _, server := range servers { - _, err := authServer.AuthServer.UpsertDatabaseServer(b.Context(), server) + _, err = presenceService.UpsertDatabaseServer(sb.Context(), server) require.NoError(sb, err) } - cluster := reversetunnelclient.NewFakeCluster(clusterName, authServer.AuthServer) - sb.Cleanup(func() { require.NoError(sb, cluster.Close()) }) - - watcher, err := cluster.DatabaseServerWatcher() + watcher, err := services.NewDatabaseServerWatcher(sb.Context(), services.DatabaseServerWatcherConfig{ + DatabaseServersGetter: presenceService, + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "bench", + MaxRetryPeriod: 200 * time.Millisecond, + Client: local.NewEventsService(backend), + }, + }) require.NoError(sb, err) + sb.Cleanup(watcher.Close) + require.NoError(sb, watcher.WaitInitialization()) params := GetDatabaseServersParams{ From 1586c0547f0351bcad226eb1a00860ebc6317f8d Mon Sep 17 00:00:00 2001 From: Tyler Richardson Date: Fri, 20 Feb 2026 12:00:37 -0500 Subject: [PATCH 5/9] Remove test cases to lower pressure in CI --- .../db/common/connect/connect_bench_test.go | 83 +++++++++---------- 1 file changed, 38 insertions(+), 45 deletions(-) diff --git a/lib/srv/db/common/connect/connect_bench_test.go b/lib/srv/db/common/connect/connect_bench_test.go index cd06c2f88ce87..aec1428f79dec 100644 --- a/lib/srv/db/common/connect/connect_bench_test.go +++ b/lib/srv/db/common/connect/connect_bench_test.go @@ -70,57 +70,50 @@ func BenchmarkConnectGetDatabaseServers(b *testing.B) { matchCount = 1 clusterName = "cluster-1" targetName = "db-target" + total = 1000 ) - totals := []int{ - 1000, - 5000, - 10000, - } - - for _, total := range totals { - b.Run(fmt.Sprintf("total=%d", total), func(sb *testing.B) { - sb.ReportAllocs() + b.Run(fmt.Sprintf("total=%d", total), func(sb *testing.B) { + sb.ReportAllocs() - servers := createBenchmarkDatabaseServers(sb, total, targetName) + servers := createBenchmarkDatabaseServers(sb, total, targetName) - backend, err := memory.New(memory.Config{Context: sb.Context()}) - require.NoError(sb, err) + backend, err := memory.New(memory.Config{Context: sb.Context()}) + require.NoError(sb, err) - presenceService := local.NewPresenceService(backend) - for _, server := range servers { - _, err = presenceService.UpsertDatabaseServer(sb.Context(), server) - require.NoError(sb, err) - } - - watcher, err := services.NewDatabaseServerWatcher(sb.Context(), services.DatabaseServerWatcherConfig{ - DatabaseServersGetter: presenceService, - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: "bench", - MaxRetryPeriod: 200 * time.Millisecond, - Client: local.NewEventsService(backend), - }, - }) + presenceService := local.NewPresenceService(backend) + for _, server := range servers { + _, err = presenceService.UpsertDatabaseServer(sb.Context(), server) require.NoError(sb, err) - sb.Cleanup(watcher.Close) - - require.NoError(sb, watcher.WaitInitialization()) - - params := GetDatabaseServersParams{ - Logger: slog.New(slog.DiscardHandler), - ClusterName: clusterName, - Watcher: watcher, - Identity: tlsca.Identity{ - RouteToDatabase: tlsca.RouteToDatabase{ServiceName: targetName}, - RouteToCluster: clusterName, - }, - } + } - for sb.Loop() { - result, err := GetDatabaseServers(b.Context(), params) - require.NoError(sb, err) - require.Len(sb, result, matchCount) - } + watcher, err := services.NewDatabaseServerWatcher(sb.Context(), services.DatabaseServerWatcherConfig{ + DatabaseServersGetter: presenceService, + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "bench", + MaxRetryPeriod: 200 * time.Millisecond, + Client: local.NewEventsService(backend), + }, }) - } + require.NoError(sb, err) + sb.Cleanup(watcher.Close) + + require.NoError(sb, watcher.WaitInitialization()) + + params := GetDatabaseServersParams{ + Logger: slog.New(slog.DiscardHandler), + ClusterName: clusterName, + Watcher: watcher, + Identity: tlsca.Identity{ + RouteToDatabase: tlsca.RouteToDatabase{ServiceName: targetName}, + RouteToCluster: clusterName, + }, + } + + for sb.Loop() { + result, err := GetDatabaseServers(b.Context(), params) + require.NoError(sb, err) + require.Len(sb, result, matchCount) + } + }) } From ad1d62ad7f30ea436bc70f58eb6bc09f2109eb1a Mon Sep 17 00:00:00 2001 From: Tyler Richardson Date: Fri, 20 Feb 2026 12:46:27 -0500 Subject: [PATCH 6/9] Add godoc --- lib/reversetunnel/peer.go | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index d3abcaec338f9..c8ff1764ec7b8 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -114,6 +114,7 @@ func (p *expectedLeafClusters) GitServerWatcher() (*services.GenericWatcher[type return cluster.GitServerWatcher() } +// DatabaseServerWatcher returns a watcher for database servers in the leaf cluster. func (p *expectedLeafClusters) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { cluster, err := p.pickCluster() if err != nil { From 0d4d47ef7b5b447a1a5b5f28dae3f81f38eeb0d7 Mon Sep 17 00:00:00 2001 From: Tyler Richardson Date: Mon, 2 Mar 2026 22:34:48 -0500 Subject: [PATCH 7/9] Update key funcs to align with actual usage --- lib/services/watcher.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 9783ad32201c9..c992fa3903532 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -581,7 +581,14 @@ func NewDatabaseServerWatcher(ctx context.Context, cfg DatabaseServerWatcherConf w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.DatabaseServer, readonly.DatabaseServer]{ ResourceWatcherConfig: cfg.ResourceWatcherConfig, ResourceKind: types.KindDatabaseServer, - ResourceKey: func(r types.DatabaseServer) string { return r.GetHostID() + r.GetName() }, + ResourceKey: func(r types.DatabaseServer) string { + // the host ID is guaranteed not to contain "/" + return r.GetHostID() + "/" + r.GetName() + }, + DeleteKey: func(r types.Resource) string { + // database servers put the host ID in the description in delete events + return r.GetMetadata().Description + "/" + r.GetName() + }, ResourceGetter: func(ctx context.Context) ([]types.DatabaseServer, error) { return cfg.DatabaseServersGetter.GetDatabaseServers(ctx, apidefaults.Namespace) }, From 8b6f34031da321cb0542104f90b4f7d6a3185041 Mon Sep 17 00:00:00 2001 From: Tyler Richardson Date: Tue, 3 Mar 2026 18:45:35 -0500 Subject: [PATCH 8/9] Move test body into named function, remove parallel --- lib/services/watcher_test.go | 127 ++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 63 deletions(-) diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 165915824c22a..82c720ea53c3f 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -760,82 +760,83 @@ func newApp(t *testing.T, name string) *types.AppV3 { // TestDatabaseServerWatcher tests that database server resource watcher properly // receives and dispatches updates. func TestDatabaseServerWatcher(t *testing.T) { - t.Parallel() - synctest.Test(t, func(t *testing.T) { - ctx := t.Context() + synctest.Test(t, syncTestDatabaseServerWatcher) +} - bk, err := memory.New(memory.Config{Context: ctx}) - require.NoError(t, err) +func syncTestDatabaseServerWatcher(t *testing.T) { + ctx := t.Context() - type client struct { - services.DatabaseServersGetter - types.Events - } + bk, err := memory.New(memory.Config{Context: ctx}) + require.NoError(t, err) - presenceService := local.NewPresenceService(bk) - w, err := services.NewDatabaseServerWatcher(ctx, services.DatabaseServerWatcherConfig{ - DatabaseServersGetter: presenceService, - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: "test", - MaxRetryPeriod: 200 * time.Millisecond, - Client: &client{ - DatabaseServersGetter: presenceService, - Events: local.NewEventsService(bk), - }, + type client struct { + services.DatabaseServersGetter + types.Events + } + + presenceService := local.NewPresenceService(bk) + w, err := services.NewDatabaseServerWatcher(ctx, services.DatabaseServerWatcherConfig{ + DatabaseServersGetter: presenceService, + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + MaxRetryPeriod: 200 * time.Millisecond, + Client: &client{ + DatabaseServersGetter: presenceService, + Events: local.NewEventsService(bk), }, - }) - require.NoError(t, err) - t.Cleanup(w.Close) + }, + }) + require.NoError(t, err) + t.Cleanup(w.Close) - // Wait for initial load. - require.NoError(t, w.WaitInitialization()) + // Wait for initial load. + require.NoError(t, w.WaitInitialization()) - // Initially there are no database servers. - servers, err := w.CurrentResources(ctx) - require.NoError(t, err) - require.Empty(t, servers) + // Initially there are no database servers. + servers, err := w.CurrentResources(ctx) + require.NoError(t, err) + require.Empty(t, servers) - // Add a database server and wait for the watcher to process the event. - server1 := newDatabaseServer(t, "db1", "host1") - _, err = presenceService.UpsertDatabaseServer(ctx, server1) - require.NoError(t, err) - synctest.Wait() + // Add a database server and wait for the watcher to process the event. + server1 := newDatabaseServer(t, "db1", "host1") + _, err = presenceService.UpsertDatabaseServer(ctx, server1) + require.NoError(t, err) + synctest.Wait() - servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { - return ds.GetDatabase().GetName() == "db1" - }) - require.NoError(t, err) - require.Len(t, servers, 1) - require.Equal(t, server1.GetName(), servers[0].GetName()) + servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { + return ds.GetDatabase().GetName() == "db1" + }) + require.NoError(t, err) + require.Len(t, servers, 1) + require.Equal(t, server1.GetName(), servers[0].GetName()) - // Add a second database server and wait for the watcher to process the event. - server2 := newDatabaseServer(t, "db2", "host2") - _, err = presenceService.UpsertDatabaseServer(ctx, server2) - require.NoError(t, err) - synctest.Wait() + // Add a second database server and wait for the watcher to process the event. + server2 := newDatabaseServer(t, "db2", "host2") + _, err = presenceService.UpsertDatabaseServer(ctx, server2) + require.NoError(t, err) + synctest.Wait() - servers, err = w.CurrentResources(ctx) - require.NoError(t, err) - require.Len(t, servers, 2) + servers, err = w.CurrentResources(ctx) + require.NoError(t, err) + require.Len(t, servers, 2) - servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { - return ds.GetDatabase().GetName() == "db2" - }) - require.NoError(t, err) - require.Len(t, servers, 1) - require.Equal(t, server2.GetName(), servers[0].GetName()) + servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { + return ds.GetDatabase().GetName() == "db2" + }) + require.NoError(t, err) + require.Len(t, servers, 1) + require.Equal(t, server2.GetName(), servers[0].GetName()) - // Delete the first database server and wait for the watcher to process the event. - err = presenceService.DeleteDatabaseServer(ctx, apidefaults.Namespace, server1.GetHostID(), server1.GetName()) - require.NoError(t, err) - synctest.Wait() + // Delete the first database server and wait for the watcher to process the event. + err = presenceService.DeleteDatabaseServer(ctx, apidefaults.Namespace, server1.GetHostID(), server1.GetName()) + require.NoError(t, err) + synctest.Wait() - // Verify the remaining server is server2. - servers, err = w.CurrentResources(ctx) - require.NoError(t, err) - require.Len(t, servers, 1) - require.Equal(t, server2.GetName(), servers[0].GetName()) - }) + // Verify the remaining server is server2. + servers, err = w.CurrentResources(ctx) + require.NoError(t, err) + require.Len(t, servers, 1) + require.Equal(t, server2.GetName(), servers[0].GetName()) } func newDatabaseServer(t *testing.T, dbName, hostID string) types.DatabaseServer { From 45855cc7882e40c60deb553117bc91eac109f662 Mon Sep 17 00:00:00 2001 From: Tyler Richardson Date: Wed, 4 Mar 2026 14:37:41 -0500 Subject: [PATCH 9/9] Refactor DatabaseServer interface to struct + constructor for read-only variant --- lib/services/readonly/readonly.go | 52 ++++++++--------------- lib/services/watcher.go | 4 +- lib/services/watcher_test.go | 4 +- lib/srv/db/common/connect/connect.go | 2 +- lib/srv/db/common/connect/connect_test.go | 2 +- 5 files changed, 23 insertions(+), 41 deletions(-) diff --git a/lib/services/readonly/readonly.go b/lib/services/readonly/readonly.go index 0cb5ce357e6f1..c5c0ea85e73b5 100644 --- a/lib/services/readonly/readonly.go +++ b/lib/services/readonly/readonly.go @@ -293,42 +293,26 @@ type AppServer interface { var _ AppServer = types.AppServer(nil) // DatabaseServer is a read only variant of [types.DatabaseServer] -type DatabaseServer interface { - // ResourceWithLabels provides common resource methods. - ResourceWithLabels - // GetNamespace returns server namespace. - GetNamespace() string - // GetTeleportVersion returns the teleport version the server is running on. - GetTeleportVersion() string - // GetHostname returns the server hostname. - GetHostname() string - // GetHostID returns ID of the host the server is running on. - GetHostID() string - // GetRotation gets the state of certificate authority rotation. - GetRotation() types.Rotation - // String returns string representation of the server. - String() string - // Copy returns a copy of this database server object. - Copy() types.DatabaseServer - // GetDatabase returns the database this database server proxies. - GetDatabase() types.Database - // ProxiedService provides common methods for a proxied service. - ProxiedService - // GetRelayGroup returns the name of the Relay group that the database - // server is connected to. - GetRelayGroup() string - // GetRelayIDs returns the list of Relay host IDs that the database server - // is connected to. - GetRelayIDs() []string - // GetTargetHealth returns the database server's target health. - GetTargetHealth() types.TargetHealth - // GetTargetHealthStatus returns target health status - GetTargetHealthStatus() types.TargetHealthStatus - // GetScope returns the scope this server belongs to. - GetScope() string +type DatabaseServer struct { + inner types.DatabaseServer +} + +// GetDatabaseName returns the name of the database this server is proxying. +func (d DatabaseServer) GetDatabaseName() string { + if d.inner == nil { + return "" + } + db := d.inner.GetDatabase() + if db == nil { + return "" + } + return db.GetName() } -var _ DatabaseServer = types.DatabaseServer(nil) +// NewDatabaseServer returns a new read-only DatabaseServer. +func NewDatabaseServer(server types.DatabaseServer) DatabaseServer { + return DatabaseServer{inner: server} +} // KubeServer is a read only variant of [types.KubeServer]. type KubeServer interface { diff --git a/lib/services/watcher.go b/lib/services/watcher.go index c992fa3903532..57860b82d3644 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -594,9 +594,7 @@ func NewDatabaseServerWatcher(ctx context.Context, cfg DatabaseServerWatcherConf }, DisableUpdateBroadcast: true, CloneFunc: types.DatabaseServer.Copy, - ReadOnlyFunc: func(resource types.DatabaseServer) readonly.DatabaseServer { - return resource - }, + ReadOnlyFunc: readonly.NewDatabaseServer, }) return w, trace.Wrap(err) diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 82c720ea53c3f..861bb7e0d9809 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -804,7 +804,7 @@ func syncTestDatabaseServerWatcher(t *testing.T) { synctest.Wait() servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { - return ds.GetDatabase().GetName() == "db1" + return ds.GetDatabaseName() == "db1" }) require.NoError(t, err) require.Len(t, servers, 1) @@ -821,7 +821,7 @@ func syncTestDatabaseServerWatcher(t *testing.T) { require.Len(t, servers, 2) servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { - return ds.GetDatabase().GetName() == "db2" + return ds.GetDatabaseName() == "db2" }) require.NoError(t, err) require.Len(t, servers, 1) diff --git a/lib/srv/db/common/connect/connect.go b/lib/srv/db/common/connect/connect.go index 5f6c2734c2cb3..b27bea26451ca 100644 --- a/lib/srv/db/common/connect/connect.go +++ b/lib/srv/db/common/connect/connect.go @@ -64,7 +64,7 @@ type GetDatabaseServersParams struct { // DatabaseServerWatcher for fast in-memory lookup. func GetDatabaseServers(ctx context.Context, params GetDatabaseServersParams) ([]types.DatabaseServer, error) { result, err := params.Watcher.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { - return ds.GetDatabase().GetName() == params.Identity.RouteToDatabase.ServiceName + return ds.GetDatabaseName() == params.Identity.RouteToDatabase.ServiceName }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/db/common/connect/connect_test.go b/lib/srv/db/common/connect/connect_test.go index 1751b288b52b0..df4becdaa0f6c 100644 --- a/lib/srv/db/common/connect/connect_test.go +++ b/lib/srv/db/common/connect/connect_test.go @@ -253,7 +253,7 @@ func (m *mockDatabaseServerWatcher) CurrentResourcesWithFilter(_ context.Context var out []types.DatabaseServer for _, s := range m.servers { - if filter(s) { + if filter(readonly.NewDatabaseServer(s)) { out = append(out, s) } }