Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fetch SA from apiserver #252

Merged
merged 7 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ func main() {
saInformer,
cmInformer,
composeRoleArnCache,
clientset.CoreV1(),
)
stop := make(chan struct{})
informerFactory.Start(stop)
Expand Down
86 changes: 66 additions & 20 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,27 @@
package cache

import (
"context"
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"time"

"github.com/aws/amazon-eks-pod-identity-webhook/pkg"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/time/rate"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
coreinformers "k8s.io/client-go/informers/core/v1"
"k8s.io/client-go/kubernetes"
corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/client-go/tools/cache"
"k8s.io/client-go/util/retry"
"k8s.io/klog/v2"
)

Expand Down Expand Up @@ -80,8 +87,7 @@ type serviceAccountCache struct {
composeRoleArn ComposeRoleArn
defaultTokenExpiration int64
webhookUsage prometheus.Gauge
notificationHandlers map[string]chan struct{}
handlerMu sync.Mutex
notifications *notifications
}

type ComposeRoleArn struct {
Expand Down Expand Up @@ -159,20 +165,13 @@ func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (u
return false, pkg.DefaultTokenExpiration
}

func (c *serviceAccountCache) getSA(req Request) (*Entry, chan struct{}) {
func (c *serviceAccountCache) getSA(req Request) (*Entry, <-chan struct{}) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, ok := c.saCache[req.CacheKey()]
if !ok && req.RequestNotification {
klog.V(5).Infof("Service Account %s not found in cache, adding notification handler", req.CacheKey())
c.handlerMu.Lock()
defer c.handlerMu.Unlock()
notifier, found := c.notificationHandlers[req.CacheKey()]
if !found {
notifier = make(chan struct{})
c.notificationHandlers[req.CacheKey()] = notifier
}
return nil, notifier
return nil, c.notifications.create(req)
}
return entry, nil
}
Expand Down Expand Up @@ -267,13 +266,7 @@ func (c *serviceAccountCache) setSA(name, namespace string, entry *Entry) {
klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, entry)
c.saCache[key] = entry

c.handlerMu.Lock()
defer c.handlerMu.Unlock()
if handler, found := c.notificationHandlers[key]; found {
klog.V(5).Infof("Notifying handlers for %q", key)
close(handler)
delete(c.notificationHandlers, key)
}
c.notifications.broadcast(key)
}

