diff --git a/lib/healthcheck/manager_test.go b/lib/healthcheck/manager_test.go index ae6362fdb126e..67012809a4059 100644 --- a/lib/healthcheck/manager_test.go +++ b/lib/healthcheck/manager_test.go @@ -237,7 +237,7 @@ func TestManager(t *testing.T) { t.Run("duplicate target is an error", func(t *testing.T) { err = mgr.AddTarget(devTarget) require.Error(t, err) - require.IsType(t, trace.AlreadyExists(""), err) + require.ErrorIs(t, trace.AlreadyExists("target health checker \"name=devDB, kind=db\" already exists"), err) }) t.Run("unsupported target resource is an error", func(t *testing.T) { err = mgr.AddTarget(Target{ @@ -253,7 +253,7 @@ func TestManager(t *testing.T) { }, }) require.Error(t, err) - require.IsType(t, trace.BadParameter(""), err) + require.ErrorIs(t, trace.BadParameter("health check target resource kind \"node\" is not supported"), err) }) requireTargetHealth := func(t *testing.T, r types.ResourceWithLabels, status types.TargetHealthStatus, reason types.TargetHealthTransitionReason) { @@ -397,11 +397,11 @@ func TestManager(t *testing.T) { // shouldn't be any target health after the target is removed _, err = mgr.GetTargetHealth(devDB) require.Error(t, err) - require.IsType(t, trace.NotFound(""), err) + require.ErrorIs(t, trace.NotFound("health checker \"name=devDB, kind=db\" not found"), err) err = mgr.RemoveTarget(devDB) require.Error(t, err) - require.IsType(t, trace.NotFound(""), err) + require.ErrorIs(t, trace.NotFound("health checker \"name=devDB, kind=db\" not found"), err) // prodDB should still be disabled requireTargetHealth(t, prodDB, types.TargetHealthStatusUnknown, types.TargetHealthTransitionReasonDisabled) diff --git a/lib/healthcheck/order_by.go b/lib/healthcheck/order_by.go new file mode 100644 index 0000000000000..1996e2c2bf2fe --- /dev/null +++ b/lib/healthcheck/order_by.go @@ -0,0 +1,45 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package healthcheck + +import ( + "iter" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/utils" +) + +// OrderByTargetHealthStatus returns an iterator over resources ordered by +// health status: healthy, unknown, and unhealthy. Each group is shuffled +// to distributing load on resources. +func OrderByTargetHealthStatus[T types.TargetHealthStatusGetter](resources []T) iter.Seq[T] { + return func(yield func(T) bool) { + groups := types.GroupByTargetHealthStatus(resources) + for _, group := range [][]T{groups.Healthy, groups.Unknown, groups.Unhealthy} { + // ShuffleVisit is used for its efficient early return and partial shuffle. + // The whole healthy group is likely not shuffled or visited. + // And the unknown and unhealthy groups are likely not shuffled or visited. + for _, resource := range utils.ShuffleVisit(group) { + if !yield(resource) { + return + } + } + } + } +} diff --git a/lib/healthcheck/order_by_test.go b/lib/healthcheck/order_by_test.go new file mode 100644 index 0000000000000..ee5c13412aab6 --- /dev/null +++ b/lib/healthcheck/order_by_test.go @@ -0,0 +1,122 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package healthcheck + +import ( + "math" + "slices" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" +) + +func TestOrderByHealthEmpty(t *testing.T) { + t.Parallel() + var servers []types.KubeServer + var visited []string + for server := range OrderByTargetHealthStatus(servers) { + visited = append(visited, server.GetHostID()) + } + require.Empty(t, visited) +} + +func TestOrderByHealthOne(t *testing.T) { + t.Parallel() + servers := []types.KubeServer{ + newKubeServer(t, "one", types.TargetHealthStatusHealthy), + } + var visited []string + for server := range OrderByTargetHealthStatus(servers) { + visited = append(visited, server.GetHostID()) + } + require.Equal(t, []string{"one"}, visited) +} + +func TestOrderByHealthUnsorted(t *testing.T) { + t.Parallel() + servers := []types.KubeServer{ + newKubeServer(t, "unknown-2", types.TargetHealthStatusUnknown), + newKubeServer(t, "unknown-1", types.TargetHealthStatusUnknown), + newKubeServer(t, "unhealthy-1", types.TargetHealthStatusUnhealthy), + newKubeServer(t, "healthy-2", types.TargetHealthStatusHealthy), + newKubeServer(t, "unhealthy-2", types.TargetHealthStatusUnhealthy), + newKubeServer(t, "healthy-1", types.TargetHealthStatusHealthy), + } + var visited []types.KubeServer + for server := range OrderByTargetHealthStatus(servers) { + visited = append(visited, server) + } + require.Len(t, visited, len(servers)) + require.True(t, slices.IsSortedFunc(visited, byHealthOrder)) +} + +func TestOrderByHealthEarlyExit(t *testing.T) { + t.Parallel() + servers := []types.KubeServer{ + newKubeServer(t, "unknown-1", types.TargetHealthStatusUnknown), + newKubeServer(t, "unhealthy-1", types.TargetHealthStatusUnhealthy), + newKubeServer(t, "healthy-2", types.TargetHealthStatusHealthy), + newKubeServer(t, "healthy-1", types.TargetHealthStatusHealthy), + } + var visited []string + for server := range OrderByTargetHealthStatus(servers) { + visited = append(visited, server.GetHostID()) + if len(visited) >= 2 { + break + } + } + require.Len(t, visited, 2) + require.NotContains(t, visited, "unknown-1") + require.NotContains(t, visited, "unhealthy-1") +} + +func newKubeServer(t *testing.T, hostID string, health types.TargetHealthStatus) types.KubeServer { + t.Helper() + cluster, err := types.NewKubernetesClusterV3( + types.Metadata{Name: "test-cluster"}, + types.KubernetesClusterSpecV3{}, + ) + require.NoError(t, err) + server, err := types.NewKubernetesServerV3FromCluster(cluster, "localhost:8080", hostID) + require.NoError(t, err) + server.Status = &types.KubernetesServerStatusV3{ + TargetHealth: &types.TargetHealth{ + Status: string(health), + }, + } + return server +} + +func healthOrder(s types.KubeServer) int { + switch s.GetTargetHealthStatus() { + case types.TargetHealthStatusHealthy: + return 0 + case types.TargetHealthStatusUnknown: + return 1 + case types.TargetHealthStatusUnhealthy: + return 2 + } + return math.MaxInt +} + +func byHealthOrder(a, b types.KubeServer) int { + return healthOrder(a) - healthOrder(b) +} diff --git a/lib/kube/proxy/transport.go b/lib/kube/proxy/transport.go index 58bd4c07b4ee1..dee8783317e12 100644 --- a/lib/kube/proxy/transport.go +++ b/lib/kube/proxy/transport.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/healthcheck" "github.com/gravitational/teleport/lib/kube/internal" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/utils" @@ -319,19 +320,21 @@ func (f *Forwarder) localClusterDialer(kubeClusterName string, opts ...contextDi return nil, trace.Wrap(err) } + // Dial kube servers in the order of health status healthy, unknown, and unhealthy. + // Each health group is shuffled to distribute load. + // Unknown servers and unhealthy servers are still dialed + // in case health status changed since last check. var errs []error - // Shuffle the list of servers to avoid always connecting to the same - // server. - for _, s := range utils.ShuffleVisit(kubeServers) { + for server := range healthcheck.OrderByTargetHealthStatus(kubeServers) { // Validate that the requested kube cluster is registered. - kubeCluster := s.GetCluster() - if kubeCluster.GetName() != kubeClusterName || !opt.matches(s.GetHostID()) { + if server.GetCluster().GetName() != kubeClusterName || !opt.matches(server.GetHostID()) { continue } + // serverID is a unique identifier of the server in the cluster. // It is a combination of the server's hostname and the cluster name. // . - serverID := fmt.Sprintf("%s.%s", s.GetHostID(), f.cfg.ClusterName) + serverID := server.GetHostID() + "." + f.cfg.ClusterName conn, err := localCluster.DialTCP(reversetunnelclient.DialParams{ // Send a sentinel value to the remote cluster because this connection // will be used to forward multiple requests to the remote cluster from @@ -339,18 +342,17 @@ func (f *Forwarder) localClusterDialer(kubeClusterName string, opts ...contextDi // IP Pinning is based on the source IP address of the connection that // we transport over HTTP headers so it's not affected. From: &utils.NetAddr{AddrNetwork: "tcp", Addr: "0.0.0.0:0"}, - To: &utils.NetAddr{AddrNetwork: "tcp", Addr: s.GetHostname()}, + To: &utils.NetAddr{AddrNetwork: "tcp", Addr: server.GetHostname()}, ConnType: types.KubeTunnel, ServerID: serverID, - ProxyIDs: s.GetProxyIDs(), + ProxyIDs: server.GetProxyIDs(), }) if err == nil { - opt.collect(s.GetHostID()) + opt.collect(server.GetHostID()) return conn, nil } - errs = append(errs, trace.Wrap(err)) + errs = append(errs, err) } - if len(errs) > 0 { return nil, trace.NewAggregate(errs...) } diff --git a/lib/kube/proxy/transport_test.go b/lib/kube/proxy/transport_test.go index d23e095d9b220..eb78522b1b4c2 100644 --- a/lib/kube/proxy/transport_test.go +++ b/lib/kube/proxy/transport_test.go @@ -21,10 +21,14 @@ package proxy import ( "context" "crypto/tls" - "fmt" + "errors" + "math" "net" + "slices" + "strings" "testing" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" @@ -41,6 +45,7 @@ func TestForwarderClusterDialer(t *testing.T) { hostId = "hostId" proxyIds = []string{"proxyId"} clusterName = "cluster" + health = types.TargetHealthStatusHealthy ) f := &Forwarder{ cfg: ForwarderConfig{ @@ -49,7 +54,7 @@ func TestForwarderClusterDialer(t *testing.T) { }, getKubernetesServersForKubeCluster: func(_ context.Context, kubeClusterName string) ([]types.KubeServer, error) { return []types.KubeServer{ - newKubeServerWithProxyIDs(t, hostname, hostId, proxyIds), + newKubeServer(t, hostname, hostId, proxyIds, health), }, nil }, } @@ -72,7 +77,7 @@ func TestForwarderClusterDialer(t *testing.T) { Addr: hostname, AddrNetwork: "tcp", }, - ServerID: fmt.Sprintf("%s.%s", hostId, clusterName), + ServerID: hostId + "." + clusterName, ConnType: types.KubeTunnel, ProxyIDs: proxyIds, }, @@ -128,7 +133,7 @@ func (f *fakeRemoteSiteTunnel) DialTCP(p reversetunnelclient.DialParams) (net.Co return nil, nil } -func newKubeServerWithProxyIDs(t *testing.T, hostname, hostID string, proxyIds []string) types.KubeServer { +func newKubeServer(t *testing.T, hostname, hostID string, proxyIds []string, health types.TargetHealthStatus) types.KubeServer { k, err := types.NewKubernetesClusterV3(types.Metadata{ Name: "cluster", }, types.KubernetesClusterSpecV3{}) @@ -137,6 +142,11 @@ func newKubeServerWithProxyIDs(t *testing.T, hostname, hostID string, proxyIds [ ks, err := types.NewKubernetesServerV3FromCluster(k, hostname, hostID) require.NoError(t, err) ks.Spec.ProxyIDs = proxyIds + ks.Status = &types.KubernetesServerStatusV3{ + TargetHealth: &types.TargetHealth{ + Status: string(health), + }, + } return ks } @@ -181,3 +191,167 @@ func TestDirectTransportNotCached(t *testing.T) { require.NoError(t, err) require.Equal(t, "example.com", tlsConfig.ServerName) } + +func TestLocalClusterDialsByHealth(t *testing.T) { + t.Parallel() + ctx := t.Context() + const ( + hostname = "localhost:8080" + clusterName = "cluster" + ) + proxyIds := []string{"proxyId"} + tests := []struct { + name string + servers []types.KubeServer + }{ + { + name: "one", + servers: []types.KubeServer{ + newKubeServer(t, hostname, "healthy-1", proxyIds, types.TargetHealthStatusHealthy), + }, + }, + { + name: "healthy", + servers: []types.KubeServer{ + newKubeServer(t, hostname, "healthy-1", proxyIds, types.TargetHealthStatusHealthy), + newKubeServer(t, hostname, "healthy-2", proxyIds, types.TargetHealthStatusHealthy), + newKubeServer(t, hostname, "healthy-3", proxyIds, types.TargetHealthStatusHealthy), + }, + }, + { + name: "unknown", + servers: []types.KubeServer{ + newKubeServer(t, hostname, "unknown-1", proxyIds, types.TargetHealthStatusUnknown), + newKubeServer(t, hostname, "unknown-2", proxyIds, types.TargetHealthStatusUnknown), + newKubeServer(t, hostname, "unknown-3", proxyIds, types.TargetHealthStatusUnknown), + }, + }, + { + name: "unhealthy", + servers: []types.KubeServer{ + newKubeServer(t, hostname, "unhealthy-1", proxyIds, types.TargetHealthStatusUnhealthy), + newKubeServer(t, hostname, "unhealthy-2", proxyIds, types.TargetHealthStatusUnhealthy), + newKubeServer(t, hostname, "unhealthy-3", proxyIds, types.TargetHealthStatusUnhealthy), + }, + }, + { + name: "random", + servers: []types.KubeServer{ + newKubeServer(t, hostname, "unhealthy-1", proxyIds, types.TargetHealthStatusUnhealthy), + newKubeServer(t, hostname, "healthy-3", proxyIds, types.TargetHealthStatusHealthy), + newKubeServer(t, hostname, "unknown-2", proxyIds, types.TargetHealthStatusUnknown), + newKubeServer(t, hostname, "healthy-2", proxyIds, types.TargetHealthStatusHealthy), + newKubeServer(t, hostname, "unknown-1", proxyIds, types.TargetHealthStatusUnknown), + newKubeServer(t, hostname, "unhealthy-3", proxyIds, types.TargetHealthStatusUnhealthy), + newKubeServer(t, hostname, "unknown-3", proxyIds, types.TargetHealthStatusUnknown), + newKubeServer(t, hostname, "unhealthy-2", proxyIds, types.TargetHealthStatusUnhealthy), + newKubeServer(t, hostname, "healthy-1", proxyIds, types.TargetHealthStatusHealthy), + }, + }, + { + name: "reversed", + servers: []types.KubeServer{ + newKubeServer(t, hostname, "unhealthy-3", proxyIds, types.TargetHealthStatusUnhealthy), + newKubeServer(t, hostname, "unhealthy-2", proxyIds, types.TargetHealthStatusUnhealthy), + newKubeServer(t, hostname, "unhealthy-1", proxyIds, types.TargetHealthStatusUnhealthy), + newKubeServer(t, hostname, "unknown-3", proxyIds, types.TargetHealthStatusUnknown), + newKubeServer(t, hostname, "unknown-2", proxyIds, types.TargetHealthStatusUnknown), + newKubeServer(t, hostname, "unknown-1", proxyIds, types.TargetHealthStatusUnknown), + newKubeServer(t, hostname, "healthy-3", proxyIds, types.TargetHealthStatusHealthy), + newKubeServer(t, hostname, "healthy-2", proxyIds, types.TargetHealthStatusHealthy), + newKubeServer(t, hostname, "healthy-1", proxyIds, types.TargetHealthStatusHealthy), + }, + }, + { + name: "sorted", + servers: []types.KubeServer{ + newKubeServer(t, hostname, "healthy-1", proxyIds, types.TargetHealthStatusHealthy), + newKubeServer(t, hostname, "healthy-2", proxyIds, types.TargetHealthStatusHealthy), + newKubeServer(t, hostname, "healthy-3", proxyIds, types.TargetHealthStatusHealthy), + newKubeServer(t, hostname, "unknown-1", proxyIds, types.TargetHealthStatusUnknown), + newKubeServer(t, hostname, "unknown-2", proxyIds, types.TargetHealthStatusUnknown), + newKubeServer(t, hostname, "unknown-3", proxyIds, types.TargetHealthStatusUnknown), + newKubeServer(t, hostname, "unhealthy-1", proxyIds, types.TargetHealthStatusUnhealthy), + newKubeServer(t, hostname, "unhealthy-2", proxyIds, types.TargetHealthStatusUnhealthy), + newKubeServer(t, hostname, "unhealthy-3", proxyIds, types.TargetHealthStatusUnhealthy), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &Forwarder{ + cfg: ForwarderConfig{ + tracer: otel.Tracer("test"), + ClusterName: clusterName, + ReverseTunnelSrv: healthReverseTunnel{}, + }, + getKubernetesServersForKubeCluster: func(context.Context, string) ([]types.KubeServer, error) { + return tt.servers, nil + }, + } + _, err := f.localClusterDialer(clusterName)(ctx, "tcp", "") + require.Error(t, err) + + var aggErr trace.Aggregate + require.ErrorAs(t, err, &aggErr, "expected an aggregate error") + var healthErrs []healthError + for _, e := range aggErr.Errors() { + var he healthError + if errors.As(e, &he) { + healthErrs = append(healthErrs, he) + } + } + + require.Len(t, healthErrs, len(tt.servers)) + require.True(t, slices.IsSortedFunc(healthErrs, byHealthOrder), + "expected dialed errors to be order by healthy, unknown, and unhealthy") + }) + } +} + +type healthReverseTunnel struct { + reversetunnelclient.Server +} + +func (f healthReverseTunnel) Cluster(context.Context, string) (reversetunnelclient.Cluster, error) { + return &healthRemoteSiteTunnel{}, nil +} + +type healthRemoteSiteTunnel struct { + reversetunnelclient.Cluster +} + +func (f healthRemoteSiteTunnel) DialTCP(p reversetunnelclient.DialParams) (net.Conn, error) { + // Extract health from ServerID. + // ServerID = -. + idx := strings.Index(p.ServerID, "-") + if idx < 0 { + return nil, trace.BadParameter("invalid server ID: %q", p.ServerID) + } + return nil, healthError{health: types.TargetHealthStatus(p.ServerID[:idx])} +} + +type healthError struct { + health types.TargetHealthStatus +} + +func (e healthError) Error() string { + return string(e.health) +} + +func healthOrder(e healthError) int { + switch e.health { + case types.TargetHealthStatusHealthy: + return 0 + case types.TargetHealthStatusUnknown: + return 1 + case types.TargetHealthStatusUnhealthy: + return 2 + } + return math.MaxInt +} + +func byHealthOrder(a, b healthError) int { + return healthOrder(a) - healthOrder(b) +}