diff --git a/jwk/refresh.go b/jwk/refresh.go index 051263259..0a8f75452 100644 --- a/jwk/refresh.go +++ b/jwk/refresh.go @@ -489,9 +489,9 @@ func (af *AutoRefresh) refreshLoop(ctx context.Context) { func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableBackoff bool) error { af.muRegistry.RLock() t, ok := af.registry[url] - af.muRegistry.RUnlock() if !ok { + af.muRegistry.RUnlock() return errors.Errorf(`url "%s" is not registered`, url) } @@ -505,6 +505,7 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB if t.wl != nil { fetchOptions = append(fetchOptions, WithFetchWhitelist(t.wl)) } + af.muRegistry.RUnlock() res, err := fetch(ctx, url, fetchOptions...) if err == nil { @@ -520,7 +521,9 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB af.muCache.Lock() af.cache[url] = keyset af.muCache.Unlock() + af.muRegistry.RLock() nextInterval := calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval) + af.muRegistry.RUnlock() rtr := &resetTimerReq{ t: t, d: nextInterval, @@ -532,8 +535,10 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB } now := time.Now() + af.muRegistry.Lock() t.lastRefresh = now.Local() t.nextRefresh = now.Add(nextInterval).Local() + af.muRegistry.Unlock() return nil } err = parseErr diff --git a/jwk/refresh_test.go b/jwk/refresh_test.go index 723079086..4faa3a904 100644 --- a/jwk/refresh_test.go +++ b/jwk/refresh_test.go @@ -16,6 +16,7 @@ import ( "github.com/lestrrat-go/jwx/internal/jwxtest" "github.com/lestrrat-go/jwx/jwk" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) //nolint:revive,golint @@ -384,3 +385,50 @@ func TestErrorSink(t *testing.T) { }) } } + +func TestAutoRefreshRace(t *testing.T) { + k, err := jwxtest.GenerateRsaJwk() + if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) { + return + } + set := jwk.NewSet() + set.Add(k) + + // set up a server that always success since we need to update the registered target + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(k) + })) + defer srv.Close() + + // configure a unique auto-refresh + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + ar := jwk.NewAutoRefresh(ctx) + ch := make(chan jwk.AutoRefreshError, 256) // big buffer + ar.ErrorSink(ch) + + wg := sync.WaitGroup{} + routineErr := make(chan error, 20) + + // execute a bunch of parallel refresh forcing the requests to the server + // need to simulate configure happening also in the goroutine since this is + // the cause of races when refresh is updating the registered targets + for i := 0; i < 5000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx := context.Background() + + ar.Configure(srv.URL, jwk.WithRefreshInterval(500*time.Millisecond)) + _, err := ar.Refresh(ctx, srv.URL) + + if err != nil { + routineErr <- err + } + }() + } + wg.Wait() + + require.Len(t, routineErr, 0) +}