diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 92cb198e2b8a6..021e74bf175af 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -5199,7 +5199,7 @@ func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequ // Always perform variable expansion on creation only; this ensures the // access request that is reviewed is the same that is approved. - expandOpts := services.ExpandVars(true) + expandOpts := services.WithExpandVars(true) if err := services.ValidateAccessRequestForUser(ctx, a.clock, a, req, identity, expandOpts); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/services/access_request.go b/lib/services/access_request.go index b28d270945c21..71c56d2bbea52 100644 --- a/lib/services/access_request.go +++ b/lib/services/access_request.go @@ -214,7 +214,7 @@ func CalculateAccessCapabilities(ctx context.Context, clock clockwork.Clock, clt var caps types.AccessCapabilities // all capabilities require use of a request validator. calculating suggested reviewers // requires that the validator be configured for variable expansion. - v, err := NewRequestValidator(ctx, clock, clt, req.User, ExpandVars(req.SuggestedReviewers)) + v, err := newRequestValidator(ctx, clock, clt, req.User, WithExpandVars(req.SuggestedReviewers)) if err != nil { return nil, trace.Wrap(err) } @@ -231,14 +231,14 @@ func CalculateAccessCapabilities(ctx context.Context, clock clockwork.Clock, clt if req.FilterRequestableRolesByResource { resourceIDs = req.ResourceIDs } - caps.RequestableRoles, err = v.GetRequestableRoles(ctx, identity, resourceIDs, req.Login) + caps.RequestableRoles, err = v.getRequestableRoles(ctx, identity, resourceIDs, req.Login) if err != nil { return nil, trace.Wrap(err) } } if req.SuggestedReviewers { - caps.SuggestedReviewers = v.SuggestedReviewers + caps.SuggestedReviewers = v.suggestedReviewers } caps.RequireReason, err = v.calcRequireReasonCap(ctx, req, caps) @@ -251,7 +251,7 @@ func CalculateAccessCapabilities(ctx context.Context, clock clockwork.Clock, clt return &caps, nil } -func (v *RequestValidator) calcRequireReasonCap(ctx context.Context, req types.AccessCapabilitiesRequest, caps types.AccessCapabilities) (requireReason bool, err error) { +func (v *requestValidator) calcRequireReasonCap(ctx context.Context, req types.AccessCapabilitiesRequest, caps types.AccessCapabilities) (requireReason bool, err error) { var roles []string if req.RequestableRoles { roles = caps.RequestableRoles @@ -270,10 +270,10 @@ func (v *RequestValidator) calcRequireReasonCap(ctx context.Context, req types.A // allowedSearchAsRoles returns all allowed `allow.request.search_as_roles` for the user that are // not in the `deny.request.search_as_roles`. It does not filter out any roles that should not be // allowed based on requests. -func (m *RequestValidator) allowedSearchAsRoles() ([]string, error) { +func (m *requestValidator) allowedSearchAsRoles() ([]string, error) { var rolesToRequest []string - for _, roleName := range m.Roles.AllowSearch { - if !m.CanSearchAsRole(roleName) { + for _, roleName := range m.roles.allowSearch { + if !m.canSearchAsRole(roleName) { continue } rolesToRequest = append(rolesToRequest, roleName) @@ -289,7 +289,7 @@ func (m *RequestValidator) allowedSearchAsRoles() ([]string, error) { // applicable for the given list of resourceIDs. // // If loginHint is provided, it will attempt to prune the list to a single role. -func (m *RequestValidator) applicableSearchAsRoles(ctx context.Context, resourceIDs []types.ResourceID, loginHint string) ([]string, error) { +func (m *requestValidator) applicableSearchAsRoles(ctx context.Context, resourceIDs []types.ResourceID, loginHint string) ([]string, error) { rolesToRequest, err := m.allowedSearchAsRoles() if err != nil { return nil, trace.Wrap(err) @@ -1021,13 +1021,14 @@ func (c *ReviewPermissionChecker) push(role types.Role) error { return nil } -// RequestValidator a helper for validating access requests. +// requestValidator a helper for validating access requests. // a user's statically assigned roles are "added" to the // validator via the push() method, which extracts all the // relevant rules, performs variable substitutions, and builds // a set of simple Allow/Deny datastructures. These, in turn, // are used to validate and expand the access request. -type RequestValidator struct { +type requestValidator struct { + logger *slog.Logger clock clockwork.Clock getter RequestValidatorGetter userState UserState @@ -1058,57 +1059,54 @@ type RequestValidator struct { } autoRequest bool prompt string - opts struct { - expandVars bool + opts ValidateRequestOptions + roles struct { + allowRequest, denyRequest []parse.Matcher + allowSearch, denySearch []string } - Roles struct { - AllowRequest, DenyRequest []parse.Matcher - AllowSearch, DenySearch []string - } - Annotations struct { - // Allowed annotations are not greedy, the role that defines the annotation must allow requesting one + annotations struct { + // allow annotations are not greedy, the role that defines the annotation must allow requesting one // of the roles that are being requested in order for the annotation to be applied. - Allow map[singleAnnotation]annotationMatcher - // Denied annotations match greedily, if a user has any role that denies a specific annotation it will + allow map[singleAnnotation]annotationMatcher + // deny annotations match greedily, if a user has any role that denies a specific annotation it will // always be denied. - Deny map[singleAnnotation]struct{} + deny map[singleAnnotation]struct{} } - ThresholdMatchers []struct { - Matchers []parse.Matcher - Thresholds []types.AccessReviewThreshold + thresholdMatchers []struct { + matchers []parse.Matcher + thresholds []types.AccessReviewThreshold } - SuggestedReviewers []string - MaxDurationMatchers []struct { - Matchers []parse.Matcher - MaxDuration time.Duration + suggestedReviewers []string + maxDurationMatchers []struct { + matchers []parse.Matcher + maxDuration time.Duration } - logger *slog.Logger } -// NewRequestValidator configures a new RequestValidator for the specified user. -func NewRequestValidator(ctx context.Context, clock clockwork.Clock, getter RequestValidatorGetter, username string, opts ...ValidateRequestOption) (RequestValidator, error) { +// newRequestValidator configures a new RequestValidator for the specified user. +func newRequestValidator(ctx context.Context, clock clockwork.Clock, getter RequestValidatorGetter, username string, opts ...ValidateRequestOption) (requestValidator, error) { uls, err := GetUserOrLoginState(ctx, getter, username) if err != nil { - return RequestValidator{}, trace.Wrap(err) + return requestValidator{}, trace.Wrap(err) } - m := RequestValidator{ + m := requestValidator{ + logger: slog.With(teleport.ComponentKey, "request.validator"), clock: clock, getter: getter, userState: uls, - logger: slog.With(teleport.ComponentKey, "request.validator"), requiringReasonRoles: make(map[string]struct{}), } for _, opt := range opts { - opt(&m) + opt(&m.opts) } if m.opts.expandVars { // validation process for incoming access requests requires // generating system annotations to be attached to the request // before it is inserted into the backend. - m.Annotations.Allow = make(map[singleAnnotation]annotationMatcher) - m.Annotations.Deny = make(map[singleAnnotation]struct{}) + m.annotations.allow = make(map[singleAnnotation]annotationMatcher) + m.annotations.deny = make(map[singleAnnotation]struct{}) } m.kubernetesResource.allow = make(map[string][]types.RequestKubernetesResource) @@ -1118,18 +1116,18 @@ func NewRequestValidator(ctx context.Context, clock clockwork.Clock, getter Requ for _, roleName := range m.userState.GetRoles() { role, err := m.getter.GetRole(ctx, roleName) if err != nil { - return RequestValidator{}, trace.Wrap(err) + return requestValidator{}, trace.Wrap(err) } if err := m.push(ctx, role); err != nil { - return RequestValidator{}, trace.Wrap(err) + return requestValidator{}, trace.Wrap(err) } } return m, nil } -// Validate validates an access request and potentially modifies it depending on how +// validate validates an access request and potentially modifies it depending on how // the validator was configured. -func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest, identity tlsca.Identity) error { +func (m *requestValidator) validate(ctx context.Context, req types.AccessRequest, identity tlsca.Identity) error { if m.userState.GetName() != req.GetUser() { return trace.BadParameter("request validator configured for different user (this is a bug)") } @@ -1155,7 +1153,7 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest return trace.BadParameter("unexpected wildcard request (this is a bug)") } - requestable, err := m.GetRequestableRoles(ctx, identity, nil, "") + requestable, err := m.getRequestableRoles(ctx, identity, nil, "") if err != nil { return trace.Wrap(err) } @@ -1181,14 +1179,14 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest // verify that all requested roles are permissible for _, roleName := range req.GetRoles() { if len(req.GetRequestedResourceIDs()) > 0 { - if !m.CanSearchAsRole(roleName) { + if !m.canSearchAsRole(roleName) { // Roles are normally determined automatically for resource // access requests, this role must have been explicitly // requested, or a new deny rule has since been added. return trace.BadParameter("user %q can not request role %q", req.GetUser(), roleName) } } else { - if !m.CanRequestRole(roleName) { + if !m.canRequestRole(roleName) { return trace.BadParameter("user %q can not request role %q", req.GetUser(), roleName) } } @@ -1256,7 +1254,7 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest // incoming requests must have system annotations attached // before being inserted into the backend. this is how the // RBAC system propagates sideband information to plugins. - systemAnnotations, err := m.SystemAnnotations(req) + systemAnnotations, err := m.systemAnnotations(req) if err != nil { return trace.Wrap(err) } @@ -1265,7 +1263,7 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest // if no suggested reviewers were provided by the user, then // use the defaults suggested by the user's static roles. if len(req.GetSuggestedReviewers()) == 0 { - req.SetSuggestedReviewers(apiutils.Deduplicate(m.SuggestedReviewers)) + req.SetSuggestedReviewers(apiutils.Deduplicate(m.suggestedReviewers)) } // Pin the time to the current time to prevent time drift. @@ -1324,7 +1322,7 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest } // isReasonRequired checks if the reason is required for the given roles and resource IDs. -func (v *RequestValidator) isReasonRequired(ctx context.Context, requestedRoles []string, requestedResourceIDs []types.ResourceID) (required bool, explanation string, err error) { +func (v *requestValidator) isReasonRequired(ctx context.Context, requestedRoles []string, requestedResourceIDs []types.ResourceID) (required bool, explanation string, err error) { if v.requireReasonForAllRoles { return true, "request reason must be specified (required request_access option in one of the roles)", nil } @@ -1355,7 +1353,7 @@ func (v *RequestValidator) isReasonRequired(ctx context.Context, requestedRoles // calculateMaxAccessDuration calculates the maximum time for the access request. // The max duration time is the minimum of the max_duration time set on the request // and the max_duration time set on the request role. -func (m *RequestValidator) calculateMaxAccessDuration(req types.AccessRequest, sessionTTL time.Duration) (time.Duration, error) { +func (m *requestValidator) calculateMaxAccessDuration(req types.AccessRequest, sessionTTL time.Duration) (time.Duration, error) { // Check if the maxDuration time is set. maxDurationTime := req.GetMaxDuration() maxDuration := maxDurationTime.Sub(req.GetCreationTime()) @@ -1396,13 +1394,13 @@ func (m *RequestValidator) calculateMaxAccessDuration(req types.AccessRequest, s return minAdjDuration, nil } -func (m *RequestValidator) maxDurationForRole(roleName string) time.Duration { +func (m *requestValidator) maxDurationForRole(roleName string) time.Duration { var maxDurationForRole time.Duration - for _, tms := range m.MaxDurationMatchers { - for _, matcher := range tms.Matchers { + for _, tms := range m.maxDurationMatchers { + for _, matcher := range tms.matchers { if matcher.Match(roleName) { - if tms.MaxDuration > maxDurationForRole { - maxDurationForRole = tms.MaxDuration + if tms.maxDuration > maxDurationForRole { + maxDurationForRole = tms.maxDuration } } } @@ -1413,7 +1411,7 @@ func (m *RequestValidator) maxDurationForRole(roleName string) time.Duration { // calculatePendingRequestTTL calculates the TTL of the Access Request (how long it will await // approval). request TTL is capped to the smaller value between the const requestTTL and the // access request access expiry. -func (m *RequestValidator) calculatePendingRequestTTL(r types.AccessRequest, now time.Time) (time.Duration, error) { +func (m *requestValidator) calculatePendingRequestTTL(r types.AccessRequest, now time.Time) (time.Duration, error) { accessExpiryTTL := r.GetAccessExpiry().Sub(now) // If no expiration provided, use default. @@ -1449,7 +1447,7 @@ func (m *RequestValidator) calculatePendingRequestTTL(r types.AccessRequest, now // sessionTTL calculates the TTL of the elevated certificate that will be issued // if the Access Request is approved. -func (m *RequestValidator) sessionTTL(ctx context.Context, identity tlsca.Identity, r types.AccessRequest, now time.Time) (time.Duration, error) { +func (m *requestValidator) sessionTTL(ctx context.Context, identity tlsca.Identity, r types.AccessRequest, now time.Time) (time.Duration, error) { ttl, err := m.truncateTTL(ctx, identity, r.GetAccessExpiry(), r.GetRoles(), now) if err != nil { return 0, trace.BadParameter("invalid session TTL: %v", err) @@ -1468,7 +1466,7 @@ func (m *RequestValidator) sessionTTL(ctx context.Context, identity tlsca.Identi // truncateTTL will truncate given expiration by identity expiration and // shortest session TTL of any role. -func (m *RequestValidator) truncateTTL(ctx context.Context, identity tlsca.Identity, expiry time.Time, roles []string, now time.Time) (time.Duration, error) { +func (m *requestValidator) truncateTTL(ctx context.Context, identity tlsca.Identity, expiry time.Time, roles []string, now time.Time) (time.Duration, error) { ttl := apidefaults.MaxCertDuration // Reduce by remaining TTL on requesting certificate (identity). @@ -1503,23 +1501,23 @@ func (m *RequestValidator) truncateTTL(ctx context.Context, identity tlsca.Ident // getResourceViewingRoles gets the subset of the user's roles that could be used // to view resources (i.e., base roles + search as roles). -func (m *RequestValidator) getResourceViewingRoles() []string { +func (m *requestValidator) getResourceViewingRoles() []string { roles := slices.Clone(m.userState.GetRoles()) - for _, role := range m.Roles.AllowSearch { - if m.CanSearchAsRole(role) { + for _, role := range m.roles.allowSearch { + if m.canSearchAsRole(role) { roles = append(roles, role) } } return apiutils.Deduplicate(roles) } -// GetRequestableRoles gets the list of all existent roles which the user is +// getRequestableRoles gets the list of all existent roles which the user is // able to request. This operation is expensive since it loads all existent // roles to determine the role list. Prefer calling CanRequestRole // when checking against a known role list. If resource IDs or a login hints // are provided, roles will be filtered to only include those that would // allow access to the given resource with the given login. -func (m *RequestValidator) GetRequestableRoles(ctx context.Context, identity tlsca.Identity, resourceIDs []types.ResourceID, loginHint string) ([]string, error) { +func (m *requestValidator) getRequestableRoles(ctx context.Context, identity tlsca.Identity, resourceIDs []types.ResourceID, loginHint string) ([]string, error) { allRoles, err := m.getter.GetRoles(ctx) if err != nil { return nil, trace.Wrap(err) @@ -1555,7 +1553,7 @@ func (m *RequestValidator) GetRequestableRoles(ctx context.Context, identity tls var expanded []string for _, role := range allRoles { n := role.GetName() - if slices.Contains(m.userState.GetRoles(), n) || !m.CanRequestRole(n) { + if slices.Contains(m.userState.GetRoles(), n) || !m.canRequestRole(n) { continue } @@ -1604,7 +1602,7 @@ func setAllowRequestKubeResourceLookup(allowKubernetesResources []types.RequestK // push compiles a role's configuration into the request validator. // All of the requesting user's statically assigned roles must be pushed // before validation begins. -func (m *RequestValidator) push(ctx context.Context, role types.Role) error { +func (m *requestValidator) push(ctx context.Context, role types.Role) error { var err error m.requireReasonForAllRoles = m.requireReasonForAllRoles || role.GetOptions().RequestAccess.RequireReason() @@ -1630,27 +1628,27 @@ func (m *RequestValidator) push(ctx context.Context, role types.Role) error { m.kubernetesResource.deny = append(m.kubernetesResource.deny, deny.KubernetesResources...) } - m.Roles.DenyRequest, err = appendRoleMatchers(m.Roles.DenyRequest, deny.Roles, deny.ClaimsToRoles, m.userState.GetTraits()) + m.roles.denyRequest, err = appendRoleMatchers(m.roles.denyRequest, deny.Roles, deny.ClaimsToRoles, m.userState.GetTraits()) if err != nil { return trace.Wrap(err) } // record what will be the starting index of the allow and deny matchers for this role, if it applies any. - astart := len(m.Roles.AllowRequest) + astart := len(m.roles.allowRequest) - m.Roles.AllowRequest, err = appendRoleMatchers(m.Roles.AllowRequest, allow.Roles, allow.ClaimsToRoles, m.userState.GetTraits()) + m.roles.allowRequest, err = appendRoleMatchers(m.roles.allowRequest, allow.Roles, allow.ClaimsToRoles, m.userState.GetTraits()) if err != nil { return trace.Wrap(err) } - m.Roles.AllowSearch = apiutils.Deduplicate(append(m.Roles.AllowSearch, allow.SearchAsRoles...)) - m.Roles.DenySearch = apiutils.Deduplicate(append(m.Roles.DenySearch, deny.SearchAsRoles...)) + m.roles.allowSearch = apiutils.Deduplicate(append(m.roles.allowSearch, allow.SearchAsRoles...)) + m.roles.denySearch = apiutils.Deduplicate(append(m.roles.denySearch, deny.SearchAsRoles...)) if m.opts.expandVars { // if this role added additional allow matchers, then we need to record the relationship // between its matchers and its thresholds. This information is used later to calculate // the rtm and threshold list. - newAllowRequestMatchers := m.Roles.AllowRequest[astart:] + newAllowRequestMatchers := m.roles.allowRequest[astart:] newAllowSearchMatchers := literalMatchers(allow.SearchAsRoles) allNewAllowMatchers := make([]parse.Matcher, 0, len(newAllowRequestMatchers)+len(newAllowSearchMatchers)) @@ -1658,22 +1656,22 @@ func (m *RequestValidator) push(ctx context.Context, role types.Role) error { allNewAllowMatchers = append(allNewAllowMatchers, newAllowSearchMatchers...) if len(allNewAllowMatchers) > 0 { - m.ThresholdMatchers = append(m.ThresholdMatchers, struct { - Matchers []parse.Matcher - Thresholds []types.AccessReviewThreshold + m.thresholdMatchers = append(m.thresholdMatchers, struct { + matchers []parse.Matcher + thresholds []types.AccessReviewThreshold }{ - Matchers: allNewAllowMatchers, - Thresholds: allow.Thresholds, + matchers: allNewAllowMatchers, + thresholds: allow.Thresholds, }) } if allow.MaxDuration != 0 { - m.MaxDurationMatchers = append(m.MaxDurationMatchers, struct { - Matchers []parse.Matcher - MaxDuration time.Duration + m.maxDurationMatchers = append(m.maxDurationMatchers, struct { + matchers []parse.Matcher + maxDuration time.Duration }{ - Matchers: allNewAllowMatchers, - MaxDuration: allow.MaxDuration.Duration(), + matchers: allNewAllowMatchers, + maxDuration: allow.MaxDuration.Duration(), }) } @@ -1683,7 +1681,7 @@ func (m *RequestValidator) push(ctx context.Context, role types.Role) error { m.insertAllowedAnnotations(ctx, allow, newAllowRequestMatchers, newAllowSearchMatchers) m.insertDeniedAnnotations(ctx, deny) - m.SuggestedReviewers = append(m.SuggestedReviewers, allow.SuggestedReviewers...) + m.suggestedReviewers = append(m.suggestedReviewers, allow.SuggestedReviewers...) } return nil } @@ -1691,7 +1689,7 @@ func (m *RequestValidator) push(ctx context.Context, role types.Role) error { // setRolesForResourceRequest determines if the given access request is // resource-based, and if so, it determines which underlying roles are necessary // and adds them to the request. -func (m *RequestValidator) setRolesForResourceRequest(ctx context.Context, req types.AccessRequest) error { +func (m *requestValidator) setRolesForResourceRequest(ctx context.Context, req types.AccessRequest) error { if !m.opts.expandVars { // Don't set the roles if expandVars is not set, they have probably // already been set and we are just validating the request. @@ -1721,7 +1719,7 @@ func (m *RequestValidator) setRolesForResourceRequest(ctx context.Context, req t // // Returns pruned roles, and a map of requested roles with allowed kinds (with denied applied), used to help aid user in case a request gets rejected, // lets user know which kinds are allowed for each requested roles. -func (m *RequestValidator) pruneRequestedRolesNotMatchingKubernetesResourceKinds(requestedResourceIDs []types.ResourceID, requestedRoles []string) ([]string, map[string][]string) { +func (m *requestValidator) pruneRequestedRolesNotMatchingKubernetesResourceKinds(requestedResourceIDs []types.ResourceID, requestedRoles []string) ([]string, map[string][]string) { // Filter for the kube_cluster and its subresource kinds. requestedKubeKinds := make(map[string]struct{}) for _, resourceID := range requestedResourceIDs { @@ -1834,14 +1832,14 @@ func (c *thresholdCollector) pushThreshold(t types.AccessReviewThreshold) (uint3 return uint32(len(c.Thresholds) - 1), nil } -// CanRequestRole checks if a given role can be requested. -func (m *RequestValidator) CanRequestRole(name string) bool { - for _, deny := range m.Roles.DenyRequest { +// canRequestRole checks if a given role can be requested. +func (m *requestValidator) canRequestRole(name string) bool { + for _, deny := range m.roles.denyRequest { if deny.Match(name) { return false } } - for _, allow := range m.Roles.AllowRequest { + for _, allow := range m.roles.allowRequest { if allow.Match(name) { return true } @@ -1849,30 +1847,30 @@ func (m *RequestValidator) CanRequestRole(name string) bool { return false } -// CanSearchAsRole check if a given role can be requested through a search-based +// canSearchAsRole check if a given role can be requested through a search-based // access request -func (m *RequestValidator) CanSearchAsRole(name string) bool { - if slices.Contains(m.Roles.DenySearch, name) { +func (m *requestValidator) canSearchAsRole(name string) bool { + if slices.Contains(m.roles.denySearch, name) { return false } - for _, deny := range m.Roles.DenyRequest { + for _, deny := range m.roles.denyRequest { if deny.Match(name) { return false } } - return slices.Contains(m.Roles.AllowSearch, name) + return slices.Contains(m.roles.allowSearch, name) } // collectSetsForRole collects the threshold index sets which describe the various groups of // thresholds which must pass in order for a request for the given role to be approved. -func (m *RequestValidator) collectSetsForRole(c *thresholdCollector, role string) ([]types.ThresholdIndexSet, error) { +func (m *requestValidator) collectSetsForRole(c *thresholdCollector, role string) ([]types.ThresholdIndexSet, error) { var sets []types.ThresholdIndexSet Outer: - for _, tms := range m.ThresholdMatchers { - for _, matcher := range tms.Matchers { + for _, tms := range m.thresholdMatchers { + for _, matcher := range tms.matchers { if matcher.Match(role) { - set, err := c.push(tms.Thresholds) + set, err := c.push(tms.thresholds) if err != nil { return nil, trace.Wrap(err) } @@ -1930,7 +1928,7 @@ func (m *annotationMatcher) matchesRequest(req types.AccessRequest) bool { // // Annotations are only applied to access requests requests when one of the requested roles matches one of the // role matchers. -func (m *RequestValidator) insertAllowedAnnotations(ctx context.Context, conditions types.AccessRequestConditions, roleRequestMatchers, resourceRequestMatchers []parse.Matcher) { +func (m *requestValidator) insertAllowedAnnotations(ctx context.Context, conditions types.AccessRequestConditions, roleRequestMatchers, resourceRequestMatchers []parse.Matcher) { for annotationKey, annotationValueTemplates := range conditions.Annotations { // iterate through all new values and expand any // variable interpolation syntax they contain. @@ -1944,10 +1942,10 @@ func (m *RequestValidator) insertAllowedAnnotations(ctx context.Context, conditi } for _, expanded := range expandedValues { annotation := singleAnnotation{annotationKey, expanded} - matchers := m.Annotations.Allow[annotation] + matchers := m.annotations.allow[annotation] matchers.roleRequestMatchers = append(matchers.roleRequestMatchers, roleRequestMatchers...) matchers.resourceRequestMatchers = append(matchers.resourceRequestMatchers, resourceRequestMatchers...) - m.Annotations.Allow[annotation] = matchers + m.annotations.allow[annotation] = matchers } } } @@ -1955,7 +1953,7 @@ func (m *RequestValidator) insertAllowedAnnotations(ctx context.Context, conditi // insertDeniedAnnotations constructs all denied annotations for a given AccessRequestConditions instance // from one of the users current roles and adds them to the denied annotations set. -func (m *RequestValidator) insertDeniedAnnotations(ctx context.Context, conditions types.AccessRequestConditions) { +func (m *requestValidator) insertDeniedAnnotations(ctx context.Context, conditions types.AccessRequestConditions) { for annotationKey, annotationValueTemplates := range conditions.Annotations { // iterate through all new values and expand any // variable interpolation syntax they contain. @@ -1969,19 +1967,19 @@ func (m *RequestValidator) insertDeniedAnnotations(ctx context.Context, conditio } for _, expanded := range expandedValues { annotation := singleAnnotation{annotationKey, expanded} - m.Annotations.Deny[annotation] = struct{}{} + m.annotations.deny[annotation] = struct{}{} } } } } -// SystemAnnotations calculates the system annotations for a pending +// systemAnnotations calculates the system annotations for a pending // access request. -func (m *RequestValidator) SystemAnnotations(req types.AccessRequest) (map[string][]string, error) { +func (m *requestValidator) systemAnnotations(req types.AccessRequest) (map[string][]string, error) { annotations := make(map[string][]string) - for annotation, allowMatchers := range m.Annotations.Allow { - if _, denied := m.Annotations.Deny[annotation]; denied { + for annotation, allowMatchers := range m.annotations.allow { + if _, denied := m.annotations.deny[annotation]; denied { // Deny matches are greedy, if any of the users roles denies this annotation it is filtered out. continue } @@ -2001,28 +1999,32 @@ func (m *RequestValidator) SystemAnnotations(req types.AccessRequest) (map[strin return annotations, nil } -type ValidateRequestOption func(*RequestValidator) +type ValidateRequestOptions struct { + expandVars bool +} + +type ValidateRequestOption func(*ValidateRequestOptions) -// ExpandVars toggles variable expansion during request validation. Variable expansion includes +// WithExpandVars toggles variable expansion during request validation. Variable expansion includes // expanding wildcard requests, setting system annotations, finding applicable roles for // resource-based requests and gathering threshold information. Variable expansion should be run // by the auth server prior to storing an access request for the first time. -func ExpandVars(expand bool) ValidateRequestOption { - return func(v *RequestValidator) { - v.opts.expandVars = expand +func WithExpandVars(expandVars bool) ValidateRequestOption { + return func(v *ValidateRequestOptions) { + v.expandVars = expandVars } } // ValidateAccessRequestForUser validates an access request against the associated users's -// *statically assigned* roles. If [[ExpandVars]] is set to true, it will also expand wildcard +// *statically assigned* roles. If [[WithExpandVars]] is set to true, it will also expand wildcard // requests, setting their role list to include all roles the user is allowed to request. // Expansion should be performed before an access request is initially placed in the backend. func ValidateAccessRequestForUser(ctx context.Context, clock clockwork.Clock, getter RequestValidatorGetter, req types.AccessRequest, identity tlsca.Identity, opts ...ValidateRequestOption) error { - v, err := NewRequestValidator(ctx, clock, getter, req.GetUser(), opts...) + v, err := newRequestValidator(ctx, clock, getter, req.GetUser(), opts...) if err != nil { return trace.Wrap(err) } - return trace.Wrap(v.Validate(ctx, req, identity)) + return trace.Wrap(v.validate(ctx, req, identity)) } // UnmarshalAccessRequest unmarshals the AccessRequest resource from JSON. @@ -2115,7 +2117,7 @@ func getInvalidKubeKindAccessRequestsError(mappedRequestedRolesToAllowedKinds ma // resource is in a leaf cluster. // // If loginHint is provided, it will attempt to prune the list to a single role. -func (m *RequestValidator) pruneResourceRequestRoles( +func (m *requestValidator) pruneResourceRequestRoles( ctx context.Context, resourceIDs []types.ResourceID, loginHint string, @@ -2292,7 +2294,7 @@ func getAllowedKubeResourceKinds(allowedKinds []string, deniedKinds []string) [] return slices.Collect(maps.Keys(allowed)) } -func (m *RequestValidator) roleAllowsResource( +func (m *requestValidator) roleAllowsResource( role types.Role, resource types.ResourceWithLabels, loginHint string, @@ -2322,7 +2324,7 @@ func (m *RequestValidator) roleAllowsResource( // requested access. Except for resource Kinds present in types.KubernetesResourcesKinds, // the underlying resources are the same as requested. If the resource requested // is a Kubernetes resource, we return the underlying Kubernetes cluster. -func (m *RequestValidator) getUnderlyingResourcesByResourceIDs(ctx context.Context, resourceIDs []types.ResourceID) ([]types.ResourceWithLabels, error) { +func (m *requestValidator) getUnderlyingResourcesByResourceIDs(ctx context.Context, resourceIDs []types.ResourceID) ([]types.ResourceWithLabels, error) { if len(resourceIDs) == 0 { return []types.ResourceWithLabels{}, nil } diff --git a/lib/services/access_request_test.go b/lib/services/access_request_test.go index 3e5cd903861d7..7529a55d03081 100644 --- a/lib/services/access_request_test.go +++ b/lib/services/access_request_test.go @@ -742,10 +742,10 @@ func TestReviewThresholds(t *testing.T) { // perform request validation (necessary in order to initialize internal // request variables like annotations and thresholds). - validator, err := NewRequestValidator(ctx, clock, g, tt.requestor, ExpandVars(true)) + validator, err := newRequestValidator(ctx, clock, g, tt.requestor, WithExpandVars(true)) require.NoError(t, err, "scenario=%q", tt.desc) - require.NoError(t, validator.Validate(ctx, req, identity), "scenario=%q", tt.desc) + require.NoError(t, validator.validate(ctx, req, identity), "scenario=%q", tt.desc) Inner: for ri, rt := range tt.reviews { @@ -1275,10 +1275,10 @@ func TestRolesForResourceRequest(t *testing.T) { Expires: clock.Now().UTC().Add(8 * time.Hour), } - validator, err := NewRequestValidator(context.Background(), clock, g, uls.GetName(), ExpandVars(true)) + validator, err := newRequestValidator(context.Background(), clock, g, uls.GetName(), WithExpandVars(true)) require.NoError(t, err) - err = validator.Validate(context.Background(), req, identity) + err = validator.validate(context.Background(), req, identity) require.ErrorIs(t, err, tc.expectError) if err != nil { return @@ -1669,7 +1669,7 @@ func TestPruneRequestRoles(t *testing.T) { accessCaps, err := CalculateAccessCapabilities(ctx, clock, g, tlsca.Identity{}, types.AccessCapabilitiesRequest{User: user, ResourceIDs: tc.requestResourceIDs}) require.NoError(t, err) - err = ValidateAccessRequestForUser(ctx, clock, g, req, identity, ExpandVars(true)) + err = ValidateAccessRequestForUser(ctx, clock, g, req, identity, WithExpandVars(true)) if tc.expectError { require.Error(t, err) return @@ -1940,7 +1940,7 @@ func TestCalculatePendingRequestTTL(t *testing.T) { roles: map[string]types.Role{"bar": role}, } - validator, err := NewRequestValidator(context.Background(), clock, getter, "foo", ExpandVars(true)) + validator, err := newRequestValidator(context.Background(), clock, getter, "foo", WithExpandVars(true)) require.NoError(t, err) request, err := types.NewAccessRequest("some-id", "foo", "bar") @@ -2015,7 +2015,7 @@ func TestSessionTTL(t *testing.T) { roles: map[string]types.Role{"bar": role}, } - validator, err := NewRequestValidator(context.Background(), clock, getter, "foo", ExpandVars(true)) + validator, err := newRequestValidator(context.Background(), clock, getter, "foo", WithExpandVars(true)) require.NoError(t, err) request, err := types.NewAccessRequest("some-id", "foo", "bar") @@ -2070,11 +2070,11 @@ func TestAutoRequest(t *testing.T) { cases := []struct { name string roles []types.Role - assertion func(t *testing.T, validator *RequestValidator, accessCaps *types.AccessCapabilities) + assertion func(t *testing.T, validator *requestValidator, accessCaps *types.AccessCapabilities) }{ { name: "no roles", - assertion: func(t *testing.T, validator *RequestValidator, accessCaps *types.AccessCapabilities) { + assertion: func(t *testing.T, validator *requestValidator, accessCaps *types.AccessCapabilities) { require.False(t, validator.requireReasonForAllRoles) require.False(t, validator.autoRequest) require.Empty(t, validator.prompt) @@ -2087,7 +2087,7 @@ func TestAutoRequest(t *testing.T) { { name: "with prompt", roles: []types.Role{empty, optionalRole, promptRole}, - assertion: func(t *testing.T, validator *RequestValidator, accessCaps *types.AccessCapabilities) { + assertion: func(t *testing.T, validator *requestValidator, accessCaps *types.AccessCapabilities) { require.False(t, validator.requireReasonForAllRoles) require.False(t, validator.autoRequest) require.Equal(t, "test prompt", validator.prompt) @@ -2100,7 +2100,7 @@ func TestAutoRequest(t *testing.T) { { name: "with auto request", roles: []types.Role{alwaysRole}, - assertion: func(t *testing.T, validator *RequestValidator, accessCaps *types.AccessCapabilities) { + assertion: func(t *testing.T, validator *requestValidator, accessCaps *types.AccessCapabilities) { require.False(t, validator.requireReasonForAllRoles) require.True(t, validator.autoRequest) require.Empty(t, validator.prompt) @@ -2113,7 +2113,7 @@ func TestAutoRequest(t *testing.T) { { name: "with prompt and auto request", roles: []types.Role{promptRole, alwaysRole}, - assertion: func(t *testing.T, validator *RequestValidator, accessCaps *types.AccessCapabilities) { + assertion: func(t *testing.T, validator *requestValidator, accessCaps *types.AccessCapabilities) { require.False(t, validator.requireReasonForAllRoles) require.True(t, validator.autoRequest) require.Equal(t, "test prompt", validator.prompt) @@ -2126,7 +2126,7 @@ func TestAutoRequest(t *testing.T) { { name: "with reason and auto prompt", roles: []types.Role{reasonRole}, - assertion: func(t *testing.T, validator *RequestValidator, accessCaps *types.AccessCapabilities) { + assertion: func(t *testing.T, validator *requestValidator, accessCaps *types.AccessCapabilities) { require.True(t, validator.requireReasonForAllRoles) require.True(t, validator.autoRequest) require.Empty(t, validator.prompt) @@ -2159,7 +2159,7 @@ func TestAutoRequest(t *testing.T) { getter.userStates[uls.GetName()] = uls - validator, err := NewRequestValidator(ctx, clock, getter, uls.GetName(), ExpandVars(true)) + validator, err := newRequestValidator(ctx, clock, getter, uls.GetName(), WithExpandVars(true)) require.NoError(t, err) accessCapabilities, err := CalculateAccessCapabilities(ctx, clock, getter, tlsca.Identity{}, types.AccessCapabilitiesRequest{ @@ -2428,7 +2428,7 @@ func TestReasonRequired(t *testing.T) { // test RequestValidator.Validate { - validator, err := NewRequestValidator(ctx, clock, g, uls.GetName(), ExpandVars(true)) + validator, err := newRequestValidator(ctx, clock, g, uls.GetName(), WithExpandVars(true)) require.NoError(t, err) req, err := types.NewAccessRequestWithResources( @@ -2436,17 +2436,17 @@ func TestReasonRequired(t *testing.T) { require.NoError(t, err) // No reason in the request. - err = validator.Validate(ctx, req.Copy(), identity) + err = validator.validate(ctx, req.Copy(), identity) require.ErrorIs(t, err, tc.expectError) // White-space reason should be treated as no reason. req.SetRequestReason(" \t \n ") - err = validator.Validate(ctx, req.Copy(), identity) + err = validator.validate(ctx, req.Copy(), identity) require.ErrorIs(t, err, tc.expectError) // When non-empty reason is provided then validation should pass. req.SetRequestReason("good reason") - err = validator.Validate(ctx, req.Copy(), identity) + err = validator.validate(ctx, req.Copy(), identity) require.NoError(t, err) } @@ -2544,7 +2544,7 @@ func TestValidateResourceRequestSizeLimits(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, ValidateAccessRequestForUser(context.Background(), clock, g, req, identity, ExpandVars(true))) + require.NoError(t, ValidateAccessRequestForUser(context.Background(), clock, g, req, identity, WithExpandVars(true))) require.Len(t, req.GetRequestedResourceIDs(), 2) require.Equal(t, "/someCluster/node/resource1", types.ResourceIDToString(req.GetRequestedResourceIDs()[0])) require.Equal(t, "/someCluster/node/resource2", types.ResourceIDToString(req.GetRequestedResourceIDs()[1])) @@ -2558,7 +2558,7 @@ func TestValidateResourceRequestSizeLimits(t *testing.T) { }) } req.SetRequestedResourceIDs(requestedResourceIDs) - require.ErrorContains(t, ValidateAccessRequestForUser(context.Background(), clock, g, req, identity, ExpandVars(true)), "access request exceeds maximum length") + require.ErrorContains(t, ValidateAccessRequestForUser(context.Background(), clock, g, req, identity, WithExpandVars(true)), "access request exceeds maximum length") } func TestValidateAccessRequestClusterNames(t *testing.T) { @@ -2848,7 +2848,7 @@ func TestValidate_RequestedMaxDuration(t *testing.T) { Expires: now.Add(defaultSessionTTL), } - validator, err := NewRequestValidator(context.Background(), clock, g, tt.requestor, ExpandVars(true)) + validator, err := newRequestValidator(context.Background(), clock, g, tt.requestor, WithExpandVars(true)) require.NoError(t, err) req.SetCreationTime(now) @@ -2857,7 +2857,7 @@ func TestValidate_RequestedMaxDuration(t *testing.T) { } req.SetDryRun(tt.dryRun) - require.NoError(t, validator.Validate(context.Background(), req, identity)) + require.NoError(t, validator.validate(context.Background(), req, identity)) require.Equal(t, now.Add(tt.expectedAccessDuration), req.GetAccessExpiry()) require.Equal(t, now.Add(tt.expectedAccessDuration), req.GetMaxDuration()) require.Equal(t, now.Add(tt.expectedSessionTTL), req.GetSessionTLL()) @@ -2897,7 +2897,7 @@ func TestValidate_RequestedPendingTTLAndMaxDuration(t *testing.T) { Expires: now.Add(defaultSessionTTL), } - validator, err := NewRequestValidator(context.Background(), clock, g, "alice", ExpandVars(true)) + validator, err := newRequestValidator(context.Background(), clock, g, "alice", WithExpandVars(true)) require.NoError(t, err) requestedMaxDuration := 4 * day @@ -2907,7 +2907,7 @@ func TestValidate_RequestedPendingTTLAndMaxDuration(t *testing.T) { req.SetMaxDuration(now.Add(requestedMaxDuration)) req.SetExpiry(now.Add(requestedPendingTTL)) - require.NoError(t, validator.Validate(context.Background(), req, identity)) + require.NoError(t, validator.validate(context.Background(), req, identity)) require.Equal(t, now.Add(requestedMaxDuration), req.GetAccessExpiry()) require.Equal(t, now.Add(requestedMaxDuration), req.GetMaxDuration()) require.Equal(t, now.Add(defaultSessionTTL), req.GetSessionTLL()) @@ -3479,10 +3479,10 @@ func TestValidate_WithAllowRequestKubernetesResources(t *testing.T) { Expires: clock.Now().UTC().Add(8 * time.Hour), } - validator, err := NewRequestValidator(context.Background(), clock, g, uls.GetName(), ExpandVars(true)) + validator, err := newRequestValidator(context.Background(), clock, g, uls.GetName(), WithExpandVars(true)) require.NoError(t, err) - err = validator.Validate(context.Background(), req, identity) + err = validator.validate(context.Background(), req, identity) if tc.wantInvalidRequestKindErr { require.Error(t, err) require.Contains(t, err.Error(), InvalidKubernetesKindAccessRequest) diff --git a/tool/tctl/common/access_request_command.go b/tool/tctl/common/access_request_command.go index ef62637dda8ca..40f3f357f0740 100644 --- a/tool/tctl/common/access_request_command.go +++ b/tool/tctl/common/access_request_command.go @@ -332,7 +332,7 @@ func (c *AccessRequestCommand) Create(ctx context.Context, client *authclient.Cl Client: client, UserLoginStatesGetter: client.UserLoginStateClient(), } - err = services.ValidateAccessRequestForUser(ctx, clockwork.NewRealClock(), users, req, tlsca.Identity{}, services.ExpandVars(true)) + err = services.ValidateAccessRequestForUser(ctx, clockwork.NewRealClock(), users, req, tlsca.Identity{}, services.WithExpandVars(true)) if err != nil { return trace.Wrap(err) }