diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 0f990b3e4e8ce..a8453ae22f316 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -192,6 +192,9 @@ const ( const ( notificationsPageReadInterval = 5 * time.Millisecond notificationsWriteInterval = 40 * time.Millisecond + + accessListsPageReadInterval = 5 * time.Millisecond + accessListsPageSize = 20 ) const ( @@ -6907,7 +6910,8 @@ const ( // createAccessListReminderNotificationsOptions defines the optional parameters for CreateAccessListReminderNotifications. type createAccessListReminderNotificationsOptions struct { - createNotificationInterval time.Duration + createNotificationInterval time.Duration + accessListsPageReadInterval time.Duration } // CreateAccessListReminderNotificationsOptions is a functional option for CreateAccessListReminderNotifications. @@ -6920,11 +6924,19 @@ func WithCreateNotificationInterval(d time.Duration) CreateAccessListReminderNot } } +// WithAccessListsPageReadInterval sets the interval between reading pages of access lists. +func WithAccessListsPageReadInterval(d time.Duration) CreateAccessListReminderNotificationsOptions { + return func(o *createAccessListReminderNotificationsOptions) { + o.accessListsPageReadInterval = d + } +} + // CreateAccessListReminderNotifications checks if there are any access lists expiring soon and creates notifications to remind their owners if so. func (a *Server) CreateAccessListReminderNotifications(ctx context.Context, opts ...CreateAccessListReminderNotificationsOptions) { opt := &createAccessListReminderNotificationsOptions{ // defaults to notificationsWriteInterval aka 40ms - createNotificationInterval: notificationsWriteInterval, + createNotificationInterval: notificationsWriteInterval, + accessListsPageReadInterval: accessListsPageReadInterval, } for _, o := range opts { o(opt) @@ -6962,21 +6974,36 @@ func (a *Server) CreateAccessListReminderNotifications(ctx context.Context, opts // Fetch all access lists var accessLists []*accesslist.AccessList - err = clientutils.IterateResources(ctx, a.Cache.ListAccessLists, func(al *accesslist.AccessList) error { - if !al.IsReviewable() { - return nil + var accessListsPageKey string + accessListsReadLimiter := time.NewTicker(opt.accessListsPageReadInterval) + defer accessListsReadLimiter.Stop() + for { + select { + case <-accessListsReadLimiter.C: + case <-ctx.Done(): + return } - // Only keep access lists that fall within our thresholds in memory - if al.Spec.Audit.NextAuditDate.Sub(now) <= 15*24*time.Hour { - accessLists = append(accessLists, al) + response, nextKey, err := a.Cache.ListAccessLists(ctx, accessListsPageSize, accessListsPageKey) + if err != nil { + a.logger.WarnContext(ctx, "failed to list access lists for periodic reminder notification check", "error", err) } - return nil - }) - if err != nil { - a.logger.WarnContext(ctx, "failed to list access lists for periodic reminder notification check", - "error", err) - return + + for _, al := range response { + if !al.IsReviewable() { + continue + } + + // Only keep access lists that fall within our thresholds in memory + if al.Spec.Audit.NextAuditDate.Sub(now) <= 15*24*time.Hour { + accessLists = append(accessLists, al) + } + } + + if nextKey == "" { + break + } + accessListsPageKey = nextKey } reminderThresholds := []struct { diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index 9857996ed6c1d..f18862bd03f6f 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -4854,7 +4854,7 @@ func TestCreateAccessListReminderNotifications_LargeOverdueSet(t *testing.T) { _, err = authServer.UpsertUser(ctx, user) require.NoError(t, err) - // Create 2001overdue access lists + // Create 2001 overdue access lists // All are overdue by 10 days, which should trigger "overdue by more than 7 days" notification const numAccessLists = 2001 overdueBy := -10 // 10 days overdue @@ -4874,14 +4874,14 @@ func TestCreateAccessListReminderNotifications_LargeOverdueSet(t *testing.T) { }, 3*time.Second, 100*time.Millisecond) // Run CreateAccessListReminderNotifications() - authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond)) + authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond), auth.WithAccessListsPageReadInterval(time.Nanosecond)) identifiers := collectAllUniqueNotificationIdentifiers(t, ctx, authServer, types.NotificationIdentifierPrefixAccessListOverdue7d) require.Len(t, identifiers, numAccessLists, "should have created unique identifiers for all %d overdue access lists", numAccessLists) // Run CreateAccessListReminderNotifications() again to verify it can read multiple pages of identifiers without memory leak - authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond)) + authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond), auth.WithAccessListsPageReadInterval(time.Nanosecond)) identifiers = collectAllUniqueNotificationIdentifiers(t, ctx, authServer, types.NotificationIdentifierPrefixAccessListOverdue7d) require.Len(t, identifiers, numAccessLists, @@ -4891,17 +4891,10 @@ func TestCreateAccessListReminderNotifications_LargeOverdueSet(t *testing.T) { func collectAllUniqueNotificationIdentifiers(t *testing.T, ctx context.Context, authServer *auth.Server, prefix string) []*notificationsv1.UniqueNotificationIdentifier { t.Helper() - var identifiers []*notificationsv1.UniqueNotificationIdentifier - iterator := clientutils.Resources(ctx, func(ctx context.Context, pageSize int, pageKey string) ([]*notificationsv1.UniqueNotificationIdentifier, string, error) { + identifiers, err := stream.Collect(clientutils.Resources(ctx, func(ctx context.Context, pageSize int, pageKey string) ([]*notificationsv1.UniqueNotificationIdentifier, string, error) { return authServer.ListUniqueNotificationIdentifiersForPrefix(ctx, prefix, pageSize, pageKey) - }) - - for identifiersResp, err := range iterator { - if err != nil { - require.NoError(t, err, "listing unique notification identifiers for prefix %q", prefix) - } - identifiers = append(identifiers, identifiersResp) - } + })) + require.NoError(t, err, "listing unique notification identifiers for prefix %q", prefix) return identifiers }