@@ -529,9 +529,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
529529
530530
531531/**
532- * Push the cast and foldable expression into (if / case) branches.
532+ * Push the foldable expression into (if / case) branches.
533533 */
534- object PushCastAndFoldableIntoBranches extends Rule [LogicalPlan ] with PredicateHelper {
534+ object PushFoldableIntoBranches extends Rule [LogicalPlan ] with PredicateHelper {
535535
536536 // To be conservative here: it's only a guaranteed win if all but at most only one branch
537537 // end up being not foldable.
@@ -542,41 +542,41 @@ object PushCastAndFoldableIntoBranches extends Rule[LogicalPlan] with PredicateH
542542
543543 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
544544 case q : LogicalPlan => q transformExpressionsUp {
545- case Cast (i @ If (_, trueValue, falseValue), dataType, timeZoneId )
546- if atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
545+ case u @ UnaryExpression (i @ If (_, trueValue, falseValue))
546+ if ! u. isInstanceOf [ Alias ] && atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
547547 i.copy(
548- trueValue = Cast ( trueValue, dataType, timeZoneId ),
549- falseValue = Cast ( falseValue, dataType, timeZoneId ))
548+ trueValue = u.withNewChildren( Array ( trueValue) ),
549+ falseValue = u.withNewChildren( Array ( falseValue) ))
550550
551- case Cast (c @ CaseWhen (branches, elseValue), dataType, timeZoneId )
552- if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
551+ case u @ UnaryExpression (c @ CaseWhen (branches, elseValue))
552+ if ! u. isInstanceOf [ Alias ] && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
553553 c.copy(
554- branches.map(e => e.copy(_2 = Cast ( e._2, dataType, timeZoneId ))),
555- elseValue.map(e => Cast (e, dataType, timeZoneId )))
554+ branches.map(e => e.copy(_2 = u.withNewChildren( Array ( e._2) ))),
555+ elseValue.map(e => u.withNewChildren( Array (e) )))
556556
557557 case b @ BinaryExpression (i @ If (_, trueValue, falseValue), right)
558558 if right.foldable && atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
559559 i.copy(
560- trueValue = b.makeCopy (Array (trueValue, right)),
561- falseValue = b.makeCopy (Array (falseValue, right)))
560+ trueValue = b.withNewChildren (Array (trueValue, right)),
561+ falseValue = b.withNewChildren (Array (falseValue, right)))
562562
563563 case b @ BinaryExpression (left, i @ If (_, trueValue, falseValue))
564564 if left.foldable && atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
565565 i.copy(
566- trueValue = b.makeCopy (Array (left, trueValue)),
567- falseValue = b.makeCopy (Array (left, falseValue)))
566+ trueValue = b.withNewChildren (Array (left, trueValue)),
567+ falseValue = b.withNewChildren (Array (left, falseValue)))
568568
569569 case b @ BinaryExpression (c @ CaseWhen (branches, elseValue), right)
570570 if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
571571 c.copy(
572- branches.map(e => e.copy(_2 = b.makeCopy (Array (e._2, right)))),
573- elseValue.map(e => b.makeCopy (Array (e, right))))
572+ branches.map(e => e.copy(_2 = b.withNewChildren (Array (e._2, right)))),
573+ elseValue.map(e => b.withNewChildren (Array (e, right))))
574574
575575 case b @ BinaryExpression (left, c @ CaseWhen (branches, elseValue))
576576 if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
577577 c.copy(
578- branches.map(e => e.copy(_2 = b.makeCopy (Array (left, e._2)))),
579- elseValue.map(e => b.makeCopy (Array (left, e))))
578+ branches.map(e => e.copy(_2 = b.withNewChildren (Array (left, e._2)))),
579+ elseValue.map(e => b.withNewChildren (Array (left, e))))
580580 }
581581 }
582582}
0 commit comments