From 61e8c3cd77375627017c759fe72f28efaa65b965 Mon Sep 17 00:00:00 2001 From: Erik Tate Date: Thu, 29 Aug 2024 16:55:45 -0400 Subject: [PATCH] adding which role caused auth failure to the resulting error when checking for create_host_user_mode and surfacing for helpful error logging server side --- lib/services/access_checker.go | 2 +- lib/srv/regular/sshserver.go | 10 +- lib/srv/regular/sshserver_test.go | 4 +- lib/srv/sess.go | 59 ++++--- lib/srv/sess_test.go | 285 ++++++++++++++++++++++++++++++ 5 files changed, 330 insertions(+), 30 deletions(-) diff --git a/lib/services/access_checker.go b/lib/services/access_checker.go index b4ccd0cfc0c48..406ac2e1185d0 100644 --- a/lib/services/access_checker.go +++ b/lib/services/access_checker.go @@ -1008,7 +1008,7 @@ func (a *accessChecker) HostUsers(s types.Server) (*HostUsersInfo, error) { // if any of the matching roles do not enable create host // user, the user should not be allowed on if createHostUserMode == types.CreateHostUserMode_HOST_USER_MODE_OFF { - return nil, trace.AccessDenied("user is not allowed to create host users") + return nil, trace.AccessDenied("role %q prevents creating host users", role.GetName()) } if mode == types.CreateHostUserMode_HOST_USER_MODE_UNSPECIFIED { diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 8421e3ec2009e..f2ed7810a3dc3 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1264,20 +1264,22 @@ func (s *Server) HandleNewConn(ctx context.Context, ccx *sshutils.ConnectionCont } // Create host user. - created, userCloser, err := s.termHandlers.SessionRegistry.TryCreateHostUser(identityContext) + created, userCloser, err := s.termHandlers.SessionRegistry.UpsertHostUser(identityContext) if err != nil { - return ctx, trace.Wrap(err) + log.Infof("error while creating host users: %s", err) } + // Indicate that the user was created by Teleport. ccx.UserCreatedByTeleport = created if userCloser != nil { ccx.AddCloser(userCloser) } - sudoersCloser, err := s.termHandlers.SessionRegistry.TryWriteSudoersFile(identityContext) + sudoersCloser, err := s.termHandlers.SessionRegistry.WriteSudoersFile(identityContext) if err != nil { - return ctx, trace.Wrap(err) + log.Warnf("error while writing sudoers: %s", err) } + if sudoersCloser != nil { ccx.AddCloser(sudoersCloser) } diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index d6d6e5ad5f53d..35c5a13bfa13a 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -3065,13 +3065,13 @@ func TestHostUserCreationProxy(t *testing.T) { reg, err := srv.NewSessionRegistry(srv.SessionRegistryConfig{Srv: proxy, SessionTrackerService: proxyClient}) require.NoError(t, err) - _, err = reg.TryWriteSudoersFile(srv.IdentityContext{ + _, err = reg.WriteSudoersFile(srv.IdentityContext{ AccessChecker: &fakeAccessChecker{}, }) assert.NoError(t, err) assert.Equal(t, 0, sudoers.writeAttempts) - _, _, err = reg.TryCreateHostUser(srv.IdentityContext{ + _, _, err = reg.UpsertHostUser(srv.IdentityContext{ AccessChecker: &fakeAccessChecker{}, }) assert.NoError(t, err) diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 9b11ae3b095dd..6e1b1a64058d6 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -235,11 +235,15 @@ func (sc *sudoersCloser) Close() error { return nil } -// TryWriteSudoersFile tries to write the needed sudoers entry to the sudoers +// WriteSudoersFile tries to write the needed sudoers entry to the sudoers // file, if any. If the returned closer is not nil, it must be called at the // end of the session to cleanup the sudoers file. -func (s *SessionRegistry) TryWriteSudoersFile(identityContext IdentityContext) (io.Closer, error) { - // Pulling sudoers directly from the Srv so TryWriteSudoersFile always +func (s *SessionRegistry) WriteSudoersFile(identityContext IdentityContext) (io.Closer, error) { + if identityContext.Login == teleport.SSHSessionJoinPrincipal { + return nil, nil + } + + // Pulling sudoers directly from the Srv so WriteSudoersFile always // respects the invariant that we shouldn't write sudoers on proxy servers. // This might invalidate the cached sudoers field on SessionRegistry, so // we may be able to remove that in a future PR @@ -265,42 +269,51 @@ func (s *SessionRegistry) TryWriteSudoersFile(identityContext IdentityContext) ( return &sudoersCloser{ username: identityContext.Login, userSessions: s.sessionsByUser, - cleanup: s.sudoers.RemoveSudoers, + cleanup: sudoWriter.RemoveSudoers, }, nil } -// TryCreateHostUser attempts to create a local user on the host if needed. -// If the returned closer is not nil, it must be called at the end of the -// session to clean up the local user. -func (s *SessionRegistry) TryCreateHostUser(identityContext IdentityContext) (created bool, closer io.Closer, err error) { +// UpsertHostUser attempts to create or update a local user on the host if needed. +// If the returned closer is not nil, it must be called at the end of the session to +// clean up the local user. +func (s *SessionRegistry) UpsertHostUser(identityContext IdentityContext) (bool, io.Closer, error) { + if identityContext.Login == teleport.SSHSessionJoinPrincipal { + return false, nil, nil + } + if !s.Srv.GetCreateHostUser() || s.users == nil { s.log.Debug("Not creating host user: node has disabled host user creation.") return false, nil, nil // not an error to not be able to create host users } - ui, err := identityContext.AccessChecker.HostUsers(s.Srv.GetInfo()) - if err != nil { - if trace.IsAccessDenied(err) { - log.Warnf("Unable to create host users: %v", err) - return false, nil, nil + ui, accessErr := identityContext.AccessChecker.HostUsers(s.Srv.GetInfo()) + if trace.IsAccessDenied(accessErr) { + existsErr := s.users.UserExists(identityContext.Login) + if existsErr != nil { + if trace.IsNotFound(existsErr) { + return false, nil, trace.WrapWithMessage(accessErr, "insufficient permissions for host user creation") + } + + return false, nil, trace.Wrap(existsErr) } - log.Debug("Error while checking host users creation permission: ", err) - return false, nil, trace.Wrap(err) } - existsErr := s.users.UserExists(identityContext.Login) - if trace.IsAccessDenied(err) && existsErr != nil { - return false, nil, trace.WrapWithMessage(err, "Insufficient permission for host user creation") + if accessErr != nil { + return false, nil, trace.Wrap(accessErr) } userCloser, err := s.users.UpsertUser(identityContext.Login, *ui) - if err != nil && !trace.IsAlreadyExists(err) && !errors.Is(err, unmanagedUserErr) { + if err != nil { log.Debugf("Error creating user %s: %s", identityContext.Login, err) - return false, nil, trace.Wrap(err) - } - if errors.Is(err, unmanagedUserErr) { - log.Warnf("User %q is not managed by teleport. Either manually delete the user from this machine or update the host_groups defined in their role to include %q. https://goteleport.com/docs/enroll-resources/server-access/guides/host-user-creation/#migrating-unmanaged-users", identityContext.Login, types.TeleportKeepGroup) + if errors.Is(err, unmanagedUserErr) { + log.Warnf("User %q is not managed by teleport. Either manually delete the user from this machine or update the host_groups defined in their role to include %q. https://goteleport.com/docs/enroll-resources/server-access/guides/host-user-creation/#migrating-unmanaged-users", identityContext.Login, types.TeleportKeepGroup) + return false, nil, nil + } + + if !trace.IsAlreadyExists(err) { + return false, nil, trace.Wrap(err) + } } return true, userCloser, nil diff --git a/lib/srv/sess_test.go b/lib/srv/sess_test.go index 4df45cd64286b..fafb4f481957a 100644 --- a/lib/srv/sess_test.go +++ b/lib/srv/sess_test.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "golang.org/x/term" @@ -1428,4 +1429,288 @@ func mockSSHSession(t *testing.T) *tracessh.Session { require.Fail(t, "timeout while waiting for the SSH session") return nil } + +} + +func TestUpsertHostUser(t *testing.T) { + username := "alice" + + cases := []struct { + name string + + identityContext IdentityContext + hostUsers *fakeHostUsersBackend + createHostUser bool + + expectCreated bool + expectErrIs error + expectErrContains string + expectUsers map[string][]string + }{ + { + name: "should upsert existing user with permission", + createHostUser: true, + identityContext: IdentityContext{ + Login: username, + AccessChecker: &fakeAccessChecker{ + hostInfo: services.HostUsersInfo{ + Groups: []string{"foo", "bar"}, + }, + }, + }, + hostUsers: &fakeHostUsersBackend{users: map[string][]string{ + username: {}, + }}, + + expectCreated: true, + + expectUsers: map[string][]string{ + username: {"foo", "bar"}, + }, + }, + { + name: "should upsert new user with permission", + createHostUser: true, + identityContext: IdentityContext{ + Login: username, + AccessChecker: &fakeAccessChecker{ + hostInfo: services.HostUsersInfo{ + Groups: []string{"foo", "bar"}, + }, + }, + }, + hostUsers: &fakeHostUsersBackend{}, + + expectCreated: true, + expectUsers: map[string][]string{ + username: {"foo", "bar"}, + }, + }, + { + name: "should not upsert existing user without permission", + createHostUser: true, + identityContext: IdentityContext{Login: username, AccessChecker: &fakeAccessChecker{err: trace.AccessDenied("test")}}, + hostUsers: &fakeHostUsersBackend{ + users: map[string][]string{ + username: {}, + }, + }, + + expectCreated: false, + expectErrIs: trace.AccessDenied("test"), + expectUsers: map[string][]string{ + username: {}, + }, + }, + { + name: "should not upsert new user without permission", + createHostUser: true, + identityContext: IdentityContext{Login: username, AccessChecker: &fakeAccessChecker{err: trace.AccessDenied("test")}}, + hostUsers: &fakeHostUsersBackend{}, + + expectCreated: false, + expectUsers: make(map[string][]string), + expectErrIs: trace.AccessDenied("test"), + expectErrContains: "insufficient permissions for host user creation", + }, + { + name: "should do nothing if login is session join principal", + createHostUser: true, + identityContext: IdentityContext{Login: teleport.SSHSessionJoinPrincipal}, + hostUsers: &fakeHostUsersBackend{}, + + expectCreated: false, + expectUsers: make(map[string][]string), + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + registry := SessionRegistry{ + SessionRegistryConfig: SessionRegistryConfig{ + Srv: &fakeServer{createHostUser: c.createHostUser}, + }, + users: c.hostUsers, + } + + userCreated, _, err := registry.UpsertHostUser(c.identityContext) + + if c.expectErrIs != nil { + assert.ErrorIs(t, err, c.expectErrIs) + } + + if c.expectErrContains != "" { + assert.Contains(t, err.Error(), c.expectErrContains) + } + + if c.expectErrIs == nil && c.expectErrContains == "" { + assert.NoError(t, err) + } + + assert.Equal(t, c.expectCreated, userCreated) + + for name, groups := range c.hostUsers.users { + expectedGroups, ok := c.expectUsers[name] + assert.True(t, ok, "user must be present in expected users") + assert.ElementsMatch(t, expectedGroups, groups) + } + }) + } +} + +func TestWriteSudoersFile(t *testing.T) { + username := "alice" + + cases := []struct { + name string + + identityContext IdentityContext + hostSudoers *fakeSudoersBackend + + expectSudoers map[string][]string + expectErrIs error + expectErrContains string + }{ + { + name: "should write sudoers with permission", + identityContext: IdentityContext{Login: username, AccessChecker: &fakeAccessChecker{}}, + hostSudoers: &fakeSudoersBackend{}, + + expectSudoers: map[string][]string{ + username: {"foo", "bar"}, + }, + }, + { + name: "should not write sudoers without permission", + identityContext: IdentityContext{Login: username, AccessChecker: &fakeAccessChecker{err: trace.AccessDenied("test")}}, + hostSudoers: &fakeSudoersBackend{}, + + expectSudoers: map[string][]string{}, + expectErrIs: trace.AccessDenied("test"), + }, + { + name: "should do nothing for session join principal", + identityContext: IdentityContext{Login: teleport.SSHSessionJoinPrincipal, AccessChecker: &fakeAccessChecker{}}, + hostSudoers: &fakeSudoersBackend{}, + + expectSudoers: map[string][]string{}, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + registry := SessionRegistry{ + SessionRegistryConfig: SessionRegistryConfig{ + Srv: &fakeServer{hostSudoers: c.hostSudoers}, + }, + sessionsByUser: &userSessions{ + sessionsByUser: make(map[string]int), + }, + } + + _, err := registry.WriteSudoersFile(c.identityContext) + + if c.expectErrIs != nil { + assert.ErrorIs(t, err, c.expectErrIs) + } + + if c.expectErrContains != "" { + assert.Contains(t, err.Error(), c.expectErrContains) + } + + if c.expectErrIs == nil && c.expectErrContains == "" { + assert.NoError(t, err) + } + + for name, sudoers := range c.hostSudoers.sudoers { + expectedSudoers, ok := c.expectSudoers[name] + assert.True(t, ok, "there should be an expected name for each login name") + assert.ElementsMatch(t, expectedSudoers, sudoers) + } + }) + } +} + +type fakeServer struct { + Server + + createHostUser bool + hostSudoers HostSudoers +} + +func (f *fakeServer) GetCreateHostUser() bool { + return f.createHostUser +} + +func (f *fakeServer) GetHostSudoers() HostSudoers { + return f.hostSudoers +} + +func (f *fakeServer) GetInfo() types.Server { + return nil +} + +type fakeAccessChecker struct { + services.AccessChecker + err error + hostInfo services.HostUsersInfo +} + +func (f *fakeAccessChecker) HostSudoers(srv types.Server) ([]string, error) { + return []string{"foo", "bar"}, f.err +} + +func (f *fakeAccessChecker) HostUsers(srv types.Server) (*services.HostUsersInfo, error) { + return &f.hostInfo, f.err +} + +type fakeHostUsersBackend struct { + HostUsers + + users map[string][]string +} + +func (f *fakeHostUsersBackend) UpsertUser(name string, hostRoleInfo services.HostUsersInfo) (io.Closer, error) { + if f.users == nil { + f.users = make(map[string][]string) + } + + f.users[name] = hostRoleInfo.Groups + return nil, nil +} + +func (f *fakeHostUsersBackend) UserExists(name string) error { + if f.users == nil { + return trace.NotFound(name) + } + + _, exists := f.users[name] + if !exists { + return trace.NotFound(name) + } + + return nil +} + +type fakeSudoersBackend struct { + sudoers map[string][]string + err error +} + +func (f *fakeSudoersBackend) WriteSudoers(name string, sudoers []string) error { + if f.sudoers == nil { + f.sudoers = make(map[string][]string) + } + + f.sudoers[name] = append(f.sudoers[name], sudoers...) + return f.err +} + +func (f *fakeSudoersBackend) RemoveSudoers(name string) error { + if f.sudoers == nil { + return nil + } + + delete(f.sudoers, name) + return f.err }