Skip to content

Commit 39c350b

Browse files
committed
fix: prevent metrics leak with cleanup
Keep track all components metrics and unregister them on close
1 parent 3083a9b commit 39c350b

7 files changed

+161
-66
lines changed

async_producer.go

+15-9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/eapache/go-resiliency/breaker"
1212
"github.com/eapache/queue"
13+
"github.com/rcrowley/go-metrics"
1314
)
1415

1516
// AsyncProducer publishes Kafka messages using a non-blocking API. It routes messages
@@ -122,6 +123,8 @@ type asyncProducer struct {
122123
brokerLock sync.Mutex
123124

124125
txnmgr *transactionManager
126+
127+
metricsRegistry metrics.Registry
125128
}
126129

127130
// NewAsyncProducer creates a new AsyncProducer using the given broker addresses and configuration.
@@ -154,15 +157,16 @@ func newAsyncProducer(client Client) (AsyncProducer, error) {
154157
}
155158

156159
p := &asyncProducer{
157-
client: client,
158-
conf: client.Config(),
159-
errors: make(chan *ProducerError),
160-
input: make(chan *ProducerMessage),
161-
successes: make(chan *ProducerMessage),
162-
retries: make(chan *ProducerMessage),
163-
brokers: make(map[*Broker]*brokerProducer),
164-
brokerRefs: make(map[*brokerProducer]int),
165-
txnmgr: txnmgr,
160+
client: client,
161+
conf: client.Config(),
162+
errors: make(chan *ProducerError),
163+
input: make(chan *ProducerMessage),
164+
successes: make(chan *ProducerMessage),
165+
retries: make(chan *ProducerMessage),
166+
brokers: make(map[*Broker]*brokerProducer),
167+
brokerRefs: make(map[*brokerProducer]int),
168+
txnmgr: txnmgr,
169+
metricsRegistry: newCleanupRegistry(client.Config().MetricRegistry),
166170
}
167171

168172
// launch our singleton dispatchers
@@ -1134,6 +1138,8 @@ func (p *asyncProducer) shutdown() {
11341138
close(p.retries)
11351139
close(p.errors)
11361140
close(p.successes)
1141+
1142+
p.metricsRegistry.UnregisterAll()
11371143
}
11381144

