Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions lib/auth/userloginstate/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,19 @@ import (
"github.com/gravitational/teleport/lib/tlsca"
)

// AccessListsAndLockGetter is an interface for retrieving access lists and locks.
type AccessListsAndLockGetter interface {
services.AccessListsGetter
services.LockGetter
}

// GeneratorConfig is the configuration for the user login state generator.
type GeneratorConfig struct {
// Log is a logger to use for the generator.
Log *logrus.Entry

// AccessLists is a service for retrieving access lists from the backend.
AccessLists services.AccessListsGetter
// AccessLists is a service for retrieving access lists and locks from the backend.
AccessLists AccessListsAndLockGetter

// Access is a service that will be used for retrieving roles from the backend.
Access services.Access
Expand Down Expand Up @@ -88,11 +94,12 @@ func (g *GeneratorConfig) CheckAndSetDefaults() error {

// Generator will generate a user login state from a user.
type Generator struct {
log *logrus.Entry
accessLists services.AccessListsGetter
access services.Access
usageEvents UsageEventsClient
clock clockwork.Clock
log *logrus.Entry
accessLists AccessListsAndLockGetter
access services.Access
usageEvents UsageEventsClient
memberChecker *services.AccessListMembershipChecker
clock clockwork.Clock
}

// NewGenerator creates a new user login state generator.
Expand All @@ -102,11 +109,12 @@ func NewGenerator(config GeneratorConfig) (*Generator, error) {
}

return &Generator{
log: config.Log,
accessLists: config.AccessLists,
access: config.Access,
usageEvents: config.UsageEvents,
clock: config.Clock,
log: config.Log,
accessLists: config.AccessLists,
access: config.Access,
usageEvents: config.UsageEvents,
memberChecker: services.NewAccessListMembershipChecker(config.Clock, config.AccessLists, config.Access),
clock: config.Clock,
}, nil
}

Expand Down Expand Up @@ -171,7 +179,7 @@ func (g *Generator) addAccessListsToState(ctx context.Context, user types.User,

for _, accessList := range accessLists {
// Check that the user meets the access list requirements.
if err := services.IsAccessListMember(ctx, identity, g.clock, accessList, g.accessLists); err != nil {
if err := g.memberChecker.IsAccessListMember(ctx, identity, accessList); err != nil {
continue
}

Expand Down
44 changes: 44 additions & 0 deletions lib/auth/userloginstate/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func TestAccessLists(t *testing.T) {
cloud bool
accessLists []*accesslist.AccessList
members []*accesslist.AccessListMember
locks []types.Lock
roles []string
expected *userloginstate.UserLoginState
expectedRoleCount int
Expand Down Expand Up @@ -97,6 +98,31 @@ func TestAccessLists(t *testing.T) {
expectedRoleCount: 2,
expectedTraitCount: 3,
},
{
name: "lock prevents adding roles and traits",
user: user,
cloud: true,
accessLists: []*accesslist.AccessList{
newAccessList(t, clock, "1", []string{"role1"}, trait.Traits{
"trait1": []string{"value1"},
}),
newAccessList(t, clock, "2", []string{"role2"}, trait.Traits{
"trait1": []string{"value2"},
"trait2": []string{"value3"},
}),
},
members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...),
locks: []types.Lock{
newUserLock(t, "test-lock", user.GetName()),
},
roles: []string{"orole1", "role1", "role2"},
expected: newUserLoginState(t, "user",
[]string{"orole1"},
[]string{"orole1"},
trait.Traits{"otrait1": []string{"value1", "value2"}}),
expectedRoleCount: 0,
expectedTraitCount: 0,
},
{
name: "access lists add roles and traits (cloud disabled)",
user: user,
Expand Down Expand Up @@ -267,6 +293,10 @@ func TestAccessLists(t *testing.T) {
require.NoError(t, backendSvc.UpsertRole(ctx, role))
}

for _, lock := range test.locks {
require.NoError(t, backendSvc.UpsertLock(ctx, lock))
}

state, err := svc.Generate(ctx, test.user)
require.NoError(t, err)
require.Empty(t, cmp.Diff(test.expected, state,
Expand All @@ -277,6 +307,7 @@ func TestAccessLists(t *testing.T) {
if test.expectedRoleCount == 0 && test.expectedTraitCount == 0 {
require.Nil(t, backendSvc.event)
} else {
require.NotNil(t, backendSvc.event)
require.IsType(t, &usageeventsv1.UsageEventOneOf_AccessListGrantsToUser{}, backendSvc.event.Event)
event := (backendSvc.event.Event).(*usageeventsv1.UsageEventOneOf_AccessListGrantsToUser)

Expand Down Expand Up @@ -394,3 +425,16 @@ func newUserLoginState(t *testing.T, name string, originalRoles, roles []string,

return uls
}

func newUserLock(t *testing.T, name string, username string) types.Lock {
t.Helper()

lock, err := types.NewLock(name, types.LockSpecV2{
Target: types.LockTarget{
User: username,
},
})
require.NoError(t, err)

return lock
}
3 changes: 3 additions & 0 deletions lib/modules/modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ type AccessResourcesGetter interface {

GetUser(userName string, withSecrets bool) (types.User, error)
GetRole(ctx context.Context, name string) (types.Role, error)

GetLock(ctx context.Context, name string) (types.Lock, error)
GetLocks(ctx context.Context, inForceOnly bool, targets ...types.LockTarget) ([]types.Lock, error)
}

// Modules defines interface that external libraries can implement customizing
Expand Down
49 changes: 46 additions & 3 deletions lib/services/access_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

accesslistclient "github.com/gravitational/teleport/api/client/accesslist"
accesslistv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accesslist/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -210,11 +211,42 @@ func IsAccessListOwner(identity tlsca.Identity, accessList *accesslist.AccessLis
return nil
}

// AccessListMembershipChecker will check if users are members of an access list and
// makes sure the user is not locked and meets membership requirements.
type AccessListMembershipChecker struct {
members AccessListMembersGetter
locks LockGetter
clock clockwork.Clock
}

// NewAccessListMembershipChecker will create a new access list membership checker.
func NewAccessListMembershipChecker(clock clockwork.Clock, members AccessListMembersGetter, locks LockGetter) *AccessListMembershipChecker {
return &AccessListMembershipChecker{
members: members,
locks: locks,
clock: clock,
}
}

// IsAccessListMember will return true if the user is a member for the current list.
func IsAccessListMember(ctx context.Context, identity tlsca.Identity, clock clockwork.Clock, accessList *accesslist.AccessList, memberGetter AccessListMembersGetter) error {
func (a AccessListMembershipChecker) IsAccessListMember(ctx context.Context, identity tlsca.Identity, accessList *accesslist.AccessList) error {
username := identity.Username

member, err := memberGetter.GetAccessListMember(ctx, accessList.GetName(), username)
// Allow for nil locks while we transition away from using `IsAccessListMember` outside of this struct.
if a.locks != nil {
locks, err := a.locks.GetLocks(ctx, true, types.LockTarget{
User: username,
})
if err != nil {
return trace.Wrap(err)
}

if len(locks) > 0 {
return trace.AccessDenied("user %s is currently locked", username)
}
}

member, err := a.members.GetAccessListMember(ctx, accessList.GetName(), username)
if trace.IsNotFound(err) {
// The member has not been found, so we know they're not a member of this list.
return trace.NotFound("user %s is not a member of the access list", username)
Expand All @@ -228,7 +260,7 @@ func IsAccessListMember(ctx context.Context, identity tlsca.Identity, clock cloc
return nil
}

if !clock.Now().Before(expires) {
if !a.clock.Now().Before(expires) {
return trace.AccessDenied("user %s's membership has expired in the access list", username)
}

Expand All @@ -238,6 +270,17 @@ func IsAccessListMember(ctx context.Context, identity tlsca.Identity, clock cloc
return nil
}

// TODO(mdwn): Remove this in favor of using the access list membership checker.
func IsAccessListMember(ctx context.Context, identity tlsca.Identity, clock clockwork.Clock, accessList *accesslist.AccessList, members AccessListMembersGetter) error {
// See if the member getter also implements lock getter. If so, use it. Otherwise, nil is fine.
lockGetter, _ := members.(LockGetter)
return AccessListMembershipChecker{
members: members,
locks: lockGetter,
clock: clock,
}.IsAccessListMember(ctx, identity, accessList)
}

// UserMeetsRequirements will return true if the user meets the requirements for the access list.
func UserMeetsRequirements(identity tlsca.Identity, requires accesslist.Requires) bool {
// Assemble the user's roles for easy look up.
Expand Down
72 changes: 65 additions & 7 deletions lib/services/access_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/api/types/header"
"github.com/gravitational/teleport/api/types/trait"
Expand Down Expand Up @@ -265,21 +266,22 @@ func TestIsAccessListOwner(t *testing.T) {
}
}

// testMembersGetter implements AccessListMembersGetter for testing.
type testMembersGetter struct {
// testMembersAndLockGetter implements AccessListMembersGetter and LockGetter for testing.
type testMembersAndLockGetter struct {
members map[string]map[string]*accesslist.AccessListMember
locks map[string]types.Lock
}

// ListAccessListMembers returns a paginated list of all access list members.
func (t *testMembersGetter) ListAccessListMembers(ctx context.Context, accessList string, _ int, _ string) (members []*accesslist.AccessListMember, nextToken string, err error) {
func (t *testMembersAndLockGetter) ListAccessListMembers(ctx context.Context, accessList string, _ int, _ string) (members []*accesslist.AccessListMember, nextToken string, err error) {
for _, member := range t.members[accessList] {
members = append(members, member)
}
return members, "", nil
}

// GetAccessListMember returns the specified access list member resource.
func (t *testMembersGetter) GetAccessListMember(ctx context.Context, accessList string, memberName string) (*accesslist.AccessListMember, error) {
func (t *testMembersAndLockGetter) GetAccessListMember(ctx context.Context, accessList string, memberName string) (*accesslist.AccessListMember, error) {
members, ok := t.members[accessList]
if !ok {
return nil, trace.NotFound("not found")
Expand All @@ -293,12 +295,36 @@ func (t *testMembersGetter) GetAccessListMember(ctx context.Context, accessList
return member, nil
}

func TestIsAccessListMember(t *testing.T) {
// GetLock gets a lock by name.
func (t *testMembersAndLockGetter) GetLock(_ context.Context, name string) (types.Lock, error) {
if t.locks == nil {
return nil, trace.NotFound("not found")
}

lock, ok := t.locks[name]
if !ok {
return nil, trace.NotFound("not found")
}

return lock, nil
}

// GetLocks gets all/in-force locks that match at least one of the targets when specified.
func (t *testMembersAndLockGetter) GetLocks(ctx context.Context, inForceOnly bool, targets ...types.LockTarget) ([]types.Lock, error) {
locks := make([]types.Lock, 0, len(t.locks))
for _, lock := range t.locks {
locks = append(locks, lock)
}
return locks, nil
}

func TestIsAccessListMemberChecker(t *testing.T) {
tests := []struct {
name string
identity tlsca.Identity
memberCtx context.Context
currentTime time.Time
locks map[string]types.Lock
errAssertionFunc require.ErrorAssertionFunc
}{
{
Expand All @@ -314,6 +340,24 @@ func TestIsAccessListMember(t *testing.T) {
currentTime: time.Date(2023, 2, 1, 0, 0, 0, 0, time.UTC),
errAssertionFunc: require.NoError,
},
{
name: "is locked member",
identity: tlsca.Identity{
Username: member1,
Groups: []string{"mrole1", "mrole2"},
Traits: map[string][]string{
"mtrait1": {"mvalue1", "mvalue2"},
"mtrait2": {"mvalue3", "mvalue4"},
},
},
locks: map[string]types.Lock{
"test-lock": newUserLock(t, "test-lock", member1),
},
currentTime: time.Date(2023, 2, 1, 0, 0, 0, 0, time.UTC),
errAssertionFunc: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorIs(t, err, trace.AccessDenied("user %s is currently locked", member1))
},
},
{
name: "is not a member",
identity: tlsca.Identity{
Expand Down Expand Up @@ -407,9 +451,10 @@ func TestIsAccessListMember(t *testing.T) {
}
memberMap[accessListName][member.Spec.Name] = member
}
getter := &testMembersGetter{members: memberMap}
getter := &testMembersAndLockGetter{members: memberMap, locks: test.locks}

test.errAssertionFunc(t, IsAccessListMember(ctx, test.identity, clockwork.NewFakeClockAt(test.currentTime), accessList, getter))
checker := NewAccessListMembershipChecker(clockwork.NewFakeClockAt(test.currentTime), getter, getter)
test.errAssertionFunc(t, checker.IsAccessListMember(ctx, test.identity, accessList))
})
}
}
Expand Down Expand Up @@ -761,3 +806,16 @@ spec:
review_frequency_changed: 3 months
review_day_of_month_changed: "15"
`

func newUserLock(t *testing.T, name, user string) types.Lock {
t.Helper()

lock, err := types.NewLock(name, types.LockSpecV2{
Target: types.LockTarget{
User: user,
},
})
require.NoError(t, err)

return lock
}