diff --git a/lib/auth/bot.go b/lib/auth/bot.go index 1fa5c738eca8b..881dc708187b8 100644 --- a/lib/auth/bot.go +++ b/lib/auth/bot.go @@ -32,8 +32,6 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/api/types/header" - "github.com/gravitational/teleport/api/types/userloginstate" "github.com/gravitational/teleport/api/types/wrappers" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/authz" @@ -116,36 +114,9 @@ func createBotUser( return nil, trace.Wrap(err) } - uls, err := ulsFromUser(user) - if err != nil { - return nil, trace.Wrap(err) - } - - if _, err := s.UserLoginStates.UpsertUserLoginState(ctx, uls); err != nil { - return nil, trace.Wrap(err) - } - return user, nil } -func ulsFromUser(user types.User) (*userloginstate.UserLoginState, error) { - uls, err := userloginstate.New(header.Metadata{ - Name: user.GetName(), - Labels: map[string]string{ - types.BotLabel: user.GetMetadata().Labels[types.BotLabel], - types.BotGenerationLabel: user.GetMetadata().Labels[types.BotGenerationLabel], - }, - }, userloginstate.Spec{ - Roles: user.GetRoles(), - Traits: user.GetTraits(), - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return uls, nil -} - // createBot creates a new certificate renewal bot from a bot request. func (a *Server) createBot(ctx context.Context, req *proto.CreateBotRequest) (*proto.CreateBotResponse, error) { if req.Name == "" { @@ -441,22 +412,6 @@ func (a *Server) validateGenerationLabel(ctx context.Context, username string, c return trace.CompareFailed("Database comparison failed, try the request again") } - uls, err := a.GetUserLoginState(ctx, user.GetName()) - if err != nil && !trace.IsNotFound(err) { - return trace.Wrap(err) - } - if uls == nil { - uls, err = ulsFromUser(user) - if err != nil { - return trace.Wrap(err) - } - } - - uls.ResourceHeader.Metadata.Labels[types.BotGenerationLabel] = generation - if _, err := a.UpsertUserLoginState(ctx, uls); err != nil { - return trace.Wrap(err) - } - return nil } diff --git a/lib/auth/userloginstate/generator.go b/lib/auth/userloginstate/generator.go index 5ac0927bea9c3..a3858430cc31c 100644 --- a/lib/auth/userloginstate/generator.go +++ b/lib/auth/userloginstate/generator.go @@ -133,13 +133,17 @@ func (g *Generator) Generate(ctx context.Context, user types.User) (*userloginst } } + labels := make(map[string]string, len(user.GetAllLabels())) + for k, v := range user.GetAllLabels() { + labels[k] = v + } + labels[userloginstate.OriginalRolesAndTraitsSet] = "true" + // Create a new empty user login state. uls, err := userloginstate.New( header.Metadata{ - Name: user.GetName(), - Labels: map[string]string{ - userloginstate.OriginalRolesAndTraitsSet: "true", - }, + Name: user.GetName(), + Labels: labels, }, userloginstate.Spec{ OriginalRoles: utils.CopyStrings(user.GetRoles()), OriginalTraits: originalTraits, diff --git a/lib/auth/userloginstate/generator_test.go b/lib/auth/userloginstate/generator_test.go index 06e25ff3463ba..b9f879c0af770 100644 --- a/lib/auth/userloginstate/generator_test.go +++ b/lib/auth/userloginstate/generator_test.go @@ -44,6 +44,10 @@ import ( func TestAccessLists(t *testing.T) { user, err := types.NewUser("user") + user.SetStaticLabels(map[string]string{ + "label1": "value1", + "label2": "value2", + }) user.SetRoles([]string{"orole1"}) user.SetTraits(map[string][]string{ "otrait1": {"value1", "value2"}, @@ -72,6 +76,11 @@ func TestAccessLists(t *testing.T) { cloud: true, roles: []string{"orole1"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1"}, @@ -95,6 +104,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1", "role1", "role2"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1", "role1", "role2"}, @@ -121,6 +135,11 @@ func TestAccessLists(t *testing.T) { }, roles: []string{"orole1", "role1", "role2"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1"}, @@ -144,6 +163,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1", "role1", "role2"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1", "role1", "role2"}, @@ -167,6 +191,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1"}, @@ -190,6 +219,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "not-user")...), roles: []string{"orole1", "role1", "role2"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1", "role1"}, @@ -209,6 +243,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1", "role1", "role2", "role3"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1", "role1", "role2", "role3"}, @@ -237,6 +276,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1"}, @@ -264,6 +308,9 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"role1"}, expected: newUserLoginState(t, "user", + map[string]string{ + userloginstate.OriginalRolesAndTraitsSet: "true", + }, nil, nil, []string{"role1"}, @@ -424,15 +471,13 @@ func newAccessListMembers(t *testing.T, clock clockwork.Clock, accessList string return alMembers } -func newUserLoginState(t *testing.T, name string, originalRoles []string, originalTraits map[string][]string, +func newUserLoginState(t *testing.T, name string, labels map[string]string, originalRoles []string, originalTraits map[string][]string, roles []string, traits map[string][]string) *userloginstate.UserLoginState { t.Helper() uls, err := userloginstate.New(header.Metadata{ - Name: name, - Labels: map[string]string{ - userloginstate.OriginalRolesAndTraitsSet: "true", - }, + Name: name, + Labels: labels, }, userloginstate.Spec{ OriginalRoles: originalRoles, OriginalTraits: originalTraits, diff --git a/lib/auth/userloginstate/service_test.go b/lib/auth/userloginstate/service_test.go index 8e58925885b5d..cfec14b3bb992 100644 --- a/lib/auth/userloginstate/service_test.go +++ b/lib/auth/userloginstate/service_test.go @@ -73,8 +73,8 @@ func TestGetUserLoginStates(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1", stRoles, stTraits, stRoles, stTraits) - uls2 := newUserLoginState(t, "2", stRoles, stTraits, stRoles, stTraits) + uls1 := newUserLoginState(t, "1", nil, stRoles, stTraits, stRoles, stTraits) + uls2 := newUserLoginState(t, "2", nil, stRoles, stTraits, stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -99,8 +99,8 @@ func TestUpsertUserLoginStates(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1", stRoles, stTraits, stRoles, stTraits) - uls2 := newUserLoginState(t, "2", stRoles, stTraits, stRoles, stTraits) + uls1 := newUserLoginState(t, "1", nil, stRoles, stTraits, stRoles, stTraits) + uls2 := newUserLoginState(t, "2", nil, stRoles, stTraits, stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -122,7 +122,7 @@ func TestGetUserLoginState(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1", stRoles, stTraits, stRoles, stTraits) + uls1 := newUserLoginState(t, "1", nil, stRoles, stTraits, stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -148,7 +148,7 @@ func TestDeleteUserLoginState(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1", stRoles, stTraits, stRoles, stTraits) + uls1 := newUserLoginState(t, "1", nil, stRoles, stTraits, stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -175,8 +175,8 @@ func TestDeleteAllAccessLists(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1", stRoles, stTraits, stRoles, stTraits) - uls2 := newUserLoginState(t, "2", stRoles, stTraits, stRoles, stTraits) + uls1 := newUserLoginState(t, "1", nil, stRoles, stTraits, stRoles, stTraits) + uls2 := newUserLoginState(t, "2", nil, stRoles, stTraits, stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err)