Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions internal/cloud/eksauth/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package eksauth
import (
"context"
"fmt"
"net/http"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -25,6 +26,14 @@ type service struct {
}

func NewService(cfg aws.Config) Iface {
// Set up custom HTTP client with 2s timeout
// default credential provider timeout
// https://github.com/boto/botocore/blob/develop/botocore/utils.py#L3032
customHTTPClient := &http.Client{
Timeout: 2 * time.Second,
}
// Inject custom HTTP client into AWS config
cfg.HTTPClient = customHTTPClient
eksAuthService := eksauth.NewFromConfig(cfg)
return &service{
eksAuthService: eksAuthService,
Expand Down
87 changes: 71 additions & 16 deletions internal/credsretriever/refreshing_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ var (
)

const (
defaultActiveRequestRetries = 16
defaultActiveRequestWaitTime = 250 * time.Millisecond
defaultActiveRequestInterval = 1 * time.Second
minActiveRequestWaitTime = 100 * time.Millisecond
maxActiveRequestWaitTime = 400 * time.Millisecond
// defaultCleanupInterval sets how often we go over the cache to check if
// there are expired credentials requiring renewal
defaultCleanupInterval = 1 * time.Minute
Expand Down Expand Up @@ -150,7 +149,8 @@ func (r *cachedCredentialRetriever) GetIamCredentials(ctx context.Context,
return nil, nil, fmt.Errorf("service account is empty, cannot fetch credentials without a valid one")
}

for i := 0; i <= defaultActiveRequestRetries; i++ {
n := 0
for {
// Check if the request is in the cache, if it is, return it
if val, ok := r.internalCache.Get(request.ServiceAccountToken); ok {
if _, withinTtl := r.credentialsInEntryWithinValidTtl(val); withinTtl {
Expand All @@ -169,36 +169,57 @@ func (r *cachedCredentialRetriever) GetIamCredentials(ctx context.Context,
log.Errorf("Failed the request with error from the same active request: %v", errActiveRequest)
return nil, nil, errActiveRequest
}
if i > 0 {
log.Infof("Waiting for active request with %v tries", i)
}

// Wait for active request to finish caching into internalCache, if not the last retry
if i < defaultActiveRequestRetries {
time.Sleep(defaultActiveRequestWaitTime)
// 2^n exponential backoff
waitTime := minActiveRequestWaitTime * (1 << uint(n)) // 2^n backoff
if waitTime > maxActiveRequestWaitTime {
break
}

log.Infof("Retrying %v waiting for internalActiveRequestCache in %v\n", n, waitTime)
time.Sleep(waitTime)
n++
} else {
// No active request, exit the loop to fetch from delegate
break
}
}

if _, ok := r.internalActiveRequestCache.Get(request.ServiceAccountToken); ok {
log.Warnf("Failed to complete active request in %v tries", defaultActiveRequestRetries)
log.Warnf("Failed to complete active request in %v", maxActiveRequestWaitTime*2)
}

r.internalActiveRequestCache.Add(request.ServiceAccountToken, nil)

log.WithField("cache-hit", 0).Tracef("Could not find entry in cache, requesting creds from delegate")

r.internalActiveRequestCache.Add(request.ServiceAccountToken, nil)
iamCredentials, metadata, err := r.callDelegateAndCache(ctx, request)
r.internalActiveRequestCache.Delete(request.ServiceAccountToken)

if err != nil {
r.internalActiveRequestCache.ReplaceWithExpire(request.ServiceAccountToken, err, defaultActiveRequestInterval)
return nil, nil, err
}
r.internalActiveRequestCache.Delete(request.ServiceAccountToken)
return iamCredentials.credentials, metadata, nil
}

// credentialsRenewalJitteredTtl returns a jittered TTL value between (ttl - 1h) and ttl
// or between 0 and ttl if ttl < 1h
func (r *cachedCredentialRetriever) credentialsRenewalJitteredTtl() time.Duration {
var (
ttl = r.credentialsRenewalTtl
jitterMax time.Duration
)

if ttl >= time.Hour {
jitterMax = time.Hour
ttl -= time.Hour
} else {
jitterMax = ttl / 2
ttl = ttl / 2
}
return ttl + time.Duration(rand.Int63n(int64(jitterMax)))
}

func (r *cachedCredentialRetriever) callDelegateAndCache(ctx context.Context,
request *credentials.EksCredentialsRequest) (cacheEntry, credentials.ResponseMetadata, error) {
log := logger.FromContext(ctx)
Expand All @@ -213,7 +234,7 @@ func (r *cachedCredentialRetriever) callDelegateAndCache(ctx context.Context,
return cacheEntry{}, nil, fmt.Errorf("fetched credentials are expired or will expire within the next %0.2f seconds", credsDuration.Seconds())
}

refreshTtl := minDuration(credsDuration, r.credentialsRenewalTtl)
refreshTtl := minDuration(credsDuration, r.credentialsRenewalJitteredTtl())
log.WithField("refreshTtl", refreshTtl).Infof("Storing creds in cache")

// Store credentials in cache if they are valid. It might be that
Expand All @@ -223,6 +244,40 @@ func (r *cachedCredentialRetriever) callDelegateAndCache(ctx context.Context,
return newCacheEntry, nil, nil
}

func (r *cachedCredentialRetriever) callDelegateAndCacheWithRetry(ctx context.Context,
request *credentials.EksCredentialsRequest, timeout time.Duration) (cacheEntry, credentials.ResponseMetadata, error) {
log := logger.FromContext(ctx)

n := 0
minWaitTime, maxWaitTime := minActiveRequestWaitTime, timeout/2
var (
iamCredentials cacheEntry
metadata credentials.ResponseMetadata
err error
)

for {
// Call callDelegateAndCache
iamCredentials, metadata, err = r.callDelegateAndCache(ctx, request)
if err == nil {
return iamCredentials, metadata, nil
}

// Wait if not the last retry
// 2^n exponential backoff
waitTime := minWaitTime * (1 << uint(n)) // 2^n backoff
if waitTime > maxWaitTime {
break
}

log.Infof("Retrying %v waiting for callDelegateAndCache in %v\n", n, waitTime)
time.Sleep(waitTime)
n++
}

return cacheEntry{}, nil, fmt.Errorf("error getting credentials and cache with retries: %w", err)
}

func (r *cachedCredentialRetriever) credentialsInEntryWithinValidTtl(newCacheEntry cacheEntry) (time.Duration, bool) {
credsDuration := newCacheEntry.credentials.Expiration.Time.Sub(r.now())
credentialsLessThanMinCredTtl := credsDuration > r.minCredentialTtl
Expand Down Expand Up @@ -257,7 +312,7 @@ func (r *cachedCredentialRetriever) onCredentialRenewal(key string, entry cacheE
log.Errorf("Problem waiting, will schedule refresh to next sweep")
return
}
_, _, err = r.callDelegateAndCache(ctx, entry.originatingRequest)
_, _, err = r.callDelegateAndCacheWithRetry(ctx, entry.originatingRequest, renewalTimeout)
if err == nil {
// if we retrieved the credentials successfully, exit we don't need to do anything else
promCacheState.WithLabelValues("hit").Inc()
Expand Down
Loading