diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 8ee1e71a1c412..74c2ecda3188c 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -4803,13 +4803,19 @@ func (a *Server) GetHeadlessAuthentication(ctx context.Context, name string) (*t return nil, trace.Wrap(err) } - // wait for the headless authentication to be updated with valid login details - // by the login process. If the headless authentication is already updated, - // Wait will return it immediately. + sub, err := a.headlessAuthenticationWatcher.Subscribe(ctx, name) + if err != nil { + return nil, trace.Wrap(err) + } + defer sub.Close() + waitCtx, cancel := context.WithTimeout(ctx, defaults.HTTPRequestTimeout) defer cancel() - headlessAuthn, err := a.headlessAuthenticationWatcher.Wait(waitCtx, name, func(ha *types.HeadlessAuthentication) (bool, error) { + // wait for the headless authentication to be updated with valid login details + // by the login process. If the headless authentication is already updated, + // Wait will return it immediately. + headlessAuthn, err := a.headlessAuthenticationWatcher.WaitForUpdate(waitCtx, sub, func(ha *types.HeadlessAuthentication) (bool, error) { return services.ValidateHeadlessAuthentication(ha) == nil, nil }) return headlessAuthn, trace.Wrap(err) diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 81ea4d12b65c6..555a9f16a8ab7 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -361,10 +361,16 @@ func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserR return nil, trace.Wrap(err) } + sub, err := s.headlessAuthenticationWatcher.Subscribe(ctx, req.HeadlessAuthenticationID) + if err != nil { + return nil, trace.Wrap(err) + } + defer sub.Close() + // Wait for a headless authenticated stub to be inserted by an authenticated // call to GetHeadlessAuthentication. We do this to avoid immediately inserting // backend items from an unauthenticated endpoint. - headlessAuthnStub, err := s.headlessAuthenticationWatcher.Wait(ctx, req.HeadlessAuthenticationID, func(ha *types.HeadlessAuthentication) (bool, error) { + headlessAuthnStub, err := s.headlessAuthenticationWatcher.WaitForUpdate(ctx, sub, func(ha *types.HeadlessAuthentication) (bool, error) { // Only headless authentication stub can be inserted without the standard validation. if services.ValidateHeadlessAuthentication(ha) == nil { return false, trace.AlreadyExists("headless auth request already exists") @@ -381,7 +387,7 @@ func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserR } // Wait for the request to be approved/denied. - headlessAuthn, err = s.headlessAuthenticationWatcher.Wait(ctx, req.HeadlessAuthenticationID, func(ha *types.HeadlessAuthentication) (bool, error) { + headlessAuthn, err = s.headlessAuthenticationWatcher.WaitForUpdate(ctx, sub, func(ha *types.HeadlessAuthentication) (bool, error) { switch ha.State { case types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED: if ha.MfaDevice == nil { diff --git a/lib/services/local/headlessauthn_watcher.go b/lib/services/local/headlessauthn_watcher.go index 05720b83d21eb..56eb1274be9b8 100644 --- a/lib/services/local/headlessauthn_watcher.go +++ b/lib/services/local/headlessauthn_watcher.go @@ -18,7 +18,7 @@ package local import ( "context" - "fmt" + "errors" "sync" "time" @@ -34,26 +34,22 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -// maxWaiters is the maximum number of concurrent waiters that a headless authentication watcher -// will accept. This limit is introduced because the headless login flow creates waiters from an +// maxSubscribers is the maximum number of concurrent subscribers that a headless authentication watcher +// will accept. This limit is introduced because the headless login flow creates subscribers from an // unauthenticated endpoint, which could be exploited in a ddos attack without the limit in place. // // 1024 was chosen as a reasonable limit, as under normal conditions, a single Teleport Cluster // would never have over 1000 concurrent headless logins, each of which has a maximum lifetime // of 30-60 seconds. If this limit is exceeded in a reasonable scenario, this limit should be // made configurable in the server configuration file. -const maxWaiters = 1024 +const maxSubscribers = 1024 -var watcherClosedErr = trace.Errorf("headless authentication watcher closed") +var ErrHeadlessAuthenticationWatcherClosed = errors.New("headless authentication watcher closed") +// HeadlessAuthenticationWatcherConfig contains configuration options for a HeadlessAuthenticationWatcher. type HeadlessAuthenticationWatcherConfig struct { // Backend is the storage backend used to create watchers. Backend backend.Backend - // WatcherService is a service used to create new watchers. - // If nil, Backend will be used as the watcher service. - WatcherService interface { - NewWatcher(ctx context.Context, watch backend.Watch) (backend.Watcher, error) - } // Log is a logger. Log logrus.FieldLogger // Clock is used to control time. @@ -67,9 +63,6 @@ func (cfg *HeadlessAuthenticationWatcherConfig) CheckAndSetDefaults() error { if cfg.Backend == nil { return trace.BadParameter("missing parameter Backend") } - if cfg.WatcherService == nil { - cfg.WatcherService = cfg.Backend - } if cfg.Log == nil { cfg.Log = logrus.StandardLogger() cfg.Log.WithField("resource-kind", types.KindHeadlessAuthentication) @@ -84,14 +77,15 @@ func (cfg *HeadlessAuthenticationWatcherConfig) CheckAndSetDefaults() error { return nil } -// HeadlessAuthenticationWatcher is a custom backend watcher for the headless authentication resource. +// HeadlessAuthenticationWatcher is a light weight backend watcher for the headless authentication resource. type HeadlessAuthenticationWatcher struct { HeadlessAuthenticationWatcherConfig identityService *IdentityService retry retryutils.Retry mux sync.Mutex - waiters [maxWaiters]headlessAuthenticationWaiter + subscribers [maxSubscribers]*headlessAuthenticationSubscriber closed chan struct{} + running chan struct{} } // NewHeadlessAuthenticationWatcher creates a new headless authentication resource watcher. @@ -112,16 +106,29 @@ func NewHeadlessAuthenticationWatcher(ctx context.Context, cfg HeadlessAuthentic return nil, trace.Wrap(err) } - watcher := &HeadlessAuthenticationWatcher{ + h := &HeadlessAuthenticationWatcher{ HeadlessAuthenticationWatcherConfig: cfg, identityService: NewIdentityService(cfg.Backend), retry: retry, closed: make(chan struct{}), + running: make(chan struct{}), } - go watcher.runWatchLoop(ctx) + go h.runWatchLoop(ctx) - return watcher, nil + // Wait for the watch loop to initialize before returning. + select { + case <-h.running: + case <-ctx.Done(): + return nil, trace.Wrap(ctx.Err()) + } + + return h, nil +} + +// Done returns a channel that's closed when the watcher is closed. +func (h *HeadlessAuthenticationWatcher) Done() <-chan struct{} { + return h.closed } func (h *HeadlessAuthenticationWatcher) close() { @@ -138,7 +145,7 @@ func (h *HeadlessAuthenticationWatcher) runWatchLoop(ctx context.Context) { startedWaiting := h.Clock.Now() select { case t := <-h.retry.After(): - h.Log.Debugf("Attempting to restart watch after waiting %v.", t.Sub(startedWaiting)) + h.Log.Warningf("Restarting watch on error after waiting %v. Error: %v.", t.Sub(startedWaiting), err) h.retry.Inc() case <-ctx.Done(): h.Log.WithError(ctx.Err()).Debugf("Context closed with err. Returning from watch loop.") @@ -147,42 +154,21 @@ func (h *HeadlessAuthenticationWatcher) runWatchLoop(ctx context.Context) { h.Log.Debug("Watcher closed. Returning from watch loop.") return } - if err != nil { - h.Log.Warningf("Restart watch on error: %v.", err) - } } } func (h *HeadlessAuthenticationWatcher) watch(ctx context.Context) error { - watcher, err := h.WatcherService.NewWatcher(ctx, backend.Watch{ - Name: types.KindHeadlessAuthentication, - MetricComponent: types.KindHeadlessAuthentication, - Prefixes: [][]byte{headlessAuthenticationKey("")}, - }) + watcher, err := h.newWatcher(ctx) if err != nil { return trace.Wrap(err) } defer watcher.Close() - select { - case <-watcher.Done(): - return fmt.Errorf("watcher closed") - case <-ctx.Done(): - return ctx.Err() - case event := <-watcher.Events(): - if event.Type != types.OpInit { - return trace.BadParameter("expected init event, got %v instead", event.Type) - } - } - - h.retry.Reset() - + // Notify any subscribers initiated before the new watcher initialized. headlessAuthns, err := h.identityService.GetHeadlessAuthentications(ctx) if err != nil { return trace.Wrap(err) } - - // Notify any waiters initiated before the new watcher initialized. h.notify(headlessAuthns...) for { @@ -198,168 +184,198 @@ func (h *HeadlessAuthenticationWatcher) watch(ctx context.Context) error { } } case <-watcher.Done(): - return fmt.Errorf("watcher closed") + return errors.New("watcher closed") case <-ctx.Done(): return ctx.Err() + case h.running <- struct{}{}: + } + } +} + +func (h *HeadlessAuthenticationWatcher) newWatcher(ctx context.Context) (backend.Watcher, error) { + watcher, err := h.identityService.NewWatcher(ctx, backend.Watch{ + Name: types.KindHeadlessAuthentication, + MetricComponent: types.KindHeadlessAuthentication, + Prefixes: [][]byte{headlessAuthenticationKey("")}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + select { + case <-watcher.Done(): + return nil, errors.New("watcher closed") + case <-ctx.Done(): + return nil, ctx.Err() + case event := <-watcher.Events(): + if event.Type != types.OpInit { + return nil, trace.BadParameter("expected init event, got %v instead", event.Type) } } + + h.retry.Reset() + return watcher, nil } func (h *HeadlessAuthenticationWatcher) notify(headlessAuthns ...*types.HeadlessAuthentication) { h.mux.Lock() defer h.mux.Unlock() + for _, ha := range headlessAuthns { - for _, waiter := range h.waiters { - if waiter.name == ha.Metadata.Name { + for _, s := range h.subscribers { + if s != nil && s.name == ha.Metadata.Name { select { - case waiter.ch <- proto.Clone(ha).(*types.HeadlessAuthentication): + case s.updates <- proto.Clone(ha).(*types.HeadlessAuthentication): default: - waiter.markStale() + select { + case s.stale <- struct{}{}: + default: + // subscriber is already stale, skip. + } } } } } } -// Wait watches for the headless authentication with the given id to be added/updated -// in the backend, and waits for the given condition to be met, to result in an error, -// or for the given context to close. -func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, cond func(*types.HeadlessAuthentication) (ok bool, err error)) (*types.HeadlessAuthentication, error) { - const bufferSize = 3 // one for each of the "main", "stale", and "initial backend check" goroutines. - conditionMet := make(chan *types.HeadlessAuthentication, bufferSize) - conditionErr := make(chan error, bufferSize) - checkCondition := func(ha *types.HeadlessAuthentication) { - if ok, err := cond(ha); err != nil { - conditionErr <- trace.Wrap(err) - } else if ok { - conditionMet <- ha - } - } +// HeadlessAuthenticationSubscriber is a subscriber of updates +// for a specific headless authentication resource. +type HeadlessAuthenticationSubscriber interface { + Name() string + // Updates is a channel used by the watcher to send headless authentication updates. + // After receiving an update, the caller should check the stale channel to ensure it + // is not a stale update. + Updates() <-chan *types.HeadlessAuthentication + // Stale is a channel used by the watcher to notify the subscriber that one or more + // updates have been missed, due to slow receives on the Updates channel. + Stale() <-chan struct{} + // Close closes the subscriber and its channels. This frees up resources for the watcher + // and should always be called on completion. + Close() +} - waiter, err := h.assignWaiter(name) +// Subscribe creates a new headless authentication subscriber for the given headless authentication name. +func (h *HeadlessAuthenticationWatcher) Subscribe(ctx context.Context, name string) (HeadlessAuthenticationSubscriber, error) { + i, err := h.assignSubscriber(name) if err != nil { return nil, trace.Wrap(err) } + subscriber := h.subscribers[i] - ctx, cancel := context.WithCancel(ctx) - var wg sync.WaitGroup - defer func() { - cancel() - wg.Wait() - h.unassignWaiter(waiter) - }() - - // Consume the main channel. - wg.Add(1) go func() { - defer wg.Done() - for { - select { - case ha := <-waiter.ch: - checkCondition(ha) - case <-ctx.Done(): - return - } - } - }() - - // Consume the stale channel. - wg.Add(1) - go func() { - defer wg.Done() select { - case <-waiter.stale: - ha, err := h.identityService.GetHeadlessAuthentication(ctx, name) - if err != nil { - conditionErr <- trace.Wrap(err) - } else { - checkCondition(ha) - } case <-ctx.Done(): - return + case <-subscriber.closed: } - }() - // With the waiter allocated, check the backend for an existing entry. - wg.Add(1) - go func() { - defer wg.Done() - ha, err := h.identityService.GetHeadlessAuthentication(ctx, name) - if trace.IsNotFound(err) { - // Ignore not found errors in the initial stale check. - return - } else if err != nil { - conditionErr <- trace.Wrap(err) - } else { - checkCondition(ha) - } + // reclaim the subscriber and close remaining open channels. + h.unassignSubscriber(i) + close(subscriber.updates) + close(subscriber.stale) }() - select { - case ha := <-conditionMet: - return ha, nil - case err := <-conditionErr: - return nil, trace.Wrap(err) - case <-ctx.Done(): - return nil, trace.Wrap(ctx.Err()) - case <-h.closed: - return nil, watcherClosedErr - } + return subscriber, nil } -func (h *HeadlessAuthenticationWatcher) assignWaiter(name string) (*headlessAuthenticationWaiter, error) { +func (h *HeadlessAuthenticationWatcher) assignSubscriber(name string) (int, error) { h.mux.Lock() defer h.mux.Unlock() select { case <-h.closed: - return nil, watcherClosedErr + return 0, ErrHeadlessAuthenticationWatcherClosed default: } - for i := range h.waiters { - if h.waiters[i].ch != nil { - continue + for i := range h.subscribers { + if h.subscribers[i] == nil { + h.subscribers[i] = &headlessAuthenticationSubscriber{ + name: name, + // small buffer for updates to avoid unnecessary stale checks. + updates: make(chan *types.HeadlessAuthentication, 1), + // buffer required to mark as stale. + stale: make(chan struct{}, 1), + closed: make(chan struct{}), + } + return i, nil } - h.waiters[i].ch = make(chan *types.HeadlessAuthentication) - h.waiters[i].name = name - h.waiters[i].stale = make(chan struct{}, 1) // buffer required by markStale - return &h.waiters[i], nil } - return nil, trace.LimitExceeded("too many in-flight headless login requests") + return 0, trace.LimitExceeded("too many in-flight headless login requests") } -func (h *HeadlessAuthenticationWatcher) unassignWaiter(waiter *headlessAuthenticationWaiter) { +func (h *HeadlessAuthenticationWatcher) unassignSubscriber(i int) { h.mux.Lock() defer h.mux.Unlock() - - // close channels. - close(waiter.ch) - close(waiter.stale) - - waiter.ch = nil - waiter.name = "" - waiter.stale = nil + h.subscribers[i] = nil } -// headlessAuthenticationWaiter is a waiter for a specific headless authentication. -type headlessAuthenticationWaiter struct { - // name is the name of the headless authentication resource being waited on. +// headlessAuthenticationSubscriber is a subscriber for a specific headless authentication. +type headlessAuthenticationSubscriber struct { + // name is the name of the headless authentication resource being subscribed to. name string - // ch is a channel used by the watcher to send resource updates. - ch chan *types.HeadlessAuthentication - // stale is a channel used to determine if the waiter is stale and + // updates is a channel used by the watcher to send resource updates. + updates chan *types.HeadlessAuthentication + // stale is a channel used to determine if the subscriber is stale and // needs to check the backend for missed data. stale chan struct{} + // closed is a channel used to determine if the subscriber is closed. + closed chan struct{} } -// markStale marks a waiter as stale so it will update itself once available. -// This should be called when a waiter misses an update due to slow consumption on its channel. -func (w *headlessAuthenticationWaiter) markStale() { - select { - case w.stale <- struct{}{}: - default: - // waiter is already stale, carry on. +func (s *headlessAuthenticationSubscriber) Name() string { + return s.name +} + +func (s *headlessAuthenticationSubscriber) Updates() <-chan *types.HeadlessAuthentication { + return s.updates +} + +func (s *headlessAuthenticationSubscriber) Stale() <-chan struct{} { + return s.stale +} + +func (s *headlessAuthenticationSubscriber) Close() { + close(s.closed) +} + +// WaitForUpdate waits until the headless authentication with the given name is updated in the +// backend to meet the given condition or returns early if the condition results in an +// error or if the watcher or given context is closed. +func (h *HeadlessAuthenticationWatcher) WaitForUpdate(ctx context.Context, subscriber HeadlessAuthenticationSubscriber, cond func(*types.HeadlessAuthentication) (bool, error)) (*types.HeadlessAuthentication, error) { + // First check for an existing backend entry. + ha, err := h.identityService.GetHeadlessAuthentication(ctx, subscriber.Name()) + if trace.IsNotFound(err) { + // If not found, that's ok. Continue to watch update channel. + } else if err != nil { + return nil, trace.Wrap(err) + } else if ok, err := cond(ha); err != nil { + return nil, trace.Wrap(err) + } else if ok { + return ha, nil + } + + for { + select { + case ha := <-subscriber.Updates(): + select { + case <-subscriber.Stale(): + // If stale, then this update is not the most recent. Check the backend. + ha, err = h.identityService.GetHeadlessAuthentication(ctx, subscriber.Name()) + if err != nil { + return nil, trace.Wrap(err) + } + default: + } + if ok, err := cond(ha); err != nil { + return nil, trace.Wrap(err) + } else if ok { + return ha, nil + } + case <-ctx.Done(): + return nil, trace.Wrap(ctx.Err()) + case <-h.Done(): + return nil, ErrHeadlessAuthenticationWatcherClosed + } } } diff --git a/lib/services/local/headlessauthn_watcher_test.go b/lib/services/local/headlessauthn_watcher_test.go index 666f611a1ca07..438ebaec10ae7 100644 --- a/lib/services/local/headlessauthn_watcher_test.go +++ b/lib/services/local/headlessauthn_watcher_test.go @@ -18,283 +18,295 @@ package local_test import ( "context" - "sync" + "errors" "testing" "time" - "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/utils/retryutils" - "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" ) -func TestHeadlessAuthenticationWatcher(t *testing.T) { - t.Parallel() - ctx := context.Background() - pubUUID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) - - type testEnv struct { - watcher *local.HeadlessAuthenticationWatcher - watcherClock clockwork.FakeClock - watcherCancel context.CancelFunc - identity *local.IdentityService - buf *backend.CircularBuffer - } - - newTestEnv := func(t *testing.T) *testEnv { - identity := newIdentityService(t, clockwork.NewFakeClock()) +type headlessAuthenticationWatcherTestEnv struct { + watcher *local.HeadlessAuthenticationWatcher + watcherCancel context.CancelFunc + identity *local.IdentityService +} - // use a standalone buffer as a watcher service. - buf := backend.NewCircularBuffer() - buf.SetInit() +func newHeadlessAuthenticationWatcherTestEnv(t *testing.T, clock clockwork.Clock) *headlessAuthenticationWatcherTestEnv { + identity := newIdentityService(t, clock) - watcherCtx, watcherCancel := context.WithCancel(ctx) - t.Cleanup(watcherCancel) + watcherCtx, watcherCancel := context.WithCancel(context.Background()) + t.Cleanup(watcherCancel) - watcherClock := clockwork.NewFakeClock() - w, err := local.NewHeadlessAuthenticationWatcher(watcherCtx, local.HeadlessAuthenticationWatcherConfig{ - Clock: watcherClock, - WatcherService: buf, - Backend: identity.Backend, - }) - require.NoError(t, err) + w, err := local.NewHeadlessAuthenticationWatcher(watcherCtx, local.HeadlessAuthenticationWatcherConfig{ + Clock: clock, + Backend: identity.Backend, + }) + require.NoError(t, err) - return &testEnv{ - watcher: w, - watcherClock: watcherClock, - watcherCancel: watcherCancel, - identity: identity, - buf: buf, - } + return &headlessAuthenticationWatcherTestEnv{ + watcher: w, + watcherCancel: watcherCancel, + identity: identity, } +} - waitInGoroutine := func(ctx context.Context, t *testing.T, watcher *local.HeadlessAuthenticationWatcher, name string, cond func(*types.HeadlessAuthentication) (bool, error), - ) (headlessAuthnC chan *types.HeadlessAuthentication, firstEventReceivedC chan struct{}, errC chan error) { - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second) - t.Cleanup(waitCancel) - - headlessAuthnC = make(chan *types.HeadlessAuthentication, 1) - errC = make(chan error, 1) - firstEventReceivedC = make(chan struct{}) - go func() { - var closeOnce sync.Once - headlessAuthn, err := watcher.Wait(waitCtx, name, func(ha *types.HeadlessAuthentication) (bool, error) { - closeOnce.Do(func() { close(firstEventReceivedC) }) - return cond(ha) - }) - errC <- err - headlessAuthnC <- headlessAuthn - }() - return headlessAuthnC, firstEventReceivedC, errC - } +func TestHeadlessAuthenticationWatcher_Subscribe(t *testing.T) { + t.Parallel() + ctx := context.Background() + pubUUID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) - // The waiter may miss put events during initialization, so we continuously emit them until one is caught. - // This can also be used to wait until a waiter is fully initialized. - waitForPutEvent := func(t *testing.T, s *testEnv, ha *types.HeadlessAuthentication, firstEventReceivedC chan struct{}) { - item, err := local.MarshalHeadlessAuthenticationToItem(ha) - require.NoError(t, err) + t.Run("Updates", func(t *testing.T) { + t.Parallel() + s := newHeadlessAuthenticationWatcherTestEnv(t, clockwork.NewFakeClock()) - retry, err := retryutils.NewLinear(retryutils.LinearConfig{ - // Send the first event after a short tick before slowing down. - // We want to catch the first event quickly without risking sending - // multiple events in quick succession, which could mark the waiter - // as stale. In practice, this stale behavior protects against race - // conditions, but in these these tests we want to control when the - // watcher is marked as stale. - First: 25 * time.Millisecond, - Step: 75 * time.Millisecond, - Max: 100 * time.Millisecond, - }) + sub, err := s.watcher.Subscribe(ctx, pubUUID) require.NoError(t, err) + t.Cleanup(sub.Close) - timeoutCtx, timeoutCancel := context.WithTimeout(ctx, 500*time.Millisecond) - defer timeoutCancel() + // Make an update. Make sure we are servicing the updates channel first. + readyForUpdate := make(chan struct{}) + stubC := make(chan *types.HeadlessAuthentication, 1) + go func() { + <-readyForUpdate + stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + assert.NoError(t, err) + stubC <- stub + }() - // We don't know when the waiter will be initialized, so we send - // events on each tick until one is received by the waiter. for { select { - case <-timeoutCtx.Done(): - t.Fatal("Watcher never received an event") - case <-retry.After(): - retry.Inc() - s.buf.Emit(backend.Event{ - Type: types.OpPut, - Item: *item, - }) - case <-firstEventReceivedC: + case update := <-sub.Updates(): + // We should receive the update. + require.Equal(t, <-stubC, update) return + case <-sub.Stale(): + t.Fatal("Expected subscriber to not be marked as stale") + case <-time.After(time.Second): + t.Fatal("Expected subscriber to receive an update") + case readyForUpdate <- struct{}{}: } } - } + }) - t.Run("WaitEventWithConditionMet", func(t *testing.T) { + t.Run("Stale", func(t *testing.T) { t.Parallel() - s := newTestEnv(t) - - headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - return ha.User != "", nil - }) + s := newHeadlessAuthenticationWatcherTestEnv(t, clockwork.NewFakeClock()) - // Emit put event that passes the condition. - stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) + sub, err := s.watcher.Subscribe(ctx, pubUUID) require.NoError(t, err) - stub.User = "user" + t.Cleanup(sub.Close) - waitForPutEvent(t, s, stub, firstEventReceivedC) + // Make an 2 updates without servicing the subscriber's Updates channel. + // The second update will be dropped and result in the subscriber being + // marked as stale. + stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + require.NoError(t, err) + replace := *stub + replace.User = "user" + replace.PublicKey = []byte(sshPubKey) + _, err = s.identity.CompareAndSwapHeadlessAuthentication(ctx, stub, &replace) + require.NoError(t, err) - require.NoError(t, <-errC) - require.Equal(t, stub, <-headlessAuthnCh) + select { + case <-sub.Stale(): + case <-time.After(time.Second): + t.Fatal("Expected subscriber to be marked as stale") + } }) - t.Run("WaitEventWithConditionUnmet", func(t *testing.T) { + t.Run("WatchReset", func(t *testing.T) { t.Parallel() - s := newTestEnv(t) + clock := clockwork.NewFakeClock() + s := newHeadlessAuthenticationWatcherTestEnv(t, clock) - waitCtx, waitCancel := context.WithCancel(ctx) - t.Cleanup(waitCancel) + sub, err := s.watcher.Subscribe(ctx, pubUUID) + require.NoError(t, err) + t.Cleanup(sub.Close) - headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(waitCtx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - return ha.User != "", nil - }) + // Closed watchers should be handled gracefully and reset. + s.identity.Backend.CloseWatchers() + clock.BlockUntil(1) - // Emit put event that doesn't pass the condition (user not set). The waiter should ignore these events. - stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) - require.NoError(t, err) + stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + assert.NoError(t, err) - waitForPutEvent(t, s, stub, firstEventReceivedC) + // Reset the watcher. Make sure we are servicing the updates channel first. + readyForUpdate := make(chan struct{}) + go func() { + <-readyForUpdate + clock.Advance(s.watcher.MaxRetryPeriod) + }() - // Ensure that the waiter did not finish with the condition unmet. + readyForUpdate <- struct{}{} select { - case err := <-errC: - t.Errorf("Expected waiter to continue but instead the waiter returned with err: %v", err) - default: - waitCancel() + case update := <-sub.Updates(): + // We should receive an update of the current backend state on watcher reset. + require.Equal(t, stub, update) + return + case <-sub.Stale(): + t.Fatal("Expected subscriber to not be marked as stale") + case <-time.After(time.Second): + t.Fatal("Expected subscriber to receive an update") } - - require.Error(t, <-errC) - require.Nil(t, <-headlessAuthnCh) }) +} - t.Run("WaitBackend", func(t *testing.T) { +func TestHeadlessAuthenticationWatcher_WaitForUpdate(t *testing.T) { + t.Parallel() + ctx := context.Background() + pubUUID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) + + t.Run("ConditionMet", func(t *testing.T) { t.Parallel() - s := newTestEnv(t) + s := newHeadlessAuthenticationWatcherTestEnv(t, clockwork.NewFakeClock()) - stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + sub, err := s.watcher.Subscribe(ctx, pubUUID) require.NoError(t, err) + t.Cleanup(sub.Close) - waitCtx, waitCancel := context.WithTimeout(ctx, time.Second) - t.Cleanup(waitCancel) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + t.Cleanup(cancel) - // Wait should immediately check the backend and return the existing headless authentication stub. - headlessAuthn, err := s.watcher.Wait(waitCtx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - return true, nil - }) + headlessAuthnCh := make(chan *types.HeadlessAuthentication, 1) + errC := make(chan error, 1) + go func() { + ha, err := s.watcher.WaitForUpdate(ctx, sub, func(ha *types.HeadlessAuthentication) (bool, error) { + return ha.User != "", nil + }) + headlessAuthnCh <- ha + errC <- err + }() + // Make an update that passes the condition. + stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) require.NoError(t, err) - require.Equal(t, stub, headlessAuthn) + + replace := *stub + replace.User = "user" + replace.PublicKey = []byte(sshPubKey) + _, err = s.identity.CompareAndSwapHeadlessAuthentication(ctx, stub, &replace) + require.NoError(t, err) + + require.NoError(t, <-errC) + require.Equal(t, &replace, <-headlessAuthnCh) }) - t.Run("WaitTimeout", func(t *testing.T) { + t.Run("ConditionUnmet", func(t *testing.T) { t.Parallel() - s := newTestEnv(t) + s := newHeadlessAuthenticationWatcherTestEnv(t, clockwork.NewFakeClock()) - waitCtx, waitCancel := context.WithTimeout(ctx, 10*time.Millisecond) - t.Cleanup(waitCancel) + sub, err := s.watcher.Subscribe(ctx, pubUUID) + require.NoError(t, err) + t.Cleanup(sub.Close) + + unknownUserErr := errors.New("Unknown user") + conditionFunc := func(ha *types.HeadlessAuthentication) (bool, error) { + if ha.User == "" { + return false, nil + } else if ha.User == "unknown" { + return false, unknownUserErr + } + return true, nil + } + + // Make an update that doesn't pass the condition (user not set). + // The waiter should ignore this update and timeout. + stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + t.Cleanup(cancel) - _, err := s.watcher.Wait(waitCtx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) + _, err = s.watcher.WaitForUpdate(ctx, sub, conditionFunc) require.Error(t, err) require.ErrorIs(t, err, context.DeadlineExceeded) + + // Make an update that causes the condition to error (user "unknown"). + // The waiter should return the condition error during the initial backend check. + replace := *stub + replace.User = "unknown" + replace.PublicKey = []byte(sshPubKey) + _, err = s.identity.CompareAndSwapHeadlessAuthentication(ctx, stub, &replace) + require.NoError(t, err) + + ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) + t.Cleanup(cancel) + + _, err = s.watcher.WaitForUpdate(ctx, sub, conditionFunc) + require.Error(t, err) + require.ErrorIs(t, err, unknownUserErr) }) - t.Run("StaleCheck", func(t *testing.T) { + t.Run("InitialBackendCheck", func(t *testing.T) { t.Parallel() - s := newTestEnv(t) + s := newHeadlessAuthenticationWatcherTestEnv(t, clockwork.NewFakeClock()) - // Create a waiter that we can block/unblock. - blockWaiter := make(chan struct{}) - var closeOnce sync.Once - t.Cleanup(func() { - closeOnce.Do(func() { close(blockWaiter) }) - }) + stub, err := s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) + require.NoError(t, err) - _, blockedWaiterEventReceived, blockedWaiterErrC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - <-blockWaiter - return false, nil - }) + waitCtx, waitCancel := context.WithTimeout(ctx, 5*time.Second) + t.Cleanup(waitCancel) - // Emit stub put event and wait for it to be caught by the waiter. - stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) + sub, err := s.watcher.Subscribe(ctx, pubUUID) require.NoError(t, err) - waitForPutEvent(t, s, stub, blockedWaiterEventReceived) + t.Cleanup(sub.Close) - // Create a second waiter to catch a second put event. - _, freeWaiterEventReceivedC, freeWaiterErrC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { + // WaitForUpdate should immediately check the backend and return the existing headless authentication stub. + headlessAuthn, err := s.watcher.WaitForUpdate(waitCtx, sub, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) - waitForPutEvent(t, s, stub, freeWaiterEventReceivedC) - require.NoError(t, <-freeWaiterErrC) - - // unblock the waiter. It should perform a stale check and return a not found error. - closeOnce.Do(func() { close(blockWaiter) }) - err = <-blockedWaiterErrC - require.True(t, trace.IsNotFound(err), "Expected a not found error from Wait but got %v", err) + require.NoError(t, err) + require.Equal(t, stub, headlessAuthn) }) - t.Run("WatchReset", func(t *testing.T) { + t.Run("Timeout", func(t *testing.T) { t.Parallel() - s := newTestEnv(t) + s := newHeadlessAuthenticationWatcherTestEnv(t, clockwork.NewFakeClock()) - headlessAuthnCh, firstEventReceivedC, errC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - return services.ValidateHeadlessAuthentication(ha) == nil, nil - }) - - stub, err := types.NewHeadlessAuthenticationStub(pubUUID, time.Now()) + sub, err := s.watcher.Subscribe(ctx, pubUUID) require.NoError(t, err) - waitForPutEvent(t, s, stub, firstEventReceivedC) - - // closed watchers should be handled gracefully and reset. - s.buf.Clear() - s.watcherClock.BlockUntil(1) + t.Cleanup(sub.Close) - // The watcher should notify waiters of backend state on watcher reset. - replace := *stub - replace.PublicKey = []byte(sshPubKey) - replace.User = "user" - - stub, err = s.identity.CreateHeadlessAuthenticationStub(ctx, pubUUID) - require.NoError(t, err) - swapped, err := s.identity.CompareAndSwapHeadlessAuthentication(ctx, stub, &replace) - require.NoError(t, err) + waitCtx, waitCancel := context.WithTimeout(ctx, 10*time.Millisecond) + t.Cleanup(waitCancel) - s.watcherClock.Advance(s.watcher.MaxRetryPeriod) - require.NoError(t, <-errC) - require.Equal(t, swapped, <-headlessAuthnCh) + _, err = s.watcher.WaitForUpdate(waitCtx, sub, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) }) t.Run("WatcherClosed", func(t *testing.T) { t.Parallel() - s := newTestEnv(t) + s := newHeadlessAuthenticationWatcherTestEnv(t, clockwork.NewFakeClock()) - _, _, errC := waitInGoroutine(ctx, t, s.watcher, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { - return true, nil - }) + sub, err := s.watcher.Subscribe(ctx, pubUUID) + require.NoError(t, err) + t.Cleanup(sub.Close) + + errC := make(chan error) + go func() { + _, err := s.watcher.WaitForUpdate(ctx, sub, func(ha *types.HeadlessAuthentication) (bool, error) { + return true, nil + }) + errC <- err + }() s.watcherCancel() - // waiters should be notified to close and result in ctx error + // WaitForUpdate should end with closed watcher error. waitErr := <-errC require.Error(t, waitErr) - require.Equal(t, waitErr.Error(), "headless authentication watcher closed") + require.ErrorIs(t, waitErr, local.ErrHeadlessAuthenticationWatcherClosed) - // New waiters should be prevented. - _, err := s.watcher.Wait(ctx, pubUUID, func(ha *types.HeadlessAuthentication) (bool, error) { return true, nil }) + // New subscribers should be prevented. + _, err = s.watcher.Subscribe(ctx, pubUUID) require.Error(t, err) }) }