diff --git a/lib/service/signals.go b/lib/service/signals.go index c319533a85dc1..290d28c2def46 100644 --- a/lib/service/signals.go +++ b/lib/service/signals.go @@ -82,9 +82,16 @@ func (process *TeleportProcess) WaitForSignals(ctx context.Context) error { process.Shutdown(ctx) process.log.Infof("All services stopped, exiting.") return nil - case syscall.SIGTERM, syscall.SIGKILL, syscall.SIGINT: - process.log.Infof("Got signal %q, exiting immediately.", signal) - process.Close() + case syscall.SIGTERM, syscall.SIGINT: + timeout := getShutdownTimeout(process.log) + cancelCtx, cancelFunc := context.WithTimeout(ctx, timeout) + process.log.Infof("Got signal %q, exiting within %vs.", signal, timeout.Seconds()) + go func() { + defer cancelFunc() + process.Shutdown(cancelCtx) + }() + <-cancelCtx.Done() + process.log.Infof("All services stopped or timeout passed, exiting immediately.") return nil case syscall.SIGUSR1: // All programs placed diagnostics on the standard output. @@ -151,6 +158,31 @@ func (process *TeleportProcess) WaitForSignals(ctx context.Context) error { } } +const defaultShutdownTimeout = time.Second * 3 +const maxShutdownTimeout = time.Minute * 10 + +func getShutdownTimeout(log logrus.FieldLogger) time.Duration { + timeout := defaultShutdownTimeout + + // read undocumented env var TELEPORT_UNSTABLE_SHUTDOWN_TIMEOUT. + // TODO(Tener): DELETE IN 15.0. after ironing out all possible shutdown bugs. + override := os.Getenv("TELEPORT_UNSTABLE_SHUTDOWN_TIMEOUT") + if override != "" { + t, err := time.ParseDuration(override) + if err != nil { + log.Warnf("Cannot parse timeout override %q, using default instead.", override) + } + if err == nil { + if t > maxShutdownTimeout { + log.Warnf("Timeout override %q exceeds maximum value, reducing.", override) + t = maxShutdownTimeout + } + timeout = t + } + } + return timeout +} + // ErrTeleportReloading is returned when signal waiter exits // because the teleport process has initiaded shutdown var ErrTeleportReloading = &trace.CompareFailedError{Message: "teleport process is reloading"} diff --git a/lib/service/signals_test.go b/lib/service/signals_test.go new file mode 100644 index 0000000000000..2819e90c92462 --- /dev/null +++ b/lib/service/signals_test.go @@ -0,0 +1,63 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +func Test_getShutdownTimeout(t *testing.T) { + tests := []struct { + name string + envValue string + want time.Duration + }{ + { + name: "no override", + envValue: "", + want: defaultShutdownTimeout, + }, + { + name: "accept valid override, one second", + envValue: "1s", + want: time.Second * 1, + }, + { + name: "accept valid override, one minute", + envValue: "1m", + want: time.Minute * 1, + }, + { + name: "ignore invalid override", + envValue: "one moment", + want: defaultShutdownTimeout, + }, + { + name: "valid override above maximum, trim", + envValue: "3000h", + want: maxShutdownTimeout, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("TELEPORT_UNSTABLE_SHUTDOWN_TIMEOUT", tt.envValue) + require.Equal(t, tt.want, getShutdownTimeout(logrus.StandardLogger())) + }) + } +}