diff --git a/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go b/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go index 7f63a90650d..2e279285698 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go +++ b/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go @@ -392,11 +392,13 @@ func (rs *rowStreamer) streamQuery(send func(*binlogdatapb.VStreamRowsResponse) heartbeatTicker := time.NewTicker(rowStreamertHeartbeatInterval) defer heartbeatTicker.Stop() go func() { - select { - case <-rs.ctx.Done(): - return - case <-heartbeatTicker.C: - safeSend(&binlogdatapb.VStreamRowsResponse{Heartbeat: true}) + for { + select { + case <-rs.ctx.Done(): + return + case <-heartbeatTicker.C: + safeSend(&binlogdatapb.VStreamRowsResponse{Heartbeat: true}) + } } }() diff --git a/go/vt/vttablet/tabletserver/vstreamer/rowstreamer_test.go b/go/vt/vttablet/tabletserver/vstreamer/rowstreamer_test.go index 371b3d814f7..935bad6d3c5 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/rowstreamer_test.go +++ b/go/vt/vttablet/tabletserver/vstreamer/rowstreamer_test.go @@ -22,6 +22,7 @@ import ( "regexp" "strconv" "testing" + "time" "github.com/stretchr/testify/require" @@ -542,6 +543,78 @@ func TestStreamRowsCancel(t *testing.T) { } } +func TestStreamRowsHeartbeat(t *testing.T) { + if testing.Short() { + t.Skip() + } + + // Save original heartbeat interval and restore it after test + originalInterval := rowStreamertHeartbeatInterval + defer func() { + rowStreamertHeartbeatInterval = originalInterval + }() + + // Set a very short heartbeat interval for testing (100ms) + rowStreamertHeartbeatInterval = 10 * time.Millisecond + + execStatements(t, []string{ + "create table t1(id int, val varchar(128), primary key(id))", + "insert into t1 values (1, 'test1')", + "insert into t1 values (2, 'test2')", + "insert into t1 values (3, 'test3')", + "insert into t1 values (4, 'test4')", + "insert into t1 values (5, 'test5')", + }) + + defer execStatements(t, []string{ + "drop table t1", + }) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + heartbeatCount := 0 + dataReceived := false + + var options binlogdatapb.VStreamOptions + options.ConfigOverrides = make(map[string]string) + options.ConfigOverrides["vstream_dynamic_packet_size"] = "false" + options.ConfigOverrides["vstream_packet_size"] = "10" + + err := engine.StreamRows(ctx, "select * from t1", nil, func(rows *binlogdatapb.VStreamRowsResponse) error { + if rows.Heartbeat { + heartbeatCount++ + // After receiving at least 3 heartbeats, we can be confident the fix is working + if heartbeatCount >= 3 { + cancel() + return nil + } + } else if len(rows.Rows) > 0 { + dataReceived = true + } + // Add a small delay to allow heartbeats to be sent + time.Sleep(50 * time.Millisecond) + return nil + }, &options) + + // We expect context canceled error since we cancel after receiving heartbeats + if err != nil && err.Error() != "stream ended: context canceled" { + t.Errorf("unexpected error: %v", err) + } + + // Verify we received data + if !dataReceived { + t.Error("expected to receive data rows") + } + + // This is the critical test: we should receive multiple heartbeats + // Without the fix (missing for loop), we would only get 1 heartbeat + // With the fix, we should get at least 3 heartbeats + if heartbeatCount < 3 { + t.Errorf("expected at least 3 heartbeats, got %d. This indicates the heartbeat goroutine is not running continuously", heartbeatCount) + } +} + func checkStream(t *testing.T, query string, lastpk []sqltypes.Value, wantQuery string, wantStream []string) { t.Helper()