Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 44 additions & 41 deletions lib/services/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,9 @@ func CalculateAccessCapabilities(ctx context.Context, clock clockwork.Clock, clt
}

// applicableSearchAsRoles prunes the search_as_roles and only returns those
// application for the given list of resourceIDs.
// applications for the given list of resourceIDs.
func (m *RequestValidator) applicableSearchAsRoles(ctx context.Context, resourceIDs []types.ResourceID, loginHint string) ([]string, error) {
// First collect all possible search_as_roles.
// First, collect all possible search_as_roles.
var rolesToRequest []string
for _, roleName := range m.Roles.AllowSearch {
if !m.CanSearchAsRole(roleName) {
Expand Down Expand Up @@ -349,9 +349,9 @@ type thresholdFilterContext struct {
}

// reviewPermissionContext is the top-level context used to evaluate
// a user's review permissions. It is functionally identical to the
// a user's review permissions. It is functionally identical to the
// thresholdFilterContext except that it does not expose review parameters.
// this is because review permissions are used to determine which requests
// This is because review permissions are used to determine which requests
// a user is allowed to see, and therefore needs to be calculable prior
// to construction of review parameters.
type reviewPermissionContext struct {
Expand All @@ -363,7 +363,7 @@ type reviewPermissionContext struct {
// syntax errors. Used to help prevent users from accidentally writing incorrect
// predicates. This function should only be called by the auth server prior to
// storing new/updated roles. Normal role validation deliberately omits these
// checks in order to allow us to extend the available namespaces without breaking
// checks to allow us to extend the available namespaces without breaking
// backwards compatibility with older nodes/proxies (which never need to evaluate
// these predicates).
func ValidateAccessPredicates(role types.Role) error {
Expand Down Expand Up @@ -462,7 +462,7 @@ func ApplyAccessReview(req types.AccessRequest, rev types.AccessReview, author U
req.SetAssumeStartTime(*rev.AssumeStartTime)
}

// request is still pending, so check to see if this
// the request is still pending, so check to see if this
// review introduces a state-transition.
res, err := calculateReviewBasedResolution(req)
if err != nil || res == nil {
Expand All @@ -487,7 +487,7 @@ func ApplyAccessReview(req types.AccessRequest, rev types.AccessReview, author U
// checkReviewCompat performs basic checks to ensure that the specified review can be
// applied to the specified request (part of review application logic).
func checkReviewCompat(req types.AccessRequest, rev types.AccessReview) error {
// Proposal cannot be already resolved.
// The Proposal cannot be yet resolved.
if !rev.ProposedState.IsResolved() {
// Skip the promoted state in the error message. It's not a state that most people
// should be concerned with.
Expand Down Expand Up @@ -621,7 +621,7 @@ func calculateReviewBasedResolution(req types.AccessRequest) (*requestResolution
// of their approval thresholds.
approved := make(map[string]struct{})

// denied keeps track of whether or not we've seen *any* role get denied
// denied keeps track of whether we've seen *any* role get denied
// (which role does not currently matter since we short-circuit on the
// first denial to be triggered).
denied := false
Expand Down Expand Up @@ -668,11 +668,11 @@ ProcessReviews:
CheckRoleApprovals:
for role, thresholdSets := range req.GetRoleThresholdMapping() {
if _, ok := approved[role]; ok {
// role was marked approved during a previous iteration
// the role was marked approved during a previous iteration
continue CheckRoleApprovals
}

// iterate through all threshold sets. All sets must have at least
// iterate through all threshold sets. All sets must have at least
// one threshold which has hit its approval count in order for the
// role to be considered approved.
CheckThresholdSets:
Expand All @@ -693,7 +693,7 @@ ProcessReviews:
}

// no thresholds met for this set. there may be additional roles/thresholds
// which did meet their requirements this iteration, but there is no point
// that did meet their requirements this iteration, but there is no point in
// processing them unless this set has also hit its requirements. we therefore
// move immediately to processing the next review.
continue ProcessReviews
Expand Down Expand Up @@ -765,7 +765,7 @@ func GetTraitMappings(cms []types.ClaimMapping) types.TraitMappingSet {
}

// RequestValidatorGetter is the interface required by the request validation
// functions used to get necessary resources.
// functions used to get the necessary resources.
type RequestValidatorGetter interface {
UserLoginStatesGetter
UserGetter
Expand Down Expand Up @@ -844,7 +844,7 @@ func (c *ReviewPermissionChecker) HasAllowDirectives() bool {
}

// CanReviewRequest checks if the user is allowed to review the specified request.
// note that the ability to review a request does not necessarily imply that any specific
// Note that the ability to review a request does not necessarily imply that any specific
// approval/denial thresholds will actually match the user's review. Matching one or more
// thresholds is not a pre-requisite for review submission.
func (c *ReviewPermissionChecker) CanReviewRequest(req types.AccessRequest) (bool, error) {
Expand All @@ -856,7 +856,7 @@ func (c *ReviewPermissionChecker) CanReviewRequest(req types.AccessRequest) (boo
return false, nil
}

// method allocates new array if an override has already been
// method allocates a new array if an override has already been
// called, so get the role list once in advance.
requestedRoles := req.GetOriginalRoles()

Expand Down Expand Up @@ -1187,7 +1187,7 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest
return trace.Wrap(err)
}

// build the thresholds array and role-threshold-mapping. the rtm encodes the
// build the threshold array and role-threshold-mapping. the rtm encodes the
// relationship between a role, and the thresholds which must pass in order
// for that role to be considered approved. when building the validator we
// recorded the relationship between the various allow matchers and their associated
Expand All @@ -1211,15 +1211,18 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest
// RBAC system propagates sideband information to plugins.
req.SetSystemAnnotations(m.SystemAnnotations())

// if no suggested reviewers were provided by the user then
// if no suggested reviewers were provided by the user, then
// use the defaults suggested by the user's static roles.
if len(req.GetSuggestedReviewers()) == 0 {
req.SetSuggestedReviewers(apiutils.Deduplicate(m.SuggestedReviewers))
}

// Pin the time to the current time to prevent time drift.
now := m.clock.Now().UTC()

// Calculate the expiration time of the elevated certificate that will
// be issued if the Access Request is approved.
sessionTTL, err := m.sessionTTL(ctx, identity, req)
sessionTTL, err := m.sessionTTL(ctx, identity, req, now)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -1231,7 +1234,7 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest

// If the maxDuration flag is set, consider it instead of only using the session TTL.
var maxAccessDuration time.Duration
now := m.clock.Now().UTC()

if maxDuration > 0 {
req.SetSessionTLL(now.Add(min(sessionTTL, maxDuration)))
maxAccessDuration = maxDuration
Expand All @@ -1246,13 +1249,13 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest
// Adjusted max access duration is equal to the access expiry time.
req.SetMaxDuration(accessExpiry)

// Setting access expiry before calling `calculatePendingRequesetTTL`
// Setting access expiry before calling `calculatePendingRequestTTL`
// matters since the func relies on this adjusted expiry.
req.SetAccessExpiry(accessExpiry)

// Calculate the expiration time of the Access Request (how long it
// will await approval).
requestTTL, err := m.calculatePendingRequestTTL(req)
requestTTL, err := m.calculatePendingRequestTTL(req, now)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1324,36 +1327,36 @@ func (m *RequestValidator) calculateMaxAccessDuration(req types.AccessRequest, s
}

// calculatePendingRequestTTL calculates the TTL of the Access Request (how long it will await
// approval). request TTL is capped to the smaller value between the const requsetTTL and the
// approval). request TTL is capped to the smaller value between the const requestTTL and the
// access request access expiry.
func (m *RequestValidator) calculatePendingRequestTTL(r types.AccessRequest) (time.Duration, error) {
accessExpiryTTL := r.GetAccessExpiry().Sub(m.clock.Now().UTC())
func (m *RequestValidator) calculatePendingRequestTTL(r types.AccessRequest, now time.Time) (time.Duration, error) {
accessExpiryTTL := r.GetAccessExpiry().Sub(now)

// If no expiration provided, use default.
expiry := r.Expiry()
if expiry.IsZero() {
// Guard against the default expiry being greater than access expiry.
if requestTTL < accessExpiryTTL {
expiry = m.clock.Now().UTC().Add(requestTTL)
expiry = now.Add(requestTTL)
} else {
expiry = m.clock.Now().UTC().Add(accessExpiryTTL)
expiry = now.Add(accessExpiryTTL)
}
}

if expiry.Before(m.clock.Now().UTC()) {
if expiry.Before(now) {
return 0, trace.BadParameter("invalid request TTL: Access Request can not be created in the past")
}

// Before returning the TTL, validate that the value requested was smaller
// than the maximum value allowed. Used to return a sensible error to the
// user.
requestedTTL := expiry.Sub(m.clock.Now().UTC())
requestedTTL := expiry.Sub(now)
if !r.Expiry().IsZero() {
if requestedTTL > requestTTL {
return 0, trace.BadParameter("invalid request TTL: %v greater than maximum allowed (%v)", requestedTTL.Round(time.Minute), requestTTL.Round(time.Minute))
return 0, trace.BadParameter("invalid request TTL: %v greater than maximum allowed (%v)", requestedTTL, requestTTL)
}
if requestedTTL > accessExpiryTTL {
return 0, trace.BadParameter("invalid request TTL: %v greater than maximum allowed (%v)", requestedTTL.Round(time.Minute), accessExpiryTTL.Round(time.Minute))
return 0, trace.BadParameter("invalid request TTL: %v greater than maximum allowed (%v)", requestedTTL, accessExpiryTTL)
}
}

Expand All @@ -1362,37 +1365,37 @@ func (m *RequestValidator) calculatePendingRequestTTL(r types.AccessRequest) (ti

// sessionTTL calculates the TTL of the elevated certificate that will be issued
// if the Access Request is approved.
func (m *RequestValidator) sessionTTL(ctx context.Context, identity tlsca.Identity, r types.AccessRequest) (time.Duration, error) {
ttl, err := m.truncateTTL(ctx, identity, r.GetAccessExpiry(), r.GetRoles())
func (m *RequestValidator) sessionTTL(ctx context.Context, identity tlsca.Identity, r types.AccessRequest, now time.Time) (time.Duration, error) {
ttl, err := m.truncateTTL(ctx, identity, r.GetAccessExpiry(), r.GetRoles(), now)
if err != nil {
return 0, trace.BadParameter("invalid session TTL: %v", err)
}

// Before returning the TTL, validate that the value requested was smaller
// than the maximum value allowed. Used to return a sensible error to the
// user.
requestedTTL := r.GetAccessExpiry().Sub(m.clock.Now().UTC())
requestedTTL := r.GetAccessExpiry().Sub(now)
if !r.GetAccessExpiry().IsZero() && requestedTTL > ttl {
return 0, trace.BadParameter("invalid session TTL: %v greater than maximum allowed (%v)", requestedTTL.Round(time.Minute), ttl.Round(time.Minute))
return 0, trace.BadParameter("invalid session TTL: %v greater than maximum allowed (%v)", requestedTTL, ttl)
}

return ttl, nil
}

// truncateTTL will truncate given expiration by identity expiration and
// shortest session TTL of any role.
func (m *RequestValidator) truncateTTL(ctx context.Context, identity tlsca.Identity, expiry time.Time, roles []string) (time.Duration, error) {
func (m *RequestValidator) truncateTTL(ctx context.Context, identity tlsca.Identity, expiry time.Time, roles []string, now time.Time) (time.Duration, error) {
ttl := apidefaults.MaxCertDuration

// Reduce by remaining TTL on requesting certificate (identity).
identityTTL := identity.Expires.Sub(m.clock.Now())
identityTTL := identity.Expires.Sub(now)
if identityTTL > 0 && identityTTL < ttl {
ttl = identityTTL
}

// Reduce TTL further if expiration time requested is shorter than that
// identity.
expiryTTL := expiry.Sub(m.clock.Now())
expiryTTL := expiry.Sub(now)
if expiryTTL > 0 && expiryTTL < ttl {
ttl = expiryTTL
}
Expand All @@ -1415,7 +1418,7 @@ func (m *RequestValidator) truncateTTL(ctx context.Context, identity tlsca.Ident
}

// getResourceViewingRoles gets the subset of the user's roles that could be used
// to view resources (i.e. base roles + search as roles).
// to view resources (i.e., base roles + search as roles).
func (m *RequestValidator) getResourceViewingRoles() []string {
roles := slices.Clone(m.userState.GetRoles())
for _, role := range m.Roles.AllowSearch {
Expand All @@ -1428,8 +1431,8 @@ func (m *RequestValidator) getResourceViewingRoles() []string {

// GetRequestableRoles gets the list of all existent roles which the user is
// able to request. This operation is expensive since it loads all existent
// roles in order to determine the role list. Prefer calling CanRequestRole
// when checking against a known role list. If resource IDs or a login hint
// roles to determine the role list. Prefer calling CanRequestRole
// when checking against a known role list. If resource IDs or a login hints
// are provided, roles will be filtered to only include those that would
// allow access to the given resource with the given login.
func (m *RequestValidator) GetRequestableRoles(ctx context.Context, identity tlsca.Identity, resourceIDs []types.ResourceID, loginHint string) ([]string, error) {
Expand Down Expand Up @@ -1562,7 +1565,7 @@ func (m *RequestValidator) push(role types.Role) error {
}

// setRolesForResourceRequest determines if the given access request is
// resource-based, and if so it determines which underlying roles are necessary
// resource-based, and if so, it determines which underlying roles are necessary
// and adds them to the request.
func (m *RequestValidator) setRolesForResourceRequest(ctx context.Context, req types.AccessRequest) error {
if !m.opts.expandVars {
Expand All @@ -1588,7 +1591,7 @@ func (m *RequestValidator) setRolesForResourceRequest(ctx context.Context, req t
return nil
}

// thresholdCollector is a helper which assembles the Thresholds array for a request.
// thresholdCollector is a helper that assembles the Thresholds array for a request.
// the push() method is used to insert groups of related thresholds and calculate their
// corresponding index set.
type thresholdCollector struct {
Expand Down
8 changes: 4 additions & 4 deletions lib/services/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1774,9 +1774,9 @@ func TestGetRequestableRoles(t *testing.T) {
}
}

// TestCalculatePendingRequesTTL verifies that the TTL for the Access Request is capped to the
// TestCalculatePendingRequestTTL verifies that the TTL for the Access Request is capped to the
// request's access expiry or capped to the default const requestTTL, whichever is smaller.
func TestCalculatePendingRequesTTL(t *testing.T) {
func TestCalculatePendingRequestTTL(t *testing.T) {
clock := clockwork.NewFakeClock()
now := clock.Now().UTC()

Expand Down Expand Up @@ -1869,7 +1869,7 @@ func TestCalculatePendingRequesTTL(t *testing.T) {
request.SetExpiry(tt.requestPendingExpiryTTL)
request.SetAccessExpiry(now.Add(tt.accessExpiryTTL))

ttl, err := validator.calculatePendingRequestTTL(request)
ttl, err := validator.calculatePendingRequestTTL(request, now)
tt.assertion(t, err)
if err == nil {
require.Equal(t, tt.expectedDuration, ttl)
Expand Down Expand Up @@ -1943,7 +1943,7 @@ func TestSessionTTL(t *testing.T) {
request.SetAccessExpiry(tt.accessExpiry)
require.NoError(t, err)

ttl, err := validator.sessionTTL(context.Background(), tt.identity, request)
ttl, err := validator.sessionTTL(context.Background(), tt.identity, request, now)
tt.assertion(t, err)
if err == nil {
require.Equal(t, tt.expectedTTL, ttl)
Expand Down
3 changes: 1 addition & 2 deletions tool/tsh/common/kube_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -366,7 +365,7 @@ func TestKubeSelection(t *testing.T) {
cfg.Kube.ResourceMatchers = []services.ResourceMatcher{{
Labels: map[string]apiutils.Strings{"*": {"*"}},
}}
cfg.Clock = clockwork.NewFakeClock()
// Do not use a fake clock to better imitate real-world behavior.
}),
)
kubeBarEKS := "bar-eks-us-west-1-123456789012"
Expand Down
6 changes: 2 additions & 4 deletions tool/tsh/common/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ import (

"github.com/ghodss/yaml"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
otlp "go.opentelemetry.io/proto/otlp/trace/v1"
Expand Down Expand Up @@ -1823,9 +1822,8 @@ func TestSSHAccessRequest(t *testing.T) {

rootAuth, rootProxy := makeTestServers(t,
withBootstrap(requester, searchOnlyRequester, nodeAccessRole, emptyRole, connector, alice),
withConfig(func(cfg *servicecfg.Config) {
cfg.Clock = clockwork.NewFakeClock()
}))
// Do not use a fake clock to better imitate real-world behavior.
)

authAddr, err := rootAuth.AuthAddr()
require.NoError(t, err)
Expand Down