diff --git a/api/client/accesslist/accesslist.go b/api/client/accesslist/accesslist.go index 92e4b3bf0ca02..7874706ada036 100644 --- a/api/client/accesslist/accesslist.go +++ b/api/client/accesslist/accesslist.go @@ -16,6 +16,7 @@ package accesslist import ( "context" + "time" "github.com/gravitational/trace" @@ -242,15 +243,15 @@ func (c *Client) ListAccessListReviews(ctx context.Context, accessList string, p } // CreateAccessListReview will create a new review for an access list. -func (c *Client) CreateAccessListReview(ctx context.Context, review *accesslist.Review) (*accesslist.Review, error) { +func (c *Client) CreateAccessListReview(ctx context.Context, review *accesslist.Review) (*accesslist.Review, time.Time, error) { resp, err := c.grpcClient.CreateAccessListReview(ctx, &accesslistv1.CreateAccessListReviewRequest{ Review: conv.ToReviewProto(review), }) if err != nil { - return nil, trace.Wrap(err) + return nil, time.Time{}, trace.Wrap(err) } review.SetName(resp.ReviewName) - return review, nil + return review, resp.NextAuditDate.AsTime(), nil } // DeleteAccessListReview will delete an access list review from the backend. diff --git a/lib/services/access_list.go b/lib/services/access_list.go index 305e73396efbe..b4eb7b1482fed 100644 --- a/lib/services/access_list.go +++ b/lib/services/access_list.go @@ -297,7 +297,7 @@ type AccessListReviews interface { ListAccessListReviews(ctx context.Context, accessList string, pageSize int, pageToken string) (reviews []*accesslist.Review, nextToken string, err error) // CreateAccessListReview will create a new review for an access list. - CreateAccessListReview(ctx context.Context, review *accesslist.Review) (updatedReview *accesslist.Review, err error) + CreateAccessListReview(ctx context.Context, review *accesslist.Review) (updatedReview *accesslist.Review, nextReviewDate time.Time, err error) // DeleteAccessListReview will delete an access list review from the backend. DeleteAccessListReview(ctx context.Context, accessListName, reviewName string) error diff --git a/lib/services/local/access_list.go b/lib/services/local/access_list.go index a7662244fbd37..104e44d484fd7 100644 --- a/lib/services/local/access_list.go +++ b/lib/services/local/access_list.go @@ -338,7 +338,7 @@ func (a *AccessListService) ListAccessListReviews(ctx context.Context, accessLis } // CreateAccessListReview will create a new review for an access list. -func (a *AccessListService) CreateAccessListReview(ctx context.Context, review *accesslist.Review) (*accesslist.Review, error) { +func (a *AccessListService) CreateAccessListReview(ctx context.Context, review *accesslist.Review) (*accesslist.Review, time.Time, error) { reviewName := uuid.New().String() createdReview, err := accesslist.NewReview(header.Metadata{ Name: reviewName, @@ -349,9 +349,11 @@ func (a *AccessListService) CreateAccessListReview(ctx context.Context, review * Changes: review.Spec.Changes, }) if err != nil { - return nil, trace.Wrap(err) + return nil, time.Time{}, trace.Wrap(err) } + var nextAuditDate time.Time + err = a.service.RunWhileLocked(ctx, lockName(review.Spec.AccessList), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { accessList, err := a.service.GetResource(ctx, review.Spec.AccessList) if err != nil { @@ -386,7 +388,8 @@ func (a *AccessListService) CreateAccessListReview(ctx context.Context, review * return trace.Wrap(err) } - accessList.Spec.Audit.NextAuditDate = services.SelectNextReviewDate(accessList) + nextAuditDate = services.SelectNextReviewDate(accessList) + accessList.Spec.Audit.NextAuditDate = nextAuditDate for _, removedMember := range review.Spec.Changes.RemovedMembers { if err := a.memberService.WithPrefix(review.Spec.AccessList).DeleteResource(ctx, removedMember); err != nil { @@ -401,9 +404,10 @@ func (a *AccessListService) CreateAccessListReview(ctx context.Context, review * return nil }) if err != nil { - return nil, trace.Wrap(err) + return nil, time.Time{}, trace.Wrap(err) } - return createdReview, nil + + return createdReview, nextAuditDate, nil } // accessListRequiresEqual returns true if two access lists are equal. diff --git a/lib/services/local/access_list_test.go b/lib/services/local/access_list_test.go index abd7800f6fdfd..828d6ed66f990 100644 --- a/lib/services/local/access_list_test.go +++ b/lib/services/local/access_list_test.go @@ -472,9 +472,10 @@ func TestAccessListReviewCRUD(t *testing.T) { accessList2Review1.Spec.Changes.RemovedMembers = nil accessList2Review1.Spec.Changes.ReviewFrequencyChanged = 0 accessList2Review1.Spec.Changes.ReviewDayOfMonthChanged = 0 + var nextReviewDate time.Time // Add access list review. - accessList1Review1, err = service.CreateAccessListReview(ctx, accessList1Review1) + accessList1Review1, nextReviewDate, err = service.CreateAccessListReview(ctx, accessList1Review1) require.NoError(t, err) // Verify changes to access list. @@ -487,6 +488,8 @@ func TestAccessListReviewCRUD(t *testing.T) { require.Empty(t, cmp.Diff(*(accessList1Review1.Spec.Changes.MembershipRequirementsChanged), accessList1Updated.Spec.MembershipRequires)) require.Equal(t, accessList1Review1.Spec.Changes.ReviewFrequencyChanged, accessList1Updated.Spec.Audit.Recurrence.Frequency) require.Equal(t, accessList1Review1.Spec.Changes.ReviewDayOfMonthChanged, accessList1Updated.Spec.Audit.Recurrence.DayOfMonth) + // The Correct value is returned through the API. + require.Equal(t, accessList1Updated.Spec.Audit.NextAuditDate, nextReviewDate) _, err = service.GetAccessListMember(ctx, accessList1.GetName(), accessList1Member1.GetName()) require.True(t, trace.IsNotFound(err)) @@ -494,7 +497,7 @@ func TestAccessListReviewCRUD(t *testing.T) { require.True(t, trace.IsNotFound(err)) // Add another review - accessList1Review2, err = service.CreateAccessListReview(ctx, accessList1Review2) + accessList1Review2, nextReviewDate, err = service.CreateAccessListReview(ctx, accessList1Review2) require.NoError(t, err) // Verify changes to the access list again. @@ -514,9 +517,10 @@ func TestAccessListReviewCRUD(t *testing.T) { require.Empty(t, cmp.Diff(*(accessList1Review1.Spec.Changes.MembershipRequirementsChanged), accessList1Updated.Spec.MembershipRequires)) require.Equal(t, accessList1Review1.Spec.Changes.ReviewFrequencyChanged, accessList1Updated.Spec.Audit.Recurrence.Frequency) require.Equal(t, accessList1Review1.Spec.Changes.ReviewDayOfMonthChanged, accessList1Updated.Spec.Audit.Recurrence.DayOfMonth) + require.Equal(t, accessList1Updated.Spec.Audit.NextAuditDate, nextReviewDate) // Review that doesn't change anything - accessList2Review1, err = service.CreateAccessListReview(ctx, accessList2Review1) + accessList2Review1, nextReviewDate, err = service.CreateAccessListReview(ctx, accessList2Review1) require.NoError(t, err) accessList2Updated, err := service.GetAccessList(ctx, accessList2.GetName()) @@ -528,6 +532,7 @@ func TestAccessListReviewCRUD(t *testing.T) { require.Empty(t, cmp.Diff(accessList2.Spec.MembershipRequires, accessList2Updated.Spec.MembershipRequires)) require.Equal(t, accessList2.Spec.Audit.Recurrence.Frequency, accessList2Updated.Spec.Audit.Recurrence.Frequency) require.Equal(t, accessList2.Spec.Audit.Recurrence.DayOfMonth, accessList2Updated.Spec.Audit.Recurrence.DayOfMonth) + require.Equal(t, accessList2Updated.Spec.Audit.NextAuditDate, nextReviewDate) _, err = service.GetAccessListMember(ctx, accessList2.GetName(), accessList2Member1.GetName()) require.NoError(t, err)