diff --git a/go/test/endtoend/vtgate/misc_test.go b/go/test/endtoend/vtgate/misc_test.go index 6d5f1593447..bf76b42092b 100644 --- a/go/test/endtoend/vtgate/misc_test.go +++ b/go/test/endtoend/vtgate/misc_test.go @@ -319,6 +319,18 @@ func TestInformationSchemaQuery(t *testing.T) { assert.Equal(t, "vt_ks", qr.Rows[0][0].ToString()) } +func TestOffsetAndLimitWithOLAP(t *testing.T) { + ctx := context.Background() + conn, err := mysql.Connect(ctx, &vtParams) + require.NoError(t, err) + defer conn.Close() + + exec(t, conn, "insert into t1(id1, id2) values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)") + assertMatches(t, conn, "select id1 from t1 order by id1 limit 3 offset 2", "[[INT64(3)] [INT64(4)] [INT64(5)]]") + exec(t, conn, "set workload='olap'") + assertMatches(t, conn, "select id1 from t1 order by id1 limit 3 offset 2", "[[INT64(3)] [INT64(4)] [INT64(5)]]") +} + func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) { t.Helper() qr := exec(t, conn, query) diff --git a/go/vt/vtgate/engine/limit.go b/go/vt/vtgate/engine/limit.go index c43b6a4b28f..e712c182dc5 100644 --- a/go/vt/vtgate/engine/limit.go +++ b/go/vt/vtgate/engine/limit.go @@ -94,11 +94,14 @@ func (l *Limit) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.Bind if err != nil { return err } - if !l.Offset.IsNull() { - return fmt.Errorf("offset not supported for stream execute queries") + offset, err := l.fetchOffset(bindVars) + if err != nil { + return err } - bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count)) + // When offset is present, we hijack the limit value so we can calculate + // the offset in memory from the result of the scatter query with count + offset. + bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset)) err = l.Input.StreamExecute(vcursor, bindVars, wantfields, func(qr *sqltypes.Result) error { if len(qr.Fields) != 0 { @@ -106,19 +109,31 @@ func (l *Limit) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.Bind return err } } - if len(qr.Rows) == 0 { + inputSize := len(qr.Rows) + if inputSize == 0 { return nil } + // we've still not seen all rows we need to see before we can return anything to the client + if offset > 0 { + if inputSize <= offset { + // not enough to return anything yet + offset -= inputSize + return nil + } + qr.Rows = qr.Rows[offset:] + offset = 0 + } + if count == 0 { - // Unreachable: this is just a failsafe. return io.EOF } // reduce count till 0. result := &sqltypes.Result{Rows: qr.Rows} - if count > len(result.Rows) { - count -= len(result.Rows) + resultSize := len(result.Rows) + if count > resultSize { + count -= resultSize return callback(result) } result.Rows = result.Rows[:count] @@ -140,7 +155,7 @@ func (l *Limit) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.Bind return nil } -// GetFields satisfies the Primtive interface. +// GetFields implements the Primitive interface. func (l *Limit) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { return l.Input.GetFields(vcursor, bindVars) } @@ -150,11 +165,16 @@ func (l *Limit) Inputs() []Primitive { return []Primitive{l.Input} } +//NeedsTransaction implements the Primitive interface. func (l *Limit) NeedsTransaction() bool { return l.Input.NeedsTransaction() } func (l *Limit) fetchCount(bindVars map[string]*querypb.BindVariable) (int, error) { + if l.Count.IsNull() { + return 0, nil + } + resolved, err := l.Count.ResolveValue(bindVars) if err != nil { return 0, err diff --git a/go/vt/vtgate/engine/limit_test.go b/go/vt/vtgate/engine/limit_test.go index e0dd4ddadaf..d7eb5077e71 100644 --- a/go/vt/vtgate/engine/limit_test.go +++ b/go/vt/vtgate/engine/limit_test.go @@ -432,6 +432,49 @@ func TestLimitStreamExecute(t *testing.T) { } } +func TestOffsetStreamExecute(t *testing.T) { + bindVars := make(map[string]*querypb.BindVariable) + fields := sqltypes.MakeTestFields( + "col1|col2", + "int64|varchar", + ) + inputResult := sqltypes.MakeTestResult( + fields, + "a|1", + "b|2", + "c|3", + "d|4", + "e|5", + "f|6", + ) + fp := &fakePrimitive{ + results: []*sqltypes.Result{inputResult}, + } + + l := &Limit{ + Offset: int64PlanValue(2), + Count: int64PlanValue(3), + Input: fp, + } + + var results []*sqltypes.Result + err := l.StreamExecute(nil, bindVars, false, func(qr *sqltypes.Result) error { + results = append(results, qr) + return nil + }) + require.NoError(t, err) + wantResults := sqltypes.MakeTestStreamingResults( + fields, + "c|3", + "d|4", + "---", + "e|5", + ) + if !reflect.DeepEqual(results, wantResults) { + t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(wantResults)) + } +} + func TestLimitGetFields(t *testing.T) { result := sqltypes.MakeTestResult( sqltypes.MakeTestFields( diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index cbdb1362754..a0ffb4533ed 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -221,6 +221,78 @@ func TestStreamBuffering(t *testing.T) { } } +func TestStreamLimitOffset(t *testing.T) { + executor, sbc1, sbc2, _ := createLegacyExecutorEnv() + + // This test is similar to TestStreamUnsharded except that it returns a Result > 10 bytes, + // such that the splitting of the Result into multiple Result responses gets tested. + sbc1.SetResults([]*sqltypes.Result{{ + Fields: []*querypb.Field{ + {Name: "id", Type: sqltypes.Int32}, + {Name: "textcol", Type: sqltypes.VarChar}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt32(1), + sqltypes.NewVarChar("1234"), + }, { + sqltypes.NewInt32(4), + sqltypes.NewVarChar("4567"), + }}, + }}) + + sbc2.SetResults([]*sqltypes.Result{{ + Fields: []*querypb.Field{ + {Name: "id", Type: sqltypes.Int32}, + {Name: "textcol", Type: sqltypes.VarChar}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt32(2), + sqltypes.NewVarChar("2345"), + }}, + }}) + + results := make(chan *sqltypes.Result, 10) + err := executor.StreamExecute( + context.Background(), + "TestStreamLimitOffset", + NewSafeSession(masterSession), + "select id, textcol from user order by id limit 2 offset 2", + nil, + querypb.Target{ + TabletType: topodatapb.TabletType_MASTER, + }, + func(qr *sqltypes.Result) error { + results <- qr + return nil + }, + ) + close(results) + require.NoError(t, err) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "id", Type: sqltypes.Int32}, + {Name: "textcol", Type: sqltypes.VarChar}, + }, + + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt32(1), + sqltypes.NewVarChar("1234"), + }, { + sqltypes.NewInt32(1), + sqltypes.NewVarChar("foo"), + }}, + } + var gotResults []*sqltypes.Result + for r := range results { + gotResults = append(gotResults, r) + } + res := gotResults[0] + for i := 1; i < len(gotResults); i++ { + res.Rows = append(res.Rows, gotResults[i].Rows...) + } + utils.MustMatch(t, wantResult, res, "") +} + func TestSelectLastInsertId(t *testing.T) { masterSession.LastInsertId = 52 executor, _, _, _ := createLegacyExecutorEnv() diff --git a/go/vt/vtgate/planbuilder/limit.go b/go/vt/vtgate/planbuilder/limit.go index 74fdc754188..796850b3b18 100644 --- a/go/vt/vtgate/planbuilder/limit.go +++ b/go/vt/vtgate/planbuilder/limit.go @@ -18,11 +18,9 @@ package planbuilder import ( "errors" - "fmt" - - "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" ) @@ -93,17 +91,12 @@ func (l *limit) SetLimit(limit *sqlparser.Limit) error { } l.elimit.Count = pv - switch offset := limit.Offset.(type) { - case *sqlparser.Literal: - pv, err = sqlparser.NewPlanValue(offset) + if limit.Offset != nil { + pv, err = sqlparser.NewPlanValue(limit.Offset) if err != nil { - return err + return vterrors.Wrap(err, "unexpected expression in OFFSET") } l.elimit.Offset = pv - case nil: - // NOOP - default: - return fmt.Errorf("unexpected expression in LIMIT: %v", sqlparser.String(limit)) } l.input.SetUpperLimit(sqlparser.NewArgument([]byte(":__upper_limit")))