diff --git a/api/client/accesslist/accesslist.go b/api/client/accesslist/accesslist.go index e3e05c6a2ad02..0499a2e87a725 100644 --- a/api/client/accesslist/accesslist.go +++ b/api/client/accesslist/accesslist.go @@ -16,6 +16,7 @@ package accesslist import ( "context" + "time" "github.com/gravitational/trace" @@ -261,15 +262,15 @@ func (c *Client) ListAccessListReviews(ctx context.Context, accessList string, p } // CreateAccessListReview will create a new review for an access list. -func (c *Client) CreateAccessListReview(ctx context.Context, review *accesslist.Review) (*accesslist.Review, error) { +func (c *Client) CreateAccessListReview(ctx context.Context, review *accesslist.Review) (*accesslist.Review, time.Time, error) { resp, err := c.grpcClient.CreateAccessListReview(ctx, &accesslistv1.CreateAccessListReviewRequest{ Review: conv.ToReviewProto(review), }) if err != nil { - return nil, trace.Wrap(err) + return nil, time.Time{}, trace.Wrap(err) } review.SetName(resp.ReviewName) - return review, nil + return review, resp.NextAuditDate.AsTime(), nil } // DeleteAccessListReview will delete an access list review from the backend. diff --git a/lib/services/access_list.go b/lib/services/access_list.go index d1b7f7650ad90..73689e8b1ba8f 100644 --- a/lib/services/access_list.go +++ b/lib/services/access_list.go @@ -307,7 +307,7 @@ type AccessListReviews interface { ListAccessListReviews(ctx context.Context, accessList string, pageSize int, pageToken string) (reviews []*accesslist.Review, nextToken string, err error) // CreateAccessListReview will create a new review for an access list. - CreateAccessListReview(ctx context.Context, review *accesslist.Review) (updatedReview *accesslist.Review, err error) + CreateAccessListReview(ctx context.Context, review *accesslist.Review) (updatedReview *accesslist.Review, nextReviewDate time.Time, err error) // DeleteAccessListReview will delete an access list review from the backend. DeleteAccessListReview(ctx context.Context, accessListName, reviewName string) error diff --git a/lib/services/local/access_list.go b/lib/services/local/access_list.go index b89edeeae7c5d..78be4e416333a 100644 --- a/lib/services/local/access_list.go +++ b/lib/services/local/access_list.go @@ -343,7 +343,7 @@ func (a *AccessListService) ListAccessListReviews(ctx context.Context, accessLis } // CreateAccessListReview will create a new review for an access list. -func (a *AccessListService) CreateAccessListReview(ctx context.Context, review *accesslist.Review) (*accesslist.Review, error) { +func (a *AccessListService) CreateAccessListReview(ctx context.Context, review *accesslist.Review) (*accesslist.Review, time.Time, error) { reviewName := uuid.New().String() createdReview, err := accesslist.NewReview(header.Metadata{ Name: reviewName, @@ -354,9 +354,11 @@ func (a *AccessListService) CreateAccessListReview(ctx context.Context, review * Changes: review.Spec.Changes, }) if err != nil { - return nil, trace.Wrap(err) + return nil, time.Time{}, trace.Wrap(err) } + var nextAuditDate time.Time + err = a.service.RunWhileLocked(ctx, lockName(review.Spec.AccessList), accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { accessList, err := a.service.GetResource(ctx, review.Spec.AccessList) if err != nil { @@ -391,7 +393,8 @@ func (a *AccessListService) CreateAccessListReview(ctx context.Context, review * return trace.Wrap(err) } - accessList.Spec.Audit.NextAuditDate = services.SelectNextReviewDate(accessList) + nextAuditDate = services.SelectNextReviewDate(accessList) + accessList.Spec.Audit.NextAuditDate = nextAuditDate for _, removedMember := range review.Spec.Changes.RemovedMembers { if err := a.memberService.WithPrefix(review.Spec.AccessList).DeleteResource(ctx, removedMember); err != nil { @@ -406,9 +409,10 @@ func (a *AccessListService) CreateAccessListReview(ctx context.Context, review * return nil }) if err != nil { - return nil, trace.Wrap(err) + return nil, time.Time{}, trace.Wrap(err) } - return createdReview, nil + + return createdReview, nextAuditDate, nil } // accessListRequiresEqual returns true if two access lists are equal. diff --git a/lib/services/local/access_list_test.go b/lib/services/local/access_list_test.go index a786c2d8b93eb..78768a4f839b1 100644 --- a/lib/services/local/access_list_test.go +++ b/lib/services/local/access_list_test.go @@ -473,9 +473,10 @@ func TestAccessListReviewCRUD(t *testing.T) { accessList2Review1.Spec.Changes.RemovedMembers = nil accessList2Review1.Spec.Changes.ReviewFrequencyChanged = 0 accessList2Review1.Spec.Changes.ReviewDayOfMonthChanged = 0 + var nextReviewDate time.Time // Add access list review. - accessList1Review1, err = service.CreateAccessListReview(ctx, accessList1Review1) + accessList1Review1, nextReviewDate, err = service.CreateAccessListReview(ctx, accessList1Review1) require.NoError(t, err) // Verify changes to access list. @@ -488,6 +489,8 @@ func TestAccessListReviewCRUD(t *testing.T) { require.Empty(t, cmp.Diff(*(accessList1Review1.Spec.Changes.MembershipRequirementsChanged), accessList1Updated.Spec.MembershipRequires)) require.Equal(t, accessList1Review1.Spec.Changes.ReviewFrequencyChanged, accessList1Updated.Spec.Audit.Recurrence.Frequency) require.Equal(t, accessList1Review1.Spec.Changes.ReviewDayOfMonthChanged, accessList1Updated.Spec.Audit.Recurrence.DayOfMonth) + // The Correct value is returned through the API. + require.Equal(t, accessList1Updated.Spec.Audit.NextAuditDate, nextReviewDate) _, err = service.GetAccessListMember(ctx, accessList1.GetName(), accessList1Member1.GetName()) require.True(t, trace.IsNotFound(err)) @@ -495,7 +498,7 @@ func TestAccessListReviewCRUD(t *testing.T) { require.True(t, trace.IsNotFound(err)) // Add another review - accessList1Review2, err = service.CreateAccessListReview(ctx, accessList1Review2) + accessList1Review2, nextReviewDate, err = service.CreateAccessListReview(ctx, accessList1Review2) require.NoError(t, err) // Verify changes to the access list again. @@ -515,9 +518,10 @@ func TestAccessListReviewCRUD(t *testing.T) { require.Empty(t, cmp.Diff(*(accessList1Review1.Spec.Changes.MembershipRequirementsChanged), accessList1Updated.Spec.MembershipRequires)) require.Equal(t, accessList1Review1.Spec.Changes.ReviewFrequencyChanged, accessList1Updated.Spec.Audit.Recurrence.Frequency) require.Equal(t, accessList1Review1.Spec.Changes.ReviewDayOfMonthChanged, accessList1Updated.Spec.Audit.Recurrence.DayOfMonth) + require.Equal(t, accessList1Updated.Spec.Audit.NextAuditDate, nextReviewDate) // Review that doesn't change anything - accessList2Review1, err = service.CreateAccessListReview(ctx, accessList2Review1) + accessList2Review1, nextReviewDate, err = service.CreateAccessListReview(ctx, accessList2Review1) require.NoError(t, err) accessList2Updated, err := service.GetAccessList(ctx, accessList2.GetName()) @@ -529,6 +533,7 @@ func TestAccessListReviewCRUD(t *testing.T) { require.Empty(t, cmp.Diff(accessList2.Spec.MembershipRequires, accessList2Updated.Spec.MembershipRequires)) require.Equal(t, accessList2.Spec.Audit.Recurrence.Frequency, accessList2Updated.Spec.Audit.Recurrence.Frequency) require.Equal(t, accessList2.Spec.Audit.Recurrence.DayOfMonth, accessList2Updated.Spec.Audit.Recurrence.DayOfMonth) + require.Equal(t, accessList2Updated.Spec.Audit.NextAuditDate, nextReviewDate) _, err = service.GetAccessListMember(ctx, accessList2.GetName(), accessList2Member1.GetName()) require.NoError(t, err)