Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[branch/v9] Improve CertAuthorityWatcher (#10403) #12724

Merged
merged 3 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 20 additions & 40 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4100,44 +4100,27 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// waitForPhase waits until aux cluster detects the rotation
waitForPhase := func(phase string) error {
ctx, cancel := context.WithTimeout(context.Background(), tconf.PollingPeriod*10)
defer cancel()

watcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Clock: tconf.Clock,
Client: aux.GetSiteAPI(clusterAux),
},
WatchHostCA: true,
})
if err != nil {
return err
}
defer watcher.Close()
waitForPhase := func(phase string) {
require.Eventually(t, func() bool {
ca, err := aux.Process.GetAuthServer().GetCertAuthority(
ctx,
types.CertAuthID{
Type: types.HostCA,
DomainName: clusterMain,
}, false)
if err != nil {
return false
}

var lastPhase string
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase)
case cas := <-watcher.CertAuthorityC:
for _, ca := range cas {
if ca.GetClusterName() == clusterMain &&
ca.GetType() == types.HostCA &&
ca.GetRotation().Phase == phase {
return nil
}
lastPhase = ca.GetRotation().Phase
}
if ca.GetRotation().Phase == phase {
return true
}
}
return trace.CompareFailed("failed to converge to phase %q, last phase %q", phase, lastPhase)

return false
}, 30*time.Second, 250*time.Millisecond, "failed to converge to phase %q", phase)
}

err = waitForPhase(types.RotationPhaseInit)
require.NoError(t, err)
waitForPhase(types.RotationPhaseInit)

// update clients
err = svc.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{
Expand All @@ -4150,8 +4133,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
svc, err = suite.waitForReload(serviceC, svc)
require.NoError(t, err)

err = waitForPhase(types.RotationPhaseUpdateClients)
require.NoError(t, err)
waitForPhase(types.RotationPhaseUpdateClients)

// old client should work as is
err = runAndMatch(clt, 8, []string{"echo", "hello world"}, ".*hello world.*")
Expand All @@ -4170,8 +4152,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
svc, err = suite.waitForReload(serviceC, svc)
require.NoError(t, err)

err = waitForPhase(types.RotationPhaseUpdateServers)
require.NoError(t, err)
waitForPhase(types.RotationPhaseUpdateServers)

// new credentials will work from this phase to others
newCreds, err := GenerateUserCreds(UserCredsRequest{Process: svc, Username: suite.me.Username})
Expand Down Expand Up @@ -4199,8 +4180,7 @@ func testRotateTrustedClusters(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)
t.Log("Service reload completed, waiting for phase.")

err = waitForPhase(types.RotationPhaseStandby)
require.NoError(t, err)
waitForPhase(types.RotationPhaseStandby)
t.Log("Phase completed.")

// new client still works
Expand Down
105 changes: 61 additions & 44 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,20 +427,12 @@ func (s *remoteSite) compareAndSwapCertAuthority(ca types.CertAuthority) error {
return trace.CompareFailed("remote certificate authority rotation has been updated")
}

func (s *remoteSite) updateCertAuthorities(retry utils.Retry) {
s.Debugf("Watching for cert authority changes.")
func (s *remoteSite) updateCertAuthorities(retry utils.Retry, remoteWatcher *services.CertAuthorityWatcher, remoteVersion string) {
defer remoteWatcher.Close()

cas := make(map[types.CertAuthType]types.CertAuthority)
for {
startedWaiting := s.clock.Now()
select {
case t := <-retry.After():
s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting))
retry.Inc()
case <-s.ctx.Done():
return
}

err := s.watchCertAuthorities()
err := s.watchCertAuthorities(remoteWatcher, remoteVersion, cas)
if err != nil {
switch {
case trace.IsNotFound(err):
Expand All @@ -456,67 +448,92 @@ func (s *remoteSite) updateCertAuthorities(retry utils.Retry) {
}
}

startedWaiting := s.clock.Now()
select {
case t := <-retry.After():
s.Debugf("Initiating new cert authority watch after waiting %v.", t.Sub(startedWaiting))
retry.Inc()
case <-s.ctx.Done():
return
}
}
}

func (s *remoteSite) watchCertAuthorities() error {
localWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: s,
Clock: s.clock,
Client: s.localAccessPoint,
func (s *remoteSite) watchCertAuthorities(remoteWatcher *services.CertAuthorityWatcher, remoteVersion string, cas map[types.CertAuthType]types.CertAuthority) error {
localWatch, err := s.srv.CertAuthorityWatcher.Subscribe(
s.ctx,
services.CertAuthorityTarget{
Type: types.HostCA,
ClusterName: s.srv.ClusterName,
},
WatchUserCA: true,
WatchHostCA: true,
})
services.CertAuthorityTarget{
Type: types.UserCA,
ClusterName: s.srv.ClusterName,
})
if err != nil {
return trace.Wrap(err)
}
defer localWatcher.Close()
defer func() {
if err := localWatch.Close(); err != nil {
s.WithError(err).Warn("Failed to close local ca watcher subscription.")
}
}()

remoteWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: s,
Clock: s.clock,
Client: s.remoteAccessPoint,
remoteWatch, err := remoteWatcher.Subscribe(
s.ctx,
services.CertAuthorityTarget{
ClusterName: s.domainName,
Type: types.HostCA,
},
WatchHostCA: true,
})
)
if err != nil {
return trace.Wrap(err)
}
defer remoteWatcher.Close()
defer func() {
if err := remoteWatch.Close(); err != nil {
s.WithError(err).Warn("Failed to close remote ca watcher subscription.")
}
}()

