@@ -49,10 +49,21 @@ trait HiveTypeCoercion {
4949 BooleanCasts ::
5050 StringToIntegralCasts ::
5151 FunctionArgumentConversion ::
52- CastNulls ::
52+ CaseWhenCoercion ::
5353 Division ::
5454 Nil
5555
56+ trait TypeWidening {
57+ def findTightestCommonType (t1 : DataType , t2 : DataType ): Option [DataType ] = {
58+ // Try and find a promotion rule that contains both types in question.
59+ val applicableConversion =
60+ HiveTypeCoercion .allPromotions.find(p => p.contains(t1) && p.contains(t2))
61+
62+ // If found return the widest common type, otherwise None
63+ applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
64+ }
65+ }
66+
5667 /**
5768 * Applies any changes to [[AttributeReference ]] data types that are made by other rules to
5869 * instances higher in the query tree.
@@ -133,16 +144,7 @@ trait HiveTypeCoercion {
133144 * - LongType to FloatType
134145 * - LongType to DoubleType
135146 */
136- object WidenTypes extends Rule [LogicalPlan ] {
137-
138- def findTightestCommonType (t1 : DataType , t2 : DataType ): Option [DataType ] = {
139- // Try and find a promotion rule that contains both types in question.
140- val applicableConversion =
141- HiveTypeCoercion .allPromotions.find(p => p.contains(t1) && p.contains(t2))
142-
143- // If found return the widest common type, otherwise None
144- applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
145- }
147+ object WidenTypes extends Rule [LogicalPlan ] with TypeWidening {
146148
147149 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
148150 case u @ Union (left, right) if u.childrenResolved && ! u.resolved =>
@@ -336,28 +338,34 @@ trait HiveTypeCoercion {
336338 }
337339
338340 /**
339- * Ensures that NullType gets casted to some other types under certain circumstances .
341+ * Coerces the type of different branches of a CASE WHEN statement to a common type .
340342 */
341- object CastNulls extends Rule [LogicalPlan ] {
343+ object CaseWhenCoercion extends Rule [LogicalPlan ] with TypeWidening {
342344 def apply (plan : LogicalPlan ): LogicalPlan = plan transformAllExpressions {
343- case cw @ CaseWhen (branches) =>
345+ case cw @ CaseWhen (branches) if ! cw.resolved && ! branches.exists( ! _.resolved) =>
344346 val valueTypes = branches.sliding(2 , 2 ).map {
345- case Seq (_, value) if value.resolved => Some (value.dataType)
346- case Seq (elseVal) if elseVal.resolved => Some (elseVal.dataType)
347- case _ => None
347+ case Seq (_, value) => value.dataType
348+ case Seq (elseVal) => elseVal.dataType
348349 }.toSeq
349- if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some (NullType ))) {
350- val otherType = valueTypes.filterNot(_ == Some (NullType ))(0 ).get
350+
351+ logDebug(s " Input values for null casting ${valueTypes.mkString(" ," )}" )
352+
353+ if (valueTypes.distinct.size > 1 ) {
354+ val commonType = valueTypes.reduce { (v1, v2) =>
355+ findTightestCommonType(v1, v2)
356+ .getOrElse(sys.error(
357+ s " Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2" ))
358+ }
351359 val transformedBranches = branches.sliding(2 , 2 ).map {
352- case Seq (cond, value) if value.resolved && value. dataType == NullType =>
353- Seq (cond, Cast (value, otherType ))
354- case Seq (elseVal) if elseVal.resolved && elseVal. dataType == NullType =>
355- Seq (Cast (elseVal, otherType ))
360+ case Seq (cond, value) if value.dataType != commonType =>
361+ Seq (cond, Cast (value, commonType ))
362+ case Seq (elseVal) if elseVal.dataType != commonType =>
363+ Seq (Cast (elseVal, commonType ))
356364 case s => s
357365 }.reduce(_ ++ _)
358366 CaseWhen (transformedBranches)
359367 } else {
360- // It is possible to have more types due to the possibility of short-circuiting .
368+ // Types match up. Hopefully some other rule fixes whatever is wrong with resolution .
361369 cw
362370 }
363371 }
0 commit comments