diff --git a/controller/execute.go b/controller/execute.go index 47940c8766..38b88f0426 100644 --- a/controller/execute.go +++ b/controller/execute.go @@ -22,6 +22,7 @@ import ( "net/http" "os" "os/signal" + "runtime" "syscall" "time" @@ -72,6 +73,24 @@ import ( "sigs.k8s.io/external-dns/source/wrappers" ) +// sigtermSignals is a package-level signal channel that is registered in init(). +// This way, SIGTERM is captured as soon as the package is loaded, preventing +// default process termination, even if application startup is delayed. +var sigtermSignals chan os.Signal + +func init() { + sigtermSignals = make(chan os.Signal, 1) + signal.Notify(sigtermSignals, terminationSignals()...) +} + +func terminationSignals() []os.Signal { + signals := []os.Signal{os.Interrupt} + if runtime.GOOS != "windows" { + signals = append(signals, syscall.SIGTERM) + } + return signals +} + func Execute() { cfg := externaldns.NewConfig() if err := cfg.ParseFlags(os.Args[1:]); err != nil { @@ -99,8 +118,14 @@ func Execute() { ctx, cancel := context.WithCancel(context.Background()) - go serveMetrics(cfg.MetricsAddress) - go handleSigterm(cancel) + // Connect global SIGTERM capture to this run's context cancellation. + go func() { + <-sigtermSignals + log.Info("Received termination signal. Terminating...") + cancel() + }() + + go serveMetrics(ctx, cfg.MetricsAddress) endpointsSource, err := buildSource(ctx, cfg) if err != nil { @@ -468,22 +493,25 @@ func createDomainFilter(cfg *externaldns.Config) *endpoint.DomainFilter { } } -// handleSigterm listens for a SIGTERM signal and triggers the provided cancel function +// handleSigterm listens for termination signals and triggers the provided cancel function // to gracefully terminate the application. It logs a message when the signal is received. func handleSigterm(cancel func()) { signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGTERM) + signal.Notify(signals, terminationSignals()...) <-signals - log.Info("Received SIGTERM. Terminating...") + log.Info("Received termination signal. Terminating...") cancel() + signal.Stop(signals) } // serveMetrics starts an HTTP server that serves health and metrics endpoints. // The /healthz endpoint returns a 200 OK status to indicate the service is healthy. // The /metrics endpoint serves Prometheus metrics. // The server listens on the specified address and logs debug information about the endpoints. -func serveMetrics(address string) { - http.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) { +func serveMetrics(ctx context.Context, address string) { + mux := http.NewServeMux() + + mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("OK")) }) @@ -492,7 +520,19 @@ func serveMetrics(address string) { log.Debugf("serving 'metrics' on '%s/metrics'", address) log.Debugf("registered '%d' metrics", len(metrics.RegisterMetric.Metrics)) - http.Handle("/metrics", promhttp.Handler()) + mux.Handle("/metrics", promhttp.Handler()) + + srv := &http.Server{Addr: address, Handler: mux} - log.Fatal(http.ListenAndServe(address, nil)) + // Shutdown server on context cancellation + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + _ = srv.Shutdown(shutdownCtx) + cancel() + }() + + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatal(err) + } } diff --git a/controller/execute_test.go b/controller/execute_test.go index d0198e7f36..5d02bdfbb6 100644 --- a/controller/execute_test.go +++ b/controller/execute_test.go @@ -20,12 +20,17 @@ import ( "bytes" "context" "errors" + "fmt" + "net" "net/http" "net/http/httptest" "os" "os/exec" + "os/signal" "reflect" "regexp" + "runtime" + "syscall" "testing" "time" @@ -375,6 +380,95 @@ func TestCreateDomainFilter(t *testing.T) { } } +func getRandomPort() (int, error) { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + return 0, err + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return 0, err + } + defer l.Close() + return l.Addr().(*net.TCPAddr).Port, nil +} + +func sendTerminationSignal() error { + proc, err := os.FindProcess(os.Getpid()) + if err != nil { + return err + } + if runtime.GOOS == "windows" { + return proc.Signal(os.Interrupt) + } + return proc.Signal(syscall.SIGTERM) +} + +func TestServeMetrics(t *testing.T) { + // Use a fresh DefaultServeMux for this test (do not restore to avoid data race with server goroutine) + http.DefaultServeMux = http.NewServeMux() + + port, err := getRandomPort() + require.NoError(t, err) + address := fmt.Sprintf("localhost:%d", port) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go serveMetrics(ctx, fmt.Sprintf(":%d", port)) + + // Wait for the TCP socket to be ready + require.Eventually(t, func() bool { + conn, err := net.Dial("tcp", address) + if err != nil { + return false + } + _ = conn.Close() + return true + }, 2*time.Second, 10*time.Millisecond, "server not ready with port open in time") + + resp, err := http.Get(fmt.Sprintf("http://%s/healthz", address)) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() + + resp, err = http.Get(fmt.Sprintf("http://%s/metrics", address)) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() + + // Stop the server to avoid leaking goroutines across tests + cancel() +} + +func TestHandleSigterm(t *testing.T) { + cancelCalled := make(chan bool, 1) + cancel := func() { cancelCalled <- true } + + var logOutput bytes.Buffer + log.SetOutput(&logOutput) + defer log.SetOutput(os.Stderr) + + go handleSigterm(cancel) + + // Simulate sending a termination signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, terminationSignals()...) + defer signal.Stop(sigChan) + err := sendTerminationSignal() + assert.NoError(t, err) + + // Wait for cancel to be called + select { + case <-cancelCalled: + assert.Contains(t, logOutput.String(), "Received termination signal. Terminating...") + case sig := <-sigChan: + assert.Contains(t, terminationSignals(), sig) + case <-time.After(1 * time.Second): + t.Fatal("cancel function was not called") + } +} + func TestBuildSource(t *testing.T) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotImplemented) @@ -629,6 +723,129 @@ func TestExecuteBuildControllerErrorExitsNonZero(t *testing.T) { assert.NotEqual(t, 0, code) } +// ValidateConfig triggers log.Fatalf (in-process). +func TestExecuteConfigValidationFatalInProcess(t *testing.T) { + // Prepare args to trigger validation error before any goroutines start + prevArgs := os.Args + os.Args = []string{ + "external-dns", + "--source", "fake", + "--provider", "inmemory", + "--ignore-hostname-annotation", // triggers validation: FQDN template required when ignoring annotations + "--metrics-address", ":0", + } + t.Cleanup(func() { os.Args = prevArgs }) + + // Capture logs and replace Fatalf with Goexit to stop only the Execute goroutine + logger := log.StandardLogger() + prevExit := logger.ExitFunc + prevOut := logger.Out + buf := new(bytes.Buffer) + logger.SetOutput(buf) + logger.ExitFunc = func(int) { runtime.Goexit() } + t.Cleanup(func() { logger.ExitFunc = prevExit; logger.SetOutput(prevOut) }) + + done := make(chan struct{}) + go func() { + defer close(done) + Execute() + }() + + select { + case <-done: + // ok + case <-time.After(2 * time.Second): + t.Fatal("Execute did not exit after validation fatal") + } + + // Do not assert on logger text to avoid flakiness with global logger +} + +// Run path with --events; shut down via SIGTERM. +func TestExecuteDefaultRunWithEventsStopsOnSigterm(t *testing.T) { + // Use a fresh DefaultServeMux for this test (do not restore to avoid data race with server goroutine) + http.DefaultServeMux = http.NewServeMux() + + // Prepare args to run Execute without --once and with --events + prevArgs := os.Args + os.Args = []string{ + "external-dns", + "--source", "fake", + "--provider", "inmemory", + "--events", + "--dry-run", + "--metrics-address", ":0", + } + t.Cleanup(func() { os.Args = prevArgs }) + + // Prevent log.Fatal from terminating the test process + logger := log.StandardLogger() + prevExit := logger.ExitFunc + logger.ExitFunc = func(int) { runtime.Goexit() } + t.Cleanup(func() { logger.ExitFunc = prevExit }) + + done := make(chan struct{}) + go func() { + defer close(done) + Execute() + }() + + // Give goroutines time to start + time.Sleep(50 * time.Millisecond) + + // Send termination signal to trigger handleSigterm(cancel) + require.NoError(t, sendTerminationSignal()) + + select { + case <-done: + // ok + case <-time.After(2 * time.Second): + t.Fatal("Execute did not stop after termination signal") + } +} + +// Webhook server path; pre-bind 127.0.0.1:8888 to force a bind failure. +func TestExecuteWebhookServerFailsPortInUseInProcess(t *testing.T) { + // Use a fresh DefaultServeMux for this test (do not restore to avoid data race with server goroutine) + http.DefaultServeMux = http.NewServeMux() + + // Pre-bind the webhook server port so it is unavailable + l, err := net.Listen("tcp", "127.0.0.1:8888") + if err != nil { + // If we cannot bind, assume something else is bound already, which is fine for this test + } else { + t.Cleanup(func() { _ = l.Close() }) + } + + prevArgs := os.Args + os.Args = []string{ + "external-dns", + "--source", "fake", + "--provider", "inmemory", + "--webhook-server", + "--metrics-address", ":0", + } + t.Cleanup(func() { os.Args = prevArgs }) + + logger := log.StandardLogger() + prevExit := logger.ExitFunc + logger.ExitFunc = func(int) { runtime.Goexit() } + t.Cleanup(func() { logger.ExitFunc = prevExit }) + + done := make(chan struct{}) + go func() { + defer close(done) + Execute() + }() + + select { + case <-done: + // ok + case <-time.After(2 * time.Second): + t.Fatal("Execute did not exit after webhook server fatal") + } +} + // Controller run loop stops on context cancel. func TestControllerRunCancelContextStopsLoop(t *testing.T) { // Minimal controller using fake source and inmemory provider.