diff --git a/lib/services/access_request.go b/lib/services/access_request.go index 87c5497c0807f..5ae13c5081c8a 100644 --- a/lib/services/access_request.go +++ b/lib/services/access_request.go @@ -1187,45 +1187,50 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest req.SetSuggestedReviewers(apiutils.Deduplicate(m.SuggestedReviewers)) } - now := m.clock.Now().UTC() - - // Calculate the expiration time of the Access Request (how long it - // will await approval). - ttl, err := m.requestTTL(ctx, identity, req) - if err != nil { - return trace.Wrap(err) - } - req.SetExpiry(now.Add(ttl)) - - maxDuration, err := m.calculateMaxAccessDuration(req) + // Calculate the expiration time of the elevated certificate that will + // be issued if the Access Request is approved. + sessionTTL, err := m.sessionTTL(ctx, identity, req) if err != nil { return trace.Wrap(err) } - // Calculate the expiration time of the elevated certificate that will - // be issued if the Access Request is approved. - sessionTTL, err := m.sessionTTL(ctx, identity, req) + maxDuration, err := m.calculateMaxAccessDuration(req, sessionTTL) if err != nil { return trace.Wrap(err) } // If the maxDuration flag is set, consider it instead of only using the session TTL. + var maxAccessDuration time.Duration + now := m.clock.Now().UTC() if maxDuration > 0 { req.SetSessionTLL(now.Add(min(sessionTTL, maxDuration))) - ttl = maxDuration + maxAccessDuration = maxDuration } else { req.SetSessionTLL(now.Add(sessionTTL)) - ttl = sessionTTL + maxAccessDuration = sessionTTL } - accessTTL := now.Add(ttl) - req.SetAccessExpiry(accessTTL) + // This is the final adjusted access expiry where both max duration + // and session TTL were taken into consideration. + accessExpiry := now.Add(maxAccessDuration) // Adjusted max access duration is equal to the access expiry time. - req.SetMaxDuration(accessTTL) + req.SetMaxDuration(accessExpiry) + + // Setting access expiry before calling `calculatePendingRequesetTTL` + // matters since the func relies on this adjusted expiry. + req.SetAccessExpiry(accessExpiry) + + // Calculate the expiration time of the Access Request (how long it + // will await approval). + requestTTL, err := m.calculatePendingRequestTTL(req) + if err != nil { + return trace.Wrap(err) + } + req.SetExpiry(now.Add(requestTTL)) if req.GetAssumeStartTime() != nil { assumeStartTime := *req.GetAssumeStartTime() - if err := types.ValidateAssumeStartTime(assumeStartTime, accessTTL, req.GetCreationTime()); err != nil { + if err := types.ValidateAssumeStartTime(assumeStartTime, accessExpiry, req.GetCreationTime()); err != nil { return trace.Wrap(err) } } @@ -1237,7 +1242,7 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest // calculateMaxAccessDuration calculates the maximum time for the access request. // The max duration time is the minimum of the max_duration time set on the request // and the max_duration time set on the request role. -func (m *RequestValidator) calculateMaxAccessDuration(req types.AccessRequest) (time.Duration, error) { +func (m *RequestValidator) calculateMaxAccessDuration(req types.AccessRequest, sessionTTL time.Duration) (time.Duration, error) { // Check if the maxDuration time is set. maxDurationTime := req.GetMaxDuration() if maxDurationTime.IsZero() { @@ -1277,16 +1282,32 @@ func (m *RequestValidator) calculateMaxAccessDuration(req types.AccessRequest) ( } } + // minAdjDuration can end up being 0, if any role does not have + // field `max_duration` defined. + // In this case, return the smaller value between the sessionTTL + // and the requested max duration. + if minAdjDuration == 0 && maxDuration < sessionTTL { + return maxDuration, nil + } + return minAdjDuration, nil } -// requestTTL calculates the TTL of the Access Request (how long it will await -// approval). -func (m *RequestValidator) requestTTL(ctx context.Context, identity tlsca.Identity, r types.AccessRequest) (time.Duration, error) { +// calculatePendingRequestTTL calculates the TTL of the Access Request (how long it will await +// approval). request TTL is capped to the smaller value between the const requsetTTL and the +// access request access expiry. +func (m *RequestValidator) calculatePendingRequestTTL(r types.AccessRequest) (time.Duration, error) { + accessExpiryTTL := r.GetAccessExpiry().Sub(m.clock.Now().UTC()) + // If no expiration provided, use default. expiry := r.Expiry() if expiry.IsZero() { - expiry = m.clock.Now().UTC().Add(requestTTL) + // Guard against the default expiry being greater than access expiry. + if requestTTL < accessExpiryTTL { + expiry = m.clock.Now().UTC().Add(requestTTL) + } else { + expiry = m.clock.Now().UTC().Add(accessExpiryTTL) + } } if expiry.Before(m.clock.Now().UTC()) { @@ -1297,8 +1318,13 @@ func (m *RequestValidator) requestTTL(ctx context.Context, identity tlsca.Identi // than the maximum value allowed. Used to return a sensible error to the // user. requestedTTL := expiry.Sub(m.clock.Now().UTC()) - if !r.Expiry().IsZero() && requestedTTL > requestTTL { - return 0, trace.BadParameter("invalid request TTL: %v greater than maximum allowed (%v)", requestedTTL.Round(time.Minute), requestTTL.Round(time.Minute)) + if !r.Expiry().IsZero() { + if requestedTTL > requestTTL { + return 0, trace.BadParameter("invalid request TTL: %v greater than maximum allowed (%v)", requestedTTL.Round(time.Minute), requestTTL.Round(time.Minute)) + } + if requestedTTL > accessExpiryTTL { + return 0, trace.BadParameter("invalid request TTL: %v greater than maximum allowed (%v)", requestedTTL.Round(time.Minute), accessExpiryTTL.Round(time.Minute)) + } } return requestedTTL, nil diff --git a/lib/services/access_request_test.go b/lib/services/access_request_test.go index b47fed510dc5b..3c99d28114f74 100644 --- a/lib/services/access_request_test.go +++ b/lib/services/access_request_test.go @@ -1608,6 +1608,110 @@ func TestPruneRequestRoles(t *testing.T) { } } +// TestCalculatePendingRequesTTL verifies that the TTL for the Access Request is capped to the +// request's access expiry or capped to the default const requestTTL, whichever is smaller. +func TestCalculatePendingRequesTTL(t *testing.T) { + clock := clockwork.NewFakeClock() + now := clock.Now().UTC() + + tests := []struct { + desc string + // accessExpiryTTL == max access duration. + accessExpiryTTL time.Duration + // when the access request expires in the PENDING state. + requestPendingExpiryTTL time.Time + assertion require.ErrorAssertionFunc + expectedDuration time.Duration + }{ + { + desc: "valid: requested ttl < access expiry", + accessExpiryTTL: requestTTL - (3 * day), + requestPendingExpiryTTL: now.Add(requestTTL - (4 * day)), + expectedDuration: requestTTL - (4 * day), + assertion: require.NoError, + }, + { + desc: "valid: requested ttl == access expiry", + accessExpiryTTL: requestTTL - (3 * day), + requestPendingExpiryTTL: now.Add(requestTTL - (3 * day)), + expectedDuration: requestTTL - (3 * day), + assertion: require.NoError, + }, + { + desc: "valid: requested ttl == default request ttl", + accessExpiryTTL: requestTTL, + requestPendingExpiryTTL: now.Add(requestTTL), + expectedDuration: requestTTL, + assertion: require.NoError, + }, + { + desc: "valid: no TTL request defaults to the const requestTTL if access expiry is larger", + accessExpiryTTL: requestTTL + (3 * day), + expectedDuration: requestTTL, + assertion: require.NoError, + }, + { + desc: "valid: no TTL request defaults to accessExpiry if const requestTTL is larger", + accessExpiryTTL: requestTTL - (3 * day), + expectedDuration: requestTTL - (3 * day), + assertion: require.NoError, + }, + { + desc: "invalid: requested ttl > access expiry", + accessExpiryTTL: requestTTL - (3 * day), + requestPendingExpiryTTL: now.Add(requestTTL - (2 * day)), + assertion: require.Error, + }, + { + desc: "invalid: requested ttl > default request TTL", + accessExpiryTTL: requestTTL + (1 * day), + requestPendingExpiryTTL: now.Add(requestTTL + (1 * day)), + assertion: require.Error, + }, + { + desc: "invalid: requested ttl < now", + accessExpiryTTL: requestTTL - (3 * day), + requestPendingExpiryTTL: now.Add(-(3 * day)), + assertion: require.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + // Setup test user "foo" and "bar" and the mock auth server that + // will return users and roles. + uls, err := userloginstate.New(header.Metadata{ + Name: "foo", + }, userloginstate.Spec{ + Roles: []string{"bar"}, + }) + require.NoError(t, err) + + role, err := types.NewRole("bar", types.RoleSpecV6{}) + require.NoError(t, err) + + getter := &mockGetter{ + userStates: map[string]*userloginstate.UserLoginState{"foo": uls}, + roles: map[string]types.Role{"bar": role}, + } + + validator, err := NewRequestValidator(context.Background(), clock, getter, "foo", ExpandVars(true)) + require.NoError(t, err) + + request, err := types.NewAccessRequest("some-id", "foo", "bar") + require.NoError(t, err) + request.SetExpiry(tt.requestPendingExpiryTTL) + request.SetAccessExpiry(now.Add(tt.accessExpiryTTL)) + + ttl, err := validator.calculatePendingRequestTTL(request) + tt.assertion(t, err) + if err == nil { + require.Equal(t, tt.expectedDuration, ttl) + } + }) + } +} + // TestSessionTTL verifies that the TTL for elevated access gets reduced by // requested access time, lifetime of certificate, and strictest session TTL on // any role. @@ -1867,7 +1971,10 @@ func TestValidateAccessRequestClusterNames(t *testing.T) { } } -func TestMaxDuration(t *testing.T) { +// TestValidate_RequestedMaxDuration tests requested max duration +// and the default values for session and pending TTL as a result +// of requested max duration. +func TestValidate_RequestedMaxDuration(t *testing.T) { // describes a collection of roles and their conditions roleDesc := roleTestSet{ "requestedRole": { @@ -1929,6 +2036,8 @@ func TestMaxDuration(t *testing.T) { "david": {"maxDurationReqRole"}, } + defaultSessionTTL := 8 * time.Hour + g := getMockGetter(t, roleDesc, userDesc) tts := []struct { @@ -1938,101 +2047,122 @@ func TestMaxDuration(t *testing.T) { requestor string // the roles to be requested (defaults to "dictator") roles []string - // maxDuration is the requested maxDuration duration - maxDuration time.Duration + // requestedMaxDuration is the requested requestedMaxDuration duration + requestedMaxDuration time.Duration // expectedAccessDuration is the expected access duration expectedAccessDuration time.Duration // expectedSessionTTL is the expected session TTL expectedSessionTTL time.Duration + // expectedPendingTTL is the time when request expires in PENDING state + expectedPendingTTL time.Duration // DryRun is true if the request is a dry run dryRun bool }{ { - desc: "role maxDuration is respected", + desc: "role max_duration is respected and sessionTTL does not exceed the calculated max duration", requestor: "alice", - roles: []string{"requestedRole"}, - maxDuration: 7 * day, + roles: []string{"requestedRole"}, // role max_duration capped to 3 days + requestedMaxDuration: 7 * day, // ignored b/c it's > role max_duration expectedAccessDuration: 3 * day, - expectedSessionTTL: 8 * time.Hour, + expectedSessionTTL: 8 * time.Hour, // caps to defaultSessionTTL b/c it's < than the expectedAccessDuration + expectedPendingTTL: 3 * day, // caps to expectedAccessDuration b/c it's < than the const default TTL }, { - desc: "dry run allows for longer maxDuration then 7d", + desc: "role max_duration is still respected even with dry run (which requests for longer maxDuration)", requestor: "alice", - roles: []string{"requestedRole"}, - maxDuration: 10 * day, + roles: []string{"requestedRole"}, // role max_duration capped to 3 days + requestedMaxDuration: 10 * day, // ignored b/c it's > role max_duration expectedAccessDuration: 3 * day, + expectedPendingTTL: 3 * day, expectedSessionTTL: 8 * time.Hour, dryRun: true, }, { - desc: "maxDuration not set, default maxTTL (8h)", - requestor: "bob", - roles: []string{"requestedRole"}, - expectedAccessDuration: 8 * time.Hour, + desc: "role max_duration is ignored when requestedMaxDuration is not set", + requestor: "alice", + roles: []string{"requestedRole"}, // role max_duration capped to 3 days + expectedAccessDuration: 8 * time.Hour, // caps to defaultSessionTTL since requestedMaxDuration was not set + expectedPendingTTL: 8 * time.Hour, expectedSessionTTL: 8 * time.Hour, }, { - desc: "maxDuration inside request is respected", + desc: "when role max_duration is not set: default to defaultSessionTTL when requestedMaxDuration is not set", requestor: "bob", - roles: []string{"requestedRole"}, - maxDuration: 5 * time.Hour, - expectedAccessDuration: 8 * time.Hour, + roles: []string{"requestedRole"}, // role max_duration is not set (0) + expectedAccessDuration: 8 * time.Hour, // caps to defaultSessionTTL since requestedMaxDuration was not set + expectedPendingTTL: 8 * time.Hour, expectedSessionTTL: 8 * time.Hour, }, { - desc: "users with no MaxDuration are constrained by normal maxTTL logic", + desc: "when role max_duration is not set: requestedMaxDuration is respected when < defaultSessionTTL", + requestor: "bob", + roles: []string{"requestedRole"}, // role max_duration is not set (0) + requestedMaxDuration: 5 * time.Hour, + expectedAccessDuration: 5 * time.Hour, + expectedPendingTTL: 5 * time.Hour, + expectedSessionTTL: 5 * time.Hour, // capped to expectedAccessDuration because it's < defaultSessionTTL (8h) + }, + { + desc: "when role max_duration is not set: requestedMaxDuration is ignored if > defaultSessionTTL", requestor: "bob", - roles: []string{"requestedRole"}, - maxDuration: 2 * day, - expectedAccessDuration: 8 * time.Hour, + roles: []string{"requestedRole"}, // role max_duration is not set (0) + requestedMaxDuration: 10 * time.Hour, + expectedAccessDuration: 8 * time.Hour, // caps to defaultSessionTTL (8h) which is < requestedMaxDuration + expectedPendingTTL: 8 * time.Hour, expectedSessionTTL: 8 * time.Hour, }, { - desc: "maxDuration can't exceed maxTTL by default", + desc: "when role max_duration is not set: requestedMaxDuration is ignored if > role defined sesssionTTL (6h)", requestor: "bob", - roles: []string{"setMaxTTLRole"}, - maxDuration: day, - expectedAccessDuration: 6 * time.Hour, + roles: []string{"setMaxTTLRole"}, // role max_duration is not set (0), caps sessionTTL to 6 hours + requestedMaxDuration: day, + expectedAccessDuration: 6 * time.Hour, // capped to the lowest sessionTTL found in role (6h) which is < requestedMaxDuration + expectedPendingTTL: 6 * time.Hour, expectedSessionTTL: 6 * time.Hour, }, { - desc: "maxDuration is ignored if max_duration is not set in role", + desc: "when role max_duration is not set: requestedMaxDuration is respected when < role defined sessionTTL (6h)", requestor: "bob", - roles: []string{"setMaxTTLRole"}, - maxDuration: 2 * time.Hour, - expectedAccessDuration: 6 * time.Hour, - expectedSessionTTL: 6 * time.Hour, + roles: []string{"setMaxTTLRole"}, // role max_duration is not set (0), caps sessionTTL to 6 hours + requestedMaxDuration: 5 * time.Hour, + expectedAccessDuration: 5 * time.Hour, // caps to requestedMaxDuration which is < role defined sessionTTL (6h) + expectedPendingTTL: 5 * time.Hour, + expectedSessionTTL: 5 * time.Hour, }, { - desc: "maxDuration can exceed maxTTL if max_duration is set in role", + desc: "requestedMaxDuration is respected if it's < the max_duration set in role", requestor: "david", - roles: []string{"setMaxTTLRole"}, - maxDuration: day, + roles: []string{"setMaxTTLRole"}, // role max_duration capped to default MaxAccessDuration, caps sessionTTL to 6 hours + requestedMaxDuration: day, // respected because it's < default const MaxAccessDuration expectedAccessDuration: day, - expectedSessionTTL: 6 * time.Hour, + expectedPendingTTL: day, + expectedSessionTTL: 6 * time.Hour, // capped to the lowest sessionTTL found in role which is < requestedMaxDuration }, { - desc: "maxDuration shorter than maxTTL if max_duration is set in role", + desc: "expectedSessionTTL does not exceed requestedMaxDuration", requestor: "david", - roles: []string{"setMaxTTLRole"}, - maxDuration: 2 * time.Hour, + roles: []string{"setMaxTTLRole"}, // caps max_duration to default MaxAccessDuration, caps sessionTTL to 6 hours + requestedMaxDuration: 2 * time.Hour, // respected because it's < default const MaxAccessDuration expectedAccessDuration: 2 * time.Hour, - expectedSessionTTL: 2 * time.Hour, + expectedPendingTTL: 2 * time.Hour, + expectedSessionTTL: 2 * time.Hour, // capped to requestedMaxDuration because it's < role defined sessionTTL (6h) }, { - desc: "only required roles are considered for maxDuration", - requestor: "carol", - roles: []string{"requestedRole"}, - maxDuration: 5 * day, + desc: "only the assigned role that allows the requested roles are considered for maxDuration", + requestor: "carol", // has multiple roles assigned + roles: []string{"requestedRole"}, // caps max_duration to 3 days + requestedMaxDuration: 5 * day, expectedAccessDuration: 3 * day, + expectedPendingTTL: 3 * day, expectedSessionTTL: 8 * time.Hour, }, { - desc: "only required roles are considered for maxDuration #2", - requestor: "carol", - roles: []string{"requestedRole2"}, - maxDuration: 6 * day, + desc: "only the assigned role that allows the requested roles are considered for maxDuration #2", + requestor: "carol", // has multiple roles assigned + roles: []string{"requestedRole2"}, // caps max_duration to 1 day + requestedMaxDuration: 6 * day, expectedAccessDuration: day, + expectedPendingTTL: day, expectedSessionTTL: 8 * time.Hour, }, } @@ -2048,24 +2178,73 @@ func TestMaxDuration(t *testing.T) { clock := clockwork.NewFakeClock() now := clock.Now().UTC() identity := tlsca.Identity{ - Expires: now.Add(8 * time.Hour), + Expires: now.Add(defaultSessionTTL), } validator, err := NewRequestValidator(context.Background(), clock, g, tt.requestor, ExpandVars(true)) require.NoError(t, err) req.SetCreationTime(now) - req.SetMaxDuration(now.Add(tt.maxDuration)) + req.SetMaxDuration(now.Add(tt.requestedMaxDuration)) req.SetDryRun(tt.dryRun) require.NoError(t, validator.Validate(context.Background(), req, identity)) require.Equal(t, now.Add(tt.expectedAccessDuration), req.GetAccessExpiry()) require.Equal(t, now.Add(tt.expectedAccessDuration), req.GetMaxDuration()) require.Equal(t, now.Add(tt.expectedSessionTTL), req.GetSessionTLL()) + require.Equal(t, now.Add(tt.expectedPendingTTL), req.Expiry()) }) } } +// TestValidate_RequestedPendingTTLAndMaxDuration tests that both requested +// max duration and pending TTL is respected (given within limits). +func TestValidate_RequestedPendingTTLAndMaxDuration(t *testing.T) { + // describes a collection of roles and their conditions + roleDesc := roleTestSet{ + "requestRole": { + condition: types.RoleConditions{ + Request: &types.AccessRequestConditions{ + Roles: []string{"requestRole"}, + MaxDuration: types.Duration(5 * day), + }, + }, + }, + } + + // describes a collection of users with various roles + userDesc := map[string][]string{ + "alice": {"requestRole"}, + } + + g := getMockGetter(t, roleDesc, userDesc) + req, err := types.NewAccessRequest("some-id", "alice", []string{"requestRole"}...) + require.NoError(t, err) + + clock := clockwork.NewFakeClock() + now := clock.Now().UTC() + defaultSessionTTL := 8 * time.Hour + identity := tlsca.Identity{ + Expires: now.Add(defaultSessionTTL), + } + + validator, err := NewRequestValidator(context.Background(), clock, g, "alice", ExpandVars(true)) + require.NoError(t, err) + + requestedMaxDuration := 4 * day + requestedPendingTTL := 2 * day + + req.SetCreationTime(now) + req.SetMaxDuration(now.Add(requestedMaxDuration)) + req.SetExpiry(now.Add(requestedPendingTTL)) + + require.NoError(t, validator.Validate(context.Background(), req, identity)) + require.Equal(t, now.Add(requestedMaxDuration), req.GetAccessExpiry()) + require.Equal(t, now.Add(requestedMaxDuration), req.GetMaxDuration()) + require.Equal(t, now.Add(defaultSessionTTL), req.GetSessionTLL()) + require.Equal(t, now.Add(requestedPendingTTL), req.Expiry()) +} + type roleTestSet map[string]struct { condition types.RoleConditions options types.RoleOptions diff --git a/tool/tsh/common/kube_test.go b/tool/tsh/common/kube_test.go index ac19165591d30..c35d380fc830c 100644 --- a/tool/tsh/common/kube_test.go +++ b/tool/tsh/common/kube_test.go @@ -37,6 +37,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" @@ -365,6 +366,7 @@ func TestKubeSelection(t *testing.T) { cfg.Kube.ResourceMatchers = []services.ResourceMatcher{{ Labels: map[string]apiutils.Strings{"*": {"*"}}, }} + cfg.Clock = clockwork.NewFakeClock() }), ) kubeBarEKS := "bar-eks-us-west-1-123456789012" diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index d8f0f3dd1a3b0..7959fc6596a36 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -47,6 +47,7 @@ import ( "github.com/ghodss/yaml" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" otlp "go.opentelemetry.io/proto/otlp/trace/v1" @@ -1808,7 +1809,11 @@ func TestSSHAccessRequest(t *testing.T) { } alice.SetTraits(traits) - rootAuth, rootProxy := makeTestServers(t, withBootstrap(requester, nodeAccessRole, connector, alice)) + rootAuth, rootProxy := makeTestServers(t, + withBootstrap(requester, nodeAccessRole, connector, alice), + withConfig(func(cfg *servicecfg.Config) { + cfg.Clock = clockwork.NewFakeClock() + })) authAddr, err := rootAuth.AuthAddr() require.NoError(t, err)