Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
| LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) }
| IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
{ case c ~ t ~ f => If(c, t, f) }
| CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~
| CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~
(ELSE ~> expression).? <~ END ^^ {
case casePart ~ altPart ~ elsePart =>
val altExprs = altPart.flatMap { case whenExpr ~ thenExpr =>
Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr)
}
CaseWhen(altExprs ++ elsePart.toList)
val branches = altPart.flatMap { case whenExpr ~ thenExpr =>
Seq(whenExpr, thenExpr)
} ++ elsePart
casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches))
}
| (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^
{ case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,31 +631,24 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
val valueTypes = branches.sliding(2, 2).map {
case Seq(_, value) => value.dataType
case Seq(elseVal) => elseVal.dataType
}.toSeq

logDebug(s"Input values for null casting ${valueTypes.mkString(",")}")

if (valueTypes.distinct.size > 1) {
val commonType = valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2)
.getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = branches.sliding(2, 2).map {
case Seq(cond, value) if value.dataType != commonType =>
Seq(cond, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case s => s
}.reduce(_ ++ _)
CaseWhen(transformedBranches)
} else {
// Types match up. Hopefully some other rule fixes whatever is wrong with resolution.
cw
case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual =>
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
val commonType = cw.valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = cw.branches.sliding(2, 2).map {
case Seq(when, value) if value.dataType != commonType =>
Seq(when, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case s => s
}.reduce(_ ++ _)
cw match {
case _: CaseWhen =>
CaseWhen(transformedBranches)
case CaseKeyWhen(key, _) =>
CaseKeyWhen(key, transformedBranches)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ abstract class Expression extends TreeNode[Expression] {
* Returns true if all the children of this expression have been resolved to a specific schema
* and false if any still contains any unresolved placeholders.
*/
def childrenResolved: Boolean = !children.exists(!_.resolved)
def childrenResolved: Boolean = children.forall(_.resolved)

/**
* Returns a string representation of this expression that does not have developer centric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,79 +353,134 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
}

trait CaseWhenLike extends Expression {
self: Product =>

type EvaluatedType = Any

// Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
// element is the value for the default catch-all case (if provided).
// Hence, `branches` consists of at least two elements, and can have an odd or even length.
def branches: Seq[Expression]

@transient lazy val whenList =
branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
@transient lazy val thenList =
branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)

// both then and else val should be considered.
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1

override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
}
valueTypes.head
}

override def nullable: Boolean = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
}
}

// scalastyle:off
/**
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* Refer to this link for the corresponding semantics:
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
*
* The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
* translated to this form at parsing time. Namely, such a statement gets translated to
* "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END".
*
* Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
* element is the value for the default catch-all case (if provided). Hence, `branches` consists of
* at least two elements, and can have an odd or even length.
*/
// scalastyle:on
case class CaseWhen(branches: Seq[Expression]) extends Expression {
type EvaluatedType = Any
case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {

// Use private[this] Array to speed up evaluation.
@transient private[this] lazy val branchesArr = branches.toArray

override def children: Seq[Expression] = branches

override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
override lazy val resolved: Boolean =
childrenResolved &&
whenList.forall(_.dataType == BooleanType) &&
valueTypesEqual

/** Written in imperative fashion for performance considerations. */
override def eval(input: Row): Any = {
val len = branchesArr.length
var i = 0
// If all branches fail and an elseVal is not provided, the whole statement
// defaults to null, according to Hive's semantics.
while (i < len - 1) {
if (branchesArr(i).eval(input) == true) {
return branchesArr(i + 1).eval(input)
}
i += 2
}
var res: Any = null
if (i == len - 1) {
res = branchesArr(i).eval(input)
}
branches(1).dataType
return res
}

override def toString: String = {
"CASE" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
case Seq(elseValue) => s" ELSE $elseValue"
}.mkString
}
}

// scalastyle:off
/**
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
* Refer to this link for the corresponding semantics:
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
*/
// scalastyle:on
case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike {

// Use private[this] Array to speed up evaluation.
@transient private[this] lazy val branchesArr = branches.toArray
@transient private[this] lazy val predicates =
branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq
@transient private[this] lazy val values =
branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq
@transient private[this] lazy val elseValue =
if (branches.length % 2 == 0) None else Option(branches.last)

override def nullable: Boolean = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
}
override def children: Seq[Expression] = key +: branches

override lazy val resolved: Boolean = {
if (!childrenResolved) {
false
} else {
val allCondBooleans = predicates.forall(_.dataType == BooleanType)
// both then and else val should be considered.
val dataTypesEqual = (values ++ elseValue).map(_.dataType).distinct.size <= 1
allCondBooleans && dataTypesEqual
}
}
override lazy val resolved: Boolean =
childrenResolved && valueTypesEqual

/** Written in imperative fashion for performance considerations. */
override def eval(input: Row): Any = {
val evaluatedKey = key.eval(input)
val len = branchesArr.length
var i = 0
// If all branches fail and an elseVal is not provided, the whole statement
// defaults to null, according to Hive's semantics.
var res: Any = null
while (i < len - 1) {
if (branchesArr(i).eval(input) == true) {
res = branchesArr(i + 1).eval(input)
return res
if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) {
return branchesArr(i + 1).eval(input)
}
i += 2
}
var res: Any = null
if (i == len - 1) {
res = branchesArr(i).eval(input)
}
res
return res
}

private def equalNullSafe(l: Any, r: Any) = {
if (l == null && r == null) {
true
} else if (l == null || r == null) {
false
} else {
l == r
}
}

override def toString: String = {
"CASE" + branches.sliding(2, 2).map {
s"CASE $key" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
case Seq(elseValue) => s" ELSE $elseValue"
}.mkString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,32 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
}

test("case key when") {
val row = create_row(null, 1, 2, "a", "b", "c")
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.string.at(3)
val c5 = 'a.string.at(4)
val c6 = 'a.string.at(5)

val literalNull = Literal.create(null, BooleanType)
val literalInt = Literal(1)
val literalString = Literal("a")

checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row)
checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row)
checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row)

checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row)
checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row)
}

test("complex type") {
val row = create_row(
"^Ba*n", // 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* TODO: This can be optimized to use broadcast join when replacementMap is large.
*/
private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) =>
df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr ::
lit(target).cast(col.dataType).expr :: Nil
val keyExpr = df.col(col.name).expr
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(buildExpr(source), buildExpr(target))
}.toSeq
new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name)
new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
}

private def convertToDouble(v: Any): Double = v match {
Expand Down
12 changes: 2 additions & 10 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1249,16 +1249,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
CaseWhen(branches.map(nodeToExpr))
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
val transformed = branches.drop(1).sliding(2, 2).map {
case Seq(condVal, value) =>
// FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval().
// Hence effectful / non-deterministic key expressions are *not* supported at the moment.
// We should consider adding new Expressions to get around this.
Seq(EqualTo(nodeToExpr(branches(0)), nodeToExpr(condVal)),
nodeToExpr(value))
case Seq(elseVal) => Seq(nodeToExpr(elseVal))
}.toSeq.reduce(_ ++ _)
CaseWhen(transformed)
val keyExpr = nodeToExpr(branches.head)
CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))

/* Complex datatype manipulation */
case Token("[", child :: ordinal :: Nil) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,4 +751,11 @@ class SQLQuerySuite extends QueryTest {
(6, "c", 0, 6)
).map(i => Row(i._1, i._2, i._3, i._4)))
}

test("test case key when") {
(1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t")
checkAnswer(
sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"),
Row(0, "1") :: Row(22, "2") :: Row(0, "3") :: Row(44, "4") :: Row(0, "5") :: Nil)
}
}