diff --git a/client.go b/client.go index 0e32376..57116e9 100644 --- a/client.go +++ b/client.go @@ -79,10 +79,10 @@ var ( // ReaderFunc is the type of function that can be given natively to NewRequest type ReaderFunc func() (io.Reader, error) -// ResponseHandlingFunc is a type of function that takes in a Response, and does something with it. +// ResponseHandlerFunc is a type of function that takes in a Response, and does something with it. // It only runs if the initial part of the request was successful. // If an error is returned, the client's retry policy will be used to determine whether to retry the whole request. -type ResponseHandlingFunc func(*http.Response) error +type ResponseHandlerFunc func(*http.Response) error // LenReader is an interface implemented by many in-memory io.Reader's. Used // for automatically sending the right Content-Length header when possible. @@ -96,7 +96,7 @@ type Request struct { // used to rewind the request data in between retries. body ReaderFunc - responseHandler ResponseHandlingFunc + responseHandler ResponseHandlerFunc // Embed an HTTP request directly. This makes a *Request act exactly // like an *http.Request so that all meta methods are supported. @@ -114,7 +114,7 @@ func (r *Request) WithContext(ctx context.Context) *Request { } // SetResponseHandler allows setting the response handler. -func (r *Request) SetResponseHandler(fn ResponseHandlingFunc) { +func (r *Request) SetResponseHandler(fn ResponseHandlerFunc) { r.responseHandler = fn } @@ -589,6 +589,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) { var doErr, respErr, checkErr error for i := 0; ; i++ { + doErr, respErr = nil, nil attempt++ // Always rewind the request body when non-nil. @@ -619,12 +620,23 @@ func (c *Client) Do(req *Request) (*http.Response, error) { // Attempt the request resp, doErr = c.HTTPClient.Do(req.Request) - if doErr != nil { + // Check if we should continue with retries. + shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) + if !shouldRetry && doErr == nil && req.responseHandler != nil { + respErr = req.responseHandler(resp) + shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, respErr) + } + + err := doErr + if respErr != nil { + err = respErr + } + if err != nil { switch v := logger.(type) { case LeveledLogger: - v.Error("request failed", "error", doErr, "method", req.Method, "url", req.URL) + v.Error("request failed", "error", err, "method", req.Method, "url", req.URL) case Logger: - v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, doErr) + v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, err) } } else { // Call this here to maintain the behavior of logging all requests, @@ -642,13 +654,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } } - // Check if we should continue with retries. - shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) - if !shouldRetry && doErr == nil && req.responseHandler != nil { - respErr = req.responseHandler(resp) - shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, respErr) - } - if !shouldRetry { break } diff --git a/client_test.go b/client_test.go index 470e6e5..b3f8e6d 100644 --- a/client_test.go +++ b/client_test.go @@ -279,7 +279,7 @@ func TestClient_Do_WithResponseHandler(t *testing.T) { var shouldSucceed bool tests := []struct { name string - handler ResponseHandlingFunc + handler ResponseHandlerFunc expectedChecks int // often 2x number of attempts since we check twice err string }{