diff --git a/lib/services/local/access_list.go b/lib/services/local/access_list.go index 3f6ffaada787c..9b3c2e075ef65 100644 --- a/lib/services/local/access_list.go +++ b/lib/services/local/access_list.go @@ -278,13 +278,7 @@ func (a *AccessListService) runOpWithLock(ctx context.Context, accessList *acces return accesslists.ValidateAccessListWithMembers(ctx, existingAccessList, accessList, listMembers, &accessListAndMembersGetter{a.service, a.memberService}) } - updateAccessList := func() error { - var err error - upserted, err = opFn(ctx, accessList) - return trace.Wrap(err) - } - - reconcileOwners := func() error { + reconcileOldOwners := func() error { currentOwnersMap := make(map[string]struct{}) for _, owner := range accessList.Spec.Owners { if owner.MembershipKind == accesslist.MembershipKindList { @@ -292,13 +286,6 @@ func (a *AccessListService) runOpWithLock(ctx context.Context, accessList *acces } } - // update references for new owners - for ownerName := range currentOwnersMap { - if err := a.updateAccessListOwnerOf(ctx, accessList.GetName(), ownerName, true); err != nil { - return trace.Wrap(err) - } - } - // update references for old owners if existingAccessList != nil { for _, owner := range existingAccessList.Spec.Owners { @@ -318,6 +305,23 @@ func (a *AccessListService) runOpWithLock(ctx context.Context, accessList *acces return nil } + updateAccessList := func() error { + var err error + upserted, err = opFn(ctx, accessList) + return trace.Wrap(err) + } + + reconcileNewOwners := func() error { + for _, owner := range accessList.Spec.Owners { + if owner.MembershipKind == accesslist.MembershipKindList { + if err := a.updateAccessListOwnerOf(ctx, accessList.GetName(), owner.Name, true); err != nil { + return trace.Wrap(err) + } + } + } + return nil + } + var actions []func() error // If IGS is not enabled for this cluster we need to wrap the whole @@ -328,7 +332,12 @@ func (a *AccessListService) runOpWithLock(ctx context.Context, accessList *acces actions = append(actions, func() error { return a.VerifyAccessListCreateLimit(ctx, accessList.GetName()) }) } - actions = append(actions, validateAccessList, updateAccessList, reconcileOwners) + // Note we need to reconcile the old owners (clean status.owner_of for the owner lists + // which are removed with this request) first, then update the access list and then + // reconcile the new owners (set status.owner_of of the owner lists that are added with + // this request). This is to make sure the operation doesn't escalate privileges if + // interrupted as we user status.owner_of to calculate hierarchy. + actions = append(actions, validateAccessList, reconcileOldOwners, updateAccessList, reconcileNewOwners) err := a.service.RunWhileLocked(ctx, []string{accessListResourceLockName}, accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { @@ -731,8 +740,11 @@ func (a *AccessListService) writeAccessListWithMembers(ctx context.Context, acce } } + var existingAccessList *accesslist.AccessList + validateAccessList := func() error { - existingAccessList, err := a.service.GetResource(ctx, accessList.GetName()) + var err error + existingAccessList, err = a.service.GetResource(ctx, accessList.GetName()) if err != nil { // a not found error is totally legal for an upsert operation, but // fatal for an update. @@ -826,12 +838,18 @@ func (a *AccessListService) writeAccessListWithMembers(ctx context.Context, acce return nil } - reconcileOwners := func() error { - // update references for new owners - for _, owner := range accessList.Spec.Owners { - if owner.MembershipKind == accesslist.MembershipKindList { - if err := a.updateAccessListOwnerOf(ctx, accessList.GetName(), owner.Name, true); err != nil { - return trace.Wrap(err) + reconcileOldOwners := func() error { + if existingAccessList == nil { + return nil + } + for _, existingOwner := range existingAccessList.Spec.Owners { + if existingOwner.MembershipKind == accesslist.MembershipKindList { + if !slices.ContainsFunc(accessList.Spec.Owners, func(owner accesslist.Owner) bool { + return owner.Name == existingOwner.Name + }) { + if err := a.updateAccessListOwnerOf(ctx, existingAccessList.GetName(), existingOwner.Name, false); err != nil { + return trace.Wrap(err) + } } } } @@ -844,6 +862,17 @@ func (a *AccessListService) writeAccessListWithMembers(ctx context.Context, acce return trace.Wrap(err) } + reconcileNewOwners := func() error { + for _, owner := range accessList.Spec.Owners { + if owner.MembershipKind == accesslist.MembershipKindList { + if err := a.updateAccessListOwnerOf(ctx, accessList.GetName(), owner.Name, true); err != nil { + return trace.Wrap(err) + } + } + } + return nil + } + var actions []func() error // If IGS is not enabled for this cluster we need to wrap the whole update and @@ -854,7 +883,12 @@ func (a *AccessListService) writeAccessListWithMembers(ctx context.Context, acce actions = append(actions, func() error { return a.VerifyAccessListCreateLimit(ctx, accessList.GetName()) }) } - actions = append(actions, validateAccessList, reconcileMembers, writeAccessList, reconcileOwners) + // Note we need to reconcile the old owners (clean status.owner_of for the owner lists + // which are removed with this request) first, then update the access list and then + // reconcile the new owners (set status.owner_of of the owner lists that are added with + // this request). This is to make sure the operation doesn't escalate privileges if + // interrupted as we use status.owner_of to calculate hierarchy. + actions = append(actions, validateAccessList, reconcileMembers, reconcileOldOwners, writeAccessList, reconcileNewOwners) if err := a.service.RunWhileLocked(ctx, []string{accessListResourceLockName}, 2*accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { return a.service.RunWhileLocked(ctx, lockName(accessList.GetName()), 2*accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { diff --git a/lib/services/local/access_list_test.go b/lib/services/local/access_list_test.go index 508c8905b3aff..a62868cd811c3 100644 --- a/lib/services/local/access_list_test.go +++ b/lib/services/local/access_list_test.go @@ -2010,96 +2010,182 @@ func TestAccessListService_ListAllAccessListReviews(t *testing.T) { } func TestAccessListService_Status_OwnerOf(t *testing.T) { - ctx := context.Background() + ctx := t.Context() clock := clockwork.NewFakeClock() - mem, err := memory.New(memory.Config{ - Context: ctx, - Clock: clock, - }) - require.NoError(t, err) + testCases := []struct { + name string + newCreateFn func(service *AccessListService) func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) + newUpdateFn func(service *AccessListService) func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) + }{ + { + name: "UpdateAccessList", + newCreateFn: func(service *AccessListService) func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + return service.UpsertAccessList + }, + newUpdateFn: func(service *AccessListService) func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + return service.UpdateAccessList + }, + }, + { + name: "UpsertAccessList", + newCreateFn: func(service *AccessListService) func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + return service.UpsertAccessList + }, + newUpdateFn: func(service *AccessListService) func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + return service.UpsertAccessList + }, + }, + { + name: "UpsertAccessListWithMembers", + newCreateFn: func(service *AccessListService) func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + return func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + accessList, _, err := service.UpsertAccessListWithMembers(ctx, accessList, nil) + return accessList, err + } + }, + newUpdateFn: func(service *AccessListService) func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + return func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + accessList, _, err := service.UpsertAccessListWithMembers(ctx, accessList, nil) + return accessList, err + } + }, + }, + { + name: "UpdateAccessListAndOverwriteMembers", + newCreateFn: func(service *AccessListService) func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + return service.UpsertAccessList + }, + newUpdateFn: func(service *AccessListService) func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + return func(ctx context.Context, accessList *accesslist.AccessList) (*accesslist.AccessList, error) { + accessList, _, err := service.UpdateAccessListAndOverwriteMembers(ctx, accessList, nil) + return accessList, err + } + }, + }, + } - service := newAccessListService(t, mem, modulestest.EnterpriseModules()) + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + mem, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) - ownerAccessList1 := createAccessList(t, service, "test-owners-acl-"+uuid.NewString(), clock) - requireStatusOwnerOf(t, service, ownerAccessList1.GetName(), nil) - ownerAccessList2 := createAccessList(t, service, "test-owners-acl-"+uuid.NewString(), clock) - requireStatusOwnerOf(t, service, ownerAccessList2.GetName(), nil) + service := newAccessListService(t, mem, modulestest.EnterpriseModules()) - accessList := createAccessList(t, service, "test-acl-"+uuid.NewString(), clock, - withOwners([]accesslist.Owner{ - { - Name: ownerAccessList1.GetName(), - MembershipKind: accesslist.MembershipKindList, - }, - }), + testAccessListService_Status_OwnerOf_suite(t, + tt.newCreateFn(service), + tt.newUpdateFn(service), + service.GetAccessList, + ) + }) + } +} + +func testAccessListService_Status_OwnerOf_suite(t *testing.T, + createFn, updateFn func(context.Context, *accesslist.AccessList) (*accesslist.AccessList, error), + getFn testAccessListGetterFunc, +) { + t.Helper() + ctx := t.Context() + clock := clockwork.NewFakeClock() + + ownerAccessList1 := newAccessList(t, "test_owners_list_1", clock) + ownerAccessList2 := newAccessList(t, "test_owners_list_2", clock) + ownerAccessList3 := newAccessList(t, "test_owners_list_3", clock) + ownerAccessList4 := newAccessList(t, "test_owners_list_4", clock) + for _, al := range []*accesslist.AccessList{ownerAccessList1, ownerAccessList2, ownerAccessList3, ownerAccessList4} { + _, err := createFn(ctx, al) + require.NoError(t, err, "AccessList = %q", al.GetName()) + } + + accessList1 := newAccessList(t, "test_list_1", clock, + withOwners([]accesslist.Owner{{ + Name: "test_user_owner_1", + MembershipKind: accesslist.MembershipKindUser, + }}), + ) + accessList2 := newAccessList(t, "test_list_2", clock, + withOwners([]accesslist.Owner{{ + Name: "test_user_owner_2", + MembershipKind: accesslist.MembershipKindUser, + }}), ) - requireStatusOwnerOf(t, service, ownerAccessList1.GetName(), []string{accessList.GetName()}) - ownerAccessList1, _, err = service.UpsertAccessListWithMembers(ctx, ownerAccessList1, nil) + // Create 1 list with no list owners and 1 list with 2 list owners. + accessList1, err := createFn(ctx, accessList1) require.NoError(t, err) - requireStatusOwnerOf(t, service, ownerAccessList1.GetName(), []string{accessList.GetName()}) - ownerAccessList1, err = service.UpsertAccessList(ctx, ownerAccessList1) + addListOwner(accessList2, ownerAccessList1) + addListOwner(accessList2, ownerAccessList2) + accessList2, err = createFn(ctx, accessList2) require.NoError(t, err) - requireStatusOwnerOf(t, service, ownerAccessList1.GetName(), []string{accessList.GetName()}) - ownerAccessList1, err = service.UpdateAccessList(ctx, ownerAccessList1) + requireStatusOwnerOf(t, getFn, ownerAccessList1.GetName(), []string{accessList2.GetName()}) + requireStatusOwnerOf(t, getFn, ownerAccessList2.GetName(), []string{accessList2.GetName()}) + + // Add 1 owner. + addListOwner(accessList1, ownerAccessList1) + accessList1, err = updateFn(ctx, accessList1) require.NoError(t, err) - requireStatusOwnerOf(t, service, ownerAccessList1.GetName(), []string{accessList.GetName()}) - t.Run("origin access list updates and upserts fix status.owner_of of existing list owners", func(t *testing.T) { - requireStatusOwnerOf(t, service, ownerAccessList1.GetName(), []string{accessList.GetName()}) + requireStatusOwnerOf(t, getFn, ownerAccessList1.GetName(), []string{accessList1.GetName(), accessList2.GetName()}) + requireStatusOwnerOf(t, getFn, ownerAccessList2.GetName(), []string{accessList2.GetName()}) - err = service.updateAccessListOwnerOf(ctx, accessList.GetName(), ownerAccessList1.GetName(), false /* new - this will delete */) - require.NoError(t, err) - requireStatusOwnerOf(t, service, ownerAccessList1.GetName(), []string{}) + // Add 2nd owner. + addListOwner(accessList1, ownerAccessList2) + accessList1, err = updateFn(ctx, accessList1) + require.NoError(t, err) - accessList, err = service.UpsertAccessList(ctx, accessList) - require.NoError(t, err) - requireStatusOwnerOf(t, service, ownerAccessList1.GetName(), []string{accessList.GetName()}) + requireStatusOwnerOf(t, getFn, ownerAccessList1.GetName(), []string{accessList1.GetName(), accessList2.GetName()}) + requireStatusOwnerOf(t, getFn, ownerAccessList2.GetName(), []string{accessList1.GetName(), accessList2.GetName()}) - err = service.updateAccessListOwnerOf(ctx, accessList.GetName(), ownerAccessList1.GetName(), false /* new - this will delete */) - require.NoError(t, err) - requireStatusOwnerOf(t, service, ownerAccessList1.GetName(), []string{}) + // Remove 1 owner. + rmListOwner(accessList1, ownerAccessList1) + accessList1, err = updateFn(ctx, accessList1) + require.NoError(t, err) - accessList, err = service.UpdateAccessList(ctx, accessList) - require.NoError(t, err) - requireStatusOwnerOf(t, service, ownerAccessList1.GetName(), []string{accessList.GetName()}) - }) + requireStatusOwnerOf(t, getFn, ownerAccessList1.GetName(), []string{accessList2.GetName()}) + requireStatusOwnerOf(t, getFn, ownerAccessList2.GetName(), []string{accessList1.GetName(), accessList2.GetName()}) - t.Run("when list owner is deleted during update or upsert former owners list status.owner_of is updated", func(t *testing.T) { - requireStatusOwnerOf(t, service, ownerAccessList2.GetName(), nil) + // Remove 2nd owner. + rmListOwner(accessList1, ownerAccessList2) + _, err = updateFn(ctx, accessList1) + require.NoError(t, err) - owner2 := accesslist.Owner{ - Name: ownerAccessList2.GetName(), - MembershipKind: accesslist.MembershipKindList, - } + requireStatusOwnerOf(t, getFn, ownerAccessList1.GetName(), []string{accessList2.GetName()}) + requireStatusOwnerOf(t, getFn, ownerAccessList2.GetName(), []string{accessList2.GetName()}) - accessList.Spec.Owners = append(accessList.Spec.Owners, owner2) - accessList, err = service.UpsertAccessList(ctx, accessList) - requireStatusOwnerOf(t, service, ownerAccessList2.GetName(), []string{accessList.GetName()}) + // Remove 2 owners at a time. + rmListOwner(accessList2, ownerAccessList1) + rmListOwner(accessList2, ownerAccessList2) + accessList2, err = updateFn(ctx, accessList2) + require.NoError(t, err) - accessList.Spec.Owners = slices.DeleteFunc(accessList.Spec.Owners, func(o accesslist.Owner) bool { - return o.Name == owner2.Name - }) - accessList, err = service.UpsertAccessList(ctx, accessList) - requireStatusOwnerOf(t, service, ownerAccessList2.GetName(), nil) + requireStatusOwnerOf(t, getFn, ownerAccessList1.GetName(), []string{}) + requireStatusOwnerOf(t, getFn, ownerAccessList2.GetName(), []string{}) - accessList.Spec.Owners = append(accessList.Spec.Owners, owner2) - accessList, err = service.UpdateAccessList(ctx, accessList) - requireStatusOwnerOf(t, service, ownerAccessList2.GetName(), []string{accessList.GetName()}) + // Add 2 owners at a time. + addListOwner(accessList2, ownerAccessList1) + addListOwner(accessList2, ownerAccessList2) + accessList2, err = updateFn(ctx, accessList2) + require.NoError(t, err) - accessList.Spec.Owners = slices.DeleteFunc(accessList.Spec.Owners, func(o accesslist.Owner) bool { - return o.Name == owner2.Name - }) - accessList, err = service.UpdateAccessList(ctx, accessList) - requireStatusOwnerOf(t, service, ownerAccessList2.GetName(), nil) - }) + requireStatusOwnerOf(t, getFn, ownerAccessList1.GetName(), []string{accessList2.GetName()}) + requireStatusOwnerOf(t, getFn, ownerAccessList2.GetName(), []string{accessList2.GetName()}) - err = service.DeleteAccessList(ctx, accessList.GetName()) + // Do 2x2 swap + rmListOwner(accessList2, ownerAccessList1) + rmListOwner(accessList2, ownerAccessList2) + addListOwner(accessList2, ownerAccessList3) + addListOwner(accessList2, ownerAccessList4) + accessList2, err = updateFn(ctx, accessList2) require.NoError(t, err) - requireStatusOwnerOf(t, service, ownerAccessList1.GetName(), nil) + + requireStatusOwnerOf(t, getFn, ownerAccessList3.GetName(), []string{accessList2.GetName()}) + requireStatusOwnerOf(t, getFn, ownerAccessList4.GetName(), []string{accessList2.GetName()}) } func TestAccessListService_Status_MemberOf(t *testing.T) { @@ -2457,7 +2543,17 @@ func TestAccessListService_EnsureNestedAccessListStatuses(t *testing.T) { requireStatusMemberOf(t, service, a5, []string{ghost, a1}) } -func requireStatusOwnerOf(t *testing.T, service *AccessListService, accessListName string, ownerOf []string) { +type testAccessListGetter interface { + GetAccessList(ctx context.Context, name string) (*accesslist.AccessList, error) +} + +type testAccessListGetterFunc func(ctx context.Context, name string) (*accesslist.AccessList, error) + +func (fn testAccessListGetterFunc) GetAccessList(ctx context.Context, name string) (*accesslist.AccessList, error) { + return fn(ctx, name) +} + +func requireStatusOwnerOf(t *testing.T, service testAccessListGetter, accessListName string, ownerOf []string) { t.Helper() ctx := context.Background() accessList, err := service.GetAccessList(ctx, accessListName) @@ -2467,7 +2563,7 @@ func requireStatusOwnerOf(t *testing.T, service *AccessListService, accessListNa require.ElementsMatch(t, ownerOf, accessList.Status.OwnerOf) } -func requireStatusMemberOf(t *testing.T, service *AccessListService, accessListName string, memberOf []string) { +func requireStatusMemberOf(t *testing.T, service testAccessListGetter, accessListName string, memberOf []string) { t.Helper() ctx := context.Background() accessList, err := service.GetAccessList(ctx, accessListName) @@ -2844,3 +2940,16 @@ func TestAccessListDeletePrevention_MissingReferences(t *testing.T) { }, ) } + +func addListOwner(accessList, owner *accesslist.AccessList) { + accessList.Spec.Owners = append(accessList.Spec.Owners, accesslist.Owner{ + Name: owner.GetName(), + MembershipKind: accesslist.MembershipKindList, + }) +} + +func rmListOwner(accessList, owner *accesslist.AccessList) { + accessList.Spec.Owners = slices.DeleteFunc(accessList.Spec.Owners, func(o accesslist.Owner) bool { + return o.MembershipKind == accesslist.MembershipKindList && o.Name == owner.GetName() + }) +}