diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 5a49eb1c261f4..3397b3da2a3e2 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -1700,17 +1700,7 @@ func certRequestDeviceExtensions(ext tlsca.DeviceExtensions) certRequestOption { // GetUserOrLoginState will return the given user or the login state associated with the user. func (a *Server) GetUserOrLoginState(ctx context.Context, username string) (services.UserState, error) { - uls, err := a.GetUserLoginState(ctx, username) - if err != nil && !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - - if err == nil { - return uls, nil - } - - user, err := a.GetUser(username, false) - return user, trace.Wrap(err) + return services.GetUserOrLoginState(ctx, a, username) } func (a *Server) GenerateOpenSSHCert(ctx context.Context, req *proto.OpenSSHCertRequest) (*proto.OpenSSHCert, error) { diff --git a/lib/services/access_request.go b/lib/services/access_request.go index 7d24a7fab33d8..667bd1f73821a 100644 --- a/lib/services/access_request.go +++ b/lib/services/access_request.go @@ -241,7 +241,7 @@ func (m *RequestValidator) applicableSearchAsRoles(ctx context.Context, resource rolesToRequest = append(rolesToRequest, roleName) } if len(rolesToRequest) == 0 { - return nil, trace.AccessDenied(`Resource Access Requests require usable "search_as_roles", none found for user %q`, m.user.GetName()) + return nil, trace.AccessDenied(`Resource Access Requests require usable "search_as_roles", none found for user %q`, m.userState.GetName()) } // Prune the list of roles to request to only those which may be necessary @@ -373,7 +373,7 @@ func ValidateAccessPredicates(role types.Role) error { } // ApplyAccessReview attempts to apply the specified access review to the specified request. -func ApplyAccessReview(req types.AccessRequest, rev types.AccessReview, author types.User) error { +func ApplyAccessReview(req types.AccessRequest, rev types.AccessReview, author UserState) error { if rev.Author != author.GetName() { return trace.BadParameter("mismatched review author (expected %q, got %q)", rev.Author, author) } @@ -485,7 +485,7 @@ func checkReviewCompat(req types.AccessRequest, rev types.AccessReview) error { // collectReviewThresholdIndexes aggregates the indexes of all thresholds whose filters match // the supplied review (part of review application logic). -func collectReviewThresholdIndexes(req types.AccessRequest, rev types.AccessReview, author types.User) ([]uint32, error) { +func collectReviewThresholdIndexes(req types.AccessRequest, rev types.AccessReview, author UserState) ([]uint32, error) { parser, err := newThresholdFilterParser(req, rev, author) if err != nil { return nil, trace.Wrap(err) @@ -534,7 +534,7 @@ func accessReviewThresholdMatchesFilter(t types.AccessReviewThreshold, parser pr // newThresholdFilterParser creates a custom parser context which exposes a simplified view of the review author // and the request for evaluation of review threshold filters. -func newThresholdFilterParser(req types.AccessRequest, rev types.AccessReview, author types.User) (BoolPredicateParser, error) { +func newThresholdFilterParser(req types.AccessRequest, rev types.AccessReview, author UserState) (BoolPredicateParser, error) { return NewJSONBoolParser(thresholdFilterContext{ Reviewer: reviewAuthorContext{ Roles: author.GetRoles(), @@ -723,6 +723,7 @@ type ResourceLister interface { // RequestValidatorGetter is the interface required by the request validation // functions used to get necessary resources. type RequestValidatorGetter interface { + UserLoginStatesGetter UserGetter RoleGetter ResourceLister @@ -779,8 +780,8 @@ func insertAnnotations(annotations map[string][]string, conditions types.AccessR // ReviewPermissionChecker is a helper for validating whether a user // is allowed to review specific access requests. type ReviewPermissionChecker struct { - User types.User - Roles struct { + UserState UserState + Roles struct { // allow/deny mappings sort role matches into lists based on their // constraining predicate (where) expression. AllowReview, DenyReview map[string][]parse.Matcher @@ -807,7 +808,7 @@ func (c *ReviewPermissionChecker) CanReviewRequest(req types.AccessRequest) (boo // adding role subselection support. // user cannot review their own request - if c.User.GetName() == req.GetUser() { + if c.UserState.GetName() == req.GetUser() { return false, nil } @@ -817,8 +818,8 @@ func (c *ReviewPermissionChecker) CanReviewRequest(req types.AccessRequest) (boo parser, err := NewJSONBoolParser(reviewPermissionContext{ Reviewer: reviewAuthorContext{ - Roles: c.User.GetRoles(), - Traits: c.User.GetTraits(), + Roles: c.UserState.GetRoles(), + Traits: c.UserState.GetTraits(), }, Request: reviewRequestContext{ Roles: requestedRoles, @@ -903,13 +904,13 @@ Outer: } func NewReviewPermissionChecker(ctx context.Context, getter RequestValidatorGetter, username string) (ReviewPermissionChecker, error) { - user, err := getter.GetUser(username, false) + uls, err := GetUserOrLoginState(ctx, getter, username) if err != nil { return ReviewPermissionChecker{}, trace.Wrap(err) } c := ReviewPermissionChecker{ - User: user, + UserState: uls, } c.Roles.AllowReview = make(map[string][]parse.Matcher) @@ -917,7 +918,7 @@ func NewReviewPermissionChecker(ctx context.Context, getter RequestValidatorGett // load all statically assigned roles for the user and // use them to build our checker state. - for _, roleName := range c.User.GetRoles() { + for _, roleName := range c.UserState.GetRoles() { role, err := getter.GetRole(ctx, roleName) if err != nil { return ReviewPermissionChecker{}, trace.Wrap(err) @@ -935,12 +936,12 @@ func (c *ReviewPermissionChecker) push(role types.Role) error { var err error - c.Roles.DenyReview[deny.Where], err = appendRoleMatchers(c.Roles.DenyReview[deny.Where], deny.Roles, deny.ClaimsToRoles, c.User.GetTraits()) + c.Roles.DenyReview[deny.Where], err = appendRoleMatchers(c.Roles.DenyReview[deny.Where], deny.Roles, deny.ClaimsToRoles, c.UserState.GetTraits()) if err != nil { return trace.Wrap(err) } - c.Roles.AllowReview[allow.Where], err = appendRoleMatchers(c.Roles.AllowReview[allow.Where], allow.Roles, allow.ClaimsToRoles, c.User.GetTraits()) + c.Roles.AllowReview[allow.Where], err = appendRoleMatchers(c.Roles.AllowReview[allow.Where], allow.Roles, allow.ClaimsToRoles, c.UserState.GetTraits()) if err != nil { return trace.Wrap(err) } @@ -957,7 +958,7 @@ func (c *ReviewPermissionChecker) push(role types.Role) error { type RequestValidator struct { clock clockwork.Clock getter RequestValidatorGetter - user types.User + userState UserState requireReason bool opts struct { expandVars bool @@ -982,15 +983,15 @@ type RequestValidator struct { // NewRequestValidator configures a new RequestValidator for the specified user. func NewRequestValidator(ctx context.Context, clock clockwork.Clock, getter RequestValidatorGetter, username string, opts ...ValidateRequestOption) (RequestValidator, error) { - user, err := getter.GetUser(username, false) + uls, err := GetUserOrLoginState(ctx, getter, username) if err != nil { return RequestValidator{}, trace.Wrap(err) } m := RequestValidator{ - clock: clock, - getter: getter, - user: user, + clock: clock, + getter: getter, + userState: uls, } for _, opt := range opts { opt(&m) @@ -1005,7 +1006,7 @@ func NewRequestValidator(ctx context.Context, clock clockwork.Clock, getter Requ // load all statically assigned roles for the user and // use them to build our validation state. - for _, roleName := range m.user.GetRoles() { + for _, roleName := range m.userState.GetRoles() { role, err := m.getter.GetRole(ctx, roleName) if err != nil { return RequestValidator{}, trace.Wrap(err) @@ -1020,7 +1021,7 @@ func NewRequestValidator(ctx context.Context, clock clockwork.Clock, getter Requ // Validate validates an access request and potentially modifies it depending on how // the validator was configured. func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest, identity tlsca.Identity) error { - if m.user.GetName() != req.GetUser() { + if m.userState.GetName() != req.GetUser() { return trace.BadParameter("request validator configured for different user (this is a bug)") } @@ -1305,7 +1306,7 @@ func (m *RequestValidator) GetRequestableRoles() ([]string, error) { var expanded []string for _, role := range allRoles { - if n := role.GetName(); !slices.Contains(m.user.GetRoles(), n) && m.CanRequestRole(n) { + if n := role.GetName(); !slices.Contains(m.userState.GetRoles(), n) && m.CanRequestRole(n) { // user does not currently hold this role, and is allowed to request it. expanded = append(expanded, n) } @@ -1323,7 +1324,7 @@ func (m *RequestValidator) push(role types.Role) error { allow, deny := role.GetAccessRequestConditions(types.Allow), role.GetAccessRequestConditions(types.Deny) - m.Roles.DenyRequest, err = appendRoleMatchers(m.Roles.DenyRequest, deny.Roles, deny.ClaimsToRoles, m.user.GetTraits()) + m.Roles.DenyRequest, err = appendRoleMatchers(m.Roles.DenyRequest, deny.Roles, deny.ClaimsToRoles, m.userState.GetTraits()) if err != nil { return trace.Wrap(err) } @@ -1332,7 +1333,7 @@ func (m *RequestValidator) push(role types.Role) error { // matchers for this role, if it applies any. astart := len(m.Roles.AllowRequest) - m.Roles.AllowRequest, err = appendRoleMatchers(m.Roles.AllowRequest, allow.Roles, allow.ClaimsToRoles, m.user.GetTraits()) + m.Roles.AllowRequest, err = appendRoleMatchers(m.Roles.AllowRequest, allow.Roles, allow.ClaimsToRoles, m.userState.GetTraits()) if err != nil { return trace.Wrap(err) } @@ -1371,8 +1372,8 @@ func (m *RequestValidator) push(role types.Role) error { // validation process for incoming access requests requires // generating system annotations to be attached to the request // before it is inserted into the backend. - insertAnnotations(m.Annotations.Deny, deny, m.user.GetTraits()) - insertAnnotations(m.Annotations.Allow, allow, m.user.GetTraits()) + insertAnnotations(m.Annotations.Deny, deny, m.userState.GetTraits()) + insertAnnotations(m.Annotations.Allow, allow, m.userState.GetTraits()) m.SuggestedReviewers = append(m.SuggestedReviewers, allow.SuggestedReviewers...) } @@ -1670,7 +1671,7 @@ func (m *RequestValidator) pruneResourceRequestRoles( } } - allRoles, err := FetchRoles(roles, m.getter, m.user.GetTraits()) + allRoles, err := FetchRoles(roles, m.getter, m.userState.GetTraits()) if err != nil { return nil, trace.Wrap(err) } @@ -1787,7 +1788,7 @@ func (m *RequestValidator) roleAllowsResource( matchers = append(matchers, NewLoginMatcher(loginHint)) } matchers = append(matchers, extraMatchers...) - err := roleSet.checkAccess(resource, m.user.GetTraits(), AccessState{MFAVerified: true}, matchers...) + err := roleSet.checkAccess(resource, m.userState.GetTraits(), AccessState{MFAVerified: true}, matchers...) if trace.IsAccessDenied(err) { // Access denied, this role does not allow access to this resource, no // unexpected error to report. diff --git a/lib/services/access_request_test.go b/lib/services/access_request_test.go index 707bce74b0222..3424bd92ec80d 100644 --- a/lib/services/access_request_test.go +++ b/lib/services/access_request_test.go @@ -30,6 +30,8 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/header" + "github.com/gravitational/teleport/api/types/userloginstate" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/tlsca" @@ -37,6 +39,7 @@ import ( // mockGetter mocks the UserAndRoleGetter interface. type mockGetter struct { + userStates map[string]*userloginstate.UserLoginState users map[string]types.User roles map[string]types.Role nodes map[string]types.Server @@ -50,23 +53,27 @@ type mockGetter struct { // user inserts a new user with the specified roles and returns the username. func (m *mockGetter) user(t *testing.T, roles ...string) string { name := uuid.New().String() - user, err := types.NewUser(name) + uls, err := userloginstate.New(header.Metadata{ + Name: name, + }, userloginstate.Spec{ + Roles: roles, + }) require.NoError(t, err) - user.SetRoles(roles) - m.users[name] = user + m.userStates[name] = uls return name } -func (m *mockGetter) GetUser(name string, withSecrets bool) (types.User, error) { - if withSecrets { - return nil, trace.NotImplemented("mock getter does not store secrets") - } - user, ok := m.users[name] +func (m *mockGetter) GetUserLoginStates(context.Context) ([]*userloginstate.UserLoginState, error) { + return nil, trace.NotImplemented("GetUserLoginStates is not implemented") +} + +func (m *mockGetter) GetUserLoginState(ctx context.Context, name string) (*userloginstate.UserLoginState, error) { + uls, ok := m.userStates[name] if !ok { - return nil, trace.NotFound("no such user: %q", name) + return nil, trace.NotFound("no such user login state: %q", name) } - return user, nil + return uls, nil } func (m *mockGetter) GetRole(ctx context.Context, name string) (types.Role, error) { @@ -77,6 +84,18 @@ func (m *mockGetter) GetRole(ctx context.Context, name string) (types.Role, erro return role, nil } +func (m *mockGetter) GetUser(name string, withSecrets bool) (types.User, error) { + if withSecrets { + return nil, trace.NotImplemented("") + } + + user, ok := m.users[name] + if !ok { + return nil, trace.NotFound("no such user: %q", name) + } + return user, nil +} + func (m *mockGetter) GetRoles(ctx context.Context) ([]types.Role, error) { roles := make([]types.Role, 0, len(m.roles)) for _, r := range m.roles { @@ -241,15 +260,28 @@ func TestReviewThresholds(t *testing.T) { } // describes a collection of users with various roles - userDesc := map[string][]string{ + ulsDesc := map[string][]string{ "alice": {"populist", "proletariat", "intelligentsia", "military"}, - "bob": {"general", "proletariat", "intelligentsia", "military"}, "carol": {"conqueror", "proletariat", "intelligentsia", "military"}, - "dave": {"populist", "general", "conqueror"}, "erika": {"populist", "idealist"}, } + userStates := make(map[string]*userloginstate.UserLoginState) + for name, roles := range ulsDesc { + uls, err := userloginstate.New(header.Metadata{ + Name: name, + }, userloginstate.Spec{ + Roles: roles, + }) + require.NoError(t, err) + userStates[name] = uls + } + users := make(map[string]types.User) + userDesc := map[string][]string{ + "bob": {"general", "proletariat", "intelligentsia", "military"}, + "dave": {"populist", "general", "conqueror"}, + } for name, roles := range userDesc { user, err := types.NewUser(name) @@ -260,8 +292,9 @@ func TestReviewThresholds(t *testing.T) { } g := &mockGetter{ - roles: roles, - users: users, + roles: roles, + userStates: userStates, + users: users, } const ( @@ -635,7 +668,7 @@ func TestReviewThresholds(t *testing.T) { ProposedState: rt.propose, } - author, ok := users[rt.author] + author, ok := userStates[rt.author] require.True(t, ok, "scenario=%q, rev=%d", tt.desc, ri) err = ApplyAccessReview(req, rev, author) @@ -1104,21 +1137,24 @@ func TestRolesForResourceRequest(t *testing.T) { } for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - user, err := types.NewUser("test-user") + uls, err := userloginstate.New(header.Metadata{ + Name: "test-user", + }, userloginstate.Spec{ + Roles: tc.currentRoles, + }) require.NoError(t, err) - user.SetRoles(tc.currentRoles) - users := map[string]types.User{ - user.GetName(): user, + userStates := map[string]*userloginstate.UserLoginState{ + uls.GetName(): uls, } g := &mockGetter{ roles: roles, - users: users, + userStates: userStates, clusterName: "my-cluster", } req, err := types.NewAccessRequestWithResources( - "some-id", user.GetName(), tc.requestRoles, tc.requestResourceIDs) + "some-id", uls.GetName(), tc.requestRoles, tc.requestResourceIDs) require.NoError(t, err) clock := clockwork.NewFakeClock() @@ -1126,7 +1162,7 @@ func TestRolesForResourceRequest(t *testing.T) { Expires: clock.Now().UTC().Add(8 * time.Hour), } - validator, err := NewRequestValidator(context.Background(), clock, g, user.GetName(), ExpandVars(true)) + validator, err := NewRequestValidator(context.Background(), clock, g, uls.GetName(), ExpandVars(true)) require.NoError(t, err) err = validator.Validate(context.Background(), req, identity) @@ -1147,6 +1183,7 @@ func TestPruneRequestRoles(t *testing.T) { g := &mockGetter{ roles: make(map[string]types.Role), + userStates: make(map[string]*userloginstate.UserLoginState), users: make(map[string]types.User), nodes: make(map[string]types.Server), kubeServers: make(map[string]types.KubeServer), @@ -1241,10 +1278,10 @@ func TestPruneRequestRoles(t *testing.T) { } user := g.user(t, "response-team") - g.users[user].SetTraits(map[string][]string{ + g.userStates[user].Spec.Traits = map[string][]string{ "logins": {"responder"}, "team": {"response-team"}, - }) + } nodeDesc := []struct { name string @@ -1598,9 +1635,12 @@ func TestRequestTTL(t *testing.T) { t.Run(tt.desc, func(t *testing.T) { // Setup test user "foo" and "bar" and the mock auth server that // will return users and roles. - user, err := types.NewUser("foo") + uls, err := userloginstate.New(header.Metadata{ + Name: "foo", + }, userloginstate.Spec{ + Roles: []string{"bar"}, + }) require.NoError(t, err) - user.SetRoles([]string{"bar"}) role, err := types.NewRole("bar", types.RoleSpecV6{ Options: types.RoleOptions{ @@ -1610,8 +1650,8 @@ func TestRequestTTL(t *testing.T) { require.NoError(t, err) getter := &mockGetter{ - users: map[string]types.User{"foo": user}, - roles: map[string]types.Role{"bar": role}, + userStates: map[string]*userloginstate.UserLoginState{"foo": uls}, + roles: map[string]types.Role{"bar": role}, } validator, err := NewRequestValidator(context.Background(), clock, getter, "foo", ExpandVars(true)) @@ -1675,7 +1715,6 @@ func TestSessionTTL(t *testing.T) { // will return users and roles. user, err := types.NewUser("foo") require.NoError(t, err) - user.SetRoles([]string{"bar"}) role, err := types.NewRole("bar", types.RoleSpecV6{ Options: types.RoleOptions{ @@ -2049,19 +2088,22 @@ func getMockGetter(t *testing.T, roleDesc roleTestSet, userDesc map[string][]str roles[name] = role } - users := make(map[string]types.User) + userStates := make(map[string]*userloginstate.UserLoginState) for name, roles := range userDesc { - user, err := types.NewUser(name) + uls, err := userloginstate.New(header.Metadata{ + Name: name, + }, userloginstate.Spec{ + Roles: roles, + }) require.NoError(t, err) - user.SetRoles(roles) - users[name] = user + userStates[name] = uls } g := &mockGetter{ - roles: roles, - users: users, + roles: roles, + userStates: userStates, } return g } diff --git a/lib/services/local/dynamic_access.go b/lib/services/local/dynamic_access.go index acf914fe721d4..f92a11c3337b8 100644 --- a/lib/services/local/dynamic_access.go +++ b/lib/services/local/dynamic_access.go @@ -171,7 +171,7 @@ func (s *DynamicAccessService) ApplyAccessReview(ctx context.Context, params typ } // run the application logic - if err := services.ApplyAccessReview(req, params.Review, checker.User); err != nil { + if err := services.ApplyAccessReview(req, params.Review, checker.UserState); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/services/user_login_state.go b/lib/services/user_login_state.go index 652a9d0f51f60..34376dd889dc5 100644 --- a/lib/services/user_login_state.go +++ b/lib/services/user_login_state.go @@ -90,3 +90,24 @@ func UnmarshalUserLoginState(data []byte, opts ...MarshalOption) (*userloginstat return uls, nil } + +// UserOrLoginStateGetter defines an interface that can get user login states or users. +type UserOrLoginStateGetter interface { + UserLoginStatesGetter + UserGetter +} + +// GetUserOrLoginState will return the given user or the login state associated with the user. +func GetUserOrLoginState(ctx context.Context, getter UserOrLoginStateGetter, username string) (UserState, error) { + uls, err := getter.GetUserLoginState(ctx, username) + if err != nil && !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + + if err == nil { + return uls, nil + } + + user, err := getter.GetUser(username, false) + return user, trace.Wrap(err) +} diff --git a/tool/tctl/common/access_request_command.go b/tool/tctl/common/access_request_command.go index d2db54a91515c..d021180ba074d 100644 --- a/tool/tctl/common/access_request_command.go +++ b/tool/tctl/common/access_request_command.go @@ -276,7 +276,14 @@ func (c *AccessRequestCommand) Create(ctx context.Context, client auth.ClientI) req.SetRequestReason(c.reason) if c.dryRun { - err = services.ValidateAccessRequestForUser(ctx, clockwork.NewRealClock(), client, req, tlsca.Identity{}, services.ExpandVars(true)) + users := &struct { + auth.ClientI + services.UserLoginStatesGetter + }{ + ClientI: client, + UserLoginStatesGetter: client.UserLoginStateClient(), + } + err = services.ValidateAccessRequestForUser(ctx, clockwork.NewRealClock(), users, req, tlsca.Identity{}, services.ExpandVars(true)) if err != nil { return trace.Wrap(err) }