func (c *serviceAccountCache) setCM(name, namespace string, entry *Entry) {
Expand All @@ -283,7 +276,15 @@ func (c *serviceAccountCache) setCM(name, namespace string, entry *Entry) {
c.cmCache[namespace+"/"+name] = entry
}

func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenExpiration int64, saInformer coreinformers.ServiceAccountInformer, cmInformer coreinformers.ConfigMapInformer, composeRoleArn ComposeRoleArn) ServiceAccountCache {
func New(defaultAudience,
prefix string,
defaultRegionalSTS bool,
defaultTokenExpiration int64,
saInformer coreinformers.ServiceAccountInformer,
cmInformer coreinformers.ConfigMapInformer,
composeRoleArn ComposeRoleArn,
SAGetter corev1.ServiceAccountsGetter,
) ServiceAccountCache {
hasSynced := func() bool {
if cmInformer != nil {
return saInformer.Informer().HasSynced() && cmInformer.Informer().HasSynced()
Expand All @@ -292,6 +293,9 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
}
}

// Allocate capacity large enough to not block writers (sync path in pod mutation).
// Rate limiting is done in the consumer side below.
saFetchRequests := make(chan *Request, 1000)
c := &serviceAccountCache{
saCache: map[string]*Entry{},
cmCache: map[string]*Entry{},
Expand All @@ -302,9 +306,30 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
defaultTokenExpiration: defaultTokenExpiration,
hasSynced: hasSynced,
webhookUsage: webhookUsage,
notificationHandlers: map[string]chan struct{}{},
notifications: newNotifications(saFetchRequests),
}

// Rate limiting at 10 requests per second with burst to 20.
// In case the requests are queued in the channel for period longer than the service-account-lookup-grace-period,
// the pod will not be mutated if the service account is also not synced by informer cache before service-account-lookup-grace-period.
// This is to avoid adding unlimited latency to the pod mutation time. The maximum latency would be service-account-lookup-grace-period.
rl := rate.NewLimiter(rate.Every(100*time.Millisecond), 20)
go func() {
for req := range saFetchRequests {
go func() {
// Do rate limiting inside go routine, the goal is to consume the channel as fast as possible to
// avoid writer being blocked but still rate limit the requests sent to the API server.
_ = rl.Wait(context.Background())
sa, err := fetchFromAPI(SAGetter, req)
if err != nil {
klog.Errorf("fetching SA: %s, but got error from API: %v", req.CacheKey(), err)
return
}
c.addSA(sa)
}()
}
}()

saInformer.Informer().AddEventHandler(
cache.ResourceEventHandlerFuncs{
AddFunc: func(obj interface{}) {
Expand Down Expand Up @@ -354,6 +379,27 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
return c
}

func fetchFromAPI(getter corev1.ServiceAccountsGetter, req *Request) (*v1.ServiceAccount, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
defer cancel()

klog.V(5).Infof("fetching SA: %s", req.CacheKey())

var sa *v1.ServiceAccount
err := retry.OnError(retry.DefaultBackoff, func(err error) bool {
return errors.IsServerTimeout(err)
}, func() error {
res, err := getter.ServiceAccounts(req.Namespace).Get(ctx, req.Name, metav1.GetOptions{})
if err != nil {
return err
}
sa = res
return nil
})

return sa, err
}

func (c *serviceAccountCache) populateCacheFromCM(oldCM, newCM *v1.ConfigMap) error {
if newCM.Name != "pod-identity-webhook" {
return nil
Expand Down
97 changes: 88 additions & 9 deletions pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func TestSaCache(t *testing.T) {
defaultAudience: "sts.amazonaws.com",
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

resp := cache.Get(Request{Name: "default", Namespace: "default"})
Expand Down Expand Up @@ -69,9 +70,9 @@ func TestNotification(t *testing.T) {

t.Run("with one notification handler", func(t *testing.T) {
cache := &serviceAccountCache{
saCache: map[string]*Entry{},
notificationHandlers: map[string]chan struct{}{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
saCache: map[string]*Entry{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

// test that the requested SA is not in the cache
Expand Down Expand Up @@ -106,9 +107,9 @@ func TestNotification(t *testing.T) {

t.Run("with 10 notification handlers", func(t *testing.T) {
cache := &serviceAccountCache{
saCache: map[string]*Entry{},
notificationHandlers: map[string]chan struct{}{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
saCache: map[string]*Entry{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 5)),
}

// test that the requested SA is not in the cache
Expand Down Expand Up @@ -153,6 +154,63 @@ func TestNotification(t *testing.T) {
})
}

func TestFetchFromAPIServer(t *testing.T) {
testSA := &v1.ServiceAccount{
ObjectMeta: metav1.ObjectMeta{
Name: "my-sa",
Namespace: "default",
Annotations: map[string]string{
"eks.amazonaws.com/role-arn": "arn:aws:iam::111122223333:role/s3-reader",
"eks.amazonaws.com/token-expiration": "3600",
},
},
}
fakeSAClient := fake.NewSimpleClientset(testSA)

// use an empty informer to simulate the need to fetch SA from api server:
fakeEmptyClient := fake.NewSimpleClientset()
emptyInformerFactory := informers.NewSharedInformerFactory(fakeEmptyClient, 0)
emptyInformer := emptyInformerFactory.Core().V1().ServiceAccounts()

cache := New(
"sts.amazonaws.com",
"eks.amazonaws.com",
true,
86400,
emptyInformer,
nil,
ComposeRoleArn{},
fakeSAClient.CoreV1(),
)

stop := make(chan struct{})
emptyInformerFactory.Start(stop)
emptyInformerFactory.WaitForCacheSync(stop)
cache.Start(stop)
defer close(stop)

err := wait.ExponentialBackoff(wait.Backoff{Duration: 10 * time.Millisecond, Factor: 1.0, Steps: 3}, func() (bool, error) {
return len(fakeEmptyClient.Actions()) != 0, nil
})
if err != nil {
t.Fatalf("informer never called client: %v", err)
}

resp := cache.Get(Request{Name: "my-sa", Namespace: "default", RequestNotification: true})
assert.False(t, resp.FoundInCache, "Expected cache entry to not be found")

// wait for the notification while we fetch the SA from the API server:
select {
case <-resp.Notifier:
// expected
// test that the requested SA is now in the cache
resp := cache.Get(Request{Name: "my-sa", Namespace: "default", RequestNotification: false})
assert.True(t, resp.FoundInCache, "Expected cache entry to be found in cache")
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for notification")
}
}

func TestNonRegionalSTS(t *testing.T) {
trueStr := "true"
falseStr := "false"
Expand Down Expand Up @@ -237,7 +295,16 @@ func TestNonRegionalSTS(t *testing.T) {

testComposeRoleArn := ComposeRoleArn{}

cache := New(audience, "eks.amazonaws.com", tc.defaultRegionalSTS, 86400, informer, nil, testComposeRoleArn)
cache := New(
audience,
"eks.amazonaws.com",
tc.defaultRegionalSTS,
86400,
informer,
nil,
testComposeRoleArn,
fakeClient.CoreV1(),
)
stop := make(chan struct{})
informerFactory.Start(stop)
informerFactory.WaitForCacheSync(stop)
Expand Down Expand Up @@ -295,7 +362,8 @@ func TestPopulateCacheFromCM(t *testing.T) {
}

c := serviceAccountCache{
cmCache: make(map[string]*Entry),
cmCache: make(map[string]*Entry),
notifications: newNotifications(make(chan *Request, 10)),
}

{
Expand Down Expand Up @@ -413,6 +481,7 @@ func TestSAAnnotationRemoval(t *testing.T) {
saCache: make(map[string]*Entry),
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

c.addSA(oldSA)
Expand Down Expand Up @@ -476,6 +545,7 @@ func TestCachePrecedence(t *testing.T) {
defaultTokenExpiration: pkg.DefaultTokenExpiration,
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

{
Expand Down Expand Up @@ -574,7 +644,15 @@ func TestRoleArnComposition(t *testing.T) {
informerFactory := informers.NewSharedInformerFactory(fakeClient, 0)
informer := informerFactory.Core().V1().ServiceAccounts()

cache := New(audience, "eks.amazonaws.com", true, 86400, informer, nil, testComposeRoleArn)
cache := New(audience,
"eks.amazonaws.com",
true,
86400,
informer,
nil,
testComposeRoleArn,
fakeClient.CoreV1(),
)
stop := make(chan struct{})
informerFactory.Start(stop)
informerFactory.WaitForCacheSync(stop)
Expand Down Expand Up @@ -673,6 +751,7 @@ func TestGetCommonConfigurations(t *testing.T) {
defaultAudience: "sts.amazonaws.com",
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

if tc.serviceAccount != nil {
Expand Down
44 changes: 44 additions & 0 deletions pkg/cache/notifications.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package cache

import (
"sync"

"k8s.io/klog/v2"
)

type notifications struct {
handlers map[string]chan struct{}
mu sync.Mutex
fetchRequests chan<- *Request
}

func newNotifications(saFetchRequests chan<- *Request) *notifications {
return &notifications{
handlers: map[string]chan struct{}{},
fetchRequests: saFetchRequests,
}
}

func (n *notifications) create(req Request) <-chan struct{} {
n.mu.Lock()
defer n.mu.Unlock()

// deduplicate requests to SA with same namespace/name to single request
notifier, found := n.handlers[req.CacheKey()]
if !found {
notifier = make(chan struct{})
n.handlers[req.CacheKey()] = notifier
n.fetchRequests <- &req
}
return notifier
}

func (n *notifications) broadcast(key string) {
n.mu.Lock()
defer n.mu.Unlock()
if handler, found := n.handlers[key]; found {
klog.V(5).Infof("Notifying handlers for %q", key)
close(handler)
delete(n.handlers, key)
}
}
2 changes: 1 addition & 1 deletion pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig {

// Use the STS WebIdentity method if set
gracePeriodEnabled := m.saLookupGraceTime > 0
request := cache.Request{Namespace: pod.Namespace, Name: pod.Spec.ServiceAccountName, RequestNotification: true}
request := cache.Request{Namespace: pod.Namespace, Name: pod.Spec.ServiceAccountName, RequestNotification: gracePeriodEnabled}
response := m.Cache.Get(request)
if !response.FoundInCache && !gracePeriodEnabled {
missingSACounter.WithLabelValues().Inc()
Expand Down
4 changes: 4 additions & 0 deletions vendor/k8s.io/client-go/util/retry/OWNERS

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading