diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index eba70aa56a690..2de3d8d0d746c 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -25,25 +25,51 @@ import ( "log/slog" "os" "os/exec" + "strconv" + "syscall" + "time" "github.com/gravitational/trace" + "golang.org/x/sync/errgroup" ) -// SystemdService manages a Teleport systemd service. +const ( + // crashMonitorInterval is the polling interval for determining restart times from LastRestartPath. + crashMonitorInterval = 2 * time.Second + // minRunningIntervalsBeforeStable is the number of consecutive intervals with the same running PID detect + // before the service is determined stable. + minRunningIntervalsBeforeStable = 6 + // maxCrashesBeforeFailure is the number of total crashes detected before the service is marked as crash-looping. + maxCrashesBeforeFailure = 2 + // crashMonitorTimeout + crashMonitorTimeout = 30 * time.Second +) + +// log keys +const ( + unitKey = "unit" +) + +// SystemdService manages a systemd service (e.g., teleport or teleport-update). type SystemdService struct { // ServiceName specifies the systemd service name. ServiceName string + // PIDPath is a path to a file containing the service's PID. + PIDPath string // Log contains a logger. Log *slog.Logger } -// Reload a systemd service. +// Reload the systemd service. // Attempts a graceful reload before a hard restart. // See Process interface for more details. func (s SystemdService) Reload(ctx context.Context) error { + // TODO(sclevine): allow server to force restart instead of reload + if err := s.checkSystem(ctx); err != nil { return trace.Wrap(err) } + // Command error codes < 0 indicate that we are unable to run the command. // Errors from s.systemctl are logged along with stderr and stdout (debug only). @@ -55,30 +81,167 @@ func (s SystemdService) Reload(ctx context.Context) error { case code < 0: return trace.Errorf("unable to determine if systemd service is active") case code > 0: - s.Log.WarnContext(ctx, "Teleport systemd service not running.") + s.Log.WarnContext(ctx, "Systemd service not running.", unitKey, s.ServiceName) return trace.Wrap(ErrNotNeeded) } + + // Get initial PID for crash monitoring. + + initPID, err := readInt(s.PIDPath) + if errors.Is(err, os.ErrNotExist) { + s.Log.InfoContext(ctx, "No existing process detected. Skipping crash monitoring.", unitKey, s.ServiceName) + } else if err != nil { + s.Log.ErrorContext(ctx, "Error reading initial PID value. Skipping crash monitoring.", unitKey, s.ServiceName, errorKey, err) + } + // Attempt graceful reload of running service. code = s.systemctl(ctx, slog.LevelError, "reload", s.ServiceName) switch { case code < 0: - return trace.Errorf("unable to attempt reload of Teleport systemd service") + return trace.Errorf("unable to reload systemd service") case code > 0: // Graceful reload fails, try hard restart. code = s.systemctl(ctx, slog.LevelError, "try-restart", s.ServiceName) if code != 0 { - return trace.Errorf("hard restart of Teleport systemd service failed") + return trace.Errorf("hard restart of systemd service failed") } - s.Log.WarnContext(ctx, "Teleport ungracefully restarted. Connections potentially dropped.") + s.Log.WarnContext(ctx, "Service ungracefully restarted. Connections potentially dropped.", unitKey, s.ServiceName) default: - s.Log.InfoContext(ctx, "Teleport gracefully reloaded.") + s.Log.InfoContext(ctx, "Gracefully reloaded.", unitKey, s.ServiceName) + } + if initPID != 0 { + s.Log.InfoContext(ctx, "Monitoring PID file to detect crashes.", unitKey, s.ServiceName) + return trace.Wrap(s.monitor(ctx, initPID)) + } + return nil +} + +// monitor for the started process to ensure it's running by polling PIDFile. +// This function detects several types of crashes while minimizing its own runtime during updates. +// For example, the process may crash by failing to fork (non-running PID), or looping (repeatedly changing PID), +// or getting stuck on quit (no change in PID). +// initPID is the PID before the restart operation has been issued. +func (s SystemdService) monitor(ctx context.Context, initPID int) error { + ctx, cancel := context.WithTimeout(ctx, crashMonitorTimeout) + defer cancel() + tickC := time.NewTicker(crashMonitorInterval).C + + pidC := make(chan int) + g := &errgroup.Group{} + g.Go(func() error { + return tickFile(ctx, s.PIDPath, pidC, tickC) + }) + err := s.waitForStablePID(ctx, minRunningIntervalsBeforeStable, maxCrashesBeforeFailure, + initPID, pidC, func(pid int) error { + p, err := os.FindProcess(pid) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(p.Signal(syscall.Signal(0))) + }) + cancel() + if err := g.Wait(); err != nil { + s.Log.ErrorContext(ctx, "Error monitoring for crashing process.", errorKey, err, unitKey, s.ServiceName) } + return trace.Wrap(err) +} - // TODO(sclevine): Ensure restart was successful and verify healthcheck. +// waitForStablePID monitors a service's PID via pidC and determines whether the service is crashing. +// verifyPID must be passed so that waitForStablePID can determine whether the process is running. +// verifyPID must return os.ErrProcessDone in the case that the PID cannot be found, or nil otherwise. +// baselinePID is the initial PID before any operation that might cause the process to start crashing. +// minStable is the number of times pidC must return the same running PID before waitForStablePID returns nil. +// minCrashes is the number of times pidC conveys a process crash or bad state before waitForStablePID returns an error. +func (s SystemdService) waitForStablePID(ctx context.Context, minStable, maxCrashes, baselinePID int, pidC <-chan int, verifyPID func(pid int) error) error { + pid := baselinePID + var last, stale int + var crashes int + for stable := 0; stable < minStable; stable++ { + select { + case <-ctx.Done(): + return ctx.Err() + case p := <-pidC: + last = pid + pid = p + } + // A "crash" is defined as a transition away from a new (non-baseline) PID, or + // an interval where the current PID remains non-running (stale) since the last check. + if (last != 0 && pid != last && last != baselinePID) || + (stale != 0 && pid == stale && last == stale) { + crashes++ + } + if crashes > maxCrashes { + return trace.Errorf("detected crashing process") + } + // PID can only be stable if it is a real PID that is not new, + // has changed at least once, and hasn't been observed as missing. + if pid == 0 || + pid == baselinePID || + pid == stale || + pid != last { + stable = -1 + continue + } + err := verifyPID(pid) + // A stale PID most likely indicates that the process forked and crashed without systemd noticing. + // There is a small chance that we read the PID file before systemd removed it. + // Note: we only perform this check on PIDs that survive one iteration. + if errors.Is(err, os.ErrProcessDone) || + errors.Is(err, syscall.ESRCH) { + if pid != stale && + pid != baselinePID { + stale = pid + s.Log.WarnContext(ctx, "Detected stale PID.", unitKey, s.ServiceName, "pid", stale) + } + stable = -1 + continue + } + if err != nil { + return trace.Wrap(err) + } + } return nil } +// readInt reads an integer from a file. +func readInt(path string) (int, error) { + p, err := readFileN(path, 32) + if err != nil { + return 0, trace.Wrap(err) + } + i, err := strconv.ParseInt(string(bytes.TrimSpace(p)), 10, 64) + if err != nil { + return 0, trace.Wrap(err) + } + return int(i), nil +} + +// tickFile reads the current time on tickC, and outputs the last read int from path on ch for each received tick. +// If the path cannot be read, tickFile sends 0 on ch. +// Any error from the last attempt to read path is returned when ctx is canceled, unless the error is os.ErrNotExist. +func tickFile(ctx context.Context, path string, ch chan<- int, tickC <-chan time.Time) error { + var err error + for { + // two select statements -> never skip reads + select { + case <-tickC: + case <-ctx.Done(): + return err + } + var t int + t, err = readInt(path) + if errors.Is(err, os.ErrNotExist) { + err = nil + } + select { + case ch <- t: + case <-ctx.Done(): + return err + } + } +} + // Sync systemd service configuration by running systemctl daemon-reload. // See Process interface for more details. func (s SystemdService) Sync(ctx context.Context) error { @@ -106,9 +269,42 @@ func (s SystemdService) checkSystem(ctx context.Context) error { // Output sent to stdout is logged at debug level. // Output sent to stderr is logged at the level specified by errLevel. func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args ...string) int { - cmd := exec.CommandContext(ctx, "systemctl", args...) - stderr := &lineLogger{ctx: ctx, log: s.Log, level: errLevel} - stdout := &lineLogger{ctx: ctx, log: s.Log, level: slog.LevelDebug} + cmd := &localExec{ + Log: s.Log, + ErrLevel: errLevel, + OutLevel: slog.LevelDebug, + } + code, err := cmd.Run(ctx, "systemctl", args...) + if err == nil { + return code + } + if code >= 0 { + s.Log.Log(ctx, errLevel, "Error running systemctl.", + "args", args, "code", code) + return code + } + s.Log.Log(ctx, errLevel, "Unable to run systemctl.", + "args", args, "code", code, errorKey, err) + return code +} + +// localExec runs a command locally, logging any output. +type localExec struct { + // Log contains a slog logger. + // Defaults to slog.Default() if nil. + Log *slog.Logger + // ErrLevel is the log level for stderr. + ErrLevel slog.Level + // OutLevel is the log level for stdout. + OutLevel slog.Level +} + +// Run the command. Same arguments as exec.CommandContext. +// Outputs the status code, or -1 if out-of-range or unstarted. +func (c *localExec) Run(ctx context.Context, name string, args ...string) (int, error) { + cmd := exec.CommandContext(ctx, name, args...) + stderr := &lineLogger{ctx: ctx, log: c.Log, level: c.ErrLevel, prefix: "[stderr] "} + stdout := &lineLogger{ctx: ctx, log: c.Log, level: c.OutLevel, prefix: "[stdout] "} cmd.Stderr = stderr cmd.Stdout = stdout err := cmd.Run() @@ -122,24 +318,23 @@ func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args if code == 255 { code = -1 } - if err != nil { - s.Log.Log(ctx, errLevel, "Failed to run systemctl.", - "args", args, - "code", code, - "error", err) - } - return code + return code, trace.Wrap(err) } // lineLogger logs each line written to it. type lineLogger struct { - ctx context.Context - log *slog.Logger - level slog.Level + ctx context.Context + log *slog.Logger + level slog.Level + prefix string last bytes.Buffer } +func (w *lineLogger) out(s string) { + w.log.Log(w.ctx, w.level, w.prefix+s) //nolint:sloglint // msg cannot be constant +} + func (w *lineLogger) Write(p []byte) (n int, err error) { lines := bytes.Split(p, []byte("\n")) // Finish writing line @@ -153,13 +348,13 @@ func (w *lineLogger) Write(p []byte) (n int, err error) { } // Newline found, log line - w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant + w.out(w.last.String()) n += 1 w.last.Reset() // Log lines that are already newline-terminated for _, line := range lines[:len(lines)-1] { - w.log.Log(w.ctx, w.level, string(line)) //nolint:sloglint // msg cannot be constant + w.out(string(line)) n += len(line) + 1 } @@ -174,6 +369,6 @@ func (w *lineLogger) Flush() { if w.last.Len() == 0 { return } - w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant + w.out(w.last.String()) w.last.Reset() } diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go index 5ffa70dd0091e..c558a7539831a 100644 --- a/lib/autoupdate/agent/process_test.go +++ b/lib/autoupdate/agent/process_test.go @@ -21,8 +21,13 @@ package agent import ( "bytes" "context" + "errors" + "fmt" "log/slog" + "os" + "path/filepath" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -69,3 +74,266 @@ func msgOnly(_ []string, a slog.Attr) slog.Attr { } return slog.Attr{Key: a.Key, Value: a.Value} } + +func TestWaitForStablePID(t *testing.T) { + t.Parallel() + + svc := &SystemdService{ + Log: slog.Default(), + } + + for _, tt := range []struct { + name string + ticks []int + baseline int + minStable int + maxCrashes int + findErrs map[int]error + + errored bool + canceled bool + }{ + { + name: "immediate restart", + ticks: []int{2, 2}, + baseline: 1, + minStable: 1, + maxCrashes: 1, + }, + { + name: "zero stable", + }, + { + name: "immediate crash", + ticks: []int{2, 3}, + baseline: 1, + minStable: 1, + maxCrashes: 0, + errored: true, + }, + { + name: "no changes times out", + ticks: []int{1, 1, 1, 1}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + canceled: true, + }, + { + name: "baseline restart", + ticks: []int{2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + }, + { + name: "one restart then stable", + ticks: []int{1, 2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + }, + { + name: "two restarts then stable", + ticks: []int{1, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + }, + { + name: "three restarts then stable", + ticks: []int{1, 2, 3, 4, 4, 4, 4}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + }, + { + name: "too many restarts excluding baseline", + ticks: []int{1, 2, 3, 4, 5}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + errored: true, + }, + { + name: "too many restarts including baseline", + ticks: []int{1, 2, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, + }, + { + name: "too many restarts slow", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, + }, + { + name: "too many restarts after stable", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + }, + { + name: "stable after too many restarts", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, + }, + { + name: "cancel", + ticks: []int{1, 1, 1}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + canceled: true, + }, + { + name: "stale PID crash", + ticks: []int{2, 2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: os.ErrProcessDone, + }, + errored: true, + }, + { + name: "stale PID but fixed", + ticks: []int{2, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: os.ErrProcessDone, + }, + }, + { + name: "error PID", + ticks: []int{2, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: errors.New("bad"), + }, + errored: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + ch := make(chan int) + go func() { + defer cancel() // always quit after last tick + for _, tick := range tt.ticks { + ch <- tick + } + }() + err := svc.waitForStablePID(ctx, tt.minStable, tt.maxCrashes, + tt.baseline, ch, func(pid int) error { + return tt.findErrs[pid] + }) + require.Equal(t, tt.canceled, errors.Is(err, context.Canceled)) + if !tt.canceled { + require.Equal(t, tt.errored, err != nil) + } + }) + } +} + +func TestTickFile(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + ticks []int + errored bool + }{ + { + name: "consistent", + ticks: []int{1, 1, 1}, + errored: false, + }, + { + name: "divergent", + ticks: []int{1, 2, 3}, + errored: false, + }, + { + name: "start error", + ticks: []int{-1, 1, 1}, + errored: false, + }, + { + name: "ephemeral error", + ticks: []int{1, -1, 1}, + errored: false, + }, + { + name: "end error", + ticks: []int{1, 1, -1}, + errored: true, + }, + { + name: "start missing", + ticks: []int{0, 1, 1}, + errored: false, + }, + { + name: "ephemeral missing", + ticks: []int{1, 0, 1}, + errored: false, + }, + { + name: "end missing", + ticks: []int{1, 1, 0}, + errored: false, + }, + { + name: "cancel-only", + errored: false, + }, + } { + t.Run(tt.name, func(t *testing.T) { + filePath := filepath.Join(t.TempDir(), "file") + + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + tickC := make(chan time.Time) + ch := make(chan int) + + go func() { + defer cancel() // always quit after last tick or fail + for _, tick := range tt.ticks { + _ = os.RemoveAll(filePath) + switch { + case tick > 0: + err := os.WriteFile(filePath, []byte(fmt.Sprintln(tick)), os.ModePerm) + require.NoError(t, err) + case tick < 0: + err := os.Mkdir(filePath, os.ModePerm) + require.NoError(t, err) + } + tickC <- time.Now() + res := <-ch + if tick < 0 { + tick = 0 + } + require.Equal(t, tick, res) + } + }() + err := tickFile(ctx, filePath, ch, tickC) + require.Equal(t, tt.errored, err != nil) + }) + } +} diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index 9625481df2cd2..5d82017998263 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -159,6 +159,7 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { }, Process: &SystemdService{ ServiceName: "teleport.service", + PIDPath: "/run/teleport.pid", Log: cfg.Log, }, }, nil @@ -530,7 +531,7 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s } else if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) { u.Log.ErrorContext(ctx, "Failed to revert Teleport to older version. Installation likely broken.", errorKey, err) } else { - u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.") + u.Log.WarnContext(ctx, "Teleport updater encountered an error during the update and successfully reverted the installation.") } return trace.Errorf("failed to start new version %q of Teleport: %w", targetVersion, err)