From 8bbf1d0ea7eac316744ddc0c456fe1ba5ce23a49 Mon Sep 17 00:00:00 2001 From: Forrest <30576607+fspmarshall@users.noreply.github.com> Date: Tue, 25 Mar 2025 13:25:48 -0700 Subject: [PATCH] fix hot keepalives (#53298) --- lib/inventory/controller.go | 349 +++++++++++---------- lib/inventory/controller_test.go | 66 ++-- lib/inventory/internal/delay/heap.go | 106 +++++++ lib/inventory/internal/delay/heap_test.go | 68 ++++ lib/inventory/internal/delay/multi.go | 210 +++++++++++++ lib/inventory/internal/delay/multi_test.go | 157 +++++++++ lib/inventory/inventory.go | 13 + 7 files changed, 783 insertions(+), 186 deletions(-) create mode 100644 lib/inventory/internal/delay/heap.go create mode 100644 lib/inventory/internal/delay/heap_test.go create mode 100644 lib/inventory/internal/delay/multi.go create mode 100644 lib/inventory/internal/delay/multi_test.go diff --git a/lib/inventory/controller.go b/lib/inventory/controller.go index 06c0e5e23126d..73a7e1ed05360 100644 --- a/lib/inventory/controller.go +++ b/lib/inventory/controller.go @@ -375,16 +375,14 @@ func (c *Controller) handleControlStream(handle *upstreamHandle) { // these delays are lazily initialized upon receipt of the first heartbeat // since not all servers send all heartbeats var sshKeepAliveDelay *delay.Delay - var appKeepAliveDelay *delay.Delay - var dbKeepAliveDelay *delay.Delay - var kubeKeepAliveDelay *delay.Delay + defer func() { // this is a function expression because the variables are initialized // later and we want to call Stop on the initialized value (if any) sshKeepAliveDelay.Stop() - appKeepAliveDelay.Stop() - dbKeepAliveDelay.Stop() - kubeKeepAliveDelay.Stop() + handle.appKeepAliveDelay.Stop() + handle.dbKeepAliveDelay.Stop() + handle.kubeKeepAliveDelay.Stop() }() for _, service := range handle.hello.Services { @@ -502,36 +500,36 @@ func (c *Controller) handleControlStream(handle *upstreamHandle) { } if m.AppServer != nil { + if handle.appKeepAliveDelay == nil { + handle.appKeepAliveDelay = c.createKeepAliveMultiDelay(c.appHBVariableDuration) + } + if err := c.handleAppServerHB(handle, m.AppServer); err != nil { handle.CloseWithError(err) return } - - if appKeepAliveDelay == nil { - appKeepAliveDelay = c.createKeepAliveDelay(c.appHBVariableDuration) - } } if m.DatabaseServer != nil { + if handle.dbKeepAliveDelay == nil { + handle.dbKeepAliveDelay = c.createKeepAliveMultiDelay(c.dbHBVariableDuration) + } + if err := c.handleDatabaseServerHB(handle, m.DatabaseServer); err != nil { handle.CloseWithError(err) return } - - if dbKeepAliveDelay == nil { - dbKeepAliveDelay = c.createKeepAliveDelay(c.dbHBVariableDuration) - } } if m.KubernetesServer != nil { + if handle.kubeKeepAliveDelay == nil { + handle.kubeKeepAliveDelay = c.createKeepAliveMultiDelay(c.kubeHBVariableDuration) + } + if err := c.handleKubernetesServerHB(handle, m.KubernetesServer); err != nil { handle.CloseWithError(err) return } - - if kubeKeepAliveDelay == nil { - kubeKeepAliveDelay = c.createKeepAliveDelay(c.kubeHBVariableDuration) - } } case proto.UpstreamInventoryPong: @@ -563,28 +561,29 @@ func (c *Controller) handleControlStream(handle *upstreamHandle) { } c.testEvent(keepAliveSSHTick) - case now := <-appKeepAliveDelay.Elapsed(): - appKeepAliveDelay.Advance(now) + case now := <-handle.appKeepAliveDelay.Elapsed(): + key := handle.appKeepAliveDelay.Tick(now) - if err := c.keepAliveAppServer(handle, now); err != nil { + if err := c.keepAliveAppServer(handle, now, key); err != nil { handle.CloseWithError(err) return } c.testEvent(keepAliveAppTick) - case now := <-dbKeepAliveDelay.Elapsed(): - dbKeepAliveDelay.Advance(now) + case now := <-handle.dbKeepAliveDelay.Elapsed(): + key := handle.dbKeepAliveDelay.Tick(now) - if err := c.keepAliveDatabaseServer(handle, now); err != nil { + if err := c.keepAliveDatabaseServer(handle, now, key); err != nil { handle.CloseWithError(err) return } + c.testEvent(keepAliveDatabaseTick) - case now := <-kubeKeepAliveDelay.Elapsed(): - kubeKeepAliveDelay.Advance(now) + case now := <-handle.kubeKeepAliveDelay.Elapsed(): + key := handle.kubeKeepAliveDelay.Tick(now) - if err := c.keepAliveKubernetesServer(handle, now); err != nil { + if err := c.keepAliveKubernetesServer(handle, now, key); err != nil { handle.CloseWithError(err) return } @@ -797,6 +796,7 @@ func (c *Controller) handleAppServerHB(handle *upstreamHandle, appServer *types. c.appHBVariableDuration.Inc() } handle.appServers[appKey] = &heartBeatInfo[*types.AppServerV3]{} + handle.appKeepAliveDelay.Add(appKey) } now := time.Now() @@ -851,6 +851,7 @@ func (c *Controller) handleDatabaseServerHB(handle *upstreamHandle, databaseServ c.dbHBVariableDuration.Inc() } handle.databaseServers[dbKey] = &heartBeatInfo[*types.DatabaseServerV3]{} + handle.dbKeepAliveDelay.Add(dbKey) } now := time.Now() @@ -905,6 +906,7 @@ func (c *Controller) handleKubernetesServerHB(handle *upstreamHandle, kubernetes c.kubeHBVariableDuration.Inc() } handle.kubernetesServers[kubeKey] = &heartBeatInfo[*types.KubernetesServerV3]{} + handle.kubeKeepAliveDelay.Add(kubeKey) } now := time.Now() @@ -960,163 +962,177 @@ func (c *Controller) handleAgentMetadata(handle *upstreamHandle, m proto.Upstrea }) } -func (c *Controller) keepAliveAppServer(handle *upstreamHandle, now time.Time) error { - for name, srv := range handle.appServers { - if srv.lease != nil { - lease := *srv.lease - lease.Expires = now.Add(c.serverTTL).UTC() - if err := c.auth.KeepAliveServer(c.closeContext, lease); err != nil { - c.testEvent(appKeepAliveErr) - - srv.keepAliveErrs++ - handle.appServers[name] = srv - shouldRemove := srv.keepAliveErrs > c.maxKeepAliveErrs - slog.WarnContext(c.closeContext, "Failed to keep alive app server", - "server_id", handle.Hello().ServerID, - "error", err, - "error_count", srv.keepAliveErrs, - "should_remove", shouldRemove, - ) +func (c *Controller) keepAliveAppServer(handle *upstreamHandle, now time.Time, name resourceKey) error { + srv, ok := handle.appServers[name] + if !ok { + handle.appKeepAliveDelay.Remove(name) + return trace.Errorf("desync between app server hb registry and keepalive delay (this is a bug)") + } - if shouldRemove { - c.testEvent(appKeepAliveDel) - c.onDisconnectFunc(constants.KeepAliveApp, 1) - if c.appHBVariableDuration != nil { - c.appHBVariableDuration.Dec() - } - delete(handle.appServers, name) + if srv.lease != nil { + lease := *srv.lease + lease.Expires = now.Add(c.serverTTL).UTC() + if err := c.auth.KeepAliveServer(c.closeContext, lease); err != nil { + c.testEvent(appKeepAliveErr) + + srv.keepAliveErrs++ + handle.appServers[name] = srv + shouldRemove := srv.keepAliveErrs > c.maxKeepAliveErrs + slog.WarnContext(c.closeContext, "Failed to keep alive app server", + "server_id", handle.Hello().ServerID, + "error", err, + "error_count", srv.keepAliveErrs, + "should_remove", shouldRemove, + ) + + if shouldRemove { + c.testEvent(appKeepAliveDel) + c.onDisconnectFunc(constants.KeepAliveApp, 1) + if c.appHBVariableDuration != nil { + c.appHBVariableDuration.Dec() } - } else { - srv.keepAliveErrs = 0 - c.testEvent(appKeepAliveOk) - } - } else if srv.retryUpsert { - srv.resource.SetExpiry(time.Now().Add(c.serverTTL).UTC()) - lease, err := c.auth.UpsertApplicationServer(c.closeContext, srv.resource) - if err != nil { - c.testEvent(appUpsertRetryErr) - slog.WarnContext(c.closeContext, "Failed to upsert app server on retry", - "server_id", handle.Hello().ServerID, - "error", err, - ) - // since this is retry-specific logic, an error here means that upsert failed twice in - // a row. Missing upserts is more problematic than missing keepalives so we don't bother - // attempting a third time. - return trace.Errorf("failed to upsert app server on retry: %v", err) + delete(handle.appServers, name) + handle.appKeepAliveDelay.Remove(name) } - c.testEvent(appUpsertRetryOk) - - srv.lease = lease - srv.retryUpsert = false + } else { + srv.keepAliveErrs = 0 + c.testEvent(appKeepAliveOk) } + } else if srv.retryUpsert { + srv.resource.SetExpiry(time.Now().Add(c.serverTTL).UTC()) + lease, err := c.auth.UpsertApplicationServer(c.closeContext, srv.resource) + if err != nil { + c.testEvent(appUpsertRetryErr) + slog.WarnContext(c.closeContext, "Failed to upsert app server on retry", + "server_id", handle.Hello().ServerID, + "error", err, + ) + // since this is retry-specific logic, an error here means that upsert failed twice in + // a row. Missing upserts is more problematic than missing keepalives so we don't bother + // attempting a third time. + return trace.Errorf("failed to upsert app server on retry: %v", err) + } + c.testEvent(appUpsertRetryOk) + + srv.lease = lease + srv.retryUpsert = false } return nil } -func (c *Controller) keepAliveDatabaseServer(handle *upstreamHandle, now time.Time) error { - for name, srv := range handle.databaseServers { - if srv.lease != nil { - lease := *srv.lease - lease.Expires = now.Add(c.serverTTL).UTC() - if err := c.auth.KeepAliveServer(c.closeContext, lease); err != nil { - c.testEvent(dbKeepAliveErr) - - srv.keepAliveErrs++ - handle.databaseServers[name] = srv - shouldRemove := srv.keepAliveErrs > c.maxKeepAliveErrs - slog.WarnContext(c.closeContext, "Failed to keep alive database server", - "server_id", handle.Hello().ServerID, - "error", err, - "error_count", srv.keepAliveErrs, - "should_remove", shouldRemove, - ) +func (c *Controller) keepAliveDatabaseServer(handle *upstreamHandle, now time.Time, name resourceKey) error { + srv, ok := handle.databaseServers[name] + if !ok { + handle.dbKeepAliveDelay.Remove(name) + return trace.Errorf("desync between db server hb registry and keepalive delay (this is a bug)") + } + if srv.lease != nil { + lease := *srv.lease + lease.Expires = now.Add(c.serverTTL).UTC() + if err := c.auth.KeepAliveServer(c.closeContext, lease); err != nil { + c.testEvent(dbKeepAliveErr) + + srv.keepAliveErrs++ + handle.databaseServers[name] = srv + shouldRemove := srv.keepAliveErrs > c.maxKeepAliveErrs + slog.WarnContext(c.closeContext, "Failed to keep alive database server", + "server_id", handle.Hello().ServerID, + "error", err, + "error_count", srv.keepAliveErrs, + "should_remove", shouldRemove, + ) - if shouldRemove { - c.testEvent(dbKeepAliveDel) - c.onDisconnectFunc(constants.KeepAliveDatabase, 1) - if c.dbHBVariableDuration != nil { - c.dbHBVariableDuration.Dec() - } - delete(handle.databaseServers, name) + if shouldRemove { + c.testEvent(dbKeepAliveDel) + c.onDisconnectFunc(constants.KeepAliveDatabase, 1) + if c.dbHBVariableDuration != nil { + c.dbHBVariableDuration.Dec() } - } else { - srv.keepAliveErrs = 0 - c.testEvent(dbKeepAliveOk) - } - } else if srv.retryUpsert { - srv.resource.SetExpiry(time.Now().Add(c.serverTTL).UTC()) - lease, err := c.auth.UpsertDatabaseServer(c.closeContext, srv.resource) - if err != nil { - c.testEvent(dbUpsertRetryErr) - slog.WarnContext(c.closeContext, "Failed to upsert database server on retry", - "server_id", handle.Hello().ServerID, - "error", err, - ) - // since this is retry-specific logic, an error here means that upsert failed twice in - // a row. Missing upserts is more problematic than missing keepalives so we don't bother - // attempting a third time. - return trace.Errorf("failed to upsert database server on retry: %v", err) + delete(handle.databaseServers, name) + handle.dbKeepAliveDelay.Remove(name) } - c.testEvent(dbUpsertRetryOk) - - srv.lease = lease - srv.retryUpsert = false + } else { + srv.keepAliveErrs = 0 + c.testEvent(dbKeepAliveOk) + } + } else if srv.retryUpsert { + srv.resource.SetExpiry(time.Now().Add(c.serverTTL).UTC()) + lease, err := c.auth.UpsertDatabaseServer(c.closeContext, srv.resource) + if err != nil { + c.testEvent(dbUpsertRetryErr) + slog.WarnContext(c.closeContext, "Failed to upsert database server on retry", + "server_id", handle.Hello().ServerID, + "error", err, + ) + // since this is retry-specific logic, an error here means that upsert failed twice in + // a row. Missing upserts is more problematic than missing keepalives so we don't bother + // attempting a third time. + return trace.Errorf("failed to upsert database server on retry: %v", err) } + c.testEvent(dbUpsertRetryOk) + + srv.lease = lease + srv.retryUpsert = false } return nil } -func (c *Controller) keepAliveKubernetesServer(handle *upstreamHandle, now time.Time) error { - for name, srv := range handle.kubernetesServers { - if srv.lease != nil { - lease := *srv.lease - lease.Expires = now.Add(c.serverTTL).UTC() - if err := c.auth.KeepAliveServer(c.closeContext, lease); err != nil { - c.testEvent(kubeKeepAliveErr) - - srv.keepAliveErrs++ - handle.kubernetesServers[name] = srv - shouldRemove := srv.keepAliveErrs > c.maxKeepAliveErrs - slog.WarnContext(c.closeContext, "Failed to keep alive kubernetes server", - "server_id", handle.Hello().ServerID, - "error", err, - "error_count", srv.keepAliveErrs, - "should_remove", shouldRemove, - ) +func (c *Controller) keepAliveKubernetesServer(handle *upstreamHandle, now time.Time, name resourceKey) error { + srv, ok := handle.kubernetesServers[name] + if !ok { + handle.kubeKeepAliveDelay.Remove(name) + return trace.Errorf("desync between kube server hb registry and keepalive delay (this is a bug)") + } - if shouldRemove { - c.testEvent(kubeKeepAliveDel) - c.onDisconnectFunc(constants.KeepAliveKube, 1) - if c.kubeHBVariableDuration != nil { - c.kubeHBVariableDuration.Dec() - } - delete(handle.kubernetesServers, name) + if srv.lease != nil { + lease := *srv.lease + lease.Expires = now.Add(c.serverTTL).UTC() + if err := c.auth.KeepAliveServer(c.closeContext, lease); err != nil { + c.testEvent(kubeKeepAliveErr) + + srv.keepAliveErrs++ + handle.kubernetesServers[name] = srv + shouldRemove := srv.keepAliveErrs > c.maxKeepAliveErrs + slog.WarnContext(c.closeContext, "Failed to keep alive kubernetes server", + "server_id", handle.Hello().ServerID, + "error", err, + "error_count", srv.keepAliveErrs, + "should_remove", shouldRemove, + ) + + if shouldRemove { + c.testEvent(kubeKeepAliveDel) + c.onDisconnectFunc(constants.KeepAliveKube, 1) + if c.kubeHBVariableDuration != nil { + c.kubeHBVariableDuration.Dec() } - } else { - srv.keepAliveErrs = 0 - c.testEvent(kubeKeepAliveOk) - } - } else if srv.retryUpsert { - srv.resource.SetExpiry(time.Now().Add(c.serverTTL).UTC()) - lease, err := c.auth.UpsertKubernetesServer(c.closeContext, srv.resource) - if err != nil { - c.testEvent(kubeUpsertRetryErr) - slog.WarnContext(c.closeContext, "Failed to upsert kubernetes server on retry.", - "server_id", handle.Hello().ServerID, - "error", err, - ) - // since this is retry-specific logic, an error here means that upsert failed twice in - // a row. Missing upserts is more problematic than missing keepalives so we don'resource bother - // attempting a third time. - return trace.Errorf("failed to upsert kubernetes server on retry: %v", err) + delete(handle.kubernetesServers, name) + handle.kubeKeepAliveDelay.Remove(name) } - c.testEvent(kubeUpsertRetryOk) - - srv.lease = lease - srv.retryUpsert = false + } else { + srv.keepAliveErrs = 0 + c.testEvent(kubeKeepAliveOk) } + } else if srv.retryUpsert { + srv.resource.SetExpiry(time.Now().Add(c.serverTTL).UTC()) + lease, err := c.auth.UpsertKubernetesServer(c.closeContext, srv.resource) + if err != nil { + c.testEvent(kubeUpsertRetryErr) + slog.WarnContext(c.closeContext, "Failed to upsert kubernetes server on retry.", + "server_id", handle.Hello().ServerID, + "error", err, + ) + // since this is retry-specific logic, an error here means that upsert failed twice in + // a row. Missing upserts is more problematic than missing keepalives so we don'resource bother + // attempting a third time. + return trace.Errorf("failed to upsert kubernetes server on retry: %v", err) + } + c.testEvent(kubeUpsertRetryOk) + + srv.lease = lease + srv.retryUpsert = false } return nil @@ -1176,6 +1192,15 @@ func (c *Controller) createKeepAliveDelay(variableDuration *interval.VariableDur }) } +func (c *Controller) createKeepAliveMultiDelay(variableDuration *interval.VariableDuration) *delay.Multi[resourceKey] { + return delay.NewMulti[resourceKey](delay.MultiParams{ + FixedInterval: c.serverKeepAlive, + VariableInterval: variableDuration, + FirstJitter: retryutils.HalfJitter, + Jitter: retryutils.SeventhJitter, + }) +} + // Close terminates all control streams registered with this controller. Control streams // registered after Close() is called are closed immediately. func (c *Controller) Close() error { diff --git a/lib/inventory/controller_test.go b/lib/inventory/controller_test.go index 0183716b880a4..7e9d72227f082 100644 --- a/lib/inventory/controller_test.go +++ b/lib/inventory/controller_test.go @@ -461,6 +461,19 @@ func TestAppServerBasics(t *testing.T) { deny(appUpsertErr, handlerClose), ) + // jitter can sometimes cause one app to keepalive twice before another has completed one keepalive. for that + // reason, we want 2x the number of apps worth of keepalives to ensure that the failed keepalive counts associated + // with each app have been reset. otherwise, later parts of this test become flaky. + var keepaliveEvents []testEvent + for i := 0; i < appCount; i++ { + keepaliveEvents = append(keepaliveEvents, []testEvent{appKeepAliveOk, appKeepAliveOk}...) + } + + awaitEvents(t, events, + expect(keepaliveEvents...), + deny(appKeepAliveErr, handlerClose), + ) + for i := 0; i < appCount; i++ { err := downstream.Send(ctx, proto.InventoryHeartbeat{ AppServer: &types.AppServerV3{ @@ -486,7 +499,7 @@ func TestAppServerBasics(t *testing.T) { // we should now see an upsert failure, but no additional // keepalive failures, and the upsert should succeed on retry. awaitEvents(t, events, - expect(appKeepAliveOk, appKeepAliveOk, appKeepAliveOk, appUpsertErr, appUpsertRetryOk), + expect(appUpsertErr, appUpsertRetryOk), deny(appKeepAliveErr, handlerClose), ) @@ -535,13 +548,6 @@ func TestAppServerBasics(t *testing.T) { deny(handlerClose), ) - // verify that further keepalive ticks to not result in attempts to keepalive - // apps (successful or not). - awaitEvents(t, events, - expect(keepAliveAppTick, keepAliveAppTick, keepAliveAppTick), - deny(appKeepAliveOk, appKeepAliveErr, handlerClose), - ) - // set up to induce enough consecutive errors to cause stream closure auth.mu.Lock() auth.failUpserts = 5 @@ -680,6 +686,19 @@ func TestDatabaseServerBasics(t *testing.T) { deny(dbUpsertErr, handlerClose), ) + // jitter can sometimes cause one app to keepalive twice before another has completed one keepalive. for that + // reason, we want 2x the number of apps worth of keepalives to ensure that the failed keepalive counts associated + // with each app have been reset. otherwise, later parts of this test become flaky. + var keepaliveEvents []testEvent + for i := 0; i < dbCount; i++ { + keepaliveEvents = append(keepaliveEvents, []testEvent{dbKeepAliveOk, dbKeepAliveOk}...) + } + + awaitEvents(t, events, + expect(keepaliveEvents...), + deny(appKeepAliveErr, handlerClose), + ) + for i := 0; i < dbCount; i++ { err := downstream.Send(ctx, proto.InventoryHeartbeat{ DatabaseServer: &types.DatabaseServerV3{ @@ -705,7 +724,7 @@ func TestDatabaseServerBasics(t *testing.T) { // we should now see an upsert failure, but no additional // keepalive failures, and the upsert should succeed on retry. awaitEvents(t, events, - expect(dbKeepAliveOk, dbKeepAliveOk, dbKeepAliveOk, dbUpsertErr, dbUpsertRetryOk), + expect(dbUpsertErr, dbUpsertRetryOk), deny(dbKeepAliveErr, handlerClose), ) @@ -754,13 +773,6 @@ func TestDatabaseServerBasics(t *testing.T) { deny(handlerClose), ) - // verify that further keepalive ticks to not result in attempts to keepalive - // dbs (successful or not). - awaitEvents(t, events, - expect(keepAliveDatabaseTick, keepAliveDatabaseTick, keepAliveDatabaseTick), - deny(dbKeepAliveOk, dbKeepAliveErr, handlerClose), - ) - // set up to induce enough consecutive errors to cause stream closure auth.mu.Lock() auth.failUpserts = 5 @@ -1279,6 +1291,19 @@ func TestKubernetesServerBasics(t *testing.T) { deny(kubeUpsertErr, handlerClose), ) + // jitter can sometimes cause one app to keepalive twice before another has completed one keepalive. for that + // reason, we want 2x the number of apps worth of keepalives to ensure that the failed keepalive counts associated + // with each app have been reset. otherwise, later parts of this test become flaky. + var keepaliveEvents []testEvent + for i := 0; i < kubeCount; i++ { + keepaliveEvents = append(keepaliveEvents, []testEvent{kubeKeepAliveOk, kubeKeepAliveOk}...) + } + + awaitEvents(t, events, + expect(keepaliveEvents...), + deny(appKeepAliveErr, handlerClose), + ) + for i := 0; i < kubeCount; i++ { err := downstream.Send(ctx, proto.InventoryHeartbeat{ KubernetesServer: &types.KubernetesServerV3{ @@ -1305,7 +1330,7 @@ func TestKubernetesServerBasics(t *testing.T) { // we should now see an upsert failure, but no additional // keepalive failures, and the upsert should succeed on retry. awaitEvents(t, events, - expect(kubeKeepAliveOk, kubeKeepAliveOk, kubeKeepAliveOk, kubeUpsertErr, kubeUpsertRetryOk), + expect(kubeUpsertErr, kubeUpsertRetryOk), deny(kubeKeepAliveErr, handlerClose), ) @@ -1354,13 +1379,6 @@ func TestKubernetesServerBasics(t *testing.T) { deny(handlerClose), ) - // verify that further keepalive ticks to not result in attempts to keepalive - // apps (successful or not). - awaitEvents(t, events, - expect(keepAliveKubeTick, keepAliveKubeTick, keepAliveKubeTick), - deny(kubeKeepAliveOk, kubeKeepAliveErr, handlerClose), - ) - // set up to induce enough consecutive errors to cause stream closure auth.mu.Lock() auth.failUpserts = 5 diff --git a/lib/inventory/internal/delay/heap.go b/lib/inventory/internal/delay/heap.go new file mode 100644 index 0000000000000..ea59ddbaea4e1 --- /dev/null +++ b/lib/inventory/internal/delay/heap.go @@ -0,0 +1,106 @@ +// Copyright 2025 Gravitational, Inc. +// Copyright 2009 The Go Authors +// SPDX-License-Identifier: BSD-3-Clause + +package delay + +type noUnkeyedLiterals struct{} + +type heap[T any] struct { + _ noUnkeyedLiterals + + Less func(T, T) bool + Slice []T +} + +func (h *heap[T]) up(j int) { + for { + i := (j - 1) / 2 // parent + if i == j || !h.Less(h.Slice[j], h.Slice[i]) { + break + } + h.Slice[i], h.Slice[j] = h.Slice[j], h.Slice[i] + j = i + } +} + +func (h *heap[T]) down(i int) bool { + i0 := i + for { + j1 := 2*i + 1 + if j1 >= len(h.Slice) || j1 < 0 { // j1 < 0 after int overflow + break + } + j := j1 // left child + if j2 := j1 + 1; j2 < len(h.Slice) && h.Less(h.Slice[j2], h.Slice[j1]) { + j = j2 // = 2*i + 2 // right child + } + if !h.Less(h.Slice[j], h.Slice[i]) { + break + } + h.Slice[i], h.Slice[j] = h.Slice[j], h.Slice[i] + i = j + } + return i > i0 +} + +func (h *heap[T]) Root() *T { + if len(h.Slice) == 0 { + return nil + } + + return &h.Slice[0] +} + +func (h *heap[T]) FixRoot() { + h.Fix(0) +} + +func (h *heap[T]) Init() { + // heapify + n := len(h.Slice) + for i := n/2 - 1; i >= 0; i-- { + h.down(i) + } +} + +func (h *heap[T]) Push(x T) { + h.Slice = append(h.Slice, x) + h.up(len(h.Slice) - 1) +} + +func (h *heap[T]) Pop() T { + n := len(h.Slice) - 1 + x := h.Slice[0] + h.Slice[0] = h.Slice[n] + h.Slice[n] = *new(T) + h.Slice = h.Slice[:n] + if n != 0 { + h.down(0) + } + return x +} + +func (h *heap[T]) Remove(i int) T { + n := len(h.Slice) - 1 + x := h.Slice[i] + h.Slice[i] = h.Slice[n] + h.Slice[n] = *new(T) + h.Slice = h.Slice[:n] + if n != i { + if !h.down(i) { + h.up(i) + } + } + return x +} + +func (h *heap[T]) Fix(i int) { + if !h.down(i) { + h.up(i) + } +} + +func (h *heap[T]) Clear() { + h.Slice = nil +} diff --git a/lib/inventory/internal/delay/heap_test.go b/lib/inventory/internal/delay/heap_test.go new file mode 100644 index 0000000000000..440d8c0363067 --- /dev/null +++ b/lib/inventory/internal/delay/heap_test.go @@ -0,0 +1,68 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package delay + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestHeapBasics(t *testing.T) { + heap := heap[entry[int]]{ + Less: entryLess[int], + } + + now := time.Now() + + t1 := now.Add(time.Millisecond) + heap.Push(entry[int]{tick: t1, key: 1}) + + t2 := now.Add(time.Millisecond * 2) + heap.Push(entry[int]{tick: t2, key: 2}) + + require.Equal(t, entry[int]{tick: t1, key: 1}, heap.Pop()) + require.Equal(t, entry[int]{tick: t2, key: 2}, heap.Pop()) + + for i := 0; i < 100; i++ { + ts := now.Add(time.Duration(i+1) * time.Millisecond) + heap.Push(entry[int]{tick: ts, key: i}) + } + + root := heap.Root() + require.NotNil(t, root) + require.Equal(t, 0, root.key) + root.tick = now.Add(time.Hour) + heap.FixRoot() + + newRoot := heap.Root() + require.NotNil(t, newRoot) + require.Equal(t, 1, newRoot.key) + + var prev *entry[int] + for i := 0; i < 100; i++ { + next := heap.Pop() + if prev != nil { + require.True(t, prev.tick.Before(next.tick), "prev: %v, next: %v", prev, next) + } + require.Equal(t, (i+1)%100, next.key) + prev = &next + } + + require.Empty(t, heap.Slice) +} diff --git a/lib/inventory/internal/delay/multi.go b/lib/inventory/internal/delay/multi.go new file mode 100644 index 0000000000000..c9fcf8fe3a032 --- /dev/null +++ b/lib/inventory/internal/delay/multi.go @@ -0,0 +1,210 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package delay + +import ( + "fmt" + "time" + + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/lib/utils/interval" +) + +type entry[T any] struct { + tick time.Time + key T +} + +func (e entry[T]) String() string { + return fmt.Sprintf("entry{tick: %v, key: %v}", e.tick.Format(time.RFC3339Nano), e.key) +} + +func entryLess[T any](a, b entry[T]) bool { + return a.tick.Before(b.tick) +} + +// MultiParams contains the parameters for [NewMulti]. +type MultiParams struct { + // FirstInterval is the expected time between the creation of the [Delay] + // and the first tick. It's not modified by the configured jitter. + FirstInterval time.Duration + // FixedInterval is the interval of the delay, unless VariableInterval is + // set. If a jitter is configured, the interval will be jittered every tick. + FixedInterval time.Duration + // VariableInterval, if set, overrides FixedInterval at every tick. + VariableInterval *interval.VariableDuration + // FirstJitter is the jitter applied to the first interval. It's not applied + // to the interval after the first tick. If unset, the standard jitter is + // applied to the first interval. + FirstJitter retryutils.Jitter + // Jitter is a jitter function, applied every tick (if set) to the fixed or + // variable interval (except for the first tick if FirstJitter is set). + Jitter retryutils.Jitter + + clock clockwork.Clock +} + +// Multi is a ticker-like abstraction around a [*time.Timer] that's made to tick +// periodically with a potentially varying interval and optionally some jitter. +// Its use requires some care as the logic driving the ticks and the jittering +// must be explicitly invoked by the code making use of it, but uses no +// background resources. It tracks an arbitrary number of sub-intervals by key, +// allowing a single delay to be applied to multiple overlapping intervals. +type Multi[T comparable] struct { + clock clockwork.Clock + timer clockwork.Timer + + heap heap[entry[T]] + + fixedInterval time.Duration + variableInterval *interval.VariableDuration + + firstJitter retryutils.Jitter + jitter retryutils.Jitter +} + +// NewMulti returns a new [*Multi]. Note that the delay starts with no subintervals +// and will not tick until at least one subinterval is added. +func NewMulti[T comparable](p MultiParams) *Multi[T] { + if p.clock == nil { + p.clock = clockwork.NewRealClock() + } + return &Multi[T]{ + clock: p.clock, + + heap: heap[entry[T]]{ + Less: entryLess[T], + }, + + fixedInterval: p.FixedInterval, + variableInterval: p.VariableInterval, + + firstJitter: p.FirstJitter, + jitter: p.Jitter, + } +} + +func (h *Multi[T]) Add(key T) { + // add new target to the heap + now := h.clock.Now() + entry := entry[T]{ + tick: now.Add(h.interval(true /* first */)), + key: key, + } + h.heap.Push(entry) + + // trigger reset in case the new entry should be the next target + h.reset(now, false /* fired */) +} + +func (h *Multi[T]) Remove(key T) { + // key is not the current target, remove it from the heap + for i, entry := range h.heap.Slice { + if entry.key == key { + h.heap.Remove(i) + if i == 0 { + // if the removed entry was the root of the heap, then our target + // has changed and we need to reset the timer to a new target. + h.reset(h.clock.Now(), false /* fired */) + } + return + } + } +} + +// Tick *must* be called exactly once for each firing observed on the Elapsed channel, with the time +// of the firing. Tick will advance the internal state of the multi to start targeting the next interval, +// and return the key associated with the interval that just fired. +func (h *Multi[T]) Tick(now time.Time) (key T) { + // advance the current root entry (source of the tick), and record its + // key for later return. + root := h.heap.Root() + key = root.key + root.tick = now.Add(h.interval(false /* first */)) + + // fix the heap ordering to reflect the updated state + h.heap.FixRoot() + + // reset timer to match the new state + h.reset(now, true /* fired */) + + return +} + +// reset configures the appropriate timer/channel for the current state given the +// current time. reset must be called after any addition, removal, or advancement. +// the fired parameter must be true if the call context is one in which a timer firing +// has been *observed* (i.e. the channel alread drained) and false otherwise. +func (h *Multi[T]) reset(now time.Time, fired bool) { + // if reset isn't in *response* to firing timer may need to be reset + if h.timer != nil && !fired && !h.timer.Stop() { + <-h.timer.Chan() + } + + root := h.heap.Root() + if root == nil { + // no targets, fully reset state to free resources and ensure that we're + // in the expected state if/when new targets are added in the future. + h.timer = nil + h.heap.Clear() + return + } + + d := root.tick.Sub(now) + + if h.timer == nil { + h.timer = h.clock.NewTimer(d) + } else { + h.timer.Reset(d) + } +} + +func (h *Multi[T]) Elapsed() <-chan time.Time { + if h == nil || h.timer == nil { + return nil + } + + return h.timer.Chan() +} + +func (h *Multi[T]) interval(first bool) time.Duration { + ivl := h.fixedInterval + if h.variableInterval != nil { + ivl = h.variableInterval.Duration() + } + + if first && h.firstJitter != nil { + ivl = h.firstJitter(ivl) + } else if h.jitter != nil { + ivl = h.jitter(ivl) + } + + return ivl +} + +// Stop stops the delay. Only needed for Go 1.22 and [clockwork.Clock] +// compatibility. Can be called on a nil delay, as a no-op. The delay should not +// be used afterwards. +func (h *Multi[T]) Stop() { + if h == nil || h.timer == nil { + return + } + + h.timer.Stop() +} diff --git a/lib/inventory/internal/delay/multi_test.go b/lib/inventory/internal/delay/multi_test.go new file mode 100644 index 0000000000000..24c7a04c081fd --- /dev/null +++ b/lib/inventory/internal/delay/multi_test.go @@ -0,0 +1,157 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package delay + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/utils/interval" +) + +func TestMultiBasics(t *testing.T) { + const interval = time.Millisecond * 20 + t.Parallel() + + multi := NewMulti[int](MultiParams{ + FixedInterval: interval, + }) + + // verify that delay is in an initial state that will never fire + require.Nil(t, multi.Elapsed()) + + for i := 1; i <= 10; i++ { + // add a subinterval + multi.Add(i) + } + + for i := 0; i < 30; i++ { + now := <-multi.Elapsed() + require.Equal(t, i%10+1, multi.Tick(now)) + } + + // remove some subintervals + for i := 1; i <= 8; i++ { + multi.Remove(i) + } + + // verify that remaining subintervals are still being serviced + for i := 0; i < 30; i++ { + k := 10 + if i%2 == 0 { + k = 9 + } + now := <-multi.Elapsed() + require.Equal(t, k, multi.Tick(now)) + } + + multi.Remove(9) + multi.Remove(10) + + // verify complete removal of all sub-intervals + select { + case <-multi.Elapsed(): + t.Fatal("expected no more ticks") + case <-time.After(interval * 3): + } + + // verify that the multi is still usable after having been + // fully drained. + multi.Add(777) + select { + case now := <-multi.Elapsed(): + require.Equal(t, 777, multi.Tick(now)) + case <-time.After(interval * 3): + t.Fatal("timeout waiting for re-added delay to fire") + } +} + +func TestMultiJitter(t *testing.T) { + t.Parallel() + + var jitterCalled atomic.Bool + fakeJitter := func(d time.Duration) time.Duration { + jitterCalled.Store(true) + return time.Millisecond * 20 + } + + multi := NewMulti[int](MultiParams{ + FixedInterval: time.Hour, + Jitter: fakeJitter, + }) + + for i := 0; i < 10; i++ { + multi.Add(i + 1) + } + + for i := 0; i < 10; i++ { + select { + case now := <-multi.Elapsed(): + multi.Tick(now) + case <-time.After(time.Second * 10): + t.Fatal("timeout waiting for delay to fire") + } + require.True(t, jitterCalled.Swap(false)) + } +} + +func TestMultiVariable(t *testing.T) { + t.Parallel() + + clock := clockwork.NewFakeClock() + start := clock.Now() + + ivl := interval.NewVariableDuration(interval.VariableDurationConfig{ + MinDuration: 2 * time.Minute, + MaxDuration: 4 * time.Minute, + Step: 1, + }) + + // deterministic jitter, always half the actual time + multi := NewMulti[int](MultiParams{ + VariableInterval: ivl, + Jitter: func(d time.Duration) time.Duration { + return d / 2 + }, + clock: clock, + }) + defer multi.Stop() + + multi.Add(1) + + clock.BlockUntil(1) + clock.Advance(time.Minute) + + // this is enough to saturate the VariableDuration, so we are going to hit + // the max duration every time + ivl.Add(100) + + ts := <-multi.Elapsed() + multi.Tick(ts) + require.Equal(t, start.Add(time.Minute), ts) + + clock.BlockUntil(1) + clock.Advance(2 * time.Minute) + + ts = <-multi.Elapsed() + multi.Tick(ts) + require.Equal(t, start.Add(3*time.Minute), ts) +} diff --git a/lib/inventory/inventory.go b/lib/inventory/inventory.go index 2cae348402472..a25d5c6bb4e22 100644 --- a/lib/inventory/inventory.go +++ b/lib/inventory/inventory.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/inventory/internal/delay" "github.com/gravitational/teleport/lib/inventory/metadata" "github.com/gravitational/teleport/lib/utils" vc "github.com/gravitational/teleport/lib/versioncontrol" @@ -577,6 +578,18 @@ type upstreamHandle struct { // kubernetesServers track kubernetesServers server details. kubernetesServers map[resourceKey]*heartBeatInfo[*types.KubernetesServerV3] + + // appKeepAliveDelay is a multi-delay that controls the cadence of app server keepalive + // operations. Note that this is not created automatically by newUpstreamHandle. + appKeepAliveDelay *delay.Multi[resourceKey] + + // dbKeepAliveDelay is a multi-delay that controls the cadence of database server keepalive + // operations. Note that this is not created automatically by newUpstreamHandle. + dbKeepAliveDelay *delay.Multi[resourceKey] + + // kubeKeepAliveDelay is a multi-delay that controls the cadence of kubernetes server keepalive + // operations. Note that this is not created automatically by newUpstreamHandle. + kubeKeepAliveDelay *delay.Multi[resourceKey] } type resourceKey struct {