Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package database_observability

import (
"context"
"database/sql"
"time"

"github.com/prometheus/client_golang/prometheus"
)

// ConnectionCheckInterval is how often the connection_info collector pings the DB to verify connectivity.
const ConnectionCheckInterval = 60 * time.Second

// ConnectionChecksThreshold is the number of consecutive failed pings before unregistering the metric,
// and the number of consecutive successful pings before re-registering it after a disconnect.
const ConnectionChecksThreshold = 3

// ConnectionInfoMonitorConfig optionally overrides the default check interval and threshold.
// Used by tests to run the monitor with shorter intervals. If nil, defaults are used.
type ConnectionInfoMonitorConfig struct {
CheckInterval time.Duration
ChecksThreshold int
}

// RunConnectionInfoMonitor starts a goroutine that pings db every ConnectionCheckInterval.
// After ConnectionChecksThreshold consecutive ping failures it unregisters infoMetric from registry.
// After ConnectionChecksThreshold consecutive ping successes (when the metric is unregistered) it re-registers
// infoMetric and sets it to 1 with the given labelValues.
// The goroutine runs until ctx is done. onStopped is called when the goroutine exits (e.g. when ctx is cancelled).
// RunConnectionInfoMonitor returns a cancel function that cancels the context passed to the goroutine; the caller
// should call cancel in Stop() to ensure the goroutine exits.
// labelValues must contain exactly 6 values in order: provider_name, provider_region, provider_account,
// db_instance_identifier, engine, engine_version.
// If config is non-nil, its CheckInterval and ChecksThreshold override the default constants (used for testing).
func RunConnectionInfoMonitor(ctx context.Context, db *sql.DB, registry *prometheus.Registry, infoMetric *prometheus.GaugeVec, labelValues []string, onStopped func(), config *ConnectionInfoMonitorConfig) (cancel context.CancelFunc) {
interval := ConnectionCheckInterval
threshold := ConnectionChecksThreshold
if config != nil {
if config.CheckInterval > 0 {
interval = config.CheckInterval
}
if config.ChecksThreshold > 0 {
threshold = config.ChecksThreshold
}
}
ctx, cancel = context.WithCancel(ctx)
go func() {
defer onStopped()
ticker := time.NewTicker(interval)
defer ticker.Stop()

var consecutiveFailures, consecutiveSuccesses int
metricRegistered := true
for {
if err := db.PingContext(ctx); err != nil {
consecutiveFailures++
consecutiveSuccesses = 0
if metricRegistered && consecutiveFailures >= threshold {
registry.Unregister(infoMetric)
metricRegistered = false
consecutiveFailures = 0
}
} else {
consecutiveFailures = 0
if metricRegistered {
consecutiveSuccesses = 0
} else {
consecutiveSuccesses++
if consecutiveSuccesses >= threshold {
registry.MustRegister(infoMetric)
infoMetric.WithLabelValues(labelValues...).Set(1)
metricRegistered = true
consecutiveSuccesses = 0
}
}
}

select {
case <-ctx.Done():
return
case <-ticker.C:
// continue loop
}
}
}()
return cancel
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package database_observability

import (
"context"
"errors"
"testing"
"time"

"github.com/DATA-DOG/go-sqlmock"
"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)

const testCheckInterval = 15 * time.Millisecond
const testThreshold = 3

func TestRunConnectionInfoMonitor_UnregistersAfterConsecutiveFailures(t *testing.T) {
defer goleak.VerifyNone(t)

db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
require.NoError(t, err)
defer db.Close()

registry := prometheus.NewRegistry()
infoMetric := prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "database_observability",
Name: "connection_info",
Help: "Information about the connection",
}, []string{"provider_name", "provider_region", "provider_account", "db_instance_identifier", "engine", "engine_version"})
require.NoError(t, registry.Register(infoMetric))

labelValues := []string{"aws", "us-east-1", "123456789", "my-db", "postgres", "15.0"}
infoMetric.WithLabelValues(labelValues...).Set(1)

// Expect 3 pings, all failing
pingErr := errors.New("connection refused")
for i := 0; i < testThreshold; i++ {
mock.ExpectPing().WillReturnError(pingErr)
}

ctx := context.Background()
onStopped := func() {}
config := &ConnectionInfoMonitorConfig{
CheckInterval: testCheckInterval,
ChecksThreshold: testThreshold,
}
cancel := RunConnectionInfoMonitor(ctx, db, registry, infoMetric, labelValues, onStopped, config)
defer cancel()

// Wait for at least 3 tick intervals so the monitor performs 3 failed pings and unregisters
time.Sleep(testCheckInterval*time.Duration(testThreshold) + 20*time.Millisecond)

// Metric should have been unregistered (not present in gather)
metrics, err := registry.Gather()
require.NoError(t, err)
var found bool
for _, mf := range metrics {
if mf.GetName() == "database_observability_connection_info" {
found = true
break
}
}
require.False(t, found, "metric should be unregistered after %d consecutive ping failures", testThreshold)

require.NoError(t, mock.ExpectationsWereMet())
}