s.Debugf("Watching for cert authority changes.")
for {
select {
case <-s.ctx.Done():
s.WithError(s.ctx.Err()).Debug("Context is closing.")
return trace.Wrap(s.ctx.Err())
case <-localWatcher.Done():
case <-localWatch.Done():
s.Warn("Local CertAuthority watcher subscription has closed")
return fmt.Errorf("local ca watcher for cluster %s has closed", s.srv.ClusterName)
case <-remoteWatcher.Done():
case <-remoteWatch.Done():
s.Warn("Remote CertAuthority watcher subscription has closed")
return fmt.Errorf("remote ca watcher for cluster %s has closed", s.domainName)
case cas := <-localWatcher.CertAuthorityC:
for _, localCA := range cas {
if localCA.GetClusterName() != s.srv.ClusterName ||
(localCA.GetType() != types.HostCA &&
localCA.GetType() != types.UserCA) {
case evt := <-localWatch.Events():
switch evt.Type {
case types.OpPut:
localCA, ok := evt.Resource.(types.CertAuthority)
if !ok {
continue
}

ca, ok := cas[localCA.GetType()]
if ok && services.CertAuthoritiesEquivalent(ca, localCA) {
continue
}

// clone to prevent a race with watcher filtering
localCA = localCA.Clone()
if err := s.remoteClient.RotateExternalCertAuthority(s.ctx, localCA); err != nil {
s.WithError(err).Warn("Failed to rotate external ca")
log.WithError(err).Warn("Failed to rotate external ca")
return trace.Wrap(err)
}

cas[localCA.GetType()] = localCA
}
case cas := <-remoteWatcher.CertAuthorityC:
for _, remoteCA := range cas {
if remoteCA.GetType() != types.HostCA ||
remoteCA.GetClusterName() != s.domainName {
case evt := <-remoteWatch.Events():
switch evt.Type {
case types.OpPut:
remoteCA, ok := evt.Resource.(types.CertAuthority)
if !ok {
continue
}

Expand Down
35 changes: 27 additions & 8 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ type Config struct {

// NodeWatcher is a node watcher.
NodeWatcher *services.NodeWatcher

// CertAuthorityWatcher is a cert authority watcher.
CertAuthorityWatcher *services.CertAuthorityWatcher
}

// CheckAndSetDefaults checks parameters and sets default values
Expand Down Expand Up @@ -259,6 +262,9 @@ func (cfg *Config) CheckAndSetDefaults() error {
if cfg.NodeWatcher == nil {
return trace.BadParameter("missing parameter NodeWatcher")
}
if cfg.CertAuthorityWatcher == nil {
return trace.BadParameter("missing parameter CertAuthorityWatcher")
}
return nil
}

Expand Down Expand Up @@ -1040,6 +1046,11 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
connInfo.SetExpiry(srv.Clock.Now().Add(srv.offlineThreshold))

closeContext, cancel := context.WithCancel(srv.ctx)
defer func() {
if err != nil {
cancel()
}
}()
remoteSite := &remoteSite{
srv: srv,
domainName: domainName,
Expand All @@ -1063,20 +1074,17 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,

clt, _, err := remoteSite.getRemoteClient()
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteClient = clt

remoteVersion, err := getRemoteAuthVersion(closeContext, sconn)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.remoteAccessPoint = accessPoint
Expand All @@ -1088,7 +1096,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
},
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.nodeWatcher = nodeWatcher
Expand All @@ -1098,7 +1105,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
// is signed by the correct certificate authority.
certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient)
if err != nil {
cancel()
return nil, trace.Wrap(err)
}
remoteSite.certificateCache = certificateCache
Expand All @@ -1111,11 +1117,25 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

go remoteSite.updateCertAuthorities(caRetry)
remoteWatcher, err := services.NewCertAuthorityWatcher(srv.ctx, services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: srv.log,
Clock: srv.Clock,
Client: remoteSite.remoteAccessPoint,
},
Types: []types.CertAuthType{types.HostCA},
})
if err != nil {
return nil, trace.Wrap(err)
}

go func() {
remoteSite.updateCertAuthorities(caRetry, remoteWatcher, remoteVersion)
}()

lockRetry, err := utils.NewLinear(utils.LinearConfig{
First: utils.HalfJitter(srv.Config.PollingPeriod),
Expand All @@ -1125,7 +1145,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
Clock: srv.Clock,
})
if err != nil {
cancel()
return nil, trace.Wrap(err)
}

Expand Down
36 changes: 25 additions & 11 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,19 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return trace.Wrap(err)
}

caWatcher, err := services.NewCertAuthorityWatcher(process.ExitContext(), services.CertAuthorityWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentProxy,
Log: process.log.WithField(trace.Component, teleport.ComponentProxy),
Client: conn.Client,
},
AuthorityGetter: accessPoint,
Types: []types.CertAuthType{types.HostCA, types.UserCA},
})
if err != nil {
return trace.Wrap(err)
}

serverTLSConfig, err := conn.ServerIdentity.TLSConfig(cfg.CipherSuites)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -2903,17 +2916,18 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
Client: conn.Client,
},
},
KeyGen: cfg.Keygen,
Ciphers: cfg.Ciphers,
KEXAlgorithms: cfg.KEXAlgorithms,
MACAlgorithms: cfg.MACAlgorithms,
DataDir: process.Config.DataDir,
PollingPeriod: process.Config.PollingPeriod,
FIPS: cfg.FIPS,
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
KeyGen: cfg.Keygen,
Ciphers: cfg.Ciphers,
KEXAlgorithms: cfg.KEXAlgorithms,
MACAlgorithms: cfg.MACAlgorithms,
DataDir: process.Config.DataDir,
PollingPeriod: process.Config.PollingPeriod,
FIPS: cfg.FIPS,
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
})
if err != nil {
return trace.Wrap(err)
Expand Down
Loading