Skip to content

Commit bbc1e25

Browse files
committed
Push the cast into (if / case) branches.
1 parent 554600c commit bbc1e25

File tree

3 files changed

+53
-6
lines changed

3 files changed

+53
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
9999
LikeSimplification,
100100
BooleanSimplification,
101101
SimplifyConditionals,
102-
PushFoldableIntoBranches,
102+
PushCastAndFoldableIntoBranches,
103103
RemoveDispensableExpressions,
104104
SimplifyBinaryComparison,
105105
ReplaceNullWithFalseInPredicate,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,9 +529,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
529529

530530

531531
/**
532-
* Push the foldable expression into (if / case) branches.
532+
* Push the cast and foldable expression into (if / case) branches.
533533
*/
534-
object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
534+
object PushCastAndFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
535535

536536
// To be conservative here: it's only a guaranteed win if all but at most only one branch
537537
// end up being not foldable.
@@ -542,6 +542,18 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
542542

543543
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
544544
case q: LogicalPlan => q transformExpressionsUp {
545+
case Cast(i @ If(_, trueValue, falseValue), dataType, timeZoneId)
546+
if atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
547+
i.copy(
548+
trueValue = Cast(trueValue, dataType, timeZoneId),
549+
falseValue = Cast(falseValue, dataType, timeZoneId))
550+
551+
case Cast(c @ CaseWhen(branches, elseValue), dataType, timeZoneId)
552+
if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
553+
c.copy(
554+
branches.map(e => e.copy(_2 = Cast(e._2, dataType, timeZoneId))),
555+
elseValue.map(e => Cast(e, dataType, timeZoneId)))
556+
545557
case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right)
546558
if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
547559
i.copy(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,18 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
2727
import org.apache.spark.sql.catalyst.plans.PlanTest
2828
import org.apache.spark.sql.catalyst.plans.logical._
2929
import org.apache.spark.sql.catalyst.rules._
30-
import org.apache.spark.sql.types.{BooleanType, IntegerType}
30+
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType}
3131

3232

3333
class PushFoldableIntoBranchesSuite
3434
extends PlanTest with ExpressionEvalHelper with PredicateHelper {
3535

3636
object Optimize extends RuleExecutor[LogicalPlan] {
37-
val batches = Batch("PushFoldableIntoBranches", FixedPoint(50),
38-
BooleanSimplification, ConstantFolding, SimplifyConditionals, PushFoldableIntoBranches) :: Nil
37+
val batches = Batch("PushCastAndFoldableIntoBranches", FixedPoint(50),
38+
BooleanSimplification,
39+
ConstantFolding,
40+
SimplifyConditionals,
41+
PushCastAndFoldableIntoBranches) :: Nil
3942
}
4043

4144
private val relation = LocalRelation('a.int, 'b.int, 'c.boolean)
@@ -222,4 +225,36 @@ class PushFoldableIntoBranchesSuite
222225
assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral)
223226
assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral)
224227
}
228+
229+
test("Push down cast through If") {
230+
assertEquivalent(If(a, Literal(2), Literal(3)).cast(StringType),
231+
If(a, Literal("2"), Literal("3")))
232+
assertEquivalent(If(a, b, Literal(3)).cast(StringType),
233+
If(a, b.cast(StringType), Literal("3")))
234+
assertEquivalent(If(a, b, b + 1).cast(StringType),
235+
If(a, b, b + 1).cast(StringType))
236+
}
237+
238+
test("Push down cast through CaseWhen") {
239+
assertEquivalent(
240+
CaseWhen(Seq((a, Literal(1))), Some(Literal(3))).cast(StringType),
241+
CaseWhen(Seq((a, Literal("1"))), Some(Literal("3"))))
242+
assertEquivalent(
243+
CaseWhen(Seq((a, Literal(1))), Some(b)).cast(StringType),
244+
CaseWhen(Seq((a, Literal("1"))), Some(b.cast(StringType))))
245+
assertEquivalent(
246+
CaseWhen(Seq((a, b)), Some(b + 1)).cast(StringType),
247+
CaseWhen(Seq((a, b)), Some(b + 1)).cast(StringType))
248+
}
249+
250+
test("Push down cast with binary expression through If/CaseWhen") {
251+
assertEquivalent(EqualTo(If(a, Literal(2), Literal(3)).cast(StringType), Literal("4")),
252+
FalseLiteral)
253+
assertEquivalent(
254+
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(3))).cast(StringType), Literal("4")),
255+
FalseLiteral)
256+
assertEquivalent(
257+
EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None).cast(StringType), Literal("4")),
258+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
259+
}
225260
}

0 commit comments

Comments
 (0)