@@ -199,4 +199,97 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
199199 If (Factorial (5 ) > 100L , b, nullLiteral).eval(EmptyRow ))
200200 }
201201 }
202+
203+ test(" SPARK-33798: simplify EqualTo(If, Literal) always false" ) {
204+ val a = EqualTo (UnresolvedAttribute (" a" ), Literal (100 ))
205+ val ifExp = If (a === Literal (1 ), Literal (2 ), Literal (3 ))
206+
207+ assertEquivalent(EqualTo (ifExp, Literal (4 )), FalseLiteral )
208+ assertEquivalent(EqualTo (ifExp, Literal (3 )), EqualTo (ifExp, Literal (3 )))
209+ assertEquivalent(EqualTo (ifExp, Literal (" 4" )), FalseLiteral )
210+ assertEquivalent(EqualTo (ifExp, Literal (" 3" )), EqualTo (ifExp, Literal (3 )))
211+
212+ // Do not simplify if it contains non foldable expressions.
213+ assertEquivalent(EqualTo (ifExp, NonFoldableLiteral (true )),
214+ EqualTo (ifExp, NonFoldableLiteral (true )))
215+ val nonFoldable = If (NonFoldableLiteral (true ), Literal (1 ), Literal (2 ))
216+ assertEquivalent(EqualTo (nonFoldable, Literal (1 )), EqualTo (nonFoldable, Literal (1 )))
217+
218+ // Do not simplify if it contains non-deterministic expressions.
219+ val nonDeterministic = If (LessThan (Rand (1 ), Literal (0.5 )), Literal (1 ), Literal (1 ))
220+ assert(! nonDeterministic.deterministic)
221+ assertEquivalent(EqualTo (nonDeterministic, Literal (- 1 )), EqualTo (nonDeterministic, Literal (- 1 )))
222+
223+ // null check, SPARK-33798 will not change these behaviors.
224+ assertEquivalent(
225+ EqualTo (If (FalseLiteral , Literal (null , IntegerType ), Literal (1 )), Literal (1 )),
226+ TrueLiteral )
227+ assertEquivalent(
228+ EqualTo (If (TrueLiteral , Literal (null , IntegerType ), Literal (1 )), Literal (1 )),
229+ Literal (null , BooleanType ))
230+ assertEquivalent(
231+ EqualTo (If (FalseLiteral , Literal (null , IntegerType ), Literal (null , IntegerType )), Literal (1 )),
232+ Literal (null , BooleanType ))
233+
234+ assertEquivalent(
235+ EqualTo (If (FalseLiteral , Literal (1 ), Literal (2 )), Literal (null , IntegerType )),
236+ Literal (null , BooleanType ))
237+ assertEquivalent(
238+ EqualTo (If (TrueLiteral , Literal (1 ), Literal (2 )), Literal (null , IntegerType )),
239+ Literal (null , BooleanType ))
240+ }
241+
242+ test(" SPARK-33798: simplify EqualTo(CaseWhen, Literal) always false" ) {
243+ val a = EqualTo (UnresolvedAttribute (" a" ), Literal (100 ))
244+ val b = UnresolvedAttribute (" b" )
245+ val c = EqualTo (UnresolvedAttribute (" c" ), Literal (true ))
246+ val caseWhen = CaseWhen (Seq ((a, Literal (1 )), (c, Literal (2 ))), Some (Literal (3 )))
247+
248+ assertEquivalent(EqualTo (caseWhen, Literal (4 )), FalseLiteral )
249+ assertEquivalent(EqualTo (caseWhen, Literal (3 )), EqualTo (caseWhen, Literal (3 )))
250+ assertEquivalent(EqualTo (caseWhen, Literal (" 4" )), FalseLiteral )
251+ assertEquivalent(EqualTo (caseWhen, Literal (" 3" )), EqualTo (caseWhen, Literal (3 )))
252+ assertEquivalent(
253+ EqualTo (CaseWhen (Seq ((a, Literal (" 1" )), (c, Literal (" 2" ))), None ), Literal (" 4" )),
254+ FalseLiteral )
255+
256+ assertEquivalent(
257+ And (EqualTo (caseWhen, Literal (5 )), EqualTo (caseWhen, Literal (6 ))),
258+ FalseLiteral )
259+
260+ assertEquivalent(
261+ EqualTo (CaseWhen (Seq (normalBranch, (a, Literal (1 )), (c, Literal (1 ))), None ), Literal (- 1 )),
262+ FalseLiteral )
263+
264+ // Do not simplify if it contains non foldable expressions.
265+ assertEquivalent(EqualTo (caseWhen, NonFoldableLiteral (true )),
266+ EqualTo (caseWhen, NonFoldableLiteral (true )))
267+ val nonFoldable = CaseWhen (Seq (normalBranch, (a, b)), None )
268+ assertEquivalent(EqualTo (nonFoldable, Literal (1 )), EqualTo (nonFoldable, Literal (1 )))
269+
270+ // Do not simplify if it contains non-deterministic expressions.
271+ val nonDeterministic = CaseWhen (Seq ((LessThan (Rand (1 ), Literal (0.5 )), Literal (1 ))), Some (b))
272+ assert(! nonDeterministic.deterministic)
273+ assertEquivalent(EqualTo (nonDeterministic, Literal (- 1 )), EqualTo (nonDeterministic, Literal (- 1 )))
274+
275+ // null check, SPARK-33798 will change the following two behaviors.
276+ assertEquivalent(
277+ EqualTo (CaseWhen (Seq ((a, Literal (null , IntegerType ))), Some (Literal (1 ))), Literal (2 )),
278+ FalseLiteral )
279+ assertEquivalent(
280+ EqualTo (CaseWhen (Seq ((a, Literal (1 ))), Some (Literal (2 ))), Literal (null , IntegerType )),
281+ FalseLiteral )
282+
283+ assertEquivalent(
284+ EqualTo (CaseWhen (Seq ((a, Literal (null , IntegerType ))), Some (Literal (1 ))), Literal (1 )),
285+ EqualTo (CaseWhen (Seq ((a, Literal (null , IntegerType ))), Some (Literal (1 ))), Literal (1 )))
286+ assertEquivalent(
287+ EqualTo (CaseWhen (Seq ((a, Literal (null , IntegerType ))), Some (Literal (null , IntegerType ))),
288+ Literal (1 )),
289+ Literal (null , BooleanType ))
290+ assertEquivalent(
291+ EqualTo (CaseWhen (Seq ((a, Literal (null , IntegerType ))), Some (Literal (null , IntegerType ))),
292+ Literal (null , IntegerType )),
293+ Literal (null , BooleanType ))
294+ }
202295}
0 commit comments