Skip to content

Commit

Permalink
Fix race in jwk.AutoRefresh (#686)
Browse files Browse the repository at this point in the history
* Add test that shows up the race condition

This happens when AutoRefresh.Configure and
AutoRefresh.doRefreshRequests happens at same time, trying to read/write
the same target

* Fix the race in the AutoRefresh

Splicitly protect read and write to targets in the method
AutoRefresh.doRefreshRequest
  • Loading branch information
sergicastro authored Apr 13, 2022
1 parent e831228 commit ea97e8c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
7 changes: 6 additions & 1 deletion jwk/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,9 @@ func (af *AutoRefresh) refreshLoop(ctx context.Context) {
func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableBackoff bool) error {
af.muRegistry.RLock()
t, ok := af.registry[url]
af.muRegistry.RUnlock()

if !ok {
af.muRegistry.RUnlock()
return errors.Errorf(`url "%s" is not registered`, url)
}

Expand All @@ -505,6 +505,7 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB
if t.wl != nil {
fetchOptions = append(fetchOptions, WithFetchWhitelist(t.wl))
}
af.muRegistry.RUnlock()

res, err := fetch(ctx, url, fetchOptions...)
if err == nil {
Expand All @@ -520,7 +521,9 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB
af.muCache.Lock()
af.cache[url] = keyset
af.muCache.Unlock()
af.muRegistry.RLock()
nextInterval := calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval)
af.muRegistry.RUnlock()
rtr := &resetTimerReq{
t: t,
d: nextInterval,
Expand All @@ -532,8 +535,10 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB
}

now := time.Now()
af.muRegistry.Lock()
t.lastRefresh = now.Local()
t.nextRefresh = now.Add(nextInterval).Local()
af.muRegistry.Unlock()
return nil
}
err = parseErr
Expand Down
48 changes: 48 additions & 0 deletions jwk/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/lestrrat-go/jwx/internal/jwxtest"
"github.com/lestrrat-go/jwx/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

//nolint:revive,golint
Expand Down Expand Up @@ -384,3 +385,50 @@ func TestErrorSink(t *testing.T) {
})
}
}

func TestAutoRefreshRace(t *testing.T) {
k, err := jwxtest.GenerateRsaJwk()
if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) {
return
}
set := jwk.NewSet()
set.Add(k)

// set up a server that always success since we need to update the registered target
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(k)
}))
defer srv.Close()

// configure a unique auto-refresh
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
ar := jwk.NewAutoRefresh(ctx)
ch := make(chan jwk.AutoRefreshError, 256) // big buffer
ar.ErrorSink(ch)

wg := sync.WaitGroup{}
routineErr := make(chan error, 20)

// execute a bunch of parallel refresh forcing the requests to the server
// need to simulate configure happening also in the goroutine since this is
// the cause of races when refresh is updating the registered targets
for i := 0; i < 5000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
ctx := context.Background()

ar.Configure(srv.URL, jwk.WithRefreshInterval(500*time.Millisecond))
_, err := ar.Refresh(ctx, srv.URL)

if err != nil {
routineErr <- err
}
}()
}
wg.Wait()

require.Len(t, routineErr, 0)
}

0 comments on commit ea97e8c

Please sign in to comment.