Skip to content

Commit 21ef0c3

Browse files
committed
Push down EqualTo through CaseWhen/If
1 parent f9f622f commit 21ef0c3

File tree

4 files changed

+168
-94
lines changed

4 files changed

+168
-94
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
9999
LikeSimplification,
100100
BooleanSimplification,
101101
SimplifyConditionals,
102+
PushFoldableIntoBranches,
102103
RemoveDispensableExpressions,
103104
SimplifyBinaryComparison,
104105
ReplaceNullWithFalseInPredicate,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -523,20 +523,33 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
523523
} else {
524524
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
525525
}
526+
}
527+
}
528+
}
526529

527-
case EqualTo(i @ If(_, trueValue: Literal, falseValue: Literal), right: Literal)
528-
if i.deterministic =>
529-
i.copy(trueValue = EqualTo(trueValue, right), falseValue = EqualTo(falseValue, right))
530-
531-
case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal)
532-
if c.deterministic && (branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[Literal]) =>
533-
c.copy(branches.map(b => b.copy(_2 = EqualTo(b._2, right))),
534-
elseValue.map(EqualTo(_, right)))
530+
/**
531+
* Push the foldable expression into (if / case) branches.
532+
*/
533+
object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
534+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
535+
case q: LogicalPlan => q transformExpressionsUp {
536+
case b @ BinaryComparison(i @ If(_, trueValue, falseValue), right)
537+
if i.deterministic && trueValue.foldable && falseValue.foldable && right.foldable =>
538+
i.copy(
539+
trueValue = b.makeCopy(Array(trueValue, right)),
540+
falseValue = b.makeCopy(Array(falseValue, right)))
541+
542+
case b @ BinaryComparison(c @ CaseWhen(branches, elseValue), right) if c.deterministic &&
543+
right.foldable && (branches.map(_._2) ++ elseValue).forall(_.foldable) =>
544+
c.copy(
545+
branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))),
546+
elseValue.map(e => b.makeCopy(Array(e, right))))
535547
}
536548
}
537549
}
538550

539551