func TestRunConnectionInfoMonitor_ReregistersAfterConsecutiveSuccesses(t *testing.T) {
defer goleak.VerifyNone(t)

db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
require.NoError(t, err)
defer db.Close()

registry := prometheus.NewRegistry()
infoMetric := prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "database_observability",
Name: "connection_info",
Help: "Information about the connection",
}, []string{"provider_name", "provider_region", "provider_account", "db_instance_identifier", "engine", "engine_version"})
require.NoError(t, registry.Register(infoMetric))

labelValues := []string{"aws", "us-east-1", "123456789", "my-db", "mysql", "8.0.32"}
infoMetric.WithLabelValues(labelValues...).Set(1)

// First 3 pings fail (metric gets unregistered), then many succeed (metric re-registers and stays up).
// Extra success expectations prevent sqlmock from returning errors for pings that occur while
// require.Eventually polls, which would re-trigger the failure threshold and unregister again.
pingErr := errors.New("connection refused")
for i := 0; i < testThreshold; i++ {
mock.ExpectPing().WillReturnError(pingErr)
}
for i := 0; i < 30; i++ {
mock.ExpectPing()
}

ctx := context.Background()
onStopped := func() {}
config := &ConnectionInfoMonitorConfig{
CheckInterval: testCheckInterval,
ChecksThreshold: testThreshold,
}
cancel := RunConnectionInfoMonitor(ctx, db, registry, infoMetric, labelValues, onStopped, config)
defer cancel()

// Poll until the metric is re-registered rather than sleeping a fixed duration, which is
// unreliable: extra pings after the mock expectations are exhausted return errors, causing
// the failure threshold to be hit again and the metric to be unregistered before we check.
var mf *dto.MetricFamily
require.Eventually(t, func() bool {
metrics, err := registry.Gather()
if err != nil {
return false
}
for _, m := range metrics {
if m.GetName() == "database_observability_connection_info" {
mf = m
return true
}
}
return false
}, 2*time.Second, testCheckInterval, "metric should be re-registered after %d consecutive successes", testThreshold)

require.Len(t, mf.Metric, 1, "metric should have one series when present")
require.Equal(t, float64(1), mf.Metric[0].GetGauge().GetValue())
}

func TestRunConnectionInfoMonitor_MetricRemainsRegisteredWhilePingsSucceed(t *testing.T) {
defer goleak.VerifyNone(t)

db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
require.NoError(t, err)
defer db.Close()

registry := prometheus.NewRegistry()
infoMetric := prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "database_observability",
Name: "connection_info",
Help: "Information about the connection",
}, []string{"provider_name", "provider_region", "provider_account", "db_instance_identifier", "engine", "engine_version"})
require.NoError(t, registry.Register(infoMetric))

labelValues := []string{"unknown", "unknown", "unknown", "unknown", "postgres", "15.0"}
infoMetric.WithLabelValues(labelValues...).Set(1)

// All pings succeed (allow at least 4 successful pings)
for i := 0; i < 4; i++ {
mock.ExpectPing()
}

ctx := context.Background()
onStopped := func() {}
config := &ConnectionInfoMonitorConfig{
CheckInterval: testCheckInterval,
ChecksThreshold: testThreshold,
}
cancel := RunConnectionInfoMonitor(ctx, db, registry, infoMetric, labelValues, onStopped, config)
defer cancel()

