Skip to content

Commit

Permalink
Make retries cancelable
Browse files Browse the repository at this point in the history
When using a request.WithContext it is possible to cancel a long running
request.  However, when using a cancelable request with retries enabled,
we want to quit the retry loop.

This is impactful when using a conservative backoff strategy and a
large number of retries.
  • Loading branch information
stevenjackson committed Dec 5, 2017
1 parent 760f891 commit 8c123bb
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
10 changes: 10 additions & 0 deletions pester.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,16 @@ func (c *Client) pester(p params) (*http.Response, error) {
return
}

//If the request has been cancelled, skip retries
if p.req != nil {
ctx := p.req.Context()
select {
case <-ctx.Done():
multiplexCh <- result{resp: resp, err: ctx.Err()}
return
}
}

// if we are retrying, we should close this response body to free the fd
if resp != nil {
resp.Body.Close()
Expand Down
43 changes: 43 additions & 0 deletions pester_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pester

import (
"context"
"fmt"
"log"
"net"
Expand Down Expand Up @@ -397,6 +398,48 @@ func TestConcurrentRequestsNotRacyAndDontLeak_SuccessfulRequest(t *testing.T) {
}
}

func TestRetriesNotAttemptedIfContextIsCancelled(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())

port, err := timeoutServer(1 * time.Second)
if err != nil {
t.Fatal("unable to start timeout server", err)
}

timeoutURL := fmt.Sprintf("http://localhost:%d", port)
req, err := http.NewRequest("GET", timeoutURL, nil)
if err != nil {
t.Fatalf("unable to create request %v", err)
}
req = req.WithContext(ctx)

c := New()
c.MaxRetries = 10
c.KeepLog = true
c.Backoff = ExponentialBackoff

//Cancel the context in another routine (eg: user interrupt)
go func() {
cancel()
t.Logf("\n%d - cancelled", time.Now().Unix())
}()

_, err = c.Do(req)
if err == nil {
t.Fatal("expected to get an error")
}
c.Wait()

// in the event of an error, let's see what the logs were
t.Log("\n", c.LogString())

if got, want := c.LogErrCount(), 1; got != want {
t.Fatalf("got %d errors, want %d", got, want)
}
}

func withinEpsilon(got, want int64, epslion float64) bool {
if want <= int64(epslion*float64(got)) || want >= int64(epslion*float64(got)) {
return false
Expand Down

0 comments on commit 8c123bb

Please sign in to comment.