diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 144ae451dc..466814017e 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -146,7 +146,8 @@ func TestSpatialQueriesPrepared(t *testing.T, harness Harness) { // TestJoinQueries tests join queries against a provided harness. func TestJoinQueries(t *testing.T, harness Harness) { - harness.Setup(setup.MydbData, setup.MytableData, setup.Pk_tablesData, setup.OthertableData, setup.NiltableData, setup.XyData, setup.FooData, setup.Comp_index_tablesData) + harness.Setup(setup.MydbData, setup.MytableData, setup.Pk_tablesData, setup.OthertableData, setup.NiltableData, + setup.XyData, setup.FooData, setup.Comp_index_tablesData, setup.EmptytableData) e, err := harness.NewEngine(t) require.NoError(t, err) @@ -261,7 +262,8 @@ func TestQueriesPrepared(t *testing.T, harness Harness) { // TestJoinQueriesPrepared tests join queries as prepared statements against a provided harness. func TestJoinQueriesPrepared(t *testing.T, harness Harness) { - harness.Setup(setup.MydbData, setup.MytableData, setup.Pk_tablesData, setup.OthertableData, setup.NiltableData, setup.XyData, setup.FooData, setup.Comp_index_tablesData) + harness.Setup(setup.MydbData, setup.MytableData, setup.Pk_tablesData, setup.OthertableData, setup.NiltableData, + setup.XyData, setup.FooData, setup.Comp_index_tablesData, setup.EmptytableData) for _, tt := range queries.JoinQueryTests { if tt.Skip || tt.SkipPrepared { continue diff --git a/enginetest/queries/join_queries.go b/enginetest/queries/join_queries.go index 7751c8ebeb..430824aa0b 100644 --- a/enginetest/queries/join_queries.go +++ b/enginetest/queries/join_queries.go @@ -801,6 +801,61 @@ on w = 0;`, Query: "select * from comp_index_t0 a join comp_index_t0 b join comp_index_t0 c on a.v2 = b.pk and b.v2 = c.pk and c.v2 = 5", Expected: []sql.Row{}, }, + { + Query: "select * from mytable join othertable on 3 >= 2 where mytable.i = 1 order by othertable.i2", + Expected: []sql.Row{ + {1, "first row", "third", 1}, + {1, "first row", "second", 2}, + {1, "first row", "first", 3}, + }, + }, + { + Query: "select * from mytable join othertable on 3 < 2", + Expected: []sql.Row{}, + }, + { + Query: "select * from mytable left join emptytable on 3 >= 2", + Expected: []sql.Row{ + {1, "first row", nil, nil}, + {2, "second row", nil, nil}, + {3, "third row", nil, nil}, + }, + }, + { + Query: "select * from mytable left join othertable on 3 < 2", + Expected: []sql.Row{ + {1, "first row", nil, nil}, + {2, "second row", nil, nil}, + {3, "third row", nil, nil}, + }, + }, + { + Query: "select * from emptytable right join mytable on 3 >= 2", + Expected: []sql.Row{ + {nil, nil, 1, "first row"}, + {nil, nil, 2, "second row"}, + {nil, nil, 3, "third row"}, + }, + }, + { + Query: "select * from othertable right join mytable on 3 < 2", + Expected: []sql.Row{ + {nil, nil, 1, "first row"}, + {nil, nil, 2, "second row"}, + {nil, nil, 3, "third row"}, + }, + }, + { + Query: "select * from mytable full outer join othertable on 3 < 2", + Expected: []sql.Row{ + {1, "first row", nil, nil}, + {2, "second row", nil, nil}, + {3, "third row", nil, nil}, + {nil, nil, "first", 3}, + {nil, nil, "second", 2}, + {nil, nil, "third", 1}, + }, + }, } var JoinScriptTests = []ScriptTest{ diff --git a/enginetest/queries/query_plans.go b/enginetest/queries/query_plans.go index dd9e0be18e..500a6275d4 100644 --- a/enginetest/queries/query_plans.go +++ b/enginetest/queries/query_plans.go @@ -18598,19 +18598,15 @@ inner join pq on true " ├─ cmp: Eq\n" + " │ ├─ a.i:0!null\n" + " │ └─ b.i:2!null\n" + - " ├─ Filter\n" + - " │ ├─ GreaterThanOrEqual\n" + - " │ │ ├─ NOW()\n" + - " │ │ └─ coalesce(NULL (null),NULL (null),NOW())\n" + - " │ └─ TableAlias(a)\n" + - " │ └─ IndexedTableAccess(mytable)\n" + - " │ ├─ index: [mytable.i,mytable.s]\n" + - " │ ├─ static: [{[NULL, ∞), [NULL, ∞)}]\n" + - " │ ├─ colSet: (1,2)\n" + - " │ ├─ tableId: 1\n" + - " │ └─ Table\n" + - " │ ├─ name: mytable\n" + - " │ └─ columns: [i s]\n" + + " ├─ TableAlias(a)\n" + + " │ └─ IndexedTableAccess(mytable)\n" + + " │ ├─ index: [mytable.i,mytable.s]\n" + + " │ ├─ static: [{[NULL, ∞), [NULL, ∞)}]\n" + + " │ ├─ colSet: (1,2)\n" + + " │ ├─ tableId: 1\n" + + " │ └─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + " └─ TableAlias(b)\n" + " └─ IndexedTableAccess(mytable)\n" + " ├─ index: [mytable.i]\n" + @@ -18623,15 +18619,13 @@ inner join pq on true "", ExpectedEstimates: "Project\n" + " ├─ columns: [a.i, a.s]\n" + - " └─ MergeJoin (estimated cost=5.070 rows=2)\n" + + " └─ MergeJoin (estimated cost=6.090 rows=3)\n" + " ├─ cmp: (a.i = b.i)\n" + - " ├─ Filter\n" + - " │ ├─ (NOW() >= coalesce(NULL,NULL,NOW()))\n" + - " │ └─ TableAlias(a)\n" + - " │ └─ IndexedTableAccess(mytable)\n" + - " │ ├─ index: [mytable.i,mytable.s]\n" + - " │ ├─ filters: [{[NULL, ∞), [NULL, ∞)}]\n" + - " │ └─ columns: [i s]\n" + + " ├─ TableAlias(a)\n" + + " │ └─ IndexedTableAccess(mytable)\n" + + " │ ├─ index: [mytable.i,mytable.s]\n" + + " │ ├─ filters: [{[NULL, ∞), [NULL, ∞)}]\n" + + " │ └─ columns: [i s]\n" + " └─ TableAlias(b)\n" + " └─ IndexedTableAccess(mytable)\n" + " ├─ index: [mytable.i]\n" + @@ -18640,15 +18634,13 @@ inner join pq on true "", ExpectedAnalysis: "Project\n" + " ├─ columns: [a.i, a.s]\n" + - " └─ MergeJoin (estimated cost=5.070 rows=2) (actual rows=3 loops=1)\n" + + " └─ MergeJoin (estimated cost=6.090 rows=3) (actual rows=3 loops=1)\n" + " ├─ cmp: (a.i = b.i)\n" + - " ├─ Filter\n" + - " │ ├─ (NOW() >= coalesce(NULL,NULL,NOW()))\n" + - " │ └─ TableAlias(a)\n" + - " │ └─ IndexedTableAccess(mytable)\n" + - " │ ├─ index: [mytable.i,mytable.s]\n" + - " │ ├─ filters: [{[NULL, ∞), [NULL, ∞)}]\n" + - " │ └─ columns: [i s]\n" + + " ├─ TableAlias(a)\n" + + " │ └─ IndexedTableAccess(mytable)\n" + + " │ ├─ index: [mytable.i,mytable.s]\n" + + " │ ├─ filters: [{[NULL, ∞), [NULL, ∞)}]\n" + + " │ └─ columns: [i s]\n" + " └─ TableAlias(b)\n" + " └─ IndexedTableAccess(mytable)\n" + " ├─ index: [mytable.i]\n" + @@ -23850,11 +23842,9 @@ where i = a order by i;`, ExpectedPlan: "Sort(mytable.i:0!null ASC nullsFirst)\n" + " └─ LateralCrossJoin\n" + - " ├─ AND\n" + - " │ ├─ Eq\n" + - " │ │ ├─ mytable.i:0!null\n" + - " │ │ └─ sqa.a:2\n" + - " │ └─ true (tinyint(1))\n" + + " ├─ Eq\n" + + " │ ├─ mytable.i:0!null\n" + + " │ └─ sqa.a:2\n" + " ├─ ProcessTable\n" + " │ └─ Table\n" + " │ ├─ name: mytable\n" + @@ -23889,7 +23879,7 @@ order by i;`, "", ExpectedEstimates: "Sort(mytable.i ASC)\n" + " └─ LateralCrossJoin (estimated cost=302.000 rows=3)\n" + - " ├─ ((mytable.i = sqa.a) AND true)\n" + + " ├─ (mytable.i = sqa.a)\n" + " ├─ Table\n" + " │ └─ name: mytable\n" + " └─ CachedResults\n" + @@ -23914,7 +23904,7 @@ order by i;`, "", ExpectedAnalysis: "Sort(mytable.i ASC)\n" + " └─ LateralCrossJoin (estimated cost=302.000 rows=3) (actual rows=3 loops=1)\n" + - " ├─ ((mytable.i = sqa.a) AND true)\n" + + " ├─ (mytable.i = sqa.a)\n" + " ├─ Table\n" + " │ └─ name: mytable\n" + " └─ CachedResults\n" + @@ -23961,11 +23951,9 @@ where i = a order by i;`, ExpectedPlan: "Sort(mytable.i:0!null ASC nullsFirst)\n" + " └─ LateralCrossJoin\n" + - " ├─ AND\n" + - " │ ├─ Eq\n" + - " │ │ ├─ mytable.i:0!null\n" + - " │ │ └─ sqa2.a:2\n" + - " │ └─ true (tinyint(1))\n" + + " ├─ Eq\n" + + " │ ├─ mytable.i:0!null\n" + + " │ └─ sqa2.a:2\n" + " ├─ ProcessTable\n" + " │ └─ Table\n" + " │ ├─ name: mytable\n" + @@ -24011,7 +23999,7 @@ order by i;`, "", ExpectedEstimates: "Sort(mytable.i ASC)\n" + " └─ LateralCrossJoin (estimated cost=302.000 rows=3)\n" + - " ├─ ((mytable.i = sqa2.a) AND true)\n" + + " ├─ (mytable.i = sqa2.a)\n" + " ├─ Table\n" + " │ └─ name: mytable\n" + " └─ CachedResults\n" + @@ -24045,7 +24033,7 @@ order by i;`, "", ExpectedAnalysis: "Sort(mytable.i ASC)\n" + " └─ LateralCrossJoin (estimated cost=302.000 rows=3) (actual rows=2 loops=1)\n" + - " ├─ ((mytable.i = sqa2.a) AND true)\n" + + " ├─ (mytable.i = sqa2.a)\n" + " ├─ Table\n" + " │ └─ name: mytable\n" + " └─ CachedResults\n" + @@ -25719,4 +25707,256 @@ order by x, y; " └─ keys: p18.id\n" + "", }, + { + Query: `select * from mytable where 3 < 2`, + ExpectedPlan: "EmptyTable\n" + + "", + ExpectedEstimates: "EmptyTable\n" + + "", + ExpectedAnalysis: "EmptyTable\n" + + "", + }, + { + Query: `select * from mytable join othertable on 3 > 2`, + ExpectedPlan: "Project\n" + + " ├─ columns: [mytable.i:2!null, mytable.s:3!null, othertable.s2:0!null, othertable.i2:1!null]\n" + + " └─ CrossJoin\n" + + " ├─ ProcessTable\n" + + " │ └─ Table\n" + + " │ ├─ name: othertable\n" + + " │ └─ columns: [s2 i2]\n" + + " └─ ProcessTable\n" + + " └─ Table\n" + + " ├─ name: mytable\n" + + " └─ columns: [i s]\n" + + "", + ExpectedEstimates: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ CrossJoin (estimated cost=10.090 rows=3)\n" + + " ├─ Table\n" + + " │ ├─ name: othertable\n" + + " │ └─ columns: [s2 i2]\n" + + " └─ Table\n" + + " ├─ name: mytable\n" + + " └─ columns: [i s]\n" + + "", + ExpectedAnalysis: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ CrossJoin (estimated cost=10.090 rows=3) (actual rows=9 loops=1)\n" + + " ├─ Table\n" + + " │ ├─ name: othertable\n" + + " │ └─ columns: [s2 i2]\n" + + " └─ Table\n" + + " ├─ name: mytable\n" + + " └─ columns: [i s]\n" + + "", + }, + { + Query: `select * from mytable join othertable on 3 < 2`, + ExpectedPlan: "EmptyTable\n" + + "", + ExpectedEstimates: "EmptyTable\n" + + "", + ExpectedAnalysis: "EmptyTable\n" + + "", + }, + { + Query: `select * from mytable left join othertable on 3 > 2`, + ExpectedPlan: "Project\n" + + " ├─ columns: [mytable.i:0!null, mytable.s:1!null, othertable.s2:2!null, othertable.i2:3!null]\n" + + " └─ LeftOuterJoin\n" + + " ├─ ProcessTable\n" + + " │ └─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ ProcessTable\n" + + " └─ Table\n" + + " ├─ name: othertable\n" + + " └─ columns: [s2 i2]\n" + + "", + ExpectedEstimates: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ LeftOuterJoin (estimated cost=8.090 rows=3)\n" + + " ├─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ Table\n" + + " ├─ name: othertable\n" + + " └─ columns: [s2 i2]\n" + + "", + ExpectedAnalysis: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ LeftOuterJoin (estimated cost=8.090 rows=3) (actual rows=9 loops=1)\n" + + " ├─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ Table\n" + + " ├─ name: othertable\n" + + " └─ columns: [s2 i2]\n" + + "", + }, + { + Query: `select * from mytable left join othertable on 3 < 2`, + ExpectedPlan: "Project\n" + + " ├─ columns: [mytable.i:0!null, mytable.s:1!null, othertable.s2:2!null, othertable.i2:3!null]\n" + + " └─ LeftOuterJoin\n" + + " ├─ ProcessTable\n" + + " │ └─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ EmptyTable\n" + + "", + ExpectedEstimates: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ LeftOuterJoin (estimated cost=2.030 rows=3)\n" + + " ├─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ EmptyTable\n" + + "", + ExpectedAnalysis: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ LeftOuterJoin (estimated cost=2.030 rows=3) (actual rows=3 loops=1)\n" + + " ├─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ EmptyTable\n" + + "", + }, + { + Query: `select * from mytable right join othertable on 3 > 2`, + ExpectedPlan: "Project\n" + + " ├─ columns: [mytable.i:2!null, mytable.s:3!null, othertable.s2:0!null, othertable.i2:1!null]\n" + + " └─ LeftOuterJoin\n" + + " ├─ ProcessTable\n" + + " │ └─ Table\n" + + " │ ├─ name: othertable\n" + + " │ └─ columns: [s2 i2]\n" + + " └─ ProcessTable\n" + + " └─ Table\n" + + " ├─ name: mytable\n" + + " └─ columns: [i s]\n" + + "", + ExpectedEstimates: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ LeftOuterJoin (estimated cost=8.090 rows=3)\n" + + " ├─ Table\n" + + " │ ├─ name: othertable\n" + + " │ └─ columns: [s2 i2]\n" + + " └─ Table\n" + + " ├─ name: mytable\n" + + " └─ columns: [i s]\n" + + "", + ExpectedAnalysis: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ LeftOuterJoin (estimated cost=8.090 rows=3) (actual rows=9 loops=1)\n" + + " ├─ Table\n" + + " │ ├─ name: othertable\n" + + " │ └─ columns: [s2 i2]\n" + + " └─ Table\n" + + " ├─ name: mytable\n" + + " └─ columns: [i s]\n" + + "", + }, + { + Query: `select * from mytable right join othertable on 3 < 2`, + ExpectedPlan: "Project\n" + + " ├─ columns: [mytable.i:2!null, mytable.s:3!null, othertable.s2:0!null, othertable.i2:1!null]\n" + + " └─ LeftOuterJoin\n" + + " ├─ ProcessTable\n" + + " │ └─ Table\n" + + " │ ├─ name: othertable\n" + + " │ └─ columns: [s2 i2]\n" + + " └─ EmptyTable\n" + + "", + ExpectedEstimates: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ LeftOuterJoin (estimated cost=2.030 rows=3)\n" + + " ├─ Table\n" + + " │ ├─ name: othertable\n" + + " │ └─ columns: [s2 i2]\n" + + " └─ EmptyTable\n" + + "", + ExpectedAnalysis: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ LeftOuterJoin (estimated cost=2.030 rows=3) (actual rows=3 loops=1)\n" + + " ├─ Table\n" + + " │ ├─ name: othertable\n" + + " │ └─ columns: [s2 i2]\n" + + " └─ EmptyTable\n" + + "", + }, + { + Query: `select * from mytable full outer join othertable on 3 > 2`, + ExpectedPlan: "Project\n" + + " ├─ columns: [mytable.i:0!null, mytable.s:1!null, othertable.s2:2!null, othertable.i2:3!null]\n" + + " └─ FullOuterJoin\n" + + " ├─ ProcessTable\n" + + " │ └─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ ProcessTable\n" + + " └─ Table\n" + + " ├─ name: othertable\n" + + " └─ columns: [s2 i2]\n" + + "", + ExpectedEstimates: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ FullOuterJoin (estimated cost=16.180 rows=3)\n" + + " ├─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ Table\n" + + " ├─ name: othertable\n" + + " └─ columns: [s2 i2]\n" + + "", + ExpectedAnalysis: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ FullOuterJoin (estimated cost=16.180 rows=3) (actual rows=9 loops=1)\n" + + " ├─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ Table\n" + + " ├─ name: othertable\n" + + " └─ columns: [s2 i2]\n" + + "", + }, + { + Query: `select * from mytable full outer join othertable on 3 < 2`, + ExpectedPlan: "Project\n" + + " ├─ columns: [mytable.i:0!null, mytable.s:1!null, othertable.s2:2!null, othertable.i2:3!null]\n" + + " └─ FullOuterJoin\n" + + " ├─ false (tinyint(1))\n" + + " ├─ ProcessTable\n" + + " │ └─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ ProcessTable\n" + + " └─ Table\n" + + " ├─ name: othertable\n" + + " └─ columns: [s2 i2]\n" + + "", + ExpectedEstimates: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ FullOuterJoin (estimated cost=16.180 rows=3)\n" + + " ├─ false\n" + + " ├─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ Table\n" + + " ├─ name: othertable\n" + + " └─ columns: [s2 i2]\n" + + "", + ExpectedAnalysis: "Project\n" + + " ├─ columns: [mytable.i, mytable.s, othertable.s2, othertable.i2]\n" + + " └─ FullOuterJoin (estimated cost=16.180 rows=3) (actual rows=6 loops=1)\n" + + " ├─ false\n" + + " ├─ Table\n" + + " │ ├─ name: mytable\n" + + " │ └─ columns: [i s]\n" + + " └─ Table\n" + + " ├─ name: othertable\n" + + " └─ columns: [s2 i2]\n" + + "", + }, } diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index aa26cd6aa0..dbfee6edfa 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -209,195 +209,269 @@ func expressionSources(expr sql.Expression) (sql.FastIntSet, bool) { return tables, nullRejecting } -// simplifyFilters simplifies the expressions in Filter nodes where possible. This involves removing redundant parts of AND -// and OR expressions, as well as replacing evaluable expressions with their literal result. Filters that can -// statically be determined to be true or false are replaced with the child node or an empty result, respectively. +// simplifyFilters simplifies filter expressions in nodes where possible. Nodes with filter expressions that can be +// statically evaluated to true or false are transformed so that the expression no longer needs to be evaluated. func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { if !node.Resolved() { return node, transform.SameTree, nil } return transform.NodeWithOpaque(node, func(node sql.Node) (sql.Node, transform.TreeIdentity, error) { - filter, ok := node.(*plan.Filter) - if !ok { - return node, transform.SameTree, nil - } - - e, same, err := transform.Expr(filter.Expression, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { - switch e := e.(type) { - case *plan.Subquery: - newQ, same, err := simplifyFilters(ctx, a, e.Query, scope, sel, qFlags) - if same || err != nil { - return e, transform.SameTree, err - } - return e.WithQuery(newQ), transform.NewTree, nil - case *expression.Between: - return expression.NewAnd( - expression.NewGreaterThanOrEqual(e.Val, e.Lower), - expression.NewLessThanOrEqual(e.Val, e.Upper), - ), transform.NewTree, nil - case *expression.Or: - if isTrue(ctx, e.LeftChild) { - return expression.NewLiteral(true, types.Boolean), transform.NewTree, nil + switch n := node.(type) { + case *plan.JoinNode: + if n.Filter != nil { + e, same, err := simplifyExpression(ctx, a, scope, sel, qFlags, n.Filter) + if err != nil { + return nil, transform.SameTree, err } - if isTrue(ctx, e.RightChild) { - return expression.NewLiteral(true, types.Boolean), transform.NewTree, nil + isTrue, isFalse := getDefiniteBoolValues(ctx, e) + joinType := n.JoinType() + if isTrue { + // If the filter always evaluates to true, convert to cross join if possible + switch joinType { + case plan.JoinTypeInner: + return plan.NewCrossJoin(n.Left(), n.Right()), transform.NewTree, nil + case plan.JoinTypeLateralInner: + return plan.NewLateralCrossJoin(n.Left(), n.Right()), transform.NewTree, nil + default: + // Remove filter. Filter does not need to be evaluated if always true + return n.WithFilter(nil), transform.NewTree, nil + } + } else if isFalse { + switch joinType { + case plan.JoinTypeFullOuter: + // Do nothing here. For a full outer join, we still want to return every row of both. + case plan.JoinTypeLeftOuter, plan.JoinTypeLateralLeft: + // In a left join, we still want all rows on the left side. But because the filter is always + // false, it will never match rows on the right side so we can treat it like it's empty + return plan.NewJoin(n.Left(), plan.NewEmptyTableWithSchema(n.Right().Schema()), joinType, nil), transform.NewTree, nil + default: + // For non-outer joins, a join condition that always evaluates to false would return an empty set + return plan.NewEmptyTableWithSchema(n.Schema()), transform.NewTree, nil + } } - if isFalse(ctx, e.LeftChild) && types.IsBoolean(e.RightChild.Type()) { - return e.RightChild, transform.NewTree, nil + if !same { + return n.WithFilter(e), transform.NewTree, nil } - if isFalse(ctx, e.RightChild) && types.IsBoolean(e.LeftChild.Type()) { - return e.LeftChild, transform.NewTree, nil - } + } + case *plan.Filter: + e, same, err := simplifyExpression(ctx, a, scope, sel, qFlags, n.Expression) + if err != nil { + return nil, transform.SameTree, err + } - return e, transform.SameTree, nil - case *expression.And: - if isFalse(ctx, e.LeftChild) { - return expression.NewLiteral(false, types.Boolean), transform.NewTree, nil - } + isTrue, isFalse := getDefiniteBoolValues(ctx, e) + // if the filter always evaluates to true, it can be removed + if isTrue { + return n.Child, transform.NewTree, nil + } + // if the filter always evaluates to false, the result is an empty table + if isFalse { + return plan.NewEmptyTableWithSchema(n.Schema()), transform.NewTree, nil + } - if isFalse(ctx, e.RightChild) { + if !same { + return plan.NewFilter(e, n.Child), transform.NewTree, nil + } + } + return node, transform.SameTree, nil + }) +} + +// simplifyExpressions replaces expressions that can be evaluated statically with their Literal value and removes +// redundant parts of AND and OR expressions. +func simplifyExpression(ctx *sql.Context, a *Analyzer, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags, e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + return transform.Expr(e, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + switch e := e.(type) { + // TODO: if the left and right children of Equals refer to the same field, simplify to true + case *plan.Subquery: + newQ, same, err := simplifyFilters(ctx, a, e.Query, scope, sel, qFlags) + if same || err != nil { + return e, transform.SameTree, err + } + return e.WithQuery(newQ), transform.NewTree, nil + case *expression.Between: + // TODO: if e.Lower and e.Upper refer to the same field, simplify to Equals(e.Val, e.Lower) + // TODO: if e.Val is the same field as e.Lower and e.Upper, simplify to true + + // TODO: Can be evaluated to true/false if e.Val, e.Lower, and e.Upper are all Literals. + // If e.Lower and e.Upper are both Literals: + // If e.Lower > e.Upper, simplify to false. If e.Lower == e.Upper, simplify to Equals(e.Val, e.Lower). + return expression.NewAnd( + expression.NewGreaterThanOrEqual(e.Val, e.Lower), + expression.NewLessThanOrEqual(e.Val, e.Upper), + ), transform.NewTree, nil + case *expression.Or: + leftIsTrue, leftIsFalse := getDefiniteBoolValues(ctx, e.LeftChild) + // if left side is true, the OR expression is true + if leftIsTrue { + return expression.NewLiteral(true, types.Boolean), transform.NewTree, nil + } + + rightIsTrue, rightIsFalse := getDefiniteBoolValues(ctx, e.RightChild) + // if right side is true, the OR expression is true + if rightIsTrue { + return expression.NewLiteral(true, types.Boolean), transform.NewTree, nil + } + + if leftIsFalse { + // if both sides are false, the OR expression is false + if rightIsFalse { return expression.NewLiteral(false, types.Boolean), transform.NewTree, nil } - - if isTrue(ctx, e.LeftChild) && types.IsBoolean(e.RightChild.Type()) { + // if left side is false, the value of the OR expression is determined by the right side + // TODO If RightChild is not a boolean type, it can be returned if converted to a boolean. Nil values + // must be preserved + if types.IsBoolean(e.RightChild.Type()) { return e.RightChild, transform.NewTree, nil } + } - if isTrue(ctx, e.RightChild) && types.IsBoolean(e.LeftChild.Type()) { - return e.LeftChild, transform.NewTree, nil - } + // if right side is false, the value of the OR expression is determined by the left side + // TODO If LeftChild is not a boolean type, it can be returned if converted to a boolean. Nil values must be + // preserved + if rightIsFalse && types.IsBoolean(e.LeftChild.Type()) { + return e.LeftChild, transform.NewTree, nil + } - return e, transform.SameTree, nil - case *expression.Like: - // if the charset is not utf8mb4, the last character used in optimization rule does not work - coll, _ := sql.GetCoercibility(ctx, e.LeftChild) - charset := coll.CharacterSet() - if charset != sql.CharacterSet_utf8mb4 { - return e, transform.SameTree, nil - } - // TODO: maybe more cases to simplify - r, ok := e.RightChild.(*expression.Literal) - if !ok { - return e, transform.SameTree, nil - } - // TODO: handle escapes - if e.Escape != nil { - return e, transform.SameTree, nil - } - val := r.Value() - valStr, ok := val.(string) - if !ok { - return e, transform.SameTree, nil - } - if len(valStr) == 0 { - return e, transform.SameTree, nil - } - // if there are single character wildcards, don't simplify - if strings.Count(valStr, "_")-strings.Count(valStr, "\\_") > 0 { - return e, transform.SameTree, nil - } - // if there are also no multiple character wildcards, this is just a plain equals - numWild := strings.Count(valStr, "%") - strings.Count(valStr, "\\%") - if numWild == 0 { - return expression.NewEquals(e.LeftChild, e.RightChild), transform.NewTree, nil - } - // if there are many multiple character wildcards, don't simplify - if numWild != 1 { - return e, transform.SameTree, nil - } - // if the last character is an escaped multiple character wildcard, don't simplify - if len(valStr) >= 2 && valStr[len(valStr)-2:] == "\\%" { - return e, transform.SameTree, nil - } - if valStr[len(valStr)-1] != '%' { - return e, transform.SameTree, nil - } - // TODO: like expression with just a wild card shouldn't even make it here; analyzer rule should just drop filter - if len(valStr) == 1 { - return e, transform.SameTree, nil + return e, transform.SameTree, nil + case *expression.And: + leftIsTrue, leftIsFalse := getDefiniteBoolValues(ctx, e.LeftChild) + // if left side is false, the AND expression is false + if leftIsFalse { + return expression.NewLiteral(false, types.Boolean), transform.NewTree, nil + } + + rightIsTrue, rightIsFalse := getDefiniteBoolValues(ctx, e.RightChild) + // if right side is false, the AND expression is false + if rightIsFalse { + return expression.NewLiteral(false, types.Boolean), transform.NewTree, nil + } + + if leftIsTrue { + // if both sides are true, the AND expression is true + if rightIsTrue { + return expression.NewLiteral(true, types.Boolean), transform.NewTree, nil } - valStr = valStr[:len(valStr)-1] - newRightLower := expression.NewLiteral(valStr, e.RightChild.Type()) - valStr += string(byte(255)) // append largest possible character as upper bound - newRightUpper := expression.NewLiteral(valStr, e.RightChild.Type()) - newExpr := expression.NewAnd(expression.NewGreaterThanOrEqual(e.LeftChild, newRightLower), expression.NewLessThanOrEqual(e.LeftChild, newRightUpper)) - return newExpr, transform.NewTree, nil - case *expression.Not: - if lit, ok := e.Child.(*expression.Literal); ok { - val, err := sql.ConvertToBool(ctx, lit.Value()) - if err != nil { - // error while converting, keep as is - return e, transform.SameTree, nil - } - return expression.NewLiteral(!val, e.Type()), transform.NewTree, nil + // if left side is true, the value of the AND expression is determined by the right side + // TODO If RightChild is not a boolean type, it can be returned if converted to a boolean. Nil values + // must be preserved + if types.IsBoolean(e.RightChild.Type()) { + return e.RightChild, transform.NewTree, nil } + } + + // if right side is true, the value of the AND expression is determined by the left side + // TODO If LeftChild is not a boolean type, it can be returned if converted to a boolean. Nil values must be + // preserved + if rightIsTrue && types.IsBoolean(e.LeftChild.Type()) { + return e.LeftChild, transform.NewTree, nil + } + + return e, transform.SameTree, nil + case *expression.Like: + // if the charset is not utf8mb4, the last character used in optimization rule does not work + coll, _ := sql.GetCoercibility(ctx, e.LeftChild) + charset := coll.CharacterSet() + if charset != sql.CharacterSet_utf8mb4 { + return e, transform.SameTree, nil + } + // TODO: maybe more cases to simplify + r, ok := e.RightChild.(*expression.Literal) + if !ok { + return e, transform.SameTree, nil + } + // TODO: handle escapes + if e.Escape != nil { return e, transform.SameTree, nil - case *expression.Literal, expression.Tuple, *expression.Interval, *expression.CollatedExpression, *expression.MatchAgainst: + } + val := r.Value() + valStr, ok := val.(string) + if !ok { + return e, transform.SameTree, nil + } + if len(valStr) == 0 { + return e, transform.SameTree, nil + } + // if there are single character wildcards, don't simplify + if strings.Count(valStr, "_")-strings.Count(valStr, "\\_") > 0 { + return e, transform.SameTree, nil + } + // if there are also no multiple character wildcards, this is just a plain equals + numWild := strings.Count(valStr, "%") - strings.Count(valStr, "\\%") + if numWild == 0 { + return expression.NewEquals(e.LeftChild, e.RightChild), transform.NewTree, nil + } + // if there are many multiple character wildcards, don't simplify + if numWild != 1 { return e, transform.SameTree, nil - default: - if !isEvaluable(e) { + } + // if the last character is an escaped multiple character wildcard, don't simplify + if len(valStr) >= 2 && valStr[len(valStr)-2:] == "\\%" { + return e, transform.SameTree, nil + } + if valStr[len(valStr)-1] != '%' { + return e, transform.SameTree, nil + } + // TODO: like expression with just a wild card shouldn't even make it here; analyzer rule should just drop filter + if len(valStr) == 1 { + return e, transform.SameTree, nil + } + valStr = valStr[:len(valStr)-1] + newRightLower := expression.NewLiteral(valStr, e.RightChild.Type()) + valStr += string(byte(255)) // append largest possible character as upper bound + newRightUpper := expression.NewLiteral(valStr, e.RightChild.Type()) + newExpr := expression.NewAnd(expression.NewGreaterThanOrEqual(e.LeftChild, newRightLower), expression.NewLessThanOrEqual(e.LeftChild, newRightUpper)) + return newExpr, transform.NewTree, nil + case *expression.Not: + if lit, ok := e.Child.(*expression.Literal); ok { + val, err := sql.ConvertToBool(ctx, lit.Value()) + if err != nil { + // error while converting, keep as is return e, transform.SameTree, nil } - if conv, ok := e.(*expression.Convert); ok { - if types.IsBinaryType(conv.Type()) { - return e, transform.SameTree, nil - } - } - - // All other expressions types can be evaluated once and turned into literals for the rest of query execution - val, err := e.Eval(ctx, nil) - if err != nil { - return e, transform.SameTree, err + return expression.NewLiteral(!val, types.Boolean), transform.NewTree, nil + } + return e, transform.SameTree, nil + case *expression.Literal, expression.Tuple, *expression.Interval, *expression.CollatedExpression, *expression.MatchAgainst: + return e, transform.SameTree, nil + default: + if !isEvaluable(e) { + return e, transform.SameTree, nil + } + if conv, ok := e.(*expression.Convert); ok { + if types.IsBinaryType(conv.Type()) { + return e, transform.SameTree, nil } - return expression.NewLiteral(val, e.Type()), transform.NewTree, nil } - }) - if err != nil { - return nil, transform.SameTree, err - } - if isFalse(ctx, e) { - emptyTable := plan.NewEmptyTableWithSchema(filter.Schema()) - return emptyTable, transform.NewTree, nil - } - - if isTrue(ctx, e) { - return filter.Child, transform.NewTree, nil - } - - if same { - return filter, transform.SameTree, nil + // All other expressions types can be evaluated once and turned into literals for the rest of query execution + val, err := e.Eval(ctx, nil) + if err != nil { + return e, transform.SameTree, err + } + return expression.NewLiteral(val, e.Type()), transform.NewTree, nil } - return plan.NewFilter(e, filter.Child), transform.NewTree, nil }) } -func isFalse(ctx *sql.Context, e sql.Expression) bool { - lit, ok := e.(*expression.Literal) - if !ok || lit == nil || lit.Value() == nil { - return false - } - val, err := sql.ConvertToBool(ctx, lit.Value()) - if err != nil { - return false - } - return !val -} - -func isTrue(ctx *sql.Context, e sql.Expression) bool { +// getDefiniteBoolValues gets the definite boolean values of an expression. isTrue will only be true if the expression +// is a non-nil Literal that evaluates to true, and isFalse will only be true if the expression is a non-nil Literal +// that evaluates to false. Both return values are necessary since nil values are neither true nor false. We also cannot +// yet evaluate the value of non-Literal expressions so they can neither be definitely true nor false. +func getDefiniteBoolValues(ctx *sql.Context, e sql.Expression) (isTrue, isFalse bool) { lit, ok := e.(*expression.Literal) if !ok || lit == nil || lit.Value() == nil { - return false + return false, false } val, err := sql.ConvertToBool(ctx, lit.Value()) if err != nil { - return false + return false, false } - return val + return val, !val } // pushNotFilters applies De'Morgan's laws to push NOT expressions as low diff --git a/sql/analyzer/triggers.go b/sql/analyzer/triggers.go index 13483666e1..244c0dfc1a 100644 --- a/sql/analyzer/triggers.go +++ b/sql/analyzer/triggers.go @@ -168,10 +168,11 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, } case *plan.Update: if n.IsJoin { - uj := n.Child.(*plan.UpdateJoin) - updateTargets := uj.UpdateTargets - for _, updateTarget := range updateTargets { - affectedTables = append(affectedTables, getTableName(updateTarget)) + if uj, ok := n.Child.(*plan.UpdateJoin); ok { + updateTargets := uj.UpdateTargets + for _, updateTarget := range updateTargets { + affectedTables = append(affectedTables, getTableName(updateTarget)) + } } } else { affectedTables = append(affectedTables, getTableName(n)) diff --git a/sql/plan/empty_table.go b/sql/plan/empty_table.go index d20732e5d2..f18b4d91fb 100644 --- a/sql/plan/empty_table.go +++ b/sql/plan/empty_table.go @@ -87,7 +87,7 @@ func (e *EmptyTable) Schema() sql.Schema { return e.schema } func (*EmptyTable) Children() []sql.Node { return nil } func (*EmptyTable) Resolved() bool { return true } func (*EmptyTable) IsReadOnly() bool { return true } -func (e *EmptyTable) String() string { return "EmptyTable" } +func (e *EmptyTable) String() string { return "EmptyTable\n" } // RowIter implements the sql.Node interface. func (*EmptyTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { diff --git a/sql/plan/join.go b/sql/plan/join.go index cae2f5e70d..eead5f6adf 100644 --- a/sql/plan/join.go +++ b/sql/plan/join.go @@ -439,6 +439,12 @@ func (j *JoinNode) WithComment(comment string) sql.Node { return &ret } +func (j *JoinNode) WithFilter(filter sql.Expression) *JoinNode { + ret := *j + ret.Filter = filter + return &ret +} + var _ sql.Describable = (*JoinNode)(nil) // Describe implements sql.Describable diff --git a/sql/planbuilder/from.go b/sql/planbuilder/from.go index bb71db6b4b..273be77daf 100644 --- a/sql/planbuilder/from.go +++ b/sql/planbuilder/from.go @@ -115,7 +115,7 @@ func (b *Builder) buildJoin(inScope *scope, te *ast.JoinTableExpr) (outScope *sc if b.canConvertToCrossJoin(te) { if rast, ok := te.RightExpr.(*ast.AliasedTableExpr); ok && rast.Lateral { var err error - outScope.node, err = b.f.buildJoin(leftScope.node, rightScope.node, plan.JoinTypeLateralCross, expression.NewLiteral(true, types.Boolean)) + outScope.node, err = b.f.buildJoin(leftScope.node, rightScope.node, plan.JoinTypeLateralCross, nil) if err != nil { b.handleErr(err) }