From 3e9beae32bda84811f9fd838ec84c446f7f15774 Mon Sep 17 00:00:00 2001 From: Travis Bischel Date: Sat, 28 Dec 2024 17:42:27 -0800 Subject: [PATCH] kgo: fix accounting when topics/partitions are {,un}paused for PollRecords Topics or partitions that were paused (in `takeNBuffered` specifically, which is used by PollRecords) did not have proper accounting for stripped (removed) topics or partitions. PollFetches does not have this problem because the logic is different -- we build a big fetch, call our accounting logic, and then remove the stripped topics & partitions. Instead in takeNBuffered, we skip topics and partitions as we build a fetch and then once the fetch is built, do the accounting. The fix is to build a separate internal-only fetch of everything that was stripped and pass that to our accounting hooks. Fixes #865. --- pkg/kgo/consumer_direct_test.go | 78 +++++++++++++++++++++++++++++++++ pkg/kgo/hooks.go | 4 +- pkg/kgo/source.go | 22 +++++++++- 3 files changed, 101 insertions(+), 3 deletions(-) diff --git a/pkg/kgo/consumer_direct_test.go b/pkg/kgo/consumer_direct_test.go index 736730e8..cde63f54 100644 --- a/pkg/kgo/consumer_direct_test.go +++ b/pkg/kgo/consumer_direct_test.go @@ -6,6 +6,8 @@ import ( "fmt" "reflect" "sort" + "strconv" + "sync" "sync/atomic" "testing" "time" @@ -605,3 +607,79 @@ func TestIssue810(t *testing.T) { } } } + +func TestIssue865(t *testing.T) { + t.Parallel() + + t1, cleanup1 := tmpTopicPartitions(t, 1) + defer cleanup1() + t2, cleanup2 := tmpTopicPartitions(t, 1) + defer cleanup2() + + cl, _ := newTestClient( + UnknownTopicRetries(-1), + ConsumeTopics(t1, t2), + FetchMaxWait(100*time.Millisecond), + ) + defer cl.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + nrecs = 10_000 + flushEvery = 1000 + pollAmount = 100 + ) + + var wg sync.WaitGroup + for i := 0; i < nrecs; i++ { + r1 := StringRecord(strconv.Itoa(i)) + r1.Topic = t1 + wg.Add(1) + cl.Produce(ctx, r1, func(_ *Record, err error) { + defer wg.Done() + if err != nil { + t.Fatal(err) + } + }) + + r2 := StringRecord(strconv.Itoa(i)) + r2.Topic = t2 + wg.Add(1) + cl.Produce(ctx, r2, func(_ *Record, err error) { + defer wg.Done() + if err != nil { + t.Fatal(err) + } + }) + + if nrecs%flushEvery == 0 { + cl.Flush(ctx) + } + } + + wg.Wait() + + for i := 2 * nrecs; i > 0; { + ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + fs := cl.PollRecords(ctx, 100) + cancel() + cl.ResumeFetchTopics(t2) + fs.EachRecord(func(r *Record) { + i-- + if r.Topic == t2 { + cl.PauseFetchTopics(t2) + } + }) + } + + nrecbuf, nbytebuf := cl.BufferedFetchRecords(), cl.BufferedFetchBytes() + + if nrecbuf != 0 { + t.Errorf("got rec buffered %d != 0", nrecbuf) + } + if nbytebuf != 0 { + t.Errorf("got byte buffered %d != 0", nbytebuf) + } +} diff --git a/pkg/kgo/hooks.go b/pkg/kgo/hooks.go index aeff4f19..0150bab3 100644 --- a/pkg/kgo/hooks.go +++ b/pkg/kgo/hooks.go @@ -350,7 +350,9 @@ type HookProduceRecordPartitioned interface { // As an example, if using HookProduceRecordBuffered for a gauge of how many // record bytes are buffered, this hook can be used to decrement the gauge. // -// Note that this hook will slow down high-volume producing a bit. +// Note that this hook will slow down high-volume producing a bit. As well, +// records that were buffered but are paused (and stripped internally before +// being returned to the user) will still be passed to this hook. type HookProduceRecordUnbuffered interface { // OnProduceRecordUnbuffered is passed a record that is just about to // have its produce promise called, as well as the error that the diff --git a/pkg/kgo/source.go b/pkg/kgo/source.go index 12732e90..fcb198ec 100644 --- a/pkg/kgo/source.go +++ b/pkg/kgo/source.go @@ -489,8 +489,11 @@ func (s *source) discardBuffered() { // This returns the number of records taken and whether the source has been // completely drained. func (s *source) takeNBuffered(paused pausedTopics, n int) (Fetch, int, bool) { - var r Fetch - var taken int + var ( + r Fetch + rstrip Fetch + taken int + ) b := &s.buffered bf := &b.fetch @@ -500,6 +503,7 @@ func (s *source) takeNBuffered(paused pausedTopics, n int) (Fetch, int, bool) { // If the topic is outright paused, we allowUsable all // partitions in the topic and skip the topic entirely. if paused.has(t.Topic, -1) { + rstrip.Topics = append(rstrip.Topics, *t) bf.Topics = bf.Topics[1:] for _, pCursor := range b.usedOffsets[t.Topic] { pCursor.from.allowUsable() @@ -517,6 +521,15 @@ func (s *source) takeNBuffered(paused pausedTopics, n int) (Fetch, int, bool) { rt = &r.Topics[len(r.Topics)-1] rt.Partitions = nil } + var rtstrip *FetchTopic + ensureTopicStripped := func() { + if rtstrip != nil { + return + } + rstrip.Topics = append(rstrip.Topics, *t) + rtstrip = &rstrip.Topics[len(rstrip.Topics)-1] + rtstrip.Partitions = nil + } tCursors := b.usedOffsets[t.Topic] @@ -524,6 +537,8 @@ func (s *source) takeNBuffered(paused pausedTopics, n int) (Fetch, int, bool) { p := &t.Partitions[0] if paused.has(t.Topic, p.Partition) { + ensureTopicStripped() + rtstrip.Partitions = append(rtstrip.Partitions, *p) t.Partitions = t.Partitions[1:] pCursor := tCursors[p.Partition] pCursor.from.allowUsable() @@ -577,6 +592,9 @@ func (s *source) takeNBuffered(paused pausedTopics, n int) (Fetch, int, bool) { } } + if len(rstrip.Topics) > 0 { + s.hook(&rstrip, false, true) + } s.hook(&r, false, true) // unbuffered, polled drained := len(bf.Topics) == 0