diff --git a/lib/srv/desktop/audit.go b/lib/srv/desktop/audit.go index 7b82b4e3e7e2c..75f82b93f068c 100644 --- a/lib/srv/desktop/audit.go +++ b/lib/srv/desktop/audit.go @@ -47,6 +47,7 @@ type desktopSessionAuditor struct { clusterName string desktopServiceUUID string + compactor auditCompactor auditCache sharedDirectoryAuditCache } @@ -75,17 +76,16 @@ func (s *WindowsService) newSessionAuditor( return &desktopSessionAuditor{ clock: s.cfg.Clock, - sessionID: sessionID, - identity: identity, - windowsUser: windowsUser, - desktop: desktop, - enableNLA: s.enableNLA, - + sessionID: sessionID, + identity: identity, + windowsUser: windowsUser, + desktop: desktop, + enableNLA: s.enableNLA, startTime: s.cfg.Clock.Now().UTC().Round(time.Millisecond), clusterName: s.clusterName, desktopServiceUUID: s.cfg.Heartbeat.HostUUID, - - auditCache: newSharedDirectoryAuditCache(), + compactor: newAuditCompactor(3*time.Second, 10*time.Second, s.emit), + auditCache: newSharedDirectoryAuditCache(), } } @@ -123,6 +123,10 @@ func (d *desktopSessionAuditor) makeSessionStart(err error) *events.WindowsDeskt return event } +func (d *desktopSessionAuditor) teardown(ctx context.Context) { + d.compactor.flush(ctx) +} + func (d *desktopSessionAuditor) makeSessionEnd(recorded bool) *events.WindowsDesktopSessionEnd { userMetadata := d.identity.GetUserMetadata() userMetadata.Login = d.windowsUser diff --git a/lib/srv/desktop/audit_compactor.go b/lib/srv/desktop/audit_compactor.go new file mode 100644 index 0000000000000..2cd94df522cdf --- /dev/null +++ b/lib/srv/desktop/audit_compactor.go @@ -0,0 +1,276 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package desktop + +import ( + "context" + "iter" + "maps" + "math" + "slices" + "sync" + "time" + + "github.com/gravitational/teleport/api/types/events" +) + +// fileOperationsKey uniquely identifies a set of common file operations +type fileOperationsKey struct { + path string + directoryID directoryID + write bool +} + +// fileOperationEvent is an abstraction of read/write events +// so that we need only one compactor implementation. +type fileOperationEvent interface { + Base() events.AuditEvent + IsWriteEvent() bool + GetDirectoryID() directoryID + GetPath() string + GetOffset() uint64 + GetLength() uint64 + SetLength(uint64) +} + +// fileOperationsBucket identifies a set of reads/writes +// to a particular file within some period of time. +type fileOperationsBucket struct { + expireTime time.Time + events []fileOperationEvent + timer *time.Timer + done chan struct{} +} + +// auditCompactor retains read and write events to a given file for a period of time before +// emitting them to the audit log. Once the timeout period expires, contiguous read/write events are +// compacted into a single audit event and emitted. +type auditCompactor struct { + // refreshInterval defines how long a bucket should wait for a subsequent + // file operation to arrive before compacting and emitting its audit event(s). + refreshInterval time.Duration + // maxDelayInterval defines the maximum length of time that a bucket should wait + // before before compacting and emitting its audit event(s) + // this prevents a slow trickle of read/write events within the refreshInterval from + // indefinitely delaying audit events from being emitted. + maxDelayInterval time.Duration + emitFn func(context.Context, events.AuditEvent) + buckets map[fileOperationsKey]*fileOperationsBucket + bucketsLock sync.Mutex +} + +func newAuditCompactor(refreshInterval, maxDelayInterval time.Duration, emitFn func(context.Context, events.AuditEvent)) auditCompactor { + return auditCompactor{ + refreshInterval: refreshInterval, + maxDelayInterval: maxDelayInterval, + emitFn: emitFn, + buckets: map[fileOperationsKey]*fileOperationsBucket{}, + } +} + +func (s *fileOperationsBucket) emitEvents(ctx context.Context, emitFn func(ctx context.Context, event events.AuditEvent)) { + for event := range s.compactEvents() { + emitFn(ctx, event.Base()) + } +} + +func (s *fileOperationsBucket) compactEvents() iter.Seq[fileOperationEvent] { + offsetMapping := map[uint64][]fileOperationEvent{} + for _, event := range s.events { + offsetMapping[event.GetOffset()] = append(offsetMapping[event.GetOffset()], event) + } + + var finalEvents []fileOperationEvent + for len(offsetMapping) > 0 { + // Find the read/write event with the lowest offset + // so that we may greedily search for the longest + // contiguous segment we can produce. + smallestKey := slices.Min(slices.Collect(maps.Keys(offsetMapping))) + // The audit event at which we will begin our search. + event := offsetMapping[smallestKey][0] + // compact returns the longest slice of contiguous read/write audit events. + // It always a slice of at least length 1, containing the starting event. + sequentialEvents, sequenceLength := s.compact(event, offsetMapping) + // base is the first event in the sequence. We will mutate this + // event with the updated length and emit it. + base := sequentialEvents[0] + + // Remove each event in this sequence from the map + for _, subsequent := range sequentialEvents { + offset := subsequent.GetOffset() + events := offsetMapping[offset] + deleteIdx := slices.Index(events, subsequent) + events = slices.Delete(events, deleteIdx, deleteIdx+1) + if len(events) > 0 { + offsetMapping[offset] = events + } else { + delete(offsetMapping, offset) + } + } + base.SetLength(sequenceLength) + finalEvents = append(finalEvents, base) + } + + return slices.Values(finalEvents) +} + +// compact finds the longest contiguous set of reads/writes following the given 'event'. +func (s *fileOperationsBucket) compact(event fileOperationEvent, eventsByOffset map[uint64][]fileOperationEvent) ([]fileOperationEvent, uint64) { + // Determine the offset at which the next contiguous segment must start. + offset := event.GetOffset() + event.GetLength() + // Consule the map for any events with this offset. + if len(eventsByOffset[offset]) > 0 { + // There may be multiple candidate segments to follow. + // Try each of them out and select the longest contiguous set of segments + var winner []fileOperationEvent + var maxLength uint64 + for _, choice := range eventsByOffset[offset] { + // TODO: We could probably speed this up with memoization/dynamic programmming, + // but this code is fairly readable as-is and it's probably not likely that + // we'll end up with too many possible paths in production. + // Recursively evaluate each option. + option, length := s.compact(choice, eventsByOffset) + if length > maxLength { + winner = option + maxLength = length + } + } + return append([]fileOperationEvent{event}, winner...), maxLength + event.GetLength() + } + return []fileOperationEvent{event}, event.GetLength() +} + +func (s *fileOperationsBucket) addEvent(event fileOperationEvent) { + s.events = append(s.events, event) +} + +func (a *auditCompactor) handleEvent(ctx context.Context, event fileOperationEvent) { + // File Operations are grouped by directoryID, path, and read vs write + key := fileOperationsKey{ + write: event.IsWriteEvent(), + directoryID: event.GetDirectoryID(), + path: event.GetPath(), + } + + newBucket := true + a.bucketsLock.Lock() + defer a.bucketsLock.Unlock() + + if bucket, exists := a.buckets[key]; exists { + // We're currently tracking this bucket + // Temporarily stop the timer (if possible) + alreadyFired := !bucket.timer.Stop() + if !alreadyFired { + // Update the current bucket. It is a continuation of the current bucket + // and the timer has not yet fired for it. + bucket.addEvent(event) + // Reset the timer either to the refresh interval, or until + // the buckets's expiration time + bucket.timer.Reset(time.Duration(math.Min(float64(a.refreshInterval), float64(time.Until(bucket.expireTime))))) + newBucket = false + } else { + // The timer has already fired. Stop tracking this bucket. + // A new bucket will be created below to handle this event. + delete(a.buckets, key) + } + } + + // We need to create a new bucket due to one of the following: + // - We are not tracking any such bucket yet. + // - We were tracking this bucket but the timer has already fired. + if newBucket { + bucket := &fileOperationsBucket{ + done: make(chan struct{}), + expireTime: time.Now().Add(a.maxDelayInterval), + events: []fileOperationEvent{event}, + } + bucket.timer = time.AfterFunc(a.refreshInterval, func() { + // Close done channel so that the 'flush' function can + // verify that this goroutine has completed its work. + defer close(bucket.done) + a.bucketsLock.Lock() + delete(a.buckets, key) + a.bucketsLock.Unlock() + bucket.emitEvents(ctx, a.emitFn) + + }) + a.buckets[key] = bucket + } +} + +// flush immediately compacts and emits audit events for all +// unexpired buckets and blocks until completion. +func (a *auditCompactor) flush(ctx context.Context) { + a.bucketsLock.Lock() + wait := []chan struct{}{} + for bucketKey, bucket := range a.buckets { + if bucket.timer.Stop() { + // If we successfully stop the timer before it fires, + // go ahead and emit the audit event. + bucket.emitEvents(ctx, a.emitFn) + delete(a.buckets, bucketKey) + } else { + // The timer was already firing, so wait until + // the emitFn as been executed by the underlying goroutine. + wait = append(wait, bucket.done) + } + } + // Unlock so that we may unblock timer functions. + a.bucketsLock.Unlock() + // Wait for pending timers to complete + // We use our own "done" channel rather than the timer's + // because we need to know that the timer's underlying goroutine. + for _, doneChan := range wait { + <-doneChan + } +} + +// Adapters for current read/write audit events. + +type readEvent struct { + *events.DesktopSharedDirectoryRead +} + +func (r *readEvent) SetLength(len uint64) { r.Length = uint32(len) } +func (r *readEvent) GetLength() uint64 { return uint64(r.Length) } +func (r *readEvent) GetOffset() uint64 { return r.Offset } +func (r *readEvent) GetPath() string { return r.Path } +func (r *readEvent) IsWriteEvent() bool { return true } +func (r *readEvent) GetDirectoryID() directoryID { return directoryID(r.DirectoryID) } +func (r *readEvent) Base() events.AuditEvent { return r.DesktopSharedDirectoryRead } + +type writeEvent struct { + *events.DesktopSharedDirectoryWrite +} + +func (r *writeEvent) SetLength(len uint64) { r.Length = uint32(len) } +func (r *writeEvent) GetLength() uint64 { return uint64(r.Length) } +func (r *writeEvent) GetOffset() uint64 { return r.Offset } +func (r *writeEvent) GetPath() string { return r.Path } +func (r *writeEvent) IsWriteEvent() bool { return false } +func (r *writeEvent) GetDirectoryID() directoryID { return directoryID(r.DirectoryID) } +func (r *writeEvent) Base() events.AuditEvent { return r.DesktopSharedDirectoryWrite } + +func (a *auditCompactor) handleRead(ctx context.Context, event *events.DesktopSharedDirectoryRead) { + a.handleEvent(ctx, &readEvent{DesktopSharedDirectoryRead: event}) +} + +func (a *auditCompactor) handleWrite(ctx context.Context, event *events.DesktopSharedDirectoryWrite) { + a.handleEvent(ctx, &writeEvent{DesktopSharedDirectoryWrite: event}) +} diff --git a/lib/srv/desktop/audit_compactor_test.go b/lib/srv/desktop/audit_compactor_test.go new file mode 100644 index 0000000000000..59685dee9a802 --- /dev/null +++ b/lib/srv/desktop/audit_compactor_test.go @@ -0,0 +1,260 @@ +//go:build go1.24 && enablesynctest + +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package desktop + +import ( + "context" + "sync" + "testing" + "testing/synctest" + "time" + + "github.com/gravitational/teleport/api/types/events" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newReadEvent(path string, directory directoryID, offset uint64, length uint32) *events.DesktopSharedDirectoryRead { + return &events.DesktopSharedDirectoryRead{ + Path: path, + DirectoryID: uint32(directory), + Offset: offset, + Length: length, + } +} + +func newWriteEvent(path string, directory directoryID, offset uint64, length uint32) *events.DesktopSharedDirectoryWrite { + return &events.DesktopSharedDirectoryWrite{ + Path: path, + DirectoryID: uint32(directory), + Offset: offset, + Length: length, + } +} + +func TestAuditCompactor(t *testing.T) { + auditEvents := []events.AuditEvent{} + eventsLock := sync.Mutex{} + const refreshInterval = 1 * time.Second + const maxDelayInterval = 3 * time.Second + compactor := &auditCompactor{ + refreshInterval: refreshInterval, + maxDelayInterval: maxDelayInterval, + emitFn: func(_ context.Context, event events.AuditEvent) { + eventsLock.Lock() + defer eventsLock.Unlock() + auditEvents = append(auditEvents, event) + }, + buckets: map[fileOperationsKey]*fileOperationsBucket{}, + } + + t.Run("basic", func(t *testing.T) { + auditEvents = auditEvents[:0] + ctx := t.Context() + synctest.Run(func() { + // Read sequence A + compactor.handleRead(ctx, newReadEvent("foo", 1, 0, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 100, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 200, 100)) + // Read sequence B + compactor.handleRead(ctx, newReadEvent("foo", 1, 0, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 100, 100)) + // Read sequence A continued + compactor.handleRead(ctx, newReadEvent("foo", 1, 300, 200)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 500, 50)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 550, 90)) + + compactor.flush(ctx) + require.Len(t, auditEvents, 2) + // Should be compacted to 2 audit events + // Once compacted, audit events should inherit the timestamp of + // the first event in the stream + assert.Contains(t, auditEvents, newReadEvent("foo", 1, 0, 640)) + assert.Contains(t, auditEvents, newReadEvent("foo", 1, 0, 200)) + }) + + }) + + t.Run("complex", func(t *testing.T) { + auditEvents = auditEvents[:0] + ctx := t.Context() + synctest.Run(func() { + + // Three separate reads (with different lengths) of the same file + // Read sequence A + compactor.handleRead(ctx, newReadEvent("foo", 1, 0, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 100, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 200, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 300, 50)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 350, 75)) + + // Read sequence B + compactor.handleRead(ctx, newReadEvent("foo", 1, 0, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 100, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 200, 150)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 350, 400)) + + // Read sequence C (does not start at 0) + compactor.handleRead(ctx, newReadEvent("foo", 1, 100, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 200, 500)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 700, 500)) + + compactor.flush(ctx) + require.Len(t, auditEvents, 3) + // Should be compacted to 3 audit events + assert.Contains(t, auditEvents, newReadEvent("foo", 1, 100, 325)) + assert.Contains(t, auditEvents, newReadEvent("foo", 1, 0, 750)) + assert.Contains(t, auditEvents, newReadEvent("foo", 1, 0, 1200)) + }) + + }) + + t.Run("expirations", func(t *testing.T) { + auditEvents = auditEvents[:0] + ctx := t.Context() + synctest.Run(func() { + // 2 sequential reads + compactor.handleRead(ctx, newReadEvent("foo", 1, 0, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 100, 100)) + time.Sleep(refreshInterval - time.Millisecond) + synctest.Wait() + + // Should not be emitted yet refresh interval has not been exceeded + eventsLock.Lock() + assert.Empty(t, auditEvents) + eventsLock.Unlock() + + // Complete the refreshInterval and we should have an event available + time.Sleep(time.Millisecond) + synctest.Wait() + eventsLock.Lock() + assert.Contains(t, auditEvents, newReadEvent("foo", 1, 0, 200)) + eventsLock.Unlock() + + // Continue submitting events just before the refresh interval. + // Not audit event should be submitted until maxDelayInterval is reached + auditEvents = auditEvents[:0] + var elapsedTime time.Duration + offset := uint64(200) + const length = uint32(100) + + count := 0 + for elapsedTime < maxDelayInterval { + compactor.handleRead(ctx, newReadEvent("foo", 1, offset, length)) + time.Sleep(refreshInterval - time.Millisecond) + synctest.Wait() + elapsedTime += refreshInterval - time.Millisecond + offset += uint64(length) + count++ + } + // maxDelay should be exeeded by now and we should have + // a single consolidated event + eventsLock.Lock() + require.Len(t, auditEvents, 1) + assert.Contains(t, auditEvents, newReadEvent("foo", 1, 200, uint32(length*uint32(count)))) + eventsLock.Unlock() + + }) + + }) + + t.Run("mix-reads-writes", func(t *testing.T) { + auditEvents = auditEvents[:0] + ctx := t.Context() + synctest.Run(func() { + // 3 sequential reads + compactor.handleRead(ctx, newReadEvent("foo", 1, 0, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 100, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 1, 200, 100)) + // same file and directory, and looks sequential, but it's a write + compactor.handleWrite(ctx, newWriteEvent("foo", 1, 300, 50)) + compactor.handleWrite(ctx, newWriteEvent("foo", 1, 350, 50)) + + compactor.flush(ctx) + require.Len(t, auditEvents, 2) + // Should be compacted to 2 audit events + assert.Contains(t, auditEvents, newReadEvent("foo", 1, 0, 300)) + assert.Contains(t, auditEvents, newWriteEvent("foo", 1, 300, 100)) + }) + }) + + t.Run("mix-files-and-directories", func(t *testing.T) { + auditEvents = auditEvents[:0] + ctx := t.Context() + synctest.Run(func() { + // Identical offsets and lengths, but different path and/or directoryID + compactor.handleRead(ctx, newReadEvent("foo", 1, 0, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 2, 0, 100)) + compactor.handleRead(ctx, newReadEvent("bar", 1, 0, 100)) + + compactor.handleRead(ctx, newReadEvent("foo", 1, 100, 100)) + compactor.handleRead(ctx, newReadEvent("foo", 2, 100, 100)) + compactor.handleRead(ctx, newReadEvent("bar", 1, 100, 100)) + + compactor.flush(ctx) + require.Len(t, auditEvents, 3) + // Should be compacted to 3 audit events + assert.Contains(t, auditEvents, newReadEvent("foo", 1, 0, 200)) + assert.Contains(t, auditEvents, newReadEvent("foo", 2, 0, 200)) + assert.Contains(t, auditEvents, newReadEvent("bar", 1, 0, 200)) + }) + + }) + + t.Run("racy-flush", func(t *testing.T) { + ctx := t.Context() + synctest.Run(func() { + auditEvents := make(chan events.AuditEvent) + compactor.emitFn = func(_ context.Context, ae events.AuditEvent) { + auditEvents <- ae + } + // Identical offsets and lengths, but different path and/or directoryID + compactor.handleRead(ctx, newReadEvent("foo", 1, 0, 100)) + compactor.handleRead(ctx, newReadEvent("bar", 1, 0, 100)) + compactor.handleRead(ctx, newReadEvent("baz", 1, 0, 100)) + time.Sleep(refreshInterval - 1*time.Nanosecond) + compactor.handleRead(ctx, newReadEvent("caz", 1, 0, 100)) + + // Timers should start firing + time.Sleep(1 * time.Nanosecond) + synctest.Wait() + + flushDone := false + go func() { + compactor.flush(ctx) + flushDone = true + }() + + expectedEvents := []events.AuditEvent{ + newReadEvent("foo", 1, 0, 100), + newReadEvent("bar", 1, 0, 100), + newReadEvent("baz", 1, 0, 100), + newReadEvent("caz", 1, 0, 100), + } + for _ = range len(expectedEvents) { + assert.False(t, flushDone) + assert.Contains(t, expectedEvents, <-auditEvents) + synctest.Wait() + } + assert.True(t, flushDone) + }) + }) +} diff --git a/lib/srv/desktop/windows_server.go b/lib/srv/desktop/windows_server.go index 0ee0de0f77d7f..22c93c992f410 100644 --- a/lib/srv/desktop/windows_server.go +++ b/lib/srv/desktop/windows_server.go @@ -1062,6 +1062,7 @@ func (s *WindowsService) connectRDP(ctx context.Context, log *slog.Logger, tdpCo err = rdpc.Run(ctx, certDER, keyDER) // ctx may have been canceled, so emit with a separate context + audit.teardown(context.Background()) endEvent := audit.makeSessionEnd(recordSession) s.record(context.Background(), recorder, endEvent) s.emit(context.Background(), endEvent) @@ -1219,9 +1220,11 @@ func (s *WindowsService) makeTDPReceiveHandler( s.emit(ctx, errorEvent) } case tdp.SharedDirectoryReadResponse: - s.emit(ctx, audit.makeSharedDirectoryReadResponse(msg)) + // shared directory audit events can be noisy, so we use a compactor + // to retain and delay them in an attempt to coalesce contiguous events + audit.compactor.handleRead(ctx, audit.makeSharedDirectoryReadResponse(msg)) case tdp.SharedDirectoryWriteResponse: - s.emit(ctx, audit.makeSharedDirectoryWriteResponse(msg)) + audit.compactor.handleWrite(ctx, audit.makeSharedDirectoryWriteResponse(msg)) } } }