diff --git a/pkg/server/datastore/sqldriver/awsrds/auth_token.go b/pkg/server/datastore/sqldriver/awsrds/auth_token.go index 5179aed0e7..fa363956ec 100644 --- a/pkg/server/datastore/sqldriver/awsrds/auth_token.go +++ b/pkg/server/datastore/sqldriver/awsrds/auth_token.go @@ -13,19 +13,18 @@ import ( "github.com/aws/aws-sdk-go-v2/feature/rds/auth" ) -const iso8601BasicFormat = "20060102T150405Z" +const ( + iso8601BasicFormat = "20060102T150405Z" + clockSkew = time.Minute // Make sure that the authentication token is valid for one more minute. +) type authTokenBuilder interface { buildAuthToken(ctx context.Context, endpoint string, region string, dbUser string, creds aws.CredentialsProvider, optFns ...func(options *auth.BuildAuthTokenOptions)) (string, error) } -type tokenGetter interface { - getAuthToken(ctx context.Context, params *Config, tokenBuilder authTokenBuilder) (string, error) -} - type authToken struct { - token string - expiresAt time.Time + cachedToken string + expiresAt time.Time } func (a *authToken) getAuthToken(ctx context.Context, config *Config, tokenBuilder authTokenBuilder) (string, error) { @@ -37,8 +36,8 @@ func (a *authToken) getAuthToken(ctx context.Context, config *Config, tokenBuild return "", errors.New("missing token builder") } - if !a.isExpired() { - return a.token, nil + if !a.shouldRotate() { + return a.cachedToken, nil } awsClientConfig, err := newAWSClientConfig(ctx, config) @@ -79,14 +78,18 @@ func (a *authToken) getAuthToken(ctx context.Context, config *Config, tokenBuild if err != nil { return "", fmt.Errorf("failed to parse X-Amz-Expires duration: %w", err) } - a.token = authenticationToken + a.cachedToken = authenticationToken a.expiresAt = dateTime.Add(durationTime) return authenticationToken, nil } -func (a *authToken) isExpired() bool { - clockSkew := time.Minute // Make sure that the authentication token is valid for one more minute. - return nowFunc().Add(-clockSkew).Sub(a.expiresAt) >= 0 +// shouldRotate returns true if the cached token is either expired or is +// expiring soon. This means that this function will return true also if the +// token is still valid but should be rotated because it's expiring soon. The +// time window that establish when a cached token should be rotated even if it's +// still valid is adjusted by a clock skew, defined in the clockSkew constant. +func (a *authToken) shouldRotate() bool { + return nowFunc().Add(clockSkew).Sub(a.expiresAt) >= 0 } type awsTokenBuilder struct{} diff --git a/pkg/server/datastore/sqldriver/awsrds/awsrds.go b/pkg/server/datastore/sqldriver/awsrds/awsrds.go index c966bba038..0de042607c 100644 --- a/pkg/server/datastore/sqldriver/awsrds/awsrds.go +++ b/pkg/server/datastore/sqldriver/awsrds/awsrds.go @@ -67,7 +67,7 @@ func (c *Config) getConnStringWithPassword(password string) (string, error) { } } -type tokens map[string]tokenGetter +type tokens map[string]*authToken // sqlDriverWrapper is a wrapper for SQL drivers, adding IAM authentication. type sqlDriverWrapper struct { diff --git a/pkg/server/datastore/sqldriver/awsrds/awsrds_test.go b/pkg/server/datastore/sqldriver/awsrds/awsrds_test.go index 0e754f8968..1c7dca6fc1 100644 --- a/pkg/server/datastore/sqldriver/awsrds/awsrds_test.go +++ b/pkg/server/datastore/sqldriver/awsrds/awsrds_test.go @@ -270,11 +270,12 @@ func TestCacheToken(t *testing.T) { dsn, err := config.FormatDSN() require.NoError(t, err) - now := time.Now().UTC() - nowString := now.Format(iso8601BasicFormat) + initialTime := time.Now().UTC() + nowString := initialTime.Format(iso8601BasicFormat) + ttl := 900 // Set a first token to be always returned by the token builder. - firstToken := fmt.Sprintf("X-Amz-Date=%s&X-Amz-Expires=900&X-Amz-Signature=first-token", nowString) + firstToken := fmt.Sprintf("X-Amz-Date=%s&X-Amz-Expires=%d&X-Amz-Signature=first-token", nowString, ttl) fakeSQLDriverWrapper.tokenBuilder = &fakeTokenBuilder{ authToken: firstToken, } @@ -299,11 +300,15 @@ func TestCacheToken(t *testing.T) { // token (not expired) that we can use. For that, we start by setting a new // token that will be returned by the token builder when getAWSAuthToken is // called. - newToken := fmt.Sprintf("X-Amz-Date=%s&X-Amz-Expires=900&X-Amz-Signature=second-token", nowString) + + newToken := fmt.Sprintf("X-Amz-Date=%s&X-Amz-Expires=%d&X-Amz-Signature=second-token", nowString, ttl) fakeSQLDriverWrapper.tokenBuilder = &fakeTokenBuilder{ authToken: newToken, } + // Advance the clock just a few seconds. + nowFunc = func() time.Time { return initialTime.Add(time.Second * 15) } + // Call Open again, the cached token should be used. db, err = gorm.Open(fakeSQLDriverName, dsn) require.NoError(t, err) @@ -318,8 +323,14 @@ func TestCacheToken(t *testing.T) { // We will now make firstToken to expire, so we can test that the token // builder is called to get a new token when the current token has expired. - // For that, we advance the clock one hour. - nowFunc = func() time.Time { return now.Add(time.Hour) } + // For that, we advance the clock the number of seconds of the ttl of the + // token. + newTime := initialTime.Add(time.Second * time.Duration(ttl)) + + // nowFunc will subtract the clock skew from the new time, to make sure + // that we get a new token even if it's not expired but it's within the + // clock skew period. + nowFunc = func() time.Time { return newTime.Add(-clockSkew) } // Call Open again, the new token should be used. db, err = gorm.Open(fakeSQLDriverName, dsn)