Skip to content

Commit 912cdda

Browse files
committed
fix
1 parent 480a92b commit 912cdda

File tree

4 files changed

+36
-28
lines changed

4 files changed

+36
-28
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,12 @@ abstract class UnaryExpression extends Expression {
536536
}
537537
}
538538

539+
540+
object UnaryExpression {
541+
def unapply(e: UnaryExpression): Option[Expression] = Some(e.child)
542+
}
543+
544+
539545
/**
540546
* An expression with two inputs and one output. The output is by default evaluated to null
541547
* if any input is evaluated to null.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
9999
LikeSimplification,
100100
BooleanSimplification,
101101
SimplifyConditionals,
102-
PushCastAndFoldableIntoBranches,
102+
PushFoldableIntoBranches,
103103
RemoveDispensableExpressions,
104104
SimplifyBinaryComparison,
105105
ReplaceNullWithFalseInPredicate,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -529,9 +529,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
529529

530530

531531
/**
532-
* Push the cast and foldable expression into (if / case) branches.
532+
* Push the foldable expression into (if / case) branches.
533533
*/
534-
object PushCastAndFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
534+
object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
535535

536536
// To be conservative here: it's only a guaranteed win if all but at most only one branch
537537
// end up being not foldable.
@@ -542,41 +542,41 @@ object PushCastAndFoldableIntoBranches extends Rule[LogicalPlan] with PredicateH
542542

543543
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
544544
case q: LogicalPlan => q transformExpressionsUp {
545-
case Cast(i @ If(_, trueValue, falseValue), dataType, timeZoneId)
546-
if atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
545+
case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
546+
if !u.isInstanceOf[Alias] && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
547547
i.copy(
548-
trueValue = Cast(trueValue, dataType, timeZoneId),
549-
falseValue = Cast(falseValue, dataType, timeZoneId))
548+
trueValue = u.withNewChildren(Array(trueValue)),
549+
falseValue = u.withNewChildren(Array(falseValue)))
550550

551-
case Cast(c @ CaseWhen(branches, elseValue), dataType, timeZoneId)
552-
if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
551+
case u @ UnaryExpression(c @ CaseWhen(branches, elseValue))
552+
if !u.isInstanceOf[Alias] && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
553553
c.copy(
554-
branches.map(e => e.copy(_2 = Cast(e._2, dataType, timeZoneId))),
555-
elseValue.map(e => Cast(e, dataType, timeZoneId)))
554+
branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))),
555+
elseValue.map(e => u.withNewChildren(Array(e))))
556556

557557
case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right)
558558
if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
559559
i.copy(
560-
trueValue = b.makeCopy(Array(trueValue, right)),
561-
falseValue = b.makeCopy(Array(falseValue, right)))
560+
trueValue = b.withNewChildren(Array(trueValue, right)),
561+
falseValue = b.withNewChildren(Array(falseValue, right)))
562562

563563
case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue))
564564
if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
565565
i.copy(
566-
trueValue = b.makeCopy(Array(left, trueValue)),
567-
falseValue = b.makeCopy(Array(left, falseValue)))
566+
trueValue = b.withNewChildren(Array(left, trueValue)),
567+
falseValue = b.withNewChildren(Array(left, falseValue)))
568568

569569
case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right)
570570
if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
571571
c.copy(
572-
branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))),
573-
elseValue.map(e => b.makeCopy(Array(e, right))))
572+
branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))),
573+
elseValue.map(e => b.withNewChildren(Array(e, right))))
574574

575575
case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue))
576576
if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
577577
c.copy(
578-
branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))),
579-
elseValue.map(e => b.makeCopy(Array(left, e))))
578+
branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))),
579+
elseValue.map(e => b.withNewChildren(Array(left, e))))
580580
}
581581
}
582582
}
Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,12 @@ import org.apache.spark.sql.catalyst.rules._
3030
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType}
3131

3232

33-
class PushCastAndFoldableIntoBranchesSuite
33+
class PushFoldableIntoBranchesSuite
3434
extends PlanTest with ExpressionEvalHelper with PredicateHelper {
3535

3636
object Optimize extends RuleExecutor[LogicalPlan] {
37-
val batches = Batch("PushCastAndFoldableIntoBranches", FixedPoint(50),
38-
BooleanSimplification,
39-
ConstantFolding,
40-
SimplifyConditionals,
41-
PushCastAndFoldableIntoBranches) :: Nil
37+
val batches = Batch("PushFoldableIntoBranchesSuite", FixedPoint(50),
38+
BooleanSimplification, ConstantFolding, SimplifyConditionals, PushFoldableIntoBranches) :: Nil
4239
}
4340

4441
private val relation = LocalRelation('a.int, 'b.int, 'c.boolean)
@@ -226,16 +223,14 @@ class PushCastAndFoldableIntoBranchesSuite
226223
assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral)
227224
}
228225

229-
test("Push down cast through If") {
226+
test("Push down cast through If/CaseWhen") {
230227
assertEquivalent(If(a, Literal(2), Literal(3)).cast(StringType),
231228
If(a, Literal("2"), Literal("3")))
232229
assertEquivalent(If(a, b, Literal(3)).cast(StringType),
233230
If(a, b.cast(StringType), Literal("3")))
234231
assertEquivalent(If(a, b, b + 1).cast(StringType),
235232
If(a, b, b + 1).cast(StringType))
236-
}
237233

238-
test("Push down cast through CaseWhen") {
239234
assertEquivalent(
240235
CaseWhen(Seq((a, Literal(1))), Some(Literal(3))).cast(StringType),
241236
CaseWhen(Seq((a, Literal("1"))), Some(Literal("3"))))
@@ -247,6 +242,13 @@ class PushCastAndFoldableIntoBranchesSuite
247242
CaseWhen(Seq((a, b)), Some(b + 1)).cast(StringType))
248243
}
249244

245+
test("Push down abs through If/CaseWhen") {
246+
assertEquivalent(Abs(If(a, Literal(-2), Literal(-3))), If(a, Literal(2), Literal(3)))
247+
assertEquivalent(
248+
Abs(CaseWhen(Seq((a, Literal(-1))), Some(Literal(-3)))),
249+
CaseWhen(Seq((a, Literal(1))), Some(Literal(3))))
250+
}
251+
250252
test("Push down cast with binary expression through If/CaseWhen") {
251253
assertEquivalent(EqualTo(If(a, Literal(2), Literal(3)).cast(StringType), Literal("4")),
252254
FalseLiteral)

0 commit comments

Comments
 (0)