Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retryJitter option to add jitter to retryTimeout #22

Merged
merged 1 commit into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqs/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func newMockServer(concurrency int, mockSQS *mockSQSAPI) *Server {
serverCancelFunc: serverCancelFunc,
QueueURL: "https://myqueue.com",
Svc: mockSQS,
retryTimeout: 100,
}

return srv
Expand Down
41 changes: 40 additions & 1 deletion sqs/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package sqs
import (
"bytes"
"context"
crand "crypto/rand"
"encoding/binary"
"errors"
"fmt"
"log"
"math/rand"
"os"
"time"

Expand All @@ -21,6 +24,15 @@ import (
msg "github.com/zerofox-oss/go-msg"
)

func init() {
var b [8]byte
_, err := crand.Read(b[:])
if err != nil {
panic("cannot seed math/rand package with cryptographically secure random number generator")
}
rand.Seed(int64(binary.LittleEndian.Uint64(b[:])))
}

// Server represents a msg.Server for receiving messages
// from an AWS SQS Queue
type Server struct {
Expand All @@ -31,6 +43,7 @@ type Server struct {

maxConcurrentReceives chan struct{} // The maximum number of message processing routines allowed
retryTimeout int64 // Visbility Timeout for a message when a receiver fails
retryJitter int64

receiverCtx context.Context // context used to control the life of receivers
receiverCancelFunc context.CancelFunc // CancelFunc for all receiver routines
Expand Down Expand Up @@ -102,7 +115,7 @@ func (s *Server) Serve(r msg.Receiver) error {
params := &sqs.ChangeMessageVisibilityInput{
QueueUrl: aws.String(s.QueueURL),
ReceiptHandle: sqsMsg.ReceiptHandle,
VisibilityTimeout: aws.Int64(s.retryTimeout),
VisibilityTimeout: aws.Int64(getVisiblityTimeout(s.retryTimeout, s.retryJitter)),
}
if _, err := s.Svc.ChangeMessageVisibility(params); err != nil {
log.Printf("[ERROR] cannot change message visibility %s", err)
Expand All @@ -125,6 +138,14 @@ func (s *Server) Serve(r msg.Receiver) error {
}
}

func getVisiblityTimeout(retryTimeout int64, retryJitter int64) int64 {
if retryJitter > retryTimeout {
panic("jitter must be less than or equal to retryTimeout")
}
minRetry, maxRetry := retryTimeout-retryJitter, retryTimeout+retryJitter
return int64(rand.Intn(int(maxRetry-minRetry)+1) + int(minRetry))
}

const shutdownPollInterval = 500 * time.Millisecond

// Shutdown stops the receipt of new messages and waits for routines
Expand Down Expand Up @@ -260,3 +281,21 @@ func WithRetries(delay time.Duration, max int) Option {
return nil
}
}

// WithRetryJitter sets a value for Jitter on the VisibilityTimeout.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// With jitter applied every message that needs to be retried will
// have a visibility timeout in the interval:
// [(visibilityTimeout - jitter), visibilityTimeout + jitter)]
func WithRetryJitter(retryJitter int64) Option {
return func(s *Server) error {
if retryJitter > s.retryTimeout {
return fmt.Errorf(
"invalid jitter: %d. Jitter must be less or equal to the retryTimeout (%d)",
retryJitter,
s.retryTimeout,
)
}
s.retryJitter = retryJitter
return nil
}
}
60 changes: 60 additions & 0 deletions sqs/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,63 @@ func TestServer_ShutdownHard(t *testing.T) {
t.Errorf("Expected context.DeadlineExceeded, got %v", err)
}
}

func TestWithRetryJitter_SetsValidJitter(t *testing.T) {
jitter := 10
msgs := newSQSMessages(0)
srv := newMockServer(1, newMockSQSAPI(msgs, t))
optionFunc := WithRetryJitter(int64(jitter))
err := optionFunc(srv)
if err != nil {
t.Errorf("Unexpected error %s", err)
}
if srv.retryJitter != int64(jitter) {
t.Errorf("Expected retryJitter to be %d", jitter)
}
}

func TestWithRetryJitter_ErrorOnInvalidJitter(t *testing.T) {
jitter := 1000
msgs := newSQSMessages(0)
srv := newMockServer(1, newMockSQSAPI(msgs, t))
optionFunc := WithRetryJitter(int64(jitter))
err := optionFunc(srv)
if err == nil {
t.Errorf("Expected error, received nil")
}
if !strings.HasPrefix(err.Error(), "invalid jitter:") {
t.Errorf("expected error to start with 'invalid jitter:', error is '%s'", err)
}
}

func TestGetVisiblityTimeout_NoJitter(t *testing.T) {
var retryTimeout int64 = 100
var jitter int64 = 0
val := getVisiblityTimeout(retryTimeout, jitter)
if val < (retryTimeout-jitter) || val > (retryTimeout+jitter) {
t.Errorf("val should be in the interval %d±%d", retryTimeout, jitter)
}
}

func TestGetVisiblityTimeout_ValidJitter(t *testing.T) {
var retryTimeout int64 = 100
var jitter int64 = 10
val := getVisiblityTimeout(retryTimeout, jitter)
if val < (retryTimeout-jitter) || val > (retryTimeout+jitter) {
t.Errorf("val should be in the interval %d±%d", retryTimeout, jitter)
}
}

func TestGetVisiblityTimeout_InvalidJitter(t *testing.T) {
var retryTimeout int64 = 100
var jitter int64 = 1000
defer func() {
if r := recover(); r == nil {
t.Errorf("Expected panic")
}
}()
val := getVisiblityTimeout(retryTimeout, jitter)
if val < (retryTimeout-jitter) || val > (retryTimeout+jitter) {
t.Errorf("val should be in the interval %d±%d", retryTimeout, jitter)
}
}