diff --git a/api/types/target_health.go b/api/types/target_health.go
index e6b36eb64e0e6..40a16b0f19dce 100644
--- a/api/types/target_health.go
+++ b/api/types/target_health.go
@@ -25,8 +25,10 @@ import (
type TargetHealthProtocol string
const (
- // TargetHealthProtocolTCP is a target health check protocol.
- TargetHealthProtocolTCP TargetHealthProtocol = "TCP"
+ // TargetHealthProtocolTCP is the TCP target health check protocol.
+ TargetHealthProtocolTCP TargetHealthProtocol = "tcp"
+ // TargetHealthProtocolHTTP is the HTTP target health check protocol.
+ TargetHealthProtocolHTTP TargetHealthProtocol = "http"
)
// TargetHealthStatus is a target resource's health status.
diff --git a/lib/healthcheck/config.go b/lib/healthcheck/config.go
index 1de67559a7a59..44a5df8903b51 100644
--- a/lib/healthcheck/config.go
+++ b/lib/healthcheck/config.go
@@ -34,7 +34,6 @@ import (
// [*healthcheckconfigv1.HealthCheckConfig] with defaults set.
type healthCheckConfig struct {
name string
- protocol types.TargetHealthProtocol
interval time.Duration
timeout time.Duration
healthyThreshold uint32
@@ -48,14 +47,11 @@ func newHealthCheckConfig(cfg *healthcheckconfigv1.HealthCheckConfig) *healthChe
spec := cfg.GetSpec()
match := spec.GetMatch()
return &healthCheckConfig{
- name: cfg.GetMetadata().GetName(),
- timeout: cmp.Or(spec.GetTimeout().AsDuration(), defaults.HealthCheckTimeout),
- interval: cmp.Or(spec.GetInterval().AsDuration(), defaults.HealthCheckInterval),
- healthyThreshold: cmp.Or(spec.GetHealthyThreshold(), defaults.HealthCheckHealthyThreshold),
- unhealthyThreshold: cmp.Or(spec.GetUnhealthyThreshold(), defaults.HealthCheckUnhealthyThreshold),
- // we only support plain TCP health checks currently, but eventually we
- // may add support for other protocols such as TLS or HTTP
- protocol: types.TargetHealthProtocolTCP,
+ name: cfg.GetMetadata().GetName(),
+ timeout: cmp.Or(spec.GetTimeout().AsDuration(), defaults.HealthCheckTimeout),
+ interval: cmp.Or(spec.GetInterval().AsDuration(), defaults.HealthCheckInterval),
+ healthyThreshold: cmp.Or(spec.GetHealthyThreshold(), defaults.HealthCheckHealthyThreshold),
+ unhealthyThreshold: cmp.Or(spec.GetUnhealthyThreshold(), defaults.HealthCheckUnhealthyThreshold),
databaseLabelMatchers: newLabelMatchers(match.GetDbLabelsExpression(), match.GetDbLabels()),
}
}
@@ -66,7 +62,6 @@ func (h *healthCheckConfig) equivalent(other *healthCheckConfig) bool {
return (h == nil && other == nil) ||
h != nil && other != nil &&
h.name == other.name &&
- h.protocol == other.protocol &&
h.interval == other.interval &&
h.timeout == other.timeout &&
h.healthyThreshold == other.healthyThreshold &&
diff --git a/lib/healthcheck/config_test.go b/lib/healthcheck/config_test.go
index d9edbe9910ee8..21837c40bc762 100644
--- a/lib/healthcheck/config_test.go
+++ b/lib/healthcheck/config_test.go
@@ -77,7 +77,6 @@ func Test_newHealthCheckConfig(t *testing.T) {
interval: time.Second * 43,
healthyThreshold: 7,
unhealthyThreshold: 8,
- protocol: types.TargetHealthProtocolTCP,
databaseLabelMatchers: types.LabelMatchers{
Labels: types.Labels{
"foo": utils.Strings{"bar", "baz"},
@@ -95,7 +94,6 @@ func Test_newHealthCheckConfig(t *testing.T) {
interval: defaults.HealthCheckInterval,
healthyThreshold: defaults.HealthCheckHealthyThreshold,
unhealthyThreshold: defaults.HealthCheckUnhealthyThreshold,
- protocol: types.TargetHealthProtocolTCP,
databaseLabelMatchers: types.LabelMatchers{
Labels: types.Labels{},
Expression: `labels["*"] == "*"`,
@@ -139,7 +137,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) {
desc: "all fields equal",
a: &healthCheckConfig{
name: "test",
- protocol: "http",
interval: time.Second,
timeout: 500 * time.Millisecond,
healthyThreshold: 3,
@@ -147,7 +144,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) {
},
b: &healthCheckConfig{
name: "test",
- protocol: "http",
interval: time.Second,
timeout: 500 * time.Millisecond,
healthyThreshold: 3,
@@ -159,7 +155,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) {
desc: "all fields equal ignoring labels",
a: &healthCheckConfig{
name: "test",
- protocol: "http",
interval: time.Second,
timeout: 500 * time.Millisecond,
healthyThreshold: 3,
@@ -168,7 +163,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) {
},
b: &healthCheckConfig{
name: "test",
- protocol: "http",
interval: time.Second,
timeout: 500 * time.Millisecond,
healthyThreshold: 3,
@@ -187,16 +181,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) {
},
want: false,
},
- {
- desc: "different protocol",
- a: &healthCheckConfig{
- protocol: "http",
- },
- b: &healthCheckConfig{
- protocol: "tcp",
- },
- want: false,
- },
{
desc: "different interval",
a: &healthCheckConfig{
diff --git a/lib/healthcheck/manager_test.go b/lib/healthcheck/manager_test.go
index af8c22f1f50b2..ae6362fdb126e 100644
--- a/lib/healthcheck/manager_test.go
+++ b/lib/healthcheck/manager_test.go
@@ -181,17 +181,19 @@ func TestManager(t *testing.T) {
var endpointMu sync.Mutex
prodDialer := fakeDialer{}
err = mgr.AddTarget(Target{
+ HealthChecker: &TargetDialer{
+ Resolver: func(ctx context.Context) ([]string, error) {
+ endpointMu.Lock()
+ defer endpointMu.Unlock()
+ return []string{prodDB.GetURI()}, nil
+ },
+ dial: prodDialer.DialContext,
+ },
GetResource: func() types.ResourceWithLabels {
endpointMu.Lock()
defer endpointMu.Unlock()
return prodDB
},
- ResolverFn: func(ctx context.Context) ([]string, error) {
- endpointMu.Lock()
- defer endpointMu.Unlock()
- return []string{prodDB.GetURI()}, nil
- },
- dialFn: prodDialer.DialContext,
onHealthCheck: func(lastResultErr error) {
eventsCh <- lastResultTestEvent(prodDB.GetName(), lastResultErr)
},
@@ -205,18 +207,20 @@ func TestManager(t *testing.T) {
require.NoError(t, err)
devDialer := fakeDialer{}
- err = mgr.AddTarget(Target{
+ devTarget := Target{
+ HealthChecker: &TargetDialer{
+ Resolver: func(ctx context.Context) ([]string, error) {
+ endpointMu.Lock()
+ defer endpointMu.Unlock()
+ return []string{devDB.GetURI()}, nil
+ },
+ dial: devDialer.DialContext,
+ },
GetResource: func() types.ResourceWithLabels {
endpointMu.Lock()
defer endpointMu.Unlock()
return devDB
},
- ResolverFn: func(ctx context.Context) ([]string, error) {
- endpointMu.Lock()
- defer endpointMu.Unlock()
- return []string{devDB.GetURI()}, nil
- },
- dialFn: devDialer.DialContext,
onHealthCheck: func(lastResultErr error) {
eventsCh <- lastResultTestEvent(devDB.GetName(), lastResultErr)
},
@@ -226,21 +230,27 @@ func TestManager(t *testing.T) {
onClose: func() {
eventsCh <- closedTestEvent(devDB.GetName())
},
- })
+ }
+ err = mgr.AddTarget(devTarget)
require.NoError(t, err)
t.Run("duplicate target is an error", func(t *testing.T) {
- err = mgr.AddTarget(Target{
- GetResource: func() types.ResourceWithLabels { return devDB },
- ResolverFn: func(ctx context.Context) ([]string, error) { return nil, nil },
- })
+ err = mgr.AddTarget(devTarget)
require.Error(t, err)
require.IsType(t, trace.AlreadyExists(""), err)
})
t.Run("unsupported target resource is an error", func(t *testing.T) {
err = mgr.AddTarget(Target{
- GetResource: func() types.ResourceWithLabels { return &fakeResource{kind: "node"} },
- ResolverFn: func(ctx context.Context) ([]string, error) { return nil, nil },
+ HealthChecker: &TargetDialer{
+ Resolver: func(ctx context.Context) ([]string, error) {
+ endpointMu.Lock()
+ defer endpointMu.Unlock()
+ return nil, nil
+ },
+ },
+ GetResource: func() types.ResourceWithLabels {
+ return &fakeResource{kind: "node"}
+ },
})
require.Error(t, err)
require.IsType(t, trace.BadParameter(""), err)
diff --git a/lib/healthcheck/net.go b/lib/healthcheck/net.go
new file mode 100644
index 0000000000000..557eff587d78a
--- /dev/null
+++ b/lib/healthcheck/net.go
@@ -0,0 +1,97 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package healthcheck
+
+import (
+ "context"
+ "net"
+
+ "github.com/gravitational/trace"
+ "golang.org/x/sync/errgroup"
+
+ "github.com/gravitational/teleport/api/types"
+)
+
+// EndpointsResolverFunc is callback func that returns endpoints for a target.
+type EndpointsResolverFunc func(ctx context.Context) ([]string, error)
+
+// dialFunc dials an address on the given network.
+type dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
+
+// TargetDialer is a health check target which uses a net.Dialer.
+type TargetDialer struct {
+ // Resolver resolves the target endpoint(s).
+ Resolver EndpointsResolverFunc
+ // dial is used to dial network connections.
+ dial dialFunc
+}
+
+// NewTargetDialer returns a new TargetDialer ready for use.
+func NewTargetDialer(resolver EndpointsResolverFunc) *TargetDialer {
+ return &TargetDialer{
+ Resolver: resolver,
+ dial: defaultDialer().DialContext,
+ }
+}
+
+// GetProtocol returns the network protocol used for checking health.
+func (t *TargetDialer) GetProtocol() types.TargetHealthProtocol {
+ return types.TargetHealthProtocolTCP
+}
+
+// CheckHealth checks the health of the target resource.
+func (t *TargetDialer) CheckHealth(ctx context.Context) ([]string, error) {
+ return t.dialEndpoints(ctx)
+}
+
+func (t *TargetDialer) dialEndpoints(ctx context.Context) ([]string, error) {
+ endpoints, err := t.Resolver(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err, "failed to resolve target endpoints")
+ }
+ switch len(endpoints) {
+ case 0:
+ return nil, trace.NotFound("resolved zero target endpoints")
+ case 1:
+ return endpoints, t.dialEndpoint(ctx, endpoints[0])
+ default:
+ group, ctx := errgroup.WithContext(ctx)
+ group.SetLimit(10)
+ for _, ep := range endpoints {
+ group.Go(func() error {
+ return trace.Wrap(t.dialEndpoint(ctx, ep))
+ })
+ }
+ return endpoints, group.Wait()
+ }
+}
+
+func (t *TargetDialer) dialEndpoint(ctx context.Context, endpoint string) error {
+ conn, err := t.dial(ctx, "tcp", endpoint)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ // an error while closing the connection could indicate an RST packet from
+ // the endpoint - that's a health check failure.
+ return trace.Wrap(conn.Close())
+}
+
+func defaultDialer() *net.Dialer {
+ return &net.Dialer{}
+}
diff --git a/lib/healthcheck/net_test.go b/lib/healthcheck/net_test.go
new file mode 100644
index 0000000000000..b5880f5dc2c6c
--- /dev/null
+++ b/lib/healthcheck/net_test.go
@@ -0,0 +1,117 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package healthcheck
+
+import (
+ "context"
+ "net"
+ "testing"
+
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/api/types"
+)
+
+func TestTargetDialer_dialEndpoints(t *testing.T) {
+ t.Parallel()
+
+ const healthyAddr = "healthy.com:123"
+ const unhealthyAddr = "unhealthy.com:123"
+ tests := []struct {
+ desc string
+ resolver EndpointsResolverFunc
+ wantErrContains string
+ }{
+ {
+ desc: "resolver error",
+ resolver: func(ctx context.Context) ([]string, error) {
+ return nil, trace.Errorf("resolver error")
+ },
+ wantErrContains: "resolver error",
+ },
+ {
+ desc: "resolved zero addrs",
+ resolver: func(ctx context.Context) ([]string, error) {
+ return nil, nil
+ },
+ wantErrContains: "resolved zero target endpoints",
+ },
+ {
+ desc: "resolved one healthy addr",
+ resolver: func(ctx context.Context) ([]string, error) {
+ return []string{healthyAddr}, nil
+ },
+ },
+ {
+ desc: "resolved one unhealthy addr",
+ resolver: func(ctx context.Context) ([]string, error) {
+ return []string{unhealthyAddr}, nil
+ },
+ wantErrContains: "unhealthy addr",
+ },
+ {
+ desc: "resolved multiple healthy addrs",
+ resolver: func(ctx context.Context) ([]string, error) {
+ return []string{healthyAddr, healthyAddr, healthyAddr}, nil
+ },
+ },
+ {
+ desc: "resolved a mix of healthy and unhealthy addrs",
+ resolver: func(ctx context.Context) ([]string, error) {
+ return []string{healthyAddr, unhealthyAddr, healthyAddr}, nil
+ },
+ wantErrContains: "unhealthy addr",
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ d := &TargetDialer{
+ Resolver: test.resolver,
+ dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ if addr == healthyAddr {
+ return fakeConn{}, nil
+ }
+ return nil, trace.Errorf("unhealthy addr")
+ },
+ }
+ _, err := d.dialEndpoints(t.Context())
+ if test.wantErrContains != "" {
+ require.ErrorContains(t, err, test.wantErrContains)
+ return
+ }
+ require.NoError(t, err)
+ })
+ }
+}
+
+type fakeConn struct {
+ net.Conn
+}
+
+func (fakeConn) Close() error { return nil }
+
+type fakeResource struct {
+ kind string
+ types.ResourceWithLabels
+}
+
+func (r *fakeResource) GetKind() string {
+ return r.kind
+}
diff --git a/lib/healthcheck/target.go b/lib/healthcheck/target.go
index b9c9ca0851c73..f785848566e65 100644
--- a/lib/healthcheck/target.go
+++ b/lib/healthcheck/target.go
@@ -20,30 +20,29 @@ package healthcheck
import (
"context"
- "net"
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/types"
)
-// EndpointsResolverFunc is callback func that returns endpoints for a target.
-type EndpointsResolverFunc func(ctx context.Context) ([]string, error)
-
-// OnHealthChangeFunc is a func called on each health change.
-type OnHealthChangeFunc func(oldHealth, newHealth types.TargetHealth)
+// HealthChecker is a resource which provides health checks.
+type HealthChecker interface {
+ // CheckHealth checks the health of a target resource.
+ CheckHealth(ctx context.Context) ([]string, error)
+ // GetProtocol returns the network protocol used for checking health.
+ GetProtocol() types.TargetHealthProtocol
+}
// Target is a health check target.
type Target struct {
+ // HealthChecker checks the resource's health.
+ HealthChecker
// GetResource gets a copy of the target resource with updated labels.
GetResource func() types.ResourceWithLabels
- // ResolverFn resolves the target endpoint(s).
- ResolverFn EndpointsResolverFunc
// -- test fields below --
- // dialFn used to mock dialing in tests
- dialFn dialFunc
// onHealthCheck is called after each health check.
onHealthCheck func(lastResultErr error)
// onConfigUpdate is called after each config update.
@@ -56,15 +55,8 @@ func (t *Target) checkAndSetDefaults() error {
if t.GetResource == nil {
return trace.BadParameter("missing target resource getter")
}
- if t.ResolverFn == nil {
- return trace.BadParameter("missing target endpoint resolver")
- }
- if t.dialFn == nil {
- t.dialFn = defaultDialer().DialContext
+ if t.HealthChecker == nil {
+ return trace.BadParameter("missing health checker")
}
return nil
}
-
-func defaultDialer() *net.Dialer {
- return &net.Dialer{}
-}
diff --git a/lib/healthcheck/target_test.go b/lib/healthcheck/target_test.go
index 0d4c3c5ec4ba0..f5bcedc2ae9cf 100644
--- a/lib/healthcheck/target_test.go
+++ b/lib/healthcheck/target_test.go
@@ -28,15 +28,6 @@ import (
"github.com/gravitational/teleport/api/types"
)
-type fakeResource struct {
- kind string
- types.ResourceWithLabels
-}
-
-func (r *fakeResource) GetKind() string {
- return r.kind
-}
-
func TestTarget_checkAndSetDefaults(t *testing.T) {
tests := []struct {
name string
@@ -46,23 +37,27 @@ func TestTarget_checkAndSetDefaults(t *testing.T) {
{
name: "valid target",
target: &Target{
+ HealthChecker: NewTargetDialer(
+ func(ctx context.Context) ([]string, error) { return []string{"127.0.0.1"}, nil },
+ ),
GetResource: func() types.ResourceWithLabels { return &fakeResource{kind: types.KindDatabase} },
- ResolverFn: func(ctx context.Context) ([]string, error) { return []string{"127.0.0.1"}, nil },
},
},
{
name: "missing resource getter",
target: &Target{
- ResolverFn: func(ctx context.Context) ([]string, error) { return nil, nil },
+ HealthChecker: NewTargetDialer(
+ func(ctx context.Context) ([]string, error) { return nil, nil },
+ ),
},
wantErr: "missing target resource getter",
},
{
- name: "missing resolver",
+ name: "missing health checker",
target: &Target{
GetResource: func() types.ResourceWithLabels { return nil },
},
- wantErr: "missing target endpoint resolver",
+ wantErr: "missing health checker",
},
}
diff --git a/lib/healthcheck/worker.go b/lib/healthcheck/worker.go
index 4ff9c678ca344..d224639fe0d3f 100644
--- a/lib/healthcheck/worker.go
+++ b/lib/healthcheck/worker.go
@@ -22,7 +22,6 @@ import (
"context"
"fmt"
"log/slog"
- "net"
"strings"
"sync"
"time"
@@ -30,7 +29,6 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
- "golang.org/x/sync/errgroup"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
@@ -160,9 +158,6 @@ type worker struct {
metricType string
}
-// dialFunc dials an address on the given network.
-type dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
-
// GetTargetHealth returns the worker's target health.
func (w *worker) GetTargetHealth() *types.TargetHealth {
w.mu.RLock()
@@ -238,7 +233,6 @@ func (w *worker) run() {
func (w *worker) startHealthCheckInterval(ctx context.Context) {
w.log.InfoContext(ctx, "Health checker started",
"health_check_config", w.healthCheckCfg.name,
- "protocol", w.healthCheckCfg.protocol,
"interval", log.StringerAttr(w.healthCheckCfg.interval),
"timeout", log.StringerAttr(w.healthCheckCfg.timeout),
"healthy_threshold", w.healthCheckCfg.healthyThreshold,
@@ -276,25 +270,31 @@ func (w *worker) nextHealthCheck() <-chan time.Time {
// updates the worker's health check result history, and possibly updates the
// target health.
func (w *worker) checkHealth(ctx context.Context) {
- initializing := w.lastResultCount == 0
- dialErr := w.dialEndpoints(ctx)
+ ctx, cancel := context.WithTimeout(ctx, w.healthCheckCfg.timeout)
+ defer cancel()
+
+ // check target health
+ var curErr error
+ w.lastResolvedEndpoints, curErr = w.target.CheckHealth(ctx)
+
if ctx.Err() == context.Canceled {
return
}
- if (dialErr == nil) == (w.lastResultErr == nil) {
+ initializing := w.lastResultCount == 0
+ if (curErr == nil) == (w.lastResultErr == nil) {
w.lastResultCount++
} else {
// the passing/failing result streak has ended, so reset the count
w.lastResultCount = 1
}
- w.lastResultErr = dialErr
+ w.lastResultErr = curErr
if w.lastResultErr != nil {
w.log.DebugContext(ctx, "Failed health check",
"error", w.lastResultErr,
)
}
- // update target health when we exactly reach the threshold or initialize
+ // update target health when we initialize or exactly reach the threshold
if initializing || w.getThreshold(w.healthCheckCfg) == w.lastResultCount {
w.setThresholdReached(ctx)
}
@@ -322,7 +322,6 @@ func (w *worker) updateHealthCheckConfig(ctx context.Context, newCfg *healthChec
}
w.log.DebugContext(ctx, "Updated health check config",
"health_check_config", w.healthCheckCfg.name,
- "protocol", w.healthCheckCfg.protocol,
"interval", log.StringerAttr(w.healthCheckCfg.interval),
"timeout", log.StringerAttr(w.healthCheckCfg.timeout),
"healthy_threshold", w.healthCheckCfg.healthyThreshold,
@@ -347,41 +346,6 @@ func (w *worker) updateHealthCheckConfig(ctx context.Context, newCfg *healthChec
}
}
-func (w *worker) dialEndpoints(ctx context.Context) error {
- ctx, cancel := context.WithTimeout(ctx, w.healthCheckCfg.timeout)
- defer cancel()
- endpoints, err := w.target.ResolverFn(ctx)
- if err != nil {
- return trace.Wrap(err, "failed to resolve target endpoints")
- }
- w.lastResolvedEndpoints = endpoints
- switch len(endpoints) {
- case 0:
- return trace.NotFound("resolved zero target endpoints")
- case 1:
- return w.dialEndpoint(ctx, endpoints[0])
- default:
- group, ctx := errgroup.WithContext(ctx)
- group.SetLimit(10)
- for _, ep := range endpoints {
- group.Go(func() error {
- return trace.Wrap(w.dialEndpoint(ctx, ep))
- })
- }
- return group.Wait()
- }
-}
-
-func (w *worker) dialEndpoint(ctx context.Context, endpoint string) error {
- conn, err := w.target.dialFn(ctx, "tcp", endpoint)
- if err != nil {
- return trace.Wrap(err)
- }
- // an error while closing the connection could indicate an RST packet from
- // the endpoint - that's a health check failure.
- return trace.Wrap(conn.Close())
-}
-
// getThreshold returns the appropriate threshold to compare against the last
// result.
func (w *worker) getThreshold(cfg *healthCheckConfig) uint32 {
@@ -466,6 +430,7 @@ func (w *worker) setTargetHealthStatus(ctx context.Context, newStatus types.Targ
now := w.clock.Now()
w.targetHealth = types.TargetHealth{
Address: strings.Join(w.lastResolvedEndpoints, ","),
+ Protocol: string(w.target.GetProtocol()),
Status: string(newStatus),
TransitionTimestamp: &now,
TransitionReason: string(reason),
@@ -474,9 +439,6 @@ func (w *worker) setTargetHealthStatus(ctx context.Context, newStatus types.Targ
if w.lastResultErr != nil {
w.targetHealth.TransitionError = w.lastResultErr.Error()
}
- if w.healthCheckCfg != nil {
- w.targetHealth.Protocol = string(w.healthCheckCfg.protocol)
- }
}
// notifyInitStatusAvailableLocked closes the pending init status channel, if
diff --git a/lib/healthcheck/worker_experimental_test.go b/lib/healthcheck/worker_experimental_test.go
index c0860dc74637a..8c25fd4bad49b 100644
--- a/lib/healthcheck/worker_experimental_test.go
+++ b/lib/healthcheck/worker_experimental_test.go
@@ -72,15 +72,17 @@ func TestGetTargetHealth(t *testing.T) {
worker, err := newWorker(t.Context(), workerConfig{
HealthCheckCfg: test.healthCheckConfig,
Target: Target{
- GetResource: func() types.ResourceWithLabels { return nil },
- ResolverFn: func(ctx context.Context) ([]string, error) {
- return []string{"localhost:1234"}, nil
- },
- dialFn: func(ctx context.Context, network, addr string) (net.Conn, error) {
- time.Sleep(5*time.Second - time.Nanosecond)
- synctest.Wait()
- return fakeConn{}, test.dialErr
+ HealthChecker: &TargetDialer{
+ Resolver: func(ctx context.Context) ([]string, error) {
+ return []string{"localhost:1234"}, nil
+ },
+ dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ time.Sleep(5*time.Second - time.Nanosecond)
+ synctest.Wait()
+ return fakeConn{}, test.dialErr
+ },
},
+ GetResource: func() types.ResourceWithLabels { return nil },
},
})
assert.NoError(t, err)
diff --git a/lib/healthcheck/worker_test.go b/lib/healthcheck/worker_test.go
index 73468d7016eab..b495894c017e3 100644
--- a/lib/healthcheck/worker_test.go
+++ b/lib/healthcheck/worker_test.go
@@ -20,14 +20,12 @@ package healthcheck
import (
"context"
- "log/slog"
"net"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/types"
@@ -37,7 +35,8 @@ import (
func Test_newUnstartedWorker(t *testing.T) {
t.Parallel()
ctx := context.Background()
- listener, err := net.Listen("tcp", "localhost:0")
+ protocol := string(types.TargetHealthProtocolTCP)
+ listener, err := net.Listen(protocol, "localhost:0")
require.NoError(t, err)
t.Cleanup(func() { _ = listener.Close() })
@@ -63,15 +62,15 @@ func Test_newUnstartedWorker(t *testing.T) {
cfg: workerConfig{
HealthCheckCfg: nil,
Target: Target{
+ HealthChecker: NewTargetDialer(
+ func(ctx context.Context) ([]string, error) { return nil, nil },
+ ),
GetResource: func() types.ResourceWithLabels { return db },
- ResolverFn: func(ctx context.Context) ([]string, error) {
- return []string{db.GetURI()}, nil
- },
},
},
wantHealth: types.TargetHealth{
Address: "",
- Protocol: "",
+ Protocol: protocol,
Status: string(types.TargetHealthStatusUnknown),
TransitionReason: string(types.TargetHealthTransitionReasonDisabled),
Message: "No health check config matches this resource",
@@ -87,16 +86,16 @@ func Test_newUnstartedWorker(t *testing.T) {
unhealthyThreshold: 10,
},
Target: Target{
+ HealthChecker: NewTargetDialer(
+ func(ctx context.Context) ([]string, error) { return nil, nil },
+ ),
GetResource: func() types.ResourceWithLabels { return db },
- ResolverFn: func(ctx context.Context) ([]string, error) {
- return []string{db.GetURI()}, nil
- },
},
getTargetHealthTimeout: time.Millisecond,
},
wantHealth: types.TargetHealth{
Address: "",
- Protocol: "",
+ Protocol: protocol,
Status: string(types.TargetHealthStatusUnknown),
TransitionReason: string(types.TargetHealthTransitionReasonInit),
Message: "Health checker initialized",
@@ -112,9 +111,9 @@ func Test_newUnstartedWorker(t *testing.T) {
unhealthyThreshold: 10,
},
Target: Target{
- ResolverFn: func(ctx context.Context) ([]string, error) {
- return []string{db.GetURI()}, nil
- },
+ HealthChecker: NewTargetDialer(
+ func(ctx context.Context) ([]string, error) { return nil, nil },
+ ),
},
},
wantErr: "missing target resource getter",
@@ -135,87 +134,3 @@ func Test_newUnstartedWorker(t *testing.T) {
})
}
}
-
-func Test_dialEndpoints(t *testing.T) {
- t.Parallel()
- ctx, cancel := context.WithCancel(context.Background())
- t.Cleanup(cancel)
-
- const healthyAddr = "healthy.com:123"
- const unhealthyAddr = "unhealthy.com:123"
- tests := []struct {
- desc string
- resolverFn EndpointsResolverFunc
- wantErrContains string
- }{
- {
- desc: "resolver error",
- resolverFn: func(ctx context.Context) ([]string, error) {
- return nil, trace.Errorf("resolver error")
- },
- wantErrContains: "resolver error",
- },
- {
- desc: "resolved zero addrs",
- resolverFn: func(ctx context.Context) ([]string, error) {
- return nil, nil
- },
- wantErrContains: "resolved zero target endpoints",
- },
- {
- desc: "resolved one healthy addr",
- resolverFn: func(ctx context.Context) ([]string, error) {
- return []string{healthyAddr}, nil
- },
- },
- {
- desc: "resolved one unhealthy addr",
- resolverFn: func(ctx context.Context) ([]string, error) {
- return []string{unhealthyAddr}, nil
- },
- wantErrContains: "unhealthy addr",
- },
- {
- desc: "resolved multiple healthy addrs",
- resolverFn: func(ctx context.Context) ([]string, error) {
- return []string{healthyAddr, healthyAddr, healthyAddr}, nil
- },
- },
- {
- desc: "resolved a mix of healthy and unhealthy addrs",
- resolverFn: func(ctx context.Context) ([]string, error) {
- return []string{healthyAddr, unhealthyAddr, healthyAddr}, nil
- },
- wantErrContains: "unhealthy addr",
- },
- }
- for _, test := range tests {
- t.Run(test.desc, func(t *testing.T) {
- w := &worker{
- healthCheckCfg: &healthCheckConfig{},
- log: slog.Default(),
- target: Target{
- ResolverFn: test.resolverFn,
- dialFn: func(ctx context.Context, network, addr string) (net.Conn, error) {
- if addr == healthyAddr {
- return fakeConn{}, nil
- }
- return nil, trace.Errorf("unhealthy addr")
- },
- },
- }
- err := w.dialEndpoints(ctx)
- if test.wantErrContains != "" {
- require.ErrorContains(t, err, test.wantErrContains)
- return
- }
- require.NoError(t, err)
- })
- }
-}
-
-type fakeConn struct {
- net.Conn
-}
-
-func (fakeConn) Close() error { return nil }
diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go
index 54a54daac172f..c738b43ae9f23 100644
--- a/lib/srv/db/server.go
+++ b/lib/srv/db/server.go
@@ -1481,12 +1481,12 @@ func (s *Server) startHealthCheck(ctx context.Context, db types.Database) error
return trace.Wrap(err)
}
err = s.cfg.healthCheckManager.AddTarget(healthcheck.Target{
+ HealthChecker: healthcheck.NewTargetDialer(resolver),
GetResource: func() types.ResourceWithLabels {
s.mu.RLock()
defer s.mu.RUnlock()
return s.copyDatabaseWithUpdatedLabelsLocked(db)
},
- ResolverFn: resolver,
})
return trace.Wrap(err)
}