diff --git a/lib/reversetunnel/leaf_cluster.go b/lib/reversetunnel/leaf_cluster.go index 604b95f0f9553..15e11ccddf471 100644 --- a/lib/reversetunnel/leaf_cluster.go +++ b/lib/reversetunnel/leaf_cluster.go @@ -92,6 +92,9 @@ type leafCluster struct { // appServerWatcher is a app server watcher. appServerWatcher *services.GenericWatcher[types.AppServer, readonly.AppServer] + // databaseServerWatcher is a database server watcher. + databaseServerWatcher *services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer] + // remoteCA is the last remote certificate authority recorded by the client. // It is used to detect CA rotation status changes. If the rotation // state has been changed, the tunnel will reconnect to re-create the client @@ -182,6 +185,11 @@ func (s *leafCluster) GitServerWatcher() (*services.GenericWatcher[types.Server, return nil, trace.NotImplemented("GitServerWatcher not implemented for leafCluster") } +// DatabaseServerWatcher returns the Database server watcher for the leaf cluster. +func (s *leafCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + return s.databaseServerWatcher, nil +} + func (s *leafCluster) GetClient() (authclient.ClientI, error) { return s.leafClient, nil } diff --git a/lib/reversetunnel/local_cluster.go b/lib/reversetunnel/local_cluster.go index 6f9cf897d4bf3..6e289c82c85b8 100644 --- a/lib/reversetunnel/local_cluster.go +++ b/lib/reversetunnel/local_cluster.go @@ -193,6 +193,11 @@ func (s *localCluster) GitServerWatcher() (*services.GenericWatcher[types.Server return s.srv.GitServerWatcher, nil } +// DatabaseServerWatcher returns a Database server watcher for this cluster. +func (s *localCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + return s.srv.DatabaseServerWatcher, nil +} + // GetClient returns a client to the full Auth Server API. func (s *localCluster) GetClient() (authclient.ClientI, error) { return s.client, nil diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index 3c351e1af612b..c8ff1764ec7b8 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -114,6 +114,15 @@ func (p *expectedLeafClusters) GitServerWatcher() (*services.GenericWatcher[type return cluster.GitServerWatcher() } +// DatabaseServerWatcher returns a watcher for database servers in the leaf cluster. +func (p *expectedLeafClusters) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + cluster, err := p.pickCluster() + if err != nil { + return nil, trace.Wrap(err) + } + return cluster.DatabaseServerWatcher() +} + func (p *expectedLeafClusters) GetClient() (authclient.ClientI, error) { cluster, err := p.pickCluster() if err != nil { @@ -227,6 +236,10 @@ func (s *expectedLeafCluster) GitServerWatcher() (*services.GenericWatcher[types return nil, s.discoveryError("unable to fetch git server watcher for leaf cluster") } +func (s *expectedLeafCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + return nil, s.discoveryError("unable to fetch database server watcher for leaf cluster") +} + func (s *expectedLeafCluster) GetClient() (authclient.ClientI, error) { return nil, s.discoveryError("unable to fetch auth client for leaf cluster") } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index bc05edeff5829..54080baca3cc2 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -226,6 +226,9 @@ type Config struct { // AppServerWatcher is a app server watcher. AppServerWatcher *services.GenericWatcher[types.AppServer, readonly.AppServer] + // DatabaseServerWatcher is a database server watcher. + DatabaseServerWatcher *services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer] + // CircuitBreakerConfig configures the auth client circuit breaker CircuitBreakerConfig breaker.Config @@ -305,6 +308,9 @@ func (cfg *Config) CheckAndSetDefaults() error { if cfg.AppServerWatcher == nil { return trace.BadParameter("missing parameter AppServerWatcher") } + if cfg.DatabaseServerWatcher == nil { + return trace.BadParameter("missing parameter DatabaseServerWatcher") + } if cfg.EICEDialer == nil { return trace.BadParameter("missing parameter EICEDialer") @@ -1298,6 +1304,18 @@ func newLeafCluster(srv *server, domainName string, sconn ssh.Conn) (*leafCluste } leaf.appServerWatcher = appServerWatcher + databaseServerWatcher, err := services.NewDatabaseServerWatcher(closeContext, services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: srv.Component, + Logger: srv.Logger, + Client: accessPoint, + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + leaf.databaseServerWatcher = databaseServerWatcher + // instantiate a cache of host certificates for the forwarding server. the // certificate cache is created in each cluster (instead of creating it in // reversetunnel.server and passing it along) so that the host certificate diff --git a/lib/reversetunnelclient/api.go b/lib/reversetunnelclient/api.go index 4c7ad5e9c0f30..5803770ed7f06 100644 --- a/lib/reversetunnelclient/api.go +++ b/lib/reversetunnelclient/api.go @@ -133,6 +133,8 @@ type Cluster interface { GitServerWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) // AppServerWatcher returns the watcher that maintains the app server set for the cluster AppServerWatcher() (*services.GenericWatcher[types.AppServer, readonly.AppServer], error) + // DatabaseServerWatcher returns the watcher that maintains the database server set for the cluster + DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) // GetTunnelsCount returns the amount of active inbound tunnels // from the remote cluster GetTunnelsCount() int diff --git a/lib/reversetunnelclient/fake.go b/lib/reversetunnelclient/fake.go index c4f24ed04e302..efc3d954bc91f 100644 --- a/lib/reversetunnelclient/fake.go +++ b/lib/reversetunnelclient/fake.go @@ -73,6 +73,8 @@ type FakeCluster struct { closed bool // appServerWatcher ia a app server watcher to speed up app look up. appServerWatcher *services.GenericWatcher[types.AppServer, readonly.AppServer] + // databaseServerWatcher is a database server watcher to speed up database server look up. + databaseServerWatcher *services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer] } // NewFakeCluster is a FakeCluster constructor. @@ -84,11 +86,19 @@ func NewFakeCluster(clusterName string, accessPoint authclient.RemoteProxyAccess }, }) + databaseServerWatcher, _ := services.NewDatabaseServerWatcher(context.TODO(), services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "FakeCluster", + Client: accessPoint, + }, + }) + return &FakeCluster{ - Name: clusterName, - connCh: make(chan net.Conn), - AccessPoint: accessPoint, - appServerWatcher: appServerWatcher, + Name: clusterName, + connCh: make(chan net.Conn), + AccessPoint: accessPoint, + appServerWatcher: appServerWatcher, + databaseServerWatcher: databaseServerWatcher, } } @@ -97,6 +107,11 @@ func (s *FakeCluster) AppServerWatcher() (*services.GenericWatcher[types.AppServ return s.appServerWatcher, nil } +// DatabaseServerWatcher returns the watcher that maintains the database server set for the cluster +func (s *FakeCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + return s.databaseServerWatcher, nil +} + // CachingAccessPoint returns caching auth server client. func (s *FakeCluster) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, error) { return s.AccessPoint, nil diff --git a/lib/service/service.go b/lib/service/service.go index 95eb933aaa0cf..60e1a11be80d8 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4974,6 +4974,18 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } + databaseServerWatcher, err := services.NewDatabaseServerWatcher(process.ExitContext(), services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Logger: process.logger.With(teleport.ComponentKey, teleport.ComponentProxy), + Client: accessPoint, + }, + DatabaseServersGetter: accessPoint, + }) + if err != nil { + return trace.Wrap(err) + } + serverTLSConfig, err := process.ServerTLSConfig(conn) if err != nil { return trace.Wrap(err) @@ -5246,6 +5258,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { NodeWatcher: nodeWatcher, AppServerWatcher: appServerWatcher, GitServerWatcher: gitServerWatcher, + DatabaseServerWatcher: databaseServerWatcher, CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: process.Config.CircuitBreakerConfig, LocalAuthAddresses: utils.NetAddrsToStrings(process.Config.AuthServerAddresses()), diff --git a/lib/services/presence.go b/lib/services/presence.go index a5fbcde41079c..2dee0f598c8ca 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -48,7 +48,8 @@ type NodesGetter interface { // DatabaseServersGetter is a service that gets database servers. type DatabaseServersGetter interface { - GetDatabaseServers(context.Context, string, ...MarshalOption) ([]types.DatabaseServer, error) + // GetDatabaseServers returns all registered database proxy servers. + GetDatabaseServers(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.DatabaseServer, error) } // AppServersGetter is a service that gets application servers. diff --git a/lib/services/readonly/readonly.go b/lib/services/readonly/readonly.go index 3100073cb03db..c5c0ea85e73b5 100644 --- a/lib/services/readonly/readonly.go +++ b/lib/services/readonly/readonly.go @@ -292,6 +292,28 @@ type AppServer interface { var _ AppServer = types.AppServer(nil) +// DatabaseServer is a read only variant of [types.DatabaseServer] +type DatabaseServer struct { + inner types.DatabaseServer +} + +// GetDatabaseName returns the name of the database this server is proxying. +func (d DatabaseServer) GetDatabaseName() string { + if d.inner == nil { + return "" + } + db := d.inner.GetDatabase() + if db == nil { + return "" + } + return db.GetName() +} + +// NewDatabaseServer returns a new read-only DatabaseServer. +func NewDatabaseServer(server types.DatabaseServer) DatabaseServer { + return DatabaseServer{inner: server} +} + // KubeServer is a read only variant of [types.KubeServer]. type KubeServer interface { // ResourceWithLabels provides common resource methods. diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 982f74fa317d3..57860b82d3644 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -546,6 +546,60 @@ func NewAppServersWatcher(ctx context.Context, cfg AppServersWatcherConfig) (*Ge return w, trace.Wrap(err) } +type DatabaseServerWatcherConfig struct { + DatabaseServersGetter + ResourceWatcherConfig +} + +// CheckAndSetDefaults checks parameters and sets default values. +func (cfg *DatabaseServerWatcherConfig) CheckAndSetDefaults() error { + if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + + if cfg.MaxStaleness == 0 { + const databaseServerMaxStaleness = time.Minute + cfg.MaxStaleness = databaseServerMaxStaleness + } + + if cfg.DatabaseServersGetter == nil { + getter, ok := cfg.Client.(DatabaseServersGetter) + if !ok { + return trace.BadParameter("missing parameter DatabaseServersGetter and Client not usable as DatabaseServersGetter") + } + cfg.DatabaseServersGetter = getter + } + + return nil +} + +func NewDatabaseServerWatcher(ctx context.Context, cfg DatabaseServerWatcherConfig) (*GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.DatabaseServer, readonly.DatabaseServer]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindDatabaseServer, + ResourceKey: func(r types.DatabaseServer) string { + // the host ID is guaranteed not to contain "/" + return r.GetHostID() + "/" + r.GetName() + }, + DeleteKey: func(r types.Resource) string { + // database servers put the host ID in the description in delete events + return r.GetMetadata().Description + "/" + r.GetName() + }, + ResourceGetter: func(ctx context.Context) ([]types.DatabaseServer, error) { + return cfg.DatabaseServersGetter.GetDatabaseServers(ctx, apidefaults.Namespace) + }, + DisableUpdateBroadcast: true, + CloneFunc: types.DatabaseServer.Copy, + ReadOnlyFunc: readonly.NewDatabaseServer, + }) + + return w, trace.Wrap(err) +} + // KubeServerWatcherConfig is an KubeServerWatcher configuration. type KubeServerWatcherConfig struct { // KubernetesServerGetter is responsible for fetching kube_server resources. diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 3ef18656481d9..861bb7e0d9809 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -757,6 +757,109 @@ func newApp(t *testing.T, name string) *types.AppV3 { return app } +// TestDatabaseServerWatcher tests that database server resource watcher properly +// receives and dispatches updates. +func TestDatabaseServerWatcher(t *testing.T) { + synctest.Test(t, syncTestDatabaseServerWatcher) +} + +func syncTestDatabaseServerWatcher(t *testing.T) { + ctx := t.Context() + + bk, err := memory.New(memory.Config{Context: ctx}) + require.NoError(t, err) + + type client struct { + services.DatabaseServersGetter + types.Events + } + + presenceService := local.NewPresenceService(bk) + w, err := services.NewDatabaseServerWatcher(ctx, services.DatabaseServerWatcherConfig{ + DatabaseServersGetter: presenceService, + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + MaxRetryPeriod: 200 * time.Millisecond, + Client: &client{ + DatabaseServersGetter: presenceService, + Events: local.NewEventsService(bk), + }, + }, + }) + require.NoError(t, err) + t.Cleanup(w.Close) + + // Wait for initial load. + require.NoError(t, w.WaitInitialization()) + + // Initially there are no database servers. + servers, err := w.CurrentResources(ctx) + require.NoError(t, err) + require.Empty(t, servers) + + // Add a database server and wait for the watcher to process the event. + server1 := newDatabaseServer(t, "db1", "host1") + _, err = presenceService.UpsertDatabaseServer(ctx, server1) + require.NoError(t, err) + synctest.Wait() + + servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { + return ds.GetDatabaseName() == "db1" + }) + require.NoError(t, err) + require.Len(t, servers, 1) + require.Equal(t, server1.GetName(), servers[0].GetName()) + + // Add a second database server and wait for the watcher to process the event. + server2 := newDatabaseServer(t, "db2", "host2") + _, err = presenceService.UpsertDatabaseServer(ctx, server2) + require.NoError(t, err) + synctest.Wait() + + servers, err = w.CurrentResources(ctx) + require.NoError(t, err) + require.Len(t, servers, 2) + + servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { + return ds.GetDatabaseName() == "db2" + }) + require.NoError(t, err) + require.Len(t, servers, 1) + require.Equal(t, server2.GetName(), servers[0].GetName()) + + // Delete the first database server and wait for the watcher to process the event. + err = presenceService.DeleteDatabaseServer(ctx, apidefaults.Namespace, server1.GetHostID(), server1.GetName()) + require.NoError(t, err) + synctest.Wait() + + // Verify the remaining server is server2. + servers, err = w.CurrentResources(ctx) + require.NoError(t, err) + require.Len(t, servers, 1) + require.Equal(t, server2.GetName(), servers[0].GetName()) +} + +func newDatabaseServer(t *testing.T, dbName, hostID string) types.DatabaseServer { + t.Helper() + server, err := types.NewDatabaseServerV3(types.Metadata{ + Name: dbName, + }, types.DatabaseServerSpecV3{ + Database: &types.DatabaseV3{ + Metadata: types.Metadata{ + Name: dbName, + }, + Spec: types.DatabaseSpecV3{ + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + }, + }, + HostID: hostID, + Hostname: dbName, + }) + require.NoError(t, err) + return server +} + func TestCertAuthorityWatcher(t *testing.T) { t.Parallel() diff --git a/lib/srv/db/common/connect/connect.go b/lib/srv/db/common/connect/connect.go index 1be2ee49ab776..b27bea26451ca 100644 --- a/lib/srv/db/common/connect/connect.go +++ b/lib/srv/db/common/connect/connect.go @@ -31,22 +31,20 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" - apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/reversetunnelclient" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" logutils "github.com/gravitational/teleport/lib/utils/log" ) -// DatabaseServersGetter is an interface for retrieving information about -// database proxy servers within a specific namespace. -type DatabaseServersGetter interface { - // GetDatabaseServers returns all registered database proxy servers. - GetDatabaseServers(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.DatabaseServer, error) +// DatabaseServerWatcher defines an interface for watching database servers in a cluster. +type DatabaseServerWatcher interface { + // CurrentResourcesWithFilter returns the current list of database servers in the cluster that match the provided filter function. + CurrentResourcesWithFilter(ctx context.Context, filter func(readonly.DatabaseServer) bool) ([]types.DatabaseServer, error) } // GetDatabaseServersParams contains the parameters required to retrieve @@ -55,30 +53,24 @@ type GetDatabaseServersParams struct { Logger *slog.Logger // ClusterName is the cluster name to which the database belongs. ClusterName string - // DatabaseServersGetter used to fetch the list of database servers. - DatabaseServersGetter DatabaseServersGetter + // Watcher is used to retrieve database servers registered in the cluster. + Watcher DatabaseServerWatcher // Identity contains the identity information. Identity tlsca.Identity } // GetDatabaseServers returns a list of database servers in a cluster that match -// the routing information from the provided identity. +// the routing information from the provided identity. It uses the cluster's +// DatabaseServerWatcher for fast in-memory lookup. func GetDatabaseServers(ctx context.Context, params GetDatabaseServersParams) ([]types.DatabaseServer, error) { - servers, err := params.DatabaseServersGetter.GetDatabaseServers(ctx, apidefaults.Namespace) + result, err := params.Watcher.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool { + return ds.GetDatabaseName() == params.Identity.RouteToDatabase.ServiceName + }) if err != nil { return nil, trace.Wrap(err) } - // Find out which database servers proxy the database a user is - // connecting to using routing information from identity. - var result []types.DatabaseServer - for _, server := range servers { - if server.GetDatabase().GetName() == params.Identity.RouteToDatabase.ServiceName { - result = append(result, server) - } - } - - params.Logger.DebugContext(ctx, "Available database servers", + params.Logger.DebugContext(ctx, "Retrieved database servers from watcher", "cluster", params.ClusterName, "servers", logutils.StringerSliceAttr(result), ) diff --git a/lib/srv/db/common/connect/connect_bench_test.go b/lib/srv/db/common/connect/connect_bench_test.go new file mode 100644 index 0000000000000..aec1428f79dec --- /dev/null +++ b/lib/srv/db/common/connect/connect_bench_test.go @@ -0,0 +1,119 @@ +// Teleport +// Copyright (C) 2026 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connect + +import ( + "fmt" + "log/slog" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/tlsca" +) + +func createBenchmarkDatabaseServers(b *testing.B, total int, targetName string) []types.DatabaseServer { + b.Helper() + + servers := make([]types.DatabaseServer, 0, total) + for i := range total { + dbName := fmt.Sprintf("db-%d", i) + if i == total-1 { + dbName = targetName + } + + server, err := types.NewDatabaseServerV3(types.Metadata{ + Name: fmt.Sprintf("db-server-%d", i), + }, types.DatabaseServerSpecV3{ + Hostname: "localhost", + HostID: fmt.Sprintf("host-%d", i), + Database: &types.DatabaseV3{ + Metadata: types.Metadata{ + Name: dbName, + }, + Spec: types.DatabaseSpecV3{ + Protocol: types.DatabaseProtocolPostgreSQL, + URI: "localhost", + }, + }, + }) + require.NoError(b, err) + servers = append(servers, server) + } + + return servers +} + +// BenchmarkConnectGetDatabaseServers measures the memory usage of GetDatabaseServers +// with varying numbers of database servers in the cluster. +func BenchmarkConnectGetDatabaseServers(b *testing.B) { + const ( + matchCount = 1 + clusterName = "cluster-1" + targetName = "db-target" + total = 1000 + ) + + b.Run(fmt.Sprintf("total=%d", total), func(sb *testing.B) { + sb.ReportAllocs() + + servers := createBenchmarkDatabaseServers(sb, total, targetName) + + backend, err := memory.New(memory.Config{Context: sb.Context()}) + require.NoError(sb, err) + + presenceService := local.NewPresenceService(backend) + for _, server := range servers { + _, err = presenceService.UpsertDatabaseServer(sb.Context(), server) + require.NoError(sb, err) + } + + watcher, err := services.NewDatabaseServerWatcher(sb.Context(), services.DatabaseServerWatcherConfig{ + DatabaseServersGetter: presenceService, + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "bench", + MaxRetryPeriod: 200 * time.Millisecond, + Client: local.NewEventsService(backend), + }, + }) + require.NoError(sb, err) + sb.Cleanup(watcher.Close) + + require.NoError(sb, watcher.WaitInitialization()) + + params := GetDatabaseServersParams{ + Logger: slog.New(slog.DiscardHandler), + ClusterName: clusterName, + Watcher: watcher, + Identity: tlsca.Identity{ + RouteToDatabase: tlsca.RouteToDatabase{ServiceName: targetName}, + RouteToCluster: clusterName, + }, + } + + for sb.Loop() { + result, err := GetDatabaseServers(b.Context(), params) + require.NoError(sb, err) + require.Len(sb, result, matchCount) + } + }) +} diff --git a/lib/srv/db/common/connect/connect_test.go b/lib/srv/db/common/connect/connect_test.go index b5fdbe916db49..df4becdaa0f6c 100644 --- a/lib/srv/db/common/connect/connect_test.go +++ b/lib/srv/db/common/connect/connect_test.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "net" "net/netip" + "os" "strings" "testing" "time" @@ -36,44 +37,57 @@ import ( "github.com/gravitational/teleport/lib/auth/authtest" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/reversetunnelclient" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils/log/logtest" ) +func TestMain(m *testing.M) { + logtest.InitLogger(testing.Verbose) + os.Exit(m.Run()) +} + func TestGetDatabaseServers(t *testing.T) { + const clusterName = "root" + for name, tc := range map[string]struct { identity tlsca.Identity - getter *databaseServersMock + watcher *mockDatabaseServerWatcher expectErrorFunc require.ErrorAssertionFunc expectedServersLen int }{ "match": { identity: identityWithDatabase("matched-db", "root", "alice", nil), - getter: newDatabaseServersWithServers("no-match", "matched-db", "another-db"), + watcher: newMockWatcherWithServers("matched-db", "other-db"), expectErrorFunc: require.NoError, expectedServersLen: 1, }, + "multiple agents for same database": { + identity: identityWithDatabase("matched-db", "root", "alice", nil), + watcher: newMockWatcherWithServers("matched-db", "matched-db", "other-db"), + expectErrorFunc: require.NoError, + expectedServersLen: 2, + }, "no match": { identity: identityWithDatabase("no-match", "root", "alice", nil), - getter: newDatabaseServersWithServers("first", "second", "third"), + watcher: newMockWatcherWithServers("matched-db", "other-db"), expectErrorFunc: func(tt require.TestingT, err error, i ...any) { require.Error(t, err) require.True(t, trace.IsNotFound(err), "expected trace.NotFound error but got %T", err) }, }, - "get server error": { - identity: identityWithDatabase("no-match", "root", "alice", nil), - getter: newDatabaseServersWithErr(trace.Errorf("failure")), + "watcher error": { + identity: identityWithDatabase("matched-db", "root", "alice", nil), + watcher: newMockWatcherWithErr(trace.Errorf("failure")), expectErrorFunc: require.Error, }, } { t.Run(name, func(t *testing.T) { servers, err := GetDatabaseServers(context.Background(), GetDatabaseServersParams{ - Logger: logtest.NewLogger(), - ClusterName: "root", - DatabaseServersGetter: tc.getter, - Identity: tc.identity, + Logger: logtest.NewLogger(), + ClusterName: clusterName, + Watcher: tc.watcher, + Identity: tc.identity, }) tc.expectErrorFunc(t, err) require.Len(t, servers, tc.expectedServersLen) @@ -226,11 +240,26 @@ func identityWithDatabase(name, clusterName, user string, roles []string) tlsca. } } -type databaseServersMock struct { +// mockDatabaseServerWatcher implements DatabaseServerWatcher for tests. +type mockDatabaseServerWatcher struct { servers []types.DatabaseServer err error } +func (m *mockDatabaseServerWatcher) CurrentResourcesWithFilter(_ context.Context, filter func(readonly.DatabaseServer) bool) ([]types.DatabaseServer, error) { + if m.err != nil { + return nil, m.err + } + + var out []types.DatabaseServer + for _, s := range m.servers { + if filter(readonly.NewDatabaseServer(s)) { + out = append(out, s) + } + } + return out, nil +} + func databaseServerWithName(name, hostId string) types.DatabaseServer { return &types.DatabaseServerV3{ Spec: types.DatabaseServerSpecV3{ @@ -245,21 +274,17 @@ func databaseServerWithName(name, hostId string) types.DatabaseServer { } } -func newDatabaseServersWithServers(dbNames ...string) *databaseServersMock { +func newMockWatcherWithServers(dbNames ...string) *mockDatabaseServerWatcher { var servers []types.DatabaseServer for _, name := range dbNames { servers = append(servers, databaseServerWithName(name, uuid.New().String())) } - return &databaseServersMock{servers: servers} -} - -func newDatabaseServersWithErr(err error) *databaseServersMock { - return &databaseServersMock{err: err} + return &mockDatabaseServerWatcher{servers: servers} } -func (d *databaseServersMock) GetDatabaseServers(_ context.Context, _ string, _ ...services.MarshalOption) ([]types.DatabaseServer, error) { - return d.servers, d.err +func newMockWatcherWithErr(err error) *mockDatabaseServerWatcher { + return &mockDatabaseServerWatcher{err: err} } func newDialerMock(t *testing.T, authServer *auth.Server, dbName string, availableServers []string, unavailableServers []string) *dialerMock { diff --git a/lib/srv/db/proxyserver.go b/lib/srv/db/proxyserver.go index 1f26c82a254ee..bc0ffba2361f9 100644 --- a/lib/srv/db/proxyserver.go +++ b/lib/srv/db/proxyserver.go @@ -522,15 +522,15 @@ func (s *ProxyServer) Authorize(ctx context.Context, tlsConn utils.TLSConn, para if err != nil { return nil, trace.Wrap(err) } - accessPoint, err := cluster.CachingAccessPoint() + watcher, err := cluster.DatabaseServerWatcher() if err != nil { return nil, trace.Wrap(err) } servers, err := connect.GetDatabaseServers(ctx, connect.GetDatabaseServersParams{ - Logger: s.log, - ClusterName: cluster.GetName(), - DatabaseServersGetter: accessPoint, - Identity: identity, + Logger: s.log, + Watcher: watcher, + ClusterName: cluster.GetName(), + Identity: identity, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index b663efed6307a..7eb3af4d099ae 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -1803,6 +1803,7 @@ func TestProxyRoundRobin(t *testing.T) { NodeWatcher: nodeWatcher, GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), AppServerWatcher: newAppServerWatcher(ctx, t, proxyClient), + DatabaseServerWatcher: newDatabaseServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) { @@ -1947,6 +1948,7 @@ func TestProxyDirectAccess(t *testing.T) { NodeWatcher: nodeWatcher, GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), AppServerWatcher: newAppServerWatcher(ctx, t, proxyClient), + DatabaseServerWatcher: newDatabaseServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) { @@ -2626,6 +2628,7 @@ func TestParseSubsystemRequest(t *testing.T) { NodeWatcher: nodeWatcher, GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), AppServerWatcher: newAppServerWatcher(ctx, t, proxyClient), + DatabaseServerWatcher: newDatabaseServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) { return nil, errors.New("eice disabled in tests") @@ -2887,6 +2890,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { NodeWatcher: nodeWatcher, GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), AppServerWatcher: newAppServerWatcher(ctx, t, proxyClient), + DatabaseServerWatcher: newDatabaseServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) { return nil, errors.New("eice disabled in tests") @@ -3264,6 +3268,21 @@ func newAppServerWatcher(ctx context.Context, t *testing.T, client *authclient.C return appServerWatcher } +func newDatabaseServerWatcher(ctx context.Context, t *testing.T, client *authclient.Client) *services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer] { + t.Helper() + + databaseServerWatcher, err := services.NewDatabaseServerWatcher(ctx, services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: client, + }, + }) + + require.NoError(t, err) + t.Cleanup(databaseServerWatcher.Close) + return databaseServerWatcher +} + // newSigner creates a new SSH signer that can be used by the Server. func newSigner(t testing.TB, ctx context.Context, testServer *authtest.Server) ssh.Signer { t.Helper() @@ -3336,6 +3355,7 @@ func TestHostUserCreationProxy(t *testing.T) { NodeWatcher: nodeWatcher, GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), AppServerWatcher: newAppServerWatcher(ctx, t, proxyClient), + DatabaseServerWatcher: newDatabaseServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 92d0647a51673..84835eec61300 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -466,6 +466,15 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { require.NoError(t, err) t.Cleanup(appServerWatcher.Close) + databaseServerWatcher, err := services.NewDatabaseServerWatcher(ctx, services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.proxyClient, + }, + }) + require.NoError(t, err) + t.Cleanup(databaseServerWatcher.Close) + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -483,6 +492,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { NodeWatcher: proxyNodeWatcher, GitServerWatcher: proxyGitServerWatcher, AppServerWatcher: appServerWatcher, + DatabaseServerWatcher: databaseServerWatcher, CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), LocalAuthAddresses: []string{s.server.TLS.Listener.Addr().String()}, @@ -8607,6 +8617,15 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(appServerWatcher.Close) + databaseServerWatcher, err := services.NewDatabaseServerWatcher(ctx, services.DatabaseServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: client, + }, + }) + require.NoError(t, err) + t.Cleanup(databaseServerWatcher.Close) + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -8625,6 +8644,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula GitServerWatcher: proxyGitServerWatcher, CertAuthorityWatcher: proxyCAWatcher, AppServerWatcher: appServerWatcher, + DatabaseServerWatcher: databaseServerWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), LocalAuthAddresses: []string{authServer.Listener.Addr().String()}, EICESigner: func(ctx context.Context, target types.Server, integration types.Integration, login, token string, ap cryptosuites.AuthPreferenceGetter) (ssh.Signer, error) {