diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index 4cab2ac78948c..00824d2a3f4ee 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -426,39 +426,40 @@ func TestOIDCLogin(t *testing.T) { proxyAddr, err := proxyProcess.ProxyWebAddr() require.NoError(t, err) + // set up watcher to approve the automatic request in background var didAutoRequest atomic.Bool + watcher, err := authServer.NewWatcher(ctx, types.Watch{ + Kinds: []types.WatchKind{ + {Kind: types.KindAccessRequest}, + }, + }) + require.NoError(t, err) + + // ensure that we observe init event prior to moving watcher to background + // goroutine (ensures watcher init does not race with request creation). + select { + case event := <-watcher.Events(): + require.Equal(t, event.Type, types.OpInit) + case <-watcher.Done(): + require.FailNow(t, "watcher closed unexpected", "err: %v", watcher.Error()) + } - errCh := make(chan error) go func() { - watcher, err := authServer.NewWatcher(ctx, types.Watch{ - Kinds: []types.WatchKind{ - {Kind: types.KindAccessRequest}, - }, - }) - if err != nil { - errCh <- err - return - } - for { - select { - case event := <-watcher.Events(): - if event.Type != types.OpPut { - continue - } - err = authServer.SetAccessRequestState(ctx, types.AccessRequestUpdate{ - RequestID: event.Resource.(types.AccessRequest).GetName(), - State: types.RequestState_APPROVED, - }) - didAutoRequest.Store(true) - errCh <- err - return - case <-watcher.Done(): - errCh <- nil - return - case <-ctx.Done(): - errCh <- nil - return + select { + case event := <-watcher.Events(): + if event.Type != types.OpPut { + panic(fmt.Sprintf("unexpected event type: %v\n", event)) } + err = authServer.SetAccessRequestState(ctx, types.AccessRequestUpdate{ + RequestID: event.Resource.(types.AccessRequest).GetName(), + State: types.RequestState_APPROVED, + }) + if err != nil { + panic(fmt.Sprintf("failed to approve request: %v", err)) + } + didAutoRequest.Store(true) + case <-watcher.Done(): + panic(fmt.Sprintf("watcher exited unexpectedly: %v", watcher.Error())) } }() @@ -479,7 +480,6 @@ func TestOIDCLogin(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, <-errCh) // verify that auto-request happened require.True(t, didAutoRequest.Load())