Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry connection on credential errors #11

Merged
merged 1 commit into from
Apr 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sqs/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ func newMockServer(concurrency int, mockSQS *mockSQSAPI) *Server {
receiverCtx, receiverCancelFunc := context.WithCancel(context.Background())

srv := &Server{
QueueURL: "https://myqueue.com",
Svc: mockSQS,
maxConcurrentReceives: make(chan struct{}, concurrency),
receiverCtx: receiverCtx,
receiverCancelFunc: receiverCancelFunc,
serverCtx: serverCtx,
serverCancelFunc: serverCancelFunc,
QueueURL: "https://myqueue.com",
Svc: mockSQS,
}

return srv
Expand Down
88 changes: 87 additions & 1 deletion sqs/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ package sqs
import (
"bytes"
"context"
"errors"
"fmt"
"log"
"os"
"time"

"github.com/zerofox-oss/go-msg"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
Expand All @@ -31,6 +35,7 @@ type Server struct {
receiverCancelFunc context.CancelFunc // CancelFunc for all receiver routines
serverCtx context.Context // context used to control the life of the Server
serverCancelFunc context.CancelFunc // CancelFunc to signal the server should stop requesting messages
session *session.Session // session used to re-create `Svc` when needed
}

// convertToMsgAttrs creates msg.Attributes from sqs.Message.Attributes
Expand Down Expand Up @@ -144,6 +149,35 @@ func (s *Server) Shutdown(ctx context.Context) error {
}
}

// DefaultRetryer implements an AWS `request.Retryer` that has a custom delay
// for credential errors (403 statuscode).
// This is needed in order to wait for credentials to be valid for SQS requests
// due to AWS "eventually consistent" credentials:
// https://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_general.html
type DefaultRetryer struct {
request.Retryer
delay time.Duration
}

// RetryRules returns the delay for the next request to be made
func (r DefaultRetryer) RetryRules(req *request.Request) time.Duration {
if req.HTTPResponse.StatusCode == 403 {
return r.delay
}
return r.Retryer.RetryRules(req)
}

// ShouldRetry determines if the passed request should be retried
func (r DefaultRetryer) ShouldRetry(req *request.Request) bool {
if req.HTTPResponse.StatusCode == 403 {
return true
}
return r.Retryer.ShouldRetry(req)
}

// Option is the signature that modifies a `Server` to set some configuration
type Option func(*Server) error

// NewServer creates and initializes a new Server using queueURL to a SQS queue
// `cl` represents the number of concurrent message receives (10 msgs each).
//
Expand All @@ -152,7 +186,7 @@ func (s *Server) Shutdown(ctx context.Context) error {
//
// SQS_ENDPOINT can be set as an environment variable in order to
// override the aws.Client's Configured Endpoint
func NewServer(queueURL string, cl int, retryTimeout int64) (msg.Server, error) {
func NewServer(queueURL string, cl int, retryTimeout int64, opts ...Option) (msg.Server, error) {
// It makes no sense to have a concurrency of less than 1.
if cl < 1 {
log.Printf("[WARN] Requesting concurrency of %d, this makes no sense, setting to 1\n", cl)
Expand All @@ -176,6 +210,10 @@ func NewServer(queueURL string, cl int, retryTimeout int64) (msg.Server, error)
if url := os.Getenv("SQS_ENDPOINT"); url != "" {
conf.Endpoint = aws.String(url)
}
conf.Retryer = DefaultRetryer{
Retryer: client.DefaultRetryer{NumMaxRetries: 7},
delay: 2 * time.Second,
}

// Create an SQS Client with creds from the Environment
svc := sqs.New(sess, conf)
Expand All @@ -192,6 +230,54 @@ func NewServer(queueURL string, cl int, retryTimeout int64) (msg.Server, error)
serverCancelFunc: serverCancelFunc,
receiverCtx: receiverCtx,
receiverCancelFunc: receiverCancelFunc,
session: sess,
}

for _, opt := range opts {
if err = opt(srv); err != nil {
return nil, fmt.Errorf("Failed setting option: %s", err)
}
}

return srv, nil
}

func getConf(s *Server) (*aws.Config, error) {
sqs, ok := s.Svc.(*sqs.SQS)
if !ok {
return nil, errors.New("`Svc` could not be casted to a SQS client")
}
return &sqs.Client.Config, nil
}

// WithCustomRetryer sets a custom `Retryer` to use on the SQS client.
func WithCustomRetryer(r request.Retryer) Option {
return func(s *Server) error {
c, err := getConf(s)
if err != nil {
return err
}
c.Retryer = r
s.Svc = sqs.New(s.session, c)
return nil
}
}

// WithRetries makes the `Server` retry on credential errors until
// `max` attempts with `delay` seconds between requests.
// This is needed in scenarios where credentials are automatically generated
// and the program starts before AWS finishes propagating them
func WithRetries(delay time.Duration, max int) Option {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I'm misunderstanding or not but I think what's happening here is:

  • 403s simply apply a delay between retries, though there is no maximum number of retries
  • 500s (specific codes are specified by client.DefaultRetryer) simply specify the maximum number of retries, though the delay is determined by the aws-sdk

Is that right?

Copy link
Contributor Author

@elmarcoh elmarcoh Apr 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The maximum number of retries is the same for all requests, it's defined by the request.Retryer.MaxRetries() method.
For 403 we simply set a different, longer delay between retries.

Originally 403 statuscodes were not retried, that's the normal DefaultRetrier behavior, thats why in the retryer implementation we have to override the ShouldRetry method.

Copy link
Contributor

@Xopherus Xopherus Apr 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I wasn't sure if we should have checked if we exceeded MaxRetries() in the ShouldRetry func or not. I did a bit more digging into the aws-sdk so it looks like that it's not the case. I assume clients of a Retryer will continue calling ShouldRetry until you reach MaxRetries() and then it stops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the MaxRetries() check it's done outside the Retryer implementation/interface

return func(s *Server) error {
c, err := getConf(s)
if err != nil {
return err
}
c.Retryer = DefaultRetryer{
Retryer: client.DefaultRetryer{NumMaxRetries: max},
delay: delay,
}
s.Svc = sqs.New(s.session, c)
return nil
}
}
63 changes: 63 additions & 0 deletions sqs/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -78,6 +83,64 @@ func TestServer_Serve(t *testing.T) {
}
}

func TestServer_Serve_retries(t *testing.T) {
retries := make([]*http.Request, 0, 3)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, _ := ioutil.ReadAll(r.Body)
t.Logf("Request: %s\n", b)
retries = append(retries, r)
w.WriteHeader(403)
fmt.Fprintln(w, `<?xml version="1.0"?><ErrorResponse xmlns="http://queue.amazonaws.com/doc/2012-11-05/"><Error><Type>Sender</Type><Code>InvalidClientTokenId</Code><Message>The security token included in the request is invalid.</Message><Detail/></Error><RequestId>ee1c20d5-2537-5e47-97b1-73909c83231a</RequestId></ErrorResponse>`)
}))
defer ts.Close()

