diff --git a/sqs/mock.go b/sqs/mock.go index 65dbfb4..62ba386 100644 --- a/sqs/mock.go +++ b/sqs/mock.go @@ -103,13 +103,13 @@ func newMockServer(concurrency int, mockSQS *mockSQSAPI) *Server { receiverCtx, receiverCancelFunc := context.WithCancel(context.Background()) srv := &Server{ - QueueURL: "https://myqueue.com", - Svc: mockSQS, maxConcurrentReceives: make(chan struct{}, concurrency), receiverCtx: receiverCtx, receiverCancelFunc: receiverCancelFunc, serverCtx: serverCtx, serverCancelFunc: serverCancelFunc, + QueueURL: "https://myqueue.com", + Svc: mockSQS, } return srv diff --git a/sqs/server.go b/sqs/server.go index b1d0cb2..fead591 100644 --- a/sqs/server.go +++ b/sqs/server.go @@ -3,6 +3,8 @@ package sqs import ( "bytes" "context" + "errors" + "fmt" "log" "os" "time" @@ -10,7 +12,9 @@ import ( "github.com/zerofox-oss/go-msg" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sqs" "github.com/aws/aws-sdk-go/service/sqs/sqsiface" @@ -31,6 +35,7 @@ type Server struct { receiverCancelFunc context.CancelFunc // CancelFunc for all receiver routines serverCtx context.Context // context used to control the life of the Server serverCancelFunc context.CancelFunc // CancelFunc to signal the server should stop requesting messages + session *session.Session // session used to re-create `Svc` when needed } // convertToMsgAttrs creates msg.Attributes from sqs.Message.Attributes @@ -144,6 +149,35 @@ func (s *Server) Shutdown(ctx context.Context) error { } } +// DefaultRetryer implements an AWS `request.Retryer` that has a custom delay +// for credential errors (403 statuscode). +// This is needed in order to wait for credentials to be valid for SQS requests +// due to AWS "eventually consistent" credentials: +// https://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_general.html +type DefaultRetryer struct { + request.Retryer + delay time.Duration +} + +// RetryRules returns the delay for the next request to be made +func (r DefaultRetryer) RetryRules(req *request.Request) time.Duration { + if req.HTTPResponse.StatusCode == 403 { + return r.delay + } + return r.Retryer.RetryRules(req) +} + +// ShouldRetry determines if the passed request should be retried +func (r DefaultRetryer) ShouldRetry(req *request.Request) bool { + if req.HTTPResponse.StatusCode == 403 { + return true + } + return r.Retryer.ShouldRetry(req) +} + +// Option is the signature that modifies a `Server` to set some configuration +type Option func(*Server) error + // NewServer creates and initializes a new Server using queueURL to a SQS queue // `cl` represents the number of concurrent message receives (10 msgs each). // @@ -152,7 +186,7 @@ func (s *Server) Shutdown(ctx context.Context) error { // // SQS_ENDPOINT can be set as an environment variable in order to // override the aws.Client's Configured Endpoint -func NewServer(queueURL string, cl int, retryTimeout int64) (msg.Server, error) { +func NewServer(queueURL string, cl int, retryTimeout int64, opts ...Option) (msg.Server, error) { // It makes no sense to have a concurrency of less than 1. if cl < 1 { log.Printf("[WARN] Requesting concurrency of %d, this makes no sense, setting to 1\n", cl) @@ -176,6 +210,10 @@ func NewServer(queueURL string, cl int, retryTimeout int64) (msg.Server, error) if url := os.Getenv("SQS_ENDPOINT"); url != "" { conf.Endpoint = aws.String(url) } + conf.Retryer = DefaultRetryer{ + Retryer: client.DefaultRetryer{NumMaxRetries: 7}, + delay: 2 * time.Second, + } // Create an SQS Client with creds from the Environment svc := sqs.New(sess, conf) @@ -192,6 +230,54 @@ func NewServer(queueURL string, cl int, retryTimeout int64) (msg.Server, error) serverCancelFunc: serverCancelFunc, receiverCtx: receiverCtx, receiverCancelFunc: receiverCancelFunc, + session: sess, + } + + for _, opt := range opts { + if err = opt(srv); err != nil { + return nil, fmt.Errorf("Failed setting option: %s", err) + } } + return srv, nil } + +func getConf(s *Server) (*aws.Config, error) { + sqs, ok := s.Svc.(*sqs.SQS) + if !ok { + return nil, errors.New("`Svc` could not be casted to a SQS client") + } + return &sqs.Client.Config, nil +} + +// WithCustomRetryer sets a custom `Retryer` to use on the SQS client. +func WithCustomRetryer(r request.Retryer) Option { + return func(s *Server) error { + c, err := getConf(s) + if err != nil { + return err + } + c.Retryer = r + s.Svc = sqs.New(s.session, c) + return nil + } +} + +// WithRetries makes the `Server` retry on credential errors until +// `max` attempts with `delay` seconds between requests. +// This is needed in scenarios where credentials are automatically generated +// and the program starts before AWS finishes propagating them +func WithRetries(delay time.Duration, max int) Option { + return func(s *Server) error { + c, err := getConf(s) + if err != nil { + return err + } + c.Retryer = DefaultRetryer{ + Retryer: client.DefaultRetryer{NumMaxRetries: max}, + delay: delay, + } + s.Svc = sqs.New(s.session, c) + return nil + } +} diff --git a/sqs/server_test.go b/sqs/server_test.go index 7ebfb87..c35a6c5 100644 --- a/sqs/server_test.go +++ b/sqs/server_test.go @@ -4,6 +4,11 @@ import ( "context" "errors" "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "strings" "testing" "time" @@ -78,6 +83,64 @@ func TestServer_Serve(t *testing.T) { } } +func TestServer_Serve_retries(t *testing.T) { + retries := make([]*http.Request, 0, 3) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := ioutil.ReadAll(r.Body) + t.Logf("Request: %s\n", b) + retries = append(retries, r) + w.WriteHeader(403) + fmt.Fprintln(w, `SenderInvalidClientTokenIdThe security token included in the request is invalid.ee1c20d5-2537-5e47-97b1-73909c83231a`) + })) + defer ts.Close() + + os.Setenv("SQS_ENDPOINT", ts.URL) + os.Setenv("AWS_ACCESS_KEY_ID", "AKIyJLQDLOCKWMFHfake") + os.Setenv("AWS_SECRET_ACCESS_KEY", "T1PERSo63zFp1q5AGkGERmqOLQNZGfFu6iqAfake") + + defer func() { + os.Unsetenv("SQS_ENDPOINT") + os.Unsetenv("AWS_ACCESS_KEY_ID") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + }() + + cases := []struct { + name string + options []Option + numTries int + }{ + {"default", nil, 8}, + {"1 retry", []Option{WithRetries(0, 1)}, 2}, + {"No retries", []Option{WithRetries(0, 0)}, 1}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + retries = make([]*http.Request, 0, 3) + srv, err := NewServer(ts.URL+"/queue", 1, 1, c.options...) + if err != nil { + t.Errorf("Server creation should not fail: %s", err) + } + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + srv.Shutdown(ctx) + cancel() + }() + + r := &SimpleReceiver{t: t} + err = srv.Serve(r) + if strings.Index(err.Error(), "InvalidClientTokenId: The security token included in the request is invalid") != 0 { + t.Errorf("Expected error message to start with `InvalidClientTokenId: The security token included in the request is invalid`, was `%s`", err.Error()) + } + + t.Logf("retries: %v", retries) + if len(retries) != c.numTries { + t.Errorf("It should try %d times before failing, was %d", c.numTries, len(retries)) + } + }) + } +} + // TestServer_ServeConcurrency tests that an SQS server can process a lot of // messages using many concurrent goroutines. func TestServer_Concurrency(t *testing.T) {