From 7c7906ed3a8211bcf179dbd56677fda2ee887c5f Mon Sep 17 00:00:00 2001 From: sxllwx Date: Mon, 23 Oct 2023 20:59:40 +0800 Subject: [PATCH 1/4] UPSTREAM: 129472: Fix API server crash on concurrent map iteration and write Improve audit context handling by encapsulating event data and operations behind a structured API. Make the Audit system more robust in concurrent environments by properly isolating mutable state. The cleaner API simplifies interaction with audit events, improving maintainability. Encapsulation reduces bugs by preventing direct manipulation of audit events. Signed-off-by: Davanum Srinivas Co-Authored-By: Jordan Liggitt Co-Authored-By: sxllwx Kubernetes-commit: 75afa1e0acfb309d984be14937a06f796f220cd6 --- pkg/admission/audit.go | 6 +- pkg/admission/audit_test.go | 9 +- pkg/audit/context.go | 275 ++++++++++++++++-- pkg/audit/context_test.go | 57 ++-- pkg/audit/request.go | 204 ++++++------- .../token/cache/cached_token_authenticator.go | 4 +- .../cache/cached_token_authenticator_test.go | 10 +- pkg/endpoints/filters/audit.go | 103 +++---- pkg/endpoints/filters/audit_test.go | 193 ++++++++++-- pkg/endpoints/filters/authn_audit.go | 13 +- pkg/endpoints/filters/authorization_test.go | 16 +- pkg/endpoints/filters/impersonation.go | 3 +- pkg/endpoints/filters/request_deadline.go | 12 +- .../filters/request_deadline_test.go | 27 +- pkg/endpoints/handlers/delete_test.go | 2 +- pkg/server/config_test.go | 13 +- .../server_cert_deprecations_test.go | 16 +- 17 files changed, 662 insertions(+), 301 deletions(-) diff --git a/pkg/admission/audit.go b/pkg/admission/audit.go index 7c0993f09..f9f90cd02 100644 --- a/pkg/admission/audit.go +++ b/pkg/admission/audit.go @@ -83,7 +83,7 @@ func ensureAnnotationGetter(a Attributes) error { } func (handler *auditHandler) logAnnotations(ctx context.Context, a Attributes) { - ae := audit.AuditEventFrom(ctx) + ae := audit.AuditContextFrom(ctx) if ae == nil { return } @@ -91,9 +91,9 @@ func (handler *auditHandler) logAnnotations(ctx context.Context, a Attributes) { var annotations map[string]string switch a := a.(type) { case privateAnnotationsGetter: - annotations = a.getAnnotations(ae.Level) + annotations = a.getAnnotations(ae.GetEventLevel()) case AnnotationsGetter: - annotations = a.GetAnnotations(ae.Level) + annotations = a.GetAnnotations(ae.GetEventLevel()) default: // this will never happen, because we have already checked it in ensureAnnotationGetter } diff --git a/pkg/admission/audit_test.go b/pkg/admission/audit_test.go index e8bc2d8e0..dde433d79 100644 --- a/pkg/admission/audit_test.go +++ b/pkg/admission/audit_test.go @@ -144,8 +144,7 @@ func TestWithAudit(t *testing.T) { var handler Interface = fakeHandler{tc.admit, tc.admitAnnotations, tc.validate, tc.validateAnnotations, tc.handles} ctx := audit.WithAuditContext(context.Background()) ac := audit.AuditContextFrom(ctx) - ae := &ac.Event - ae.Level = auditinternal.LevelMetadata + ac.SetEventLevel(auditinternal.LevelMetadata) auditHandler := WithAudit(handler) a := attributes() @@ -171,9 +170,9 @@ func TestWithAudit(t *testing.T) { annotations[k] = v } if len(annotations) == 0 { - assert.Nil(t, ae.Annotations, tcName+": unexptected annotations set in audit event") + assert.Nil(t, ac.GetEventAnnotations(), tcName+": unexptected annotations set in audit event") } else { - assert.Equal(t, annotations, ae.Annotations, tcName+": unexptected annotations set in audit event") + assert.Equal(t, annotations, ac.GetEventAnnotations(), tcName+": unexptected annotations set in audit event") } } } @@ -188,7 +187,7 @@ func TestWithAuditConcurrency(t *testing.T) { var handler Interface = fakeHandler{admitAnnotations: admitAnnotations, handles: true} ctx := audit.WithAuditContext(context.Background()) ac := audit.AuditContextFrom(ctx) - ac.Event.Level = auditinternal.LevelMetadata + ac.SetEventLevel(auditinternal.LevelMetadata) auditHandler := WithAudit(handler) a := attributes() diff --git a/pkg/audit/context.go b/pkg/audit/context.go index 964858737..538b3d956 100644 --- a/pkg/audit/context.go +++ b/pkg/audit/context.go @@ -18,10 +18,18 @@ package audit import ( "context" + "errors" + "maps" "sync" + "sync/atomic" + "time" + authnv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" auditinternal "k8s.io/apiserver/pkg/apis/audit" + "k8s.io/apiserver/pkg/authentication/user" genericapirequest "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/klog/v2" ) @@ -35,22 +43,232 @@ const auditKey key = iota // AuditContext holds the information for constructing the audit events for the current request. type AuditContext struct { - // RequestAuditConfig is the audit configuration that applies to the request - RequestAuditConfig RequestAuditConfig - - // Event is the audit Event object that is being captured to be written in + // initialized indicates whether requestAuditConfig and sink have been populated and are safe to read unguarded. + // This should only be set via Init(). + initialized atomic.Bool + // initialize wraps setting requestAuditConfig and sink, and is only called via Init(). + initialize sync.Once + // requestAuditConfig is the audit configuration that applies to the request. + // This should only be written via Init(RequestAuditConfig, Sink), and only read when initialized.Load() is true. + requestAuditConfig RequestAuditConfig + // sink is the sink to use when processing event stages. + // This should only be written via Init(RequestAuditConfig, Sink), and only read when initialized.Load() is true. + sink Sink + + // lock guards event + lock sync.Mutex + + // event is the audit Event object that is being captured to be written in // the API audit log. - Event auditinternal.Event + event auditinternal.Event - // annotationMutex guards event.Annotations - annotationMutex sync.Mutex + // unguarded copy of auditID from the event + auditID atomic.Value } // Enabled checks whether auditing is enabled for this audit context. func (ac *AuditContext) Enabled() bool { - // Note: An unset Level should be considered Enabled, so that request data (e.g. annotations) - // can still be captured before the audit policy is evaluated. - return ac != nil && ac.RequestAuditConfig.Level != auditinternal.LevelNone + if ac == nil { + // protect against nil pointers + return false + } + if !ac.initialized.Load() { + // Note: An unset Level should be considered Enabled, so that request data (e.g. annotations) + // can still be captured before the audit policy is evaluated. + return true + } + return ac.requestAuditConfig.Level != auditinternal.LevelNone +} + +func (ac *AuditContext) Init(requestAuditConfig RequestAuditConfig, sink Sink) error { + initialized := false + ac.initialize.Do(func() { + ac.requestAuditConfig = requestAuditConfig + ac.sink = sink + ac.initialized.Store(true) + initialized = true + }) + if !initialized { + return errors.New("audit context was already initialized") + } + return nil +} + +func (ac *AuditContext) AuditID() types.UID { + // return the unguarded copy of the auditID + id, _ := ac.auditID.Load().(types.UID) + return id +} + +func (ac *AuditContext) visitEvent(f func(event *auditinternal.Event)) { + ac.lock.Lock() + defer ac.lock.Unlock() + f(&ac.event) +} + +// ProcessEventStage returns true on success, false if there was an error processing the stage. +func (ac *AuditContext) ProcessEventStage(ctx context.Context, stage auditinternal.Stage) bool { + if ac == nil || !ac.initialized.Load() { + return true + } + if ac.sink == nil { + return true + } + for _, omitStage := range ac.requestAuditConfig.OmitStages { + if stage == omitStage { + return true + } + } + + processed := false + ac.visitEvent(func(event *auditinternal.Event) { + event.Stage = stage + if stage == auditinternal.StageRequestReceived { + event.StageTimestamp = event.RequestReceivedTimestamp + } else { + event.StageTimestamp = metav1.NewMicroTime(time.Now()) + } + + ObserveEvent(ctx) + processed = ac.sink.ProcessEvents(event) + }) + return processed +} + +func (ac *AuditContext) LogImpersonatedUser(user user.Info) { + ac.visitEvent(func(ev *auditinternal.Event) { + if ev == nil || ev.Level.Less(auditinternal.LevelMetadata) { + return + } + ev.ImpersonatedUser = &authnv1.UserInfo{ + Username: user.GetName(), + } + ev.ImpersonatedUser.Groups = user.GetGroups() + ev.ImpersonatedUser.UID = user.GetUID() + ev.ImpersonatedUser.Extra = map[string]authnv1.ExtraValue{} + for k, v := range user.GetExtra() { + ev.ImpersonatedUser.Extra[k] = authnv1.ExtraValue(v) + } + }) +} + +func (ac *AuditContext) LogResponseObject(status *metav1.Status, obj *runtime.Unknown) { + ac.visitEvent(func(ae *auditinternal.Event) { + if status != nil { + // selectively copy the bounded fields. + ae.ResponseStatus = &metav1.Status{ + Status: status.Status, + Message: status.Message, + Reason: status.Reason, + Details: status.Details, + Code: status.Code, + } + } + if ae.Level.Less(auditinternal.LevelRequestResponse) { + return + } + ae.ResponseObject = obj + }) +} + +// LogRequestPatch fills in the given patch as the request object into an audit event. +func (ac *AuditContext) LogRequestPatch(patch []byte) { + ac.visitEvent(func(ae *auditinternal.Event) { + ae.RequestObject = &runtime.Unknown{ + Raw: patch, + ContentType: runtime.ContentTypeJSON, + } + }) +} + +func (ac *AuditContext) GetEventAnnotation(key string) (string, bool) { + var val string + var ok bool + ac.visitEvent(func(event *auditinternal.Event) { + val, ok = event.Annotations[key] + }) + return val, ok +} + +func (ac *AuditContext) GetEventLevel() auditinternal.Level { + var level auditinternal.Level + ac.visitEvent(func(event *auditinternal.Event) { + level = event.Level + }) + return level +} + +func (ac *AuditContext) SetEventLevel(level auditinternal.Level) { + ac.visitEvent(func(event *auditinternal.Event) { + event.Level = level + }) +} + +func (ac *AuditContext) SetEventStage(stage auditinternal.Stage) { + ac.visitEvent(func(event *auditinternal.Event) { + event.Stage = stage + }) +} + +func (ac *AuditContext) GetEventStage() auditinternal.Stage { + var stage auditinternal.Stage + ac.visitEvent(func(event *auditinternal.Event) { + stage = event.Stage + }) + return stage +} + +func (ac *AuditContext) SetEventStageTimestamp(timestamp metav1.MicroTime) { + ac.visitEvent(func(event *auditinternal.Event) { + event.StageTimestamp = timestamp + }) +} + +func (ac *AuditContext) GetEventResponseStatus() *metav1.Status { + var status *metav1.Status + ac.visitEvent(func(event *auditinternal.Event) { + status = event.ResponseStatus + }) + return status +} + +func (ac *AuditContext) GetEventRequestReceivedTimestamp() metav1.MicroTime { + var timestamp metav1.MicroTime + ac.visitEvent(func(event *auditinternal.Event) { + timestamp = event.RequestReceivedTimestamp + }) + return timestamp +} + +func (ac *AuditContext) GetEventStageTimestamp() metav1.MicroTime { + var timestamp metav1.MicroTime + ac.visitEvent(func(event *auditinternal.Event) { + timestamp = event.StageTimestamp + }) + return timestamp +} + +func (ac *AuditContext) SetEventResponseStatus(status *metav1.Status) { + ac.visitEvent(func(event *auditinternal.Event) { + event.ResponseStatus = status + }) +} + +func (ac *AuditContext) SetEventResponseStatusCode(statusCode int32) { + ac.visitEvent(func(event *auditinternal.Event) { + if event.ResponseStatus == nil { + event.ResponseStatus = &metav1.Status{} + } + event.ResponseStatus.Code = statusCode + }) +} + +func (ac *AuditContext) GetEventAnnotations() map[string]string { + var annotations map[string]string + ac.visitEvent(func(event *auditinternal.Event) { + annotations = maps.Clone(event.Annotations) + }) + return annotations } // AddAuditAnnotation sets the audit annotation for the given key, value pair. @@ -66,8 +284,8 @@ func AddAuditAnnotation(ctx context.Context, key, value string) { return } - ac.annotationMutex.Lock() - defer ac.annotationMutex.Unlock() + ac.lock.Lock() + defer ac.lock.Unlock() addAuditAnnotationLocked(ac, key, value) } @@ -81,8 +299,8 @@ func AddAuditAnnotations(ctx context.Context, keysAndValues ...string) { return } - ac.annotationMutex.Lock() - defer ac.annotationMutex.Unlock() + ac.lock.Lock() + defer ac.lock.Unlock() if len(keysAndValues)%2 != 0 { klog.Errorf("Dropping mismatched audit annotation %q", keysAndValues[len(keysAndValues)-1]) @@ -100,8 +318,8 @@ func AddAuditAnnotationsMap(ctx context.Context, annotations map[string]string) return } - ac.annotationMutex.Lock() - defer ac.annotationMutex.Unlock() + ac.lock.Lock() + defer ac.lock.Unlock() for k, v := range annotations { addAuditAnnotationLocked(ac, k, v) @@ -110,8 +328,7 @@ func AddAuditAnnotationsMap(ctx context.Context, annotations map[string]string) // addAuditAnnotationLocked records the audit annotation on the event. func addAuditAnnotationLocked(ac *AuditContext, key, value string) { - ae := &ac.Event - + ae := &ac.event if ae.Annotations == nil { ae.Annotations = make(map[string]string) } @@ -128,15 +345,11 @@ func WithAuditContext(parent context.Context) context.Context { return parent // Avoid double registering. } - return genericapirequest.WithValue(parent, auditKey, &AuditContext{}) -} - -// AuditEventFrom returns the audit event struct on the ctx -func AuditEventFrom(ctx context.Context) *auditinternal.Event { - if ac := AuditContextFrom(ctx); ac.Enabled() { - return &ac.Event - } - return nil + return genericapirequest.WithValue(parent, auditKey, &AuditContext{ + event: auditinternal.Event{ + Stage: auditinternal.StageResponseStarted, + }, + }) } // AuditContextFrom returns the pair of the audit configuration object @@ -154,7 +367,10 @@ func WithAuditID(ctx context.Context, auditID types.UID) { return } if ac := AuditContextFrom(ctx); ac != nil { - ac.Event.AuditID = auditID + ac.visitEvent(func(event *auditinternal.Event) { + ac.auditID.Store(auditID) + event.AuditID = auditID + }) } } @@ -162,7 +378,8 @@ func WithAuditID(ctx context.Context, auditID types.UID) { // auditing is enabled. func AuditIDFrom(ctx context.Context) (types.UID, bool) { if ac := AuditContextFrom(ctx); ac != nil { - return ac.Event.AuditID, true + id, _ := ac.auditID.Load().(types.UID) + return id, true } return "", false } diff --git a/pkg/audit/context_test.go b/pkg/audit/context_test.go index 2bb3d39dd..9606d395c 100644 --- a/pkg/audit/context_test.go +++ b/pkg/audit/context_test.go @@ -40,16 +40,34 @@ func TestEnabled(t *testing.T) { ctx: &AuditContext{}, expectEnabled: true, // An AuditContext should be considered enabled before the level is set }, { - name: "level None", - ctx: &AuditContext{RequestAuditConfig: RequestAuditConfig{Level: auditinternal.LevelNone}}, + name: "level None", + ctx: func() *AuditContext { + ctx := &AuditContext{} + if err := ctx.Init(RequestAuditConfig{Level: auditinternal.LevelNone}, nil); err != nil { + t.Fatal(err) + } + return ctx + }(), expectEnabled: false, }, { - name: "level Metadata", - ctx: &AuditContext{RequestAuditConfig: RequestAuditConfig{Level: auditinternal.LevelMetadata}}, + name: "level Metadata", + ctx: func() *AuditContext { + ctx := &AuditContext{} + if err := ctx.Init(RequestAuditConfig{Level: auditinternal.LevelMetadata}, nil); err != nil { + t.Fatal(err) + } + return ctx + }(), expectEnabled: true, }, { - name: "level RequestResponse", - ctx: &AuditContext{RequestAuditConfig: RequestAuditConfig{Level: auditinternal.LevelRequestResponse}}, + name: "level RequestResponse", + ctx: func() *AuditContext { + ctx := &AuditContext{} + if err := ctx.Init(RequestAuditConfig{Level: auditinternal.LevelRequestResponse}, nil); err != nil { + t.Fatal(err) + } + return ctx + }(), expectEnabled: true, }} @@ -72,7 +90,7 @@ func TestAddAuditAnnotation(t *testing.T) { assert.Len(t, annotations, numAnnotations) } - ctxWithAnnotation := withAuditContextAndLevel(context.Background(), auditinternal.LevelMetadata) + ctxWithAnnotation := withAuditContextAndLevel(context.Background(), t, auditinternal.LevelMetadata) AddAuditAnnotation(ctxWithAnnotation, fmt.Sprintf(annotationKeyTemplate, 0), annotationExtraValue) tests := []struct { @@ -89,28 +107,28 @@ func TestAddAuditAnnotation(t *testing.T) { // Annotations should be retained. ctx: WithAuditContext(context.Background()), validator: func(t *testing.T, ctx context.Context) { - ev := AuditContextFrom(ctx).Event + ev := AuditContextFrom(ctx).event expectAnnotations(t, ev.Annotations) }, }, { description: "with metadata level", - ctx: withAuditContextAndLevel(context.Background(), auditinternal.LevelMetadata), + ctx: withAuditContextAndLevel(context.Background(), t, auditinternal.LevelMetadata), validator: func(t *testing.T, ctx context.Context) { - ev := AuditContextFrom(ctx).Event + ev := AuditContextFrom(ctx).event expectAnnotations(t, ev.Annotations) }, }, { description: "with none level", - ctx: withAuditContextAndLevel(context.Background(), auditinternal.LevelNone), + ctx: withAuditContextAndLevel(context.Background(), t, auditinternal.LevelNone), validator: func(t *testing.T, ctx context.Context) { - ev := AuditContextFrom(ctx).Event + ev := AuditContextFrom(ctx).event assert.Empty(t, ev.Annotations) }, }, { description: "with overlapping annotations", ctx: ctxWithAnnotation, validator: func(t *testing.T, ctx context.Context) { - ev := AuditContextFrom(ctx).Event + ev := AuditContextFrom(ctx).event expectAnnotations(t, ev.Annotations) // Verify that the pre-existing annotation is not overwritten. assert.Equal(t, annotationExtraValue, ev.Annotations[fmt.Sprintf(annotationKeyTemplate, 0)]) @@ -144,8 +162,8 @@ func TestAuditAnnotationsWithAuditLoggingSetup(t *testing.T) { AddAuditAnnotation(ctx, "before-evaluation", "1") // policy evaluated, audit logging enabled - if ac := AuditContextFrom(ctx); ac != nil { - ac.RequestAuditConfig.Level = auditinternal.LevelMetadata + if err := AuditContextFrom(ctx).Init(RequestAuditConfig{Level: auditinternal.LevelMetadata}, nil); err != nil { + t.Fatal(err) } AddAuditAnnotation(ctx, "after-evaluation", "2") @@ -153,13 +171,14 @@ func TestAuditAnnotationsWithAuditLoggingSetup(t *testing.T) { "before-evaluation": "1", "after-evaluation": "2", } - actual := AuditContextFrom(ctx).Event.Annotations + actual := AuditContextFrom(ctx).event.Annotations assert.Equal(t, expected, actual) } -func withAuditContextAndLevel(ctx context.Context, l auditinternal.Level) context.Context { +func withAuditContextAndLevel(ctx context.Context, t *testing.T, l auditinternal.Level) context.Context { ctx = WithAuditContext(ctx) - ac := AuditContextFrom(ctx) - ac.RequestAuditConfig.Level = l + if err := AuditContextFrom(ctx).Init(RequestAuditConfig{Level: l}, nil); err != nil { + t.Fatal(err) + } return ctx } diff --git a/pkg/audit/request.go b/pkg/audit/request.go index 9185278f0..d8662e63f 100644 --- a/pkg/audit/request.go +++ b/pkg/audit/request.go @@ -45,105 +45,69 @@ func LogRequestMetadata(ctx context.Context, req *http.Request, requestReceivedT if !ac.Enabled() { return } - ev := &ac.Event - - ev.RequestReceivedTimestamp = metav1.NewMicroTime(requestReceivedTimestamp) - ev.Verb = attribs.GetVerb() - ev.RequestURI = req.URL.RequestURI() - ev.UserAgent = maybeTruncateUserAgent(req) - ev.Level = level - - ips := utilnet.SourceIPs(req) - ev.SourceIPs = make([]string, len(ips)) - for i := range ips { - ev.SourceIPs[i] = ips[i].String() - } - if user := attribs.GetUser(); user != nil { - ev.User.Username = user.GetName() - ev.User.Extra = map[string]authnv1.ExtraValue{} - for k, v := range user.GetExtra() { - ev.User.Extra[k] = authnv1.ExtraValue(v) + ac.visitEvent(func(ev *auditinternal.Event) { + ev.RequestReceivedTimestamp = metav1.NewMicroTime(requestReceivedTimestamp) + ev.Verb = attribs.GetVerb() + ev.RequestURI = req.URL.RequestURI() + ev.UserAgent = maybeTruncateUserAgent(req) + ev.Level = level + + ips := utilnet.SourceIPs(req) + ev.SourceIPs = make([]string, len(ips)) + for i := range ips { + ev.SourceIPs[i] = ips[i].String() } - ev.User.Groups = user.GetGroups() - ev.User.UID = user.GetUID() - } - if attribs.IsResourceRequest() { - ev.ObjectRef = &auditinternal.ObjectReference{ - Namespace: attribs.GetNamespace(), - Name: attribs.GetName(), - Resource: attribs.GetResource(), - Subresource: attribs.GetSubresource(), - APIGroup: attribs.GetAPIGroup(), - APIVersion: attribs.GetAPIVersion(), + if user := attribs.GetUser(); user != nil { + ev.User.Username = user.GetName() + ev.User.Extra = map[string]authnv1.ExtraValue{} + for k, v := range user.GetExtra() { + ev.User.Extra[k] = authnv1.ExtraValue(v) + } + ev.User.Groups = user.GetGroups() + ev.User.UID = user.GetUID() } - } + + if attribs.IsResourceRequest() { + ev.ObjectRef = &auditinternal.ObjectReference{ + Namespace: attribs.GetNamespace(), + Name: attribs.GetName(), + Resource: attribs.GetResource(), + Subresource: attribs.GetSubresource(), + APIGroup: attribs.GetAPIGroup(), + APIVersion: attribs.GetAPIVersion(), + } + } + }) } // LogImpersonatedUser fills in the impersonated user attributes into an audit event. -func LogImpersonatedUser(ae *auditinternal.Event, user user.Info) { - if ae == nil || ae.Level.Less(auditinternal.LevelMetadata) { +func LogImpersonatedUser(ctx context.Context, user user.Info) { + ac := AuditContextFrom(ctx) + if !ac.Enabled() { return } - ae.ImpersonatedUser = &authnv1.UserInfo{ - Username: user.GetName(), - } - ae.ImpersonatedUser.Groups = user.GetGroups() - ae.ImpersonatedUser.UID = user.GetUID() - ae.ImpersonatedUser.Extra = map[string]authnv1.ExtraValue{} - for k, v := range user.GetExtra() { - ae.ImpersonatedUser.Extra[k] = authnv1.ExtraValue(v) - } + ac.LogImpersonatedUser(user) } // LogRequestObject fills in the request object into an audit event. The passed runtime.Object // will be converted to the given gv. func LogRequestObject(ctx context.Context, obj runtime.Object, objGV schema.GroupVersion, gvr schema.GroupVersionResource, subresource string, s runtime.NegotiatedSerializer) { - ae := AuditEventFrom(ctx) - if ae == nil || ae.Level.Less(auditinternal.LevelMetadata) { + ac := AuditContextFrom(ctx) + if !ac.Enabled() { return } - - // complete ObjectRef - if ae.ObjectRef == nil { - ae.ObjectRef = &auditinternal.ObjectReference{} - } - - // meta.Accessor is more general than ObjectMetaAccessor, but if it fails, we can just skip setting these bits - if meta, err := meta.Accessor(obj); err == nil { - if len(ae.ObjectRef.Namespace) == 0 { - ae.ObjectRef.Namespace = meta.GetNamespace() - } - if len(ae.ObjectRef.Name) == 0 { - ae.ObjectRef.Name = meta.GetName() - } - if len(ae.ObjectRef.UID) == 0 { - ae.ObjectRef.UID = meta.GetUID() - } - if len(ae.ObjectRef.ResourceVersion) == 0 { - ae.ObjectRef.ResourceVersion = meta.GetResourceVersion() - } - } - if len(ae.ObjectRef.APIVersion) == 0 { - ae.ObjectRef.APIGroup = gvr.Group - ae.ObjectRef.APIVersion = gvr.Version - } - if len(ae.ObjectRef.Resource) == 0 { - ae.ObjectRef.Resource = gvr.Resource - } - if len(ae.ObjectRef.Subresource) == 0 { - ae.ObjectRef.Subresource = subresource - } - - if ae.Level.Less(auditinternal.LevelRequest) { + if ac.GetEventLevel().Less(auditinternal.LevelMetadata) { return } - if shouldOmitManagedFields(ctx) { + // meta.Accessor is more general than ObjectMetaAccessor, but if it fails, we can just skip setting these bits + objMeta, _ := meta.Accessor(obj) + if shouldOmitManagedFields(ac) { copy, ok, err := copyWithoutManagedFields(obj) if err != nil { - klog.ErrorS(err, "Error while dropping managed fields from the request", "auditID", ae.AuditID) + klog.ErrorS(err, "Error while dropping managed fields from the request", "auditID", ac.AuditID()) } if ok { obj = copy @@ -151,54 +115,73 @@ func LogRequestObject(ctx context.Context, obj runtime.Object, objGV schema.Grou } // TODO(audit): hook into the serializer to avoid double conversion - var err error - ae.RequestObject, err = encodeObject(obj, objGV, s) + requestObject, err := encodeObject(obj, objGV, s) if err != nil { // TODO(audit): add error slice to audit event struct - klog.ErrorS(err, "Encoding failed of request object", "auditID", ae.AuditID, "gvr", gvr.String(), "obj", obj) + klog.ErrorS(err, "Encoding failed of request object", "auditID", ac.AuditID(), "gvr", gvr.String(), "obj", obj) return } + + ac.visitEvent(func(ae *auditinternal.Event) { + if ae.ObjectRef == nil { + ae.ObjectRef = &auditinternal.ObjectReference{} + } + + if objMeta != nil { + if len(ae.ObjectRef.Namespace) == 0 { + ae.ObjectRef.Namespace = objMeta.GetNamespace() + } + if len(ae.ObjectRef.Name) == 0 { + ae.ObjectRef.Name = objMeta.GetName() + } + if len(ae.ObjectRef.UID) == 0 { + ae.ObjectRef.UID = objMeta.GetUID() + } + if len(ae.ObjectRef.ResourceVersion) == 0 { + ae.ObjectRef.ResourceVersion = objMeta.GetResourceVersion() + } + } + if len(ae.ObjectRef.APIVersion) == 0 { + ae.ObjectRef.APIGroup = gvr.Group + ae.ObjectRef.APIVersion = gvr.Version + } + if len(ae.ObjectRef.Resource) == 0 { + ae.ObjectRef.Resource = gvr.Resource + } + if len(ae.ObjectRef.Subresource) == 0 { + ae.ObjectRef.Subresource = subresource + } + + if ae.Level.Less(auditinternal.LevelRequest) { + return + } + ae.RequestObject = requestObject + }) } // LogRequestPatch fills in the given patch as the request object into an audit event. func LogRequestPatch(ctx context.Context, patch []byte) { - ae := AuditEventFrom(ctx) - if ae == nil || ae.Level.Less(auditinternal.LevelRequest) { + ac := AuditContextFrom(ctx) + if ac.GetEventLevel().Less(auditinternal.LevelRequest) { return } - - ae.RequestObject = &runtime.Unknown{ - Raw: patch, - ContentType: runtime.ContentTypeJSON, - } + ac.LogRequestPatch(patch) } // LogResponseObject fills in the response object into an audit event. The passed runtime.Object // will be converted to the given gv. func LogResponseObject(ctx context.Context, obj runtime.Object, gv schema.GroupVersion, s runtime.NegotiatedSerializer) { - ae := AuditEventFrom(ctx) - if ae == nil || ae.Level.Less(auditinternal.LevelMetadata) { + ac := AuditContextFrom(WithAuditContext(ctx)) + if ac.GetEventLevel().Less(auditinternal.LevelMetadata) { return } - if status, ok := obj.(*metav1.Status); ok { - // selectively copy the bounded fields. - ae.ResponseStatus = &metav1.Status{ - Status: status.Status, - Message: status.Message, - Reason: status.Reason, - Details: status.Details, - Code: status.Code, - } - } - if ae.Level.Less(auditinternal.LevelRequestResponse) { - return - } + status, _ := obj.(*metav1.Status) - if shouldOmitManagedFields(ctx) { + if shouldOmitManagedFields(ac) { copy, ok, err := copyWithoutManagedFields(obj) if err != nil { - klog.ErrorS(err, "Error while dropping managed fields from the response", "auditID", ae.AuditID) + klog.ErrorS(err, "Error while dropping managed fields from the response", "auditID", ac.AuditID()) } if ok { obj = copy @@ -207,10 +190,11 @@ func LogResponseObject(ctx context.Context, obj runtime.Object, gv schema.GroupV // TODO(audit): hook into the serializer to avoid double conversion var err error - ae.ResponseObject, err = encodeObject(obj, gv, s) + responseObject, err := encodeObject(obj, gv, s) if err != nil { - klog.ErrorS(err, "Encoding failed of response object", "auditID", ae.AuditID, "obj", obj) + klog.ErrorS(err, "Encoding failed of response object", "auditID", ac.AuditID(), "obj", obj) } + ac.LogResponseObject(status, responseObject) } func encodeObject(obj runtime.Object, gv schema.GroupVersion, serializer runtime.NegotiatedSerializer) (*runtime.Unknown, error) { @@ -301,9 +285,9 @@ func removeManagedFields(obj runtime.Object) error { return nil } -func shouldOmitManagedFields(ctx context.Context) bool { - if auditContext := AuditContextFrom(ctx); auditContext != nil { - return auditContext.RequestAuditConfig.OmitManagedFields +func shouldOmitManagedFields(ac *AuditContext) bool { + if ac != nil && ac.initialized.Load() && ac.requestAuditConfig.OmitManagedFields { + return true } // If we can't decide, return false to maintain current behavior which is diff --git a/pkg/authentication/token/cache/cached_token_authenticator.go b/pkg/authentication/token/cache/cached_token_authenticator.go index 18167dddc..1b448e5d8 100644 --- a/pkg/authentication/token/cache/cached_token_authenticator.go +++ b/pkg/authentication/token/cache/cached_token_authenticator.go @@ -201,10 +201,10 @@ func (a *cachedTokenAuthenticator) doAuthenticateToken(ctx context.Context, toke ac := audit.AuditContextFrom(ctx) // since this is shared work between multiple requests, we have no way of knowing if any // particular request supports audit annotations. thus we always attempt to record them. - ac.Event.Level = auditinternal.LevelMetadata + ac.SetEventLevel(auditinternal.LevelMetadata) record.resp, record.ok, record.err = a.authenticator.AuthenticateToken(ctx, token) - record.annotations = ac.Event.Annotations + record.annotations = ac.GetEventAnnotations() record.warnings = recorder.extractWarnings() if !a.cacheErrs && record.err != nil { diff --git a/pkg/authentication/token/cache/cached_token_authenticator_test.go b/pkg/authentication/token/cache/cached_token_authenticator_test.go index fab6381b2..c4902d808 100644 --- a/pkg/authentication/token/cache/cached_token_authenticator_test.go +++ b/pkg/authentication/token/cache/cached_token_authenticator_test.go @@ -306,7 +306,7 @@ func TestCachedAuditAnnotations(t *testing.T) { ctx := withAudit(context.Background()) _, _, _ = a.AuthenticateToken(ctx, "token") - allAnnotations <- audit.AuditEventFrom(ctx).Annotations + allAnnotations <- audit.AuditContextFrom(ctx).GetEventAnnotations() }() } @@ -343,7 +343,7 @@ func TestCachedAuditAnnotations(t *testing.T) { for i := 0; i < cap(allAnnotations); i++ { ctx := withAudit(context.Background()) _, _, _ = a.AuthenticateToken(ctx, "token") - allAnnotations = append(allAnnotations, audit.AuditEventFrom(ctx).Annotations) + allAnnotations = append(allAnnotations, audit.AuditContextFrom(ctx).GetEventAnnotations()) } if len(allAnnotations) != cap(allAnnotations) { @@ -370,14 +370,14 @@ func TestCachedAuditAnnotations(t *testing.T) { ctx1 := withAudit(context.Background()) _, _, _ = a.AuthenticateToken(ctx1, "token1") - annotations1 := audit.AuditEventFrom(ctx1).Annotations + annotations1 := audit.AuditContextFrom(ctx1).GetEventAnnotations() // guarantee different now times time.Sleep(time.Second) ctx2 := withAudit(context.Background()) _, _, _ = a.AuthenticateToken(ctx2, "token2") - annotations2 := audit.AuditEventFrom(ctx2).Annotations + annotations2 := audit.AuditContextFrom(ctx2).GetEventAnnotations() if ok := len(annotations1) == 1 && len(annotations1["timestamp"]) > 0; !ok { t.Errorf("invalid annotations 1: %v", annotations1) @@ -547,7 +547,7 @@ func (s *singleBenchmark) bench(b *testing.B) { func withAudit(ctx context.Context) context.Context { ctx = audit.WithAuditContext(ctx) ac := audit.AuditContextFrom(ctx) - ac.Event.Level = auditinternal.LevelMetadata + ac.SetEventLevel(auditinternal.LevelMetadata) return ctx } diff --git a/pkg/endpoints/filters/audit.go b/pkg/endpoints/filters/audit.go index 6f850f728..5f992fd9c 100644 --- a/pkg/endpoints/filters/audit.go +++ b/pkg/endpoints/filters/audit.go @@ -44,7 +44,7 @@ func WithAudit(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEva return handler } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ac, err := evaluatePolicyAndCreateAuditEvent(req, policy) + ac, err := evaluatePolicyAndCreateAuditEvent(req, policy, sink) if err != nil { utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err)) responsewriters.InternalError(w, req, errors.New("failed to create audit event")) @@ -55,41 +55,37 @@ func WithAudit(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEva handler.ServeHTTP(w, req) return } - ev := &ac.Event ctx := req.Context() - omitStages := ac.RequestAuditConfig.OmitStages - ev.Stage = auditinternal.StageRequestReceived - if processed := processAuditEvent(ctx, sink, ev, omitStages); !processed { + if processed := ac.ProcessEventStage(ctx, auditinternal.StageRequestReceived); !processed { audit.ApiserverAuditDroppedCounter.WithContext(ctx).Inc() responsewriters.InternalError(w, req, errors.New("failed to store audit event")) return } // intercept the status code - var longRunningSink audit.Sink + isLongRunning := false if longRunningCheck != nil { ri, _ := request.RequestInfoFrom(ctx) if longRunningCheck(req, ri) { - longRunningSink = sink + isLongRunning = true } } - respWriter := decorateResponseWriter(ctx, w, ev, longRunningSink, omitStages) + respWriter := decorateResponseWriter(ctx, w, isLongRunning) // send audit event when we leave this func, either via a panic or cleanly. In the case of long // running requests, this will be the second audit event. defer func() { if r := recover(); r != nil { defer panic(r) - ev.Stage = auditinternal.StagePanic - ev.ResponseStatus = &metav1.Status{ + ac.SetEventResponseStatus(&metav1.Status{ Code: http.StatusInternalServerError, Status: metav1.StatusFailure, Reason: metav1.StatusReasonInternalError, Message: fmt.Sprintf("APIServer panic'd: %v", r), - } - processAuditEvent(ctx, sink, ev, omitStages) + }) + ac.ProcessEventStage(ctx, auditinternal.StagePanic) return } @@ -100,27 +96,25 @@ func WithAudit(handler http.Handler, sink audit.Sink, policy audit.PolicyRuleEva Status: metav1.StatusSuccess, Message: "Connection closed early", } - if ev.ResponseStatus == nil && longRunningSink != nil { - ev.ResponseStatus = fakedSuccessStatus - ev.Stage = auditinternal.StageResponseStarted - processAuditEvent(ctx, longRunningSink, ev, omitStages) - } - - ev.Stage = auditinternal.StageResponseComplete - if ev.ResponseStatus == nil { - ev.ResponseStatus = fakedSuccessStatus + if ac.GetEventResponseStatus() == nil { + ac.SetEventResponseStatus(fakedSuccessStatus) + if isLongRunning { + // A nil ResponseStatus means the writer never processed the ResponseStarted stage, so do that now. + ac.ProcessEventStage(ctx, auditinternal.StageResponseStarted) + } } - processAuditEvent(ctx, sink, ev, omitStages) + writeLatencyToAnnotation(ctx) + ac.ProcessEventStage(ctx, auditinternal.StageResponseComplete) }() handler.ServeHTTP(respWriter, req) }) } // evaluatePolicyAndCreateAuditEvent is responsible for evaluating the audit -// policy configuration applicable to the request and create a new audit -// event that will be written to the API audit log. +// policy configuration applicable to the request and initializing the audit +// context with the audit config for the request, the sink to write to, and the request metadata. // - error if anything bad happened -func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRuleEvaluator) (*audit.AuditContext, error) { +func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRuleEvaluator, sink audit.Sink) (*audit.AuditContext, error) { ctx := req.Context() ac := audit.AuditContextFrom(ctx) if ac == nil { @@ -135,7 +129,10 @@ func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRul rac := policy.EvaluatePolicyRule(attribs) audit.ObservePolicyLevel(ctx, rac.Level) - ac.RequestAuditConfig = rac + err = ac.Init(rac, sink) + if err != nil { + return nil, fmt.Errorf("failed to initialize audit context: %w", err) + } if rac.Level == auditinternal.LevelNone { // Don't audit. return ac, nil @@ -153,13 +150,14 @@ func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRul // writeLatencyToAnnotation writes the latency incurred in different // layers of the apiserver to the annotations of the audit object. // it should be invoked after ev.StageTimestamp has been set appropriately. -func writeLatencyToAnnotation(ctx context.Context, ev *auditinternal.Event) { +func writeLatencyToAnnotation(ctx context.Context) { + ac := audit.AuditContextFrom(ctx) // we will track latency in annotation only when the total latency // of the given request exceeds 500ms, this is in keeping with the // traces in rest/handlers for create, delete, update, // get, list, and deletecollection. const threshold = 500 * time.Millisecond - latency := ev.StageTimestamp.Time.Sub(ev.RequestReceivedTimestamp.Time) + latency := ac.GetEventStageTimestamp().Sub(ac.GetEventRequestReceivedTimestamp().Time) if latency <= threshold { return } @@ -177,34 +175,12 @@ func writeLatencyToAnnotation(ctx context.Context, ev *auditinternal.Event) { audit.AddAuditAnnotationsMap(ctx, layerLatencies) } -func processAuditEvent(ctx context.Context, sink audit.Sink, ev *auditinternal.Event, omitStages []auditinternal.Stage) bool { - for _, stage := range omitStages { - if ev.Stage == stage { - return true - } - } - - switch { - case ev.Stage == auditinternal.StageRequestReceived: - ev.StageTimestamp = metav1.NewMicroTime(ev.RequestReceivedTimestamp.Time) - case ev.Stage == auditinternal.StageResponseComplete: - ev.StageTimestamp = metav1.NewMicroTime(time.Now()) - writeLatencyToAnnotation(ctx, ev) - default: - ev.StageTimestamp = metav1.NewMicroTime(time.Now()) - } - - audit.ObserveEvent(ctx) - return sink.ProcessEvents(ev) -} - -func decorateResponseWriter(ctx context.Context, responseWriter http.ResponseWriter, ev *auditinternal.Event, sink audit.Sink, omitStages []auditinternal.Stage) http.ResponseWriter { +func decorateResponseWriter(ctx context.Context, responseWriter http.ResponseWriter, processResponseStartedStage bool) http.ResponseWriter { delegate := &auditResponseWriter{ ctx: ctx, ResponseWriter: responseWriter, - event: ev, - sink: sink, - omitStages: omitStages, + + processResponseStartedStage: processResponseStartedStage, } return responsewriter.WrapForHTTP1Or2(delegate) @@ -217,11 +193,10 @@ var _ responsewriter.UserProvidedDecorator = &auditResponseWriter{} // create immediately an event (for long running requests). type auditResponseWriter struct { http.ResponseWriter - ctx context.Context - event *auditinternal.Event - once sync.Once - sink audit.Sink - omitStages []auditinternal.Stage + ctx context.Context + once sync.Once + + processResponseStartedStage bool } func (a *auditResponseWriter) Unwrap() http.ResponseWriter { @@ -230,14 +205,10 @@ func (a *auditResponseWriter) Unwrap() http.ResponseWriter { func (a *auditResponseWriter) processCode(code int) { a.once.Do(func() { - if a.event.ResponseStatus == nil { - a.event.ResponseStatus = &metav1.Status{} - } - a.event.ResponseStatus.Code = int32(code) - a.event.Stage = auditinternal.StageResponseStarted - - if a.sink != nil { - processAuditEvent(a.ctx, a.sink, a.event, a.omitStages) + ac := audit.AuditContextFrom(a.ctx) + ac.SetEventResponseStatusCode(int32(code)) + if a.processResponseStartedStage { + ac.ProcessEventStage(a.ctx, auditinternal.StageResponseStarted) } }) } diff --git a/pkg/endpoints/filters/audit_test.go b/pkg/endpoints/filters/audit_test.go index e9c375a9b..cffed607b 100644 --- a/pkg/endpoints/filters/audit_test.go +++ b/pkg/endpoints/filters/audit_test.go @@ -18,24 +18,37 @@ package filters import ( "context" + "math/rand" "net/http" "net/http/httptest" + "net/url" "reflect" "sync" "testing" "time" + "unsafe" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/serializer" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/wait" auditinternal "k8s.io/apiserver/pkg/apis/audit" + auditv1 "k8s.io/apiserver/pkg/apis/audit/v1" "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/audit/policy" "k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/apiserver/pkg/endpoints/responsewriter" + "k8s.io/apiserver/plugin/pkg/audit/buffered" + "k8s.io/apiserver/plugin/pkg/audit/log" + "k8s.io/apiserver/plugin/pkg/audit/webhook" + "k8s.io/client-go/rest" + "k8s.io/client-go/util/flowcontrol" + "k8s.io/client-go/util/retry" ) type fakeAuditSink struct { @@ -76,7 +89,7 @@ func (s *fakeAuditSink) Pop(timeout time.Duration) (*auditinternal.Event, error) func TestConstructResponseWriter(t *testing.T) { inner := &responsewriter.FakeResponseWriter{} - actual := decorateResponseWriter(context.Background(), inner, nil, nil, nil) + actual := decorateResponseWriter(context.Background(), inner, false) switch v := actual.(type) { case *auditResponseWriter: default: @@ -86,7 +99,7 @@ func TestConstructResponseWriter(t *testing.T) { t.Errorf("Expected the decorator to return the inner http.ResponseWriter object") } - actual = decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriterFlusherCloseNotifier{}, nil, nil, nil) + actual = decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriterFlusherCloseNotifier{}, false) //lint:file-ignore SA1019 Keep supporting deprecated http.CloseNotifier if _, ok := actual.(http.CloseNotifier); !ok { t.Errorf("Expected http.ResponseWriter to implement http.CloseNotifier") @@ -98,7 +111,7 @@ func TestConstructResponseWriter(t *testing.T) { t.Errorf("Expected http.ResponseWriter not to implement http.Hijacker") } - actual = decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriterFlusherCloseNotifierHijacker{}, nil, nil, nil) + actual = decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriterFlusherCloseNotifierHijacker{}, false) //lint:file-ignore SA1019 Keep supporting deprecated http.CloseNotifier if _, ok := actual.(http.CloseNotifier); !ok { t.Errorf("Expected http.ResponseWriter to implement http.CloseNotifier") @@ -112,37 +125,43 @@ func TestConstructResponseWriter(t *testing.T) { } func TestDecorateResponseWriterWithoutChannel(t *testing.T) { - ev := &auditinternal.Event{} - actual := decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriter{}, ev, nil, nil) + ctx := audit.WithAuditContext(context.Background()) + ac := audit.AuditContextFrom(ctx) + actual := decorateResponseWriter(ctx, &responsewriter.FakeResponseWriter{}, false) // write status. This will not block because firstEventSentCh is nil actual.WriteHeader(42) - if ev.ResponseStatus == nil { + if ac.GetEventResponseStatus() == nil { t.Fatalf("Expected ResponseStatus to be non-nil") } - if ev.ResponseStatus.Code != 42 { - t.Errorf("expected status code 42, got %d", ev.ResponseStatus.Code) + if ac.GetEventResponseStatus().Code != 42 { + t.Errorf("expected status code 42, got %d", ac.GetEventResponseStatus().Code) } } func TestDecorateResponseWriterWithImplicitWrite(t *testing.T) { - ev := &auditinternal.Event{} - actual := decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriter{}, ev, nil, nil) + ctx := audit.WithAuditContext(context.Background()) + ac := audit.AuditContextFrom(ctx) + actual := decorateResponseWriter(ctx, &responsewriter.FakeResponseWriter{}, false) // write status. This will not block because firstEventSentCh is nil actual.Write([]byte("foo")) - if ev.ResponseStatus == nil { + if ac.GetEventResponseStatus() == nil { t.Fatalf("Expected ResponseStatus to be non-nil") } - if ev.ResponseStatus.Code != 200 { - t.Errorf("expected status code 200, got %d", ev.ResponseStatus.Code) + if ac.GetEventResponseStatus().Code != 200 { + t.Errorf("expected status code 200, got %d", ac.GetEventResponseStatus().Code) } } func TestDecorateResponseWriterChannel(t *testing.T) { + ctx := audit.WithAuditContext(context.Background()) sink := &fakeAuditSink{} - ev := &auditinternal.Event{} - actual := decorateResponseWriter(context.Background(), &responsewriter.FakeResponseWriter{}, ev, sink, nil) + auditContext := audit.AuditContextFrom(ctx) + if err := auditContext.Init(audit.RequestAuditConfig{}, sink); err != nil { + t.Fatal(err) + } + actual := decorateResponseWriter(ctx, &responsewriter.FakeResponseWriter{}, true) done := make(chan struct{}) go func() { @@ -164,8 +183,11 @@ func TestDecorateResponseWriterChannel(t *testing.T) { } t.Logf("Seen event with status %v", ev1.ResponseStatus) - if !reflect.DeepEqual(ev, ev1) { - t.Fatalf("ev1 and ev must be equal") + ev := getAuditContextEvent(auditContext) + if diff := cmp.Diff(ev, ev1, cmp.FilterPath(func(p cmp.Path) bool { + return p.String() == "StageTimestamp" + }, cmp.Ignore())); diff != "" { + t.Fatalf("ev1 and ev must be equal, diff: %s", diff) } <-done @@ -178,6 +200,20 @@ func TestDecorateResponseWriterChannel(t *testing.T) { } } +func getAuditContextEvent(ac *audit.AuditContext) *auditinternal.Event { + // Get the reflect.Value of the AuditContext + val := reflect.ValueOf(ac).Elem() + + // Access the unexported `event` field + eventField := val.FieldByName("event") + + // Use unsafe to get a pointer to the field + eventPtr := unsafe.Pointer(eventField.UnsafeAddr()) + + // Cast the pointer to the correct type + return (*auditinternal.Event)(eventPtr) +} + type fakeHTTPHandler struct{} func (*fakeHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -848,11 +884,130 @@ func withTestContext(req *http.Request, user user.Info, ae *auditinternal.Event) ctx = request.WithUser(ctx, user) } if ae != nil { - ac := audit.AuditContextFrom(ctx) - ac.Event = *ae + ev := getAuditContextEvent(audit.AuditContextFrom(ctx)) + *ev = *ae } if info, err := newTestRequestInfoResolver().NewRequestInfo(req); err == nil { ctx = request.WithRequestInfo(ctx, info) } return req.WithContext(ctx) } + +type fakeAuditFile struct{} + +func (s fakeAuditFile) Write(p []byte) (n int, err error) { + time.Sleep(time.Duration(rand.Int63n(10000))) + return len(p), nil +} + +type fakeAuditWebhookAuditBackend struct { +} + +func (f fakeAuditWebhookAuditBackend) RoundTrip(r *http.Request) (*http.Response, error) { + time.Sleep(time.Duration(rand.Int63n(10000))) + return &http.Response{ + StatusCode: http.StatusOK, + }, nil +} + +// Test case for https://github.com/kubernetes/kubernetes/issues/120507 +// to test for race conditions in audit backends use the following command: +// `go test ./ -race --run=TestAuditBackendRaceCondition -v` +func TestAuditBackendRaceCondition(t *testing.T) { + defaultFakeLogBackend := log.NewBackend(fakeAuditFile{}, log.FormatJson, auditv1.SchemeGroupVersion) + testCases := []struct { + name string + backendBuilder func() audit.Backend + }{ + { + "log audit backend", + func() audit.Backend { + return defaultFakeLogBackend + }, + }, + { + "buffered audit backend", + func() audit.Backend { + backend := buffered.NewBackend(defaultFakeLogBackend, buffered.BatchConfig{ + BufferSize: 10000, + MaxBatchSize: 1, + ThrottleEnable: false, + AsyncDelegate: false, + }) + err := backend.Run(wait.NeverStop) + if err != nil { + t.Fatal(err) + } + return backend + }, + }, + { + name: "webhook audit backend", + backendBuilder: func() audit.Backend { + codecFactory := audit.Codecs + codec := codecFactory.LegacyCodec(auditv1.SchemeGroupVersion) + negotiatedSerializer := serializer.NegotiatedSerializerWrapper(runtime.SerializerInfo{Serializer: codec}) + client, err := rest.NewRESTClient(&url.URL{}, "/hello", rest.ClientContentConfig{ + ContentType: "application/json", + Negotiator: runtime.NewClientNegotiator(negotiatedSerializer, auditv1.SchemeGroupVersion), + }, flowcontrol.NewTokenBucketRateLimiter(100, 200), &http.Client{Transport: fakeAuditWebhookAuditBackend{}}) + if err != nil { + t.Fatal(err) + } + return webhook.NewDynamicBackend(client, retry.DefaultBackoff) + }, + }, + { + "union audit backend", + func() audit.Backend { + return audit.Union(defaultFakeLogBackend, defaultFakeLogBackend) + }, + }, + } + fakeRuleEvaluator := policy.NewFakePolicyRuleEvaluator(auditinternal.LevelRequestResponse, nil) + longRunningCheck := func(r *http.Request, ri *request.RequestInfo) bool { return false } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), wait.ForeverTestTimeout) + defer cancel() + for { + select { + case <-ctx.Done(): + // finished the test + return + default: + } + serveStarted := make(chan struct{}) + req, _ := http.NewRequest(http.MethodGet, "/api/v1/namespaces/default/pods/foo", nil) + req = withTestContext(req, &user.DefaultInfo{Name: "admin"}, nil) + backend := tc.backendBuilder() + go func() { + <-serveStarted + for { + select { + case <-ctx.Done(): + // finished the test + backend.Shutdown() + return + default: + } + audit.AddAuditAnnotations(req.Context(), "a", "b") + } + }() + realHandler := http.HandlerFunc(func(writer http.ResponseWriter, r *http.Request) { + close(serveStarted) + // mock some business logic + time.Sleep(time.Millisecond) + }) + handler := WithAudit(realHandler, backend, fakeRuleEvaluator, longRunningCheck) + handler = WithAuditInit(handler) + serveFinished := make(chan struct{}) + go func() { + defer close(serveFinished) + handler.ServeHTTP(httptest.NewRecorder(), req) + }() + <-serveFinished + } + }) + } +} diff --git a/pkg/endpoints/filters/authn_audit.go b/pkg/endpoints/filters/authn_audit.go index 4bd6bbc13..d9cdcd2d6 100644 --- a/pkg/endpoints/filters/authn_audit.go +++ b/pkg/endpoints/filters/authn_audit.go @@ -24,7 +24,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilruntime "k8s.io/apimachinery/pkg/util/runtime" - auditinternal "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/endpoints/handlers/responsewriters" ) @@ -36,7 +35,7 @@ func WithFailedAuthenticationAudit(failedHandler http.Handler, sink audit.Sink, return failedHandler } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ac, err := evaluatePolicyAndCreateAuditEvent(req, policy) + ac, err := evaluatePolicyAndCreateAuditEvent(req, policy, sink) if err != nil { utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err)) responsewriters.InternalError(w, req, errors.New("failed to create audit event")) @@ -47,13 +46,11 @@ func WithFailedAuthenticationAudit(failedHandler http.Handler, sink audit.Sink, failedHandler.ServeHTTP(w, req) return } - ev := &ac.Event - ev.ResponseStatus = &metav1.Status{} - ev.ResponseStatus.Message = getAuthMethods(req) - ev.Stage = auditinternal.StageResponseStarted - - rw := decorateResponseWriter(req.Context(), w, ev, sink, ac.RequestAuditConfig.OmitStages) + ac.SetEventResponseStatus(&metav1.Status{ + Message: getAuthMethods(req), + }) + rw := decorateResponseWriter(req.Context(), w, true) failedHandler.ServeHTTP(rw, req) }) } diff --git a/pkg/endpoints/filters/authorization_test.go b/pkg/endpoints/filters/authorization_test.go index deef9054b..b2bd49fd6 100644 --- a/pkg/endpoints/filters/authorization_test.go +++ b/pkg/endpoints/filters/authorization_test.go @@ -286,11 +286,21 @@ func TestAuditAnnotation(t *testing.T) { req, _ := http.NewRequest("GET", "/api/v1/namespaces/default/pods", nil) req = withTestContext(req, nil, &auditinternal.Event{Level: auditinternal.LevelMetadata}) - ae := audit.AuditEventFrom(req.Context()) + ae := audit.AuditContextFrom(req.Context()) req.RemoteAddr = "127.0.0.1" handler.ServeHTTP(httptest.NewRecorder(), req) - assert.Equal(t, tc.decisionAnnotation, ae.Annotations[decisionAnnotationKey], k+": unexpected decision annotation") - assert.Equal(t, tc.reasonAnnotation, ae.Annotations[reasonAnnotationKey], k+": unexpected reason annotation") + + var annotation string + var ok bool + if len(tc.decisionAnnotation) > 0 { + annotation, ok = ae.GetEventAnnotation(decisionAnnotationKey) + assert.True(t, ok, k+": decision annotation not found") + assert.Equal(t, tc.decisionAnnotation, annotation, k+": unexpected decision annotation") + } + + annotation, ok = ae.GetEventAnnotation(reasonAnnotationKey) + assert.True(t, ok, k+": reason annotation not found") + assert.Equal(t, tc.reasonAnnotation, annotation, k+": unexpected reason annotation") } } diff --git a/pkg/endpoints/filters/impersonation.go b/pkg/endpoints/filters/impersonation.go index a6d293a15..aa47a7536 100644 --- a/pkg/endpoints/filters/impersonation.go +++ b/pkg/endpoints/filters/impersonation.go @@ -166,8 +166,7 @@ func WithImpersonation(handler http.Handler, a authorizer.Authorizer, s runtime. oldUser, _ := request.UserFrom(ctx) httplog.LogOf(req, w).Addf("%v is impersonating %v", userString(oldUser), userString(newUser)) - ae := audit.AuditEventFrom(ctx) - audit.LogImpersonatedUser(ae, newUser) + audit.LogImpersonatedUser(audit.WithAuditContext(ctx), newUser) // clear all the impersonation headers from the request req.Header.Del(authenticationv1.ImpersonateUserHeader) diff --git a/pkg/endpoints/filters/request_deadline.go b/pkg/endpoints/filters/request_deadline.go index 7497bc38a..066d670a2 100644 --- a/pkg/endpoints/filters/request_deadline.go +++ b/pkg/endpoints/filters/request_deadline.go @@ -108,7 +108,7 @@ func withFailedRequestAudit(failedHandler http.Handler, statusErr *apierrors.Sta return failedHandler } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ac, err := evaluatePolicyAndCreateAuditEvent(req, policy) + ac, err := evaluatePolicyAndCreateAuditEvent(req, policy, sink) if err != nil { utilruntime.HandleError(fmt.Errorf("failed to create audit event: %v", err)) responsewriters.InternalError(w, req, errors.New("failed to create audit event")) @@ -119,15 +119,15 @@ func withFailedRequestAudit(failedHandler http.Handler, statusErr *apierrors.Sta failedHandler.ServeHTTP(w, req) return } - ev := &ac.Event - ev.ResponseStatus = &metav1.Status{} - ev.Stage = auditinternal.StageResponseStarted + respStatus := &metav1.Status{} if statusErr != nil { - ev.ResponseStatus.Message = statusErr.Error() + respStatus.Message = statusErr.Error() } + ac.SetEventResponseStatus(respStatus) + ac.SetEventStage(auditinternal.StageResponseStarted) - rw := decorateResponseWriter(req.Context(), w, ev, sink, ac.RequestAuditConfig.OmitStages) + rw := decorateResponseWriter(req.Context(), w, true) failedHandler.ServeHTTP(rw, req) }) } diff --git a/pkg/endpoints/filters/request_deadline_test.go b/pkg/endpoints/filters/request_deadline_test.go index 6cc1b3c38..6216429f8 100644 --- a/pkg/endpoints/filters/request_deadline_test.go +++ b/pkg/endpoints/filters/request_deadline_test.go @@ -22,7 +22,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "reflect" "strings" "testing" "time" @@ -408,21 +407,21 @@ func TestWithFailedRequestAudit(t *testing.T) { t.Errorf("expected an http.ResponseWriter of type: %T but got: %T", &auditResponseWriter{}, rwGot) } - auditEventGot := audit.AuditEventFrom(requestGot.Context()) - if auditEventGot == nil { + auditContext := audit.AuditContextFrom(requestGot.Context()) + if auditContext == nil { t.Fatal("expected an audit event object but got nil") } - if auditEventGot.Stage != auditinternal.StageResponseStarted { - t.Errorf("expected audit event Stage: %s, but got: %s", auditinternal.StageResponseStarted, auditEventGot.Stage) + if auditContext.GetEventStage() != auditinternal.StageResponseStarted { + t.Errorf("expected audit event Stage: %s, but got: %s", auditinternal.StageResponseStarted, auditContext.GetEventStage()) } - if auditEventGot.ResponseStatus == nil { + if auditContext.GetEventResponseStatus() == nil { t.Fatal("expected a ResponseStatus field of the audit event object, but got nil") } - if test.statusCodeExpected != int(auditEventGot.ResponseStatus.Code) { - t.Errorf("expected audit event ResponseStatus.Code: %d, but got: %d", test.statusCodeExpected, auditEventGot.ResponseStatus.Code) + if test.statusCodeExpected != int(auditContext.GetEventResponseStatus().Code) { + t.Errorf("expected audit event ResponseStatus.Code: %d, but got: %d", test.statusCodeExpected, auditContext.GetEventResponseStatus().Code) } - if test.statusErr.Error() != auditEventGot.ResponseStatus.Message { - t.Errorf("expected audit event ResponseStatus.Message: %s, but got: %s", test.statusErr, auditEventGot.ResponseStatus.Message) + if test.statusErr.Error() != auditContext.GetEventResponseStatus().Message { + t.Errorf("expected audit event ResponseStatus.Message: %s, but got: %s", test.statusErr, auditContext.GetEventResponseStatus().Message) } // verify that the audit event from the request context is written to the audit sink. @@ -430,8 +429,12 @@ func TestWithFailedRequestAudit(t *testing.T) { t.Fatalf("expected audit sink to have 1 event, but got: %d", len(fakeSink.events)) } auditEventFromSink := fakeSink.events[0] - if !reflect.DeepEqual(auditEventGot, auditEventFromSink) { - t.Errorf("expected the audit event from the request context to be written to the audit sink, but got diffs: %s", cmp.Diff(auditEventGot, auditEventFromSink)) + eventFromAuditContext := getAuditContextEvent(auditContext) + + if diff := cmp.Diff(eventFromAuditContext, auditEventFromSink, cmp.FilterPath(func(p cmp.Path) bool { + return p.String() == "StageTimestamp" + }, cmp.Ignore())); diff != "" { + t.Errorf("expected the audit event from the request context to be written to the audit sink, but got diffs: %s", diff) } } }) diff --git a/pkg/endpoints/handlers/delete_test.go b/pkg/endpoints/handlers/delete_test.go index 90cb2f9f4..9f38420c3 100644 --- a/pkg/endpoints/handlers/delete_test.go +++ b/pkg/endpoints/handlers/delete_test.go @@ -74,7 +74,7 @@ func TestDeleteResourceAuditLogRequestObject(t *testing.T) { ctx := audit.WithAuditContext(context.TODO()) ac := audit.AuditContextFrom(ctx) - ac.Event.Level = auditapis.LevelRequestResponse + ac.SetEventLevel(auditapis.LevelRequestResponse) policy := metav1.DeletePropagationBackground deleteOption := &metav1.DeleteOptions{ diff --git a/pkg/server/config_test.go b/pkg/server/config_test.go index a4c2b7359..577b99fc2 100644 --- a/pkg/server/config_test.go +++ b/pkg/server/config_test.go @@ -351,8 +351,8 @@ func TestAuthenticationAuditAnnotationsDefaultChain(t *testing.T) { } // confirm that we have an audit event - ae := audit.AuditEventFrom(r.Context()) - if ae == nil { + ac := audit.AuditContextFrom(r.Context()) + if ac == nil { t.Error("unexpected nil audit event") } @@ -376,11 +376,15 @@ func TestAuthenticationAuditAnnotationsDefaultChain(t *testing.T) { } // these should all be the same because the handler chain mutates the event in place want := map[string]string{"pandas": "are awesome", "dogs": "are okay"} + foundResponseComplete := false for _, event := range backend.events { + if event.Stage == auditinternal.StageRequestReceived { + continue + } if event.Stage != auditinternal.StageResponseComplete { t.Errorf("expected event stage to be complete, got: %s", event.Stage) } - + foundResponseComplete = true for wantK, wantV := range want { gotV, ok := event.Annotations[wantK] if !ok { @@ -392,6 +396,9 @@ func TestAuthenticationAuditAnnotationsDefaultChain(t *testing.T) { } } } + if !foundResponseComplete { + t.Errorf("expected to find %s in events", auditinternal.StageResponseComplete) + } } type testBackend struct { diff --git a/pkg/util/x509metrics/server_cert_deprecations_test.go b/pkg/util/x509metrics/server_cert_deprecations_test.go index dfdd565b2..d9d16c4c8 100644 --- a/pkg/util/x509metrics/server_cert_deprecations_test.go +++ b/pkg/util/x509metrics/server_cert_deprecations_test.go @@ -247,15 +247,15 @@ func TestCheckForHostnameError(t *testing.T) { } req = req.WithContext(audit.WithAuditContext(req.Context())) auditCtx := audit.AuditContextFrom(req.Context()) - auditCtx.Event.Level = auditapi.LevelMetadata + auditCtx.SetEventLevel(auditapi.LevelMetadata) _, err = client.Transport.RoundTrip(req) if sanChecker.CheckRoundTripError(err) { sanChecker.IncreaseMetricsCounter(req) - - if len(auditCtx.Event.Annotations["missing-san.invalid-cert.kubernetes.io/"+req.URL.Hostname()]) == 0 { - t.Errorf("expected audit annotations, got %#v", auditCtx.Event.Annotations) + annotations := auditCtx.GetEventAnnotations() + if len(annotations["missing-san.invalid-cert.kubernetes.io/"+req.URL.Hostname()]) == 0 { + t.Errorf("expected audit annotations, got %#v", annotations) } } @@ -390,7 +390,7 @@ func TestCheckForInsecureAlgorithmError(t *testing.T) { } req = req.WithContext(audit.WithAuditContext(req.Context())) auditCtx := audit.AuditContextFrom(req.Context()) - auditCtx.Event.Level = auditapi.LevelMetadata + auditCtx.SetEventLevel(auditapi.LevelMetadata) // can't use tlsServer.Client() as it contains the server certificate // in tls.Config.Certificates. The signatures are, however, only checked @@ -414,9 +414,9 @@ func TestCheckForInsecureAlgorithmError(t *testing.T) { if sha1checker.CheckRoundTripError(err) { sha1checker.IncreaseMetricsCounter(req) - - if len(auditCtx.Event.Annotations["insecure-sha1.invalid-cert.kubernetes.io/"+req.URL.Hostname()]) == 0 { - t.Errorf("expected audit annotations, got %#v", auditCtx.Event.Annotations) + annotations := auditCtx.GetEventAnnotations() + if len(annotations["insecure-sha1.invalid-cert.kubernetes.io/"+req.URL.Hostname()]) == 0 { + t.Errorf("expected audit annotations, got %#v", annotations) } } From fb6914a125364bf7f0e75de18095f8ba85ab1147 Mon Sep 17 00:00:00 2001 From: Davanum Srinivas Date: Fri, 9 May 2025 06:57:31 -0400 Subject: [PATCH 2/4] UPSTREAM: 131694: Eliminate AuditContext`s SetEventLevel Signed-off-by: Davanum Srinivas Co-Authored-By: Jordan Liggitt Set event level during context init Signed-off-by: Davanum Srinivas Kubernetes-commit: 960a4939f2502f2a8f2b923203e9075354e4bdc0 --- pkg/admission/audit_test.go | 7 +++--- pkg/audit/context.go | 23 ++++++------------- pkg/audit/request.go | 3 +-- .../token/cache/cached_token_authenticator.go | 4 ---- .../cache/cached_token_authenticator_test.go | 3 --- pkg/endpoints/filters/audit.go | 2 +- pkg/endpoints/handlers/delete_test.go | 6 +++-- .../server_cert_deprecations_test.go | 3 --- 8 files changed, 17 insertions(+), 34 deletions(-) diff --git a/pkg/admission/audit_test.go b/pkg/admission/audit_test.go index dde433d79..8b291244b 100644 --- a/pkg/admission/audit_test.go +++ b/pkg/admission/audit_test.go @@ -144,7 +144,10 @@ func TestWithAudit(t *testing.T) { var handler Interface = fakeHandler{tc.admit, tc.admitAnnotations, tc.validate, tc.validateAnnotations, tc.handles} ctx := audit.WithAuditContext(context.Background()) ac := audit.AuditContextFrom(ctx) - ac.SetEventLevel(auditinternal.LevelMetadata) + if err := ac.Init(audit.RequestAuditConfig{Level: auditinternal.LevelMetadata}, nil); err != nil { + t.Fatal(err) + } + auditHandler := WithAudit(handler) a := attributes() @@ -186,8 +189,6 @@ func TestWithAuditConcurrency(t *testing.T) { } var handler Interface = fakeHandler{admitAnnotations: admitAnnotations, handles: true} ctx := audit.WithAuditContext(context.Background()) - ac := audit.AuditContextFrom(ctx) - ac.SetEventLevel(auditinternal.LevelMetadata) auditHandler := WithAudit(handler) a := attributes() diff --git a/pkg/audit/context.go b/pkg/audit/context.go index 538b3d956..5b93d594b 100644 --- a/pkg/audit/context.go +++ b/pkg/audit/context.go @@ -46,8 +46,6 @@ type AuditContext struct { // initialized indicates whether requestAuditConfig and sink have been populated and are safe to read unguarded. // This should only be set via Init(). initialized atomic.Bool - // initialize wraps setting requestAuditConfig and sink, and is only called via Init(). - initialize sync.Once // requestAuditConfig is the audit configuration that applies to the request. // This should only be written via Init(RequestAuditConfig, Sink), and only read when initialized.Load() is true. requestAuditConfig RequestAuditConfig @@ -81,16 +79,15 @@ func (ac *AuditContext) Enabled() bool { } func (ac *AuditContext) Init(requestAuditConfig RequestAuditConfig, sink Sink) error { - initialized := false - ac.initialize.Do(func() { - ac.requestAuditConfig = requestAuditConfig - ac.sink = sink - ac.initialized.Store(true) - initialized = true - }) - if !initialized { + ac.lock.Lock() + defer ac.lock.Unlock() + if ac.initialized.Load() { return errors.New("audit context was already initialized") } + ac.requestAuditConfig = requestAuditConfig + ac.sink = sink + ac.event.Level = requestAuditConfig.Level + ac.initialized.Store(true) return nil } @@ -198,12 +195,6 @@ func (ac *AuditContext) GetEventLevel() auditinternal.Level { return level } -func (ac *AuditContext) SetEventLevel(level auditinternal.Level) { - ac.visitEvent(func(event *auditinternal.Event) { - event.Level = level - }) -} - func (ac *AuditContext) SetEventStage(stage auditinternal.Stage) { ac.visitEvent(func(event *auditinternal.Event) { event.Stage = stage diff --git a/pkg/audit/request.go b/pkg/audit/request.go index d8662e63f..60b69b0b2 100644 --- a/pkg/audit/request.go +++ b/pkg/audit/request.go @@ -40,7 +40,7 @@ const ( userAgentTruncateSuffix = "...TRUNCATED" ) -func LogRequestMetadata(ctx context.Context, req *http.Request, requestReceivedTimestamp time.Time, level auditinternal.Level, attribs authorizer.Attributes) { +func LogRequestMetadata(ctx context.Context, req *http.Request, requestReceivedTimestamp time.Time, attribs authorizer.Attributes) { ac := AuditContextFrom(ctx) if !ac.Enabled() { return @@ -51,7 +51,6 @@ func LogRequestMetadata(ctx context.Context, req *http.Request, requestReceivedT ev.Verb = attribs.GetVerb() ev.RequestURI = req.URL.RequestURI() ev.UserAgent = maybeTruncateUserAgent(req) - ev.Level = level ips := utilnet.SourceIPs(req) ev.SourceIPs = make([]string, len(ips)) diff --git a/pkg/authentication/token/cache/cached_token_authenticator.go b/pkg/authentication/token/cache/cached_token_authenticator.go index 1b448e5d8..9d1556e63 100644 --- a/pkg/authentication/token/cache/cached_token_authenticator.go +++ b/pkg/authentication/token/cache/cached_token_authenticator.go @@ -33,7 +33,6 @@ import ( "golang.org/x/sync/singleflight" apierrors "k8s.io/apimachinery/pkg/api/errors" - auditinternal "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/warning" @@ -199,9 +198,6 @@ func (a *cachedTokenAuthenticator) doAuthenticateToken(ctx context.Context, toke ctx = audit.WithAuditContext(ctx) ac := audit.AuditContextFrom(ctx) - // since this is shared work between multiple requests, we have no way of knowing if any - // particular request supports audit annotations. thus we always attempt to record them. - ac.SetEventLevel(auditinternal.LevelMetadata) record.resp, record.ok, record.err = a.authenticator.AuthenticateToken(ctx, token) record.annotations = ac.GetEventAnnotations() diff --git a/pkg/authentication/token/cache/cached_token_authenticator_test.go b/pkg/authentication/token/cache/cached_token_authenticator_test.go index c4902d808..7913575ce 100644 --- a/pkg/authentication/token/cache/cached_token_authenticator_test.go +++ b/pkg/authentication/token/cache/cached_token_authenticator_test.go @@ -35,7 +35,6 @@ import ( utilrand "k8s.io/apimachinery/pkg/util/rand" "k8s.io/apimachinery/pkg/util/uuid" - auditinternal "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/user" @@ -546,8 +545,6 @@ func (s *singleBenchmark) bench(b *testing.B) { // extraction. func withAudit(ctx context.Context) context.Context { ctx = audit.WithAuditContext(ctx) - ac := audit.AuditContextFrom(ctx) - ac.SetEventLevel(auditinternal.LevelMetadata) return ctx } diff --git a/pkg/endpoints/filters/audit.go b/pkg/endpoints/filters/audit.go index 5f992fd9c..d25bf35ae 100644 --- a/pkg/endpoints/filters/audit.go +++ b/pkg/endpoints/filters/audit.go @@ -142,7 +142,7 @@ func evaluatePolicyAndCreateAuditEvent(req *http.Request, policy audit.PolicyRul if !ok { requestReceivedTimestamp = time.Now() } - audit.LogRequestMetadata(ctx, req, requestReceivedTimestamp, rac.Level, attribs) + audit.LogRequestMetadata(ctx, req, requestReceivedTimestamp, attribs) return ac, nil } diff --git a/pkg/endpoints/handlers/delete_test.go b/pkg/endpoints/handlers/delete_test.go index 9f38420c3..7fd19df10 100644 --- a/pkg/endpoints/handlers/delete_test.go +++ b/pkg/endpoints/handlers/delete_test.go @@ -34,7 +34,7 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/serializer" "k8s.io/apiserver/pkg/admission" - auditapis "k8s.io/apiserver/pkg/apis/audit" + auditinternal "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authorization/authorizer" @@ -74,7 +74,9 @@ func TestDeleteResourceAuditLogRequestObject(t *testing.T) { ctx := audit.WithAuditContext(context.TODO()) ac := audit.AuditContextFrom(ctx) - ac.SetEventLevel(auditapis.LevelRequestResponse) + if err := ac.Init(audit.RequestAuditConfig{Level: auditinternal.LevelRequestResponse}, nil); err != nil { + t.Fatal(err) + } policy := metav1.DeletePropagationBackground deleteOption := &metav1.DeleteOptions{ diff --git a/pkg/util/x509metrics/server_cert_deprecations_test.go b/pkg/util/x509metrics/server_cert_deprecations_test.go index d9d16c4c8..eaa17bcf8 100644 --- a/pkg/util/x509metrics/server_cert_deprecations_test.go +++ b/pkg/util/x509metrics/server_cert_deprecations_test.go @@ -30,7 +30,6 @@ import ( "testing" "github.com/stretchr/testify/require" - auditapi "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/audit" "k8s.io/component-base/metrics" "k8s.io/component-base/metrics/testutil" @@ -247,7 +246,6 @@ func TestCheckForHostnameError(t *testing.T) { } req = req.WithContext(audit.WithAuditContext(req.Context())) auditCtx := audit.AuditContextFrom(req.Context()) - auditCtx.SetEventLevel(auditapi.LevelMetadata) _, err = client.Transport.RoundTrip(req) @@ -390,7 +388,6 @@ func TestCheckForInsecureAlgorithmError(t *testing.T) { } req = req.WithContext(audit.WithAuditContext(req.Context())) auditCtx := audit.AuditContextFrom(req.Context()) - auditCtx.SetEventLevel(auditapi.LevelMetadata) // can't use tlsServer.Client() as it contains the server certificate // in tls.Config.Certificates. The signatures are, however, only checked From 550c6e4749a9979cf61969e455416fce33bd1bbf Mon Sep 17 00:00:00 2001 From: Davanum Srinivas Date: Mon, 12 May 2025 09:29:22 -0400 Subject: [PATCH 3/4] UPSTREAM: 131725: Avoid encoding in LogResponseObject when we are not going to use it Signed-off-by: Davanum Srinivas Kubernetes-commit: e418ee3a92ca6c670d26f775b0f669e8a5fe233c --- pkg/audit/request.go | 2 +- pkg/audit/request_log_test.go | 259 ++++++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+), 1 deletion(-) create mode 100644 pkg/audit/request_log_test.go diff --git a/pkg/audit/request.go b/pkg/audit/request.go index 60b69b0b2..705d11e88 100644 --- a/pkg/audit/request.go +++ b/pkg/audit/request.go @@ -171,7 +171,7 @@ func LogRequestPatch(ctx context.Context, patch []byte) { // will be converted to the given gv. func LogResponseObject(ctx context.Context, obj runtime.Object, gv schema.GroupVersion, s runtime.NegotiatedSerializer) { ac := AuditContextFrom(WithAuditContext(ctx)) - if ac.GetEventLevel().Less(auditinternal.LevelMetadata) { + if ac.GetEventLevel().Less(auditinternal.LevelRequestResponse) { return } diff --git a/pkg/audit/request_log_test.go b/pkg/audit/request_log_test.go new file mode 100644 index 000000000..b41450db2 --- /dev/null +++ b/pkg/audit/request_log_test.go @@ -0,0 +1,259 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package audit + +import ( + "context" + "io" + "strings" + "testing" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/runtime/serializer" + auditinternal "k8s.io/apiserver/pkg/apis/audit" +) + +func TestLogResponseObjectWithPod(t *testing.T) { + testPod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "Pod", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "test-namespace", + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + }, + }, + }, + } + + scheme := runtime.NewScheme() + if err := corev1.AddToScheme(scheme); err != nil { + t.Fatalf("Failed to add core/v1 to scheme: %v", err) + } + codecs := serializer.NewCodecFactory(scheme) + negotiatedSerializer := codecs.WithoutConversion() + + // Create audit context with RequestResponse level + ctx := WithAuditContext(context.Background()) + ac := AuditContextFrom(ctx) + + captureSink := &capturingAuditSink{} + if err := ac.Init(RequestAuditConfig{Level: auditinternal.LevelRequestResponse}, captureSink); err != nil { + t.Fatalf("Failed to initialize audit context: %v", err) + } + + LogResponseObject(ctx, testPod, schema.GroupVersion{Group: "", Version: "v1"}, negotiatedSerializer) + ac.ProcessEventStage(ctx, auditinternal.StageResponseComplete) + + if len(captureSink.events) != 1 { + t.Fatalf("Expected one audit event to be captured, got %d", len(captureSink.events)) + } + event := captureSink.events[0] + if event.ResponseObject == nil { + t.Fatal("Expected ResponseObject to be set, but it was nil") + } + if event.ResponseObject.ContentType != runtime.ContentTypeJSON { + t.Errorf("Expected ContentType to be %q, got %q", runtime.ContentTypeJSON, event.ResponseObject.ContentType) + } + if len(event.ResponseObject.Raw) == 0 { + t.Error("Expected ResponseObject.Raw to contain data, but it was empty") + } + + responseJSON := string(event.ResponseObject.Raw) + expectedFields := []string{"test-pod", "test-namespace", "test-container", "test-image"} + for _, field := range expectedFields { + if !strings.Contains(responseJSON, field) { + t.Errorf("Response should contain %q but didn't. Response: %s", field, responseJSON) + } + } + + if event.ResponseStatus != nil { + t.Errorf("Expected ResponseStatus to be nil for regular object, got: %+v", event.ResponseStatus) + } +} + +func TestLogResponseObjectWithStatus(t *testing.T) { + // Create a status object to test ResponseStatus handling + testStatus := &metav1.Status{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "Status", + }, + Status: "Success", + Message: "Test status message", + Reason: "TestReason", + Code: 200, + } + + scheme := runtime.NewScheme() + err := metav1.AddMetaToScheme(scheme) + if err != nil { + t.Fatalf("Failed to add meta to scheme: %v", err) + } + scheme.AddKnownTypes(schema.GroupVersion{Version: "v1"}, &metav1.Status{}) + codecs := serializer.NewCodecFactory(scheme) + negotiatedSerializer := codecs.WithoutConversion() + + ctx := WithAuditContext(context.Background()) + ac := AuditContextFrom(ctx) + + captureSink := &capturingAuditSink{} + if err := ac.Init(RequestAuditConfig{Level: auditinternal.LevelRequestResponse}, captureSink); err != nil { + t.Fatalf("Failed to initialize audit context: %v", err) + } + + LogResponseObject(ctx, testStatus, schema.GroupVersion{Group: "", Version: "v1"}, negotiatedSerializer) + ac.ProcessEventStage(ctx, auditinternal.StageResponseComplete) + + if len(captureSink.events) != 1 { + t.Fatalf("Expected one audit event to be captured, got %d", len(captureSink.events)) + } + event := captureSink.events[0] + + if event.ResponseObject == nil { + t.Fatal("Expected ResponseObject to be set, but it was nil") + } + if event.ResponseStatus == nil { + t.Fatal("Expected ResponseStatus to be set for Status object, but it was nil") + } + if event.ResponseStatus.Status != "Success" { + t.Errorf("Expected ResponseStatus.Status to be 'Success', got %q", event.ResponseStatus.Status) + } + if event.ResponseStatus.Message != "Test status message" { + t.Errorf("Expected ResponseStatus.Message to be 'Test status message', got %q", event.ResponseStatus.Message) + } + if event.ResponseStatus.Reason != "TestReason" { + t.Errorf("Expected ResponseStatus.Reason to be 'TestReason', got %q", event.ResponseStatus.Reason) + } + if event.ResponseStatus.Code != 200 { + t.Errorf("Expected ResponseStatus.Code to be 200, got %d", event.ResponseStatus.Code) + } +} + +func TestLogResponseObjectLevelCheck(t *testing.T) { + testCases := []struct { + name string + level auditinternal.Level + shouldEncode bool + }{ + { + name: "None level should not encode", + level: auditinternal.LevelNone, + shouldEncode: false, + }, + { + name: "Metadata level should not encode", + level: auditinternal.LevelMetadata, + shouldEncode: false, + }, + { + name: "Request level should not encode", + level: auditinternal.LevelRequest, + shouldEncode: false, + }, + { + name: "RequestResponse level should encode", + level: auditinternal.LevelRequestResponse, + shouldEncode: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a test object + testObj := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "test-namespace", + }, + } + + // Create audit context with the specified level + ctx := WithAuditContext(context.Background()) + ac := AuditContextFrom(ctx) + ac.Init(RequestAuditConfig{Level: tc.level}, nil) + + // Create a mock serializer that tracks if encoding was attempted + mockSerializer := &mockNegotiatedSerializer{} + + // Call the function under test + LogResponseObject(ctx, testObj, schema.GroupVersion{Group: "", Version: "v1"}, mockSerializer) + + // Check if encoding was attempted as expected + if mockSerializer.encodeCalled != tc.shouldEncode { + t.Errorf("Expected encoding to be called: %v, but got: %v", tc.shouldEncode, mockSerializer.encodeCalled) + } + }) + } +} + +type mockNegotiatedSerializer struct { + encodeCalled bool +} + +func (m *mockNegotiatedSerializer) SupportedMediaTypes() []runtime.SerializerInfo { + return []runtime.SerializerInfo{ + { + MediaType: runtime.ContentTypeJSON, + EncodesAsText: true, + Serializer: nil, + PrettySerializer: nil, + StreamSerializer: nil, + }, + } +} + +func (m *mockNegotiatedSerializer) EncoderForVersion(serializer runtime.Encoder, gv runtime.GroupVersioner) runtime.Encoder { + m.encodeCalled = true + return &mockEncoder{} +} + +func (m *mockNegotiatedSerializer) DecoderToVersion(serializer runtime.Decoder, gv runtime.GroupVersioner) runtime.Decoder { + return nil +} + +type mockEncoder struct{} + +func (e *mockEncoder) Encode(obj runtime.Object, w io.Writer) error { + return nil +} + +func (e *mockEncoder) Identifier() runtime.Identifier { + return runtime.Identifier("mock") +} + +type capturingAuditSink struct { + events []*auditinternal.Event +} + +func (s *capturingAuditSink) ProcessEvents(events ...*auditinternal.Event) bool { + for _, event := range events { + eventCopy := event.DeepCopy() + s.events = append(s.events, eventCopy) + } + return true +} From 7bc3c9987f36d1fdb63f0e20aacb50f1000e105e Mon Sep 17 00:00:00 2001 From: Davanum Srinivas Date: Mon, 12 May 2025 12:50:38 -0400 Subject: [PATCH 4/4] UPSTREAM: 131725: Avoid encoding in LogResponseObject when we are not going to use it Signed-off-by: Davanum Srinivas Kubernetes-commit: 153233c677d62c0254d54c1e7013645a081ac03d --- pkg/audit/request.go | 8 ++- pkg/audit/request_log_test.go | 128 ++++++++++++++++++++++++++-------- 2 files changed, 103 insertions(+), 33 deletions(-) diff --git a/pkg/audit/request.go b/pkg/audit/request.go index 705d11e88..d5f9c730f 100644 --- a/pkg/audit/request.go +++ b/pkg/audit/request.go @@ -171,12 +171,14 @@ func LogRequestPatch(ctx context.Context, patch []byte) { // will be converted to the given gv. func LogResponseObject(ctx context.Context, obj runtime.Object, gv schema.GroupVersion, s runtime.NegotiatedSerializer) { ac := AuditContextFrom(WithAuditContext(ctx)) - if ac.GetEventLevel().Less(auditinternal.LevelRequestResponse) { + status, _ := obj.(*metav1.Status) + if ac.GetEventLevel().Less(auditinternal.LevelMetadata) { + return + } else if ac.GetEventLevel().Less(auditinternal.LevelRequestResponse) { + ac.LogResponseObject(status, nil) return } - status, _ := obj.(*metav1.Status) - if shouldOmitManagedFields(ac) { copy, ok, err := copyWithoutManagedFields(obj) if err != nil { diff --git a/pkg/audit/request_log_test.go b/pkg/audit/request_log_test.go index b41450db2..236b9ebdf 100644 --- a/pkg/audit/request_log_test.go +++ b/pkg/audit/request_log_test.go @@ -156,57 +156,125 @@ func TestLogResponseObjectWithStatus(t *testing.T) { func TestLogResponseObjectLevelCheck(t *testing.T) { testCases := []struct { - name string - level auditinternal.Level - shouldEncode bool + name string + level auditinternal.Level + obj runtime.Object + shouldEncode bool + expectResponseObj bool + expectStatusFields bool }{ { - name: "None level should not encode", - level: auditinternal.LevelNone, - shouldEncode: false, + name: "None level should not encode or log anything", + level: auditinternal.LevelNone, + obj: &corev1.Pod{}, + shouldEncode: false, + expectResponseObj: false, + expectStatusFields: false, }, { - name: "Metadata level should not encode", - level: auditinternal.LevelMetadata, - shouldEncode: false, + name: "Metadata level should not encode or log anything", + level: auditinternal.LevelMetadata, + obj: &corev1.Pod{}, + shouldEncode: false, + expectResponseObj: false, + expectStatusFields: false, }, { - name: "Request level should not encode", - level: auditinternal.LevelRequest, - shouldEncode: false, + name: "Request level with Pod should not encode or log", + level: auditinternal.LevelRequest, + obj: &corev1.Pod{}, + shouldEncode: false, + expectResponseObj: false, + expectStatusFields: false, }, { - name: "RequestResponse level should encode", - level: auditinternal.LevelRequestResponse, - shouldEncode: true, + name: "Request level with Status should log status fields without encoding", + level: auditinternal.LevelRequest, + obj: &metav1.Status{ + Status: "Success", + Message: "Test message", + Code: 200, + }, + shouldEncode: false, + expectResponseObj: false, + expectStatusFields: true, + }, + { + name: "RequestResponse level with Pod should encode", + level: auditinternal.LevelRequestResponse, + obj: &corev1.Pod{}, + shouldEncode: true, + expectResponseObj: true, + expectStatusFields: false, + }, + { + name: "RequestResponse level with Status should encode and log status fields", + level: auditinternal.LevelRequestResponse, + obj: &metav1.Status{ + Status: "Success", + Message: "Test message", + Code: 200, + }, + shouldEncode: true, + expectResponseObj: true, + expectStatusFields: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // Create a test object - testObj := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-pod", - Namespace: "test-namespace", - }, - } - - // Create audit context with the specified level ctx := WithAuditContext(context.Background()) ac := AuditContextFrom(ctx) - ac.Init(RequestAuditConfig{Level: tc.level}, nil) - // Create a mock serializer that tracks if encoding was attempted - mockSerializer := &mockNegotiatedSerializer{} + captureSink := &capturingAuditSink{} + if err := ac.Init(RequestAuditConfig{Level: tc.level}, captureSink); err != nil { + t.Fatalf("Failed to initialize audit context: %v", err) + } - // Call the function under test - LogResponseObject(ctx, testObj, schema.GroupVersion{Group: "", Version: "v1"}, mockSerializer) + mockSerializer := &mockNegotiatedSerializer{} + LogResponseObject(ctx, tc.obj, schema.GroupVersion{Group: "", Version: "v1"}, mockSerializer) + ac.ProcessEventStage(ctx, auditinternal.StageResponseComplete) - // Check if encoding was attempted as expected if mockSerializer.encodeCalled != tc.shouldEncode { t.Errorf("Expected encoding to be called: %v, but got: %v", tc.shouldEncode, mockSerializer.encodeCalled) } + + if len(captureSink.events) != 1 { + t.Fatalf("Expected one audit event to be captured, got %d", len(captureSink.events)) + } + event := captureSink.events[0] + + if tc.expectResponseObj { + if event.ResponseObject == nil { + t.Error("Expected ResponseObject to be set, but it was nil") + } + } else { + if event.ResponseObject != nil { + t.Error("Expected ResponseObject to be nil") + } + } + + // Check ResponseStatus for Status objects + status, isStatus := tc.obj.(*metav1.Status) + if isStatus && tc.expectStatusFields { + if event.ResponseStatus == nil { + t.Error("Expected ResponseStatus to be set for Status object, but it was nil") + } else { + if event.ResponseStatus.Status != status.Status { + t.Errorf("Expected ResponseStatus.Status to be %q, got %q", status.Status, event.ResponseStatus.Status) + } + if event.ResponseStatus.Message != status.Message { + t.Errorf("Expected ResponseStatus.Message to be %q, got %q", status.Message, event.ResponseStatus.Message) + } + if event.ResponseStatus.Code != status.Code { + t.Errorf("Expected ResponseStatus.Code to be %d, got %d", status.Code, event.ResponseStatus.Code) + } + } + } else { + if event.ResponseStatus != nil { + t.Error("Expected ResponseStatus to be nil") + } + } }) } }