Skip to content

Commit b0eda59

Browse files
fix(metrics): fix race when accessing metric registry (#2409)
A race condition was introduced in 5b04c98 (feat(metrics): track consumer-fetch-response-size) when passing the metric registry around to get additional metrics. Notably, `handleResponsePromise()` could access the registry after the broker has been closed and is tentatively being reopened. This triggers a data race because `b.metricRegistry` is being set during `Open()` (as it is part of the configuration). We fix this by reverting the addition of `handleResponsePromise()` as a method to `Broker`. Instead, we provide it with the metric registry as an argument. An alternative would have been to get the metric registry before the `select` call. However, removing it as a method make it clearer than this function is not allowed to access the broker internals as they are not protected by the lock and the broker may not be alive any more. All the following calls to `b.metricRegistry` are done while the lock is held: - inside `Open()`, the lock is held, including inside the goroutine - inside `Close()`, the lock is held - `AsyncProduce()` has a contract that it must be called while the broker is open, we keep a copy of the metric registry to use inside the callback - `sendInternal()`, has a contract that the lock should be held - `authenticateViaSASLv1()` is called from `Open()` and `sendWithPromise()`, both of them holding the lock - `sendAndReceiveSASLHandshake()` is called from - `authenticateViaSASLv0/v1()`, which are called from `Open()` and `sendWithPromise()` I am unsure about `responseReceiver()`, however, it is also calling `b.readFull()` which accesses `b.conn`, so I suppose it is safe. This leaves `sendAndReceive()` which is calling `send()`, which is calling `sendWithPromise()` which puts a lock. We move the lock to `sendAndReceive()` instead. `send()` is only called from `sendAndReceiver()` and we put a lock for `sendWithPromise()` other caller. The test has been stolen from #2393 from @samuelhewitt. #2393 is an alternative proposal using a RW lock to protect `b.metricRegistry`. Fix #2320
1 parent 67d977b commit b0eda59

File tree

3 files changed

+122
-10
lines changed

3 files changed

+122
-10
lines changed

broker.go

+14-9
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ type ProduceCallback func(*ProduceResponse, error)
429429
//
430430
// Make sure not to Close the broker in the callback as it will lead to a deadlock.
431431
func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error {
432+
metricRegistry := b.metricRegistry
432433
needAcks := request.RequiredAcks != NoResponse
433434
// Use a nil promise when no acks is required
434435
var promise *responsePromise
@@ -446,7 +447,7 @@ func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error
446447
return
447448
}
448449

449-
if err := versionedDecode(packets, res, request.version(), b.metricRegistry); err != nil {
450+
if err := versionedDecode(packets, res, request.version(), metricRegistry); err != nil {
450451
// Malformed response
451452
cb(nil, err)
452453
return
@@ -459,6 +460,8 @@ func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error
459460
}
460461
}
461462

463+
b.lock.Lock()
464+
defer b.lock.Unlock()
462465
return b.sendWithPromise(request, promise)
463466
}
464467

@@ -939,6 +942,7 @@ func (b *Broker) write(buf []byte) (n int, err error) {
939942
return b.conn.Write(buf)
940943
}
941944

