@@ -39,7 +39,8 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
3939 }
4040
4141 private val trueBranch = (TrueLiteral , Literal (5 ))
42- private val normalBranch = (NonFoldableLiteral (true ), Literal (10 ))
42+ private val normalBranch1 = (NonFoldableLiteral (true ), Literal (10 ))
43+ private val normalBranch2 = (NonFoldableLiteral (false ), Literal (3 ))
4344 private val unreachableBranch = (FalseLiteral , Literal (20 ))
4445 private val nullBranch = (Literal .create(null , NullType ), Literal (30 ))
4546
@@ -60,18 +61,23 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
6061 test(" remove unreachable branches" ) {
6162 // i.e. removing branches whose conditions are always false
6263 assertEquivalent(
63- CaseWhen (unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil , None ),
64- If (normalBranch._1, normalBranch._2, Literal (null , normalBranch._2.dataType)))
64+ CaseWhen (unreachableBranch :: normalBranch1 :: unreachableBranch ::
65+ normalBranch2 :: nullBranch :: Nil , None ),
66+ CaseWhen (normalBranch1 :: normalBranch2 :: Nil , None ))
6567 }
6668
6769 test(" simplify CaseWhen to If when there is only one branch" ) {
6870 assertEquivalent(
69- CaseWhen (normalBranch :: Nil , None ),
70- If (normalBranch ._1, normalBranch ._2, Literal (null , normalBranch._2.dataType )))
71+ CaseWhen (normalBranch1 :: Nil , Some ( Literal ( 30 )) ),
72+ If (normalBranch1 ._1, normalBranch1 ._2, Literal (30 )))
7173
7274 assertEquivalent(
73- CaseWhen (normalBranch :: Nil , Some (Literal (30 ))),
74- If (normalBranch._1, normalBranch._2, Literal (30 )))
75+ CaseWhen (normalBranch1 :: Nil , None ),
76+ If (normalBranch1._1, normalBranch1._2, Literal (null , normalBranch1._2.dataType)))
77+
78+ assertEquivalent(
79+ CaseWhen (unreachableBranch :: normalBranch1 :: unreachableBranch :: nullBranch :: Nil , None ),
80+ If (normalBranch1._1, normalBranch1._2, Literal (null , normalBranch1._2.dataType)))
7581 }
7682
7783 test(" remove entire CaseWhen if only the else branch is reachable" ) {
@@ -86,28 +92,28 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
8692
8793 test(" remove entire CaseWhen if the first branch is always true" ) {
8894 assertEquivalent(
89- CaseWhen (trueBranch :: normalBranch :: nullBranch :: Nil , None ),
95+ CaseWhen (trueBranch :: normalBranch1 :: nullBranch :: Nil , None ),
9096 Literal (5 ))
9197
9298 // Test branch elimination and simplification in combination
9399 assertEquivalent(
94- CaseWhen (unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
100+ CaseWhen (unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch1
95101 :: Nil , None ),
96102 Literal (5 ))
97103
98104 // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
99105 assertEquivalent(
100- CaseWhen (normalBranch :: trueBranch :: normalBranch :: Nil , None ),
101- CaseWhen (normalBranch :: trueBranch :: Nil , None ))
106+ CaseWhen (normalBranch1 :: trueBranch :: normalBranch1 :: Nil , None ),
107+ CaseWhen (normalBranch1 :: trueBranch :: Nil , None ))
102108 }
103109
104110 test(" simplify CaseWhen, prune branches following a definite true" ) {
105111 assertEquivalent(
106- CaseWhen (normalBranch :: unreachableBranch ::
112+ CaseWhen (normalBranch1 :: unreachableBranch ::
107113 unreachableBranch :: nullBranch ::
108- trueBranch :: normalBranch ::
114+ trueBranch :: normalBranch1 ::
109115 Nil ,
110116 None ),
111- CaseWhen (normalBranch :: trueBranch :: Nil , None ))
117+ CaseWhen (normalBranch1 :: trueBranch :: Nil , None ))
112118 }
113119}
0 commit comments