diff --git a/lib/services/local/access_list.go b/lib/services/local/access_list.go index b82a82a5c5cca..08132492cef35 100644 --- a/lib/services/local/access_list.go +++ b/lib/services/local/access_list.go @@ -202,10 +202,16 @@ func (a *AccessListService) ListAccessListMembers(ctx context.Context, accessLis } // ListAllAccessListMembers returns a paginated list of all access list members for all access lists. -func (a *AccessListService) ListAllAccessListMembers(ctx context.Context, pageSize int, pageToken string) (members []*accesslist.AccessListMember, nextToken string, err error) { - // Locks are not used here as these operations are more likely to be used by the cache. - // Lists all access list members for all access lists. - return a.memberService.ListResources(ctx, pageSize, nextToken) +func (a *AccessListService) ListAllAccessListMembers(ctx context.Context, pageSize int, pageToken string) ([]*accesslist.AccessListMember, string, error) { + members, next, err := a.memberService.ListResourcesReturnNextResource(ctx, pageSize, pageToken) + if err != nil { + return nil, "", trace.Wrap(err) + } + var nextKey string + if next != nil { + nextKey = (*next).Spec.AccessList + string(backend.Separator) + (*next).Metadata.Name + } + return members, nextKey, nil } // GetAccessListMember returns the specified access list member resource. @@ -355,10 +361,16 @@ func (a *AccessListService) ListAccessListReviews(ctx context.Context, accessLis } // ListAllAccessListReviews will list access list reviews for all access lists. -func (a *AccessListService) ListAllAccessListReviews(ctx context.Context, pageSize int, pageToken string) (reviews []*accesslist.Review, nextToken string, err error) { - // Locks are not used here as these operations are more likely to be used by the cache. - // Lists all access list reviews for all access lists. - return a.reviewService.ListResources(ctx, pageSize, pageToken) +func (a *AccessListService) ListAllAccessListReviews(ctx context.Context, pageSize int, pageToken string) ([]*accesslist.Review, string, error) { + reviews, next, err := a.reviewService.ListResourcesReturnNextResource(ctx, pageSize, pageToken) + if err != nil { + return nil, "", trace.Wrap(err) + } + var nextKey string + if next != nil { + nextKey = (*next).Spec.AccessList + string(backend.Separator) + (*next).Metadata.Name + } + return reviews, nextKey, nil } // CreateAccessListReview will create a new review for an access list. diff --git a/lib/services/local/access_list_test.go b/lib/services/local/access_list_test.go index d934adfb4ac7e..536e19b6a7288 100644 --- a/lib/services/local/access_list_test.go +++ b/lib/services/local/access_list_test.go @@ -18,6 +18,8 @@ package local import ( "context" + "fmt" + "strconv" "testing" "time" @@ -786,7 +788,7 @@ func newAccessListReview(t *testing.T, accessList, name string) *accesslist.Revi review, err := accesslist.NewReview( header.Metadata{ - Name: "test-access-list-review", + Name: name, }, accesslist.ReviewSpec{ AccessList: accessList, @@ -826,3 +828,119 @@ func newAccessListReview(t *testing.T, accessList, name string) *accesslist.Revi return review } + +func TestAccessListService_ListAllAccessListMembers(t *testing.T) { + ctx := context.Background() + clock := clockwork.NewFakeClock() + + mem, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) + + service, err := NewAccessListService(backend.NewSanitizer(mem), clock) + require.NoError(t, err) + + const numAccessLists = 10 + const numAccessListMembersPerAccessList = 250 + totalMembers := numAccessLists * numAccessListMembersPerAccessList + + // Create several access lists. + expectedMembers := make([]*accesslist.AccessListMember, totalMembers) + for i := 0; i < numAccessLists; i++ { + alName := strconv.Itoa(i) + _, err := service.UpsertAccessList(ctx, newAccessList(t, alName, clock)) + require.NoError(t, err) + + for j := 0; j < numAccessListMembersPerAccessList; j++ { + member := newAccessListMember(t, alName, fmt.Sprintf("%03d", j)) + expectedMembers[i*numAccessListMembersPerAccessList+j] = member + _, err := service.UpsertAccessListMember(ctx, member) + require.NoError(t, err) + } + } + + allMembers := make([]*accesslist.AccessListMember, 0, totalMembers) + var nextToken string + for { + var members []*accesslist.AccessListMember + var err error + members, nextToken, err = service.ListAllAccessListMembers(ctx, 0, nextToken) + require.NoError(t, err) + + allMembers = append(allMembers, members...) + + if nextToken == "" { + break + } + } + + require.Empty(t, cmp.Diff(expectedMembers, allMembers, cmpopts.IgnoreFields(header.Metadata{}, "ID"))) +} + +func TestAccessListService_ListAllAccessListReviews(t *testing.T) { + ctx := context.Background() + clock := clockwork.NewFakeClock() + + mem, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) + + service, err := NewAccessListService(backend.NewSanitizer(mem), clock) + require.NoError(t, err) + + const numAccessLists = 10 + const numAccessListReviewsPerAccessList = 250 + totalReviews := numAccessLists * numAccessListReviewsPerAccessList + + // Create several access lists. + expectedReviews := make([]*accesslist.Review, totalReviews) + for i := 0; i < numAccessLists; i++ { + alName := strconv.Itoa(i) + _, err := service.UpsertAccessList(ctx, newAccessList(t, alName, clock)) + require.NoError(t, err) + + for j := 0; j < numAccessListReviewsPerAccessList; j++ { + review, err := accesslist.NewReview( + header.Metadata{ + Name: strconv.Itoa(j), + }, + accesslist.ReviewSpec{ + AccessList: alName, + Reviewers: []string{ + "user1", + }, + ReviewDate: time.Now(), + }, + ) + require.NoError(t, err) + review, _, err = service.CreateAccessListReview(ctx, review) + expectedReviews[i*numAccessListReviewsPerAccessList+j] = review + require.NoError(t, err) + } + } + + allReviews := make([]*accesslist.Review, 0, totalReviews) + var nextToken string + for { + var reviews []*accesslist.Review + var err error + reviews, nextToken, err = service.ListAllAccessListReviews(ctx, 0, nextToken) + require.NoError(t, err) + + allReviews = append(allReviews, reviews...) + + if nextToken == "" { + break + } + } + + require.Empty(t, cmp.Diff(expectedReviews, allReviews, cmpopts.IgnoreFields(header.Metadata{}, "ID"), cmpopts.SortSlices( + func(r1, r2 *accesslist.Review) bool { + return r1.GetName() < r2.GetName() + }), + )) +} diff --git a/lib/services/local/generic/generic.go b/lib/services/local/generic/generic.go index c884c1dbd90f3..c98e1a8a47195 100644 --- a/lib/services/local/generic/generic.go +++ b/lib/services/local/generic/generic.go @@ -138,6 +138,17 @@ func (s *Service[T]) GetResources(ctx context.Context) ([]T, error) { // ListResources returns a paginated list of resources. func (s *Service[T]) ListResources(ctx context.Context, pageSize int, pageToken string) ([]T, string, error) { + resources, next, err := s.ListResourcesReturnNextResource(ctx, pageSize, pageToken) + var nextKey string + if next != nil { + nextKey = backend.GetPaginationKey(*next) + } + return resources, nextKey, trace.Wrap(err) +} + +// ListResourcesReturnNextResource returns a paginated list of resources. The next resource is returned, which allows consumers to construct +// the next pagination key as appropriate. +func (s *Service[T]) ListResourcesReturnNextResource(ctx context.Context, pageSize int, pageToken string) ([]T, *T, error) { rangeStart := backend.Key(s.backendPrefix, pageToken) rangeEnd := backend.RangeEnd(backend.ExactKey(s.backendPrefix)) @@ -151,26 +162,26 @@ func (s *Service[T]) ListResources(ctx context.Context, pageSize int, pageToken // no filter provided get the range directly result, err := s.backend.GetRange(ctx, rangeStart, rangeEnd, limit) if err != nil { - return nil, "", trace.Wrap(err) + return nil, nil, trace.Wrap(err) } out := make([]T, 0, len(result.Items)) for _, item := range result.Items { resource, err := s.unmarshalFunc(item.Value) if err != nil { - return nil, "", trace.Wrap(err) + return nil, nil, trace.Wrap(err) } out = append(out, resource) } - var nextKey string + var next *T if len(out) > pageSize { - nextKey = backend.GetPaginationKey(out[len(out)-1]) + next = &out[pageSize] // Truncate the last item that was used to determine next row existence. out = out[:pageSize] } - return out, nextKey, nil + return out, next, nil } // GetResource returns the specified resource. diff --git a/lib/services/local/generic/generic_test.go b/lib/services/local/generic/generic_test.go index 1e5efa1adc2ec..21d147dbf94e4 100644 --- a/lib/services/local/generic/generic_test.go +++ b/lib/services/local/generic/generic_test.go @@ -264,3 +264,46 @@ func TestGenericCRUD(t *testing.T) { require.Empty(t, nextToken) require.Empty(t, out) } + +func TestGenericListResourcesReturnNextResource(t *testing.T) { + ctx := context.Background() + + memBackend, err := memory.New(memory.Config{ + Context: ctx, + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) + + service, err := NewService(&ServiceConfig[*testResource]{ + Backend: memBackend, + ResourceKind: "generic resource", + PageLimit: 200, + BackendPrefix: "generic_prefix", + UnmarshalFunc: unmarshalResource, + MarshalFunc: marshalResource, + }) + require.NoError(t, err) + + // Create a couple test resources. + r1 := newTestResource("r1") + r2 := newTestResource("r2") + + err = service.WithPrefix("a-unique-prefix").UpsertResource(ctx, r1) + require.NoError(t, err) + err = service.WithPrefix("another-unique-prefix").UpsertResource(ctx, r2) + require.NoError(t, err) + + page, next, err := service.ListResourcesReturnNextResource(ctx, 1, "") + require.NoError(t, err) + require.Empty(t, cmp.Diff([]*testResource{r1}, page, + cmpopts.IgnoreFields(types.Metadata{}, "ID"), + )) + require.NotNil(t, next) + + page, next, err = service.ListResourcesReturnNextResource(ctx, 1, "another-unique-prefix"+string(backend.Separator)+backend.GetPaginationKey(*next)) + require.NoError(t, err) + require.Empty(t, cmp.Diff([]*testResource{r2}, page, + cmpopts.IgnoreFields(types.Metadata{}, "ID"), + )) + require.Nil(t, next) +}