Skip to content

Commit b837e37

Browse files
committed
fix
1 parent 3c5f3da commit b837e37

File tree

4 files changed

+19
-8
lines changed

4 files changed

+19
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
525525
} else {
526526
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
527527
}
528+
529+
case e @ CaseWhen(_, elseValue) if elseValue.isEmpty =>
530+
e.copy(elseValue = Some(Literal.create(null, e.dataType)))
528531
}
529532
}
530533
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,20 @@ class PushFoldableIntoBranchesSuite
122122
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
123123
assertEquivalent(
124124
EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)),
125-
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
125+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Literal.create(null, BooleanType)))
126126

127127
assertEquivalent(
128128
And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))),
129129
FalseLiteral)
130130

131131
// Push down at most one branch is not foldable expressions.
132132
assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)),
133-
CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None))
133+
CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)),
134+
Literal.create(null, BooleanType)))
134135
assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)),
135-
EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)))
136+
EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), Literal.create(null, IntegerType)), Literal(1)))
136137
assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)),
137-
EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)))
138+
CaseWhen(Seq((a, b === Literal(1))), Literal.create(null, BooleanType)))
138139

139140
// Push down non-deterministic expressions.
140141
val nonDeterministic =

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
115115
val expectedBranches = Seq(
116116
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
117117
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
118-
val expectedCond = CaseWhen(expectedBranches)
118+
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)
119119

120120
testFilter(originalCond, expectedCond)
121121
testJoin(originalCond, expectedCond)
@@ -246,7 +246,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
246246
val expectedCond = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) ->
247247
CaseWhen(
248248
Seq((UnresolvedAttribute("i") > Literal(20)) -> TrueLiteral),
249-
FalseLiteral)))
249+
FalseLiteral)),
250+
FalseLiteral)
250251
testFilter(originalCond = condition, expectedCond = expectedCond)
251252
testJoin(originalCond = condition, expectedCond = expectedCond)
252253
testDelete(originalCond = condition, expectedCond = expectedCond)
@@ -394,7 +395,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
394395
val nonAllFalseBranches = Seq(
395396
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
396397
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
397-
val nonAllFalseCond = CaseWhen(nonAllFalseBranches)
398+
val nonAllFalseCond = CaseWhen(nonAllFalseBranches, FalseLiteral)
398399

399400
testFilter(allFalseCond, FalseLiteral)
400401
testJoin(allFalseCond, FalseLiteral)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
8383
// i.e. removing branches whose conditions are always false
8484
assertEquivalent(
8585
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
86-
CaseWhen(normalBranch :: Nil, None))
86+
CaseWhen(normalBranch :: Nil, Literal.create(null, IntegerType)))
8787
}
8888

8989
test("remove entire CaseWhen if only the else branch is reachable") {
@@ -215,4 +215,10 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
215215
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
216216
LessThanOrEqual(Rand(0), UnresolvedAttribute("a")))
217217
}
218+
219+
test("SPARK-33847: Replace None of elseValue inside CaseWhen to null literal") {
220+
assertEquivalent(
221+
CaseWhen(normalBranch :: Nil, None),
222+
CaseWhen(normalBranch :: Nil, Literal.create(null, IntegerType)))
223+
}
218224
}

0 commit comments

Comments
 (0)