diff --git a/lib/auth/userloginstate/generator.go b/lib/auth/userloginstate/generator.go index af50a22bcc816..d9f5e5c309b6e 100644 --- a/lib/auth/userloginstate/generator.go +++ b/lib/auth/userloginstate/generator.go @@ -34,13 +34,19 @@ import ( "github.com/gravitational/teleport/lib/tlsca" ) +// AccessListsAndLockGetter is an interface for retrieving access lists and locks. +type AccessListsAndLockGetter interface { + services.AccessListsGetter + services.LockGetter +} + // GeneratorConfig is the configuration for the user login state generator. type GeneratorConfig struct { // Log is a logger to use for the generator. Log *logrus.Entry - // AccessLists is a service for retrieving access lists from the backend. - AccessLists services.AccessListsGetter + // AccessLists is a service for retrieving access lists and locks from the backend. + AccessLists AccessListsAndLockGetter // Access is a service that will be used for retrieving roles from the backend. Access services.Access @@ -88,11 +94,12 @@ func (g *GeneratorConfig) CheckAndSetDefaults() error { // Generator will generate a user login state from a user. type Generator struct { - log *logrus.Entry - accessLists services.AccessListsGetter - access services.Access - usageEvents UsageEventsClient - clock clockwork.Clock + log *logrus.Entry + accessLists AccessListsAndLockGetter + access services.Access + usageEvents UsageEventsClient + memberChecker *services.AccessListMembershipChecker + clock clockwork.Clock } // NewGenerator creates a new user login state generator. @@ -102,11 +109,12 @@ func NewGenerator(config GeneratorConfig) (*Generator, error) { } return &Generator{ - log: config.Log, - accessLists: config.AccessLists, - access: config.Access, - usageEvents: config.UsageEvents, - clock: config.Clock, + log: config.Log, + accessLists: config.AccessLists, + access: config.Access, + usageEvents: config.UsageEvents, + memberChecker: services.NewAccessListMembershipChecker(config.Clock, config.AccessLists, config.Access), + clock: config.Clock, }, nil } @@ -171,7 +179,7 @@ func (g *Generator) addAccessListsToState(ctx context.Context, user types.User, for _, accessList := range accessLists { // Check that the user meets the access list requirements. - if err := services.IsAccessListMember(ctx, identity, g.clock, accessList, g.accessLists); err != nil { + if err := g.memberChecker.IsAccessListMember(ctx, identity, accessList); err != nil { continue } diff --git a/lib/auth/userloginstate/generator_test.go b/lib/auth/userloginstate/generator_test.go index 104b2eeb4b620..23b95e922d682 100644 --- a/lib/auth/userloginstate/generator_test.go +++ b/lib/auth/userloginstate/generator_test.go @@ -58,6 +58,7 @@ func TestAccessLists(t *testing.T) { cloud bool accessLists []*accesslist.AccessList members []*accesslist.AccessListMember + locks []types.Lock roles []string expected *userloginstate.UserLoginState expectedRoleCount int @@ -97,6 +98,31 @@ func TestAccessLists(t *testing.T) { expectedRoleCount: 2, expectedTraitCount: 3, }, + { + name: "lock prevents adding roles and traits", + user: user, + cloud: true, + accessLists: []*accesslist.AccessList{ + newAccessList(t, clock, "1", []string{"role1"}, trait.Traits{ + "trait1": []string{"value1"}, + }), + newAccessList(t, clock, "2", []string{"role2"}, trait.Traits{ + "trait1": []string{"value2"}, + "trait2": []string{"value3"}, + }), + }, + members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), + locks: []types.Lock{ + newUserLock(t, "test-lock", user.GetName()), + }, + roles: []string{"orole1", "role1", "role2"}, + expected: newUserLoginState(t, "user", + []string{"orole1"}, + []string{"orole1"}, + trait.Traits{"otrait1": []string{"value1", "value2"}}), + expectedRoleCount: 0, + expectedTraitCount: 0, + }, { name: "access lists add roles and traits (cloud disabled)", user: user, @@ -267,6 +293,10 @@ func TestAccessLists(t *testing.T) { require.NoError(t, backendSvc.UpsertRole(ctx, role)) } + for _, lock := range test.locks { + require.NoError(t, backendSvc.UpsertLock(ctx, lock)) + } + state, err := svc.Generate(ctx, test.user) require.NoError(t, err) require.Empty(t, cmp.Diff(test.expected, state, @@ -277,6 +307,7 @@ func TestAccessLists(t *testing.T) { if test.expectedRoleCount == 0 && test.expectedTraitCount == 0 { require.Nil(t, backendSvc.event) } else { + require.NotNil(t, backendSvc.event) require.IsType(t, &usageeventsv1.UsageEventOneOf_AccessListGrantsToUser{}, backendSvc.event.Event) event := (backendSvc.event.Event).(*usageeventsv1.UsageEventOneOf_AccessListGrantsToUser) @@ -394,3 +425,16 @@ func newUserLoginState(t *testing.T, name string, originalRoles, roles []string, return uls } + +func newUserLock(t *testing.T, name string, username string) types.Lock { + t.Helper() + + lock, err := types.NewLock(name, types.LockSpecV2{ + Target: types.LockTarget{ + User: username, + }, + }) + require.NoError(t, err) + + return lock +} diff --git a/lib/modules/modules.go b/lib/modules/modules.go index 4522069de26a1..fc0aae88cfc5f 100644 --- a/lib/modules/modules.go +++ b/lib/modules/modules.go @@ -148,6 +148,9 @@ type AccessResourcesGetter interface { GetUser(userName string, withSecrets bool) (types.User, error) GetRole(ctx context.Context, name string) (types.Role, error) + + GetLock(ctx context.Context, name string) (types.Lock, error) + GetLocks(ctx context.Context, inForceOnly bool, targets ...types.LockTarget) ([]types.Lock, error) } // Modules defines interface that external libraries can implement customizing diff --git a/lib/services/access_list.go b/lib/services/access_list.go index 73689e8b1ba8f..3b1ba3f9d4abd 100644 --- a/lib/services/access_list.go +++ b/lib/services/access_list.go @@ -25,6 +25,7 @@ import ( accesslistclient "github.com/gravitational/teleport/api/client/accesslist" accesslistv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accesslist/v1" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/accesslist" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -210,11 +211,42 @@ func IsAccessListOwner(identity tlsca.Identity, accessList *accesslist.AccessLis return nil } +// AccessListMembershipChecker will check if users are members of an access list and +// makes sure the user is not locked and meets membership requirements. +type AccessListMembershipChecker struct { + members AccessListMembersGetter + locks LockGetter + clock clockwork.Clock +} + +// NewAccessListMembershipChecker will create a new access list membership checker. +func NewAccessListMembershipChecker(clock clockwork.Clock, members AccessListMembersGetter, locks LockGetter) *AccessListMembershipChecker { + return &AccessListMembershipChecker{ + members: members, + locks: locks, + clock: clock, + } +} + // IsAccessListMember will return true if the user is a member for the current list. -func IsAccessListMember(ctx context.Context, identity tlsca.Identity, clock clockwork.Clock, accessList *accesslist.AccessList, memberGetter AccessListMembersGetter) error { +func (a AccessListMembershipChecker) IsAccessListMember(ctx context.Context, identity tlsca.Identity, accessList *accesslist.AccessList) error { username := identity.Username - member, err := memberGetter.GetAccessListMember(ctx, accessList.GetName(), username) + // Allow for nil locks while we transition away from using `IsAccessListMember` outside of this struct. + if a.locks != nil { + locks, err := a.locks.GetLocks(ctx, true, types.LockTarget{ + User: username, + }) + if err != nil { + return trace.Wrap(err) + } + + if len(locks) > 0 { + return trace.AccessDenied("user %s is currently locked", username) + } + } + + member, err := a.members.GetAccessListMember(ctx, accessList.GetName(), username) if trace.IsNotFound(err) { // The member has not been found, so we know they're not a member of this list. return trace.NotFound("user %s is not a member of the access list", username) @@ -228,7 +260,7 @@ func IsAccessListMember(ctx context.Context, identity tlsca.Identity, clock cloc return nil } - if !clock.Now().Before(expires) { + if !a.clock.Now().Before(expires) { return trace.AccessDenied("user %s's membership has expired in the access list", username) } @@ -238,6 +270,17 @@ func IsAccessListMember(ctx context.Context, identity tlsca.Identity, clock cloc return nil } +// TODO(mdwn): Remove this in favor of using the access list membership checker. +func IsAccessListMember(ctx context.Context, identity tlsca.Identity, clock clockwork.Clock, accessList *accesslist.AccessList, members AccessListMembersGetter) error { + // See if the member getter also implements lock getter. If so, use it. Otherwise, nil is fine. + lockGetter, _ := members.(LockGetter) + return AccessListMembershipChecker{ + members: members, + locks: lockGetter, + clock: clock, + }.IsAccessListMember(ctx, identity, accessList) +} + // UserMeetsRequirements will return true if the user meets the requirements for the access list. func UserMeetsRequirements(identity tlsca.Identity, requires accesslist.Requires) bool { // Assemble the user's roles for easy look up. diff --git a/lib/services/access_list_test.go b/lib/services/access_list_test.go index 1de31a7e60c1e..8c511245f1c8a 100644 --- a/lib/services/access_list_test.go +++ b/lib/services/access_list_test.go @@ -25,6 +25,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/accesslist" "github.com/gravitational/teleport/api/types/header" "github.com/gravitational/teleport/api/types/trait" @@ -265,13 +266,14 @@ func TestIsAccessListOwner(t *testing.T) { } } -// testMembersGetter implements AccessListMembersGetter for testing. -type testMembersGetter struct { +// testMembersAndLockGetter implements AccessListMembersGetter and LockGetter for testing. +type testMembersAndLockGetter struct { members map[string]map[string]*accesslist.AccessListMember + locks map[string]types.Lock } // ListAccessListMembers returns a paginated list of all access list members. -func (t *testMembersGetter) ListAccessListMembers(ctx context.Context, accessList string, _ int, _ string) (members []*accesslist.AccessListMember, nextToken string, err error) { +func (t *testMembersAndLockGetter) ListAccessListMembers(ctx context.Context, accessList string, _ int, _ string) (members []*accesslist.AccessListMember, nextToken string, err error) { for _, member := range t.members[accessList] { members = append(members, member) } @@ -279,7 +281,7 @@ func (t *testMembersGetter) ListAccessListMembers(ctx context.Context, accessLis } // GetAccessListMember returns the specified access list member resource. -func (t *testMembersGetter) GetAccessListMember(ctx context.Context, accessList string, memberName string) (*accesslist.AccessListMember, error) { +func (t *testMembersAndLockGetter) GetAccessListMember(ctx context.Context, accessList string, memberName string) (*accesslist.AccessListMember, error) { members, ok := t.members[accessList] if !ok { return nil, trace.NotFound("not found") @@ -293,12 +295,36 @@ func (t *testMembersGetter) GetAccessListMember(ctx context.Context, accessList return member, nil } -func TestIsAccessListMember(t *testing.T) { +// GetLock gets a lock by name. +func (t *testMembersAndLockGetter) GetLock(_ context.Context, name string) (types.Lock, error) { + if t.locks == nil { + return nil, trace.NotFound("not found") + } + + lock, ok := t.locks[name] + if !ok { + return nil, trace.NotFound("not found") + } + + return lock, nil +} + +// GetLocks gets all/in-force locks that match at least one of the targets when specified. +func (t *testMembersAndLockGetter) GetLocks(ctx context.Context, inForceOnly bool, targets ...types.LockTarget) ([]types.Lock, error) { + locks := make([]types.Lock, 0, len(t.locks)) + for _, lock := range t.locks { + locks = append(locks, lock) + } + return locks, nil +} + +func TestIsAccessListMemberChecker(t *testing.T) { tests := []struct { name string identity tlsca.Identity memberCtx context.Context currentTime time.Time + locks map[string]types.Lock errAssertionFunc require.ErrorAssertionFunc }{ { @@ -314,6 +340,24 @@ func TestIsAccessListMember(t *testing.T) { currentTime: time.Date(2023, 2, 1, 0, 0, 0, 0, time.UTC), errAssertionFunc: require.NoError, }, + { + name: "is locked member", + identity: tlsca.Identity{ + Username: member1, + Groups: []string{"mrole1", "mrole2"}, + Traits: map[string][]string{ + "mtrait1": {"mvalue1", "mvalue2"}, + "mtrait2": {"mvalue3", "mvalue4"}, + }, + }, + locks: map[string]types.Lock{ + "test-lock": newUserLock(t, "test-lock", member1), + }, + currentTime: time.Date(2023, 2, 1, 0, 0, 0, 0, time.UTC), + errAssertionFunc: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorIs(t, err, trace.AccessDenied("user %s is currently locked", member1)) + }, + }, { name: "is not a member", identity: tlsca.Identity{ @@ -407,9 +451,10 @@ func TestIsAccessListMember(t *testing.T) { } memberMap[accessListName][member.Spec.Name] = member } - getter := &testMembersGetter{members: memberMap} + getter := &testMembersAndLockGetter{members: memberMap, locks: test.locks} - test.errAssertionFunc(t, IsAccessListMember(ctx, test.identity, clockwork.NewFakeClockAt(test.currentTime), accessList, getter)) + checker := NewAccessListMembershipChecker(clockwork.NewFakeClockAt(test.currentTime), getter, getter) + test.errAssertionFunc(t, checker.IsAccessListMember(ctx, test.identity, accessList)) }) } } @@ -761,3 +806,16 @@ spec: review_frequency_changed: 3 months review_day_of_month_changed: "15" ` + +func newUserLock(t *testing.T, name, user string) types.Lock { + t.Helper() + + lock, err := types.NewLock(name, types.LockSpecV2{ + Target: types.LockTarget{ + User: user, + }, + }) + require.NoError(t, err) + + return lock +}