diff --git a/api/types/headlessauthn.go b/api/types/headlessauthn.go index b49d3d637f248..c37a8cd50d28c 100644 --- a/api/types/headlessauthn.go +++ b/api/types/headlessauthn.go @@ -17,9 +17,25 @@ limitations under the License. package types import ( + "time" + "github.com/gravitational/trace" ) +// NewHeadlessAuthenticationStub creates a new a headless authentication resource with limited data. +// The stub is used to initiate headless login. +func NewHeadlessAuthenticationStub(name string, expires time.Time) (*HeadlessAuthentication, error) { + ha := &HeadlessAuthentication{ + ResourceHeader: ResourceHeader{ + Metadata: Metadata{ + Name: name, + Expires: &expires, + }, + }, + } + return ha, ha.CheckAndSetDefaults() +} + // CheckAndSetDefaults does basic validation and default setting. func (h *HeadlessAuthentication) CheckAndSetDefaults() error { h.setStaticFields() diff --git a/lib/services/local/headlessauthn.go b/lib/services/local/headlessauthn.go index c8b0a70bb3cf4..9537761020aa7 100644 --- a/lib/services/local/headlessauthn.go +++ b/lib/services/local/headlessauthn.go @@ -31,17 +31,12 @@ import ( // CreateHeadlessAuthenticationStub creates a headless authentication stub in the backend. func (s *IdentityService) CreateHeadlessAuthenticationStub(ctx context.Context, name string) (*types.HeadlessAuthentication, error) { - expires := s.Clock().Now().Add(defaults.CallbackTimeout) - headlessAuthn := &types.HeadlessAuthentication{ - ResourceHeader: types.ResourceHeader{ - Metadata: types.Metadata{ - Name: name, - Expires: &expires, - }, - }, + headlessAuthn, err := types.NewHeadlessAuthenticationStub(name, s.Clock().Now().Add(defaults.CallbackTimeout)) + if err != nil { + return nil, trace.Wrap(err) } - item, err := marshalHeadlessAuthenticationToItem(headlessAuthn) + item, err := MarshalHeadlessAuthenticationToItem(headlessAuthn) if err != nil { return nil, trace.Wrap(err) } @@ -60,12 +55,12 @@ func (s *IdentityService) CompareAndSwapHeadlessAuthentication(ctx context.Conte return nil, trace.Wrap(err) } - oldItem, err := marshalHeadlessAuthenticationToItem(old) + oldItem, err := MarshalHeadlessAuthenticationToItem(old) if err != nil { return nil, trace.Wrap(err) } - newItem, err := marshalHeadlessAuthenticationToItem(new) + newItem, err := MarshalHeadlessAuthenticationToItem(new) if err != nil { return nil, trace.Wrap(err) } @@ -120,7 +115,7 @@ func (s *IdentityService) DeleteHeadlessAuthentication(ctx context.Context, name return trace.Wrap(err) } -func marshalHeadlessAuthenticationToItem(headlessAuthn *types.HeadlessAuthentication) (*backend.Item, error) { +func MarshalHeadlessAuthenticationToItem(headlessAuthn *types.HeadlessAuthentication) (*backend.Item, error) { if err := headlessAuthn.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/services/local/headlessauthn_watcher.go b/lib/services/local/headlessauthn_watcher.go index 52a0e841b9d58..b8c492c70f261 100644 --- a/lib/services/local/headlessauthn_watcher.go +++ b/lib/services/local/headlessauthn_watcher.go @@ -48,6 +48,11 @@ var watcherClosedErr = trace.Errorf("headless authentication watcher closed") type HeadlessAuthenticationWatcherConfig struct { // Backend is the storage backend used to create watchers. Backend backend.Backend + // WatcherService is a service used to create new watchers. + // If nil, Backend will be used as the watcher service. + WatcherService interface { + NewWatcher(ctx context.Context, watch backend.Watch) (backend.Watcher, error) + } // Log is a logger. Log logrus.FieldLogger // Clock is used to control time. @@ -61,6 +66,9 @@ func (cfg *HeadlessAuthenticationWatcherConfig) CheckAndSetDefaults() error { if cfg.Backend == nil { return trace.BadParameter("missing parameter Backend") } + if cfg.WatcherService == nil { + cfg.WatcherService = cfg.Backend + } if cfg.Log == nil { cfg.Log = logrus.StandardLogger() cfg.Log.WithField("resource-kind", types.KindHeadlessAuthentication) @@ -145,7 +153,7 @@ func (h *HeadlessAuthenticationWatcher) runWatchLoop(ctx context.Context) { } func (h *HeadlessAuthenticationWatcher) watch(ctx context.Context) error { - watcher, err := h.Backend.NewWatcher(ctx, backend.Watch{ + watcher, err := h.WatcherService.NewWatcher(ctx, backend.Watch{ Name: types.KindHeadlessAuthentication, MetricComponent: types.KindHeadlessAuthentication, Prefixes: [][]byte{headlessAuthenticationKey("")}, @@ -211,29 +219,10 @@ func (h *HeadlessAuthenticationWatcher) notify(headlessAuthns ...*types.Headless } } -// CheckWaiter checks if there is an active waiter matching the given -// headless authentication ID. Used in tests. -func (h *HeadlessAuthenticationWatcher) CheckWaiter(name string) bool { - h.mux.Lock() - defer h.mux.Unlock() - for i := range h.waiters { - if h.waiters[i].name == name { - return true - } - } - return false -} - // Wait watches for the headless authentication with the given id to be added/updated // in the backend, and waits for the given condition to be met, to result in an error, // or for the given context to close. func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, cond func(*types.HeadlessAuthentication) (bool, error)) (*types.HeadlessAuthentication, error) { - waiter, err := h.assignWaiter(ctx, name) - if err != nil { - return nil, trace.Wrap(err) - } - defer h.unassignWaiter(waiter) - checkBackend := func() (*types.HeadlessAuthentication, bool, error) { headlessAuthn, err := h.identityService.GetHeadlessAuthentication(ctx, name) if err != nil { @@ -248,7 +237,7 @@ func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, c return headlessAuthn, ok, nil } - // With the waiter allocated, check if there is an existing entry in the backend. + // Before the waiter is allocated, check if there is an existing entry in the backend. headlessAuthn, ok, err := checkBackend() if err != nil && !trace.IsNotFound(err) { return nil, trace.Wrap(err) @@ -256,6 +245,12 @@ func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, c return headlessAuthn, nil } + waiter, err := h.assignWaiter(ctx, name) + if err != nil { + return nil, trace.Wrap(err) + } + defer h.unassignWaiter(waiter) + for { select { case <-waiter.stale: @@ -270,12 +265,6 @@ func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, c return headlessAuthn, nil } case headlessAuthn := <-waiter.ch: - select { - case <-waiter.stale: - // prioritize stale check. - continue - default: - } if ok, err := cond(headlessAuthn); err != nil { return nil, trace.Wrap(err) } else if ok { diff --git a/lib/services/local/headlessauthn_watcher_test.go b/lib/services/local/headlessauthn_watcher_test.go index bb4346b419782..ee8f9ad866b08 100644 --- a/lib/services/local/headlessauthn_watcher_test.go +++ b/lib/services/local/headlessauthn_watcher_test.go @@ -26,163 +26,251 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" ) func TestHeadlessAuthenticationWatcher(t *testing.T) { + t.Parallel() ctx := context.Background() + pubUUID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) - t.Parallel() - identity := newIdentityService(t, clockwork.NewFakeClock()) + type testEnv struct { + watcher *local.HeadlessAuthenticationWatcher + watcherClock clockwork.FakeClock + watcherCancel context.CancelFunc + identity *local.IdentityService + buf *backend.CircularBuffer + } - watcherCtx, watcherCancel := context.WithCancel(ctx) - defer watcherCancel() + newTestEnv := func(t *testing.T) *testEnv { + identity := newIdentityService(t, clockwork.NewFakeClock()) - watcherClock := clockwork.NewFakeClock() - w, err := local.NewHeadlessAuthenticationWatcher(watcherCtx, local.HeadlessAuthenticationWatcherConfig{ - Clock: watcherClock, - Backend: identity.Backend, - }) - require.NoError(t, err) + // use a standalone buffer as a watcher service. + buf := backend.NewCircularBuffer() + buf.SetInit() - pubUUID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) + watcherCtx, watcherCancel := context.WithCancel(ctx) + t.Cleanup(watcherCancel) + + watcherClock := clockwork.NewFakeClock() + w, err := local.NewHeadlessAuthenticationWatcher(watcherCtx, local.HeadlessAuthenticationWatcherConfig{ + Clock: watcherClock, + WatcherService: buf, + Backend: identity.Backend, + }) + require.NoError(t, err) + + return &testEnv{ + watcher: w, + watcherClock: watcherClock, + watcherCancel: watcherCancel, + identity: identity, + buf: buf, + } + } + + waitInGoroutine := func(ctx context.Context, t *testing.T, watcher *local.HeadlessAuthenticationWatcher, name string, cond func(*types.HeadlessAuthentication) (bool, error), + ) (headlessAuthnC chan *types.HeadlessAuthentication, firstEventReceivedC chan struct{}, errC chan error) { + waitCtx, waitCancel := context.WithTimeout(ctx, 5*time.Second) + t.Cleanup(waitCancel) - waitInGoroutine := func(ctx context.Context, t *testing.T, name string, cond func(*types.HeadlessAuthentication) (bool, error)) (chan *types.HeadlessAuthentication, chan error) { - headlessAuthnCh := make(chan *types.HeadlessAuthentication, 1) - errC := make(chan error, 1) + headlessAuthnC = make(chan *types.HeadlessAuthentication, 1) + errC = make(chan error, 1) + firstEventReceivedC = make(chan struct{}) go func() { - headlessAuthn, err := w.Wait(ctx, name, cond) + headlessAuthn, err := watcher.Wait(waitCtx, name, func(ha *types.HeadlessAuthentication) (bool, error) { + select { + case <-firstEventReceivedC: + default: + close(firstEventReceivedC) + } + return cond(ha) + }) errC <- err - headlessAuthnCh <- headlessAuthn + headlessAuthnC <- headlessAuthn }() - require.Eventually(t, func() bool { return w.CheckWaiter(name) }, time.Millisecond*100, time.Millisecond*10) - - return headlessAuthnCh, errC + return headlessAuthnC, firstEventReceivedC, errC } - t.Run("WaitTimeout", func(t *testing.T) { - waitCtx, waitCancel := context.WithTimeout(ctx, time.Millisecond*10) - defer waitCancel() + // The waiter may miss put events during initialization, so we continuously emit them until one is caught. + // This can also be used to wait until a waiter is fully initialized. + waitForPutEvent := func(t *testing.T, s *testEnv, ha *types.HeadlessAuthentication, firstEventReceivedC chan struct{}) { + item, err := local.MarshalHeadlessAuthenticationToItem(ha) + require.NoError(t, err) - _, err = w.Wait(waitCtx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) - require.Error(t, err) - require.Equal(t, waitCtx.Err(), err) - }) + require.Eventually(t, func() bool { + s.buf.Emit(backend.Event{ + Type: types.OpPut, + Item: *item, + }) + + select { + case <-firstEventReceivedC: + return true + default: + return false + } + }, 2*time.Second, 100*time.Millisecond) + } - t.Run("WaitCreateStub", func(t *testing.T) { - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*2) - defer waitCancel() + t.Run("WaitEventWithConditionMet", func(t *testing.T) { + t.Parallel() + s := newTestEnv(t) - headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - return true, nil + headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + return ha.User != "", nil }) - stub, err := identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + // Emit put event that passes the condition. + stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, identity.DeleteHeadlessAuthentication(ctx, pubUUID)) }) + stub.User = "user" + + waitForPutEvent(t, s, stub, firstEventReceivedC) require.NoError(t, <-errC) require.Equal(t, stub, <-headlessAuthnCh) }) - t.Run("WaitCompareAndSwap", func(t *testing.T) { - stub, err := identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, identity.DeleteHeadlessAuthentication(ctx, pubUUID)) }) + t.Run("WaitEventWithConditionUnmet", func(t *testing.T) { + t.Parallel() + s := newTestEnv(t) - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*2) - defer waitCancel() + waitCtx, waitCancel := context.WithCancel(ctx) + t.Cleanup(waitCancel) - headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - return ha.State == types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, nil + headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + return ha.User != "", nil }) - replace := *stub - replace.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED - replace.PublicKey = []byte(sshPubKey) - replace.User = "user" + // Emit put event that doesn't pass the condition (user not set). The waiter should ignore these events. + stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) + require.NoError(t, err) + + waitForPutEvent(t, s, stub, firstEventReceivedC) + + // Ensure that the waiter did not finish with the condition unmet. + select { + case err := <-errC: + t.Errorf("Expected waiter to continue but instead the waiter returned with err: %v", err) + default: + waitCancel() + } - swapped, err := identity.CompareAndSwapHeadlessAuthentication(ctx, stub, &replace) + require.Error(t, <-errC) + require.Nil(t, <-headlessAuthnCh) + }) + + t.Run("WaitBackend", func(t *testing.T) { + t.Parallel() + s := newTestEnv(t) + + stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) require.NoError(t, err) - require.NoError(t, <-errC) - require.Equal(t, swapped, <-headlessAuthnCh) + waitCtx, waitCancel := context.WithTimeout(ctx, 5*time.Second) + t.Cleanup(waitCancel) + + // Wait should immediately check the backend and return the existing headless authentication stub. + headlessAuthn, err := s.watcher.Wait(waitCtx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + return true, nil + }) + + require.NoError(t, err) + require.Equal(t, stub, headlessAuthn) }) - t.Run("StaleCheck", func(t *testing.T) { - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*2) - defer waitCancel() + t.Run("WaitTimeout", func(t *testing.T) { + t.Parallel() + s := newTestEnv(t) - // Create two waiters - a blocked consumer and a free consumer. - blockWait := make(chan struct{}) - _, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - <-blockWait - return false, nil + waitCtx, waitCancel := context.WithTimeout(ctx, 10*time.Millisecond) + t.Cleanup(waitCancel) + + _, err := s.watcher.Wait(waitCtx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) + require.Error(t, err) + require.Equal(t, waitCtx.Err(), err) + }) + + t.Run("StaleCheck", func(t *testing.T) { + t.Parallel() + s := newTestEnv(t) + + // Create a waiter that we can block/unblock. + blockWaiter := make(chan struct{}) + t.Cleanup(func() { + select { + case <-blockWaiter: + default: + close(blockWaiter) + } }) - notifyReceived := make(chan struct{}, 1) - waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - notifyReceived <- struct{}{} + _, blockedWaiterEventReceived, blockedWaiterErrC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + <-blockWaiter return false, nil }) - // Create stub and wait for it to be caught by the free waiter. - stub, err := identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + // Emit stub put event and wait for it to be caught by the waiter. + stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) require.NoError(t, err) - <-notifyReceived - replace := *stub - replace.PublicKey = []byte(sshPubKey) - replace.User = "user" + waitForPutEvent(t, s, stub, blockedWaiterEventReceived) - // perform a put to mark the blocked waiter as stale and - // wait for it to be caught by the free waiter. - _, err = identity.CompareAndSwapHeadlessAuthentication(ctx, stub, &replace) - require.NoError(t, err) - <-notifyReceived + // Create a second waiter to catch a second put event. + _, freeWaiterEventReceivedC, freeWaiterErrC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + return true, nil + }) - // delete the headless authentication and unblock. - err = identity.DeleteHeadlessAuthentication(ctx, pubUUID) - require.NoError(t, err) - close(blockWait) + waitForPutEvent(t, s, stub, freeWaiterEventReceivedC) + require.NoError(t, <-freeWaiterErrC) - // the blocked waiter should perform a stale check and return a not found error. - err = <-errC + // unblock the waiter. It should perform a stale check and return a not found error. + close(blockWaiter) + err = <-blockedWaiterErrC require.True(t, trace.IsNotFound(err), "Expected a not found error from Wait but got %v", err) }) t.Run("WatchReset", func(t *testing.T) { - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*2) - defer waitCancel() + t.Parallel() + s := newTestEnv(t) - headlessAuthnCh, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - return true, nil + headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + return services.ValidateHeadlessAuthentication(ha) == nil, nil }) + stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + require.NoError(t, err) + + waitForPutEvent(t, s, stub, firstEventReceivedC) + // closed watchers should be handled gracefully and reset. - identity.Backend.CloseWatchers() - watcherClock.BlockUntil(1) + s.buf.Clear() + s.watcherClock.BlockUntil(1) - // The watcher should notify waiters of missed events. - stub, err := identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + // The watcher should notify waiters of backend state on watcher reset. + replace := *stub + replace.PublicKey = []byte(sshPubKey) + replace.User = "user" + swapped, err := s.identity.CompareAndSwapHeadlessAuthentication(ctx, stub, &replace) require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, identity.DeleteHeadlessAuthentication(ctx, pubUUID)) }) - watcherClock.Advance(w.MaxRetryPeriod) + s.watcherClock.Advance(s.watcher.MaxRetryPeriod) require.NoError(t, <-errC) - require.Equal(t, stub, <-headlessAuthnCh) + require.Equal(t, swapped, <-headlessAuthnCh) }) t.Run("WatcherClosed", func(t *testing.T) { - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second*2) - defer waitCancel() + t.Parallel() + s := newTestEnv(t) - _, errC := waitInGoroutine(waitCtx, t, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + _, _, errC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) - watcherCancel() + s.watcherCancel() // waiters should be notified to close and result in ctx error waitErr := <-errC @@ -190,7 +278,7 @@ func TestHeadlessAuthenticationWatcher(t *testing.T) { require.Equal(t, waitErr.Error(), "headless authentication watcher closed") // New waiters should be prevented. - _, err = w.Wait(ctx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) + _, err := s.watcher.Wait(ctx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) require.Error(t, err) }) }