diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 372253e950e44..bea574ef75039 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -516,9 +516,9 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { log.Warnf("missing connected resources gauge for keep alive %s (this is a bug)", s) } }), - inventory.WithOnDisconnect(func(s string) { + inventory.WithOnDisconnect(func(s string, c int) { if g, ok := connectedResourceGauges[s]; ok { - g.Dec() + g.Sub(float64(c)) } else { log.Warnf("missing connected resources gauge for keep alive %s (this is a bug)", s) } diff --git a/lib/inventory/controller.go b/lib/inventory/controller.go index 92220597b0e9a..71868b6da189a 100644 --- a/lib/inventory/controller.go +++ b/lib/inventory/controller.go @@ -131,7 +131,7 @@ type controllerOptions struct { maxKeepAliveErrs int authID string onConnectFunc func(string) - onDisconnectFunc func(string) + onDisconnectFunc func(string, int) } func (options *controllerOptions) SetDefaults() { @@ -153,11 +153,11 @@ func (options *controllerOptions) SetDefaults() { } if options.onConnectFunc == nil { - options.onConnectFunc = func(s string) {} + options.onConnectFunc = func(string) {} } if options.onDisconnectFunc == nil { - options.onDisconnectFunc = func(s string) {} + options.onDisconnectFunc = func(string, int) {} } } @@ -180,12 +180,12 @@ func WithOnConnect(f func(heartbeatKind string)) ControllerOption { } } -// WithOnDisconnect sets a function to be called every time an existing -// instance disconnects from the inventory control stream. The value -// provided to the callback is the keep alive type of the disconnected -// resource. The callback should return quickly so as not to prevent -// processing of heartbeats. -func WithOnDisconnect(f func(heartbeatKind string)) ControllerOption { +// WithOnDisconnect sets a function to be called every time an existing instance +// disconnects from the inventory control stream. The values provided to the +// callback are the keep alive type of the disconnected resource, as well as a +// count of how many resources disconnected at once. The callback should return +// quickly so as not to prevent processing of heartbeats. +func WithOnDisconnect(f func(heartbeatKind string, amount int)) ControllerOption { return func(opts *controllerOptions) { opts.onDisconnectFunc = f } @@ -226,7 +226,7 @@ type Controller struct { usageReporter usagereporter.UsageReporter testEvents chan testEvent onConnectFunc func(string) - onDisconnectFunc func(string) + onDisconnectFunc func(string, int) closeContext context.Context cancel context.CancelFunc } @@ -351,9 +351,10 @@ func (c *Controller) handleControlStream(handle *upstreamHandle) { defer func() { if handle.goodbye.GetDeleteResources() { log.WithFields(log.Fields{ - "apps": len(handle.appServers), - "dbs": len(handle.databaseServers), - "kube": len(handle.kubernetesServers), + "apps": len(handle.appServers), + "dbs": len(handle.databaseServers), + "kube": len(handle.kubernetesServers), + "server_id": handle.Hello().ServerID, }).Debug("Cleaning up resources in response to instance termination") for _, app := range handle.appServers { if err := c.auth.DeleteApplicationServer(c.closeContext, apidefaults.Namespace, app.resource.GetHostID(), app.resource.GetName()); err != nil && !trace.IsNotFound(err) { @@ -383,19 +384,19 @@ func (c *Controller) handleControlStream(handle *upstreamHandle) { handle.ticker.Stop() if handle.sshServer != nil { - c.onDisconnectFunc(constants.KeepAliveNode) + c.onDisconnectFunc(constants.KeepAliveNode, 1) } - for range handle.appServers { - c.onDisconnectFunc(constants.KeepAliveApp) + if len(handle.appServers) > 0 { + c.onDisconnectFunc(constants.KeepAliveApp, len(handle.appServers)) } - for range handle.databaseServers { - c.onDisconnectFunc(constants.KeepAliveDatabase) + if len(handle.databaseServers) > 0 { + c.onDisconnectFunc(constants.KeepAliveDatabase, len(handle.databaseServers)) } - for range handle.kubernetesServers { - c.onDisconnectFunc(constants.KeepAliveKube) + if len(handle.kubernetesServers) > 0 { + c.onDisconnectFunc(constants.KeepAliveKube, len(handle.kubernetesServers)) } clear(handle.appServers) @@ -845,6 +846,7 @@ func (c *Controller) keepAliveAppServer(handle *upstreamHandle, now time.Time) e if shouldRemove { c.testEvent(appKeepAliveDel) + c.onDisconnectFunc(constants.KeepAliveApp, 1) delete(handle.appServers, name) } } else { @@ -887,6 +889,7 @@ func (c *Controller) keepAliveDatabaseServer(handle *upstreamHandle, now time.Ti if shouldRemove { c.testEvent(dbKeepAliveDel) + c.onDisconnectFunc(constants.KeepAliveDatabase, 1) delete(handle.databaseServers, name) } } else { @@ -929,6 +932,7 @@ func (c *Controller) keepAliveKubernetesServer(handle *upstreamHandle, now time. if shouldRemove { c.testEvent(kubeKeepAliveDel) + c.onDisconnectFunc(constants.KeepAliveKube, 1) delete(handle.kubernetesServers, name) } } else { diff --git a/lib/inventory/controller_test.go b/lib/inventory/controller_test.go index d9c40e279a4ec..ba16edf66994f 100644 --- a/lib/inventory/controller_test.go +++ b/lib/inventory/controller_test.go @@ -176,11 +176,14 @@ func TestSSHServerBasics(t *testing.T) { expectAddr: wantAddr, } + rc := &resourceCounter{} controller := NewController( auth, usagereporter.DiscardUsageReporter{}, withServerKeepAlive(time.Millisecond*200), withTestEventsChannel(events), + WithOnConnect(rc.onConnect), + WithOnDisconnect(rc.onDisconnect), ) defer controller.Close() @@ -314,6 +317,9 @@ func TestSSHServerBasics(t *testing.T) { // here). require.Equal(t, int64(0), controller.instanceHBVariableDuration.Count()) + // verify that metrics have been updated correctly + require.Zero(t, 0, rc.count()) + // verify that the peer address of the control stream was used to override // zero-value IPs for heartbeats. auth.mu.Lock() @@ -337,11 +343,14 @@ func TestAppServerBasics(t *testing.T) { auth := &fakeAuth{} + rc := &resourceCounter{} controller := NewController( auth, usagereporter.DiscardUsageReporter{}, withServerKeepAlive(time.Millisecond*200), withTestEventsChannel(events), + WithOnConnect(rc.onConnect), + WithOnDisconnect(rc.onDisconnect), ) defer controller.Close() @@ -532,6 +541,9 @@ func TestAppServerBasics(t *testing.T) { // always *before* closure is propagated to downstream handle, hence being safe to load // here). require.Equal(t, int64(0), controller.instanceHBVariableDuration.Count()) + + // verify that metrics have been updated correctly + require.Zero(t, rc.count()) } // TestDatabaseServerBasics verifies basic expected behaviors for a single control stream heartbeating @@ -549,11 +561,14 @@ func TestDatabaseServerBasics(t *testing.T) { auth := &fakeAuth{} + rc := &resourceCounter{} controller := NewController( auth, usagereporter.DiscardUsageReporter{}, withServerKeepAlive(time.Millisecond*200), withTestEventsChannel(events), + WithOnConnect(rc.onConnect), + WithOnDisconnect(rc.onDisconnect), ) defer controller.Close() @@ -745,6 +760,9 @@ func TestDatabaseServerBasics(t *testing.T) { // always *before* closure is propagated to downstream handle, hence being safe to load // here). require.Equal(t, int64(0), controller.instanceHBVariableDuration.Count()) + + // verify that metrics have been updated correctly + require.Zero(t, rc.count()) } // TestInstanceHeartbeat verifies basic expected behaviors for instance heartbeat. @@ -1154,11 +1172,14 @@ func TestKubernetesServerBasics(t *testing.T) { auth := &fakeAuth{} + rc := &resourceCounter{} controller := NewController( auth, usagereporter.DiscardUsageReporter{}, withServerKeepAlive(time.Millisecond*200), withTestEventsChannel(events), + WithOnConnect(rc.onConnect), + WithOnDisconnect(rc.onDisconnect), ) defer controller.Close() @@ -1354,10 +1375,12 @@ func TestKubernetesServerBasics(t *testing.T) { // always *before* closure is propagated to downstream handle, hence being safe to load // here). require.Equal(t, int64(0), controller.instanceHBVariableDuration.Count()) + + // verify that metrics have been updated correctly + require.Zero(t, rc.count()) } func TestGetSender(t *testing.T) { - controller := NewController( &fakeAuth{}, usagereporter.DiscardUsageReporter{}, @@ -1468,3 +1491,37 @@ func awaitEvents(t *testing.T, ch <-chan testEvent, opts ...eventOption) { } } } + +type resourceCounter struct { + mu sync.Mutex + c map[string]int +} + +func (r *resourceCounter) onConnect(typ string) { + r.mu.Lock() + defer r.mu.Unlock() + if r.c == nil { + r.c = make(map[string]int) + } + r.c[typ]++ +} + +func (r *resourceCounter) onDisconnect(typ string, amount int) { + r.mu.Lock() + defer r.mu.Unlock() + if r.c == nil { + r.c = make(map[string]int) + } + r.c[typ] -= amount +} + +func (r *resourceCounter) count() int { + r.mu.Lock() + defer r.mu.Unlock() + + var count int + for _, v := range r.c { + count += v + } + return count +}