diff --git a/lib/services/local/externalauditstorage_watcher.go b/lib/services/local/externalauditstorage_watcher.go index 5cf898cfa6d84..af44537fe6ce9 100644 --- a/lib/services/local/externalauditstorage_watcher.go +++ b/lib/services/local/externalauditstorage_watcher.go @@ -63,15 +63,15 @@ func (cfg *ClusterExternalAuditStorageWatcherConfig) CheckAndSetDefaults() error // ClusterExternalAuditWatcher is a light weight backend watcher for the cluster external audit resource. type ClusterExternalAuditWatcher struct { - backend backend.Backend - log logrus.FieldLogger - clock clockwork.Clock - onChange func() - retry retryutils.Retry - initialized chan struct{} - closed chan struct{} - closeOnce sync.Once - done chan struct{} + backend backend.Backend + log logrus.FieldLogger + clock clockwork.Clock + onChange func() + retry retryutils.Retry + running chan struct{} + closed chan struct{} + closeOnce sync.Once + done chan struct{} } // NewClusterExternalAuditWatcher creates a new cluster external audit resource watcher. @@ -93,14 +93,14 @@ func NewClusterExternalAuditWatcher(ctx context.Context, cfg ClusterExternalAudi } w := &ClusterExternalAuditWatcher{ - backend: cfg.Backend, - log: cfg.Log, - clock: cfg.Clock, - onChange: cfg.OnChange, - retry: retry, - initialized: make(chan struct{}), - closed: make(chan struct{}), - done: make(chan struct{}), + backend: cfg.Backend, + log: cfg.Log, + clock: cfg.Clock, + onChange: cfg.OnChange, + retry: retry, + running: make(chan struct{}), + closed: make(chan struct{}), + done: make(chan struct{}), } go w.runWatchLoop(ctx) @@ -111,13 +111,13 @@ func NewClusterExternalAuditWatcher(ctx context.Context, cfg ClusterExternalAudi // WaitInit waits for the watch loop to initialize. func (w *ClusterExternalAuditWatcher) WaitInit(ctx context.Context) error { select { - case <-w.initialized: + case <-w.running: + return nil case <-w.done: - return errors.New("watcher closed") + return trace.Errorf("watcher closed") case <-ctx.Done(): return trace.Wrap(ctx.Err()) } - return nil } // close stops the watcher and waits for the watch loop to exit @@ -155,8 +155,9 @@ func (w *ClusterExternalAuditWatcher) watch(ctx context.Context) error { case <-watcher.Events(): w.log.Infof("Detected change to cluster ExternalAuditStorage config") w.onChange() + case w.running <- struct{}{}: case <-watcher.Done(): - return errors.New("watcher closed") + return trace.Errorf("watcher closed") case <-ctx.Done(): return ctx.Err() case <-w.closed: @@ -185,7 +186,6 @@ func (w *ClusterExternalAuditWatcher) newWatcher(ctx context.Context) (backend.W if event.Type != types.OpInit { return nil, trace.BadParameter("expected init event, got %v instead", event.Type) } - close(w.initialized) } w.retry.Reset() diff --git a/lib/services/local/externalauditstorage_watcher_test.go b/lib/services/local/externalauditstorage_watcher_test.go index 5360229b4f414..00552dc564f3d 100644 --- a/lib/services/local/externalauditstorage_watcher_test.go +++ b/lib/services/local/externalauditstorage_watcher_test.go @@ -17,10 +17,14 @@ package local import ( "context" "testing" + "time" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/defaults" ) func TestClusterExternalAuditWatcher(t *testing.T) { @@ -34,7 +38,6 @@ func TestClusterExternalAuditWatcher(t *testing.T) { require.NoError(t, err) svc := NewExternalAuditStorageService(bk) - require.NotNil(t, svc) ch := make(chan string) @@ -120,3 +123,80 @@ func TestClusterExternalAuditWatcher(t *testing.T) { }) } } + +// TestClusterExternalAuditWatcher_WatcherClosed tests that the +// ExternalAuditWatcher can recover from the underlying backend watcher closing. +func TestClusterExternalAuditWatcher_WatcherClosed(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + bk, err := memory.New(memory.Config{ + Context: ctx, + }) + require.NoError(t, err) + + svc := NewExternalAuditStorageService(bk) + + interceptor := &watcherInterceptor{ + Backend: bk, + watchers: make(chan backend.Watcher, 1), + } + + changes := make(chan struct{}) + clock := clockwork.NewFakeClock() + + auditWatcher, err := NewClusterExternalAuditWatcher(ctx, ClusterExternalAuditStorageWatcherConfig{ + Backend: interceptor, + OnChange: func() { + changes <- struct{}{} + }, + Clock: clock, + }) + require.NoError(t, err) + + require.NoError(t, auditWatcher.WaitInit(ctx)) + + // Sanity test a change is detected + _, err = svc.GenerateDraftExternalAuditStorage(ctx, "test-integration", "us-west-2") + require.NoError(t, err) + err = svc.PromoteToClusterExternalAuditStorage(ctx) + require.NoError(t, err) + select { + case <-changes: + case <-time.After(5 * time.Second): + t.Fatal("watcher failed to detect change") + } + + // Close the backend watcher and make sure the audit watcher recovers + w := <-interceptor.watchers + w.Close() + clock.BlockUntil(1) + clock.Advance(defaults.LowResPollingPeriod) + require.NoError(t, auditWatcher.WaitInit(ctx)) + + // It should still detect changes + err = svc.DisableClusterExternalAuditStorage(ctx) + require.NoError(t, err) + select { + case <-changes: + case <-time.After(5 * time.Second): + t.Fatal("watcher failed to detect change") + } +} + +// watcherInterceptor wraps a backend.Backend and writes all backend watchers +// returned from NewWatcher to a channel. +type watcherInterceptor struct { + backend.Backend + watchers chan backend.Watcher +} + +func (i *watcherInterceptor) NewWatcher(ctx context.Context, watch backend.Watch) (backend.Watcher, error) { + w, err := i.Backend.NewWatcher(ctx, watch) + if err != nil { + return nil, err + } + i.watchers <- w + return w, nil +}