Skip to content

Commit 06b1bbb

Browse files
wangyumcloud-fan
authored andcommitted
[SPARK-33798][SQL] Add new rule to push down the foldable expressions through CaseWhen/If
### What changes were proposed in this pull request? This pr add a new rule(`PushFoldableIntoBranches`) to push down the foldable expressions through `CaseWhen/If`. This is a real case from production: ```sql create table t1 using parquet as select * from range(100); create table t2 using parquet as select * from range(200); create temp view v1 as select 'a' as event_type, * from t1 union all select CASE WHEN id = 1 THEN 'b' WHEN id = 3 THEN 'c' end as event_type, * from t2 explain select * from v1 where event_type = 'a'; ``` Before this PR: ``` == Physical Plan == Union :- *(1) Project [a AS event_type#30533, id#30535L] : +- *(1) ColumnarToRow : +- FileScan parquet default.t1[id#30535L] Batched: true, DataFilters: [], Format: Parquet +- *(2) Project [CASE WHEN (id#30536L = 1) THEN b WHEN (id#30536L = 3) THEN c END AS event_type#30534, id#30536L] +- *(2) Filter (CASE WHEN (id#30536L = 1) THEN b WHEN (id#30536L = 3) THEN c END = a) +- *(2) ColumnarToRow +- FileScan parquet default.t2[id#30536L] Batched: true, DataFilters: [(CASE WHEN (id#30536L = 1) THEN b WHEN (id#30536L = 3) THEN c END = a)], Format: Parquet ``` After this PR: ``` == Physical Plan == *(1) Project [a AS event_type#8, id#4L] +- *(1) ColumnarToRow +- FileScan parquet default.t1[id#4L] Batched: true, DataFilters: [], Format: Parquet ``` ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30790 from wangyum/SPARK-33798. Authored-by: Yuming Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent bc46d27 commit 06b1bbb

File tree

4 files changed

+274
-1
lines changed

4 files changed

+274
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,11 @@ abstract class BinaryExpression extends Expression {
636636
}
637637

638638

