diff --git a/api/utils/clientutils/resources.go b/api/utils/clientutils/resources.go index 3077e739e0241..a06fb39e34f1a 100644 --- a/api/utils/clientutils/resources.go +++ b/api/utils/clientutils/resources.go @@ -69,7 +69,7 @@ func ResourcesWithPageSize[T any](ctx context.Context, pageFunc func(context.Con continue } - yield(*new(T), err) + yield(*new(T), trace.Wrap(err)) return } for _, resource := range page { diff --git a/lib/services/local/access_list.go b/lib/services/local/access_list.go index 2b8a346d4c22d..4ce5aa00abc3a 100644 --- a/lib/services/local/access_list.go +++ b/lib/services/local/access_list.go @@ -204,7 +204,7 @@ func (a *AccessListService) runOpWithLock(ctx context.Context, accessList *acces } var upserted *accesslist.AccessList - var existingList *accesslist.AccessList + var existingAccessList *accesslist.AccessList opFn := a.service.UpsertResource if op == opTypeUpdate { @@ -214,19 +214,14 @@ func (a *AccessListService) runOpWithLock(ctx context.Context, accessList *acces validateAccessList := func() error { var err error - if op == opTypeUpdate { - existingList, err = a.service.GetResource(ctx, accessList.GetName()) - if err != nil { - return trace.Wrap(err) - } - // Set memberOf / ownerOf to the existing values to prevent them from being updated. - accessList.Status.MemberOf = existingList.Status.MemberOf - accessList.Status.OwnerOf = existingList.Status.OwnerOf - } else { - // In case the MemberOf/OwnerOf fields were manually changed, set to empty. - accessList.Status.MemberOf = []string{} - accessList.Status.OwnerOf = []string{} + existingAccessList, err = a.service.GetResource(ctx, accessList.GetName()) + if op == opTypeUpsert && trace.IsNotFound(err) { + // Not having already existing access_list in the backend is ok in case of + // upsert. + } else if err != nil { + return trace.Wrap(err) } + preserveAccessListFields(existingAccessList, accessList) listMembers, err := a.memberService.WithPrefix(accessList.GetName()).GetResources(ctx) if err != nil { @@ -245,8 +240,8 @@ func (a *AccessListService) runOpWithLock(ctx context.Context, accessList *acces reconcileOwners := func() error { // Create map to store owners for efficient lookup originalOwnersMap := make(map[string]struct{}) - if existingList != nil { - for _, owner := range existingList.Spec.Owners { + if existingAccessList != nil { + for _, owner := range existingAccessList.Spec.Owners { if owner.MembershipKind == accesslist.MembershipKindList { originalOwnersMap[owner.Name] = struct{}{} } @@ -552,14 +547,17 @@ func (a *AccessListService) UpsertAccessListMember(ctx context.Context, member * } upserted, err = a.memberService.WithPrefix(member.Spec.AccessList).UpsertResource(ctx, member) + if err != nil { + return trace.Wrap(err) + } - if err == nil && member.Spec.MembershipKind == accesslist.MembershipKindList { + if member.Spec.MembershipKind == accesslist.MembershipKindList { if err := a.updateAccessListMemberOf(ctx, member.Spec.AccessList, member.Spec.Name, true); err != nil { return trace.Wrap(err) } } - return trace.Wrap(err) + return nil } err := a.service.RunWhileLocked(ctx, []string{accessListResourceLockName}, accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { @@ -679,18 +677,11 @@ func (a *AccessListService) UpsertAccessListWithMembers(ctx context.Context, acc } validateAccessList := func() error { - existingList, err := a.service.GetResource(ctx, accessList.GetName()) + existingAccessList, err := a.service.GetResource(ctx, accessList.GetName()) if err != nil && !trace.IsNotFound(err) { return trace.Wrap(err) } - if existingList != nil { - accessList.Status.MemberOf = existingList.Status.MemberOf - accessList.Status.OwnerOf = existingList.Status.OwnerOf - } else { - // In case the MemberOf/OwnerOf fields were manually changed, set to empty. - accessList.Status.MemberOf = []string{} - accessList.Status.OwnerOf = []string{} - } + preserveAccessListFields(existingAccessList, accessList) if err := accesslists.ValidateAccessListWithMembers(ctx, accessList, membersIn, &accessListAndMembersGetter{a.service, a.memberService}); err != nil { return trace.Wrap(err) @@ -1045,6 +1036,18 @@ func (a *AccessListService) VerifyAccessListCreateLimit(ctx context.Context, tar return trace.AccessDenied("%s", limitReachedMessage) } +func preserveAccessListFields(existingAccessList, accessList *accesslist.AccessList) { + if existingAccessList != nil { + // Set MemberOf/OwnerOf to the existing values to prevent them from being updated. + accessList.Status.MemberOf = existingAccessList.Status.MemberOf + accessList.Status.OwnerOf = existingAccessList.Status.OwnerOf + } else { + // For newly created AccessList make sure MemberOf/OwnerOf are empty. + accessList.Status.MemberOf = []string{} + accessList.Status.OwnerOf = []string{} + } +} + // keepAWSIdentityCenterLabels preserves member labels if // it originated from AWS Identity Center plugin. // The Web UI does not currently preserve metadata labels so this function should be called diff --git a/lib/services/local/access_list_test.go b/lib/services/local/access_list_test.go index 4258b388a1760..f17df9287b707 100644 --- a/lib/services/local/access_list_test.go +++ b/lib/services/local/access_list_test.go @@ -21,12 +21,14 @@ package local import ( "context" "fmt" + "slices" "strconv" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -1053,7 +1055,8 @@ func TestAccessListRequiresEqual(t *testing.T) { } type newAccessListOptions struct { - typ accesslist.Type + typ accesslist.Type + owners []accesslist.Owner } type newAccessListOpt func(*newAccessListOptions) @@ -1064,10 +1067,27 @@ func withType(typ accesslist.Type) newAccessListOpt { } } +func withOwners(owners []accesslist.Owner) newAccessListOpt { + return func(o *newAccessListOptions) { + o.owners = owners + } +} + func newAccessList(t *testing.T, name string, clock clockwork.Clock, opts ...newAccessListOpt) *accesslist.AccessList { t.Helper() - options := newAccessListOptions{} + options := newAccessListOptions{ + owners: []accesslist.Owner{ + { + Name: "test-user1", + Description: "test user 1", + }, + { + Name: "test-user2", + Description: "test user 2", + }, + }, + } for _, o := range opts { o(&options) } @@ -1083,19 +1103,10 @@ func newAccessList(t *testing.T, name string, clock clockwork.Clock, opts ...new }, accesslist.Spec{ Type: options.typ, - Title: "title", + Title: name + " title", Description: "test access list", - Owners: []accesslist.Owner{ - { - Name: "test-user1", - Description: "test user 1", - }, - { - Name: "test-user2", - Description: "test user 2", - }, - }, - Audit: audit, + Owners: options.owners, + Audit: audit, MembershipRequires: accesslist.Requires{ Roles: []string{"mrole1", "mrole2"}, Traits: map[string][]string{ @@ -1124,20 +1135,47 @@ func newAccessList(t *testing.T, name string, clock clockwork.Clock, opts ...new return accessList } -func newAccessListMember(t *testing.T, accessList, name string) *accesslist.AccessListMember { +func createAccessList(t *testing.T, service *AccessListService, name string, clock clockwork.Clock, opts ...newAccessListOpt) *accesslist.AccessList { t.Helper() + ctx := context.Background() + accessList := newAccessList(t, name, clock, opts...) + upserted, err := service.UpsertAccessList(ctx, accessList) + require.NoError(t, err) + return upserted +} + +type accessListMemberOptions struct { + membershipKind string +} + +type accessListMemberOpt func(*accessListMemberOptions) + +func withMembershipKind(membershipKind string) accessListMemberOpt { + return func(o *accessListMemberOptions) { + o.membershipKind = membershipKind + } +} + +func newAccessListMember(t *testing.T, accessList, name string, opts ...accessListMemberOpt) *accesslist.AccessListMember { + t.Helper() + + options := accessListMemberOptions{} + for _, o := range opts { + o(&options) + } member, err := accesslist.NewAccessListMember( header.Metadata{ Name: name, }, accesslist.AccessListMemberSpec{ - AccessList: accessList, - Name: name, - Joined: time.Now(), - Expires: time.Now().Add(time.Hour * 24), - Reason: "a reason", - AddedBy: "dummy", + AccessList: accessList, + Name: name, + Joined: time.Now(), + Expires: time.Now().Add(time.Hour * 24), + Reason: "a reason", + AddedBy: "dummy", + MembershipKind: options.membershipKind, }, ) require.NoError(t, err) @@ -1305,6 +1343,179 @@ func TestAccessListService_ListAllAccessListReviews(t *testing.T) { )) } +func TestAccessListService_Status_OwnerOf(t *testing.T) { + ctx := context.Background() + clock := clockwork.NewFakeClock() + + mem, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) + + service := newAccessListService(t, mem, clock, true /* igsEnabled */) + + ownersAccessList := createAccessList(t, service, "test-owners-acl-"+uuid.NewString(), clock) + requireStatusOwnerOf(t, service, ownersAccessList.GetName(), nil) + + accessList := createAccessList(t, service, "test-acl-"+uuid.NewString(), clock, + withOwners([]accesslist.Owner{ + { + Name: ownersAccessList.GetName(), + MembershipKind: accesslist.MembershipKindList, + }, + }), + ) + requireStatusOwnerOf(t, service, ownersAccessList.GetName(), []string{accessList.GetName()}) + + ownersAccessList, _, err = service.UpsertAccessListWithMembers(ctx, ownersAccessList, nil) + require.NoError(t, err) + requireStatusOwnerOf(t, service, ownersAccessList.GetName(), []string{accessList.GetName()}) + + ownersAccessList, err = service.UpsertAccessList(ctx, ownersAccessList) + require.NoError(t, err) + requireStatusOwnerOf(t, service, ownersAccessList.GetName(), []string{accessList.GetName()}) + + ownersAccessList, err = service.UpdateAccessList(ctx, ownersAccessList) + require.NoError(t, err) + requireStatusOwnerOf(t, service, ownersAccessList.GetName(), []string{accessList.GetName()}) + + err = service.DeleteAccessList(ctx, accessList.GetName()) + require.NoError(t, err) + requireStatusOwnerOf(t, service, ownersAccessList.GetName(), nil) +} + +func TestAccessListService_Status_MemberOf(t *testing.T) { + ctx := context.Background() + clock := clockwork.NewFakeClock() + + mem, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) + + service := newAccessListService(t, mem, clock, true /* igsEnabled */) + + t.Run("creation for UpsertAccessListMember", func(t *testing.T) { + accessList := createAccessList(t, service, "test-acl-"+uuid.NewString(), clock) + nestedAccessList := createAccessList(t, service, "test-nested-acl-"+uuid.NewString(), clock) + + _, err = service.UpsertAccessListMember(ctx, newAccessListMember(t, + accessList.GetName(), + nestedAccessList.GetName(), + withMembershipKind(accesslist.MembershipKindList), + )) + require.NoError(t, err) + + requireStatusMemberOf(t, service, nestedAccessList.GetName(), []string{accessList.GetName()}) + + err = service.DeleteAccessListMember(ctx, accessList.GetName(), nestedAccessList.GetName()) + require.NoError(t, err) + + requireStatusMemberOf(t, service, nestedAccessList.GetName(), nil) + }) + + t.Run("creation for UpsertAccessListWithMembers", func(t *testing.T) { + accessList := createAccessList(t, service, "test-acl-"+uuid.NewString(), clock) + nestedAccessList := createAccessList(t, service, "test-nested-acl-"+uuid.NewString(), clock) + + _, _, err = service.UpsertAccessListWithMembers( + ctx, + accessList, + []*accesslist.AccessListMember{ + newAccessListMember(t, + accessList.GetName(), + nestedAccessList.GetName(), + withMembershipKind(accesslist.MembershipKindList), + ), + }, + ) + require.NoError(t, err) + + requireStatusMemberOf(t, service, nestedAccessList.GetName(), []string{accessList.GetName()}) + + _, _, err = service.UpsertAccessListWithMembers( + ctx, + accessList, + []*accesslist.AccessListMember{ + // delete the member + }, + ) + require.NoError(t, err) + + requireStatusMemberOf(t, service, nestedAccessList.GetName(), nil) + }) + + t.Run("member updates and upserts do not affect MemberOf", func(t *testing.T) { + accessList := createAccessList(t, service, "test-acl-"+uuid.NewString(), clock) + nestedAccessList := createAccessList(t, service, "test-nested-acl-"+uuid.NewString(), clock) + + member, err := service.UpsertAccessListMember(ctx, newAccessListMember(t, + accessList.GetName(), + nestedAccessList.GetName(), + withMembershipKind(accesslist.MembershipKindList), + )) + require.NoError(t, err) + + requireStatusMemberOf(t, service, nestedAccessList.GetName(), []string{accessList.GetName()}) + + updatedMember, err := service.UpdateAccessListMember(ctx, member) + require.NoError(t, err) + requireStatusMemberOf(t, service, nestedAccessList.GetName(), []string{accessList.GetName()}) + + _, err = service.UpsertAccessListMember(ctx, updatedMember) + require.NoError(t, err) + requireStatusMemberOf(t, service, nestedAccessList.GetName(), []string{accessList.GetName()}) + }) + + t.Run("member access list updates and upserts do not affect its MemberOf", func(t *testing.T) { + accessList := createAccessList(t, service, "test-acl-"+uuid.NewString(), clock) + nestedAccessList := createAccessList(t, service, "test-nested-acl-"+uuid.NewString(), clock) + + _, err = service.UpsertAccessListMember(ctx, newAccessListMember(t, + accessList.GetName(), + nestedAccessList.GetName(), + withMembershipKind(accesslist.MembershipKindList), + )) + require.NoError(t, err) + + requireStatusMemberOf(t, service, nestedAccessList.GetName(), []string{accessList.GetName()}) + + nestedAccessList, _, err = service.UpsertAccessListWithMembers(ctx, nestedAccessList, nil) + require.NoError(t, err) + requireStatusMemberOf(t, service, nestedAccessList.GetName(), []string{accessList.GetName()}) + + nestedAccessList, err = service.UpdateAccessList(ctx, nestedAccessList) + require.NoError(t, err) + requireStatusMemberOf(t, service, nestedAccessList.GetName(), []string{accessList.GetName()}) + + nestedAccessList, err = service.UpsertAccessList(ctx, nestedAccessList) + require.NoError(t, err) + requireStatusMemberOf(t, service, nestedAccessList.GetName(), []string{accessList.GetName()}) + }) +} + +func requireStatusOwnerOf(t *testing.T, service *AccessListService, accessListName string, ownerOf []string) { + t.Helper() + ctx := context.Background() + accessList, err := service.GetAccessList(ctx, accessListName) + require.NoError(t, err) + slices.Sort(ownerOf) + slices.Sort(accessList.Status.OwnerOf) + require.ElementsMatch(t, ownerOf, accessList.Status.OwnerOf) +} + +func requireStatusMemberOf(t *testing.T, service *AccessListService, accessListName string, memberOf []string) { + t.Helper() + ctx := context.Background() + accessList, err := service.GetAccessList(ctx, accessListName) + require.NoError(t, err) + slices.Sort(memberOf) + slices.Sort(accessList.Status.MemberOf) + require.ElementsMatch(t, memberOf, accessList.Status.MemberOf) +} + func newAccessListService(t *testing.T, mem *memory.Memory, clock clockwork.Clock, igsEnabled bool) *AccessListService { t.Helper()