Skip to content

Commit

Permalink
Merge pull request #11 from marcrosis/cred-retry
Browse files Browse the repository at this point in the history
Retry connection on credential errors
  • Loading branch information
Xopherus authored Apr 11, 2018
2 parents 6c140a3 + 4a5606f commit ae8441a
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 3 deletions.
4 changes: 2 additions & 2 deletions sqs/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ func newMockServer(concurrency int, mockSQS *mockSQSAPI) *Server {
receiverCtx, receiverCancelFunc := context.WithCancel(context.Background())

srv := &Server{
QueueURL: "https://myqueue.com",
Svc: mockSQS,
maxConcurrentReceives: make(chan struct{}, concurrency),
receiverCtx: receiverCtx,
receiverCancelFunc: receiverCancelFunc,
serverCtx: serverCtx,
serverCancelFunc: serverCancelFunc,
QueueURL: "https://myqueue.com",
Svc: mockSQS,
}

return srv
Expand Down
88 changes: 87 additions & 1 deletion sqs/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ package sqs
import (
"bytes"
"context"
"errors"
"fmt"
"log"
"os"
"time"

"github.com/zerofox-oss/go-msg"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
Expand All @@ -31,6 +35,7 @@ type Server struct {
receiverCancelFunc context.CancelFunc // CancelFunc for all receiver routines
serverCtx context.Context // context used to control the life of the Server
serverCancelFunc context.CancelFunc // CancelFunc to signal the server should stop requesting messages
session *session.Session // session used to re-create `Svc` when needed
}

// convertToMsgAttrs creates msg.Attributes from sqs.Message.Attributes
Expand Down Expand Up @@ -144,6 +149,35 @@ func (s *Server) Shutdown(ctx context.Context) error {
}
}

// DefaultRetryer implements an AWS `request.Retryer` that has a custom delay
// for credential errors (403 statuscode).
// This is needed in order to wait for credentials to be valid for SQS requests
// due to AWS "eventually consistent" credentials:
// https://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_general.html
type DefaultRetryer struct {
request.Retryer
delay time.Duration
}

// RetryRules returns the delay for the next request to be made
func (r DefaultRetryer) RetryRules(req *request.Request) time.Duration {
if req.HTTPResponse.StatusCode == 403 {
return r.delay
}
return r.Retryer.RetryRules(req)
}

// ShouldRetry determines if the passed request should be retried
func (r DefaultRetryer) ShouldRetry(req *request.Request) bool {
if req.HTTPResponse.StatusCode == 403 {
return true
}
return r.Retryer.ShouldRetry(req)
}

// Option is the signature that modifies a `Server` to set some configuration
type Option func(*Server) error

// NewServer creates and initializes a new Server using queueURL to a SQS queue
// `cl` represents the number of concurrent message receives (10 msgs each).
//
Expand All @@ -152,7 +186,7 @@ func (s *Server) Shutdown(ctx context.Context) error {
//
// SQS_ENDPOINT can be set as an environment variable in order to
// override the aws.Client's Configured Endpoint
func NewServer(queueURL string, cl int, retryTimeout int64) (msg.Server, error) {
func NewServer(queueURL string, cl int, retryTimeout int64, opts ...Option) (msg.Server, error) {
// It makes no sense to have a concurrency of less than 1.
if cl < 1 {
log.Printf("[WARN] Requesting concurrency of %d, this makes no sense, setting to 1\n", cl)
Expand All @@ -176,6 +210,10 @@ func NewServer(queueURL string, cl int, retryTimeout int64) (msg.Server, error)
if url := os.Getenv("SQS_ENDPOINT"); url != "" {
conf.Endpoint = aws.String(url)
}
conf.Retryer = DefaultRetryer{
Retryer: client.DefaultRetryer{NumMaxRetries: 7},
delay: 2 * time.Second,
}

// Create an SQS Client with creds from the Environment
svc := sqs.New(sess, conf)
Expand All @@ -192,6 +230,54 @@ func NewServer(queueURL string, cl int, retryTimeout int64) (msg.Server, error)
serverCancelFunc: serverCancelFunc,
receiverCtx: receiverCtx,
receiverCancelFunc: receiverCancelFunc,
session: sess,
}

for _, opt := range opts {
if err = opt(srv); err != nil {
return nil, fmt.Errorf("Failed setting option: %s", err)
}
}

return srv, nil
}

func getConf(s *Server) (*aws.Config, error) {
sqs, ok := s.Svc.(*sqs.SQS)
if !ok {
return nil, errors.New("`Svc` could not be casted to a SQS client")
}
return &sqs.Client.Config, nil
}

// WithCustomRetryer sets a custom `Retryer` to use on the SQS client.
func WithCustomRetryer(r request.Retryer) Option {
return func(s *Server) error {
c, err := getConf(s)
if err != nil {
return err
}
c.Retryer = r
s.Svc = sqs.New(s.session, c)
return nil
}
}

// WithRetries makes the `Server` retry on credential errors until
// `max` attempts with `delay` seconds between requests.
// This is needed in scenarios where credentials are automatically generated
// and the program starts before AWS finishes propagating them
func WithRetries(delay time.Duration, max int) Option {
return func(s *Server) error {
c, err := getConf(s)
if err != nil {
return err
}
c.Retryer = DefaultRetryer{
Retryer: client.DefaultRetryer{NumMaxRetries: max},
delay: delay,
}
s.Svc = sqs.New(s.session, c)
return nil
}
}
63 changes: 63 additions & 0 deletions sqs/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -78,6 +83,64 @@ func TestServer_Serve(t *testing.T) {
}
}

func TestServer_Serve_retries(t *testing.T) {
retries := make([]*http.Request, 0, 3)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, _ := ioutil.ReadAll(r.Body)
t.Logf("Request: %s\n", b)
retries = append(retries, r)
w.WriteHeader(403)
fmt.Fprintln(w, `<?xml version="1.0"?><ErrorResponse xmlns="http://queue.amazonaws.com/doc/2012-11-05/"><Error><Type>Sender</Type><Code>InvalidClientTokenId</Code><Message>The security token included in the request is invalid.</Message><Detail/></Error><RequestId>ee1c20d5-2537-5e47-97b1-73909c83231a</RequestId></ErrorResponse>`)
}))
defer ts.Close()

