Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/admission/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,17 @@ func ensureAnnotationGetter(a Attributes) error {
}

func (handler *auditHandler) logAnnotations(ctx context.Context, a Attributes) {
ae := audit.AuditEventFrom(ctx)
ae := audit.AuditContextFrom(ctx)
if ae == nil {
return
}

var annotations map[string]string
switch a := a.(type) {
case privateAnnotationsGetter:
annotations = a.getAnnotations(ae.Level)
annotations = a.getAnnotations(ae.GetEventLevel())
case AnnotationsGetter:
annotations = a.GetAnnotations(ae.Level)
annotations = a.GetAnnotations(ae.GetEventLevel())
default:
// this will never happen, because we have already checked it in ensureAnnotationGetter
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/admission/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ func TestWithAudit(t *testing.T) {
var handler Interface = fakeHandler{tc.admit, tc.admitAnnotations, tc.validate, tc.validateAnnotations, tc.handles}
ctx := audit.WithAuditContext(context.Background())
ac := audit.AuditContextFrom(ctx)
ae := &ac.Event
ae.Level = auditinternal.LevelMetadata
if err := ac.Init(audit.RequestAuditConfig{Level: auditinternal.LevelMetadata}, nil); err != nil {
t.Fatal(err)
}

auditHandler := WithAudit(handler)
a := attributes()

Expand All @@ -171,9 +173,9 @@ func TestWithAudit(t *testing.T) {
annotations[k] = v
}
if len(annotations) == 0 {
assert.Nil(t, ae.Annotations, tcName+": unexptected annotations set in audit event")
assert.Nil(t, ac.GetEventAnnotations(), tcName+": unexptected annotations set in audit event")
} else {
assert.Equal(t, annotations, ae.Annotations, tcName+": unexptected annotations set in audit event")
assert.Equal(t, annotations, ac.GetEventAnnotations(), tcName+": unexptected annotations set in audit event")
}
}
}
Expand All @@ -187,8 +189,6 @@ func TestWithAuditConcurrency(t *testing.T) {
}
var handler Interface = fakeHandler{admitAnnotations: admitAnnotations, handles: true}
ctx := audit.WithAuditContext(context.Background())
ac := audit.AuditContextFrom(ctx)
ac.Event.Level = auditinternal.LevelMetadata
auditHandler := WithAudit(handler)
a := attributes()

Expand Down
266 changes: 237 additions & 29 deletions pkg/audit/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,18 @@ package audit

import (
"context"
"errors"
"maps"
"sync"
"sync/atomic"
"time"

authnv1 "k8s.io/api/authentication/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
auditinternal "k8s.io/apiserver/pkg/apis/audit"
"k8s.io/apiserver/pkg/authentication/user"
genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/klog/v2"
)
Expand All @@ -35,22 +43,223 @@ const auditKey key = iota

// AuditContext holds the information for constructing the audit events for the current request.
type AuditContext struct {
// RequestAuditConfig is the audit configuration that applies to the request
RequestAuditConfig RequestAuditConfig

// Event is the audit Event object that is being captured to be written in
// initialized indicates whether requestAuditConfig and sink have been populated and are safe to read unguarded.
// This should only be set via Init().
initialized atomic.Bool
// requestAuditConfig is the audit configuration that applies to the request.
// This should only be written via Init(RequestAuditConfig, Sink), and only read when initialized.Load() is true.
requestAuditConfig RequestAuditConfig
// sink is the sink to use when processing event stages.
// This should only be written via Init(RequestAuditConfig, Sink), and only read when initialized.Load() is true.
sink Sink

// lock guards event
lock sync.Mutex

// event is the audit Event object that is being captured to be written in
// the API audit log.
Event auditinternal.Event
event auditinternal.Event

// annotationMutex guards event.Annotations
annotationMutex sync.Mutex
// unguarded copy of auditID from the event
auditID atomic.Value
}

// Enabled checks whether auditing is enabled for this audit context.
func (ac *AuditContext) Enabled() bool {
// Note: An unset Level should be considered Enabled, so that request data (e.g. annotations)
// can still be captured before the audit policy is evaluated.
return ac != nil && ac.RequestAuditConfig.Level != auditinternal.LevelNone
if ac == nil {
// protect against nil pointers
return false
}
if !ac.initialized.Load() {
// Note: An unset Level should be considered Enabled, so that request data (e.g. annotations)
// can still be captured before the audit policy is evaluated.
return true
}
return ac.requestAuditConfig.Level != auditinternal.LevelNone
}

func (ac *AuditContext) Init(requestAuditConfig RequestAuditConfig, sink Sink) error {
ac.lock.Lock()
defer ac.lock.Unlock()
if ac.initialized.Load() {
return errors.New("audit context was already initialized")
}
ac.requestAuditConfig = requestAuditConfig
ac.sink = sink
ac.event.Level = requestAuditConfig.Level
ac.initialized.Store(true)
return nil
}

func (ac *AuditContext) AuditID() types.UID {
// return the unguarded copy of the auditID
id, _ := ac.auditID.Load().(types.UID)
return id
}

func (ac *AuditContext) visitEvent(f func(event *auditinternal.Event)) {
ac.lock.Lock()
defer ac.lock.Unlock()
f(&ac.event)
}

// ProcessEventStage returns true on success, false if there was an error processing the stage.
func (ac *AuditContext) ProcessEventStage(ctx context.Context, stage auditinternal.Stage) bool {
if ac == nil || !ac.initialized.Load() {
return true
}
if ac.sink == nil {
return true
}
for _, omitStage := range ac.requestAuditConfig.OmitStages {
if stage == omitStage {
return true
}
}

processed := false
ac.visitEvent(func(event *auditinternal.Event) {
event.Stage = stage
if stage == auditinternal.StageRequestReceived {
event.StageTimestamp = event.RequestReceivedTimestamp
} else {
event.StageTimestamp = metav1.NewMicroTime(time.Now())
}

ObserveEvent(ctx)
processed = ac.sink.ProcessEvents(event)
})
return processed
}

func (ac *AuditContext) LogImpersonatedUser(user user.Info) {
ac.visitEvent(func(ev *auditinternal.Event) {
if ev == nil || ev.Level.Less(auditinternal.LevelMetadata) {
return
}
ev.ImpersonatedUser = &authnv1.UserInfo{
Username: user.GetName(),
}
ev.ImpersonatedUser.Groups = user.GetGroups()
ev.ImpersonatedUser.UID = user.GetUID()
ev.ImpersonatedUser.Extra = map[string]authnv1.ExtraValue{}
for k, v := range user.GetExtra() {
ev.ImpersonatedUser.Extra[k] = authnv1.ExtraValue(v)
}
})
}

func (ac *AuditContext) LogResponseObject(status *metav1.Status, obj *runtime.Unknown) {
ac.visitEvent(func(ae *auditinternal.Event) {
if status != nil {
// selectively copy the bounded fields.
ae.ResponseStatus = &metav1.Status{
Status: status.Status,
Message: status.Message,
Reason: status.Reason,
Details: status.Details,
Code: status.Code,
}
}
if ae.Level.Less(auditinternal.LevelRequestResponse) {
return
}
ae.ResponseObject = obj
})
}

// LogRequestPatch fills in the given patch as the request object into an audit event.
func (ac *AuditContext) LogRequestPatch(patch []byte) {
ac.visitEvent(func(ae *auditinternal.Event) {
ae.RequestObject = &runtime.Unknown{
Raw: patch,
ContentType: runtime.ContentTypeJSON,
}
})
}

func (ac *AuditContext) GetEventAnnotation(key string) (string, bool) {
var val string
var ok bool
ac.visitEvent(func(event *auditinternal.Event) {
val, ok = event.Annotations[key]
})
return val, ok
}

func (ac *AuditContext) GetEventLevel() auditinternal.Level {
var level auditinternal.Level
ac.visitEvent(func(event *auditinternal.Event) {
level = event.Level
})
return level
}

func (ac *AuditContext) SetEventStage(stage auditinternal.Stage) {
ac.visitEvent(func(event *auditinternal.Event) {
event.Stage = stage
})
}

func (ac *AuditContext) GetEventStage() auditinternal.Stage {
var stage auditinternal.Stage
ac.visitEvent(func(event *auditinternal.Event) {
stage = event.Stage
})
return stage
}

func (ac *AuditContext) SetEventStageTimestamp(timestamp metav1.MicroTime) {
ac.visitEvent(func(event *auditinternal.Event) {
event.StageTimestamp = timestamp
})
}

func (ac *AuditContext) GetEventResponseStatus() *metav1.Status {
var status *metav1.Status
ac.visitEvent(func(event *auditinternal.Event) {
status = event.ResponseStatus
})
return status
}

func (ac *AuditContext) GetEventRequestReceivedTimestamp() metav1.MicroTime {
var timestamp metav1.MicroTime
ac.visitEvent(func(event *auditinternal.Event) {
timestamp = event.RequestReceivedTimestamp
})
return timestamp
}

func (ac *AuditContext) GetEventStageTimestamp() metav1.MicroTime {
var timestamp metav1.MicroTime
ac.visitEvent(func(event *auditinternal.Event) {
timestamp = event.StageTimestamp
})
return timestamp
}

func (ac *AuditContext) SetEventResponseStatus(status *metav1.Status) {
ac.visitEvent(func(event *auditinternal.Event) {
event.ResponseStatus = status
})
}

func (ac *AuditContext) SetEventResponseStatusCode(statusCode int32) {
ac.visitEvent(func(event *auditinternal.Event) {
if event.ResponseStatus == nil {
event.ResponseStatus = &metav1.Status{}
}
event.ResponseStatus.Code = statusCode
})
}

func (ac *AuditContext) GetEventAnnotations() map[string]string {
var annotations map[string]string
ac.visitEvent(func(event *auditinternal.Event) {
annotations = maps.Clone(event.Annotations)
})
return annotations
}

// AddAuditAnnotation sets the audit annotation for the given key, value pair.
Expand All @@ -66,8 +275,8 @@ func AddAuditAnnotation(ctx context.Context, key, value string) {
return
}

ac.annotationMutex.Lock()
defer ac.annotationMutex.Unlock()
ac.lock.Lock()
defer ac.lock.Unlock()

addAuditAnnotationLocked(ac, key, value)
}
Expand All @@ -81,8 +290,8 @@ func AddAuditAnnotations(ctx context.Context, keysAndValues ...string) {
return
}

