diff --git a/sqs/mock.go b/sqs/mock.go index 62ba386..3e29ef3 100644 --- a/sqs/mock.go +++ b/sqs/mock.go @@ -110,6 +110,7 @@ func newMockServer(concurrency int, mockSQS *mockSQSAPI) *Server { serverCancelFunc: serverCancelFunc, QueueURL: "https://myqueue.com", Svc: mockSQS, + retryTimeout: 100, } return srv diff --git a/sqs/server.go b/sqs/server.go index e579ee5..5fa0c52 100644 --- a/sqs/server.go +++ b/sqs/server.go @@ -3,9 +3,12 @@ package sqs import ( "bytes" "context" + crand "crypto/rand" + "encoding/binary" "errors" "fmt" "log" + "math/rand" "os" "time" @@ -21,6 +24,15 @@ import ( msg "github.com/zerofox-oss/go-msg" ) +func init() { + var b [8]byte + _, err := crand.Read(b[:]) + if err != nil { + panic("cannot seed math/rand package with cryptographically secure random number generator") + } + rand.Seed(int64(binary.LittleEndian.Uint64(b[:]))) +} + // Server represents a msg.Server for receiving messages // from an AWS SQS Queue type Server struct { @@ -31,6 +43,7 @@ type Server struct { maxConcurrentReceives chan struct{} // The maximum number of message processing routines allowed retryTimeout int64 // Visbility Timeout for a message when a receiver fails + retryJitter int64 receiverCtx context.Context // context used to control the life of receivers receiverCancelFunc context.CancelFunc // CancelFunc for all receiver routines @@ -102,7 +115,7 @@ func (s *Server) Serve(r msg.Receiver) error { params := &sqs.ChangeMessageVisibilityInput{ QueueUrl: aws.String(s.QueueURL), ReceiptHandle: sqsMsg.ReceiptHandle, - VisibilityTimeout: aws.Int64(s.retryTimeout), + VisibilityTimeout: aws.Int64(getVisiblityTimeout(s.retryTimeout, s.retryJitter)), } if _, err := s.Svc.ChangeMessageVisibility(params); err != nil { log.Printf("[ERROR] cannot change message visibility %s", err) @@ -125,6 +138,14 @@ func (s *Server) Serve(r msg.Receiver) error { } } +func getVisiblityTimeout(retryTimeout int64, retryJitter int64) int64 { + if retryJitter > retryTimeout { + panic("jitter must be less than or equal to retryTimeout") + } + minRetry, maxRetry := retryTimeout-retryJitter, retryTimeout+retryJitter + return int64(rand.Intn(int(maxRetry-minRetry)+1) + int(minRetry)) +} + const shutdownPollInterval = 500 * time.Millisecond // Shutdown stops the receipt of new messages and waits for routines @@ -260,3 +281,21 @@ func WithRetries(delay time.Duration, max int) Option { return nil } } + +// WithRetryJitter sets a value for Jitter on the VisibilityTimeout. +// With jitter applied every message that needs to be retried will +// have a visibility timeout in the interval: +// [(visibilityTimeout - jitter), visibilityTimeout + jitter)] +func WithRetryJitter(retryJitter int64) Option { + return func(s *Server) error { + if retryJitter > s.retryTimeout { + return fmt.Errorf( + "invalid jitter: %d. Jitter must be less or equal to the retryTimeout (%d)", + retryJitter, + s.retryTimeout, + ) + } + s.retryJitter = retryJitter + return nil + } +} diff --git a/sqs/server_test.go b/sqs/server_test.go index f89b3e1..18cf427 100644 --- a/sqs/server_test.go +++ b/sqs/server_test.go @@ -249,3 +249,63 @@ func TestServer_ShutdownHard(t *testing.T) { t.Errorf("Expected context.DeadlineExceeded, got %v", err) } } + +func TestWithRetryJitter_SetsValidJitter(t *testing.T) { + jitter := 10 + msgs := newSQSMessages(0) + srv := newMockServer(1, newMockSQSAPI(msgs, t)) + optionFunc := WithRetryJitter(int64(jitter)) + err := optionFunc(srv) + if err != nil { + t.Errorf("Unexpected error %s", err) + } + if srv.retryJitter != int64(jitter) { + t.Errorf("Expected retryJitter to be %d", jitter) + } +} + +func TestWithRetryJitter_ErrorOnInvalidJitter(t *testing.T) { + jitter := 1000 + msgs := newSQSMessages(0) + srv := newMockServer(1, newMockSQSAPI(msgs, t)) + optionFunc := WithRetryJitter(int64(jitter)) + err := optionFunc(srv) + if err == nil { + t.Errorf("Expected error, received nil") + } + if !strings.HasPrefix(err.Error(), "invalid jitter:") { + t.Errorf("expected error to start with 'invalid jitter:', error is '%s'", err) + } +} + +func TestGetVisiblityTimeout_NoJitter(t *testing.T) { + var retryTimeout int64 = 100 + var jitter int64 = 0 + val := getVisiblityTimeout(retryTimeout, jitter) + if val < (retryTimeout-jitter) || val > (retryTimeout+jitter) { + t.Errorf("val should be in the interval %d±%d", retryTimeout, jitter) + } +} + +func TestGetVisiblityTimeout_ValidJitter(t *testing.T) { + var retryTimeout int64 = 100 + var jitter int64 = 10 + val := getVisiblityTimeout(retryTimeout, jitter) + if val < (retryTimeout-jitter) || val > (retryTimeout+jitter) { + t.Errorf("val should be in the interval %d±%d", retryTimeout, jitter) + } +} + +func TestGetVisiblityTimeout_InvalidJitter(t *testing.T) { + var retryTimeout int64 = 100 + var jitter int64 = 1000 + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic") + } + }() + val := getVisiblityTimeout(retryTimeout, jitter) + if val < (retryTimeout-jitter) || val > (retryTimeout+jitter) { + t.Errorf("val should be in the interval %d±%d", retryTimeout, jitter) + } +}