From 34b04862fb6e478dd3c3a9361388f82d51ce752d Mon Sep 17 00:00:00 2001 From: Arthur Schreiber Date: Wed, 24 Sep 2025 14:34:20 +0000 Subject: [PATCH 1/5] fix: ensure callbacks are not called after `VStream` returns Signed-off-by: Arthur Schreiber --- go/vt/vtgate/vstream_manager.go | 7 +- go/vt/vtgate/vstream_manager_test.go | 572 ++++++++++++++------------- 2 files changed, 298 insertions(+), 281 deletions(-) diff --git a/go/vt/vtgate/vstream_manager.go b/go/vt/vtgate/vstream_manager.go index dc0f9ad329a..bbf92a9b595 100644 --- a/go/vt/vtgate/vstream_manager.go +++ b/go/vt/vtgate/vstream_manager.go @@ -321,7 +321,11 @@ func (vs *vstream) stream(ctx context.Context) error { ctx, vs.cancel = context.WithCancel(ctx) defer vs.cancel() - go vs.sendEvents(ctx) + vs.wg.Add(1) + go func() { + defer vs.wg.Done() + vs.sendEvents(ctx) + }() // Make a copy first, because the ShardGtids list can change once streaming starts. copylist := append(([]*binlogdatapb.ShardGtid)(nil), vs.vgtid.ShardGtids...) @@ -359,6 +363,7 @@ func (vs *vstream) sendEvents(ctx context.Context) { } return nil } + for { select { case <-ctx.Done(): diff --git a/go/vt/vtgate/vstream_manager_test.go b/go/vt/vtgate/vstream_manager_test.go index 69173b56735..4f04f2f554e 100644 --- a/go/vt/vtgate/vstream_manager_test.go +++ b/go/vt/vtgate/vstream_manager_test.go @@ -20,17 +20,14 @@ import ( "context" "fmt" "os" - "reflect" "runtime/pprof" "strings" "sync" - "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/protobuf/proto" "vitess.io/vitess/go/stats" "vitess.io/vitess/go/test/utils" @@ -113,16 +110,25 @@ func TestVStreamSkew(t *testing.T) { vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{Keyspace: ks, Gtid: "pos", Shard: "20-40"}) go stream(sbc1, ks, "20-40", tcase.numEventsPerShard, tcase.shard1idx) } - ch := startVStream(ctx, t, vsm, vgtid, &vtgatepb.VStreamFlags{MinimizeSkew: true}) - var receivedEvents []*binlogdatapb.VEvent - for len(receivedEvents) < int(want) { - select { - case <-time.After(1 * time.Minute): - require.FailNow(t, "test timed out") - case response := <-ch: - receivedEvents = append(receivedEvents, response.Events...) + + vstreamCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) + defer cancel() + + receivedEvents := make([]*binlogdatapb.VEvent, 0) + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{MinimizeSkew: true}, func(events []*binlogdatapb.VEvent) error { + receivedEvents = append(receivedEvents, events...) + + if int64(len(receivedEvents)) == want { + // Stop streaming after receiving both expected responses. + cancel() } - } + + return nil + }) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + require.Equal(t, int(want), int(len(receivedEvents))) require.Equal(t, tcase.expectedDelays, vsm.GetTotalStreamDelay()-previousDelays) previousDelays = vsm.GetTotalStreamDelay() @@ -187,23 +193,23 @@ func TestVStreamEventsExcludeKeyspaceFromTableName(t *testing.T) { Gtid: "pos", }}, } - ch := make(chan *binlogdatapb.VStreamResponse) - go func() { - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{ExcludeKeyspaceFromTableName: true}, func(events []*binlogdatapb.VEvent) error { - ch <- &binlogdatapb.VStreamResponse{Events: events} - return nil - }) - wantErr := "context canceled" - if err == nil || !strings.Contains(err.Error(), wantErr) { - t.Errorf("vstream end: %v, must contain %v", err.Error(), wantErr) + + receivedResponses := make([]*binlogdatapb.VStreamResponse, 0) + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{ExcludeKeyspaceFromTableName: true}, func(events []*binlogdatapb.VEvent) error { + receivedResponses = append(receivedResponses, &binlogdatapb.VStreamResponse{Events: events}) + + if len(receivedResponses) == 2 { + // Stop streaming after receiving both expected responses. + cancel() } - ch <- nil - }() - verifyEvents(t, ch, want1, want2) - // Ensure the go func error return was verified. - cancel() - <-ch + return nil + }) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + + require.ElementsMatch(t, []*binlogdatapb.VStreamResponse{want1, want2}, receivedResponses) } func TestVStreamEvents(t *testing.T) { @@ -262,23 +268,23 @@ func TestVStreamEvents(t *testing.T) { Gtid: "pos", }}, } - ch := make(chan *binlogdatapb.VStreamResponse) - go func() { - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { - ch <- &binlogdatapb.VStreamResponse{Events: events} - return nil - }) - wantErr := "context canceled" - if err == nil || !strings.Contains(err.Error(), wantErr) { - t.Errorf("vstream end: %v, must contain %v", err.Error(), wantErr) + + receivedEvents := make([]*binlogdatapb.VStreamResponse, 0) + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + receivedEvents = append(receivedEvents, &binlogdatapb.VStreamResponse{Events: events}) + + if len(receivedEvents) == 2 { + // Stop streaming after receiving both expected responses. + cancel() } - ch <- nil - }() - verifyEvents(t, ch, want1, want2) - // Ensure the go func error return was verified. - cancel() - <-ch + return nil + }) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + + require.ElementsMatch(t, []*binlogdatapb.VStreamResponse{want1, want2}, receivedEvents) } func BenchmarkVStreamEvents(b *testing.B) { @@ -339,53 +345,34 @@ func BenchmarkVStreamEvents(b *testing.B) { Gtid: "pos", }}, } - start := make(chan struct{}) - ch := make(chan *binlogdatapb.VStreamResponse) - go func() { - close(start) - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, - &vtgatepb.VStreamFlags{ExcludeKeyspaceFromTableName: tt.excludeKeyspaceFromTableName}, func(events []*binlogdatapb.VEvent) error { - ch <- &binlogdatapb.VStreamResponse{Events: events} - return nil - }) - wantErr := "context canceled" - if err == nil || !strings.Contains(err.Error(), wantErr) { - b.Errorf("vstream end: %v, must contain %v", err.Error(), wantErr) - } - ch <- nil - }() - // Start the timer when the VStream begins - <-start + // Start the timer and CPU profile after all setup is done b.ResetTimer() if os.Getenv("PROFILE_CPU") == "true" { pprof.StartCPUProfile(f) } received := 0 - for { - resp := <-ch - if resp == nil { - close(ch) - break - } - received += len(resp.Events) + err = vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{ExcludeKeyspaceFromTableName: tt.excludeKeyspaceFromTableName}, func(events []*binlogdatapb.VEvent) error { + received += len(events) + if received >= totalEvents { - b.Logf("Received events %d, expected total %d", received, totalEvents) - b.StopTimer() - if os.Getenv("PROFILE_CPU") == "true" { - pprof.StopCPUProfile() - } cancel() } - } - if received < totalEvents { - b.Errorf("expected at least %d events, got %d", totalEvents, received) + return nil + }) + + b.Logf("Received events %d, expected total %d", received, totalEvents) + b.StopTimer() + if os.Getenv("PROFILE_CPU") == "true" { + pprof.StopCPUProfile() } - cancel() - <-ch + require.Error(b, err) + require.ErrorIs(b, vterrors.UnwrapAll(err), context.Canceled) + + require.GreaterOrEqual(b, received, totalEvents) }) } } @@ -415,7 +402,6 @@ func TestVStreamChunks(t *testing.T) { rowEncountered := false doneCounting := false - var rowCount, ddlCount atomic.Int32 vgtid := &binlogdatapb.VGtid{ ShardGtids: []*binlogdatapb.ShardGtid{{ Keyspace: ks, @@ -427,7 +413,10 @@ func TestVStreamChunks(t *testing.T) { Gtid: "pos", }}, } - _ = vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + + var rowCount, ddlCount int + + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { switch events[0].Type { case binlogdatapb.VEventType_ROW: if doneCounting { @@ -435,30 +424,39 @@ func TestVStreamChunks(t *testing.T) { return fmt.Errorf("unexpected event: %v", events[0]) } rowEncountered = true - rowCount.Add(1) + rowCount += 1 + case binlogdatapb.VEventType_COMMIT: if !rowEncountered { t.Errorf("Unexpected event, COMMIT after non-rows: %v", events[0]) return fmt.Errorf("unexpected event: %v", events[0]) } doneCounting = true + case binlogdatapb.VEventType_DDL: if !doneCounting && rowEncountered { t.Errorf("Unexpected event, DDL during ROW events: %v", events[0]) return fmt.Errorf("unexpected event: %v", events[0]) } - ddlCount.Add(1) + ddlCount += 1 + default: t.Errorf("Unexpected event: %v", events[0]) return fmt.Errorf("unexpected event: %v", events[0]) } - if rowCount.Load() == int32(100) && ddlCount.Load() == int32(100) { + + if rowCount == 100 && ddlCount == 100 { cancel() } + return nil }) - assert.Equal(t, int32(100), rowCount.Load()) - assert.Equal(t, int32(100), ddlCount.Load()) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + + require.Equal(t, 100, rowCount) + require.Equal(t, 100, ddlCount) } func TestVStreamMulti(t *testing.T) { @@ -498,15 +496,31 @@ func TestVStreamMulti(t *testing.T) { Gtid: "pos", }}, } - ch := startVStream(ctx, t, vsm, vgtid, nil) - <-ch - response := <-ch + + receivedEvents := make([]*binlogdatapb.VEvent, 0) + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + receivedEvents = append(receivedEvents, events...) + + if len(receivedEvents) == 4 { + // Stop streaming after receiving both expected responses. + cancel() + } + + return nil + }) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + + require.Equal(t, 4, len(receivedEvents)) + var got *binlogdatapb.VGtid - for _, ev := range response.Events { + for _, ev := range receivedEvents { if ev.Type == binlogdatapb.VEventType_VGTID { got = ev.Vgtid } } + want := &binlogdatapb.VGtid{ ShardGtids: []*binlogdatapb.ShardGtid{{ Keyspace: ks, @@ -518,9 +532,8 @@ func TestVStreamMulti(t *testing.T) { Gtid: "gtid02", }}, } - if !proto.Equal(got, want) { - t.Errorf("VGtid:\n%v, want\n%v", got, want) - } + + require.ElementsMatch(t, got.ShardGtids, want.ShardGtids) } func TestVStreamsMetrics(t *testing.T) { @@ -566,52 +579,58 @@ func TestVStreamsMetrics(t *testing.T) { Gtid: "pos", }}, } - ch := startVStream(ctx, t, vsm, vgtid, nil) - <-ch - <-ch + expectedLabels1 := "TestVStream.-20.PRIMARY" expectedLabels2 := "TestVStream.20-40.PRIMARY" - wantVStreamsCreated := map[string]int64{ + receivedResponses := make([]*binlogdatapb.VStreamResponse, 0) + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + receivedResponses = append(receivedResponses, &binlogdatapb.VStreamResponse{Events: events}) + + // While the VStream is running, we should see one active stream per shard. + require.Equal(t, map[string]int64{ + expectedLabels1: 1, + expectedLabels2: 1, + }, vsm.vstreamsCount.Counts()) + + if len(receivedResponses) == 2 { + // Stop streaming after receiving both expected responses. + cancel() + } + + return nil + }) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + + require.Equal(t, 2, len(receivedResponses)) + + // After the streams end, the count should go back to zero. + require.Equal(t, map[string]int64{ + expectedLabels1: 0, + expectedLabels2: 0, + }, vsm.vstreamsCount.Counts()) + + require.Equal(t, map[string]int64{ expectedLabels1: 1, expectedLabels2: 1, - } - waitForMetricsMatch(t, vsm.vstreamsCreated.Counts, wantVStreamsCreated) + }, vsm.vstreamsCreated.Counts()) - wantVStreamsLag := map[string]int64{ + require.Equal(t, map[string]int64{ expectedLabels1: 5, expectedLabels2: 7, - } - waitForMetricsMatch(t, vsm.vstreamsLag.Counts, wantVStreamsLag) - - wantVStreamsCount := map[string]int64{ - expectedLabels1: 1, - expectedLabels2: 1, - } - waitForMetricsMatch(t, vsm.vstreamsCount.Counts, wantVStreamsCount) + }, vsm.vstreamsLag.Counts()) - wantVEventsCount := map[string]int64{ + require.Equal(t, map[string]int64{ expectedLabels1: 2, expectedLabels2: 2, - } - waitForMetricsMatch(t, vsm.vstreamsEventsStreamed.Counts, wantVEventsCount) + }, vsm.vstreamsEventsStreamed.Counts()) - wantVStreamsEndedWithErrors := map[string]int64{ + require.Equal(t, map[string]int64{ expectedLabels1: 0, expectedLabels2: 0, - } - waitForMetricsMatch(t, vsm.vstreamsEndedWithErrors.Counts, wantVStreamsEndedWithErrors) -} - -func waitForMetricsMatch(t *testing.T, getActual func() map[string]int64, want map[string]int64) { - deadline := time.Now().Add(1 * time.Second) - for time.Now().Before(deadline) { - if reflect.DeepEqual(getActual(), want) { - return - } - time.Sleep(10 * time.Millisecond) - } - assert.Equal(t, want, getActual(), "metrics did not match within timeout") + }, vsm.vstreamsEndedWithErrors.Counts()) } func TestVStreamsMetricsErrors(t *testing.T) { @@ -657,23 +676,11 @@ func TestVStreamsMetricsErrors(t *testing.T) { } results := make([]*binlogdatapb.VStreamResponse, 0) + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + results = append(results, &binlogdatapb.VStreamResponse{Events: events}) - var err error - ch := make(chan *binlogdatapb.VStreamResponse) - wg := sync.WaitGroup{} - wg.Go(func() { - for res := range ch { - results = append(results, res) - } - }) - wg.Go(func() { - err = vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { - ch <- &binlogdatapb.VStreamResponse{Events: events} - return nil - }) - close(ch) + return nil }) - wg.Wait() require.Error(t, err) require.ErrorContains(t, err, wantErr) @@ -751,8 +758,6 @@ func TestVStreamRetriableErrors(t *testing.T) { {Type: binlogdatapb.VEventType_COMMIT}, } - want := &binlogdatapb.VStreamResponse{Events: commit} - for _, tcase := range tcases { t.Run(tcase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -793,43 +798,24 @@ func TestVStreamRetriableErrors(t *testing.T) { }}, } - ch := make(chan *binlogdatapb.VStreamResponse) - done := make(chan struct{}) - go func() { - err := vsm.VStream(ctx, topodatapb.TabletType_REPLICA, vgtid, nil, &vtgatepb.VStreamFlags{Cells: strings.Join(cells, ",")}, func(events []*binlogdatapb.VEvent) error { - ch <- &binlogdatapb.VStreamResponse{Events: events} - return nil - }) - wantErr := "context canceled" + err := vsm.VStream(ctx, topodatapb.TabletType_REPLICA, vgtid, nil, &vtgatepb.VStreamFlags{Cells: strings.Join(cells, ",")}, func(events []*binlogdatapb.VEvent) error { + defer cancel() - if !tcase.shouldRetry { - wantErr = tcase.msg - } + require.Equal(t, 1, len(events)) + require.Equal(t, commit, events) - if err == nil || !strings.Contains(err.Error(), wantErr) { - t.Errorf("vstream end: %v, must contain %v", err.Error(), wantErr) - } - close(done) - }() - - Loop: - for { - if tcase.shouldRetry { - select { - case event := <-ch: - got := event.CloneVT() - if !proto.Equal(got, want) { - t.Errorf("got different vstream event than expected") - } - cancel() - case <-done: - // The goroutine has completed, so break out of the loop - break Loop - } - } else { - <-done - break Loop - } + return nil + }) + + if tcase.shouldRetry { + // Expect a cancel error because the stream was retried and our callback + // was called. + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + } else { + // Expect the original error because no retry was done. + require.Error(t, err) + require.ErrorContains(t, err, tcase.msg) } }) } @@ -882,8 +868,23 @@ func TestVStreamShouldNotSendSourceHeartbeats(t *testing.T) { Gtid: "pos", }}, } - ch := startVStream(ctx, t, vsm, vgtid, nil) - verifyEvents(t, ch, want) + + receivedResponses := make([]*binlogdatapb.VStreamResponse, 0) + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + receivedResponses = append(receivedResponses, &binlogdatapb.VStreamResponse{Events: events}) + + if len(receivedResponses) == 1 { + // Stop streaming after receiving the expected response. + cancel() + } + return nil + }) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + + require.Equal(t, 1, len(receivedResponses)) + require.EqualExportedValues(t, want, receivedResponses[0]) } func TestVStreamJournalOneToMany(t *testing.T) { @@ -968,14 +969,32 @@ func TestVStreamJournalOneToMany(t *testing.T) { Gtid: "pos", }}, } - ch := startVStream(ctx, t, vsm, vgtid, nil) - verifyEvents(t, ch, want1) - // The following two events from the different shards can come in any order. - // But the resulting VGTID should be the same after both are received. - <-ch - got := <-ch - wantevent := &binlogdatapb.VEvent{ + receivedEvents := make([]*binlogdatapb.VStreamResponse, 0) + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + receivedEvents = append(receivedEvents, &binlogdatapb.VStreamResponse{Events: events}) + + if len(receivedEvents) == 3 { + // Stop streaming after receiving all expected responses. + cancel() + } + + return nil + }) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + + require.Equal(t, 3, len(receivedEvents)) + + // First event should be the first transaction from the first shard. + require.EqualExportedValues(t, want1, receivedEvents[0]) + + // The second and third events can come in any order. + // So instead of comparing them directly, we simply verify that the GTID + // after the last event is the expected combined GTID. + + require.EqualExportedValues(t, &binlogdatapb.VEvent{ Type: binlogdatapb.VEventType_VGTID, Vgtid: &binlogdatapb.VGtid{ ShardGtids: []*binlogdatapb.ShardGtid{{ @@ -988,13 +1007,7 @@ func TestVStreamJournalOneToMany(t *testing.T) { Gtid: "gtid04", }}, }, - } - gotEvent := got.Events[0] - gotEvent.Keyspace = "" - gotEvent.Shard = "" - if !proto.Equal(gotEvent, wantevent) { - t.Errorf("vgtid: %v, want %v", got.Events[0], wantevent) - } + }, receivedEvents[2].Events[0]) } func TestVStreamJournalManyToOne(t *testing.T) { @@ -1087,12 +1100,25 @@ func TestVStreamJournalManyToOne(t *testing.T) { Gtid: "pos1020", }}, } - ch := startVStream(ctx, t, vsm, vgtid, nil) - // The following two events from the different shards can come in any order. - // But the resulting VGTID should be the same after both are received. - <-ch - got := <-ch - wantevent := &binlogdatapb.VEvent{ + + receivedResponses := make([]*binlogdatapb.VStreamResponse, 0) + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + receivedResponses = append(receivedResponses, &binlogdatapb.VStreamResponse{Events: events}) + + if len(receivedResponses) == 3 { + // Stop streaming after receiving all expected responses. + cancel() + } + + return nil + }) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + + require.Equal(t, 3, len(receivedResponses)) + + require.EqualExportedValues(t, &binlogdatapb.VEvent{ Type: binlogdatapb.VEventType_VGTID, Vgtid: &binlogdatapb.VGtid{ ShardGtids: []*binlogdatapb.ShardGtid{{ @@ -1105,14 +1131,9 @@ func TestVStreamJournalManyToOne(t *testing.T) { Gtid: "gtid04", }}, }, - } - gotEvent := got.Events[0] - gotEvent.Keyspace = "" - gotEvent.Shard = "" - if !proto.Equal(gotEvent, wantevent) { - t.Errorf("vgtid: %v, want %v", got.Events[0], wantevent) - } - verifyEvents(t, ch, want1) + }, receivedResponses[1].Events[0]) + + require.EqualExportedValues(t, want1, receivedResponses[2]) } func TestVStreamJournalNoMatch(t *testing.T) { @@ -1239,8 +1260,29 @@ func TestVStreamJournalNoMatch(t *testing.T) { Gtid: "pos", }}, } - ch := startVStream(ctx, t, vsm, vgtid, nil) - verifyEvents(t, ch, want1, wantjn1, want2, wantjn2, want3) + + receivedResponses := make([]*binlogdatapb.VStreamResponse, 0) + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + receivedResponses = append(receivedResponses, &binlogdatapb.VStreamResponse{Events: events}) + + if len(receivedResponses) == 5 { + // Stop streaming after receiving all expected responses. + cancel() + } + + return nil + }) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + + require.Equal(t, 5, len(receivedResponses)) + + require.EqualExportedValues(t, want1, receivedResponses[0]) + require.EqualExportedValues(t, wantjn1, receivedResponses[1]) + require.EqualExportedValues(t, want2, receivedResponses[2]) + require.EqualExportedValues(t, wantjn2, receivedResponses[3]) + require.EqualExportedValues(t, want3, receivedResponses[4]) } func TestVStreamJournalPartialMatch(t *testing.T) { @@ -1290,14 +1332,14 @@ func TestVStreamJournalPartialMatch(t *testing.T) { Gtid: "pos1020", }}, } + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { t.Errorf("unexpected events: %v", events) return nil }) - wantErr := "not all journaling participants are in the stream" - if err == nil || !strings.Contains(err.Error(), wantErr) { - t.Errorf("vstream end: %v, must contain %v", err, wantErr) - } + + require.Error(t, err) + require.Contains(t, err.Error(), "not all journaling participants are in the stream") // Try a different order (different code path) send = []*binlogdatapb.VEvent{ @@ -1319,14 +1361,14 @@ func TestVStreamJournalPartialMatch(t *testing.T) { }}, } sbc2.AddVStreamEvents(send, nil) + err = vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { t.Errorf("unexpected events: %v", events) return nil }) - if err == nil || !strings.Contains(err.Error(), wantErr) { - t.Errorf("vstream end: %v, must contain %v", err, wantErr) - } - cancel() + + require.Error(t, err) + require.Contains(t, err.Error(), "not all journaling participants are in the stream") } func TestResolveVStreamParams(t *testing.T) { @@ -1575,27 +1617,25 @@ func TestVStreamIdleHeartbeat(t *testing.T) { } for _, tcase := range testcases { t.Run(tcase.name, func(t *testing.T) { - var mu sync.Mutex var heartbeatCount int - ctx, cancel := context.WithCancel(ctx) - go func() { - vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{HeartbeatInterval: tcase.heartbeatInterval}, - func(events []*binlogdatapb.VEvent) error { - mu.Lock() - defer mu.Unlock() - for _, event := range events { - if event.Type == binlogdatapb.VEventType_HEARTBEAT { - heartbeatCount++ - } - } - return nil - }) - }() - time.Sleep(time.Duration(4500) * time.Millisecond) - mu.Lock() - defer mu.Unlock() + + ctx, cancel := context.WithTimeout(ctx, time.Duration(4500)*time.Millisecond) + defer cancel() + + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{HeartbeatInterval: tcase.heartbeatInterval}, func(events []*binlogdatapb.VEvent) error { + for _, event := range events { + if event.Type == binlogdatapb.VEventType_HEARTBEAT { + heartbeatCount++ + } + } + + return nil + }) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.DeadlineExceeded) + require.Equalf(t, heartbeatCount, tcase.want, "got %d, want %d", heartbeatCount, tcase.want) - cancel() }) } } @@ -1978,26 +2018,28 @@ func TestVStreamManagerHealthCheckResponseHandling(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - done := make(chan struct{}) - go func() { - sctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - defer close(done) - // SandboxConn's VStream implementation always waits for the context to timeout. - err := vsm.VStream(sctx, tabletType, vgtid, nil, nil, func(events []*binlogdatapb.VEvent) error { - require.Fail(t, "unexpected event", "Received unexpected events: %v", events) - return nil - }) - if tc.wantErr != "" { // Otherwise we simply expect the context to timeout - if !strings.Contains(logger.String(), tc.wantErr) { - require.Fail(t, "unexpected vstream error", "vstream ended with error: %v, which did not contain: %s", err, tc.wantErr) - } - } - }() if tc.wantErr != "" { source.SetStreamHealthResponse(tc.hcRes) } - <-done + + sctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + // SandboxConn's VStream implementation always waits for the context to timeout. + err := vsm.VStream(sctx, tabletType, vgtid, nil, nil, func(events []*binlogdatapb.VEvent) error { + require.Fail(t, "unexpected event", "Received unexpected events: %v", events) + return nil + }) + + if tc.wantErr != "" { + require.Error(t, err) + require.Contains(t, logger.String(), tc.wantErr) + } else { + // Otherwise we simply expect the context to timeout + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.DeadlineExceeded) + } + logger.Clear() }) } @@ -2009,36 +2051,6 @@ func newTestVStreamManager(ctx context.Context, hc discovery.HealthCheck, serv s return newVStreamManager(srvResolver, serv, cell) } -func startVStream(ctx context.Context, t *testing.T, vsm *vstreamManager, vgtid *binlogdatapb.VGtid, flags *vtgatepb.VStreamFlags) <-chan *binlogdatapb.VStreamResponse { - t.Helper() - if flags == nil { - flags = &vtgatepb.VStreamFlags{} - } - ch := make(chan *binlogdatapb.VStreamResponse) - go func() { - _ = vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, flags, func(events []*binlogdatapb.VEvent) error { - ch <- &binlogdatapb.VStreamResponse{Events: events} - return nil - }) - }() - return ch -} - -func verifyEvents(t *testing.T, ch <-chan *binlogdatapb.VStreamResponse, wants ...*binlogdatapb.VStreamResponse) { - t.Helper() - for i, want := range wants { - val := <-ch - got := val.CloneVT() - require.NotNil(t, got) - for _, event := range got.Events { - event.Timestamp = 0 - } - if !proto.Equal(got, want) { - t.Errorf("vstream(%d):\n%v, want\n%v", i, got, want) - } - } -} - func getVEvents(keyspace, shard string, count, idx int64) []*binlogdatapb.VEvent { mu.Lock() defer mu.Unlock() From 754e07e1c00bff8875001c98b9304475d63221ca Mon Sep 17 00:00:00 2001 From: Arthur Schreiber Date: Thu, 25 Sep 2025 12:22:23 +0000 Subject: [PATCH 2/5] Clean up contexts used in test cases Signed-off-by: Arthur Schreiber --- go/vt/vtgate/vstream_manager_test.go | 117 ++++++++++++++++++--------- 1 file changed, 77 insertions(+), 40 deletions(-) diff --git a/go/vt/vtgate/vstream_manager_test.go b/go/vt/vtgate/vstream_manager_test.go index 4f04f2f554e..49226cc7213 100644 --- a/go/vt/vtgate/vstream_manager_test.go +++ b/go/vt/vtgate/vstream_manager_test.go @@ -111,8 +111,8 @@ func TestVStreamSkew(t *testing.T) { go stream(sbc1, ks, "20-40", tcase.numEventsPerShard, tcase.shard1idx) } - vstreamCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) - defer cancel() + vstreamCtx, vstreamCancel := context.WithTimeout(ctx, 1*time.Minute) + defer vstreamCancel() receivedEvents := make([]*binlogdatapb.VEvent, 0) err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{MinimizeSkew: true}, func(events []*binlogdatapb.VEvent) error { @@ -120,7 +120,7 @@ func TestVStreamSkew(t *testing.T) { if int64(len(receivedEvents)) == want { // Stop streaming after receiving both expected responses. - cancel() + vstreamCancel() } return nil @@ -139,6 +139,7 @@ func TestVStreamSkew(t *testing.T) { func TestVStreamEventsExcludeKeyspaceFromTableName(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + cell := "aa" ks := "TestVStream" _ = createSandbox(ks) @@ -194,13 +195,16 @@ func TestVStreamEventsExcludeKeyspaceFromTableName(t *testing.T) { }}, } + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() + receivedResponses := make([]*binlogdatapb.VStreamResponse, 0) - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{ExcludeKeyspaceFromTableName: true}, func(events []*binlogdatapb.VEvent) error { + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{ExcludeKeyspaceFromTableName: true}, func(events []*binlogdatapb.VEvent) error { receivedResponses = append(receivedResponses, &binlogdatapb.VStreamResponse{Events: events}) if len(receivedResponses) == 2 { // Stop streaming after receiving both expected responses. - cancel() + vstreamCancel() } return nil @@ -269,13 +273,16 @@ func TestVStreamEvents(t *testing.T) { }}, } + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() + receivedEvents := make([]*binlogdatapb.VStreamResponse, 0) - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { receivedEvents = append(receivedEvents, &binlogdatapb.VStreamResponse{Events: events}) if len(receivedEvents) == 2 { // Stop streaming after receiving both expected responses. - cancel() + vstreamCancel() } return nil @@ -352,12 +359,15 @@ func BenchmarkVStreamEvents(b *testing.B) { pprof.StartCPUProfile(f) } + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() + received := 0 - err = vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{ExcludeKeyspaceFromTableName: tt.excludeKeyspaceFromTableName}, func(events []*binlogdatapb.VEvent) error { + err = vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{ExcludeKeyspaceFromTableName: tt.excludeKeyspaceFromTableName}, func(events []*binlogdatapb.VEvent) error { received += len(events) if received >= totalEvents { - cancel() + vstreamCancel() } return nil @@ -414,9 +424,11 @@ func TestVStreamChunks(t *testing.T) { }}, } - var rowCount, ddlCount int + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + var rowCount, ddlCount int + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { switch events[0].Type { case binlogdatapb.VEventType_ROW: if doneCounting { @@ -446,7 +458,7 @@ func TestVStreamChunks(t *testing.T) { } if rowCount == 100 && ddlCount == 100 { - cancel() + vstreamCancel() } return nil @@ -497,13 +509,16 @@ func TestVStreamMulti(t *testing.T) { }}, } + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() + receivedEvents := make([]*binlogdatapb.VEvent, 0) - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { receivedEvents = append(receivedEvents, events...) if len(receivedEvents) == 4 { // Stop streaming after receiving both expected responses. - cancel() + vstreamCancel() } return nil @@ -583,8 +598,11 @@ func TestVStreamsMetrics(t *testing.T) { expectedLabels1 := "TestVStream.-20.PRIMARY" expectedLabels2 := "TestVStream.20-40.PRIMARY" + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() + receivedResponses := make([]*binlogdatapb.VStreamResponse, 0) - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { receivedResponses = append(receivedResponses, &binlogdatapb.VStreamResponse{Events: events}) // While the VStream is running, we should see one active stream per shard. @@ -595,7 +613,7 @@ func TestVStreamsMetrics(t *testing.T) { if len(receivedResponses) == 2 { // Stop streaming after receiving both expected responses. - cancel() + vstreamCancel() } return nil @@ -675,10 +693,18 @@ func TestVStreamsMetricsErrors(t *testing.T) { }}, } + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() + results := make([]*binlogdatapb.VStreamResponse, 0) - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { results = append(results, &binlogdatapb.VStreamResponse{Events: events}) + if len(results) == 2 { + // We should never actually see 2 responses come in + vstreamCancel() + } + return nil }) @@ -798,8 +824,11 @@ func TestVStreamRetriableErrors(t *testing.T) { }}, } - err := vsm.VStream(ctx, topodatapb.TabletType_REPLICA, vgtid, nil, &vtgatepb.VStreamFlags{Cells: strings.Join(cells, ",")}, func(events []*binlogdatapb.VEvent) error { - defer cancel() + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() + + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_REPLICA, vgtid, nil, &vtgatepb.VStreamFlags{Cells: strings.Join(cells, ",")}, func(events []*binlogdatapb.VEvent) error { + defer vstreamCancel() require.Equal(t, 1, len(events)) require.Equal(t, commit, events) @@ -819,7 +848,6 @@ func TestVStreamRetriableErrors(t *testing.T) { } }) } - } func TestVStreamShouldNotSendSourceHeartbeats(t *testing.T) { @@ -869,13 +897,16 @@ func TestVStreamShouldNotSendSourceHeartbeats(t *testing.T) { }}, } + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() + receivedResponses := make([]*binlogdatapb.VStreamResponse, 0) - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { receivedResponses = append(receivedResponses, &binlogdatapb.VStreamResponse{Events: events}) if len(receivedResponses) == 1 { // Stop streaming after receiving the expected response. - cancel() + vstreamCancel() } return nil }) @@ -970,13 +1001,16 @@ func TestVStreamJournalOneToMany(t *testing.T) { }}, } + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() + receivedEvents := make([]*binlogdatapb.VStreamResponse, 0) - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { receivedEvents = append(receivedEvents, &binlogdatapb.VStreamResponse{Events: events}) if len(receivedEvents) == 3 { // Stop streaming after receiving all expected responses. - cancel() + vstreamCancel() } return nil @@ -1101,13 +1135,16 @@ func TestVStreamJournalManyToOne(t *testing.T) { }}, } + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() + receivedResponses := make([]*binlogdatapb.VStreamResponse, 0) - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { receivedResponses = append(receivedResponses, &binlogdatapb.VStreamResponse{Events: events}) if len(receivedResponses) == 3 { // Stop streaming after receiving all expected responses. - cancel() + vstreamCancel() } return nil @@ -1261,13 +1298,16 @@ func TestVStreamJournalNoMatch(t *testing.T) { }}, } + vstreamCtx, vstreamCancel := context.WithCancel(ctx) + defer vstreamCancel() + receivedResponses := make([]*binlogdatapb.VStreamResponse, 0) - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { receivedResponses = append(receivedResponses, &binlogdatapb.VStreamResponse{Events: events}) if len(receivedResponses) == 5 { // Stop streaming after receiving all expected responses. - cancel() + vstreamCancel() } return nil @@ -1334,8 +1374,7 @@ func TestVStreamJournalPartialMatch(t *testing.T) { } err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { - t.Errorf("unexpected events: %v", events) - return nil + return fmt.Errorf("unexpected events: %v", events) }) require.Error(t, err) @@ -1363,8 +1402,7 @@ func TestVStreamJournalPartialMatch(t *testing.T) { sbc2.AddVStreamEvents(send, nil) err = vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { - t.Errorf("unexpected events: %v", events) - return nil + return fmt.Errorf("unexpected events: %v", events) }) require.Error(t, err) @@ -1619,10 +1657,10 @@ func TestVStreamIdleHeartbeat(t *testing.T) { t.Run(tcase.name, func(t *testing.T) { var heartbeatCount int - ctx, cancel := context.WithTimeout(ctx, time.Duration(4500)*time.Millisecond) - defer cancel() + vstreamCtx, vstreamCancel := context.WithTimeout(ctx, time.Duration(4500)*time.Millisecond) + defer vstreamCancel() - err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{HeartbeatInterval: tcase.heartbeatInterval}, func(events []*binlogdatapb.VEvent) error { + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{HeartbeatInterval: tcase.heartbeatInterval}, func(events []*binlogdatapb.VEvent) error { for _, event := range events { if event.Type == binlogdatapb.VEventType_HEARTBEAT { heartbeatCount++ @@ -2022,13 +2060,12 @@ func TestVStreamManagerHealthCheckResponseHandling(t *testing.T) { source.SetStreamHealthResponse(tc.hcRes) } - sctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() + vstreamCtx, vstreamCancel := context.WithTimeout(ctx, 5*time.Second) + defer vstreamCancel() // SandboxConn's VStream implementation always waits for the context to timeout. - err := vsm.VStream(sctx, tabletType, vgtid, nil, nil, func(events []*binlogdatapb.VEvent) error { - require.Fail(t, "unexpected event", "Received unexpected events: %v", events) - return nil + err := vsm.VStream(vstreamCtx, tabletType, vgtid, nil, nil, func(events []*binlogdatapb.VEvent) error { + return fmt.Errorf("unexpected events: %v", events) }) if tc.wantErr != "" { From 5221fa0314dba627960849e8d8983832f43f7bfb Mon Sep 17 00:00:00 2001 From: Arthur Schreiber Date: Thu, 25 Sep 2025 12:24:56 +0000 Subject: [PATCH 3/5] Ensure errors in the callback cancel the VStream call. Signed-off-by: Arthur Schreiber --- go/vt/vtgate/vstream_manager.go | 8 +++- go/vt/vtgate/vstream_manager_test.go | 55 ++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/go/vt/vtgate/vstream_manager.go b/go/vt/vtgate/vstream_manager.go index bbf92a9b595..2cd9526f31e 100644 --- a/go/vt/vtgate/vstream_manager.go +++ b/go/vt/vtgate/vstream_manager.go @@ -319,11 +319,17 @@ func (vsm *vstreamManager) GetTotalStreamDelay() int64 { func (vs *vstream) stream(ctx context.Context) error { ctx, vs.cancel = context.WithCancel(ctx) - defer vs.cancel() vs.wg.Add(1) go func() { defer vs.wg.Done() + + // sendEvents returns either if the given context has been canceled or if + // an error is returned from the callback. If the callback returns an error, + // we need to cancel the context to stop the other stream goroutines + // and to unblock the VStream call. + defer vs.cancel() + vs.sendEvents(ctx) }() diff --git a/go/vt/vtgate/vstream_manager_test.go b/go/vt/vtgate/vstream_manager_test.go index 49226cc7213..34a551ab1d8 100644 --- a/go/vt/vtgate/vstream_manager_test.go +++ b/go/vt/vtgate/vstream_manager_test.go @@ -733,6 +733,61 @@ func TestVStreamsMetricsErrors(t *testing.T) { require.LessOrEqual(t, errorCounts["TestVStream.20-40.PRIMARY"], int64(1)) } +func TestVStreamErrorInCallback(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Use a unique cell to avoid parallel tests interfering with each other's metrics + cell := "ac" + ks := "TestVStream" + _ = createSandbox(ks) + hc := discovery.NewFakeHealthCheck(nil) + st := getSandboxTopo(ctx, cell, ks, []string{"-20", "20-40"}) + vsm := newTestVStreamManager(ctx, hc, st, cell) + vsm.vstreamsCreated.ResetAll() + vsm.vstreamsLag.ResetAll() + vsm.vstreamsCount.ResetAll() + vsm.vstreamsEventsStreamed.ResetAll() + vsm.vstreamsEndedWithErrors.ResetAll() + sbc0 := hc.AddTestTablet(cell, "1.1.1.1", 1001, ks, "-20", topodatapb.TabletType_PRIMARY, true, 1, nil) + addTabletToSandboxTopo(t, ctx, st, ks, "-20", sbc0.Tablet()) + sbc1 := hc.AddTestTablet(cell, "1.1.1.2", 1002, ks, "20-40", topodatapb.TabletType_PRIMARY, true, 1, nil) + addTabletToSandboxTopo(t, ctx, st, ks, "20-40", sbc1.Tablet()) + + send1 := []*binlogdatapb.VEvent{ + {Type: binlogdatapb.VEventType_GTID, Gtid: "gtid01"}, + {Type: binlogdatapb.VEventType_COMMIT, Timestamp: 10, CurrentTime: 15 * 1e9}, + } + sbc0.AddVStreamEvents(send1, nil) + + send2 := []*binlogdatapb.VEvent{ + {Type: binlogdatapb.VEventType_GTID, Gtid: "gtid02"}, + {Type: binlogdatapb.VEventType_COMMIT, Timestamp: 10, CurrentTime: 17 * 1e9}, + } + sbc1.AddVStreamEvents(send2, nil) + + vgtid := &binlogdatapb.VGtid{ + ShardGtids: []*binlogdatapb.ShardGtid{{ + Keyspace: ks, + Shard: "-20", + Gtid: "pos", + }, { + Keyspace: ks, + Shard: "20-40", + Gtid: "pos", + }}, + } + + expectedError := fmt.Errorf("callback error") + + err := vsm.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{}, func(events []*binlogdatapb.VEvent) error { + return expectedError + }) + + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), expectedError) +} + func TestVStreamRetriableErrors(t *testing.T) { type testCase struct { name string From 6b033133a681cca15c3b419ea91f23101ff869ef Mon Sep 17 00:00:00 2001 From: Arthur Schreiber Date: Fri, 26 Sep 2025 09:38:02 +0000 Subject: [PATCH 4/5] Showcase using `synctest`. Signed-off-by: Arthur Schreiber --- go/vt/vtgate/vstream_manager_test.go | 155 ++++++++++++++------------- 1 file changed, 81 insertions(+), 74 deletions(-) diff --git a/go/vt/vtgate/vstream_manager_test.go b/go/vt/vtgate/vstream_manager_test.go index 34a551ab1d8..6b3b4d834bd 100644 --- a/go/vt/vtgate/vstream_manager_test.go +++ b/go/vt/vtgate/vstream_manager_test.go @@ -24,6 +24,7 @@ import ( "strings" "sync" "testing" + "testing/synctest" "time" "github.com/stretchr/testify/assert" @@ -83,55 +84,57 @@ func TestVStreamSkew(t *testing.T) { cell := "aa" for idx, tcase := range tcases { t.Run("", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ks := fmt.Sprintf("TestVStreamSkew-%d", idx) - _ = createSandbox(ks) - hc := discovery.NewFakeHealthCheck(nil) - st := getSandboxTopo(ctx, cell, ks, []string{"-20", "20-40"}) - vsm := newTestVStreamManager(ctx, hc, st, cell) - vgtid := &binlogdatapb.VGtid{ShardGtids: []*binlogdatapb.ShardGtid{}} - want := int64(0) - var sbc0, sbc1 *sandboxconn.SandboxConn - if tcase.shard0idx != 0 { - sbc0 = hc.AddTestTablet(cell, "1.1.1.1", 1001, ks, "-20", topodatapb.TabletType_PRIMARY, true, 1, nil) - addTabletToSandboxTopo(t, ctx, st, ks, "-20", sbc0.Tablet()) - sbc0.VStreamCh = make(chan *binlogdatapb.VEvent) - want += 2 * tcase.numEventsPerShard - vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{Keyspace: ks, Gtid: "pos", Shard: "-20"}) - go stream(sbc0, ks, "-20", tcase.numEventsPerShard, tcase.shard0idx) - } - if tcase.shard1idx != 0 { - sbc1 = hc.AddTestTablet(cell, "1.1.1.1", 1002, ks, "20-40", topodatapb.TabletType_PRIMARY, true, 1, nil) - addTabletToSandboxTopo(t, ctx, st, ks, "20-40", sbc1.Tablet()) - sbc1.VStreamCh = make(chan *binlogdatapb.VEvent) - want += 2 * tcase.numEventsPerShard - vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{Keyspace: ks, Gtid: "pos", Shard: "20-40"}) - go stream(sbc1, ks, "20-40", tcase.numEventsPerShard, tcase.shard1idx) - } + synctest.Test(t, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ks := fmt.Sprintf("TestVStreamSkew-%d", idx) + _ = createSandbox(ks) + hc := discovery.NewFakeHealthCheck(nil) + st := getSandboxTopo(ctx, cell, ks, []string{"-20", "20-40"}) + vsm := newTestVStreamManager(ctx, hc, st, cell) + vgtid := &binlogdatapb.VGtid{ShardGtids: []*binlogdatapb.ShardGtid{}} + want := int64(0) + var sbc0, sbc1 *sandboxconn.SandboxConn + if tcase.shard0idx != 0 { + sbc0 = hc.AddTestTablet(cell, "1.1.1.1", 1001, ks, "-20", topodatapb.TabletType_PRIMARY, true, 1, nil) + addTabletToSandboxTopo(t, ctx, st, ks, "-20", sbc0.Tablet()) + sbc0.VStreamCh = make(chan *binlogdatapb.VEvent) + want += 2 * tcase.numEventsPerShard + vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{Keyspace: ks, Gtid: "pos", Shard: "-20"}) + go stream(sbc0, ks, "-20", tcase.numEventsPerShard, tcase.shard0idx) + } + if tcase.shard1idx != 0 { + sbc1 = hc.AddTestTablet(cell, "1.1.1.1", 1002, ks, "20-40", topodatapb.TabletType_PRIMARY, true, 1, nil) + addTabletToSandboxTopo(t, ctx, st, ks, "20-40", sbc1.Tablet()) + sbc1.VStreamCh = make(chan *binlogdatapb.VEvent) + want += 2 * tcase.numEventsPerShard + vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{Keyspace: ks, Gtid: "pos", Shard: "20-40"}) + go stream(sbc1, ks, "20-40", tcase.numEventsPerShard, tcase.shard1idx) + } - vstreamCtx, vstreamCancel := context.WithTimeout(ctx, 1*time.Minute) - defer vstreamCancel() + vstreamCtx, vstreamCancel := context.WithTimeout(ctx, 1*time.Minute) + defer vstreamCancel() - receivedEvents := make([]*binlogdatapb.VEvent, 0) - err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{MinimizeSkew: true}, func(events []*binlogdatapb.VEvent) error { - receivedEvents = append(receivedEvents, events...) + receivedEvents := make([]*binlogdatapb.VEvent, 0) + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{MinimizeSkew: true}, func(events []*binlogdatapb.VEvent) error { + receivedEvents = append(receivedEvents, events...) - if int64(len(receivedEvents)) == want { - // Stop streaming after receiving both expected responses. - vstreamCancel() - } + if int64(len(receivedEvents)) == want { + // Stop streaming after receiving both expected responses. + vstreamCancel() + } - return nil - }) + return nil + }) - require.Error(t, err) - require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.Canceled) - require.Equal(t, int(want), int(len(receivedEvents))) - require.Equal(t, tcase.expectedDelays, vsm.GetTotalStreamDelay()-previousDelays) - previousDelays = vsm.GetTotalStreamDelay() + require.Equal(t, int(want), int(len(receivedEvents))) + require.Equal(t, tcase.expectedDelays, vsm.GetTotalStreamDelay()-previousDelays) + previousDelays = vsm.GetTotalStreamDelay() + }) }) } } @@ -1710,25 +1713,27 @@ func TestVStreamIdleHeartbeat(t *testing.T) { } for _, tcase := range testcases { t.Run(tcase.name, func(t *testing.T) { - var heartbeatCount int + synctest.Test(t, func(t *testing.T) { + var heartbeatCount int - vstreamCtx, vstreamCancel := context.WithTimeout(ctx, time.Duration(4500)*time.Millisecond) - defer vstreamCancel() + vstreamCtx, vstreamCancel := context.WithTimeout(ctx, time.Duration(4500)*time.Millisecond) + defer vstreamCancel() - err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{HeartbeatInterval: tcase.heartbeatInterval}, func(events []*binlogdatapb.VEvent) error { - for _, event := range events { - if event.Type == binlogdatapb.VEventType_HEARTBEAT { - heartbeatCount++ + err := vsm.VStream(vstreamCtx, topodatapb.TabletType_PRIMARY, vgtid, nil, &vtgatepb.VStreamFlags{HeartbeatInterval: tcase.heartbeatInterval}, func(events []*binlogdatapb.VEvent) error { + for _, event := range events { + if event.Type == binlogdatapb.VEventType_HEARTBEAT { + heartbeatCount++ + } } - } - return nil - }) + return nil + }) - require.Error(t, err) - require.ErrorIs(t, vterrors.UnwrapAll(err), context.DeadlineExceeded) + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.DeadlineExceeded) - require.Equalf(t, heartbeatCount, tcase.want, "got %d, want %d", heartbeatCount, tcase.want) + require.Equalf(t, heartbeatCount, tcase.want, "got %d, want %d", heartbeatCount, tcase.want) + }) }) } } @@ -2111,28 +2116,30 @@ func TestVStreamManagerHealthCheckResponseHandling(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - if tc.wantErr != "" { - source.SetStreamHealthResponse(tc.hcRes) - } + synctest.Test(t, func(t *testing.T) { + if tc.wantErr != "" { + source.SetStreamHealthResponse(tc.hcRes) + } - vstreamCtx, vstreamCancel := context.WithTimeout(ctx, 5*time.Second) - defer vstreamCancel() + vstreamCtx, vstreamCancel := context.WithTimeout(ctx, 5*time.Second) + defer vstreamCancel() - // SandboxConn's VStream implementation always waits for the context to timeout. - err := vsm.VStream(vstreamCtx, tabletType, vgtid, nil, nil, func(events []*binlogdatapb.VEvent) error { - return fmt.Errorf("unexpected events: %v", events) - }) + // SandboxConn's VStream implementation always waits for the context to timeout. + err := vsm.VStream(vstreamCtx, tabletType, vgtid, nil, nil, func(events []*binlogdatapb.VEvent) error { + return fmt.Errorf("unexpected events: %v", events) + }) - if tc.wantErr != "" { - require.Error(t, err) - require.Contains(t, logger.String(), tc.wantErr) - } else { - // Otherwise we simply expect the context to timeout - require.Error(t, err) - require.ErrorIs(t, vterrors.UnwrapAll(err), context.DeadlineExceeded) - } + if tc.wantErr != "" { + require.Error(t, err) + require.Contains(t, logger.String(), tc.wantErr) + } else { + // Otherwise we simply expect the context to timeout + require.Error(t, err) + require.ErrorIs(t, vterrors.UnwrapAll(err), context.DeadlineExceeded) + } - logger.Clear() + logger.Clear() + }) }) } } From 39705b4ba7df0e402d5bc8a3e7ccde5701f3318d Mon Sep 17 00:00:00 2001 From: Arthur Schreiber Date: Mon, 29 Sep 2025 17:08:59 +0000 Subject: [PATCH 5/5] Add more synctest usage. Signed-off-by: Arthur Schreiber --- go/test/utils/noleak.go | 2 + go/vt/vtgate/executor_scatter_stats_test.go | 40 ++-- go/vt/vtgate/executor_select_test.go | 83 +++---- go/vt/vtgate/executor_test.go | 25 +- go/vt/vtgate/plugin_mysql_server_test.go | 54 +++-- go/vt/vtgate/tabletgateway_flaky_test.go | 239 ++++++++++---------- go/vt/vtgate/vstream_manager_test.go | 1 + 7 files changed, 232 insertions(+), 212 deletions(-) diff --git a/go/test/utils/noleak.go b/go/test/utils/noleak.go index 41e1a42b960..ea8dc10513d 100644 --- a/go/test/utils/noleak.go +++ b/go/test/utils/noleak.go @@ -72,6 +72,8 @@ func ensureNoLeaks() error { func ensureNoGoroutines() error { var ignored = []goleak.Option{ + goleak.IgnoreTopFunction("internal/synctest.Run"), + goleak.IgnoreTopFunction("testing/synctest.testingSynctestTest"), goleak.IgnoreTopFunction("github.com/golang/glog.(*fileSink).flushDaemon"), goleak.IgnoreTopFunction("github.com/golang/glog.(*loggingT).flushDaemon"), goleak.IgnoreTopFunction("vitess.io/vitess/go/vt/dbconfigs.init.0.func1"), diff --git a/go/vt/vtgate/executor_scatter_stats_test.go b/go/vt/vtgate/executor_scatter_stats_test.go index 487b2ea4df6..115b35aa36c 100644 --- a/go/vt/vtgate/executor_scatter_stats_test.go +++ b/go/vt/vtgate/executor_scatter_stats_test.go @@ -19,7 +19,7 @@ package vtgate import ( "net/http/httptest" "testing" - "time" + "testing/synctest" "github.com/stretchr/testify/require" @@ -53,29 +53,31 @@ func TestScatterStatsWithSingleScatterQuery(t *testing.T) { } func TestScatterStatsHttpWriting(t *testing.T) { - executor, _, _, _, ctx := createExecutorEnv(t) - session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) + synctest.Test(t, func(t *testing.T) { + executor, _, _, _, ctx := createExecutorEnv(t) + session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: "@primary"}) - _, err := executorExecSession(ctx, executor, session, "select * from user", nil) - require.NoError(t, err) + _, err := executorExecSession(ctx, executor, session, "select * from user", nil) + require.NoError(t, err) - _, err = executorExecSession(ctx, executor, session, "select * from user where Id = 15", nil) - require.NoError(t, err) + _, err = executorExecSession(ctx, executor, session, "select * from user where Id = 15", nil) + require.NoError(t, err) - _, err = executorExecSession(ctx, executor, session, "select * from user where Id > 15", nil) - require.NoError(t, err) + _, err = executorExecSession(ctx, executor, session, "select * from user where Id > 15", nil) + require.NoError(t, err) - query4 := "select * from user as u1 join user as u2 on u1.Id = u2.Id" - _, err = executorExecSession(ctx, executor, session, query4, nil) - require.NoError(t, err) + query4 := "select * from user as u1 join user as u2 on u1.Id = u2.Id" + _, err = executorExecSession(ctx, executor, session, query4, nil) + require.NoError(t, err) - time.Sleep(500 * time.Millisecond) + synctest.Wait() - recorder := httptest.NewRecorder() - executor.WriteScatterStats(recorder) + recorder := httptest.NewRecorder() + executor.WriteScatterStats(recorder) - // Here we are checking that the template was executed correctly. - // If it wasn't, instead of html, we'll get an error message - require.Contains(t, recorder.Body.String(), "select * from `user` as u1 join `user` as u2 on u1.Id = u2.Id") - require.NoError(t, err) + // Here we are checking that the template was executed correctly. + // If it wasn't, instead of html, we'll get an error message + require.Contains(t, recorder.Body.String(), "select * from `user` as u1 join `user` as u2 on u1.Id = u2.Id") + require.NoError(t, err) + }) } diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 4310e7c98eb..89e599003b6 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -25,6 +25,7 @@ import ( "strconv" "strings" "testing" + "testing/synctest" "time" "github.com/google/go-cmp/cmp" @@ -3476,46 +3477,50 @@ func TestSelectFromInformationSchema(t *testing.T) { } func TestStreamOrderByWithMultipleResults(t *testing.T) { - ctx := utils.LeakCheckContext(t) - - // Special setup: Don't use createExecutorEnv. - cell := "aa" - hc := discovery.NewFakeHealthCheck(nil) - u := createSandbox(KsTestUnsharded) - s := createSandbox(KsTestSharded) - s.VSchema = executorVSchema - u.VSchema = unshardedVSchema - serv := newSandboxForCells(ctx, []string{cell}) - resolver := newTestResolver(ctx, hc, serv, cell) - shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"} - count := 1 - for _, shard := range shards { - sbc := hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_PRIMARY, true, 1, nil) - sbc.SetResults([]*sqltypes.Result{ - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|weight_string(id)", "int32|int32|varchar"), fmt.Sprintf("%d|%d|NULL", count, count)), - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|weight_string(id)", "int32|int32|varchar"), fmt.Sprintf("%d|%d|NULL", count+10, count)), - }) - count++ - } - queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) - plans := DefaultPlanCache() - executor := NewExecutor(ctx, vtenv.NewTestEnv(), serv, cell, resolver, createExecutorConfigWithNormalizer(), false, plans, nil, querypb.ExecuteOptions_Gen4, NewDynamicViperConfig()) - executor.SetQueryLogger(queryLogger) - defer executor.Close() - // some sleep for all goroutines to start - time.Sleep(100 * time.Millisecond) - before := runtime.NumGoroutine() + synctest.Test(t, func(t *testing.T) { + ctx := utils.LeakCheckContext(t) - query := "select id, col from user order by id" - gotResult, err := executorStream(ctx, executor, query) - require.NoError(t, err) - - wantResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col", "int32|int32"), - "1|1", "2|2", "3|3", "4|4", "5|5", "6|6", "7|7", "8|8", "11|1", "12|2", "13|3", "14|4", "15|5", "16|6", "17|7", "18|8") - assert.Equal(t, fmt.Sprintf("%v", wantResult.Rows), fmt.Sprintf("%v", gotResult.Rows)) - // some sleep to close all goroutines. - time.Sleep(100 * time.Millisecond) - assert.GreaterOrEqual(t, before, runtime.NumGoroutine(), "left open goroutines lingering") + // Special setup: Don't use createExecutorEnv. + cell := "aa" + hc := discovery.NewFakeHealthCheck(nil) + u := createSandbox(KsTestUnsharded) + s := createSandbox(KsTestSharded) + s.VSchema = executorVSchema + u.VSchema = unshardedVSchema + serv := newSandboxForCells(ctx, []string{cell}) + resolver := newTestResolver(ctx, hc, serv, cell) + shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"} + count := 1 + for _, shard := range shards { + sbc := hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_PRIMARY, true, 1, nil) + sbc.SetResults([]*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|weight_string(id)", "int32|int32|varchar"), fmt.Sprintf("%d|%d|NULL", count, count)), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|weight_string(id)", "int32|int32|varchar"), fmt.Sprintf("%d|%d|NULL", count+10, count)), + }) + count++ + } + queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) + plans := DefaultPlanCache() + executor := NewExecutor(ctx, vtenv.NewTestEnv(), serv, cell, resolver, createExecutorConfigWithNormalizer(), false, plans, nil, querypb.ExecuteOptions_Gen4, NewDynamicViperConfig()) + executor.SetQueryLogger(queryLogger) + defer executor.Close() + + // some sleep for all goroutines to start + synctest.Wait() + before := runtime.NumGoroutine() + + query := "select id, col from user order by id" + gotResult, err := executorStream(ctx, executor, query) + require.NoError(t, err) + + wantResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col", "int32|int32"), + "1|1", "2|2", "3|3", "4|4", "5|5", "6|6", "7|7", "8|8", "11|1", "12|2", "13|3", "14|4", "15|5", "16|6", "17|7", "18|8") + assert.Equal(t, fmt.Sprintf("%v", wantResult.Rows), fmt.Sprintf("%v", gotResult.Rows)) + + // some sleep to close all goroutines. + synctest.Wait() + assert.GreaterOrEqual(t, before, runtime.NumGoroutine(), "left open goroutines lingering") + }) } func TestStreamOrderByLimitWithMultipleResults(t *testing.T) { diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index c47b585ef18..0fe3531fa5e 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -28,6 +28,7 @@ import ( "sort" "strings" "testing" + "testing/synctest" "time" "unsafe" @@ -1579,17 +1580,19 @@ func TestExecutorUnrecognized(t *testing.T) { } func TestExecutorDeniedErrorNoBuffer(t *testing.T) { - executor, sbc1, _, _, ctx := createExecutorEnv(t) - sbc1.EphemeralShardErr = errors.New("enforce denied tables") - - vschemaWaitTimeout = 500 * time.Millisecond - - session := econtext.NewAutocommitSession(&vtgatepb.Session{TargetString: "@primary"}) - startExec := time.Now() - _, err := executorExecSession(ctx, executor, session, "select * from user", nil) - require.NoError(t, err, "enforce denied tables not buffered") - endExec := time.Now() - require.GreaterOrEqual(t, endExec.Sub(startExec).Milliseconds(), int64(500)) + synctest.Test(t, func(t *testing.T) { + executor, sbc1, _, _, ctx := createExecutorEnv(t) + sbc1.EphemeralShardErr = errors.New("enforce denied tables") + + vschemaWaitTimeout = 500 * time.Millisecond + + session := econtext.NewAutocommitSession(&vtgatepb.Session{TargetString: "@primary"}) + startExec := time.Now() + _, err := executorExecSession(ctx, executor, session, "select * from user", nil) + require.NoError(t, err, "enforce denied tables not buffered") + endExec := time.Now() + require.GreaterOrEqual(t, endExec.Sub(startExec).Milliseconds(), int64(500)) + }) } // TestVSchemaStats makes sure the building and displaying of the diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index a77d66d32b0..b6e945dad51 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -26,7 +26,7 @@ import ( "strings" "syscall" "testing" - "time" + "testing/synctest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -299,35 +299,39 @@ func TestInitTLSConfigWithServerCA(t *testing.T) { } func testInitTLSConfig(t *testing.T, serverCA bool) { - // Create the certs. - ctx := utils.LeakCheckContext(t) - - root := t.TempDir() - tlstest.CreateCA(root) - tlstest.CreateCRL(root, tlstest.CA) - tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") + synctest.Test(t, func(t *testing.T) { + // Create the certs. + ctx := utils.LeakCheckContext(t) + + root := t.TempDir() + tlstest.CreateCA(root) + tlstest.CreateCRL(root, tlstest.CA) + tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") + + serverCACert := "" + if serverCA { + serverCACert = path.Join(root, "ca-cert.pem") + } - serverCACert := "" - if serverCA { - serverCACert = path.Join(root, "ca-cert.pem") - } + srv := &mysqlServer{tcpListener: &mysql.Listener{}} + if err := initTLSConfig(ctx, srv, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), path.Join(root, "ca-crl.pem"), serverCACert, true, tls.VersionTLS12); err != nil { + t.Fatalf("init tls config failure due to: +%v", err) + } - srv := &mysqlServer{tcpListener: &mysql.Listener{}} - if err := initTLSConfig(ctx, srv, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), path.Join(root, "ca-crl.pem"), serverCACert, true, tls.VersionTLS12); err != nil { - t.Fatalf("init tls config failure due to: +%v", err) - } + serverConfig := srv.tcpListener.TLSConfig.Load() + if serverConfig == nil { + t.Fatalf("init tls config shouldn't create nil server config") + } - serverConfig := srv.tcpListener.TLSConfig.Load() - if serverConfig == nil { - t.Fatalf("init tls config shouldn't create nil server config") - } + srv.sigChan <- syscall.SIGHUP - srv.sigChan <- syscall.SIGHUP - time.Sleep(100 * time.Millisecond) // wait for signal handler + // wait for signal handler + synctest.Wait() - if srv.tcpListener.TLSConfig.Load() == serverConfig { - t.Fatalf("init tls config should have been recreated after SIGHUP") - } + if srv.tcpListener.TLSConfig.Load() == serverConfig { + t.Fatalf("init tls config should have been recreated after SIGHUP") + } + }) } // TestKillMethods test the mysql plugin for kill method calls. diff --git a/go/vt/vtgate/tabletgateway_flaky_test.go b/go/vt/vtgate/tabletgateway_flaky_test.go index 124997bea9e..7b6bf332644 100644 --- a/go/vt/vtgate/tabletgateway_flaky_test.go +++ b/go/vt/vtgate/tabletgateway_flaky_test.go @@ -18,6 +18,7 @@ package vtgate import ( "testing" + "testing/synctest" "time" econtext "vitess.io/vitess/go/vt/vtgate/executorcontext" @@ -138,130 +139,132 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) { // TestGatewayBufferingWhileReparenting is used to test that the buffering mechanism buffers the queries when a PRS happens // the healthchecks that happen during a PRS are simulated in this test func TestGatewayBufferingWhileReparenting(t *testing.T) { - ctx := utils.LeakCheckContext(t) - - buffer.SetBufferingModeInTestingEnv(true) - defer func() { - buffer.SetBufferingModeInTestingEnv(false) - }() - - keyspace := "ks1" - shard := "-80" - tabletType := topodatapb.TabletType_PRIMARY - host := "1.1.1.1" - hostReplica := "1.1.1.2" - port := int32(1001) - portReplica := int32(1002) - target := &querypb.Target{ - Keyspace: keyspace, - Shard: shard, - TabletType: tabletType, - } - - ts := &econtext.FakeTopoServer{} - // create a new fake health check. We want to check the buffering code which uses Subscribe, so we must also pass a channel - hc := discovery.NewFakeHealthCheck(make(chan *discovery.TabletHealth)) - // create a new tablet gateway - tg := NewTabletGateway(ctx, hc, ts, "cell") - defer tg.Close(ctx) - - // add a primary tablet which is serving - sbc := hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil) - // also add a replica which is serving - sbcReplica := hc.AddTestTablet("cell", hostReplica, portReplica, keyspace, shard, topodatapb.TabletType_REPLICA, true, 0, nil) - - // add a result to the sandbox connection - sqlResult1 := &sqltypes.Result{ - Fields: []*querypb.Field{{ - Name: "col1", - Type: sqltypes.VarChar, - Charset: uint32(collations.MySQL8().DefaultConnectionCharset()), - }}, - RowsAffected: 1, - Rows: [][]sqltypes.Value{{ - sqltypes.MakeTrusted(sqltypes.VarChar, []byte("bb")), - }}, - } - sbc.SetResults([]*sqltypes.Result{sqlResult1}) - - // run a query that we indeed get the result added to the sandbox connection back - // this also checks that the query reaches the primary tablet and not the replica - res, err := tg.Execute(ctx, target, "query", nil, 0, 0, nil) - require.NoError(t, err) - require.Equal(t, res, sqlResult1) - - // get the primary and replica tablet from the fake health check - tablets := hc.GetAllTablets() - var primaryTablet *topodatapb.Tablet - var replicaTablet *topodatapb.Tablet - - for _, tablet := range tablets { - if tablet.Type == topodatapb.TabletType_PRIMARY { - primaryTablet = tablet - } else { - replicaTablet = tablet + synctest.Test(t, func(t *testing.T) { + ctx := utils.LeakCheckContext(t) + + buffer.SetBufferingModeInTestingEnv(true) + defer func() { + buffer.SetBufferingModeInTestingEnv(false) + }() + + keyspace := "ks1" + shard := "-80" + tabletType := topodatapb.TabletType_PRIMARY + host := "1.1.1.1" + hostReplica := "1.1.1.2" + port := int32(1001) + portReplica := int32(1002) + target := &querypb.Target{ + Keyspace: keyspace, + Shard: shard, + TabletType: tabletType, } - } - require.NotNil(t, primaryTablet) - require.NotNil(t, replicaTablet) - // broadcast its state initially - hc.Broadcast(primaryTablet) - // set the serving type for the primary tablet false and broadcast it so that the buffering code registers this change - hc.SetServing(primaryTablet, false) - // We call the broadcast twice to ensure that the change has been processed by the keyspace event watcher. - // The second broadcast call is blocking until the first one has been processed. - hc.Broadcast(primaryTablet) - hc.Broadcast(primaryTablet) - - require.Len(t, tg.hc.GetHealthyTabletStats(target), 0, "GetHealthyTabletStats has tablets even though it shouldn't") - _, shouldStartBuffering := tg.kev.ShouldStartBufferingForTarget(ctx, target) - require.True(t, shouldStartBuffering) - - // add a result to the sandbox connection of the new primary - sbcReplica.SetResults([]*sqltypes.Result{sqlResult1}) - - // execute the query in a go routine since it should be buffered, and check that it eventually succeed - queryChan := make(chan struct{}) - go func() { - res, err = tg.Execute(ctx, target, "query", nil, 0, 0, nil) - queryChan <- struct{}{} - }() + ts := &econtext.FakeTopoServer{} + // create a new fake health check. We want to check the buffering code which uses Subscribe, so we must also pass a channel + hc := discovery.NewFakeHealthCheck(make(chan *discovery.TabletHealth)) + // create a new tablet gateway + tg := NewTabletGateway(ctx, hc, ts, "cell") + defer tg.Close(ctx) + + // add a primary tablet which is serving + sbc := hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil) + // also add a replica which is serving + sbcReplica := hc.AddTestTablet("cell", hostReplica, portReplica, keyspace, shard, topodatapb.TabletType_REPLICA, true, 0, nil) + + // add a result to the sandbox connection + sqlResult1 := &sqltypes.Result{ + Fields: []*querypb.Field{{ + Name: "col1", + Type: sqltypes.VarChar, + Charset: uint32(collations.MySQL8().DefaultConnectionCharset()), + }}, + RowsAffected: 1, + Rows: [][]sqltypes.Value{{ + sqltypes.MakeTrusted(sqltypes.VarChar, []byte("bb")), + }}, + } + sbc.SetResults([]*sqltypes.Result{sqlResult1}) - // set the serving type for the new primary tablet true and broadcast it so that the buffering code registers this change - // this should stop the buffering and the query executed in the go routine should work. This should be done with some delay so - // that we know that the query was buffered - time.Sleep(1 * time.Second) - // change the tablets types to simulate a PRS. - hc.SetTabletType(primaryTablet, topodatapb.TabletType_REPLICA) - hc.Broadcast(primaryTablet) - hc.SetTabletType(replicaTablet, topodatapb.TabletType_PRIMARY) - hc.SetPrimaryTimestamp(replicaTablet, 100) // We set a higher timestamp than before to simulate a PRS. - hc.SetServing(replicaTablet, true) - hc.Broadcast(replicaTablet) - - timeout := time.After(1 * time.Minute) -outer: - for { - select { - case <-timeout: - require.Fail(t, "timed out - could not verify the new primary") - case <-time.After(10 * time.Millisecond): - newPrimary, shouldBuffer := tg.kev.ShouldStartBufferingForTarget(ctx, target) - if newPrimary != nil && newPrimary.Uid == replicaTablet.Alias.Uid && !shouldBuffer { - break outer + // run a query that we indeed get the result added to the sandbox connection back + // this also checks that the query reaches the primary tablet and not the replica + res, err := tg.Execute(ctx, target, "query", nil, 0, 0, nil) + require.NoError(t, err) + require.Equal(t, res, sqlResult1) + + // get the primary and replica tablet from the fake health check + tablets := hc.GetAllTablets() + var primaryTablet *topodatapb.Tablet + var replicaTablet *topodatapb.Tablet + + for _, tablet := range tablets { + if tablet.Type == topodatapb.TabletType_PRIMARY { + primaryTablet = tablet + } else { + replicaTablet = tablet + } + } + require.NotNil(t, primaryTablet) + require.NotNil(t, replicaTablet) + + // broadcast its state initially + hc.Broadcast(primaryTablet) + // set the serving type for the primary tablet false and broadcast it so that the buffering code registers this change + hc.SetServing(primaryTablet, false) + // We call the broadcast twice to ensure that the change has been processed by the keyspace event watcher. + // The second broadcast call is blocking until the first one has been processed. + hc.Broadcast(primaryTablet) + hc.Broadcast(primaryTablet) + + require.Len(t, tg.hc.GetHealthyTabletStats(target), 0, "GetHealthyTabletStats has tablets even though it shouldn't") + _, shouldStartBuffering := tg.kev.ShouldStartBufferingForTarget(ctx, target) + require.True(t, shouldStartBuffering) + + // add a result to the sandbox connection of the new primary + sbcReplica.SetResults([]*sqltypes.Result{sqlResult1}) + + // execute the query in a go routine since it should be buffered, and check that it eventually succeed + queryChan := make(chan struct{}) + go func() { + res, err = tg.Execute(ctx, target, "query", nil, 0, 0, nil) + queryChan <- struct{}{} + }() + + // set the serving type for the new primary tablet true and broadcast it so that the buffering code registers this change + // this should stop the buffering and the query executed in the go routine should work. This should be done with some delay so + // that we know that the query was buffered + time.Sleep(1 * time.Second) + // change the tablets types to simulate a PRS. + hc.SetTabletType(primaryTablet, topodatapb.TabletType_REPLICA) + hc.Broadcast(primaryTablet) + hc.SetTabletType(replicaTablet, topodatapb.TabletType_PRIMARY) + hc.SetPrimaryTimestamp(replicaTablet, 100) // We set a higher timestamp than before to simulate a PRS. + hc.SetServing(replicaTablet, true) + hc.Broadcast(replicaTablet) + + timeout := time.After(1 * time.Minute) + outer: + for { + select { + case <-timeout: + require.Fail(t, "timed out - could not verify the new primary") + case <-time.After(10 * time.Millisecond): + newPrimary, shouldBuffer := tg.kev.ShouldStartBufferingForTarget(ctx, target) + if newPrimary != nil && newPrimary.Uid == replicaTablet.Alias.Uid && !shouldBuffer { + break outer + } } } - } - // wait for the query to execute before checking for results - select { - case <-queryChan: - require.NoError(t, err) - require.Equal(t, sqlResult1, res) - case <-time.After(15 * time.Second): - t.Fatalf("timed out waiting for query to execute") - } + // wait for the query to execute before checking for results + select { + case <-queryChan: + require.NoError(t, err) + require.Equal(t, sqlResult1, res) + case <-time.After(15 * time.Second): + t.Fatalf("timed out waiting for query to execute") + } + }) } // TestInconsistentStateDetectedBuffering simulates the case where we have used up all our buffering retries and in the diff --git a/go/vt/vtgate/vstream_manager_test.go b/go/vt/vtgate/vstream_manager_test.go index 6b3b4d834bd..3c6d24b55fb 100644 --- a/go/vt/vtgate/vstream_manager_test.go +++ b/go/vt/vtgate/vstream_manager_test.go @@ -58,6 +58,7 @@ func TestVStreamSkew(t *testing.T) { time.Sleep(time.Duration(idx*100) * time.Millisecond) } } + type skewTestCase struct { numEventsPerShard int64 shard0idx, shard1idx int64