diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..723ef36 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/benchmarks/go.mod b/benchmarks/go.mod new file mode 100644 index 0000000..a4c594e --- /dev/null +++ b/benchmarks/go.mod @@ -0,0 +1,9 @@ +module pester/benchmarks + +require ( + github.com/sethgrid/pester v1.0.0 +) + +replace github.com/sethgrid/pester v1.0.0 => ../ + +go 1.14 \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1e56fb6 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/sethgrid/pester + +go 1.14 \ No newline at end of file diff --git a/pester.go b/pester.go index f5d8866..8021e60 100644 --- a/pester.go +++ b/pester.go @@ -12,22 +12,26 @@ import ( "math/rand" "net/http" "net/url" + "strings" "sync" "time" ) const ( - methodDo = "Do" - methodGet = "Get" - methodHead = "Head" - methodPost = "Post" - methodPostForm = "PostForm" + methodDo = "Do" + methodGet = "Get" + methodHead = "Head" + methodPost = "Post" + methodPostForm = "PostForm" + headerKeyContentType = "Content-Type" + contentTypeFormURLEncoded = "application/x-www-form-urlencoded" ) //ErrUnexpectedMethod occurs when an http.Client method is unable to be mapped from a calling method in the pester client var ErrUnexpectedMethod = errors.New("unexpected client method, must be one of Do, Get, Head, Post, or PostFrom") // ErrReadingBody happens when we cannot read the body bytes +// Deprecated: use ErrReadingRequestBody var ErrReadingBody = errors.New("error reading body") // ErrReadingRequestBody happens when we cannot read the request body bytes @@ -91,7 +95,7 @@ type params struct { req *http.Request url string bodyType string - body io.Reader + body io.ReadCloser data url.Values } @@ -184,6 +188,16 @@ func (c *Client) Wait() { c.wg.Wait() } +func (c *Client) copyBody(src io.ReadCloser) ([]byte, error) { + b, err := ioutil.ReadAll(src) + if err != nil { + return nil, ErrReadingRequestBody + } + src.Close() + + return b, nil +} + // pester provides all the logic of retries, concurrency, backoff, and logging func (c *Client) pester(p params) (*http.Response, error) { resultCh := make(chan result) @@ -227,21 +241,34 @@ func (c *Client) pester(p params) (*http.Response, error) { } // if we have a request body, we need to save it for later - var originalRequestBody []byte - var originalBody []byte - var err error - if p.req != nil && p.req.Body != nil { - originalRequestBody, err = ioutil.ReadAll(p.req.Body) - if err != nil { - return nil, ErrReadingRequestBody - } - p.req.Body.Close() + var ( + request *http.Request + originalBody []byte + err error + ) + + if p.req != nil && p.req.Body != nil && p.body == nil { + originalBody, err = c.copyBody(p.req.Body) + } else if p.body != nil { + originalBody, err = c.copyBody(p.body) } - if p.body != nil { - originalBody, err = ioutil.ReadAll(p.body) - if err != nil { - return nil, ErrReadingBody - } + + switch p.method { + case methodDo: + request = p.req + case methodGet, methodHead: + request, err = http.NewRequest(p.verb, p.url, nil) + case methodPostForm, methodPost: + request, err = http.NewRequest(http.MethodPost, p.url, ioutil.NopCloser(bytes.NewBuffer(originalBody))) + default: + err = ErrUnexpectedMethod + } + if err != nil { + return nil, err + } + + if len(p.bodyType) > 0 { + request.Header.Set(headerKeyContentType, p.bodyType) } AttemptLimit := c.MaxRetries @@ -249,73 +276,45 @@ func (c *Client) pester(p params) (*http.Response, error) { AttemptLimit = 1 } - for req := 0; req < concurrency; req++ { + for n := 0; n < concurrency; n++ { c.wg.Add(1) totalSentRequests.Add(1) - go func(n int, p params) { + go func(n int, req *http.Request) { defer c.wg.Done() defer totalSentRequests.Done() - var err error for i := 1; i <= AttemptLimit; i++ { c.wg.Add(1) defer c.wg.Done() + select { case <-finishCh: return default: } - // rehydrate the body (it is drained each read) - if len(originalRequestBody) > 0 { - p.req.Body = ioutil.NopCloser(bytes.NewBuffer(originalRequestBody)) - } - if len(originalBody) > 0 { - p.body = bytes.NewBuffer(originalBody) - } - - var resp *http.Response - // route the calls - switch p.method { - case methodDo: - resp, err = httpClient.Do(p.req) - case methodGet: - resp, err = httpClient.Get(p.url) - case methodHead: - resp, err = httpClient.Head(p.url) - case methodPost: - resp, err = httpClient.Post(p.url, p.bodyType, p.body) - case methodPostForm: - resp, err = httpClient.PostForm(p.url, p.data) - default: - err = ErrUnexpectedMethod - } - + resp, err := httpClient.Do(req) // Early return if we have a valid result // Only retry (ie, continue the loop) on 5xx status codes and 429 - if err == nil && resp.StatusCode < http.StatusInternalServerError && (resp.StatusCode != http.StatusTooManyRequests || (resp.StatusCode == http.StatusTooManyRequests && !c.RetryOnHTTP429)) { multiplexCh <- result{resp: resp, err: err, req: n, retry: i} return } - loggingContext := context.Background() - if p.req != nil { - loggingContext = p.req.Context() - } - + loggingContext := req.Context() c.log( loggingContext, ErrEntry{ Time: time.Now(), Method: p.method, - Verb: p.verb, - URL: p.url, + Verb: req.Method, + URL: req.URL.String(), Request: n, Retry: i + 1, // would remove, but would break backward compatibility Attempt: i, Err: err, - }) + }, + ) // if it is the last iteration, grab the result (which is an error at this point) if i == AttemptLimit { @@ -324,14 +323,11 @@ func (c *Client) pester(p params) (*http.Response, error) { } //If the request has been cancelled, skip retries - if p.req != nil { - ctx := p.req.Context() - select { - case <-ctx.Done(): - multiplexCh <- result{resp: resp, err: ctx.Err()} - return - default: - } + select { + case <-req.Context().Done(): + multiplexCh <- result{resp: resp, err: req.Context().Err()} + return + default: } // if we are retrying, we should close this response body to free the fd @@ -342,7 +338,12 @@ func (c *Client) pester(p params) (*http.Response, error) { // prevent a 0 from causing the tick to block, pass additional microsecond <-time.After(c.Backoff(i) + 1*time.Microsecond) } - }(req, p) + }(n, request) + + // rehydrate the body (it is drained each read) + if request.Body != nil { + request.Body = ioutil.NopCloser(bytes.NewBuffer(originalBody)) + } } // spin off the go routine so it can continually listen in on late results and close the response bodies @@ -373,8 +374,8 @@ func (c *Client) pester(p params) (*http.Response, error) { defer c.Unlock() c.SuccessReqNum = res.req c.SuccessRetryNum = res.retry - return res.resp, res.err + return res.resp, res.err } // LogString provides a string representation of the errors the client has seen @@ -440,12 +441,12 @@ func (c *Client) Head(url string) (resp *http.Response, err error) { // Post provides the same functionality as http.Client.Post func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *http.Response, err error) { - return c.pester(params{method: methodPost, url: url, bodyType: bodyType, body: body, verb: http.MethodPost}) + return c.pester(params{method: methodPost, url: url, bodyType: bodyType, body: ioutil.NopCloser(body), verb: http.MethodPost}) } // PostForm provides the same functionality as http.Client.PostForm func (c *Client) PostForm(url string, data url.Values) (resp *http.Response, err error) { - return c.pester(params{method: methodPostForm, url: url, data: data, verb: http.MethodPost}) + return c.pester(params{method: methodPostForm, url: url, bodyType: contentTypeFormURLEncoded, body: ioutil.NopCloser(strings.NewReader(data.Encode())), verb: http.MethodPost}) } // set RetryOnHTTP429 for clients, diff --git a/sample/go.mod b/sample/go.mod new file mode 100644 index 0000000..8bc7a87 --- /dev/null +++ b/sample/go.mod @@ -0,0 +1,9 @@ +module pester/sample + +require ( + github.com/sethgrid/pester v1.0.0 +) + +replace github.com/sethgrid/pester v1.0.0 => ../ + +go 1.14 \ No newline at end of file diff --git a/sample/main.go b/sample/main.go index 3a63fdc..635dd81 100644 --- a/sample/main.go +++ b/sample/main.go @@ -164,15 +164,15 @@ func randoHandler(w http.ResponseWriter, r *http.Request) { var code int switch rand.Intn(10) { case 0: - code = 404 + code = http.StatusNotFound case 1: - code = 400 + code = http.StatusBadRequest case 2: - code = 501 + code = http.StatusNotImplemented case 3: - code = 500 + code = http.StatusInternalServerError default: - code = 200 + code = http.StatusOK } log.Printf("incoming request on :9000 - will return %d in %d ms", code, delay)