-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-33847][SQL] Simplify CaseWhen if elseValue is None #30852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
09edff5
ca2a224
3c5f3da
b837e37
d07344f
81c38f8
7f3529d
1684d67
0933019
d20fede
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -258,4 +258,22 @@ class PushFoldableIntoBranchesSuite | |
| EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None).cast(StringType), Literal("4")), | ||
| CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) | ||
| } | ||
|
|
||
| test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") { | ||
| assertEquivalent( | ||
|
||
| EqualTo(CaseWhen(Seq((a, Literal.create(null, IntegerType)))), Literal(2)), | ||
| Literal.create(null, BooleanType)) | ||
| assertEquivalent( | ||
| EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal.create(null, IntegerType)))), | ||
| Literal(2)), | ||
| Literal.create(null, BooleanType)) | ||
|
|
||
| assertEquivalent( | ||
| EqualTo(CaseWhen(Seq((a, Literal("str")))).cast(IntegerType), Literal(2)), | ||
| Literal.create(null, BooleanType)) | ||
| assertEquivalent( | ||
| EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal("str")))).cast(IntegerType), | ||
| Literal(2)), | ||
| Literal.create(null, BooleanType)) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException | |
| import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} | ||
| import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, EqualTo, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Not, Or, UnresolvedNamedLambdaVariable} | ||
| import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} | ||
| import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} | ||
| import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable} | ||
|
|
@@ -38,6 +38,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { | |
| ConstantFolding, | ||
| BooleanSimplification, | ||
| SimplifyConditionals, | ||
| PushFoldableIntoBranches, | ||
| ReplaceNullWithFalseInPredicate) :: Nil | ||
| } | ||
|
|
||
|
|
@@ -222,10 +223,14 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { | |
| Literal(null, IntegerType), | ||
| Literal(3)), | ||
| FalseLiteral) | ||
| testFilter(originalCond = condition, expectedCond = condition) | ||
| testJoin(originalCond = condition, expectedCond = condition) | ||
| testDelete(originalCond = condition, expectedCond = condition) | ||
| testUpdate(originalCond = condition, expectedCond = condition) | ||
| val expectedCond = If( | ||
|
||
| UnresolvedAttribute("i") > Literal(10), | ||
| Not(UnresolvedAttribute("i") === Literal(15)), | ||
| FalseLiteral) | ||
| testFilter(originalCond = condition, expectedCond = expectedCond) | ||
| testJoin(originalCond = condition, expectedCond = expectedCond) | ||
| testDelete(originalCond = condition, expectedCond = expectedCond) | ||
| testUpdate(originalCond = condition, expectedCond = expectedCond) | ||
| } | ||
|
|
||
| test("inability to replace null in non-boolean values of CaseWhen") { | ||
|
|
@@ -237,8 +242,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { | |
| TrueLiteral, | ||
| FalseLiteral) | ||
| val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)) | ||
| val expectedCond = | ||
| CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> (Literal(2) === nestedCaseWhen))) | ||
| val expectedCond = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> | ||
| CaseWhen(Seq((UnresolvedAttribute("i") > Literal(20)) -> TrueLiteral), FalseLiteral))) | ||
| testFilter(originalCond = condition, expectedCond = expectedCond) | ||
| testJoin(originalCond = condition, expectedCond = expectedCond) | ||
| testDelete(originalCond = condition, expectedCond = expectedCond) | ||
|
|
@@ -253,10 +258,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { | |
| Literal(3)), | ||
| TrueLiteral, | ||
| FalseLiteral) | ||
| val expectedCond = Literal(5) > If( | ||
| UnresolvedAttribute("i") === Literal(15), | ||
| Literal(null, IntegerType), | ||
| Literal(3)) | ||
| val expectedCond = Not(UnresolvedAttribute("i") === Literal(15)) | ||
| testFilter(originalCond = condition, expectedCond = expectedCond) | ||
| testJoin(originalCond = condition, expectedCond = expectedCond) | ||
| testDelete(originalCond = condition, expectedCond = expectedCond) | ||
|
|
@@ -380,6 +382,52 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { | |
| testProjection(originalExpr = column, expectedExpr = column) | ||
| } | ||
|
|
||
| test("replace None of elseValue inside CaseWhen if all branches are FalseLiteral") { | ||
| val allFalseBranches = Seq( | ||
| (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, | ||
| (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) | ||
| val allFalseCond = CaseWhen(allFalseBranches) | ||
|
|
||
| val nonAllFalseBranches = Seq( | ||
| (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, | ||
| (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) | ||
| val nonAllFalseCond = CaseWhen(nonAllFalseBranches, FalseLiteral) | ||
|
|
||
| testFilter(allFalseCond, FalseLiteral) | ||
| testJoin(allFalseCond, FalseLiteral) | ||
| testDelete(allFalseCond, FalseLiteral) | ||
| testUpdate(allFalseCond, FalseLiteral) | ||
|
|
||
| testFilter(nonAllFalseCond, nonAllFalseCond) | ||
| testJoin(nonAllFalseCond, nonAllFalseCond) | ||
| testDelete(nonAllFalseCond, nonAllFalseCond) | ||
| testUpdate(nonAllFalseCond, nonAllFalseCond) | ||
| } | ||
|
|
||
| test("replace None of elseValue inside CaseWhen if all branches are null") { | ||
| val allFalseBranches = Seq( | ||
|
||
| (UnresolvedAttribute("i") < Literal(10)) -> Literal.create(null, BooleanType), | ||
| (UnresolvedAttribute("i") > Literal(40)) -> Literal.create(null, BooleanType)) | ||
| val allFalseCond = CaseWhen(allFalseBranches) | ||
|
|
||
| testFilter(allFalseCond, FalseLiteral) | ||
| testJoin(allFalseCond, FalseLiteral) | ||
| testDelete(allFalseCond, FalseLiteral) | ||
| testUpdate(allFalseCond, FalseLiteral) | ||
| } | ||
|
|
||
| test("replace None of elseValue inside CaseWhen with PushFoldableIntoBranches") { | ||
| val allFalseBranches = Seq( | ||
| (UnresolvedAttribute("i") < Literal(10)) -> Literal("a"), | ||
| (UnresolvedAttribute("i") > Literal(40)) -> Literal("b")) | ||
| val allFalseCond = EqualTo(CaseWhen(allFalseBranches), "c") | ||
|
|
||
| testFilter(allFalseCond, FalseLiteral) | ||
| testJoin(allFalseCond, FalseLiteral) | ||
| testDelete(allFalseCond, FalseLiteral) | ||
| testUpdate(allFalseCond, FalseLiteral) | ||
| } | ||
|
|
||
| private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { | ||
| test((rel, exp) => rel.where(exp), originalCond, expectedCond) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -215,4 +215,24 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P | |
| If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), | ||
| LessThanOrEqual(Rand(0), UnresolvedAttribute("a"))) | ||
| } | ||
|
|
||
| test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") { | ||
| assertEquivalent( | ||
| CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, | ||
|
||
| None), | ||
| Literal.create(null, IntegerType)) | ||
| assertEquivalent( | ||
| CaseWhen((GreaterThan(Rand(0), 1), Literal.create(null, IntegerType)) :: Nil, | ||
| None), | ||
| Literal.create(null, IntegerType)) | ||
|
|
||
| assertEquivalent( | ||
|
||
| CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, | ||
| Some(Literal.create(null, IntegerType))), | ||
| Literal.create(null, IntegerType)) | ||
| assertEquivalent( | ||
| CaseWhen((GreaterThan('a, 1), Literal(20)) :: (GreaterThan('b, 1), Literal(20)) :: Nil, | ||
| Some(Literal(20))), | ||
| Literal(20)) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.