Skip to content

Commit d3b072e

Browse files
committed
fix
1 parent 186a3d5 commit d3b072e

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,10 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
486486
case If(cond, FalseLiteral, l @ Literal(null, _)) if !cond.nullable => And(Not(cond), l)
487487
case If(cond, TrueLiteral, l @ Literal(null, _)) if !cond.nullable => Or(cond, l)
488488

489-
case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral)) => cond
490-
case CaseWhen(Seq((cond, FalseLiteral)), Some(TrueLiteral)) => Not(cond)
489+
case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral)) =>
490+
if (cond.nullable) EqualNullSafe(cond, TrueLiteral) else cond
491+
case CaseWhen(Seq((cond, FalseLiteral)), Some(TrueLiteral)) =>
492+
if (cond.nullable) Not(EqualNullSafe(cond, TrueLiteral)) else Not(cond)
491493

492494
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
493495
// If there are branches that are always false, remove them.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,10 @@ class PushFoldableIntoBranchesSuite
272272

273273
test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
274274
assertEquivalent(
275-
EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(0)), 'a > 10)
275+
EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(0)),
276+
'a > 10 <=> TrueLiteral)
276277
assertEquivalent(
277-
EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(1)), 'a <= 10)
278+
EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(1)),
279+
Not('a > 10 <=> TrueLiteral))
278280
}
279281
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,38 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
245245
}
246246

247247
test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
248-
Seq(IsNull('a), GreaterThan(Rand(0), 1.0)).foreach { cond =>
249-
assertEquivalent(CaseWhen(Seq((cond, TrueLiteral)), FalseLiteral), cond)
248+
// verify the boolean equivalence of all transformations involved
249+
val fields = Seq(
250+
'cond.boolean.notNull,
251+
'cond_nullable.boolean,
252+
'a.boolean,
253+
'b.boolean
254+
)
255+
val Seq(cond, cond_nullable, a, b) = fields.zipWithIndex.map { case (f, i) => f.at(i) }
256+
257+
val exprs = Seq(
258+
// actual expressions of the transformations: original -> transformed
259+
CaseWhen(Seq((cond, TrueLiteral)), FalseLiteral) -> cond,
260+
CaseWhen(Seq((cond, FalseLiteral)), TrueLiteral) -> !cond,
261+
CaseWhen(Seq((cond_nullable, TrueLiteral)), FalseLiteral) -> (cond_nullable <=> true),
262+
CaseWhen(Seq((cond_nullable, FalseLiteral)), TrueLiteral) -> (!(cond_nullable <=> true)))
263+
264+
// check plans
265+
for ((originalExpr, expectedExpr) <- exprs) {
266+
assertEquivalent(originalExpr, expectedExpr)
267+
}
268+
269+
// check evaluation
270+
val binaryBooleanValues = Seq(true, false)
271+
val ternaryBooleanValues = Seq(true, false, null)
272+
for (condVal <- binaryBooleanValues;
273+
condNullableVal <- ternaryBooleanValues;
274+
aVal <- ternaryBooleanValues;
275+
bVal <- ternaryBooleanValues;
276+
(originalExpr, expectedExpr) <- exprs) {
277+
val inputRow = create_row(condVal, condNullableVal, aVal, bVal)
278+
val optimizedVal = evaluateWithoutCodegen(expectedExpr, inputRow)
279+
checkEvaluation(originalExpr, optimizedVal, inputRow)
250280
}
251-
assertEquivalent(CaseWhen(Seq((IsNull('a), FalseLiteral)), TrueLiteral), IsNotNull('a))
252-
assertEquivalent(CaseWhen(Seq((GreaterThan(Rand(0), 1.0), FalseLiteral)), TrueLiteral),
253-
LessThanOrEqual(Rand(0), 1.0))
254281
}
255282
}

0 commit comments

Comments
 (0)