Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: AWS Temporary credential support for SQS eventsource #2092

Merged
merged 8 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions api/event-source.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions api/event-source.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions api/jsonschema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2464,6 +2464,10 @@
"$ref": "#/definitions/io.k8s.api.core.v1.SecretKeySelector",
"description": "SecretKey refers K8s secret containing aws secret key"
},
"sessionToken": {
"$ref": "#/definitions/io.k8s.api.core.v1.SecretKeySelector",
"description": "SessionToken refers to K8s secret containing AWS temporary credentials(STS) session token"
},
"waitTimeSeconds": {
"description": "WaitTimeSeconds is The duration (in seconds) for which the call waits for a message to arrive in the queue before returning.",
"format": "int64",
Expand Down
4 changes: 4 additions & 0 deletions api/openapi-spec/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 13 additions & 3 deletions eventsources/common/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func GetAWSCredFromEnvironment(access *corev1.SecretKeySelector, secret *corev1.
}

// GetAWSCredFromVolume reads credential stored in mounted secret volume.
func GetAWSCredFromVolume(access *corev1.SecretKeySelector, secret *corev1.SecretKeySelector) (*credentials.Credentials, error) {
func GetAWSCredFromVolume(access *corev1.SecretKeySelector, secret *corev1.SecretKeySelector, sessionToken *corev1.SecretKeySelector) (*credentials.Credentials, error) {
accessKey, err := common.GetSecretFromVolume(access)
if err != nil {
return nil, errors.Wrap(err, "can not find access key")
Expand All @@ -53,9 +53,19 @@ func GetAWSCredFromVolume(access *corev1.SecretKeySelector, secret *corev1.Secre
if err != nil {
return nil, errors.Wrap(err, "can not find secret key")
}

var token string
if sessionToken != nil {
token, err = common.GetSecretFromVolume(sessionToken)
if err != nil {
return nil, errors.Wrap(err, "can not find session token")
}
}

return credentials.NewStaticCredentialsFromCreds(credentials.Value{
AccessKeyID: accessKey,
SecretAccessKey: secretKey,
SessionToken: token,
}), nil
}

Expand Down Expand Up @@ -97,7 +107,7 @@ func CreateAWSSessionWithCredsInEnv(region string, roleARN string, accessKey *co
}

// CreateAWSSessionWithCredsInVolume based on credentials in mounted volumes, return a aws session
func CreateAWSSessionWithCredsInVolume(region string, roleARN string, accessKey *corev1.SecretKeySelector, secretKey *corev1.SecretKeySelector) (*session.Session, error) {
func CreateAWSSessionWithCredsInVolume(region string, roleARN string, accessKey *corev1.SecretKeySelector, secretKey *corev1.SecretKeySelector, sessionToken *corev1.SecretKeySelector) (*session.Session, error) {
if roleARN != "" {
return GetAWSAssumeRoleCreds(roleARN, region)
}
Expand All @@ -106,7 +116,7 @@ func CreateAWSSessionWithCredsInVolume(region string, roleARN string, accessKey
return GetAWSSessionWithoutCreds(region)
}

creds, err := GetAWSCredFromVolume(accessKey, secretKey)
creds, err := GetAWSCredFromVolume(accessKey, secretKey, sessionToken)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion eventsources/sources/awssns/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func (router *Router) PostActivate() error {

snsEventSource := router.eventSource

awsSession, err := commonaws.CreateAWSSessionWithCredsInVolume(snsEventSource.Region, snsEventSource.RoleARN, snsEventSource.AccessKey, snsEventSource.SecretKey)
awsSession, err := commonaws.CreateAWSSessionWithCredsInVolume(snsEventSource.Region, snsEventSource.RoleARN, snsEventSource.AccessKey, snsEventSource.SecretKey, nil)
if err != nil {
return err
}
Expand Down
61 changes: 49 additions & 12 deletions eventsources/sources/awssqs/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
sqslib "github.com/aws/aws-sdk-go/service/sqs"
"github.com/pkg/errors"
Expand Down Expand Up @@ -68,19 +69,9 @@ func (el *EventListener) StartListening(ctx context.Context, dispatch func([]byt
defer sources.Recover(el.GetEventName())

sqsEventSource := &el.SQSEventSource
var awsSession *session.Session
awsSession, err := awscommon.CreateAWSSessionWithCredsInVolume(sqsEventSource.Region, sqsEventSource.RoleARN, sqsEventSource.AccessKey, sqsEventSource.SecretKey)
sqsClient, err := el.createSqsClient()
if err != nil {
log.Errorw("Error creating AWS credentials", zap.Error(err))
return errors.Wrapf(err, "failed to create aws session for %s", el.GetEventName())
}

var sqsClient *sqslib.SQS

if sqsEventSource.Endpoint == "" {
sqsClient = sqslib.New(awsSession)
} else {
sqsClient = sqslib.New(awsSession, &aws.Config{Endpoint: &sqsEventSource.Endpoint, Region: &sqsEventSource.Region})
return err
}

log.Info("fetching queue url...")
Expand Down Expand Up @@ -112,6 +103,17 @@ func (el *EventListener) StartListening(ctx context.Context, dispatch func([]byt
messages, err := fetchMessages(ctx, sqsClient, *queueURL.QueueUrl, 10, sqsEventSource.WaitTimeSeconds)
if err != nil {
log.Errorw("failed to get messages from SQS", zap.Error(err))
awsError, ok := err.(awserr.Error)
if ok && awsError.Code() == "ExpiredToken" && el.SQSEventSource.SessionToken != nil {
log.Info("credentials expired, reading credentials again")
newSqsClient, err := el.createSqsClient()
if err != nil {
log.Errorw("Error creating SQS client", zap.Error(err))
} else if newSqsClient != nil {
sqsClient = newSqsClient
}
}

time.Sleep(2 * time.Second)
continue
}
Expand All @@ -123,6 +125,16 @@ func (el *EventListener) StartListening(ctx context.Context, dispatch func([]byt
})
if err != nil {
log.Errorw("Failed to delete message", zap.Error(err))
awsError, ok := err.(awserr.Error)
if ok && awsError.Code() == "ExpiredToken" && el.SQSEventSource.SessionToken != nil {
log.Info("credentials expired, reading credentials again")
newSqsClient, err := el.createSqsClient()
if err != nil {
log.Errorw("Error creating SQS client", zap.Error(err))
} else if newSqsClient != nil {
sqsClient = newSqsClient
}
}
}
}, log)
}
Expand Down Expand Up @@ -185,3 +197,28 @@ func fetchMessages(ctx context.Context, q *sqslib.SQS, url string, maxSize, wait
}
return result.Messages, nil
}

func (el *EventListener) createAWSSession() (*session.Session, error) {
sqsEventSource := &el.SQSEventSource
awsSession, err := awscommon.CreateAWSSessionWithCredsInVolume(sqsEventSource.Region, sqsEventSource.RoleARN, sqsEventSource.AccessKey, sqsEventSource.SecretKey, sqsEventSource.SessionToken)
if err != nil {
return nil, errors.Wrapf(err, "failed to create aws session for %s", el.GetEventName())
}
return awsSession, nil
}

func (el *EventListener) createSqsClient() (*sqslib.SQS, error) {
awsSession, err := el.createAWSSession()
if err != nil {
return nil, err
}

var sqsClient *sqslib.SQS
if el.SQSEventSource.Endpoint == "" {
sqsClient = sqslib.New(awsSession)
} else {
sqsClient = sqslib.New(awsSession, &aws.Config{Endpoint: &el.SQSEventSource.Endpoint, Region: &el.SQSEventSource.Region})
}

return sqsClient, nil
}
Loading