diff --git a/retryer/retryer.go b/retryer/retryer.go new file mode 100644 index 0000000..96b81f5 --- /dev/null +++ b/retryer/retryer.go @@ -0,0 +1,33 @@ +package retryer + +import ( + "time" + + "github.com/aws/aws-sdk-go/aws/request" +) + +// DefaultRetryer implements an AWS `request.Retryer` that has a custom delay +// for credential errors (403 statuscode). +// This is needed in order to wait for credentials to be valid for SQS requests +// due to AWS "eventually consistent" credentials: +// https://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_general.html +type DefaultRetryer struct { + request.Retryer + Delay time.Duration +} + +// RetryRules returns the delay for the next request to be made +func (r DefaultRetryer) RetryRules(req *request.Request) time.Duration { + if req.HTTPResponse.StatusCode == 403 { + return r.Delay + } + return r.Retryer.RetryRules(req) +} + +// ShouldRetry determines if the passed request should be retried +func (r DefaultRetryer) ShouldRetry(req *request.Request) bool { + if req.HTTPResponse.StatusCode == 403 { + return true + } + return r.Retryer.ShouldRetry(req) +} diff --git a/sns/topic.go b/sns/topic.go index dda7e1f..380a203 100644 --- a/sns/topic.go +++ b/sns/topic.go @@ -3,16 +3,23 @@ package sns import ( "bytes" "context" + "errors" + "fmt" "log" "os" "strings" "sync" + "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sns" "github.com/aws/aws-sdk-go/service/sns/snsiface" + + "github.com/zerofox-oss/go-aws-msg/retryer" msg "github.com/zerofox-oss/go-msg" b64 "github.com/zerofox-oss/go-msg/decorators/base64" ) @@ -21,6 +28,50 @@ import ( type Topic struct { Svc snsiface.SNSAPI TopicARN string + session *session.Session +} + +func getConf(t *Topic) (*aws.Config, error) { + svc, ok := t.Svc.(*sns.SNS) + if !ok { + return nil, errors.New("Svc could not be casted to a SNS client") + } + return &svc.Client.Config, nil +} + +// Option is the signature that modifies a `Topic` to set some configuration +type Option func(*Topic) error + +// WithCustomRetryer sets a custom `Retryer` to use on the SQS client. +func WithCustomRetryer(r request.Retryer) Option { + return func(t *Topic) error { + c, err := getConf(t) + if err != nil { + return err + } + c.Retryer = r + t.Svc = sns.New(t.session, c) + return nil + } +} + +// WithRetries makes the `Server` retry on credential errors until +// `max` attempts with `delay` seconds between requests. +// This is needed in scenarios where credentials are automatically generated +// and the program starts before AWS finishes propagating them +func WithRetries(delay time.Duration, max int) Option { + return func(t *Topic) error { + c, err := getConf(t) + if err != nil { + return err + } + c.Retryer = retryer.DefaultRetryer{ + Retryer: client.DefaultRetryer{NumMaxRetries: max}, + Delay: delay, + } + t.Svc = sns.New(t.session, c) + return nil + } } // NewTopic returns a sns.Topic with fully configured SNSAPI. @@ -31,7 +82,7 @@ type Topic struct { // messages are base64-encoded as a best practice. // // You may use NewUnencodedTopic if you wish to ignore the encoding step. -func NewTopic(topicARN string) (msg.Topic, error) { +func NewTopic(topicARN string, opts ...Option) (msg.Topic, error) { sess, err := session.NewSession() if err != nil { return nil, err @@ -51,17 +102,31 @@ func NewTopic(topicARN string) (msg.Topic, error) { conf.Endpoint = aws.String(url) } - return b64.Encoder(&Topic{ + t := &Topic{ Svc: sns.New(sess, conf), TopicARN: topicARN, - }), nil + session: sess, + } + + // Default retryer + if err = WithRetries(2*time.Second, 7)(t); err != nil { + return nil, err + } + + for _, opt := range opts { + if err = opt(t); err != nil { + return nil, fmt.Errorf("cannot set option: %s", err) + } + } + + return b64.Encoder(t), nil } // NewUnencodedTopic creates an concrete SNS msg.Topic // // Messages published by the `Topic` returned will not // have the body base64-encoded. -func NewUnencodedTopic(topicARN string) (msg.Topic, error) { +func NewUnencodedTopic(topicARN string, opts ...Option) (msg.Topic, error) { sess, err := session.NewSession() if err != nil { return nil, err @@ -77,10 +142,24 @@ func NewUnencodedTopic(topicARN string) (msg.Topic, error) { conf.Endpoint = aws.String(url) } - return &Topic{ + t := &Topic{ Svc: sns.New(sess, conf), TopicARN: topicARN, - }, nil + session: sess, + } + + // Default retryer + if err = WithRetries(2*time.Second, 7)(t); err != nil { + return nil, err + } + + for _, opt := range opts { + if err = opt(t); err != nil { + return nil, fmt.Errorf("cannot set option: %s", err) + } + } + + return t, nil } // NewWriter returns a sns.MessageWriter instance for writing to @@ -124,7 +203,7 @@ func (w *MessageWriter) Close() error { w.mux.Lock() defer w.mux.Unlock() - if w.closed == true { + if w.closed { return msg.ErrClosedMessageWriter } w.closed = true @@ -137,10 +216,8 @@ func (w *MessageWriter) Close() error { } log.Printf("[TRACE] writing to sns: %v", snsPublishParams) - if _, err := w.snsClient.PublishWithContext(w.ctx, snsPublishParams); err != nil { - return err - } - return nil + _, err := w.snsClient.PublishWithContext(w.ctx, snsPublishParams) + return err } // Write writes data to the MessageWriter's internal buffer for aggregation @@ -152,7 +229,7 @@ func (w *MessageWriter) Write(p []byte) (int, error) { w.mux.Lock() defer w.mux.Unlock() - if w.closed == true { + if w.closed { return 0, msg.ErrClosedMessageWriter } return w.buf.Write(p) diff --git a/sns/topic_test.go b/sns/topic_test.go index 9b0a93c..529af75 100644 --- a/sns/topic_test.go +++ b/sns/topic_test.go @@ -2,7 +2,14 @@ package sns import ( "context" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "strings" "testing" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/sns" @@ -131,3 +138,75 @@ func TestMessageWriter_CloseProperlyConstructsPublishInput(t *testing.T) { <-control } + +type constructor func(string, ...Option) (msg.Topic, error) + +func TestMessageWriter_Close_retryer(t *testing.T) { + retries := make([]*http.Request, 0, 3) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := ioutil.ReadAll(r.Body) + t.Logf("Request: %s\n", b) + retries = append(retries, r) + w.WriteHeader(403) + fmt.Fprintln(w, ` + + + Sender + InvalidClientTokenId + The security token included in the request is invalid. + + 590d5457-e4b6-5464-a482-071900d4c7d6 +`) + })) + defer ts.Close() + + os.Setenv("SNS_ENDPOINT", ts.URL) + os.Setenv("AWS_ACCESS_KEY_ID", "fake") + os.Setenv("AWS_SECRET_ACCESS_KEY", "fake") + + defer func() { + os.Unsetenv("SNS_ENDPOINT") + os.Unsetenv("AWS_ACCESS_KEY_ID") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + }() + + cases := []struct { + name string + newTopic constructor + options []Option + numTries int + }{ + {"default", NewTopic, nil, 8}, + {"1 retry", NewTopic, []Option{WithRetries(0, 1)}, 2}, + {"No retries", NewTopic, []Option{WithRetries(0, 0)}, 1}, + {"UnencodedTopic default", NewUnencodedTopic, nil, 8}, + {"UnencodedTopic 1 retry", NewUnencodedTopic, []Option{WithRetries(0, 1)}, 2}, + {"UnencodedTopic No retries", NewUnencodedTopic, []Option{WithRetries(0, 0)}, 1}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + retries = make([]*http.Request, 0, 3) + tpc, err := c.newTopic("arn:aws:sns:us-west-2:777777777777:test-sns", c.options...) + if err != nil { + t.Errorf("Server creation should not fail: %s", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + w := tpc.NewWriter(ctx) + + w.Write([]byte("it's full of stars!")) + err = w.Close() + + if strings.Index(err.Error(), "InvalidClientTokenId: The security token included in the request is invalid") != 0 { + t.Errorf("Expected error message to start with `InvalidClientTokenId: The security token included in the request is invalid`, was `%s`", err.Error()) + } + + t.Logf("retries: %v", retries) + if len(retries) != c.numTries { + t.Errorf("It should try %d times before failing, was %d", c.numTries, len(retries)) + } + }) + } +} diff --git a/sqs/server.go b/sqs/server.go index e2d8f7d..410eb41 100644 --- a/sqs/server.go +++ b/sqs/server.go @@ -16,6 +16,8 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sqs" "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + + "github.com/zerofox-oss/go-aws-msg/retryer" msg "github.com/zerofox-oss/go-msg" ) @@ -148,32 +150,6 @@ func (s *Server) Shutdown(ctx context.Context) error { } } -// DefaultRetryer implements an AWS `request.Retryer` that has a custom delay -// for credential errors (403 statuscode). -// This is needed in order to wait for credentials to be valid for SQS requests -// due to AWS "eventually consistent" credentials: -// https://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_general.html -type DefaultRetryer struct { - request.Retryer - delay time.Duration -} - -// RetryRules returns the delay for the next request to be made -func (r DefaultRetryer) RetryRules(req *request.Request) time.Duration { - if req.HTTPResponse.StatusCode == 403 { - return r.delay - } - return r.Retryer.RetryRules(req) -} - -// ShouldRetry determines if the passed request should be retried -func (r DefaultRetryer) ShouldRetry(req *request.Request) bool { - if req.HTTPResponse.StatusCode == 403 { - return true - } - return r.Retryer.ShouldRetry(req) -} - // Option is the signature that modifies a `Server` to set some configuration type Option func(*Server) error @@ -199,6 +175,10 @@ func NewServer(queueURL string, cl int, retryTimeout int64, opts ...Option) (msg conf := &aws.Config{ Credentials: credentials.NewCredentials(&credentials.EnvProvider{}), Region: aws.String("us-west-2"), + Retryer: retryer.DefaultRetryer{ + Retryer: client.DefaultRetryer{NumMaxRetries: 7}, + Delay: 2 * time.Second, + }, } // http://docs.aws.amazon.com/sdk-for-go/api/aws/client/#Config @@ -209,10 +189,6 @@ func NewServer(queueURL string, cl int, retryTimeout int64, opts ...Option) (msg if url := os.Getenv("SQS_ENDPOINT"); url != "" { conf.Endpoint = aws.String(url) } - conf.Retryer = DefaultRetryer{ - Retryer: client.DefaultRetryer{NumMaxRetries: 7}, - delay: 2 * time.Second, - } // Create an SQS Client with creds from the Environment svc := sqs.New(sess, conf) @@ -234,7 +210,7 @@ func NewServer(queueURL string, cl int, retryTimeout int64, opts ...Option) (msg for _, opt := range opts { if err = opt(srv); err != nil { - return nil, fmt.Errorf("Failed setting option: %s", err) + return nil, fmt.Errorf("cannot set option: %s", err) } } @@ -272,9 +248,9 @@ func WithRetries(delay time.Duration, max int) Option { if err != nil { return err } - c.Retryer = DefaultRetryer{ + c.Retryer = retryer.DefaultRetryer{ Retryer: client.DefaultRetryer{NumMaxRetries: max}, - delay: delay, + Delay: delay, } s.Svc = sqs.New(s.session, c) return nil diff --git a/sqs/server_test.go b/sqs/server_test.go index 6f21efd..148339b 100644 --- a/sqs/server_test.go +++ b/sqs/server_test.go @@ -94,8 +94,8 @@ func TestServer_Serve_retries(t *testing.T) { defer ts.Close() os.Setenv("SQS_ENDPOINT", ts.URL) - os.Setenv("AWS_ACCESS_KEY_ID", "AKIyJLQDLOCKWMFHfake") - os.Setenv("AWS_SECRET_ACCESS_KEY", "T1PERSo63zFp1q5AGkGERmqOLQNZGfFu6iqAfake") + os.Setenv("AWS_ACCESS_KEY_ID", "foo") + os.Setenv("AWS_SECRET_ACCESS_KEY", "bar") defer func() { os.Unsetenv("SQS_ENDPOINT")