ac.annotationMutex.Lock()
defer ac.annotationMutex.Unlock()
ac.lock.Lock()
defer ac.lock.Unlock()

if len(keysAndValues)%2 != 0 {
klog.Errorf("Dropping mismatched audit annotation %q", keysAndValues[len(keysAndValues)-1])
Expand All @@ -100,8 +309,8 @@ func AddAuditAnnotationsMap(ctx context.Context, annotations map[string]string)
return
}

ac.annotationMutex.Lock()
defer ac.annotationMutex.Unlock()
ac.lock.Lock()
defer ac.lock.Unlock()

for k, v := range annotations {
addAuditAnnotationLocked(ac, k, v)
Expand All @@ -110,8 +319,7 @@ func AddAuditAnnotationsMap(ctx context.Context, annotations map[string]string)

// addAuditAnnotationLocked records the audit annotation on the event.
func addAuditAnnotationLocked(ac *AuditContext, key, value string) {
ae := &ac.Event

ae := &ac.event
if ae.Annotations == nil {
ae.Annotations = make(map[string]string)
}
Expand All @@ -128,15 +336,11 @@ func WithAuditContext(parent context.Context) context.Context {
return parent // Avoid double registering.
}

return genericapirequest.WithValue(parent, auditKey, &AuditContext{})
}

// AuditEventFrom returns the audit event struct on the ctx
func AuditEventFrom(ctx context.Context) *auditinternal.Event {
if ac := AuditContextFrom(ctx); ac.Enabled() {
return &ac.Event
}
return nil
return genericapirequest.WithValue(parent, auditKey, &AuditContext{
event: auditinternal.Event{
Stage: auditinternal.StageResponseStarted,
},
})
}

// AuditContextFrom returns the pair of the audit configuration object
Expand All @@ -154,15 +358,19 @@ func WithAuditID(ctx context.Context, auditID types.UID) {
return
}
if ac := AuditContextFrom(ctx); ac != nil {
ac.Event.AuditID = auditID
ac.visitEvent(func(event *auditinternal.Event) {
ac.auditID.Store(auditID)
event.AuditID = auditID
})
}
}

// AuditIDFrom returns the value of the audit ID from the request context, along with whether
// auditing is enabled.
func AuditIDFrom(ctx context.Context) (types.UID, bool) {
if ac := AuditContextFrom(ctx); ac != nil {
return ac.Event.AuditID, true
id, _ := ac.auditID.Load().(types.UID)
return id, true
}
return "", false
}
Expand Down
Loading