Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,12 @@ abstract class UnaryExpression extends Expression {
}
}


object UnaryExpression {
def unapply(e: UnaryExpression): Option[Expression] = Some(e.child)
}


/**
* An expression with two inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,29 +542,42 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case a: Alias => a // Skip an alias.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alias is not the only exception, we can't apply this optimization for Generator as well, as the logical plan Generate requires explicit type Generator.

It happened many times that an optimization rule introduces bugs because it uses denylist instead of allowlist. Let's avoid similar mistakes here, and explicitly list what expressions we should support.

An initial list from my mind:

  1. IsNull, IsNotNull
  2. UnaryMathExpression
  3. String2StringExpression
  4. Cast
  5. BinaryComparison
  6. BinaryArithmetic

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think BinaryExpression also has this issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, but it's safer to start with an allowlist. We can extend it later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
if atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = u.withNewChildren(Array(trueValue)),
falseValue = u.withNewChildren(Array(falseValue)))

case u @ UnaryExpression(c @ CaseWhen(branches, elseValue))
if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))),
elseValue.map(e => u.withNewChildren(Array(e))))

case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right)
if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = b.makeCopy(Array(trueValue, right)),
falseValue = b.makeCopy(Array(falseValue, right)))
trueValue = b.withNewChildren(Array(trueValue, right)),
falseValue = b.withNewChildren(Array(falseValue, right)))

case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue))
if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = b.makeCopy(Array(left, trueValue)),
falseValue = b.makeCopy(Array(left, falseValue)))
trueValue = b.withNewChildren(Array(left, trueValue)),
falseValue = b.withNewChildren(Array(left, falseValue)))

case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right)
if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))),
elseValue.map(e => b.makeCopy(Array(e, right))))
branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))),
elseValue.map(e => b.withNewChildren(Array(e, right))))

case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue))
if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))),
elseValue.map(e => b.makeCopy(Array(left, e))))
branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))),
elseValue.map(e => b.withNewChildren(Array(left, e))))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{BooleanType, IntegerType}
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType}


class PushFoldableIntoBranchesSuite
Expand Down Expand Up @@ -222,4 +222,41 @@ class PushFoldableIntoBranchesSuite
assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral)
assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral)
}

test("Push down cast through If/CaseWhen") {
assertEquivalent(If(a, Literal(2), Literal(3)).cast(StringType),
If(a, Literal("2"), Literal("3")))
assertEquivalent(If(a, b, Literal(3)).cast(StringType),
If(a, b.cast(StringType), Literal("3")))
assertEquivalent(If(a, b, b + 1).cast(StringType),
If(a, b, b + 1).cast(StringType))

assertEquivalent(
CaseWhen(Seq((a, Literal(1))), Some(Literal(3))).cast(StringType),
CaseWhen(Seq((a, Literal("1"))), Some(Literal("3"))))
assertEquivalent(
CaseWhen(Seq((a, Literal(1))), Some(b)).cast(StringType),
CaseWhen(Seq((a, Literal("1"))), Some(b.cast(StringType))))
assertEquivalent(
CaseWhen(Seq((a, b)), Some(b + 1)).cast(StringType),
CaseWhen(Seq((a, b)), Some(b + 1)).cast(StringType))
}

test("Push down abs through If/CaseWhen") {
assertEquivalent(Abs(If(a, Literal(-2), Literal(-3))), If(a, Literal(2), Literal(3)))
assertEquivalent(
Abs(CaseWhen(Seq((a, Literal(-1))), Some(Literal(-3)))),
CaseWhen(Seq((a, Literal(1))), Some(Literal(3))))
}

test("Push down cast with binary expression through If/CaseWhen") {
assertEquivalent(EqualTo(If(a, Literal(2), Literal(3)).cast(StringType), Literal("4")),
FalseLiteral)
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(3))).cast(StringType), Literal("4")),
FalseLiteral)
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None).cast(StringType), Literal("4")),
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,24 @@ Input [6]: [inv_warehouse_sk#3, inv_quantity_on_hand#4, i_item_id#6, d_date#10,
(23) HashAggregate [codegen id : 4]
Input [4]: [inv_quantity_on_hand#4, w_warehouse_name#13, i_item_id#6, d_date#10]
Keys [2]: [w_warehouse_name#13, i_item_id#6]
Functions [2]: [partial_sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), partial_sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Functions [2]: [partial_sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), partial_sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum#15, sum#16]
Results [4]: [w_warehouse_name#13, i_item_id#6, sum#17, sum#18]

(24) Exchange
Input [4]: [w_warehouse_name#13, i_item_id#6, sum#17, sum#18]
Arguments: hashpartitioning(w_warehouse_name#13, i_item_id#6, 5), true, [id=#19]
Arguments: hashpartitioning(w_warehouse_name#13, i_item_id#6, 5), ENSURE_REQUIREMENTS, [id=#19]

(25) HashAggregate [codegen id : 5]
Input [4]: [w_warehouse_name#13, i_item_id#6, sum#17, sum#18]
Keys [2]: [w_warehouse_name#13, i_item_id#6]
Functions [2]: [sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Aggregate Attributes [2]: [sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20, sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21]
Results [4]: [w_warehouse_name#13, i_item_id#6, sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20 AS inv_before#22, sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21 AS inv_after#23]
Functions [2]: [sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20, sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21]
Results [4]: [w_warehouse_name#13, i_item_id#6, sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20 AS inv_before#22, sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21 AS inv_after#23]

(26) Filter [codegen id : 5]
Input [4]: [w_warehouse_name#13, i_item_id#6, inv_before#22, inv_after#23]
Condition : ((CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END >= 0.666667) AND (CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END <= 1.5))
Condition : (CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) >= 0.666667) ELSE false END AND CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) <= 1.5) ELSE false END)

(27) TakeOrderedAndProject
Input [4]: [w_warehouse_name#13, i_item_id#6, inv_before#22, inv_after#23]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
TakeOrderedAndProject [w_warehouse_name,i_item_id,inv_before,inv_after]
WholeStageCodegen (5)
Filter [inv_before,inv_after]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(cast(CASE WHEN (d_date < 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),sum(cast(CASE WHEN (d_date >= 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),inv_before,inv_after,sum,sum]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(CASE WHEN (d_date < 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),sum(CASE WHEN (d_date >= 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),inv_before,inv_after,sum,sum]
InputAdapter
Exchange [w_warehouse_name,i_item_id] #1
WholeStageCodegen (4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,24 @@ Input [6]: [inv_date_sk#1, inv_quantity_on_hand#4, w_warehouse_name#6, i_item_id
(23) HashAggregate [codegen id : 4]
Input [4]: [inv_quantity_on_hand#4, w_warehouse_name#6, i_item_id#9, d_date#13]
Keys [2]: [w_warehouse_name#6, i_item_id#9]
Functions [2]: [partial_sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), partial_sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Functions [2]: [partial_sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), partial_sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum#15, sum#16]
Results [4]: [w_warehouse_name#6, i_item_id#9, sum#17, sum#18]

(24) Exchange
Input [4]: [w_warehouse_name#6, i_item_id#9, sum#17, sum#18]
Arguments: hashpartitioning(w_warehouse_name#6, i_item_id#9, 5), true, [id=#19]
Arguments: hashpartitioning(w_warehouse_name#6, i_item_id#9, 5), ENSURE_REQUIREMENTS, [id=#19]

(25) HashAggregate [codegen id : 5]
Input [4]: [w_warehouse_name#6, i_item_id#9, sum#17, sum#18]
Keys [2]: [w_warehouse_name#6, i_item_id#9]
Functions [2]: [sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Aggregate Attributes [2]: [sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20, sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21]
Results [4]: [w_warehouse_name#6, i_item_id#9, sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20 AS inv_before#22, sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21 AS inv_after#23]
Functions [2]: [sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20, sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21]
Results [4]: [w_warehouse_name#6, i_item_id#9, sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20 AS inv_before#22, sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21 AS inv_after#23]

(26) Filter [codegen id : 5]
Input [4]: [w_warehouse_name#6, i_item_id#9, inv_before#22, inv_after#23]
Condition : ((CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END >= 0.666667) AND (CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END <= 1.5))
Condition : (CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) >= 0.666667) ELSE false END AND CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) <= 1.5) ELSE false END)

(27) TakeOrderedAndProject
Input [4]: [w_warehouse_name#6, i_item_id#9, inv_before#22, inv_after#23]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
TakeOrderedAndProject [w_warehouse_name,i_item_id,inv_before,inv_after]
WholeStageCodegen (5)
Filter [inv_before,inv_after]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(cast(CASE WHEN (d_date < 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),sum(cast(CASE WHEN (d_date >= 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),inv_before,inv_after,sum,sum]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(CASE WHEN (d_date < 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),sum(CASE WHEN (d_date >= 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),inv_before,inv_after,sum,sum]
InputAdapter
Exchange [w_warehouse_name,i_item_id] #1
WholeStageCodegen (4)
Expand Down
Loading