diff --git a/go/vt/vttablet/tabletserver/messager/message_manager.go b/go/vt/vttablet/tabletserver/messager/message_manager.go index 573be17bb9b..1c4af748ee3 100644 --- a/go/vt/vttablet/tabletserver/messager/message_manager.go +++ b/go/vt/vttablet/tabletserver/messager/message_manager.go @@ -356,8 +356,10 @@ func (mm *messageManager) send(receiver *receiverWithStatus, qr *sqltypes.Result tabletenv.LogError() mm.wg.Done() }() + receiverClosed := false if err := receiver.receiver.Send(qr); err != nil { if err == io.EOF { + receiverClosed = true // No need to call Cancel. messageReceiver already // does that before returning this error. mm.unsubscribe(receiver.receiver) @@ -381,10 +383,18 @@ func (mm *messageManager) send(receiver *receiverWithStatus, qr *sqltypes.Result for i, row := range qr.Rows { ids[i] = row[0].ToString() } - // postpone should discard, but this is a safety measure - // in case it fails. + mm.cache.Discard(ids) - go postpone(mm.tsv, mm.name.String(), mm.ackWaitTime, ids) + if receiverClosed { + // If the receiver ended the stream, we want the messages + // to be resent ASAP without postponement. Setting messagesPending + // will trigger the poller as soon as the cache is clear. + mm.mu.Lock() + mm.messagesPending = true + mm.mu.Unlock() + } else { + go postpone(mm.tsv, mm.name.String(), mm.ackWaitTime, ids) + } } // postpone is a non-member because it should be called asynchronously and should diff --git a/go/vt/vttablet/tabletserver/messager/message_manager_test.go b/go/vt/vttablet/tabletserver/messager/message_manager_test.go index 43a088d4e42..1ad39b9a4e0 100644 --- a/go/vt/vttablet/tabletserver/messager/message_manager_test.go +++ b/go/vt/vttablet/tabletserver/messager/message_manager_test.go @@ -213,8 +213,8 @@ func TestMessageManagerSend(t *testing.T) { } // Ensure Postpone got called. - if got := <-ch; got != mmTable.Name.String() { - t.Errorf("Postpone: %s, want %v", got, mmTable.Name) + if got, want := <-ch, "postpone"; got != want { + t.Errorf("Postpone: %s, want %v", got, want) } // Verify item has been removed from cache. @@ -244,7 +244,7 @@ func TestMessageManagerSend(t *testing.T) { continue } mm.mu.Unlock() - return + break } mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("4")}}) @@ -256,6 +256,43 @@ func TestMessageManagerSend(t *testing.T) { <-r1.ch } +func TestMessageManagerSendEOF(t *testing.T) { + db := fakesqldb.New(t) + defer db.Close() + tsv := newFakeTabletServer() + mm := newMessageManager(tsv, mmTable, newMMConnPool(db)) + mm.Open() + defer mm.Close() + r1 := newTestReceiver(0) + ctx, cancel := context.WithCancel(context.Background()) + mm.Subscribe(ctx, r1.rcv) + // Pull field info. + <-r1.ch + + mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("1"), sqltypes.NULL}}) + // Wait for send to enqueue + r1.WaitForCount(2) + + // Now cancel, which will send an EOF to the sender. + cancel() + // Wait for send to enqueue + messagesWerePending := false + for i := 0; i < 10; i++ { + runtime.Gosched() + mm.mu.Lock() + if mm.messagesPending { + messagesWerePending = true + mm.mu.Unlock() + break + } + mm.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + if !messagesWerePending { + t.Error("Send with EOF did not trigger pending messages") + } +} + func TestMessageManagerBatchSend(t *testing.T) { db := fakesqldb.New(t) defer db.Close() @@ -509,8 +546,8 @@ func TestMessageManagerPurge(t *testing.T) { mm.Open() defer mm.Close() // Ensure Purge got called. - if got := <-ch; got != mmTable.Name.String() { - t.Errorf("Postpone: %s, want %v", got, mmTable.Name) + if got, want := <-ch, "purge"; got != want { + t.Errorf("Purge: %s, want %v", got, want) } } @@ -583,14 +620,14 @@ func (fts *fakeTabletServer) SetChannel(ch chan string) { func (fts *fakeTabletServer) PostponeMessages(ctx context.Context, target *querypb.Target, name string, ids []string) (count int64, err error) { if fts.ch != nil { - fts.ch <- name + fts.ch <- "postpone" } return 0, nil } func (fts *fakeTabletServer) PurgeMessages(ctx context.Context, target *querypb.Target, name string, timeCutoff int64) (count int64, err error) { if fts.ch != nil { - fts.ch <- name + fts.ch <- "purge" } return 0, nil }