From 233464e48006615bd9d59ac1efd3f2a7213484a4 Mon Sep 17 00:00:00 2001 From: joerger Date: Tue, 21 Mar 2023 18:38:37 -0700 Subject: [PATCH 1/4] Fix race condition in test by using a helper function instead of complex channel mechanisms. --- lib/services/local/headlessauthn_watcher.go | 18 +++++++++++ .../local/headlessauthn_watcher_test.go | 31 +++++++++---------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/lib/services/local/headlessauthn_watcher.go b/lib/services/local/headlessauthn_watcher.go index 52a0e841b9d58..18717f28db473 100644 --- a/lib/services/local/headlessauthn_watcher.go +++ b/lib/services/local/headlessauthn_watcher.go @@ -224,6 +224,24 @@ func (h *HeadlessAuthenticationWatcher) CheckWaiter(name string) bool { return false } +// CheckWaiterStale checks if the active waiter with the given +// headless authentication ID is marked as stale. Used in tests. +func (h *HeadlessAuthenticationWatcher) CheckWaiterStale(name string) (bool, error) { + h.mux.Lock() + defer h.mux.Unlock() + for i := range h.waiters { + if h.waiters[i].name == name { + select { + case <-h.waiters[i].stale: + return true, nil + default: + return false, nil + } + } + } + return false, trace.NotFound("no waiter found with ID %v", name) +} + // Wait watches for the headless authentication with the given id to be added/updated // in the backend, and waits for the given condition to be met, to result in an error, // or for the given context to close. diff --git a/lib/services/local/headlessauthn_watcher_test.go b/lib/services/local/headlessauthn_watcher_test.go index bb4346b419782..6f6d6486f5336 100644 --- a/lib/services/local/headlessauthn_watcher_test.go +++ b/lib/services/local/headlessauthn_watcher_test.go @@ -114,40 +114,39 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*2) defer waitCancel() - // Create two waiters - a blocked consumer and a free consumer. - blockWait := make(chan struct{}) - _, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - <-blockWait - return false, nil - }) - + // Create a waiter that we can block/unblock. notifyReceived := make(chan struct{}, 1) - waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + blockWaiter := make(chan struct{}) + _, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { notifyReceived <- struct{}{} + <-blockWaiter return false, nil }) - // Create stub and wait for it to be caught by the free waiter. + // Create stub and wait for it to be caught by the waiter. stub, err := identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) require.NoError(t, err) <-notifyReceived + // perform a put to mark the blocked waiter as stale and replace := *stub replace.PublicKey = []byte(sshPubKey) replace.User = "user" - - // perform a put to mark the blocked waiter as stale and - // wait for it to be caught by the free waiter. _, err = identity.CompareAndSwapHeadlessAuthentication(ctx, stub, &replace) require.NoError(t, err) - <-notifyReceived - // delete the headless authentication and unblock. + require.Eventually(t, func() bool { + ok, err := w.CheckWaiterStale(pubUUID) + require.NoError(t, err) + return ok + }, time.Second, time.Millisecond, "Expected waiter to be marked as stale") + + // delete the headless authentication. err = identity.DeleteHeadlessAuthentication(ctx, pubUUID) require.NoError(t, err) - close(blockWait) - // the blocked waiter should perform a stale check and return a not found error. + // unblock the waiter. It should perform a stale check and return a not found error. + close(blockWaiter) err = <-errC require.True(t, trace.IsNotFound(err), "Expected a not found error from Wait but got %v", err) }) From 47e7cab7b757e9c8624997c6c9b56b43ffc54691 Mon Sep 17 00:00:00 2001 From: joerger Date: Wed, 22 Mar 2023 13:25:04 -0700 Subject: [PATCH 2/4] Avoid creating new methods solely for testing; resolve other comments. --- api/types/headlessauthn.go | 17 + lib/services/local/headlessauthn.go | 19 +- lib/services/local/headlessauthn_watcher.go | 41 +-- .../local/headlessauthn_watcher_test.go | 310 +++++++++++++----- 4 files changed, 265 insertions(+), 122 deletions(-) diff --git a/api/types/headlessauthn.go b/api/types/headlessauthn.go index b49d3d637f248..60cd458f8a9ae 100644 --- a/api/types/headlessauthn.go +++ b/api/types/headlessauthn.go @@ -17,9 +17,26 @@ limitations under the License. package types import ( + "time" + "github.com/gravitational/trace" ) +// NewHeadlessAuthenticationStub creates a new headless authentication stub, which is +// a headless authentication resource with limited data. This stub is used to initiate +// headless login. +func NewHeadlessAuthenticationStub(name string, expires time.Time) (*HeadlessAuthentication, error) { + ha := &HeadlessAuthentication{ + ResourceHeader: ResourceHeader{ + Metadata: Metadata{ + Name: name, + Expires: &expires, + }, + }, + } + return ha, ha.CheckAndSetDefaults() +} + // CheckAndSetDefaults does basic validation and default setting. func (h *HeadlessAuthentication) CheckAndSetDefaults() error { h.setStaticFields() diff --git a/lib/services/local/headlessauthn.go b/lib/services/local/headlessauthn.go index c8b0a70bb3cf4..9537761020aa7 100644 --- a/lib/services/local/headlessauthn.go +++ b/lib/services/local/headlessauthn.go @@ -31,17 +31,12 @@ import ( // CreateHeadlessAuthenticationStub creates a headless authentication stub in the backend. func (s *IdentityService) CreateHeadlessAuthenticationStub(ctx context.Context, name string) (*types.HeadlessAuthentication, error) { - expires := s.Clock().Now().Add(defaults.CallbackTimeout) - headlessAuthn := &types.HeadlessAuthentication{ - ResourceHeader: types.ResourceHeader{ - Metadata: types.Metadata{ - Name: name, - Expires: &expires, - }, - }, + headlessAuthn, err := types.NewHeadlessAuthenticationStub(name, s.Clock().Now().Add(defaults.CallbackTimeout)) + if err != nil { + return nil, trace.Wrap(err) } - item, err := marshalHeadlessAuthenticationToItem(headlessAuthn) + item, err := MarshalHeadlessAuthenticationToItem(headlessAuthn) if err != nil { return nil, trace.Wrap(err) } @@ -60,12 +55,12 @@ func (s *IdentityService) CompareAndSwapHeadlessAuthentication(ctx context.Conte return nil, trace.Wrap(err) } - oldItem, err := marshalHeadlessAuthenticationToItem(old) + oldItem, err := MarshalHeadlessAuthenticationToItem(old) if err != nil { return nil, trace.Wrap(err) } - newItem, err := marshalHeadlessAuthenticationToItem(new) + newItem, err := MarshalHeadlessAuthenticationToItem(new) if err != nil { return nil, trace.Wrap(err) } @@ -120,7 +115,7 @@ func (s *IdentityService) DeleteHeadlessAuthentication(ctx context.Context, name return trace.Wrap(err) } -func marshalHeadlessAuthenticationToItem(headlessAuthn *types.HeadlessAuthentication) (*backend.Item, error) { +func MarshalHeadlessAuthenticationToItem(headlessAuthn *types.HeadlessAuthentication) (*backend.Item, error) { if err := headlessAuthn.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/services/local/headlessauthn_watcher.go b/lib/services/local/headlessauthn_watcher.go index 18717f28db473..7ff88be174c75 100644 --- a/lib/services/local/headlessauthn_watcher.go +++ b/lib/services/local/headlessauthn_watcher.go @@ -48,6 +48,11 @@ var watcherClosedErr = trace.Errorf("headless authentication watcher closed") type HeadlessAuthenticationWatcherConfig struct { // Backend is the storage backend used to create watchers. Backend backend.Backend + // WatcherService is a service used to create new watchers. + // If nil, Backend will be used as the watcher service. + WatcherService interface { + NewWatcher(ctx context.Context, watch backend.Watch) (backend.Watcher, error) + } // Log is a logger. Log logrus.FieldLogger // Clock is used to control time. @@ -61,6 +66,9 @@ func (cfg *HeadlessAuthenticationWatcherConfig) CheckAndSetDefaults() error { if cfg.Backend == nil { return trace.BadParameter("missing parameter Backend") } + if cfg.WatcherService == nil { + cfg.WatcherService = cfg.Backend + } if cfg.Log == nil { cfg.Log = logrus.StandardLogger() cfg.Log.WithField("resource-kind", types.KindHeadlessAuthentication) @@ -145,7 +153,7 @@ func (h *HeadlessAuthenticationWatcher) runWatchLoop(ctx context.Context) { } func (h *HeadlessAuthenticationWatcher) watch(ctx context.Context) error { - watcher, err := h.Backend.NewWatcher(ctx, backend.Watch{ + watcher, err := h.WatcherService.NewWatcher(ctx, backend.Watch{ Name: types.KindHeadlessAuthentication, MetricComponent: types.KindHeadlessAuthentication, Prefixes: [][]byte{headlessAuthenticationKey("")}, @@ -211,37 +219,6 @@ func (h *HeadlessAuthenticationWatcher) notify(headlessAuthns ...*types.Headless } } -// CheckWaiter checks if there is an active waiter matching the given -// headless authentication ID. Used in tests. -func (h *HeadlessAuthenticationWatcher) CheckWaiter(name string) bool { - h.mux.Lock() - defer h.mux.Unlock() - for i := range h.waiters { - if h.waiters[i].name == name { - return true - } - } - return false -} - -// CheckWaiterStale checks if the active waiter with the given -// headless authentication ID is marked as stale. Used in tests. -func (h *HeadlessAuthenticationWatcher) CheckWaiterStale(name string) (bool, error) { - h.mux.Lock() - defer h.mux.Unlock() - for i := range h.waiters { - if h.waiters[i].name == name { - select { - case <-h.waiters[i].stale: - return true, nil - default: - return false, nil - } - } - } - return false, trace.NotFound("no waiter found with ID %v", name) -} - // Wait watches for the headless authentication with the given id to be added/updated // in the backend, and waits for the given condition to be met, to result in an error, // or for the given context to close. diff --git a/lib/services/local/headlessauthn_watcher_test.go b/lib/services/local/headlessauthn_watcher_test.go index 6f6d6486f5336..6af822c91e861 100644 --- a/lib/services/local/headlessauthn_watcher_test.go +++ b/lib/services/local/headlessauthn_watcher_test.go @@ -26,162 +26,316 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" ) func TestHeadlessAuthenticationWatcher(t *testing.T) { + t.Parallel() ctx := context.Background() + pubUUID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) - t.Parallel() - identity := newIdentityService(t, clockwork.NewFakeClock()) + type testSuite struct { + watcher *local.HeadlessAuthenticationWatcher + watcherClock clockwork.FakeClock + watcherCancel context.CancelFunc + identity *local.IdentityService + buf *backend.CircularBuffer + } - watcherCtx, watcherCancel := context.WithCancel(ctx) - defer watcherCancel() + newSuite := func(t *testing.T) *testSuite { + identity := newIdentityService(t, clockwork.NewFakeClock()) - watcherClock := clockwork.NewFakeClock() - w, err := local.NewHeadlessAuthenticationWatcher(watcherCtx, local.HeadlessAuthenticationWatcherConfig{ - Clock: watcherClock, - Backend: identity.Backend, - }) - require.NoError(t, err) + // use a standalone buffer as a watcher service. + buf := backend.NewCircularBuffer() + buf.SetInit() - pubUUID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) + watcherCtx, watcherCancel := context.WithCancel(ctx) + t.Cleanup(func() { + watcherCancel() + }) + + watcherClock := clockwork.NewFakeClock() + w, err := local.NewHeadlessAuthenticationWatcher(watcherCtx, local.HeadlessAuthenticationWatcherConfig{ + Clock: watcherClock, + WatcherService: buf, + Backend: identity.Backend, + }) + require.NoError(t, err) - waitInGoroutine := func(ctx context.Context, t *testing.T, name string, cond func(*types.HeadlessAuthentication) (bool, error)) (chan *types.HeadlessAuthentication, chan error) { + return &testSuite{ + watcher: w, + watcherClock: watcherClock, + watcherCancel: watcherCancel, + identity: identity, + buf: buf, + } + } + + waitInGoroutine := func(ctx context.Context, t *testing.T, watcher *local.HeadlessAuthenticationWatcher, name string, cond func(*types.HeadlessAuthentication) (bool, error)) (chan *types.HeadlessAuthentication, chan error) { headlessAuthnCh := make(chan *types.HeadlessAuthentication, 1) errC := make(chan error, 1) go func() { - headlessAuthn, err := w.Wait(ctx, name, cond) + headlessAuthn, err := watcher.Wait(ctx, name, cond) errC <- err headlessAuthnCh <- headlessAuthn }() - require.Eventually(t, func() bool { return w.CheckWaiter(name) }, time.Millisecond*100, time.Millisecond*10) - return headlessAuthnCh, errC } - t.Run("WaitTimeout", func(t *testing.T) { - waitCtx, waitCancel := context.WithTimeout(ctx, time.Millisecond*10) + t.Run("WaitEventWithConditionMet", func(t *testing.T) { + t.Parallel() + s := newSuite(t) + + waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) defer waitCancel() - _, err = w.Wait(waitCtx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) - require.Error(t, err) - require.Equal(t, waitCtx.Err(), err) + firstEmitReceived := make(chan struct{}) + headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + select { + case <-firstEmitReceived: + default: + close(firstEmitReceived) + } + + return ha.User != "", nil + }) + + // Emit put event that doesn't pass the condition (user not set). The waiter should ignore these events. + stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) + require.NoError(t, err) + stub.User = "user" + item, err := local.MarshalHeadlessAuthenticationToItem(stub) + require.NoError(t, err) + + // The waiter may miss put events during initialization, so we continuously emit them until one is caught. + require.Eventually(t, func() bool { + s.buf.Emit(backend.Event{ + Type: types.OpPut, + Item: *item, + }) + select { + case <-firstEmitReceived: + return true + default: + return false + } + }, time.Second, time.Millisecond*100) + + require.NoError(t, <-errC) + require.Equal(t, stub, <-headlessAuthnCh) }) - t.Run("WaitCreateStub", func(t *testing.T) { - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*2) + t.Run("WaitEventWithConditionUnmet", func(t *testing.T) { + t.Parallel() + s := newSuite(t) + + waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) defer waitCancel() - headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - return true, nil + firstEmitReceived := make(chan struct{}) + headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + select { + case <-firstEmitReceived: + default: + close(firstEmitReceived) + } + + return ha.User != "", nil }) - stub, err := identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + // Emit put event that doesn't pass the condition (user not set). The waiter should ignore these events. + stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) + require.NoError(t, err) + item, err := local.MarshalHeadlessAuthenticationToItem(stub) require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, identity.DeleteHeadlessAuthentication(ctx, pubUUID)) }) - require.NoError(t, <-errC) - require.Equal(t, stub, <-headlessAuthnCh) + // The waiter may miss put events during initialization, so we continuously emit them until one is caught. + require.Eventually(t, func() bool { + s.buf.Emit(backend.Event{ + Type: types.OpPut, + Item: *item, + }) + select { + case <-firstEmitReceived: + return true + default: + return false + } + }, time.Second, time.Millisecond*100) + + // Ensure that the waiter did not finish with the condition unmet. + select { + case err := <-errC: + t.Errorf("Expected waiter to continue but instead the waiter returned with err: %v", err) + default: + waitCancel() + } + + require.Error(t, <-errC) + require.Nil(t, <-headlessAuthnCh) }) - t.Run("WaitCompareAndSwap", func(t *testing.T) { - stub, err := identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + t.Run("WaitBackend", func(t *testing.T) { + t.Parallel() + s := newSuite(t) + + stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, identity.DeleteHeadlessAuthentication(ctx, pubUUID)) }) - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*2) + waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) defer waitCancel() - headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - return ha.State == types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, nil + // Wait should immediately check the backend and return the existing headless authentication stub. + headlessAuthn, err := s.watcher.Wait(waitCtx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + return true, nil }) - replace := *stub - replace.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED - replace.PublicKey = []byte(sshPubKey) - replace.User = "user" - - swapped, err := identity.CompareAndSwapHeadlessAuthentication(ctx, stub, &replace) require.NoError(t, err) + require.Equal(t, stub, headlessAuthn) + }) - require.NoError(t, <-errC) - require.Equal(t, swapped, <-headlessAuthnCh) + t.Run("WaitTimeout", func(t *testing.T) { + t.Parallel() + s := newSuite(t) + + waitCtx, waitCancel := context.WithTimeout(ctx, time.Millisecond*10) + defer waitCancel() + + _, err := s.watcher.Wait(waitCtx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) + require.Error(t, err) + require.Equal(t, waitCtx.Err(), err) }) t.Run("StaleCheck", func(t *testing.T) { - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*2) + t.Parallel() + s := newSuite(t) + + waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) defer waitCancel() // Create a waiter that we can block/unblock. - notifyReceived := make(chan struct{}, 1) + firstEmitReceived := make(chan struct{}) blockWaiter := make(chan struct{}) - _, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - notifyReceived <- struct{}{} - <-blockWaiter + _, blockedWaiterErrC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + select { + case <-firstEmitReceived: + default: + close(firstEmitReceived) + // we only block the first event received, incase additional events get to this waiter. + <-blockWaiter + } return false, nil }) - // Create stub and wait for it to be caught by the waiter. - stub, err := identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + // Emit stub put event and wait for it to be caught by the waiter. + stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) require.NoError(t, err) - <-notifyReceived - - // perform a put to mark the blocked waiter as stale and - replace := *stub - replace.PublicKey = []byte(sshPubKey) - replace.User = "user" - _, err = identity.CompareAndSwapHeadlessAuthentication(ctx, stub, &replace) + item, err := local.MarshalHeadlessAuthenticationToItem(stub) require.NoError(t, err) require.Eventually(t, func() bool { - ok, err := w.CheckWaiterStale(pubUUID) - require.NoError(t, err) - return ok - }, time.Second, time.Millisecond, "Expected waiter to be marked as stale") + s.buf.Emit(backend.Event{ + Type: types.OpPut, + Item: *item, + }) + select { + case <-firstEmitReceived: + return true + default: + return false + } + }, time.Second, time.Millisecond*100) + + // Create a second waiter to catch a second put event. + _, freeWaiterErrC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + return true, nil + }) - // delete the headless authentication. - err = identity.DeleteHeadlessAuthentication(ctx, pubUUID) - require.NoError(t, err) + require.Eventually(t, func() bool { + s.buf.Emit(backend.Event{ + Type: types.OpPut, + Item: *item, + }) + select { + case <-freeWaiterErrC: + return true + default: + return false + } + }, time.Second, time.Millisecond*100) // unblock the waiter. It should perform a stale check and return a not found error. close(blockWaiter) - err = <-errC + err = <-blockedWaiterErrC require.True(t, trace.IsNotFound(err), "Expected a not found error from Wait but got %v", err) }) t.Run("WatchReset", func(t *testing.T) { - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*2) + t.Parallel() + s := newSuite(t) + + waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) defer waitCancel() - headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - return true, nil + firstEmitReceived := make(chan struct{}) + headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + select { + case <-firstEmitReceived: + default: + close(firstEmitReceived) + } + + return services.ValidateHeadlessAuthentication(ha) == nil, nil }) + stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + require.NoError(t, err) + + item, err := local.MarshalHeadlessAuthenticationToItem(stub) + require.NoError(t, err) + require.Eventually(t, func() bool { + s.buf.Emit(backend.Event{ + Type: types.OpPut, + Item: *item, + }) + select { + case <-firstEmitReceived: + return true + default: + return false + } + }, time.Second, time.Millisecond*100) + // closed watchers should be handled gracefully and reset. - identity.Backend.CloseWatchers() - watcherClock.BlockUntil(1) + s.buf.Clear() + s.watcherClock.BlockUntil(1) - // The watcher should notify waiters of missed events. - stub, err := identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + // The watcher should notify waiters of backend state on watcher reset. + replace := *stub + replace.PublicKey = []byte(sshPubKey) + replace.User = "user" + swapped, err := s.identity.CompareAndSwapHeadlessAuthentication(ctx, stub, &replace) require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, identity.DeleteHeadlessAuthentication(ctx, pubUUID)) }) - watcherClock.Advance(w.MaxRetryPeriod) + s.watcherClock.Advance(s.watcher.MaxRetryPeriod) require.NoError(t, <-errC) - require.Equal(t, stub, <-headlessAuthnCh) + require.Equal(t, swapped, <-headlessAuthnCh) }) t.Run("WatcherClosed", func(t *testing.T) { - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*2) + t.Parallel() + s := newSuite(t) + + waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) defer waitCancel() - _, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + _, errC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) - watcherCancel() + s.watcherCancel() // waiters should be notified to close and result in ctx error waitErr := <-errC @@ -189,7 +343,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { require.Equal(t, waitErr.Error(), "headless authentication watcher closed") // New waiters should be prevented. - _, err = w.Wait(ctx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) + _, err := s.watcher.Wait(ctx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) require.Error(t, err) }) } From baeb8851fcc114e720385bba28a3dedceaa2661a Mon Sep 17 00:00:00 2001 From: joerger Date: Thu, 23 Mar 2023 12:13:18 -0700 Subject: [PATCH 3/4] Reuse more code; resolve other comments. --- api/types/headlessauthn.go | 5 +- .../local/headlessauthn_watcher_test.go | 210 ++++++------------ 2 files changed, 74 insertions(+), 141 deletions(-) diff --git a/api/types/headlessauthn.go b/api/types/headlessauthn.go index 60cd458f8a9ae..c37a8cd50d28c 100644 --- a/api/types/headlessauthn.go +++ b/api/types/headlessauthn.go @@ -22,9 +22,8 @@ import ( "github.com/gravitational/trace" ) -// NewHeadlessAuthenticationStub creates a new headless authentication stub, which is -// a headless authentication resource with limited data. This stub is used to initiate -// headless login. +// NewHeadlessAuthenticationStub creates a new a headless authentication resource with limited data. +// The stub is used to initiate headless login. func NewHeadlessAuthenticationStub(name string, expires time.Time) (*HeadlessAuthentication, error) { ha := &HeadlessAuthentication{ ResourceHeader: ResourceHeader{ diff --git a/lib/services/local/headlessauthn_watcher_test.go b/lib/services/local/headlessauthn_watcher_test.go index 6af822c91e861..9b1d4ae83cfd4 100644 --- a/lib/services/local/headlessauthn_watcher_test.go +++ b/lib/services/local/headlessauthn_watcher_test.go @@ -36,7 +36,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { ctx := context.Background() pubUUID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) - type testSuite struct { + type testEnv struct { watcher *local.HeadlessAuthenticationWatcher watcherClock clockwork.FakeClock watcherCancel context.CancelFunc @@ -44,7 +44,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { buf *backend.CircularBuffer } - newSuite := func(t *testing.T) *testSuite { + newTestEnv := func(t *testing.T) *testEnv { identity := newIdentityService(t, clockwork.NewFakeClock()) // use a standalone buffer as a watcher service. @@ -52,9 +52,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { buf.SetInit() watcherCtx, watcherCancel := context.WithCancel(ctx) - t.Cleanup(func() { - watcherCancel() - }) + t.Cleanup(watcherCancel) watcherClock := clockwork.NewFakeClock() w, err := local.NewHeadlessAuthenticationWatcher(watcherCtx, local.HeadlessAuthenticationWatcherConfig{ @@ -64,7 +62,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { }) require.NoError(t, err) - return &testSuite{ + return &testEnv{ watcher: w, watcherClock: watcherClock, watcherCancel: watcherCancel, @@ -73,55 +71,64 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { } } - waitInGoroutine := func(ctx context.Context, t *testing.T, watcher *local.HeadlessAuthenticationWatcher, name string, cond func(*types.HeadlessAuthentication) (bool, error)) (chan *types.HeadlessAuthentication, chan error) { - headlessAuthnCh := make(chan *types.HeadlessAuthentication, 1) - errC := make(chan error, 1) + waitInGoroutine := func(ctx context.Context, watcher *local.HeadlessAuthenticationWatcher, name string, cond func(*types.HeadlessAuthentication) (bool, error), + ) (headlessAuthnC chan *types.HeadlessAuthentication, firstEventReceivedC chan struct{}, errC chan error) { + waitCtx, waitCancel := context.WithTimeout(ctx, 5*time.Second) + t.Cleanup(waitCancel) + + headlessAuthnC = make(chan *types.HeadlessAuthentication, 1) + errC = make(chan error, 1) + firstEventReceivedC = make(chan struct{}) go func() { - headlessAuthn, err := watcher.Wait(ctx, name, cond) + headlessAuthn, err := watcher.Wait(waitCtx, name, func(ha *types.HeadlessAuthentication) (bool, error) { + select { + case <-firstEventReceivedC: + default: + close(firstEventReceivedC) + } + return cond(ha) + }) errC <- err - headlessAuthnCh <- headlessAuthn + headlessAuthnC <- headlessAuthn }() - return headlessAuthnCh, errC + return headlessAuthnC, firstEventReceivedC, errC } - t.Run("WaitEventWithConditionMet", func(t *testing.T) { - t.Parallel() - s := newSuite(t) + // The waiter may miss put events during initialization, so we continuously emit them until one is caught. + // This can also be used to wait until a waiter is fully initialized. + waitForPutEvent := func(s *testEnv, ha *types.HeadlessAuthentication, firstEventReceivedC chan struct{}) { + item, err := local.MarshalHeadlessAuthenticationToItem(ha) + require.NoError(t, err) - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) - defer waitCancel() + require.Eventually(t, func() bool { + s.buf.Emit(backend.Event{ + Type: types.OpPut, + Item: *item, + }) - firstEmitReceived := make(chan struct{}) - headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { select { - case <-firstEmitReceived: + case <-firstEventReceivedC: + return true default: - close(firstEmitReceived) + return false } + }, 2*time.Second, 100*time.Millisecond) + } + t.Run("WaitEventWithConditionMet", func(t *testing.T) { + t.Parallel() + s := newTestEnv(t) + + headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(ctx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return ha.User != "", nil }) - // Emit put event that doesn't pass the condition (user not set). The waiter should ignore these events. + // Emit put event that passes the condition. stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) require.NoError(t, err) stub.User = "user" - item, err := local.MarshalHeadlessAuthenticationToItem(stub) - require.NoError(t, err) - // The waiter may miss put events during initialization, so we continuously emit them until one is caught. - require.Eventually(t, func() bool { - s.buf.Emit(backend.Event{ - Type: types.OpPut, - Item: *item, - }) - select { - case <-firstEmitReceived: - return true - default: - return false - } - }, time.Second, time.Millisecond*100) + waitForPutEvent(s, stub, firstEventReceivedC) require.NoError(t, <-errC) require.Equal(t, stub, <-headlessAuthnCh) @@ -129,41 +136,20 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { t.Run("WaitEventWithConditionUnmet", func(t *testing.T) { t.Parallel() - s := newSuite(t) - - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) - defer waitCancel() + s := newTestEnv(t) - firstEmitReceived := make(chan struct{}) - headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - select { - case <-firstEmitReceived: - default: - close(firstEmitReceived) - } + waitCtx, waitCancel := context.WithCancel(ctx) + t.Cleanup(waitCancel) + headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(waitCtx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return ha.User != "", nil }) // Emit put event that doesn't pass the condition (user not set). The waiter should ignore these events. stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) require.NoError(t, err) - item, err := local.MarshalHeadlessAuthenticationToItem(stub) - require.NoError(t, err) - // The waiter may miss put events during initialization, so we continuously emit them until one is caught. - require.Eventually(t, func() bool { - s.buf.Emit(backend.Event{ - Type: types.OpPut, - Item: *item, - }) - select { - case <-firstEmitReceived: - return true - default: - return false - } - }, time.Second, time.Millisecond*100) + waitForPutEvent(s, stub, firstEventReceivedC) // Ensure that the waiter did not finish with the condition unmet. select { @@ -179,13 +165,13 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { t.Run("WaitBackend", func(t *testing.T) { t.Parallel() - s := newSuite(t) + s := newTestEnv(t) stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) require.NoError(t, err) - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) - defer waitCancel() + waitCtx, waitCancel := context.WithTimeout(ctx, 5*time.Second) + t.Cleanup(waitCancel) // Wait should immediately check the backend and return the existing headless authentication stub. headlessAuthn, err := s.watcher.Wait(waitCtx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { @@ -198,10 +184,10 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { t.Run("WaitTimeout", func(t *testing.T) { t.Parallel() - s := newSuite(t) + s := newTestEnv(t) - waitCtx, waitCancel := context.WithTimeout(ctx, time.Millisecond*10) - defer waitCancel() + waitCtx, waitCancel := context.WithTimeout(ctx, 10*time.Millisecond) + t.Cleanup(waitCancel) _, err := s.watcher.Wait(waitCtx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) require.Error(t, err) @@ -210,61 +196,35 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { t.Run("StaleCheck", func(t *testing.T) { t.Parallel() - s := newSuite(t) - - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) - defer waitCancel() - + s := newTestEnv(t) // Create a waiter that we can block/unblock. - firstEmitReceived := make(chan struct{}) blockWaiter := make(chan struct{}) - _, blockedWaiterErrC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + t.Cleanup(func() { select { - case <-firstEmitReceived: + case <-blockWaiter: default: - close(firstEmitReceived) - // we only block the first event received, incase additional events get to this waiter. - <-blockWaiter + close(blockWaiter) } + }) + + _, blockedWaiterEventReceived, blockedWaiterErrC := waitInGoroutine(ctx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + <-blockWaiter return false, nil }) // Emit stub put event and wait for it to be caught by the waiter. stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) require.NoError(t, err) - item, err := local.MarshalHeadlessAuthenticationToItem(stub) - require.NoError(t, err) - require.Eventually(t, func() bool { - s.buf.Emit(backend.Event{ - Type: types.OpPut, - Item: *item, - }) - select { - case <-firstEmitReceived: - return true - default: - return false - } - }, time.Second, time.Millisecond*100) + waitForPutEvent(s, stub, blockedWaiterEventReceived) // Create a second waiter to catch a second put event. - _, freeWaiterErrC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + _, freeWaiterEventReceivedC, freeWaiterErrC := waitInGoroutine(ctx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) - require.Eventually(t, func() bool { - s.buf.Emit(backend.Event{ - Type: types.OpPut, - Item: *item, - }) - select { - case <-freeWaiterErrC: - return true - default: - return false - } - }, time.Second, time.Millisecond*100) + waitForPutEvent(s, stub, freeWaiterEventReceivedC) + require.NoError(t, <-freeWaiterErrC) // unblock the waiter. It should perform a stale check and return a not found error. close(blockWaiter) @@ -274,39 +234,16 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { t.Run("WatchReset", func(t *testing.T) { t.Parallel() - s := newSuite(t) - - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) - defer waitCancel() - - firstEmitReceived := make(chan struct{}) - headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - select { - case <-firstEmitReceived: - default: - close(firstEmitReceived) - } + s := newTestEnv(t) + headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(ctx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return services.ValidateHeadlessAuthentication(ha) == nil, nil }) stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) require.NoError(t, err) - item, err := local.MarshalHeadlessAuthenticationToItem(stub) - require.NoError(t, err) - require.Eventually(t, func() bool { - s.buf.Emit(backend.Event{ - Type: types.OpPut, - Item: *item, - }) - select { - case <-firstEmitReceived: - return true - default: - return false - } - }, time.Second, time.Millisecond*100) + waitForPutEvent(s, stub, firstEventReceivedC) // closed watchers should be handled gracefully and reset. s.buf.Clear() @@ -326,12 +263,9 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { t.Run("WatcherClosed", func(t *testing.T) { t.Parallel() - s := newSuite(t) - - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*5) - defer waitCancel() + s := newTestEnv(t) - _, errC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + _, _, errC := waitInGoroutine(ctx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) From c1706addd80d9101ef0d1ee46829376fa2e4ab87 Mon Sep 17 00:00:00 2001 From: joerger Date: Thu, 23 Mar 2023 15:42:42 -0700 Subject: [PATCH 4/4] Fix race condition that could cause a new watcher to be marked as stale before the channel is consumed; Fix minor test issues. --- lib/services/local/headlessauthn_watcher.go | 20 +++++--------- .../local/headlessauthn_watcher_test.go | 27 ++++++++++--------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/lib/services/local/headlessauthn_watcher.go b/lib/services/local/headlessauthn_watcher.go index 7ff88be174c75..b8c492c70f261 100644 --- a/lib/services/local/headlessauthn_watcher.go +++ b/lib/services/local/headlessauthn_watcher.go @@ -223,12 +223,6 @@ func (h *HeadlessAuthenticationWatcher) notify(headlessAuthns ...*types.Headless // in the backend, and waits for the given condition to be met, to result in an error, // or for the given context to close. func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, cond func(*types.HeadlessAuthentication) (bool, error)) (*types.HeadlessAuthentication, error) { - waiter, err := h.assignWaiter(ctx, name) - if err != nil { - return nil, trace.Wrap(err) - } - defer h.unassignWaiter(waiter) - checkBackend := func() (*types.HeadlessAuthentication, bool, error) { headlessAuthn, err := h.identityService.GetHeadlessAuthentication(ctx, name) if err != nil { @@ -243,7 +237,7 @@ func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, c return headlessAuthn, ok, nil } - // With the waiter allocated, check if there is an existing entry in the backend. + // Before the waiter is allocated, check if there is an existing entry in the backend. headlessAuthn, ok, err := checkBackend() if err != nil && !trace.IsNotFound(err) { return nil, trace.Wrap(err) @@ -251,6 +245,12 @@ func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, c return headlessAuthn, nil } + waiter, err := h.assignWaiter(ctx, name) + if err != nil { + return nil, trace.Wrap(err) + } + defer h.unassignWaiter(waiter) + for { select { case <-waiter.stale: @@ -265,12 +265,6 @@ func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, c return headlessAuthn, nil } case headlessAuthn := <-waiter.ch: - select { - case <-waiter.stale: - // prioritize stale check. - continue - default: - } if ok, err := cond(headlessAuthn); err != nil { return nil, trace.Wrap(err) } else if ok { diff --git a/lib/services/local/headlessauthn_watcher_test.go b/lib/services/local/headlessauthn_watcher_test.go index 9b1d4ae83cfd4..ee8f9ad866b08 100644 --- a/lib/services/local/headlessauthn_watcher_test.go +++ b/lib/services/local/headlessauthn_watcher_test.go @@ -71,7 +71,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { } } - waitInGoroutine := func(ctx context.Context, watcher *local.HeadlessAuthenticationWatcher, name string, cond func(*types.HeadlessAuthentication) (bool, error), + waitInGoroutine := func(ctx context.Context, t *testing.T, watcher *local.HeadlessAuthenticationWatcher, name string, cond func(*types.HeadlessAuthentication) (bool, error), ) (headlessAuthnC chan *types.HeadlessAuthentication, firstEventReceivedC chan struct{}, errC chan error) { waitCtx, waitCancel := context.WithTimeout(ctx, 5*time.Second) t.Cleanup(waitCancel) @@ -96,7 +96,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { // The waiter may miss put events during initialization, so we continuously emit them until one is caught. // This can also be used to wait until a waiter is fully initialized. - waitForPutEvent := func(s *testEnv, ha *types.HeadlessAuthentication, firstEventReceivedC chan struct{}) { + waitForPutEvent := func(t *testing.T, s *testEnv, ha *types.HeadlessAuthentication, firstEventReceivedC chan struct{}) { item, err := local.MarshalHeadlessAuthenticationToItem(ha) require.NoError(t, err) @@ -119,7 +119,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { t.Parallel() s := newTestEnv(t) - headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(ctx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return ha.User != "", nil }) @@ -128,7 +128,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { require.NoError(t, err) stub.User = "user" - waitForPutEvent(s, stub, firstEventReceivedC) + waitForPutEvent(t, s, stub, firstEventReceivedC) require.NoError(t, <-errC) require.Equal(t, stub, <-headlessAuthnCh) @@ -141,7 +141,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { waitCtx, waitCancel := context.WithCancel(ctx) t.Cleanup(waitCancel) - headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(waitCtx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return ha.User != "", nil }) @@ -149,7 +149,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) require.NoError(t, err) - waitForPutEvent(s, stub, firstEventReceivedC) + waitForPutEvent(t, s, stub, firstEventReceivedC) // Ensure that the waiter did not finish with the condition unmet. select { @@ -197,6 +197,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { t.Run("StaleCheck", func(t *testing.T) { t.Parallel() s := newTestEnv(t) + // Create a waiter that we can block/unblock. blockWaiter := make(chan struct{}) t.Cleanup(func() { @@ -207,7 +208,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { } }) - _, blockedWaiterEventReceived, blockedWaiterErrC := waitInGoroutine(ctx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + _, blockedWaiterEventReceived, blockedWaiterErrC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { <-blockWaiter return false, nil }) @@ -216,14 +217,14 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) require.NoError(t, err) - waitForPutEvent(s, stub, blockedWaiterEventReceived) + waitForPutEvent(t, s, stub, blockedWaiterEventReceived) // Create a second waiter to catch a second put event. - _, freeWaiterEventReceivedC, freeWaiterErrC := waitInGoroutine(ctx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + _, freeWaiterEventReceivedC, freeWaiterErrC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) - waitForPutEvent(s, stub, freeWaiterEventReceivedC) + waitForPutEvent(t, s, stub, freeWaiterEventReceivedC) require.NoError(t, <-freeWaiterErrC) // unblock the waiter. It should perform a stale check and return a not found error. @@ -236,14 +237,14 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { t.Parallel() s := newTestEnv(t) - headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(ctx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return services.ValidateHeadlessAuthentication(ha) == nil, nil }) stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) require.NoError(t, err) - waitForPutEvent(s, stub, firstEventReceivedC) + waitForPutEvent(t, s, stub, firstEventReceivedC) // closed watchers should be handled gracefully and reset. s.buf.Clear() @@ -265,7 +266,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { t.Parallel() s := newTestEnv(t) - _, _, errC := waitInGoroutine(ctx, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + _, _, errC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil })