Skip to content

Commit 19b0a83

Browse files
committed
Simplify EqualTo(CaseWhen/If, Literal) always false
1 parent 40c37d6 commit 19b0a83

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,10 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
470470
case _ => false
471471
}
472472

473+
private def isAlwaysFalse(exps: Seq[Expression], equalTo: Literal): Boolean = {
474+
exps.forall(!EqualTo(_, equalTo).eval(EmptyRow).asInstanceOf[Boolean])
475+
}
476+
473477
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
474478
case q: LogicalPlan => q transformExpressionsUp {
475479
case If(TrueLiteral, trueValue, _) => trueValue
@@ -523,6 +527,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
523527
} else {
524528
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
525529
}
530+
531+
case EqualTo(i @ If(_, trueValue: Literal, falseValue: Literal), right: Literal)
532+
if i.deterministic && isAlwaysFalse(trueValue :: falseValue :: Nil, right) =>
533+
FalseLiteral
534+
535+
case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal) if c.deterministic &&
536+
(branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[Literal]) &&
537+
isAlwaysFalse(branches.map(_._2) ++ elseValue, right) =>
538+
FalseLiteral
526539
}
527540
}
528541
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,97 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
199199
If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow))
200200
}
201201
}
202+
203+
test("SPARK-33798: simplify EqualTo(If, Literal) always false") {
204+
val a = EqualTo(UnresolvedAttribute("a"), Literal(100))
205+
val ifExp = If(a === Literal(1), Literal(2), Literal(3))
206+
207+
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
208+
assertEquivalent(EqualTo(ifExp, Literal(3)), EqualTo(ifExp, Literal(3)))
209+
assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral)
210+
assertEquivalent(EqualTo(ifExp, Literal("3")), EqualTo(ifExp, Literal(3)))
211+
212+
// Do not simplify if it contains non foldable expressions.
213+
assertEquivalent(EqualTo(ifExp, NonFoldableLiteral(true)),
214+
EqualTo(ifExp, NonFoldableLiteral(true)))
215+
val nonFoldable = If(NonFoldableLiteral(true), Literal(1), Literal(2))
216+
assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1)))
217+
218+
// Do not simplify if it contains non-deterministic expressions.
219+
val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1))
220+
assert(!nonDeterministic.deterministic)
221+
assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1)))
222+
223+
// null check, SPARK-33798 will not change these behaviors.
224+
assertEquivalent(
225+
EqualTo(If(FalseLiteral, Literal(null, IntegerType), Literal(1)), Literal(1)),
226+
TrueLiteral)
227+
assertEquivalent(
228+
EqualTo(If(TrueLiteral, Literal(null, IntegerType), Literal(1)), Literal(1)),
229+
Literal(null, BooleanType))
230+
assertEquivalent(
231+
EqualTo(If(FalseLiteral, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)),
232+
Literal(null, BooleanType))
233+
234+
assertEquivalent(
235+
EqualTo(If(FalseLiteral, Literal(1), Literal(2)), Literal(null, IntegerType)),
236+
Literal(null, BooleanType))
237+
assertEquivalent(
238+
EqualTo(If(TrueLiteral, Literal(1), Literal(2)), Literal(null, IntegerType)),
239+
Literal(null, BooleanType))
240+
}
241+
242+
test("SPARK-33798: simplify EqualTo(CaseWhen, Literal) always false") {
243+
val a = EqualTo(UnresolvedAttribute("a"), Literal(100))
244+
val b = UnresolvedAttribute("b")
245+
val c = EqualTo(UnresolvedAttribute("c"), Literal(true))
246+
val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))
247+
248+
assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral)
249+
assertEquivalent(EqualTo(caseWhen, Literal(3)), EqualTo(caseWhen, Literal(3)))
250+
assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral)
251+
assertEquivalent(EqualTo(caseWhen, Literal("3")), EqualTo(caseWhen, Literal(3)))
252+
assertEquivalent(
253+
EqualTo(CaseWhen(Seq((a, Literal("1")), (c, Literal("2"))), None), Literal("4")),
254+
FalseLiteral)
255+
256+
assertEquivalent(
257+
And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))),
258+
FalseLiteral)
259+
260+
assertEquivalent(
261+
EqualTo(CaseWhen(Seq(normalBranch, (a, Literal(1)), (c, Literal(1))), None), Literal(-1)),
262+
FalseLiteral)
263+
264+
// Do not simplify if it contains non foldable expressions.
265+
assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)),
266+
EqualTo(caseWhen, NonFoldableLiteral(true)))
267+
val nonFoldable = CaseWhen(Seq(normalBranch, (a, b)), None)
268+
assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1)))
269+
270+
// Do not simplify if it contains non-deterministic expressions.
271+
val nonDeterministic = CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(b))
272+
assert(!nonDeterministic.deterministic)
273+
assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1)))
274+
275+
// null check, SPARK-33798 will change the following two behaviors.
276+
assertEquivalent(
277+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)),
278+
FalseLiteral)
279+
assertEquivalent(
280+
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)),
281+
FalseLiteral)
282+
283+
assertEquivalent(
284+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)),
285+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)))
286+
assertEquivalent(
287+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
288+
Literal(1)),
289+
Literal(null, BooleanType))
290+
assertEquivalent(
291+
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
292+
Literal(null, IntegerType)),
293+
Literal(null, BooleanType))
294+
}
202295
}

0 commit comments

Comments
 (0)