diff --git a/lib/auth/bot.go b/lib/auth/bot.go index b6d2e7a7ee61f..2dee779aea1c7 100644 --- a/lib/auth/bot.go +++ b/lib/auth/bot.go @@ -30,8 +30,6 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/api/types/header" - "github.com/gravitational/teleport/api/types/userloginstate" "github.com/gravitational/teleport/api/types/wrappers" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/authz" @@ -113,36 +111,9 @@ func createBotUser( return nil, trace.Wrap(err) } - uls, err := ulsFromUser(user) - if err != nil { - return nil, trace.Wrap(err) - } - - if _, err := s.UserLoginStates.UpsertUserLoginState(ctx, uls); err != nil { - return nil, trace.Wrap(err) - } - return user, nil } -func ulsFromUser(user types.User) (*userloginstate.UserLoginState, error) { - uls, err := userloginstate.New(header.Metadata{ - Name: user.GetName(), - Labels: map[string]string{ - types.BotLabel: user.GetMetadata().Labels[types.BotLabel], - types.BotGenerationLabel: user.GetMetadata().Labels[types.BotGenerationLabel], - }, - }, userloginstate.Spec{ - Roles: user.GetRoles(), - Traits: user.GetTraits(), - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return uls, nil -} - // createBot creates a new certificate renewal bot from a bot request. func (a *Server) createBot(ctx context.Context, req *proto.CreateBotRequest) (*proto.CreateBotResponse, error) { if req.Name == "" { @@ -438,22 +409,6 @@ func (a *Server) validateGenerationLabel(ctx context.Context, username string, c return trace.CompareFailed("Database comparison failed, try the request again") } - uls, err := a.GetUserLoginState(ctx, user.GetName()) - if err != nil && !trace.IsNotFound(err) { - return trace.Wrap(err) - } - if uls == nil { - uls, err = ulsFromUser(user) - if err != nil { - return trace.Wrap(err) - } - } - - uls.ResourceHeader.Metadata.Labels[types.BotGenerationLabel] = generation - if _, err := a.UpsertUserLoginState(ctx, uls); err != nil { - return trace.Wrap(err) - } - return nil } diff --git a/lib/auth/userloginstate/generator.go b/lib/auth/userloginstate/generator.go index b71f2f16ac3e4..1a93ee7268c18 100644 --- a/lib/auth/userloginstate/generator.go +++ b/lib/auth/userloginstate/generator.go @@ -131,13 +131,17 @@ func (g *Generator) Generate(ctx context.Context, user types.User) (*userloginst } } + labels := make(map[string]string, len(user.GetMetadata().Labels)) + for k, v := range user.GetMetadata().Labels { + labels[k] = v + } + labels[userloginstate.OriginalRolesAndTraitsSet] = "true" + // Create a new empty user login state. uls, err := userloginstate.New( header.Metadata{ - Name: user.GetName(), - Labels: map[string]string{ - userloginstate.OriginalRolesAndTraitsSet: "true", - }, + Name: user.GetName(), + Labels: labels, }, userloginstate.Spec{ OriginalRoles: utils.CopyStrings(user.GetRoles()), OriginalTraits: originalTraits, diff --git a/lib/auth/userloginstate/generator_test.go b/lib/auth/userloginstate/generator_test.go index 66dfab399dcdb..9ae2351eebecd 100644 --- a/lib/auth/userloginstate/generator_test.go +++ b/lib/auth/userloginstate/generator_test.go @@ -42,6 +42,13 @@ import ( func TestAccessLists(t *testing.T) { user, err := types.NewUser("user") + user.SetMetadata(types.Metadata{ + Name: "user", + Labels: map[string]string{ + "label1": "value1", + "label2": "value2", + }, + }) user.SetRoles([]string{"orole1"}) user.SetTraits(map[string][]string{ "otrait1": {"value1", "value2"}, @@ -70,6 +77,11 @@ func TestAccessLists(t *testing.T) { cloud: true, roles: []string{"orole1"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1"}, @@ -93,6 +105,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1", "role1", "role2"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1", "role1", "role2"}, @@ -119,6 +136,11 @@ func TestAccessLists(t *testing.T) { }, roles: []string{"orole1", "role1", "role2"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1"}, @@ -142,6 +164,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1", "role1", "role2"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1", "role1", "role2"}, @@ -165,6 +192,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1"}, @@ -188,6 +220,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "not-user")...), roles: []string{"orole1", "role1", "role2"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1", "role1"}, @@ -207,6 +244,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1", "role1", "role2", "role3"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1", "role1", "role2", "role3"}, @@ -235,6 +277,11 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"orole1"}, expected: newUserLoginState(t, "user", + map[string]string{ + "label1": "value1", + "label2": "value2", + userloginstate.OriginalRolesAndTraitsSet: "true", + }, []string{"orole1"}, trait.Traits{"otrait1": {"value1", "value2"}}, []string{"orole1"}, @@ -262,6 +309,9 @@ func TestAccessLists(t *testing.T) { members: append(newAccessListMembers(t, clock, "1", "user"), newAccessListMembers(t, clock, "2", "user")...), roles: []string{"role1"}, expected: newUserLoginState(t, "user", + map[string]string{ + userloginstate.OriginalRolesAndTraitsSet: "true", + }, nil, nil, []string{"role1"}, @@ -420,15 +470,13 @@ func newAccessListMembers(t *testing.T, clock clockwork.Clock, accessList string return alMembers } -func newUserLoginState(t *testing.T, name string, originalRoles []string, originalTraits map[string][]string, +func newUserLoginState(t *testing.T, name string, labels map[string]string, originalRoles []string, originalTraits map[string][]string, roles []string, traits map[string][]string) *userloginstate.UserLoginState { t.Helper() uls, err := userloginstate.New(header.Metadata{ - Name: name, - Labels: map[string]string{ - userloginstate.OriginalRolesAndTraitsSet: "true", - }, + Name: name, + Labels: labels, }, userloginstate.Spec{ OriginalRoles: originalRoles, OriginalTraits: originalTraits, diff --git a/lib/auth/userloginstate/service_test.go b/lib/auth/userloginstate/service_test.go index 7532915b99ed8..36550fc68b60d 100644 --- a/lib/auth/userloginstate/service_test.go +++ b/lib/auth/userloginstate/service_test.go @@ -68,8 +68,8 @@ func TestGetUserLoginStates(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1", stRoles, stTraits, stRoles, stTraits) - uls2 := newUserLoginState(t, "2", stRoles, stTraits, stRoles, stTraits) + uls1 := newUserLoginState(t, "1", nil, stRoles, stTraits, stRoles, stTraits) + uls2 := newUserLoginState(t, "2", nil, stRoles, stTraits, stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -94,8 +94,8 @@ func TestUpsertUserLoginStates(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1", stRoles, stTraits, stRoles, stTraits) - uls2 := newUserLoginState(t, "2", stRoles, stTraits, stRoles, stTraits) + uls1 := newUserLoginState(t, "1", nil, stRoles, stTraits, stRoles, stTraits) + uls2 := newUserLoginState(t, "2", nil, stRoles, stTraits, stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -117,7 +117,7 @@ func TestGetUserLoginState(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1", stRoles, stTraits, stRoles, stTraits) + uls1 := newUserLoginState(t, "1", nil, stRoles, stTraits, stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -143,7 +143,7 @@ func TestDeleteUserLoginState(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1", stRoles, stTraits, stRoles, stTraits) + uls1 := newUserLoginState(t, "1", nil, stRoles, stTraits, stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -170,8 +170,8 @@ func TestDeleteAllAccessLists(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1", stRoles, stTraits, stRoles, stTraits) - uls2 := newUserLoginState(t, "2", stRoles, stTraits, stRoles, stTraits) + uls1 := newUserLoginState(t, "1", nil, stRoles, stTraits, stRoles, stTraits) + uls2 := newUserLoginState(t, "2", nil, stRoles, stTraits, stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err)