Skip to content

Commit

Permalink
feat: implement retry policy
Browse files Browse the repository at this point in the history
  • Loading branch information
jooola committed Jul 15, 2024
1 parent b3e91d9 commit 7e4e0a3
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 37 deletions.
30 changes: 27 additions & 3 deletions hcloud/action_waiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,41 @@ func TestWaitFor(t *testing.T) {
},
},
{
Name: "fail with api error",
Name: "fail with server error",
WantRequests: []mockutils.Request{
{Method: "GET", Path: "/actions?id=1509772237&page=1&sort=status&sort=id",
Status: 503},
Status: 500,
},
},
Run: func(env testEnv) {
actions := []*Action{{ID: 1509772237, Status: ActionStatusRunning}}

err := env.Client.Action.WaitFor(context.Background(), actions...)
assert.Error(t, err)
assert.Equal(t, "hcloud: server responded with status code 503", err.Error())
assert.Equal(t, "hcloud: server responded with status code 500", err.Error())
},
},
{
Name: "succeed with retry",
WantRequests: []mockutils.Request{
{Method: "GET", Path: "/actions?id=1509772237&page=1&sort=status&sort=id",
Status: 503,
},
{Method: "GET", Path: "/actions?id=1509772237&page=1&sort=status&sort=id",
Status: 200,
JSONRaw: `{
"actions": [
{ "id": 1509772237, "status": "success", "progress": 100 }
],
"meta": { "pagination": { "page": 1 }}
}`,
},
},
Run: func(env testEnv) {
actions := []*Action{{ID: 1509772237, Status: ActionStatusRunning}}

err := env.Client.Action.WaitFor(context.Background(), actions...)
assert.NoError(t, err)
},
},
},
Expand Down
16 changes: 9 additions & 7 deletions hcloud/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ type Client struct {
endpoint string
token string
tokenValid bool
backoffFunc BackoffFunc
retryBackoffFunc BackoffFunc
retryMaxRetries int
pollBackoffFunc BackoffFunc
httpClient *http.Client
applicationName string
Expand Down Expand Up @@ -162,7 +163,7 @@ func WithPollBackoffFunc(f BackoffFunc) ClientOption {
// The backoff function is used for retrying HTTP requests.
func WithBackoffFunc(f BackoffFunc) ClientOption {
return func(client *Client) {
client.backoffFunc = f
client.retryBackoffFunc = f
}
}

Expand Down Expand Up @@ -201,11 +202,12 @@ func WithInstrumentation(registry prometheus.Registerer) ClientOption {
// NewClient creates a new client.
func NewClient(options ...ClientOption) *Client {
client := &Client{
endpoint: Endpoint,
tokenValid: true,
httpClient: &http.Client{},
backoffFunc: ExponentialBackoff(2, 500*time.Millisecond),
pollBackoffFunc: ConstantBackoff(500 * time.Millisecond),
endpoint: Endpoint,
tokenValid: true,
httpClient: &http.Client{},
retryBackoffFunc: ExponentialBackoff(2, time.Second),
retryMaxRetries: 5,
pollBackoffFunc: ConstantBackoff(500 * time.Millisecond),
}

for _, option := range options {
Expand Down
2 changes: 1 addition & 1 deletion hcloud/client_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func assembleHandlerChain(client *Client) handler {
h = wrapErrorHandler(h)

// Retry request if condition are met
h = wrapRetryHandler(h, client.backoffFunc)
h = wrapRetryHandler(h, client.retryBackoffFunc, client.retryMaxRetries)

// Finally parse the response body into the provided schema
h = wrapParseHandler(h)
Expand Down
5 changes: 4 additions & 1 deletion hcloud/client_handler_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package hcloud

import (
"encoding/json"
"errors"
"fmt"
"net/http"

"github.com/hetznercloud/hcloud-go/v2/hcloud/schema"
)

var ErrStatusCode = errors.New("server responded with status code")

func wrapErrorHandler(wrapped handler) handler {
return &errorHandler{wrapped}
}
Expand All @@ -25,7 +28,7 @@ func (h *errorHandler) Do(req *http.Request, v any) (resp *Response, err error)
if resp.StatusCode >= 400 && resp.StatusCode <= 599 {
err = errorFromBody(resp)
if err == nil {
err = fmt.Errorf("hcloud: server responded with status code %d", resp.StatusCode)
err = fmt.Errorf("hcloud: %w %d", ErrStatusCode, resp.StatusCode)
}
}
return resp, err
Expand Down
59 changes: 52 additions & 7 deletions hcloud/client_handler_retry.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
package hcloud

import (
"errors"
"net"
"net/http"
"time"
)

func wrapRetryHandler(wrapped handler, backoffFunc BackoffFunc) handler {
return &retryHandler{wrapped, backoffFunc}
func wrapRetryHandler(wrapped handler, backoffFunc BackoffFunc, maxRetries int) handler {
return &retryHandler{wrapped, backoffFunc, maxRetries}
}

type retryHandler struct {
handler handler
backoffFunc BackoffFunc
maxRetries int
}

func (h *retryHandler) Do(req *http.Request, v any) (resp *Response, err error) {
retries := 0
ctx := req.Context()

for {
// Clone the request using the original context
cloned, err := cloneRequest(req, req.Context())
cloned, err := cloneRequest(req, ctx)
if err != nil {
return nil, err
}
Expand All @@ -30,13 +34,54 @@ func (h *retryHandler) Do(req *http.Request, v any) (resp *Response, err error)
// - request preparation
// - network connectivity
// - http status code (see [errorHandler])
if IsError(err, ErrorCodeConflict) {
time.Sleep(h.backoffFunc(retries))
retries++
continue
if ctx.Err() != nil {
// early return if the context was canceled or timed out
return resp, err
}

if retries < h.maxRetries && retryPolicy(resp, err) {
select {
case <-ctx.Done():
return resp, err
case <-time.After(h.backoffFunc(retries)):
retries++
continue
}
}
}

return resp, err
}
}

func retryPolicy(resp *Response, err error) bool {
if err != nil {
var apiErr Error
var netErr net.Error

switch {
case errors.As(err, &apiErr):
switch apiErr.Code { //nolint:exhaustive
case ErrorCodeConflict:
return true
case ErrorCodeRateLimitExceeded:
return true
}
case errors.Is(err, ErrStatusCode):
switch resp.Response.StatusCode {
// 4xx errors
case http.StatusTooManyRequests:
return true
// 5xx errors
case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
return true
}
case errors.As(err, &netErr):
if netErr.Timeout() {
return true
}
}
}

return false
}
148 changes: 131 additions & 17 deletions hcloud/client_handler_retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,166 @@ func TestRetryHandler(t *testing.T) {
testCases := []struct {
name string
wrapped func(req *http.Request, v any) (*Response, error)
want int
recover bool
want func(t *testing.T, err error, retryCount int)
}{
{
name: "network error",
name: "random error",
wrapped: func(_ *http.Request, _ any) (*Response, error) {
return nil, fmt.Errorf("network error")
return nil, fmt.Errorf("random error")
},
want: func(t *testing.T, err error, retryCount int) {
assert.EqualError(t, err, "random error")
assert.Equal(t, 0, retryCount)
},
},
{
name: "http 503 error recovery",
wrapped: func(req *http.Request, _ any) (*Response, error) {
resp := fakeResponse(t, 503, "", false)
resp.Response.Request = req
return resp, fmt.Errorf("%w %d", ErrStatusCode, 503)
},
recover: true,
want: func(t *testing.T, err error, retryCount int) {
assert.Nil(t, err)
assert.Equal(t, 1, retryCount)
},
want: 0,
},
{
name: "http 503 error",
wrapped: func(_ *http.Request, _ any) (*Response, error) {
return nil, nil
wrapped: func(req *http.Request, _ any) (*Response, error) {
resp := fakeResponse(t, 503, "", false)
resp.Response.Request = req
return resp, fmt.Errorf("%w %d", ErrStatusCode, 503)
},
want: func(t *testing.T, err error, retryCount int) {
assert.EqualError(t, err, "server responded with status code 503")
assert.Equal(t, 5, retryCount)
},
},
{
name: "api conflict error recovery",
wrapped: func(req *http.Request, _ any) (*Response, error) {
resp := fakeResponse(t, 409, "", false)
resp.Response.Request = req
return nil, ErrorFromSchema(schema.Error{Code: string(ErrorCodeConflict), Message: "A conflict occurred"})
},
recover: true,
want: func(t *testing.T, err error, retryCount int) {
assert.Nil(t, err)
assert.Equal(t, 1, retryCount)
},
want: 0,
},
{
name: "api conflict error",
wrapped: func(_ *http.Request, _ any) (*Response, error) {
return nil, ErrorFromSchema(schema.Error{Code: string(ErrorCodeConflict)})
wrapped: func(req *http.Request, _ any) (*Response, error) {
resp := fakeResponse(t, 409, "", false)
resp.Response.Request = req
return nil, ErrorFromSchema(schema.Error{Code: string(ErrorCodeConflict), Message: "A conflict occurred"})
},
want: func(t *testing.T, err error, retryCount int) {
assert.EqualError(t, err, "A conflict occurred (conflict)")
assert.Equal(t, 5, retryCount)
},
want: 1,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
m := &mockHandler{testCase.wrapped}

retryCount := 0
h := wrapRetryHandler(m, func(_ int) time.Duration {
// Reset the mock handler to exit the retry loop
m.f = func(_ *http.Request, _ any) (*Response, error) { return nil, nil }
h := wrapRetryHandler(m, func(retries int) time.Duration {
assert.Equal(t, retryCount, retries)

if testCase.recover {
// Reset the mock handler to exit the retry loop
m.f = func(_ *http.Request, _ any) (*Response, error) { return nil, nil }
}

retryCount++
return 0
})
}, 5)

client := NewClient(WithToken("dummy"))
req, err := client.NewRequest(context.Background(), "GET", "/", nil)
require.NoError(t, err)
require.Equal(t, 0, retryCount)

_, err = h.Do(req, nil)
testCase.want(t, err, retryCount)
})
}
}

assert.Equal(t, 0, retryCount)
func TestRetryPolicy(t *testing.T) {
testCases := []struct {
name string
resp *Response
want bool
}{
{
// The API error code unavailable is used in many unexpected situations, we must only
// retry if the server returns HTTP 503.
name: "api returns unavailable error",
resp: fakeResponse(t, 503, `{"error":{"code":"unavailable"}}`, true),
want: false,
},
{
name: "server returns 503 error",
resp: fakeResponse(t, 503, ``, false),
want: true,
},
{
name: "api returns rate_limit_exceeded error",
resp: fakeResponse(t, 429, `{"error":{"code":"rate_limit_exceeded"}}`, true),
want: true,
},
{
name: "server returns 429 error",
resp: fakeResponse(t, 429, ``, false),
want: true,
},
{
name: "api returns conflict error",
resp: fakeResponse(t, 409, `{"error":{"code":"conflict"}}`, true),
want: true,
},
{
// HTTP 409 is used in many situations (e.g. uniqueness_error), we must only
// retry if the API error code is conflict.
name: "server returns 409 error",
resp: fakeResponse(t, 409, ``, false),
want: false,
},
{
// The API error code locked is used in many unexpected situations, we can
// only retry in specific context where we know the error is not misused.
name: "api returns locked error",
resp: fakeResponse(t, 423, `{"error":{"code":"locked"}}`, true),
want: false,
},
{
// HTTP 423 is used in many situations (e.g. protected), we must only
// retry if the API error code is locked.
name: "server returns 423 error",
resp: fakeResponse(t, 423, ``, false),
want: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
req, err := http.NewRequest("GET", "/", nil)
require.NoError(t, err)

h.Do(req, nil)
m := &mockHandler{func(req *http.Request, _ any) (*Response, error) {
testCase.resp.Request = req
return testCase.resp, nil
}}
h := wrapErrorHandler(m)

assert.Equal(t, testCase.want, retryCount)
result := retryPolicy(h.Do(req, nil))
assert.Equal(t, testCase.want, result)
})
}
}
Loading

0 comments on commit 7e4e0a3

Please sign in to comment.