Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions lib/services/local/access_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,16 @@ func (a *AccessListService) ListAccessListMembers(ctx context.Context, accessLis
}

// ListAllAccessListMembers returns a paginated list of all access list members for all access lists.
func (a *AccessListService) ListAllAccessListMembers(ctx context.Context, pageSize int, pageToken string) (members []*accesslist.AccessListMember, nextToken string, err error) {
// Locks are not used here as these operations are more likely to be used by the cache.
// Lists all access list members for all access lists.
return a.memberService.ListResources(ctx, pageSize, nextToken)
func (a *AccessListService) ListAllAccessListMembers(ctx context.Context, pageSize int, pageToken string) ([]*accesslist.AccessListMember, string, error) {
members, next, err := a.memberService.ListResourcesReturnNextResource(ctx, pageSize, pageToken)
if err != nil {
return nil, "", trace.Wrap(err)
}
var nextKey string
if next != nil {
nextKey = (*next).Spec.AccessList + string(backend.Separator) + (*next).Metadata.Name
}
return members, nextKey, nil
}

// GetAccessListMember returns the specified access list member resource.
Expand Down Expand Up @@ -355,10 +361,16 @@ func (a *AccessListService) ListAccessListReviews(ctx context.Context, accessLis
}

// ListAllAccessListReviews will list access list reviews for all access lists.
func (a *AccessListService) ListAllAccessListReviews(ctx context.Context, pageSize int, pageToken string) (reviews []*accesslist.Review, nextToken string, err error) {
// Locks are not used here as these operations are more likely to be used by the cache.
// Lists all access list reviews for all access lists.
return a.reviewService.ListResources(ctx, pageSize, pageToken)
func (a *AccessListService) ListAllAccessListReviews(ctx context.Context, pageSize int, pageToken string) ([]*accesslist.Review, string, error) {
reviews, next, err := a.reviewService.ListResourcesReturnNextResource(ctx, pageSize, pageToken)
if err != nil {
return nil, "", trace.Wrap(err)
}
var nextKey string
if next != nil {
nextKey = (*next).Spec.AccessList + string(backend.Separator) + (*next).Metadata.Name
}
return reviews, nextKey, nil
}

// CreateAccessListReview will create a new review for an access list.
Expand Down
120 changes: 119 additions & 1 deletion lib/services/local/access_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package local

import (
"context"
"fmt"
"strconv"
"testing"
"time"

Expand Down Expand Up @@ -786,7 +788,7 @@ func newAccessListReview(t *testing.T, accessList, name string) *accesslist.Revi

review, err := accesslist.NewReview(
header.Metadata{
Name: "test-access-list-review",
Name: name,
},
accesslist.ReviewSpec{
AccessList: accessList,
Expand Down Expand Up @@ -826,3 +828,119 @@ func newAccessListReview(t *testing.T, accessList, name string) *accesslist.Revi

return review
}

func TestAccessListService_ListAllAccessListMembers(t *testing.T) {
ctx := context.Background()
clock := clockwork.NewFakeClock()

mem, err := memory.New(memory.Config{
Context: ctx,
Clock: clock,
})
require.NoError(t, err)

service, err := NewAccessListService(backend.NewSanitizer(mem), clock)
require.NoError(t, err)

const numAccessLists = 10
const numAccessListMembersPerAccessList = 250
totalMembers := numAccessLists * numAccessListMembersPerAccessList

// Create several access lists.
expectedMembers := make([]*accesslist.AccessListMember, totalMembers)
for i := 0; i < numAccessLists; i++ {
alName := strconv.Itoa(i)
_, err := service.UpsertAccessList(ctx, newAccessList(t, alName, clock))
require.NoError(t, err)

for j := 0; j < numAccessListMembersPerAccessList; j++ {
member := newAccessListMember(t, alName, fmt.Sprintf("%03d", j))
expectedMembers[i*numAccessListMembersPerAccessList+j] = member
_, err := service.UpsertAccessListMember(ctx, member)
require.NoError(t, err)
}
}

allMembers := make([]*accesslist.AccessListMember, 0, totalMembers)
var nextToken string
for {
var members []*accesslist.AccessListMember
var err error
members, nextToken, err = service.ListAllAccessListMembers(ctx, 0, nextToken)
require.NoError(t, err)

allMembers = append(allMembers, members...)

if nextToken == "" {
break
}
}

require.Empty(t, cmp.Diff(expectedMembers, allMembers, cmpopts.IgnoreFields(header.Metadata{}, "ID")))
}

func TestAccessListService_ListAllAccessListReviews(t *testing.T) {
ctx := context.Background()
clock := clockwork.NewFakeClock()

mem, err := memory.New(memory.Config{
Context: ctx,
Clock: clock,
})
require.NoError(t, err)

service, err := NewAccessListService(backend.NewSanitizer(mem), clock)
require.NoError(t, err)

const numAccessLists = 10
const numAccessListReviewsPerAccessList = 250
totalReviews := numAccessLists * numAccessListReviewsPerAccessList

// Create several access lists.
expectedReviews := make([]*accesslist.Review, totalReviews)
for i := 0; i < numAccessLists; i++ {
alName := strconv.Itoa(i)
_, err := service.UpsertAccessList(ctx, newAccessList(t, alName, clock))
require.NoError(t, err)

for j := 0; j < numAccessListReviewsPerAccessList; j++ {
review, err := accesslist.NewReview(
header.Metadata{
Name: strconv.Itoa(j),
},
accesslist.ReviewSpec{
AccessList: alName,
Reviewers: []string{
"user1",
},
ReviewDate: time.Now(),
},
)
require.NoError(t, err)
review, _, err = service.CreateAccessListReview(ctx, review)
expectedReviews[i*numAccessListReviewsPerAccessList+j] = review
require.NoError(t, err)
}
}

allReviews := make([]*accesslist.Review, 0, totalReviews)
var nextToken string
for {
var reviews []*accesslist.Review
var err error
reviews, nextToken, err = service.ListAllAccessListReviews(ctx, 0, nextToken)
require.NoError(t, err)

allReviews = append(allReviews, reviews...)

if nextToken == "" {
break
}
}

require.Empty(t, cmp.Diff(expectedReviews, allReviews, cmpopts.IgnoreFields(header.Metadata{}, "ID"), cmpopts.SortSlices(
func(r1, r2 *accesslist.Review) bool {
return r1.GetName() < r2.GetName()
}),
))
}
21 changes: 16 additions & 5 deletions lib/services/local/generic/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ func (s *Service[T]) GetResources(ctx context.Context) ([]T, error) {

// ListResources returns a paginated list of resources.
func (s *Service[T]) ListResources(ctx context.Context, pageSize int, pageToken string) ([]T, string, error) {
resources, next, err := s.ListResourcesReturnNextResource(ctx, pageSize, pageToken)
var nextKey string
if next != nil {
nextKey = backend.GetPaginationKey(*next)
}
return resources, nextKey, trace.Wrap(err)
}

// ListResourcesReturnNextResource returns a paginated list of resources. The next resource is returned, which allows consumers to construct
// the next pagination key as appropriate.
func (s *Service[T]) ListResourcesReturnNextResource(ctx context.Context, pageSize int, pageToken string) ([]T, *T, error) {
rangeStart := backend.Key(s.backendPrefix, pageToken)
rangeEnd := backend.RangeEnd(backend.ExactKey(s.backendPrefix))

Expand All @@ -151,26 +162,26 @@ func (s *Service[T]) ListResources(ctx context.Context, pageSize int, pageToken
// no filter provided get the range directly
result, err := s.backend.GetRange(ctx, rangeStart, rangeEnd, limit)
if err != nil {
return nil, "", trace.Wrap(err)
return nil, nil, trace.Wrap(err)
}

out := make([]T, 0, len(result.Items))
for _, item := range result.Items {
resource, err := s.unmarshalFunc(item.Value)
if err != nil {
return nil, "", trace.Wrap(err)
return nil, nil, trace.Wrap(err)
}
out = append(out, resource)
}

var nextKey string
var next *T
if len(out) > pageSize {
nextKey = backend.GetPaginationKey(out[len(out)-1])
next = &out[pageSize]
// Truncate the last item that was used to determine next row existence.
out = out[:pageSize]
}

return out, nextKey, nil
return out, next, nil
}

// GetResource returns the specified resource.
Expand Down
43 changes: 43 additions & 0 deletions lib/services/local/generic/generic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,46 @@ func TestGenericCRUD(t *testing.T) {
require.Empty(t, nextToken)
require.Empty(t, out)
}

func TestGenericListResourcesReturnNextResource(t *testing.T) {
ctx := context.Background()

memBackend, err := memory.New(memory.Config{
Context: ctx,
Clock: clockwork.NewFakeClock(),
})
require.NoError(t, err)

service, err := NewService(&ServiceConfig[*testResource]{
Backend: memBackend,
ResourceKind: "generic resource",
PageLimit: 200,
BackendPrefix: "generic_prefix",
UnmarshalFunc: unmarshalResource,
MarshalFunc: marshalResource,
})
require.NoError(t, err)

// Create a couple test resources.
r1 := newTestResource("r1")
r2 := newTestResource("r2")

err = service.WithPrefix("a-unique-prefix").UpsertResource(ctx, r1)
require.NoError(t, err)
err = service.WithPrefix("another-unique-prefix").UpsertResource(ctx, r2)
require.NoError(t, err)

page, next, err := service.ListResourcesReturnNextResource(ctx, 1, "")
require.NoError(t, err)
require.Empty(t, cmp.Diff([]*testResource{r1}, page,
cmpopts.IgnoreFields(types.Metadata{}, "ID"),
))
require.NotNil(t, next)

page, next, err = service.ListResourcesReturnNextResource(ctx, 1, "another-unique-prefix"+string(backend.Separator)+backend.GetPaginationKey(*next))
require.NoError(t, err)
require.Empty(t, cmp.Diff([]*testResource{r2}, page,
cmpopts.IgnoreFields(types.Metadata{}, "ID"),
))
require.Nil(t, next)
}