From a33b11350a4bc92ea39782f15afa6a51e2e38a18 Mon Sep 17 00:00:00 2001 From: Michael Wilson Date: Wed, 16 Aug 2023 20:44:32 -0400 Subject: [PATCH] Generate user login state from access lists and integrate into certificates. (#29364) * Generate user login state from access lists and integrate into certificates. On login, the user login state will be generated, using access lists to register additional roles and traits that will be inserted into the user's certificate. Tests have been added to exercise this as well. * Cache user login states, filter roles that aren't in the backend. * Small refactor. * Optimize RPC calls, test merge login in auth.go more thoroughly. * Warn when role is missing. * Update so access info uses the user login state directly, user login state comprises the whole state as opposed to a mix. * Logic tweaks to restore tests. * Integrate user login state cache. * Swap out get user for get user state where applicable. * Revert unrelated debug change. * Add in missing err check. * Further replacing with user state. * Revert changes to helpers to try to get integration tests working. * Revert "Revert changes to helpers to try to get integration tests working." This reverts commit 682e92064b81fc99db22ede79bf2b6c337588e39. * Add in user type to generator. * Use supplied user for generating SSH certs. --- .../userloginstate/v1/userloginstate.pb.go | 29 +- .../userloginstate/v1/userloginstate.proto | 3 + api/types/user.go | 15 + .../convert/v1/user_login_state.go | 11 +- .../convert/v1/user_login_state_test.go | 10 + api/types/userloginstate/user_login_state.go | 31 +- lib/auth/auth.go | 71 +++-- lib/auth/auth_test.go | 107 ++++++- lib/auth/auth_with_roles.go | 6 +- lib/auth/bot.go | 66 ++++- lib/auth/helpers.go | 20 +- lib/auth/init.go | 3 + lib/auth/methods.go | 15 +- lib/auth/sessions.go | 14 +- lib/auth/userloginstate/generator.go | 207 +++++++++++++ lib/auth/userloginstate/generator_test.go | 279 ++++++++++++++++++ lib/auth/userloginstate/service_test.go | 59 ++-- lib/services/access_checker.go | 30 ++ lib/services/parser.go | 4 +- lib/services/role_test.go | 2 +- 20 files changed, 883 insertions(+), 99 deletions(-) create mode 100644 lib/auth/userloginstate/generator.go create mode 100644 lib/auth/userloginstate/generator_test.go diff --git a/api/gen/proto/go/teleport/userloginstate/v1/userloginstate.pb.go b/api/gen/proto/go/teleport/userloginstate/v1/userloginstate.pb.go index 932b78b81d743..0a8344a9107ce 100644 --- a/api/gen/proto/go/teleport/userloginstate/v1/userloginstate.pb.go +++ b/api/gen/proto/go/teleport/userloginstate/v1/userloginstate.pb.go @@ -104,6 +104,8 @@ type Spec struct { Roles []string `protobuf:"bytes,1,rep,name=roles,proto3" json:"roles,omitempty"` // traits are the traits attached to the user. Traits []*v11.Trait `protobuf:"bytes,2,rep,name=traits,proto3" json:"traits,omitempty"` + // user_type is the type of user this state represents. + UserType string `protobuf:"bytes,3,opt,name=user_type,json=userType,proto3" json:"user_type,omitempty"` } func (x *Spec) Reset() { @@ -152,6 +154,13 @@ func (x *Spec) GetTraits() []*v11.Trait { return nil } +func (x *Spec) GetUserType() string { + if x != nil { + return x.UserType + } + return "" +} + var File_teleport_userloginstate_v1_userloginstate_proto protoreflect.FileDescriptor var file_teleport_userloginstate_v1_userloginstate_proto_rawDesc = []byte{ @@ -172,19 +181,21 @@ var file_teleport_userloginstate_v1_userloginstate_proto_rawDesc = []byte{ 0x61, 0x64, 0x65, 0x72, 0x12, 0x34, 0x0a, 0x04, 0x73, 0x70, 0x65, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x75, 0x73, 0x65, 0x72, 0x6c, 0x6f, 0x67, 0x69, 0x6e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e, 0x76, 0x31, 0x2e, - 0x53, 0x70, 0x65, 0x63, 0x52, 0x04, 0x73, 0x70, 0x65, 0x63, 0x22, 0x4e, 0x0a, 0x04, 0x53, 0x70, + 0x53, 0x70, 0x65, 0x63, 0x52, 0x04, 0x73, 0x70, 0x65, 0x63, 0x22, 0x6b, 0x0a, 0x04, 0x53, 0x70, 0x65, 0x63, 0x12, 0x14, 0x0a, 0x05, 0x72, 0x6f, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x72, 0x6f, 0x6c, 0x65, 0x73, 0x12, 0x30, 0x0a, 0x06, 0x74, 0x72, 0x61, 0x69, 0x74, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x74, 0x72, 0x61, 0x69, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x72, 0x61, - 0x69, 0x74, 0x52, 0x06, 0x74, 0x72, 0x61, 0x69, 0x74, 0x73, 0x42, 0x60, 0x5a, 0x5e, 0x67, 0x69, - 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x72, 0x61, 0x76, 0x69, 0x74, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2f, - 0x61, 0x70, 0x69, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, - 0x2f, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x75, 0x73, 0x65, 0x72, 0x6c, 0x6f, - 0x67, 0x69, 0x6e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x31, 0x3b, 0x75, 0x73, 0x65, 0x72, - 0x6c, 0x6f, 0x67, 0x69, 0x6e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, + 0x69, 0x74, 0x52, 0x06, 0x74, 0x72, 0x61, 0x69, 0x74, 0x73, 0x12, 0x1b, 0x0a, 0x09, 0x75, 0x73, + 0x65, 0x72, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, + 0x73, 0x65, 0x72, 0x54, 0x79, 0x70, 0x65, 0x42, 0x60, 0x5a, 0x5e, 0x67, 0x69, 0x74, 0x68, 0x75, + 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x72, 0x61, 0x76, 0x69, 0x74, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x61, 0x70, 0x69, + 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, 0x2f, 0x74, 0x65, + 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x75, 0x73, 0x65, 0x72, 0x6c, 0x6f, 0x67, 0x69, 0x6e, + 0x73, 0x74, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x31, 0x3b, 0x75, 0x73, 0x65, 0x72, 0x6c, 0x6f, 0x67, + 0x69, 0x6e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( diff --git a/api/proto/teleport/userloginstate/v1/userloginstate.proto b/api/proto/teleport/userloginstate/v1/userloginstate.proto index 1acbd1f3d9af0..56c7facbaa5c4 100644 --- a/api/proto/teleport/userloginstate/v1/userloginstate.proto +++ b/api/proto/teleport/userloginstate/v1/userloginstate.proto @@ -37,4 +37,7 @@ message Spec { // traits are the traits attached to the user. repeated teleport.trait.v1.Trait traits = 2; + + // user_type is the type of user this state represents. + string user_type = 3; } diff --git a/api/types/user.go b/api/types/user.go index a7c0d7288ba89..4072c909a9ed8 100644 --- a/api/types/user.go +++ b/api/types/user.go @@ -125,6 +125,10 @@ type User interface { GetTrustedDeviceIDs() []string // SetTrustedDeviceIDs assigns the IDs of the user's trusted devices. SetTrustedDeviceIDs(ids []string) + // IsBot returns true if the user is a bot. + IsBot() bool + // BotGenerationLabel returns the bot generation label. + BotGenerationLabel() string } // NewUser creates new empty user @@ -460,6 +464,17 @@ func (u UserV2) GetUserType() UserType { return UserTypeSSO } +// IsBot returns true if the user is a bot. +func (u UserV2) IsBot() bool { + _, ok := u.GetMetadata().Labels[BotGenerationLabel] + return ok +} + +// BotGenerationLabel returns the bot generation label. +func (u UserV2) BotGenerationLabel() string { + return u.GetMetadata().Labels[BotGenerationLabel] +} + func (u *UserV2) String() string { return fmt.Sprintf("User(name=%v, roles=%v, identities=%v)", u.Metadata.Name, u.Spec.Roles, u.Spec.OIDCIdentities) } diff --git a/api/types/userloginstate/convert/v1/user_login_state.go b/api/types/userloginstate/convert/v1/user_login_state.go index dfafc74cbc0b8..92ebd1d708628 100644 --- a/api/types/userloginstate/convert/v1/user_login_state.go +++ b/api/types/userloginstate/convert/v1/user_login_state.go @@ -20,6 +20,7 @@ import ( "github.com/gravitational/trace" userloginstatev1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/userloginstate/v1" + "github.com/gravitational/teleport/api/types" headerv1 "github.com/gravitational/teleport/api/types/header/convert/v1" traitv1 "github.com/gravitational/teleport/api/types/trait/convert/v1" "github.com/gravitational/teleport/api/types/userloginstate" @@ -32,8 +33,9 @@ func FromProto(msg *userloginstatev1.UserLoginState) (*userloginstate.UserLoginS } uls, err := userloginstate.New(headerv1.FromMetadataProto(msg.Header.Metadata), userloginstate.Spec{ - Roles: msg.Spec.Roles, - Traits: traitv1.FromProto(msg.Spec.Traits), + Roles: msg.Spec.Roles, + Traits: traitv1.FromProto(msg.Spec.Traits), + UserType: types.UserType(msg.Spec.UserType), }) return uls, trace.Wrap(err) @@ -44,8 +46,9 @@ func ToProto(uls *userloginstate.UserLoginState) *userloginstatev1.UserLoginStat return &userloginstatev1.UserLoginState{ Header: headerv1.ToResourceHeaderProto(uls.ResourceHeader), Spec: &userloginstatev1.Spec{ - Roles: uls.GetRoles(), - Traits: traitv1.ToProto(uls.GetTraits()), + Roles: uls.GetRoles(), + Traits: traitv1.ToProto(uls.GetTraits()), + UserType: string(uls.Spec.UserType), }, } } diff --git a/api/types/userloginstate/convert/v1/user_login_state_test.go b/api/types/userloginstate/convert/v1/user_login_state_test.go index 705f054fa99e1..b5c4b0fc10995 100644 --- a/api/types/userloginstate/convert/v1/user_login_state_test.go +++ b/api/types/userloginstate/convert/v1/user_login_state_test.go @@ -22,6 +22,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/header" "github.com/gravitational/teleport/api/types/trait" "github.com/gravitational/teleport/api/types/userloginstate" @@ -58,6 +59,14 @@ func TestFromProtoNils(t *testing.T) { _, err = FromProto(uls) require.NoError(t, err) + + // UserType is empty + uls = ToProto(newUserLoginState(t, "user-login-state")) + uls.Spec.UserType = "" + + fromProto, err := FromProto(uls) + require.NoError(t, err) + require.Equal(t, fromProto.GetUserType(), types.UserTypeLocal) } func newUserLoginState(t *testing.T, name string) *userloginstate.UserLoginState { @@ -73,6 +82,7 @@ func newUserLoginState(t *testing.T, name string) *userloginstate.UserLoginState "key1": []string{"value1"}, "key2": []string{"value2"}, }, + UserType: types.UserTypeSSO, }, ) require.NoError(t, err) diff --git a/api/types/userloginstate/user_login_state.go b/api/types/userloginstate/user_login_state.go index a382ffa910be0..ad6b0f7611fa8 100644 --- a/api/types/userloginstate/user_login_state.go +++ b/api/types/userloginstate/user_login_state.go @@ -46,6 +46,9 @@ type Spec struct { // Traits are the traits attached to the user login state. Traits trait.Traits `json:"traits" yaml:"traits"` + + // UserType is the type of user that this state represents. + UserType types.UserType `json:"user_type" yaml:"user_type"` } // New creates a new user login state. @@ -67,7 +70,15 @@ func (u *UserLoginState) CheckAndSetDefaults() error { u.SetKind(types.KindUserLoginState) u.SetVersion(types.V1) - return trace.Wrap(u.ResourceHeader.CheckAndSetDefaults()) + if err := trace.Wrap(u.ResourceHeader.CheckAndSetDefaults()); err != nil { + return trace.Wrap(err) + } + + if u.Spec.UserType == "" { + u.Spec.UserType = types.UserTypeLocal + } + + return nil } // GetRoles returns the roles attached to the user login state. @@ -76,10 +87,26 @@ func (u *UserLoginState) GetRoles() []string { } // GetTraits returns the traits attached to the user login state. -func (u *UserLoginState) GetTraits() trait.Traits { +func (u *UserLoginState) GetTraits() map[string][]string { return u.Spec.Traits } +// GetUserType returns the user type for the user login state. +func (u *UserLoginState) GetUserType() types.UserType { + return u.Spec.UserType +} + +// IsBot returns true if the user is a bot. +func (u *UserLoginState) IsBot() bool { + _, ok := u.GetMetadata().Labels[types.BotGenerationLabel] + return ok +} + +// BotGenerationLabel returns the bot generation label. +func (u *UserLoginState) BotGenerationLabel() string { + return u.GetMetadata().Labels[types.BotGenerationLabel] +} + // GetMetadata returns metadata. This is specifically for conforming to the Resource interface, // and should be removed when possible. func (u *UserLoginState) GetMetadata() types.Metadata { diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 63546154513bc..5f8089726a997 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -72,6 +72,7 @@ import ( "github.com/gravitational/teleport/lib/ai" "github.com/gravitational/teleport/lib/auth/keystore" "github.com/gravitational/teleport/lib/auth/native" + "github.com/gravitational/teleport/lib/auth/userloginstate" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" "github.com/gravitational/teleport/lib/authz" @@ -248,6 +249,12 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { if cfg.UserPreferences == nil { cfg.UserPreferences = local.NewUserPreferencesService(cfg.Backend) } + if cfg.UserLoginState == nil { + cfg.UserLoginState, err = local.NewUserLoginStateService(cfg.Backend) + if err != nil { + return nil, trace.Wrap(err) + } + } limiter, err := limiter.NewConnectionsLimiter(limiter.Config{ MaxConnections: defaults.LimiterMaxConcurrentSignatures, @@ -300,6 +307,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { Embeddings: cfg.Embeddings, Okta: cfg.Okta, AccessLists: cfg.AccessLists, + UserLoginStates: cfg.UserLoginState, StatusInternal: cfg.Status, UsageReporter: cfg.UsageReporter, Assistant: cfg.Assist, @@ -392,6 +400,19 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { ) } + // Add in a login hook for generating state during user login. + ulsGenerator, err := userloginstate.NewGenerator(userloginstate.GeneratorConfig{ + Log: log, + AccessLists: services, + Access: services, + Clock: cfg.Clock, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + as.RegisterLoginHook(ulsGenerator.LoginHook(services.UserLoginStates)) + return &as, nil } @@ -1497,7 +1518,7 @@ func (a *Server) GetKeyStore() *keystore.Manager { type certRequest struct { // user is a user to generate certificate for - user types.User + user services.UserState // impersonator is a user who generates the certificate, // is set when different from the user in the certificate impersonator string @@ -1631,6 +1652,21 @@ func certRequestDeviceExtensions(ext tlsca.DeviceExtensions) certRequestOption { } } +// getUserOrLoginState will return the given user or the login state associated with the user. +func (a *Server) getUserOrLoginState(ctx context.Context, username string) (services.UserState, error) { + uls, err := a.GetUserLoginState(ctx, username) + if err != nil && !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + + if err == nil { + return uls, nil + } + + user, err := a.GetUser(username, false) + return user, trace.Wrap(err) +} + func (a *Server) GenerateOpenSSHCert(ctx context.Context, req *proto.OpenSSHCertRequest) (*proto.OpenSSHCert, error) { if req.User == nil { return nil, trace.BadParameter("user is empty") @@ -1653,7 +1689,7 @@ func (a *Server) GenerateOpenSSHCert(ctx context.Context, req *proto.OpenSSHCert } // add implicit roles to the set and build a checker - accessInfo := services.AccessInfoFromUser(req.User) + accessInfo := services.AccessInfoFromUserState(req.User) roles := make([]types.Role, len(req.Roles)) for i := range req.Roles { var err error @@ -1702,11 +1738,11 @@ type GenerateUserTestCertsRequest struct { // GenerateUserTestCerts is used to generate user certificate, used internally for tests func (a *Server) GenerateUserTestCerts(req GenerateUserTestCertsRequest) ([]byte, []byte, error) { - user, err := a.GetUser(req.Username, false) + userState, err := a.getUserOrLoginState(context.Background(), req.Username) if err != nil { return nil, nil, trace.Wrap(err) } - accessInfo := services.AccessInfoFromUser(user) + accessInfo := services.AccessInfoFromUserState(userState) clusterName, err := a.GetClusterName() if err != nil { return nil, nil, trace.Wrap(err) @@ -1716,13 +1752,13 @@ func (a *Server) GenerateUserTestCerts(req GenerateUserTestCertsRequest) ([]byte return nil, nil, trace.Wrap(err) } certs, err := a.generateUserCert(certRequest{ - user: user, + user: userState, ttl: req.TTL, compatibility: req.Compatibility, publicKey: req.Key, routeToCluster: req.RouteToCluster, checker: checker, - traits: user.GetTraits(), + traits: userState.GetTraits(), loginIP: req.PinnedIP, pinIP: req.PinnedIP != "", mfaVerified: req.MFAVerified, @@ -1762,11 +1798,11 @@ type AppTestCertRequest struct { // GenerateUserAppTestCert generates an application specific certificate, used // internally for tests. func (a *Server) GenerateUserAppTestCert(req AppTestCertRequest) ([]byte, error) { - user, err := a.GetUser(req.Username, false) + userState, err := a.getUserOrLoginState(context.Background(), req.Username) if err != nil { return nil, trace.Wrap(err) } - accessInfo := services.AccessInfoFromUser(user) + accessInfo := services.AccessInfoFromUserState(userState) clusterName, err := a.GetClusterName() if err != nil { return nil, trace.Wrap(err) @@ -1786,7 +1822,7 @@ func (a *Server) GenerateUserAppTestCert(req AppTestCertRequest) ([]byte, error) } certs, err := a.generateUserCert(certRequest{ - user: user, + user: userState, publicKey: req.PublicKey, checker: checker, ttl: req.TTL, @@ -1832,11 +1868,11 @@ type DatabaseTestCertRequest struct { // GenerateDatabaseTestCert generates a database access certificate for the // provided parameters. Used only internally in tests. func (a *Server) GenerateDatabaseTestCert(req DatabaseTestCertRequest) ([]byte, error) { - user, err := a.GetUser(req.Username, false) + userState, err := a.getUserOrLoginState(context.Background(), req.Username) if err != nil { return nil, trace.Wrap(err) } - accessInfo := services.AccessInfoFromUser(user) + accessInfo := services.AccessInfoFromUserState(userState) clusterName, err := a.GetClusterName() if err != nil { return nil, trace.Wrap(err) @@ -1846,7 +1882,7 @@ func (a *Server) GenerateDatabaseTestCert(req DatabaseTestCertRequest) ([]byte, return nil, trace.Wrap(err) } certs, err := a.generateUserCert(certRequest{ - user: user, + user: userState, publicKey: req.PublicKey, loginIP: req.PinnedIP, pinIP: req.PinnedIP != "", @@ -2125,10 +2161,7 @@ func (a *Server) submitCertificateIssuedEvent(req *certRequest) { } // Bot users are regular Teleport users, but have a special internal label. - bot := false - if _, ok := req.user.GetMetadata().Labels[types.BotLabel]; ok { - bot = true - } + bot := req.user.IsBot() // Unfortunately the only clue we have about Windows certs is the usage // restriction: `RouteToWindowsDesktop` isn't actually passed along to the @@ -3347,7 +3380,7 @@ func (a *Server) getValidatedAccessRequest(ctx context.Context, identity tlsca.I // CreateWebSession creates a new web session for user without any // checks, is used by admins func (a *Server) CreateWebSession(ctx context.Context, user string) (types.WebSession, error) { - u, err := a.GetUser(user, false) + u, err := a.getUserOrLoginState(ctx, user) if err != nil { return nil, trace.Wrap(err) } @@ -3914,7 +3947,7 @@ func (a *Server) GetTokens(ctx context.Context, opts ...services.MarshalOption) // NewWebSession creates and returns a new web session for the specified request func (a *Server) NewWebSession(ctx context.Context, req types.NewWebSessionRequest) (types.WebSession, error) { - user, err := a.GetUser(req.User, false) + userState, err := a.getUserOrLoginState(ctx, req.User) if err != nil { return nil, trace.Wrap(err) } @@ -3949,7 +3982,7 @@ func (a *Server) NewWebSession(ctx context.Context, req types.NewWebSessionReque sessionTTL = checker.AdjustSessionTTL(apidefaults.CertDuration) } certs, err := a.generateUserCert(certRequest{ - user: user, + user: userState, loginIP: req.LoginIP, ttl: sessionTTL, publicKey: pub, diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index 34e7d9a80c620..4e1f859417b45 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -49,7 +49,11 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/types/header" "github.com/gravitational/teleport/api/types/installers" + "github.com/gravitational/teleport/api/types/trait" + "github.com/gravitational/teleport/api/types/userloginstate" + "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth/keystore" @@ -1844,7 +1848,7 @@ func TestGenerateUserCertWithCertExtension(t *testing.T) { err = p.a.UpsertRole(ctx, role) require.NoError(t, err) - accessInfo := services.AccessInfoFromUser(user) + accessInfo := services.AccessInfoFromUserState(user) accessChecker, err := services.NewAccessChecker(accessInfo, p.clusterName.GetClusterName(), p.a) require.NoError(t, err) @@ -1930,7 +1934,7 @@ func TestGenerateUserCertWithLocks(t *testing.T) { user, _, err := CreateUserAndRole(p.a, "test-user", []string{}, nil) require.NoError(t, err) - accessInfo := services.AccessInfoFromUser(user) + accessInfo := services.AccessInfoFromUserState(user) accessChecker, err := services.NewAccessChecker(accessInfo, p.clusterName.GetClusterName(), p.a) require.NoError(t, err) const mfaID = "test-mfa-id" @@ -2028,6 +2032,105 @@ func TestGenerateHostCertWithLocks(t *testing.T) { require.Error(t, err) } +func TestGenerateUserCertWithUserLoginState(t *testing.T) { + t.Parallel() + ctx := context.Background() + p, err := newTestPack(ctx, t.TempDir()) + require.NoError(t, err) + + user, role, err := CreateUserAndRole(p.a, "test-user", []string{}, nil) + require.NoError(t, err) + userState, err := p.a.getUserOrLoginState(ctx, user.GetName()) + require.NoError(t, err) + accessInfo := services.AccessInfoFromUserState(userState) + accessChecker, err := services.NewAccessChecker(accessInfo, p.clusterName.GetClusterName(), p.a) + require.NoError(t, err) + keygen := testauthority.New() + _, pub, err := keygen.GetNewKeyPairFromPool() + require.NoError(t, err) + + // Generate cert with no user login state. + certReq := certRequest{ + user: user, + checker: accessChecker, + publicKey: pub, + traits: accessChecker.Traits(), + } + resp, err := p.a.generateUserCert(certReq) + require.NoError(t, err) + + sshCert, err := sshutils.ParseCertificate(resp.SSH) + require.NoError(t, err) + + roles, err := services.UnmarshalCertRoles(sshCert.Extensions[teleport.CertExtensionTeleportRoles]) + require.NoError(t, err) + require.Equal(t, []string{role.GetName()}, roles) + + traits := wrappers.Traits{} + err = wrappers.UnmarshalTraits([]byte(sshCert.Extensions[teleport.CertExtensionTeleportTraits]), &traits) + require.NoError(t, err) + require.Empty(t, traits) + + uls, err := userloginstate.New( + header.Metadata{ + Name: user.GetName(), + }, + userloginstate.Spec{ + Roles: []string{ + role.GetName(), // We'll try to grant a duplicate role, which should be deduplicated. + "uls-role1", + "uls-role2", + }, + Traits: trait.Traits{ + "uls-trait1": []string{"value1", "value2"}, + "uls-trait2": []string{"value3", "value4"}, + }, + }, + ) + require.NoError(t, err) + _, err = p.a.UpsertUserLoginState(ctx, uls) + require.NoError(t, err) + + ulsRole1, err := types.NewRole("uls-role1", types.RoleSpecV6{}) + require.NoError(t, err) + ulsRole2, err := types.NewRole("uls-role2", types.RoleSpecV6{}) + require.NoError(t, err) + + require.NoError(t, p.a.UpsertRole(ctx, ulsRole1)) + require.NoError(t, p.a.UpsertRole(ctx, ulsRole2)) + + userState, err = p.a.getUserOrLoginState(ctx, user.GetName()) + require.NoError(t, err) + accessInfo = services.AccessInfoFromUserState(userState) + accessChecker, err = services.NewAccessChecker(accessInfo, p.clusterName.GetClusterName(), p.a) + require.NoError(t, err) + + certReq = certRequest{ + user: user, + checker: accessChecker, + publicKey: pub, + traits: accessChecker.Traits(), + } + + resp, err = p.a.generateUserCert(certReq) + require.NoError(t, err) + + sshCert, err = sshutils.ParseCertificate(resp.SSH) + require.NoError(t, err) + + roles, err = services.UnmarshalCertRoles(sshCert.Extensions[teleport.CertExtensionTeleportRoles]) + require.NoError(t, err) + require.Equal(t, []string{role.GetName(), "uls-role1", "uls-role2"}, roles) + + traits = wrappers.Traits{} + err = wrappers.UnmarshalTraits([]byte(sshCert.Extensions[teleport.CertExtensionTeleportTraits]), &traits) + require.NoError(t, err) + require.Equal(t, map[string][]string{ + "uls-trait1": {"value1", "value2"}, + "uls-trait2": {"value3", "value4"}, + }, map[string][]string(traits)) +} + func TestNewWebSession(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 62bde3c0c4ea0..c05ad2d03284c 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -2756,7 +2756,11 @@ func (a *ServerWithRoles) desiredAccessInfoForUser(ctx context.Context, req *pro // Reset to the base roles and traits stored in the backend user, // currently active requests (not being dropped) and new access requests // will be filled in below. - accessInfo = services.AccessInfoFromUser(user) + userState, err := a.authServer.getUserOrLoginState(ctx, user.GetName()) + if err != nil { + return nil, trace.Wrap(err) + } + accessInfo = services.AccessInfoFromUserState(userState) // Check for ["*"] as special case to drop all requests. if len(req.DropAccessRequests) == 1 && req.DropAccessRequests[0] == "*" { diff --git a/lib/auth/bot.go b/lib/auth/bot.go index 8418132a66c4d..898542d4e536a 100644 --- a/lib/auth/bot.go +++ b/lib/auth/bot.go @@ -30,6 +30,8 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/types/header" + "github.com/gravitational/teleport/api/types/userloginstate" "github.com/gravitational/teleport/api/types/wrappers" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/authz" @@ -111,9 +113,36 @@ func createBotUser( return nil, trace.Wrap(err) } + uls, err := ulsFromUser(user) + if err != nil { + return nil, trace.Wrap(err) + } + + if _, err := s.UserLoginStates.UpsertUserLoginState(ctx, uls); err != nil { + return nil, trace.Wrap(err) + } + return user, nil } +func ulsFromUser(user types.User) (*userloginstate.UserLoginState, error) { + uls, err := userloginstate.New(header.Metadata{ + Name: user.GetName(), + Labels: map[string]string{ + types.BotLabel: user.GetMetadata().Labels[types.BotLabel], + types.BotGenerationLabel: user.GetMetadata().Labels[types.BotGenerationLabel], + }, + }, userloginstate.Spec{ + Roles: user.GetRoles(), + Traits: user.GetTraits(), + }) + if err != nil { + return nil, trace.Wrap(err) + } + + return uls, nil +} + // createBot creates a new certificate renewal bot from a bot request. func (s *Server) createBot(ctx context.Context, req *proto.CreateBotRequest) (*proto.CreateBotResponse, error) { if req.Name == "" { @@ -342,17 +371,17 @@ func (s *Server) checkOrCreateBotToken(ctx context.Context, req *proto.CreateBot } // validateGenerationLabel validates and updates a generation label. -func (s *Server) validateGenerationLabel(ctx context.Context, user types.User, certReq *certRequest, currentIdentityGeneration uint64) error { +func (s *Server) validateGenerationLabel(ctx context.Context, userState services.UserState, certReq *certRequest, currentIdentityGeneration uint64) error { // Fetch the user, bypassing the cache. We might otherwise fetch a stale // value in case of a rapid certificate renewal. - user, err := s.Services.GetUser(user.GetName(), false) + user, err := s.Services.GetUser(userState.GetName(), false) if err != nil { return trace.Wrap(err) } var currentUserGeneration uint64 - label, labelOk := user.GetMetadata().Labels[types.BotGenerationLabel] - if labelOk { + label := userState.BotGenerationLabel() + if label != "" { currentUserGeneration, err = strconv.ParseUint(label, 10, 64) if err != nil { return trace.BadParameter("user has invalid value for label %q", types.BotGenerationLabel) @@ -394,7 +423,8 @@ func (s *Server) validateGenerationLabel(ctx context.Context, user types.User, c } newUser := apiutils.CloneProtoMsg(userV2) metadata := newUser.GetMetadata() - metadata.Labels[types.BotGenerationLabel] = fmt.Sprint(certReq.generation) + generation := fmt.Sprint(certReq.generation) + metadata.Labels[types.BotGenerationLabel] = generation newUser.SetMetadata(metadata) // Note: we bypass the RBAC check on purpose as bot users should not @@ -406,6 +436,22 @@ func (s *Server) validateGenerationLabel(ctx context.Context, user types.User, c return trace.CompareFailed("Database comparison failed, try the request again") } + uls, err := s.GetUserLoginState(ctx, user.GetName()) + if err != nil && !trace.IsNotFound(err) { + return trace.Wrap(err) + } + if uls == nil { + uls, err = ulsFromUser(user) + if err != nil { + return trace.Wrap(err) + } + } + + uls.ResourceHeader.Metadata.Labels[types.BotGenerationLabel] = generation + if _, err := s.UpsertUserLoginState(ctx, uls); err != nil { + return trace.Wrap(err) + } + return nil } @@ -487,14 +533,14 @@ func (s *Server) generateInitialBotCerts(ctx context.Context, username string, p // This call bypasses RBAC check for users read on purpose. // Users who are allowed to impersonate other users might not have // permissions to read user data. - user, err := s.GetUser(username, false) + userState, err := s.getUserOrLoginState(ctx, username) if err != nil { log.WithError(err).Debugf("Could not impersonate user %v. The user could not be fetched from local store.", username) return nil, trace.AccessDenied("access denied") } // Do not allow SSO users to be impersonated. - if user.GetUserType() == types.UserTypeSSO { + if userState.GetUserType() == types.UserTypeSSO { log.Warningf("Tried to issue a renewable cert for externally managed user %v, this is not supported.", username) return nil, trace.AccessDenied("access denied") } @@ -505,7 +551,7 @@ func (s *Server) generateInitialBotCerts(ctx context.Context, username string, p } // Inherit the user's roles and traits verbatim. - accessInfo := services.AccessInfoFromUser(user) + accessInfo := services.AccessInfoFromUserState(userState) clusterName, err := s.GetClusterName() if err != nil { return nil, trace.Wrap(err) @@ -523,7 +569,7 @@ func (s *Server) generateInitialBotCerts(ctx context.Context, username string, p // Generate certificate certReq := certRequest{ - user: user, + user: userState, ttl: expires.Sub(s.GetClock().Now()), publicKey: pubKey, checker: checker, @@ -533,7 +579,7 @@ func (s *Server) generateInitialBotCerts(ctx context.Context, username string, p generation: generation, } - if err := s.validateGenerationLabel(ctx, user, &certReq, 0); err != nil { + if err := s.validateGenerationLabel(ctx, userState, &certReq, 0); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index d57ddbe4185f4..102b323ca642f 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -416,18 +416,22 @@ func (a *TestAuthServer) GenerateUserCert(key []byte, username string, ttl time. if err != nil { return nil, trace.Wrap(err) } - accessInfo := services.AccessInfoFromUser(user) + userState, err := a.AuthServer.getUserOrLoginState(context.Background(), user.GetName()) + if err != nil { + return nil, trace.Wrap(err) + } + accessInfo := services.AccessInfoFromUserState(userState) checker, err := services.NewAccessChecker(accessInfo, a.ClusterName, a.AuthServer) if err != nil { return nil, trace.Wrap(err) } certs, err := a.AuthServer.generateUserCert(certRequest{ - user: user, + user: userState, ttl: ttl, compatibility: compatibility, publicKey: key, checker: checker, - traits: user.GetTraits(), + traits: userState.GetTraits(), }) if err != nil { return nil, trace.Wrap(err) @@ -474,7 +478,11 @@ func generateCertificate(authServer *Server, identity TestIdentity) ([]byte, []b if err != nil { return nil, nil, trace.Wrap(err) } - accessInfo := services.AccessInfoFromUser(user) + userState, err := authServer.getUserOrLoginState(context.Background(), user.GetName()) + if err != nil { + return nil, nil, trace.Wrap(err) + } + accessInfo := services.AccessInfoFromUserState(userState) checker, err := services.NewAccessChecker(accessInfo, clusterName.GetClusterName(), authServer) if err != nil { return nil, nil, trace.Wrap(err) @@ -485,12 +493,12 @@ func generateCertificate(authServer *Server, identity TestIdentity) ([]byte, []b certs, err := authServer.generateUserCert(certRequest{ publicKey: pub, - user: user, + user: userState, ttl: identity.TTL, usage: identity.AcceptedUsage, routeToCluster: identity.RouteToCluster, checker: checker, - traits: user.GetTraits(), + traits: userState.GetTraits(), renewable: identity.Renewable, generation: identity.Generation, deviceExtensions: DeviceExtensions(id.Identity.DeviceExtensions), diff --git a/lib/auth/init.go b/lib/auth/init.go index 0f66807ee0e5e..9ef84fe5ec420 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -236,6 +236,9 @@ type InitConfig struct { // AccessLists is a service that manages access list resources. AccessLists services.AccessLists + // UserLoginStates is a service that manages user login states. + UserLoginState services.UserLoginStates + // Clock is the clock instance auth uses. Typically you'd only want to set // this during testing. Clock clockwork.Clock diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 5d60d450fa6f3..4696441ab3962 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -113,7 +113,7 @@ type SessionCreds struct { // AuthenticateUser authenticates user based on the request type. // Returns the username of the authenticated user. -func (s *Server) AuthenticateUser(ctx context.Context, req AuthenticateUserRequest) (types.User, error) { +func (s *Server) AuthenticateUser(ctx context.Context, req AuthenticateUserRequest) (services.UserState, error) { username := req.Username mfaDev, actualUsername, err := s.authenticateUser(ctx, req) @@ -614,7 +614,12 @@ func (s *Server) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHReq return nil, trace.Wrap(err) } - accessInfo := services.AccessInfoFromUser(user) + userState, err := s.getUserOrLoginState(ctx, user.GetName()) + if err != nil { + return nil, trace.Wrap(err) + } + + accessInfo := services.AccessInfoFromUserState(userState) checker, err := services.NewAccessChecker(accessInfo, clusterName.GetClusterName(), s) if err != nil { return nil, trace.Wrap(err) @@ -645,12 +650,12 @@ func (s *Server) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHReq } certReq := certRequest{ - user: user, + user: userState, ttl: req.TTL, publicKey: req.PublicKey, compatibility: req.CompatibilityMode, checker: checker, - traits: user.GetTraits(), + traits: userState.GetTraits(), routeToCluster: req.RouteToCluster, kubernetesCluster: req.KubernetesCluster, loginIP: clientIP, @@ -702,7 +707,7 @@ func (s *Server) emitNoLocalAuthEvent(username string) { } } -func (s *Server) createUserWebSession(ctx context.Context, user types.User, loginIP string) (types.WebSession, error) { +func (s *Server) createUserWebSession(ctx context.Context, user services.UserState, loginIP string) (types.WebSession, error) { // It's safe to extract the roles and traits directly from services.User as this method // is only used for local accounts. return s.CreateWebSessionFromReq(ctx, types.NewWebSessionRequest{ diff --git a/lib/auth/sessions.go b/lib/auth/sessions.go index e155b901b19e8..45a2c573ba3c1 100644 --- a/lib/auth/sessions.go +++ b/lib/auth/sessions.go @@ -266,7 +266,15 @@ func (s *Server) CreateSessionCert(user types.User, sessionTTL time.Duration, pu // It's safe to extract the access info directly from services.User because // this occurs during the initial login before the first certs have been // generated, so there's no possibility of any active access requests. - accessInfo := services.AccessInfoFromUser(user) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + userState, err := s.getUserOrLoginState(ctx, user.GetName()) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + accessInfo := services.AccessInfoFromUserState(userState) clusterName, err := s.GetClusterName() if err != nil { return nil, nil, trace.Wrap(err) @@ -277,12 +285,12 @@ func (s *Server) CreateSessionCert(user types.User, sessionTTL time.Duration, pu } certs, err := s.generateUserCert(certRequest{ - user: user, + user: userState, ttl: sessionTTL, publicKey: publicKey, compatibility: compatibility, checker: checker, - traits: user.GetTraits(), + traits: userState.GetTraits(), routeToCluster: routeToCluster, kubernetesCluster: kubernetesCluster, attestationStatement: attestationReq, diff --git a/lib/auth/userloginstate/generator.go b/lib/auth/userloginstate/generator.go new file mode 100644 index 0000000000000..af477094da035 --- /dev/null +++ b/lib/auth/userloginstate/generator.go @@ -0,0 +1,207 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package userloginstate + +import ( + "context" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/header" + "github.com/gravitational/teleport/api/types/userloginstate" + "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/tlsca" +) + +// GeneratorConfig is the configuration for the user login state generator. +type GeneratorConfig struct { + // Log is a logger to use for the generator. + Log *logrus.Entry + + // AccessLists is a service for retrieving access lists from the backend. + AccessLists services.AccessListsGetter + + // Access is a service that will be used for retrieving roles from the backend. + Access services.Access + + // Clock is the clock to use for the generator. + Clock clockwork.Clock +} + +func (g *GeneratorConfig) CheckAndSetDefaults() error { + if g.Log == nil { + return trace.BadParameter("missing log") + } + + if g.AccessLists == nil { + return trace.BadParameter("missing access lists") + } + + if g.Access == nil { + return trace.BadParameter("missing access") + } + + if g.Clock == nil { + g.Clock = clockwork.NewRealClock() + } + + return nil +} + +// Generator will generate a user login state from a user. +type Generator struct { + log *logrus.Entry + accessLists services.AccessListsGetter + access services.Access + clock clockwork.Clock +} + +// NewGenerator creates a new user login state generator. +func NewGenerator(config GeneratorConfig) (*Generator, error) { + if err := config.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + return &Generator{ + log: config.Log, + accessLists: config.AccessLists, + access: config.Access, + clock: config.Clock, + }, nil +} + +// Generate will generate the user login state for the given user. +func (g *Generator) Generate(ctx context.Context, user types.User) (*userloginstate.UserLoginState, error) { + var traits map[string][]string + if len(user.GetTraits()) > 0 { + traits = make(map[string][]string, len(user.GetTraits())) + for k, v := range user.GetTraits() { + traits[k] = utils.CopyStrings(v) + } + } + // Create a new empty user login state. + uls, err := userloginstate.New( + header.Metadata{ + Name: user.GetName(), + }, userloginstate.Spec{ + Roles: utils.CopyStrings(user.GetRoles()), + Traits: traits, + UserType: user.GetUserType(), + }) + if err != nil { + return nil, trace.Wrap(err) + } + + // Generate the user login state. + if err := g.addAccessListsToState(ctx, user, uls); err != nil { + return nil, trace.Wrap(err) + } + + // Clean up the user login state after generating it. + if err := g.postProcess(ctx, uls); err != nil { + return nil, trace.Wrap(err) + } + + return uls, nil +} + +// addAccessListsToState will added the user's applicable access lists to the user login state. +func (g *Generator) addAccessListsToState(ctx context.Context, user types.User, state *userloginstate.UserLoginState) error { + accessLists, err := g.accessLists.GetAccessLists(ctx) + if err != nil { + return trace.Wrap(err) + } + + // Create an identity for testing membership to access lists. + identity := tlsca.Identity{ + Username: user.GetName(), + Groups: user.GetRoles(), + Traits: user.GetTraits(), + UserType: user.GetUserType(), + } + + for _, accessList := range accessLists { + if err := services.IsMember(identity, g.clock, accessList); err != nil { + continue + } + + state.Spec.Roles = append(state.Spec.Roles, accessList.Spec.Grants.Roles...) + + for k, values := range accessList.Spec.Grants.Traits { + state.Spec.Traits[k] = append(state.Spec.Traits[k], values...) + } + } + + return nil +} + +// postProcess will perform cleanup to the user login state after its generation. +func (g *Generator) postProcess(ctx context.Context, state *userloginstate.UserLoginState) error { + // Deduplicate roles and traits + state.Spec.Roles = utils.Deduplicate(state.Spec.Roles) + + for k, v := range state.Spec.Traits { + state.Spec.Traits[k] = utils.Deduplicate(v) + } + + // If there are no roles, don't bother filtering out non-existent roles + if len(state.Spec.Roles) == 0 { + return nil + } + + // Remove roles that don't exist in the backend so that we don't generate certs for non-existent roles. + // Doing so can prevent login from working properly. This could occur if access lists refer to roles that + // no longer exist, for example. + roles, err := g.access.GetRoles(ctx) + if err != nil { + return trace.Wrap(err) + } + + roleLookup := map[string]bool{} + for _, role := range roles { + roleLookup[role.GetName()] = true + } + + existingRoles := []string{} + for _, role := range state.Spec.Roles { + if roleLookup[role] { + existingRoles = append(existingRoles, role) + } else { + g.log.Warnf("Role %s does not exist when trying to add user login state, will be skipped", role) + } + } + state.Spec.Roles = existingRoles + + return nil +} + +// LoginHook creates a login hook from the Generator and the user login state service. +func (g *Generator) LoginHook(ulsService services.UserLoginStates) func(context.Context, types.User) error { + return func(ctx context.Context, user types.User) error { + uls, err := g.Generate(ctx, user) + if err != nil { + return trace.Wrap(err) + } + + _, err = ulsService.UpsertUserLoginState(ctx, uls) + return trace.Wrap(err) + } +} diff --git a/lib/auth/userloginstate/generator_test.go b/lib/auth/userloginstate/generator_test.go new file mode 100644 index 0000000000000..34954f4bc52e2 --- /dev/null +++ b/lib/auth/userloginstate/generator_test.go @@ -0,0 +1,279 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package userloginstate + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/accesslist" + "github.com/gravitational/teleport/api/types/header" + "github.com/gravitational/teleport/api/types/trait" + "github.com/gravitational/teleport/api/types/userloginstate" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" +) + +func TestAccessLists(t *testing.T) { + user, err := types.NewUser("user") + user.SetRoles([]string{"orole1"}) + user.SetTraits(map[string][]string{ + "otrait1": {"value1", "value2"}, + }) + require.NoError(t, err) + + tests := []struct { + name string + accessLists []*accesslist.AccessList + roles []string + expected *userloginstate.UserLoginState + }{ + { + name: "access lists are empty", + roles: []string{"orole1"}, + expected: newUserLoginState(t, "user", []string{ + "orole1", + }, map[string][]string{ + "otrait1": {"value1", "value2"}, + }), + }, + { + name: "access lists add roles and traits", + accessLists: []*accesslist.AccessList{ + newAccessList(t, "1", []string{"user"}, []string{"role1"}, trait.Traits{ + "trait1": []string{"value1"}, + }), + newAccessList(t, "2", []string{"user"}, []string{"role2"}, trait.Traits{ + "trait1": []string{"value2"}, + "trait2": []string{"value3"}, + }), + }, + roles: []string{"orole1", "role1", "role2"}, + expected: newUserLoginState(t, "user", + []string{ + "orole1", + "role1", + "role2", + }, trait.Traits{ + "otrait1": []string{"value1", "value2"}, + "trait1": []string{"value1", "value2"}, + "trait2": []string{"value3"}, + }), + }, + { + name: "access lists add roles and traits, roles missing from backend", + accessLists: []*accesslist.AccessList{ + newAccessList(t, "1", []string{"user"}, []string{"role1"}, trait.Traits{ + "trait1": []string{"value1"}, + }), + newAccessList(t, "2", []string{"user"}, []string{"role2"}, trait.Traits{ + "trait1": []string{"value2"}, + "trait2": []string{"value3"}, + }), + }, + roles: []string{"orole1"}, + expected: newUserLoginState(t, "user", + []string{"orole1"}, trait.Traits{ + "otrait1": []string{"value1", "value2"}, + "trait1": []string{"value1", "value2"}, + "trait2": []string{"value3"}, + }), + }, + { + name: "access lists only a member of some lists", + accessLists: []*accesslist.AccessList{ + newAccessList(t, "1", []string{"user"}, []string{"role1"}, trait.Traits{ + "trait1": []string{"value1"}, + }), + newAccessList(t, "2", []string{"not-user"}, []string{"role2"}, trait.Traits{ + "trait1": []string{"value2"}, + "trait2": []string{"value3"}, + }), + }, + roles: []string{"orole1", "role1", "role2"}, + expected: newUserLoginState(t, "user", + []string{ + "orole1", + "role1", + }, trait.Traits{ + "otrait1": []string{"value1", "value2"}, + "trait1": []string{"value1"}, + }), + }, + { + name: "access lists add roles with duplicates", + accessLists: []*accesslist.AccessList{ + newAccessList(t, "1", []string{"user"}, []string{"role1", "role2"}, trait.Traits{}), + newAccessList(t, "2", []string{"user"}, []string{"role2", "role3"}, trait.Traits{}), + }, + roles: []string{"orole1", "role1", "role2", "role3"}, + expected: newUserLoginState(t, "user", + []string{ + "orole1", + "role1", + "role2", + "role3", + }, trait.Traits{ + "otrait1": []string{"value1", "value2"}, + }), + }, + { + name: "access lists add traits with duplicates", + accessLists: []*accesslist.AccessList{ + newAccessList(t, "1", []string{"user"}, []string{}, + trait.Traits{ + "trait1": []string{"value1", "value2"}, + "trait2": []string{"value3", "value4"}, + }, + ), + newAccessList(t, "2", []string{"user"}, []string{}, + trait.Traits{ + "trait2": []string{"value3", "value1"}, + "trait3": []string{"value5", "value6"}, + }, + ), + }, + roles: []string{"orole1"}, + expected: newUserLoginState(t, "user", + []string{ + "orole1", + }, + trait.Traits{ + "otrait1": []string{"value1", "value2"}, + "trait1": []string{"value1", "value2"}, + "trait2": []string{"value3", "value4", "value1"}, + "trait3": []string{"value5", "value6"}, + }), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + svc, backendSvc := initGeneratorSvc(t) + + for _, accessList := range test.accessLists { + _, err = backendSvc.UpsertAccessList(ctx, accessList) + require.NoError(t, err) + } + + for _, role := range test.roles { + role, err := types.NewRole(role, types.RoleSpecV6{}) + require.NoError(t, err) + require.NoError(t, backendSvc.UpsertRole(ctx, role)) + } + + state, err := svc.Generate(ctx, user) + require.NoError(t, err) + require.Empty(t, cmp.Diff(test.expected, state, + cmpopts.SortSlices(func(str1, str2 string) bool { + return str1 < str2 + }))) + }) + } +} + +type svc struct { + services.AccessLists + services.Access +} + +func initGeneratorSvc(t *testing.T) (*Generator, *svc) { + t.Helper() + + clock := clockwork.NewFakeClock() + mem, err := memory.New(memory.Config{ + Clock: clock, + }) + require.NoError(t, err) + + accessListsSvc, err := local.NewAccessListService(mem, clock) + require.NoError(t, err) + accessSvc := local.NewAccessService(mem) + + log := logrus.WithField("test", "logger") + return &Generator{log: log, accessLists: accessListsSvc, access: accessSvc, clock: clock}, + &svc{AccessLists: accessListsSvc, Access: accessSvc} +} + +func newAccessList(t *testing.T, name string, members []string, roles []string, traits trait.Traits) *accesslist.AccessList { + t.Helper() + + alMembers := make([]accesslist.Member, len(members)) + for i, member := range members { + alMembers[i] = accesslist.Member{ + Name: member, + Joined: time.Now(), + Expires: time.Now().Add(24 * time.Hour), + Reason: "added", + AddedBy: "owner", + } + } + + accessList, err := accesslist.NewAccessList(header.Metadata{ + Name: name, + }, accesslist.Spec{ + Audit: accesslist.Audit{ + Frequency: time.Hour, + }, + Owners: []accesslist.Owner{ + { + Name: "owner", + Description: "description", + }, + }, + OwnershipRequires: accesslist.Requires{ + Roles: []string{}, + Traits: map[string][]string{}, + }, + MembershipRequires: accesslist.Requires{ + Roles: []string{}, + Traits: map[string][]string{}, + }, + Grants: accesslist.Grants{ + Roles: roles, + Traits: traits, + }, + Members: alMembers, + }) + require.NoError(t, err) + + return accessList +} + +func newUserLoginState(t *testing.T, name string, roles []string, traits map[string][]string) *userloginstate.UserLoginState { + t.Helper() + + uls, err := userloginstate.New(header.Metadata{ + Name: name, + }, userloginstate.Spec{ + Roles: roles, + Traits: traits, + }) + require.NoError(t, err) + + return uls +} diff --git a/lib/auth/userloginstate/service_test.go b/lib/auth/userloginstate/service_test.go index e22061f8bd040..208c7b4923c5c 100644 --- a/lib/auth/userloginstate/service_test.go +++ b/lib/auth/userloginstate/service_test.go @@ -42,13 +42,22 @@ const ( noAccessUser = "no-access-user" ) -// cmpOpts are general cmpOpts for all comparisons. -var cmpOpts = []cmp.Option{ - cmpopts.IgnoreFields(header.Metadata{}, "ID"), - cmpopts.SortSlices(func(a, b *userloginstate.UserLoginState) bool { - return a.GetName() < b.GetName() - }), -} +var ( + // cmpOpts are general cmpOpts for all comparisons across the service tests. + cmpOpts = []cmp.Option{ + cmpopts.IgnoreFields(header.Metadata{}, "ID"), + cmpopts.SortSlices(func(a, b *userloginstate.UserLoginState) bool { + return a.GetName() < b.GetName() + }), + } + + stRoles = []string{"role1", "role2"} + + stTraits = trait.Traits{ + "key1": []string{"value1"}, + "key2": []string{"value2"}, + } +) func TestGetUserLoginStates(t *testing.T) { t.Parallel() @@ -59,8 +68,8 @@ func TestGetUserLoginStates(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1") - uls2 := newUserLoginState(t, "2") + uls1 := newUserLoginState(t, "1", stRoles, stTraits) + uls2 := newUserLoginState(t, "2", stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -85,8 +94,8 @@ func TestUpsertUserLoginStates(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1") - uls2 := newUserLoginState(t, "2") + uls1 := newUserLoginState(t, "1", stRoles, stTraits) + uls2 := newUserLoginState(t, "2", stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -108,7 +117,7 @@ func TestGetUserLoginState(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1") + uls1 := newUserLoginState(t, "1", stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -134,7 +143,7 @@ func TestDeleteUserLoginState(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1") + uls1 := newUserLoginState(t, "1", stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -161,8 +170,8 @@ func TestDeleteAllAccessLists(t *testing.T) { require.NoError(t, err) require.Empty(t, getResp.UserLoginStates) - uls1 := newUserLoginState(t, "1") - uls2 := newUserLoginState(t, "2") + uls1 := newUserLoginState(t, "1", stRoles, stTraits) + uls2 := newUserLoginState(t, "2", stRoles, stTraits) _, err = svc.UpsertUserLoginState(ctx, &userloginstatev1.UpsertUserLoginStateRequest{UserLoginState: conv.ToProto(uls1)}) require.NoError(t, err) @@ -279,26 +288,6 @@ func genUserContext(ctx context.Context, username string, groups []string) conte }) } -func newUserLoginState(t *testing.T, name string) *userloginstate.UserLoginState { - t.Helper() - - uls, err := userloginstate.New( - header.Metadata{ - Name: name, - }, - userloginstate.Spec{ - Roles: []string{"role1", "role2"}, - Traits: trait.Traits{ - "key1": []string{"value1"}, - "key2": []string{"value2"}, - }, - }, - ) - require.NoError(t, err) - - return uls -} - func mustFromProto(t *testing.T, uls *userloginstatev1.UserLoginState) *userloginstate.UserLoginState { t.Helper() diff --git a/lib/services/access_checker.go b/lib/services/access_checker.go index 9863f07981778..285b58905f5cb 100644 --- a/lib/services/access_checker.go +++ b/lib/services/access_checker.go @@ -1076,11 +1076,41 @@ func AccessInfoFromRemoteIdentity(identity tlsca.Identity, roleMap types.RoleMap }, nil } +// UserState is a representation of a user's current state. +type UserState interface { + // GetName returns the username associated with the user state. + GetName() string + + // GetRoles returns the roles associated with the user's current state. + GetRoles() []string + + // GetTraits returns the traits associated with the user's current sate. + GetTraits() map[string][]string + + // GetUserType returns the user type for the user login state. + GetUserType() types.UserType + + // IsBot returns true if the user belongs to a bot. + IsBot() bool + + // BotGenerationLabel returns the bot generation label for the user. + BotGenerationLabel() string +} + // AccessInfoFromUser return a new AccessInfo populated from the roles and // traits held be the given user. This should only be used in cases where the // user does not have any active access requests (initial web login, initial // tbot certs, tests). +// TODO(mdwn): Remove this once enterprise has been moved away from this function. func AccessInfoFromUser(user types.User) *AccessInfo { + return AccessInfoFromUserState(user) +} + +// AccessInfoFromUserState return a new AccessInfo populated from the roles and +// traits held be the given user state. This should only be used in cases where the +// user does not have any active access requests (initial web login, initial +// tbot certs, tests). +func AccessInfoFromUserState(user UserState) *AccessInfo { roles := user.GetRoles() traits := user.GetTraits() return &AccessInfo{ diff --git a/lib/services/parser.go b/lib/services/parser.go index d2fccdcb30129..7f3f4f58cfd67 100644 --- a/lib/services/parser.go +++ b/lib/services/parser.go @@ -262,7 +262,7 @@ func (l *LogAction) Log(level, format string, args ...interface{}) predicate.Boo // Context is a default rule context used in teleport type Context struct { // User is currently authenticated user - User types.User + User UserState // Resource is an optional resource, in case if the rule // checks access to the resource Resource types.Resource @@ -324,7 +324,7 @@ func (ctx *Context) GetResource() (types.Resource, error) { func (ctx *Context) GetIdentifier(fields []string) (interface{}, error) { switch fields[0] { case UserIdentifier: - var user types.User + var user UserState if ctx.User == nil { user = emptyUser } else { diff --git a/lib/services/role_test.go b/lib/services/role_test.go index be17067ca1a5f..0757f2e940245 100644 --- a/lib/services/role_test.go +++ b/lib/services/role_test.go @@ -7450,7 +7450,7 @@ func TestNewAccessCheckerForRemoteCluster(t *testing.T) { currentUser: user, } - accessInfo := AccessInfoFromUser(user) + accessInfo := AccessInfoFromUserState(user) accessChecker, err := NewAccessCheckerForRemoteCluster(context.Background(), accessInfo, "clustername", currentUserRoleGetter) require.NoError(t, err)