945+
// b.lock must be haled by caller
942946
func (b *Broker) send(rb protocolBody, promiseResponse bool, responseHeaderVersion int16) (*responsePromise, error) {
943947
var promise *responsePromise
944948
if promiseResponse {
@@ -963,10 +967,8 @@ func makeResponsePromise(responseHeaderVersion int16) *responsePromise {
963967
return promise
964968
}
965969

970+
// b.lock must be held by caller
966971
func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) error {
967-
b.lock.Lock()
968-
defer b.lock.Unlock()
969-
970972
if b.conn == nil {
971973
if b.connErr != nil {
972974
return b.connErr
@@ -1022,6 +1024,8 @@ func (b *Broker) sendInternal(rb protocolBody, promise *responsePromise) error {
10221024
}
10231025

10241026
func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error {
1027+
b.lock.Lock()
1028+
defer b.lock.Unlock()
10251029
responseHeaderVersion := int16(-1)
10261030
if res != nil {
10271031
responseHeaderVersion = res.headerVersion()
@@ -1036,13 +1040,13 @@ func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error {
10361040
return nil
10371041
}
10381042

1039-
return b.handleResponsePromise(req, res, promise)
1043+
return handleResponsePromise(req, res, promise, b.metricRegistry)
10401044
}
10411045

1042-
func (b *Broker) handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise) error {
1046+
func handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise, metricRegistry metrics.Registry) error {
10431047
select {
10441048
case buf := <-promise.packets:
1045-
return versionedDecode(buf, res, req.version(), b.metricRegistry)
1049+
return versionedDecode(buf, res, req.version(), metricRegistry)
10461050
case err := <-promise.errors:
10471051
return err
10481052
}
@@ -1185,6 +1189,7 @@ func (b *Broker) authenticateViaSASLv0() error {
11851189
}
11861190

11871191
func (b *Broker) authenticateViaSASLv1() error {
1192+
metricRegistry := b.metricRegistry
11881193
if b.conf.Net.SASL.Handshake {
11891194
handshakeRequest := &SaslHandshakeRequest{Mechanism: string(b.conf.Net.SASL.Mechanism), Version: b.conf.Net.SASL.Version}
11901195
handshakeResponse := new(SaslHandshakeResponse)
@@ -1195,7 +1200,7 @@ func (b *Broker) authenticateViaSASLv1() error {
11951200
Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
11961201
return handshakeErr
11971202
}
1198-
handshakeErr = b.handleResponsePromise(handshakeRequest, handshakeResponse, prom)
1203+
handshakeErr = handleResponsePromise(handshakeRequest, handshakeResponse, prom, metricRegistry)
11991204
if handshakeErr != nil {
12001205
Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
12011206
return handshakeErr
@@ -1215,7 +1220,7 @@ func (b *Broker) authenticateViaSASLv1() error {
12151220
Logger.Printf("Error while performing SASL Auth %s\n", b.addr)
12161221
return nil, authErr
12171222
}
1218-
authErr = b.handleResponsePromise(authenticateRequest, authenticateResponse, prom)
1223+
authErr = handleResponsePromise(authenticateRequest, authenticateResponse, prom, metricRegistry)
12191224
if authErr != nil {
12201225
Logger.Printf("Error while performing SASL Auth %s\n", b.addr)
12211226
return nil, authErr

broker_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ func TestSimpleBrokerCommunication(t *testing.T) {
123123
pendingNotify <- brokerMetrics{bytesRead, bytesWritten}
124124
})
125125
broker := NewBroker(mb.Addr())
126-
// Set the broker id in order to validate local broujhjker metrics
126+
// Set the broker id in order to validate local broker metrics
127127
broker.id = 0
128128
conf := NewTestConfig()
129129
conf.ApiVersionsRequest = false

consumer_group_test.go

+107
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package sarama
22

33
import (
44
"context"
5+
"errors"
56
"sync"
67
"testing"
8+
"time"
79
)
810

911
type handler struct {
@@ -93,3 +95,108 @@ func TestConsumerGroupNewSessionDuringOffsetLoad(t *testing.T) {
9395
}()
9496
wg.Wait()
9597
}
98+
99+
func TestConsume_RaceTest(t *testing.T) {
100+
const groupID = "test-group"
101+
const topic = "test-topic"
102+
const offsetStart = int64(1234)
103+
104+
cfg := NewConfig()
105+
cfg.Version = V2_8_1_0
106+
cfg.Consumer.Return.Errors = true
107+
108+
seedBroker := NewMockBroker(t, 1)
109+
110+
joinGroupResponse := &JoinGroupResponse{}
111+
112+
syncGroupResponse := &SyncGroupResponse{
113+
Version: 3, // sarama > 2.3.0.0 uses version 3
114+
}
115+
// Leverage mock response to get the MemberAssignment bytes
116+
mockSyncGroupResponse := NewMockSyncGroupResponse(t).SetMemberAssignment(&ConsumerGroupMemberAssignment{
117+
Version: 1,
118+
Topics: map[string][]int32{topic: {0}}, // map "test-topic" to partition 0
119+
UserData: []byte{0x01},
120+
})
121+
syncGroupResponse.MemberAssignment = mockSyncGroupResponse.MemberAssignment
122+
123+
heartbeatResponse := &HeartbeatResponse{
124+
Err: ErrNoError,
125+
}
126+
offsetFetchResponse := &OffsetFetchResponse{
127+
Version: 1,
128+
ThrottleTimeMs: 0,
129+
Err: ErrNoError,
130+
}
131+
offsetFetchResponse.AddBlock(topic, 0, &OffsetFetchResponseBlock{
132+
Offset: offsetStart,
133+
LeaderEpoch: 0,
134+
Metadata: "",
135+
Err: ErrNoError})
136+
137+
offsetResponse := &OffsetResponse{
138+
Version: 1,
139+
}
140+
offsetResponse.AddTopicPartition(topic, 0, offsetStart)
141+
142+
metadataResponse := new(MetadataResponse)
143+
metadataResponse.AddBroker(seedBroker.Addr(), seedBroker.BrokerID())
144+
metadataResponse.AddTopic("mismatched-topic", ErrUnknownTopicOrPartition)
145+
146+
handlerMap := map[string]MockResponse{
147+
"ApiVersionsRequest": NewMockApiVersionsResponse(t),
148+
"MetadataRequest": NewMockSequence(metadataResponse),
149+
"OffsetRequest": NewMockSequence(offsetResponse),
150+
"OffsetFetchRequest": NewMockSequence(offsetFetchResponse),
151+
"FindCoordinatorRequest": NewMockSequence(NewMockFindCoordinatorResponse(t).
152+
SetCoordinator(CoordinatorGroup, groupID, seedBroker)),
153+
"JoinGroupRequest": NewMockSequence(joinGroupResponse),
154+
"SyncGroupRequest": NewMockSequence(syncGroupResponse),
155+
"HeartbeatRequest": NewMockSequence(heartbeatResponse),
156+
}
157+
seedBroker.SetHandlerByMap(handlerMap)
158+
159+
cancelCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(4*time.Second))
160+
161+
defer seedBroker.Close()
162+
163+
retryWait := 20 * time.Millisecond
164+
var err error
165+
clientRetries := 0
166+
outerFor:
167+
for {
168+
_, err = NewConsumerGroup([]string{seedBroker.Addr()}, groupID, cfg)
169+
if err == nil {
170+
break
171+
}
172+
173+
if retryWait < time.Minute {
174+
retryWait *= 2
175+
}
176+
177+
clientRetries++
178+
179+
timer := time.NewTimer(retryWait)
180+
select {
181+
case <-cancelCtx.Done():
182+
err = cancelCtx.Err()
183+
timer.Stop()
184+
break outerFor
185+
case <-timer.C:
186+
}
187+
timer.Stop()
188+
}
189+
if err == nil {
190+
t.Fatalf("should not proceed to Consume")
191+
}
192+
193+
if clientRetries <= 0 {
194+
t.Errorf("clientRetries = %v; want > 0", clientRetries)
195+
}
196+
197+
if err != nil && !errors.Is(err, context.DeadlineExceeded) {
198+
t.Fatal(err)
199+
}
200+
201+
cancel()
202+
}

0 commit comments

Comments
 (0)