diff --git a/lib/auth/auth.go b/lib/auth/auth.go index e74f7d9dafca3..40fa7e67badac 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -6866,8 +6866,40 @@ const ( accessListReminderSemaphoreMaxLeases = 1 ) +// createAccessListReminderNotificationsOptions defines the optional parameters for CreateAccessListReminderNotifications. +type createAccessListReminderNotificationsOptions struct { + createNotificationInterval time.Duration + accessListsPageReadInterval time.Duration +} + +// CreateAccessListReminderNotificationsOptions is a functional option for CreateAccessListReminderNotifications. +type CreateAccessListReminderNotificationsOptions func(*createAccessListReminderNotificationsOptions) + +// WithCreateNotificationInterval sets the interval between creating notifications. +func WithCreateNotificationInterval(d time.Duration) CreateAccessListReminderNotificationsOptions { + return func(o *createAccessListReminderNotificationsOptions) { + o.createNotificationInterval = d + } +} + +// WithAccessListsPageReadInterval sets the interval between reading pages of access lists. +func WithAccessListsPageReadInterval(d time.Duration) CreateAccessListReminderNotificationsOptions { + return func(o *createAccessListReminderNotificationsOptions) { + o.accessListsPageReadInterval = d + } +} + // CreateAccessListReminderNotifications checks if there are any access lists expiring soon and creates notifications to remind their owners if so. -func (a *Server) CreateAccessListReminderNotifications(ctx context.Context) { +func (a *Server) CreateAccessListReminderNotifications(ctx context.Context, opts ...CreateAccessListReminderNotificationsOptions) { + opt := &createAccessListReminderNotificationsOptions{ + // defaults to notificationsWriteInterval aka 40ms + createNotificationInterval: notificationsWriteInterval, + accessListsPageReadInterval: accessListsPageReadInterval, + } + for _, o := range opts { + o(opt) + } + // Ensure only one auth server is running this check at a time. lease, err := services.AcquireSemaphoreLock(ctx, services.SemaphoreLockConfig{ Service: a, @@ -6901,7 +6933,7 @@ func (a *Server) CreateAccessListReminderNotifications(ctx context.Context) { // Fetch all access lists var accessLists []*accesslist.AccessList var accessListsPageKey string - accessListsReadLimiter := time.NewTicker(accessListsPageReadInterval) + accessListsReadLimiter := time.NewTicker(opt.accessListsPageReadInterval) defer accessListsReadLimiter.Stop() for { select { @@ -6909,6 +6941,7 @@ func (a *Server) CreateAccessListReminderNotifications(ctx context.Context) { case <-ctx.Done(): return } + response, nextKey, err := a.Cache.ListAccessLists(ctx, accessListsPageSize, accessListsPageKey) if err != nil { a.logger.WarnContext(ctx, "failed to list access lists for periodic reminder notification check", "error", err) @@ -6966,17 +6999,16 @@ func (a *Server) CreateAccessListReminderNotifications(ctx context.Context) { // Fetch all identifiers for this treshold prefix. var identifiers []*notificationsv1.UniqueNotificationIdentifier - var nextKey string - for { - identifiersResp, nextKey, err := a.ListUniqueNotificationIdentifiersForPrefix(ctx, threshold.prefix, 0, nextKey) + iterator := clientutils.Resources(ctx, func(ctx context.Context, pageSize int, pageKey string) ([]*notificationsv1.UniqueNotificationIdentifier, string, error) { + return a.ListUniqueNotificationIdentifiersForPrefix(ctx, threshold.prefix, pageSize, pageKey) + }) + + for identifiersResp, err := range iterator { if err != nil { a.logger.WarnContext(ctx, "failed to list notification identifiers", "error", err, "prefix", threshold.prefix) - continue - } - identifiers = append(identifiers, identifiersResp...) - if nextKey == "" { break } + identifiers = append(identifiers, identifiersResp) } // Create a map of identifiers for quick lookup @@ -6992,7 +7024,7 @@ func (a *Server) CreateAccessListReminderNotifications(ctx context.Context) { // Check for access lists which haven't already been accounted for in a notification var needsNotification bool - writeLimiter := time.NewTicker(notificationsWriteInterval) + writeLimiter := time.NewTicker(opt.createNotificationInterval) for _, accessList := range relevantLists { select { case <-writeLimiter.C: diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index bced9f22d8ab2..872607eaf1d48 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -4717,7 +4717,7 @@ func TestCreateAccessListReminderNotifications(t *testing.T) { } // Run CreateAccessListReminderNotifications() - authServer.CreateAccessListReminderNotifications(ctx) + authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond)) reminderNotificationSubKind := func(n *notificationsv1.Notification) string { return n.GetSubKind() } expectedSubKinds := []string{ @@ -4735,7 +4735,7 @@ func TestCreateAccessListReminderNotifications(t *testing.T) { require.ElementsMatch(t, expectedSubKinds, slices.Map(resp.Notifications, reminderNotificationSubKind)) // Run CreateAccessListReminderNotifications() again to verify no duplicates are created - authServer.CreateAccessListReminderNotifications(ctx) + authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond)) // Check notifications again, counts should remain the same. resp, err = client.ListNotifications(ctx, ¬ificationsv1.ListNotificationsRequest{}) @@ -4801,6 +4801,84 @@ func createAccessList(t *testing.T, authServer *auth.Server, name string, opts . require.NoError(t, err) } +func TestCreateAccessListReminderNotifications_LargeOverdueSet(t *testing.T) { + ctx := t.Context() + + modulestest.SetTestModules(t, modulestest.Modules{ + TestBuildType: modules.BuildEnterprise, + TestFeatures: modules.Features{ + Entitlements: map[entitlements.EntitlementKind]modules.EntitlementInfo{ + entitlements.Identity: {Enabled: true}, + }, + }, + }) + + // Setup test auth server + testServer := newTestTLSServer(t) + authServer := testServer.Auth() + + testRole, err := types.NewRole("test", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Logins: []string{"user"}, + ReviewRequests: &types.AccessReviewConditions{}, + }, + }) + require.NoError(t, err) + _, err = authServer.UpsertRole(ctx, testRole) + require.NoError(t, err) + + testUsername := "user1" + user, err := types.NewUser(testUsername) + require.NoError(t, err) + user.SetRoles([]string{"test"}) + _, err = authServer.UpsertUser(ctx, user) + require.NoError(t, err) + + // Create 2001 overdue access lists + // All are overdue by 10 days, which should trigger "overdue by more than 7 days" notification + const numAccessLists = 2001 + overdueBy := -10 // 10 days overdue + nextAuditDate := authServer.GetClock().Now().Add(time.Duration(overdueBy) * 24 * time.Hour) + + for i := range numAccessLists { + createAccessList(t, authServer, fmt.Sprintf("al-overdue-%d", i), + withOwners([]accesslist.Owner{{Name: testUsername}}), + withNextAuditDate(nextAuditDate), + ) + } + + require.EventuallyWithT(t, func(t *assert.CollectT) { + lists, err := testServer.Auth().Cache.GetAccessLists(ctx) + assert.NoError(t, err) + assert.Len(t, lists, numAccessLists, "should have created all %d overdue access lists", numAccessLists) + }, 3*time.Second, 100*time.Millisecond) + + // Run CreateAccessListReminderNotifications() + authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond), auth.WithAccessListsPageReadInterval(time.Nanosecond)) + + identifiers := collectAllUniqueNotificationIdentifiers(t, ctx, authServer, types.NotificationIdentifierPrefixAccessListOverdue7d) + require.Len(t, identifiers, numAccessLists, + "should have created unique identifiers for all %d overdue access lists", numAccessLists) + + // Run CreateAccessListReminderNotifications() again to verify it can read multiple pages of identifiers without memory leak + authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond), auth.WithAccessListsPageReadInterval(time.Nanosecond)) + + identifiers = collectAllUniqueNotificationIdentifiers(t, ctx, authServer, types.NotificationIdentifierPrefixAccessListOverdue7d) + require.Len(t, identifiers, numAccessLists, + "should have created unique identifiers for all %d overdue access lists", numAccessLists) +} + +func collectAllUniqueNotificationIdentifiers(t *testing.T, ctx context.Context, authServer *auth.Server, prefix string) []*notificationsv1.UniqueNotificationIdentifier { + t.Helper() + + identifiers, err := stream.Collect(clientutils.Resources(ctx, func(ctx context.Context, pageSize int, pageKey string) ([]*notificationsv1.UniqueNotificationIdentifier, string, error) { + return authServer.ListUniqueNotificationIdentifiersForPrefix(ctx, prefix, pageSize, pageKey) + })) + require.NoError(t, err, "listing unique notification identifiers for prefix %q", prefix) + + return identifiers +} + func TestServer_GetAnonymizationKey(t *testing.T) { tests := []struct { name string