diff --git a/go/test/endtoend/vtgate/found_rows_test.go b/go/test/endtoend/vtgate/found_rows_test.go new file mode 100644 index 00000000000..0f7143a3b28 --- /dev/null +++ b/go/test/endtoend/vtgate/found_rows_test.go @@ -0,0 +1,66 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vtgate + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/test/endtoend/cluster" +) + +func TestFoundRows(t *testing.T) { + defer cluster.PanicHandler(t) + ctx := context.Background() + conn, err := mysql.Connect(ctx, &vtParams) + require.Nil(t, err) + defer conn.Close() + + exec(t, conn, "insert into t2(id3,id4) values(1,2), (2,2), (3,3), (4,3), (5,3)") + + runTests := func(workload string) { + assertFoundRowsValue(t, conn, "select * from t2", workload, 5) + assertFoundRowsValue(t, conn, "select * from t2 limit 2", workload, 2) + assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS * from t2 limit 2", workload, 5) + assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS * from t2 where id3 = 4 limit 2", workload, 1) + assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS * from t2 where id4 = 3 limit 2", workload, 3) + assertFoundRowsValue(t, conn, "select SQL_CALC_FOUND_ROWS id4, count(id3) from t2 where id3 = 3 group by id4 limit 1", workload, 1) + } + + runTests("oltp") + exec(t, conn, "set workload = olap") + runTests("olap") + + // cleanup test data + exec(t, conn, "set workload = oltp") + exec(t, conn, "delete from t2") + exec(t, conn, "delete from t2_id4_idx") +} + +func assertFoundRowsValue(t *testing.T, conn *mysql.Conn, query, workload string, count int) { + exec(t, conn, query) + qr := exec(t, conn, "select found_rows()") + got := fmt.Sprintf("%v", qr.Rows) + want := fmt.Sprintf(`[[UINT64(%d)]]`, count) + assert.Equalf(t, want, got, "Workload: %s\nQuery:%s\n", workload, query) +} diff --git a/go/test/endtoend/vtgate/lookup_test.go b/go/test/endtoend/vtgate/lookup_test.go index 29829410d38..636be7f2d0a 100644 --- a/go/test/endtoend/vtgate/lookup_test.go +++ b/go/test/endtoend/vtgate/lookup_test.go @@ -410,14 +410,8 @@ func TestHashLookupMultiInsertIgnore(t *testing.T) { defer conn2.Close() // DB should start out clean - qr := exec(t, conn, "select count(*) from t2_id4_idx") - if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(0)]]"; got != want { - t.Errorf("select:\n%v want\n%v", got, want) - } - qr = exec(t, conn, "select count(*) from t2") - if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(0)]]"; got != want { - t.Errorf("select:\n%v want\n%v", got, want) - } + assertMatches(t, conn, "select count(*) from t2_id4_idx", "[[INT64(0)]]") + assertMatches(t, conn, "select count(*) from t2", "[[INT64(0)]]") // Try inserting a bunch of ids at once exec(t, conn, "begin") @@ -425,14 +419,8 @@ func TestHashLookupMultiInsertIgnore(t *testing.T) { exec(t, conn, "commit") // Verify - qr = exec(t, conn, "select id3, id4 from t2 order by id3") - if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]"; got != want { - t.Errorf("select:\n%v want\n%v", got, want) - } - qr = exec(t, conn, "select id3, id4 from t2_id4_idx order by id3") - if got, want := fmt.Sprintf("%v", qr.Rows), "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]"; got != want { - t.Errorf("select:\n%v want\n%v", got, want) - } + assertMatches(t, conn, "select id3, id4 from t2 order by id3", "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]") + assertMatches(t, conn, "select id3, id4 from t2_id4_idx order by id3", "[[INT64(10) INT64(20)] [INT64(30) INT64(40)] [INT64(50) INT64(60)]]") } func TestConsistentLookupUpdate(t *testing.T) { diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 04448731c4c..f931a40b5c2 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -53,6 +53,10 @@ type noopVCursor struct { ctx context.Context } +func (t noopVCursor) SetFoundRows(u uint64) { + panic("implement me") +} + func (t noopVCursor) InTransactionAndIsDML() bool { panic("implement me") } @@ -199,6 +203,10 @@ type loggingVCursor struct { resolvedTargetTabletType topodatapb.TabletType } +func (f *loggingVCursor) SetFoundRows(u uint64) { + panic("implement me") +} + func (f *loggingVCursor) InTransactionAndIsDML() bool { return false } diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index 4e2e40e7ed7..f5bdf6a6fe5 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -116,6 +116,7 @@ type ( SetSQLSelectLimit(int64) SetTransactionMode(vtgatepb.TransactionMode) SetWorkload(querypb.ExecuteOptions_Workload) + SetFoundRows(uint64) } // Plan represents the execution strategy for a given query. diff --git a/go/vt/vtgate/engine/sql_calc_found_rows.go b/go/vt/vtgate/engine/sql_calc_found_rows.go new file mode 100644 index 00000000000..65c3f7574dc --- /dev/null +++ b/go/vt/vtgate/engine/sql_calc_found_rows.go @@ -0,0 +1,124 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" +) + +var _ Primitive = (*SQLCalcFoundRows)(nil) + +//SQLCalcFoundRows is a primitive to execute limit and count query as per their individual plan. +type SQLCalcFoundRows struct { + LimitPrimitive Primitive + CountPrimitive Primitive +} + +//RouteType implements the Primitive interface +func (s SQLCalcFoundRows) RouteType() string { + return "SQLCalcFoundRows" +} + +//GetKeyspaceName implements the Primitive interface +func (s SQLCalcFoundRows) GetKeyspaceName() string { + return s.LimitPrimitive.GetKeyspaceName() +} + +//GetTableName implements the Primitive interface +func (s SQLCalcFoundRows) GetTableName() string { + return s.LimitPrimitive.GetTableName() +} + +//Execute implements the Primitive interface +func (s SQLCalcFoundRows) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + limitQr, err := s.LimitPrimitive.Execute(vcursor, bindVars, wantfields) + if err != nil { + return nil, err + } + countQr, err := s.CountPrimitive.Execute(vcursor, bindVars, false) + if err != nil { + return nil, err + } + if len(countQr.Rows) != 1 || len(countQr.Rows[0]) != 1 { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "count query is not a scalar") + } + fr, err := evalengine.ToUint64(countQr.Rows[0][0]) + if err != nil { + return nil, err + } + vcursor.Session().SetFoundRows(fr) + return limitQr, nil +} + +//StreamExecute implements the Primitive interface +func (s SQLCalcFoundRows) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { + err := s.LimitPrimitive.StreamExecute(vcursor, bindVars, wantfields, callback) + if err != nil { + return err + } + + var fr *uint64 + + err = s.CountPrimitive.StreamExecute(vcursor, bindVars, wantfields, func(countQr *sqltypes.Result) error { + if len(countQr.Rows) == 0 && countQr.Fields != nil { + // this is the fields, which we can ignore + return nil + } + if len(countQr.Rows) != 1 || len(countQr.Rows[0]) != 1 { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "count query is not a scalar") + } + toUint64, err := evalengine.ToUint64(countQr.Rows[0][0]) + if err != nil { + return err + } + fr = &toUint64 + return nil + }) + if err != nil { + return err + } + if fr == nil { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "count query for SQL_CALC_FOUND_ROWS never returned a value") + } + vcursor.Session().SetFoundRows(*fr) + return nil +} + +//GetFields implements the Primitive interface +func (s SQLCalcFoundRows) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + return s.LimitPrimitive.GetFields(vcursor, bindVars) +} + +//NeedsTransaction implements the Primitive interface +func (s SQLCalcFoundRows) NeedsTransaction() bool { + return s.LimitPrimitive.NeedsTransaction() +} + +//Inputs implements the Primitive interface +func (s SQLCalcFoundRows) Inputs() []Primitive { + return []Primitive{s.LimitPrimitive, s.CountPrimitive} +} + +func (s SQLCalcFoundRows) description() PrimitiveDescription { + return PrimitiveDescription{ + OperatorType: "SQL_CALC_FOUND_ROWS", + } +} diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index ab2ef1e757a..b5e81ec4b93 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -163,7 +163,9 @@ func saveSessionStats(safeSession *SafeSession, stmtType sqlparser.StatementType if err != nil { return } - safeSession.FoundRows = result.RowsAffected + if !safeSession.foundRowsHandled { + safeSession.FoundRows = result.RowsAffected + } if result.InsertID > 0 { safeSession.LastInsertId = result.InsertID } @@ -997,6 +999,7 @@ func (e *Executor) StreamExecute(ctx context.Context, method string, safeSession result := &sqltypes.Result{} byteCount := 0 seenResults := false + var foundRows uint64 err = plan.Instructions.StreamExecute(vcursor, bindVars, true, func(qr *sqltypes.Result) error { // If the row has field info, send it separately. // TODO(sougou): this behavior is for handling tests because @@ -1009,8 +1012,10 @@ func (e *Executor) StreamExecute(ctx context.Context, method string, safeSession seenResults = true } + foundRows += uint64(len(qr.Rows)) for _, row := range qr.Rows { result.Rows = append(result.Rows, row) + for _, col := range row { byteCount += col.Len() } @@ -1038,6 +1043,12 @@ func (e *Executor) StreamExecute(ctx context.Context, method string, safeSession logStats.ExecuteTime = time.Since(execStart) e.updateQueryCounts(plan.Instructions.RouteType(), plan.Instructions.GetKeyspaceName(), plan.Instructions.GetTableName(), int64(logStats.ShardQueries)) + // save session stats for future queries + if !safeSession.foundRowsHandled { + safeSession.FoundRows = foundRows + } + safeSession.RowCount = -1 + return err } diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index 5f5dabc26da..c1f12e64dd6 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -305,7 +305,7 @@ func buildRoutePlan(stmt sqlparser.Statement, vschema ContextVSchema, f func(sta func createInstructionFor(query string, stmt sqlparser.Statement, vschema ContextVSchema) (engine.Primitive, error) { switch stmt := stmt.(type) { case *sqlparser.Select: - return buildRoutePlan(stmt, vschema, buildSelectPlan) + return buildRoutePlan(stmt, vschema, buildSelectPlan(query)) case *sqlparser.Insert: return buildRoutePlan(stmt, vschema, buildInsertPlan) case *sqlparser.Update: diff --git a/go/vt/vtgate/planbuilder/expr.go b/go/vt/vtgate/planbuilder/expr.go index c289b073f29..8c86472fc21 100644 --- a/go/vt/vtgate/planbuilder/expr.go +++ b/go/vt/vtgate/planbuilder/expr.go @@ -104,7 +104,7 @@ func (pb *primitiveBuilder) findOrigin(expr sqlparser.Expr) (pullouts []*pullout spb := newPrimitiveBuilder(pb.vschema, pb.jt) switch stmt := node.Select.(type) { case *sqlparser.Select: - if err := spb.processSelect(stmt, pb.st); err != nil { + if err := spb.processSelect(stmt, pb.st, ""); err != nil { return false, err } case *sqlparser.Union: @@ -230,7 +230,7 @@ func (pb *primitiveBuilder) finalizeUnshardedDMLSubqueries(nodes ...sqlparser.SQ return true, nil } spb := newPrimitiveBuilder(pb.vschema, pb.jt) - if err := spb.processSelect(nodeType, pb.st); err != nil { + if err := spb.processSelect(nodeType, pb.st, ""); err != nil { samePlan = false return false, err } diff --git a/go/vt/vtgate/planbuilder/from.go b/go/vt/vtgate/planbuilder/from.go index 467e83ac752..49699be6786 100644 --- a/go/vt/vtgate/planbuilder/from.go +++ b/go/vt/vtgate/planbuilder/from.go @@ -96,7 +96,7 @@ func (pb *primitiveBuilder) processAliasedTable(tableExpr *sqlparser.AliasedTabl spb := newPrimitiveBuilder(pb.vschema, pb.jt) switch stmt := expr.Select.(type) { case *sqlparser.Select: - if err := spb.processSelect(stmt, nil); err != nil { + if err := spb.processSelect(stmt, nil, ""); err != nil { return err } case *sqlparser.Union: diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index f4295350152..90ff1c57068 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -20,6 +20,8 @@ import ( "errors" "fmt" + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" @@ -33,26 +35,27 @@ import ( "vitess.io/vitess/go/vt/vtgate/engine" ) -// buildSelectPlan is the new function to build a Select plan. -func buildSelectPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.Primitive, error) { - sel := stmt.(*sqlparser.Select) +func buildSelectPlan(query string) func(sqlparser.Statement, ContextVSchema) (engine.Primitive, error) { + return func(stmt sqlparser.Statement, vschema ContextVSchema) (engine.Primitive, error) { + sel := stmt.(*sqlparser.Select) - p, err := handleDualSelects(sel, vschema) - if err != nil { - return nil, err - } - if p != nil { - return p, nil - } + p, err := handleDualSelects(sel, vschema) + if err != nil { + return nil, err + } + if p != nil { + return p, nil + } - pb := newPrimitiveBuilder(vschema, newJointab(sqlparser.GetBindvars(sel))) - if err := pb.processSelect(sel, nil); err != nil { - return nil, err - } - if err := pb.bldr.Wireup(pb.bldr, pb.jt); err != nil { - return nil, err + pb := newPrimitiveBuilder(vschema, newJointab(sqlparser.GetBindvars(sel))) + if err := pb.processSelect(sel, nil, query); err != nil { + return nil, err + } + if err := pb.bldr.Wireup(pb.bldr, pb.jt); err != nil { + return nil, err + } + return pb.bldr.Primitive(), nil } - return pb.bldr.Primitive(), nil } // processSelect builds a primitive tree for the given query or subquery. @@ -90,15 +93,26 @@ func buildSelectPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.P // The LIMIT clause is the last construct of a query. If it cannot be // pushed into a route, then a primitive is created on top of any // of the above trees to make it discard unwanted rows. -func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab) error { +func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab, query string) error { // Check and error if there is any locking function present in select expression. for _, expr := range sel.SelectExprs { if aExpr, ok := expr.(*sqlparser.AliasedExpr); ok && sqlparser.IsLockingFunc(aExpr.Expr) { return vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "%v allowed only with dual", sqlparser.String(aExpr)) } } - if sel.SQLCalcFoundRows && sel.Limit != nil { - return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "sql_calc_found_rows not yet fully supported") + if sel.SQLCalcFoundRows { + if outer != nil || query == "" { + return mysql.NewSQLError(mysql.ERCantUseOptionHere, "42000", "Incorrect usage/placement of 'SQL_CALC_FOUND_ROWS'") + } + sel.SQLCalcFoundRows = false + if sel.Limit != nil { + builder, err := buildSQLCalcFoundRowsPlan(query, sel, outer, pb.vschema) + if err != nil { + return err + } + pb.bldr = builder + return nil + } } if err := pb.processTableExprs(sel.From); err != nil { @@ -147,6 +161,59 @@ func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab) return nil } +func buildSQLCalcFoundRowsPlan(query string, sel *sqlparser.Select, outer *symtab, vschema ContextVSchema) (builder, error) { + ljt := newJointab(sqlparser.GetBindvars(sel)) + frpb := newPrimitiveBuilder(vschema, ljt) + err := frpb.processSelect(sel, outer, "") + if err != nil { + return nil, err + } + + statement, err := sqlparser.Parse(query) + if err != nil { + return nil, err + } + sel2 := statement.(*sqlparser.Select) + + sel2.SQLCalcFoundRows = false + sel2.OrderBy = nil + sel2.Limit = nil + + countStartExpr := []sqlparser.SelectExpr{&sqlparser.AliasedExpr{ + Expr: &sqlparser.FuncExpr{ + Name: sqlparser.NewColIdent("count"), + Exprs: []sqlparser.SelectExpr{&sqlparser.StarExpr{}}, + }, + }} + if sel2.GroupBy == nil && sel2.Having == nil { + // if there is no grouping, we can use the same query and + // just replace the SELECT sub-clause to have a single count(*) + sel2.SelectExprs = countStartExpr + } else { + // when there is grouping, we have to move the original query into a derived table. + // select id, sum(12) from user group by id => + // select count(*) from (select id, sum(12) from user group by id) t + sel3 := &sqlparser.Select{ + SelectExprs: countStartExpr, + From: []sqlparser.TableExpr{ + &sqlparser.AliasedTableExpr{ + Expr: &sqlparser.Subquery{Select: sel2}, + As: sqlparser.NewTableIdent("t"), + }, + }, + } + sel2 = sel3 + } + + cjt := newJointab(sqlparser.GetBindvars(sel2)) + countpb := newPrimitiveBuilder(vschema, cjt) + err = countpb.processSelect(sel2, outer, "") + if err != nil { + return nil, err + } + return &sqlCalcFoundRows{LimitQuery: frpb.bldr, CountQuery: countpb.bldr, ljt: ljt, cjt: cjt}, nil +} + func handleDualSelects(sel *sqlparser.Select, vschema ContextVSchema) (engine.Primitive, error) { if !isOnlyDual(sel) { return nil, nil diff --git a/go/vt/vtgate/planbuilder/sql_calc_found_rows.go b/go/vt/vtgate/planbuilder/sql_calc_found_rows.go new file mode 100644 index 00000000000..07dff25cf15 --- /dev/null +++ b/go/vt/vtgate/planbuilder/sql_calc_found_rows.go @@ -0,0 +1,125 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/engine" +) + +var _ builder = (*sqlCalcFoundRows)(nil) + +type sqlCalcFoundRows struct { + LimitQuery, CountQuery builder + ljt, cjt *jointab +} + +//Wireup implements the builder interface +func (s *sqlCalcFoundRows) Wireup(builder, *jointab) error { + err := s.LimitQuery.Wireup(s.LimitQuery, s.ljt) + if err != nil { + return err + } + return s.CountQuery.Wireup(s.CountQuery, s.cjt) +} + +//Primitive implements the builder interface +func (s *sqlCalcFoundRows) Primitive() engine.Primitive { + return engine.SQLCalcFoundRows{ + LimitPrimitive: s.LimitQuery.Primitive(), + CountPrimitive: s.CountQuery.Primitive(), + } +} + +// All the methods below are not implemented. They should not be called on a sqlCalcFoundRows builder + +//Order implements the builder interface +func (s *sqlCalcFoundRows) Order() int { + return s.LimitQuery.Order() +} + +//ResultColumns implements the builder interface +func (s *sqlCalcFoundRows) ResultColumns() []*resultColumn { + return s.LimitQuery.ResultColumns() +} + +//Reorder implements the builder interface +func (s *sqlCalcFoundRows) Reorder(order int) { + s.LimitQuery.Reorder(order) +} + +//First implements the builder interface +func (s *sqlCalcFoundRows) First() builder { + return s.LimitQuery.First() +} + +//PushFilter implements the builder interface +func (s *sqlCalcFoundRows) PushFilter(*primitiveBuilder, sqlparser.Expr, string, builder) error { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unreachable: sqlCalcFoundRows.PushFilter") +} + +//PushSelect implements the builder interface +func (s *sqlCalcFoundRows) PushSelect(*primitiveBuilder, *sqlparser.AliasedExpr, builder) (rc *resultColumn, colNumber int, err error) { + return nil, 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unreachable: sqlCalcFoundRows.PushSelect") +} + +//MakeDistinct implements the builder interface +func (s *sqlCalcFoundRows) MakeDistinct() error { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unreachable: sqlCalcFoundRows.MakeDistinct") +} + +//PushGroupBy implements the builder interface +func (s *sqlCalcFoundRows) PushGroupBy(sqlparser.GroupBy) error { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unreachable: sqlCalcFoundRows.PushGroupBy") +} + +//PushOrderBy implements the builder interface +func (s *sqlCalcFoundRows) PushOrderBy(sqlparser.OrderBy) (builder, error) { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unreachable: sqlCalcFoundRows.PushOrderBy") +} + +//SetUpperLimit implements the builder interface +func (s *sqlCalcFoundRows) SetUpperLimit(count sqlparser.Expr) { + s.LimitQuery.SetUpperLimit(count) +} + +//PushMisc implements the builder interface +func (s *sqlCalcFoundRows) PushMisc(sel *sqlparser.Select) { + s.LimitQuery.PushMisc(sel) +} + +//SupplyVar implements the builder interface +func (s *sqlCalcFoundRows) SupplyVar(from, to int, col *sqlparser.ColName, varname string) { + s.LimitQuery.SupplyVar(from, to, col, varname) +} + +//SupplyCol implements the builder interface +func (s *sqlCalcFoundRows) SupplyCol(col *sqlparser.ColName) (*resultColumn, int) { + return s.LimitQuery.SupplyCol(col) +} + +//SupplyWeightString implements the builder interface +func (s *sqlCalcFoundRows) SupplyWeightString(int) (weightcolNumber int, err error) { + return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unreachable: sqlCalcFoundRows.SupplyWeightString") +} + +//PushLock implements the builder interface +func (s *sqlCalcFoundRows) PushLock(string) error { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unreachable: sqlCalcFoundRows.PushLock") +} diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index 35750f6e37d..95b70a3ef3d 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -1387,3 +1387,173 @@ ] } } + +# sql_calc_found_rows without limit +"select sql_calc_found_rows * from music where user_id = 1" +{ + "QueryType": "SELECT", + "Original": "select sql_calc_found_rows * from music where user_id = 1", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select * from music where 1 != 1", + "Query": "select * from music where user_id = 1", + "Table": "music", + "Values": [ + 1 + ], + "Vindex": "user_index" + } +} + +# sql_calc_found_rows with limit +"select sql_calc_found_rows * from music limit 100" +{ + "QueryType": "SELECT", + "Original": "select sql_calc_found_rows * from music limit 100", + "Instructions": { + "OperatorType": "SQL_CALC_FOUND_ROWS", + "Inputs": [ + { + "OperatorType": "Limit", + "Count": 100, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select * from music where 1 != 1", + "Query": "select * from music limit :__upper_limit", + "Table": "music" + } + ] + }, + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count(0)", + "Distinct": "false", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from music where 1 != 1", + "Query": "select count(*) from music", + "Table": "music" + } + ] + } + ] + } +} + +# sql_calc_found_rows with SelectEqualUnique plans +"select sql_calc_found_rows * from music where user_id = 1 limit 2" +{ + "QueryType": "SELECT", + "Original": "select sql_calc_found_rows * from music where user_id = 1 limit 2", + "Instructions": { + "OperatorType": "SQL_CALC_FOUND_ROWS", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select * from music where 1 != 1", + "Query": "select * from music where user_id = 1 limit 2", + "Table": "music", + "Values": [ + 1 + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from music where 1 != 1", + "Query": "select count(*) from music where user_id = 1", + "Table": "music", + "Values": [ + 1 + ], + "Vindex": "user_index" + } + ] + } +} + +# sql_calc_found_rows with group by and having +"select sql_calc_found_rows user_id, count(id) from music group by user_id having count(user_id) = 1 order by user_id limit 2" +{ + "QueryType": "SELECT", + "Original": "select sql_calc_found_rows user_id, count(id) from music group by user_id having count(user_id) = 1 order by user_id limit 2", + "Instructions": { + "OperatorType": "SQL_CALC_FOUND_ROWS", + "Inputs": [ + { + "OperatorType": "Limit", + "Count": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_id, count(id) from music where 1 != 1 group by user_id", + "Query": "select user_id, count(id) from music group by user_id having count(user_id) = 1 order by user_id asc limit :__upper_limit", + "Table": "music" + } + ] + }, + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count(0)", + "Distinct": "false", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from (select user_id, count(id) from music where 1 != 1 group by user_id) as t where 1 != 1", + "Query": "select count(*) from (select user_id, count(id) from music group by user_id having count(user_id) = 1) as t", + "Table": "music" + } + ] + } + ] + } +} + +# sql_calc_found_rows in sub queries +"select * from music where user_id IN (select sql_calc_found_rows * from music limit 10)" +"Incorrect usage/placement of 'SQL_CALC_FOUND_ROWS' (errno 1234) (sqlstate 42000)" + +# sql_calc_found_rows in derived table +"select sql_calc_found_rows * from (select sql_calc_found_rows * from music limit 10) t limit 1" +"Incorrect usage/placement of 'SQL_CALC_FOUND_ROWS' (errno 1234) (sqlstate 42000)" + + diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.txt b/go/vt/vtgate/planbuilder/testdata/union_cases.txt index 8a88dc48f0c..0af47b4cae2 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.txt @@ -279,4 +279,3 @@ ] } } - diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt index 7dd4584e1f5..f2b9a155564 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt @@ -423,3 +423,7 @@ # insert using select get_lock from table "insert into user(pattern) SELECT GET_LOCK('xyz1', 10)" "unsupported: insert into select" + +# union with SQL_CALC_FOUND_ROWS +"(select sql_calc_found_rows id from user where id = 1 limit 1) union select id from user where id = 1" +"SQL_CALC_FOUND_ROWS not supported with union" diff --git a/go/vt/vtgate/planbuilder/union.go b/go/vt/vtgate/planbuilder/union.go index 42f692cedb6..e96178f9331 100644 --- a/go/vt/vtgate/planbuilder/union.go +++ b/go/vt/vtgate/planbuilder/union.go @@ -20,6 +20,9 @@ import ( "errors" "fmt" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/vt/sqlparser" @@ -86,13 +89,16 @@ func (pb *primitiveBuilder) processPart(part sqlparser.SelectStatement, outer *s case *sqlparser.Union: return pb.processUnion(part, outer) case *sqlparser.Select: + if part.SQLCalcFoundRows { + return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "SQL_CALC_FOUND_ROWS not supported with union") + } if !hasParens { err := checkOrderByAndLimit(part) if err != nil { return err } } - return pb.processSelect(part, outer) + return pb.processSelect(part, outer, "") case *sqlparser.ParenSelect: err := pb.processPart(part.Select, outer, true) if err != nil { diff --git a/go/vt/vtgate/safe_session.go b/go/vt/vtgate/safe_session.go index c33d0375b6b..2d4259b34ae 100644 --- a/go/vt/vtgate/safe_session.go +++ b/go/vt/vtgate/safe_session.go @@ -38,6 +38,10 @@ type SafeSession struct { mustRollback bool autocommitState autocommitState commitOrder vtgatepb.CommitOrder + + // this is a signal that found_rows has already been handles by the primitives, + // and doesn't have to be updated by the executor + foundRowsHandled bool *vtgatepb.Session } diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index 09ea335ba7d..4aaa7aae770 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -526,6 +526,11 @@ func (vc *vcursorImpl) SysVarSetEnabled() bool { return *sysVarSetEnabled } +func (vc *vcursorImpl) SetFoundRows(foundRows uint64) { + vc.safeSession.FoundRows = foundRows + vc.safeSession.foundRowsHandled = true +} + // ParseDestinationTarget parses destination target string and sets default keyspace if possible. func parseDestinationTarget(targetString string, vschema *vindexes.VSchema) (string, topodatapb.TabletType, key.Destination, error) { destKeyspace, destTabletType, dest, err := topoprotopb.ParseDestination(targetString, defaultTabletType)