Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 49 additions & 13 deletions lib/auth/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ package auth

import (
"context"
"errors"
"slices"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/events"
Expand All @@ -48,6 +51,12 @@ func (a *Server) UpsertRole(ctx context.Context, role types.Role) error {
return nil
}

var (
errDeleteRoleUser = errors.New("failed to delete a role that is still in use by a user, check the system server logs for more details")
errDeleteRoleCA = errors.New("failed to delete a role that is still in use by a certificate authority, check the system server logs for more details")
errDeleteRoleAccessList = errors.New("failed to delete a role that is still in use by an access list, check the system server logs for more details")
)

// DeleteRole deletes a role and emits a related audit event.
func (a *Server) DeleteRole(ctx context.Context, name string) error {
// check if this role is used by CA or Users
Expand All @@ -56,13 +65,11 @@ func (a *Server) DeleteRole(ctx context.Context, name string) error {
return trace.Wrap(err)
}
for _, u := range users {
for _, r := range u.GetRoles() {
if r == name {
// Mask the actual error here as it could be used to enumerate users
// within the system.
log.Warnf("Failed to delete role: role %v is used by user %v.", name, u.GetName())
return trace.BadParameter("failed to delete a role that is still in use by a user, check the system server logs for more details")
}
if slices.Contains(u.GetRoles(), name) {
// Mask the actual error here as it could be used to enumerate users
// within the system.
log.Warnf("Failed to delete role: role %v is used by user %v.", name, u.GetName())
return trace.Wrap(errDeleteRoleUser)
}
}
// check if it's used by some external cert authorities, e.g.
Expand All @@ -72,13 +79,42 @@ func (a *Server) DeleteRole(ctx context.Context, name string) error {
return trace.Wrap(err)
}
for _, a := range cas {
for _, r := range a.GetRoles() {
if r == name {
// Mask the actual error here as it could be used to enumerate users
// within the system.
log.Warnf("Failed to delete role: role %v is used by user cert authority %v", name, a.GetClusterName())
return trace.BadParameter("failed to delete a role that is still in use by a user, check the system server logs for more details")
if slices.Contains(a.GetRoles(), name) {
// Mask the actual error here as it could be used to enumerate users
// within the system.
log.Warnf("Failed to delete role: role %v is used by user cert authority %v", name, a.GetClusterName())
return trace.Wrap(errDeleteRoleCA)
}
}

var nextToken string
for {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no "forEach" kind of helper method for access lists already somewhere?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, unfortunately. :-(

var accessLists []*accesslist.AccessList
var err error
accessLists, nextToken, err = a.Services.AccessListClient().ListAccessLists(ctx, 0 /* default page size */, nextToken)
if err != nil {
return trace.Wrap(err)
}

for _, accessList := range accessLists {
if slices.Contains(accessList.Spec.Grants.Roles, name) {
log.Warnf("Failed to delete role: role %v is granted by access list %s", name, accessList.GetName())
return trace.Wrap(errDeleteRoleAccessList)
}

if slices.Contains(accessList.Spec.MembershipRequires.Roles, name) {
log.Warnf("Failed to delete role: role %v is required by members of access list %s", name, accessList.GetName())
return trace.Wrap(errDeleteRoleAccessList)
}

if slices.Contains(accessList.Spec.OwnershipRequires.Roles, name) {
log.Warnf("Failed to delete role: role %v is required by owners of access list %s", name, accessList.GetName())
return trace.Wrap(errDeleteRoleAccessList)
}
}

if nextToken == "" {
break
}
}

Expand Down
95 changes: 95 additions & 0 deletions lib/auth/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ package auth
import (
"context"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/header"
"github.com/gravitational/teleport/lib/events"
)

Expand Down Expand Up @@ -68,6 +71,98 @@ func TestUpsertDeleteRoleEventsEmitted(t *testing.T) {
require.Nil(t, p.mockEmitter.LastEvent())
}

func TestUpsertDeleteDependentRoles(t *testing.T) {
t.Parallel()
ctx := context.Background()
p, err := newTestPack(ctx, t.TempDir())
require.NoError(t, err)

// test create new role
role, err := types.NewRole("test-role", types.RoleSpecV6{
Options: types.RoleOptions{},
Allow: types.RoleConditions{},
})
require.NoError(t, err)

// Create a role and assign it to a user.
err = p.a.UpsertRole(ctx, role)
require.NoError(t, err)
user, err := types.NewUser("test-user")
require.NoError(t, err)
user.AddRole(role.GetName())
err = p.a.CreateUser(ctx, user)
require.NoError(t, err)

// Deletion should fail.
require.ErrorIs(t, p.a.DeleteRole(ctx, role.GetName()), errDeleteRoleUser)
require.NoError(t, p.a.DeleteUser(ctx, user.GetName()))

clusterName, err := p.a.GetClusterName()
require.NoError(t, err)

// Update the user CA with the role.
ca, err := p.a.GetCertAuthority(ctx, types.CertAuthID{Type: types.UserCA, DomainName: clusterName.GetClusterName()}, true)
require.NoError(t, err)
ca.AddRole(role.GetName())
require.NoError(t, p.a.UpsertCertAuthority(ctx, ca))

// Deletion should fail.
require.ErrorIs(t, p.a.DeleteRole(ctx, role.GetName()), errDeleteRoleCA)

// Clear out the roles for the CA.
ca.SetRoles([]string{})
require.NoError(t, p.a.UpsertCertAuthority(ctx, ca))

// Create an access list that references the role.
accessList, err := accesslist.NewAccessList(header.Metadata{
Name: "test-access-list",
}, accesslist.Spec{
Title: "simple",
Owners: []accesslist.Owner{
{Name: "some-user"},
},
Grants: accesslist.Grants{
Roles: []string{role.GetName()},
},
Audit: accesslist.Audit{
NextAuditDate: time.Now(),
},
MembershipRequires: accesslist.Requires{
Roles: []string{role.GetName()},
},
OwnershipRequires: accesslist.Requires{
Roles: []string{role.GetName()},
},
})
require.NoError(t, err)
_, err = p.a.UpsertAccessList(ctx, accessList)
require.NoError(t, err)

// Deletion should fail due to the grant.
require.ErrorIs(t, p.a.DeleteRole(ctx, role.GetName()), errDeleteRoleAccessList)

accessList.Spec.Grants.Roles = []string{"non-existent-role"}
_, err = p.a.UpsertAccessList(ctx, accessList)
require.NoError(t, err)

// Deletion should fail due to membership requires.
require.ErrorIs(t, p.a.DeleteRole(ctx, role.GetName()), errDeleteRoleAccessList)

accessList.Spec.MembershipRequires.Roles = []string{"non-existent-role"}
_, err = p.a.UpsertAccessList(ctx, accessList)
require.NoError(t, err)

// Deletion should fail due to ownership requires.
require.ErrorIs(t, p.a.DeleteRole(ctx, role.GetName()), errDeleteRoleAccessList)

accessList.Spec.OwnershipRequires.Roles = []string{"non-existent-role"}
_, err = p.a.UpsertAccessList(ctx, accessList)
require.NoError(t, err)

// Deletion should succeed
require.NoError(t, p.a.DeleteRole(ctx, role.GetName()))
}

func TestUpsertDeleteLockEventsEmitted(t *testing.T) {
t.Parallel()
ctx := context.Background()
Expand Down
24 changes: 17 additions & 7 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
}

// Add in a login hook for generating state during user login.
ulsGenerator, err := userloginstate.NewGenerator(userloginstate.GeneratorConfig{
as.ulsGenerator, err = userloginstate.NewGenerator(userloginstate.GeneratorConfig{
Log: log,
AccessLists: services,
Access: services,
Expand All @@ -442,7 +442,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
return nil, trace.Wrap(err)
}

as.RegisterLoginHook(ulsGenerator.LoginHook(services.UserLoginStates))
as.RegisterLoginHook(as.ulsGenerator.LoginHook(services.UserLoginStates))

return &as, nil
}
Expand Down Expand Up @@ -791,6 +791,9 @@ type Server struct {

// accessMonitoringEnabled is a flag that indicates whether access monitoring is enabled.
accessMonitoringEnabled bool

// ulsGenerator is the user login state generator.
ulsGenerator *userloginstate.Generator
}

// SetSAMLService registers svc as the SAMLService that provides the SAML
Expand Down Expand Up @@ -3354,10 +3357,17 @@ func (a *Server) ExtendWebSession(ctx context.Context, req WebSessionReq, identi
if err != nil {
return nil, trace.Wrap(err)
}

// Make sure to refresh the user login state.
userState, err := a.ulsGenerator.Refresh(ctx, user, a.UserLoginStates)
if err != nil {
return nil, trace.Wrap(err)
}

// Updating traits is needed for guided SSH flow in Discover.
traits = user.GetTraits()
traits = userState.GetTraits()
// Updating roles is needed for guided Connect My Computer flow in Discover.
roles = user.GetRoles()
roles = userState.GetRoles()

} else if req.AccessRequestID != "" {
accessRequest, err := a.getValidatedAccessRequest(ctx, identity, req.User, req.AccessRequestID)
Expand Down Expand Up @@ -3391,7 +3401,7 @@ func (a *Server) ExtendWebSession(ctx context.Context, req WebSessionReq, identi
}

// Get default/static roles.
user, err := a.GetUser(req.User, false)
userState, err := a.GetUserOrLoginState(ctx, req.User)
if err != nil {
return nil, trace.Wrap(err, "failed to switchback")
}
Expand All @@ -3400,7 +3410,7 @@ func (a *Server) ExtendWebSession(ctx context.Context, req WebSessionReq, identi
allowedResourceIDs = nil

// Calculate expiry time.
roleSet, err := services.FetchRoles(user.GetRoles(), a, user.GetTraits())
roleSet, err := services.FetchRoles(userState.GetRoles(), a, userState.GetTraits())
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -3409,7 +3419,7 @@ func (a *Server) ExtendWebSession(ctx context.Context, req WebSessionReq, identi

// Set default roles and expiration.
expiresAt = prevSession.GetLoginTime().UTC().Add(sessionTTL)
roles = user.GetRoles()
roles = userState.GetRoles()
accessRequests = nil
}

Expand Down
18 changes: 12 additions & 6 deletions lib/auth/userloginstate/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,21 @@ func (g *Generator) emitUsageEvent(ctx context.Context, user types.User, state *
return nil
}

// Refresh will take the user and update the user login state in the backend.
func (g *Generator) Refresh(ctx context.Context, user types.User, ulsService services.UserLoginStates) (*userloginstate.UserLoginState, error) {
uls, err := g.Generate(ctx, user)
if err != nil {
return nil, trace.Wrap(err)
}

uls, err = ulsService.UpsertUserLoginState(ctx, uls)
return uls, trace.Wrap(err)
}

// LoginHook creates a login hook from the Generator and the user login state service.
func (g *Generator) LoginHook(ulsService services.UserLoginStates) func(context.Context, types.User) error {
return func(ctx context.Context, user types.User) error {
uls, err := g.Generate(ctx, user)
if err != nil {
return trace.Wrap(err)
}

_, err = ulsService.UpsertUserLoginState(ctx, uls)
_, err := g.Refresh(ctx, user, ulsService)
return trace.Wrap(err)
}
}