Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ const (
const (
notificationsPageReadInterval = 5 * time.Millisecond
notificationsWriteInterval = 40 * time.Millisecond

accessListsPageReadInterval = 5 * time.Millisecond
accessListsPageSize = 20
)

const (
Expand Down Expand Up @@ -6907,7 +6910,8 @@ const (

// createAccessListReminderNotificationsOptions defines the optional parameters for CreateAccessListReminderNotifications.
type createAccessListReminderNotificationsOptions struct {
createNotificationInterval time.Duration
createNotificationInterval time.Duration
accessListsPageReadInterval time.Duration
}

// CreateAccessListReminderNotificationsOptions is a functional option for CreateAccessListReminderNotifications.
Expand All @@ -6920,11 +6924,19 @@ func WithCreateNotificationInterval(d time.Duration) CreateAccessListReminderNot
}
}

// WithAccessListsPageReadInterval sets the interval between reading pages of access lists.
func WithAccessListsPageReadInterval(d time.Duration) CreateAccessListReminderNotificationsOptions {
return func(o *createAccessListReminderNotificationsOptions) {
o.accessListsPageReadInterval = d
}
}

// CreateAccessListReminderNotifications checks if there are any access lists expiring soon and creates notifications to remind their owners if so.
func (a *Server) CreateAccessListReminderNotifications(ctx context.Context, opts ...CreateAccessListReminderNotificationsOptions) {
opt := &createAccessListReminderNotificationsOptions{
// defaults to notificationsWriteInterval aka 40ms
createNotificationInterval: notificationsWriteInterval,
createNotificationInterval: notificationsWriteInterval,
accessListsPageReadInterval: accessListsPageReadInterval,
}
for _, o := range opts {
o(opt)
Expand Down Expand Up @@ -6962,21 +6974,36 @@ func (a *Server) CreateAccessListReminderNotifications(ctx context.Context, opts

// Fetch all access lists
var accessLists []*accesslist.AccessList
err = clientutils.IterateResources(ctx, a.Cache.ListAccessLists, func(al *accesslist.AccessList) error {
if !al.IsReviewable() {
return nil
var accessListsPageKey string
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about keeping IterateResources and just putting a time.Sleep at the end of the closure? It's only 5ms, so if the context is canceled it won't take us long to detect.

I'm just weary of continuing to hand-write pagination code given the dozens of pagination bugs we've had. Now that we've got a well-tested and uniform tool, I think we should use it wherever possible.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I removed it in the first place is that if the sleep is inside the closure, it will sleep 20× more than intended. That’s because the IterateResources closure is invoked per item, not per page.

I could technically sleep for 1/20 of the time, but I think it’s better to keep the code as it is.

I could also add another iterator that operates on pages instead, but that would likely make the code more confusing.

accessListsReadLimiter := time.NewTicker(opt.accessListsPageReadInterval)
defer accessListsReadLimiter.Stop()
for {
select {
case <-accessListsReadLimiter.C:
case <-ctx.Done():
return
}

// Only keep access lists that fall within our thresholds in memory
if al.Spec.Audit.NextAuditDate.Sub(now) <= 15*24*time.Hour {
accessLists = append(accessLists, al)
response, nextKey, err := a.Cache.ListAccessLists(ctx, accessListsPageSize, accessListsPageKey)
if err != nil {
a.logger.WarnContext(ctx, "failed to list access lists for periodic reminder notification check", "error", err)
}
return nil
})
if err != nil {
a.logger.WarnContext(ctx, "failed to list access lists for periodic reminder notification check",
"error", err)
return

for _, al := range response {
if !al.IsReviewable() {
continue
}

// Only keep access lists that fall within our thresholds in memory
if al.Spec.Audit.NextAuditDate.Sub(now) <= 15*24*time.Hour {
accessLists = append(accessLists, al)
}
}

if nextKey == "" {
break
}
accessListsPageKey = nextKey
}

reminderThresholds := []struct {
Expand Down
19 changes: 6 additions & 13 deletions lib/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4854,7 +4854,7 @@ func TestCreateAccessListReminderNotifications_LargeOverdueSet(t *testing.T) {
_, err = authServer.UpsertUser(ctx, user)
require.NoError(t, err)

// Create 2001overdue access lists
// Create 2001 overdue access lists
// All are overdue by 10 days, which should trigger "overdue by more than 7 days" notification
const numAccessLists = 2001
overdueBy := -10 // 10 days overdue
Expand All @@ -4874,14 +4874,14 @@ func TestCreateAccessListReminderNotifications_LargeOverdueSet(t *testing.T) {
}, 3*time.Second, 100*time.Millisecond)

// Run CreateAccessListReminderNotifications()
authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond))
authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond), auth.WithAccessListsPageReadInterval(time.Nanosecond))

identifiers := collectAllUniqueNotificationIdentifiers(t, ctx, authServer, types.NotificationIdentifierPrefixAccessListOverdue7d)
require.Len(t, identifiers, numAccessLists,
"should have created unique identifiers for all %d overdue access lists", numAccessLists)

// Run CreateAccessListReminderNotifications() again to verify it can read multiple pages of identifiers without memory leak
authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond))
authServer.CreateAccessListReminderNotifications(ctx, auth.WithCreateNotificationInterval(time.Nanosecond), auth.WithAccessListsPageReadInterval(time.Nanosecond))

identifiers = collectAllUniqueNotificationIdentifiers(t, ctx, authServer, types.NotificationIdentifierPrefixAccessListOverdue7d)
require.Len(t, identifiers, numAccessLists,
Expand All @@ -4891,17 +4891,10 @@ func TestCreateAccessListReminderNotifications_LargeOverdueSet(t *testing.T) {
func collectAllUniqueNotificationIdentifiers(t *testing.T, ctx context.Context, authServer *auth.Server, prefix string) []*notificationsv1.UniqueNotificationIdentifier {
t.Helper()

var identifiers []*notificationsv1.UniqueNotificationIdentifier
iterator := clientutils.Resources(ctx, func(ctx context.Context, pageSize int, pageKey string) ([]*notificationsv1.UniqueNotificationIdentifier, string, error) {
identifiers, err := stream.Collect(clientutils.Resources(ctx, func(ctx context.Context, pageSize int, pageKey string) ([]*notificationsv1.UniqueNotificationIdentifier, string, error) {
return authServer.ListUniqueNotificationIdentifiersForPrefix(ctx, prefix, pageSize, pageKey)
})

for identifiersResp, err := range iterator {
if err != nil {
require.NoError(t, err, "listing unique notification identifiers for prefix %q", prefix)
}
identifiers = append(identifiers, identifiersResp)
}
}))
require.NoError(t, err, "listing unique notification identifiers for prefix %q", prefix)

return identifiers
}
Expand Down
Loading