Skip to content

Commit f9f622f

Browse files
committed
Push down EqualTo through CaseWhen/If
1 parent 0a48048 commit f9f622f

File tree

2 files changed

+23
-33
lines changed

2 files changed

+23
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -470,15 +470,6 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
470470
case _ => false
471471
}
472472

473-
private def isAlwaysFalse(exps: Seq[Expression], equalTo: Literal): Boolean = {
474-
exps.forall {
475-
case l: Literal =>
476-
val res = EqualTo(l, equalTo).eval(EmptyRow)
477-
res != null && !res.asInstanceOf[Boolean]
478-
case _ => false
479-
}
480-
}
481-
482473
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
483474
case q: LogicalPlan => q transformExpressionsUp {
484475
case If(TrueLiteral, trueValue, _) => trueValue
@@ -533,13 +524,14 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
533524
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
534525
}
535526

536-
case EqualTo(i @ If(_, trueValue, falseValue), right: Literal)
537-
if i.deterministic && isAlwaysFalse(trueValue :: falseValue :: Nil, right) =>
538-
FalseLiteral
527+
case EqualTo(i @ If(_, trueValue: Literal, falseValue: Literal), right: Literal)
528+
if i.deterministic =>
529+
i.copy(trueValue = EqualTo(trueValue, right), falseValue = EqualTo(falseValue, right))
539530

540531
case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal)
541-
if c.deterministic && isAlwaysFalse(branches.map(_._2) ++ elseValue, right) =>
542-
FalseLiteral
532+
if c.deterministic && (branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[Literal]) =>
533+
c.copy(branches.map(b => b.copy(_2 = EqualTo(b._2, right))),
534+
elseValue.map(EqualTo(_, right)))
543535
}
544536
}
545537
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,15 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
200200
}
201201
}
202202

203-
test("SPARK-33798: simplify EqualTo(If, Literal) always false") {
203+
test("SPARK-33798: Push down EqualTo through If") {
204204
val a = EqualTo(UnresolvedAttribute("a"), Literal(100))
205205
val b = UnresolvedAttribute("b")
206206
val ifExp = If(a, Literal(2), Literal(3))
207207

208208
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
209-
assertEquivalent(EqualTo(ifExp, Literal(3)), EqualTo(ifExp, Literal(3)))
209+
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral))
210210
assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral)
211-
assertEquivalent(EqualTo(ifExp, Literal("3")), EqualTo(ifExp, Literal(3)))
211+
assertEquivalent(EqualTo(ifExp, Literal("3")), If(a, FalseLiteral, TrueLiteral))
212212

213213
// Do not simplify if it contains non foldable expressions.
214214
assertEquivalent(
@@ -220,43 +220,41 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
220220
assert(!nonDeterministic.deterministic)
221221
assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1)))
222222

223-
// Should not handle Null values.
223+
// Handle Null values.
224224
assertEquivalent(
225225
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)),
226-
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)))
226+
If(a, Literal(null, BooleanType), TrueLiteral))
227227
assertEquivalent(
228228
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)),
229-
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)))
229+
If(a, Literal(null, BooleanType), FalseLiteral))
230230
assertEquivalent(
231231
EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)),
232-
EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)))
232+
Literal(null, BooleanType))
233233
assertEquivalent(
234234
EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)),
235235
Literal(null, BooleanType))
236236
}
237237

238-
test("SPARK-33798: simplify EqualTo(CaseWhen, Literal) always false") {
238+
test("SPARK-33798: Push down EqualTo through CaseWhen") {
239239
val a = EqualTo(UnresolvedAttribute("a"), Literal(100))
240240
val b = UnresolvedAttribute("b")
241241
val c = EqualTo(UnresolvedAttribute("c"), Literal(true))
242242
val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))
243243

244244
assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral)
245-
assertEquivalent(EqualTo(caseWhen, Literal(3)), EqualTo(caseWhen, Literal(3)))
245+
assertEquivalent(EqualTo(caseWhen, Literal(3)),
246+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
246247
assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral)
247-
assertEquivalent(EqualTo(caseWhen, Literal("3")), EqualTo(caseWhen, Literal(3)))
248+
assertEquivalent(EqualTo(caseWhen, Literal("3")),
249+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
248250
assertEquivalent(
249251
EqualTo(CaseWhen(Seq((a, Literal("1")), (c, Literal("2"))), None), Literal("4")),
250-
FalseLiteral)
252+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
251253

252254
assertEquivalent(
253255
And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))),
254256
FalseLiteral)
255257

256-
assertEquivalent(
257-
EqualTo(CaseWhen(Seq(normalBranch, (a, Literal(1)), (c, Literal(1))), None), Literal(-1)),
258-
FalseLiteral)
259-
260258
// Do not simplify if it contains non foldable expressions.
261259
assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)),
262260
EqualTo(caseWhen, NonFoldableLiteral(true)))
@@ -268,16 +266,16 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
268266
assert(!nonDeterministic.deterministic)
269267
assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1)))
270268

271-
// Should not handle Null values.
269+
// Handle Null values.
272270
assertEquivalent(
273271
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)),
274-
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)))
272+
CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral)))
275273
assertEquivalent(
276274
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)),
277-
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)))
275+
Literal(null, BooleanType))
278276
assertEquivalent(
279277
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)),
280-
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)))
278+
CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral)))
281279
assertEquivalent(
282280
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
283281
Literal(1)),

0 commit comments

Comments
 (0)