diff --git a/go/test/endtoend/cluster/cluster_process.go b/go/test/endtoend/cluster/cluster_process.go index c2a30adf327..16362bd24ce 100644 --- a/go/test/endtoend/cluster/cluster_process.go +++ b/go/test/endtoend/cluster/cluster_process.go @@ -1313,3 +1313,19 @@ func (cluster *LocalProcessCluster) EnableVTOrcRecoveries(t *testing.T) { vtorc.EnableGlobalRecoveries(t) } } + +// EnableGeneralLog enables generals logs on all the mysql server started by this cluster. +// This method should be used only for local debugging purpose. +func (cluster *LocalProcessCluster) EnableGeneralLog() error { + for _, ks := range cluster.Keyspaces { + for _, shard := range ks.Shards { + for _, vttablet := range shard.Vttablets { + _, err := vttablet.VttabletProcess.QueryTablet("set global general_log = 1", "", false) + if err != nil { + return err + } + } + } + } + return nil +} diff --git a/go/test/endtoend/utils/cmp.go b/go/test/endtoend/utils/cmp.go index 849f81240e9..34f5417f8b5 100644 --- a/go/test/endtoend/utils/cmp.go +++ b/go/test/endtoend/utils/cmp.go @@ -196,7 +196,7 @@ func (mcmp *MySQLCompare) Exec(query string) *sqltypes.Result { mysqlQr, err := mcmp.MySQLConn.ExecuteFetch(query, 1000, true) require.NoError(mcmp.t, err, "[MySQL Error] for query: "+query) - compareVitessAndMySQLResults(mcmp.t, query, vtQr, mysqlQr, false) + compareVitessAndMySQLResults(mcmp.t, query, mcmp.VtConn, vtQr, mysqlQr, false) return vtQr } @@ -222,7 +222,7 @@ func (mcmp *MySQLCompare) ExecWithColumnCompare(query string) *sqltypes.Result { mysqlQr, err := mcmp.MySQLConn.ExecuteFetch(query, 1000, true) require.NoError(mcmp.t, err, "[MySQL Error] for query: "+query) - compareVitessAndMySQLResults(mcmp.t, query, vtQr, mysqlQr, true) + compareVitessAndMySQLResults(mcmp.t, query, mcmp.VtConn, vtQr, mysqlQr, true) return vtQr } @@ -241,7 +241,7 @@ func (mcmp *MySQLCompare) ExecAllowAndCompareError(query string) (*sqltypes.Resu // Since we allow errors, we don't want to compare results if one of the client failed. // Vitess and MySQL should always be agreeing whether the query returns an error or not. if vtErr == nil && mysqlErr == nil { - compareVitessAndMySQLResults(mcmp.t, query, vtQr, mysqlQr, false) + compareVitessAndMySQLResults(mcmp.t, query, mcmp.VtConn, vtQr, mysqlQr, false) } return vtQr, vtErr } diff --git a/go/test/endtoend/utils/mysql.go b/go/test/endtoend/utils/mysql.go index fc1a62d6736..6249d639a4d 100644 --- a/go/test/endtoend/utils/mysql.go +++ b/go/test/endtoend/utils/mysql.go @@ -154,7 +154,7 @@ func prepareMySQLWithSchema(params mysql.ConnParams, sql string) error { return nil } -func compareVitessAndMySQLResults(t *testing.T, query string, vtQr, mysqlQr *sqltypes.Result, compareColumns bool) { +func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumns bool) { if vtQr == nil && mysqlQr == nil { return } @@ -207,6 +207,10 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtQr, mysqlQr *sql for _, row := range mysqlQr.Rows { errStr += fmt.Sprintf("%s\n", row) } + if vtConn != nil { + qr := Exec(t, vtConn, fmt.Sprintf("vexplain plan %s", query)) + errStr += fmt.Sprintf("query plan: \n%s\n", qr.Rows[0][0].ToString()) + } t.Error(errStr) } diff --git a/go/test/endtoend/utils/utils.go b/go/test/endtoend/utils/utils.go index d3707b0bd3b..38d4e527d66 100644 --- a/go/test/endtoend/utils/utils.go +++ b/go/test/endtoend/utils/utils.go @@ -163,7 +163,7 @@ func ExecCompareMySQL(t *testing.T, vtConn, mysqlConn *mysql.Conn, query string) mysqlQr, err := mysqlConn.ExecuteFetch(query, 1000, true) require.NoError(t, err, "[MySQL Error] for query: "+query) - compareVitessAndMySQLResults(t, query, vtQr, mysqlQr, false) + compareVitessAndMySQLResults(t, query, vtConn, vtQr, mysqlQr, false) return vtQr } diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index f07fb734df8..ede9a3df9f8 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -73,10 +73,12 @@ func TestGroupBy(t *testing.T) { mcmp.Exec("insert into t3(id5, id6, id7) values(1,1,2), (2,2,4), (3,2,4), (4,1,2), (5,1,2), (6,3,6)") // test ordering and group by int column mcmp.AssertMatches("select id6, id7, count(*) k from t3 group by id6, id7 order by k", `[[INT64(3) INT64(6) INT64(1)] [INT64(2) INT64(4) INT64(2)] [INT64(1) INT64(2) INT64(3)]]`) + mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ id6+id7, count(*) k from t3 group by id6+id7 order by k", `[[INT64(9) INT64(1)] [INT64(6) INT64(2)] [INT64(3) INT64(3)]]`) // Test the same queries in streaming mode utils.Exec(t, mcmp.VtConn, "set workload = olap") mcmp.AssertMatches("select id6, id7, count(*) k from t3 group by id6, id7 order by k", `[[INT64(3) INT64(6) INT64(1)] [INT64(2) INT64(4) INT64(2)] [INT64(1) INT64(2) INT64(3)]]`) + mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ id6+id7, count(*) k from t3 group by id6+id7 order by k", `[[INT64(9) INT64(1)] [INT64(6) INT64(2)] [INT64(3) INT64(3)]]`) } func TestDistinct(t *testing.T) { diff --git a/go/test/endtoend/vtgate/queries/aggregation/fuzz_test.go b/go/test/endtoend/vtgate/queries/aggregation/fuzz_test.go new file mode 100644 index 00000000000..25bec1a39b4 --- /dev/null +++ b/go/test/endtoend/vtgate/queries/aggregation/fuzz_test.go @@ -0,0 +1,212 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aggregation + +import ( + "fmt" + "math/rand" + "strings" + "testing" + "time" + + "golang.org/x/exp/maps" + + "vitess.io/vitess/go/vt/log" +) + +type ( + column struct { + name string + typ string + } + tableT struct { + name string + columns []column + } +) + +func TestFuzzAggregations(t *testing.T) { + // This test randomizes values and queries, and checks that mysql returns the same values that Vitess does + mcmp, closer := start(t) + defer closer() + + noOfRows := rand.Intn(20) + var values []string + for i := 0; i < noOfRows; i++ { + values = append(values, fmt.Sprintf("(%d, 'name%d', 'value%d', %d)", i, i, i, i)) + } + t1Insert := fmt.Sprintf("insert into t1 (t1_id, name, value, shardKey) values %s;", strings.Join(values, ",")) + values = nil + noOfRows = rand.Intn(20) + for i := 0; i < noOfRows; i++ { + values = append(values, fmt.Sprintf("(%d, %d)", i, i)) + } + t2Insert := fmt.Sprintf("insert into t2 (id, shardKey) values %s;", strings.Join(values, ",")) + + mcmp.Exec(t1Insert) + mcmp.Exec(t2Insert) + + t.Cleanup(func() { + if t.Failed() { + fmt.Println(t1Insert) + fmt.Println(t2Insert) + } + }) + + schema := map[string]tableT{ + "t1": {name: "t1", columns: []column{ + {name: "t1_id", typ: "bigint"}, + {name: "name", typ: "varchar"}, + {name: "value", typ: "varchar"}, + {name: "shardKey", typ: "bigint"}, + }}, + "t2": {name: "t2", columns: []column{ + {name: "id", typ: "bigint"}, + {name: "shardKey", typ: "bigint"}, + }}, + } + + endBy := time.Now().Add(1 * time.Second) + schemaTables := maps.Values(schema) + + var queryCount int + for time.Now().Before(endBy) || t.Failed() { + tables := createTables(schemaTables) + query := randomQuery(tables, 3, 3) + mcmp.Exec(query) + if t.Failed() { + fmt.Println(query) + } + queryCount++ + } + log.Info("Queries successfully executed: %d", queryCount) +} + +func randomQuery(tables []tableT, maxAggrs, maxGroupBy int) string { + randomCol := func(tblIdx int) (string, string) { + tbl := tables[tblIdx] + col := randomEl(tbl.columns) + return fmt.Sprintf("tbl%d.%s", tblIdx, col.name), col.typ + } + predicates := createPredicates(tables, randomCol) + aggregates := createAggregations(tables, maxAggrs, randomCol) + grouping := createGroupBy(tables, maxGroupBy, randomCol) + sel := "select /*vt+ PLANNER=Gen4 */ " + strings.Join(aggregates, ", ") + " from " + + var tbls []string + for i, s := range tables { + tbls = append(tbls, fmt.Sprintf("%s as tbl%d", s.name, i)) + } + sel += strings.Join(tbls, ", ") + + if len(predicates) > 0 { + sel += " where " + sel += strings.Join(predicates, " and ") + } + if len(grouping) > 0 { + sel += " group by " + sel += strings.Join(grouping, ", ") + } + // we do it this way so we don't have to do only `only_full_group_by` queries + var noOfOrderBy int + if len(grouping) > 0 { + // panic on rand function call if value is 0 + noOfOrderBy = rand.Intn(len(grouping)) + } + if noOfOrderBy > 0 { + noOfOrderBy = 0 // TODO turning on ORDER BY here causes lots of failures to happen + } + if noOfOrderBy > 0 { + var orderBy []string + for noOfOrderBy > 0 { + noOfOrderBy-- + if rand.Intn(2) == 0 || len(grouping) == 0 { + orderBy = append(orderBy, randomEl(aggregates)) + } else { + orderBy = append(orderBy, randomEl(grouping)) + } + } + sel += " order by " + sel += strings.Join(orderBy, ", ") + } + return sel +} + +func createGroupBy(tables []tableT, maxGB int, randomCol func(tblIdx int) (string, string)) (grouping []string) { + noOfGBs := rand.Intn(maxGB) + for i := 0; i < noOfGBs; i++ { + tblIdx := rand.Intn(len(tables)) + col, _ := randomCol(tblIdx) + grouping = append(grouping, col) + } + return +} + +func createAggregations(tables []tableT, maxAggrs int, randomCol func(tblIdx int) (string, string)) (aggregates []string) { + aggregations := []func(string) string{ + func(_ string) string { return "count(*)" }, + func(e string) string { return fmt.Sprintf("count(%s)", e) }, + //func(e string) string { return fmt.Sprintf("sum(%s)", e) }, + //func(e string) string { return fmt.Sprintf("avg(%s)", e) }, + //func(e string) string { return fmt.Sprintf("min(%s)", e) }, + //func(e string) string { return fmt.Sprintf("max(%s)", e) }, + } + + noOfAggrs := rand.Intn(maxAggrs) + 1 + for i := 0; i < noOfAggrs; i++ { + tblIdx := rand.Intn(len(tables)) + e, _ := randomCol(tblIdx) + aggregates = append(aggregates, randomEl(aggregations)(e)) + } + return aggregates +} + +func createTables(schemaTables []tableT) []tableT { + noOfTables := rand.Intn(2) + 1 + var tables []tableT + + for i := 0; i < noOfTables; i++ { + tables = append(tables, randomEl(schemaTables)) + } + return tables +} + +func createPredicates(tables []tableT, randomCol func(tblIdx int) (string, string)) (predicates []string) { + for idx1 := range tables { + for idx2 := range tables { + if idx1 == idx2 { + continue + } + noOfPredicates := rand.Intn(2) + + for noOfPredicates > 0 { + col1, t1 := randomCol(idx1) + col2, t2 := randomCol(idx2) + if t1 != t2 { + continue + } + predicates = append(predicates, fmt.Sprintf("%s = %s", col1, col2)) + noOfPredicates-- + } + } + } + return predicates +} + +func randomEl[K any](in []K) K { + return in[rand.Intn(len(in))] +} diff --git a/go/test/endtoend/vtgate/queries/orderby/orderby_test.go b/go/test/endtoend/vtgate/queries/orderby/orderby_test.go index 5c6c7503a97..7afca61e245 100644 --- a/go/test/endtoend/vtgate/queries/orderby/orderby_test.go +++ b/go/test/endtoend/vtgate/queries/orderby/orderby_test.go @@ -66,6 +66,8 @@ func TestOrderBy(t *testing.T) { mcmp.AssertMatches("select id1, id2 from t4 order by id2 desc", `[[INT64(5) VARCHAR("test")] [INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(2) VARCHAR("Abc")] [INT64(1) VARCHAR("a")]]`) // test ordering of int column mcmp.AssertMatches("select id1, id2 from t4 order by id1 desc", `[[INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(5) VARCHAR("test")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(2) VARCHAR("Abc")] [INT64(1) VARCHAR("a")]]`) + // test ordering of complex column + mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ id1, id2 from t4 order by reverse(id2) desc", `[[INT64(5) VARCHAR("test")] [INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(2) VARCHAR("Abc")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(1) VARCHAR("a")]]`) defer func() { utils.Exec(t, mcmp.VtConn, "set workload = oltp") @@ -75,4 +77,5 @@ func TestOrderBy(t *testing.T) { utils.Exec(t, mcmp.VtConn, "set workload = olap") mcmp.AssertMatches("select id1, id2 from t4 order by id2 desc", `[[INT64(5) VARCHAR("test")] [INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(2) VARCHAR("Abc")] [INT64(1) VARCHAR("a")]]`) mcmp.AssertMatches("select id1, id2 from t4 order by id1 desc", `[[INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(5) VARCHAR("test")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(2) VARCHAR("Abc")] [INT64(1) VARCHAR("a")]]`) + mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ id1, id2 from t4 order by reverse(id2) desc", `[[INT64(5) VARCHAR("test")] [INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(2) VARCHAR("Abc")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(1) VARCHAR("a")]]`) } diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index c5a89a25173..0c98bfb01b9 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -2867,6 +2867,33 @@ type ( } CountStar struct { + _ bool + // TL;DR; This makes sure that reference equality checks works as expected + // + // You're correct that this might seem a bit strange at first glance. + // It's a quirk of Go's handling of empty structs. In Go, two instances of an empty struct are considered + // identical, which can be problematic when using these as keys in maps. + // They would be treated as the same key and potentially lead to incorrect map behavior. + // + // Here's a brief example: + // + // ```golang + // func TestWeirdGo(t *testing.T) { + // type CountStar struct{} + // + // cs1 := &CountStar{} + // cs2 := &CountStar{} + // if cs1 == cs2 { + // panic("what the what!?") + // } + // } + // ``` + // + // In the above code, cs1 and cs2, despite being distinct variables, would be treated as the same object. + // + // The solution we employed was to add a dummy field `_ bool` to the otherwise empty struct `CountStar`. + // This ensures that each instance of `CountStar` is treated as a separate object, + // even in the context of out semantic state which uses these objects as map keys. } Avg struct { diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 886e2620915..4138a5588c6 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -933,6 +933,11 @@ func writeEscapedString(buf *TrackedBuffer, original string) { buf.WriteByte('`') } +func CompliantString(in SQLNode) string { + s := String(in) + return compliantName(s) +} + func compliantName(in string) string { var buf strings.Builder for i, c := range in { diff --git a/go/vt/sqlparser/ast_rewriting.go b/go/vt/sqlparser/ast_rewriting.go index 2b6683a15d4..00fa7c8b740 100644 --- a/go/vt/sqlparser/ast_rewriting.go +++ b/go/vt/sqlparser/ast_rewriting.go @@ -75,11 +75,15 @@ func (r *ReservedVars) ReserveAll(names ...string) bool { // with the same name already exists, it'll be suffixed with a numberic identifier // to make it unique. func (r *ReservedVars) ReserveColName(col *ColName) string { - compliantName := col.CompliantName() - if r.fast && strings.HasPrefix(compliantName, r.prefix) { - compliantName = "_" + compliantName + reserveName := col.CompliantName() + if r.fast && strings.HasPrefix(reserveName, r.prefix) { + reserveName = "_" + reserveName } + return r.ReserveVariable(reserveName) +} + +func (r *ReservedVars) ReserveVariable(compliantName string) string { joinVar := []byte(compliantName) baseLen := len(joinVar) i := int64(1) diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index faa51a74346..4238aa28a1a 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -922,6 +922,16 @@ func (cached *Count) CachedSize(alloc bool) int64 { } return size } +func (cached *CountStar) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(8) + } + return size +} func (cached *CreateDatabase) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index c67a0951b35..a4c7f66b174 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -58,8 +58,8 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st } result := &sqltypes.Result{} if len(lresult.Rows) == 0 && wantfields { - for k := range jn.Vars { - joinVars[k] = sqltypes.NullBindVariable + for k, col := range jn.Vars { + joinVars[k] = bindvarForType(lresult.Fields[col].Type) } rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars)) if err != nil { @@ -93,6 +93,25 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st return result, nil } +func bindvarForType(t querypb.Type) *querypb.BindVariable { + bv := &querypb.BindVariable{ + Type: t, + Value: nil, + } + switch t { + case querypb.Type_INT8, querypb.Type_UINT8, querypb.Type_INT16, querypb.Type_UINT16, + querypb.Type_INT32, querypb.Type_UINT32, querypb.Type_INT64, querypb.Type_UINT64: + bv.Value = []byte("0") + case querypb.Type_FLOAT32, querypb.Type_FLOAT64: + bv.Value = []byte("0e0") + case querypb.Type_DECIMAL: + bv.Value = []byte("0.0") + default: + return sqltypes.NullBindVariable + } + return bv +} + // TryStreamExecute performs a streaming exec. func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { joinVars := make(map[string]*querypb.BindVariable) diff --git a/go/vt/vtgate/engine/join_test.go b/go/vt/vtgate/engine/join_test.go index 50ccb35ac7c..2df507f9512 100644 --- a/go/vt/vtgate/engine/join_test.go +++ b/go/vt/vtgate/engine/join_test.go @@ -237,9 +237,7 @@ func TestJoinExecuteNoResult(t *testing.T) { }, } r, err := jn.TryExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) leftPrim.ExpectLog(t, []string{ `Execute true`, }) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 65d154869f0..b456f7df04d 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -2495,10 +2495,12 @@ func TestEmptyJoin(t *testing.T) { sbc1.SetResults([]*sqltypes.Result{{ Fields: []*querypb.Field{ {Name: "id", Type: sqltypes.Int32}, + {Name: "col", Type: sqltypes.Int32}, }, }, { Fields: []*querypb.Field{ {Name: "id", Type: sqltypes.Int32}, + {Name: "col", Type: sqltypes.Int32}, }, }}) result, err := executorExec(executor, "select u1.id, u2.id from user u1 join user u2 on u2.id = u1.col where u1.id = 1", nil) @@ -2509,7 +2511,7 @@ func TestEmptyJoin(t *testing.T) { }, { Sql: "select u2.id from `user` as u2 where 1 != 1", BindVariables: map[string]*querypb.BindVariable{ - "u1_col": sqltypes.NullBindVariable, + "u1_col": sqltypes.Int32BindVariable(0), }, }} utils.MustMatch(t, wantQueries, sbc1.Queries) @@ -3680,7 +3682,7 @@ func TestSelectAggregationNoData(t *testing.T) { { sql: `select col, count(*) from user group by col limit 2`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|count(*)", "int64|int64")), - expSandboxQ: "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc limit :__upper_limit", + expSandboxQ: "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc", expField: `[name:"col" type:INT64 name:"count(*)" type:INT64]`, expRow: `[]`, }, @@ -3764,7 +3766,7 @@ func TestSelectAggregationData(t *testing.T) { { sql: `select col, count(*) from user group by col limit 2`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|count(*)|weight_string(col)", "int64|int64|varbinary"), "1|2|NULL", "2|1|NULL", "3|4|NULL"), - expSandboxQ: "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc limit :__upper_limit", + expSandboxQ: "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc", expField: `[name:"col" type:INT64 name:"count(*)" type:INT64]`, expRow: `[[INT64(1) INT64(16)] [INT64(2) INT64(8)]]`, }, diff --git a/go/vt/vtgate/planbuilder/aggregation_pushing.go b/go/vt/vtgate/planbuilder/aggregation_pushing.go index cc05c9e8377..a878e039b91 100644 --- a/go/vt/vtgate/planbuilder/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/aggregation_pushing.go @@ -69,7 +69,7 @@ func (hp *horizonPlanning) pushAggregation( pushed = false for _, grp := range grouping { - offset, wOffset, err := wrapAndPushExpr(ctx, grp.Inner, grp.WeightStrExpr, plan.input) + offset, wOffset, err := wrapAndPushExpr(ctx, grp.Inner, grp.SimplifiedExpr, plan.input) if err != nil { return nil, nil, nil, false, err } @@ -166,8 +166,8 @@ func pushAggrOnRoute( pos = newOffset(groupingCols[idx]) } - if expr.WeightStrExpr != nil && ctx.SemTable.NeedsWeightString(expr.Inner) { - wsExpr := weightStringFor(expr.WeightStrExpr) + if expr.SimplifiedExpr != nil && ctx.SemTable.NeedsWeightString(expr.Inner) { + wsExpr := weightStringFor(expr.SimplifiedExpr) wsCol, _, err := addExpressionToRoute(ctx, plan, &sqlparser.AliasedExpr{Expr: wsExpr}, true) if err != nil { return nil, nil, nil, err @@ -287,7 +287,7 @@ func (hp *horizonPlanning) pushAggrOnJoin( return nil, nil, err } l = sqlparser.NewIntLiteral(strconv.Itoa(offset + 1)) - rhsGrouping = append(rhsGrouping, operators.GroupBy{Inner: l}) + rhsGrouping = append(rhsGrouping, operators.NewGroupBy(l, nil, nil)) } // Next we push the aggregations to both sides diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 352043e8552..1375f5ff690 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -269,7 +269,7 @@ func (hp *horizonPlanning) planAggrUsingOA( var order []ops.OrderBy if hp.qp.CanPushDownSorting { - hp.qp.AlignGroupByAndOrderBy(ctx) + hp.qp.OldAlignGroupByAndOrderBy(ctx) // the grouping order might have changed, so we reload the grouping expressions grouping = hp.qp.GetGrouping() order = hp.qp.OrderExprs @@ -519,7 +519,8 @@ func (hp *horizonPlanning) handleDistinctAggr(ctx *plancontext.PlanningContext, continue } - inner, innerWS, err := hp.qp.GetSimplifiedExpr(expr.Func.GetArg()) + inner := expr.Func.GetArg() + innerWS := hp.qp.GetSimplifiedExpr(inner) if err != nil { return nil, nil, nil, err } @@ -535,11 +536,9 @@ func (hp *horizonPlanning) handleDistinctAggr(ctx *plancontext.PlanningContext, return nil, nil, nil, err } } - distincts = append(distincts, operators.GroupBy{ - Inner: inner, - WeightStrExpr: innerWS, - InnerIndex: expr.Index, - }) + groupBy := operators.NewGroupBy(inner, innerWS, nil) + groupBy.InnerIndex = expr.Index + distincts = append(distincts, groupBy) offsets = append(offsets, i) } return @@ -578,22 +577,16 @@ func newOffset(col int) offsets { func (hp *horizonPlanning) createGroupingsForColumns(columns []*sqlparser.ColName) ([]operators.GroupBy, error) { var lhsGrouping []operators.GroupBy for _, lhsColumn := range columns { - expr, wsExpr, err := hp.qp.GetSimplifiedExpr(lhsColumn) - if err != nil { - return nil, err - } + wsExpr := hp.qp.GetSimplifiedExpr(lhsColumn) - lhsGrouping = append(lhsGrouping, operators.GroupBy{ - Inner: expr, - WeightStrExpr: wsExpr, - }) + lhsGrouping = append(lhsGrouping, operators.NewGroupBy(lhsColumn, wsExpr, nil)) } return lhsGrouping, nil } func hasUniqueVindex(semTable *semantics.SemTable, groupByExprs []operators.GroupBy) bool { for _, groupByExpr := range groupByExprs { - if exprHasUniqueVindex(semTable, groupByExpr.WeightStrExpr) { + if exprHasUniqueVindex(semTable, groupByExpr.SimplifiedExpr) { return true } } @@ -620,7 +613,7 @@ func (hp *horizonPlanning) planOrderBy(ctx *plancontext.PlanningContext, orderEx orderExprs = orderExprsWithoutNils for _, order := range orderExprs { - if sqlparser.ContainsAggregation(order.WeightStrExpr) { + if sqlparser.ContainsAggregation(order.SimplifiedExpr) { ms, err := createMemorySortPlanOnAggregation(ctx, plan, orderExprs) if err != nil { return nil, err @@ -680,7 +673,7 @@ func planOrderByForRoute(ctx *plancontext.PlanningContext, orderExprs []ops.Orde } var wsExpr sqlparser.Expr if ctx.SemTable.NeedsWeightString(order.Inner.Expr) { - wsExpr = order.WeightStrExpr + wsExpr = order.SimplifiedExpr } offset, weightStringOffset, err := wrapAndPushExpr(ctx, order.Inner.Expr, wsExpr, plan) @@ -832,7 +825,7 @@ func createMemorySortPlanOnAggregation(ctx *plancontext.PlanningContext, plan *o return nil, vterrors.VT13001(fmt.Sprintf("expected to find ORDER BY expression (%s) in orderedAggregate", sqlparser.String(order.Inner))) } - collationID := ctx.SemTable.CollationForExpr(order.WeightStrExpr) + collationID := ctx.SemTable.CollationForExpr(order.SimplifiedExpr) ms.eMemorySort.OrderBy = append(ms.eMemorySort.OrderBy, engine.OrderByParams{ Col: offset, WeightStringCol: woffset, @@ -846,13 +839,13 @@ func createMemorySortPlanOnAggregation(ctx *plancontext.PlanningContext, plan *o func findExprInOrderedAggr(ctx *plancontext.PlanningContext, plan *orderedAggregate, order ops.OrderBy) (keyCol int, weightStringCol int, found bool) { for _, key := range plan.groupByKeys { - if ctx.SemTable.EqualsExpr(order.WeightStrExpr, key.Expr) || + if ctx.SemTable.EqualsExpr(order.SimplifiedExpr, key.Expr) || ctx.SemTable.EqualsExpr(order.Inner.Expr, key.Expr) { return key.KeyCol, key.WeightStringCol, true } } for _, aggregate := range plan.aggregates { - if ctx.SemTable.EqualsExpr(order.WeightStrExpr, aggregate.Original.Expr) || + if ctx.SemTable.EqualsExpr(order.SimplifiedExpr, aggregate.Original.Expr) || ctx.SemTable.EqualsExpr(order.Inner.Expr, aggregate.Original.Expr) { return aggregate.Col, -1, true } @@ -872,7 +865,7 @@ func (hp *horizonPlanning) createMemorySortPlan(ctx *plancontext.PlanningContext } for _, order := range orderExprs { - wsExpr := order.WeightStrExpr + wsExpr := order.SimplifiedExpr if !useWeightStr { wsExpr = nil } @@ -991,8 +984,8 @@ func (hp *horizonPlanning) addDistinct(ctx *plancontext.PlanningContext, plan lo groupByKeys = append(groupByKeys, grpParam) orderExprs = append(orderExprs, ops.OrderBy{ - Inner: &sqlparser.Order{Expr: inner}, - WeightStrExpr: aliasExpr.Expr}, + Inner: &sqlparser.Order{Expr: inner}, + SimplifiedExpr: aliasExpr.Expr}, ) } innerPlan, err := hp.planOrderBy(ctx, orderExprs, plan) @@ -1177,7 +1170,7 @@ func planGroupByGen4(ctx *plancontext.PlanningContext, groupExpr operators.Group // then we need to add that to the group by clause otherwise the query will fail on mysql with full_group_by error // as the weight_string function might not be functionally dependent on the group by. if wsAdded { - sel.AddGroupBy(weightStringFor(groupExpr.WeightStrExpr)) + sel.AddGroupBy(weightStringFor(groupExpr.SimplifiedExpr)) } return nil case *pulloutSubquery: diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 540d9b6771e..9580fd85075 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -23,7 +23,7 @@ import ( "strings" "vitess.io/vitess/go/slices2" - + "vitess.io/vitess/go/vt/vtgate/engine/opcode" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" @@ -71,11 +71,55 @@ func transformToLogicalPlan(ctx *plancontext.PlanningContext, op ops.Operator, i return transformLimit(ctx, op) case *operators.Ordering: return transformOrdering(ctx, op) + case *operators.Aggregator: + return transformAggregator(ctx, op) } return nil, vterrors.VT13001(fmt.Sprintf("unknown type encountered: %T (transformToLogicalPlan)", op)) } +func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggregator) (logicalPlan, error) { + plan, err := transformToLogicalPlan(ctx, op.Source, false) + if err != nil { + return nil, err + } + + oa := &orderedAggregate{ + resultsBuilder: resultsBuilder{ + logicalPlanCommon: newBuilderCommon(plan), + weightStrings: make(map[*resultColumn]int), + }, + } + + for _, aggr := range op.Aggregations { + if aggr.OpCode == opcode.AggregateUnassigned { + return nil, vterrors.VT12001(fmt.Sprintf("in scatter query: aggregation function '%s'", sqlparser.String(aggr.Original))) + } + oa.aggregates = append(oa.aggregates, &engine.AggregateParams{ + Opcode: aggr.OpCode, + Col: aggr.ColOffset, + Alias: aggr.Alias, + Expr: aggr.Func, + Original: aggr.Original, + OrigOpcode: aggr.OriginalOpCode, + }) + } + for _, groupBy := range op.Grouping { + oa.groupByKeys = append(oa.groupByKeys, &engine.GroupByParams{ + KeyCol: groupBy.ColOffset, + WeightStringCol: groupBy.WSOffset, + Expr: groupBy.AsAliasedExpr().Expr, + CollationID: ctx.SemTable.CollationForExpr(groupBy.SimplifiedExpr), + }) + } + + if err != nil { + return nil, err + } + oa.truncateColumnCount = op.ResultColumns + return oa, nil +} + func transformOrdering(ctx *plancontext.PlanningContext, op *operators.Ordering) (logicalPlan, error) { plan, err := transformToLogicalPlan(ctx, op.Source, false) if err != nil { @@ -86,7 +130,9 @@ func transformOrdering(ctx *plancontext.PlanningContext, op *operators.Ordering) } func createMemorySort(ctx *plancontext.PlanningContext, src logicalPlan, ordering *operators.Ordering) (logicalPlan, error) { - primitive := &engine.MemorySort{} + primitive := &engine.MemorySort{ + TruncateColumnCount: ordering.ResultColumns, + } ms := &memorySort{ resultsBuilder: resultsBuilder{ logicalPlanCommon: newBuilderCommon(src), @@ -97,7 +143,7 @@ func createMemorySort(ctx *plancontext.PlanningContext, src logicalPlan, orderin } for idx, order := range ordering.Order { - collationID := ctx.SemTable.CollationForExpr(order.WeightStrExpr) + collationID := ctx.SemTable.CollationForExpr(order.SimplifiedExpr) ms.eMemorySort.OrderBy = append(ms.eMemorySort.OrderBy, engine.OrderByParams{ Col: ordering.Offset[idx], WeightStringCol: ordering.WOffset[idx], @@ -122,14 +168,44 @@ func transformProjection(ctx *plancontext.PlanningContext, op *operators.Project return useSimpleProjection(op, cols, src) } - expressions := slices2.Map(op.Columns, func(from operators.ProjExpr) sqlparser.Expr { + expressions := slices2.Map(op.Projections, func(from operators.ProjExpr) sqlparser.Expr { return from.GetExpr() }) + failed := false + evalengineExprs := slices2.Map(op.Projections, func(from operators.ProjExpr) evalengine.Expr { + switch e := from.(type) { + case operators.Eval: + return e.EExpr + case operators.Offset: + t := ctx.SemTable.ExprTypes[e.Expr] + return &evalengine.Column{ + Offset: e.Offset, + Type: t.Type, + Collation: collations.TypedCollation{}, + } + default: + failed = true + return nil + } + }) + var primitive *engine.Projection + columnNames := slices2.Map(op.Columns, func(from *sqlparser.AliasedExpr) string { + return from.ColumnName() + }) + + if !failed { + primitive = &engine.Projection{ + Cols: columnNames, + Exprs: evalengineExprs, + } + } + return &projection{ source: src, - columnNames: op.ColumnNames, + columnNames: columnNames, columns: expressions, + primitive: primitive, }, nil } diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index e9d729cee28..b2f8a6a1766 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -100,6 +100,11 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { } } +func (qb *queryBuilder) addGroupBy(original sqlparser.Expr) { + sel := qb.sel.(*sqlparser.Select) + sel.GroupBy = append(sel.GroupBy, original) +} + func (qb *queryBuilder) addProjection(projection *sqlparser.AliasedExpr) { sel := qb.sel.(*sqlparser.Select) sel.SelectExprs = append(sel.SelectExprs, projection) @@ -322,12 +327,41 @@ func buildQuery(op ops.Operator, qb *queryBuilder) error { return buildLimit(op, qb) case *Ordering: return buildOrdering(op, qb) + case *Aggregator: + return buildAggregation(op, qb) default: return vterrors.VT13001(fmt.Sprintf("do not know how to turn %T into SQL", op)) } return nil } +func buildAggregation(op *Aggregator, qb *queryBuilder) error { + err := buildQuery(op.Source, qb) + if err != nil { + return err + } + + qb.clearProjections() + + cols, err := op.GetColumns() + if err != nil { + return err + } + for _, column := range cols { + qb.addProjection(column) + } + + for _, by := range op.Grouping { + qb.addGroupBy(by.Inner) + simplified := by.SimplifiedExpr + if by.WSOffset != -1 { + qb.addGroupBy(weightStringFor(simplified)) + } + } + + return nil +} + func buildOrdering(op *Ordering, qb *queryBuilder) error { err := buildQuery(op.Source, qb) if err != nil { @@ -372,12 +406,8 @@ func buildProjection(op *Projection, qb *queryBuilder) error { qb.clearProjections() - for i, column := range op.Columns { - ae := &sqlparser.AliasedExpr{Expr: column.GetExpr()} - if op.ColumnNames[i] != "" { - ae.As = sqlparser.NewIdentifierCI(op.ColumnNames[i]) - } - qb.addProjection(ae) + for _, column := range op.Columns { + qb.addProjection(column) } // if the projection is on derived table, we use the select we have diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go new file mode 100644 index 00000000000..b9f08be73d9 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -0,0 +1,571 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "fmt" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/engine/opcode" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +func tryPushingDownAggregator(ctx *plancontext.PlanningContext, aggregator *Aggregator) (output ops.Operator, applyResult *rewrite.ApplyResult, err error) { + if aggregator.Pushed { + return aggregator, rewrite.SameTree, nil + } + aggregator.Pushed = true + switch src := aggregator.Source.(type) { + case *Route: + output, applyResult, err = pushDownAggregationThroughRoute(ctx, aggregator, src) + case *ApplyJoin: + output, applyResult, err = pushDownAggregationThroughJoin(ctx, aggregator, src) + default: + return aggregator, rewrite.SameTree, nil + } + + if applyResult != rewrite.SameTree && aggregator.Original { + aggregator.aggregateTheAggregates() + } + + return +} + +func (a *Aggregator) aggregateTheAggregates() { + for i, aggr := range a.Aggregations { + // Handle different aggregation operations when pushing down through a sharded route. + switch aggr.OpCode { + case opcode.AggregateCount, opcode.AggregateCountStar, opcode.AggregateCountDistinct: + // All count variations turn into SUM above the Route. + // Think of it as we are SUMming together a bunch of distributed COUNTs. + aggr.OriginalOpCode, aggr.OpCode = aggr.OpCode, opcode.AggregateSum + a.Aggregations[i] = aggr + } + } +} + +func pushDownAggregationThroughRoute( + ctx *plancontext.PlanningContext, + aggregator *Aggregator, + route *Route, +) (ops.Operator, *rewrite.ApplyResult, error) { + // If the route is single-shard, or we are grouping by sharding keys, we can just push down the aggregation + if route.IsSingleShard() || overlappingUniqueVindex(ctx, aggregator.Grouping) { + return rewrite.Swap(aggregator, route, "pushDownAggregationThroughRoute") + } + + // Create a new aggregator to be placed below the route. + aggrBelowRoute := aggregator.Clone([]ops.Operator{route.Source}).(*Aggregator) + aggrBelowRoute.Pushed = false + aggrBelowRoute.Original = false + + // Set the source of the route to the new aggregator placed below the route. + route.Source = aggrBelowRoute + + if !aggregator.Original { + // we only keep the root aggregation, if this aggregator was created + // by splitting one and pushing under a join, we can get rid of this one + return aggregator.Source, rewrite.NewTree("push aggregation under route - remove original", aggregator), nil + } + + return aggregator, rewrite.NewTree("push aggregation under route - keep original", aggregator), nil +} + +func overlappingUniqueVindex(ctx *plancontext.PlanningContext, groupByExprs []GroupBy) bool { + for _, groupByExpr := range groupByExprs { + if exprHasUniqueVindex(ctx, groupByExpr.SimplifiedExpr) { + return true + } + } + return false +} + +func exprHasUniqueVindex(ctx *plancontext.PlanningContext, expr sqlparser.Expr) bool { + return exprHasVindex(ctx, expr, true) +} + +func exprHasVindex(ctx *plancontext.PlanningContext, expr sqlparser.Expr, hasToBeUnique bool) bool { + col, isCol := expr.(*sqlparser.ColName) + if !isCol { + return false + } + ts := ctx.SemTable.RecursiveDeps(expr) + tableInfo, err := ctx.SemTable.TableInfoFor(ts) + if err != nil { + return false + } + vschemaTable := tableInfo.GetVindexTable() + for _, vindex := range vschemaTable.ColumnVindexes { + // TODO: Support composite vindexes (multicol, etc). + if len(vindex.Columns) > 1 || hasToBeUnique && !vindex.IsUnique() { + return false + } + if col.Name.Equal(vindex.Columns[0]) { + return true + } + } + return false +} + +/* +We push down aggregations using the logic from the paper Orthogonal Optimization of Subqueries and Aggregation, by +Cesar A. Galindo-Legaria and Milind M. Joshi from Microsoft Corp. + +It explains how one can split an aggregation into local aggregates that depend on only one side of the join. +The local aggregates can then be gathered together to produce the global +group by/aggregate query that the user asked for. + +In Vitess, this is particularly useful because it allows us to push aggregation down to the routes, even when +we have to join the results at the vtgate level. Instead of doing all the grouping and aggregation at the +vtgate level, we can offload most of the work to MySQL, and at the vtgate just summarize the results. + +# For a query, such as + +select count(*) from R1 JOIN R2 on R1.id = R2.id + +Original: + + GB <- This is the original grouping, doing count(*) + | + JOIN + / \ + R1 R2 + +Transformed: + + rootAggr <- This grouping is now SUMing together the distributed `count(*)` we got back + | + Proj <- This projection makes sure that the columns are lined up as expected + | + Sort <- Here we are sorting the input so that the OrderedAggregate can do its thing + | + JOIN + / \ + lAggr rAggr + / \ + R1 R2 +*/ +func pushDownAggregationThroughJoin(ctx *plancontext.PlanningContext, rootAggr *Aggregator, join *ApplyJoin) (ops.Operator, *rewrite.ApplyResult, error) { + lhs := &joinPusher{ + orig: rootAggr, + pushed: &Aggregator{ + Source: join.LHS, + QP: rootAggr.QP, + }, + columns: initColReUse(len(rootAggr.Columns)), + tableID: TableID(join.LHS), + } + rhs := &joinPusher{ + orig: rootAggr, + pushed: &Aggregator{ + Source: join.RHS, + QP: rootAggr.QP, + }, + columns: initColReUse(len(rootAggr.Columns)), + tableID: TableID(join.RHS), + } + + joinColumns, output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, lhs, rhs) + if err != nil { + return nil, nil, err + } + + groupingJCs, err := splitGroupingToLeftAndRight(ctx, rootAggr, lhs, rhs) + if err != nil { + return nil, nil, err + } + joinColumns = append(joinColumns, groupingJCs...) + + // We need to add any columns coming from the lhs of the join to the group by on that side + // If we don't, the LHS will not be able to return the column, and it can't be used to send down to the RHS + err = addColumnsFromLHSInJoinPredicates(ctx, rootAggr, join, lhs) + if err != nil { + return nil, nil, err + } + + join.LHS, join.RHS = lhs.pushed, rhs.pushed + join.ColumnsAST = joinColumns + + if !rootAggr.Original { + // we only keep the root aggregation, if this aggregator was created + // by splitting one and pushing under a join, we can get rid of this one + return output, rewrite.NewTree("push Aggregation under join - keep original", rootAggr), nil + } + + rootAggr.Source = output + return rootAggr, rewrite.NewTree("push Aggregation under join", rootAggr), nil +} + +func addColumnsFromLHSInJoinPredicates(ctx *plancontext.PlanningContext, rootAggr *Aggregator, join *ApplyJoin, lhs *joinPusher) error { + for _, pred := range join.JoinPredicates { + for _, expr := range pred.LHSExprs { + wexpr := rootAggr.QP.GetSimplifiedExpr(expr) + idx, found := canReuseColumn(ctx, lhs.pushed.Columns, expr, extractExpr) + if !found { + idx = len(lhs.pushed.Columns) + lhs.pushed.Columns = append(lhs.pushed.Columns, aeWrap(expr)) + } + _, found = canReuseColumn(ctx, lhs.pushed.Grouping, wexpr, func(by GroupBy) sqlparser.Expr { + return by.SimplifiedExpr + }) + + if found { + continue + } + + lhs.pushed.Grouping = append(lhs.pushed.Grouping, GroupBy{ + Inner: expr, + SimplifiedExpr: wexpr, + ColOffset: idx, + WSOffset: -1, + }) + } + } + return nil +} + +func splitGroupingToLeftAndRight(ctx *plancontext.PlanningContext, rootAggr *Aggregator, lhs, rhs *joinPusher) ([]JoinColumn, error) { + var groupingJCs []JoinColumn + + for _, groupBy := range rootAggr.Grouping { + deps := ctx.SemTable.RecursiveDeps(groupBy.Inner) + expr := groupBy.Inner + switch { + case deps.IsSolvedBy(lhs.tableID): + lhs.addGrouping(ctx, groupBy) + groupingJCs = append(groupingJCs, JoinColumn{ + Original: aeWrap(groupBy.Inner), + LHSExprs: []sqlparser.Expr{expr}, + }) + case deps.IsSolvedBy(rhs.tableID): + rhs.addGrouping(ctx, groupBy) + groupingJCs = append(groupingJCs, JoinColumn{ + Original: aeWrap(groupBy.Inner), + RHSExpr: expr, + }) + default: + return nil, vterrors.VT12001("grouping on columns from different sources") + } + } + return groupingJCs, nil +} + +// splitAggrColumnsToLeftAndRight pushes all aggregations on the aggregator above a join and +// pushes them to one or both sides of the join, and also provides the projections needed to re-assemble the +// aggregations that have been spread across the join +func splitAggrColumnsToLeftAndRight( + ctx *plancontext.PlanningContext, + aggregator *Aggregator, + join *ApplyJoin, + lhs, rhs *joinPusher, +) ([]JoinColumn, ops.Operator, error) { + builder := &aggBuilder{ + lhs: lhs, + rhs: rhs, + proj: &Projection{Source: join, FromAggr: true}, + outerJoin: join.LeftJoin, + } + +outer: + // we prefer adding the aggregations in the same order as the columns are declared + for colIdx, col := range aggregator.Columns { + for aggrIdx, aggr := range aggregator.Aggregations { + if aggr.ColOffset == colIdx { + aggrToKeep, err := builder.handleAggr(ctx, aggr) + if err != nil { + return nil, nil, err + } + aggregator.Aggregations[aggrIdx] = aggrToKeep + continue outer + } + } + builder.proj.addUnexploredExpr(col, col.Expr) + } + if builder.projectionRequired { + return builder.joinColumns, builder.proj, nil + } + + return builder.joinColumns, join, nil +} + +type ( + // aggBuilder is a helper struct that aids in pushing down an Aggregator through a join + // it accumulates the projections (if any) that need to be evaluated on top of the join + aggBuilder struct { + lhs, rhs *joinPusher + projectionRequired bool + joinColumns []JoinColumn + proj *Projection + outerJoin bool + } + // joinPusher is a helper struct that aids in pushing down an Aggregator into one side of a Join. + // It creates a new Aggregator that is pushed down and keeps track of the column dependencies that the new Aggregator has. + joinPusher struct { + orig *Aggregator // The original Aggregator before pushing. + pushed *Aggregator // The new Aggregator created for push-down. + columns []int // List of column offsets used in the new Aggregator. + tableID semantics.TableSet // The TableSet denoting the side of the Join where the new Aggregator is pushed. + + // csAE keeps the copy of the countStar expression that has already been added to split an aggregation. + // No need to have multiple countStars, so we cache it here + csAE *sqlparser.AliasedExpr + } +) + +func (ab *aggBuilder) leftCountStar(ctx *plancontext.PlanningContext) *sqlparser.AliasedExpr { + ae, created := ab.lhs.countStar(ctx) + if created { + ab.joinColumns = append(ab.joinColumns, JoinColumn{ + Original: ae, + LHSExprs: []sqlparser.Expr{ae.Expr}, + }) + } + return ae +} + +func (ab *aggBuilder) rightCountStar(ctx *plancontext.PlanningContext) *sqlparser.AliasedExpr { + ae, created := ab.rhs.countStar(ctx) + if created { + ab.joinColumns = append(ab.joinColumns, JoinColumn{ + Original: ae, + RHSExpr: ae.Expr, + }) + } + return ae +} + +func (p *joinPusher) countStar(ctx *plancontext.PlanningContext) (*sqlparser.AliasedExpr, bool) { + if p.csAE != nil { + return p.csAE, false + } + cs := &sqlparser.CountStar{} + ae := aeWrap(cs) + csAggr := Aggr{ + Original: ae, + Func: cs, + OpCode: opcode.AggregateCountStar, + } + expr := p.addAggr(ctx, csAggr) + p.csAE = aeWrap(expr) + return p.csAE, true +} + +func (ab *aggBuilder) handleAggr(ctx *plancontext.PlanningContext, aggr Aggr) (Aggr, error) { + switch aggr.OpCode { + case opcode.AggregateCountStar: + return ab.handleCountStar(ctx, aggr) + case opcode.AggregateMax, opcode.AggregateMin, opcode.AggregateRandom: + return ab.handlePushThroughAggregation(ctx, aggr) + case opcode.AggregateCount: + return ab.handleCount(ctx, aggr) + + case opcode.AggregateUnassigned: + return Aggr{}, vterrors.VT12001(fmt.Sprintf("in scatter query: aggregation function '%s'", sqlparser.String(aggr.Original))) + default: + return Aggr{}, errHorizonNotPlanned() + } +} + +// pushThroughLeft and Right are used for extremums and random, +// which are not split and then arithmetics is used to aggregate the per-shard aggregations. +// For these, we just copy the aggregation to one side of the join and then pick the max of the max:es returned +func (ab *aggBuilder) pushThroughLeft(aggr Aggr) { + ab.lhs.pushThroughAggr(aggr) + ab.joinColumns = append(ab.joinColumns, JoinColumn{ + Original: aggr.Original, + LHSExprs: []sqlparser.Expr{aggr.Original.Expr}, + }) +} +func (ab *aggBuilder) pushThroughRight(aggr Aggr) { + ab.rhs.pushThroughAggr(aggr) + ab.joinColumns = append(ab.joinColumns, JoinColumn{ + Original: aggr.Original, + RHSExpr: aggr.Original.Expr, + }) +} + +func (ab *aggBuilder) handlePushThroughAggregation(ctx *plancontext.PlanningContext, aggr Aggr) (Aggr, error) { + ab.proj.addUnexploredExpr(aggr.Original, aggr.Original.Expr) + + deps := ctx.SemTable.RecursiveDeps(aggr.Original.Expr) + switch { + case deps.IsSolvedBy(ab.lhs.tableID): + ab.pushThroughLeft(aggr) + return aggr, nil + case deps.IsSolvedBy(ab.rhs.tableID): + ab.pushThroughRight(aggr) + return aggr, nil + default: + return Aggr{}, vterrors.VT12001("aggregation on columns from different sources: " + sqlparser.String(aggr.Original.Expr)) + } +} + +func (ab *aggBuilder) handleCountStar(ctx *plancontext.PlanningContext, aggr Aggr) (Aggr, error) { + // Projection is necessary since we are going to need to do arithmetics to summarize the aggregates + ab.projectionRequired = true + + // Add the aggregate to both sides of the join. + lhsAE := ab.leftCountStar(ctx) + rhsAE := ab.rightCountStar(ctx) + + // We expect the expressions to be different on each side of the join, otherwise it's an error. + if lhsAE.Expr == rhsAE.Expr { + panic(fmt.Sprintf("Need the two produced expressions to be different. %T %T", lhsAE, rhsAE)) + } + + rhsExpr := rhsAE.Expr + + // When dealing with outer joins, we don't want null values from the RHS to ruin the calculations we are doing, + // so we use the MySQL `coalesce` after the join is applied to multiply the count from LHS with 1. + if ab.outerJoin { + rhsExpr = coalesceFunc(rhsExpr) + } + + // The final COUNT is obtained by multiplying the counts from both sides. + // This is equivalent to transforming a "select count(*) from t1 join t2" into + // "select count_t1*count_t2 from + // (select count(*) as count_t1 from t1) as x, + // (select count(*) as count_t2 from t2) as y". + projExpr := &sqlparser.BinaryExpr{ + Operator: sqlparser.MultOp, + Left: lhsAE.Expr, + Right: rhsExpr, + } + projAE := &sqlparser.AliasedExpr{ + Expr: aggr.Original.Expr, + As: sqlparser.NewIdentifierCI(aggr.Original.ColumnName()), + } + + ab.proj.addUnexploredExpr(projAE, projExpr) + return aggr, nil +} + +func (ab *aggBuilder) handleCount(ctx *plancontext.PlanningContext, aggr Aggr) (Aggr, error) { + ab.projectionRequired = true + + expr := aggr.Original.Expr + deps := ctx.SemTable.RecursiveDeps(expr) + var otherSide sqlparser.Expr + + switch { + case deps.IsSolvedBy(ab.lhs.tableID): + ab.pushThroughLeft(aggr) + ae := ab.rightCountStar(ctx) + otherSide = ae.Expr + + case deps.IsSolvedBy(ab.rhs.tableID): + ab.pushThroughRight(aggr) + ae := ab.leftCountStar(ctx) + otherSide = ae.Expr + + default: + return Aggr{}, errHorizonNotPlanned() + } + + if ab.outerJoin { + otherSide = coalesceFunc(otherSide) + } + + projAE := &sqlparser.AliasedExpr{ + Expr: aggr.Original.Expr, + As: sqlparser.NewIdentifierCI(aggr.Original.ColumnName()), + } + ab.proj.addUnexploredExpr(projAE, &sqlparser.BinaryExpr{ + Operator: sqlparser.MultOp, + Left: expr, + Right: otherSide, + }) + return aggr, nil +} + +func coalesceFunc(e sqlparser.Expr) sqlparser.Expr { + // `coalesce(e,1)` will return `e` if `e` is not `NULL`, otherwise it will return `1` + return &sqlparser.FuncExpr{ + Name: sqlparser.NewIdentifierCI("coalesce"), + Exprs: sqlparser.SelectExprs{ + aeWrap(e), + aeWrap(sqlparser.NewIntLiteral("1")), + }, + } +} + +// addAggr creates a copy of the given aggregation, updates its column offset to point to the correct location in the new Aggregator, +// and adds it to the list of Aggregations of the new Aggregator. It also updates the semantic analysis information to reflect the new structure. +// It returns the expression of the aggregation as it should be used in the parent Aggregator. +func (p *joinPusher) addAggr(ctx *plancontext.PlanningContext, aggr Aggr) sqlparser.Expr { + copyAggr := aggr + expr := sqlparser.CloneExpr(aggr.Original.Expr) + copyAggr.Original = aeWrap(expr) + // copy dependencies so we can keep track of which side expressions need to be pushed to + ctx.SemTable.Direct[expr] = p.tableID + ctx.SemTable.Recursive[expr] = p.tableID + copyAggr.ColOffset = len(p.pushed.Columns) + p.pushed.Columns = append(p.pushed.Columns, copyAggr.Original) + p.pushed.Aggregations = append(p.pushed.Aggregations, copyAggr) + return expr +} + +// pushThroughAggr pushes through an aggregation without changing dependencies. +// Can be used for aggregations we can push in one piece +func (p *joinPusher) pushThroughAggr(aggr Aggr) { + p.pushed.Columns = append(p.pushed.Columns, aggr.Original) + p.pushed.Aggregations = append(p.pushed.Aggregations, aggr) +} + +// addGrouping creates a copy of the given GroupBy, updates its column offset to point to the correct location in the new Aggregator, +// and adds it to the list of GroupBy expressions of the new Aggregator. It also updates the semantic analysis information to reflect the new structure. +// It returns the expression of the GroupBy as it should be used in the parent Aggregator. +func (p *joinPusher) addGrouping(ctx *plancontext.PlanningContext, gb GroupBy) sqlparser.Expr { + copyGB := gb + expr := sqlparser.CloneExpr(gb.Inner) + // copy dependencies so we can keep track of which side expressions need to be pushed to + ctx.SemTable.CopyDependencies(gb.Inner, expr) + // if the column exists in the selection then copy it down to the pushed aggregator operator. + if copyGB.ColOffset != -1 { + offset := p.useColumn(copyGB.ColOffset) + copyGB.ColOffset = offset + } + p.pushed.Grouping = append(p.pushed.Grouping, copyGB) + return expr +} + +// useColumn checks whether the column corresponding to the given offset has been used in the new Aggregator. +// If it has not been used before, it adds the column to the new Aggregator +// and updates the columns mapping to reflect the new location of the column. +// It returns the offset of the column in the new Aggregator. +func (p *joinPusher) useColumn(offset int) int { + if p.columns[offset] == -1 { + p.columns[offset] = len(p.pushed.Columns) + // still haven't used this expression on this side + p.pushed.Columns = append(p.pushed.Columns, p.orig.Columns[offset]) + } + return p.columns[offset] +} + +func initColReUse(size int) []int { + cols := make([]int, size) + for i := 0; i < size; i++ { + cols[i] = -1 + } + return cols +} + +func extractExpr(expr *sqlparser.AliasedExpr) sqlparser.Expr { return expr.Expr } diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go new file mode 100644 index 00000000000..b49a0cb1b07 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -0,0 +1,248 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "fmt" + "strings" + + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/semantics" + + "golang.org/x/exp/slices" + + "vitess.io/vitess/go/vt/vtgate/engine/opcode" + + "vitess.io/vitess/go/slices2" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +type ( + // Aggregator represents a GroupBy γ relational operator. + // Both all aggregations and no grouping, and the inverse + // of all grouping and no aggregations are valid configurations of this operator + Aggregator struct { + Source ops.Operator + Columns []*sqlparser.AliasedExpr + + Grouping []GroupBy + Aggregations []Aggr + + // Pushed will be set to true once this aggregation has been pushed deeper in the tree + Pushed bool + + // Original will only be true for the original aggregator created from the AST + Original bool + ResultColumns int + + QP *QueryProjection + // TableID will be non-nil for derived tables + TableID *semantics.TableSet + Alias string + } +) + +func (a *Aggregator) Clone(inputs []ops.Operator) ops.Operator { + return &Aggregator{ + Source: inputs[0], + Columns: slices.Clone(a.Columns), + Grouping: slices.Clone(a.Grouping), + Aggregations: slices.Clone(a.Aggregations), + Pushed: a.Pushed, + Original: a.Original, + ResultColumns: a.ResultColumns, + QP: a.QP, + } +} + +func (a *Aggregator) Inputs() []ops.Operator { + return []ops.Operator{a.Source} +} + +func (a *Aggregator) SetInputs(operators []ops.Operator) { + if len(operators) != 1 { + panic(fmt.Sprintf("unexpected number of operators as input in aggregator: %d", len(operators))) + } + a.Source = operators[0] +} + +func (a *Aggregator) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) (ops.Operator, error) { + newOp, err := a.Source.AddPredicate(ctx, expr) + if err != nil { + return nil, err + } + a.Source = newOp + return a, nil +} + +func (a *Aggregator) addNoPushCol(expr *sqlparser.AliasedExpr, addToGroupBy bool) int { + offset := len(a.Columns) + a.Columns = append(a.Columns, expr) + + if addToGroupBy { + groupBy := NewGroupBy(expr.Expr, expr.Expr, expr) + groupBy.ColOffset = offset + a.Grouping = append(a.Grouping, groupBy) + } else { + a.Aggregations = append(a.Aggregations, Aggr{ + Original: expr, + Func: nil, + OpCode: opcode.AggregateRandom, + Alias: expr.As.String(), + ColOffset: offset, + }) + } + return offset +} + +func (a *Aggregator) isDerived() bool { + return a.TableID != nil +} + +func (a *Aggregator) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, _, addToGroupBy bool) (ops.Operator, int, error) { + if addToGroupBy { + return nil, 0, vterrors.VT13001("did not expect to add group by here") + } + if offset, found := canReuseColumn(ctx, a.Columns, expr.Expr, extractExpr); found { + return a, offset, nil + } + colName, isColName := expr.Expr.(*sqlparser.ColName) + for i, col := range a.Columns { + if isColName && colName.Name.EqualString(col.As.String()) { + return a, i, nil + } + } + + // If weight string function is received from above operator. Then check if we have a group on the expression used. + // If it is found, then continue to push it down but with addToGroupBy true so that is the added to group by sql down in the AddColumn. + // This also set the weight string column offset so that we would not need to add it later in aggregator operator planOffset. + + // If the expression is a WeightStringFuncExpr, it checks if a GroupBy + // already exists on the argument of the expression. + // If it is found, the column offset for the WeightStringFuncExpr is set, + // and the column is marked to be added to the GroupBy in the recursive AddColumn call. + if wsExpr, isWS := expr.Expr.(*sqlparser.WeightStringFuncExpr); isWS { + idx := slices.IndexFunc(a.Grouping, func(by GroupBy) bool { + return ctx.SemTable.EqualsExprWithDeps(wsExpr.Expr, by.SimplifiedExpr) + }) + if idx >= 0 { + a.Grouping[idx].WSOffset = len(a.Columns) + addToGroupBy = true + } + } + + if !addToGroupBy { + a.Aggregations = append(a.Aggregations, Aggr{ + Original: expr, + Func: nil, + OpCode: opcode.AggregateRandom, + Alias: expr.As.String(), + ColOffset: len(a.Columns), + }) + } + a.Columns = append(a.Columns, expr) + expectedOffset := len(a.Columns) - 1 + newSrc, offset, err := a.Source.AddColumn(ctx, expr, false, addToGroupBy) + if err != nil { + return nil, 0, err + } + if offset != expectedOffset { + return nil, 0, vterrors.VT13001("the offset needs to be aligned here") + } + a.Source = newSrc + return a, offset, nil +} + +func (a *Aggregator) GetColumns() (columns []*sqlparser.AliasedExpr, err error) { + return a.Columns, nil +} + +func (a *Aggregator) Description() ops.OpDescription { + return ops.OpDescription{ + OperatorType: "Aggregator", + } +} + +func (a *Aggregator) ShortDescription() string { + columnns := slices2.Map(a.Columns, func(from *sqlparser.AliasedExpr) string { + return sqlparser.String(from) + }) + + if len(a.Grouping) == 0 { + return strings.Join(columnns, ", ") + } + + var grouping []string + for _, gb := range a.Grouping { + grouping = append(grouping, sqlparser.String(gb.SimplifiedExpr)) + } + + org := "" + if a.Original { + org = "ORG " + } + + return fmt.Sprintf("%s%s group by %s", org, strings.Join(columnns, ", "), strings.Join(grouping, ",")) +} + +func (a *Aggregator) GetOrdering() ([]ops.OrderBy, error) { + return a.Source.GetOrdering() +} + +var _ ops.Operator = (*Aggregator)(nil) + +func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) error { + addColumn := func(aliasedExpr *sqlparser.AliasedExpr, addToGroupBy bool) (int, error) { + newSrc, offset, err := a.Source.AddColumn(ctx, aliasedExpr, true, addToGroupBy) + if err != nil { + return 0, err + } + a.Source = newSrc + if offset == len(a.Columns) { + // if we get an offset at the end of our current column list, it means we added a new column + a.Columns = append(a.Columns, aliasedExpr) + } + return offset, nil + } + + for idx, gb := range a.Grouping { + if gb.ColOffset == -1 { + offset, err := addColumn(aeWrap(gb.Inner), false) + if err != nil { + return err + } + a.Grouping[idx].ColOffset = offset + } + if a.Grouping[idx].WSOffset != -1 || !ctx.SemTable.NeedsWeightString(gb.SimplifiedExpr) { + continue + } + + offset, err := addColumn(aeWrap(weightStringFor(gb.SimplifiedExpr)), true) + if err != nil { + return err + } + a.Grouping[idx].WSOffset = offset + } + + return nil +} + +func (a *Aggregator) setTruncateColumnCount(offset int) { + a.ResultColumns = offset +} diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index 277a4aa81d4..6f67e549a7a 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -17,6 +17,9 @@ limitations under the License. package operators import ( + "fmt" + "strings" + "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -73,6 +76,7 @@ type JoinColumn struct { BvNames []string // the BvNames and LHSCols line up LHSExprs []sqlparser.Expr RHSExpr sqlparser.Expr + GroupBy bool // if this is true, we need to push this down to our inputs with addToGroupBy set to true } func NewApplyJoin(lhs, rhs ops.Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin { @@ -157,8 +161,8 @@ func (a *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlp return nil } -func (a *ApplyJoin) pushColLeft(ctx *plancontext.PlanningContext, e *sqlparser.AliasedExpr) (int, error) { - newLHS, offset, err := a.LHS.AddColumn(ctx, e) +func (a *ApplyJoin) pushColLeft(ctx *plancontext.PlanningContext, e *sqlparser.AliasedExpr, addToGroupBy bool) (int, error) { + newLHS, offset, err := a.LHS.AddColumn(ctx, e, true, addToGroupBy) if err != nil { return 0, err } @@ -166,8 +170,8 @@ func (a *ApplyJoin) pushColLeft(ctx *plancontext.PlanningContext, e *sqlparser.A return offset, nil } -func (a *ApplyJoin) pushColRight(ctx *plancontext.PlanningContext, e *sqlparser.AliasedExpr) (int, error) { - newRHS, offset, err := a.RHS.AddColumn(ctx, e) +func (a *ApplyJoin) pushColRight(ctx *plancontext.PlanningContext, e *sqlparser.AliasedExpr, addToGroupBy bool) (int, error) { + newRHS, offset, err := a.RHS.AddColumn(ctx, e, true, addToGroupBy) if err != nil { return 0, err } @@ -191,7 +195,7 @@ func joinColumnToExpr(column JoinColumn) sqlparser.Expr { return column.Original.Expr } -func (a *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, e *sqlparser.AliasedExpr) (col JoinColumn, err error) { +func (a *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, e *sqlparser.AliasedExpr, addToGroupBy bool) (col JoinColumn, err error) { defer func() { col.Original = e }() @@ -200,6 +204,7 @@ func (a *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, e *sqlpar both := lhs.Merge(rhs) expr := e.Expr deps := ctx.SemTable.RecursiveDeps(expr) + col.GroupBy = addToGroupBy switch { case deps.IsSolvedBy(lhs): @@ -218,11 +223,11 @@ func (a *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, e *sqlpar return } -func (a *ApplyJoin) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr) (ops.Operator, int, error) { +func (a *ApplyJoin) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, _, addToGroupBy bool) (ops.Operator, int, error) { if offset, found := canReuseColumn(ctx, a.ColumnsAST, expr.Expr, joinColumnToExpr); found { return a, offset, nil } - col, err := a.getJoinColumnFor(ctx, expr) + col, err := a.getJoinColumnFor(ctx, expr, addToGroupBy) if err != nil { return nil, 0, err } @@ -234,19 +239,19 @@ func (a *ApplyJoin) planOffsets(ctx *plancontext.PlanningContext) (err error) { for _, col := range a.ColumnsAST { // Read the type description for JoinColumn to understand the following code for i, lhsExpr := range col.LHSExprs { - offset, err := a.pushColLeft(ctx, aeWrap(lhsExpr)) + offset, err := a.pushColLeft(ctx, aeWrap(lhsExpr), col.GroupBy) if err != nil { return err } if col.RHSExpr == nil { - // if we don't have a RHS expr, it means that this is a pure LHS expression + // if we don't have an RHS expr, it means that this is a pure LHS expression a.addOffset(-offset - 1) } else { a.Vars[col.BvNames[i]] = offset } } if col.RHSExpr != nil { - offset, err := a.pushColRight(ctx, aeWrap(col.RHSExpr)) + offset, err := a.pushColRight(ctx, aeWrap(col.RHSExpr), col.GroupBy) if err != nil { return err } @@ -256,7 +261,7 @@ func (a *ApplyJoin) planOffsets(ctx *plancontext.PlanningContext) (err error) { for _, col := range a.JoinPredicates { for i, lhsExpr := range col.LHSExprs { - offset, err := a.pushColLeft(ctx, aeWrap(lhsExpr)) + offset, err := a.pushColLeft(ctx, aeWrap(lhsExpr), false) if err != nil { return err } @@ -298,7 +303,11 @@ func (a *ApplyJoin) Description() ops.OpDescription { } func (a *ApplyJoin) ShortDescription() string { - return sqlparser.String(a.Predicate) + pred := sqlparser.String(a.Predicate) + columns := slices2.Map(a.ColumnsAST, func(from JoinColumn) string { + return sqlparser.String(from.Original) + }) + return fmt.Sprintf("on %s columns: %s", pred, strings.Join(columns, ", ")) } func (jc JoinColumn) IsPureLeft() bool { diff --git a/go/vt/vtgate/planbuilder/operators/delete.go b/go/vt/vtgate/planbuilder/operators/delete.go index cda98da33dd..cd4f30f6d01 100644 --- a/go/vt/vtgate/planbuilder/operators/delete.go +++ b/go/vt/vtgate/planbuilder/operators/delete.go @@ -42,7 +42,7 @@ func (d *Delete) Introduces() semantics.TableSet { } // Clone implements the Operator interface -func (d *Delete) Clone(inputs []ops.Operator) ops.Operator { +func (d *Delete) Clone([]ops.Operator) ops.Operator { return &Delete{ QTable: d.QTable, VTable: d.VTable, diff --git a/go/vt/vtgate/planbuilder/operators/derived.go b/go/vt/vtgate/planbuilder/operators/derived.go index d12a928d81c..fd1d57cc63a 100644 --- a/go/vt/vtgate/planbuilder/operators/derived.go +++ b/go/vt/vtgate/planbuilder/operators/derived.go @@ -21,10 +21,9 @@ import ( "golang.org/x/exp/slices" - "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" - "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -158,7 +157,7 @@ func canBePushedDownIntoDerived(expr sqlparser.Expr) (canBePushed bool) { return } -func (d *Derived) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr) (ops.Operator, int, error) { +func (d *Derived) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, _, addToGroupBy bool) (ops.Operator, int, error) { col, ok := expr.Expr.(*sqlparser.ColName) if !ok { return nil, 0, vterrors.VT13001("cannot push non-colname expression to a derived table") @@ -179,7 +178,7 @@ func (d *Derived) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.Al d.Columns = append(d.Columns, col) // add it to the source if we were not already passing it through if i <= -1 { - newSrc, _, err := d.Source.AddColumn(ctx, aeWrap(sqlparser.NewColName(col.Name.String()))) + newSrc, _, err := d.Source.AddColumn(ctx, aeWrap(sqlparser.NewColName(col.Name.String())), true, addToGroupBy) if err != nil { return nil, 0, err } @@ -197,7 +196,7 @@ func canReuseColumn[T any]( f func(T) sqlparser.Expr, ) (offset int, found bool) { for offset, column := range columns { - if ctx.SemTable.EqualsExpr(col, f(column)) { + if ctx.SemTable.EqualsExprWithDeps(col, f(column)) { return offset, true } } diff --git a/go/vt/vtgate/planbuilder/operators/dml_planning.go b/go/vt/vtgate/planbuilder/operators/dml_planning.go index eec13340f4c..a9c5c4b8871 100644 --- a/go/vt/vtgate/planbuilder/operators/dml_planning.go +++ b/go/vt/vtgate/planbuilder/operators/dml_planning.go @@ -46,7 +46,7 @@ func getVindexInformation( var vindexesAndPredicates []*VindexPlusPredicates for _, colVindex := range table.Ordered { if lu, isLu := colVindex.Vindex.(vindexes.LookupBackfill); isLu && lu.IsBackfilling() { - // Checking if the Vindex is currently backfilling or not, if it isn't we can read from the vindex table + // Checking if the Vindex is currently backfilling or not, if it isn't we can read from the vindex table, // and we will be able to do a delete equal. Otherwise, we continue to look for next best vindex. continue } diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index 896ac18ec22..246a6702142 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -31,11 +31,13 @@ func BreakExpressionInLHSandRHS( lhs semantics.TableSet, ) (col JoinColumn, err error) { rewrittenExpr := sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) { - node, ok := cursor.Node().(*sqlparser.ColName) - if !ok { + node := cursor.Node() + reservedName := getReservedBVName(node) + if reservedName == "" { return } - deps := ctx.SemTable.RecursiveDeps(node) + nodeExpr := node.(sqlparser.Expr) + deps := ctx.SemTable.RecursiveDeps(nodeExpr) if deps.IsEmpty() { err = vterrors.VT13001("unknown column. has the AST been copied?") cursor.StopTreeWalk() @@ -45,14 +47,16 @@ func BreakExpressionInLHSandRHS( return } - node.Qualifier.Qualifier = sqlparser.NewIdentifierCS("") - col.LHSExprs = append(col.LHSExprs, node) - bvName := node.CompliantName() + col.LHSExprs = append(col.LHSExprs, nodeExpr) + bvName := ctx.GetArgumentFor(nodeExpr, func() string { + return ctx.ReservedVars.ReserveVariable(reservedName) + }) + col.BvNames = append(col.BvNames, bvName) arg := sqlparser.NewArgument(bvName) // we are replacing one of the sides of the comparison with an argument, // but we don't want to lose the type information we have, so we copy it over - ctx.SemTable.CopyExprInfo(node, arg) + ctx.SemTable.CopyExprInfo(nodeExpr, arg) cursor.Replace(arg) }, nil).(sqlparser.Expr) @@ -63,3 +67,14 @@ func BreakExpressionInLHSandRHS( col.RHSExpr = rewrittenExpr return } + +func getReservedBVName(node sqlparser.SQLNode) string { + switch node := node.(type) { + case *sqlparser.ColName: + node.Qualifier.Qualifier = sqlparser.NewIdentifierCS("") + return node.CompliantName() + case sqlparser.AggrFunc: + return sqlparser.CompliantString(node) + } + return "" +} diff --git a/go/vt/vtgate/planbuilder/operators/filter.go b/go/vt/vtgate/planbuilder/operators/filter.go index db45325825f..f2f07b3ee19 100644 --- a/go/vt/vtgate/planbuilder/operators/filter.go +++ b/go/vt/vtgate/planbuilder/operators/filter.go @@ -82,8 +82,8 @@ func (f *Filter) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.E return f, nil } -func (f *Filter) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr) (ops.Operator, int, error) { - newSrc, offset, err := f.Source.AddColumn(ctx, expr) +func (f *Filter) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, reuseExisting, addToGroupBy bool) (ops.Operator, int, error) { + newSrc, offset, err := f.Source.AddColumn(ctx, expr, reuseExisting, addToGroupBy) if err != nil { return nil, 0, err } @@ -99,9 +99,9 @@ func (f *Filter) GetOrdering() ([]ops.OrderBy, error) { return f.Source.GetOrdering() } -func (f *Filter) Compact(*plancontext.PlanningContext) (ops.Operator, rewrite.ApplyResult, error) { +func (f *Filter) Compact(*plancontext.PlanningContext) (ops.Operator, *rewrite.ApplyResult, error) { if len(f.Predicates) == 0 { - return f.Source, rewrite.NewTree, nil + return f.Source, rewrite.NewTree("filter with no predicates removed", f), nil } other, isFilter := f.Source.(*Filter) @@ -110,12 +110,12 @@ func (f *Filter) Compact(*plancontext.PlanningContext) (ops.Operator, rewrite.Ap } f.Source = other.Source f.Predicates = append(f.Predicates, other.Predicates...) - return f, rewrite.NewTree, nil + return f, rewrite.NewTree("two filters merged into one", f), nil } func (f *Filter) planOffsets(ctx *plancontext.PlanningContext) error { resolveColumn := func(col *sqlparser.ColName) (int, error) { - newSrc, offset, err := f.Source.AddColumn(ctx, aeWrap(col)) + newSrc, offset, err := f.Source.AddColumn(ctx, aeWrap(col), true, false) if err != nil { return 0, err } diff --git a/go/vt/vtgate/planbuilder/operators/helpers.go b/go/vt/vtgate/planbuilder/operators/helpers.go index 172e47cae74..1c472acf413 100644 --- a/go/vt/vtgate/planbuilder/operators/helpers.go +++ b/go/vt/vtgate/planbuilder/operators/helpers.go @@ -32,16 +32,16 @@ import ( func compact(ctx *plancontext.PlanningContext, op ops.Operator) (ops.Operator, error) { type compactable interface { // Compact implement this interface for operators that have easy to see optimisations - Compact(ctx *plancontext.PlanningContext) (ops.Operator, rewrite.ApplyResult, error) + Compact(ctx *plancontext.PlanningContext) (ops.Operator, *rewrite.ApplyResult, error) } - newOp, err := rewrite.BottomUpAll(op, TableID, func(op ops.Operator, _ semantics.TableSet, _ bool) (ops.Operator, rewrite.ApplyResult, error) { + newOp, err := rewrite.BottomUp(op, TableID, func(op ops.Operator, _ semantics.TableSet, _ bool) (ops.Operator, *rewrite.ApplyResult, error) { newOp, ok := op.(compactable) if !ok { return op, rewrite.SameTree, nil } return newOp.Compact(ctx) - }) + }, stopAtRoute) return newOp, err } @@ -103,7 +103,7 @@ func TablesUsed(op ops.Operator) []string { func UnresolvedPredicates(op ops.Operator, st *semantics.SemTable) (result []sqlparser.Expr) { type unresolved interface { // UnsolvedPredicates returns any predicates that have dependencies on the given Operator and - // on the outside of it (a parent Select expression, any other table not used by Operator, etc). + // on the outside of it (a parent Select expression, any other table not used by Operator, etc.). // This is used for sub-queries. An example query could be: // SELECT * FROM tbl WHERE EXISTS (SELECT 1 FROM otherTbl WHERE tbl.col = otherTbl.col) // The subquery would have one unsolved predicate: `tbl.col = otherTbl.col` @@ -148,14 +148,6 @@ func QualifiedString(ks *vindexes.Keyspace, s string) string { return fmt.Sprintf("%s.%s", ks.Name, s) } -func QualifiedStrings(ks *vindexes.Keyspace, ss []string) []string { - add, collect := collectSortedUniqueStrings() - for _, s := range ss { - add(QualifiedString(ks, s)) - } - return collect() -} - func QualifiedTableName(ks *vindexes.Keyspace, t sqlparser.TableName) string { return QualifiedIdentifier(ks, t.Name) } @@ -184,10 +176,6 @@ func SingleQualifiedString(ks *vindexes.Keyspace, s string) []string { return []string{QualifiedString(ks, s)} } -func SingleQualifiedTableName(ks *vindexes.Keyspace, t sqlparser.TableName) []string { - return SingleQualifiedIdentifier(ks, t.Name) -} - func collectSortedUniqueStrings() (add func(string), collect func() []string) { uniq := make(map[string]any) add = func(v string) { diff --git a/go/vt/vtgate/planbuilder/operators/horizon.go b/go/vt/vtgate/planbuilder/operators/horizon.go index 94581e279ab..39efe2b9956 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon.go +++ b/go/vt/vtgate/planbuilder/operators/horizon.go @@ -36,7 +36,7 @@ type Horizon struct { QP *QueryProjection } -func (h *Horizon) AddColumn(*plancontext.PlanningContext, *sqlparser.AliasedExpr) (ops.Operator, int, error) { +func (h *Horizon) AddColumn(*plancontext.PlanningContext, *sqlparser.AliasedExpr, bool, bool) (ops.Operator, int, error) { return nil, 0, vterrors.VT13001("the Horizon operator cannot accept new columns") } diff --git a/go/vt/vtgate/planbuilder/operators/horizon_planning.go b/go/vt/vtgate/planbuilder/operators/horizon_planning.go index 5d3f700c65e..766783bf2e1 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_planning.go @@ -17,6 +17,10 @@ limitations under the License. package operators import ( + "fmt" + + "vitess.io/vitess/go/slices2" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite" @@ -32,25 +36,87 @@ func errHorizonNotPlanned() error { var _errHorizonNotPlanned = vterrors.VT12001("query cannot be fully operator planned") +func tryHorizonPlanning(ctx *plancontext.PlanningContext, root ops.Operator) (output ops.Operator, err error) { + backup := Clone(root) + defer func() { + // If we encounter the _errHorizonNotPlanned error, we'll revert to using the old horizon planning strategy. + if err == _errHorizonNotPlanned { + // The only offset planning we did before was on joins. + // Therefore, we traverse the tree to find all joins and calculate the joinColumns offsets. + // Our fallback strategy is to clone the original operator tree, compute the join offsets, + // and allow the legacy horizonPlanner to handle this query using logical plans. + err = planOffsetsOnJoins(ctx, backup) + if err == nil { + output = backup + } + } + }() + + _, ok := root.(*Horizon) + + if !ok || len(ctx.SemTable.SubqueryMap) > 0 || len(ctx.SemTable.SubqueryRef) > 0 { + // we are not ready to deal with subqueries yet + return root, errHorizonNotPlanned() + } + + output, err = planHorizons(ctx, root) + if err != nil { + return nil, err + } + + output, err = planOffsets(ctx, output) + if err != nil { + return nil, err + } + + output, err = makeSureOutputIsCorrect(ctx, root, output) + if err != nil { + return nil, err + } + + return +} + // planHorizons is the process of figuring out how to perform the operations in the Horizon // If we can push it under a route - done. // If we can't, we will instead expand the Horizon into // smaller operators and try to push these down as far as possible func planHorizons(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) { - visitor := func(in ops.Operator, _ semantics.TableSet, isRoot bool) (ops.Operator, rewrite.ApplyResult, error) { + var err error + root, err = optimizeHorizonPlanning(ctx, root) + if err != nil { + return nil, err + } + + // Adding Ordering Op - This is needed if there is no explicit ordering and aggregation is performed on top of route. + // Adding Group by - This is needed if the grouping is performed on a join with a join condition then + // aggregation happening at route needs a group by to ensure only matching rows returns + // the aggregations otherwise returns no result. + root, err = addOrderBysAndGroupBysForAggregations(ctx, root) + if err != nil { + return nil, err + } + + root, err = optimizeHorizonPlanning(ctx, root) + if err != nil { + return nil, err + } + return root, nil +} + +func optimizeHorizonPlanning(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) { + visitor := func(in ops.Operator, _ semantics.TableSet, isRoot bool) (ops.Operator, *rewrite.ApplyResult, error) { switch in := in.(type) { case horizonLike: - op, err := pushOrExpandHorizon(ctx, in) - if err != nil { - return nil, false, err - } - return op, rewrite.NewTree, nil + return pushOrExpandHorizon(ctx, in) case *Projection: return tryPushingDownProjection(ctx, in) case *Limit: return tryPushingDownLimit(in) case *Ordering: return tryPushingDownOrdering(ctx, in) + case *Aggregator: + return tryPushingDownAggregator(ctx, in) default: return in, rewrite.SameTree, nil } @@ -68,21 +134,147 @@ func planHorizons(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Oper return newOp, nil } -func tryPushingDownOrdering(ctx *plancontext.PlanningContext, in *Ordering) (ops.Operator, rewrite.ApplyResult, error) { +func addOrderBysAndGroupBysForAggregations(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) { + visitor := func(in ops.Operator, _ semantics.TableSet, isRoot bool) (ops.Operator, *rewrite.ApplyResult, error) { + switch in := in.(type) { + case *Aggregator: + requireOrdering, err := needsOrdering(in, ctx) + if err != nil { + return nil, nil, err + } + if !requireOrdering { + return in, rewrite.SameTree, nil + } + in.Source = &Ordering{ + Source: in.Source, + Order: slices2.Map(in.Grouping, func(from GroupBy) ops.OrderBy { + return from.AsOrderBy() + }), + } + return in, rewrite.NewTree("added ordering before aggregation", in), nil + case *ApplyJoin: + _ = rewrite.Visit(in.RHS, func(op ops.Operator) error { + aggr, isAggr := op.(*Aggregator) + if !isAggr { + return nil + } + if len(aggr.Grouping) == 0 { + gb := sqlparser.NewIntLiteral(".0") + aggr.Grouping = append(aggr.Grouping, NewGroupBy(gb, gb, aeWrap(gb))) + } + return nil + }) + } + return in, rewrite.SameTree, nil + } + + return rewrite.TopDown(root, TableID, visitor, stopAtRoute) +} + +func needsOrdering(in *Aggregator, ctx *plancontext.PlanningContext) (bool, error) { + if len(in.Grouping) == 0 { + return false, nil + } + srcOrdering, err := in.Source.GetOrdering() + if err != nil { + return false, err + } + if len(srcOrdering) < len(in.Grouping) { + return true, nil + } + for idx, gb := range in.Grouping { + if !ctx.SemTable.EqualsExprWithDeps(srcOrdering[idx].SimplifiedExpr, gb.SimplifiedExpr) { + return true, nil + } + } + return false, nil +} + +func tryPushingDownOrdering(ctx *plancontext.PlanningContext, in *Ordering) (ops.Operator, *rewrite.ApplyResult, error) { switch src := in.Source.(type) { case *Route: - return swap(in, src) + return rewrite.Swap(in, src, "push ordering under route") case *ApplyJoin: if canPushLeft(ctx, src, in.Order) { // ApplyJoin is stable in regard to the columns coming from the LHS, // so if all the ordering columns come from the LHS, we can push down the Ordering there src.LHS, in.Source = in, src.LHS - return src, rewrite.NewTree, nil + return src, rewrite.NewTree("push down ordering on the LHS of a join", in), nil + } + case *Ordering: + // we'll just remove the order underneath. The top order replaces whatever was incoming + in.Source = src.Source + return in, rewrite.NewTree("remove double ordering", src), nil + case *Projection: + // we can move ordering under a projection if it's not introducing a column we're sorting by + for _, by := range in.Order { + if !fetchByOffset(by.SimplifiedExpr) { + return in, rewrite.SameTree, nil + } } + return rewrite.Swap(in, src, "push ordering under projection") + case *Aggregator: + if !src.QP.AlignGroupByAndOrderBy(ctx) { + return in, rewrite.SameTree, nil + } + + return pushOrderingUnderAggr(ctx, in, src) + } return in, rewrite.SameTree, nil } +func pushOrderingUnderAggr(ctx *plancontext.PlanningContext, order *Ordering, aggregator *Aggregator) (ops.Operator, *rewrite.ApplyResult, error) { + // Step 1: Align the GROUP BY and ORDER BY. + // Reorder the GROUP BY columns to match the ORDER BY columns. + // Since the GB clause is a set, we can reorder these columns freely. + var newGrouping []GroupBy + used := make([]bool, len(aggregator.Grouping)) + for _, orderExpr := range order.Order { + for grpIdx, by := range aggregator.Grouping { + if !used[grpIdx] && ctx.SemTable.EqualsExprWithDeps(by.SimplifiedExpr, orderExpr.SimplifiedExpr) { + newGrouping = append(newGrouping, by) + used[grpIdx] = true + } + } + } + + // Step 2: Add any missing columns from the ORDER BY. + // The ORDER BY column is not a set, but we can add more elements + // to the end without changing the semantics of the query. + if len(newGrouping) != len(aggregator.Grouping) { + // we are missing some groupings. We need to add them both to the new groupings list, but also to the ORDER BY + for i, added := range used { + if !added { + groupBy := aggregator.Grouping[i] + newGrouping = append(newGrouping, groupBy) + order.Order = append(order.Order, groupBy.AsOrderBy()) + } + } + } + + aggregator.Grouping = newGrouping + aggrSource, isOrdering := aggregator.Source.(*Ordering) + if isOrdering { + // Transform the query plan tree: + // From: Ordering(1) To: Aggregation + // | | + // Aggregation Ordering(1) + // | | + // Ordering(2) + // | + // + // + // Remove Ordering(2) from the plan tree, as it's redundant + // after pushing down the higher ordering. + order.Source = aggrSource.Source + aggrSource.Source = nil // removing from plan tree + aggregator.Source = order + return aggregator, rewrite.NewTree("push ordering under aggregation, removing extra ordering", aggregator), nil + } + return rewrite.Swap(order, aggregator, "push ordering under aggregation") +} + func canPushLeft(ctx *plancontext.PlanningContext, aj *ApplyJoin, order []ops.OrderBy) bool { lhs := TableID(aj.LHS) for _, order := range order { @@ -97,11 +289,14 @@ func canPushLeft(ctx *plancontext.PlanningContext, aj *ApplyJoin, order []ops.Or func tryPushingDownProjection( ctx *plancontext.PlanningContext, p *Projection, -) (ops.Operator, rewrite.ApplyResult, error) { +) (ops.Operator, *rewrite.ApplyResult, error) { switch src := p.Source.(type) { case *Route: - return swap(p, src) + return rewrite.Swap(p, src, "pushed projection under route") case *ApplyJoin: + if p.FromAggr { + return p, rewrite.SameTree, nil + } return pushDownProjectionInApplyJoin(ctx, p, src) case *Vindex: return pushDownProjectionInVindex(ctx, p, src) @@ -110,35 +305,27 @@ func tryPushingDownProjection( } } -func swap(a, b ops.Operator) (ops.Operator, rewrite.ApplyResult, error) { - op, err := rewrite.Swap(a, b) - if err != nil { - return nil, false, err - } - return op, rewrite.NewTree, nil -} - func pushDownProjectionInVindex( ctx *plancontext.PlanningContext, p *Projection, src *Vindex, -) (ops.Operator, rewrite.ApplyResult, error) { - for _, column := range p.Columns { +) (ops.Operator, *rewrite.ApplyResult, error) { + for _, column := range p.Projections { expr := column.GetExpr() - _, _, err := src.AddColumn(ctx, aeWrap(expr)) + _, _, err := src.AddColumn(ctx, aeWrap(expr), true, false) if err != nil { - return nil, false, err + return nil, nil, err } } - return src, rewrite.NewTree, nil + return src, rewrite.NewTree("push projection into vindex", p), nil } type projector struct { cols []ProjExpr - names []string + names []*sqlparser.AliasedExpr } -func (p *projector) add(e ProjExpr, alias string) { +func (p *projector) add(e ProjExpr, alias *sqlparser.AliasedExpr) { p.cols = append(p.cols, e) p.names = append(p.names, alias) } @@ -151,21 +338,25 @@ func pushDownProjectionInApplyJoin( ctx *plancontext.PlanningContext, p *Projection, src *ApplyJoin, -) (ops.Operator, rewrite.ApplyResult, error) { +) (ops.Operator, *rewrite.ApplyResult, error) { + if src.LeftJoin { + // we can't push down expression evaluation to the rhs if we are not sure if it will even be executed + return p, rewrite.SameTree, nil + } lhs, rhs := &projector{}, &projector{} src.ColumnsAST = nil - for idx := 0; idx < len(p.Columns); idx++ { - err := splitProjectionAcrossJoin(ctx, src, lhs, rhs, p.Columns[idx], p.ColumnNames[idx]) + for idx := 0; idx < len(p.Projections); idx++ { + err := splitProjectionAcrossJoin(ctx, src, lhs, rhs, p.Projections[idx], p.Columns[idx]) if err != nil { - return nil, false, err + return nil, nil, err } } if p.TableID != nil { err := exposeColumnsThroughDerivedTable(ctx, p, src, lhs) if err != nil { - return nil, false, err + return nil, nil, err } } @@ -174,15 +365,15 @@ func pushDownProjectionInApplyJoin( // Create and update the Projection operators for the left and right children, if needed. src.LHS, err = createProjectionWithTheseColumns(src.LHS, lhs, p.TableID, p.Alias) if err != nil { - return nil, false, err + return nil, nil, err } src.RHS, err = createProjectionWithTheseColumns(src.RHS, rhs, p.TableID, p.Alias) if err != nil { - return nil, false, err + return nil, nil, err } - return src, rewrite.NewTree, nil + return src, rewrite.NewTree("split projection to either side of join", src), nil } // splitProjectionAcrossJoin creates JoinColumns for all projections, @@ -192,7 +383,7 @@ func splitProjectionAcrossJoin( join *ApplyJoin, lhs, rhs *projector, in ProjExpr, - colName string, + colName *sqlparser.AliasedExpr, ) error { expr := in.GetExpr() @@ -202,7 +393,7 @@ func splitProjectionAcrossJoin( } // Get a JoinColumn for the current expression. - col, err := join.getJoinColumnFor(ctx, &sqlparser.AliasedExpr{Expr: expr, As: sqlparser.NewIdentifierCI(colName)}) + col, err := join.getJoinColumnFor(ctx, colName, false) if err != nil { return err } @@ -215,9 +406,9 @@ func splitProjectionAcrossJoin( rhs.add(in, colName) case col.IsMixedLeftAndRight(): for _, lhsExpr := range col.LHSExprs { - lhs.add(&Expr{E: lhsExpr}, "") + lhs.add(&UnexploredExpression{E: lhsExpr}, aeWrap(lhsExpr)) } - rhs.add(&Expr{E: col.RHSExpr}, colName) + rhs.add(&UnexploredExpression{E: col.RHSExpr}, &sqlparser.AliasedExpr{Expr: col.RHSExpr, As: colName.As}) } // Add the new JoinColumn to the ApplyJoin's ColumnsAST. @@ -266,7 +457,7 @@ func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Proje alias := sqlparser.UnescapedString(out) predicate.LHSExprs[idx] = sqlparser.NewColNameWithQualifier(alias, derivedTblName) - lhs.add(&Expr{E: out}, alias) + lhs.add(&UnexploredExpression{E: out}, &sqlparser.AliasedExpr{Expr: out, As: sqlparser.NewIdentifierCI(alias)}) } } return nil @@ -298,8 +489,8 @@ func createProjectionWithTheseColumns( if err != nil { return nil, err } - proj.ColumnNames = p.names - proj.Columns = p.cols + proj.Columns = p.names + proj.Projections = p.cols proj.TableID = tableID proj.Alias = alias return proj, nil @@ -310,22 +501,25 @@ func stopAtRoute(operator ops.Operator) rewrite.VisitRule { return rewrite.VisitRule(!isRoute) } -func tryPushingDownLimit(in *Limit) (ops.Operator, rewrite.ApplyResult, error) { +func tryPushingDownLimit(in *Limit) (ops.Operator, *rewrite.ApplyResult, error) { switch src := in.Source.(type) { case *Route: return tryPushingDownLimitInRoute(in, src) case *Projection: - return swap(in, src) + return rewrite.Swap(in, src, "push limit under projection") + case *Aggregator: + return in, rewrite.SameTree, nil default: - if in.Pushed { - return in, rewrite.SameTree, nil - } return setUpperLimit(in) } } -func setUpperLimit(in *Limit) (ops.Operator, rewrite.ApplyResult, error) { - visitor := func(op ops.Operator, _ semantics.TableSet, _ bool) (ops.Operator, rewrite.ApplyResult, error) { +func setUpperLimit(in *Limit) (ops.Operator, *rewrite.ApplyResult, error) { + if in.Pushed { + return in, rewrite.SameTree, nil + } + in.Pushed = true + visitor := func(op ops.Operator, _ semantics.TableSet, _ bool) (ops.Operator, *rewrite.ApplyResult, error) { return op, rewrite.SameTree, nil } shouldVisit := func(op ops.Operator) rewrite.VisitRule { @@ -348,45 +542,45 @@ func setUpperLimit(in *Limit) (ops.Operator, rewrite.ApplyResult, error) { _, err := rewrite.TopDown(in.Source, TableID, visitor, shouldVisit) if err != nil { - return nil, false, err + return nil, nil, err } return in, rewrite.SameTree, nil } -func tryPushingDownLimitInRoute(in *Limit, src *Route) (ops.Operator, rewrite.ApplyResult, error) { +func tryPushingDownLimitInRoute(in *Limit, src *Route) (ops.Operator, *rewrite.ApplyResult, error) { if src.IsSingleShard() { - return swap(in, src) + return rewrite.Swap(in, src, "limit pushed into single sharded route") } return setUpperLimit(in) } -func pushOrExpandHorizon(ctx *plancontext.PlanningContext, in horizonLike) (ops.Operator, error) { +func pushOrExpandHorizon(ctx *plancontext.PlanningContext, in horizonLike) (ops.Operator, *rewrite.ApplyResult, error) { if derived, ok := in.(*Derived); ok { if len(derived.ColumnAliases) > 0 { - return nil, errHorizonNotPlanned() + return nil, nil, errHorizonNotPlanned() } } rb, isRoute := in.src().(*Route) if isRoute && rb.IsSingleShard() { - return rewrite.Swap(in, rb) + return rewrite.Swap(in, rb, "push horizon into route") } sel, isSel := in.selectStatement().(*sqlparser.Select) if !isSel { - return nil, errHorizonNotPlanned() + return nil, nil, errHorizonNotPlanned() } qp, err := in.getQP(ctx) if err != nil { - return nil, err + return nil, nil, err } needsOrdering := len(qp.OrderExprs) > 0 canPushDown := isRoute && sel.Having == nil && !needsOrdering && !qp.NeedsAggregation() && !sel.Distinct && sel.Limit == nil if canPushDown { - return rewrite.Swap(in, rb) + return rewrite.Swap(in, rb, "push horizon into route") } return expandHorizon(ctx, in) @@ -400,70 +594,164 @@ type horizonLike interface { getQP(ctx *plancontext.PlanningContext) (*QueryProjection, error) } -func expandHorizon(ctx *plancontext.PlanningContext, horizon horizonLike) (ops.Operator, error) { +func expandHorizon(ctx *plancontext.PlanningContext, horizon horizonLike) (ops.Operator, *rewrite.ApplyResult, error) { sel, isSel := horizon.selectStatement().(*sqlparser.Select) if !isSel { - return nil, errHorizonNotPlanned() + return nil, nil, errHorizonNotPlanned() + } + qp, err := horizon.getQP(ctx) + if err != nil { + return nil, nil, err + } + + if sel.Having != nil || qp.NeedsDistinct() || sel.Distinct { + return nil, nil, errHorizonNotPlanned() + } + + op, err := createProjectionFromSelect(ctx, horizon) + if err != nil { + return nil, nil, err + } + + if sel.Limit != nil { + op = &Limit{ + Source: op, + AST: sel.Limit, + } + } + + return op, rewrite.NewTree("expand horizon into smaller components", op), nil +} + +func checkInvalid(aggregations []Aggr, horizon horizonLike) error { + for _, aggregation := range aggregations { + if aggregation.Distinct { + return errHorizonNotPlanned() + } + } + if _, isDerived := horizon.(*Derived); isDerived { + return errHorizonNotPlanned() } + return nil +} + +func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon horizonLike) (out ops.Operator, err error) { qp, err := horizon.getQP(ctx) if err != nil { return nil, err } - src := horizon.src() + if !qp.NeedsAggregation() { + projX, err := createProjectionWithoutAggr(qp, horizon.src()) + if err != nil { + return nil, err + } + if derived, isDerived := horizon.(*Derived); isDerived { + id := derived.TableId + projX.TableID = &id + projX.Alias = derived.Alias + } + out = projX + if qp.OrderExprs != nil { + out = &Ordering{ + Source: out, + Order: qp.OrderExprs, + } + } - if qp.NeedsAggregation() || sel.Having != nil || qp.NeedsDistinct() || sel.Distinct { - return nil, errHorizonNotPlanned() + return out, nil } - var op ops.Operator - proj, err := createProjectionFromSelect(src, qp.SelectExprs) + err = checkAggregationSupported(horizon) if err != nil { return nil, err } + + aggregations, err := qp.AggregationExpressions(ctx) + if err != nil { + return nil, err + } + + if err := checkInvalid(aggregations, horizon); err != nil { + return nil, err + } + + a := &Aggregator{ + Source: horizon.src(), + Original: true, + QP: qp, + Grouping: qp.GetGrouping(), + Aggregations: aggregations, + } + if derived, isDerived := horizon.(*Derived); isDerived { id := derived.TableId - proj.TableID = &id - proj.Alias = derived.Alias + a.TableID = &id + a.Alias = derived.Alias } - op = proj - if qp.OrderExprs != nil { - op = &Ordering{ - Source: op, - Order: qp.OrderExprs, +outer: + for colIdx, expr := range qp.SelectExprs { + ae, err := expr.GetAliasedExpr() + if err != nil { + return nil, err + } + for idx, groupBy := range a.Grouping { + if ae == groupBy.aliasedExpr { + a.Columns = append(a.Columns, ae) + a.Grouping[idx].ColOffset = colIdx + continue outer + } } + for idx, aggr := range a.Aggregations { + if ae == aggr.Original { + a.Columns = append(a.Columns, ae) + a.Aggregations[idx].ColOffset = colIdx + continue outer + } + } + return nil, vterrors.VT13001(fmt.Sprintf("Could not find the %v in aggregation in the original query", expr)) } - if sel.Limit != nil { - op = &Limit{ - Source: op, - AST: sel.Limit, - } + // If ordering is required, create an Ordering operation. + if len(qp.OrderExprs) > 0 { + return &Ordering{ + Source: a, + Order: qp.OrderExprs, + }, nil } - return op, nil + return a, nil } -func createProjectionFromSelect(src ops.Operator, selectExprs []SelectExpr) (*Projection, error) { +func createProjectionWithoutAggr(qp *QueryProjection, src ops.Operator) (*Projection, error) { proj := &Projection{ Source: src, } - for _, e := range selectExprs { + for _, e := range qp.SelectExprs { if _, isStar := e.Col.(*sqlparser.StarExpr); isStar { return nil, errHorizonNotPlanned() } - expr, err := e.GetAliasedExpr() + ae, err := e.GetAliasedExpr() + if err != nil { return nil, err } - proj.Columns = append(proj.Columns, Expr{E: expr.Expr}) - colName := "" - if !expr.As.IsEmpty() { - colName = expr.ColumnName() + expr := ae.Expr + if sqlparser.ContainsAggregation(expr) { + aggr, ok := expr.(sqlparser.AggrFunc) + if !ok { + // need to add logic to extract aggregations and pushed them to the top level + return nil, errHorizonNotPlanned() + } + expr = aggr.GetArg() + if expr == nil { + expr = sqlparser.NewIntLiteral("1") + } } - proj.ColumnNames = append(proj.ColumnNames, colName) + + proj.addUnexploredExpr(ae, expr) } return proj, nil } @@ -471,3 +759,39 @@ func createProjectionFromSelect(src ops.Operator, selectExprs []SelectExpr) (*Pr func aeWrap(e sqlparser.Expr) *sqlparser.AliasedExpr { return &sqlparser.AliasedExpr{Expr: e} } + +func makeSureOutputIsCorrect(ctx *plancontext.PlanningContext, oldHorizon ops.Operator, output ops.Operator) (ops.Operator, error) { + // next we use the original Horizon to make sure that the output columns line up with what the user asked for + // in the future, we'll tidy up the results. for now, we are just failing these queries and going back to the + // old horizon planning instead + cols, err := output.GetColumns() + if err != nil { + return nil, err + } + + horizon := oldHorizon.(*Horizon) + + sel := sqlparser.GetFirstSelect(horizon.Select) + + if len(sel.SelectExprs) == len(cols) { + return output, nil + } + + if tryTruncateColumnsAt(output, len(sel.SelectExprs)) { + return output, nil + } + + qp, err := horizon.getQP(ctx) + if err != nil { + return nil, err + } + proj, err := createProjectionWithoutAggr(qp, output) + if err != nil { + return nil, err + } + err = proj.passThroughAllColumns(ctx) + if err != nil { + return nil, err + } + return proj, nil +} diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index e4cc94a6068..5470f3ac378 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -61,7 +61,7 @@ func (j *Join) SetInputs(ops []ops.Operator) { j.LHS, j.RHS = ops[0], ops[1] } -func (j *Join) Compact(ctx *plancontext.PlanningContext) (ops.Operator, rewrite.ApplyResult, error) { +func (j *Join) Compact(ctx *plancontext.PlanningContext) (ops.Operator, *rewrite.ApplyResult, error) { if j.LeftJoin { // we can't merge outer joins into a single QG return j, rewrite.SameTree, nil @@ -84,7 +84,7 @@ func (j *Join) Compact(ctx *plancontext.PlanningContext) (ops.Operator, rewrite. return nil, rewrite.SameTree, err } } - return newOp, rewrite.NewTree, nil + return newOp, rewrite.NewTree("merge querygraphs into a single one", newOp), nil } func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs ops.Operator) (ops.Operator, error) { diff --git a/go/vt/vtgate/planbuilder/operators/joins.go b/go/vt/vtgate/planbuilder/operators/joins.go index 01ce66ba814..d027d4115a7 100644 --- a/go/vt/vtgate/planbuilder/operators/joins.go +++ b/go/vt/vtgate/planbuilder/operators/joins.go @@ -106,7 +106,7 @@ func AddPredicate( // matched no rows on the right-hand, if we are later going to remove all the rows where the right-hand // side did not match, we might as well turn the join into an inner join. // -// This is based on the paper "Canonical Abstraction for Outerjoin Optimization" by J Rao et al +// This is based on the paper "Canonical Abstraction for Outerjoin Optimization" by J Rao et al. func canConvertToInner(ctx *plancontext.PlanningContext, expr sqlparser.Expr, rhs semantics.TableSet) bool { isColNameFromRHS := func(e sqlparser.Expr) bool { return sqlparser.IsColName(e) && ctx.SemTable.RecursiveDeps(e).IsSolvedBy(rhs) diff --git a/go/vt/vtgate/planbuilder/operators/limit.go b/go/vt/vtgate/planbuilder/operators/limit.go index a13a6fe9f00..a2531b4bde5 100644 --- a/go/vt/vtgate/planbuilder/operators/limit.go +++ b/go/vt/vtgate/planbuilder/operators/limit.go @@ -56,8 +56,8 @@ func (l *Limit) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Ex return l, nil } -func (l *Limit) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr) (ops.Operator, int, error) { - newSrc, offset, err := l.Source.AddColumn(ctx, expr) +func (l *Limit) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, reuseExisting, addToGroupBy bool) (ops.Operator, int, error) { + newSrc, offset, err := l.Source.AddColumn(ctx, expr, reuseExisting, addToGroupBy) if err != nil { return nil, 0, err } diff --git a/go/vt/vtgate/planbuilder/operators/merging.go b/go/vt/vtgate/planbuilder/operators/merging.go index 6edc2400668..846d9f11784 100644 --- a/go/vt/vtgate/planbuilder/operators/merging.go +++ b/go/vt/vtgate/planbuilder/operators/merging.go @@ -237,7 +237,7 @@ func (s *subQueryMerger) markPredicateInOuterRouting(outer *ShardedRouting, inne // predicates list, so this might be a no-op. subQueryWasPredicate := false for i, predicate := range outer.SeenPredicates { - if s.ctx.SemTable.EqualsExpr(predicate, s.subq.ExtractedSubquery) { + if s.ctx.SemTable.EqualsExprWithDeps(predicate, s.subq.ExtractedSubquery) { outer.SeenPredicates = append(outer.SeenPredicates[:i], outer.SeenPredicates[i+1:]...) subQueryWasPredicate = true diff --git a/go/vt/vtgate/planbuilder/operators/offset_planning.go b/go/vt/vtgate/planbuilder/operators/offset_planning.go index 1337f0ec7b9..f4287f281ed 100644 --- a/go/vt/vtgate/planbuilder/operators/offset_planning.go +++ b/go/vt/vtgate/planbuilder/operators/offset_planning.go @@ -17,9 +17,10 @@ limitations under the License. package operators import ( + "fmt" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -32,20 +33,16 @@ func planOffsets(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Opera planOffsets(ctx *plancontext.PlanningContext) error } - visitor := func(in ops.Operator, _ semantics.TableSet, _ bool) (ops.Operator, rewrite.ApplyResult, error) { + visitor := func(in ops.Operator, _ semantics.TableSet, _ bool) (ops.Operator, *rewrite.ApplyResult, error) { var err error switch op := in.(type) { - case *Horizon: - return nil, false, vterrors.VT13001("should not see Horizons here") - case *Derived: - return nil, false, vterrors.VT13001("should not see Derived here") + case *Derived, *Horizon: + return nil, nil, vterrors.VT13001(fmt.Sprintf("should not see %T here", in)) case offsettable: err = op.planOffsets(ctx) - case *Projection: - return op.planOffsetsForProjection(ctx) } if err != nil { - return nil, false, err + return nil, nil, err } return in, rewrite.SameTree, nil } @@ -62,64 +59,15 @@ func planOffsets(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Opera return op, nil } -func (p *Projection) planOffsetsForProjection(ctx *plancontext.PlanningContext) (ops.Operator, rewrite.ApplyResult, error) { - var err error - for i, col := range p.Columns { - rewritten := sqlparser.CopyOnRewrite(col.GetExpr(), nil, func(cursor *sqlparser.CopyOnWriteCursor) { - column := cursor.Node() - expr, ok := column.(sqlparser.Expr) - if !ok { - return - } - if !fetchByOffset(column) { - return - } - - newSrc, offset, terr := p.Source.AddColumn(ctx, aeWrap(expr)) - if terr != nil { - err = terr - return - } - p.Source = newSrc - cursor.Replace(sqlparser.NewOffset(offset, expr)) - }, nil).(sqlparser.Expr) - if err != nil { - return nil, false, err - } - - offset, ok := rewritten.(*sqlparser.Offset) - if ok { - // we got a pure offset back. No need to do anything else - p.Columns[i] = Offset{ - Expr: col.GetExpr(), - Offset: offset.V, - } - continue - } - - eexpr, err := evalengine.Translate(rewritten, nil) - if err != nil { - return nil, false, err - } - - p.Columns[i] = Eval{ - Expr: rewritten, - EExpr: eexpr, - } - } - - return p, rewrite.SameTree, nil -} - func (p *Projection) passThroughAllColumns(ctx *plancontext.PlanningContext) error { - for i, col := range p.Columns { - newSrc, offset, err := p.Source.AddColumn(ctx, aeWrap(col.GetExpr())) + for i, col := range p.Projections { + newSrc, offset, err := p.Source.AddColumn(ctx, aeWrap(col.GetExpr()), true, false) if err != nil { return err } p.Source = newSrc - p.Columns[i] = Offset{ + p.Projections[i] = Offset{ Expr: col.GetExpr(), Offset: offset, } diff --git a/go/vt/vtgate/planbuilder/operators/operator.go b/go/vt/vtgate/planbuilder/operators/operator.go index 680525cfd9c..b4037739d67 100644 --- a/go/vt/vtgate/planbuilder/operators/operator.go +++ b/go/vt/vtgate/planbuilder/operators/operator.go @@ -22,7 +22,7 @@ The operators go through a few phases while planning: It will contain logical joins - we still haven't decided on the join algorithm to use yet. At the leaves, it will contain QueryGraphs - these are the tables in the FROM clause that we can easily do join ordering on. The logical tree will represent the full query, - including projections, grouping, ordering and so on. + including projections, Grouping, ordering and so on. 2. Physical Once the logical plan has been fully built, we go bottom up and plan which routes that will be used. During this phase, we will also decide which join algorithms should be used on the vtgate level @@ -83,76 +83,6 @@ func PlanQuery(ctx *plancontext.PlanningContext, selStmt sqlparser.Statement) (o return op, err } -func tryHorizonPlanning(ctx *plancontext.PlanningContext, root ops.Operator) (output ops.Operator, err error) { - backup := Clone(root) - defer func() { - if err == errHorizonNotPlanned() { - err = planOffsetsOnJoins(ctx, backup) - if err == nil { - output = backup - } - } - }() - - _, ok := root.(*Horizon) - - if !ok || len(ctx.SemTable.SubqueryMap) > 0 || len(ctx.SemTable.SubqueryRef) > 0 { - // we are not ready to deal with subqueries yet - return root, errHorizonNotPlanned() - } - - output, err = planHorizons(ctx, root) - if err != nil { - return nil, err - } - - output, err = planOffsets(ctx, output) - if err != nil { - return nil, err - } - - output, err = makeSureOutputIsCorrect(ctx, root, output) - if err != nil { - return nil, err - } - - return -} - -func makeSureOutputIsCorrect(ctx *plancontext.PlanningContext, oldHorizon ops.Operator, output ops.Operator) (ops.Operator, error) { - // next we use the original Horizon to make sure that the output columns line up with what the user asked for - // in the future, we'll tidy up the results. for now, we are just failing these queries and going back to the - // old horizon planning instead - cols, err := output.GetColumns() - if err != nil { - return nil, err - } - - horizon := oldHorizon.(*Horizon) - - sel := sqlparser.GetFirstSelect(horizon.Select) - if len(sel.SelectExprs) != len(cols) { - route := getRouteIfPassThroughColumns(output) - if route != nil { - route.ResultColumns = len(sel.SelectExprs) - return output, nil - } - } - qp, err := horizon.getQP(ctx) - if err != nil { - return nil, err - } - proj, err := createProjectionFromSelect(output, qp.SelectExprs) - if err != nil { - return nil, err - } - err = proj.passThroughAllColumns(ctx) - if err != nil { - return nil, err - } - return proj, nil -} - // Inputs implements the Operator interface func (noInputs) Inputs() []ops.Operator { return nil @@ -166,7 +96,7 @@ func (noInputs) SetInputs(ops []ops.Operator) { } // AddColumn implements the Operator interface -func (noColumns) AddColumn(*plancontext.PlanningContext, *sqlparser.AliasedExpr) (ops.Operator, int, error) { +func (noColumns) AddColumn(*plancontext.PlanningContext, *sqlparser.AliasedExpr, bool, bool) (ops.Operator, int, error) { return nil, 0, vterrors.VT13001("the noColumns operator cannot accept columns") } @@ -179,27 +109,29 @@ func (noPredicates) AddPredicate(*plancontext.PlanningContext, sqlparser.Expr) ( return nil, vterrors.VT13001("the noColumns operator cannot accept predicates") } -// getRouteIfPassThroughColumns will return the route that is feeding this operator, -// if it can be reached only through operators that pass through all columns -// This is used to limit the number of columns passed through by asking the route -// to truncate the results -func getRouteIfPassThroughColumns(op ops.Operator) *Route { - route, isRoute := op.(*Route) - if isRoute { - return route +// tryTruncateColumnsAt will see if we can truncate the columns by just asking the operator to do it for us +func tryTruncateColumnsAt(op ops.Operator, truncateAt int) bool { + type columnTruncator interface { + setTruncateColumnCount(offset int) + } + + truncator, ok := op.(columnTruncator) + if ok { + truncator.setTruncateColumnCount(truncateAt) + return true } inputs := op.Inputs() if len(inputs) != 1 { - return nil + return false } switch op.(type) { case *Limit: // empty by design default: - return nil + return false } - return getRouteIfPassThroughColumns(inputs[0]) + return tryTruncateColumnsAt(inputs[0], truncateAt) } diff --git a/go/vt/vtgate/planbuilder/operators/operator_funcs.go b/go/vt/vtgate/planbuilder/operators/operator_funcs.go index ee23175eac3..7f7aaff29c5 100644 --- a/go/vt/vtgate/planbuilder/operators/operator_funcs.go +++ b/go/vt/vtgate/planbuilder/operators/operator_funcs.go @@ -59,7 +59,7 @@ func RemovePredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op o var keep []sqlparser.Expr for _, e := range sqlparser.SplitAndExpression(nil, op.Predicate) { - if ctx.SemTable.EqualsExpr(expr, e) { + if ctx.SemTable.EqualsExprWithDeps(expr, e) { isRemoved = true } else { keep = append(keep, e) @@ -75,7 +75,7 @@ func RemovePredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op o case *Filter: idx := -1 for i, predicate := range op.Predicates { - if ctx.SemTable.EqualsExpr(predicate, expr) { + if ctx.SemTable.EqualsExprWithDeps(predicate, expr) { idx = i } } diff --git a/go/vt/vtgate/planbuilder/operators/ops/op.go b/go/vt/vtgate/planbuilder/operators/ops/op.go index ab45fc43cab..d68daed439e 100644 --- a/go/vt/vtgate/planbuilder/operators/ops/op.go +++ b/go/vt/vtgate/planbuilder/operators/ops/op.go @@ -45,7 +45,7 @@ type ( // AddColumn tells an operator to also output an additional column specified. // The offset to the column is returned. - AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr) (Operator, int, error) + AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, reuseExisting, addToGroupBy bool) (Operator, int, error) GetColumns() ([]*sqlparser.AliasedExpr, error) @@ -57,8 +57,10 @@ type ( // OrderBy contains the expression to used in order by and also if ordering is needed at VTGate level then what the weight_string function expression to be sent down for evaluation. OrderBy struct { - Inner *sqlparser.Order - WeightStrExpr sqlparser.Expr + Inner *sqlparser.Order + + // See GroupBy#SimplifiedExpr for more details about this + SimplifiedExpr sqlparser.Expr } OpDescription struct { diff --git a/go/vt/vtgate/planbuilder/operators/ordering.go b/go/vt/vtgate/planbuilder/operators/ordering.go index f16a7989499..60eff4bbaaf 100644 --- a/go/vt/vtgate/planbuilder/operators/ordering.go +++ b/go/vt/vtgate/planbuilder/operators/ordering.go @@ -33,7 +33,8 @@ type Ordering struct { Offset []int WOffset []int - Order []ops.OrderBy + Order []ops.OrderBy + ResultColumns int } func (o *Ordering) Clone(inputs []ops.Operator) ops.Operator { @@ -62,8 +63,8 @@ func (o *Ordering) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser return o, nil } -func (o *Ordering) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr) (ops.Operator, int, error) { - newSrc, offset, err := o.Source.AddColumn(ctx, expr) +func (o *Ordering) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, reuseExisting, addToGroupBy bool) (ops.Operator, int, error) { + newSrc, offset, err := o.Source.AddColumn(ctx, expr, reuseExisting, addToGroupBy) if err != nil { return nil, 0, err } @@ -81,20 +82,20 @@ func (o *Ordering) GetOrdering() ([]ops.OrderBy, error) { func (o *Ordering) planOffsets(ctx *plancontext.PlanningContext) error { for _, order := range o.Order { - newSrc, offset, err := o.Source.AddColumn(ctx, aeWrap(order.Inner.Expr)) + newSrc, offset, err := o.Source.AddColumn(ctx, aeWrap(order.SimplifiedExpr), true, false) if err != nil { return err } o.Source = newSrc o.Offset = append(o.Offset, offset) - if !ctx.SemTable.NeedsWeightString(order.WeightStrExpr) { + if !ctx.SemTable.NeedsWeightString(order.SimplifiedExpr) { o.WOffset = append(o.WOffset, -1) continue } - wsExpr := &sqlparser.WeightStringFuncExpr{Expr: order.WeightStrExpr} - newSrc, offset, err = o.Source.AddColumn(ctx, aeWrap(wsExpr)) + wsExpr := &sqlparser.WeightStringFuncExpr{Expr: order.SimplifiedExpr} + newSrc, offset, err = o.Source.AddColumn(ctx, aeWrap(wsExpr), true, false) if err != nil { return err } @@ -118,3 +119,7 @@ func (o *Ordering) ShortDescription() string { }) return strings.Join(ordering, ", ") } + +func (o *Ordering) setTruncateColumnCount(offset int) { + o.ResultColumns = offset +} diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index 756a2f98bf7..58904df9273 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -35,13 +35,20 @@ type ( // Projection is used when we need to evaluate expressions on the vtgate // It uses the evalengine to accomplish its goal Projection struct { - Source ops.Operator - ColumnNames []string - Columns []ProjExpr + Source ops.Operator + + // Columns contain the expressions as viewed from the outside of this operator + Columns []*sqlparser.AliasedExpr + + // Projections will contain the actual evaluations we need to + // do if this operator is still above a route after optimisation + Projections []ProjExpr // TableID will be non-nil for derived tables TableID *semantics.TableSet Alias string + + FromAggr bool } ProjExpr interface { @@ -60,38 +67,58 @@ type ( EExpr evalengine.Expr } - // Expr is used before we have planned, or if we are able to push this down to mysql - Expr struct { + // UnexploredExpression is used before we have planned - one of two end results are possible for it + // - we are able to push this projection under a route, and then this is not used at all - we'll just + // use the ColumnNames field of the Projection struct + // - we have to evaluate this on the vtgate, and either it's just a copy from the input, + // or it's an evalengine expression that we have to evaluate + UnexploredExpression struct { E sqlparser.Expr } ) -func (p *Projection) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr) (ops.Operator, int, error) { - colAsExpr := func(pe ProjExpr) sqlparser.Expr { return pe.GetExpr() } - if offset, found := canReuseColumn(ctx, p.Columns, expr.Expr, colAsExpr); found { +var _ selectExpressions = (*Projection)(nil) + +func (p *Projection) addUnexploredExpr(ae *sqlparser.AliasedExpr, e sqlparser.Expr) int { + p.Projections = append(p.Projections, UnexploredExpression{E: e}) + p.Columns = append(p.Columns, ae) + return len(p.Projections) - 1 +} + +func (p *Projection) addNoPushCol(expr *sqlparser.AliasedExpr, _ bool) int { + return p.addUnexploredExpr(expr, expr.Expr) +} + +func (p *Projection) isDerived() bool { + return p.TableID != nil +} + +func (p *Projection) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, _, addToGroupBy bool) (ops.Operator, int, error) { + if offset, found := canReuseColumn(ctx, p.Columns, expr.Expr, extractExpr); found { return p, offset, nil } - sourceOp, offset, err := p.Source.AddColumn(ctx, expr) + sourceOp, offset, err := p.Source.AddColumn(ctx, expr, true, addToGroupBy) if err != nil { return nil, 0, err } p.Source = sourceOp - p.Columns = append(p.Columns, Offset{Offset: offset, Expr: expr.Expr}) - p.ColumnNames = append(p.ColumnNames, expr.As.String()) - return p, len(p.Columns) - 1, nil + p.Projections = append(p.Projections, Offset{Offset: offset, Expr: expr.Expr}) + p.Columns = append(p.Columns, expr) + return p, len(p.Projections) - 1, nil } -func (po Offset) GetExpr() sqlparser.Expr { return po.Expr } -func (po Eval) GetExpr() sqlparser.Expr { return po.Expr } -func (po Expr) GetExpr() sqlparser.Expr { return po.E } +func (po Offset) GetExpr() sqlparser.Expr { return po.Expr } +func (po Eval) GetExpr() sqlparser.Expr { return po.Expr } +func (po UnexploredExpression) GetExpr() sqlparser.Expr { return po.E } func (p *Projection) Clone(inputs []ops.Operator) ops.Operator { return &Projection{ Source: inputs[0], - ColumnNames: slices.Clone(p.ColumnNames), Columns: slices.Clone(p.Columns), + Projections: slices.Clone(p.Projections), TableID: p.TableID, Alias: p.Alias, + FromAggr: p.FromAggr, } } @@ -113,22 +140,11 @@ func (p *Projection) AddPredicate(ctx *plancontext.PlanningContext, expr sqlpars return p, nil } -func (p *Projection) expressions() (result []*sqlparser.AliasedExpr) { - for i, col := range p.Columns { - expr := col.GetExpr() - result = append(result, &sqlparser.AliasedExpr{ - Expr: expr, - As: sqlparser.NewIdentifierCI(p.ColumnNames[i]), - }) - } - return -} - func (p *Projection) GetColumns() ([]*sqlparser.AliasedExpr, error) { if p.TableID != nil { return nil, nil } - return p.expressions(), nil + return p.Columns, nil } func (p *Projection) GetOrdering() ([]ops.OrderBy, error) { @@ -138,7 +154,7 @@ func (p *Projection) GetOrdering() ([]ops.OrderBy, error) { // AllOffsets returns a slice of integer offsets for all columns in the Projection // if all columns are of type Offset. If any column is not of type Offset, it returns nil. func (p *Projection) AllOffsets() (cols []int) { - for _, c := range p.Columns { + for _, c := range p.Projections { offset, ok := c.(Offset) if !ok { return nil @@ -151,12 +167,12 @@ func (p *Projection) AllOffsets() (cols []int) { func (p *Projection) Description() ops.OpDescription { var columns []string - for i, col := range p.Columns { - alias := p.ColumnNames[i] - if alias == "" { - columns = append(columns, sqlparser.String(col.GetExpr())) + for i, col := range p.Projections { + aliasExpr := p.Columns[i] + if aliasExpr.Expr == col.GetExpr() { + columns = append(columns, sqlparser.String(aliasExpr)) } else { - columns = append(columns, fmt.Sprintf("%s AS %s", sqlparser.String(col.GetExpr()), alias)) + columns = append(columns, fmt.Sprintf("%s AS %s", sqlparser.String(col.GetExpr()), aliasExpr.ColumnName())) } } @@ -178,20 +194,18 @@ func (p *Projection) ShortDescription() string { if p.Alias != "" { columns = append(columns, "derived["+p.Alias+"]") } - for i, column := range p.Columns { - expr := sqlparser.String(column.GetExpr()) - alias := p.ColumnNames[i] - if alias == "" { - columns = append(columns, expr) - continue + for i, col := range p.Projections { + aliasExpr := p.Columns[i] + if aliasExpr.Expr == col.GetExpr() { + columns = append(columns, sqlparser.String(aliasExpr)) + } else { + columns = append(columns, fmt.Sprintf("%s AS %s", sqlparser.String(col.GetExpr()), aliasExpr.ColumnName())) } - columns = append(columns, fmt.Sprintf("%s AS %s", expr, alias)) - } return strings.Join(columns, ", ") } -func (p *Projection) Compact(*plancontext.PlanningContext) (ops.Operator, rewrite.ApplyResult, error) { +func (p *Projection) Compact(*plancontext.PlanningContext) (ops.Operator, *rewrite.ApplyResult, error) { switch src := p.Source.(type) { case *Route: return p.compactWithRoute(src) @@ -201,10 +215,10 @@ func (p *Projection) Compact(*plancontext.PlanningContext) (ops.Operator, rewrit return p, rewrite.SameTree, nil } -func (p *Projection) compactWithJoin(src *ApplyJoin) (ops.Operator, rewrite.ApplyResult, error) { +func (p *Projection) compactWithJoin(src *ApplyJoin) (ops.Operator, *rewrite.ApplyResult, error) { var newColumns []int var newColumnsAST []JoinColumn - for _, col := range p.Columns { + for _, col := range p.Projections { offset, ok := col.(Offset) if !ok { return p, rewrite.SameTree, nil @@ -216,11 +230,11 @@ func (p *Projection) compactWithJoin(src *ApplyJoin) (ops.Operator, rewrite.Appl src.Columns = newColumns src.ColumnsAST = newColumnsAST - return src, rewrite.NewTree, nil + return src, rewrite.NewTree("remove projection from before join", src), nil } -func (p *Projection) compactWithRoute(rb *Route) (ops.Operator, rewrite.ApplyResult, error) { - for i, col := range p.Columns { +func (p *Projection) compactWithRoute(rb *Route) (ops.Operator, *rewrite.ApplyResult, error) { + for i, col := range p.Projections { offset, ok := col.(Offset) if !ok || offset.Offset != i { return p, rewrite.SameTree, nil @@ -228,12 +242,72 @@ func (p *Projection) compactWithRoute(rb *Route) (ops.Operator, rewrite.ApplyRes } columns, err := rb.GetColumns() if err != nil { - return nil, false, err + return nil, nil, err } - if len(columns) == len(p.Columns) { - return rb, rewrite.NewTree, nil + if len(columns) == len(p.Projections) { + return rb, rewrite.NewTree("remove projection from before route", rb), nil } rb.ResultColumns = len(columns) return rb, rewrite.SameTree, nil } + +func stopAtAggregations(node, _ sqlparser.SQLNode) bool { + _, aggr := node.(sqlparser.AggrFunc) + b := !aggr + return b +} + +func (p *Projection) planOffsets(ctx *plancontext.PlanningContext) error { + var err error + offsetter := func(cursor *sqlparser.CopyOnWriteCursor) { + expr, ok := cursor.Node().(sqlparser.Expr) + if !ok || !fetchByOffset(expr) { + return + } + + newSrc, offset, terr := p.Source.AddColumn(ctx, aeWrap(expr), true, false) + if terr != nil { + err = terr + return + } + p.Source = newSrc + cursor.Replace(sqlparser.NewOffset(offset, expr)) + } + + for i, col := range p.Projections { + _, unexplored := col.(UnexploredExpression) + if !unexplored { + continue + } + + // first step is to replace the expressions we expect to get from our input with the offsets for these + rewritten := sqlparser.CopyOnRewrite(col.GetExpr(), stopAtAggregations, offsetter, nil).(sqlparser.Expr) + if err != nil { + return err + } + + offset, ok := rewritten.(*sqlparser.Offset) + if ok { + // we got a pure offset back. No need to do anything else + p.Projections[i] = Offset{ + Expr: col.GetExpr(), + Offset: offset.V, + } + continue + } + + // for everything else, we'll turn to the evalengine + eexpr, err := evalengine.Translate(rewritten, nil) + if err != nil { + return err + } + + p.Projections[i] = Eval{ + Expr: rewritten, + EExpr: eexpr, + } + } + + return nil +} diff --git a/go/vt/vtgate/planbuilder/operators/querygraph.go b/go/vt/vtgate/planbuilder/operators/querygraph.go index 35d167b9efd..a764ca3db89 100644 --- a/go/vt/vtgate/planbuilder/operators/querygraph.go +++ b/go/vt/vtgate/planbuilder/operators/querygraph.go @@ -184,7 +184,7 @@ func (qg *QueryGraph) UnsolvedPredicates(_ *semantics.SemTable) []sqlparser.Expr } // Clone implements the Operator interface -func (qg *QueryGraph) Clone(inputs []ops.Operator) ops.Operator { +func (qg *QueryGraph) Clone([]ops.Operator) ops.Operator { result := &QueryGraph{ Tables: nil, innerJoins: nil, diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 270b557bf28..82032d52768 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -22,13 +22,16 @@ import ( "sort" "strings" + "golang.org/x/exp/slices" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" - popcode "vitess.io/vitess/go/vt/vtgate/engine/opcode" + "vitess.io/vitess/go/vt/vtgate/engine/opcode" ) type ( @@ -41,39 +44,58 @@ type ( // QueryProjection contains the information about the projections, group by and order by expressions used to do horizon planning. QueryProjection struct { // If you change the contents here, please update the toString() method - SelectExprs []SelectExpr - HasAggr bool - Distinct bool - groupByExprs []GroupBy - OrderExprs []ops.OrderBy - CanPushDownSorting bool - HasStar bool + SelectExprs []SelectExpr + HasAggr bool + Distinct bool + groupByExprs []GroupBy + OrderExprs []ops.OrderBy + HasStar bool // AddedColumn keeps a counter for expressions added to solve HAVING expressions the user is not selecting AddedColumn int + + hasCheckedAlignment bool + + // TODO Remove once all of horizon planning is done on the operators + CanPushDownSorting bool } // GroupBy contains the expression to used in group by and also if grouping is needed at VTGate level then what the weight_string function expression to be sent down for evaluation. GroupBy struct { - Inner sqlparser.Expr - WeightStrExpr sqlparser.Expr + Inner sqlparser.Expr + + // The simplified expressions is the "unaliased expression". + // In the following query, the group by has the inner expression + // `x` and the `SimplifiedExpr` is `table.col + 10`: + // select table.col + 10 as x, count(*) from tbl group by x + SimplifiedExpr sqlparser.Expr // The index at which the user expects to see this column. Set to nil, if the user does not ask for it InnerIndex *int // The original aliased expression that this group by is referring aliasedExpr *sqlparser.AliasedExpr + + // points to the column on the same aggregator + ColOffset int + WSOffset int } // Aggr encodes all information needed for aggregation functions Aggr struct { Original *sqlparser.AliasedExpr Func sqlparser.AggrFunc - OpCode popcode.AggregateOpcode - Alias string + OpCode opcode.AggregateOpcode + + // OriginalOpCode will contain opcode.AggregateUnassigned unless we are changing opcode while pushing them down + OriginalOpCode opcode.AggregateOpcode + + Alias string // The index at which the user expects to see this aggregated function. Set to nil, if the user does not ask for it Index *int Distinct bool + + ColOffset int // points to the column on the same aggregator } AggrRewriter struct { @@ -83,13 +105,24 @@ type ( } ) +// NewGroupBy creates a new group by from the given fields. +func NewGroupBy(inner, simplified sqlparser.Expr, aliasedExpr *sqlparser.AliasedExpr) GroupBy { + return GroupBy{ + Inner: inner, + SimplifiedExpr: simplified, + aliasedExpr: aliasedExpr, + ColOffset: -1, + WSOffset: -1, + } +} + func (b GroupBy) AsOrderBy() ops.OrderBy { return ops.OrderBy{ Inner: &sqlparser.Order{ Expr: b.Inner, Direction: sqlparser.AscOrder, }, - WeightStrExpr: b.WeightStrExpr, + SimplifiedExpr: b.SimplifiedExpr, } } @@ -98,18 +131,18 @@ func (b GroupBy) AsAliasedExpr() *sqlparser.AliasedExpr { return b.aliasedExpr } col, isColName := b.Inner.(*sqlparser.ColName) - if isColName && b.WeightStrExpr != b.Inner { + if isColName && b.SimplifiedExpr != b.Inner { return &sqlparser.AliasedExpr{ - Expr: b.WeightStrExpr, + Expr: b.SimplifiedExpr, As: col.Name, } } - if !isColName && b.WeightStrExpr != b.Inner { + if !isColName && b.SimplifiedExpr != b.Inner { panic("this should not happen - different inner and weighStringExpr and not a column alias") } return &sqlparser.AliasedExpr{ - Expr: b.WeightStrExpr, + Expr: b.SimplifiedExpr, } } @@ -148,21 +181,14 @@ func CreateQPFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) } for _, group := range sel.GroupBy { selectExprIdx, aliasExpr := qp.FindSelectExprIndexForExpr(ctx, group) - expr, weightStrExpr, err := qp.GetSimplifiedExpr(group) - if err != nil { - return nil, err - } + weightStrExpr := qp.GetSimplifiedExpr(group) err = checkForInvalidGroupingExpressions(weightStrExpr) if err != nil { return nil, err } - groupBy := GroupBy{ - Inner: expr, - WeightStrExpr: weightStrExpr, - InnerIndex: selectExprIdx, - aliasedExpr: aliasExpr, - } + groupBy := NewGroupBy(group, weightStrExpr, aliasExpr) + groupBy.InnerIndex = selectExprIdx qp.groupByExprs = append(qp.groupByExprs, groupBy) } @@ -207,7 +233,7 @@ func (ar *AggrRewriter) RewriteUp() func(*sqlparser.Cursor) bool { ar.Err = err return false } - if ar.st.EqualsExpr(ae.Expr, fExp) { + if ar.st.EqualsExprWithDeps(ae.Expr, fExp) { cursor.Replace(sqlparser.NewOffset(offset, fExp)) return true } @@ -285,20 +311,14 @@ func CreateQPFromUnion(union *sqlparser.Union) (*QueryProjection, error) { func (qp *QueryProjection) addOrderBy(orderBy sqlparser.OrderBy) error { canPushDownSorting := true for _, order := range orderBy { - expr, weightStrExpr, err := qp.GetSimplifiedExpr(order.Expr) - if err != nil { - return err - } + weightStrExpr := qp.GetSimplifiedExpr(order.Expr) if sqlparser.IsNull(weightStrExpr) { // ORDER BY null can safely be ignored continue } qp.OrderExprs = append(qp.OrderExprs, ops.OrderBy{ - Inner: &sqlparser.Order{ - Expr: expr, - Direction: order.Direction, - }, - WeightStrExpr: weightStrExpr, + Inner: sqlparser.CloneRefOfOrder(order), + SimplifiedExpr: weightStrExpr, }) canPushDownSorting = canPushDownSorting && !sqlparser.ContainsAggregation(weightStrExpr) } @@ -308,9 +328,7 @@ func (qp *QueryProjection) addOrderBy(orderBy sqlparser.OrderBy) error { // GetGrouping returns a copy of the grouping parameters of the QP func (qp *QueryProjection) GetGrouping() []GroupBy { - out := make([]GroupBy, len(qp.groupByExprs)) - copy(out, qp.groupByExprs) - return out + return slices.Clone(qp.groupByExprs) } func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) error { @@ -333,7 +351,7 @@ func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext if err != nil { return false } - if ctx.SemTable.EqualsExpr(groupByExpr.WeightStrExpr, exp) { + if ctx.SemTable.EqualsExprWithDeps(groupByExpr.SimplifiedExpr, exp) { return true } } @@ -341,34 +359,29 @@ func (qp *QueryProjection) isExprInGroupByExprs(ctx *plancontext.PlanningContext } // GetSimplifiedExpr takes an expression used in ORDER BY or GROUP BY, and returns an expression that is simpler to evaluate -func (qp *QueryProjection) GetSimplifiedExpr(e sqlparser.Expr) (expr sqlparser.Expr, weightStrExpr sqlparser.Expr, err error) { +func (qp *QueryProjection) GetSimplifiedExpr(e sqlparser.Expr) sqlparser.Expr { // If the ORDER BY is against a column alias, we need to remember the expression // behind the alias. The weightstring(.) calls needs to be done against that expression and not the alias. // Eg - select music.foo as bar, weightstring(music.foo) from music order by bar colExpr, isColName := e.(*sqlparser.ColName) - if !isColName { - return e, e, nil + if !(isColName && colExpr.Qualifier.IsEmpty()) { + // we are only interested in unqualified column names. if it's not a column name and not + return e } - if sqlparser.IsNull(e) { - return e, nil, nil - } - - if colExpr.Qualifier.IsEmpty() { - for _, selectExpr := range qp.SelectExprs { - aliasedExpr, isAliasedExpr := selectExpr.Col.(*sqlparser.AliasedExpr) - if !isAliasedExpr { - continue - } - isAliasExpr := !aliasedExpr.As.IsEmpty() - if isAliasExpr && colExpr.Name.Equal(aliasedExpr.As) { - return e, aliasedExpr.Expr, nil - } + for _, selectExpr := range qp.SelectExprs { + aliasedExpr, isAliasedExpr := selectExpr.Col.(*sqlparser.AliasedExpr) + if !isAliasedExpr { + continue + } + aliased := !aliasedExpr.As.IsEmpty() + if aliased && colExpr.Name.Equal(aliasedExpr.As) { + return aliasedExpr.Expr } } - return e, e, nil + return e } // toString should only be used for tests @@ -514,13 +527,13 @@ func (qp *QueryProjection) NeedsDistinct() bool { func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningContext) (out []Aggr, err error) { orderBy: for _, orderExpr := range qp.OrderExprs { - orderExpr := orderExpr.WeightStrExpr + orderExpr := orderExpr.SimplifiedExpr for _, expr := range qp.SelectExprs { col, ok := expr.Col.(*sqlparser.AliasedExpr) if !ok { continue } - if ctx.SemTable.EqualsExpr(col.Expr, orderExpr) { + if ctx.SemTable.EqualsExprWithDeps(col.Expr, orderExpr) { continue orderBy // we found the expression we were looking for! } } @@ -531,6 +544,9 @@ orderBy: qp.AddedColumn++ } + // Here we go over the expressions we are returning. Since we know we are aggregating, + // all expressions have to be either grouping expressions or aggregate expressions. + // If we find an expression that is neither, we treat is as a special aggregation function AggrRandom for idx, expr := range qp.SelectExprs { aliasedExpr, err := expr.GetAliasedExpr() if err != nil { @@ -543,7 +559,7 @@ orderBy: if !qp.isExprInGroupByExprs(ctx, expr) { out = append(out, Aggr{ Original: aliasedExpr, - OpCode: popcode.AggregateRandom, + OpCode: opcode.AggregateRandom, Alias: aliasedExpr.ColumnName(), Index: &idxCopy, }) @@ -555,32 +571,29 @@ orderBy: return nil, vterrors.VT12001("in scatter query: complex aggregate expression") } - opcode, found := popcode.SupportedAggregates[strings.ToLower(fnc.AggrName())] - if !found { - return nil, vterrors.VT12001(fmt.Sprintf("in scatter query: aggregation function '%s'", fnc.AggrName())) - } + code := opcode.SupportedAggregates[strings.ToLower(fnc.AggrName())] - if opcode == popcode.AggregateCount { + if code == opcode.AggregateCount { if _, isStar := fnc.(*sqlparser.CountStar); isStar { - opcode = popcode.AggregateCountStar + code = opcode.AggregateCountStar } } aggr, _ := aliasedExpr.Expr.(sqlparser.AggrFunc) if aggr.IsDistinct() { - switch opcode { - case popcode.AggregateCount: - opcode = popcode.AggregateCountDistinct - case popcode.AggregateSum: - opcode = popcode.AggregateSumDistinct + switch code { + case opcode.AggregateCount: + code = opcode.AggregateCountDistinct + case opcode.AggregateSum: + code = opcode.AggregateSumDistinct } } out = append(out, Aggr{ Original: aliasedExpr, Func: aggr, - OpCode: opcode, + OpCode: code, Alias: aliasedExpr.ColumnName(), Index: &idxCopy, Distinct: aggr.IsDistinct(), @@ -605,19 +618,15 @@ func (qp *QueryProjection) FindSelectExprIndexForExpr(ctx *plancontext.PlanningC return &idx, aliasedExpr } } - if ctx.SemTable.EqualsExpr(aliasedExpr.Expr, expr) { + if ctx.SemTable.EqualsExprWithDeps(aliasedExpr.Expr, expr) { return &idx, aliasedExpr } } return nil, nil } -// AlignGroupByAndOrderBy aligns the group by and order by columns, so they are in the same order -// The GROUP BY clause is a set - the order between the elements does not make any difference, -// so we can simply re-arrange the column order -// We are also free to add more ORDER BY columns than the user asked for which we leverage, -// so the input is already ordered according to the GROUP BY columns used -func (qp *QueryProjection) AlignGroupByAndOrderBy(ctx *plancontext.PlanningContext) { +// OldAlignGroupByAndOrderBy TODO Remove once all of horizon planning is done on the operators +func (qp *QueryProjection) OldAlignGroupByAndOrderBy(ctx *plancontext.PlanningContext) { // The ORDER BY can be performed before the OA var newGrouping []GroupBy @@ -635,7 +644,7 @@ func (qp *QueryProjection) AlignGroupByAndOrderBy(ctx *plancontext.PlanningConte used := make([]bool, len(qp.groupByExprs)) for _, orderExpr := range qp.OrderExprs { for i, groupingExpr := range qp.groupByExprs { - if !used[i] && ctx.SemTable.EqualsExpr(groupingExpr.WeightStrExpr, orderExpr.WeightStrExpr) { + if !used[i] && ctx.SemTable.EqualsExpr(groupingExpr.SimplifiedExpr, orderExpr.SimplifiedExpr) { newGrouping = append(newGrouping, groupingExpr) used[i] = true } @@ -656,6 +665,42 @@ func (qp *QueryProjection) AlignGroupByAndOrderBy(ctx *plancontext.PlanningConte qp.groupByExprs = newGrouping } +// AlignGroupByAndOrderBy aligns the group by and order by columns, so they are in the same order +// The GROUP BY clause is a set - the order between the elements does not make any difference, +// so we can simply re-arrange the column order +// We are also free to add more ORDER BY columns than the user asked for which we leverage, +// so the input is already ordered according to the GROUP BY columns used +func (qp *QueryProjection) AlignGroupByAndOrderBy(ctx *plancontext.PlanningContext) bool { + if qp.hasCheckedAlignment { + return false + } + qp.hasCheckedAlignment = true + newGrouping := make([]GroupBy, 0, len(qp.groupByExprs)) + used := make([]bool, len(qp.groupByExprs)) + +outer: + for _, orderBy := range qp.OrderExprs { + for gidx, groupBy := range qp.groupByExprs { + if ctx.SemTable.EqualsExprWithDeps(groupBy.SimplifiedExpr, orderBy.SimplifiedExpr) { + newGrouping = append(newGrouping, groupBy) + used[gidx] = true + continue outer + } + } + return false + } + + // if we get here, it means that all the OrderBy expressions are also in the GroupBy clause + for gidx, gb := range qp.groupByExprs { + if !used[gidx] { + newGrouping = append(newGrouping, gb) + qp.OrderExprs = append(qp.OrderExprs, gb.AsOrderBy()) + } + } + qp.groupByExprs = newGrouping + return true +} + // AddGroupBy does just that func (qp *QueryProjection) AddGroupBy(by GroupBy) { qp.groupByExprs = append(qp.groupByExprs, by) @@ -665,6 +710,19 @@ func (qp *QueryProjection) GetColumnCount() int { return len(qp.SelectExprs) - qp.AddedColumn } +// checkAggregationSupported checks if the aggregation is supported on the given operator tree or not. +// We don't currently support planning for operators having derived tables. +func checkAggregationSupported(op ops.Operator) error { + return rewrite.Visit(op, func(operator ops.Operator) error { + _, isDerived := operator.(*Derived) + projection, isProjection := operator.(*Projection) + if isDerived || (isProjection && projection.TableID != nil) { + return errHorizonNotPlanned() + } + return nil + }) +} + func checkForInvalidGroupingExpressions(expr sqlparser.Expr) error { return sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { if _, isAggregate := node.(sqlparser.AggrFunc); isAggregate { diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection_test.go b/go/vt/vtgate/planbuilder/operators/queryprojection_test.go index 7502e182060..311b4cdea91 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection_test.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection_test.go @@ -50,15 +50,15 @@ func TestQP(t *testing.T) { { sql: "select 1, count(1) from user order by 1", expOrder: []ops.OrderBy{ - {Inner: &sqlparser.Order{Expr: sqlparser.NewIntLiteral("1")}, WeightStrExpr: sqlparser.NewIntLiteral("1")}, + {Inner: &sqlparser.Order{Expr: sqlparser.NewIntLiteral("1")}, SimplifiedExpr: sqlparser.NewIntLiteral("1")}, }, }, { sql: "select id from user order by col, id, 1", expOrder: []ops.OrderBy{ - {Inner: &sqlparser.Order{Expr: sqlparser.NewColName("col")}, WeightStrExpr: sqlparser.NewColName("col")}, - {Inner: &sqlparser.Order{Expr: sqlparser.NewColName("id")}, WeightStrExpr: sqlparser.NewColName("id")}, - {Inner: &sqlparser.Order{Expr: sqlparser.NewColName("id")}, WeightStrExpr: sqlparser.NewColName("id")}, + {Inner: &sqlparser.Order{Expr: sqlparser.NewColName("col")}, SimplifiedExpr: sqlparser.NewColName("col")}, + {Inner: &sqlparser.Order{Expr: sqlparser.NewColName("id")}, SimplifiedExpr: sqlparser.NewColName("id")}, + {Inner: &sqlparser.Order{Expr: sqlparser.NewColName("id")}, SimplifiedExpr: sqlparser.NewColName("id")}, }, }, { @@ -66,7 +66,7 @@ func TestQP(t *testing.T) { expOrder: []ops.OrderBy{ { Inner: &sqlparser.Order{Expr: sqlparser.NewColName("full_name")}, - WeightStrExpr: &sqlparser.FuncExpr{ + SimplifiedExpr: &sqlparser.FuncExpr{ Name: sqlparser.NewIdentifierCI("CONCAT"), Exprs: sqlparser.SelectExprs{ &sqlparser.AliasedExpr{Expr: sqlparser.NewColName("last_name")}, @@ -101,7 +101,7 @@ func TestQP(t *testing.T) { require.Equal(t, len(tcase.expOrder), len(qp.OrderExprs), "not enough order expressions in QP") for index, expOrder := range tcase.expOrder { assert.True(t, sqlparser.Equals.SQLNode(expOrder.Inner, qp.OrderExprs[index].Inner), "want: %+v, got %+v", sqlparser.String(expOrder.Inner), sqlparser.String(qp.OrderExprs[index].Inner)) - assert.True(t, sqlparser.Equals.SQLNode(expOrder.WeightStrExpr, qp.OrderExprs[index].WeightStrExpr), "want: %v, got %v", sqlparser.String(expOrder.WeightStrExpr), sqlparser.String(qp.OrderExprs[index].WeightStrExpr)) + assert.True(t, sqlparser.Equals.SQLNode(expOrder.SimplifiedExpr, qp.OrderExprs[index].SimplifiedExpr), "want: %v, got %v", sqlparser.String(expOrder.SimplifiedExpr), sqlparser.String(qp.OrderExprs[index].SimplifiedExpr)) } } }) diff --git a/go/vt/vtgate/planbuilder/operators/rewrite/rewriters.go b/go/vt/vtgate/planbuilder/operators/rewrite/rewriters.go index 1f40b8a813e..0d81e34fabf 100644 --- a/go/vt/vtgate/planbuilder/operators/rewrite/rewriters.go +++ b/go/vt/vtgate/planbuilder/operators/rewrite/rewriters.go @@ -17,6 +17,8 @@ limitations under the License. package rewrite import ( + "fmt" + "golang.org/x/exp/slices" "vitess.io/vitess/go/vt/vterrors" @@ -30,7 +32,7 @@ type ( op ops.Operator, // op is the operator being visited lhsTables semantics.TableSet, // lhsTables contains the TableSet for all table on the LHS of our parent isRoot bool, // isRoot will be true for the root of the operator tree - ) (ops.Operator, ApplyResult, error) + ) (ops.Operator, *ApplyResult, error) // ShouldVisit is used when we want to control which nodes and ancestors to visit and which to skip ShouldVisit func(ops.Operator) VisitRule @@ -38,23 +40,52 @@ type ( // ApplyResult tracks modifications to node and expression trees. // Only return SameTree when it is acceptable to return the original // input and discard the returned result as a performance improvement. - ApplyResult bool + ApplyResult struct { + Transformations []Rewrite + } + + Rewrite struct { + Message string + Op ops.Operator + } // VisitRule signals to the rewriter if the children of this operator should be visited or not VisitRule bool ) -const ( - SameTree ApplyResult = false - NewTree ApplyResult = true +var ( + SameTree *ApplyResult = nil +) +const ( VisitChildren VisitRule = true SkipChildren VisitRule = false ) +func NewTree(message string, op ops.Operator) *ApplyResult { + if DebugOperatorTree { + fmt.Println(">>>>>>>> " + message) + } + return &ApplyResult{Transformations: []Rewrite{{Message: message, Op: op}}} +} + +func (ar *ApplyResult) Merge(other *ApplyResult) *ApplyResult { + if ar == nil { + return other + } + if other == nil { + return ar + } + return &ApplyResult{Transformations: append(ar.Transformations, other.Transformations...)} +} + +func (ar *ApplyResult) Changed() bool { + return ar != nil +} + // Visit allows for the walking of the operator tree. If any error is returned, the walk is aborted func Visit(root ops.Operator, visitor func(ops.Operator) error) error { - _, _, err := breakableTopDown(root, func(op ops.Operator) (ops.Operator, ApplyResult, VisitRule, error) { + _, _, err := breakableTopDown(root, func(op ops.Operator) (ops.Operator, *ApplyResult, VisitRule, error) { err := visitor(op) if err != nil { return nil, SameTree, SkipChildren, err @@ -80,6 +111,8 @@ func BottomUp( return op, nil } +var DebugOperatorTree = false + // FixedPointBottomUp rewrites an operator tree much like BottomUp does, // but does the rewriting repeatedly, until a tree walk is done with no changes to the tree. func FixedPointBottomUp( @@ -88,15 +121,20 @@ func FixedPointBottomUp( visit VisitF, shouldVisit ShouldVisit, ) (op ops.Operator, err error) { - id := NewTree + var id *ApplyResult op = root - for id == NewTree { + // will loop while the rewriting changes anything + for ok := true; ok; ok = id != SameTree { + if DebugOperatorTree { + fmt.Println(ops.ToTree(op)) + } // Continue the top-down rewriting process as long as changes were made during the last traversal op, id, err = bottomUp(op, semantics.EmptyTableSet(), resolveID, visit, shouldVisit, true) if err != nil { return nil, err } } + return op, nil } @@ -141,29 +179,29 @@ func TopDown( } // Swap takes a tree like a->b->c and swaps `a` and `b`, so we end up with b->a->c -func Swap(a, b ops.Operator) (ops.Operator, error) { - c := b.Inputs() +func Swap(parent, child ops.Operator, message string) (ops.Operator, *ApplyResult, error) { + c := child.Inputs() if len(c) != 1 { - return nil, vterrors.VT13001("Swap can only be used on single input operators") + return nil, nil, vterrors.VT13001("Swap can only be used on single input operators") } - aInputs := slices.Clone(a.Inputs()) + aInputs := slices.Clone(parent.Inputs()) var tmp ops.Operator for i, in := range aInputs { - if in == b { + if in == child { tmp = aInputs[i] aInputs[i] = c[0] break } } if tmp == nil { - return nil, vterrors.VT13001("Swap can only be used when the second argument is an input to the first") + return nil, nil, vterrors.VT13001("Swap can only be used when the second argument is an input to the first") } - b.SetInputs([]ops.Operator{a}) - a.SetInputs(aInputs) + child.SetInputs([]ops.Operator{parent}) + parent.SetInputs(aInputs) - return b, nil + return child, NewTree(message, parent), nil } func bottomUp( @@ -173,13 +211,13 @@ func bottomUp( rewriter VisitF, shouldVisit ShouldVisit, isRoot bool, -) (ops.Operator, ApplyResult, error) { +) (ops.Operator, *ApplyResult, error) { if !shouldVisit(root) { return root, SameTree, nil } oldInputs := root.Inputs() - anythingChanged := false + var anythingChanged *ApplyResult newInputs := make([]ops.Operator, len(oldInputs)) childID := rootID @@ -198,54 +236,50 @@ func bottomUp( } in, changed, err := bottomUp(operator, childID, resolveID, rewriter, shouldVisit, false) if err != nil { - return nil, SameTree, err - } - if changed == NewTree { - anythingChanged = true + return nil, nil, err } + anythingChanged = anythingChanged.Merge(changed) newInputs[i] = in } - if anythingChanged { + if anythingChanged.Changed() { root = root.Clone(newInputs) } newOp, treeIdentity, err := rewriter(root, rootID, isRoot) if err != nil { - return nil, SameTree, err + return nil, nil, err } - if anythingChanged { - treeIdentity = NewTree - } - return newOp, treeIdentity, nil + anythingChanged = anythingChanged.Merge(treeIdentity) + return newOp, anythingChanged, nil } func breakableTopDown( in ops.Operator, - rewriter func(ops.Operator) (ops.Operator, ApplyResult, VisitRule, error), -) (ops.Operator, ApplyResult, error) { + rewriter func(ops.Operator) (ops.Operator, *ApplyResult, VisitRule, error), +) (ops.Operator, *ApplyResult, error) { newOp, identity, visit, err := rewriter(in) if err != nil || visit == SkipChildren { return newOp, identity, err } - anythingChanged := identity == NewTree + var anythingChanged *ApplyResult oldInputs := newOp.Inputs() newInputs := make([]ops.Operator, len(oldInputs)) for i, oldInput := range oldInputs { newInputs[i], identity, err = breakableTopDown(oldInput, rewriter) - anythingChanged = anythingChanged || identity == NewTree + anythingChanged = anythingChanged.Merge(identity) if err != nil { return nil, SameTree, err } } - if anythingChanged { - return newOp.Clone(newInputs), NewTree, nil + if anythingChanged.Changed() { + return newOp, SameTree, nil } - return newOp, SameTree, nil + return newOp.Clone(newInputs), anythingChanged, nil } // topDown is a helper function that recursively traverses the operator tree from the @@ -258,22 +292,21 @@ func topDown( rewriter VisitF, shouldVisit ShouldVisit, isRoot bool, -) (ops.Operator, ApplyResult, error) { - newOp, treeIdentity, err := rewriter(root, rootID, isRoot) +) (ops.Operator, *ApplyResult, error) { + newOp, anythingChanged, err := rewriter(root, rootID, isRoot) if err != nil { - return nil, false, err + return nil, nil, err } if !shouldVisit(root) { - return newOp, treeIdentity, nil + return newOp, anythingChanged, nil } - if treeIdentity == NewTree { + if anythingChanged.Changed() { root = newOp } oldInputs := root.Inputs() - anythingChanged := treeIdentity == NewTree newInputs := make([]ops.Operator, len(oldInputs)) childID := rootID @@ -285,16 +318,14 @@ func topDown( } in, changed, err := topDown(operator, childID, resolveID, rewriter, shouldVisit, false) if err != nil { - return nil, false, err - } - if changed == NewTree { - anythingChanged = true + return nil, nil, err } + anythingChanged = anythingChanged.Merge(changed) newInputs[i] = in } - if anythingChanged { - return root.Clone(newInputs), NewTree, nil + if anythingChanged != SameTree { + return root.Clone(newInputs), anythingChanged, nil } return root, SameTree, nil diff --git a/go/vt/vtgate/planbuilder/operators/route.go b/go/vt/vtgate/planbuilder/operators/route.go index eed7b31d8be..bcff3ffa652 100644 --- a/go/vt/vtgate/planbuilder/operators/route.go +++ b/go/vt/vtgate/planbuilder/operators/route.go @@ -18,6 +18,7 @@ package operators import ( "fmt" + "strings" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -25,7 +26,6 @@ import ( "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" - "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -537,13 +537,12 @@ func createProjection(src ops.Operator) (*Projection, error) { return nil, err } for _, col := range cols { - proj.Columns = append(proj.Columns, Expr{E: col.Expr}) - proj.ColumnNames = append(proj.ColumnNames, col.As.String()) + proj.addUnexploredExpr(col, col.Expr) } return proj, nil } -func (r *Route) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr) (ops.Operator, int, error) { +func (r *Route) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, _, addToGroupBy bool) (ops.Operator, int, error) { removeKeyspaceFromSelectExpr(expr) // check if columns is already added. @@ -551,23 +550,58 @@ func (r *Route) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.Alia if err != nil { return nil, 0, err } - colAsExpr := func(e *sqlparser.AliasedExpr) sqlparser.Expr { return e.Expr } + colAsExpr := func(e *sqlparser.AliasedExpr) sqlparser.Expr { + return e.Expr + } if offset, found := canReuseColumn(ctx, cols, expr.Expr, colAsExpr); found { return r, offset, nil } - proj, exists := r.Source.(*Projection) - if !exists { - proj, err = createProjection(r.Source) - if err != nil { - return nil, 0, err + // if column is not already present, we check if we can easily find a projection + // or aggregation in our source that we can add to + if ok, offset := addColumnToInput(r.Source, expr, addToGroupBy); ok { + return r, offset, nil + } + + // If no-one could be found, we probably don't have one yet, so we add one here + src, err := createProjection(r.Source) + if err != nil { + return nil, 0, err + } + r.Source = src + + // And since we are under the route, we don't need to continue pushing anything further down + offset := src.addNoPushCol(expr, false) + if err != nil { + return nil, 0, err + } + return r, offset, nil +} + +type selectExpressions interface { + addNoPushCol(expr *sqlparser.AliasedExpr, addToGroupBy bool) int + isDerived() bool +} + +func addColumnToInput(operator ops.Operator, expr *sqlparser.AliasedExpr, addToGroupBy bool) (bool, int) { + switch op := operator.(type) { + case *CorrelatedSubQueryOp: + return addColumnToInput(op.Outer, expr, addToGroupBy) + case *Limit: + return addColumnToInput(op.Source, expr, addToGroupBy) + case *Ordering: + return addColumnToInput(op.Source, expr, addToGroupBy) + case selectExpressions: + if op.isDerived() { + // if the only thing we can push to is a derived table, + // we have to add a new projection and can't build on this one + return false, 0 } - r.Source = proj + offset := op.addNoPushCol(expr, addToGroupBy) + return true, offset + default: + return false, 0 } - // add the new column - proj.Columns = append(proj.Columns, Expr{E: expr.Expr}) - proj.ColumnNames = append(proj.ColumnNames, expr.As.String()) - return r, len(proj.Columns) - 1, nil } func (r *Route) GetColumns() ([]*sqlparser.AliasedExpr, error) { @@ -633,9 +667,9 @@ func (r *Route) planOffsets(ctx *plancontext.PlanningContext) (err error) { WOffset: -1, Direction: order.Inner.Direction, } - if ctx.SemTable.NeedsWeightString(order.WeightStrExpr) { - wrap := aeWrap(weightStringFor(order.WeightStrExpr)) - _, offset, err = r.AddColumn(ctx, wrap) + if ctx.SemTable.NeedsWeightString(order.SimplifiedExpr) { + wrap := aeWrap(weightStringFor(order.SimplifiedExpr)) + _, offset, err = r.AddColumn(ctx, wrap, true, false) if err != nil { return err } @@ -653,12 +687,12 @@ func weightStringFor(expr sqlparser.Expr) sqlparser.Expr { func (r *Route) getOffsetFor(ctx *plancontext.PlanningContext, order ops.OrderBy, columns []*sqlparser.AliasedExpr) (int, error) { for idx, column := range columns { - if sqlparser.Equals.Expr(order.WeightStrExpr, column.Expr) { + if sqlparser.Equals.Expr(order.SimplifiedExpr, column.Expr) { return idx, nil } } - _, offset, err := r.AddColumn(ctx, aeWrap(order.Inner.Expr)) + _, offset, err := r.AddColumn(ctx, aeWrap(order.Inner.Expr), true, false) if err != nil { return 0, err } @@ -676,9 +710,30 @@ func (r *Route) Description() ops.OpDescription { } func (r *Route) ShortDescription() string { + first := r.Routing.OpCode().String() + ks := r.Routing.Keyspace() - if ks == nil { - return r.Routing.OpCode().String() + if ks != nil { + first = fmt.Sprintf("%s on %s", r.Routing.OpCode().String(), ks.Name) + } + + orderBy, err := r.Source.GetOrdering() + if err != nil { + return first + } + + ordering := "" + if len(orderBy) > 0 { + var oo []string + for _, o := range orderBy { + oo = append(oo, sqlparser.String(o.Inner)) + } + ordering = " order by " + strings.Join(oo, ",") } - return fmt.Sprintf("%s on %s", r.Routing.OpCode().String(), ks.Name) + + return first + ordering +} + +func (r *Route) setTruncateColumnCount(offset int) { + r.ResultColumns = offset } diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index bf0ea8f2a30..b99b0d55d71 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -49,18 +49,18 @@ type ( // Here we try to merge query parts into the same route primitives. At the end of this process, // all the operators in the tree are guaranteed to be PhysicalOperators func transformToPhysical(ctx *plancontext.PlanningContext, in ops.Operator) (ops.Operator, error) { - op, err := rewrite.BottomUpAll(in, TableID, func(operator ops.Operator, ts semantics.TableSet, _ bool) (ops.Operator, rewrite.ApplyResult, error) { + op, err := rewrite.BottomUpAll(in, TableID, func(operator ops.Operator, ts semantics.TableSet, _ bool) (ops.Operator, *rewrite.ApplyResult, error) { switch op := operator.(type) { case *QueryGraph: return optimizeQueryGraph(ctx, op) case *Join: return optimizeJoin(ctx, op) case *Derived: - return optimizeDerived(ctx, op) + return pushDownDerived(ctx, op) case *SubQuery: return optimizeSubQuery(ctx, op, ts) case *Filter: - return optimizeFilter(op) + return pushDownFilter(op) default: return operator, rewrite.SameTree, nil } @@ -73,18 +73,15 @@ func transformToPhysical(ctx *plancontext.PlanningContext, in ops.Operator) (ops return compact(ctx, op) } -func optimizeFilter(op *Filter) (ops.Operator, rewrite.ApplyResult, error) { - if route, ok := op.Source.(*Route); ok { - // let's push the filter into the route - op.Source = route.Source - route.Source = op - return route, rewrite.NewTree, nil +func pushDownFilter(op *Filter) (ops.Operator, *rewrite.ApplyResult, error) { + if _, ok := op.Source.(*Route); ok { + return rewrite.Swap(op, op.Source, "push filter into Route") } return op, rewrite.SameTree, nil } -func optimizeDerived(ctx *plancontext.PlanningContext, op *Derived) (ops.Operator, rewrite.ApplyResult, error) { +func pushDownDerived(ctx *plancontext.PlanningContext, op *Derived) (ops.Operator, *rewrite.ApplyResult, error) { innerRoute, ok := op.Source.(*Route) if !ok { return op, rewrite.SameTree, nil @@ -95,22 +92,15 @@ func optimizeDerived(ctx *plancontext.PlanningContext, op *Derived) (ops.Operato return op, rewrite.SameTree, nil } - op.Source = innerRoute.Source - innerRoute.Source = op - - return innerRoute, rewrite.NewTree, nil + return rewrite.Swap(op, op.Source, "push derived under route") } -func optimizeJoin(ctx *plancontext.PlanningContext, op *Join) (ops.Operator, rewrite.ApplyResult, error) { - join, err := mergeOrJoin(ctx, op.LHS, op.RHS, sqlparser.SplitAndExpression(nil, op.Predicate), !op.LeftJoin) - if err != nil { - return nil, false, err - } - return join, rewrite.NewTree, nil +func optimizeJoin(ctx *plancontext.PlanningContext, op *Join) (ops.Operator, *rewrite.ApplyResult, error) { + return mergeOrJoin(ctx, op.LHS, op.RHS, sqlparser.SplitAndExpression(nil, op.Predicate), !op.LeftJoin) } -func optimizeQueryGraph(ctx *plancontext.PlanningContext, op *QueryGraph) (result ops.Operator, changed rewrite.ApplyResult, err error) { - changed = rewrite.NewTree +func optimizeQueryGraph(ctx *plancontext.PlanningContext, op *QueryGraph) (result ops.Operator, changed *rewrite.ApplyResult, err error) { + switch { case ctx.PlannerVersion == querypb.ExecuteOptions_Gen4Left2Right: result, err = leftToRightSolve(ctx, op) @@ -125,6 +115,7 @@ func optimizeQueryGraph(ctx *plancontext.PlanningContext, op *QueryGraph) (resul result = newFilter(result, ctx.SemTable.AndExpressions(unresolved...)) } + changed = rewrite.NewTree("solved query graph", result) return } @@ -252,7 +243,7 @@ func leftToRightSolve(ctx *plancontext.PlanningContext, qg *QueryGraph) (ops.Ope continue } joinPredicates := qg.GetPredicates(TableID(acc), TableID(plan)) - acc, err = mergeOrJoin(ctx, acc, plan, joinPredicates, true) + acc, _, err = mergeOrJoin(ctx, acc, plan, joinPredicates, true) if err != nil { return nil, err } @@ -386,7 +377,7 @@ func getJoinFor(ctx *plancontext.PlanningContext, cm opCacheMap, lhs, rhs ops.Op return cachedPlan, nil } - join, err := mergeOrJoin(ctx, lhs, rhs, joinPredicates, true) + join, _, err := mergeOrJoin(ctx, lhs, rhs, joinPredicates, true) if err != nil { return nil, err } @@ -413,30 +404,38 @@ func requiresSwitchingSides(ctx *plancontext.PlanningContext, op ops.Operator) b return required } -func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPredicates []sqlparser.Expr, inner bool) (ops.Operator, error) { +func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPredicates []sqlparser.Expr, inner bool) (ops.Operator, *rewrite.ApplyResult, error) { newPlan, err := Merge(ctx, lhs, rhs, joinPredicates, newJoinMerge(ctx, joinPredicates, inner)) if err != nil { - return nil, err + return nil, nil, err } if newPlan != nil { - return newPlan, nil + return newPlan, rewrite.NewTree("merge routes into single operator", newPlan), nil } if len(joinPredicates) > 0 && requiresSwitchingSides(ctx, rhs) { if !inner { - return nil, vterrors.VT12001("LEFT JOIN with derived tables") + return nil, nil, vterrors.VT12001("LEFT JOIN with derived tables") } if requiresSwitchingSides(ctx, lhs) { - return nil, vterrors.VT12001("JOIN between derived tables") + return nil, nil, vterrors.VT12001("JOIN between derived tables") } join := NewApplyJoin(Clone(rhs), Clone(lhs), nil, !inner) - return pushJoinPredicates(ctx, joinPredicates, join) + newOp, err := pushJoinPredicates(ctx, joinPredicates, join) + if err != nil { + return nil, nil, err + } + return newOp, rewrite.NewTree("merge routes, but switch sides", newOp), nil } join := NewApplyJoin(Clone(lhs), Clone(rhs), nil, !inner) - return pushJoinPredicates(ctx, joinPredicates, join) + newOp, err := pushJoinPredicates(ctx, joinPredicates, join) + if err != nil { + return nil, nil, err + } + return newOp, rewrite.NewTree("logical join to applyJoin ", newOp), nil } func operatorsToRoutes(a, b ops.Operator) (*Route, *Route) { diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 26f6fae23a5..28dd005e2f1 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -27,7 +27,7 @@ import ( "vitess.io/vitess/go/vt/vtgate/semantics" ) -func optimizeSubQuery(ctx *plancontext.PlanningContext, op *SubQuery, ts semantics.TableSet) (ops.Operator, rewrite.ApplyResult, error) { +func optimizeSubQuery(ctx *plancontext.PlanningContext, op *SubQuery, ts semantics.TableSet) (ops.Operator, *rewrite.ApplyResult, error) { var unmerged []*SubQueryOp // first loop over the subqueries and try to merge them into the outer plan @@ -44,7 +44,7 @@ func optimizeSubQuery(ctx *plancontext.PlanningContext, op *SubQuery, ts semanti } merged, err := tryMergeSubQueryOp(ctx, outer, innerOp, newInner, preds, newSubQueryMerge(ctx, newInner), ts) if err != nil { - return nil, false, err + return nil, nil, err } if merged != nil { @@ -65,20 +65,20 @@ func optimizeSubQuery(ctx *plancontext.PlanningContext, op *SubQuery, ts semanti if inner.ExtractedSubquery.OpCode == int(popcode.PulloutExists) { correlatedTree, err := createCorrelatedSubqueryOp(ctx, innerOp, outer, preds, inner.ExtractedSubquery) if err != nil { - return nil, false, err + return nil, nil, err } outer = correlatedTree continue } - return nil, false, vterrors.VT12001("cross-shard correlated subquery") + return nil, nil, vterrors.VT12001("cross-shard correlated subquery") } for _, tree := range unmerged { tree.Outer = outer outer = tree } - return outer, rewrite.NewTree, nil + return outer, rewrite.NewTree("merged subqueries", outer), nil } func unresolvedAndSource(ctx *plancontext.PlanningContext, op ops.Operator) ([]sqlparser.Expr, ops.Operator) { @@ -94,65 +94,6 @@ func unresolvedAndSource(ctx *plancontext.PlanningContext, op ops.Operator) ([]s return preds, op } -func mergeSubQueryOp(ctx *plancontext.PlanningContext, outer *Route, inner *Route, subq *SubQueryInner, mergedRouting Routing) (*Route, error) { - subq.ExtractedSubquery.Merged = true - - switch outerRouting := outer.Routing.(type) { - case *ShardedRouting: - return mergeSubQueryFromTableRouting(ctx, outer, inner, outerRouting, subq) - default: - outer.Routing = mergedRouting - } - - outer.MergedWith = append(outer.MergedWith, inner) - - return outer, nil -} - -func mergeSubQueryFromTableRouting( - ctx *plancontext.PlanningContext, - outer, inner *Route, - outerRouting *ShardedRouting, - subq *SubQueryInner, -) (*Route, error) { - // When merging an inner query with its outer query, we can remove the - // inner query from the list of predicates that can influence routing of - // the outer query. - // - // Note that not all inner queries necessarily are part of the routing - // predicates list, so this might be a no-op. - subQueryWasPredicate := false - for i, predicate := range outerRouting.SeenPredicates { - if ctx.SemTable.EqualsExpr(predicate, subq.ExtractedSubquery) { - outerRouting.SeenPredicates = append(outerRouting.SeenPredicates[:i], outerRouting.SeenPredicates[i+1:]...) - - subQueryWasPredicate = true - - // The `ExtractedSubquery` of an inner query is unique (due to the uniqueness of bind variable names) - // so we can stop after the first match. - break - } - } - - err := outerRouting.resetRoutingSelections(ctx) - if err != nil { - return nil, err - } - - if subQueryWasPredicate { - if innerTR, isTR := inner.Routing.(*ShardedRouting); isTR { - // Copy Vindex predicates from the inner route to the upper route. - // If we can route based on some of these predicates, the routing can improve - outerRouting.VindexPreds = append(outerRouting.VindexPreds, innerTR.VindexPreds...) - } - - if inner.Routing.OpCode() == engine.None { - outer.Routing = &NoneRouting{keyspace: outerRouting.keyspace} - } - } - return outer, nil -} - func isMergeable(ctx *plancontext.PlanningContext, query sqlparser.SelectStatement, op ops.Operator) bool { validVindex := func(expr sqlparser.Expr) bool { sc := findColumnVindex(ctx, op, expr) @@ -353,7 +294,9 @@ func rewriteColumnsInSubqueryOpForJoin( // get the bindVariable for that column name and replace it in the subquery typ, _, _ := ctx.SemTable.TypeForExpr(node) - bindVar := ctx.ReservedVars.ReserveColName(node) + bindVar := ctx.GetArgumentFor(node, func() string { + return ctx.ReservedVars.ReserveColName(node) + }) cursor.Replace(sqlparser.NewTypedArgument(bindVar, typ)) // check whether the bindVariable already exists in the joinVars of the other tree _, alreadyExists := outerTree.Vars[bindVar] @@ -361,7 +304,7 @@ func rewriteColumnsInSubqueryOpForJoin( return true } // if it does not exist, then push this as an output column there and add it to the joinVars - newInnerOp, offset, err := resultInnerOp.AddColumn(ctx, aeWrap(node)) + newInnerOp, offset, err := resultInnerOp.AddColumn(ctx, aeWrap(node), true, false) if err != nil { rewriteError = err return false @@ -413,7 +356,7 @@ func createCorrelatedSubqueryOp( // we do so by checking that the column names are the same and their recursive dependencies are the same // so the column names `user.a` and `a` would be considered equal as long as both are bound to the same table for colName, bindVar := range bindVars { - if ctx.SemTable.EqualsExpr(node, colName) { + if ctx.SemTable.EqualsExprWithDeps(node, colName) { cursor.Replace(sqlparser.NewArgument(bindVar)) return true } @@ -427,7 +370,7 @@ func createCorrelatedSubqueryOp( bindVars[node] = bindVar // if it does not exist, then push this as an output column in the outerOp and add it to the joinVars - newOuterOp, offset, err := resultOuterOp.AddColumn(ctx, aeWrap(node)) + newOuterOp, offset, err := resultOuterOp.AddColumn(ctx, aeWrap(node), true, false) if err != nil { rewriteError = err return true diff --git a/go/vt/vtgate/planbuilder/operators/table.go b/go/vt/vtgate/planbuilder/operators/table.go index 30ae90735a4..92735a055cd 100644 --- a/go/vt/vtgate/planbuilder/operators/table.go +++ b/go/vt/vtgate/planbuilder/operators/table.go @@ -63,7 +63,10 @@ func (to *Table) AddPredicate(_ *plancontext.PlanningContext, expr sqlparser.Exp return newFilter(to, expr), nil } -func (to *Table) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr) (ops.Operator, int, error) { +func (to *Table) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, _, addToGroupBy bool) (ops.Operator, int, error) { + if addToGroupBy { + return nil, 0, vterrors.VT13001("tried to add group by to a table") + } offset, err := addColumn(ctx, to, expr.Expr) if err != nil { return nil, 0, err @@ -97,7 +100,7 @@ func (to *Table) TablesUsed() []string { func addColumn(ctx *plancontext.PlanningContext, op ColNameColumns, e sqlparser.Expr) (int, error) { col, ok := e.(*sqlparser.ColName) if !ok { - return 0, vterrors.VT13001("cannot push this expression to a table/vindex") + return 0, vterrors.VT13001("cannot push this expression to a table/vindex: %s", sqlparser.String(e)) } sqlparser.RemoveKeyspaceFromColName(col) cols := op.GetColNames() diff --git a/go/vt/vtgate/planbuilder/operators/union.go b/go/vt/vtgate/planbuilder/operators/union.go index 5dd7f0720df..4fff00ef819 100644 --- a/go/vt/vtgate/planbuilder/operators/union.go +++ b/go/vt/vtgate/planbuilder/operators/union.go @@ -153,9 +153,9 @@ func (u *Union) GetSelectFor(source int) (*sqlparser.Select, error) { } } -func (u *Union) Compact(*plancontext.PlanningContext) (ops.Operator, rewrite.ApplyResult, error) { +func (u *Union) Compact(*plancontext.PlanningContext) (ops.Operator, *rewrite.ApplyResult, error) { var newSources []ops.Operator - anythingChanged := false + var anythingChanged *rewrite.ApplyResult for _, source := range u.Sources { var other *Union horizon, ok := source.(*Horizon) @@ -169,7 +169,7 @@ func (u *Union) Compact(*plancontext.PlanningContext) (ops.Operator, rewrite.App newSources = append(newSources, source) continue } - anythingChanged = true + anythingChanged = anythingChanged.Merge(rewrite.NewTree("merged UNIONs", other)) switch { case len(other.Ordering) == 0 && !other.Distinct: fallthrough @@ -181,15 +181,11 @@ func (u *Union) Compact(*plancontext.PlanningContext) (ops.Operator, rewrite.App newSources = append(newSources, other) } } - if anythingChanged { + if anythingChanged != rewrite.SameTree { u.Sources = newSources } - identity := rewrite.SameTree - if anythingChanged { - identity = rewrite.NewTree - } - return u, identity, nil + return u, anythingChanged, nil } func (u *Union) NoLHSTableSet() {} diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 3d8fa523188..f9c831860f1 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -43,7 +43,7 @@ func (u *Update) Introduces() semantics.TableSet { } // Clone implements the Operator interface -func (u *Update) Clone(inputs []ops.Operator) ops.Operator { +func (u *Update) Clone([]ops.Operator) ops.Operator { return &Update{ QTable: u.QTable, VTable: u.VTable, diff --git a/go/vt/vtgate/planbuilder/operators/vindex.go b/go/vt/vtgate/planbuilder/operators/vindex.go index d2eb3fdf98d..79104fc7364 100644 --- a/go/vt/vtgate/planbuilder/operators/vindex.go +++ b/go/vt/vtgate/planbuilder/operators/vindex.go @@ -62,7 +62,11 @@ func (v *Vindex) Clone([]ops.Operator) ops.Operator { return &clone } -func (v *Vindex) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr) (ops.Operator, int, error) { +func (v *Vindex) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, _, addToGroupBy bool) (ops.Operator, int, error) { + if addToGroupBy { + return nil, 0, vterrors.VT13001("tried to add group by to a table") + } + offset, err := addColumn(ctx, v, expr.Expr) if err != nil { return nil, 0, err diff --git a/go/vt/vtgate/planbuilder/ordering.go b/go/vt/vtgate/planbuilder/ordering.go index 5abf2823e9e..2a8613620e7 100644 --- a/go/vt/vtgate/planbuilder/ordering.go +++ b/go/vt/vtgate/planbuilder/ordering.go @@ -103,17 +103,17 @@ func planOAOrdering(pb *primitiveBuilder, orderBy v3OrderBy, oa *orderedAggregat case *sqlparser.CastExpr: col, ok := expr.Expr.(*sqlparser.ColName) if !ok { - return nil, vterrors.VT12001(fmt.Sprintf("in scatter query: complex ORDER BY expression: %s", sqlparser.String(expr))) + return nil, complexOrderBy(sqlparser.String(expr)) } orderByCol = col.Metadata.(*column) case *sqlparser.ConvertExpr: col, ok := expr.Expr.(*sqlparser.ColName) if !ok { - return nil, vterrors.VT12001(fmt.Sprintf("in scatter query: complex ORDER BY expression: %s", sqlparser.String(expr))) + return nil, complexOrderBy(sqlparser.String(expr)) } orderByCol = col.Metadata.(*column) default: - return nil, vterrors.VT12001(fmt.Sprintf("in scatter query: complex ORDER BY expression: %v", sqlparser.String(expr))) + return nil, complexOrderBy(sqlparser.String(expr)) } // Match orderByCol against the group by columns. @@ -294,7 +294,7 @@ func planRouteOrdering(orderBy v3OrderBy, node *route) (logicalPlan, error) { case *sqlparser.UnaryExpr: col, ok := expr.Expr.(*sqlparser.ColName) if !ok { - return nil, vterrors.VT12001(fmt.Sprintf("in scatter query: complex ORDER BY expression: %s", sqlparser.String(expr))) + return nil, complexOrderBy(sqlparser.String(expr)) } c := col.Metadata.(*column) for i, rc := range node.resultColumns { @@ -304,7 +304,7 @@ func planRouteOrdering(orderBy v3OrderBy, node *route) (logicalPlan, error) { } } default: - return nil, vterrors.VT12001(fmt.Sprintf("in scatter query: complex ORDER BY expression: %s", sqlparser.String(expr))) + return nil, complexOrderBy(sqlparser.String(expr)) } // If column is not found, then the order by is referencing // a column that's not on the select list. @@ -350,3 +350,7 @@ func planRouteOrdering(orderBy v3OrderBy, node *route) (logicalPlan, error) { } return newMergeSort(node), nil } + +func complexOrderBy(s string) error { + return vterrors.VT12001(fmt.Sprintf("in scatter query: complex ORDER BY expression: %s", s)) +} diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 8259bc59cbf..cf6c5b944a8 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -39,6 +39,7 @@ import ( "vitess.io/vitess/go/test/utils" vschemapb "vitess.io/vitess/go/vt/proto/vschema" + oprewriters "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/mysql/collations" @@ -289,6 +290,7 @@ func TestViews(t *testing.T) { } func TestOne(t *testing.T) { + oprewriters.DebugOperatorTree = true vschema := &vschemaWrapper{ v: loadSchema(t, "vschemas/schema.json", true), } diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 653691a4697..e8045ce0b04 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -35,21 +35,26 @@ type PlanningContext struct { SkipPredicates map[sqlparser.Expr]any PlannerVersion querypb.ExecuteOptions_PlannerVersion RewriteDerivedExpr bool + + // If we during planning have turned this expression into an argument name, + // we can continue using the same argument name + ReservedArguments map[sqlparser.Expr]string } func NewPlanningContext(reservedVars *sqlparser.ReservedVars, semTable *semantics.SemTable, vschema VSchema, version querypb.ExecuteOptions_PlannerVersion) *PlanningContext { ctx := &PlanningContext{ - ReservedVars: reservedVars, - SemTable: semTable, - VSchema: vschema, - JoinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, - SkipPredicates: map[sqlparser.Expr]any{}, - PlannerVersion: version, + ReservedVars: reservedVars, + SemTable: semTable, + VSchema: vschema, + JoinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, + SkipPredicates: map[sqlparser.Expr]any{}, + PlannerVersion: version, + ReservedArguments: map[sqlparser.Expr]string{}, } return ctx } -func (c PlanningContext) IsSubQueryToReplace(e sqlparser.Expr) bool { +func (c *PlanningContext) IsSubQueryToReplace(e sqlparser.Expr) bool { ext, ok := e.(*sqlparser.Subquery) if !ok { return false @@ -61,3 +66,14 @@ func (c PlanningContext) IsSubQueryToReplace(e sqlparser.Expr) bool { } return false } + +func (ctx *PlanningContext) GetArgumentFor(expr sqlparser.Expr, f func() string) string { + for key, name := range ctx.ReservedArguments { + if ctx.SemTable.EqualsExpr(key, expr) { + return name + } + } + bvName := f() + ctx.ReservedArguments[expr] = bvName + return bvName +} diff --git a/go/vt/vtgate/planbuilder/projection.go b/go/vt/vtgate/planbuilder/projection.go index a3274d3d0e4..2444b916d48 100644 --- a/go/vt/vtgate/planbuilder/projection.go +++ b/go/vt/vtgate/planbuilder/projection.go @@ -42,6 +42,13 @@ var _ logicalPlan = (*projection)(nil) // WireupGen4 implements the logicalPlan interface func (p *projection) WireupGen4(ctx *plancontext.PlanningContext) error { + if p.primitive != nil { + // if primitive is not nil, it means that the horizon planning in the operator phase already + // created all the needed evalengine expressions. + // we don't need to do anything here, let's just shortcut out of this call + return p.source.WireupGen4(ctx) + } + columns := make([]evalengine.Expr, 0, len(p.columns)) for _, expr := range p.columns { convert, err := evalengine.Translate(expr, &evalengine.Config{ diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 6573a46d495..78dda15aa76 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -1,4 +1,293 @@ [ + { + "comment": "count(*) spread across join", + "query": "select count(*) from user join user_extra on user.foo = user_extra.bar", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select count(*) from user join user_extra on user.foo = user_extra.bar", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum_count_star(0) AS count(*)", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as count(*)" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), `user`.foo from `user` where 1 != 1 group by `user`.foo", + "Query": "select count(*), `user`.foo from `user` group by `user`.foo", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from user_extra where 1 != 1 group by .0", + "Query": "select count(*) from user_extra where user_extra.bar = :user_foo group by .0", + "Table": "user_extra" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "sum spread across join", + "query": "select sum(user.col) from user join user_extra on user.foo = user_extra.bar", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select sum(user.col) from user join user_extra on user.foo = user_extra.bar", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS sum(`user`.col)", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * COALESCE([COLUMN 1], INT64(1)) as sum(`user`.col)" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:1,R:1", + "JoinVars": { + "user_foo": 0 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.foo, sum(`user`.col), weight_string(`user`.foo) from `user` where 1 != 1 group by `user`.foo, weight_string(`user`.foo)", + "Query": "select `user`.foo, sum(`user`.col), weight_string(`user`.foo) from `user` group by `user`.foo, weight_string(`user`.foo)", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1, count(*) from user_extra where 1 != 1 group by 1", + "Query": "select 1, count(*) from user_extra where user_extra.bar = :user_foo group by 1", + "Table": "user_extra" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "count spread across join", + "query": "select count(user.col) from user join user_extra on user.foo = user_extra.bar", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select count(user.col) from user join user_extra on user.foo = user_extra.bar", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum_count(0) AS count(`user`.col)", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as count(`user`.col)" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(`user`.col), `user`.foo from `user` where 1 != 1 group by `user`.foo", + "Query": "select count(`user`.col), `user`.foo from `user` group by `user`.foo", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from user_extra where 1 != 1 group by .0", + "Query": "select count(*) from user_extra where user_extra.bar = :user_foo group by .0", + "Table": "user_extra" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "max spread across join", + "query": "select max(user.col) from user join user_extra on user.foo = user_extra.bar", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select max(user.col) from user join user_extra on user.foo = user_extra.bar", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "max(0) AS max(`user`.col)", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select max(`user`.col), `user`.foo from `user` where 1 != 1 group by `user`.foo", + "Query": "select max(`user`.col), `user`.foo from `user` group by `user`.foo", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1 group by .0", + "Query": "select 1 from user_extra where user_extra.bar = :user_foo group by .0", + "Table": "user_extra" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "min spread across join RHS", + "query": "select min(user_extra.col) from user join user_extra on user.foo = user_extra.bar", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select min(user_extra.col) from user join user_extra on user.foo = user_extra.bar", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "min(0) AS min(user_extra.col)", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "R:0", + "JoinVars": { + "user_foo": 0 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.foo from `user` where 1 != 1 group by `user`.foo", + "Query": "select `user`.foo from `user` group by `user`.foo", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select min(user_extra.col) from user_extra where 1 != 1 group by .0", + "Query": "select min(user_extra.col) from user_extra where user_extra.bar = :user_foo group by .0", + "Table": "user_extra" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, { "comment": "group by a unique vindex should revert to simple route, and having clause should find the correct symbols.", "query": "select id, count(*) c from user group by id having max(col) > 10", @@ -714,15 +1003,15 @@ }, { "comment": "scatter aggregate multiple group by (columns)", - "query": "select a, b, count(*) from user group by b, a", + "query": "select a, b, count(*) from user group by a, b", "v3-plan": { "QueryType": "SELECT", - "Original": "select a, b, count(*) from user group by b, a", + "Original": "select a, b, count(*) from user group by a, b", "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "sum_count(2) AS count", - "GroupBy": "1, 0", + "GroupBy": "0, 1", "Inputs": [ { "OperatorType": "Route", @@ -731,9 +1020,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, count(*), weight_string(b), weight_string(a) from `user` where 1 != 1 group by b, a, weight_string(b), weight_string(a)", - "OrderBy": "(1|3) ASC, (0|4) ASC", - "Query": "select a, b, count(*), weight_string(b), weight_string(a) from `user` group by b, a, weight_string(b), weight_string(a) order by b asc, a asc", + "FieldQuery": "select a, b, count(*), weight_string(a), weight_string(b) from `user` where 1 != 1 group by a, b, weight_string(a), weight_string(b)", + "OrderBy": "(0|3) ASC, (1|4) ASC", + "Query": "select a, b, count(*), weight_string(a), weight_string(b) from `user` group by a, b, weight_string(a), weight_string(b) order by a asc, b asc", "ResultColumns": 3, "Table": "`user`" } @@ -742,7 +1031,7 @@ }, "gen4-plan": { "QueryType": "SELECT", - "Original": "select a, b, count(*) from user group by b, a", + "Original": "select a, b, count(*) from user group by a, b", "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", @@ -757,9 +1046,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, count(*), weight_string(a), weight_string(b) from `user` where 1 != 1 group by a, weight_string(a), b, weight_string(b)", + "FieldQuery": "select a, b, count(*), weight_string(a), weight_string(b) from `user` where 1 != 1 group by a, b, weight_string(a), weight_string(b)", "OrderBy": "(0|3) ASC, (1|4) ASC", - "Query": "select a, b, count(*), weight_string(a), weight_string(b) from `user` group by a, weight_string(a), b, weight_string(b) order by a asc, b asc", + "Query": "select a, b, count(*), weight_string(a), weight_string(b) from `user` group by a, b, weight_string(a), weight_string(b) order by a asc, b asc", "Table": "`user`" } ] @@ -804,7 +1093,7 @@ "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "sum_count_star(2) AS count(*)", - "GroupBy": "(0|3), (1|4)", + "GroupBy": "(1|3), (0|4)", "ResultColumns": 3, "Inputs": [ { @@ -814,9 +1103,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, count(*), weight_string(a), weight_string(b) from `user` where 1 != 1 group by a, weight_string(a), b, weight_string(b)", - "OrderBy": "(0|3) ASC, (1|4) ASC", - "Query": "select a, b, count(*), weight_string(a), weight_string(b) from `user` group by a, weight_string(a), b, weight_string(b) order by a asc, b asc", + "FieldQuery": "select a, b, count(*), weight_string(b), weight_string(a) from `user` where 1 != 1 group by b, a, weight_string(b), weight_string(a)", + "OrderBy": "(1|3) ASC, (0|4) ASC", + "Query": "select a, b, count(*), weight_string(b), weight_string(a) from `user` group by b, a, weight_string(b), weight_string(a) order by b asc, a asc", "Table": "`user`" } ] @@ -861,7 +1150,7 @@ "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "sum_count_star(2) AS count(*)", - "GroupBy": "(0|3), (1|4)", + "GroupBy": "(1|3), (0|4)", "ResultColumns": 3, "Inputs": [ { @@ -871,9 +1160,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, count(*), weight_string(a), weight_string(b) from `user` where 1 != 1 group by a, weight_string(a), b, weight_string(b)", - "OrderBy": "(0|3) ASC, (1|4) ASC", - "Query": "select a, b, count(*), weight_string(a), weight_string(b) from `user` group by a, weight_string(a), b, weight_string(b) order by a asc, b asc", + "FieldQuery": "select a, b, count(*), weight_string(b), weight_string(a) from `user` where 1 != 1 group by b, a, weight_string(b), weight_string(a)", + "OrderBy": "(1|3) ASC, (0|4) ASC", + "Query": "select a, b, count(*), weight_string(b), weight_string(a) from `user` group by b, a, weight_string(b), weight_string(a) order by b asc, a asc", "Table": "`user`" } ] @@ -1038,9 +1327,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` where 1 != 1 group by a, weight_string(a), b, weight_string(b), c, weight_string(c)", + "FieldQuery": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` where 1 != 1 group by a, b, c, weight_string(a), weight_string(b), weight_string(c)", "OrderBy": "(0|5) ASC, (1|6) ASC, (2|7) ASC", - "Query": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` group by a, weight_string(a), b, weight_string(b), c, weight_string(c) order by a asc, b asc, c asc", + "Query": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` group by a, b, c, weight_string(a), weight_string(b), weight_string(c) order by a asc, b asc, c asc", "Table": "`user`" } ] @@ -1095,9 +1384,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` where 1 != 1 group by a, weight_string(a), b, weight_string(b), c, weight_string(c)", + "FieldQuery": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` where 1 != 1 group by a, b, c, weight_string(a), weight_string(b), weight_string(c)", "OrderBy": "(0|5) ASC, (1|6) ASC, (2|7) ASC", - "Query": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` group by a, weight_string(a), b, weight_string(b), c, weight_string(c) order by a asc, b asc, c asc", + "Query": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` group by a, b, c, weight_string(a), weight_string(b), weight_string(c) order by a asc, b asc, c asc", "Table": "`user`" } ] @@ -1142,7 +1431,7 @@ "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "sum_count_star(4) AS count(*)", - "GroupBy": "(3|8), (1|6), (0|5), (2|7)", + "GroupBy": "(3|5), (1|6), (0|7), (2|8)", "ResultColumns": 5, "Inputs": [ { @@ -1152,9 +1441,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c), weight_string(d) from `user` where 1 != 1 group by a, weight_string(a), b, weight_string(b), c, weight_string(c), d, weight_string(d)", - "OrderBy": "(3|8) ASC, (1|6) ASC, (0|5) ASC, (2|7) ASC", - "Query": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c), weight_string(d) from `user` group by a, weight_string(a), b, weight_string(b), c, weight_string(c), d, weight_string(d) order by d asc, b asc, a asc, c asc", + "FieldQuery": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` where 1 != 1 group by a, b, c, d, weight_string(d), weight_string(b), weight_string(a), weight_string(c)", + "OrderBy": "(3|5) ASC, (1|6) ASC, (0|7) ASC, (2|8) ASC", + "Query": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` group by a, b, c, d, weight_string(d), weight_string(b), weight_string(a), weight_string(c) order by d asc, b asc, a asc, c asc", "Table": "`user`" } ] @@ -1199,7 +1488,7 @@ "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "sum_count_star(4) AS count(*)", - "GroupBy": "(3|8), (1|6), (0|5), (2|7)", + "GroupBy": "(3|5), (1|6), (0|7), (2|8)", "ResultColumns": 5, "Inputs": [ { @@ -1209,9 +1498,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c), weight_string(d) from `user` where 1 != 1 group by a, weight_string(a), b, weight_string(b), c, weight_string(c), d, weight_string(d)", - "OrderBy": "(3|8) ASC, (1|6) ASC, (0|5) ASC, (2|7) ASC", - "Query": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c), weight_string(d) from `user` group by a, weight_string(a), b, weight_string(b), c, weight_string(c), d, weight_string(d) order by d asc, b asc, a asc, c asc", + "FieldQuery": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` where 1 != 1 group by c, b, a, d, weight_string(d), weight_string(b), weight_string(a), weight_string(c)", + "OrderBy": "(3|5) ASC, (1|6) ASC, (0|7) ASC, (2|8) ASC", + "Query": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` group by c, b, a, d, weight_string(d), weight_string(b), weight_string(a), weight_string(c) order by d asc, b asc, a asc, c asc", "Table": "`user`" } ] @@ -1256,7 +1545,7 @@ "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "sum_count_star(3) AS count(*)", - "GroupBy": "(0|4), (2|6), (1|5)", + "GroupBy": "(0|4), (2|5), (1|6)", "ResultColumns": 4, "Inputs": [ { @@ -1266,9 +1555,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, c, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` where 1 != 1 group by a, weight_string(a), b, weight_string(b), c, weight_string(c)", - "OrderBy": "(0|4) DESC, (2|6) DESC, (1|5) ASC", - "Query": "select a, b, c, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` group by a, weight_string(a), b, weight_string(b), c, weight_string(c) order by a desc, c desc, b asc", + "FieldQuery": "select a, b, c, count(*), weight_string(a), weight_string(c), weight_string(b) from `user` where 1 != 1 group by c, b, a, weight_string(a), weight_string(c), weight_string(b)", + "OrderBy": "(0|4) DESC, (2|5) DESC, (1|6) ASC", + "Query": "select a, b, c, count(*), weight_string(a), weight_string(c), weight_string(b) from `user` group by c, b, a, weight_string(a), weight_string(c), weight_string(b) order by a desc, c desc, b asc", "Table": "`user`" } ] @@ -1339,7 +1628,7 @@ }, "FieldQuery": "select col, count(*) from `user` where 1 != 1 group by col", "OrderBy": "0 ASC", - "Query": "select col, count(*) from `user` group by col order by col asc limit :__upper_limit", + "Query": "select col, count(*) from `user` group by col order by col asc", "Table": "`user`" } ] @@ -1881,42 +2170,33 @@ "ResultColumns": 1, "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "[COLUMN 0] as a", - "[COLUMN 1]" - ], + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,L:1", + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1", - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.a, weight_string(`user`.a) from `user` where 1 != 1 group by `user`.a, weight_string(`user`.a)", - "OrderBy": "(0|1) ASC", - "Query": "select `user`.a, weight_string(`user`.a) from `user` group by `user`.a, weight_string(`user`.a) order by `user`.a asc", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select 1 from user_extra", - "Table": "user_extra" - } - ] + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.a, weight_string(`user`.a) from `user` where 1 != 1 group by `user`.a, weight_string(`user`.a)", + "OrderBy": "(0|1) ASC", + "Query": "select `user`.a, weight_string(`user`.a) from `user` group by `user`.a, weight_string(`user`.a) order by `user`.a asc", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1 group by .0", + "Query": "select 1 from user_extra group by .0", + "Table": "user_extra" } ] } @@ -2926,13 +3206,13 @@ { "OperatorType": "Projection", "Expressions": [ - "[COLUMN 0] * COALESCE([COLUMN 1], INT64(1)) as count(*)" + "[COLUMN 0] * [COLUMN 1] as count(*)" ], "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,R:1", + "JoinColumnIndexes": "L:0,R:0", "TableName": "`user`_user_extra", "Inputs": [ { @@ -2953,8 +3233,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select 1, count(*) from user_extra where 1 != 1 group by 1", - "Query": "select 1, count(*) from user_extra group by 1", + "FieldQuery": "select count(*) from user_extra where 1 != 1 group by .0", + "Query": "select count(*) from user_extra group by .0", "Table": "user_extra" } ] @@ -3044,15 +3324,15 @@ { "OperatorType": "Projection", "Expressions": [ - "[COLUMN 0] as a", - "[COLUMN 2] * COALESCE([COLUMN 3], INT64(1)) as count(*)", - "[COLUMN 1]" + "[COLUMN 2] as a", + "[COLUMN 0] * [COLUMN 1] as count(*)", + "[COLUMN 3] as weight_string(`user`.a)" ], "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:1,L:2,L:0,R:1", + "JoinColumnIndexes": "L:0,R:0,L:1,L:2", "TableName": "`user`_user_extra", "Inputs": [ { @@ -3074,8 +3354,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select 1, count(*) from user_extra where 1 != 1 group by 1", - "Query": "select 1, count(*) from user_extra group by 1", + "FieldQuery": "select count(*) from user_extra where 1 != 1 group by .0", + "Query": "select count(*) from user_extra group by .0", "Table": "user_extra" } ] @@ -3107,15 +3387,15 @@ { "OperatorType": "Projection", "Expressions": [ - "[COLUMN 0] as a", - "[COLUMN 2] * COALESCE([COLUMN 3], INT64(1)) as count(user_extra.a)", - "[COLUMN 1]" + "[COLUMN 2] as a", + "[COLUMN 0] * [COLUMN 1] as count(user_extra.a)", + "[COLUMN 3] as weight_string(`user`.a)" ], "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:1,L:2,L:0,R:1", + "JoinColumnIndexes": "R:0,L:0,L:1,L:2", "TableName": "`user`_user_extra", "Inputs": [ { @@ -3137,8 +3417,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select 1, count(user_extra.a) from user_extra where 1 != 1 group by 1", - "Query": "select 1, count(user_extra.a) from user_extra group by 1", + "FieldQuery": "select count(user_extra.a) from user_extra where 1 != 1 group by .0", + "Query": "select count(user_extra.a) from user_extra group by .0", "Table": "user_extra" } ] @@ -3170,23 +3450,23 @@ { "OperatorType": "Projection", "Expressions": [ - "([COLUMN 2] * COALESCE([COLUMN 3], INT64(1))) * COALESCE([COLUMN 4], INT64(1)) as count(u.textcol1)", - "([COLUMN 5] * COALESCE([COLUMN 6], INT64(1))) * COALESCE([COLUMN 7], INT64(1)) as count(ue.foo)", - "[COLUMN 0] as bar", - "[COLUMN 1]" + "[COLUMN 0] * [COLUMN 1] as count(u.textcol1)", + "[COLUMN 2] * [COLUMN 3] as count(ue.foo)", + "[COLUMN 4] as bar", + "[COLUMN 5] as weight_string(us.bar)" ], "Inputs": [ { "OperatorType": "Sort", "Variant": "Memory", - "OrderBy": "(0|1) ASC", + "OrderBy": "(4|5) ASC", "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "R:0,R:1,L:1,R:2,R:3,L:2,R:4,R:5", + "JoinColumnIndexes": "L:0,R:0,R:1,L:1,R:2,R:3", "JoinVars": { - "u_foo": 0 + "u_foo": 2 }, "TableName": "`user`_user_extra_unsharded", "Inputs": [ @@ -3197,40 +3477,51 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select u.foo, count(u.textcol1), count(*), weight_string(u.foo) from `user` as u where 1 != 1 group by u.foo, weight_string(u.foo)", - "Query": "select u.foo, count(u.textcol1), count(*), weight_string(u.foo) from `user` as u group by u.foo, weight_string(u.foo)", + "FieldQuery": "select count(u.textcol1), count(*), u.foo from `user` as u where 1 != 1 group by u.foo", + "Query": "select count(u.textcol1), count(*), u.foo from `user` as u group by u.foo", "Table": "`user`" }, { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "R:1,R:2,L:1,R:0,L:2,R:0", - "JoinVars": { - "ue_bar": 0 - }, - "TableName": "user_extra_unsharded", + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as count(*)", + "[COLUMN 2] * [COLUMN 1] as count(ue.foo)", + "[COLUMN 3] as bar", + "[COLUMN 4] as weight_string(us.bar)" + ], "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select ue.bar, count(*), count(ue.foo), weight_string(ue.bar) from user_extra as ue where 1 != 1 group by ue.bar, weight_string(ue.bar)", - "Query": "select ue.bar, count(*), count(ue.foo), weight_string(ue.bar) from user_extra as ue where ue.bar = :u_foo group by ue.bar, weight_string(ue.bar)", - "Table": "user_extra" - }, - { - "OperatorType": "Route", - "Variant": "Unsharded", - "Keyspace": { - "Name": "main", - "Sharded": false + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,L:1,R:1,R:2", + "JoinVars": { + "ue_bar": 2 }, - "FieldQuery": "select count(*), us.bar, weight_string(us.bar) from unsharded as us where 1 != 1 group by us.bar, weight_string(us.bar)", - "Query": "select count(*), us.bar, weight_string(us.bar) from unsharded as us where us.baz = :ue_bar group by us.bar, weight_string(us.bar)", - "Table": "unsharded" + "TableName": "user_extra_unsharded", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), count(ue.foo), ue.bar from user_extra as ue where 1 != 1 group by ue.bar", + "Query": "select count(*), count(ue.foo), ue.bar from user_extra as ue where ue.bar = :u_foo group by ue.bar", + "Table": "user_extra" + }, + { + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select count(*), us.bar, weight_string(us.bar) from unsharded as us where 1 != 1 group by us.bar, weight_string(us.bar)", + "Query": "select count(*), us.bar, weight_string(us.bar) from unsharded as us where us.baz = :ue_bar group by us.bar, weight_string(us.bar)", + "Table": "unsharded" + } + ] } ] } @@ -3656,48 +3947,36 @@ "ResultColumns": 4, "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "[COLUMN 0] as col", - "[COLUMN 3] as min(user_extra.foo)", - "[COLUMN 1] as bar", - "[COLUMN 4] as max(user_extra.bar)", - "[COLUMN 2]" - ], + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "R:0,R:1,L:0,L:1,L:2", + "JoinVars": { + "user_col": 0 + }, + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1,L:2,R:1,R:2", - "JoinVars": { - "user_col": 0 + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true }, - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col, `user`.bar, weight_string(`user`.bar) from `user` where 1 != 1 group by `user`.col, `user`.bar, weight_string(`user`.bar)", - "OrderBy": "0 ASC, (1|2) ASC", - "Query": "select `user`.col, `user`.bar, weight_string(`user`.bar) from `user` group by `user`.col, `user`.bar, weight_string(`user`.bar) order by `user`.col asc, `user`.bar asc", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1, min(user_extra.foo), max(user_extra.bar) from user_extra where 1 != 1 group by 1", - "Query": "select 1, min(user_extra.foo), max(user_extra.bar) from user_extra where user_extra.bar = :user_col group by 1", - "Table": "user_extra" - } - ] + "FieldQuery": "select `user`.col, `user`.bar, weight_string(`user`.bar) from `user` where 1 != 1 group by `user`.col, `user`.bar, weight_string(`user`.bar)", + "OrderBy": "0 ASC, (1|2) ASC", + "Query": "select `user`.col, `user`.bar, weight_string(`user`.bar) from `user` group by `user`.col, `user`.bar, weight_string(`user`.bar) order by `user`.col asc, `user`.bar asc", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select min(user_extra.foo), max(user_extra.bar) from user_extra where 1 != 1 group by .0", + "Query": "select min(user_extra.foo), max(user_extra.bar) from user_extra where user_extra.bar = :user_col group by .0", + "Table": "user_extra" } ] } @@ -4328,50 +4607,40 @@ "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "random(0) AS col", - "GroupBy": "(2|1)", + "GroupBy": "(1|2)", "ResultColumns": 1, "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "[COLUMN 2] as col", - "[COLUMN 1]", - "[COLUMN 0] as id" - ], + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,L:1,L:2", + "JoinVars": { + "user_col": 0 + }, + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:1,L:2,L:0", - "JoinVars": { - "user_col": 0 + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true }, - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col, `user`.id, weight_string(`user`.id) from `user` where 1 != 1 group by `user`.col, `user`.id, weight_string(`user`.id)", - "OrderBy": "(1|2) ASC", - "Query": "select `user`.col, `user`.id, weight_string(`user`.id) from `user` group by `user`.col, `user`.id, weight_string(`user`.id) order by `user`.id asc", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1 from user_extra where 1 != 1 group by 1", - "Query": "select 1 from user_extra where user_extra.col = :user_col group by 1", - "Table": "user_extra" - } - ] + "FieldQuery": "select `user`.col, `user`.id, weight_string(`user`.id) from `user` where 1 != 1 group by `user`.id, `user`.col, weight_string(`user`.id)", + "OrderBy": "(1|2) ASC", + "Query": "select `user`.col, `user`.id, weight_string(`user`.id) from `user` group by `user`.id, `user`.col, weight_string(`user`.id) order by `user`.id asc", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1 group by .0", + "Query": "select 1 from user_extra where user_extra.col = :user_col group by .0", + "Table": "user_extra" } ] } @@ -4391,22 +4660,28 @@ "QueryType": "SELECT", "Original": "select id, b as id, count(*) from user order by id", "Instructions": { - "OperatorType": "Aggregate", - "Variant": "Scalar", - "Aggregates": "random(0) AS id, random(1) AS id, sum_count_star(2) AS count(*)", + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(1|3) ASC", "ResultColumns": 3, "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select id, b as id, count(*), weight_string(b) from `user` where 1 != 1", - "OrderBy": "(1|3) ASC", - "Query": "select id, b as id, count(*), weight_string(b) from `user` order by id asc", - "Table": "`user`" + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "random(0) AS id, random(1) AS id, sum_count_star(2) AS count(*), random(3)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id, b as id, count(*), weight_string(b) from `user` where 1 != 1", + "Query": "select id, b as id, count(*), weight_string(b) from `user`", + "Table": "`user`" + } + ] } ] }, @@ -4480,42 +4755,33 @@ "ResultColumns": 1, "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "[COLUMN 0] as id", - "[COLUMN 1]" - ], + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,L:1", + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1", - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1 group by id, weight_string(id)", - "OrderBy": "(0|1) ASC", - "Query": "select id, weight_string(id) from `user` group by id, weight_string(id) order by id asc", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select 1 from user_extra", - "Table": "user_extra" - } - ] + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id, weight_string(id) from `user` where 1 != 1 group by id, weight_string(id)", + "OrderBy": "(0|1) ASC", + "Query": "select `user`.id, weight_string(id) from `user` group by id, weight_string(id) order by id asc", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1 group by .0", + "Query": "select 1 from user_extra group by .0", + "Table": "user_extra" } ] } @@ -4845,23 +5111,23 @@ { "OperatorType": "Projection", "Expressions": [ - "[COLUMN 0] as id", - "[COLUMN 2] as name", - "[COLUMN 3] * COALESCE([COLUMN 4], INT64(1)) as count(m.predef1)", - "[COLUMN 1]" + "[COLUMN 3] as id", + "[COLUMN 0] as name", + "[COLUMN 1] * [COLUMN 2] as count(m.predef1)", + "[COLUMN 4] as weight_string(u.id)" ], "Inputs": [ { "OperatorType": "Sort", "Variant": "Memory", - "OrderBy": "(0|1) ASC", + "OrderBy": "(3|4) ASC", "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "R:2,R:3,R:0,L:1,R:1", + "JoinColumnIndexes": "R:0,L:0,R:1,R:2,R:3", "JoinVars": { - "m_order": 0 + "m_order": 1 }, "TableName": "user_extra_`user`", "Inputs": [ @@ -4872,8 +5138,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select m.`order`, count(m.predef1), weight_string(m.`order`) from user_extra as m where 1 != 1 group by m.`order`, weight_string(m.`order`)", - "Query": "select m.`order`, count(m.predef1), weight_string(m.`order`) from user_extra as m group by m.`order`, weight_string(m.`order`)", + "FieldQuery": "select count(m.predef1), m.`order` from user_extra as m where 1 != 1 group by m.`order`", + "Query": "select count(m.predef1), m.`order` from user_extra as m group by m.`order`", "Table": "user_extra" }, { @@ -4925,9 +5191,9 @@ { "OperatorType": "Join", "Variant": "LeftJoin", - "JoinColumnIndexes": "L:1,R:1", + "JoinColumnIndexes": "L:0,R:0", "JoinVars": { - "u_col": 0 + "u_col": 1 }, "TableName": "`user`_user_extra", "Inputs": [ @@ -4938,8 +5204,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select u.col, count(u.id) from `user` as u where 1 != 1 group by u.col", - "Query": "select u.col, count(u.id) from `user` as u group by u.col", + "FieldQuery": "select count(u.id), u.col from `user` as u where 1 != 1 group by u.col", + "Query": "select count(u.id), u.col from `user` as u group by u.col", "Table": "`user`" }, { @@ -4949,8 +5215,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select 1, count(*) from user_extra as ue where 1 != 1 group by 1", - "Query": "select 1, count(*) from user_extra as ue where ue.col = :u_col group by 1", + "FieldQuery": "select count(*) from user_extra as ue where 1 != 1 group by .0", + "Query": "select count(*) from user_extra as ue where ue.col = :u_col group by .0", "Table": "user_extra" } ] @@ -5136,5 +5402,740 @@ "user.user" ] } + }, + { + "comment": "Scatter order by is complex with aggregates in select", + "query": "select col, count(*) from user group by col order by col+1", + "v3-plan": "VT12001: unsupported: in scatter query: complex ORDER BY expression: col + 1", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select col, count(*) from user group by col order by col+1", + "Instructions": { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(2|3) ASC", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum_count_star(1) AS count(*), random(2) AS col + 1, random(3)", + "GroupBy": "0", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col, count(*), col + 1, weight_string(col + 1) from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col, count(*), col + 1, weight_string(col + 1) from `user` group by col order by col asc", + "Table": "`user`" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "scatter aggregate complex order by", + "query": "select id from user group by id order by id+1", + "v3-plan": "VT12001: unsupported: in scatter query: complex ORDER BY expression: id + 1", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select id from user group by id order by id+1", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id, id + 1, weight_string(id + 1) from `user` where 1 != 1 group by id", + "OrderBy": "(1|2) ASC", + "Query": "select id, id + 1, weight_string(id + 1) from `user` group by id order by id + 1 asc", + "ResultColumns": 1, + "Table": "`user`" + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "select expression does not directly depend on grouping expression", + "query": "select a from user group by a+1", + "v3-plan": "VT12001: unsupported: in scatter query: only simple references are allowed", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select a from user group by a+1", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "random(0) AS a", + "GroupBy": "(1|2)", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select a, a + 1, weight_string(a + 1) from `user` where 1 != 1 group by a + 1, weight_string(a + 1)", + "OrderBy": "(1|2) ASC", + "Query": "select a, a + 1, weight_string(a + 1) from `user` group by a + 1, weight_string(a + 1) order by a + 1 asc", + "Table": "`user`" + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "inner join with scalar aggregation", + "query": "select count(*) from user join music on user.foo = music.bar", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select count(*) from user join music on user.foo = music.bar", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum_count_star(0) AS count(*)", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as count(*)" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), `user`.foo from `user` where 1 != 1 group by `user`.foo", + "Query": "select count(*), `user`.foo from `user` group by `user`.foo", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from music where 1 != 1 group by .0", + "Query": "select count(*) from music where music.bar = :user_foo group by .0", + "Table": "music" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user" + ] + } + }, + { + "comment": "left outer join with scalar aggregation", + "query": "select count(*) from user left join music on user.foo = music.bar", + "plan": { + "QueryType": "SELECT", + "Original": "select count(*) from user left join music on user.foo = music.bar", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum_count_star(0) AS count(*)", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * COALESCE([COLUMN 1], INT64(1)) as count(*)" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), `user`.foo from `user` where 1 != 1 group by `user`.foo", + "Query": "select count(*), `user`.foo from `user` group by `user`.foo", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from music where 1 != 1 group by .0", + "Query": "select count(*) from music where music.bar = :user_foo group by .0", + "Table": "music" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user" + ] + } + }, + { + "comment": "inner join with left grouping", + "query": "select count(*) from user left join music on user.foo = music.bar group by user.col", + "plan": { + "QueryType": "SELECT", + "Original": "select count(*) from user left join music on user.foo = music.bar group by user.col", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum_count_star(0) AS count(*)", + "GroupBy": "1", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * COALESCE([COLUMN 1], INT64(1)) as count(*)", + "[COLUMN 2] as col" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0,R:0,L:2", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), `user`.foo, `user`.col from `user` where 1 != 1 group by `user`.col, `user`.foo", + "OrderBy": "2 ASC", + "Query": "select count(*), `user`.foo, `user`.col from `user` group by `user`.col, `user`.foo order by `user`.col asc", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from music where 1 != 1 group by .0", + "Query": "select count(*) from music where music.bar = :user_foo group by .0", + "Table": "music" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user" + ] + } + }, + { + "comment": "inner join with right grouping", + "query": "select count(*) from user left join music on user.foo = music.bar group by music.col", + "plan": { + "QueryType": "SELECT", + "Original": "select count(*) from user left join music on user.foo = music.bar group by music.col", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum_count_star(0) AS count(*)", + "GroupBy": "(1|2)", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * COALESCE([COLUMN 1], INT64(1)) as count(*)", + "[COLUMN 2] as col", + "[COLUMN 3] as weight_string(music.col)" + ], + "Inputs": [ + { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(2|3) ASC", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0,R:0,R:1,R:2", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), `user`.foo from `user` where 1 != 1 group by `user`.foo", + "Query": "select count(*), `user`.foo from `user` group by `user`.foo", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), music.col, weight_string(music.col) from music where 1 != 1 group by music.col, weight_string(music.col)", + "Query": "select count(*), music.col, weight_string(music.col) from music where music.bar = :user_foo group by music.col, weight_string(music.col)", + "Table": "music" + } + ] + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user" + ] + } + }, + { + "comment": "left outer join with left grouping", + "query": "select count(*) from user left join music on user.foo = music.bar group by user.col", + "plan": { + "QueryType": "SELECT", + "Original": "select count(*) from user left join music on user.foo = music.bar group by user.col", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum_count_star(0) AS count(*)", + "GroupBy": "1", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * COALESCE([COLUMN 1], INT64(1)) as count(*)", + "[COLUMN 2] as col" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0,R:0,L:2", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), `user`.foo, `user`.col from `user` where 1 != 1 group by `user`.col, `user`.foo", + "OrderBy": "2 ASC", + "Query": "select count(*), `user`.foo, `user`.col from `user` group by `user`.col, `user`.foo order by `user`.col asc", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from music where 1 != 1 group by .0", + "Query": "select count(*) from music where music.bar = :user_foo group by .0", + "Table": "music" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user" + ] + } + }, + { + "comment": "left outer join with right grouping", + "query": "select count(*) from user left join music on user.foo = music.bar group by music.col", + "plan": { + "QueryType": "SELECT", + "Original": "select count(*) from user left join music on user.foo = music.bar group by music.col", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum_count_star(0) AS count(*)", + "GroupBy": "(1|2)", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * COALESCE([COLUMN 1], INT64(1)) as count(*)", + "[COLUMN 2] as col", + "[COLUMN 3] as weight_string(music.col)" + ], + "Inputs": [ + { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(2|3) ASC", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0,R:0,R:1,R:2", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), `user`.foo from `user` where 1 != 1 group by `user`.foo", + "Query": "select count(*), `user`.foo from `user` group by `user`.foo", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), music.col, weight_string(music.col) from music where 1 != 1 group by music.col, weight_string(music.col)", + "Query": "select count(*), music.col, weight_string(music.col) from music where music.bar = :user_foo group by music.col, weight_string(music.col)", + "Table": "music" + } + ] + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user" + ] + } + }, + { + "comment": "3 table inner join with scalar aggregation", + "query": "select count(*) from user join music on user.foo = music.bar join user_extra on user.foo = user_extra.baz", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select count(*) from user join music on user.foo = music.bar join user_extra on user.foo = user_extra.baz", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum_count_star(0) AS count(*)", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as count(*)" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "user_extra_baz": 1 + }, + "TableName": "user_extra_`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), user_extra.baz from user_extra where 1 != 1 group by user_extra.baz", + "Query": "select count(*), user_extra.baz from user_extra group by user_extra.baz", + "Table": "user_extra" + }, + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as count(*)" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), `user`.foo from `user` where 1 != 1 group by `user`.foo", + "Query": "select count(*), `user`.foo from `user` where `user`.foo = :user_extra_baz group by `user`.foo", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from music where 1 != 1 group by .0", + "Query": "select count(*) from music where music.bar = :user_foo group by .0", + "Table": "music" + } + ] + } + ] + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "3 table with mixed join with scalar aggregation", + "query": "select count(*) from user left join music on user.foo = music.bar join user_extra on user.foo = user_extra.baz", + "plan": { + "QueryType": "SELECT", + "Original": "select count(*) from user left join music on user.foo = music.bar join user_extra on user.foo = user_extra.baz", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum_count_star(0) AS count(*)", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as count(*)" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_music_user_extra", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * COALESCE([COLUMN 1], INT64(1)) as count(*)", + "[COLUMN 2] as foo" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0,R:0,L:1", + "JoinVars": { + "user_foo": 1 + }, + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), `user`.foo from `user` where 1 != 1 group by `user`.foo", + "Query": "select count(*), `user`.foo from `user` group by `user`.foo", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from music where 1 != 1 group by .0", + "Query": "select count(*) from music where music.bar = :user_foo group by .0", + "Table": "music" + } + ] + } + ] + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from user_extra where 1 != 1 group by .0", + "Query": "select count(*) from user_extra where user_extra.baz = :user_foo group by .0", + "Table": "user_extra" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "ordering have less column than grouping columns, grouping gets rearranged as order by and missing columns gets added to ordering", + "query": "select u.col, u.intcol, count(*) from user u join music group by 1,2 order by 2", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": { + "QueryType": "SELECT", + "Original": "select u.col, u.intcol, count(*) from user u join music group by 1,2 order by 2", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum_count_star(2) AS count(*)", + "GroupBy": "1, 0", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 2] as col", + "[COLUMN 3] as intcol", + "[COLUMN 0] * [COLUMN 1] as count(*)" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,L:1,L:2", + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), u.col, u.intcol from `user` as u where 1 != 1 group by u.col, u.intcol", + "OrderBy": "2 ASC, 1 ASC", + "Query": "select count(*), u.col, u.intcol from `user` as u group by u.col, u.intcol order by u.intcol asc, u.col asc", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) from music where 1 != 1 group by .0", + "Query": "select count(*) from music group by .0", + "Table": "music" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.json b/go/vt/vtgate/planbuilder/testdata/filter_cases.json index f08690d619e..bfe1fb9e24a 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.json @@ -1794,7 +1794,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where user_extra.col = :user_col and user_extra.user_id = :user_col", + "Query": "select user_extra.id from user_extra where user_extra.user_id = :user_col and user_extra.col = :user_col", "Table": "user_extra", "Values": [ ":user_col" @@ -1880,7 +1880,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where 1 = 1 and user_extra.col = :user_col", + "Query": "select user_extra.id from user_extra where user_extra.col = :user_col and 1 = 1", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 3571535c742..061def2c043 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -3203,7 +3203,7 @@ "Sharded": false }, "FieldQuery": "select 1 from unsharded where 1 != 1", - "Query": "select 1 from unsharded where unsharded.col1 = :t_col1 and unsharded.id = :t_id", + "Query": "select 1 from unsharded where unsharded.id = :t_id and unsharded.col1 = :t_col1", "Table": "unsharded" } ] @@ -5485,7 +5485,7 @@ "Sharded": false }, "FieldQuery": "select 1 from unsharded where 1 != 1", - "Query": "select 1 from unsharded where unsharded.col1 = :t_col1 and unsharded.a = :t_col1", + "Query": "select 1 from unsharded where unsharded.a = :t_col1 and unsharded.col1 = :t_col1", "Table": "unsharded" } ] @@ -5872,35 +5872,43 @@ "QueryType": "SELECT", "Original": "select user_extra.col+1 from user left join user_extra on user.col = user_extra.col", "Instructions": { - "OperatorType": "Join", - "Variant": "LeftJoin", - "JoinColumnIndexes": "R:0", - "JoinVars": { - "user_col": 0 - }, - "TableName": "`user`_user_extra", + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] + INT64(1) as user_extra.col + 1" + ], "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select `user`.col from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "R:0", + "JoinVars": { + "user_col": 0 }, - "FieldQuery": "select user_extra.col + 1 from user_extra where 1 != 1", - "Query": "select user_extra.col + 1 from user_extra where user_extra.col = :user_col", - "Table": "user_extra" + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.col from `user` where 1 != 1", + "Query": "select `user`.col from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.col from user_extra where 1 != 1", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Table": "user_extra" + } + ] } ] }, @@ -5923,35 +5931,44 @@ "TableName": "`user`_user_extra_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "LeftJoin", - "JoinColumnIndexes": "L:0,R:0", - "JoinVars": { - "user_col": 1 - }, - "TableName": "`user`_user_extra", + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] as id", + "[COLUMN 1] + INT64(1) as user_extra.col + 1" + ], "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.id, `user`.col from `user` where 1 != 1", - "Query": "select `user`.id, `user`.col from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "user_col": 1 }, - "FieldQuery": "select user_extra.col + 1 from user_extra where 1 != 1", - "Query": "select user_extra.col + 1 from user_extra where user_extra.col = :user_col", - "Table": "user_extra" + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id, `user`.col from `user` where 1 != 1", + "Query": "select `user`.id, `user`.col from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.col from user_extra where 1 != 1", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Table": "user_extra" + } + ] } ] }, @@ -5981,36 +5998,43 @@ "QueryType": "SELECT", "Original": "select user.foo+user_extra.col+1 from user left join user_extra on user.col = user_extra.col", "Instructions": { - "OperatorType": "Join", - "Variant": "LeftJoin", - "JoinColumnIndexes": "R:0", - "JoinVars": { - "user_col": 1, - "user_foo": 0 - }, - "TableName": "`user`_user_extra", + "OperatorType": "Projection", + "Expressions": [ + "([COLUMN 0] + [COLUMN 1]) + INT64(1) as `user`.foo + user_extra.col + 1" + ], "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.foo, `user`.col from `user` where 1 != 1", - "Query": "select `user`.foo, `user`.col from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "user_col": 1 }, - "FieldQuery": "select :user_foo + user_extra.col + 1 from user_extra where 1 != 1", - "Query": "select :user_foo + user_extra.col + 1 from user_extra where user_extra.col = :user_col", - "Table": "user_extra" + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.foo, `user`.col from `user` where 1 != 1", + "Query": "select `user`.foo, `user`.col from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.col from user_extra where 1 != 1", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Table": "user_extra" + } + ] } ] }, diff --git a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json index 0578d035526..6363203b88a 100644 --- a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json @@ -628,7 +628,7 @@ "Sharded": false }, "FieldQuery": "select 1 from information_schema.table_constraints as tc where 1 != 1", - "Query": "select 1 from information_schema.table_constraints as tc where tc.table_schema = :__vtschemaname /* VARCHAR */ and tc.table_name = :tc_table_name /* VARCHAR */ and tc.constraint_schema = :__vtschemaname /* VARCHAR */ and tc.constraint_name = :cc_constraint_name", + "Query": "select 1 from information_schema.table_constraints as tc where tc.table_schema = :__vtschemaname /* VARCHAR */ and tc.table_name = :tc_table_name /* VARCHAR */ and tc.constraint_name = :cc_constraint_name and tc.constraint_schema = :__vtschemaname /* VARCHAR */", "SysTableTableName": "[tc_table_name:VARCHAR(\"table_name\")]", "SysTableTableSchema": "[VARCHAR(\"table_schema\"), :cc_constraint_schema]", "Table": "information_schema.table_constraints" diff --git a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json index 8b56936b2b2..620fb256730 100644 --- a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json @@ -40,23 +40,30 @@ "QueryType": "SELECT", "Original": "select a, b, count(*) from user group by a order by b", "Instructions": { - "OperatorType": "Aggregate", - "Variant": "Ordered", - "Aggregates": "random(1) AS b, sum_count_star(2) AS count(*)", - "GroupBy": "(0|3)", + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(1|3) ASC", "ResultColumns": 3, "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select a, b, count(*), weight_string(a), weight_string(b) from `user` where 1 != 1 group by a, weight_string(a)", - "OrderBy": "(1|4) ASC, (0|3) ASC", - "Query": "select a, b, count(*), weight_string(a), weight_string(b) from `user` group by a, weight_string(a) order by b asc, a asc", - "Table": "`user`" + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "random(1) AS b, sum_count_star(2) AS count(*), random(3)", + "GroupBy": "(0|4)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select a, b, count(*), weight_string(b), weight_string(a) from `user` where 1 != 1 group by a, weight_string(a)", + "OrderBy": "(0|4) ASC", + "Query": "select a, b, count(*), weight_string(b), weight_string(a) from `user` group by a, weight_string(a) order by a asc", + "Table": "`user`" + } + ] } ] }, @@ -179,14 +186,14 @@ "Instructions": { "OperatorType": "Sort", "Variant": "Memory", - "OrderBy": "1 ASC, (0|3) ASC, 2 ASC", + "OrderBy": "(1|3) ASC, (0|4) ASC, 2 ASC", "ResultColumns": 3, "Inputs": [ { "OperatorType": "Aggregate", "Variant": "Ordered", - "Aggregates": "random(1) AS b, sum_count_star(2) AS k", - "GroupBy": "(0|3)", + "Aggregates": "random(1) AS b, sum_count_star(2) AS k, random(3)", + "GroupBy": "(0|4)", "Inputs": [ { "OperatorType": "Route", @@ -195,9 +202,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, count(*) as k, weight_string(a) from `user` where 1 != 1 group by a, weight_string(a)", - "OrderBy": "(0|3) ASC", - "Query": "select a, b, count(*) as k, weight_string(a) from `user` group by a, weight_string(a) order by a asc", + "FieldQuery": "select a, b, count(*) as k, weight_string(b), weight_string(a) from `user` where 1 != 1 group by a, weight_string(a)", + "OrderBy": "(0|4) ASC", + "Query": "select a, b, count(*) as k, weight_string(b), weight_string(a) from `user` group by a, weight_string(a) order by a asc", "Table": "`user`" } ] @@ -278,7 +285,7 @@ }, "FieldQuery": "select a, b, count(*) as k, weight_string(a) from `user` where 1 != 1 group by a, weight_string(a)", "OrderBy": "(0|3) ASC", - "Query": "select a, b, count(*) as k, weight_string(a) from `user` group by a, weight_string(a) order by a asc", + "Query": "select a, b, count(*) as k, weight_string(a) from `user` group by a, weight_string(a) order by a asc limit :__upper_limit", "Table": "`user`" } ] @@ -494,45 +501,38 @@ "QueryType": "SELECT", "Original": "select id from (select user.id, user.col from user join user_extra) as t order by id", "Instructions": { - "OperatorType": "SimpleProjection", - "Columns": [ - 0 - ], + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(0|1) ASC", + "ResultColumns": 1, "Inputs": [ { - "OperatorType": "Sort", - "Variant": "Memory", - "OrderBy": "(0|1) ASC", + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,L:1", + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1", - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select id, weight_string(id) from (select `user`.id, `user`.col from `user` where 1 != 1) as t where 1 != 1", - "Query": "select id, weight_string(id) from (select `user`.id, `user`.col from `user`) as t", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select 1 from user_extra", - "Table": "user_extra" - } - ] + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id, weight_string(id) from (select `user`.id, `user`.col from `user` where 1 != 1) as t where 1 != 1", + "Query": "select id, weight_string(id) from (select `user`.id, `user`.col from `user`) as t", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra", + "Table": "user_extra" } ] } @@ -604,58 +604,49 @@ "QueryType": "SELECT", "Original": "select user.col1 as a, user.col2 b, music.col3 c from user, music where user.id = music.id and user.id = 1 order by c", "Instructions": { - "OperatorType": "SimpleProjection", - "Columns": [ - 0, - 1, - 2 - ], + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(2|3) ASC", + "ResultColumns": 3, "Inputs": [ { - "OperatorType": "Sort", - "Variant": "Memory", - "OrderBy": "(3|4) ASC", + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,L:1,R:0,R:1", + "JoinVars": { + "user_id": 2 + }, + "TableName": "`user`_music", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1,R:0,R:1,R:2", - "JoinVars": { - "user_id": 2 + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true }, - "TableName": "`user`_music", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col1 as a, `user`.col2 as b, `user`.id from `user` where 1 != 1", - "Query": "select `user`.col1 as a, `user`.col2 as b, `user`.id from `user` where `user`.id = 1", - "Table": "`user`", - "Values": [ - "INT64(1)" - ], - "Vindex": "user_index" - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select music.col3 as c, c, weight_string(music.col3) from music where 1 != 1", - "Query": "select music.col3 as c, c, weight_string(music.col3) from music where music.id = :user_id", - "Table": "music", - "Values": [ - ":user_id" - ], - "Vindex": "music_user_map" - } - ] + "FieldQuery": "select `user`.col1 as a, `user`.col2 as b, `user`.id from `user` where 1 != 1", + "Query": "select `user`.col1 as a, `user`.col2 as b, `user`.id from `user` where `user`.id = 1", + "Table": "`user`", + "Values": [ + "INT64(1)" + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.col3 as c, weight_string(music.col3) from music where 1 != 1", + "Query": "select music.col3 as c, weight_string(music.col3) from music where music.id = :user_id", + "Table": "music", + "Values": [ + ":user_id" + ], + "Vindex": "music_user_map" } ] } @@ -727,58 +718,49 @@ "QueryType": "SELECT", "Original": "select user.col1 as a, user.col2, music.col3 from user join music on user.id = music.id where user.id = 1 order by 1 asc, 3 desc, 2 asc", "Instructions": { - "OperatorType": "SimpleProjection", - "Columns": [ - 0, - 1, - 2 - ], + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(0|3) ASC, (2|4) DESC, (1|5) ASC", + "ResultColumns": 3, "Inputs": [ { - "OperatorType": "Sort", - "Variant": "Memory", - "OrderBy": "(3|4) ASC, (2|5) DESC, (1|6) ASC", + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,L:1,R:0,L:2,R:1,L:3", + "JoinVars": { + "user_id": 4 + }, + "TableName": "`user`_music", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1,R:0,L:2,L:3,R:1,L:4", - "JoinVars": { - "user_id": 5 + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true }, - "TableName": "`user`_music", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col1 as a, `user`.col2, a, weight_string(`user`.col1), weight_string(`user`.col2), `user`.id from `user` where 1 != 1", - "Query": "select `user`.col1 as a, `user`.col2, a, weight_string(`user`.col1), weight_string(`user`.col2), `user`.id from `user` where `user`.id = 1", - "Table": "`user`", - "Values": [ - "INT64(1)" - ], - "Vindex": "user_index" - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select music.col3, weight_string(music.col3) from music where 1 != 1", - "Query": "select music.col3, weight_string(music.col3) from music where music.id = :user_id", - "Table": "music", - "Values": [ - ":user_id" - ], - "Vindex": "music_user_map" - } - ] + "FieldQuery": "select `user`.col1 as a, `user`.col2, weight_string(`user`.col1), weight_string(`user`.col2), `user`.id from `user` where 1 != 1", + "Query": "select `user`.col1 as a, `user`.col2, weight_string(`user`.col1), weight_string(`user`.col2), `user`.id from `user` where `user`.id = 1", + "Table": "`user`", + "Values": [ + "INT64(1)" + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.col3, weight_string(music.col3) from music where 1 != 1", + "Query": "select music.col3, weight_string(music.col3) from music where music.id = :user_id", + "Table": "music", + "Values": [ + ":user_id" + ], + "Vindex": "music_user_map" } ] } @@ -839,47 +821,38 @@ "QueryType": "SELECT", "Original": "select u.a, u.textcol1, un.col2 from user u join unsharded un order by u.textcol1, un.col2", "Instructions": { - "OperatorType": "SimpleProjection", - "Columns": [ - 0, - 1, - 2 - ], + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "1 ASC COLLATE latin1_swedish_ci, (2|3) ASC", + "ResultColumns": 3, "Inputs": [ { - "OperatorType": "Sort", - "Variant": "Memory", - "OrderBy": "1 ASC COLLATE latin1_swedish_ci, (2|3) ASC", + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,L:1,R:0,R:1", + "TableName": "`user`_unsharded", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1,R:0,R:1", - "TableName": "`user`_unsharded", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select u.a, u.textcol1 from `user` as u where 1 != 1", - "Query": "select u.a, u.textcol1 from `user` as u", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Unsharded", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "select un.col2, weight_string(un.col2) from unsharded as un where 1 != 1", - "Query": "select un.col2, weight_string(un.col2) from unsharded as un", - "Table": "unsharded" - } - ] + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.a, u.textcol1 from `user` as u where 1 != 1", + "Query": "select u.a, u.textcol1 from `user` as u", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select un.col2, weight_string(un.col2) from unsharded as un where 1 != 1", + "Query": "select un.col2, weight_string(un.col2) from unsharded as un", + "Table": "unsharded" } ] } @@ -940,47 +913,38 @@ "QueryType": "SELECT", "Original": "select u.a, u.textcol1, un.col2 from unsharded un join user u order by u.textcol1, un.col2", "Instructions": { - "OperatorType": "SimpleProjection", - "Columns": [ - 0, - 1, - 2 - ], + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "1 ASC COLLATE latin1_swedish_ci, (2|3) ASC", + "ResultColumns": 3, "Inputs": [ { - "OperatorType": "Sort", - "Variant": "Memory", - "OrderBy": "1 ASC COLLATE latin1_swedish_ci, (2|3) ASC", + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "R:0,R:1,L:0,L:1", + "TableName": "unsharded_`user`", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "R:0,R:1,L:0,L:1", - "TableName": "unsharded_`user`", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Unsharded", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "select un.col2, weight_string(un.col2) from unsharded as un where 1 != 1", - "Query": "select un.col2, weight_string(un.col2) from unsharded as un", - "Table": "unsharded" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select u.a, u.textcol1 from `user` as u where 1 != 1", - "Query": "select u.a, u.textcol1 from `user` as u", - "Table": "`user`" - } - ] + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select un.col2, weight_string(un.col2) from unsharded as un where 1 != 1", + "Query": "select un.col2, weight_string(un.col2) from unsharded as un", + "Table": "unsharded" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.a, u.textcol1 from `user` as u where 1 != 1", + "Query": "select u.a, u.textcol1 from `user` as u", + "Table": "`user`" } ] } diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json index 855e470cda4..0efd31d29b8 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json @@ -2626,22 +2626,28 @@ "QueryType": "SELECT", "Original": "select count(id), num from user order by 2", "Instructions": { - "OperatorType": "Aggregate", - "Variant": "Scalar", - "Aggregates": "sum_count(0) AS count(id), random(1) AS num", + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(1|2) ASC", "ResultColumns": 2, "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select count(id), num, weight_string(num) from `user` where 1 != 1", - "OrderBy": "(1|2) ASC", - "Query": "select count(id), num, weight_string(num) from `user` order by num asc", - "Table": "`user`" + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum_count(0) AS count(id), random(1) AS num, random(2)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(id), num, weight_string(num) from `user` where 1 != 1", + "Query": "select count(id), num, weight_string(num) from `user`", + "Table": "`user`" + } + ] } ] }, @@ -2855,23 +2861,30 @@ "QueryType": "SELECT", "Original": "select col, count(*) from user group by col order by c1", "Instructions": { - "OperatorType": "Aggregate", - "Variant": "Ordered", - "Aggregates": "sum_count_star(1) AS count(*), random(2) AS c1", - "GroupBy": "0", + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(2|3) ASC", "ResultColumns": 2, "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col, count(*), c1, weight_string(c1) from `user` where 1 != 1 group by col", - "OrderBy": "(2|3) ASC, 0 ASC", - "Query": "select col, count(*), c1, weight_string(c1) from `user` group by col order by c1 asc, col asc", - "Table": "`user`" + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum_count_star(1) AS count(*), random(2) AS c1, random(3)", + "GroupBy": "0", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col, count(*), c1, weight_string(c1) from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col, count(*), c1, weight_string(c1) from `user` group by col order by col asc", + "Table": "`user`" + } + ] } ] }, @@ -3224,48 +3237,41 @@ "QueryType": "SELECT", "Original": "select id from user, user_extra order by coalesce(user.col, user_extra.col)", "Instructions": { - "OperatorType": "SimpleProjection", - "Columns": [ - 0 - ], + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(1|2) ASC", + "ResultColumns": 1, "Inputs": [ { - "OperatorType": "Sort", - "Variant": "Memory", - "OrderBy": "(1|2) ASC", + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,R:1", + "JoinVars": { + "user_col": 1 + }, + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0,R:0,R:1", - "JoinVars": { - "user_col": 1 + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true }, - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select id, `user`.col from `user` where 1 != 1", - "Query": "select id, `user`.col from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select coalesce(:user_col, user_extra.col), weight_string(coalesce(:user_col, user_extra.col)) from user_extra where 1 != 1", - "Query": "select coalesce(:user_col, user_extra.col), weight_string(coalesce(:user_col, user_extra.col)) from user_extra", - "Table": "user_extra" - } - ] + "FieldQuery": "select id, `user`.col from `user` where 1 != 1", + "Query": "select id, `user`.col from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select coalesce(:user_col, user_extra.col), weight_string(coalesce(:user_col, user_extra.col)) from user_extra where 1 != 1", + "Query": "select coalesce(:user_col, user_extra.col), weight_string(coalesce(:user_col, user_extra.col)) from user_extra", + "Table": "user_extra" } ] } diff --git a/go/vt/vtgate/planbuilder/testdata/rails_cases.json b/go/vt/vtgate/planbuilder/testdata/rails_cases.json index 4af17d59204..94ed9961b87 100644 --- a/go/vt/vtgate/planbuilder/testdata/rails_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/rails_cases.json @@ -177,7 +177,7 @@ "Sharded": true }, "FieldQuery": "select 1 from book6s_order2s where 1 != 1", - "Query": "select 1 from book6s_order2s where book6s_order2s.book6_id = :book6s_id and book6s_order2s.order2_id = :order2s_id", + "Query": "select 1 from book6s_order2s where book6s_order2s.order2_id = :order2s_id and book6s_order2s.book6_id = :book6s_id", "Table": "book6s_order2s", "Values": [ ":book6s_id" diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index b2c998d55d0..f69f0c32ed4 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -3,7 +3,7 @@ "comment": "TPC-H query 1", "query": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, avg(l_quantity) as avg_qty, avg(l_extendedprice) as avg_price, avg(l_discount) as avg_disc, count(*) as count_order from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus order by l_returnflag, l_linestatus", "v3-plan": "VT12001: unsupported: in scatter query: complex aggregate expression", - "gen4-plan": "VT12001: unsupported: in scatter query: aggregation function 'avg'" + "gen4-plan": "VT12001: unsupported: in scatter query: aggregation function 'avg(l_quantity) as avg_qty'" }, { "comment": "TPC-H query 2", @@ -369,7 +369,7 @@ "Sharded": true }, "FieldQuery": "select s_nationkey, count(*), weight_string(s_nationkey) from supplier where 1 != 1 group by s_nationkey, weight_string(s_nationkey)", - "Query": "select s_nationkey, count(*), weight_string(s_nationkey) from supplier where s_suppkey = :l_suppkey and s_nationkey = :c_nationkey group by s_nationkey, weight_string(s_nationkey)", + "Query": "select s_nationkey, count(*), weight_string(s_nationkey) from supplier where s_nationkey = :c_nationkey and s_suppkey = :l_suppkey group by s_nationkey, weight_string(s_nationkey)", "Table": "supplier", "Values": [ ":l_suppkey" @@ -654,7 +654,7 @@ "Sharded": true }, "FieldQuery": "select n2.n_name as cust_nation, count(*), weight_string(n2.n_name), weight_string(n2.n_name) from nation as n2 where 1 != 1 group by cust_nation, weight_string(cust_nation)", - "Query": "select n2.n_name as cust_nation, count(*), weight_string(n2.n_name), weight_string(n2.n_name) from nation as n2 where n2.n_nationkey = :c_nationkey and (:n1_n_name = 'FRANCE' and n2.n_name = 'GERMANY' or :n1_n_name = 'GERMANY' and n2.n_name = 'FRANCE') group by cust_nation, weight_string(cust_nation)", + "Query": "select n2.n_name as cust_nation, count(*), weight_string(n2.n_name), weight_string(n2.n_name) from nation as n2 where (:n1_n_name = 'FRANCE' and n2.n_name = 'GERMANY' or :n1_n_name = 'GERMANY' and n2.n_name = 'FRANCE') and n2.n_nationkey = :c_nationkey group by cust_nation, weight_string(cust_nation)", "Table": "nation", "Values": [ ":c_nationkey" @@ -1416,7 +1416,7 @@ "Sharded": true }, "FieldQuery": "select 1, count(*) as numwait from orders where 1 != 1 group by 1", - "Query": "select 1, count(*) as numwait from orders where o_orderstatus = 'F' and exists (select 1 from lineitem as l2 where l2.l_orderkey = l1.l_orderkey and l2.l_suppkey != l1.l_suppkey limit 1) and not exists (select 1 from lineitem as l3 where l3.l_orderkey = l1.l_orderkey and l3.l_suppkey != l1.l_suppkey and l3.l_receiptdate > l3.l_commitdate limit 1) and o_orderkey = :l1_l_orderkey group by 1", + "Query": "select 1, count(*) as numwait from orders where o_orderstatus = 'F' and o_orderkey = :l1_l_orderkey and exists (select 1 from lineitem as l2 where l2.l_orderkey = l1.l_orderkey and l2.l_suppkey != l1.l_suppkey limit 1) and not exists (select 1 from lineitem as l3 where l3.l_orderkey = l1.l_orderkey and l3.l_suppkey != l1.l_suppkey and l3.l_receiptdate > l3.l_commitdate limit 1) group by 1", "Table": "orders", "Values": [ ":l1_l_orderkey" @@ -1442,7 +1442,7 @@ "Sharded": true }, "FieldQuery": "select s_nationkey, count(*) as numwait, weight_string(s_nationkey), s_name, weight_string(s_name) from supplier where 1 != 1 group by s_nationkey, weight_string(s_nationkey), s_name, weight_string(s_name)", - "Query": "select s_nationkey, count(*) as numwait, weight_string(s_nationkey), s_name, weight_string(s_name) from supplier where exists (select 1 from lineitem as l2 where l2.l_orderkey = l1.l_orderkey and l2.l_suppkey != l1.l_suppkey limit 1) and not exists (select 1 from lineitem as l3 where l3.l_orderkey = l1.l_orderkey and l3.l_suppkey != l1.l_suppkey and l3.l_receiptdate > l3.l_commitdate limit 1) and s_suppkey = :l1_l_suppkey group by s_nationkey, weight_string(s_nationkey), s_name, weight_string(s_name)", + "Query": "select s_nationkey, count(*) as numwait, weight_string(s_nationkey), s_name, weight_string(s_name) from supplier where s_suppkey = :l1_l_suppkey and exists (select 1 from lineitem as l2 where l2.l_orderkey = l1.l_orderkey and l2.l_suppkey != l1.l_suppkey limit 1) and not exists (select 1 from lineitem as l3 where l3.l_orderkey = l1.l_orderkey and l3.l_suppkey != l1.l_suppkey and l3.l_receiptdate > l3.l_commitdate limit 1) group by s_nationkey, weight_string(s_nationkey), s_name, weight_string(s_name)", "Table": "supplier", "Values": [ ":l1_l_suppkey" @@ -1457,7 +1457,7 @@ "Sharded": true }, "FieldQuery": "select 1, count(*) as numwait from nation where 1 != 1 group by 1", - "Query": "select 1, count(*) as numwait from nation where n_name = 'SAUDI ARABIA' and exists (select 1 from lineitem as l2 where l2.l_orderkey = l1.l_orderkey and l2.l_suppkey != l1.l_suppkey limit 1) and not exists (select 1 from lineitem as l3 where l3.l_orderkey = l1.l_orderkey and l3.l_suppkey != l1.l_suppkey and l3.l_receiptdate > l3.l_commitdate limit 1) and n_nationkey = :s_nationkey group by 1", + "Query": "select 1, count(*) as numwait from nation where n_name = 'SAUDI ARABIA' and n_nationkey = :s_nationkey and exists (select 1 from lineitem as l2 where l2.l_orderkey = l1.l_orderkey and l2.l_suppkey != l1.l_suppkey limit 1) and not exists (select 1 from lineitem as l3 where l3.l_orderkey = l1.l_orderkey and l3.l_suppkey != l1.l_suppkey and l3.l_receiptdate > l3.l_commitdate limit 1) group by 1", "Table": "nation", "Values": [ ":s_nationkey" diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index e9cc1c9db96..d88bcb93ff3 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -68,12 +68,6 @@ "v3-plan": "VT12001: unsupported: '*' expression in cross-shard query", "gen4-plan": "cannot use column offsets in group statement when using `*`" }, - { - "comment": "complex group by expression", - "query": "select a from user group by a+1", - "v3-plan": "VT12001: unsupported: in scatter query: only simple references are allowed", - "gen4-plan": "VT13001: [BUG] in scatter query: complex ORDER BY expression: a + 1" - }, { "comment": "Complex aggregate expression on scatter", "query": "select 1+count(*) from user", @@ -85,23 +79,11 @@ "v3-plan": "VT12001: unsupported: only one expression is allowed inside aggregates: count(a, b)", "gen4-plan": "VT03001: aggregate functions take a single argument 'count(a, b)'" }, - { - "comment": "scatter aggregate complex order by", - "query": "select id from user group by id order by id+1", - "v3-plan": "VT12001: unsupported: in scatter query: complex ORDER BY expression: id + 1", - "gen4-plan": "VT13001: [BUG] in scatter query: complex ORDER BY expression: id + 1" - }, - { - "comment": "Scatter order by is complex with aggregates in select", - "query": "select col, count(*) from user group by col order by col+1", - "v3-plan": "VT12001: unsupported: in scatter query: complex ORDER BY expression: col + 1", - "gen4-plan": "VT13001: [BUG] in scatter query: complex ORDER BY expression: col + 1" - }, { "comment": "Aggregate detection (group_concat)", "query": "select group_concat(user.a) from user join user_extra", "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", - "gen4-plan": "VT12001: unsupported: in scatter query: aggregation function 'group_concat'" + "gen4-plan": "VT12001: unsupported: in scatter query: aggregation function 'group_concat(`user`.a)'" }, { "comment": "subqueries not supported in group by", @@ -391,7 +373,7 @@ "comment": "avg function on scatter query", "query": "select avg(id) from user", "v3-plan": "VT12001: unsupported: in scatter query: complex aggregate expression", - "gen4-plan": "VT12001: unsupported: in scatter query: aggregation function 'avg'" + "gen4-plan": "VT12001: unsupported: in scatter query: aggregation function 'avg(id)'" }, { "comment": "scatter aggregate with ambiguous aliases", @@ -492,5 +474,29 @@ "comment": "Assignment expression in delete statement", "query": "delete from user where x = (@val := 42)", "plan": "VT12001: unsupported: Assignment expression" + }, + { + "comment": "grouping column could be coming from multiple sides", + "query": "select count(*) from user, user_extra group by user.id+user_extra.id", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": "VT12001: unsupported: grouping on columns from different sources" + }, + { + "comment": "aggregate on input from both sides", + "query": "select sum(user.foo+user_extra.bar) from user, user_extra", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": "VT12001: unsupported: aggregation on columns from different sources" + }, + { + "comment": "combine the output of two aggregations in the final result", + "query": "select greatest(sum(user.foo), sum(user_extra.bar)) from user join user_extra on user.col = user_extra.col", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": "VT12001: unsupported: in scatter query: complex aggregate expression" + }, + { + "comment": "extremum on input from both sides", + "query": "select max(u.foo*ue.bar) from user u join user_extra ue", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": "VT12001: unsupported: aggregation on columns from different sources: max(u.foo * ue.bar)" } ] diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index acbec4c8df2..887c5fef067 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -22,10 +22,9 @@ import ( "vitess.io/vitess/go/vt/key" topodatapb "vitess.io/vitess/go/vt/proto/topodata" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/vindexes" - - "vitess.io/vitess/go/vt/sqlparser" ) type ( @@ -446,6 +445,24 @@ func (st *SemTable) EqualsExpr(a, b sqlparser.Expr) bool { return st.ASTEquals().Expr(a, b) } +// EqualsExprWithDeps compares two expressions taking into account their semantic +// information. Dependency data typically pertains only to column expressions, +// this method considers them for all expression types. The method checks +// if dependency information exists for both expressions. If it does, the dependencies +// must match. If we are missing dependency information for either +func (st *SemTable) EqualsExprWithDeps(a, b sqlparser.Expr) bool { + eq := st.ASTEquals().Expr(a, b) + if !eq { + return false + } + adeps := st.DirectDeps(a) + bdeps := st.DirectDeps(b) + if adeps.IsEmpty() || bdeps.IsEmpty() || adeps == bdeps { + return true + } + return false +} + func (st *SemTable) ContainsExpr(e sqlparser.Expr, expres []sqlparser.Expr) bool { for _, expre := range expres { if st.EqualsExpr(e, expre) {