Skip to content
8 changes: 8 additions & 0 deletions lib/reversetunnel/leaf_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ type leafCluster struct {
// appServerWatcher is a app server watcher.
appServerWatcher *services.GenericWatcher[types.AppServer, readonly.AppServer]

// databaseServerWatcher is a database server watcher.
databaseServerWatcher *services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer]

// remoteCA is the last remote certificate authority recorded by the client.
// It is used to detect CA rotation status changes. If the rotation
// state has been changed, the tunnel will reconnect to re-create the client
Expand Down Expand Up @@ -182,6 +185,11 @@ func (s *leafCluster) GitServerWatcher() (*services.GenericWatcher[types.Server,
return nil, trace.NotImplemented("GitServerWatcher not implemented for leafCluster")
}

// DatabaseServerWatcher returns the Database server watcher for the leaf cluster.
func (s *leafCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) {
return s.databaseServerWatcher, nil
}

func (s *leafCluster) GetClient() (authclient.ClientI, error) {
return s.leafClient, nil
}
Expand Down
5 changes: 5 additions & 0 deletions lib/reversetunnel/local_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ func (s *localCluster) GitServerWatcher() (*services.GenericWatcher[types.Server
return s.srv.GitServerWatcher, nil
}

// DatabaseServerWatcher returns a Database server watcher for this cluster.
func (s *localCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) {
return s.srv.DatabaseServerWatcher, nil
}

// GetClient returns a client to the full Auth Server API.
func (s *localCluster) GetClient() (authclient.ClientI, error) {
return s.client, nil
Expand Down
13 changes: 13 additions & 0 deletions lib/reversetunnel/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ func (p *expectedLeafClusters) GitServerWatcher() (*services.GenericWatcher[type
return cluster.GitServerWatcher()
}

// DatabaseServerWatcher returns a watcher for database servers in the leaf cluster.
func (p *expectedLeafClusters) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) {
cluster, err := p.pickCluster()
if err != nil {
return nil, trace.Wrap(err)
}
return cluster.DatabaseServerWatcher()
}

func (p *expectedLeafClusters) GetClient() (authclient.ClientI, error) {
cluster, err := p.pickCluster()
if err != nil {
Expand Down Expand Up @@ -227,6 +236,10 @@ func (s *expectedLeafCluster) GitServerWatcher() (*services.GenericWatcher[types
return nil, s.discoveryError("unable to fetch git server watcher for leaf cluster")
}

func (s *expectedLeafCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) {
return nil, s.discoveryError("unable to fetch database server watcher for leaf cluster")
}

func (s *expectedLeafCluster) GetClient() (authclient.ClientI, error) {
return nil, s.discoveryError("unable to fetch auth client for leaf cluster")
}
Expand Down
18 changes: 18 additions & 0 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ type Config struct {
// AppServerWatcher is a app server watcher.
AppServerWatcher *services.GenericWatcher[types.AppServer, readonly.AppServer]

// DatabaseServerWatcher is a database server watcher.
DatabaseServerWatcher *services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer]

// CircuitBreakerConfig configures the auth client circuit breaker
CircuitBreakerConfig breaker.Config

Expand Down Expand Up @@ -305,6 +308,9 @@ func (cfg *Config) CheckAndSetDefaults() error {
if cfg.AppServerWatcher == nil {
return trace.BadParameter("missing parameter AppServerWatcher")
}
if cfg.DatabaseServerWatcher == nil {
return trace.BadParameter("missing parameter DatabaseServerWatcher")
}

if cfg.EICEDialer == nil {
return trace.BadParameter("missing parameter EICEDialer")
Expand Down Expand Up @@ -1298,6 +1304,18 @@ func newLeafCluster(srv *server, domainName string, sconn ssh.Conn) (*leafCluste
}
leaf.appServerWatcher = appServerWatcher

databaseServerWatcher, err := services.NewDatabaseServerWatcher(closeContext, services.DatabaseServerWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: srv.Component,
Logger: srv.Logger,
Client: accessPoint,
},
})
if err != nil {
return nil, trace.Wrap(err)
}
leaf.databaseServerWatcher = databaseServerWatcher

// instantiate a cache of host certificates for the forwarding server. the
// certificate cache is created in each cluster (instead of creating it in
// reversetunnel.server and passing it along) so that the host certificate
Expand Down
2 changes: 2 additions & 0 deletions lib/reversetunnelclient/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ type Cluster interface {
GitServerWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error)
// AppServerWatcher returns the watcher that maintains the app server set for the cluster
AppServerWatcher() (*services.GenericWatcher[types.AppServer, readonly.AppServer], error)
// DatabaseServerWatcher returns the watcher that maintains the database server set for the cluster
DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error)
// GetTunnelsCount returns the amount of active inbound tunnels
// from the remote cluster
GetTunnelsCount() int
Expand Down
23 changes: 19 additions & 4 deletions lib/reversetunnelclient/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ type FakeCluster struct {
closed bool
// appServerWatcher ia a app server watcher to speed up app look up.
appServerWatcher *services.GenericWatcher[types.AppServer, readonly.AppServer]
// databaseServerWatcher is a database server watcher to speed up database server look up.
databaseServerWatcher *services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer]
}

// NewFakeCluster is a FakeCluster constructor.
Expand All @@ -84,11 +86,19 @@ func NewFakeCluster(clusterName string, accessPoint authclient.RemoteProxyAccess
},
})

databaseServerWatcher, _ := services.NewDatabaseServerWatcher(context.TODO(), services.DatabaseServerWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: "FakeCluster",
Client: accessPoint,
},
})

