diff --git a/go/vt/vttablet/tabletserver/vstreamer/uvstreamer_test.go b/go/vt/vttablet/tabletserver/vstreamer/uvstreamer_test.go index 0ba97a68153..36a3813c0e1 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/uvstreamer_test.go +++ b/go/vt/vttablet/tabletserver/vstreamer/uvstreamer_test.go @@ -47,6 +47,7 @@ import ( "fmt" "regexp" "strings" + "sync" "testing" "time" @@ -75,7 +76,7 @@ var testState = &state{} var positions map[string]string var allEvents []*binlogdatapb.VEvent - +var muAllEvents sync.Mutex var callbacks map[string]func() func TestVStreamCopyFilterValidations(t *testing.T) { @@ -114,12 +115,12 @@ func TestVStreamCopyFilterValidations(t *testing.T) { } return uvs } - var testFilter = func(rules []*binlogdatapb.Rule, tablePKs []*binlogdatapb.TableLastPK, expected []string, err string) { + var testFilter = func(rules []*binlogdatapb.Rule, tablePKs []*binlogdatapb.TableLastPK, expected []string, expectedError string) { uvs := getUVStreamer(&binlogdatapb.Filter{Rules: rules}, tablePKs) - if err == "" { + if expectedError == "" { require.NoError(t, uvs.init()) } else { - require.Error(t, uvs.init(), err) + require.Error(t, uvs.init(), expectedError) return } require.Equal(t, len(expected), len(uvs.plans)) @@ -135,10 +136,10 @@ func TestVStreamCopyFilterValidations(t *testing.T) { } type TestCase struct { - rules []*binlogdatapb.Rule - tablePKs []*binlogdatapb.TableLastPK - expected []string - err string + rules []*binlogdatapb.Rule + tablePKs []*binlogdatapb.TableLastPK + expected []string + expectedError string } var testCases []*TestCase @@ -159,7 +160,7 @@ func TestVStreamCopyFilterValidations(t *testing.T) { for _, tc := range testCases { log.Infof("Running %v", tc.rules) - testFilter(tc.rules, tc.tablePKs, tc.expected, tc.err) + testFilter(tc.rules, tc.tablePKs, tc.expected, tc.expectedError) } } @@ -196,9 +197,7 @@ func TestVStreamCopyCompleteFlow(t *testing.T) { // Test event called after t1 copy is complete callbacks["OTHER.*Copy Start t2"] = func() { conn, err := env.Mysqld.GetDbaConnection(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer conn.Close() log.Info("Inserting row for fast forward to find, locking t2") @@ -212,9 +211,7 @@ func TestVStreamCopyCompleteFlow(t *testing.T) { callbacks["OTHER.*Copy Start t3"] = func() { conn, err := env.Mysqld.GetDbaConnection(ctx) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer conn.Close() log.Info("Inserting row for fast forward to find, locking t3") @@ -268,7 +265,8 @@ func TestVStreamCopyCompleteFlow(t *testing.T) { case <-ctx.Done(): log.Infof("Received context.Done, ending test") } - + muAllEvents.Lock() + defer muAllEvents.Unlock() printAllEvents("End of test") if len(allEvents) != numExpectedEvents { log.Errorf("Received %d events, expected %d", len(allEvents), numExpectedEvents) @@ -280,7 +278,6 @@ func TestVStreamCopyCompleteFlow(t *testing.T) { log.Infof("Successfully received %d events", numExpectedEvents) } validateReceivedEvents(t) - } func validateReceivedEvents(t *testing.T) { @@ -389,6 +386,8 @@ func startVStreamCopy(ctx context.Context, t *testing.T, filter *binlogdatapb.Fi go func() { err := engine.Stream(ctx, pos, tablePKs, filter, func(evs []*binlogdatapb.VEvent) error { //t.Logf("Received events: %v", evs) + muAllEvents.Lock() + defer muAllEvents.Unlock() for _, ev := range evs { if ev.Type == binlogdatapb.VEventType_HEARTBEAT { continue