diff --git a/internal/component/database_observability/connection_info_monitor.go b/internal/component/database_observability/connection_info_monitor.go new file mode 100644 index 00000000000..3eed3b7f219 --- /dev/null +++ b/internal/component/database_observability/connection_info_monitor.go @@ -0,0 +1,87 @@ +package database_observability + +import ( + "context" + "database/sql" + "time" + + "github.com/prometheus/client_golang/prometheus" +) + +// ConnectionCheckInterval is how often the connection_info collector pings the DB to verify connectivity. +const ConnectionCheckInterval = 60 * time.Second + +// ConnectionChecksThreshold is the number of consecutive failed pings before unregistering the metric, +// and the number of consecutive successful pings before re-registering it after a disconnect. +const ConnectionChecksThreshold = 3 + +// ConnectionInfoMonitorConfig optionally overrides the default check interval and threshold. +// Used by tests to run the monitor with shorter intervals. If nil, defaults are used. +type ConnectionInfoMonitorConfig struct { + CheckInterval time.Duration + ChecksThreshold int +} + +// RunConnectionInfoMonitor starts a goroutine that pings db every ConnectionCheckInterval. +// After ConnectionChecksThreshold consecutive ping failures it unregisters infoMetric from registry. +// After ConnectionChecksThreshold consecutive ping successes (when the metric is unregistered) it re-registers +// infoMetric and sets it to 1 with the given labelValues. +// The goroutine runs until ctx is done. onStopped is called when the goroutine exits (e.g. when ctx is cancelled). +// RunConnectionInfoMonitor returns a cancel function that cancels the context passed to the goroutine; the caller +// should call cancel in Stop() to ensure the goroutine exits. +// labelValues must contain exactly 6 values in order: provider_name, provider_region, provider_account, +// db_instance_identifier, engine, engine_version. +// If config is non-nil, its CheckInterval and ChecksThreshold override the default constants (used for testing). +func RunConnectionInfoMonitor(ctx context.Context, db *sql.DB, registry *prometheus.Registry, infoMetric *prometheus.GaugeVec, labelValues []string, onStopped func(), config *ConnectionInfoMonitorConfig) (cancel context.CancelFunc) { + interval := ConnectionCheckInterval + threshold := ConnectionChecksThreshold + if config != nil { + if config.CheckInterval > 0 { + interval = config.CheckInterval + } + if config.ChecksThreshold > 0 { + threshold = config.ChecksThreshold + } + } + ctx, cancel = context.WithCancel(ctx) + go func() { + defer onStopped() + ticker := time.NewTicker(interval) + defer ticker.Stop() + + var consecutiveFailures, consecutiveSuccesses int + metricRegistered := true + for { + if err := db.PingContext(ctx); err != nil { + consecutiveFailures++ + consecutiveSuccesses = 0 + if metricRegistered && consecutiveFailures >= threshold { + registry.Unregister(infoMetric) + metricRegistered = false + consecutiveFailures = 0 + } + } else { + consecutiveFailures = 0 + if metricRegistered { + consecutiveSuccesses = 0 + } else { + consecutiveSuccesses++ + if consecutiveSuccesses >= threshold { + registry.MustRegister(infoMetric) + infoMetric.WithLabelValues(labelValues...).Set(1) + metricRegistered = true + consecutiveSuccesses = 0 + } + } + } + + select { + case <-ctx.Done(): + return + case <-ticker.C: + // continue loop + } + } + }() + return cancel +} diff --git a/internal/component/database_observability/connection_info_monitor_test.go b/internal/component/database_observability/connection_info_monitor_test.go new file mode 100644 index 00000000000..6f71d030d99 --- /dev/null +++ b/internal/component/database_observability/connection_info_monitor_test.go @@ -0,0 +1,219 @@ +package database_observability + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" +) + +const testCheckInterval = 15 * time.Millisecond +const testThreshold = 3 + +func TestRunConnectionInfoMonitor_UnregistersAfterConsecutiveFailures(t *testing.T) { + defer goleak.VerifyNone(t) + + db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) + require.NoError(t, err) + defer db.Close() + + registry := prometheus.NewRegistry() + infoMetric := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "database_observability", + Name: "connection_info", + Help: "Information about the connection", + }, []string{"provider_name", "provider_region", "provider_account", "db_instance_identifier", "engine", "engine_version"}) + require.NoError(t, registry.Register(infoMetric)) + + labelValues := []string{"aws", "us-east-1", "123456789", "my-db", "postgres", "15.0"} + infoMetric.WithLabelValues(labelValues...).Set(1) + + // Expect 3 pings, all failing + pingErr := errors.New("connection refused") + for i := 0; i < testThreshold; i++ { + mock.ExpectPing().WillReturnError(pingErr) + } + + ctx := context.Background() + onStopped := func() {} + config := &ConnectionInfoMonitorConfig{ + CheckInterval: testCheckInterval, + ChecksThreshold: testThreshold, + } + cancel := RunConnectionInfoMonitor(ctx, db, registry, infoMetric, labelValues, onStopped, config) + defer cancel() + + // Wait for at least 3 tick intervals so the monitor performs 3 failed pings and unregisters + time.Sleep(testCheckInterval*time.Duration(testThreshold) + 20*time.Millisecond) + + // Metric should have been unregistered (not present in gather) + metrics, err := registry.Gather() + require.NoError(t, err) + var found bool + for _, mf := range metrics { + if mf.GetName() == "database_observability_connection_info" { + found = true + break + } + } + require.False(t, found, "metric should be unregistered after %d consecutive ping failures", testThreshold) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestRunConnectionInfoMonitor_ReregistersAfterConsecutiveSuccesses(t *testing.T) { + defer goleak.VerifyNone(t) + + db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) + require.NoError(t, err) + defer db.Close() + + registry := prometheus.NewRegistry() + infoMetric := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "database_observability", + Name: "connection_info", + Help: "Information about the connection", + }, []string{"provider_name", "provider_region", "provider_account", "db_instance_identifier", "engine", "engine_version"}) + require.NoError(t, registry.Register(infoMetric)) + + labelValues := []string{"aws", "us-east-1", "123456789", "my-db", "mysql", "8.0.32"} + infoMetric.WithLabelValues(labelValues...).Set(1) + + // First 3 pings fail (metric gets unregistered), then many succeed (metric re-registers and stays up). + // Extra success expectations prevent sqlmock from returning errors for pings that occur while + // require.Eventually polls, which would re-trigger the failure threshold and unregister again. + pingErr := errors.New("connection refused") + for i := 0; i < testThreshold; i++ { + mock.ExpectPing().WillReturnError(pingErr) + } + for i := 0; i < 30; i++ { + mock.ExpectPing() + } + + ctx := context.Background() + onStopped := func() {} + config := &ConnectionInfoMonitorConfig{ + CheckInterval: testCheckInterval, + ChecksThreshold: testThreshold, + } + cancel := RunConnectionInfoMonitor(ctx, db, registry, infoMetric, labelValues, onStopped, config) + defer cancel() + + // Poll until the metric is re-registered rather than sleeping a fixed duration, which is + // unreliable: extra pings after the mock expectations are exhausted return errors, causing + // the failure threshold to be hit again and the metric to be unregistered before we check. + var mf *dto.MetricFamily + require.Eventually(t, func() bool { + metrics, err := registry.Gather() + if err != nil { + return false + } + for _, m := range metrics { + if m.GetName() == "database_observability_connection_info" { + mf = m + return true + } + } + return false + }, 2*time.Second, testCheckInterval, "metric should be re-registered after %d consecutive successes", testThreshold) + + require.Len(t, mf.Metric, 1, "metric should have one series when present") + require.Equal(t, float64(1), mf.Metric[0].GetGauge().GetValue()) +} + +func TestRunConnectionInfoMonitor_MetricRemainsRegisteredWhilePingsSucceed(t *testing.T) { + defer goleak.VerifyNone(t) + + db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) + require.NoError(t, err) + defer db.Close() + + registry := prometheus.NewRegistry() + infoMetric := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "database_observability", + Name: "connection_info", + Help: "Information about the connection", + }, []string{"provider_name", "provider_region", "provider_account", "db_instance_identifier", "engine", "engine_version"}) + require.NoError(t, registry.Register(infoMetric)) + + labelValues := []string{"unknown", "unknown", "unknown", "unknown", "postgres", "15.0"} + infoMetric.WithLabelValues(labelValues...).Set(1) + + // All pings succeed (allow at least 4 successful pings) + for i := 0; i < 4; i++ { + mock.ExpectPing() + } + + ctx := context.Background() + onStopped := func() {} + config := &ConnectionInfoMonitorConfig{ + CheckInterval: testCheckInterval, + ChecksThreshold: testThreshold, + } + cancel := RunConnectionInfoMonitor(ctx, db, registry, infoMetric, labelValues, onStopped, config) + defer cancel() + + // Wait for a few tick intervals + time.Sleep(testCheckInterval*4 + 20*time.Millisecond) + + // Metric should still be registered with value 1 + metrics, err := registry.Gather() + require.NoError(t, err) + var mf *dto.MetricFamily + for _, m := range metrics { + if m.GetName() == "database_observability_connection_info" { + mf = m + break + } + } + require.NotNil(t, mf, "metric should remain registered while pings succeed") + require.Len(t, mf.Metric, 1) + require.Equal(t, float64(1), mf.Metric[0].GetGauge().GetValue()) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestRunConnectionInfoMonitor_CancelStopsGoroutine(t *testing.T) { + defer goleak.VerifyNone(t) + + db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) + require.NoError(t, err) + defer db.Close() + + registry := prometheus.NewRegistry() + infoMetric := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "database_observability", + Name: "connection_info", + Help: "Information about the connection", + }, []string{"provider_name", "provider_region", "provider_account", "db_instance_identifier", "engine", "engine_version"}) + require.NoError(t, registry.Register(infoMetric)) + + labelValues := []string{"a", "b", "c", "d", "e", "f"} + infoMetric.WithLabelValues(labelValues...).Set(1) + + mock.ExpectPing() // at most one ping before we cancel + + ctx := context.Background() + stopped := make(chan struct{}) + onStopped := func() { close(stopped) } + config := &ConnectionInfoMonitorConfig{ + CheckInterval: testCheckInterval, + ChecksThreshold: testThreshold, + } + cancel := RunConnectionInfoMonitor(ctx, db, registry, infoMetric, labelValues, onStopped, config) + + // Cancel immediately; onStopped should be called when the goroutine exits + cancel() + select { + case <-stopped: + // goroutine exited + case <-time.After(2 * time.Second): + t.Fatal("onStopped was not called after cancel") + } +} diff --git a/internal/component/database_observability/mysql/collector/connection_info.go b/internal/component/database_observability/mysql/collector/connection_info.go index d28e2b39050..01d5e4a690a 100644 --- a/internal/component/database_observability/mysql/collector/connection_info.go +++ b/internal/component/database_observability/mysql/collector/connection_info.go @@ -2,6 +2,7 @@ package collector import ( "context" + "database/sql" "net" "strings" @@ -18,6 +19,7 @@ type ConnectionInfoArguments struct { Registry *prometheus.Registry EngineVersion string CloudProvider *database_observability.CloudProvider + DB *sql.DB } type ConnectionInfo struct { @@ -26,8 +28,10 @@ type ConnectionInfo struct { EngineVersion string InfoMetric *prometheus.GaugeVec CloudProvider *database_observability.CloudProvider + dbConnection *sql.DB running *atomic.Bool + cancel context.CancelFunc } func NewConnectionInfo(args ConnectionInfoArguments) (*ConnectionInfo, error) { @@ -45,6 +49,7 @@ func NewConnectionInfo(args ConnectionInfoArguments) (*ConnectionInfo, error) { EngineVersion: args.EngineVersion, InfoMetric: infoMetric, CloudProvider: args.CloudProvider, + dbConnection: args.DB, running: &atomic.Bool{}, }, nil } @@ -105,7 +110,21 @@ func (c *ConnectionInfo) Start(ctx context.Context) error { } c.running.Store(true) - c.InfoMetric.WithLabelValues(providerName, providerRegion, providerAccount, dbInstanceIdentifier, engine, c.EngineVersion).Set(1) + labelValues := []string{providerName, providerRegion, providerAccount, dbInstanceIdentifier, engine, c.EngineVersion} + c.InfoMetric.WithLabelValues(labelValues...).Set(1) + + if c.dbConnection != nil { + c.cancel = database_observability.RunConnectionInfoMonitor( + ctx, + c.dbConnection, + c.Registry, + c.InfoMetric, + labelValues, + func() { c.running.Store(false) }, + nil, + ) + } + return nil } @@ -114,6 +133,9 @@ func (c *ConnectionInfo) Stopped() bool { } func (c *ConnectionInfo) Stop() { + if c.cancel != nil { + c.cancel() + } c.Registry.Unregister(c.InfoMetric) c.running.Store(false) } diff --git a/internal/component/database_observability/mysql/collector/connection_info_test.go b/internal/component/database_observability/mysql/collector/connection_info_test.go index a3b91ec9958..b3db6be2045 100644 --- a/internal/component/database_observability/mysql/collector/connection_info_test.go +++ b/internal/component/database_observability/mysql/collector/connection_info_test.go @@ -4,7 +4,9 @@ import ( "fmt" "strings" "testing" + "time" + sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" @@ -92,6 +94,7 @@ func TestConnectionInfo(t *testing.T) { Registry: reg, EngineVersion: tc.engineVersion, CloudProvider: tc.cloudProvider, + DB: nil, // no DB in tests: goroutine not started, metric stays set }) require.NoError(t, err) require.NotNil(t, collector) @@ -103,3 +106,100 @@ func TestConnectionInfo(t *testing.T) { require.NoError(t, err) } } + +func TestConnectionInfo_StopUnregistersMetric(t *testing.T) { + defer goleak.VerifyNone(t) + + reg := prometheus.NewRegistry() + col, err := NewConnectionInfo(ConnectionInfoArguments{ + DSN: "user:pass@tcp(localhost:3306)/schema", + Registry: reg, + EngineVersion: "8.0.32", + DB: nil, + }) + require.NoError(t, err) + + err = col.Start(t.Context()) + require.NoError(t, err) + + // metric is present after Start + metrics, err := reg.Gather() + require.NoError(t, err) + var found bool + for _, mf := range metrics { + if mf.GetName() == "database_observability_connection_info" { + found = true + break + } + } + require.True(t, found, "metric should be registered after Start") + + col.Stop() + require.True(t, col.Stopped()) + + // metric is absent after Stop + metrics, err = reg.Gather() + require.NoError(t, err) + found = false + for _, mf := range metrics { + if mf.GetName() == "database_observability_connection_info" { + found = true + break + } + } + require.False(t, found, "metric should be unregistered after Stop") +} + +func TestConnectionInfo_MonitorStartedWithDB(t *testing.T) { + defer goleak.VerifyNone(t) + + db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) + require.NoError(t, err) + defer db.Close() + + // Allow at least one ping before we cancel + mock.ExpectPing() + + reg := prometheus.NewRegistry() + col, err := NewConnectionInfo(ConnectionInfoArguments{ + DSN: "user:pass@tcp(localhost:3306)/schema", + Registry: reg, + EngineVersion: "8.0.32", + DB: db, + }) + require.NoError(t, err) + + err = col.Start(t.Context()) + require.NoError(t, err) + require.False(t, col.Stopped()) + + // Metric is set immediately on Start + metrics, err := reg.Gather() + require.NoError(t, err) + var found bool + for _, mf := range metrics { + if mf.GetName() == "database_observability_connection_info" { + found = true + break + } + } + require.True(t, found, "metric should be registered after Start with DB") + + // Give the monitor goroutine time to perform at least one ping + time.Sleep(50 * time.Millisecond) + + col.Stop() + require.True(t, col.Stopped()) + + // Metric is unregistered after Stop + metrics, err = reg.Gather() + require.NoError(t, err) + found = false + for _, mf := range metrics { + if mf.GetName() == "database_observability_connection_info" { + found = true + break + } + } + require.False(t, found, "metric should be unregistered after Stop") +} diff --git a/internal/component/database_observability/mysql/component.go b/internal/component/database_observability/mysql/component.go index 24252d55123..c828b780ebf 100644 --- a/internal/component/database_observability/mysql/component.go +++ b/internal/component/database_observability/mysql/component.go @@ -685,6 +685,7 @@ func (c *Component) startCollectors(serverID string, engineVersion string, parse Registry: c.registry, EngineVersion: engineVersion, CloudProvider: cloudProviderInfo, + DB: c.dbConnection, }) if err != nil { logStartError(collector.ConnectionInfoName, "create", err) diff --git a/internal/component/database_observability/postgres/collector/connection_info.go b/internal/component/database_observability/postgres/collector/connection_info.go index b64fa3aafe0..70d8eb10bab 100644 --- a/internal/component/database_observability/postgres/collector/connection_info.go +++ b/internal/component/database_observability/postgres/collector/connection_info.go @@ -2,6 +2,7 @@ package collector import ( "context" + "database/sql" "regexp" "strings" @@ -19,6 +20,7 @@ type ConnectionInfoArguments struct { Registry *prometheus.Registry EngineVersion string CloudProvider *database_observability.CloudProvider + DB *sql.DB } type ConnectionInfo struct { @@ -27,8 +29,10 @@ type ConnectionInfo struct { EngineVersion string InfoMetric *prometheus.GaugeVec CloudProvider *database_observability.CloudProvider + dbConnection *sql.DB running *atomic.Bool + cancel context.CancelFunc } func NewConnectionInfo(args ConnectionInfoArguments) (*ConnectionInfo, error) { @@ -46,6 +50,7 @@ func NewConnectionInfo(args ConnectionInfoArguments) (*ConnectionInfo, error) { EngineVersion: args.EngineVersion, InfoMetric: infoMetric, CloudProvider: args.CloudProvider, + dbConnection: args.DB, running: &atomic.Bool{}, }, nil } @@ -111,7 +116,21 @@ func (c *ConnectionInfo) Start(ctx context.Context) error { c.running.Store(true) - c.InfoMetric.WithLabelValues(providerName, providerRegion, providerAccount, dbInstanceIdentifier, engine, engineVersion).Set(1) + labelValues := []string{providerName, providerRegion, providerAccount, dbInstanceIdentifier, engine, engineVersion} + c.InfoMetric.WithLabelValues(labelValues...).Set(1) + + if c.dbConnection != nil { + c.cancel = database_observability.RunConnectionInfoMonitor( + ctx, + c.dbConnection, + c.Registry, + c.InfoMetric, + labelValues, + func() { c.running.Store(false) }, + nil, + ) + } + return nil } @@ -120,6 +139,9 @@ func (c *ConnectionInfo) Stopped() bool { } func (c *ConnectionInfo) Stop() { + if c.cancel != nil { + c.cancel() + } c.Registry.Unregister(c.InfoMetric) c.running.Store(false) } diff --git a/internal/component/database_observability/postgres/collector/connection_info_test.go b/internal/component/database_observability/postgres/collector/connection_info_test.go index 4cd76a02608..3217368be05 100644 --- a/internal/component/database_observability/postgres/collector/connection_info_test.go +++ b/internal/component/database_observability/postgres/collector/connection_info_test.go @@ -4,7 +4,9 @@ import ( "fmt" "strings" "testing" + "time" + sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" @@ -94,6 +96,7 @@ func TestConnectionInfo(t *testing.T) { Registry: reg, EngineVersion: tc.engineVersion, CloudProvider: tc.cloudProvider, + DB: nil, // no DB in tests: goroutine not started, metric stays set }) require.NoError(t, err) require.NotNil(t, collector) @@ -105,3 +108,104 @@ func TestConnectionInfo(t *testing.T) { require.NoError(t, err) } } + +func TestConnectionInfo_StopUnregistersMetric(t *testing.T) { + // The goroutine which deletes expired entries runs indefinitely, + // see https://github.com/hashicorp/golang-lru/blob/v2.0.7/expirable/expirable_lru.go#L79-L80 + defer goleak.VerifyNone(t, goleak.IgnoreTopFunction("github.com/hashicorp/golang-lru/v2/expirable.NewLRU[...].func1")) + + reg := prometheus.NewRegistry() + col, err := NewConnectionInfo(ConnectionInfoArguments{ + DSN: "postgres://user:pass@localhost:5432/mydb", + Registry: reg, + EngineVersion: "15.4", + DB: nil, + }) + require.NoError(t, err) + + err = col.Start(t.Context()) + require.NoError(t, err) + + // metric is present after Start + metrics, err := reg.Gather() + require.NoError(t, err) + var found bool + for _, mf := range metrics { + if mf.GetName() == "database_observability_connection_info" { + found = true + break + } + } + require.True(t, found, "metric should be registered after Start") + + col.Stop() + require.True(t, col.Stopped()) + + // metric is absent after Stop + metrics, err = reg.Gather() + require.NoError(t, err) + found = false + for _, mf := range metrics { + if mf.GetName() == "database_observability_connection_info" { + found = true + break + } + } + require.False(t, found, "metric should be unregistered after Stop") +} + +func TestConnectionInfo_MonitorStartedWithDB(t *testing.T) { + // The goroutine which deletes expired entries runs indefinitely, + // see https://github.com/hashicorp/golang-lru/blob/v2.0.7/expirable/expirable_lru.go#L79-L80 + defer goleak.VerifyNone(t, goleak.IgnoreTopFunction("github.com/hashicorp/golang-lru/v2/expirable.NewLRU[...].func1")) + + db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) + require.NoError(t, err) + defer db.Close() + + // Allow at least one ping before we cancel + mock.ExpectPing() + + reg := prometheus.NewRegistry() + col, err := NewConnectionInfo(ConnectionInfoArguments{ + DSN: "postgres://user:pass@localhost:5432/mydb", + Registry: reg, + EngineVersion: "15.4", + DB: db, + }) + require.NoError(t, err) + + err = col.Start(t.Context()) + require.NoError(t, err) + require.False(t, col.Stopped()) + + // Metric is set immediately on Start + metrics, err := reg.Gather() + require.NoError(t, err) + var found bool + for _, mf := range metrics { + if mf.GetName() == "database_observability_connection_info" { + found = true + break + } + } + require.True(t, found, "metric should be registered after Start with DB") + + // Give the monitor goroutine time to perform at least one ping + time.Sleep(50 * time.Millisecond) + + col.Stop() + require.True(t, col.Stopped()) + + // Metric is unregistered after Stop + metrics, err = reg.Gather() + require.NoError(t, err) + found = false + for _, mf := range metrics { + if mf.GetName() == "database_observability_connection_info" { + found = true + break + } + } + require.False(t, found, "metric should be unregistered after Stop") +} diff --git a/internal/component/database_observability/postgres/component.go b/internal/component/database_observability/postgres/component.go index 1831adc6840..c3cadc7aa82 100644 --- a/internal/component/database_observability/postgres/component.go +++ b/internal/component/database_observability/postgres/component.go @@ -512,6 +512,7 @@ func (c *Component) startCollectors(systemID string, engineVersion string, cloud Registry: c.registry, EngineVersion: engineVersion, CloudProvider: cloudProviderInfo, + DB: c.dbConnection, }) if err != nil { logStartError(collector.ConnectionInfoName, "create", err)