Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 5 additions & 58 deletions lib/accesslists/hierarchy.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,65 +385,12 @@ func IsAccessListMember(
}
}

members, err := fetchMembers(ctx, accessList.GetName(), g)
if err != nil {
return accesslistv1.AccessListUserAssignmentType_ACCESS_LIST_USER_ASSIGNMENT_TYPE_UNSPECIFIED, trace.Wrap(err, "fetching access list %q members", accessList.GetName())
}

var membershipErr error

for _, member := range members {
// Is user an explicit member?
if member.Spec.MembershipKind != accesslist.MembershipKindList && member.GetName() == user.GetName() {
if !UserMeetsRequirements(user, accessList.Spec.MembershipRequires) {
// Avoid non-deterministic behavior in these checks. Rather than returning immediately, continue
// through all members to make sure there isn't a valid match later on.
membershipErr = trace.AccessDenied("User '%s' does not meet the membership requirements for Access List '%s'", user.GetName(), accessList.Spec.Title)
continue
}
if !member.Spec.Expires.IsZero() && !clock.Now().Before(member.Spec.Expires) {
membershipErr = trace.AccessDenied("User '%s's membership in Access List '%s' has expired", user.GetName(), accessList.Spec.Title)
continue
}
return accesslistv1.AccessListUserAssignmentType_ACCESS_LIST_USER_ASSIGNMENT_TYPE_EXPLICIT, nil
}
// Is user an inherited member through any potential member AccessLists?
if member.Spec.MembershipKind == accesslist.MembershipKindList {
memberAccessList, err := g.GetAccessList(ctx, member.GetName())
if err != nil {
if trace.IsNotFound(err) {
continue
}
return accesslistv1.AccessListUserAssignmentType_ACCESS_LIST_USER_ASSIGNMENT_TYPE_UNSPECIFIED, trace.Wrap(err, "getting access list %q", member.GetName())
}
// Since we already verified that the user is not locked, don't provide lockGetter here
membershipType, err := IsAccessListMember(ctx, user, memberAccessList, g, nil, clock)
if err != nil {
if trace.IsAccessDenied(err) {
membershipErr = err
continue
}
return accesslistv1.AccessListUserAssignmentType_ACCESS_LIST_USER_ASSIGNMENT_TYPE_UNSPECIFIED, trace.Wrap(err)
}
if membershipType != accesslistv1.AccessListUserAssignmentType_ACCESS_LIST_USER_ASSIGNMENT_TYPE_UNSPECIFIED {
if !UserMeetsRequirements(user, accessList.Spec.MembershipRequires) {
membershipErr = trace.AccessDenied("User '%s' does not meet the membership requirements for Access List '%s'", user.GetName(), accessList.Spec.Title)
continue
}
if !member.Spec.Expires.IsZero() && !clock.Now().Before(member.Spec.Expires) {
membershipErr = trace.AccessDenied("User '%s's membership in Access List '%s' has expired", user.GetName(), accessList.Spec.Title)
continue
}
return accesslistv1.AccessListUserAssignmentType_ACCESS_LIST_USER_ASSIGNMENT_TYPE_INHERITED, nil
}
}
}

if membershipErr == nil {
membershipErr = trace.AccessDenied("no access path found")
cfg := walkConfig{
getter: g,
root: accessList,
}

return accesslistv1.AccessListUserAssignmentType_ACCESS_LIST_USER_ASSIGNMENT_TYPE_UNSPECIFIED, trace.Wrap(membershipErr)
return isAccessListMember(ctx, user, cfg, clock.Now())
}

// UserMeetsRequirements is a helper which will return whether the User meets the AccessList Ownership/MembershipRequires.
Expand Down Expand Up @@ -528,7 +475,7 @@ func withUserRequirementsCheck(user types.User, clock clockwork.Clock) ancestorO
type HierarchyConfig struct {
// AccessListService is used to fetch Access Lists and their members.
AccessListsService AccessListAndMembersGetter
// Getter is used to fetch Access Lists and their members.
// Clock is used to check if memberships are expired.
Clock clockwork.Clock
}

Expand Down
221 changes: 219 additions & 2 deletions lib/accesslists/hierarchy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"iter"
"slices"
"sort"
"strconv"
"testing"
"time"

Expand Down Expand Up @@ -432,7 +433,6 @@ func TestAccessListIsMember_NestedRequirements(t *testing.T) {
})