552+
540553
/**
541554
* Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition.
542555
* For example, when the expression is just checking to see if a string starts with a given
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.dsl.plans._
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
25+
import org.apache.spark.sql.catalyst.plans.PlanTest
26+
import org.apache.spark.sql.catalyst.plans.logical._
27+
import org.apache.spark.sql.catalyst.rules._
28+
import org.apache.spark.sql.types.{BooleanType, IntegerType}
29+
30+
31+
class PushFoldableIntoBranchesSuite
32+
extends PlanTest with ExpressionEvalHelper with PredicateHelper {
33+
34+
object Optimize extends RuleExecutor[LogicalPlan] {
35+
val batches = Batch("PushFoldableIntoBranches", FixedPoint(50),
36+
BooleanSimplification, ConstantFolding, SimplifyConditionals, PushFoldableIntoBranches) :: Nil
37+
}
38+
39+
private val relation = LocalRelation('a.int, 'b.int, 'c.boolean)
40+
private val a = EqualTo(UnresolvedAttribute("a"), Literal(100))
41+
private val b = UnresolvedAttribute("b")
42+
private val c = EqualTo(UnresolvedAttribute("c"), Literal(true))
43+
private val ifExp = If(a, Literal(2), Literal(3))
44+
private val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))
45+
46+
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
47+
val correctAnswer = Project(Alias(e2, "out")() :: Nil, relation).analyze
48+
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, relation).analyze)
49+
comparePlans(actual, correctAnswer)
50+
}
51+
52+
private val normalBranch = (NonFoldableLiteral(true), Literal(10))
53+
54+
test("SPARK-33798: Push down EqualTo through If") {
55+
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
56+
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral))
57+
assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral)
58+
assertEquivalent(EqualTo(ifExp, Literal("3")), If(a, FalseLiteral, TrueLiteral))
59+
60+
// Do not simplify if it contains non foldable expressions.
61+
assertEquivalent(
62+
EqualTo(If(a, b, Literal(2)), Literal(2)),
63+
EqualTo(If(a, b, Literal(2)), Literal(2)))
64+
65+
// Do not simplify if it contains non-deterministic expressions.
66+
val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1))
67+
assert(!nonDeterministic.deterministic)
68+
assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1)))
69+
70+
// Handle Null values.
71+
assertEquivalent(
72+
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)),
73+
If(a, Literal(null, BooleanType), TrueLiteral))
74+
assertEquivalent(
75+
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)),
76+
If(a, Literal(null, BooleanType), FalseLiteral))
77+
assertEquivalent(
78+
EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)),
79+
Literal(null, BooleanType))
80+
assertEquivalent(
81+
EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)),
82+
Literal(null, BooleanType))
83+
}
84+
85+
test("SPARK-33798: Push down other BinaryComparison through If") {
86+
assertEquivalent(EqualNullSafe(ifExp, Literal(4)), FalseLiteral)
87+
assertEquivalent(GreaterThan(ifExp, Literal(4)), FalseLiteral)
88+
assertEquivalent(GreaterThanOrEqual(ifExp, Literal(4)), FalseLiteral)
89+
assertEquivalent(LessThan(ifExp, Literal(4)), TrueLiteral)
90+
assertEquivalent(LessThanOrEqual(ifExp, Literal(4)), TrueLiteral)
91+
}
92+
93+
test("SPARK-33798: Push down EqualTo through CaseWhen") {
94+
assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral)
95+
assertEquivalent(EqualTo(caseWhen, Literal(3)),
96+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
97+
assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral)
98+
assertEquivalent(EqualTo(caseWhen, Literal("3")),
99+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
100+
assertEquivalent(
101+
EqualTo(CaseWhen(Seq((a, Literal("1")), (c, Literal("2"))), None), Literal("4")),
102+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
103+
104+
assertEquivalent(
105+
And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))),
106+
FalseLiteral)
107+
108+
// Do not simplify if it contains non foldable expressions.
109+
assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)),
110+
EqualTo(caseWhen, NonFoldableLiteral(true)))
111+
val nonFoldable = CaseWhen(Seq(normalBranch, (a, b)), None)
112+
assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1)))
113+
114+
// Do not simplify if it contains non-deterministic expressions.
115+
val nonDeterministic = CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(b))
116+
assert(!nonDeterministic.deterministic)
117+
assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1)))
118+
119+
// Handle Null values.
120+
assertEquivalent(
121+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)),
122+
CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral)))
123+
assertEquivalent(
124+
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)),
125+
Literal(null, BooleanType))
126+
assertEquivalent(
127+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)),
128+
CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral)))
129+
assertEquivalent(
130+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
131+
Literal(1)),
132+
Literal(null, BooleanType))
133+
assertEquivalent(
134+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
135+
Literal(null, IntegerType)),
136+
Literal(null, BooleanType))
137+
}
138+
139+
test("SPARK-33798: Push down other BinaryComparison through CaseWhen") {
140+
assertEquivalent(EqualNullSafe(caseWhen, Literal(4)), FalseLiteral)
141+
assertEquivalent(GreaterThan(caseWhen, Literal(4)), FalseLiteral)
142+
assertEquivalent(GreaterThanOrEqual(caseWhen, Literal(4)), FalseLiteral)
143+
assertEquivalent(LessThan(caseWhen, Literal(4)), TrueLiteral)
144+
assertEquivalent(LessThanOrEqual(caseWhen, Literal(4)), TrueLiteral)
145+
}
146+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -199,90 +199,4 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
199199
If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow))
200200
}
201201
}
202-
203-
test("SPARK-33798: Push down EqualTo through If") {
204-
val a = EqualTo(UnresolvedAttribute("a"), Literal(100))
205-
val b = UnresolvedAttribute("b")
206-
val ifExp = If(a, Literal(2), Literal(3))
207-
208-
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
209-
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral))
210-
assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral)
211-
assertEquivalent(EqualTo(ifExp, Literal("3")), If(a, FalseLiteral, TrueLiteral))
212-
213-
// Do not simplify if it contains non foldable expressions.
214-
assertEquivalent(
215-
EqualTo(If(a, b, Literal(2)), Literal(2)),
216-
EqualTo(If(a, b, Literal(2)), Literal(2)))
217-
218-
// Do not simplify if it contains non-deterministic expressions.
219-
val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1))
220-
assert(!nonDeterministic.deterministic)
221-
assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1)))
222-
223-
// Handle Null values.
224-
assertEquivalent(
225-
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)),
226-
If(a, Literal(null, BooleanType), TrueLiteral))
227-
assertEquivalent(
228-
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)),
229-
If(a, Literal(null, BooleanType), FalseLiteral))
230-
assertEquivalent(
231-
EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)),
232-
Literal(null, BooleanType))
233-
assertEquivalent(
234-
EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)),
235-
Literal(null, BooleanType))
236-
}
237-
238-
test("SPARK-33798: Push down EqualTo through CaseWhen") {
239-
val a = EqualTo(UnresolvedAttribute("a"), Literal(100))
240-
val b = UnresolvedAttribute("b")
241-
val c = EqualTo(UnresolvedAttribute("c"), Literal(true))
242-
val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))
243-
244-
assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral)
245-
assertEquivalent(EqualTo(caseWhen, Literal(3)),
246-
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
247-
assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral)
248-
assertEquivalent(EqualTo(caseWhen, Literal("3")),
249-
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
250-
assertEquivalent(
251-
EqualTo(CaseWhen(Seq((a, Literal("1")), (c, Literal("2"))), None), Literal("4")),
252-
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
253-
254-
assertEquivalent(
255-
And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))),
256-
FalseLiteral)
257-
258-
// Do not simplify if it contains non foldable expressions.
259-
assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)),
260-
EqualTo(caseWhen, NonFoldableLiteral(true)))
261-
val nonFoldable = CaseWhen(Seq(normalBranch, (a, b)), None)
262-
assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1)))
263-
264-
// Do not simplify if it contains non-deterministic expressions.
265-
val nonDeterministic = CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(b))
266-
assert(!nonDeterministic.deterministic)
267-
assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1)))
268-
269-
// Handle Null values.
270-
assertEquivalent(
271-
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)),
272-
CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral)))
273-
assertEquivalent(
274-
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)),
275-
Literal(null, BooleanType))
276-
assertEquivalent(
277-
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)),
278-
CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral)))
279-
assertEquivalent(
280-
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
281-
Literal(1)),
282-
Literal(null, BooleanType))
283-
assertEquivalent(
284-
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
285-
Literal(null, IntegerType)),
286-
Literal(null, BooleanType))
287-
}
288202
}

0 commit comments

Comments
 (0)