return &FakeCluster{
Name: clusterName,
connCh: make(chan net.Conn),
AccessPoint: accessPoint,
appServerWatcher: appServerWatcher,
Name: clusterName,
connCh: make(chan net.Conn),
AccessPoint: accessPoint,
appServerWatcher: appServerWatcher,
databaseServerWatcher: databaseServerWatcher,
}
}

Expand All @@ -97,6 +107,11 @@ func (s *FakeCluster) AppServerWatcher() (*services.GenericWatcher[types.AppServ
return s.appServerWatcher, nil
}

// DatabaseServerWatcher returns the watcher that maintains the database server set for the cluster
func (s *FakeCluster) DatabaseServerWatcher() (*services.GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) {
return s.databaseServerWatcher, nil
}

// CachingAccessPoint returns caching auth server client.
func (s *FakeCluster) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, error) {
return s.AccessPoint, nil
Expand Down
13 changes: 13 additions & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4974,6 +4974,18 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return trace.Wrap(err)
}

databaseServerWatcher, err := services.NewDatabaseServerWatcher(process.ExitContext(), services.DatabaseServerWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Logger: process.logger.With(teleport.ComponentKey, teleport.ComponentProxy),
Client: accessPoint,
},
Comment thread
wethreetrees marked this conversation as resolved.
DatabaseServersGetter: accessPoint,
})
if err != nil {
return trace.Wrap(err)
}

serverTLSConfig, err := process.ServerTLSConfig(conn)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -5246,6 +5258,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
NodeWatcher: nodeWatcher,
AppServerWatcher: appServerWatcher,
GitServerWatcher: gitServerWatcher,
DatabaseServerWatcher: databaseServerWatcher,
CertAuthorityWatcher: caWatcher,
CircuitBreakerConfig: process.Config.CircuitBreakerConfig,
LocalAuthAddresses: utils.NetAddrsToStrings(process.Config.AuthServerAddresses()),
Expand Down
3 changes: 2 additions & 1 deletion lib/services/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ type NodesGetter interface {

// DatabaseServersGetter is a service that gets database servers.
type DatabaseServersGetter interface {
GetDatabaseServers(context.Context, string, ...MarshalOption) ([]types.DatabaseServer, error)
// GetDatabaseServers returns all registered database proxy servers.
GetDatabaseServers(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.DatabaseServer, error)
}

// AppServersGetter is a service that gets application servers.
Expand Down
22 changes: 22 additions & 0 deletions lib/services/readonly/readonly.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,28 @@ type AppServer interface {

var _ AppServer = types.AppServer(nil)

// DatabaseServer is a read only variant of [types.DatabaseServer]
type DatabaseServer struct {
inner types.DatabaseServer
}

// GetDatabaseName returns the name of the database this server is proxying.
func (d DatabaseServer) GetDatabaseName() string {
if d.inner == nil {
return ""
}
db := d.inner.GetDatabase()
if db == nil {
return ""
}
return db.GetName()
}

// NewDatabaseServer returns a new read-only DatabaseServer.
func NewDatabaseServer(server types.DatabaseServer) DatabaseServer {
return DatabaseServer{inner: server}
}

// KubeServer is a read only variant of [types.KubeServer].
type KubeServer interface {
// ResourceWithLabels provides common resource methods.
Expand Down
54 changes: 54 additions & 0 deletions lib/services/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,60 @@ func NewAppServersWatcher(ctx context.Context, cfg AppServersWatcherConfig) (*Ge
return w, trace.Wrap(err)
}

type DatabaseServerWatcherConfig struct {
DatabaseServersGetter
ResourceWatcherConfig
}

// CheckAndSetDefaults checks parameters and sets default values.
func (cfg *DatabaseServerWatcherConfig) CheckAndSetDefaults() error {
if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}

if cfg.MaxStaleness == 0 {
const databaseServerMaxStaleness = time.Minute
Comment thread
wethreetrees marked this conversation as resolved.
cfg.MaxStaleness = databaseServerMaxStaleness
}

if cfg.DatabaseServersGetter == nil {
getter, ok := cfg.Client.(DatabaseServersGetter)
if !ok {
return trace.BadParameter("missing parameter DatabaseServersGetter and Client not usable as DatabaseServersGetter")
}
cfg.DatabaseServersGetter = getter
}

return nil
}

func NewDatabaseServerWatcher(ctx context.Context, cfg DatabaseServerWatcherConfig) (*GenericWatcher[types.DatabaseServer, readonly.DatabaseServer], error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}

w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.DatabaseServer, readonly.DatabaseServer]{
ResourceWatcherConfig: cfg.ResourceWatcherConfig,
ResourceKind: types.KindDatabaseServer,
ResourceKey: func(r types.DatabaseServer) string {
// the host ID is guaranteed not to contain "/"
return r.GetHostID() + "/" + r.GetName()
},
DeleteKey: func(r types.Resource) string {
// database servers put the host ID in the description in delete events
return r.GetMetadata().Description + "/" + r.GetName()
},
ResourceGetter: func(ctx context.Context) ([]types.DatabaseServer, error) {
return cfg.DatabaseServersGetter.GetDatabaseServers(ctx, apidefaults.Namespace)
},
DisableUpdateBroadcast: true,
CloneFunc: types.DatabaseServer.Copy,
ReadOnlyFunc: readonly.NewDatabaseServer,
})

