diff --git a/lib/auth/access.go b/lib/auth/access.go index 5440dfb0debb3..72d5a32ee1b8d 100644 --- a/lib/auth/access.go +++ b/lib/auth/access.go @@ -18,10 +18,13 @@ package auth import ( "context" + "errors" + "slices" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/accesslist" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/events" @@ -48,6 +51,12 @@ func (a *Server) UpsertRole(ctx context.Context, role types.Role) error { return nil } +var ( + errDeleteRoleUser = errors.New("failed to delete a role that is still in use by a user, check the system server logs for more details") + errDeleteRoleCA = errors.New("failed to delete a role that is still in use by a certificate authority, check the system server logs for more details") + errDeleteRoleAccessList = errors.New("failed to delete a role that is still in use by an access list, check the system server logs for more details") +) + // DeleteRole deletes a role and emits a related audit event. func (a *Server) DeleteRole(ctx context.Context, name string) error { // check if this role is used by CA or Users @@ -56,13 +65,11 @@ func (a *Server) DeleteRole(ctx context.Context, name string) error { return trace.Wrap(err) } for _, u := range users { - for _, r := range u.GetRoles() { - if r == name { - // Mask the actual error here as it could be used to enumerate users - // within the system. - log.Warnf("Failed to delete role: role %v is used by user %v.", name, u.GetName()) - return trace.BadParameter("failed to delete a role that is still in use by a user, check the system server logs for more details") - } + if slices.Contains(u.GetRoles(), name) { + // Mask the actual error here as it could be used to enumerate users + // within the system. + log.Warnf("Failed to delete role: role %v is used by user %v.", name, u.GetName()) + return trace.Wrap(errDeleteRoleUser) } } // check if it's used by some external cert authorities, e.g. @@ -72,13 +79,42 @@ func (a *Server) DeleteRole(ctx context.Context, name string) error { return trace.Wrap(err) } for _, a := range cas { - for _, r := range a.GetRoles() { - if r == name { - // Mask the actual error here as it could be used to enumerate users - // within the system. - log.Warnf("Failed to delete role: role %v is used by user cert authority %v", name, a.GetClusterName()) - return trace.BadParameter("failed to delete a role that is still in use by a user, check the system server logs for more details") + if slices.Contains(a.GetRoles(), name) { + // Mask the actual error here as it could be used to enumerate users + // within the system. + log.Warnf("Failed to delete role: role %v is used by user cert authority %v", name, a.GetClusterName()) + return trace.Wrap(errDeleteRoleCA) + } + } + + var nextToken string + for { + var accessLists []*accesslist.AccessList + var err error + accessLists, nextToken, err = a.Services.AccessListClient().ListAccessLists(ctx, 0 /* default page size */, nextToken) + if err != nil { + return trace.Wrap(err) + } + + for _, accessList := range accessLists { + if slices.Contains(accessList.Spec.Grants.Roles, name) { + log.Warnf("Failed to delete role: role %v is granted by access list %s", name, accessList.GetName()) + return trace.Wrap(errDeleteRoleAccessList) + } + + if slices.Contains(accessList.Spec.MembershipRequires.Roles, name) { + log.Warnf("Failed to delete role: role %v is required by members of access list %s", name, accessList.GetName()) + return trace.Wrap(errDeleteRoleAccessList) } + + if slices.Contains(accessList.Spec.OwnershipRequires.Roles, name) { + log.Warnf("Failed to delete role: role %v is required by owners of access list %s", name, accessList.GetName()) + return trace.Wrap(errDeleteRoleAccessList) + } + } + + if nextToken == "" { + break } } diff --git a/lib/auth/access_test.go b/lib/auth/access_test.go index 5e96e3b800aad..3fbbbeea2ec3b 100644 --- a/lib/auth/access_test.go +++ b/lib/auth/access_test.go @@ -19,12 +19,15 @@ package auth import ( "context" "testing" + "time" "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/accesslist" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/types/header" "github.com/gravitational/teleport/lib/events" ) @@ -68,6 +71,98 @@ func TestUpsertDeleteRoleEventsEmitted(t *testing.T) { require.Nil(t, p.mockEmitter.LastEvent()) } +func TestUpsertDeleteDependentRoles(t *testing.T) { + t.Parallel() + ctx := context.Background() + p, err := newTestPack(ctx, t.TempDir()) + require.NoError(t, err) + + // test create new role + role, err := types.NewRole("test-role", types.RoleSpecV6{ + Options: types.RoleOptions{}, + Allow: types.RoleConditions{}, + }) + require.NoError(t, err) + + // Create a role and assign it to a user. + err = p.a.UpsertRole(ctx, role) + require.NoError(t, err) + user, err := types.NewUser("test-user") + require.NoError(t, err) + user.AddRole(role.GetName()) + err = p.a.CreateUser(ctx, user) + require.NoError(t, err) + + // Deletion should fail. + require.ErrorIs(t, p.a.DeleteRole(ctx, role.GetName()), errDeleteRoleUser) + require.NoError(t, p.a.DeleteUser(ctx, user.GetName())) + + clusterName, err := p.a.GetClusterName() + require.NoError(t, err) + + // Update the user CA with the role. + ca, err := p.a.GetCertAuthority(ctx, types.CertAuthID{Type: types.UserCA, DomainName: clusterName.GetClusterName()}, true) + require.NoError(t, err) + ca.AddRole(role.GetName()) + require.NoError(t, p.a.UpsertCertAuthority(ctx, ca)) + + // Deletion should fail. + require.ErrorIs(t, p.a.DeleteRole(ctx, role.GetName()), errDeleteRoleCA) + + // Clear out the roles for the CA. + ca.SetRoles([]string{}) + require.NoError(t, p.a.UpsertCertAuthority(ctx, ca)) + + // Create an access list that references the role. + accessList, err := accesslist.NewAccessList(header.Metadata{ + Name: "test-access-list", + }, accesslist.Spec{ + Title: "simple", + Owners: []accesslist.Owner{ + {Name: "some-user"}, + }, + Grants: accesslist.Grants{ + Roles: []string{role.GetName()}, + }, + Audit: accesslist.Audit{ + NextAuditDate: time.Now(), + }, + MembershipRequires: accesslist.Requires{ + Roles: []string{role.GetName()}, + }, + OwnershipRequires: accesslist.Requires{ + Roles: []string{role.GetName()}, + }, + }) + require.NoError(t, err) + _, err = p.a.UpsertAccessList(ctx, accessList) + require.NoError(t, err) + + // Deletion should fail due to the grant. + require.ErrorIs(t, p.a.DeleteRole(ctx, role.GetName()), errDeleteRoleAccessList) + + accessList.Spec.Grants.Roles = []string{"non-existent-role"} + _, err = p.a.UpsertAccessList(ctx, accessList) + require.NoError(t, err) + + // Deletion should fail due to membership requires. + require.ErrorIs(t, p.a.DeleteRole(ctx, role.GetName()), errDeleteRoleAccessList) + + accessList.Spec.MembershipRequires.Roles = []string{"non-existent-role"} + _, err = p.a.UpsertAccessList(ctx, accessList) + require.NoError(t, err) + + // Deletion should fail due to ownership requires. + require.ErrorIs(t, p.a.DeleteRole(ctx, role.GetName()), errDeleteRoleAccessList) + + accessList.Spec.OwnershipRequires.Roles = []string{"non-existent-role"} + _, err = p.a.UpsertAccessList(ctx, accessList) + require.NoError(t, err) + + // Deletion should succeed + require.NoError(t, p.a.DeleteRole(ctx, role.GetName())) +} + func TestUpsertDeleteLockEventsEmitted(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 477befe9a3987..4168069b5ba4c 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -431,7 +431,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { } // Add in a login hook for generating state during user login. - ulsGenerator, err := userloginstate.NewGenerator(userloginstate.GeneratorConfig{ + as.ulsGenerator, err = userloginstate.NewGenerator(userloginstate.GeneratorConfig{ Log: log, AccessLists: services, Access: services, @@ -442,7 +442,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { return nil, trace.Wrap(err) } - as.RegisterLoginHook(ulsGenerator.LoginHook(services.UserLoginStates)) + as.RegisterLoginHook(as.ulsGenerator.LoginHook(services.UserLoginStates)) return &as, nil } @@ -791,6 +791,9 @@ type Server struct { // accessMonitoringEnabled is a flag that indicates whether access monitoring is enabled. accessMonitoringEnabled bool + + // ulsGenerator is the user login state generator. + ulsGenerator *userloginstate.Generator } // SetSAMLService registers svc as the SAMLService that provides the SAML @@ -3354,10 +3357,17 @@ func (a *Server) ExtendWebSession(ctx context.Context, req WebSessionReq, identi if err != nil { return nil, trace.Wrap(err) } + + // Make sure to refresh the user login state. + userState, err := a.ulsGenerator.Refresh(ctx, user, a.UserLoginStates) + if err != nil { + return nil, trace.Wrap(err) + } + // Updating traits is needed for guided SSH flow in Discover. - traits = user.GetTraits() + traits = userState.GetTraits() // Updating roles is needed for guided Connect My Computer flow in Discover. - roles = user.GetRoles() + roles = userState.GetRoles() } else if req.AccessRequestID != "" { accessRequest, err := a.getValidatedAccessRequest(ctx, identity, req.User, req.AccessRequestID) @@ -3391,7 +3401,7 @@ func (a *Server) ExtendWebSession(ctx context.Context, req WebSessionReq, identi } // Get default/static roles. - user, err := a.GetUser(req.User, false) + userState, err := a.GetUserOrLoginState(ctx, req.User) if err != nil { return nil, trace.Wrap(err, "failed to switchback") } @@ -3400,7 +3410,7 @@ func (a *Server) ExtendWebSession(ctx context.Context, req WebSessionReq, identi allowedResourceIDs = nil // Calculate expiry time. - roleSet, err := services.FetchRoles(user.GetRoles(), a, user.GetTraits()) + roleSet, err := services.FetchRoles(userState.GetRoles(), a, userState.GetTraits()) if err != nil { return nil, trace.Wrap(err) } @@ -3409,7 +3419,7 @@ func (a *Server) ExtendWebSession(ctx context.Context, req WebSessionReq, identi // Set default roles and expiration. expiresAt = prevSession.GetLoginTime().UTC().Add(sessionTTL) - roles = user.GetRoles() + roles = userState.GetRoles() accessRequests = nil } diff --git a/lib/auth/userloginstate/generator.go b/lib/auth/userloginstate/generator.go index 4a4322d482bf8..c5229bbcc800b 100644 --- a/lib/auth/userloginstate/generator.go +++ b/lib/auth/userloginstate/generator.go @@ -266,15 +266,21 @@ func (g *Generator) emitUsageEvent(ctx context.Context, user types.User, state * return nil } +// Refresh will take the user and update the user login state in the backend. +func (g *Generator) Refresh(ctx context.Context, user types.User, ulsService services.UserLoginStates) (*userloginstate.UserLoginState, error) { + uls, err := g.Generate(ctx, user) + if err != nil { + return nil, trace.Wrap(err) + } + + uls, err = ulsService.UpsertUserLoginState(ctx, uls) + return uls, trace.Wrap(err) +} + // LoginHook creates a login hook from the Generator and the user login state service. func (g *Generator) LoginHook(ulsService services.UserLoginStates) func(context.Context, types.User) error { return func(ctx context.Context, user types.User) error { - uls, err := g.Generate(ctx, user) - if err != nil { - return trace.Wrap(err) - } - - _, err = ulsService.UpsertUserLoginState(ctx, uls) + _, err := g.Refresh(ctx, user, ulsService) return trace.Wrap(err) } }