diff --git a/pkg/distributor/distributor.go b/pkg/distributor/distributor.go index baedbb90aba90..92cfed2758836 100644 --- a/pkg/distributor/distributor.go +++ b/pkg/distributor/distributor.go @@ -454,29 +454,6 @@ type streamTracker struct { failed atomic.Int32 } -// TODO taken from Cortex, see if we can refactor out an usable interface. -type pushTracker struct { - streamsPending atomic.Int32 - streamsFailed atomic.Int32 - done chan struct{} - err chan error -} - -// doneWithResult records the result of a stream push. -// If err is nil, the stream push is considered successful. -// If err is not nil, the stream push is considered failed. -func (p *pushTracker) doneWithResult(err error) { - if err == nil { - if p.streamsPending.Dec() == 0 { - p.done <- struct{}{} - } - } else { - if p.streamsFailed.Inc() == 1 { - p.err <- err - } - } -} - func (d *Distributor) waitSimulatedLatency(ctx context.Context, tenantID string, start time.Time) { latency := d.validator.SimulatedPushLatency(tenantID) if latency > 0 { @@ -754,10 +731,7 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe const maxExpectedReplicationSet = 5 // typical replication factor 3 plus one for inactive plus one for luck var descs [maxExpectedReplicationSet]ring.InstanceDesc - tracker := pushTracker{ - done: make(chan struct{}, 1), // buffer avoids blocking if caller terminates - sendSamples() only sends once on each - err: make(chan error, 1), - } + tracker := newBasicPushTracker() streamsToWrite := 0 if d.cfg.IngesterEnabled { streamsToWrite += len(streams) @@ -766,7 +740,7 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe streamsToWrite += len(streams) } // We must correctly set streamsPending before beginning any writes to ensure we don't have a race between finishing all of one path before starting the other. - tracker.streamsPending.Store(int32(streamsToWrite)) + tracker.Add(int32(streamsToWrite)) if d.cfg.KafkaEnabled { subring, err := d.partitionRing.PartitionRing().ShuffleShard(tenantID, d.validator.IngestionPartitionsTenantShardSize(tenantID)) @@ -774,7 +748,7 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe return nil, err } // We don't need to create a new context like the ingester writes, because we don't return unless all writes have succeeded. - d.sendStreamsToKafka(ctx, streams, tenantID, &tracker, subring) + d.sendStreamsToKafka(ctx, streams, tenantID, tracker, subring) } if d.cfg.IngesterEnabled { @@ -823,7 +797,7 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe case d.ingesterTasks <- pushIngesterTask{ ingester: ingester, streamTracker: samples, - pushTracker: &tracker, + pushTracker: tracker, ctx: localCtx, cancel: cancel, }: @@ -833,14 +807,11 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe } } - select { - case err := <-tracker.err: + if err := tracker.Wait(ctx); err != nil { return nil, err - case <-tracker.done: - return &logproto.PushResponse{}, validationErr - case <-ctx.Done(): - return nil, ctx.Err() } + + return &logproto.PushResponse{}, validationErr } // missingEnforcedLabels returns true if the stream is missing any of the required labels. @@ -1135,7 +1106,7 @@ func (d *Distributor) truncateLines(vContext validationContext, stream *logproto type pushIngesterTask struct { streamTracker []*streamTracker - pushTracker *pushTracker + pushTracker PushTracker ingester ring.InstanceDesc ctx context.Context cancel context.CancelFunc @@ -1172,12 +1143,12 @@ func (d *Distributor) sendStreams(task pushIngesterTask) { if task.streamTracker[i].failed.Inc() <= int32(task.streamTracker[i].maxFailures) { continue } - task.pushTracker.doneWithResult(err) + task.pushTracker.Done(err) } else { if task.streamTracker[i].succeeded.Inc() != int32(task.streamTracker[i].minSuccess) { continue } - task.pushTracker.doneWithResult(nil) + task.pushTracker.Done(nil) } } } @@ -1209,14 +1180,14 @@ func (d *Distributor) sendStreamsErr(ctx context.Context, ingester ring.Instance return err } -func (d *Distributor) sendStreamsToKafka(ctx context.Context, streams []KeyedStream, tenant string, tracker *pushTracker, subring *ring.PartitionRing) { +func (d *Distributor) sendStreamsToKafka(ctx context.Context, streams []KeyedStream, tenant string, tracker PushTracker, subring *ring.PartitionRing) { for _, s := range streams { go func(s KeyedStream) { err := d.sendStreamToKafka(ctx, s, tenant, subring) if err != nil { err = fmt.Errorf("failed to write stream to kafka: %w", err) } - tracker.doneWithResult(err) + tracker.Done(err) }(s) } } diff --git a/pkg/distributor/tracker.go b/pkg/distributor/tracker.go new file mode 100644 index 0000000000000..5df085f3d376a --- /dev/null +++ b/pkg/distributor/tracker.go @@ -0,0 +1,103 @@ +package distributor + +import ( + "context" + "sync" +) + +// PushTracker is an interface to track the status of pushes and wait on +// their completion. +type PushTracker interface { + // Add increments the number of pushes. It must not be called after the + // last call to [Done] has completed. + Add(int32) + + // Done decrements the number of pushes. It accepts an optional error + // if the push failed. + Done(err error) + + // Wait until all pushes are done or a push fails, whichever happens + // first. + Wait(ctx context.Context) error +} + +type basicPushTracker struct { + mtx sync.Mutex // protects the fields below. + n int32 // the number of pushes. + firstErr error // the first reported error from a push. + doneCh chan struct{} // closed when all pushes are done. + errCh chan struct{} // closed when an error is reported. + done bool // fast path, equivalent to select { case <-t.doneCh: default: } +} + +// newBasicPushTracker returns a new, initialized [newSimplePushTracker]. +func newBasicPushTracker() *basicPushTracker { + return &basicPushTracker{ + doneCh: make(chan struct{}), + errCh: make(chan struct{}), + } +} + +// Add implements the [PushTracker] interface. +func (t *basicPushTracker) Add(n int32) { + t.mtx.Lock() + defer t.mtx.Unlock() + if t.done { + panic("Add called after last call to Done") + } + t.n += n + if t.n < 0 { + // We panic on negative counters just like [sync.WaitGroup]. + panic("Negative counter") + } +} + +// Done implements the [PushTracker] interface. +func (t *basicPushTracker) Done(err error) { + t.mtx.Lock() + defer t.mtx.Unlock() + if t.n <= 0 { + // We panic here just like [sync.WaitGroup]. + panic("Done called more times than Add") + } + if err != nil && t.firstErr == nil { + // errCh can never be closed twice as t.firstErr can never be nil + // more than once. + t.firstErr = err + close(t.errCh) + } + t.n-- + if t.n == 0 { + close(t.doneCh) + t.done = true + } +} + +// Wait implements the [PushTracker] interface. +func (t *basicPushTracker) Wait(ctx context.Context) error { + t.mtx.Lock() + // We need to have the mutex here as t.n can be modified as doneCh has + // not been closed, while t.firstErr can still be modified as neither + // doneCh nor errCh have been closed. + if t.firstErr != nil || t.n == 0 { + // We need to store the firstErr before releasing the mutex for the + // same reason. + res := t.firstErr + t.mtx.Unlock() + return res + } + t.mtx.Unlock() + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.doneCh: + // Must return t.firstErr as done is also closed if the last push + // failed. We don't need the mutex here as t.firstErr is never + // modified after doneCh is closed. + return t.firstErr + case <-t.errCh: + // We don't need the mutex here either as t.firstErr is never modified + // after errCh is closed. + return t.firstErr + } +} diff --git a/pkg/distributor/tracker_test.go b/pkg/distributor/tracker_test.go new file mode 100644 index 0000000000000..6d151c629086c --- /dev/null +++ b/pkg/distributor/tracker_test.go @@ -0,0 +1,139 @@ +package distributor + +import ( + "context" + "errors" + "math/rand" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestBasicPushTracker(t *testing.T) { + t.Run("a new tracker that has never been incremented should never block", func(t *testing.T) { + tracker := newBasicPushTracker() + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + t.Cleanup(cancel) + require.NoError(t, tracker.Wait(ctx)) + }) + + t.Run("a canceled context should return a context canceled error", func(t *testing.T) { + tracker := newBasicPushTracker() + tracker.Add(1) + ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond) + t.Cleanup(cancel) + require.EqualError(t, tracker.Wait(ctx), "context deadline exceeded") + }) + + t.Run("a done tracker with no errors should return nil", func(t *testing.T) { + tracker := newBasicPushTracker() + tracker.Add(1) + tracker.Done(nil) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + t.Cleanup(cancel) + require.NoError(t, tracker.Wait(ctx)) + }) + + t.Run("a done tracker with an error should return the error", func(t *testing.T) { + tracker := newBasicPushTracker() + tracker.Add(1) + tracker.Done(errors.New("an error occurred")) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + t.Cleanup(cancel) + require.EqualError(t, tracker.Wait(ctx), "an error occurred") + }) + + t.Run("a done tracker should return the first error that occurred", func(t *testing.T) { + tracker := newBasicPushTracker() + tracker.Add(2) + tracker.Done(errors.New("an error occurred")) + tracker.Done(errors.New("another error occurred")) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + t.Cleanup(cancel) + require.EqualError(t, tracker.Wait(ctx), "an error occurred") + }) + + t.Run("a done tracker should return at least one error", func(t *testing.T) { + t1 := newBasicPushTracker() + t1.Add(2) + t1.Done(nil) + t1.Done(errors.New("an error occurred")) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + t.Cleanup(cancel) + require.EqualError(t, t1.Wait(ctx), "an error occurred") + // And now test the opposite sequence. + t2 := newBasicPushTracker() + t2.Add(2) + t2.Done(errors.New("an error occurred")) + t2.Done(nil) + ctx, cancel = context.WithTimeout(t.Context(), time.Second) + t.Cleanup(cancel) + require.EqualError(t, t2.Wait(ctx), "an error occurred") + }) + + t.Run("more Done than Add should panic", func(t *testing.T) { + // Should panic if Done is called before Add. + require.PanicsWithValue(t, "Done called more times than Add", func() { + tracker := newBasicPushTracker() + tracker.Done(nil) + }) + // Should panic if Done is called more times than Add. + require.PanicsWithValue(t, "Done called more times than Add", func() { + tracker := newBasicPushTracker() + tracker.Add(1) + tracker.Done(nil) + tracker.Done(nil) + }) + }) + + t.Run("Add after Done should panic", func(t *testing.T) { + require.PanicsWithValue(t, "Add called after last call to Done", func() { + tracker := newBasicPushTracker() + tracker.Add(1) + tracker.Done(nil) + tracker.Add(1) + }) + }) + + t.Run("Negative counter should panic", func(t *testing.T) { + require.PanicsWithValue(t, "Negative counter", func() { + tracker := newBasicPushTracker() + tracker.Add(-1) + }) + }) +} + +// Run with go test -fuzz=FuzzBasicPushTracker. +func FuzzBasicPushTracker(f *testing.F) { + f.Add(uint16(100)) + f.Fuzz(func(t *testing.T, n uint16) { + wg := sync.WaitGroup{} + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + t.Cleanup(cancel) + tracker := newBasicPushTracker() + tracker.Add(int32(n)) + // Create a random number of waiters. + for i := 0; i < rand.Intn(100); i++ { + wg.Add(1) + go func() { + defer wg.Done() + // Sleep a random time up to 100ms. + time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) + require.NoError(t, tracker.Wait(ctx)) + }() + } + // Done should be called for each n, cannot be random. + for i := 0; i < int(n); i++ { + wg.Add(1) + go func() { + defer wg.Done() + // Sleep a random time up to 100ms too. + time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) + tracker.Done(nil) + }() + } + wg.Wait() + }) +}