@@ -200,15 +200,15 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
200200 }
201201 }
202202
203- test(" SPARK-33798: simplify EqualTo(If, Literal) always false " ) {
203+ test(" SPARK-33798: Push down EqualTo through If " ) {
204204 val a = EqualTo (UnresolvedAttribute (" a" ), Literal (100 ))
205205 val b = UnresolvedAttribute (" b" )
206206 val ifExp = If (a, Literal (2 ), Literal (3 ))
207207
208208 assertEquivalent(EqualTo (ifExp, Literal (4 )), FalseLiteral )
209- assertEquivalent(EqualTo (ifExp, Literal (3 )), EqualTo (ifExp, Literal ( 3 ) ))
209+ assertEquivalent(EqualTo (ifExp, Literal (3 )), If (a, FalseLiteral , TrueLiteral ))
210210 assertEquivalent(EqualTo (ifExp, Literal (" 4" )), FalseLiteral )
211- assertEquivalent(EqualTo (ifExp, Literal (" 3" )), EqualTo (ifExp, Literal ( 3 ) ))
211+ assertEquivalent(EqualTo (ifExp, Literal (" 3" )), If (a, FalseLiteral , TrueLiteral ))
212212
213213 // Do not simplify if it contains non foldable expressions.
214214 assertEquivalent(
@@ -220,43 +220,41 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
220220 assert(! nonDeterministic.deterministic)
221221 assertEquivalent(EqualTo (nonDeterministic, Literal (- 1 )), EqualTo (nonDeterministic, Literal (- 1 )))
222222
223- // Should not handle Null values.
223+ // Handle Null values.
224224 assertEquivalent(
225225 EqualTo (If (a, Literal (null , IntegerType ), Literal (1 )), Literal (1 )),
226- EqualTo ( If (a, Literal (null , IntegerType ), Literal ( 1 )), Literal ( 1 ) ))
226+ If (a, Literal (null , BooleanType ), TrueLiteral ))
227227 assertEquivalent(
228228 EqualTo (If (a, Literal (null , IntegerType ), Literal (1 )), Literal (2 )),
229- EqualTo ( If (a, Literal (null , IntegerType ), Literal ( 1 )), Literal ( 2 ) ))
229+ If (a, Literal (null , BooleanType ), FalseLiteral ))
230230 assertEquivalent(
231231 EqualTo (If (a, Literal (1 ), Literal (2 )), Literal (null , IntegerType )),
232- EqualTo ( If (a, Literal (1 ), Literal ( 2 )), Literal ( null , IntegerType ) ))
232+ Literal (null , BooleanType ))
233233 assertEquivalent(
234234 EqualTo (If (a, Literal (null , IntegerType ), Literal (null , IntegerType )), Literal (1 )),
235235 Literal (null , BooleanType ))
236236 }
237237
238- test(" SPARK-33798: simplify EqualTo(CaseWhen, Literal) always false " ) {
238+ test(" SPARK-33798: Push down EqualTo through CaseWhen " ) {
239239 val a = EqualTo (UnresolvedAttribute (" a" ), Literal (100 ))
240240 val b = UnresolvedAttribute (" b" )
241241 val c = EqualTo (UnresolvedAttribute (" c" ), Literal (true ))
242242 val caseWhen = CaseWhen (Seq ((a, Literal (1 )), (c, Literal (2 ))), Some (Literal (3 )))
243243
244244 assertEquivalent(EqualTo (caseWhen, Literal (4 )), FalseLiteral )
245- assertEquivalent(EqualTo (caseWhen, Literal (3 )), EqualTo (caseWhen, Literal (3 )))
245+ assertEquivalent(EqualTo (caseWhen, Literal (3 )),
246+ CaseWhen (Seq ((a, FalseLiteral ), (c, FalseLiteral )), Some (TrueLiteral )))
246247 assertEquivalent(EqualTo (caseWhen, Literal (" 4" )), FalseLiteral )
247- assertEquivalent(EqualTo (caseWhen, Literal (" 3" )), EqualTo (caseWhen, Literal (3 )))
248+ assertEquivalent(EqualTo (caseWhen, Literal (" 3" )),
249+ CaseWhen (Seq ((a, FalseLiteral ), (c, FalseLiteral )), Some (TrueLiteral )))
248250 assertEquivalent(
249251 EqualTo (CaseWhen (Seq ((a, Literal (" 1" )), (c, Literal (" 2" ))), None ), Literal (" 4" )),
250- FalseLiteral )
252+ CaseWhen ( Seq ((a, FalseLiteral ), (c, FalseLiteral )), None ) )
251253
252254 assertEquivalent(
253255 And (EqualTo (caseWhen, Literal (5 )), EqualTo (caseWhen, Literal (6 ))),
254256 FalseLiteral )
255257
256- assertEquivalent(
257- EqualTo (CaseWhen (Seq (normalBranch, (a, Literal (1 )), (c, Literal (1 ))), None ), Literal (- 1 )),
258- FalseLiteral )
259-
260258 // Do not simplify if it contains non foldable expressions.
261259 assertEquivalent(EqualTo (caseWhen, NonFoldableLiteral (true )),
262260 EqualTo (caseWhen, NonFoldableLiteral (true )))
@@ -268,16 +266,16 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
268266 assert(! nonDeterministic.deterministic)
269267 assertEquivalent(EqualTo (nonDeterministic, Literal (- 1 )), EqualTo (nonDeterministic, Literal (- 1 )))
270268
271- // Should not handle Null values.
269+ // Handle Null values.
272270 assertEquivalent(
273271 EqualTo (CaseWhen (Seq ((a, Literal (null , IntegerType ))), Some (Literal (1 ))), Literal (2 )),
274- EqualTo ( CaseWhen (Seq ((a, Literal (null , IntegerType ))), Some (Literal ( 1 ))), Literal ( 2 )))
272+ CaseWhen (Seq ((a, Literal (null , BooleanType ))), Some (FalseLiteral )))
275273 assertEquivalent(
276274 EqualTo (CaseWhen (Seq ((a, Literal (1 ))), Some (Literal (2 ))), Literal (null , IntegerType )),
277- EqualTo ( CaseWhen ( Seq ((a, Literal (1 ))), Some ( Literal ( 2 ))), Literal ( null , IntegerType ) ))
275+ Literal (null , BooleanType ))
278276 assertEquivalent(
279277 EqualTo (CaseWhen (Seq ((a, Literal (null , IntegerType ))), Some (Literal (1 ))), Literal (1 )),
280- EqualTo ( CaseWhen (Seq ((a, Literal (null , IntegerType ))), Some (Literal ( 1 ))), Literal ( 1 )))
278+ CaseWhen (Seq ((a, Literal (null , BooleanType ))), Some (TrueLiteral )))
281279 assertEquivalent(
282280 EqualTo (CaseWhen (Seq ((a, Literal (null , IntegerType ))), Some (Literal (null , IntegerType ))),
283281 Literal (1 )),
0 commit comments