Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
LikeSimplification,
BooleanSimplification,
SimplifyConditionals,
PushFoldableIntoBranches,
PushCastAndFoldableIntoBranches,
RemoveDispensableExpressions,
SimplifyBinaryComparison,
ReplaceNullWithFalseInPredicate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,9 +529,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {


/**
* Push the foldable expression into (if / case) branches.
* Push the cast and foldable expression into (if / case) branches.
*/
object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
object PushCastAndFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {

// To be conservative here: it's only a guaranteed win if all but at most only one branch
// end up being not foldable.
Expand All @@ -542,6 +542,18 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case Cast(i @ If(_, trueValue, falseValue), dataType, timeZoneId)
if atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = Cast(trueValue, dataType, timeZoneId),
falseValue = Cast(falseValue, dataType, timeZoneId))

case Cast(c @ CaseWhen(branches, elseValue), dataType, timeZoneId)
if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = Cast(e._2, dataType, timeZoneId))),
elseValue.map(e => Cast(e, dataType, timeZoneId)))

case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right)
if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{BooleanType, IntegerType}
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType}


class PushFoldableIntoBranchesSuite
class PushCastAndFoldableIntoBranchesSuite
extends PlanTest with ExpressionEvalHelper with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("PushFoldableIntoBranches", FixedPoint(50),
BooleanSimplification, ConstantFolding, SimplifyConditionals, PushFoldableIntoBranches) :: Nil
val batches = Batch("PushCastAndFoldableIntoBranches", FixedPoint(50),
BooleanSimplification,
ConstantFolding,
SimplifyConditionals,
PushCastAndFoldableIntoBranches) :: Nil
}

private val relation = LocalRelation('a.int, 'b.int, 'c.boolean)
Expand Down Expand Up @@ -222,4 +225,36 @@ class PushFoldableIntoBranchesSuite
assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral)
assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral)
}

test("Push down cast through If") {
assertEquivalent(If(a, Literal(2), Literal(3)).cast(StringType),
If(a, Literal("2"), Literal("3")))
assertEquivalent(If(a, b, Literal(3)).cast(StringType),
If(a, b.cast(StringType), Literal("3")))
assertEquivalent(If(a, b, b + 1).cast(StringType),
If(a, b, b + 1).cast(StringType))
}

test("Push down cast through CaseWhen") {
assertEquivalent(
CaseWhen(Seq((a, Literal(1))), Some(Literal(3))).cast(StringType),
CaseWhen(Seq((a, Literal("1"))), Some(Literal("3"))))
assertEquivalent(
CaseWhen(Seq((a, Literal(1))), Some(b)).cast(StringType),
CaseWhen(Seq((a, Literal("1"))), Some(b.cast(StringType))))
assertEquivalent(
CaseWhen(Seq((a, b)), Some(b + 1)).cast(StringType),
CaseWhen(Seq((a, b)), Some(b + 1)).cast(StringType))
}

