From 7403f159ccb6b8feb17b2c4b4df08794323fb24a Mon Sep 17 00:00:00 2001 From: Matt Nolf Date: Wed, 11 Mar 2026 11:24:20 +0000 Subject: [PATCH] unregister connection_info metric when DB connection disconnected Co-Authored-By: Claude Sonnet 4.6 --- .../mysql/collector/connection_info.go | 77 ++++++++++++++++--- .../mysql/collector/connection_info_test.go | 10 +++ .../database_observability/mysql/component.go | 1 + .../postgres/collector/connection_info.go | 75 +++++++++++++++--- .../collector/connection_info_test.go | 10 +++ .../postgres/component.go | 1 + 6 files changed, 155 insertions(+), 19 deletions(-) diff --git a/internal/component/database_observability/mysql/collector/connection_info.go b/internal/component/database_observability/mysql/collector/connection_info.go index d28e2b39050..5bcbc2d24d9 100644 --- a/internal/component/database_observability/mysql/collector/connection_info.go +++ b/internal/component/database_observability/mysql/collector/connection_info.go @@ -2,8 +2,10 @@ package collector import ( "context" + "database/sql" "net" "strings" + "time" "github.com/go-sql-driver/mysql" "github.com/grafana/alloy/internal/component/database_observability" @@ -11,9 +13,13 @@ import ( "go.uber.org/atomic" ) -const ConnectionInfoName = "connection_info" +const ( + ConnectionInfoName = "connection_info" + connectionInfoInterval = 5 * time.Minute +) type ConnectionInfoArguments struct { + DB *sql.DB DSN string Registry *prometheus.Registry EngineVersion string @@ -21,6 +27,7 @@ type ConnectionInfoArguments struct { } type ConnectionInfo struct { + dbConnection *sql.DB DSN string Registry *prometheus.Registry EngineVersion string @@ -28,6 +35,8 @@ type ConnectionInfo struct { CloudProvider *database_observability.CloudProvider running *atomic.Bool + ctx context.Context + cancel context.CancelFunc } func NewConnectionInfo(args ConnectionInfoArguments) (*ConnectionInfo, error) { @@ -37,9 +46,8 @@ func NewConnectionInfo(args ConnectionInfoArguments) (*ConnectionInfo, error) { Help: "Information about the connection", }, []string{"provider_name", "provider_region", "provider_account", "db_instance_identifier", "engine", "engine_version"}) - args.Registry.MustRegister(infoMetric) - return &ConnectionInfo{ + dbConnection: args.DB, DSN: args.DSN, Registry: args.Registry, EngineVersion: args.EngineVersion, @@ -54,6 +62,41 @@ func (c *ConnectionInfo) Name() string { } func (c *ConnectionInfo) Start(ctx context.Context) error { + labels, err := c.buildLabels() + if err != nil { + return err + } + + c.running.Store(true) + ctx, cancel := context.WithCancel(ctx) + c.ctx = ctx + c.cancel = cancel + + c.ping(ctx, labels) + + go func() { + defer func() { + c.Registry.Unregister(c.InfoMetric) + c.running.Store(false) + }() + + ticker := time.NewTicker(connectionInfoInterval) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + c.ping(c.ctx, labels) + } + } + }() + + return nil +} + +func (c *ConnectionInfo) buildLabels() (prometheus.Labels, error) { var ( providerName = "unknown" providerRegion = "unknown" @@ -82,9 +125,8 @@ func (c *ConnectionInfo) Start(ctx context.Context) error { } else { cfg, err := mysql.ParseDSN(c.DSN) if err != nil { - return err + return nil, err } - host, _, err := net.SplitHostPort(cfg.Addr) if err == nil && host != "" { if strings.HasSuffix(host, "rds.amazonaws.com") { @@ -103,10 +145,24 @@ func (c *ConnectionInfo) Start(ctx context.Context) error { } } } - c.running.Store(true) - c.InfoMetric.WithLabelValues(providerName, providerRegion, providerAccount, dbInstanceIdentifier, engine, c.EngineVersion).Set(1) - return nil + return prometheus.Labels{ + "provider_name": providerName, + "provider_region": providerRegion, + "provider_account": providerAccount, + "db_instance_identifier": dbInstanceIdentifier, + "engine": engine, + "engine_version": c.EngineVersion, + }, nil +} + +func (c *ConnectionInfo) ping(ctx context.Context, labels prometheus.Labels) { + if err := c.dbConnection.PingContext(ctx); err != nil { + c.Registry.Unregister(c.InfoMetric) + return + } + _ = c.Registry.Register(c.InfoMetric) + c.InfoMetric.With(labels).Set(1) } func (c *ConnectionInfo) Stopped() bool { @@ -114,6 +170,7 @@ func (c *ConnectionInfo) Stopped() bool { } func (c *ConnectionInfo) Stop() { - c.Registry.Unregister(c.InfoMetric) - c.running.Store(false) + if c.cancel != nil { + c.cancel() + } } diff --git a/internal/component/database_observability/mysql/collector/connection_info_test.go b/internal/component/database_observability/mysql/collector/connection_info_test.go index a3b91ec9958..30b6b58da54 100644 --- a/internal/component/database_observability/mysql/collector/connection_info_test.go +++ b/internal/component/database_observability/mysql/collector/connection_info_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/DATA-DOG/go-sqlmock" "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" @@ -85,9 +86,16 @@ func TestConnectionInfo(t *testing.T) { } for _, tc := range testCases { + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), sqlmock.MonitorPingsOption(true)) + require.NoError(t, err) + defer db.Close() + + mock.ExpectPing() + reg := prometheus.NewRegistry() collector, err := NewConnectionInfo(ConnectionInfoArguments{ + DB: db, DSN: tc.dsn, Registry: reg, EngineVersion: tc.engineVersion, @@ -98,8 +106,10 @@ func TestConnectionInfo(t *testing.T) { err = collector.Start(t.Context()) require.NoError(t, err) + defer collector.Stop() err = testutil.GatherAndCompare(reg, strings.NewReader(tc.expectedMetrics)) require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) } } diff --git a/internal/component/database_observability/mysql/component.go b/internal/component/database_observability/mysql/component.go index d446fc47e70..8cf9ec683bd 100644 --- a/internal/component/database_observability/mysql/component.go +++ b/internal/component/database_observability/mysql/component.go @@ -634,6 +634,7 @@ func (c *Component) startCollectors(serverID string, engineVersion string, parse // Connection Info collector is always enabled ciCollector, err := collector.NewConnectionInfo(collector.ConnectionInfoArguments{ + DB: c.dbConnection, DSN: string(c.args.DataSourceName), Registry: c.registry, EngineVersion: engineVersion, diff --git a/internal/component/database_observability/postgres/collector/connection_info.go b/internal/component/database_observability/postgres/collector/connection_info.go index b64fa3aafe0..e4b0bdc0964 100644 --- a/internal/component/database_observability/postgres/collector/connection_info.go +++ b/internal/component/database_observability/postgres/collector/connection_info.go @@ -2,19 +2,25 @@ package collector import ( "context" + "database/sql" "regexp" "strings" + "time" "github.com/grafana/alloy/internal/component/database_observability" "github.com/prometheus/client_golang/prometheus" "go.uber.org/atomic" ) -const ConnectionInfoName = "connection_info" +const ( + ConnectionInfoName = "connection_info" + connectionInfoInterval = 5 * time.Minute +) var engineVersionRegex = regexp.MustCompile(`(?P^[1-9]+\.[1-9]+)(?P.*)?$`) type ConnectionInfoArguments struct { + DB *sql.DB DSN string Registry *prometheus.Registry EngineVersion string @@ -22,6 +28,7 @@ type ConnectionInfoArguments struct { } type ConnectionInfo struct { + dbConnection *sql.DB DSN string Registry *prometheus.Registry EngineVersion string @@ -29,6 +36,8 @@ type ConnectionInfo struct { CloudProvider *database_observability.CloudProvider running *atomic.Bool + ctx context.Context + cancel context.CancelFunc } func NewConnectionInfo(args ConnectionInfoArguments) (*ConnectionInfo, error) { @@ -38,9 +47,8 @@ func NewConnectionInfo(args ConnectionInfoArguments) (*ConnectionInfo, error) { Help: "Information about the connection", }, []string{"provider_name", "provider_region", "provider_account", "db_instance_identifier", "engine", "engine_version"}) - args.Registry.MustRegister(infoMetric) - return &ConnectionInfo{ + dbConnection: args.DB, DSN: args.DSN, Registry: args.Registry, EngineVersion: args.EngineVersion, @@ -55,6 +63,41 @@ func (c *ConnectionInfo) Name() string { } func (c *ConnectionInfo) Start(ctx context.Context) error { + labels, err := c.buildLabels() + if err != nil { + return err + } + + c.running.Store(true) + ctx, cancel := context.WithCancel(ctx) + c.ctx = ctx + c.cancel = cancel + + c.ping(ctx, labels) + + go func() { + defer func() { + c.Registry.Unregister(c.InfoMetric) + c.running.Store(false) + }() + + ticker := time.NewTicker(connectionInfoInterval) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + c.ping(c.ctx, labels) + } + } + }() + + return nil +} + +func (c *ConnectionInfo) buildLabels() (prometheus.Labels, error) { var ( providerName = "unknown" providerRegion = "unknown" @@ -84,7 +127,7 @@ func (c *ConnectionInfo) Start(ctx context.Context) error { } else { parts, err := ParseURL(c.DSN) if err != nil { - return err + return nil, err } if host, ok := parts["host"]; ok { if strings.HasSuffix(host, "rds.amazonaws.com") { @@ -109,10 +152,23 @@ func (c *ConnectionInfo) Start(ctx context.Context) error { engineVersion = matches[1] } - c.running.Store(true) + return prometheus.Labels{ + "provider_name": providerName, + "provider_region": providerRegion, + "provider_account": providerAccount, + "db_instance_identifier": dbInstanceIdentifier, + "engine": engine, + "engine_version": engineVersion, + }, nil +} - c.InfoMetric.WithLabelValues(providerName, providerRegion, providerAccount, dbInstanceIdentifier, engine, engineVersion).Set(1) - return nil +func (c *ConnectionInfo) ping(ctx context.Context, labels prometheus.Labels) { + if err := c.dbConnection.PingContext(ctx); err != nil { + c.Registry.Unregister(c.InfoMetric) + return + } + _ = c.Registry.Register(c.InfoMetric) + c.InfoMetric.With(labels).Set(1) } func (c *ConnectionInfo) Stopped() bool { @@ -120,6 +176,7 @@ func (c *ConnectionInfo) Stopped() bool { } func (c *ConnectionInfo) Stop() { - c.Registry.Unregister(c.InfoMetric) - c.running.Store(false) + if c.cancel != nil { + c.cancel() + } } diff --git a/internal/component/database_observability/postgres/collector/connection_info_test.go b/internal/component/database_observability/postgres/collector/connection_info_test.go index 4cd76a02608..1012a93be5d 100644 --- a/internal/component/database_observability/postgres/collector/connection_info_test.go +++ b/internal/component/database_observability/postgres/collector/connection_info_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/DATA-DOG/go-sqlmock" "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" @@ -87,9 +88,16 @@ func TestConnectionInfo(t *testing.T) { } for _, tc := range testCases { + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), sqlmock.MonitorPingsOption(true)) + require.NoError(t, err) + defer db.Close() + + mock.ExpectPing() + reg := prometheus.NewRegistry() collector, err := NewConnectionInfo(ConnectionInfoArguments{ + DB: db, DSN: tc.dsn, Registry: reg, EngineVersion: tc.engineVersion, @@ -100,8 +108,10 @@ func TestConnectionInfo(t *testing.T) { err = collector.Start(t.Context()) require.NoError(t, err) + defer collector.Stop() err = testutil.GatherAndCompare(reg, strings.NewReader(tc.expectedMetrics)) require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) } } diff --git a/internal/component/database_observability/postgres/component.go b/internal/component/database_observability/postgres/component.go index 1831adc6840..40117a55b99 100644 --- a/internal/component/database_observability/postgres/component.go +++ b/internal/component/database_observability/postgres/component.go @@ -508,6 +508,7 @@ func (c *Component) startCollectors(systemID string, engineVersion string, cloud // Connection Info collector is always enabled ciCollector, err := collector.NewConnectionInfo(collector.ConnectionInfoArguments{ + DB: c.dbConnection, DSN: string(c.args.DataSourceName), Registry: c.registry, EngineVersion: engineVersion,