From 64b81d0805d5bd4fe395bdb4f5d91c11d57bf2fe Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Tue, 17 May 2022 15:06:41 -0400 Subject: [PATCH] Improve CertAuthorityWatcher (#10403) * Improve CertAuthorityWatcher CertAuthorityWatcher and its usage are refactored to allow for all the following: - eliminate retransmission of the same CAs - reduce memory usage by having one local watcher per proxy - adds the ability to filter only the CAs that are desired - reduce the time required to send the first CAs watchCertAuthorities now compares all CAs it receives from the watcher with the previous CA of the same type and only sends to the remote site if they are not identical. This is to reduce unnecessary network traffic which can be problematic for a root cluster with a larger number of leafs. The CertAuthorityWatcher is refactored to leverage a fanout to emit events to any number of watchers, each subscription can be for a subset of the configured CA types. The proxy now has only one CertAuthorityWatcher that is passed around similarly to the LockWatcher. This reduces the memory usage for proxies, which prior to this has one local CAWatcher per remote site. updateCertAuthorities no longer waits on the utils.Retry it is provided with before starting to watch CAs. By doing this the proxy no longer has to wait ~8 minutes before it even starts to watch CAs. (cherry picked from commit 1ac0957d0e0094b75b3d70df311391b7c4a2b999) --- integration/integration_test.go | 60 ++++------ lib/reversetunnel/remotesite.go | 105 +++++++++-------- lib/reversetunnel/srv.go | 35 ++++-- lib/service/service.go | 36 ++++-- lib/services/watcher.go | 182 ++++++++++++++++++------------ lib/services/watcher_test.go | 158 ++++++++++++++------------ lib/srv/regular/sshserver_test.go | 19 ++++ lib/web/apiserver_test.go | 22 ++++ 8 files changed, 372 insertions(+), 245 deletions(-) diff --git a/integration/integration_test.go b/integration/integration_test.go index fabe6300142ee..160b84c9934ab 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -4100,44 +4100,27 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { require.NoError(t, err) // waitForPhase waits until aux cluster detects the rotation - waitForPhase := func(phase string) error { - ctx, cancel := context.WithTimeout(context.Background(), tconf.PollingPeriod*10) - defer cancel() - - watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Clock: tconf.Clock, - Client: aux.GetSiteAPI(clusterAux), - }, - WatchHostCA: true, - }) - if err != nil { - return err - } - defer watcher.Close() + waitForPhase := func(phase string) { + require.Eventually(t, func() bool { + ca, err := aux.Process.GetAuthServer().GetCertAuthority( + ctx, + types.CertAuthID{ + Type: types.HostCA, + DomainName: clusterMain, + }, false) + if err != nil { + return false + } - var lastPhase string - for i := 0; i < 10; i++ { - select { - case <-ctx.Done(): - return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase) - case cas := <-watcher.CertAuthorityC: - for _, ca := range cas { - if ca.GetClusterName() == clusterMain && - ca.GetType() == types.HostCA && - ca.GetRotation().Phase == phase { - return nil - } - lastPhase = ca.GetRotation().Phase - } + if ca.GetRotation().Phase == phase { + return true } - } - return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase) + + return false + }, 30*time.Second, 250*time.Millisecond, "failed to converge to phase %q", phase) } - err = waitForPhase(types.RotationPhaseInit) - require.NoError(t, err) + waitForPhase(types.RotationPhaseInit) // update clients err = svc.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{ @@ -4150,8 +4133,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { svc, err = suite.waitForReload(serviceC, svc) require.NoError(t, err) - err = waitForPhase(types.RotationPhaseUpdateClients) - require.NoError(t, err) + waitForPhase(types.RotationPhaseUpdateClients) // old client should work as is err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*") @@ -4170,8 +4152,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { svc, err = suite.waitForReload(serviceC, svc) require.NoError(t, err) - err = waitForPhase(types.RotationPhaseUpdateServers) - require.NoError(t, err) + waitForPhase(types.RotationPhaseUpdateServers) // new credentials will work from this phase to others newCreds, err := GenerateUserCreds(UserCredsRequest{Process: svc, Username: suite.me.Username}) @@ -4199,8 +4180,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) { require.NoError(t, err) t.Log("Service reload completed, waiting for phase.") - err = waitForPhase(types.RotationPhaseStandby) - require.NoError(t, err) + waitForPhase(types.RotationPhaseStandby) t.Log("Phase completed.") // new client still works diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 524f830315463..d20745554b04f 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -427,20 +427,12 @@ func (s *remoteSite) compareAndSwapCertAuthority(ca types.CertAuthority) error { return trace.CompareFailed("remote certificate authority rotation has been updated") } -func (s *remoteSite) updateCertAuthorities(retry utils.Retry) { - s.Debugf("Watching for cert authority changes.") +func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteWatcher *services.CertAuthorityWatcher, remoteVersion string) { + defer remoteWatcher.Close() + cas := make(map[types.CertAuthType]types.CertAuthority) for { - startedWaiting := s.clock.Now() - select { - case t := <-retry.After(): - s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting)) - retry.Inc() - case <-s.ctx.Done(): - return - } - - err := s.watchCertAuthorities() + err := s.watchCertAuthorities(remoteWatcher, remoteVersion, cas) if err != nil { switch { case trace.IsNotFound(err): @@ -456,67 +448,92 @@ func (s *remoteSite) updateCertAuthorities(retry utils.Retry) { } } + startedWaiting := s.clock.Now() + select { + case t := <-retry.After(): + s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting)) + retry.Inc() + case <-s.ctx.Done(): + return + } } } -func (s *remoteSite) watchCertAuthorities() error { - localWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Log: s, - Clock: s.clock, - Client: s.localAccessPoint, +func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityWatcher, remoteVersion string, cas map[types.CertAuthType]types.CertAuthority) error { + localWatch, err := s.srv.CertAuthorityWatcher.Subscribe( + s.ctx, + services.CertAuthorityTarget{ + Type: types.HostCA, + ClusterName: s.srv.ClusterName, }, - WatchUserCA: true, - WatchHostCA: true, - }) + services.CertAuthorityTarget{ + Type: types.UserCA, + ClusterName: s.srv.ClusterName, + }) if err != nil { return trace.Wrap(err) } - defer localWatcher.Close() + defer func() { + if err := localWatch.Close(); err != nil { + s.WithError(err).Warn("Failed to close local ca watcher subscription.") + } + }() - remoteWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Log: s, - Clock: s.clock, - Client: s.remoteAccessPoint, + remoteWatch, err := remoteWatcher.Subscribe( + s.ctx, + services.CertAuthorityTarget{ + ClusterName: s.domainName, + Type: types.HostCA, }, - WatchHostCA: true, - }) + ) if err != nil { return trace.Wrap(err) } - defer remoteWatcher.Close() + defer func() { + if err := remoteWatch.Close(); err != nil { + s.WithError(err).Warn("Failed to close remote ca watcher subscription.") + } + }() + s.Debugf("Watching for cert authority changes.") for { select { case <-s.ctx.Done(): s.WithError(s.ctx.Err()).Debug("Context is closing.") return trace.Wrap(s.ctx.Err()) - case <-localWatcher.Done(): + case <-localWatch.Done(): s.Warn("Local CertAuthority watcher subscription has closed") return fmt.Errorf("local ca watcher for cluster %s has closed", s.srv.ClusterName) - case <-remoteWatcher.Done(): + case <-remoteWatch.Done(): s.Warn("Remote CertAuthority watcher subscription has closed") return fmt.Errorf("remote ca watcher for cluster %s has closed", s.domainName) - case cas := <-localWatcher.CertAuthorityC: - for _, localCA := range cas { - if localCA.GetClusterName() != s.srv.ClusterName || - (localCA.GetType() != types.HostCA && - localCA.GetType() != types.UserCA) { + case evt := <-localWatch.Events(): + switch evt.Type { + case types.OpPut: + localCA, ok := evt.Resource.(types.CertAuthority) + if !ok { continue } + ca, ok := cas[localCA.GetType()] + if ok && services.CertAuthoritiesEquivalent(ca, localCA) { + continue + } + + // clone to prevent a race with watcher filtering + localCA = localCA.Clone() if err := s.remoteClient.RotateExternalCertAuthority(s.ctx, localCA); err != nil { - s.WithError(err).Warn("Failed to rotate external ca") + log.WithError(err).Warn("Failed to rotate external ca") return trace.Wrap(err) } + + cas[localCA.GetType()] = localCA } - case cas := <-remoteWatcher.CertAuthorityC: - for _, remoteCA := range cas { - if remoteCA.GetType() != types.HostCA || - remoteCA.GetClusterName() != s.domainName { + case evt := <-remoteWatch.Events(): + switch evt.Type { + case types.OpPut: + remoteCA, ok := evt.Resource.(types.CertAuthority) + if !ok { continue } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 4d2ac27c28853..c1ded366ca279 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -205,6 +205,9 @@ type Config struct { // NodeWatcher is a node watcher. NodeWatcher *services.NodeWatcher + + // CertAuthorityWatcher is a cert authority watcher. + CertAuthorityWatcher *services.CertAuthorityWatcher } // CheckAndSetDefaults checks parameters and sets default values @@ -259,6 +262,9 @@ func (cfg *Config) CheckAndSetDefaults() error { if cfg.NodeWatcher == nil { return trace.BadParameter("missing parameter NodeWatcher") } + if cfg.CertAuthorityWatcher == nil { + return trace.BadParameter("missing parameter CertAuthorityWatcher") + } return nil } @@ -1040,6 +1046,11 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, connInfo.SetExpiry(srv.Clock.Now().Add(srv.offlineThreshold)) closeContext, cancel := context.WithCancel(srv.ctx) + defer func() { + if err != nil { + cancel() + } + }() remoteSite := &remoteSite{ srv: srv, domainName: domainName, @@ -1063,20 +1074,17 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, clt, _, err := remoteSite.getRemoteClient() if err != nil { - cancel() return nil, trace.Wrap(err) } remoteSite.remoteClient = clt remoteVersion, err := getRemoteAuthVersion(closeContext, sconn) if err != nil { - cancel() return nil, trace.Wrap(err) } accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName) if err != nil { - cancel() return nil, trace.Wrap(err) } remoteSite.remoteAccessPoint = accessPoint @@ -1088,7 +1096,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, }, }) if err != nil { - cancel() return nil, trace.Wrap(err) } remoteSite.nodeWatcher = nodeWatcher @@ -1098,7 +1105,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, // is signed by the correct certificate authority. certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient) if err != nil { - cancel() return nil, trace.Wrap(err) } remoteSite.certificateCache = certificateCache @@ -1111,11 +1117,25 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, Clock: srv.Clock, }) if err != nil { - cancel() return nil, trace.Wrap(err) } - go remoteSite.updateCertAuthorities(caRetry) + remoteWatcher, err := services.NewCertAuthorityWatcher(srv.ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: srv.log, + Clock: srv.Clock, + Client: remoteSite.remoteAccessPoint, + }, + Types: []types.CertAuthType{types.HostCA}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + go func() { + remoteSite.updateCertAuthorities(caRetry, remoteWatcher, remoteVersion) + }() lockRetry, err := utils.NewLinear(utils.LinearConfig{ First: utils.HalfJitter(srv.Config.PollingPeriod), @@ -1125,7 +1145,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, Clock: srv.Clock, }) if err != nil { - cancel() return nil, trace.Wrap(err) } diff --git a/lib/service/service.go b/lib/service/service.go index fca3ea37c42e3..46a647555fe67 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2874,6 +2874,19 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } + caWatcher, err := services.NewCertAuthorityWatcher(process.ExitContext(), services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: process.log.WithField(trace.Component, teleport.ComponentProxy), + Client: conn.Client, + }, + AuthorityGetter: accessPoint, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + if err != nil { + return trace.Wrap(err) + } + serverTLSConfig, err := conn.ServerIdentity.TLSConfig(cfg.CipherSuites) if err != nil { return trace.Wrap(err) @@ -2903,17 +2916,18 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Client: conn.Client, }, }, - KeyGen: cfg.Keygen, - Ciphers: cfg.Ciphers, - KEXAlgorithms: cfg.KEXAlgorithms, - MACAlgorithms: cfg.MACAlgorithms, - DataDir: process.Config.DataDir, - PollingPeriod: process.Config.PollingPeriod, - FIPS: cfg.FIPS, - Emitter: streamEmitter, - Log: process.log, - LockWatcher: lockWatcher, - NodeWatcher: nodeWatcher, + KeyGen: cfg.Keygen, + Ciphers: cfg.Ciphers, + KEXAlgorithms: cfg.KEXAlgorithms, + MACAlgorithms: cfg.MACAlgorithms, + DataDir: process.Config.DataDir, + PollingPeriod: process.Config.PollingPeriod, + FIPS: cfg.FIPS, + Emitter: streamEmitter, + Log: process.log, + LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, + CertAuthorityWatcher: caWatcher, }) if err != nil { return trace.Wrap(err) diff --git a/lib/services/watcher.go b/lib/services/watcher.go index f035c8a5ab6e0..43cdd58cd58b2 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -919,12 +919,8 @@ type CertAuthorityWatcherConfig struct { ResourceWatcherConfig // AuthorityGetter is responsible for fetching cert authority resources. AuthorityGetter - // CertAuthorityC receives up-to-date list of all cert authority resources. - CertAuthorityC chan []types.CertAuthority - // WatchHostCA indicates that the watcher should monitor types.HostCA - WatchHostCA bool - // WatchUserCA indicates that the watcher should monitor types.UserCA - WatchUserCA bool + // Types restricts which cert authority types are retrieved via the AuthorityGetter. + Types []types.CertAuthType } // CheckAndSetDefaults checks parameters and sets default values. @@ -939,12 +935,19 @@ func (cfg *CertAuthorityWatcherConfig) CheckAndSetDefaults() error { } cfg.AuthorityGetter = getter } - if cfg.CertAuthorityC == nil { - cfg.CertAuthorityC = make(chan []types.CertAuthority) - } return nil } +// IsWatched return true if the given certificate auth type is being observer by the watcher. +func (cfg *CertAuthorityWatcherConfig) IsWatched(certType types.CertAuthType) bool { + for _, observedType := range cfg.Types { + if observedType == certType { + return true + } + } + return false +} + // NewCertAuthorityWatcher returns a new instance of CertAuthorityWatcher. func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig) (*CertAuthorityWatcher, error) { if err := cfg.CheckAndSetDefaults(); err != nil { @@ -953,6 +956,12 @@ func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig collector := &caCollector{ CertAuthorityWatcherConfig: cfg, + fanout: NewFanout(), + cas: make(map[types.CertAuthType]map[string]types.CertAuthority, len(cfg.Types)), + } + + for _, t := range cfg.Types { + collector.cas[t] = make(map[string]types.CertAuthority) } watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) @@ -960,6 +969,7 @@ func NewCertAuthorityWatcher(ctx context.Context, cfg CertAuthorityWatcherConfig return nil, trace.Wrap(err) } + collector.fanout.SetInit() return &CertAuthorityWatcher{watcher, collector}, nil } @@ -972,9 +982,66 @@ type CertAuthorityWatcher struct { // caCollector accompanies resourceWatcher when monitoring cert authority resources. type caCollector struct { CertAuthorityWatcherConfig - host map[string]types.CertAuthority - user map[string]types.CertAuthority + fanout *Fanout + + // lock protects concurrent access to cas lock sync.RWMutex + // cas maps ca type -> cluster -> ca + cas map[types.CertAuthType]map[string]types.CertAuthority +} + +// CertAuthorityTarget lists the attributes of interactions to be disabled. +type CertAuthorityTarget struct { + // ClusterName specifies the name of the cluster to watch. + ClusterName string + // Type specifies the ca types to watch for. + Type types.CertAuthType +} + +// Subscribe is used to subscribe to the lock updates. +func (c *caCollector) Subscribe(ctx context.Context, targets ...CertAuthorityTarget) (types.Watcher, error) { + watchKinds, err := caTargetToWatchKinds(targets) + if err != nil { + return nil, trace.Wrap(err) + } + sub, err := c.fanout.NewWatcher(ctx, types.Watch{Kinds: watchKinds}) + if err != nil { + return nil, trace.Wrap(err) + } + select { + case event := <-sub.Events(): + if event.Type != types.OpInit { + return nil, trace.BadParameter("expected init event, got %v instead", event.Type) + } + case <-sub.Done(): + return nil, trace.Wrap(sub.Error()) + } + return sub, nil +} + +func caTargetToWatchKinds(targets []CertAuthorityTarget) ([]types.WatchKind, error) { + watchKinds := make([]types.WatchKind, 0, len(targets)) + for _, target := range targets { + kind := types.WatchKind{ + Kind: types.KindCertAuthority, + // Note that watching SubKind doesn't work for types.WatchKind - to do so it would + // require a custom filter, which was recently added but - we can't use yet due to + // older clients not supporting the filter. + SubKind: string(target.Type), + } + + if target.ClusterName != "" { + kind.Name = target.ClusterName + } + + watchKinds = append(watchKinds, kind) + } + + if len(watchKinds) == 0 { + watchKinds = []types.WatchKind{{Kind: types.KindCertAuthority}} + } + + return watchKinds, nil } // resourceKind specifies the resource kind to watch. @@ -984,42 +1051,27 @@ func (c *caCollector) resourceKind() string { // getResourcesAndUpdateCurrent refreshes the list of current resources. func (c *caCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - var ( - newHost map[string]types.CertAuthority - newUser map[string]types.CertAuthority - ) + var cas []types.CertAuthority - if c.WatchHostCA { - host, err := c.AuthorityGetter.GetCertAuthorities(ctx, types.HostCA, false) + for _, t := range c.Types { + authorities, err := c.AuthorityGetter.GetCertAuthorities(ctx, t, false) if err != nil { return trace.Wrap(err) } - newHost = make(map[string]types.CertAuthority, len(host)) - for _, ca := range host { - newHost[ca.GetName()] = ca - } - } - if c.WatchUserCA { - user, err := c.AuthorityGetter.GetCertAuthorities(ctx, types.UserCA, false) - if err != nil { - return trace.Wrap(err) - } - newUser = make(map[string]types.CertAuthority, len(user)) - for _, ca := range user { - newUser[ca.GetName()] = ca - } + cas = append(cas, authorities...) } c.lock.Lock() - c.host = newHost - c.user = newUser - c.lock.Unlock() + defer c.lock.Unlock() - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case c.CertAuthorityC <- casToSlice(newHost, newUser): + for _, ca := range cas { + if !c.watchingType(ca.GetType()) { + continue + } + + c.cas[ca.GetType()][ca.GetName()] = ca + c.fanout.Emit(types.Event{Type: types.OpPut, Resource: ca.Clone()}) } return nil } @@ -1034,17 +1086,13 @@ func (c *caCollector) processEventAndUpdateCurrent(ctx context.Context, event ty defer c.lock.Unlock() switch event.Type { case types.OpDelete: - if c.WatchHostCA && event.Resource.GetSubKind() == string(types.HostCA) { - delete(c.host, event.Resource.GetName()) - } - if c.WatchUserCA && event.Resource.GetSubKind() == string(types.UserCA) { - delete(c.user, event.Resource.GetName()) + caType := types.CertAuthType(event.Resource.GetSubKind()) + if !c.watchingType(caType) { + return } - select { - case <-ctx.Done(): - case c.CertAuthorityC <- casToSlice(c.host, c.user): - } + delete(c.cas[caType], event.Resource.GetName()) + c.fanout.Emit(event) case types.OpPut: ca, ok := event.Resource.(types.CertAuthority) if !ok { @@ -1052,43 +1100,35 @@ func (c *caCollector) processEventAndUpdateCurrent(ctx context.Context, event ty return } - if c.WatchHostCA && ca.GetType() == types.HostCA { - c.host[ca.GetName()] = ca - } - if c.WatchUserCA && ca.GetType() == types.UserCA { - c.user[ca.GetName()] = ca + if !c.watchingType(ca.GetType()) { + return } - select { - case <-ctx.Done(): - case c.CertAuthorityC <- casToSlice(c.host, c.user): + authority, ok := c.cas[ca.GetType()][ca.GetName()] + if ok && CertAuthoritiesEquivalent(authority, ca) { + return } + + c.cas[ca.GetType()][ca.GetName()] = ca + c.fanout.Emit(event) default: c.Log.Warnf("Unsupported event type %s.", event.Type) return } } -// GetCurrent returns the currently stored authorities. -func (c *caCollector) GetCurrent() []types.CertAuthority { - c.lock.RLock() - defer c.lock.RUnlock() - return casToSlice(c.host, c.user) +func (c *caCollector) watchingType(t types.CertAuthType) bool { + for _, caType := range c.Types { + if caType == t { + return true + } + } + + return false } func (c *caCollector) notifyStale() {} -func casToSlice(host map[string]types.CertAuthority, user map[string]types.CertAuthority) []types.CertAuthority { - slice := make([]types.CertAuthority, 0, len(host)+len(user)) - for _, ca := range host { - slice = append(slice, ca) - } - for _, ca := range user { - slice = append(slice, ca) - } - return slice -} - // NodeWatcherConfig is a NodeWatcher configuration. type NodeWatcherConfig struct { ResourceWatcherConfig diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 6521e65aa30d9..a16b69913b714 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend/lite" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/tlsca" @@ -520,9 +521,10 @@ func resourceDiff(res1, res2 types.Resource) string { func caDiff(ca1, ca2 types.CertAuthority) string { return cmp.Diff(ca1, ca2, cmpopts.IgnoreFields(types.Metadata{}, "ID"), - cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs"), + cmpopts.IgnoreFields(types.CertAuthoritySpecV2{}, "CheckingKeys", "TLSKeyPairs", "JWTKeyPairs"), cmpopts.IgnoreFields(types.SSHKeyPair{}, "PrivateKey"), cmpopts.IgnoreFields(types.TLSKeyPair{}, "Key"), + cmpopts.IgnoreFields(types.JWTKeyPair{}, "PrivateKey"), cmpopts.EquateEmpty(), ) } @@ -723,10 +725,12 @@ func newApp(t *testing.T, name string) types.Application { func TestCertAuthorityWatcher(t *testing.T) { t.Parallel() ctx := context.Background() + clock := clockwork.NewFakeClock() bk, err := lite.NewWithConfig(ctx, lite.Config{ Path: t.TempDir(), PollStreamPeriod: 200 * time.Millisecond, + Clock: clock, }) require.NoError(t, err) @@ -744,86 +748,88 @@ func TestCertAuthorityWatcher(t *testing.T) { Trust: caService, Events: local.NewEventsService(bk), }, + Clock: clock, }, - CertAuthorityC: make(chan []types.CertAuthority, 10), - WatchUserCA: true, - WatchHostCA: true, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, }) require.NoError(t, err) t.Cleanup(w.Close) - nothingWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: "test", - MaxRetryPeriod: 200 * time.Millisecond, - Client: &client{ - Trust: caService, - Events: local.NewEventsService(bk), - }, - }, - CertAuthorityC: make(chan []types.CertAuthority, 10), - }) + target := services.CertAuthorityTarget{ClusterName: "test"} + sub, err := w.Subscribe(ctx, target) require.NoError(t, err) - t.Cleanup(nothingWatcher.Close) - - require.Empty(t, w.GetCurrent()) - require.Empty(t, nothingWatcher.GetCurrent()) + t.Cleanup(func() { require.NoError(t, sub.Close()) }) - // Initially there are no cas so watcher should send an empty list. + // create a CA for the cluster and a type we are filtering for + // and ensure we receive the event + ca := newCertAuthority(t, "test", types.HostCA) + require.NoError(t, caService.UpsertCertAuthority(ca)) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 0) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the first event.") + case event := <-sub.Events(): + caFromEvent, ok := event.Resource.(types.CertAuthority) + require.True(t, ok) + require.Empty(t, caDiff(ca, caFromEvent)) + case <-time.After(time.Second): + t.Fatal("timed out waiting for event") } - // Add an authority. - ca1 := newCertAuthority(t, "ca1", types.HostCA) - require.NoError(t, caService.CreateCertAuthority(ca1)) - - // The first event is always the current list of apps. + // create a CA with a type we are filtering for another cluster that we are NOT filtering for + // and ensure that we DO NOT receive the event + require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "unknown", types.UserCA))) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 1) - require.Empty(t, caDiff(changeset[0], ca1)) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the first event.") + case event := <-sub.Events(): + t.Fatalf("Unexpected event: %v.", event) + case <-sub.Done(): + t.Fatal("CA watcher subscription has unexpectedly exited.") + case <-time.After(time.Second): } - // Add a second ca. - ca2 := newCertAuthority(t, "ca2", types.UserCA) - require.NoError(t, caService.CreateCertAuthority(ca2)) + // create a CA for the cluster and a type we are filtering for + // and ensure we receive the event + ca2 := newCertAuthority(t, "test", types.UserCA) + require.NoError(t, caService.UpsertCertAuthority(ca2)) + select { + case event := <-sub.Events(): + caFromEvent, ok := event.Resource.(types.CertAuthority) + require.True(t, ok) + require.Empty(t, caDiff(ca2, caFromEvent)) + case <-time.After(time.Second): + t.Fatal("timed out waiting for event") + } - // Watcher should detect the ca list change. + // delete a CA with type being watched in the cluster we are filtering for + // and ensure we receive the event + require.NoError(t, caService.DeleteCertAuthority(ca.GetID())) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 2) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the update event.") + case event := <-sub.Events(): + require.Equal(t, types.KindCertAuthority, event.Resource.GetKind()) + require.Equal(t, string(types.HostCA), event.Resource.GetSubKind()) + require.Equal(t, "test", event.Resource.GetName()) + case <-time.After(time.Second): + t.Fatal("timed out waiting for event") } - // Delete the first ca. - require.NoError(t, caService.DeleteCertAuthority(ca1.GetID())) + // create a CA with a type we are NOT filtering for but for a cluster we are filtering for + // and ensure we DO NOT receive the event + signer := newCertAuthority(t, "test", types.JWTSigner) + require.NoError(t, caService.UpsertCertAuthority(signer)) + select { + case event := <-sub.Events(): + t.Fatalf("Unexpected event: %v.", event) + case <-sub.Done(): + t.Fatal("CA watcher subscription has unexpectedly exited.") + case <-time.After(time.Second): + } - // Watcher should detect the ca list change. + // delete a CA with a name we are filtering for but a type we are NOT filtering for + // and ensure we do NOT receive the event + require.NoError(t, caService.DeleteCertAuthority(signer.GetID())) select { - case changeset := <-w.CertAuthorityC: - require.Len(t, changeset, 1) - require.Empty(t, caDiff(changeset[0], ca2)) - require.Empty(t, nothingWatcher.GetCurrent()) - case <-w.Done(): - t.Fatal("Watcher has unexpectedly exited.") - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for the update event.") + case event := <-sub.Events(): + t.Fatalf("Unexpected event: %v.", event) + case <-sub.Done(): + t.Fatal("CA watcher subscription has unexpectedly exited.") + case <-time.After(time.Second): } } @@ -840,15 +846,25 @@ func newCertAuthority(t *testing.T, name string, caType types.CertAuthType) type Type: caType, ClusterName: name, ActiveKeys: types.CAKeySet{ - SSH: []*types.SSHKeyPair{{ - PrivateKey: priv, - PrivateKeyType: types.PrivateKeyType_RAW, - PublicKey: pub, - }}, - TLS: []*types.TLSKeyPair{{ - Cert: cert, - Key: key, - }}, + SSH: []*types.SSHKeyPair{ + { + PrivateKey: priv, + PrivateKeyType: types.PrivateKeyType_RAW, + PublicKey: pub, + }, + }, + TLS: []*types.TLSKeyPair{ + { + Cert: cert, + Key: key, + }, + }, + JWT: []*types.JWTKeyPair{ + { + PublicKey: []byte(fixtures.JWTSignerPublicKey), + PrivateKey: []byte(fixtures.JWTSignerPrivateKey), + }, + }, }, Roles: nil, SigningAlg: types.CertAuthoritySpecV2_RSA_SHA2_256, diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 33c50fbce4cf7..80a17b747070e 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -1126,6 +1126,7 @@ func TestProxyRoundRobin(t *testing.T) { defer listener.Close() lockWatcher := newLockWatcher(ctx, t, proxyClient) nodeWatcher := newNodeWatcher(ctx, t, proxyClient) + caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClusterName: f.testSrv.ClusterName(), @@ -1143,6 +1144,7 @@ func TestProxyRoundRobin(t *testing.T) { Log: logger, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) logger.WithField("tun-addr", reverseTunnelAddress.String()).Info("Created reverse tunnel server.") @@ -1252,6 +1254,7 @@ func TestProxyDirectAccess(t *testing.T) { proxyClient, _ := newProxyClient(t, f.testSrv) lockWatcher := newLockWatcher(ctx, t, proxyClient) nodeWatcher := newNodeWatcher(ctx, t, proxyClient) + caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClientTLS: proxyClient.TLSConfig(), @@ -1269,6 +1272,7 @@ func TestProxyDirectAccess(t *testing.T) { Log: logger, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) @@ -1863,6 +1867,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { proxyClient, _ := newProxyClient(t, f.testSrv) lockWatcher := newLockWatcher(ctx, t, proxyClient) nodeWatcher := newNodeWatcher(ctx, t, proxyClient) + caWatcher := newCertAuthorityWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClientTLS: proxyClient.TLSConfig(), @@ -1880,6 +1885,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { Log: logger, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) @@ -2098,6 +2104,19 @@ func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *ser return nodeWatcher } +func newCertAuthorityWatcher(ctx context.Context, t *testing.T, client types.Events) *services.CertAuthorityWatcher { + caWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: client, + }, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + require.NoError(t, err) + t.Cleanup(caWatcher.Close) + return caWatcher +} + // maxPipeSize is one larger than the maximum pipe size for most operating // systems which appears to be 65536 bytes. // diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index ec1411dacb591..9b4e1f93dc386 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -275,6 +275,16 @@ func newWebSuite(t *testing.T) *WebSuite { }) require.NoError(t, err) + caWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.proxyClient, + }, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + require.NoError(t, err) + defer caWatcher.Close() + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -289,6 +299,7 @@ func newWebSuite(t *testing.T) *WebSuite { DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, NodeWatcher: proxyNodeWatcher, + CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) s.proxyTunnel = revTunServer @@ -3922,6 +3933,16 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(proxyLockWatcher.Close) + proxyCAWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: client, + }, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + require.NoError(t, err) + t.Cleanup(proxyLockWatcher.Close) + proxyNodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, @@ -3945,6 +3966,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, NodeWatcher: proxyNodeWatcher, + CertAuthorityWatcher: proxyCAWatcher, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunServer.Close()) })