// Wait for a few tick intervals
time.Sleep(testCheckInterval*4 + 20*time.Millisecond)

// Metric should still be registered with value 1
metrics, err := registry.Gather()
require.NoError(t, err)
var mf *dto.MetricFamily
for _, m := range metrics {
if m.GetName() == "database_observability_connection_info" {
mf = m
break
}
}
require.NotNil(t, mf, "metric should remain registered while pings succeed")
require.Len(t, mf.Metric, 1)
require.Equal(t, float64(1), mf.Metric[0].GetGauge().GetValue())

require.NoError(t, mock.ExpectationsWereMet())
}

func TestRunConnectionInfoMonitor_CancelStopsGoroutine(t *testing.T) {
defer goleak.VerifyNone(t)

db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
require.NoError(t, err)
defer db.Close()

registry := prometheus.NewRegistry()
infoMetric := prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "database_observability",
Name: "connection_info",
Help: "Information about the connection",
}, []string{"provider_name", "provider_region", "provider_account", "db_instance_identifier", "engine", "engine_version"})
require.NoError(t, registry.Register(infoMetric))

labelValues := []string{"a", "b", "c", "d", "e", "f"}
infoMetric.WithLabelValues(labelValues...).Set(1)

mock.ExpectPing() // at most one ping before we cancel

ctx := context.Background()
stopped := make(chan struct{})
onStopped := func() { close(stopped) }
config := &ConnectionInfoMonitorConfig{
CheckInterval: testCheckInterval,
ChecksThreshold: testThreshold,
}
cancel := RunConnectionInfoMonitor(ctx, db, registry, infoMetric, labelValues, onStopped, config)

// Cancel immediately; onStopped should be called when the goroutine exits
cancel()
select {
case <-stopped:
// goroutine exited
case <-time.After(2 * time.Second):
t.Fatal("onStopped was not called after cancel")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package collector

import (
"context"
"database/sql"
"net"
"strings"

Expand All @@ -18,6 +19,7 @@ type ConnectionInfoArguments struct {
Registry *prometheus.Registry
EngineVersion string
CloudProvider *database_observability.CloudProvider
DB *sql.DB
}

type ConnectionInfo struct {
Expand All @@ -26,8 +28,10 @@ type ConnectionInfo struct {
EngineVersion string
InfoMetric *prometheus.GaugeVec
CloudProvider *database_observability.CloudProvider
dbConnection *sql.DB

running *atomic.Bool
cancel context.CancelFunc
}

func NewConnectionInfo(args ConnectionInfoArguments) (*ConnectionInfo, error) {
Expand All @@ -45,6 +49,7 @@ func NewConnectionInfo(args ConnectionInfoArguments) (*ConnectionInfo, error) {
EngineVersion: args.EngineVersion,
InfoMetric: infoMetric,
CloudProvider: args.CloudProvider,
dbConnection: args.DB,
running: &atomic.Bool{},
}, nil
}
Expand Down Expand Up @@ -105,7 +110,21 @@ func (c *ConnectionInfo) Start(ctx context.Context) error {
}
c.running.Store(true)

c.InfoMetric.WithLabelValues(providerName, providerRegion, providerAccount, dbInstanceIdentifier, engine, c.EngineVersion).Set(1)
labelValues := []string{providerName, providerRegion, providerAccount, dbInstanceIdentifier, engine, c.EngineVersion}
c.InfoMetric.WithLabelValues(labelValues...).Set(1)

if c.dbConnection != nil {
c.cancel = database_observability.RunConnectionInfoMonitor(
ctx,
c.dbConnection,
c.Registry,
c.InfoMetric,
labelValues,
func() { c.running.Store(false) },
nil,
)
}

return nil
}

Expand All @@ -114,6 +133,9 @@ func (c *ConnectionInfo) Stopped() bool {
}

func (c *ConnectionInfo) Stop() {
if c.cancel != nil {
c.cancel()
}
c.Registry.Unregister(c.InfoMetric)
c.running.Store(false)
}
Loading
Loading