os.Setenv("SQS_ENDPOINT", ts.URL)
os.Setenv("AWS_ACCESS_KEY_ID", "AKIyJLQDLOCKWMFHfake")
os.Setenv("AWS_SECRET_ACCESS_KEY", "T1PERSo63zFp1q5AGkGERmqOLQNZGfFu6iqAfake")

defer func() {
os.Unsetenv("SQS_ENDPOINT")
os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
}()

cases := []struct {
name string
options []Option
numTries int
}{
{"default", nil, 8},
{"1 retry", []Option{WithRetries(0, 1)}, 2},
{"No retries", []Option{WithRetries(0, 0)}, 1},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
retries = make([]*http.Request, 0, 3)
srv, err := NewServer(ts.URL+"/queue", 1, 1, c.options...)
if err != nil {
t.Errorf("Server creation should not fail: %s", err)
}
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
srv.Shutdown(ctx)
cancel()
}()

r := &SimpleReceiver{t: t}
err = srv.Serve(r)
if strings.Index(err.Error(), "InvalidClientTokenId: The security token included in the request is invalid") != 0 {
t.Errorf("Expected error message to start with `InvalidClientTokenId: The security token included in the request is invalid`, was `%s`", err.Error())
}

t.Logf("retries: %v", retries)
if len(retries) != c.numTries {
t.Errorf("It should try %d times before failing, was %d", c.numTries, len(retries))
}
})
}
}

// TestServer_ServeConcurrency tests that an SQS server can process a lot of
// messages using many concurrent goroutines.
func TestServer_Concurrency(t *testing.T) {
Expand Down