os.Setenv("SQS_ENDPOINT", ts.URL)
os.Setenv("AWS_ACCESS_KEY_ID", "AKIyJLQDLOCKWMFHfake")
os.Setenv("AWS_SECRET_ACCESS_KEY", "T1PERSo63zFp1q5AGkGERmqOLQNZGfFu6iqAfake")

defer func() {
os.Unsetenv("SQS_ENDPOINT")
os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
}()

cases := []struct {
name string
options []Option
numTries int
}{
{"default", nil, 8},
{"1 retry", []Option{WithRetries(0, 1)}, 2},
{"No retries", []Option{WithRetries(0, 0)}, 1},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
retries = make([]*http.Request, 0, 3)
srv, err := NewServer(ts.URL+"/queue", 1, 1, c.options...)
if err != nil {
t.Errorf("Server creation should not fail: %s", err)
}
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
srv.Shutdown(ctx)
cancel()
}()

r := &SimpleReceiver{t: t}
err = srv.Serve(r)
if strings.Index(err.Error(), "InvalidClientTokenId: The security token included in the request is invalid") != 0 {
t.Errorf("Expected error message to start with `InvalidClientTokenId: The security token included in the request is invalid`, was `%s`", err.Error())
}

t.Logf("retries: %v", retries)
if len(retries) != c.numTries {
t.Errorf("It should try %d times before failing, was %d", c.numTries, len(retries))
}
})
}
}

// TestServer_ServeConcurrency tests that an SQS server can process a lot of
// messages using many concurrent goroutines.
func TestServer_Concurrency(t *testing.T) {
Expand Down

0 comments on commit ae8441a

Please sign in to comment.