Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions integration/hostuser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,9 @@ func TestRootHostUsers(t *testing.T) {
namedShellUser := "named-shell"
absoluteShellUser := "absolute-shell"

t.Cleanup(func() { cleanupUsersAndGroups([]string{defaultShellUser, namedShellUser, absoluteShellUser}, nil) })
t.Cleanup(func() {
cleanupUsersAndGroups([]string{defaultShellUser, namedShellUser, absoluteShellUser}, nil)
})

// Create a user with a named shell expected to be available in the PATH
users := srv.NewHostUsers(context.Background(), presence, "host_uuid")
Expand All @@ -467,17 +469,16 @@ func TestRootHostUsers(t *testing.T) {
})
require.NoError(t, err)

// Create a user with an absolute path to a shell
_, err = users.UpsertUser(absoluteShellUser, &decisionpb.HostUsersInfo{
Mode: decisionpb.HostUserMode_HOST_USER_MODE_KEEP,
Shell: "/usr/bin/bash",
// Create a user with the host default shell (default behavior)
_, err = users.UpsertUser(defaultShellUser, &decisionpb.HostUsersInfo{
Mode: decisionpb.HostUserMode_HOST_USER_MODE_KEEP,
})
require.NoError(t, err)

// Create a user with the host default shell (default behavior)
_, err = users.UpsertUser(defaultShellUser, &decisionpb.HostUsersInfo{
// Create a user with an absolute path to a shell
_, err = users.UpsertUser(absoluteShellUser, &decisionpb.HostUsersInfo{
Mode: decisionpb.HostUserMode_HOST_USER_MODE_KEEP,
Shell: "zsh",
Shell: "/usr/bin/bash",
})
require.NoError(t, err)

Expand All @@ -502,7 +503,9 @@ func TestRootHostUsers(t *testing.T) {
assert.Equal(t, expectedShell, userShells[absoluteShellUser])
assert.NotEqual(t, expectedShell, userShells[defaultShellUser])

// User's shell should not be overwritten when updating, only when creating a new host user
// User's shell should be overwritten when a different shell
// is provided
expectedShell = "/usr/bin/sh"
_, err = users.UpsertUser(namedShellUser, &decisionpb.HostUsersInfo{
Mode: decisionpb.HostUserMode_HOST_USER_MODE_KEEP,
Shell: "sh",
Expand All @@ -512,6 +515,18 @@ func TestRootHostUsers(t *testing.T) {
userShells, err = getUserShells("/etc/passwd")
require.NoError(t, err)
assert.Equal(t, expectedShell, userShells[namedShellUser])

// Make sure we can change the user's shell back again.
expectedShell = "/usr/bin/bash"
_, err = users.UpsertUser(namedShellUser, &decisionpb.HostUsersInfo{
Mode: decisionpb.HostUserMode_HOST_USER_MODE_KEEP,
Shell: "bash",
})
require.NoError(t, err)

userShells, err = getUserShells("/etc/passwd")
require.NoError(t, err)
assert.Equal(t, expectedShell, userShells[namedShellUser])
})

t.Run("Test expiration removal", func(t *testing.T) {
Expand Down
30 changes: 24 additions & 6 deletions lib/srv/usermgmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ type HostUsersBackend interface {
LookupGroup(group string) (*user.Group, error)
// LookupGroupByID retrieves a group by its ID.
LookupGroupByID(gid string) (*user.Group, error)
// SetUserGroups sets a user's groups, replacing their existing groups.
SetUserGroups(name string, groups []string) error
// UpdateUser sets a user's groups and default shell, replacing their existing groups.
UpdateUser(name string, groups []string, defaultShell string) error
// CreateGroup creates a group on a host.
CreateGroup(group string, gid string) error
// CreateUser creates a user on a host.
Expand All @@ -143,6 +143,8 @@ type HostUsersBackend interface {
CreateHomeDirectory(userHome string, uid, gid string) error
// GetDefaultHomeDirectory returns the default home directory path for the given user
GetDefaultHomeDirectory(name string) (string, error)
// IsUsingShell returns whether or not the given user is currently using the given shell.
IsUsingShell(username, shell string) (bool, error)
// RemoveExpirations removes any sort of password or account expiration from the user
// that may have been placed by password policies.
RemoveExpirations(name string) error
Expand Down Expand Up @@ -315,7 +317,7 @@ func (u *HostUserManagement) updateUser(hostUser HostUser, ui *decisionpb.HostUs
}

return trace.Wrap(u.doWithUserLock(func(_ types.SemaphoreLease) error {
return trace.Wrap(u.backend.SetUserGroups(hostUser.Name, ui.Groups))
return trace.Wrap(u.backend.UpdateUser(hostUser.Name, ui.Groups, ui.Shell))
}))
}

Expand Down Expand Up @@ -495,10 +497,26 @@ func (u *HostUserManagement) UpsertUser(name string, ui *decisionpb.HostUsersInf
return closer, nil
}

if groups != nil {
if err := u.updateUser(*hostUser, ui); err != nil {
return nil, trace.Wrap(err)
// nothing to update
if groups == nil && ui.Shell == "" {
return closer, nil
}

if groups == nil {
// only bother checking the user's current shell if we aren't already
// updating their groups
usingShell, err := u.backend.IsUsingShell(name, ui.Shell)
if err != nil {
log.WarnContext(u.ctx, "Failed to check user's default shell", "error", err)
}

if usingShell {
return closer, nil
}
}

if err := u.updateUser(*hostUser, ui); err != nil {
return nil, trace.Wrap(err)
}

// attempt to remove password expirations from managed users if they've been added
Expand Down
21 changes: 18 additions & 3 deletions lib/srv/usermgmt_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ func (*HostUsersProvisioningBackend) LookupGroupByID(gid string) (*user.Group, e
return user.LookupGroupId(gid)
}

// SetUserGroups sets a user's groups, replacing their existing groups.
func (*HostUsersProvisioningBackend) SetUserGroups(name string, groups []string) error {
_, err := host.SetUserGroups(name, groups)
// UpdateUser sets a user's groups and default shell, replacing their existing groups.
func (*HostUsersProvisioningBackend) UpdateUser(name string, groups []string, defaultShell string) error {
_, err := host.UserUpdate(name, groups, defaultShell)
return trace.Wrap(err)
}

Expand All @@ -113,6 +113,21 @@ func (*HostUsersProvisioningBackend) GetAllUsers() ([]string, error) {
return users, err
}

// IsUsingShell returns whether or not the given user is already using the given shell.
func (*HostUsersProvisioningBackend) IsUsingShell(username, shell string) (bool, error) {
currentShell, err := host.UserShell(username)
if err != nil {
return false, trace.Wrap(err)
}

shellPath, err := exec.LookPath(shell)
if err != nil {
return false, trace.WrapWithMessage(err, "could not find path for shell %q", shell)
}

return shellPath == currentShell, nil
}

// CreateGroup creates a group on a host
func (*HostUsersProvisioningBackend) CreateGroup(name string, gid string) error {
_, err := host.GroupAdd(name, gid)
Expand Down
Loading
Loading