diff --git a/lib/auth/userloginstate/generator.go b/lib/auth/userloginstate/generator.go index 8fed79ba604bc..fdc916725514d 100644 --- a/lib/auth/userloginstate/generator.go +++ b/lib/auth/userloginstate/generator.go @@ -222,25 +222,14 @@ func (g *Generator) postProcess(ctx context.Context, state *userloginstate.UserL return nil } - // Remove roles that don't exist in the backend so that we don't generate certs for non-existent roles. - // Doing so can prevent login from working properly. This could occur if access lists refer to roles that - // no longer exist, for example. - roles, err := g.access.GetRoles(ctx) - if err != nil { - return trace.Wrap(err) - } - - roleLookup := map[string]bool{} - for _, role := range roles { - roleLookup[role.GetName()] = true - } - - existingRoles := []string{} + // Make sure all the roles exist. If they don't, error out. + var existingRoles []string for _, role := range state.Spec.Roles { - if roleLookup[role] { + _, err := g.access.GetRole(ctx, role) + if err == nil { existingRoles = append(existingRoles, role) } else { - g.log.Warnf("Role %s does not exist when trying to add user login state, will be skipped", role) + return trace.Wrap(err) } } state.Spec.Roles = existingRoles diff --git a/lib/auth/userloginstate/generator_test.go b/lib/auth/userloginstate/generator_test.go index 5069caffa786c..0acd5b9ea1b15 100644 --- a/lib/auth/userloginstate/generator_test.go +++ b/lib/auth/userloginstate/generator_test.go @@ -23,6 +23,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" @@ -64,15 +65,17 @@ func TestAccessLists(t *testing.T) { members []*accesslist.AccessListMember locks []types.Lock roles []string + wantErr require.ErrorAssertionFunc expected *userloginstate.UserLoginState expectedRoleCount int expectedTraitCount int }{ { - name: "access lists are empty", - user: user, - cloud: true, - roles: []string{"orole1"}, + name: "access lists are empty", + user: user, + cloud: true, + roles: []string{"orole1"}, + wantErr: require.NoError, expected: newUserLoginState(t, "user", map[string]string{ "label1": "value1", @@ -101,6 +104,7 @@ func TestAccessLists(t *testing.T) { }, members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1", "role1", "role2"}, + wantErr: require.NoError, expected: newUserLoginState(t, "user", map[string]string{ "label1": "value1", @@ -131,7 +135,8 @@ func TestAccessLists(t *testing.T) { locks: []types.Lock{ newUserLock(t, "test-lock", user.GetName()), }, - roles: []string{"orole1", "role1", "role2"}, + roles: []string{"orole1", "role1", "role2"}, + wantErr: require.NoError, expected: newUserLoginState(t, "user", map[string]string{ "label1": "value1", @@ -160,6 +165,7 @@ func TestAccessLists(t *testing.T) { }, members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1", "role1", "role2"}, + wantErr: require.NoError, expected: newUserLoginState(t, "user", map[string]string{ "label1": "value1", @@ -188,18 +194,9 @@ func TestAccessLists(t *testing.T) { }, members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1"}, - expected: newUserLoginState(t, "user", - map[string]string{ - "label1": "value1", - "label2": "value2", - userloginstate.OriginalRolesAndTraitsSet: "true", - }, - []string{"orole1"}, - trait.Traits{"otrait1": {"value1", "value2"}}, - []string{"orole1"}, - trait.Traits{"otrait1": {"value1", "value2"}, "trait1": {"value1", "value2"}, "trait2": {"value3"}}), - expectedRoleCount: 0, - expectedTraitCount: 3, + wantErr: func(tt require.TestingT, err error, i ...interface{}) { + require.ErrorIs(t, err, trace.NotFound("role role1 is not found")) + }, }, { name: "access lists only a member of some lists", @@ -216,6 +213,7 @@ func TestAccessLists(t *testing.T) { }, members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "not-user")...), roles: []string{"orole1", "role1", "role2"}, + wantErr: require.NoError, expected: newUserLoginState(t, "user", map[string]string{ "label1": "value1", @@ -240,6 +238,7 @@ func TestAccessLists(t *testing.T) { }, members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1", "role1", "role2", "role3"}, + wantErr: require.NoError, expected: newUserLoginState(t, "user", map[string]string{ "label1": "value1", @@ -273,6 +272,7 @@ func TestAccessLists(t *testing.T) { }, members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1"}, + wantErr: require.NoError, expected: newUserLoginState(t, "user", map[string]string{ "label1": "value1", @@ -305,6 +305,7 @@ func TestAccessLists(t *testing.T) { }, members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"role1"}, + wantErr: require.NoError, expected: newUserLoginState(t, "user", map[string]string{ userloginstate.OriginalRolesAndTraitsSet: "true", @@ -355,7 +356,12 @@ func TestAccessLists(t *testing.T) { } state, err := svc.Generate(ctx, test.user) - require.NoError(t, err) + test.wantErr(t, err) + + if err != nil { + return + } + require.Empty(t, cmp.Diff(test.expected, state, cmpopts.SortSlices(func(str1, str2 string) bool { return str1 < str2