Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DefaultRetryer to SNS topics #14

Merged
merged 2 commits into from
Apr 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions retryer/retryer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package retryer

import (
"time"

"github.com/aws/aws-sdk-go/aws/request"
)

// DefaultRetryer implements an AWS `request.Retryer` that has a custom delay
// for credential errors (403 statuscode).
// This is needed in order to wait for credentials to be valid for SQS requests
// due to AWS "eventually consistent" credentials:
// https://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_general.html
type DefaultRetryer struct {
request.Retryer
Delay time.Duration
}

// RetryRules returns the delay for the next request to be made
func (r DefaultRetryer) RetryRules(req *request.Request) time.Duration {
if req.HTTPResponse.StatusCode == 403 {
return r.Delay
}
return r.Retryer.RetryRules(req)
}

// ShouldRetry determines if the passed request should be retried
func (r DefaultRetryer) ShouldRetry(req *request.Request) bool {
if req.HTTPResponse.StatusCode == 403 {
return true
}
return r.Retryer.ShouldRetry(req)
}
101 changes: 89 additions & 12 deletions sns/topic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@ package sns
import (
"bytes"
"context"
"errors"
"fmt"
"log"
"os"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sns"
"github.com/aws/aws-sdk-go/service/sns/snsiface"

"github.com/zerofox-oss/go-aws-msg/retryer"
msg "github.com/zerofox-oss/go-msg"
b64 "github.com/zerofox-oss/go-msg/decorators/base64"
)
Expand All @@ -21,6 +28,50 @@ import (
type Topic struct {
Svc snsiface.SNSAPI
TopicARN string
session *session.Session
}

func getConf(t *Topic) (*aws.Config, error) {
svc, ok := t.Svc.(*sns.SNS)
if !ok {
return nil, errors.New("Svc could not be casted to a SNS client")
}
return &svc.Client.Config, nil
}

// Option is the signature that modifies a `Topic` to set some configuration
type Option func(*Topic) error

// WithCustomRetryer sets a custom `Retryer` to use on the SQS client.
func WithCustomRetryer(r request.Retryer) Option {
return func(t *Topic) error {
c, err := getConf(t)
if err != nil {
return err
}
c.Retryer = r
t.Svc = sns.New(t.session, c)
return nil
}
}

// WithRetries makes the `Server` retry on credential errors until
// `max` attempts with `delay` seconds between requests.
// This is needed in scenarios where credentials are automatically generated
// and the program starts before AWS finishes propagating them
func WithRetries(delay time.Duration, max int) Option {
return func(t *Topic) error {
c, err := getConf(t)
if err != nil {
return err
}
c.Retryer = retryer.DefaultRetryer{
Retryer: client.DefaultRetryer{NumMaxRetries: max},
Delay: delay,
}
t.Svc = sns.New(t.session, c)
return nil
}
}

// NewTopic returns a sns.Topic with fully configured SNSAPI.
Expand All @@ -31,7 +82,7 @@ type Topic struct {
// messages are base64-encoded as a best practice.
//
// You may use NewUnencodedTopic if you wish to ignore the encoding step.
func NewTopic(topicARN string) (msg.Topic, error) {
func NewTopic(topicARN string, opts ...Option) (msg.Topic, error) {
sess, err := session.NewSession()
if err != nil {
return nil, err
Expand All @@ -51,17 +102,31 @@ func NewTopic(topicARN string) (msg.Topic, error) {
conf.Endpoint = aws.String(url)
}

return b64.Encoder(&Topic{
t := &Topic{
Svc: sns.New(sess, conf),
TopicARN: topicARN,
}), nil
session: sess,
}

// Default retryer
if err = WithRetries(2*time.Second, 7)(t); err != nil {
return nil, err
}

for _, opt := range opts {
if err = opt(t); err != nil {
return nil, fmt.Errorf("cannot set option: %s", err)
}
}

return b64.Encoder(t), nil
}

// NewUnencodedTopic creates an concrete SNS msg.Topic
//
// Messages published by the `Topic` returned will not
// have the body base64-encoded.
func NewUnencodedTopic(topicARN string) (msg.Topic, error) {
func NewUnencodedTopic(topicARN string, opts ...Option) (msg.Topic, error) {
sess, err := session.NewSession()
if err != nil {
return nil, err
Expand All @@ -77,10 +142,24 @@ func NewUnencodedTopic(topicARN string) (msg.Topic, error) {
conf.Endpoint = aws.String(url)
}

return &Topic{
t := &Topic{
Svc: sns.New(sess, conf),
TopicARN: topicARN,
}, nil
session: sess,
}

// Default retryer
if err = WithRetries(2*time.Second, 7)(t); err != nil {
return nil, err
}

for _, opt := range opts {
if err = opt(t); err != nil {
return nil, fmt.Errorf("cannot set option: %s", err)
}
}

return t, nil
}

// NewWriter returns a sns.MessageWriter instance for writing to
Expand Down Expand Up @@ -124,7 +203,7 @@ func (w *MessageWriter) Close() error {
w.mux.Lock()
defer w.mux.Unlock()

if w.closed == true {
if w.closed {
return msg.ErrClosedMessageWriter
}
w.closed = true
Expand All @@ -137,10 +216,8 @@ func (w *MessageWriter) Close() error {
}

log.Printf("[TRACE] writing to sns: %v", snsPublishParams)
if _, err := w.snsClient.PublishWithContext(w.ctx, snsPublishParams); err != nil {
return err
}
return nil
_, err := w.snsClient.PublishWithContext(w.ctx, snsPublishParams)
return err
}

// Write writes data to the MessageWriter's internal buffer for aggregation
Expand All @@ -152,7 +229,7 @@ func (w *MessageWriter) Write(p []byte) (int, error) {
w.mux.Lock()
defer w.mux.Unlock()

if w.closed == true {
if w.closed {
return 0, msg.ErrClosedMessageWriter
}
return w.buf.Write(p)
Expand Down
79 changes: 79 additions & 0 deletions sns/topic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@ package sns

import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sns"
Expand Down Expand Up @@ -131,3 +138,75 @@ func TestMessageWriter_CloseProperlyConstructsPublishInput(t *testing.T) {

<-control
}

type constructor func(string, ...Option) (msg.Topic, error)

func TestMessageWriter_Close_retryer(t *testing.T) {
retries := make([]*http.Request, 0, 3)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, _ := ioutil.ReadAll(r.Body)
t.Logf("Request: %s\n", b)
retries = append(retries, r)
w.WriteHeader(403)
fmt.Fprintln(w, `
<ErrorResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<Error>
<Type>Sender</Type>
<Code>InvalidClientTokenId</Code>
<Message>The security token included in the request is invalid.</Message>
</Error>
<RequestId>590d5457-e4b6-5464-a482-071900d4c7d6</RequestId>
</ErrorResponse>`)
}))
defer ts.Close()

os.Setenv("SNS_ENDPOINT", ts.URL)
os.Setenv("AWS_ACCESS_KEY_ID", "fake")
os.Setenv("AWS_SECRET_ACCESS_KEY", "fake")

defer func() {
os.Unsetenv("SNS_ENDPOINT")
os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
}()

cases := []struct {
name string
newTopic constructor
options []Option
numTries int
}{
{"default", NewTopic, nil, 8},
{"1 retry", NewTopic, []Option{WithRetries(0, 1)}, 2},
{"No retries", NewTopic, []Option{WithRetries(0, 0)}, 1},
{"UnencodedTopic default", NewUnencodedTopic, nil, 8},
{"UnencodedTopic 1 retry", NewUnencodedTopic, []Option{WithRetries(0, 1)}, 2},
{"UnencodedTopic No retries", NewUnencodedTopic, []Option{WithRetries(0, 0)}, 1},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
retries = make([]*http.Request, 0, 3)
tpc, err := c.newTopic("arn:aws:sns:us-west-2:777777777777:test-sns", c.options...)
if err != nil {
t.Errorf("Server creation should not fail: %s", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
w := tpc.NewWriter(ctx)

w.Write([]byte("it's full of stars!"))
err = w.Close()

if strings.Index(err.Error(), "InvalidClientTokenId: The security token included in the request is invalid") != 0 {
t.Errorf("Expected error message to start with `InvalidClientTokenId: The security token included in the request is invalid`, was `%s`", err.Error())
}

t.Logf("retries: %v", retries)
if len(retries) != c.numTries {
t.Errorf("It should try %d times before failing, was %d", c.numTries, len(retries))
}
})
}
}
42 changes: 9 additions & 33 deletions sqs/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"

"github.com/zerofox-oss/go-aws-msg/retryer"
msg "github.com/zerofox-oss/go-msg"
)

Expand Down Expand Up @@ -148,32 +150,6 @@ func (s *Server) Shutdown(ctx context.Context) error {
}
}

// DefaultRetryer implements an AWS `request.Retryer` that has a custom delay
// for credential errors (403 statuscode).
// This is needed in order to wait for credentials to be valid for SQS requests
// due to AWS "eventually consistent" credentials:
// https://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_general.html
type DefaultRetryer struct {
request.Retryer
delay time.Duration
}

// RetryRules returns the delay for the next request to be made
func (r DefaultRetryer) RetryRules(req *request.Request) time.Duration {
if req.HTTPResponse.StatusCode == 403 {
return r.delay
}
return r.Retryer.RetryRules(req)
}

// ShouldRetry determines if the passed request should be retried
func (r DefaultRetryer) ShouldRetry(req *request.Request) bool {
if req.HTTPResponse.StatusCode == 403 {
return true
}
return r.Retryer.ShouldRetry(req)
}

// Option is the signature that modifies a `Server` to set some configuration
type Option func(*Server) error

Expand All @@ -199,6 +175,10 @@ func NewServer(queueURL string, cl int, retryTimeout int64, opts ...Option) (msg
conf := &aws.Config{
Credentials: credentials.NewCredentials(&credentials.EnvProvider{}),
Region: aws.String("us-west-2"),
Retryer: retryer.DefaultRetryer{
Retryer: client.DefaultRetryer{NumMaxRetries: 7},
Delay: 2 * time.Second,
},
}

// http://docs.aws.amazon.com/sdk-for-go/api/aws/client/#Config
Expand All @@ -209,10 +189,6 @@ func NewServer(queueURL string, cl int, retryTimeout int64, opts ...Option) (msg
if url := os.Getenv("SQS_ENDPOINT"); url != "" {
conf.Endpoint = aws.String(url)
}
conf.Retryer = DefaultRetryer{
Retryer: client.DefaultRetryer{NumMaxRetries: 7},
delay: 2 * time.Second,
}

// Create an SQS Client with creds from the Environment
svc := sqs.New(sess, conf)
Expand All @@ -234,7 +210,7 @@ func NewServer(queueURL string, cl int, retryTimeout int64, opts ...Option) (msg

for _, opt := range opts {
if err = opt(srv); err != nil {
return nil, fmt.Errorf("Failed setting option: %s", err)
return nil, fmt.Errorf("cannot set option: %s", err)
}
}

Expand Down Expand Up @@ -272,9 +248,9 @@ func WithRetries(delay time.Duration, max int) Option {
if err != nil {
return err
}
c.Retryer = DefaultRetryer{
c.Retryer = retryer.DefaultRetryer{
Retryer: client.DefaultRetryer{NumMaxRetries: max},
delay: delay,
Delay: delay,
}
s.Svc = sqs.New(s.session, c)
return nil
Expand Down
4 changes: 2 additions & 2 deletions sqs/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ func TestServer_Serve_retries(t *testing.T) {
defer ts.Close()

os.Setenv("SQS_ENDPOINT", ts.URL)
os.Setenv("AWS_ACCESS_KEY_ID", "AKIyJLQDLOCKWMFHfake")
os.Setenv("AWS_SECRET_ACCESS_KEY", "T1PERSo63zFp1q5AGkGERmqOLQNZGfFu6iqAfake")
os.Setenv("AWS_ACCESS_KEY_ID", "foo")
os.Setenv("AWS_SECRET_ACCESS_KEY", "bar")

defer func() {
os.Unsetenv("SQS_ENDPOINT")
Expand Down