Skip to content

Commit 388ab53

Browse files
committed
[SPARK-2860][SQL] Fix coercion of CASE WHEN.
Author: Michael Armbrust <[email protected]> Closes #1785 from marmbrus/caseNull and squashes the following commits: 126006d [Michael Armbrust] better error message 2fe357f [Michael Armbrust] Fix coercion of CASE WHEN. (cherry picked from commit 6e821e3) Signed-off-by: Michael Armbrust <[email protected]>
1 parent e3fe657 commit 388ab53

File tree

3 files changed

+36
-24
lines changed

3 files changed

+36
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,21 @@ trait HiveTypeCoercion {
4949
BooleanCasts ::
5050
StringToIntegralCasts ::
5151
FunctionArgumentConversion ::
52-
CastNulls ::
52+
CaseWhenCoercion ::
5353
Division ::
5454
Nil
5555

56+
trait TypeWidening {
57+
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
58+
// Try and find a promotion rule that contains both types in question.
59+
val applicableConversion =
60+
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
61+
62+
// If found return the widest common type, otherwise None
63+
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
64+
}
65+
}
66+
5667
/**
5768
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
5869
* instances higher in the query tree.
@@ -133,16 +144,7 @@ trait HiveTypeCoercion {
133144
* - LongType to FloatType
134145
* - LongType to DoubleType
135146
*/
136-
object WidenTypes extends Rule[LogicalPlan] {
137-
138-
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
139-
// Try and find a promotion rule that contains both types in question.
140-
val applicableConversion =
141-
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
142-
143-
// If found return the widest common type, otherwise None
144-
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
145-
}
147+
object WidenTypes extends Rule[LogicalPlan] with TypeWidening {
146148

147149
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
148150
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
@@ -336,28 +338,34 @@ trait HiveTypeCoercion {
336338
}
337339

338340
/**
339-
* Ensures that NullType gets casted to some other types under certain circumstances.
341+
* Coerces the type of different branches of a CASE WHEN statement to a common type.
340342
*/
341-
object CastNulls extends Rule[LogicalPlan] {
343+
object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
342344
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
343-
case cw @ CaseWhen(branches) =>
345+
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
344346
val valueTypes = branches.sliding(2, 2).map {
345-
case Seq(_, value) if value.resolved => Some(value.dataType)
346-
case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType)
347-
case _ => None
347+
case Seq(_, value) => value.dataType
348+
case Seq(elseVal) => elseVal.dataType
348349
}.toSeq
349-
if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) {
350-
val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get
350+
351+
logDebug(s"Input values for null casting ${valueTypes.mkString(",")}")
352+
353+
if (valueTypes.distinct.size > 1) {
354+
val commonType = valueTypes.reduce { (v1, v2) =>
355+
findTightestCommonType(v1, v2)
356+
.getOrElse(sys.error(
357+
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
358+
}
351359
val transformedBranches = branches.sliding(2, 2).map {
352-
case Seq(cond, value) if value.resolved && value.dataType == NullType =>
353-
Seq(cond, Cast(value, otherType))
354-
case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType =>
355-
Seq(Cast(elseVal, otherType))
360+
case Seq(cond, value) if value.dataType != commonType =>
361+
Seq(cond, Cast(value, commonType))
362+
case Seq(elseVal) if elseVal.dataType != commonType =>
363+
Seq(Cast(elseVal, commonType))
356364
case s => s
357365
}.reduce(_ ++ _)
358366
CaseWhen(transformedBranches)
359367
} else {
360-
// It is possible to have more types due to the possibility of short-circuiting.
368+
// Types match up. Hopefully some other rule fixes whatever is wrong with resolution.
361369
cw
362370
}
363371
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ case class TestData(a: Int, b: String)
3232
*/
3333
class HiveQuerySuite extends HiveComparisonTest {
3434

35+
createQueryTest("null case",
36+
"SELECT case when(true) then 1 else null end FROM src LIMIT 1")
37+
3538
createQueryTest("single case",
3639
"""SELECT case when true then 1 else 2 end FROM src LIMIT 1""")
3740

0 commit comments

Comments
 (0)