Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions api/types/target_health.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ import (
type TargetHealthProtocol string

const (
// TargetHealthProtocolTCP is a target health check protocol.
TargetHealthProtocolTCP TargetHealthProtocol = "TCP"
// TargetHealthProtocolTCP is the TCP target health check protocol.
TargetHealthProtocolTCP TargetHealthProtocol = "tcp"
Comment thread
rana marked this conversation as resolved.
// TargetHealthProtocolHTTP is the HTTP target health check protocol.
TargetHealthProtocolHTTP TargetHealthProtocol = "http"
)

// TargetHealthStatus is a target resource's health status.
Expand Down
15 changes: 5 additions & 10 deletions lib/healthcheck/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
// [*healthcheckconfigv1.HealthCheckConfig] with defaults set.
type healthCheckConfig struct {
name string
protocol types.TargetHealthProtocol
interval time.Duration
timeout time.Duration
healthyThreshold uint32
Expand All @@ -48,14 +47,11 @@ func newHealthCheckConfig(cfg *healthcheckconfigv1.HealthCheckConfig) *healthChe
spec := cfg.GetSpec()
match := spec.GetMatch()
return &healthCheckConfig{
name: cfg.GetMetadata().GetName(),
timeout: cmp.Or(spec.GetTimeout().AsDuration(), defaults.HealthCheckTimeout),
interval: cmp.Or(spec.GetInterval().AsDuration(), defaults.HealthCheckInterval),
healthyThreshold: cmp.Or(spec.GetHealthyThreshold(), defaults.HealthCheckHealthyThreshold),
unhealthyThreshold: cmp.Or(spec.GetUnhealthyThreshold(), defaults.HealthCheckUnhealthyThreshold),
// we only support plain TCP health checks currently, but eventually we
// may add support for other protocols such as TLS or HTTP
protocol: types.TargetHealthProtocolTCP,
name: cfg.GetMetadata().GetName(),
timeout: cmp.Or(spec.GetTimeout().AsDuration(), defaults.HealthCheckTimeout),
interval: cmp.Or(spec.GetInterval().AsDuration(), defaults.HealthCheckInterval),
healthyThreshold: cmp.Or(spec.GetHealthyThreshold(), defaults.HealthCheckHealthyThreshold),
unhealthyThreshold: cmp.Or(spec.GetUnhealthyThreshold(), defaults.HealthCheckUnhealthyThreshold),
databaseLabelMatchers: newLabelMatchers(match.GetDbLabelsExpression(), match.GetDbLabels()),
}
}
Expand All @@ -66,7 +62,6 @@ func (h *healthCheckConfig) equivalent(other *healthCheckConfig) bool {
return (h == nil && other == nil) ||
h != nil && other != nil &&
h.name == other.name &&
h.protocol == other.protocol &&
h.interval == other.interval &&
h.timeout == other.timeout &&
h.healthyThreshold == other.healthyThreshold &&
Expand Down
16 changes: 0 additions & 16 deletions lib/healthcheck/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ func Test_newHealthCheckConfig(t *testing.T) {
interval: time.Second * 43,
healthyThreshold: 7,
unhealthyThreshold: 8,
protocol: types.TargetHealthProtocolTCP,
databaseLabelMatchers: types.LabelMatchers{
Labels: types.Labels{
"foo": utils.Strings{"bar", "baz"},
Expand All @@ -95,7 +94,6 @@ func Test_newHealthCheckConfig(t *testing.T) {
interval: defaults.HealthCheckInterval,
healthyThreshold: defaults.HealthCheckHealthyThreshold,
unhealthyThreshold: defaults.HealthCheckUnhealthyThreshold,
protocol: types.TargetHealthProtocolTCP,
databaseLabelMatchers: types.LabelMatchers{
Labels: types.Labels{},
Expression: `labels["*"] == "*"`,
Expand Down Expand Up @@ -139,15 +137,13 @@ func TestHealthCheckConfig_equivalent(t *testing.T) {
desc: "all fields equal",
a: &healthCheckConfig{
name: "test",
protocol: "http",
interval: time.Second,
timeout: 500 * time.Millisecond,
healthyThreshold: 3,
unhealthyThreshold: 5,
},
b: &healthCheckConfig{
name: "test",
protocol: "http",
interval: time.Second,
timeout: 500 * time.Millisecond,
healthyThreshold: 3,
Expand All @@ -159,7 +155,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) {
desc: "all fields equal ignoring labels",
a: &healthCheckConfig{
name: "test",
protocol: "http",
interval: time.Second,
timeout: 500 * time.Millisecond,
healthyThreshold: 3,
Expand All @@ -168,7 +163,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) {
},
b: &healthCheckConfig{
name: "test",
protocol: "http",
interval: time.Second,
timeout: 500 * time.Millisecond,
healthyThreshold: 3,
Expand All @@ -187,16 +181,6 @@ func TestHealthCheckConfig_equivalent(t *testing.T) {
},
want: false,
},
{
desc: "different protocol",
a: &healthCheckConfig{
protocol: "http",
},
b: &healthCheckConfig{
protocol: "tcp",
},
want: false,
},
{
desc: "different interval",
a: &healthCheckConfig{
Expand Down
50 changes: 30 additions & 20 deletions lib/healthcheck/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,19 @@ func TestManager(t *testing.T) {
var endpointMu sync.Mutex
prodDialer := fakeDialer{}
err = mgr.AddTarget(Target{
HealthChecker: &TargetDialer{
Resolver: func(ctx context.Context) ([]string, error) {
endpointMu.Lock()
defer endpointMu.Unlock()
return []string{prodDB.GetURI()}, nil
},
dial: prodDialer.DialContext,
},
GetResource: func() types.ResourceWithLabels {
endpointMu.Lock()
defer endpointMu.Unlock()
return prodDB
},
ResolverFn: func(ctx context.Context) ([]string, error) {
endpointMu.Lock()
defer endpointMu.Unlock()
return []string{prodDB.GetURI()}, nil
},
dialFn: prodDialer.DialContext,
onHealthCheck: func(lastResultErr error) {
eventsCh <- lastResultTestEvent(prodDB.GetName(), lastResultErr)
},
Expand All @@ -205,18 +207,20 @@ func TestManager(t *testing.T) {
require.NoError(t, err)

devDialer := fakeDialer{}
err = mgr.AddTarget(Target{
devTarget := Target{
HealthChecker: &TargetDialer{
Resolver: func(ctx context.Context) ([]string, error) {
endpointMu.Lock()
defer endpointMu.Unlock()
return []string{devDB.GetURI()}, nil
},
dial: devDialer.DialContext,
},
GetResource: func() types.ResourceWithLabels {
endpointMu.Lock()
defer endpointMu.Unlock()
return devDB
},
ResolverFn: func(ctx context.Context) ([]string, error) {
endpointMu.Lock()
defer endpointMu.Unlock()
return []string{devDB.GetURI()}, nil
},
dialFn: devDialer.DialContext,
onHealthCheck: func(lastResultErr error) {
eventsCh <- lastResultTestEvent(devDB.GetName(), lastResultErr)
},
Expand All @@ -226,21 +230,27 @@ func TestManager(t *testing.T) {
onClose: func() {
eventsCh <- closedTestEvent(devDB.GetName())
},
})
}
err = mgr.AddTarget(devTarget)
require.NoError(t, err)

t.Run("duplicate target is an error", func(t *testing.T) {
err = mgr.AddTarget(Target{
GetResource: func() types.ResourceWithLabels { return devDB },
ResolverFn: func(ctx context.Context) ([]string, error) { return nil, nil },
})
err = mgr.AddTarget(devTarget)
require.Error(t, err)
require.IsType(t, trace.AlreadyExists(""), err)
})
t.Run("unsupported target resource is an error", func(t *testing.T) {
err = mgr.AddTarget(Target{
GetResource: func() types.ResourceWithLabels { return &fakeResource{kind: "node"} },
ResolverFn: func(ctx context.Context) ([]string, error) { return nil, nil },
HealthChecker: &TargetDialer{
Resolver: func(ctx context.Context) ([]string, error) {
endpointMu.Lock()
defer endpointMu.Unlock()
return nil, nil
},
},
GetResource: func() types.ResourceWithLabels {
return &fakeResource{kind: "node"}
},
})
require.Error(t, err)
require.IsType(t, trace.BadParameter(""), err)
Expand Down
97 changes: 97 additions & 0 deletions lib/healthcheck/net.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Teleport
* Copyright (C) 2025 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package healthcheck

import (
"context"
"net"

"github.com/gravitational/trace"
"golang.org/x/sync/errgroup"

"github.com/gravitational/teleport/api/types"
)

// EndpointsResolverFunc is callback func that returns endpoints for a target.
type EndpointsResolverFunc func(ctx context.Context) ([]string, error)

// dialFunc dials an address on the given network.
type dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)

// TargetDialer is a health check target which uses a net.Dialer.
type TargetDialer struct {
// Resolver resolves the target endpoint(s).
Resolver EndpointsResolverFunc
// dial is used to dial network connections.
dial dialFunc
}

// NewTargetDialer returns a new TargetDialer ready for use.
func NewTargetDialer(resolver EndpointsResolverFunc) *TargetDialer {
return &TargetDialer{
Resolver: resolver,
dial: defaultDialer().DialContext,
}
}

// GetProtocol returns the network protocol used for checking health.
func (t *TargetDialer) GetProtocol() types.TargetHealthProtocol {
return types.TargetHealthProtocolTCP
}

// CheckHealth checks the health of the target resource.
func (t *TargetDialer) CheckHealth(ctx context.Context) ([]string, error) {
return t.dialEndpoints(ctx)
}

func (t *TargetDialer) dialEndpoints(ctx context.Context) ([]string, error) {
endpoints, err := t.Resolver(ctx)
if err != nil {
return nil, trace.Wrap(err, "failed to resolve target endpoints")
}
switch len(endpoints) {
case 0:
return nil, trace.NotFound("resolved zero target endpoints")
case 1:
return endpoints, t.dialEndpoint(ctx, endpoints[0])
default:
group, ctx := errgroup.WithContext(ctx)
group.SetLimit(10)
for _, ep := range endpoints {
group.Go(func() error {
return trace.Wrap(t.dialEndpoint(ctx, ep))
})
}
return endpoints, group.Wait()
}
}

func (t *TargetDialer) dialEndpoint(ctx context.Context, endpoint string) error {
conn, err := t.dial(ctx, "tcp", endpoint)
if err != nil {
return trace.Wrap(err)
}
// an error while closing the connection could indicate an RST packet from
// the endpoint - that's a health check failure.
return trace.Wrap(conn.Close())
}

func defaultDialer() *net.Dialer {
return &net.Dialer{}
}
Loading
Loading