diff --git a/sns/topic.go b/sns/topic.go index ed50c30..380a203 100644 --- a/sns/topic.go +++ b/sns/topic.go @@ -3,6 +3,8 @@ package sns import ( "bytes" "context" + "errors" + "fmt" "log" "os" "strings" @@ -12,12 +14,12 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sns" "github.com/aws/aws-sdk-go/service/sns/snsiface" "github.com/zerofox-oss/go-aws-msg/retryer" - "github.com/zerofox-oss/go-msg" msg "github.com/zerofox-oss/go-msg" b64 "github.com/zerofox-oss/go-msg/decorators/base64" ) @@ -26,6 +28,50 @@ import ( type Topic struct { Svc snsiface.SNSAPI TopicARN string + session *session.Session +} + +func getConf(t *Topic) (*aws.Config, error) { + svc, ok := t.Svc.(*sns.SNS) + if !ok { + return nil, errors.New("Svc could not be casted to a SNS client") + } + return &svc.Client.Config, nil +} + +// Option is the signature that modifies a `Topic` to set some configuration +type Option func(*Topic) error + +// WithCustomRetryer sets a custom `Retryer` to use on the SQS client. +func WithCustomRetryer(r request.Retryer) Option { + return func(t *Topic) error { + c, err := getConf(t) + if err != nil { + return err + } + c.Retryer = r + t.Svc = sns.New(t.session, c) + return nil + } +} + +// WithRetries makes the `Server` retry on credential errors until +// `max` attempts with `delay` seconds between requests. +// This is needed in scenarios where credentials are automatically generated +// and the program starts before AWS finishes propagating them +func WithRetries(delay time.Duration, max int) Option { + return func(t *Topic) error { + c, err := getConf(t) + if err != nil { + return err + } + c.Retryer = retryer.DefaultRetryer{ + Retryer: client.DefaultRetryer{NumMaxRetries: max}, + Delay: delay, + } + t.Svc = sns.New(t.session, c) + return nil + } } // NewTopic returns a sns.Topic with fully configured SNSAPI. @@ -36,7 +82,7 @@ type Topic struct { // messages are base64-encoded as a best practice. // // You may use NewUnencodedTopic if you wish to ignore the encoding step. -func NewTopic(topicARN string) (msg.Topic, error) { +func NewTopic(topicARN string, opts ...Option) (msg.Topic, error) { sess, err := session.NewSession() if err != nil { return nil, err @@ -45,10 +91,6 @@ func NewTopic(topicARN string) (msg.Topic, error) { conf := &aws.Config{ Credentials: credentials.NewCredentials(&credentials.EnvProvider{}), Region: aws.String("us-west-2"), - Retryer: retryer.DefaultRetryer{ - Retryer: client.DefaultRetryer{NumMaxRetries: 7}, - Delay: 2 * time.Second, - }, } // You may override AWS_REGION, SNS_ENDPOINT @@ -60,17 +102,31 @@ func NewTopic(topicARN string) (msg.Topic, error) { conf.Endpoint = aws.String(url) } - return b64.Encoder(&Topic{ + t := &Topic{ Svc: sns.New(sess, conf), TopicARN: topicARN, - }), nil + session: sess, + } + + // Default retryer + if err = WithRetries(2*time.Second, 7)(t); err != nil { + return nil, err + } + + for _, opt := range opts { + if err = opt(t); err != nil { + return nil, fmt.Errorf("cannot set option: %s", err) + } + } + + return b64.Encoder(t), nil } // NewUnencodedTopic creates an concrete SNS msg.Topic // // Messages published by the `Topic` returned will not // have the body base64-encoded. -func NewUnencodedTopic(topicARN string) (msg.Topic, error) { +func NewUnencodedTopic(topicARN string, opts ...Option) (msg.Topic, error) { sess, err := session.NewSession() if err != nil { return nil, err @@ -86,10 +142,24 @@ func NewUnencodedTopic(topicARN string) (msg.Topic, error) { conf.Endpoint = aws.String(url) } - return &Topic{ + t := &Topic{ Svc: sns.New(sess, conf), TopicARN: topicARN, - }, nil + session: sess, + } + + // Default retryer + if err = WithRetries(2*time.Second, 7)(t); err != nil { + return nil, err + } + + for _, opt := range opts { + if err = opt(t); err != nil { + return nil, fmt.Errorf("cannot set option: %s", err) + } + } + + return t, nil } // NewWriter returns a sns.MessageWriter instance for writing to diff --git a/sns/topic_test.go b/sns/topic_test.go index 9b0a93c..ae7a233 100644 --- a/sns/topic_test.go +++ b/sns/topic_test.go @@ -2,7 +2,14 @@ package sns import ( "context" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "strings" "testing" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/sns" @@ -131,3 +138,75 @@ func TestMessageWriter_CloseProperlyConstructsPublishInput(t *testing.T) { <-control } + +type constructor func(string, ...Option) (msg.Topic, error) + +func TestMessageWriter_Close_retryer(t *testing.T) { + retries := make([]*http.Request, 0, 3) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := ioutil.ReadAll(r.Body) + t.Logf("Request: %s\n", b) + retries = append(retries, r) + w.WriteHeader(403) + fmt.Fprintln(w, ` + + + Sender + InvalidClientTokenId + The security token included in the request is invalid. + + 590d5457-e4b6-5464-a482-071900d4c7d6 +`) + })) + defer ts.Close() + + os.Setenv("SNS_ENDPOINT", ts.URL) + os.Setenv("AWS_ACCESS_KEY_ID", "AKIyJLQDLOCKWMFHfake") + os.Setenv("AWS_SECRET_ACCESS_KEY", "T1PERSo63zFp1q5AGkGERmqOLQNZGfFu6iqAfake") + + defer func() { + os.Unsetenv("SNS_ENDPOINT") + os.Unsetenv("AWS_ACCESS_KEY_ID") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + }() + + cases := []struct { + name string + newTopic constructor + options []Option + numTries int + }{ + {"default", NewTopic, nil, 8}, + {"1 retry", NewTopic, []Option{WithRetries(0, 1)}, 2}, + {"No retries", NewTopic, []Option{WithRetries(0, 0)}, 1}, + {"UnencodedTopic default", NewUnencodedTopic, nil, 8}, + {"UnencodedTopic 1 retry", NewUnencodedTopic, []Option{WithRetries(0, 1)}, 2}, + {"UnencodedTopic No retries", NewUnencodedTopic, []Option{WithRetries(0, 0)}, 1}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + retries = make([]*http.Request, 0, 3) + tpc, err := c.newTopic("arn:aws:sns:us-west-2:777777777777:test-sns", c.options...) + if err != nil { + t.Errorf("Server creation should not fail: %s", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + w := tpc.NewWriter(ctx) + + w.Write([]byte("it's full of stars!")) + err = w.Close() + + if strings.Index(err.Error(), "InvalidClientTokenId: The security token included in the request is invalid") != 0 { + t.Errorf("Expected error message to start with `InvalidClientTokenId: The security token included in the request is invalid`, was `%s`", err.Error()) + } + + t.Logf("retries: %v", retries) + if len(retries) != c.numTries { + t.Errorf("It should try %d times before failing, was %d", c.numTries, len(retries)) + } + }) + } +} diff --git a/sqs/server.go b/sqs/server.go index ef646b3..410eb41 100644 --- a/sqs/server.go +++ b/sqs/server.go @@ -210,7 +210,7 @@ func NewServer(queueURL string, cl int, retryTimeout int64, opts ...Option) (msg for _, opt := range opts { if err = opt(srv); err != nil { - return nil, fmt.Errorf("Failed setting option: %s", err) + return nil, fmt.Errorf("cannot set option: %s", err) } }