diff --git a/integration/hostuser_test.go b/integration/hostuser_test.go index 9480989ec3a8a..2fb97d24ffcda 100644 --- a/integration/hostuser_test.go +++ b/integration/hostuser_test.go @@ -457,7 +457,9 @@ func TestRootHostUsers(t *testing.T) { namedShellUser := "named-shell" absoluteShellUser := "absolute-shell" - t.Cleanup(func() { cleanupUsersAndGroups([]string{defaultShellUser, namedShellUser, absoluteShellUser}, nil) }) + t.Cleanup(func() { + cleanupUsersAndGroups([]string{defaultShellUser, namedShellUser, absoluteShellUser}, nil) + }) // Create a user with a named shell expected to be available in the PATH users := srv.NewHostUsers(context.Background(), presence, "host_uuid") @@ -467,17 +469,16 @@ func TestRootHostUsers(t *testing.T) { }) require.NoError(t, err) - // Create a user with an absolute path to a shell - _, err = users.UpsertUser(absoluteShellUser, &decisionpb.HostUsersInfo{ - Mode: decisionpb.HostUserMode_HOST_USER_MODE_KEEP, - Shell: "/usr/bin/bash", + // Create a user with the host default shell (default behavior) + _, err = users.UpsertUser(defaultShellUser, &decisionpb.HostUsersInfo{ + Mode: decisionpb.HostUserMode_HOST_USER_MODE_KEEP, }) require.NoError(t, err) - // Create a user with the host default shell (default behavior) - _, err = users.UpsertUser(defaultShellUser, &decisionpb.HostUsersInfo{ + // Create a user with an absolute path to a shell + _, err = users.UpsertUser(absoluteShellUser, &decisionpb.HostUsersInfo{ Mode: decisionpb.HostUserMode_HOST_USER_MODE_KEEP, - Shell: "zsh", + Shell: "/usr/bin/bash", }) require.NoError(t, err) @@ -502,7 +503,9 @@ func TestRootHostUsers(t *testing.T) { assert.Equal(t, expectedShell, userShells[absoluteShellUser]) assert.NotEqual(t, expectedShell, userShells[defaultShellUser]) - // User's shell should not be overwritten when updating, only when creating a new host user + // User's shell should be overwritten when a different shell + // is provided + expectedShell = "/usr/bin/sh" _, err = users.UpsertUser(namedShellUser, &decisionpb.HostUsersInfo{ Mode: decisionpb.HostUserMode_HOST_USER_MODE_KEEP, Shell: "sh", @@ -512,6 +515,18 @@ func TestRootHostUsers(t *testing.T) { userShells, err = getUserShells("/etc/passwd") require.NoError(t, err) assert.Equal(t, expectedShell, userShells[namedShellUser]) + + // Make sure we can change the user's shell back again. + expectedShell = "/usr/bin/bash" + _, err = users.UpsertUser(namedShellUser, &decisionpb.HostUsersInfo{ + Mode: decisionpb.HostUserMode_HOST_USER_MODE_KEEP, + Shell: "bash", + }) + require.NoError(t, err) + + userShells, err = getUserShells("/etc/passwd") + require.NoError(t, err) + assert.Equal(t, expectedShell, userShells[namedShellUser]) }) t.Run("Test expiration removal", func(t *testing.T) { diff --git a/lib/srv/usermgmt.go b/lib/srv/usermgmt.go index 0c874be7553c6..f71bb1f8709e8 100644 --- a/lib/srv/usermgmt.go +++ b/lib/srv/usermgmt.go @@ -131,8 +131,8 @@ type HostUsersBackend interface { LookupGroup(group string) (*user.Group, error) // LookupGroupByID retrieves a group by its ID. LookupGroupByID(gid string) (*user.Group, error) - // SetUserGroups sets a user's groups, replacing their existing groups. - SetUserGroups(name string, groups []string) error + // UpdateUser sets a user's groups and default shell, replacing their existing groups. + UpdateUser(name string, groups []string, defaultShell string) error // CreateGroup creates a group on a host. CreateGroup(group string, gid string) error // CreateUser creates a user on a host. @@ -143,6 +143,8 @@ type HostUsersBackend interface { CreateHomeDirectory(userHome string, uid, gid string) error // GetDefaultHomeDirectory returns the default home directory path for the given user GetDefaultHomeDirectory(name string) (string, error) + // IsUsingShell returns whether or not the given user is currently using the given shell. + IsUsingShell(username, shell string) (bool, error) // RemoveExpirations removes any sort of password or account expiration from the user // that may have been placed by password policies. RemoveExpirations(name string) error @@ -315,7 +317,7 @@ func (u *HostUserManagement) updateUser(hostUser HostUser, ui *decisionpb.HostUs } return trace.Wrap(u.doWithUserLock(func(_ types.SemaphoreLease) error { - return trace.Wrap(u.backend.SetUserGroups(hostUser.Name, ui.Groups)) + return trace.Wrap(u.backend.UpdateUser(hostUser.Name, ui.Groups, ui.Shell)) })) } @@ -495,10 +497,26 @@ func (u *HostUserManagement) UpsertUser(name string, ui *decisionpb.HostUsersInf return closer, nil } - if groups != nil { - if err := u.updateUser(*hostUser, ui); err != nil { - return nil, trace.Wrap(err) + // nothing to update + if groups == nil && ui.Shell == "" { + return closer, nil + } + + if groups == nil { + // only bother checking the user's current shell if we aren't already + // updating their groups + usingShell, err := u.backend.IsUsingShell(name, ui.Shell) + if err != nil { + log.WarnContext(u.ctx, "Failed to check user's default shell", "error", err) } + + if usingShell { + return closer, nil + } + } + + if err := u.updateUser(*hostUser, ui); err != nil { + return nil, trace.Wrap(err) } // attempt to remove password expirations from managed users if they've been added diff --git a/lib/srv/usermgmt_linux.go b/lib/srv/usermgmt_linux.go index b1c6ed22c416c..667675301d30b 100644 --- a/lib/srv/usermgmt_linux.go +++ b/lib/srv/usermgmt_linux.go @@ -101,9 +101,9 @@ func (*HostUsersProvisioningBackend) LookupGroupByID(gid string) (*user.Group, e return user.LookupGroupId(gid) } -// SetUserGroups sets a user's groups, replacing their existing groups. -func (*HostUsersProvisioningBackend) SetUserGroups(name string, groups []string) error { - _, err := host.SetUserGroups(name, groups) +// UpdateUser sets a user's groups and default shell, replacing their existing groups. +func (*HostUsersProvisioningBackend) UpdateUser(name string, groups []string, defaultShell string) error { + _, err := host.UserUpdate(name, groups, defaultShell) return trace.Wrap(err) } @@ -113,6 +113,21 @@ func (*HostUsersProvisioningBackend) GetAllUsers() ([]string, error) { return users, err } +// IsUsingShell returns whether or not the given user is already using the given shell. +func (*HostUsersProvisioningBackend) IsUsingShell(username, shell string) (bool, error) { + currentShell, err := host.UserShell(username) + if err != nil { + return false, trace.Wrap(err) + } + + shellPath, err := exec.LookPath(shell) + if err != nil { + return false, trace.WrapWithMessage(err, "could not find path for shell %q", shell) + } + + return shellPath == currentShell, nil +} + // CreateGroup creates a group on a host func (*HostUsersProvisioningBackend) CreateGroup(name string, gid string) error { _, err := host.GroupAdd(name, gid) diff --git a/lib/srv/usermgmt_test.go b/lib/srv/usermgmt_test.go index 7d64c3952f832..9bc246c566d3d 100644 --- a/lib/srv/usermgmt_test.go +++ b/lib/srv/usermgmt_test.go @@ -49,19 +49,22 @@ type testHostUserBackend struct { userUID map[string]string // userGID: user -> gid userGID map[string]string + // userShells: user -> shell path + userShells map[string]string - setUserGroupsCalls int + updateUserCalls int createHomeDirectoryCalls int groupDatabaseErr error } func newTestUserMgmt() *testHostUserBackend { return &testHostUserBackend{ - users: map[string][]string{}, - groups: map[string]string{}, - sudoers: map[string]string{}, - userUID: map[string]string{}, - userGID: map[string]string{}, + users: map[string][]string{}, + groups: map[string]string{}, + sudoers: map[string]string{}, + userUID: map[string]string{}, + userGID: map[string]string{}, + userShells: map[string]string{}, } } @@ -111,12 +114,16 @@ func (tm *testHostUserBackend) LookupGroupByID(gid string) (*user.Group, error) return nil, user.UnknownGroupIdError(gid) } -func (tm *testHostUserBackend) SetUserGroups(name string, groups []string) error { - tm.setUserGroupsCalls++ +func (tm *testHostUserBackend) UpdateUser(name string, groups []string, defaultShell string) error { + tm.updateUserCalls++ if _, ok := tm.users[name]; !ok { return trace.NotFound("User %q doesn't exist", name) } + tm.users[name] = groups + if defaultShell != "" { + tm.userShells[name] = defaultShell + } return nil } @@ -158,6 +165,11 @@ func (tm *testHostUserBackend) CreateUser(user string, groups []string, opts hos tm.users[user] = groups tm.userUID[user] = opts.UID tm.userGID[user] = opts.GID + + if opts.Shell == "" { + opts.Shell = "/usr/bin/sh" + } + tm.userShells[user] = opts.Shell return nil } @@ -175,6 +187,10 @@ func (tm *testHostUserBackend) GetDefaultHomeDirectory(user string) (string, err return "", nil } +func (tm *testHostUserBackend) IsUsingShell(user, shell string) (bool, error) { + return tm.userShells[user] == shell, nil +} + // RemoveSudoersFile implements HostUsersBackend func (tm *testHostUserBackend) RemoveSudoersFile(user string) error { delete(tm.sudoers, user) @@ -425,7 +441,7 @@ func Test_UpdateUserGroups_Keep(t *testing.T) { closer, err := users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.Equal(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportKeepGroup), backend.users["alice"]) assert.NotContains(t, backend.users["alice"], types.TeleportDropGroup) @@ -435,15 +451,15 @@ func Test_UpdateUserGroups_Keep(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.Equal(t, nil, closer) - assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.Equal(t, 1, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportKeepGroup), backend.users["alice"]) assert.NotContains(t, backend.users["alice"], types.TeleportDropGroup) - // Upsert again with same groups should not call SetUserGroups. + // Upsert again with same groups should not call UpdateUser. closer, err = users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.Equal(t, nil, closer) - assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.Equal(t, 1, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportKeepGroup), backend.users["alice"]) assert.NotContains(t, backend.users["alice"], types.TeleportDropGroup) @@ -452,7 +468,7 @@ func Test_UpdateUserGroups_Keep(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.ErrorIs(t, err, errStaticConversion) assert.Equal(t, nil, closer) - assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.Equal(t, 1, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportKeepGroup), backend.users["alice"]) // Updates with INSECURE_DROP mode should convert the managed user @@ -461,7 +477,7 @@ func Test_UpdateUserGroups_Keep(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.NotEqual(t, nil, closer) - assert.Equal(t, 2, backend.setUserGroupsCalls) + assert.Equal(t, 2, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportDropGroup), backend.users["alice"]) assert.NotContains(t, backend.users["alice"], types.TeleportKeepGroup) } @@ -481,7 +497,7 @@ func Test_UpdateUserGroups_Drop(t *testing.T) { closer, err := users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.NotEqual(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportDropGroup), backend.users["alice"]) assert.NotContains(t, backend.users["alice"], types.TeleportKeepGroup) @@ -491,7 +507,7 @@ func Test_UpdateUserGroups_Drop(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.NotEqual(t, nil, closer) - assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.Equal(t, 1, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportDropGroup), backend.users["alice"]) assert.NotContains(t, backend.users["alice"], types.TeleportKeepGroup) @@ -499,7 +515,7 @@ func Test_UpdateUserGroups_Drop(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.NotEqual(t, nil, closer) - assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.Equal(t, 1, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportDropGroup), backend.users["alice"]) assert.NotContains(t, backend.users["alice"], types.TeleportKeepGroup) @@ -508,7 +524,7 @@ func Test_UpdateUserGroups_Drop(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.ErrorIs(t, err, errStaticConversion) assert.Equal(t, nil, closer) - assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.Equal(t, 1, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportDropGroup), backend.users["alice"]) // Updates with KEEP mode should convert the ephemeral user @@ -517,7 +533,7 @@ func Test_UpdateUserGroups_Drop(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.Equal(t, nil, closer) - assert.Equal(t, 2, backend.setUserGroupsCalls) + assert.Equal(t, 2, backend.updateUserCalls) assert.Equal(t, 1, backend.createHomeDirectoryCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportKeepGroup), backend.users["alice"]) assert.NotContains(t, backend.users["alice"], types.TeleportDropGroup) @@ -537,7 +553,7 @@ func Test_UpdateUserGroups_Static(t *testing.T) { closer, err := users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.Equal(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportStaticGroup), backend.users["alice"]) // Update user with new groups. @@ -545,14 +561,14 @@ func Test_UpdateUserGroups_Static(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.Equal(t, nil, closer) - assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.Equal(t, 1, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportStaticGroup), backend.users["alice"]) // Upsert again with same groups should not call SetUserGroups. closer, err = users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.Equal(t, nil, closer) - assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.Equal(t, 1, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportStaticGroup), backend.users["alice"]) // Do not convert to KEEP. @@ -560,7 +576,7 @@ func Test_UpdateUserGroups_Static(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.ErrorIs(t, err, errStaticConversion) assert.Equal(t, nil, closer) - assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.Equal(t, 1, backend.updateUserCalls) assert.ElementsMatch(t, append(slices.Clone(allGroups[2:]), types.TeleportStaticGroup), backend.users["alice"]) // Do not convert to INSECURE_DROP. @@ -568,7 +584,7 @@ func Test_UpdateUserGroups_Static(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.ErrorIs(t, err, errStaticConversion) assert.Equal(t, nil, closer) - assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.Equal(t, 1, backend.updateUserCalls) assert.ElementsMatch(t, append(slices.Clone(allGroups[2:]), types.TeleportStaticGroup), backend.users["alice"]) } @@ -589,7 +605,7 @@ func Test_DontManageExistingUser(t *testing.T) { closer, err := users.UpsertUser("alice", &userinfo) assert.ErrorIs(t, err, errUnmanagedUser) assert.Equal(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) assert.ElementsMatch(t, allGroups, backend.users["alice"]) // Update user in KEEP mode @@ -597,7 +613,7 @@ func Test_DontManageExistingUser(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.ErrorIs(t, err, errUnmanagedUser) assert.Equal(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) assert.ElementsMatch(t, allGroups, backend.users["alice"]) // Update static user @@ -605,7 +621,7 @@ func Test_DontManageExistingUser(t *testing.T) { closer, err = users.UpsertUser("alice", &userinfo) assert.ErrorIs(t, err, errUnmanagedUser) assert.Equal(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) assert.ElementsMatch(t, allGroups, backend.users["alice"]) } @@ -647,7 +663,7 @@ func Test_DontUpdateUnmanagedUsers(t *testing.T) { closer, err := users.UpsertUser("alice", tc.userinfo) assert.ErrorIs(t, err, errUnmanagedUser) assert.Equal(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) assert.ElementsMatch(t, allGroups[2:], backend.users["alice"]) }) } @@ -673,7 +689,7 @@ func Test_AllowExplicitlyManageExistingUsers(t *testing.T) { closer, err := users.UpsertUser("alice-keep", &userinfo) assert.NoError(t, err) assert.Equal(t, nil, closer) - assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.Equal(t, 1, backend.updateUserCalls) // slice off the end because teleport-system should be explicitly excluded assert.ElementsMatch(t, allGroups[:2], backend.users["alice-keep"]) assert.NotContains(t, backend.users["alice-keep"], types.TeleportDropGroup) @@ -683,7 +699,7 @@ func Test_AllowExplicitlyManageExistingUsers(t *testing.T) { closer, err = users.UpsertUser("alice-static", &userinfo, TakeOwnershipIfUserExists(true)) assert.NoError(t, err) assert.Equal(t, nil, closer) - assert.Equal(t, 2, backend.setUserGroupsCalls) + assert.Equal(t, 2, backend.updateUserCalls) assert.Contains(t, backend.users["alice-static"], "foo") assert.Contains(t, backend.users["alice-static"], types.TeleportStaticGroup) assert.NotContains(t, backend.users["alice-static"], types.TeleportKeepGroup) @@ -694,7 +710,7 @@ func Test_AllowExplicitlyManageExistingUsers(t *testing.T) { closer, err = users.UpsertUser("alice-drop", &userinfo) assert.ErrorIs(t, err, errUnmanagedUser) assert.Equal(t, nil, closer) - assert.Equal(t, 2, backend.setUserGroupsCalls) + assert.Equal(t, 2, backend.updateUserCalls) assert.Empty(t, backend.users["alice-drop"]) // Don't assign teleport-keep to users created in DROP mode @@ -702,7 +718,7 @@ func Test_AllowExplicitlyManageExistingUsers(t *testing.T) { closer, err = users.UpsertUser("bob", &userinfo) assert.NoError(t, err) assert.NotEqual(t, nil, closer) - assert.Equal(t, 2, backend.setUserGroupsCalls) + assert.Equal(t, 2, backend.updateUserCalls) assert.ElementsMatch(t, []string{"foo", types.TeleportDropGroup}, backend.users["bob"]) assert.NotContains(t, backend.users["bob"], types.TeleportKeepGroup) } @@ -750,21 +766,21 @@ func TestCreateUserWithExistingPrimaryGroup(t *testing.T) { closer, err := users.UpsertUser("bob", &userinfo) assert.NoError(t, err) assert.NotEqual(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) // create a user with primary group defined in userinfo.Groups, but not yet on the host userinfo.Groups = []string{"fred"} closer, err = users.UpsertUser("fred", &userinfo) assert.NoError(t, err) assert.NotEqual(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) // create a user with primary group defined in userinfo.Groups that already exists on the host userinfo.Groups = []string{"alice"} closer, err = users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.NotEqual(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) // create a user with primary group that already exists on the host but is not defined in userinfo.Groups userinfo.Groups = []string{""} @@ -772,7 +788,7 @@ func TestCreateUserWithExistingPrimaryGroup(t *testing.T) { assert.True(t, trace.IsAlreadyExists(err)) assert.Contains(t, err.Error(), "conflicts with an existing group") assert.Equal(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) } func TestHostUsersResolveGroups(t *testing.T) { @@ -1127,7 +1143,7 @@ func TestRegressionGroupErrorDoesNotPanic(t *testing.T) { closer, err := users.UpsertUser("alice", &userinfo) assert.NoError(t, err) assert.Equal(t, nil, closer) - assert.Zero(t, backend.setUserGroupsCalls) + assert.Zero(t, backend.updateUserCalls) assert.ElementsMatch(t, append(userinfo.Groups, types.TeleportKeepGroup), backend.users["alice"]) assert.NotContains(t, backend.users["alice"], types.TeleportDropGroup) @@ -1135,3 +1151,48 @@ func TestRegressionGroupErrorDoesNotPanic(t *testing.T) { _, err = users.UpsertUser("alice", &userinfo) require.Error(t, err) } + +func TestIsUserShell(t *testing.T) { + allGroups := []string{"foo", "bar", "baz"} + users, backend := initBackend(t, allGroups) + userinfo := decisionpb.HostUsersInfo{ + Groups: slices.Clone(allGroups[:2]), + Mode: decisionpb.HostUserMode_HOST_USER_MODE_KEEP, + } + + // no shell defined, create with default + closer, err := users.UpsertUser("alice", &userinfo) + assert.NoError(t, err) + assert.Equal(t, nil, closer) + hasShell, err := backend.IsUsingShell("alice", "/usr/bin/sh") + require.NoError(t, err) + require.True(t, hasShell) + + // shell defined, create with shell + userinfo.Shell = "/usr/bin/bash" + closer, err = users.UpsertUser("bob", &userinfo) + assert.NoError(t, err) + assert.Equal(t, nil, closer) + hasShell, err = backend.IsUsingShell("bob", "/usr/bin/bash") + require.NoError(t, err) + require.True(t, hasShell) + + // shell defined but unchanged, do nothing on update + closer, err = users.UpsertUser("bob", &userinfo) + assert.NoError(t, err) + assert.Equal(t, nil, closer) + hasShell, err = backend.IsUsingShell("bob", "/usr/bin/bash") + require.NoError(t, err) + require.True(t, hasShell) + require.Equal(t, 0, backend.updateUserCalls) + + // shell defined and changed, update user + userinfo.Shell = "/usr/bin/zsh" + closer, err = users.UpsertUser("bob", &userinfo) + assert.NoError(t, err) + assert.Equal(t, nil, closer) + hasShell, err = backend.IsUsingShell("bob", "/usr/bin/zsh") + require.NoError(t, err) + require.True(t, hasShell) + require.Equal(t, 1, backend.updateUserCalls) +} diff --git a/lib/utils/host/hostusers.go b/lib/utils/host/hostusers.go index 684edec2e8d79..f8cbbd9dbdade 100644 --- a/lib/utils/host/hostusers.go +++ b/lib/utils/host/hostusers.go @@ -144,15 +144,26 @@ func UserAdd(username string, groups []string, opts UserOpts) (exitCode int, err return cmd.ProcessState.ExitCode(), trace.Wrap(err) } -// SetUserGroups adds a user to a list of specified groups on a host using `usermod`, +// UserUpdate sets the groups and default shell for a host user using `usermod`, // overriding any existing supplementary groups. -func SetUserGroups(username string, groups []string) (exitCode int, err error) { +func UserUpdate(username string, groups []string, defaultShell string) (exitCode int, err error) { usermodBin, err := exec.LookPath("usermod") if err != nil { return -1, trace.Wrap(err, "cant find usermod binary") } - // usermod -G (replace groups) (username) - cmd := exec.Command(usermodBin, "-G", strings.Join(groups, ","), username) + var args []string + if groups != nil { + args = append(args, "-G", strings.Join(groups, ",")) + } + if defaultShell != "" { + if shell, err := exec.LookPath(defaultShell); err != nil { + slog.WarnContext(context.Background(), "configured shell not found, falling back to host default", "shell", defaultShell) + } else { + args = append(args, "--shell", shell) + } + } + // usermod -G (replace groups) --shell (default shell) (username) + cmd := exec.Command(usermodBin, append(args, username)...) output, err := cmd.CombinedOutput() slog.DebugContext(context.Background(), "usermod completed", "command_path", cmd.Path, @@ -303,3 +314,29 @@ func CheckSudoers(contents []byte) error { } return trace.Wrap(err) } + +// UserShell invokes the 'getent' binary in order to fetch the default shell for the +// given user. +func UserShell(username string) (string, error) { + if username == "" { + return "", trace.BadParameter("cannot lookup shell without username") + } + getentBin, err := exec.LookPath("getent") + if err != nil { + return "", trace.NotFound("cannot find getent binary: %s", err) + } + cmd := exec.Command(getentBin, "passwd", username) + output, err := cmd.CombinedOutput() + if err != nil { + return "", trace.Wrap(err) + } + + // grab last element in passwd entry + entry := bytes.TrimSpace(output) + shellIdx := bytes.LastIndex(entry, []byte(":")) + 1 + if shellIdx >= len(entry) { + return "", trace.Errorf("invalid passwd entry for user %q", username) + } + + return string(entry), nil +}