Skip to content

Commit

Permalink
Add default retrier to the SNS implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
elmarcoh committed Apr 16, 2018
1 parent d375be8 commit 7ae1dfa
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 12 deletions.
92 changes: 81 additions & 11 deletions sns/topic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package sns
import (
"bytes"
"context"
"errors"
"fmt"
"log"
"os"
"strings"
Expand All @@ -12,12 +14,12 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sns"
"github.com/aws/aws-sdk-go/service/sns/snsiface"

"github.com/zerofox-oss/go-aws-msg/retryer"
"github.com/zerofox-oss/go-msg"
msg "github.com/zerofox-oss/go-msg"
b64 "github.com/zerofox-oss/go-msg/decorators/base64"
)
Expand All @@ -26,6 +28,50 @@ import (
type Topic struct {
Svc snsiface.SNSAPI
TopicARN string
session *session.Session
}

func getConf(t *Topic) (*aws.Config, error) {
svc, ok := t.Svc.(*sns.SNS)
if !ok {
return nil, errors.New("Svc could not be casted to a SNS client")
}
return &svc.Client.Config, nil
}

// Option is the signature that modifies a `Topic` to set some configuration
type Option func(*Topic) error

// WithCustomRetryer sets a custom `Retryer` to use on the SQS client.
func WithCustomRetryer(r request.Retryer) Option {
return func(t *Topic) error {
c, err := getConf(t)
if err != nil {
return err
}
c.Retryer = r
t.Svc = sns.New(t.session, c)
return nil
}
}

// WithRetries makes the `Server` retry on credential errors until
// `max` attempts with `delay` seconds between requests.
// This is needed in scenarios where credentials are automatically generated
// and the program starts before AWS finishes propagating them
func WithRetries(delay time.Duration, max int) Option {
return func(t *Topic) error {
c, err := getConf(t)
if err != nil {
return err
}
c.Retryer = retryer.DefaultRetryer{
Retryer: client.DefaultRetryer{NumMaxRetries: max},
Delay: delay,
}
t.Svc = sns.New(t.session, c)
return nil
}
}

// NewTopic returns a sns.Topic with fully configured SNSAPI.
Expand All @@ -36,7 +82,7 @@ type Topic struct {
// messages are base64-encoded as a best practice.
//
// You may use NewUnencodedTopic if you wish to ignore the encoding step.
func NewTopic(topicARN string) (msg.Topic, error) {
func NewTopic(topicARN string, opts ...Option) (msg.Topic, error) {
sess, err := session.NewSession()
if err != nil {
return nil, err
Expand All @@ -45,10 +91,6 @@ func NewTopic(topicARN string) (msg.Topic, error) {
conf := &aws.Config{
Credentials: credentials.NewCredentials(&credentials.EnvProvider{}),
Region: aws.String("us-west-2"),
Retryer: retryer.DefaultRetryer{
Retryer: client.DefaultRetryer{NumMaxRetries: 7},
Delay: 2 * time.Second,
},
}

// You may override AWS_REGION, SNS_ENDPOINT
Expand All @@ -60,17 +102,31 @@ func NewTopic(topicARN string) (msg.Topic, error) {
conf.Endpoint = aws.String(url)
}

return b64.Encoder(&Topic{
t := &Topic{
Svc: sns.New(sess, conf),
TopicARN: topicARN,
}), nil
session: sess,
}

// Default retryer
if err = WithRetries(2*time.Second, 7)(t); err != nil {
return nil, err
}

for _, opt := range opts {
if err = opt(t); err != nil {
return nil, fmt.Errorf("cannot set option: %s", err)
}
}

return b64.Encoder(t), nil
}

// NewUnencodedTopic creates an concrete SNS msg.Topic
//
// Messages published by the `Topic` returned will not
// have the body base64-encoded.
func NewUnencodedTopic(topicARN string) (msg.Topic, error) {
func NewUnencodedTopic(topicARN string, opts ...Option) (msg.Topic, error) {
sess, err := session.NewSession()
if err != nil {
return nil, err
Expand All @@ -86,10 +142,24 @@ func NewUnencodedTopic(topicARN string) (msg.Topic, error) {
conf.Endpoint = aws.String(url)
}

return &Topic{
t := &Topic{
Svc: sns.New(sess, conf),
TopicARN: topicARN,
}, nil
session: sess,
}

// Default retryer
if err = WithRetries(2*time.Second, 7)(t); err != nil {
return nil, err
}

for _, opt := range opts {
if err = opt(t); err != nil {
return nil, fmt.Errorf("cannot set option: %s", err)
}
}

return t, nil
}

// NewWriter returns a sns.MessageWriter instance for writing to
Expand Down
79 changes: 79 additions & 0 deletions sns/topic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@ package sns

import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sns"
Expand Down Expand Up @@ -131,3 +138,75 @@ func TestMessageWriter_CloseProperlyConstructsPublishInput(t *testing.T) {

<-control
}

type constructor func(string, ...Option) (msg.Topic, error)

func TestMessageWriter_Close_retryer(t *testing.T) {
retries := make([]*http.Request, 0, 3)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, _ := ioutil.ReadAll(r.Body)
t.Logf("Request: %s\n", b)
retries = append(retries, r)
w.WriteHeader(403)
fmt.Fprintln(w, `
<ErrorResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<Error>
<Type>Sender</Type>
<Code>InvalidClientTokenId</Code>
<Message>The security token included in the request is invalid.</Message>
</Error>
<RequestId>590d5457-e4b6-5464-a482-071900d4c7d6</RequestId>
</ErrorResponse>`)
}))
defer ts.Close()

os.Setenv("SNS_ENDPOINT", ts.URL)
os.Setenv("AWS_ACCESS_KEY_ID", "AKIyJLQDLOCKWMFHfake")
os.Setenv("AWS_SECRET_ACCESS_KEY", "T1PERSo63zFp1q5AGkGERmqOLQNZGfFu6iqAfake")

defer func() {
os.Unsetenv("SNS_ENDPOINT")
os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
}()

cases := []struct {
name string
newTopic constructor
options []Option
numTries int
}{
{"default", NewTopic, nil, 8},
{"1 retry", NewTopic, []Option{WithRetries(0, 1)}, 2},
{"No retries", NewTopic, []Option{WithRetries(0, 0)}, 1},
{"UnencodedTopic default", NewUnencodedTopic, nil, 8},
{"UnencodedTopic 1 retry", NewUnencodedTopic, []Option{WithRetries(0, 1)}, 2},
{"UnencodedTopic No retries", NewUnencodedTopic, []Option{WithRetries(0, 0)}, 1},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
retries = make([]*http.Request, 0, 3)
tpc, err := c.newTopic("arn:aws:sns:us-west-2:777777777777:test-sns", c.options...)
if err != nil {
t.Errorf("Server creation should not fail: %s", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
w := tpc.NewWriter(ctx)

w.Write([]byte("it's full of stars!"))
err = w.Close()

if strings.Index(err.Error(), "InvalidClientTokenId: The security token included in the request is invalid") != 0 {
t.Errorf("Expected error message to start with `InvalidClientTokenId: The security token included in the request is invalid`, was `%s`", err.Error())
}

t.Logf("retries: %v", retries)
if len(retries) != c.numTries {
t.Errorf("It should try %d times before failing, was %d", c.numTries, len(retries))
}
})
}
}
2 changes: 1 addition & 1 deletion sqs/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func NewServer(queueURL string, cl int, retryTimeout int64, opts ...Option) (msg

for _, opt := range opts {
if err = opt(srv); err != nil {
return nil, fmt.Errorf("Failed setting option: %s", err)
return nil, fmt.Errorf("cannot set option: %s", err)
}
}

Expand Down

0 comments on commit 7ae1dfa

Please sign in to comment.