@@ -27,15 +27,18 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
2727import org .apache .spark .sql .catalyst .plans .PlanTest
2828import org .apache .spark .sql .catalyst .plans .logical ._
2929import org .apache .spark .sql .catalyst .rules ._
30- import org .apache .spark .sql .types .{BooleanType , IntegerType }
30+ import org .apache .spark .sql .types .{BooleanType , IntegerType , StringType }
3131
3232
3333class PushFoldableIntoBranchesSuite
3434 extends PlanTest with ExpressionEvalHelper with PredicateHelper {
3535
3636 object Optimize extends RuleExecutor [LogicalPlan ] {
37- val batches = Batch (" PushFoldableIntoBranches" , FixedPoint (50 ),
38- BooleanSimplification , ConstantFolding , SimplifyConditionals , PushFoldableIntoBranches ) :: Nil
37+ val batches = Batch (" PushCastAndFoldableIntoBranches" , FixedPoint (50 ),
38+ BooleanSimplification ,
39+ ConstantFolding ,
40+ SimplifyConditionals ,
41+ PushCastAndFoldableIntoBranches ) :: Nil
3942 }
4043
4144 private val relation = LocalRelation (' a .int, ' b .int, ' c .boolean)
@@ -222,4 +225,36 @@ class PushFoldableIntoBranchesSuite
222225 assertEquivalent(EqualTo (Literal (4 ), ifExp), FalseLiteral )
223226 assertEquivalent(EqualTo (Literal (4 ), caseWhen), FalseLiteral )
224227 }
228+
229+ test(" Push down cast through If" ) {
230+ assertEquivalent(If (a, Literal (2 ), Literal (3 )).cast(StringType ),
231+ If (a, Literal (" 2" ), Literal (" 3" )))
232+ assertEquivalent(If (a, b, Literal (3 )).cast(StringType ),
233+ If (a, b.cast(StringType ), Literal (" 3" )))
234+ assertEquivalent(If (a, b, b + 1 ).cast(StringType ),
235+ If (a, b, b + 1 ).cast(StringType ))
236+ }
237+
238+ test(" Push down cast through CaseWhen" ) {
239+ assertEquivalent(
240+ CaseWhen (Seq ((a, Literal (1 ))), Some (Literal (3 ))).cast(StringType ),
241+ CaseWhen (Seq ((a, Literal (" 1" ))), Some (Literal (" 3" ))))
242+ assertEquivalent(
243+ CaseWhen (Seq ((a, Literal (1 ))), Some (b)).cast(StringType ),
244+ CaseWhen (Seq ((a, Literal (" 1" ))), Some (b.cast(StringType ))))
245+ assertEquivalent(
246+ CaseWhen (Seq ((a, b)), Some (b + 1 )).cast(StringType ),
247+ CaseWhen (Seq ((a, b)), Some (b + 1 )).cast(StringType ))
248+ }
249+
250+ test(" Push down cast with binary expression through If/CaseWhen" ) {
251+ assertEquivalent(EqualTo (If (a, Literal (2 ), Literal (3 )).cast(StringType ), Literal (" 4" )),
252+ FalseLiteral )
253+ assertEquivalent(
254+ EqualTo (CaseWhen (Seq ((a, Literal (1 ))), Some (Literal (3 ))).cast(StringType ), Literal (" 4" )),
255+ FalseLiteral )
256+ assertEquivalent(
257+ EqualTo (CaseWhen (Seq ((a, Literal (1 )), (c, Literal (2 ))), None ).cast(StringType ), Literal (" 4" )),
258+ CaseWhen (Seq ((a, FalseLiteral ), (c, FalseLiteral )), None ))
259+ }
225260}
0 commit comments