diff --git a/go/vt/tabletserver/dbconn.go b/go/vt/tabletserver/dbconn.go index 7f14f1e6a7c..416b4d78e57 100644 --- a/go/vt/tabletserver/dbconn.go +++ b/go/vt/tabletserver/dbconn.go @@ -5,7 +5,6 @@ package tabletserver import ( - "errors" "fmt" "time" @@ -78,7 +77,7 @@ func (dbc *DBConn) Exec(ctx context.Context, query string, maxrows int, wantfiel return nil, NewTabletErrorSQL(ErrFatal, vtrpcpb.ErrorCode_INTERNAL_ERROR, err) } } - return nil, NewTabletErrorSQL(ErrFatal, vtrpcpb.ErrorCode_INTERNAL_ERROR, errors.New("dbconn.Exec: unreachable code")) + panic("unreachable") } func (dbc *DBConn) execOnce(ctx context.Context, query string, maxrows int, wantfields bool) (*sqltypes.Result, error) { @@ -101,6 +100,38 @@ func (dbc *DBConn) ExecOnce(ctx context.Context, query string, maxrows int, want // Stream executes the query and streams the results. func (dbc *DBConn) Stream(ctx context.Context, query string, callback func(*sqltypes.Result) error, streamBufferSize int) error { + span := trace.NewSpanFromContext(ctx) + span.StartClient("DBConn.Stream") + defer span.Finish() + + for attempt := 1; attempt <= 2; attempt++ { + resultSent := false + err := dbc.streamOnce( + ctx, + query, + func(r *sqltypes.Result) error { + resultSent = true + return callback(r) + }, + streamBufferSize, + ) + switch { + case err == nil: + return nil + case !IsConnErr(err) || resultSent || attempt == 2: + // MySQL error that isn't due to a connection issue + return err + } + err2 := dbc.reconnect() + if err2 != nil { + dbc.pool.checker.CheckMySQL() + return err + } + } + panic("unreachable") +} + +func (dbc *DBConn) streamOnce(ctx context.Context, query string, callback func(*sqltypes.Result) error, streamBufferSize int) error { dbc.current.Set(query) defer dbc.current.Set("") diff --git a/go/vt/tabletserver/dbconn_test.go b/go/vt/tabletserver/dbconn_test.go index 1744247dd3a..604397cf4be 100644 --- a/go/vt/tabletserver/dbconn_test.go +++ b/go/vt/tabletserver/dbconn_test.go @@ -6,6 +6,7 @@ package tabletserver import ( "fmt" + "strings" "testing" "time" @@ -120,4 +121,15 @@ func TestDBConnStream(t *testing.T) { t.Fatalf("should not get an error, err: %v", err) } testUtils.checkEqual(t, expectedResult, &result) + // Stream fail + db.EnableConnFail() + err = dbConn.Stream( + ctx, sql, func(r *sqltypes.Result) error { + return nil + }, 10) + db.DisableConnFail() + want := "connection fail" + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("Error: %v, must contain %s\n", err, want) + } }