From 7c2c225d88ec029c3347ca06adbbe47b615e55d9 Mon Sep 17 00:00:00 2001 From: Rana Ian Date: Wed, 24 Sep 2025 11:48:48 -0700 Subject: [PATCH] Refactor `healthcheck` for Kubernetes extensibility with `HealthChecker` interface The main intent of refactoring is to provide health check extensibility for Kubernetes while supporting the existing DB health checks. A `HealthChecker` interface is added to support the different health check approaches of DBs and Kubernetes. Existing DB TCP health check logic is moved to a new `TargetDialer` struct. Changes: - Added `HealthChecker` interface with two functions: - `CheckHealth(ctx context.Context) ([]string, error)` - `GetProtocol() types.TargetHealthProtocol` - Added `TargetDialer` struct which encapsulates existing TCP health check logic - Changed `Target` struct to use the `HealthChecker` interface - Changed `worker.checkHealth` to call the new `CheckHealth` function - Removed a `protocol` field from `healthCheckConfig` - Added `TargetHealthProtocolHTTP` for use with Kubernetes health checks - Moved and renamed test `Test_dialEndpoints` to `TestTargetDialer_dialEndpoints` - Added files `net.go` and `net_test.go` for `TargetDialer` Part of #58413 --- api/types/target_health.go | 6 +- lib/healthcheck/config.go | 15 +-- lib/healthcheck/config_test.go | 16 --- lib/healthcheck/manager_test.go | 50 +++++---- lib/healthcheck/net.go | 97 ++++++++++++++++ lib/healthcheck/net_test.go | 117 ++++++++++++++++++++ lib/healthcheck/target.go | 30 ++--- lib/healthcheck/target_test.go | 21 ++-- lib/healthcheck/worker.go | 62 ++--------- lib/healthcheck/worker_experimental_test.go | 18 +-- lib/healthcheck/worker_test.go | 111 +++---------------- lib/srv/db/server.go | 2 +- 12 files changed, 308 insertions(+), 237 deletions(-) create mode 100644 lib/healthcheck/net.go create mode 100644 lib/healthcheck/net_test.go diff --git a/api/types/target_health.go b/api/types/target_health.go index e6b36eb64e0e6..40a16b0f19dce 100644 --- a/api/types/target_health.go +++ b/api/types/target_health.go @@ -25,8 +25,10 @@ import ( type TargetHealthProtocol string const ( - // TargetHealthProtocolTCP is a target health check protocol. - TargetHealthProtocolTCP TargetHealthProtocol = "TCP" + // TargetHealthProtocolTCP is the TCP target health check protocol. + TargetHealthProtocolTCP TargetHealthProtocol = "tcp" + // TargetHealthProtocolHTTP is the HTTP target health check protocol. + TargetHealthProtocolHTTP TargetHealthProtocol = "http" ) // TargetHealthStatus is a target resource's health status. diff --git a/lib/healthcheck/config.go b/lib/healthcheck/config.go index 1de67559a7a59..44a5df8903b51 100644 --- a/lib/healthcheck/config.go +++ b/lib/healthcheck/config.go @@ -34,7 +34,6 @@ import ( // [*healthcheckconfigv1.HealthCheckConfig] with defaults set. type healthCheckConfig struct { name string - protocol types.TargetHealthProtocol interval time.Duration timeout time.Duration healthyThreshold uint32 @@ -48,14 +47,11 @@ func newHealthCheckConfig(cfg *healthcheckconfigv1.HealthCheckConfig) *healthChe spec := cfg.GetSpec() match := spec.GetMatch() return &healthCheckConfig{ - name: cfg.GetMetadata().GetName(), - timeout: cmp.Or(spec.GetTimeout().AsDuration(), defaults.HealthCheckTimeout), - interval: cmp.Or(spec.GetInterval().AsDuration(), defaults.HealthCheckInterval), - healthyThreshold: cmp.Or(spec.GetHealthyThreshold(), defaults.HealthCheckHealthyThreshold), - unhealthyThreshold: cmp.Or(spec.GetUnhealthyThreshold(), defaults.HealthCheckUnhealthyThreshold), - // we only support plain TCP health checks currently, but eventually we - // may add support for other protocols such as TLS or HTTP - protocol: types.TargetHealthProtocolTCP, + name: cfg.GetMetadata().GetName(), + timeout: cmp.Or(spec.GetTimeout().AsDuration(), defaults.HealthCheckTimeout), + interval: cmp.Or(spec.GetInterval().AsDuration(), defaults.HealthCheckInterval), + healthyThreshold: cmp.Or(spec.GetHealthyThreshold(), defaults.HealthCheckHealthyThreshold), + unhealthyThreshold: cmp.Or(spec.GetUnhealthyThreshold(), defaults.HealthCheckUnhealthyThreshold), databaseLabelMatchers: newLabelMatchers(match.GetDbLabelsExpression(), match.GetDbLabels()), } } @@ -66,7 +62,6 @@ func (h *healthCheckConfig) equivalent(other *healthCheckConfig) bool { return (h == nil && other == nil) || h != nil && other != nil && h.name == other.name && - h.protocol == other.protocol && h.interval == other.interval && h.timeout == other.timeout && h.healthyThreshold == other.healthyThreshold && diff --git a/lib/healthcheck/config_test.go b/lib/healthcheck/config_test.go index d9edbe9910ee8..21837c40bc762 100644 --- a/lib/healthcheck/config_test.go +++ b/lib/healthcheck/config_test.go @@ -77,7 +77,6 @@ func Test_newHealthCheckConfig(t *testing.T) { interval: time.Second * 43, healthyThreshold: 7, unhealthyThreshold: 8, - protocol: types.TargetHealthProtocolTCP, databaseLabelMatchers: types.LabelMatchers{ Labels: types.Labels{ "foo": utils.Strings{"bar", "baz"}, @@ -95,7 +94,6 @@ func Test_newHealthCheckConfig(t *testing.T) { interval: defaults.HealthCheckInterval, healthyThreshold: defaults.HealthCheckHealthyThreshold, unhealthyThreshold: defaults.HealthCheckUnhealthyThreshold, - protocol: types.TargetHealthProtocolTCP, databaseLabelMatchers: types.LabelMatchers{ Labels: types.Labels{}, Expression: `labels["*"] == "*"`, @@ -139,7 +137,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) { desc: "all fields equal", a: &healthCheckConfig{ name: "test", - protocol: "http", interval: time.Second, timeout: 500 * time.Millisecond, healthyThreshold: 3, @@ -147,7 +144,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) { }, b: &healthCheckConfig{ name: "test", - protocol: "http", interval: time.Second, timeout: 500 * time.Millisecond, healthyThreshold: 3, @@ -159,7 +155,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) { desc: "all fields equal ignoring labels", a: &healthCheckConfig{ name: "test", - protocol: "http", interval: time.Second, timeout: 500 * time.Millisecond, healthyThreshold: 3, @@ -168,7 +163,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) { }, b: &healthCheckConfig{ name: "test", - protocol: "http", interval: time.Second, timeout: 500 * time.Millisecond, healthyThreshold: 3, @@ -187,16 +181,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) { }, want: false, }, - { - desc: "different protocol", - a: &healthCheckConfig{ - protocol: "http", - }, - b: &healthCheckConfig{ - protocol: "tcp", - }, - want: false, - }, { desc: "different interval", a: &healthCheckConfig{ diff --git a/lib/healthcheck/manager_test.go b/lib/healthcheck/manager_test.go index af8c22f1f50b2..ae6362fdb126e 100644 --- a/lib/healthcheck/manager_test.go +++ b/lib/healthcheck/manager_test.go @@ -181,17 +181,19 @@ func TestManager(t *testing.T) { var endpointMu sync.Mutex prodDialer := fakeDialer{} err = mgr.AddTarget(Target{ + HealthChecker: &TargetDialer{ + Resolver: func(ctx context.Context) ([]string, error) { + endpointMu.Lock() + defer endpointMu.Unlock() + return []string{prodDB.GetURI()}, nil + }, + dial: prodDialer.DialContext, + }, GetResource: func() types.ResourceWithLabels { endpointMu.Lock() defer endpointMu.Unlock() return prodDB }, - ResolverFn: func(ctx context.Context) ([]string, error) { - endpointMu.Lock() - defer endpointMu.Unlock() - return []string{prodDB.GetURI()}, nil - }, - dialFn: prodDialer.DialContext, onHealthCheck: func(lastResultErr error) { eventsCh <- lastResultTestEvent(prodDB.GetName(), lastResultErr) }, @@ -205,18 +207,20 @@ func TestManager(t *testing.T) { require.NoError(t, err) devDialer := fakeDialer{} - err = mgr.AddTarget(Target{ + devTarget := Target{ + HealthChecker: &TargetDialer{ + Resolver: func(ctx context.Context) ([]string, error) { + endpointMu.Lock() + defer endpointMu.Unlock() + return []string{devDB.GetURI()}, nil + }, + dial: devDialer.DialContext, + }, GetResource: func() types.ResourceWithLabels { endpointMu.Lock() defer endpointMu.Unlock() return devDB }, - ResolverFn: func(ctx context.Context) ([]string, error) { - endpointMu.Lock() - defer endpointMu.Unlock() - return []string{devDB.GetURI()}, nil - }, - dialFn: devDialer.DialContext, onHealthCheck: func(lastResultErr error) { eventsCh <- lastResultTestEvent(devDB.GetName(), lastResultErr) }, @@ -226,21 +230,27 @@ func TestManager(t *testing.T) { onClose: func() { eventsCh <- closedTestEvent(devDB.GetName()) }, - }) + } + err = mgr.AddTarget(devTarget) require.NoError(t, err) t.Run("duplicate target is an error", func(t *testing.T) { - err = mgr.AddTarget(Target{ - GetResource: func() types.ResourceWithLabels { return devDB }, - ResolverFn: func(ctx context.Context) ([]string, error) { return nil, nil }, - }) + err = mgr.AddTarget(devTarget) require.Error(t, err) require.IsType(t, trace.AlreadyExists(""), err) }) t.Run("unsupported target resource is an error", func(t *testing.T) { err = mgr.AddTarget(Target{ - GetResource: func() types.ResourceWithLabels { return &fakeResource{kind: "node"} }, - ResolverFn: func(ctx context.Context) ([]string, error) { return nil, nil }, + HealthChecker: &TargetDialer{ + Resolver: func(ctx context.Context) ([]string, error) { + endpointMu.Lock() + defer endpointMu.Unlock() + return nil, nil + }, + }, + GetResource: func() types.ResourceWithLabels { + return &fakeResource{kind: "node"} + }, }) require.Error(t, err) require.IsType(t, trace.BadParameter(""), err) diff --git a/lib/healthcheck/net.go b/lib/healthcheck/net.go new file mode 100644 index 0000000000000..557eff587d78a --- /dev/null +++ b/lib/healthcheck/net.go @@ -0,0 +1,97 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package healthcheck + +import ( + "context" + "net" + + "github.com/gravitational/trace" + "golang.org/x/sync/errgroup" + + "github.com/gravitational/teleport/api/types" +) + +// EndpointsResolverFunc is callback func that returns endpoints for a target. +type EndpointsResolverFunc func(ctx context.Context) ([]string, error) + +// dialFunc dials an address on the given network. +type dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + +// TargetDialer is a health check target which uses a net.Dialer. +type TargetDialer struct { + // Resolver resolves the target endpoint(s). + Resolver EndpointsResolverFunc + // dial is used to dial network connections. + dial dialFunc +} + +// NewTargetDialer returns a new TargetDialer ready for use. +func NewTargetDialer(resolver EndpointsResolverFunc) *TargetDialer { + return &TargetDialer{ + Resolver: resolver, + dial: defaultDialer().DialContext, + } +} + +// GetProtocol returns the network protocol used for checking health. +func (t *TargetDialer) GetProtocol() types.TargetHealthProtocol { + return types.TargetHealthProtocolTCP +} + +// CheckHealth checks the health of the target resource. +func (t *TargetDialer) CheckHealth(ctx context.Context) ([]string, error) { + return t.dialEndpoints(ctx) +} + +func (t *TargetDialer) dialEndpoints(ctx context.Context) ([]string, error) { + endpoints, err := t.Resolver(ctx) + if err != nil { + return nil, trace.Wrap(err, "failed to resolve target endpoints") + } + switch len(endpoints) { + case 0: + return nil, trace.NotFound("resolved zero target endpoints") + case 1: + return endpoints, t.dialEndpoint(ctx, endpoints[0]) + default: + group, ctx := errgroup.WithContext(ctx) + group.SetLimit(10) + for _, ep := range endpoints { + group.Go(func() error { + return trace.Wrap(t.dialEndpoint(ctx, ep)) + }) + } + return endpoints, group.Wait() + } +} + +func (t *TargetDialer) dialEndpoint(ctx context.Context, endpoint string) error { + conn, err := t.dial(ctx, "tcp", endpoint) + if err != nil { + return trace.Wrap(err) + } + // an error while closing the connection could indicate an RST packet from + // the endpoint - that's a health check failure. + return trace.Wrap(conn.Close()) +} + +func defaultDialer() *net.Dialer { + return &net.Dialer{} +} diff --git a/lib/healthcheck/net_test.go b/lib/healthcheck/net_test.go new file mode 100644 index 0000000000000..b5880f5dc2c6c --- /dev/null +++ b/lib/healthcheck/net_test.go @@ -0,0 +1,117 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package healthcheck + +import ( + "context" + "net" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" +) + +func TestTargetDialer_dialEndpoints(t *testing.T) { + t.Parallel() + + const healthyAddr = "healthy.com:123" + const unhealthyAddr = "unhealthy.com:123" + tests := []struct { + desc string + resolver EndpointsResolverFunc + wantErrContains string + }{ + { + desc: "resolver error", + resolver: func(ctx context.Context) ([]string, error) { + return nil, trace.Errorf("resolver error") + }, + wantErrContains: "resolver error", + }, + { + desc: "resolved zero addrs", + resolver: func(ctx context.Context) ([]string, error) { + return nil, nil + }, + wantErrContains: "resolved zero target endpoints", + }, + { + desc: "resolved one healthy addr", + resolver: func(ctx context.Context) ([]string, error) { + return []string{healthyAddr}, nil + }, + }, + { + desc: "resolved one unhealthy addr", + resolver: func(ctx context.Context) ([]string, error) { + return []string{unhealthyAddr}, nil + }, + wantErrContains: "unhealthy addr", + }, + { + desc: "resolved multiple healthy addrs", + resolver: func(ctx context.Context) ([]string, error) { + return []string{healthyAddr, healthyAddr, healthyAddr}, nil + }, + }, + { + desc: "resolved a mix of healthy and unhealthy addrs", + resolver: func(ctx context.Context) ([]string, error) { + return []string{healthyAddr, unhealthyAddr, healthyAddr}, nil + }, + wantErrContains: "unhealthy addr", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + d := &TargetDialer{ + Resolver: test.resolver, + dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + if addr == healthyAddr { + return fakeConn{}, nil + } + return nil, trace.Errorf("unhealthy addr") + }, + } + _, err := d.dialEndpoints(t.Context()) + if test.wantErrContains != "" { + require.ErrorContains(t, err, test.wantErrContains) + return + } + require.NoError(t, err) + }) + } +} + +type fakeConn struct { + net.Conn +} + +func (fakeConn) Close() error { return nil } + +type fakeResource struct { + kind string + types.ResourceWithLabels +} + +func (r *fakeResource) GetKind() string { + return r.kind +} diff --git a/lib/healthcheck/target.go b/lib/healthcheck/target.go index b9c9ca0851c73..f785848566e65 100644 --- a/lib/healthcheck/target.go +++ b/lib/healthcheck/target.go @@ -20,30 +20,29 @@ package healthcheck import ( "context" - "net" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" ) -// EndpointsResolverFunc is callback func that returns endpoints for a target. -type EndpointsResolverFunc func(ctx context.Context) ([]string, error) - -// OnHealthChangeFunc is a func called on each health change. -type OnHealthChangeFunc func(oldHealth, newHealth types.TargetHealth) +// HealthChecker is a resource which provides health checks. +type HealthChecker interface { + // CheckHealth checks the health of a target resource. + CheckHealth(ctx context.Context) ([]string, error) + // GetProtocol returns the network protocol used for checking health. + GetProtocol() types.TargetHealthProtocol +} // Target is a health check target. type Target struct { + // HealthChecker checks the resource's health. + HealthChecker // GetResource gets a copy of the target resource with updated labels. GetResource func() types.ResourceWithLabels - // ResolverFn resolves the target endpoint(s). - ResolverFn EndpointsResolverFunc // -- test fields below -- - // dialFn used to mock dialing in tests - dialFn dialFunc // onHealthCheck is called after each health check. onHealthCheck func(lastResultErr error) // onConfigUpdate is called after each config update. @@ -56,15 +55,8 @@ func (t *Target) checkAndSetDefaults() error { if t.GetResource == nil { return trace.BadParameter("missing target resource getter") } - if t.ResolverFn == nil { - return trace.BadParameter("missing target endpoint resolver") - } - if t.dialFn == nil { - t.dialFn = defaultDialer().DialContext + if t.HealthChecker == nil { + return trace.BadParameter("missing health checker") } return nil } - -func defaultDialer() *net.Dialer { - return &net.Dialer{} -} diff --git a/lib/healthcheck/target_test.go b/lib/healthcheck/target_test.go index 0d4c3c5ec4ba0..f5bcedc2ae9cf 100644 --- a/lib/healthcheck/target_test.go +++ b/lib/healthcheck/target_test.go @@ -28,15 +28,6 @@ import ( "github.com/gravitational/teleport/api/types" ) -type fakeResource struct { - kind string - types.ResourceWithLabels -} - -func (r *fakeResource) GetKind() string { - return r.kind -} - func TestTarget_checkAndSetDefaults(t *testing.T) { tests := []struct { name string @@ -46,23 +37,27 @@ func TestTarget_checkAndSetDefaults(t *testing.T) { { name: "valid target", target: &Target{ + HealthChecker: NewTargetDialer( + func(ctx context.Context) ([]string, error) { return []string{"127.0.0.1"}, nil }, + ), GetResource: func() types.ResourceWithLabels { return &fakeResource{kind: types.KindDatabase} }, - ResolverFn: func(ctx context.Context) ([]string, error) { return []string{"127.0.0.1"}, nil }, }, }, { name: "missing resource getter", target: &Target{ - ResolverFn: func(ctx context.Context) ([]string, error) { return nil, nil }, + HealthChecker: NewTargetDialer( + func(ctx context.Context) ([]string, error) { return nil, nil }, + ), }, wantErr: "missing target resource getter", }, { - name: "missing resolver", + name: "missing health checker", target: &Target{ GetResource: func() types.ResourceWithLabels { return nil }, }, - wantErr: "missing target endpoint resolver", + wantErr: "missing health checker", }, } diff --git a/lib/healthcheck/worker.go b/lib/healthcheck/worker.go index 4ff9c678ca344..d224639fe0d3f 100644 --- a/lib/healthcheck/worker.go +++ b/lib/healthcheck/worker.go @@ -22,7 +22,6 @@ import ( "context" "fmt" "log/slog" - "net" "strings" "sync" "time" @@ -30,7 +29,6 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" - "golang.org/x/sync/errgroup" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" @@ -160,9 +158,6 @@ type worker struct { metricType string } -// dialFunc dials an address on the given network. -type dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) - // GetTargetHealth returns the worker's target health. func (w *worker) GetTargetHealth() *types.TargetHealth { w.mu.RLock() @@ -238,7 +233,6 @@ func (w *worker) run() { func (w *worker) startHealthCheckInterval(ctx context.Context) { w.log.InfoContext(ctx, "Health checker started", "health_check_config", w.healthCheckCfg.name, - "protocol", w.healthCheckCfg.protocol, "interval", log.StringerAttr(w.healthCheckCfg.interval), "timeout", log.StringerAttr(w.healthCheckCfg.timeout), "healthy_threshold", w.healthCheckCfg.healthyThreshold, @@ -276,25 +270,31 @@ func (w *worker) nextHealthCheck() <-chan time.Time { // updates the worker's health check result history, and possibly updates the // target health. func (w *worker) checkHealth(ctx context.Context) { - initializing := w.lastResultCount == 0 - dialErr := w.dialEndpoints(ctx) + ctx, cancel := context.WithTimeout(ctx, w.healthCheckCfg.timeout) + defer cancel() + + // check target health + var curErr error + w.lastResolvedEndpoints, curErr = w.target.CheckHealth(ctx) + if ctx.Err() == context.Canceled { return } - if (dialErr == nil) == (w.lastResultErr == nil) { + initializing := w.lastResultCount == 0 + if (curErr == nil) == (w.lastResultErr == nil) { w.lastResultCount++ } else { // the passing/failing result streak has ended, so reset the count w.lastResultCount = 1 } - w.lastResultErr = dialErr + w.lastResultErr = curErr if w.lastResultErr != nil { w.log.DebugContext(ctx, "Failed health check", "error", w.lastResultErr, ) } - // update target health when we exactly reach the threshold or initialize + // update target health when we initialize or exactly reach the threshold if initializing || w.getThreshold(w.healthCheckCfg) == w.lastResultCount { w.setThresholdReached(ctx) } @@ -322,7 +322,6 @@ func (w *worker) updateHealthCheckConfig(ctx context.Context, newCfg *healthChec } w.log.DebugContext(ctx, "Updated health check config", "health_check_config", w.healthCheckCfg.name, - "protocol", w.healthCheckCfg.protocol, "interval", log.StringerAttr(w.healthCheckCfg.interval), "timeout", log.StringerAttr(w.healthCheckCfg.timeout), "healthy_threshold", w.healthCheckCfg.healthyThreshold, @@ -347,41 +346,6 @@ func (w *worker) updateHealthCheckConfig(ctx context.Context, newCfg *healthChec } } -func (w *worker) dialEndpoints(ctx context.Context) error { - ctx, cancel := context.WithTimeout(ctx, w.healthCheckCfg.timeout) - defer cancel() - endpoints, err := w.target.ResolverFn(ctx) - if err != nil { - return trace.Wrap(err, "failed to resolve target endpoints") - } - w.lastResolvedEndpoints = endpoints - switch len(endpoints) { - case 0: - return trace.NotFound("resolved zero target endpoints") - case 1: - return w.dialEndpoint(ctx, endpoints[0]) - default: - group, ctx := errgroup.WithContext(ctx) - group.SetLimit(10) - for _, ep := range endpoints { - group.Go(func() error { - return trace.Wrap(w.dialEndpoint(ctx, ep)) - }) - } - return group.Wait() - } -} - -func (w *worker) dialEndpoint(ctx context.Context, endpoint string) error { - conn, err := w.target.dialFn(ctx, "tcp", endpoint) - if err != nil { - return trace.Wrap(err) - } - // an error while closing the connection could indicate an RST packet from - // the endpoint - that's a health check failure. - return trace.Wrap(conn.Close()) -} - // getThreshold returns the appropriate threshold to compare against the last // result. func (w *worker) getThreshold(cfg *healthCheckConfig) uint32 { @@ -466,6 +430,7 @@ func (w *worker) setTargetHealthStatus(ctx context.Context, newStatus types.Targ now := w.clock.Now() w.targetHealth = types.TargetHealth{ Address: strings.Join(w.lastResolvedEndpoints, ","), + Protocol: string(w.target.GetProtocol()), Status: string(newStatus), TransitionTimestamp: &now, TransitionReason: string(reason), @@ -474,9 +439,6 @@ func (w *worker) setTargetHealthStatus(ctx context.Context, newStatus types.Targ if w.lastResultErr != nil { w.targetHealth.TransitionError = w.lastResultErr.Error() } - if w.healthCheckCfg != nil { - w.targetHealth.Protocol = string(w.healthCheckCfg.protocol) - } } // notifyInitStatusAvailableLocked closes the pending init status channel, if diff --git a/lib/healthcheck/worker_experimental_test.go b/lib/healthcheck/worker_experimental_test.go index c0860dc74637a..8c25fd4bad49b 100644 --- a/lib/healthcheck/worker_experimental_test.go +++ b/lib/healthcheck/worker_experimental_test.go @@ -72,15 +72,17 @@ func TestGetTargetHealth(t *testing.T) { worker, err := newWorker(t.Context(), workerConfig{ HealthCheckCfg: test.healthCheckConfig, Target: Target{ - GetResource: func() types.ResourceWithLabels { return nil }, - ResolverFn: func(ctx context.Context) ([]string, error) { - return []string{"localhost:1234"}, nil - }, - dialFn: func(ctx context.Context, network, addr string) (net.Conn, error) { - time.Sleep(5*time.Second - time.Nanosecond) - synctest.Wait() - return fakeConn{}, test.dialErr + HealthChecker: &TargetDialer{ + Resolver: func(ctx context.Context) ([]string, error) { + return []string{"localhost:1234"}, nil + }, + dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + time.Sleep(5*time.Second - time.Nanosecond) + synctest.Wait() + return fakeConn{}, test.dialErr + }, }, + GetResource: func() types.ResourceWithLabels { return nil }, }, }) assert.NoError(t, err) diff --git a/lib/healthcheck/worker_test.go b/lib/healthcheck/worker_test.go index 73468d7016eab..b495894c017e3 100644 --- a/lib/healthcheck/worker_test.go +++ b/lib/healthcheck/worker_test.go @@ -20,14 +20,12 @@ package healthcheck import ( "context" - "log/slog" "net" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -37,7 +35,8 @@ import ( func Test_newUnstartedWorker(t *testing.T) { t.Parallel() ctx := context.Background() - listener, err := net.Listen("tcp", "localhost:0") + protocol := string(types.TargetHealthProtocolTCP) + listener, err := net.Listen(protocol, "localhost:0") require.NoError(t, err) t.Cleanup(func() { _ = listener.Close() }) @@ -63,15 +62,15 @@ func Test_newUnstartedWorker(t *testing.T) { cfg: workerConfig{ HealthCheckCfg: nil, Target: Target{ + HealthChecker: NewTargetDialer( + func(ctx context.Context) ([]string, error) { return nil, nil }, + ), GetResource: func() types.ResourceWithLabels { return db }, - ResolverFn: func(ctx context.Context) ([]string, error) { - return []string{db.GetURI()}, nil - }, }, }, wantHealth: types.TargetHealth{ Address: "", - Protocol: "", + Protocol: protocol, Status: string(types.TargetHealthStatusUnknown), TransitionReason: string(types.TargetHealthTransitionReasonDisabled), Message: "No health check config matches this resource", @@ -87,16 +86,16 @@ func Test_newUnstartedWorker(t *testing.T) { unhealthyThreshold: 10, }, Target: Target{ + HealthChecker: NewTargetDialer( + func(ctx context.Context) ([]string, error) { return nil, nil }, + ), GetResource: func() types.ResourceWithLabels { return db }, - ResolverFn: func(ctx context.Context) ([]string, error) { - return []string{db.GetURI()}, nil - }, }, getTargetHealthTimeout: time.Millisecond, }, wantHealth: types.TargetHealth{ Address: "", - Protocol: "", + Protocol: protocol, Status: string(types.TargetHealthStatusUnknown), TransitionReason: string(types.TargetHealthTransitionReasonInit), Message: "Health checker initialized", @@ -112,9 +111,9 @@ func Test_newUnstartedWorker(t *testing.T) { unhealthyThreshold: 10, }, Target: Target{ - ResolverFn: func(ctx context.Context) ([]string, error) { - return []string{db.GetURI()}, nil - }, + HealthChecker: NewTargetDialer( + func(ctx context.Context) ([]string, error) { return nil, nil }, + ), }, }, wantErr: "missing target resource getter", @@ -135,87 +134,3 @@ func Test_newUnstartedWorker(t *testing.T) { }) } } - -func Test_dialEndpoints(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - - const healthyAddr = "healthy.com:123" - const unhealthyAddr = "unhealthy.com:123" - tests := []struct { - desc string - resolverFn EndpointsResolverFunc - wantErrContains string - }{ - { - desc: "resolver error", - resolverFn: func(ctx context.Context) ([]string, error) { - return nil, trace.Errorf("resolver error") - }, - wantErrContains: "resolver error", - }, - { - desc: "resolved zero addrs", - resolverFn: func(ctx context.Context) ([]string, error) { - return nil, nil - }, - wantErrContains: "resolved zero target endpoints", - }, - { - desc: "resolved one healthy addr", - resolverFn: func(ctx context.Context) ([]string, error) { - return []string{healthyAddr}, nil - }, - }, - { - desc: "resolved one unhealthy addr", - resolverFn: func(ctx context.Context) ([]string, error) { - return []string{unhealthyAddr}, nil - }, - wantErrContains: "unhealthy addr", - }, - { - desc: "resolved multiple healthy addrs", - resolverFn: func(ctx context.Context) ([]string, error) { - return []string{healthyAddr, healthyAddr, healthyAddr}, nil - }, - }, - { - desc: "resolved a mix of healthy and unhealthy addrs", - resolverFn: func(ctx context.Context) ([]string, error) { - return []string{healthyAddr, unhealthyAddr, healthyAddr}, nil - }, - wantErrContains: "unhealthy addr", - }, - } - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - w := &worker{ - healthCheckCfg: &healthCheckConfig{}, - log: slog.Default(), - target: Target{ - ResolverFn: test.resolverFn, - dialFn: func(ctx context.Context, network, addr string) (net.Conn, error) { - if addr == healthyAddr { - return fakeConn{}, nil - } - return nil, trace.Errorf("unhealthy addr") - }, - }, - } - err := w.dialEndpoints(ctx) - if test.wantErrContains != "" { - require.ErrorContains(t, err, test.wantErrContains) - return - } - require.NoError(t, err) - }) - } -} - -type fakeConn struct { - net.Conn -} - -func (fakeConn) Close() error { return nil } diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 54a54daac172f..c738b43ae9f23 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -1481,12 +1481,12 @@ func (s *Server) startHealthCheck(ctx context.Context, db types.Database) error return trace.Wrap(err) } err = s.cfg.healthCheckManager.AddTarget(healthcheck.Target{ + HealthChecker: healthcheck.NewTargetDialer(resolver), GetResource: func() types.ResourceWithLabels { s.mu.RLock() defer s.mu.RUnlock() return s.copyDatabaseWithUpdatedLabelsLocked(db) }, - ResolverFn: resolver, }) return trace.Wrap(err) }