diff --git a/lib/bpf/bpf.go b/lib/bpf/bpf.go index 4d532d1e08c0d..2ab9324fec4f1 100644 --- a/lib/bpf/bpf.go +++ b/lib/bpf/bpf.go @@ -199,9 +199,12 @@ func New(config *servicecfg.BPFConfig, restrictedSession *servicecfg.RestrictedS } // Close will stop any running BPF programs. Note this is only for a graceful -// shutdown, from the man page for BPF: "Generally, eBPF programs are loaded -// by the user process and automatically unloaded when the process exits." -func (s *Service) Close() error { +// shutdown, from the man page for BPF: "Generally, eBPF programs are loaded by +// the user process and automatically unloaded when the process exits". The +// restarting parameter indicates that Teleport is shutting down because of a +// restart, and thus we should skip any deinitialization that would interfere +// with the new Teleport instance. +func (s *Service) Close(restarting bool) error { // Unload the BPF programs. s.exec.close() s.open.close() @@ -209,8 +212,10 @@ func (s *Service) Close() error { s.conn.close() } - // Close cgroup service. - if err := s.cgroup.Close(); err != nil { + // Close cgroup service. We should not unmount the cgroup filesystem if + // we're restarting. + skipCgroupUnmount := restarting + if err := s.cgroup.Close(skipCgroupUnmount); err != nil { log.WithError(err).Warn("Failed to close cgroup") } diff --git a/lib/bpf/bpf_test.go b/lib/bpf/bpf_test.go index 45f6f90822330..31c115491d2aa 100644 --- a/lib/bpf/bpf_test.go +++ b/lib/bpf/bpf_test.go @@ -72,7 +72,8 @@ func TestRootWatch(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { - require.NoError(t, service.Close()) + const restarting = false + require.NoError(t, service.Close(restarting)) }) // Create a fake audit log that can be used to capture the events emitted. @@ -463,7 +464,8 @@ func moveIntoCgroup(t *testing.T, pid int) (uint64, error) { return 0, trace.Wrap(err) } t.Cleanup(func() { - require.NoError(t, cgroupSrv.Close()) + const skipUnmount = false + require.NoError(t, cgroupSrv.Close(skipUnmount)) }) sessionID := uuid.New().String() diff --git a/lib/bpf/common.go b/lib/bpf/common.go index 70bd270cb0170..cf9f7b967fadc 100644 --- a/lib/bpf/common.go +++ b/lib/bpf/common.go @@ -39,7 +39,7 @@ type BPF interface { CloseSession(ctx *SessionContext) error // Close will stop any running BPF programs. - Close() error + Close(restarting bool) error } // SessionContext contains all the information needed to track and emit @@ -85,7 +85,7 @@ type NOP struct { } // Close closes the NOP service. Note this function does nothing. -func (s *NOP) Close() error { +func (s *NOP) Close(bool) error { return nil } diff --git a/lib/cgroup/cgroup.go b/lib/cgroup/cgroup.go index e1cad50a980e3..41f691222b849 100644 --- a/lib/cgroup/cgroup.go +++ b/lib/cgroup/cgroup.go @@ -92,13 +92,19 @@ func New(config *Config) (*Service, error) { return s, nil } -// Close will unmount the cgroup filesystem. -func (s *Service) Close() error { +// Close will clean up the session cgroups and unmount the cgroup2 filesystem, +// unless otherwise requested. +func (s *Service) Close(skipUnmount bool) error { err := s.cleanupHierarchy() if err != nil { return trace.Wrap(err) } + if skipUnmount { + log.Debugf("Cleaned up Teleport session hierarchy at: %v.", s.teleportRoot) + return nil + } + err = s.unmount() if err != nil { return trace.Wrap(err) diff --git a/lib/cgroup/cgroup_test.go b/lib/cgroup/cgroup_test.go index 95d3b1aa58fa5..d2c51d64f3ab8 100644 --- a/lib/cgroup/cgroup_test.go +++ b/lib/cgroup/cgroup_test.go @@ -71,7 +71,8 @@ func TestRootCreate(t *testing.T) { require.NoDirExists(t, cgroupPath) // Close the cgroup service, this should unmound the cgroup filesystem. - err = service.Close() + const skipUnmount = false + err = service.Close(skipUnmount) require.NoError(t, err) // Make sure the cgroup filesystem has been unmounted. @@ -96,7 +97,8 @@ func TestRootCleanup(t *testing.T) { MountPath: dir, }) require.NoError(t, err) - defer service.Close() + const skipUnmount = false + defer service.Close(skipUnmount) // Create fake session ID and cgroup. sessionID := uuid.New().String() @@ -112,6 +114,40 @@ func TestRootCleanup(t *testing.T) { require.NoDirExists(t, cgroupPath) } +// TestRootSkipUnmount checks that closing the service with skipUnmount set to +// true works correctly; i.e. it cleans up the cgroups we're responsible for but +// doesn't unmount the cgroup2 file system. +func TestRootSkipUnmount(t *testing.T) { + // This test must be run as root. Only root can create cgroups. + if !isRoot() { + t.Skip("Tests for package cgroup can only be run as root.") + } + + t.Parallel() + + // Start a cgroup service with a temporary directory as the mount path. + service, err := New(&Config{ + MountPath: t.TempDir(), + }) + require.NoError(t, err) + + sessionID := uuid.NewString() + sessionPath := path.Join(service.teleportRoot, sessionID) + require.NoError(t, service.Create(sessionID)) + + require.DirExists(t, sessionPath) + + const skipUnmount = true + require.NoError(t, service.Close(skipUnmount)) + + require.DirExists(t, service.teleportRoot) + require.NoDirExists(t, path.Join(service.teleportRoot, sessionID)) + + require.NoError(t, service.unmount()) + + require.NoDirExists(t, service.teleportRoot) +} + // isRoot returns a boolean if the test is being run as root or not. Tests // for this package must be run as root. func isRoot() bool { diff --git a/lib/restrictedsession/restricted_test.go b/lib/restrictedsession/restricted_test.go index 8ebedb00bed91..067589ef46088 100644 --- a/lib/restrictedsession/restricted_test.go +++ b/lib/restrictedsession/restricted_test.go @@ -239,7 +239,8 @@ func (tt *bpfContext) Close(t *testing.T) { if tt.enhancedRecorder != nil && tt.ctx != nil { err := tt.enhancedRecorder.CloseSession(tt.ctx) require.NoError(t, err) - err = tt.enhancedRecorder.Close() + const restarting = false + err = tt.enhancedRecorder.Close(restarting) require.NoError(t, err) } diff --git a/lib/service/service.go b/lib/service/service.go index c102c61672209..1c9cc2af05b84 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2266,6 +2266,12 @@ func (process *TeleportProcess) initSSH() error { proxyGetter := reversetunnel.NewConnectedProxyGetter() process.RegisterCriticalFunc("ssh.node", func() error { + // restartingOnGracefulShutdown will be set to true before the function + // exits if the function is exiting because Teleport is gracefully + // shutting down as a consequence of internally-triggered reloading or + // being signaled to restart. + var restartingOnGracefulShutdown bool + conn, err := process.WaitForConnector(SSHIdentityEvent, log) if conn == nil { return trace.Wrap(err) @@ -2318,7 +2324,7 @@ func (process *TeleportProcess) initSSH() error { if err != nil { return trace.Wrap(err) } - defer func() { warnOnErr(ebpf.Close(), log) }() + defer func() { warnOnErr(ebpf.Close(restartingOnGracefulShutdown), log) }() // Start access control programs. This is blocking and if the BPF programs fail to // load, the node will not start. If access control is not enabled, this will simply @@ -2529,7 +2535,9 @@ func (process *TeleportProcess) initSSH() error { warnOnErr(s.Close(), log) } else { log.Infof("Shutting down gracefully.") - warnOnErr(s.Shutdown(payloadContext(event.Payload, log)), log) + ctx := payloadContext(event.Payload, log) + restartingOnGracefulShutdown = services.IsProcessReloading(ctx) || services.HasProcessForked(ctx) + warnOnErr(s.Shutdown(ctx), log) } s.Wait()