test("Push down cast with binary expression through If/CaseWhen") {
assertEquivalent(EqualTo(If(a, Literal(2), Literal(3)).cast(StringType), Literal("4")),
FalseLiteral)
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(3))).cast(StringType), Literal("4")),
FalseLiteral)
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None).cast(StringType), Literal("4")),
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,24 @@ Input [6]: [inv_warehouse_sk#3, inv_quantity_on_hand#4, i_item_id#6, d_date#10,
(23) HashAggregate [codegen id : 4]
Input [4]: [inv_quantity_on_hand#4, w_warehouse_name#13, i_item_id#6, d_date#10]
Keys [2]: [w_warehouse_name#13, i_item_id#6]
Functions [2]: [partial_sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), partial_sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Functions [2]: [partial_sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), partial_sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum#15, sum#16]
Results [4]: [w_warehouse_name#13, i_item_id#6, sum#17, sum#18]

(24) Exchange
Input [4]: [w_warehouse_name#13, i_item_id#6, sum#17, sum#18]
Arguments: hashpartitioning(w_warehouse_name#13, i_item_id#6, 5), true, [id=#19]
Arguments: hashpartitioning(w_warehouse_name#13, i_item_id#6, 5), ENSURE_REQUIREMENTS, [id=#19]

(25) HashAggregate [codegen id : 5]
Input [4]: [w_warehouse_name#13, i_item_id#6, sum#17, sum#18]
Keys [2]: [w_warehouse_name#13, i_item_id#6]
Functions [2]: [sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Aggregate Attributes [2]: [sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20, sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21]
Results [4]: [w_warehouse_name#13, i_item_id#6, sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20 AS inv_before#22, sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21 AS inv_after#23]
Functions [2]: [sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20, sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21]
Results [4]: [w_warehouse_name#13, i_item_id#6, sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20 AS inv_before#22, sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21 AS inv_after#23]

(26) Filter [codegen id : 5]
Input [4]: [w_warehouse_name#13, i_item_id#6, inv_before#22, inv_after#23]
Condition : ((CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END >= 0.666667) AND (CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END <= 1.5))
Condition : (CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) >= 0.666667) ELSE false END AND CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) <= 1.5) ELSE false END)

(27) TakeOrderedAndProject
Input [4]: [w_warehouse_name#13, i_item_id#6, inv_before#22, inv_after#23]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
TakeOrderedAndProject [w_warehouse_name,i_item_id,inv_before,inv_after]
WholeStageCodegen (5)
Filter [inv_before,inv_after]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(cast(CASE WHEN (d_date < 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),sum(cast(CASE WHEN (d_date >= 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),inv_before,inv_after,sum,sum]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(CASE WHEN (d_date < 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),sum(CASE WHEN (d_date >= 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),inv_before,inv_after,sum,sum]
InputAdapter
Exchange [w_warehouse_name,i_item_id] #1
WholeStageCodegen (4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,24 @@ Input [6]: [inv_date_sk#1, inv_quantity_on_hand#4, w_warehouse_name#6, i_item_id
(23) HashAggregate [codegen id : 4]
Input [4]: [inv_quantity_on_hand#4, w_warehouse_name#6, i_item_id#9, d_date#13]
Keys [2]: [w_warehouse_name#6, i_item_id#9]
Functions [2]: [partial_sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), partial_sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Functions [2]: [partial_sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), partial_sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum#15, sum#16]
Results [4]: [w_warehouse_name#6, i_item_id#9, sum#17, sum#18]

(24) Exchange
Input [4]: [w_warehouse_name#6, i_item_id#9, sum#17, sum#18]
Arguments: hashpartitioning(w_warehouse_name#6, i_item_id#9, 5), true, [id=#19]
Arguments: hashpartitioning(w_warehouse_name#6, i_item_id#9, 5), ENSURE_REQUIREMENTS, [id=#19]

(25) HashAggregate [codegen id : 5]
Input [4]: [w_warehouse_name#6, i_item_id#9, sum#17, sum#18]
Keys [2]: [w_warehouse_name#6, i_item_id#9]
Functions [2]: [sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Aggregate Attributes [2]: [sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20, sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21]
Results [4]: [w_warehouse_name#6, i_item_id#9, sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20 AS inv_before#22, sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21 AS inv_after#23]
Functions [2]: [sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20, sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21]
Results [4]: [w_warehouse_name#6, i_item_id#9, sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20 AS inv_before#22, sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21 AS inv_after#23]

(26) Filter [codegen id : 5]
Input [4]: [w_warehouse_name#6, i_item_id#9, inv_before#22, inv_after#23]
Condition : ((CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END >= 0.666667) AND (CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END <= 1.5))
Condition : (CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) >= 0.666667) ELSE false END AND CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) <= 1.5) ELSE false END)

(27) TakeOrderedAndProject
Input [4]: [w_warehouse_name#6, i_item_id#9, inv_before#22, inv_after#23]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
TakeOrderedAndProject [w_warehouse_name,i_item_id,inv_before,inv_after]
WholeStageCodegen (5)
Filter [inv_before,inv_after]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(cast(CASE WHEN (d_date < 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),sum(cast(CASE WHEN (d_date >= 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),inv_before,inv_after,sum,sum]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(CASE WHEN (d_date < 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),sum(CASE WHEN (d_date >= 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),inv_before,inv_after,sum,sum]
InputAdapter
Exchange [w_warehouse_name,i_item_id] #1
WholeStageCodegen (4)
Expand Down
Loading