return w, trace.Wrap(err)
}

// KubeServerWatcherConfig is an KubeServerWatcher configuration.
type KubeServerWatcherConfig struct {
// KubernetesServerGetter is responsible for fetching kube_server resources.
Expand Down
103 changes: 103 additions & 0 deletions lib/services/watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,109 @@ func newApp(t *testing.T, name string) *types.AppV3 {
return app
}

// TestDatabaseServerWatcher tests that database server resource watcher properly
// receives and dispatches updates.
func TestDatabaseServerWatcher(t *testing.T) {
synctest.Test(t, syncTestDatabaseServerWatcher)
}

func syncTestDatabaseServerWatcher(t *testing.T) {
ctx := t.Context()

bk, err := memory.New(memory.Config{Context: ctx})
require.NoError(t, err)

type client struct {
services.DatabaseServersGetter
types.Events
}

presenceService := local.NewPresenceService(bk)
w, err := services.NewDatabaseServerWatcher(ctx, services.DatabaseServerWatcherConfig{
DatabaseServersGetter: presenceService,
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: "test",
MaxRetryPeriod: 200 * time.Millisecond,
Client: &client{
DatabaseServersGetter: presenceService,
Events: local.NewEventsService(bk),
},
},
})
require.NoError(t, err)
t.Cleanup(w.Close)

// Wait for initial load.
require.NoError(t, w.WaitInitialization())

// Initially there are no database servers.
servers, err := w.CurrentResources(ctx)
require.NoError(t, err)
require.Empty(t, servers)

// Add a database server and wait for the watcher to process the event.
server1 := newDatabaseServer(t, "db1", "host1")
_, err = presenceService.UpsertDatabaseServer(ctx, server1)
require.NoError(t, err)
synctest.Wait()

servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool {
return ds.GetDatabaseName() == "db1"
})
require.NoError(t, err)
require.Len(t, servers, 1)
require.Equal(t, server1.GetName(), servers[0].GetName())

// Add a second database server and wait for the watcher to process the event.
server2 := newDatabaseServer(t, "db2", "host2")
_, err = presenceService.UpsertDatabaseServer(ctx, server2)
require.NoError(t, err)
synctest.Wait()

servers, err = w.CurrentResources(ctx)
require.NoError(t, err)
require.Len(t, servers, 2)

servers, err = w.CurrentResourcesWithFilter(ctx, func(ds readonly.DatabaseServer) bool {
return ds.GetDatabaseName() == "db2"
})
require.NoError(t, err)
require.Len(t, servers, 1)
require.Equal(t, server2.GetName(), servers[0].GetName())

// Delete the first database server and wait for the watcher to process the event.
err = presenceService.DeleteDatabaseServer(ctx, apidefaults.Namespace, server1.GetHostID(), server1.GetName())
require.NoError(t, err)
synctest.Wait()

// Verify the remaining server is server2.
servers, err = w.CurrentResources(ctx)
require.NoError(t, err)
require.Len(t, servers, 1)
require.Equal(t, server2.GetName(), servers[0].GetName())
}

func newDatabaseServer(t *testing.T, dbName, hostID string) types.DatabaseServer {
t.Helper()
server, err := types.NewDatabaseServerV3(types.Metadata{
Name: dbName,
}, types.DatabaseServerSpecV3{
Database: &types.DatabaseV3{
Metadata: types.Metadata{
Name: dbName,
},
Spec: types.DatabaseSpecV3{
Protocol: defaults.ProtocolPostgres,
URI: "localhost:5432",
},
},
HostID: hostID,
Hostname: dbName,
})
require.NoError(t, err)
return server
}

func TestCertAuthorityWatcher(t *testing.T) {
t.Parallel()

Expand Down
Loading
Loading