Skip to content

Commit 85a3a62

Browse files
author
Tyler Reid
committed
Break notify into submethods to create the session then create the publish input to send. Check we populate a region for all requests.
This reverts commit 4c2a5f1. Signed-off-by: Tyler Reid <[email protected]>
1 parent 51b9368 commit 85a3a62

File tree

1 file changed

+65
-41
lines changed

1 file changed

+65
-41
lines changed

notify/sns/sns.go

Lines changed: 65 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,45 @@ func New(c *config.SNSConfig, t *template.Template, l log.Logger, httpOpts ...co
6262

6363
func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, error) {
6464
var (
65-
err error
66-
data = notify.GetTemplateData(ctx, n.tmpl, alert, n.logger)
67-
tmpl = notify.TmplText(n.tmpl, data, &err)
68-
creds *credentials.Credentials = nil
65+
err error
66+
data = notify.GetTemplateData(ctx, n.tmpl, alert, n.logger)
67+
tmpl = notify.TmplText(n.tmpl, data, &err)
6968
)
70-
if n.conf.Sigv4.AccessKey != "" && n.conf.Sigv4.SecretKey != "" {
71-
creds = credentials.NewStaticCredentials(n.conf.Sigv4.AccessKey, string(n.conf.Sigv4.SecretKey), "")
69+
70+
client, err := createSNSClient(n, tmpl)
71+
if err != nil {
72+
if e, ok := err.(awserr.RequestFailure); ok {
73+
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
74+
} else {
75+
return true, err
76+
}
7277
}
7378

74-
attributes := make(map[string]*sns.MessageAttributeValue, len(n.conf.Attributes))
75-
for k, v := range n.conf.Attributes {
76-
attributes[tmpl(k)] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(tmpl(v))}
79+
publishInput, err := createPublishInput(ctx, n, tmpl)
80+
if err != nil {
81+
return true, err
82+
}
83+
84+
publishOutput, err := client.Publish(publishInput)
85+
if err != nil {
86+
if e, ok := err.(awserr.RequestFailure); ok {
87+
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
88+
} else {
89+
return true, err
90+
}
7791
}
7892

93+
level.Debug(n.logger).Log("msg", "SNS message successfully published", "message_id", publishOutput.MessageId, "sequence number", publishOutput.SequenceNumber)
94+
95+
return false, nil
96+
}
97+
98+
func createSNSClient(n *Notifier, tmpl func(string) string) (*sns.SNS, error) {
99+
var creds *credentials.Credentials = nil
100+
// If there are provided sigV4 credentials we want to use those to create a session.
101+
if n.conf.Sigv4.AccessKey != "" && n.conf.Sigv4.SecretKey != "" {
102+
creds = credentials.NewStaticCredentials(n.conf.Sigv4.AccessKey, string(n.conf.Sigv4.SecretKey), "")
103+
}
79104
sess, err := session.NewSessionWithOptions(session.Options{
80105
Config: aws.Config{
81106
Region: aws.String(n.conf.Sigv4.Region),
@@ -84,11 +109,7 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err
84109
Profile: n.conf.Sigv4.Profile,
85110
})
86111
if err != nil {
87-
if e, ok := err.(awserr.RequestFailure); ok {
88-
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
89-
} else {
90-
return true, err
91-
}
112+
return nil, err
92113
}
93114

94115
if n.conf.Sigv4.RoleARN != "" {
@@ -105,32 +126,37 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err
105126
Profile: n.conf.Sigv4.Profile,
106127
})
107128
if err != nil {
108-
if e, ok := err.(awserr.RequestFailure); ok {
109-
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
110-
} else {
111-
return true, err
112-
}
129+
return nil, err
113130
}
114131
}
115132
creds = stscreds.NewCredentials(stsSess, n.conf.Sigv4.RoleARN)
116133
}
117-
// Max message size for a message in a SNS publish request is 256KB, except for SMS messages where the limit is 1600 characters/runes.
118-
messageSizeLimit := 256 * 1024
134+
// Use our generated session with credentials to create the SNS Client.
119135
client := sns.New(sess, &aws.Config{Credentials: creds})
120-
publishInput := &sns.PublishInput{}
136+
// We will always need a region to be set by either the local config or the environment.
137+
if aws.StringValue(sess.Config.Region) == "" {
138+
return nil, fmt.Errorf("region not configured in sns.sigv4.region or in default credentials chain")
139+
}
140+
return client, nil
141+
}
121142

