diff --git a/async_producer.go b/async_producer.go index 6a4f2adc5..97f2296da 100644 --- a/async_producer.go +++ b/async_producer.go @@ -86,9 +86,173 @@ type asyncProducer struct { txnmgr *transactionManager txLock sync.Mutex + // muter ensures per-partition ordering by preventing concurrent in-flight requests, + // mirroring Kafka's RecordAccumulator. + muter *partitionMuter + metricsRegistry metrics.Registry } +type partitionMuter struct { + mu sync.Mutex + cond *sync.Cond + closed bool + inFlightCounts map[string]map[int32]int // topic -> partition -> in-flight count + unmuteSignal chan struct{} +} + +func newPartitionMuter() *partitionMuter { + m := &partitionMuter{ + inFlightCounts: make(map[string]map[int32]int), + unmuteSignal: make(chan struct{}), + } + m.cond = sync.NewCond(&m.mu) + return m +} + +// isMuted reports whether the partition has an in-flight batch. +// Requires: m.mu held. +func (m *partitionMuter) isMuted(topic string, partition int32) bool { + return m.inFlightCounts[topic][partition] > 0 +} + +// isAnyMuted reports whether any partition in the set has an in-flight batch. +// Requires: m.mu held. +func (m *partitionMuter) isAnyMuted(set *produceSet) bool { + return set.anyPartition(func(topic string, partition int32, _ *partitionSet) bool { + return m.isMuted(topic, partition) + }) +} + +// mutePartition increments the in-flight count for a single partition. +// Requires: m.mu held. +func (m *partitionMuter) mutePartition(topic string, partition int32) { + if m.inFlightCounts[topic] == nil { + m.inFlightCounts[topic] = make(map[int32]int) + } + m.inFlightCounts[topic][partition]++ +} + +// muteSet increments the in-flight count for all partitions in the set. +// Requires: m.mu held. +func (m *partitionMuter) muteSet(set *produceSet) { + set.eachPartition(func(topic string, partition int32, _ *partitionSet) { + m.mutePartition(topic, partition) + }) +} + +// tryMute checks if any of the partitions in the given produceSet are already +// muted, returning false if they are, otherwise it reserves every partition in +// the set by bumping their in-flight counters +func (m *partitionMuter) tryMute(set *produceSet) bool { + if set == nil || set.empty() { + return false + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.isAnyMuted(set) { + return false + } + m.muteSet(set) + return true +} + +func (m *partitionMuter) tryMutePartition(topic string, partition int32) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if m.isMuted(topic, partition) { + return false + } + m.mutePartition(topic, partition) + return true +} + +// waitUntilMuted blocks until all partitions in the set can be muted, then mutes them. +// Returns false if the muter was closed before all partitions could be muted. +func (m *partitionMuter) waitUntilMuted(set *produceSet) bool { + if set == nil || set.empty() { + return false + } + + m.mu.Lock() + defer m.mu.Unlock() + + for { + if m.closed { + return false + } + if !m.isAnyMuted(set) { + break + } + m.cond.Wait() + } + + m.muteSet(set) + return true +} + +func (m *partitionMuter) awaitUnmuteChan(set *produceSet) (<-chan struct{}, bool) { + if set == nil || set.empty() { + return nil, false + } + + m.mu.Lock() + defer m.mu.Unlock() + + if !m.isAnyMuted(set) { + return nil, false + } + return m.unmuteSignal, true +} + +// unmute decrements the in-flight counter for all partitions in the set. +func (m *partitionMuter) unmute(set *produceSet) { + if set == nil { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed { + return + } + + set.eachPartition(func(topic string, partition int32, _ *partitionSet) { + partitions := m.inFlightCounts[topic] + if partitions == nil { + return + } + if partitions[partition] <= 1 { + delete(partitions, partition) + } else { + partitions[partition]-- + } + if len(partitions) == 0 { + delete(m.inFlightCounts, topic) + } + }) + close(m.unmuteSignal) + m.unmuteSignal = make(chan struct{}) + m.cond.Broadcast() +} + +// close shuts down the muter, waking any goroutines blocked in waitUntilMuted. +func (m *partitionMuter) close() { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed { + return + } + m.closed = true + close(m.unmuteSignal) + m.cond.Broadcast() +} + // NewAsyncProducer creates a new AsyncProducer using the given broker addresses and configuration. func NewAsyncProducer(addrs []string, conf *Config) (AsyncProducer, error) { client, err := NewClient(addrs, conf) @@ -128,6 +292,7 @@ func newAsyncProducer(client Client) (AsyncProducer, error) { brokers: make(map[*Broker]*brokerProducer), brokerRefs: make(map[*brokerProducer]int), txnmgr: txnmgr, + muter: newPartitionMuter(), metricsRegistry: newCleanupRegistry(client.Config().MetricRegistry), } @@ -780,13 +945,13 @@ func (p *asyncProducer) newBrokerProducer(broker *Broker) *brokerProducer { ) bp := &brokerProducer{ - parent: p, - broker: broker, - input: input, - output: bridge, - responses: responses, - buffer: newProduceSet(p), - currentRetries: make(map[string]map[int32]error), + parent: p, + broker: broker, + input: input, + output: bridge, + responses: responses, + accumulatingBatch: newProduceSet(p), + currentRetries: make(map[string]map[int32]error), } go withRecover(bp.run) @@ -800,18 +965,17 @@ func (p *asyncProducer) newBrokerProducer(broker *Broker) *brokerProducer { // Count the in flight requests to know when we can close the pending channel safely wg.Add(1) - // Capture the current set to forward in the callback - sendResponse := func(set *produceSet) ProduceCallback { - return func(response *ProduceResponse, err error) { - // Forward the response to make sure we do not block the responseReceiver - pending <- &brokerProducerResponse{ - set: set, - err: err, - res: response, - } - wg.Done() + // capture the muted set. unmuting is deferred to handleResponse to ensure that + // retries block subsequent batches for the same partition. + mutedSet := set + sendResponse := func(response *ProduceResponse, err error) { + pending <- &brokerProducerResponse{ + set: mutedSet, + err: err, + res: response, } - }(set) + wg.Done() + } if p.IsTransactional() { // Add partition to tx before sending current batch @@ -900,9 +1064,10 @@ type brokerProducer struct { responses <-chan *brokerProducerResponse abandoned chan struct{} - buffer *produceSet - timer *time.Timer - timerFired bool + accumulatingBatch *produceSet + flushingBatch *produceSet // batch that has been muted and is ready to send + timer *time.Timer + timerFired bool closing error currentRetries map[string]map[int32]error @@ -910,10 +1075,24 @@ type brokerProducer struct { func (bp *brokerProducer) run() { var output chan<- *produceSet - var timerChan <-chan time.Time Logger.Printf("producer/broker/%d starting up\n", bp.broker.ID()) for { + if bp.flushingBatch == nil && (bp.timerFired || bp.accumulatingBatch.readyToFlush()) { + bp.tryBuildFlushingBatch() + } + + var timerChan <-chan time.Time + if bp.timer != nil { + timerChan = bp.timer.C + } + + if bp.flushingBatch != nil { + output = bp.output + } else { + output = nil + } + select { case msg, ok := <-bp.input: if !ok { @@ -958,14 +1137,15 @@ func (bp *brokerProducer) run() { continue } - if bp.buffer.wouldOverflow(msg) { + if bp.accumulatingBatch.wouldOverflow(msg) { + Logger.Printf("producer/broker/%d maximum request accumulated, waiting for space\n", bp.broker.ID()) if err := bp.waitForSpace(msg, false); err != nil { bp.parent.retryMessage(msg, err) continue } } - if bp.parent.txnmgr.producerID != noProducerID && bp.buffer.producerEpoch != msg.producerEpoch { + if bp.parent.txnmgr.producerID != noProducerID && bp.accumulatingBatch.producerEpoch != msg.producerEpoch { // The epoch was reset, need to roll the buffer over Logger.Printf("producer/broker/%d detected epoch rollover, waiting for new buffer\n", bp.broker.ID()) if err := bp.waitForSpace(msg, true); err != nil { @@ -973,41 +1153,79 @@ func (bp *brokerProducer) run() { continue } } - if err := bp.buffer.add(msg); err != nil { + if err := bp.accumulatingBatch.add(msg); err != nil { bp.parent.returnError(msg, err) continue } if bp.parent.conf.Producer.Flush.Frequency > 0 && bp.timer == nil { bp.timer = time.NewTimer(bp.parent.conf.Producer.Flush.Frequency) - timerChan = bp.timer.C } case <-timerChan: bp.timerFired = true - case output <- bp.buffer: - bp.rollOver() - timerChan = nil + case output <- bp.flushingBatch: + bp.flushingBatch = nil case response, ok := <-bp.responses: if ok { bp.handleResponse(response) } } + } +} - if bp.timerFired || bp.buffer.readyToFlush() { - output = bp.output - } else { - output = nil - } +func (bp *brokerProducer) tryBuildFlushingBatch() bool { + if bp.flushingBatch != nil || bp.accumulatingBatch.empty() { + return false + } + if bp.parent.muter.tryMute(bp.accumulatingBatch) { + bp.flushingBatch = bp.accumulatingBatch + bp.rollOver() + return true + } + + partial := bp.accumulatingBatch.takePartitions(func(topic string, partition int32) bool { + return bp.parent.muter.tryMutePartition(topic, partition) + }) + if partial == nil { + return false } + bp.flushingBatch = partial + if bp.accumulatingBatch.empty() { + bp.rollOver() + } + return true } func (bp *brokerProducer) shutdown() { - for !bp.buffer.empty() { + // flush any ready buffer + for bp.flushingBatch != nil { select { case response := <-bp.responses: bp.handleResponse(response) - case bp.output <- bp.buffer: - bp.rollOver() + case bp.output <- bp.flushingBatch: + bp.flushingBatch = nil + } + } + // then flush the current buffer + for !bp.accumulatingBatch.empty() || bp.flushingBatch != nil { + if bp.flushingBatch == nil { + bp.tryBuildFlushingBatch() + } + var unmuteCh <-chan struct{} + var outputCh chan<- *produceSet + if bp.flushingBatch != nil { + outputCh = bp.output + } else if ch, blocked := bp.parent.muter.awaitUnmuteChan(bp.accumulatingBatch); blocked { + unmuteCh = ch + } + select { + case response, ok := <-bp.responses: + if ok { + bp.handleResponse(response) + } + case outputCh <- bp.flushingBatch: + bp.flushingBatch = nil + case <-unmuteCh: } } close(bp.output) @@ -1027,21 +1245,52 @@ func (bp *brokerProducer) needsRetry(msg *ProducerMessage) error { return bp.currentRetries[msg.Topic][msg.Partition] } +// waitForSpace makes space in the accumulating batch by flushing. It loops until the message fits. func (bp *brokerProducer) waitForSpace(msg *ProducerMessage, forceRollover bool) error { + if bp.accumulatingBatch.empty() && !forceRollover { + return nil + } + for { - select { - case response := <-bp.responses: - bp.handleResponse(response) - // handling a response can change our state, so re-check some things - if reason := bp.needsRetry(msg); reason != nil { - return reason - } else if !bp.buffer.wouldOverflow(msg) && !forceRollover { - return nil + if !bp.accumulatingBatch.wouldOverflow(msg) && !forceRollover { + return nil + } + + if bp.flushingBatch != nil { + select { + case response := <-bp.responses: + bp.handleResponse(response) + if reason := bp.needsRetry(msg); reason != nil { + return reason + } + case bp.output <- bp.flushingBatch: + bp.flushingBatch = nil + } + + continue + } + + if bp.accumulatingBatch.empty() { + if forceRollover { + bp.rollOver() } - case bp.output <- bp.buffer: - bp.rollOver() return nil } + + if bp.tryBuildFlushingBatch() { + continue + } + + if unmuteCh, blocked := bp.parent.muter.awaitUnmuteChan(bp.accumulatingBatch); blocked { + select { + case response := <-bp.responses: + bp.handleResponse(response) + if reason := bp.needsRetry(msg); reason != nil { + return reason + } + case <-unmuteCh: + } + } } } @@ -1051,7 +1300,7 @@ func (bp *brokerProducer) rollOver() { } bp.timer = nil bp.timerFired = false - bp.buffer = newProduceSet(bp.parent) + bp.accumulatingBatch = newProduceSet(bp.parent) } func (bp *brokerProducer) handleResponse(response *brokerProducerResponse) { @@ -1061,7 +1310,7 @@ func (bp *brokerProducer) handleResponse(response *brokerProducerResponse) { bp.handleSuccess(response.set, response.res) } - if bp.buffer.empty() { + if bp.accumulatingBatch.empty() { bp.rollOver() // this can happen if the response invalidated our buffer } } @@ -1070,6 +1319,7 @@ func (bp *brokerProducer) handleSuccess(sent *produceSet, response *ProduceRespo // we iterate through the blocks in the request set, not the response, so that we notice // if the response is missing a block completely var retryTopics []string + keepMuted := make(map[string]map[int32]struct{}) sent.eachPartition(func(topic string, partition int32, pSet *partitionSet) { if response == nil { // this only happens when RequiredAcks is NoResponse, so we have to assume success @@ -1106,6 +1356,12 @@ func (bp *brokerProducer) handleSuccess(sent *produceSet, response *ProduceRespo bp.parent.returnErrors(pSet.msgs, block.Err) } else { retryTopics = append(retryTopics, topic) + if bp.parent.conf.Producer.Idempotent { + if keepMuted[topic] == nil { + keepMuted[topic] = make(map[int32]struct{}) + } + keepMuted[topic][partition] = struct{}{} + } } // Other non-retriable errors default: @@ -1141,19 +1397,28 @@ func (bp *brokerProducer) handleSuccess(sent *produceSet, response *ProduceRespo } bp.currentRetries[topic][partition] = block.Err if bp.parent.conf.Producer.Idempotent { - go bp.parent.retryBatch(topic, partition, pSet, block.Err) + go bp.parent.retryBatch(topic, partition, pSet, block.Err, true) } else { bp.parent.retryMessages(pSet.msgs, block.Err) } // dropping the following messages has the side effect of incrementing their retry count - bp.parent.retryMessages(bp.buffer.dropPartition(topic, partition), block.Err) + bp.parent.retryMessages(bp.accumulatingBatch.dropPartition(topic, partition), block.Err) } }) } + + unmuteSet := sent.copyFunc(func(topic string, partition int32) bool { + if partitions := keepMuted[topic]; partitions != nil { + _, kept := partitions[partition] + return !kept + } + return true + }) + bp.parent.muter.unmute(unmuteSet) } -func (p *asyncProducer) retryBatch(topic string, partition int32, pSet *partitionSet, kerr KError) { - Logger.Printf("Retrying batch for %v-%d because of %s\n", topic, partition, kerr) +func (p *asyncProducer) retryBatch(topic string, partition int32, pSet *partitionSet, retryErr error, alreadyMuted bool) { + Logger.Printf("Retrying batch for %v-%d because of %v\n", topic, partition, retryErr) produceSet := newProduceSet(p) produceSet.msgs[topic] = make(map[int32]*partitionSet) produceSet.msgs[topic][partition] = pSet @@ -1161,21 +1426,35 @@ func (p *asyncProducer) retryBatch(topic string, partition int32, pSet *partitio produceSet.bufferCount += len(pSet.msgs) for _, msg := range pSet.msgs { if msg.retries >= p.conf.Producer.Retry.Max { - p.returnErrors(pSet.msgs, kerr) + p.returnErrors(pSet.msgs, retryErr) + if alreadyMuted { + p.muter.unmute(produceSet) + } return } msg.retries++ } // it's expected that a metadata refresh has been requested prior to calling retryBatch - leader, err := p.client.Leader(topic, partition) - if err != nil { - Logger.Printf("Failed retrying batch for %v-%d because of %v while looking up for new leader\n", topic, partition, err) + leader, leaderErr := p.client.Leader(topic, partition) + if leaderErr != nil { + Logger.Printf("Failed retrying batch for %v-%d because of %v while looking up for new leader\n", topic, partition, leaderErr) for _, msg := range pSet.msgs { - p.returnError(msg, kerr) + p.returnError(msg, retryErr) + } + if alreadyMuted { + p.muter.unmute(produceSet) } return } + if !alreadyMuted { + if !p.muter.waitUntilMuted(produceSet) { + for _, msg := range pSet.msgs { + p.returnError(msg, retryErr) + } + return + } + } bp := p.getBrokerProducer(leader) bp.output <- produceSet p.unrefBrokerProducer(leader, bp) @@ -1187,18 +1466,57 @@ func (bp *brokerProducer) handleError(sent *produceSet, err error) { sent.eachPartition(func(topic string, partition int32, pSet *partitionSet) { bp.parent.returnErrors(pSet.msgs, err) }) + bp.parent.muter.unmute(sent) } else { Logger.Printf("producer/broker/%d state change to [closing] because %s\n", bp.broker.ID(), err) bp.parent.abandonBrokerConnection(bp.broker) _ = bp.broker.Close() bp.closing = err + var retryTopics []string + retryTopicSeen := make(map[string]struct{}) sent.eachPartition(func(topic string, partition int32, pSet *partitionSet) { - bp.parent.retryMessages(pSet.msgs, err) + if _, ok := retryTopicSeen[topic]; ok { + return + } + retryTopicSeen[topic] = struct{}{} + retryTopics = append(retryTopics, topic) }) - bp.buffer.eachPartition(func(topic string, partition int32, pSet *partitionSet) { + if bp.parent.conf.Producer.Idempotent && len(retryTopics) > 0 { + refreshErr := bp.parent.client.RefreshMetadata(retryTopics...) + if refreshErr != nil { + Logger.Printf("Failed refreshing metadata because of %v\n", refreshErr) + } + } + keepMuted := make(map[string]map[int32]struct{}) + sent.eachPartition(func(topic string, partition int32, pSet *partitionSet) { + // keep partition marked as in-flight during retry (connection error) + if bp.currentRetries[topic] == nil { + bp.currentRetries[topic] = make(map[int32]error) + } + bp.currentRetries[topic][partition] = err + if bp.parent.conf.Producer.Idempotent { + if keepMuted[topic] == nil { + keepMuted[topic] = make(map[int32]struct{}) + } + keepMuted[topic][partition] = struct{}{} + go bp.parent.retryBatch(topic, partition, pSet, err, true) + } else { + bp.parent.retryMessages(pSet.msgs, err) + } + }) + bp.accumulatingBatch.eachPartition(func(topic string, partition int32, pSet *partitionSet) { bp.parent.retryMessages(pSet.msgs, err) }) bp.rollOver() + + unmuteSet := sent.copyFunc(func(topic string, partition int32) bool { + if partitions := keepMuted[topic]; partitions != nil { + _, kept := partitions[partition] + return !kept + } + return true + }) + bp.parent.muter.unmute(unmuteSet) } } @@ -1243,6 +1561,8 @@ func (p *asyncProducer) shutdown() { Logger.Println("producer/shutdown failed to close the embedded client:", err) } + p.muter.close() + close(p.input) close(p.retries) close(p.errors) diff --git a/async_producer_test.go b/async_producer_test.go index 32dc0622b..82be17064 100644 --- a/async_producer_test.go +++ b/async_producer_test.go @@ -14,6 +14,7 @@ import ( "github.com/fortytw2/leaktest" "github.com/rcrowley/go-metrics" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1612,6 +1613,384 @@ func TestBrokerProducerShutdown(t *testing.T) { mockBroker.Close() } +// TestBrokerProducerWaitForSpaceEmptyBufferRollover ensures forced rollovers with an empty buffer +// do not deadlock waiting for responses when no partitions are muted. +func TestBrokerProducerWaitForSpaceEmptyBufferRollover(t *testing.T) { + config := NewTestConfig() + parent := &asyncProducer{ + conf: config, + muter: newPartitionMuter(), + txnmgr: &transactionManager{}, + } + + bp := &brokerProducer{ + parent: parent, + accumulatingBatch: newProduceSet(parent), + output: make(chan *produceSet, 1), + responses: make(chan *brokerProducerResponse), + } + + done := make(chan error, 1) + go func() { + done <- bp.waitForSpace(&ProducerMessage{Topic: "topic", Partition: 0}, true) + }() + + select { + case err := <-done: + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("waitForSpace blocked on empty buffer rollover") + } +} + +func awaitMuterBlocked(t *testing.T, m *partitionMuter, set *produceSet) { + t.Helper() + deadline := time.Now().Add(500 * time.Millisecond) + for { + if _, blocked := m.awaitUnmuteChan(set); blocked { + return + } + if time.Now().After(deadline) { + t.Fatal("timeout waiting for muter to block on set") + } + time.Sleep(5 * time.Millisecond) + } +} + +func assertNotDone[T any](t *testing.T, ch <-chan T, wait time.Duration) { + t.Helper() + time.Sleep(wait) + select { + case <-ch: + t.Fatal("channel should not be ready") + default: + } +} + +func assertDoneWithin[T any](t *testing.T, ch <-chan T, timeout time.Duration) T { + t.Helper() + select { + case v := <-ch: + return v + case <-time.After(timeout): + t.Fatal("timed out waiting for channel") + var zero T + return zero + } +} + +// TestBrokerProducerWaitForSpaceRespectsExternalUnmute ensures waitForSpace does not +// deadlock when partitions are muted by another producer and are unmuted elsewhere. +func TestBrokerProducerWaitForSpaceRespectsExternalUnmute(t *testing.T) { + config := NewTestConfig() + txnMgr := &transactionManager{ + producerID: 0, + producerEpoch: 0, + sequenceNumbers: make(map[string]int32), + } + parent := &asyncProducer{ + conf: config, + muter: newPartitionMuter(), + txnmgr: txnMgr, + } + + externallyMutedSet := newProduceSet(parent) + safeAddMessage(t, externallyMutedSet, &ProducerMessage{Topic: "topic", Partition: 0, Value: StringEncoder("held")}) + if !parent.muter.tryMute(externallyMutedSet) { + t.Fatal("expected to mute partition") + } + + output := make(chan *produceSet, 1) + bp := &brokerProducer{ + parent: parent, + accumulatingBatch: newProduceSet(parent), + output: output, + responses: make(chan *brokerProducerResponse), + } + msg := &ProducerMessage{Topic: "topic", Partition: 0, Value: StringEncoder("wait")} + safeAddMessage(t, bp.accumulatingBatch, msg) + + done := make(chan error, 1) + go func() { + done <- bp.waitForSpace(msg, true) + }() + + awaitMuterBlocked(t, parent.muter, bp.accumulatingBatch) + parent.muter.unmute(externallyMutedSet) + + select { + case err := <-done: + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("waitForSpace blocked waiting for external unmute") + } +} + +func TestBrokerProducerFlushSkipsMutedPartitions(t *testing.T) { + config := NewTestConfig() + parent := &asyncProducer{ + conf: config, + muter: newPartitionMuter(), + txnmgr: &transactionManager{}, + } + bp := &brokerProducer{ + parent: parent, + accumulatingBatch: newProduceSet(parent), + currentRetries: make(map[string]map[int32]error), + } + + safeAddMessage(t, bp.accumulatingBatch, &ProducerMessage{Topic: "topic", Partition: 0, Value: StringEncoder("p0")}) + safeAddMessage(t, bp.accumulatingBatch, &ProducerMessage{Topic: "topic", Partition: 1, Value: StringEncoder("p1")}) + + blocked := newProduceSet(parent) + safeAddMessage(t, blocked, &ProducerMessage{Topic: "topic", Partition: 1, Value: StringEncoder("held")}) + if !parent.muter.tryMute(blocked) { + t.Fatal("expected to mute blocked partition") + } + defer parent.muter.unmute(blocked) + + if !bp.tryBuildFlushingBatch() { + t.Fatal("expected to flush available partitions") + } + if bp.flushingBatch == nil { + t.Fatal("expected flushing batch to be set") + } + if _, ok := bp.flushingBatch.msgs["topic"][1]; ok { + t.Fatal("expected muted partition to stay buffered") + } + if _, ok := bp.accumulatingBatch.msgs["topic"][0]; ok { + t.Fatal("expected unmuted partition to flush") + } + if _, ok := bp.accumulatingBatch.msgs["topic"][1]; !ok { + t.Fatal("expected muted partition to remain in accumulating batch") + } +} + +// TestBrokerProducerWaitForSpaceAllPartitionsMuted verifies that waitForSpace unblocks +// when all partitions in the accumulating batch are externally muted and later unmuted. +func TestBrokerProducerWaitForSpaceAllPartitionsMuted(t *testing.T) { + config := NewTestConfig() + parent := &asyncProducer{ + conf: config, + muter: newPartitionMuter(), + txnmgr: &transactionManager{}, + } + + blockedSet := newProduceSet(parent) + safeAddMessage(t, blockedSet, &ProducerMessage{Topic: "topic", Partition: 0, Value: StringEncoder("held")}) + if !parent.muter.tryMute(blockedSet) { + t.Fatal("expected to mute partition externally") + } + + bp := &brokerProducer{ + parent: parent, + accumulatingBatch: newProduceSet(parent), + output: make(chan *produceSet, 1), + responses: make(chan *brokerProducerResponse), + currentRetries: make(map[string]map[int32]error), + } + safeAddMessage(t, bp.accumulatingBatch, &ProducerMessage{Topic: "topic", Partition: 0, Value: StringEncoder("waiting")}) + + done := make(chan error, 1) + go func() { + done <- bp.waitForSpace(&ProducerMessage{Topic: "topic", Partition: 0}, true) + }() + + assertNotDone(t, done, 50*time.Millisecond) + parent.muter.unmute(blockedSet) + if err := assertDoneWithin(t, done, 2*time.Second); err != nil { + t.Fatalf("expected nil error, got %v", err) + } +} + +// TestPartitionMuterCloseWakesWaitUntilMuted verifies that closing the muter wakes +// goroutines blocked in waitUntilMuted. +func TestPartitionMuterCloseWakesWaitUntilMuted(t *testing.T) { + config := NewTestConfig() + parent := &asyncProducer{ + conf: config, + muter: newPartitionMuter(), + txnmgr: &transactionManager{}, + } + + blockedSet := newProduceSet(parent) + safeAddMessage(t, blockedSet, &ProducerMessage{Topic: "topic", Partition: 0, Value: StringEncoder("held")}) + if !parent.muter.tryMute(blockedSet) { + t.Fatal("expected to mute partition") + } + + waitSet := newProduceSet(parent) + safeAddMessage(t, waitSet, &ProducerMessage{Topic: "topic", Partition: 0, Value: StringEncoder("waiting")}) + + done := make(chan bool, 1) + go func() { + done <- parent.muter.waitUntilMuted(waitSet) + }() + + assertNotDone(t, done, 50*time.Millisecond) + parent.muter.close() + + select { + case result := <-done: + if result { + t.Fatal("expected waitUntilMuted to return false after close") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out") + } +} + +// TestBrokerProducerRollOverClearsTimer ensures timer events from a previous batch +// do not cause a flush of a fresh empty batch after rollOver. +func TestBrokerProducerRollOverClearsTimer(t *testing.T) { + defer leaktest.Check(t)() + + config := NewTestConfig() + config.Producer.Flush.Frequency = 10 * time.Millisecond + parent := &asyncProducer{ + conf: config, + muter: newPartitionMuter(), + txnmgr: &transactionManager{}, + } + output := make(chan *produceSet, 2) + responses := make(chan *brokerProducerResponse) + input := make(chan *ProducerMessage) + bp := &brokerProducer{ + parent: parent, + broker: &Broker{id: 1}, + input: input, + output: output, + responses: responses, + accumulatingBatch: newProduceSet(parent), + currentRetries: make(map[string]map[int32]error), + } + + done := make(chan struct{}) + go func() { + bp.run() + close(done) + }() + + msg := &ProducerMessage{Topic: "topic", Partition: 0, Value: StringEncoder("v")} + input <- msg + + select { + case first := <-output: + if first == nil || first.empty() { + t.Fatal("expected flushed batch to contain message") + } + case <-time.After(500 * time.Millisecond): + t.Fatal("expected batch flush after timer fired") + } + + select { + case <-output: + t.Fatal("unexpected flush after rollOver") + case <-time.After(50 * time.Millisecond): + } + + close(responses) + close(input) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("brokerProducer did not shut down") + } +} + +func TestRetryBatchRespectsPartitionMuter(t *testing.T) { + config := NewTestConfig() + config.Producer.Idempotent = true + txnMgr := &transactionManager{ + producerID: 0, + producerEpoch: 0, + sequenceNumbers: make(map[string]int32), + } + + parent := &asyncProducer{ + conf: config, + muter: newPartitionMuter(), + brokers: make(map[*Broker]*brokerProducer), + brokerRefs: make(map[*brokerProducer]int), + txnmgr: txnMgr, + } + leader := &Broker{} + parent.client = &stubLeaderClient{leader: leader, cfg: config} + + output := make(chan *produceSet, 1) + bp := &brokerProducer{ + parent: parent, + broker: leader, + output: output, + input: make(chan *ProducerMessage), + } + parent.brokers[leader] = bp + + retrySet := newProduceSet(parent) + safeAddMessage(t, retrySet, &ProducerMessage{Topic: "topic", Partition: 0, Value: StringEncoder("retry")}) + retryPartitionSet := retrySet.msgs["topic"][0] + if !parent.muter.tryMute(retrySet) { + t.Fatal("expected retry set to mute partitions") + } + parent.muter.unmute(retrySet) + + parent.retryBatch("topic", 0, retryPartitionSet, ErrNotEnoughReplicas, false) + + select { + case sent := <-output: + set := sent.msgs["topic"][0] + require.Equal(t, retryPartitionSet, set) + default: + t.Fatal("expected retry batch to be dispatched") + } + + contender := newProduceSet(parent) + safeAddMessage(t, contender, &ProducerMessage{Topic: "topic", Partition: 0, Value: StringEncoder("next")}) + if parent.muter.tryMute(contender) { + t.Fatal("expected partition to remain muted by retry batch") + } +} + +type stubLeaderClient struct { + cfg *Config + leader *Broker +} + +func (c *stubLeaderClient) Config() *Config { return c.cfg } +func (c *stubLeaderClient) Controller() (*Broker, error) { return nil, nil } +func (c *stubLeaderClient) RefreshController() (*Broker, error) { return nil, nil } +func (c *stubLeaderClient) Brokers() []*Broker { return nil } +func (c *stubLeaderClient) Broker(int32) (*Broker, error) { return nil, nil } +func (c *stubLeaderClient) Topics() ([]string, error) { return nil, nil } +func (c *stubLeaderClient) Partitions(string) ([]int32, error) { return nil, nil } +func (c *stubLeaderClient) WritablePartitions(string) ([]int32, error) { return nil, nil } +func (c *stubLeaderClient) Leader(topic string, partitionID int32) (*Broker, error) { + return c.leader, nil +} +func (c *stubLeaderClient) LeaderAndEpoch(string, int32) (*Broker, int32, error) { + return c.leader, 0, nil +} +func (c *stubLeaderClient) Replicas(string, int32) ([]int32, error) { return nil, nil } +func (c *stubLeaderClient) InSyncReplicas(string, int32) ([]int32, error) { return nil, nil } +func (c *stubLeaderClient) OfflineReplicas(string, int32) ([]int32, error) { return nil, nil } +func (c *stubLeaderClient) RefreshBrokers([]string) error { return nil } +func (c *stubLeaderClient) RefreshMetadata(...string) error { return nil } +func (c *stubLeaderClient) GetOffset(string, int32, int64) (int64, error) { return 0, nil } +func (c *stubLeaderClient) Coordinator(string) (*Broker, error) { return nil, nil } +func (c *stubLeaderClient) RefreshCoordinator(string) error { return nil } +func (c *stubLeaderClient) TransactionCoordinator(string) (*Broker, error) { return nil, nil } +func (c *stubLeaderClient) RefreshTransactionCoordinator(string) error { return nil } +func (c *stubLeaderClient) InitProducerID() (*InitProducerIDResponse, error) { return nil, nil } +func (c *stubLeaderClient) LeastLoadedBroker() *Broker { return c.leader } +func (c *stubLeaderClient) PartitionNotReadable(string, int32) bool { return false } +func (c *stubLeaderClient) Close() error { return nil } +func (c *stubLeaderClient) Closed() bool { return false } + type appendInterceptor struct { i int } @@ -2156,74 +2535,81 @@ func TestTxnCanAbort(t *testing.T) { config.Producer.Retry.Max = 1 config.Net.MaxOpenRequests = 1 - metadataLeader := new(MetadataResponse) - metadataLeader.Version = 4 - metadataLeader.ControllerID = broker.brokerID - metadataLeader.AddBroker(broker.Addr(), broker.BrokerID()) - metadataLeader.AddTopic("test-topic", ErrNoError) - metadataLeader.AddTopic("test-topic-2", ErrNoError) - metadataLeader.AddTopicPartition("test-topic", 0, broker.BrokerID(), nil, nil, nil, ErrNoError) - metadataLeader.AddTopicPartition("test-topic-2", 0, broker.BrokerID(), nil, nil, nil, ErrNoError) - broker.Returns(metadataLeader) + var ( + mu sync.Mutex + addPartitionsCount int + produceRequestsCount int + ) + + broker.SetHandlerFuncByMap(map[string]requestHandlerFunc{ + "MetadataRequest": func(req *request) encoderWithHeader { + resp := new(MetadataResponse) + resp.Version = 4 + resp.ControllerID = broker.BrokerID() + resp.AddBroker(broker.Addr(), broker.BrokerID()) + resp.AddTopic("test-topic", ErrNoError) + resp.AddTopic("test-topic-2", ErrNoError) + resp.AddTopicPartition("test-topic", 0, broker.BrokerID(), nil, nil, nil, ErrNoError) + resp.AddTopicPartition("test-topic-2", 0, broker.BrokerID(), nil, nil, nil, ErrNoError) + return resp + }, + "FindCoordinatorRequest": func(req *request) encoderWithHeader { + resp := new(FindCoordinatorResponse) + resp.Version = 1 + resp.Coordinator = &Broker{id: broker.BrokerID(), addr: broker.Addr()} + resp.Err = ErrNoError + return resp + }, + "InitProducerIDRequest": func(req *request) encoderWithHeader { + return &InitProducerIDResponse{ + Err: ErrNoError, + ProducerID: 1, + ProducerEpoch: 0, + } + }, + "AddPartitionsToTxnRequest": func(req *request) encoderWithHeader { + mu.Lock() + addPartitionsCount++ + count := addPartitionsCount + mu.Unlock() + + if count == 1 { + return &AddPartitionsToTxnResponse{ + Errors: map[string][]*PartitionError{ + "test-topic-2": {{Partition: 0, Err: ErrNoError}}, + }, + } + } + return &AddPartitionsToTxnResponse{ + Errors: map[string][]*PartitionError{ + "test-topic": {{Partition: 0, Err: ErrTopicAuthorizationFailed}}, + }, + } + }, + "ProduceRequest": func(req *request) encoderWithHeader { + mu.Lock() + produceRequestsCount++ + mu.Unlock() + + resp := new(ProduceResponse) + resp.Version = 3 + resp.AddTopicPartition("test-topic-2", 0, ErrNoError) + return resp + }, + "EndTxnRequest": func(req *request) encoderWithHeader { + return &EndTxnResponse{Err: ErrNoError} + }, + }) client, err := NewClient([]string{broker.Addr()}, config) require.NoError(t, err) defer client.Close() - findCoordinatorResponse := FindCoordinatorResponse{ - Coordinator: client.Brokers()[0], - Err: ErrNoError, - Version: 1, - } - broker.Returns(&findCoordinatorResponse) - - producerIdResponse := &InitProducerIDResponse{ - Err: ErrNoError, - ProducerID: 1, - ProducerEpoch: 0, - } - broker.Returns(producerIdResponse) - ap, err := NewAsyncProducerFromClient(client) producer := ap.(*asyncProducer) require.NoError(t, err) defer ap.Close() - broker.Returns(&AddPartitionsToTxnResponse{ - Errors: map[string][]*PartitionError{ - "test-topic-2": { - { - Partition: 0, - Err: ErrNoError, - }, - }, - }, - }) - - produceResponse := new(ProduceResponse) - produceResponse.Version = 3 - produceResponse.AddTopicPartition("test-topic-2", 0, ErrNoError) - broker.Returns(produceResponse) - - broker.Returns(&AddPartitionsToTxnResponse{ - Errors: map[string][]*PartitionError{ - "test-topic": { - { - Partition: 0, - Err: ErrTopicAuthorizationFailed, - }, - }, - }, - }) - - // now broker is closed due to error. will now reopen it - broker.Returns(metadataLeader) - - endTxnResponse := &EndTxnResponse{ - Err: ErrNoError, - } - broker.Returns(endTxnResponse) - require.Equal(t, ProducerTxnFlagReady, producer.txnmgr.status) err = producer.BeginTxn() @@ -2333,3 +2719,274 @@ ProducerLoop: log.Printf("Successfully produced: %d; errors: %d\n", successes, producerErrors) } + +// TestAsyncProducerRetryOrdering verifies that message ordering is preserved during retries, +// both with and without request pipelining (MaxOpenRequests=1 vs >1). +func TestAsyncProducerRetryOrdering(t *testing.T) { + const topic = "my_topic" + + extractValue := func(pr *ProduceRequest) string { + recordsByPartition := pr.records[topic] + if recordsByPartition == nil { + return "" + } + records := recordsByPartition[0] + if rb := records.RecordBatch; rb != nil && len(rb.Records) > 0 { + return string(rb.Records[0].Value) + } + if ms := records.MsgSet; ms != nil && len(ms.Messages) > 0 { + return string(ms.Messages[0].Msg.Value) + } + return "" + } + + tests := []struct { + name string + maxOpenRequests int + retryBackoff time.Duration + }{ + { + name: "no pipelining (MaxOpenRequests=1)", + maxOpenRequests: 1, + retryBackoff: 0, + }, + { + name: "with pipelining (MaxOpenRequests=5)", + maxOpenRequests: 5, + retryBackoff: 50 * time.Millisecond, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition(topic, 0, leader.BrokerID(), nil, nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + var ( + mu sync.Mutex + produceAttempts int + valuesSeen []string + ) + + leader.setHandler(func(req *request) (res encoderWithHeader) { + switch typed := req.body.(type) { + case *MetadataRequest: + return metadataResponse + case *ProduceRequest: + mu.Lock() + defer mu.Unlock() + + produceAttempts++ + value := extractValue(typed) + valuesSeen = append(valuesSeen, value) + + resp := new(ProduceResponse) + resp.Version = typed.Version + switch produceAttempts { + case 1: + resp.AddTopicPartition(topic, 0, ErrNotLeaderForPartition) + case 2, 3: + resp.AddTopicPartition(topic, 0, ErrNoError) + default: + t.Errorf("unexpected attempt %d", produceAttempts) + resp.AddTopicPartition(topic, 0, ErrNoError) + } + return resp + default: + return nil + } + }) + + config := NewTestConfig() + config.Producer.Return.Successes = true + config.Producer.Retry.Max = 3 + config.Producer.Retry.Backoff = tt.retryBackoff + config.Producer.Flush.Messages = 1 + config.Producer.Flush.MaxMessages = 1 + config.Producer.Partitioner = NewManualPartitioner + config.Net.MaxOpenRequests = tt.maxOpenRequests + + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + msgValues := []string{"msg-0", "msg-1"} + for _, val := range msgValues { + producer.Input() <- &ProducerMessage{Topic: topic, Partition: 0, Value: StringEncoder(val)} + } + + expectResultsWithTimeout(t, producer, len(msgValues), 0, 10*time.Second) + + mu.Lock() + attempts := produceAttempts + seen := make([]string, len(valuesSeen)) + copy(seen, valuesSeen) + mu.Unlock() + + closeProducer(t, producer) + seedBroker.Close() + leader.Close() + + if attempts != 3 { + t.Errorf("expected 3 produce attempts, got %d", attempts) + } + + // Both configurations should maintain ordering: msg-0 (fail), msg-0 (retry), msg-1 + expectedOrder := []string{"msg-0", "msg-0", "msg-1"} + if !assert.Equal(t, expectedOrder, seen) { + t.Errorf("messages out of order: got %v, want %v", seen, expectedOrder) + } + }) + } +} + +// TestAsyncProducerPartitionUnmuting verifies that partitions are properly unmuted +// in all error paths: send errors, NoResponse acks, etc. Without proper unmuting, +// partitions remain muted and subsequent messages would block indefinitely. +func TestAsyncProducerPartitionUnmuting(t *testing.T) { + const topic = "test_topic" + + t.Run("NoResponse acks unmute partitions", func(t *testing.T) { + broker := NewMockBroker(t, 1) + defer broker.Close() + + metadataResponse := NewMockMetadataResponse(t). + SetBroker(broker.Addr(), broker.BrokerID()). + SetLeader(topic, 0, broker.BrokerID()) + broker.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": metadataResponse, + }) + + config := NewTestConfig() + config.Producer.RequiredAcks = NoResponse + config.Producer.Return.Successes = true + config.Producer.Flush.Messages = 1 + config.Net.MaxOpenRequests = 5 + + producer, err := NewAsyncProducer([]string{broker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 3; i++ { + producer.Input() <- &ProducerMessage{ + Topic: topic, + Partition: 0, + Value: StringEncoder("msg"), + } + } + + successCount := 0 + for i := 0; i < 3; i++ { + select { + case <-producer.Successes(): + successCount++ + case err := <-producer.Errors(): + t.Fatalf("unexpected error: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for success %d (got %d) - partitions may not be unmuted", i+1, successCount) + } + } + + if successCount != 3 { + t.Errorf("expected 3 successes, got %d", successCount) + } + + closeProducer(t, producer) + }) + + t.Run("retry keeps partition muted until queued", func(t *testing.T) { + broker := NewMockBroker(t, 1) + defer broker.Close() + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(broker.Addr(), broker.BrokerID()) + metadataResponse.AddTopicPartition(topic, 0, broker.BrokerID(), nil, nil, nil, ErrNoError) + + var attemptCount int + var mu sync.Mutex + firstRequestReceived := make(chan struct{}) + + broker.setHandler(func(req *request) (res encoderWithHeader) { + switch req.body.(type) { + case *MetadataRequest: + return metadataResponse + case *ProduceRequest: + mu.Lock() + attemptCount++ + attempt := attemptCount + mu.Unlock() + + if attempt == 1 { + close(firstRequestReceived) + } + + resp := new(ProduceResponse) + if attempt == 1 { + resp.AddTopicPartition(topic, 0, ErrNotLeaderForPartition) + } else { + resp.AddTopicPartition(topic, 0, ErrNoError) + } + return resp + } + return nil + }) + + config := NewTestConfig() + config.Producer.Return.Successes = true + config.Producer.Retry.Max = 1 + config.Producer.Retry.Backoff = 10 * time.Millisecond + config.Producer.Flush.Messages = 1 + config.Net.MaxOpenRequests = 5 + + producer, err := NewAsyncProducer([]string{broker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + producer.Input() <- &ProducerMessage{ + Topic: topic, + Partition: 0, + Value: StringEncoder("msg-0"), + } + + <-firstRequestReceived + + producer.Input() <- &ProducerMessage{ + Topic: topic, + Partition: 0, + Value: StringEncoder("msg-1"), + } + + var successCount int + for i := 0; i < 2; i++ { + select { + case <-producer.Successes(): + successCount++ + case err := <-producer.Errors(): + t.Fatalf("unexpected error: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for success %d/%d - partition may be deadlocked", successCount, 2) + } + } + + if successCount != 2 { + t.Errorf("expected 2 successes, got %d", successCount) + } + + mu.Lock() + attempts := attemptCount + mu.Unlock() + if attempts != 3 { + t.Errorf("expected 3 produce attempts, got %d", attempts) + } + + closeProducer(t, producer) + }) +} diff --git a/config.go b/config.go index 01ed81362..1bf01e348 100644 --- a/config.go +++ b/config.go @@ -195,7 +195,8 @@ type Config struct { // setting for the JVM producer. Partitioner PartitionerConstructor // If enabled, the producer will ensure that exactly one copy of each message is - // written. + // written, and it will enforce stricter ordering by requiring MaxOpenRequests=1 + // and WaitForAll acks, which can reduce throughput. Idempotent bool // Transaction specify Transaction struct { diff --git a/produce_set.go b/produce_set.go index 004fc6490..380f7f5ab 100644 --- a/produce_set.go +++ b/produce_set.go @@ -24,11 +24,15 @@ type produceSet struct { func newProduceSet(parent *asyncProducer) *produceSet { pid, epoch := parent.txnmgr.getProducerID() + return newProduceSetWithMeta(parent, pid, epoch) +} + +func newProduceSetWithMeta(parent *asyncProducer, producerID int64, producerEpoch int16) *produceSet { return &produceSet{ msgs: make(map[string]map[int32]*partitionSet), parent: parent, - producerID: pid, - producerEpoch: epoch, + producerID: producerID, + producerEpoch: producerEpoch, } } @@ -127,6 +131,52 @@ func (ps *produceSet) add(msg *ProducerMessage) error { return nil } +func (ps *produceSet) takePartitions(predicate func(topic string, partition int32) bool) *produceSet { + if ps.empty() { + return nil + } + out := newProduceSetWithMeta(ps.parent, ps.producerID, ps.producerEpoch) + for topic, partitions := range ps.msgs { + for partition, set := range partitions { + if !predicate(topic, partition) { + continue + } + if out.msgs[topic] == nil { + out.msgs[topic] = make(map[int32]*partitionSet) + } + out.msgs[topic][partition] = set + out.bufferBytes += set.bufferBytes + out.bufferCount += len(set.msgs) + ps.bufferBytes -= set.bufferBytes + ps.bufferCount -= len(set.msgs) + delete(partitions, partition) + } + if len(partitions) == 0 { + delete(ps.msgs, topic) + } + } + if out.empty() { + return nil + } + return out +} + +func (ps *produceSet) copyFunc(predicate func(topic string, partition int32) bool) *produceSet { + out := newProduceSetWithMeta(ps.parent, ps.producerID, ps.producerEpoch) + for topic, partitions := range ps.msgs { + for partition, set := range partitions { + if !predicate(topic, partition) { + continue + } + if out.msgs[topic] == nil { + out.msgs[topic] = make(map[int32]*partitionSet) + } + out.msgs[topic][partition] = set + } + } + return out +} + func (ps *produceSet) buildRequest() *ProduceRequest { req := &ProduceRequest{ RequiredAcks: ps.parent.conf.Producer.RequiredAcks, @@ -225,6 +275,17 @@ func (ps *produceSet) eachPartition(cb func(topic string, partition int32, pSet } } +func (ps *produceSet) anyPartition(predicate func(topic string, partition int32, pSet *partitionSet) bool) bool { + for topic, partitionSet := range ps.msgs { + for partition, set := range partitionSet { + if predicate(topic, partition, set) { + return true + } + } + } + return false +} + func (ps *produceSet) dropPartition(topic string, partition int32) []*ProducerMessage { if ps.msgs[topic] == nil { return nil diff --git a/produce_set_test.go b/produce_set_test.go index 8f580c27d..b8a7c5e54 100644 --- a/produce_set_test.go +++ b/produce_set_test.go @@ -11,6 +11,7 @@ func makeProduceSet() (*asyncProducer, *produceSet) { txnmgr, _ := newTransactionManager(conf, nil) parent := &asyncProducer{ conf: conf, + muter: newPartitionMuter(), txnmgr: txnmgr, } return parent, newProduceSet(parent)