11391145
func (p *asyncProducer) bumpIdempotentProducerEpoch() {

broker.go

+22-40
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ type Broker struct {
3333
responses chan *responsePromise
3434
done chan bool
3535

36-
registeredMetrics map[string]struct{}
37-
36+
metricRegistry metrics.Registry
3837
incomingByteRate metrics.Meter
3938
requestRate metrics.Meter
4039
fetchRate metrics.Meter
@@ -174,6 +173,8 @@ func (b *Broker) Open(conf *Config) error {
174173

175174
b.lock.Lock()
176175

176+
b.metricRegistry = newCleanupRegistry(conf.MetricRegistry)
177+
177178
go withRecover(func() {
178179
defer func() {
179180
b.lock.Unlock()
@@ -208,15 +209,15 @@ func (b *Broker) Open(conf *Config) error {
208209
b.conf = conf
209210

210211
// Create or reuse the global metrics shared between brokers
211-
b.incomingByteRate = metrics.GetOrRegisterMeter("incoming-byte-rate", conf.MetricRegistry)
212-
b.requestRate = metrics.GetOrRegisterMeter("request-rate", conf.MetricRegistry)
213-
b.fetchRate = metrics.GetOrRegisterMeter("consumer-fetch-rate", conf.MetricRegistry)
214-
b.requestSize = getOrRegisterHistogram("request-size", conf.MetricRegistry)
215-
b.requestLatency = getOrRegisterHistogram("request-latency-in-ms", conf.MetricRegistry)
216-
b.outgoingByteRate = metrics.GetOrRegisterMeter("outgoing-byte-rate", conf.MetricRegistry)
217-
b.responseRate = metrics.GetOrRegisterMeter("response-rate", conf.MetricRegistry)
218-
b.responseSize = getOrRegisterHistogram("response-size", conf.MetricRegistry)
219-
b.requestsInFlight = metrics.GetOrRegisterCounter("requests-in-flight", conf.MetricRegistry)
212+
b.incomingByteRate = metrics.GetOrRegisterMeter("incoming-byte-rate", b.metricRegistry)
213+
b.requestRate = metrics.GetOrRegisterMeter("request-rate", b.metricRegistry)
214+
b.fetchRate = metrics.GetOrRegisterMeter("consumer-fetch-rate", b.metricRegistry)
215+
b.requestSize = getOrRegisterHistogram("request-size", b.metricRegistry)
216+
b.requestLatency = getOrRegisterHistogram("request-latency-in-ms", b.metricRegistry)
217+
b.outgoingByteRate = metrics.GetOrRegisterMeter("outgoing-byte-rate", b.metricRegistry)
218+
b.responseRate = metrics.GetOrRegisterMeter("response-rate", b.metricRegistry)
219+
b.responseSize = getOrRegisterHistogram("response-size", b.metricRegistry)
220+
b.requestsInFlight = metrics.GetOrRegisterCounter("requests-in-flight", b.metricRegistry)
220221
// Do not gather metrics for seeded broker (only used during bootstrap) because they share
221222
// the same id (-1) and are already exposed through the global metrics above
222223
if b.id >= 0 && !metrics.UseNilMetrics {
@@ -319,7 +320,7 @@ func (b *Broker) Close() error {
319320
b.done = nil
320321
b.responses = nil
321322

322-
b.unregisterMetrics()
323+
b.metricRegistry.UnregisterAll()
323324

324325
if err == nil {
325326
DebugLogger.Printf("Closed connection to broker %s\n", b.addr)
@@ -435,7 +436,7 @@ func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error
435436
return
436437
}
437438

438-
if err := versionedDecode(packets, res, request.version(), b.conf.MetricRegistry); err != nil {
439+
if err := versionedDecode(packets, res, request.version(), b.metricRegistry); err != nil {
439440
// Malformed response
440441
cb(nil, err)
441442
return
@@ -979,7 +980,7 @@ func (b *Broker) sendInternal(rb protocolBody, promise *responsePromise) error {
979980
}
980981

981982
req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
982-
buf, err := encode(req, b.conf.MetricRegistry)
983+
buf, err := encode(req, b.metricRegistry)
983984
if err != nil {
984985
return err
985986
}
@@ -1029,7 +1030,7 @@ func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error {
10291030
func (b *Broker) handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise) error {
10301031
select {
10311032
case buf := <-promise.packets:
1032-
return versionedDecode(buf, res, req.version(), b.conf.MetricRegistry)
1033+
return versionedDecode(buf, res, req.version(), b.metricRegistry)
10331034
case err := <-promise.errors:
10341035
return err
10351036
}
@@ -1121,7 +1122,7 @@ func (b *Broker) responseReceiver() {
11211122
}
11221123

11231124
decodedHeader := responseHeader{}
1124-
err = versionedDecode(header, &decodedHeader, response.headerVersion, b.conf.MetricRegistry)
1125+
err = versionedDecode(header, &decodedHeader, response.headerVersion, b.metricRegistry)
11251126
if err != nil {
11261127
b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
11271128
dead = err
@@ -1243,7 +1244,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
12431244
rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version}
12441245

12451246
req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
1246-
buf, err := encode(req, b.conf.MetricRegistry)
1247+
buf, err := encode(req, b.metricRegistry)
12471248
if err != nil {
12481249
return err
12491250
}
@@ -1280,7 +1281,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
12801281
b.updateIncomingCommunicationMetrics(n+8, time.Since(requestTime))
12811282
res := &SaslHandshakeResponse{}
12821283

1283-
err = versionedDecode(payload, res, 0, b.conf.MetricRegistry)
1284+
err = versionedDecode(payload, res, 0, b.metricRegistry)
12841285
if err != nil {
12851286
Logger.Printf("Failed to parse SASL handshake : %s\n", err.Error())
12861287
return err
@@ -1622,38 +1623,19 @@ func (b *Broker) registerMetrics() {
16221623
b.brokerThrottleTime = b.registerHistogram("throttle-time-in-ms")
16231624
}
16241625

1625-
func (b *Broker) unregisterMetrics() {
1626-
for name := range b.registeredMetrics {
1627-
b.conf.MetricRegistry.Unregister(name)
1628-
}
1629-
b.registeredMetrics = nil
1630-
}
1631-
16321626
func (b *Broker) registerMeter(name string) metrics.Meter {
16331627
nameForBroker := getMetricNameForBroker(name, b)
1634-
if b.registeredMetrics == nil {
1635-
b.registeredMetrics = map[string]struct{}{}
1636-
}
1637-
b.registeredMetrics[nameForBroker] = struct{}{}
1638-
return metrics.GetOrRegisterMeter(nameForBroker, b.conf.MetricRegistry)
1628+
return metrics.GetOrRegisterMeter(nameForBroker, b.metricRegistry)
16391629
}
16401630

16411631
func (b *Broker) registerHistogram(name string) metrics.Histogram {
16421632
nameForBroker := getMetricNameForBroker(name, b)
1643-
if b.registeredMetrics == nil {
1644-
b.registeredMetrics = map[string]struct{}{}
1645-
}
1646-
b.registeredMetrics[nameForBroker] = struct{}{}
1647-
return getOrRegisterHistogram(nameForBroker, b.conf.MetricRegistry)
1633+
return getOrRegisterHistogram(nameForBroker, b.metricRegistry)
16481634
}
16491635

16501636
func (b *Broker) registerCounter(name string) metrics.Counter {
16511637
nameForBroker := getMetricNameForBroker(name, b)
1652-
if b.registeredMetrics == nil {
1653-
b.registeredMetrics = map[string]struct{}{}
1654-
}
1655-
b.registeredMetrics[nameForBroker] = struct{}{}
1656-
return metrics.GetOrRegisterCounter(nameForBroker, b.conf.MetricRegistry)
1638+
return metrics.GetOrRegisterCounter(nameForBroker, b.metricRegistry)
16571639
}
16581640

16591641
func validServerNameTLS(addr string, cfg *tls.Config) *tls.Config {

client_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"syscall"
99
"testing"
1010
"time"
11+
12+
"github.com/rcrowley/go-metrics"
1113
)
1214

1315
func safeClose(t testing.TB, c io.Closer) {
@@ -1096,3 +1098,25 @@ func TestInitProducerIDConnectionRefused(t *testing.T) {
10961098

10971099
safeClose(t, client)
10981100
}
1101+
1102+
func TestMetricsCleanup(t *testing.T) {
1103+
seedBroker := NewMockBroker(t, 1)
1104+
seedBroker.Returns(new(MetadataResponse))
1105+
1106+
config := NewTestConfig()
1107+
metrics.GetOrRegisterMeter("a", config.MetricRegistry)
1108+
1109+
client, err := NewClient([]string{seedBroker.Addr()}, config)
1110+
if err != nil {
1111+
t.Fatal(err)
1112+
}
1113+
safeClose(t, client)
1114+
1115+
// Wait async close
1116+
time.Sleep(10 * time.Millisecond)
1117+
1118+
all := config.MetricRegistry.GetAll()
1119+
if len(all) != 1 || all["a"] == nil {
1120+
t.Errorf("excepted 1 metric, found: %v", all)
1121+
}
1122+
}

consumer.go

+9-8
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ type consumer struct {
104104
children map[string]map[int32]*partitionConsumer
105105
brokerConsumers map[*Broker]*brokerConsumer
106106
client Client
107+
metricRegistry metrics.Registry
107108
lock sync.Mutex
108109
}
109110

@@ -136,12 +137,14 @@ func newConsumer(client Client) (Consumer, error) {
136137
conf: client.Config(),
137138
children: make(map[string]map[int32]*partitionConsumer),
138139
brokerConsumers: make(map[*Broker]*brokerConsumer),
140+
metricRegistry: newCleanupRegistry(client.Config().MetricRegistry),
139141
}
140142

141143
return c, nil
142144
}
143145

144146
func (c *consumer) Close() error {
147+
c.metricRegistry.UnregisterAll()
145148
return c.client.Close()
146149
}
147150

@@ -678,13 +681,9 @@ func (child *partitionConsumer) parseRecords(batch *RecordBatch) ([]*ConsumerMes
678681
}
679682

680683
func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*ConsumerMessage, error) {
681-
var (
682-
metricRegistry = child.conf.MetricRegistry
683-
consumerBatchSizeMetric metrics.Histogram
684-
)
685-
686-
if metricRegistry != nil {
687-
consumerBatchSizeMetric = getOrRegisterHistogram("consumer-batch-size", metricRegistry)
684+
var consumerBatchSizeMetric metrics.Histogram
685+
if child.consumer != nil && child.consumer.metricRegistry != nil {
686+
consumerBatchSizeMetric = getOrRegisterHistogram("consumer-batch-size", child.consumer.metricRegistry)
688687
}
689688

690689
// If request was throttled and empty we log and return without error
@@ -709,7 +708,9 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
709708
return nil, err
710709
}
711710

712-
consumerBatchSizeMetric.Update(int64(nRecs))
711+
if consumerBatchSizeMetric != nil {
712+
consumerBatchSizeMetric.Update(int64(nRecs))
713+
}
713714

714715
if block.PreferredReadReplica != invalidPreferredReplicaID {
715716
child.preferredReadReplica = block.PreferredReadReplica

consumer_group.go

+13-8
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ type consumerGroup struct {
9191
closeOnce sync.Once
9292

9393
userData []byte
94+
95+
metricRegistry metrics.Registry
9496
}
9597

9698
// NewConsumerGroup creates a new consumer group the given broker addresses and configuration.
@@ -129,13 +131,14 @@ func newConsumerGroup(groupID string, client Client) (ConsumerGroup, error) {
129131
}
130132

131133
cg := &consumerGroup{
132-
client: client,
133-
consumer: consumer,
134-
config: config,
135-
groupID: groupID,
136-
errors: make(chan error, config.ChannelBufferSize),
137-
closed: make(chan none),
138-
userData: config.Consumer.Group.Member.UserData,
134+
client: client,
135+
consumer: consumer,
136+
config: config,
137+
groupID: groupID,
138+
errors: make(chan error, config.ChannelBufferSize),
139+
closed: make(chan none),
140+
userData: config.Consumer.Group.Member.UserData,
141+
metricRegistry: newCleanupRegistry(config.MetricRegistry),
139142
}
140143
if client.Config().Consumer.Group.InstanceId != "" && config.Version.IsAtLeast(V2_3_0_0) {
141144
cg.groupInstanceId = &client.Config().Consumer.Group.InstanceId
@@ -167,6 +170,8 @@ func (c *consumerGroup) Close() (err error) {
167170
if e := c.client.Close(); e != nil {
168171
err = e
169172
}
173+
174+
c.metricRegistry.UnregisterAll()
170175
})
171176
return
172177
}
@@ -261,7 +266,7 @@ func (c *consumerGroup) newSession(ctx context.Context, topics []string, handler
261266
}
262267

263268
var (
264-
metricRegistry = c.config.MetricRegistry
269+
metricRegistry = c.metricRegistry
265270
consumerGroupJoinTotal metrics.Counter
266271
consumerGroupJoinFailed metrics.Counter
267272
consumerGroupSyncTotal metrics.Counter

0 commit comments

Comments
 (0)