143+
func createPublishInput(ctx context.Context, n *Notifier, tmpl func(string) string) (*sns.PublishInput, error) {
144+
publishInput := &sns.PublishInput{}
145+
messageAttributes := createMessageAttributes(n, tmpl)
146+
// Max message size for a message in a SNS publish request is 256KB, except for SMS messages where the limit is 1600 characters/runes.
147+
messageSizeLimit := 256 * 1024
122148
if n.conf.TopicARN != "" {
123149
topicTmpl := tmpl(n.conf.TopicARN)
124150
publishInput.SetTopicArn(topicTmpl)
125-
126151
if n.isFifo == nil {
152+
// If we are using a topic ARN it could be a FIFO topic specified by the topic postfix .fifo.
127153
n.isFifo = aws.Bool(n.conf.TopicARN[len(n.conf.TopicARN)-5:] == ".fifo")
128154
}
129155
if *n.isFifo {
130156
// Deduplication key and Message Group ID are only added if it's a FIFO SNS Topic.
131157
key, err := notify.ExtractGroupKey(ctx)
132158
if err != nil {
133-
return false, err
159+
return nil, err
134160
}
135161
publishInput.SetMessageDeduplicationId(key.Hash())
136162
publishInput.SetMessageGroupId(key.Hash())
@@ -143,36 +169,25 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err
143169
}
144170
if n.conf.TargetARN != "" {
145171
publishInput.SetTargetArn(tmpl(n.conf.TargetARN))
146-
147172
}
148173

149174
messageToSend, isTrunc, err := validateAndTruncateMessage(tmpl(n.conf.Message), messageSizeLimit)
150175
if err != nil {
151-
return false, err
176+
return nil, err
152177
}
153178
if isTrunc {
154-
attributes["truncated"] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}
179+
// If we truncated the message we need to add a message attribute showing that it was truncated.
180+
messageAttributes["truncated"] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}
155181
}
182+
156183
publishInput.SetMessage(messageToSend)
184+
publishInput.SetMessageAttributes(messageAttributes)
157185

158186
if n.conf.Subject != "" {
159187
publishInput.SetSubject(tmpl(n.conf.Subject))
160188
}
161189

162-
publishInput.SetMessageAttributes(attributes)
163-
164-
publishOutput, err := client.Publish(publishInput)
165-
if err != nil {
166-
if e, ok := err.(awserr.RequestFailure); ok {
167-
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
168-
} else {
169-
return true, err
170-
}
171-
}
172-
173-
level.Debug(n.logger).Log("msg", "SNS message successfully published", "message_id", publishOutput.MessageId, "sequence number", publishOutput.SequenceNumber)
174-
175-
return false, nil
190+
return publishInput, nil
176191
}
177192

178193
func validateAndTruncateMessage(message string, maxMessageSizeInBytes int) (string, bool, error) {
@@ -187,3 +202,12 @@ func validateAndTruncateMessage(message string, maxMessageSizeInBytes int) (stri
187202
copy(truncated, message)
188203
return string(truncated), true, nil
189204
}
205+
206+
func createMessageAttributes(n *Notifier, tmpl func(string) string) map[string]*sns.MessageAttributeValue {
207+
// Convert the given attributes map into the AWS Message Attributes Format
208+
attributes := make(map[string]*sns.MessageAttributeValue, len(n.conf.Attributes))
209+
for k, v := range n.conf.Attributes {
210+
attributes[tmpl(k)] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(tmpl(v))}
211+
}
212+
return attributes
213+
}

0 commit comments

Comments
 (0)