639+
object BinaryExpression {
640+
def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right))
641+
}
642+
643+
639644
/**
640645
* A [[BinaryExpression]] that is an operator, with two properties:
641646
*

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
9999
LikeSimplification,
100100
BooleanSimplification,
101101
SimplifyConditionals,
102+
PushFoldableIntoBranches,
102103
RemoveDispensableExpressions,
103104
SimplifyBinaryComparison,
104105
ReplaceNullWithFalseInPredicate,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.immutable.HashSet
2121
import scala.collection.mutable.{ArrayBuffer, Stack}
2222

2323
import org.apache.spark.sql.catalyst.analysis._
24-
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, _}
2525
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
2626
import org.apache.spark.sql.catalyst.expressions.aggregate._
2727
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
@@ -528,6 +528,48 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
528528
}
529529

530530

531+
/**
532+
* Push the foldable expression into (if / case) branches.
533+
*/
534+
object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
535+
536+
// To be conservative here: it's only a guaranteed win if all but at most only one branch
537+
// end up being not foldable.
538+
private def atMostOneUnfoldable(exprs: Seq[Expression]): Boolean = {
539+
val (foldables, others) = exprs.partition(_.foldable)
540+
foldables.nonEmpty && others.length < 2
541+
}
542+
543+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
544+
case q: LogicalPlan => q transformExpressionsUp {
545+
case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right)
546+
if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
547+
i.copy(
548+
trueValue = b.makeCopy(Array(trueValue, right)),
549+
falseValue = b.makeCopy(Array(falseValue, right)))
550+
551+
case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue))
552+
if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
553+
i.copy(
554+
trueValue = b.makeCopy(Array(left, trueValue)),
555+
falseValue = b.makeCopy(Array(left, falseValue)))
556+
557+
case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right)
558+
if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
559+
c.copy(
560+
branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))),
561+
elseValue.map(e => b.makeCopy(Array(e, right))))
562+
563+
case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue))
564+
if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
565+
c.copy(
566+
branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))),
567+
elseValue.map(e => b.makeCopy(Array(left, e))))
568+
}
569+
}
570+
}
571+
572+
531573
/**
532574
* Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition.
533575
* For example, when the expression is just checking to see if a string starts with a given
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import java.sql.Date
21+
22+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
23+
import org.apache.spark.sql.catalyst.dsl.expressions._
24+
import org.apache.spark.sql.catalyst.dsl.plans._
25+
import org.apache.spark.sql.catalyst.expressions._
26+
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
27+
import org.apache.spark.sql.catalyst.plans.PlanTest
28+
import org.apache.spark.sql.catalyst.plans.logical._
29+
import org.apache.spark.sql.catalyst.rules._
30+
import org.apache.spark.sql.types.{BooleanType, IntegerType}
31+
32+
33+
class PushFoldableIntoBranchesSuite
34+
extends PlanTest with ExpressionEvalHelper with PredicateHelper {
35+
36+
object Optimize extends RuleExecutor[LogicalPlan] {
37+
val batches = Batch("PushFoldableIntoBranches", FixedPoint(50),
38+
BooleanSimplification, ConstantFolding, SimplifyConditionals, PushFoldableIntoBranches) :: Nil
39+
}
40+
41+
private val relation = LocalRelation('a.int, 'b.int, 'c.boolean)
42+
private val a = EqualTo(UnresolvedAttribute("a"), Literal(100))
43+
private val b = UnresolvedAttribute("b")
44+
private val c = EqualTo(UnresolvedAttribute("c"), Literal(true))
45+
private val ifExp = If(a, Literal(2), Literal(3))
46+
private val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))
47+
48+
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
49+
val correctAnswer = Project(Alias(e2, "out")() :: Nil, relation).analyze
50+
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, relation).analyze)
51+
comparePlans(actual, correctAnswer)
52+
}
53+
54+
test("Push down EqualTo through If") {
55+
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
56+
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral))
57+
58+
// Push down at most one not foldable expressions.
59+
assertEquivalent(
60+
EqualTo(If(a, b, Literal(2)), Literal(2)),
61+
If(a, EqualTo(b, Literal(2)), TrueLiteral))
62+
assertEquivalent(
63+
EqualTo(If(a, b, b + 1), Literal(2)),
64+
EqualTo(If(a, b, b + 1), Literal(2)))
65+
66+
// Push down non-deterministic expressions.
67+
val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(2))
68+
assert(!nonDeterministic.deterministic)
69+
assertEquivalent(EqualTo(nonDeterministic, Literal(2)),
70+
If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, TrueLiteral))
71+
assertEquivalent(EqualTo(nonDeterministic, Literal(3)),
72+
If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, FalseLiteral))
73+
74+
// Handle Null values.
75+
assertEquivalent(
76+
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)),
77+
If(a, Literal(null, BooleanType), TrueLiteral))
78+
assertEquivalent(
79+
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)),
80+
If(a, Literal(null, BooleanType), FalseLiteral))
81+
assertEquivalent(
82+
EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)),
83+
Literal(null, BooleanType))
84+
assertEquivalent(
85+
EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)),
86+
Literal(null, BooleanType))
87+
}
88+
89+
test("Push down other BinaryComparison through If") {
90+
assertEquivalent(EqualNullSafe(ifExp, Literal(4)), FalseLiteral)
91+
assertEquivalent(GreaterThan(ifExp, Literal(4)), FalseLiteral)
92+
assertEquivalent(GreaterThanOrEqual(ifExp, Literal(4)), FalseLiteral)
93+
assertEquivalent(LessThan(ifExp, Literal(4)), TrueLiteral)
94+
assertEquivalent(LessThanOrEqual(ifExp, Literal(4)), TrueLiteral)
95+
}
96+
97+
test("Push down other BinaryOperator through If") {
98+
assertEquivalent(Add(ifExp, Literal(4)), If(a, Literal(6), Literal(7)))
99+
assertEquivalent(Subtract(ifExp, Literal(4)), If(a, Literal(-2), Literal(-1)))
100+
assertEquivalent(Multiply(ifExp, Literal(4)), If(a, Literal(8), Literal(12)))
101+
assertEquivalent(Pmod(ifExp, Literal(4)), If(a, Literal(2), Literal(3)))
102+
assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3)))
103+
assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)),
104+
If(a, Literal(2.0), Literal(3.0)))
105+
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral),
106+
If(a, FalseLiteral, TrueLiteral))
107+
assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral)
108+
}
109+
110+
test("Push down other BinaryExpression through If") {
111+
assertEquivalent(BRound(If(a, Literal(1.23), Literal(1.24)), Literal(1)), Literal(1.2))
112+
assertEquivalent(StartsWith(If(a, Literal("ab"), Literal("ac")), Literal("a")), TrueLiteral)
113+
assertEquivalent(FindInSet(If(a, Literal("ab"), Literal("ac")), Literal("a")), Literal(0))
114+
assertEquivalent(
115+
AddMonths(If(a, Literal(Date.valueOf("2020-01-01")), Literal(Date.valueOf("2021-01-01"))),
116+
Literal(1)),
117+
If(a, Literal(Date.valueOf("2020-02-01")), Literal(Date.valueOf("2021-02-01"))))
118+
}
119+
120+
test("Push down EqualTo through CaseWhen") {
121+
assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral)
122+
assertEquivalent(EqualTo(caseWhen, Literal(3)),
123+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
124+
assertEquivalent(
125+
EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)),
126+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
127+
128+
assertEquivalent(
129+
And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))),
130+
FalseLiteral)
131+
132+
// Push down at most one branch is not foldable expressions.
133+
assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)),
134+
CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None))
135+
assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)),
136+
EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)))
137+
assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)),
138+
EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)))
139+
140+
// Push down non-deterministic expressions.
141+
val nonDeterministic =
142+
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(Literal(2)))
143+
assert(!nonDeterministic.deterministic)
144+
assertEquivalent(EqualTo(nonDeterministic, Literal(2)),
145+
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(TrueLiteral)))
146+
assertEquivalent(EqualTo(nonDeterministic, Literal(3)),
147+
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(FalseLiteral)))
148+
149+
// Handle Null values.
150+
assertEquivalent(
151+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)),
152+
CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral)))
153+
assertEquivalent(
154+
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)),
155+
Literal(null, BooleanType))
156+
assertEquivalent(
157+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)),
158+
CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral)))
159+
assertEquivalent(
160+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
161+
Literal(1)),
162+
Literal(null, BooleanType))
163+
assertEquivalent(
164+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
165+
Literal(null, IntegerType)),
166+
Literal(null, BooleanType))
167+
}
168+
169+
test("Push down other BinaryComparison through CaseWhen") {
170+
assertEquivalent(EqualNullSafe(caseWhen, Literal(4)), FalseLiteral)
171+
assertEquivalent(GreaterThan(caseWhen, Literal(4)), FalseLiteral)
172+
assertEquivalent(GreaterThanOrEqual(caseWhen, Literal(4)), FalseLiteral)
173+
assertEquivalent(LessThan(caseWhen, Literal(4)), TrueLiteral)
174+
assertEquivalent(LessThanOrEqual(caseWhen, Literal(4)), TrueLiteral)
175+
}
176+
177+
test("Push down other BinaryOperator through CaseWhen") {
178+
assertEquivalent(Add(caseWhen, Literal(4)),
179+
CaseWhen(Seq((a, Literal(5)), (c, Literal(6))), Some(Literal(7))))
180+
assertEquivalent(Subtract(caseWhen, Literal(4)),
181+
CaseWhen(Seq((a, Literal(-3)), (c, Literal(-2))), Some(Literal(-1))))
182+
assertEquivalent(Multiply(caseWhen, Literal(4)),
183+
CaseWhen(Seq((a, Literal(4)), (c, Literal(8))), Some(Literal(12))))
184+
assertEquivalent(Pmod(caseWhen, Literal(4)),
185+
CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))))
186+
assertEquivalent(Remainder(caseWhen, Literal(4)),
187+
CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))))
188+
assertEquivalent(Divide(CaseWhen(Seq((a, Literal(1.0)), (c, Literal(2.0))), Some(Literal(3.0))),
189+
Literal(1.0)),
190+
CaseWhen(Seq((a, Literal(1.0)), (c, Literal(2.0))), Some(Literal(3.0))))
191+
assertEquivalent(And(CaseWhen(Seq((a, FalseLiteral), (c, TrueLiteral)), Some(TrueLiteral)),
192+
TrueLiteral),
193+
CaseWhen(Seq((a, FalseLiteral), (c, TrueLiteral)), Some(TrueLiteral)))
194+
assertEquivalent(Or(CaseWhen(Seq((a, FalseLiteral), (c, TrueLiteral)), Some(TrueLiteral)),
195+
TrueLiteral), TrueLiteral)
196+
}
197+
198+
test("Push down other BinaryExpression through CaseWhen") {
199+
assertEquivalent(
200+
BRound(CaseWhen(Seq((a, Literal(1.23)), (c, Literal(1.24))), Some(Literal(1.25))),
201+
Literal(1)),
202+
Literal(1.2))
203+
assertEquivalent(
204+
StartsWith(CaseWhen(Seq((a, Literal("ab")), (c, Literal("ac"))), Some(Literal("ad"))),
205+
Literal("a")),
206+
TrueLiteral)
207+
assertEquivalent(
208+
FindInSet(CaseWhen(Seq((a, Literal("ab")), (c, Literal("ac"))), Some(Literal("ad"))),
209+
Literal("a")),
210+
Literal(0))
211+
assertEquivalent(
212+
AddMonths(CaseWhen(Seq((a, Literal(Date.valueOf("2020-01-01"))),
213+
(c, Literal(Date.valueOf("2021-01-01")))),
214+
Some(Literal(Date.valueOf("2022-01-01")))),
215+
Literal(1)),
216+
CaseWhen(Seq((a, Literal(Date.valueOf("2020-02-01"))),
217+
(c, Literal(Date.valueOf("2021-02-01")))),
218+
Some(Literal(Date.valueOf("2022-02-01")))))
219+
}
220+
221+
test("Push down BinaryExpression through If/CaseWhen backwards") {
222+
assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral)
223+
assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral)
224+
}
225+
}

0 commit comments

Comments
 (0)