t.Run("cyclic graph, no membership", func(t *testing.T) {
t.Skip("cyclic graph not supported yet")
firstList := newAccessList(t, "first", clock)
secondList := newAccessList(t, "second", clock)
thirdList := newAccessList(t, "third", clock)
Expand Down Expand Up @@ -475,7 +475,6 @@ func TestAccessListIsMember_NestedRequirements(t *testing.T) {
})

t.Run("cyclic graph, user membership", func(t *testing.T) {
t.Skip("cyclic graph not supported yet")
firstList := newAccessList(t, "first", clock)
secondList := newAccessList(t, "second", clock)
thirdList := newAccessList(t, "third", clock)
Expand Down Expand Up @@ -981,3 +980,221 @@ func newAccessListMember(t *testing.T, accessListName, memberName string, member
require.NoError(t, err)
return member
}

func generateAccessList(name string) *accesslist.AccessList {
return &accesslist.AccessList{
ResourceHeader: header.ResourceHeader{
Metadata: header.Metadata{
Name: name,
},
},
}
}

func generateNestedALs(level, directMembers int, rootListName, userName string) (map[string]*accesslist.AccessList, map[string][]*accesslist.AccessListMember) {
accesslists := []*accesslist.AccessList{generateAccessList(rootListName)}
members := make(map[string][]*accesslist.AccessListMember)

for i := range level - 1 {
parentName := accesslists[i].GetName()
name := "nested-al-" + strconv.Itoa(i)
accesslists = append(accesslists, generateAccessList(name))
listMembers := generateUserMembers(directMembers/2, name)
listMembers = append(listMembers, &accesslist.AccessListMember{
ResourceHeader: header.ResourceHeader{
Metadata: header.Metadata{
Name: name,
},
},
Spec: accesslist.AccessListMemberSpec{
AccessList: parentName,
Name: name,
MembershipKind: accesslist.MembershipKindList,
},
})
listMembers = append(listMembers, generateUserMembers(directMembers/2+directMembers%2, name)...)
members[parentName] = listMembers
}

alMap := make(map[string]*accesslist.AccessList)
for _, al := range accesslists {
alMap[al.GetName()] = al
}
return alMap, members
}

func generateUserMembers(count int, alName string) []*accesslist.AccessListMember {
var members []*accesslist.AccessListMember
for i := range count {
memberName := "member-" + strconv.Itoa(i)
members = append(members, &accesslist.AccessListMember{
ResourceHeader: header.ResourceHeader{
Metadata: header.Metadata{
Name: memberName,
},
},
Spec: accesslist.AccessListMemberSpec{
AccessList: alName,
Name: memberName,
MembershipKind: accesslist.MembershipKindUser,
},
})
}
return members
}

func BenchmarkIsAccessListMember(b *testing.B) {
const mainAccessListName = "main-al"
const testUserName = "test-user"

lockGetter := &mockLocksGetter{}
clock := clockwork.NewFakeClock()

b.Run("no accessPaths", func(b *testing.B) {
mock := &mockAccessListAndMembersGetter{
accessLists: map[string]*accesslist.AccessList{
mainAccessListName: generateAccessList(mainAccessListName),
},
members: map[string][]*accesslist.AccessListMember{
mainAccessListName: {},
},
}

for b.Loop() {
_, err := IsAccessListMember(
b.Context(),
&types.UserV2{Metadata: types.Metadata{Name: testUserName}},
generateAccessList(mainAccessListName),
mock,
lockGetter,
clock)
if err != nil {
b.Fatal(err)
}
}
})

b.Run("single-page direct member", func(b *testing.B) {
member := &accesslist.AccessListMember{
ResourceHeader: header.ResourceHeader{
Metadata: header.Metadata{
Name: testUserName,
},
},
Spec: accesslist.AccessListMemberSpec{
AccessList: mainAccessListName,
Name: testUserName,
MembershipKind: accesslist.MembershipKindUser,
},
}
generatedMembers := generateUserMembers(50, mainAccessListName)
// We inject the member we are looking for in the middle of the member list
members := append(generatedMembers[:25], member)
members = append(members, generatedMembers[25:]...)

mock := &mockAccessListAndMembersGetter{
accessLists: map[string]*accesslist.AccessList{
mainAccessListName: generateAccessList(mainAccessListName),
},
members: map[string][]*accesslist.AccessListMember{
mainAccessListName: members,
},
}

for b.Loop() {
_, err := IsAccessListMember(
b.Context(),
&types.UserV2{Metadata: types.Metadata{Name: testUserName}},
generateAccessList(mainAccessListName),
mock,
lockGetter,
clock)
if err != nil {
b.Fatal(err)
}
}
})

b.Run("multiple-pages direct member", func(b *testing.B) {
member := &accesslist.AccessListMember{
ResourceHeader: header.ResourceHeader{
Metadata: header.Metadata{
Name: testUserName,
},
},
Spec: accesslist.AccessListMemberSpec{
AccessList: mainAccessListName,
Name: testUserName,
MembershipKind: accesslist.MembershipKindUser,
},
}
generatedMembers := generateUserMembers(500, mainAccessListName)
// We inject the member we are looking for in the middle of the member list
members := append(generatedMembers[:250], member)
members = append(members, generatedMembers[250:]...)

mock := &mockAccessListAndMembersGetter{
accessLists: map[string]*accesslist.AccessList{
mainAccessListName: generateAccessList(mainAccessListName),
},
members: map[string][]*accesslist.AccessListMember{
mainAccessListName: members,
},
}

for b.Loop() {
_, err := IsAccessListMember(
b.Context(),
&types.UserV2{Metadata: types.Metadata{Name: testUserName}},
generateAccessList(mainAccessListName),
mock,
lockGetter,
clock)
if err != nil {
b.Fatal(err)
}
}
})

b.Run("single-page nested member", func(b *testing.B) {
lists, members := generateNestedALs(5, 0, mainAccessListName, testUserName)
mock := &mockAccessListAndMembersGetter{
accessLists: lists,
members: members,
}

for b.Loop() {
_, err := IsAccessListMember(
b.Context(),
&types.UserV2{Metadata: types.Metadata{Name: testUserName}},
generateAccessList(mainAccessListName),
mock,
lockGetter,
clock)
if err != nil {
b.Fatal(err)
}
}
})

b.Run("multiple pages nested member", func(b *testing.B) {
lists, members := generateNestedALs(5, 501, mainAccessListName, testUserName)
mock := &mockAccessListAndMembersGetter{
accessLists: lists,
members: members,
}

for b.Loop() {
_, err := IsAccessListMember(
b.Context(),
&types.UserV2{Metadata: types.Metadata{Name: testUserName}},
generateAccessList(mainAccessListName),
mock,
lockGetter,
clock)
if err != nil {
b.Fatal(err)
}
}
})
}
Loading
Loading