From 3c05a072273c4febe5aeebbb503f8db059502bc2 Mon Sep 17 00:00:00 2001 From: Pawel Kopiczko Date: Mon, 29 Sep 2025 17:19:19 +0100 Subject: [PATCH] AccessRequest: Stop using Expires as MaxDuration --- api/types/access_request.go | 10 +++++----- lib/services/access_request.go | 20 +++++++------------- lib/services/access_request_test.go | 27 +++++++++++++-------------- lib/services/local/dynamic_access.go | 2 +- 4 files changed, 26 insertions(+), 33 deletions(-) diff --git a/api/types/access_request.go b/api/types/access_request.go index 152126bbb421d..963a6700ebcb2 100644 --- a/api/types/access_request.go +++ b/api/types/access_request.go @@ -965,20 +965,20 @@ func NewAccessRequestAllowedPromotions(promotions []*AccessRequestAllowedPromoti } // ValidateAssumeStartTime returns error if start time is in an invalid range. -func ValidateAssumeStartTime(assumeStartTime time.Time, accessExpiry time.Time, creationTime time.Time) error { +func ValidateAssumeStartTime(assumeStartTime time.Time, maxDuration time.Time, creationTime time.Time) error { // Guard against requesting a start time before the request creation time. if assumeStartTime.Before(creationTime) { return trace.BadParameter("assume start time has to be after %v", creationTime.Format(time.RFC3339)) } // Guard against requesting a start time after access expiry. - if assumeStartTime.After(accessExpiry) || assumeStartTime.Equal(accessExpiry) { - return trace.BadParameter("assume start time must be prior to access expiry time at %v", - accessExpiry.Format(time.RFC3339)) + if assumeStartTime.After(maxDuration) || assumeStartTime.Equal(maxDuration) { + return trace.BadParameter("assume start time must be prior to max duration time at %v", + maxDuration.Format(time.RFC3339)) } // Access expiry can be greater than constants.MaxAssumeStartDuration, but start time // should be on or before constants.MaxAssumeStartDuration. maxAssumableStartTime := creationTime.Add(constants.MaxAssumeStartDuration) - if maxAssumableStartTime.Before(accessExpiry) && assumeStartTime.After(maxAssumableStartTime) { + if maxAssumableStartTime.Before(maxDuration) && assumeStartTime.After(maxAssumableStartTime) { return trace.BadParameter("assume start time is too far in the future, latest time allowed is %v", maxAssumableStartTime.Format(time.RFC3339)) } diff --git a/lib/services/access_request.go b/lib/services/access_request.go index 4ed77ca42055d..58bfd664cdac9 100644 --- a/lib/services/access_request.go +++ b/lib/services/access_request.go @@ -490,7 +490,7 @@ func ApplyAccessReview(req types.AccessRequest, rev types.AccessReview, author U req.SetReviews(append(req.GetReviews(), rev)) if rev.AssumeStartTime != nil { - if err := types.ValidateAssumeStartTime(*rev.AssumeStartTime, req.GetAccessExpiry(), req.GetCreationTime()); err != nil { + if err := types.ValidateAssumeStartTime(*rev.AssumeStartTime, req.GetMaxDuration(), req.GetCreationTime()); err != nil { return trace.Wrap(err) } req.SetAssumeStartTime(*rev.AssumeStartTime) @@ -514,7 +514,7 @@ func ApplyAccessReview(req types.AccessRequest, rev types.AccessReview, author U req.SetPromotedAccessListName(rev.GetAccessListName()) req.SetPromotedAccessListTitle(rev.GetAccessListTitle()) } - req.SetExpiry(req.GetAccessExpiry()) + req.SetExpiry(req.GetMaxDuration()) return nil } @@ -1364,15 +1364,9 @@ func (m *RequestValidator) validate(ctx context.Context, req types.AccessRequest maxAccessDuration = sessionTTL } - // This is the final adjusted access expiry where both max duration - // and session TTL were taken into consideration. - accessExpiry := now.Add(maxAccessDuration) - // Adjusted max access duration is equal to the access expiry time. - req.SetMaxDuration(accessExpiry) - - // Setting access expiry before calling `calculatePendingRequestTTL` - // matters since the func relies on this adjusted expiry. - req.SetAccessExpiry(accessExpiry) + // This is the final adjusted max duration where both max duration and session TTL + // were taken into consideration. + req.SetMaxDuration(now.Add(maxAccessDuration)) // Calculate the expiration time of the Access Request (how long it // will await approval). @@ -1384,7 +1378,7 @@ func (m *RequestValidator) validate(ctx context.Context, req types.AccessRequest if req.GetAssumeStartTime() != nil { assumeStartTime := *req.GetAssumeStartTime() - if err := types.ValidateAssumeStartTime(assumeStartTime, accessExpiry, req.GetCreationTime()); err != nil { + if err := types.ValidateAssumeStartTime(assumeStartTime, req.GetMaxDuration(), req.GetCreationTime()); err != nil { return trace.Wrap(err) } } @@ -1487,7 +1481,7 @@ func (m *RequestValidator) maxDurationForRole(roleName string) time.Duration { // approval). request TTL is capped to the smaller value between the const requestTTL and the // access request access expiry. func (m *RequestValidator) calculatePendingRequestTTL(r types.AccessRequest, now time.Time) (time.Duration, error) { - accessExpiryTTL := r.GetAccessExpiry().Sub(now) + accessExpiryTTL := r.GetMaxDuration().Sub(now) // If no expiration provided, use default. expiry := r.Expiry() diff --git a/lib/services/access_request_test.go b/lib/services/access_request_test.go index fbadf3a8248ce..da6734a82a011 100644 --- a/lib/services/access_request_test.go +++ b/lib/services/access_request_test.go @@ -708,7 +708,7 @@ func TestReviewThresholds(t *testing.T) { propose: approve, assumeStartTime: clock.Now().UTC().Add(10000 * time.Hour), errCheck: func(tt require.TestingT, err error, i ...any) { - require.ErrorContains(tt, err, "assume start time must be prior to access expiry time", i...) + require.ErrorContains(tt, err, "assume start time must be prior to max duration time", i...) }, }, }, @@ -2128,8 +2128,8 @@ func TestCalculatePendingRequestTTL(t *testing.T) { tests := []struct { desc string - // accessExpiryTTL == max access duration. - accessExpiryTTL time.Duration + // maxDuration == max access duration. + maxDuration time.Duration // when the access request expires in the PENDING state. requestPendingExpiryTTL time.Time assertion require.ErrorAssertionFunc @@ -2137,52 +2137,52 @@ func TestCalculatePendingRequestTTL(t *testing.T) { }{ { desc: "valid: requested ttl < access expiry", - accessExpiryTTL: requestTTL - (3 * day), + maxDuration: requestTTL - (3 * day), requestPendingExpiryTTL: now.Add(requestTTL - (4 * day)), expectedDuration: requestTTL - (4 * day), assertion: require.NoError, }, { desc: "valid: requested ttl == access expiry", - accessExpiryTTL: requestTTL - (3 * day), + maxDuration: requestTTL - (3 * day), requestPendingExpiryTTL: now.Add(requestTTL - (3 * day)), expectedDuration: requestTTL - (3 * day), assertion: require.NoError, }, { desc: "valid: requested ttl == default request ttl", - accessExpiryTTL: requestTTL, + maxDuration: requestTTL, requestPendingExpiryTTL: now.Add(requestTTL), expectedDuration: requestTTL, assertion: require.NoError, }, { desc: "valid: no TTL request defaults to the const requestTTL if access expiry is larger", - accessExpiryTTL: requestTTL + (3 * day), + maxDuration: requestTTL + (3 * day), expectedDuration: requestTTL, assertion: require.NoError, }, { desc: "valid: no TTL request defaults to accessExpiry if const requestTTL is larger", - accessExpiryTTL: requestTTL - (3 * day), + maxDuration: requestTTL - (3 * day), expectedDuration: requestTTL - (3 * day), assertion: require.NoError, }, { desc: "invalid: requested ttl > access expiry", - accessExpiryTTL: requestTTL - (3 * day), + maxDuration: requestTTL - (3 * day), requestPendingExpiryTTL: now.Add(requestTTL - (2 * day)), assertion: require.Error, }, { desc: "invalid: requested ttl > default request TTL", - accessExpiryTTL: requestTTL + (1 * day), + maxDuration: requestTTL + (1 * day), requestPendingExpiryTTL: now.Add(requestTTL + (1 * day)), assertion: require.Error, }, { desc: "invalid: requested ttl < now", - accessExpiryTTL: requestTTL - (3 * day), + maxDuration: requestTTL - (3 * day), requestPendingExpiryTTL: now.Add(-(3 * day)), assertion: require.Error, }, @@ -2213,7 +2213,7 @@ func TestCalculatePendingRequestTTL(t *testing.T) { request, err := types.NewAccessRequest("some-id", "foo", "bar") require.NoError(t, err) request.SetExpiry(tt.requestPendingExpiryTTL) - request.SetAccessExpiry(now.Add(tt.accessExpiryTTL)) + request.SetMaxDuration(now.Add(tt.maxDuration)) ttl, err := validator.calculatePendingRequestTTL(request, now) tt.assertion(t, err) @@ -3129,7 +3129,6 @@ func TestValidate_RequestedMaxDuration(t *testing.T) { err = validator.validate(context.Background(), req, identity) require.NoError(t, err) - require.Equal(t, now.Add(tt.expectedAccessDuration), req.GetAccessExpiry()) require.Equal(t, now.Add(tt.expectedAccessDuration), req.GetMaxDuration()) require.Equal(t, now.Add(tt.expectedSessionTTL), req.GetSessionTLL()) require.Equal(t, now.Add(tt.expectedPendingTTL), req.Expiry()) @@ -3180,7 +3179,7 @@ func TestValidate_RequestedPendingTTLAndMaxDuration(t *testing.T) { err = validator.validate(context.Background(), req, identity) require.NoError(t, err) - require.Equal(t, now.Add(requestedMaxDuration), req.GetAccessExpiry()) + require.Equal(t, time.Time{}, req.GetAccessExpiry()) require.Equal(t, now.Add(requestedMaxDuration), req.GetMaxDuration()) require.Equal(t, now.Add(defaultSessionTTL), req.GetSessionTLL()) require.Equal(t, now.Add(requestedPendingTTL), req.Expiry()) diff --git a/lib/services/local/dynamic_access.go b/lib/services/local/dynamic_access.go index c94b897fbd256..13071846be76f 100644 --- a/lib/services/local/dynamic_access.go +++ b/lib/services/local/dynamic_access.go @@ -123,7 +123,7 @@ func (s *DynamicAccessService) SetAccessRequestState(ctx context.Context, params } if params.AssumeStartTime != nil { - if err := types.ValidateAssumeStartTime(*params.AssumeStartTime, req.GetAccessExpiry(), req.GetCreationTime()); err != nil { + if err := types.ValidateAssumeStartTime(*params.AssumeStartTime, req.GetMaxDuration(), req.GetCreationTime()); err != nil { return nil, trace.Wrap(err) } req.SetAssumeStartTime(*params.AssumeStartTime)