Skip to content

Commit 7d2b7e2

Browse files
Translate CaseKeyWhen to CaseWhen at parsing time.
1 parent 47d406a commit 7d2b7e2

File tree

2 files changed

+11
-66
lines changed
  • sql

2 files changed

+11
-66
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 3 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
208208
* Refer to this link for the corresponding semantics:
209209
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
210210
*
211+
* The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
212+
* translated to this form at parsing time i.e. CASE WHEN a=b THEN c ...).
213+
*
211214
* Note that branches are considered in consecutive pairs (cond, val), and the optional last element
212215
* is the val for the default catch-all case (if provided). Hence, `branches` consist of at least
213216
* two elements, and can have an odd or even length.
@@ -274,68 +277,3 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
274277
firstBranch ++ otherBranches
275278
}
276279
}
277-
278-
/**
279-
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". This type
280-
* of case statements is separated out from the other type mainly due to performance reason: this
281-
* approach avoids branching (based on whether or not the key is provided) in eval().
282-
*/
283-
case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends Expression {
284-
type EvaluatedType = Any
285-
def children = key +: branches
286-
def references = children.flatMap(_.references).toSet
287-
def dataType = {
288-
if (!resolved) {
289-
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
290-
}
291-
branches(1).dataType
292-
}
293-
294-
override def nullable = branches.sliding(2, 2).map {
295-
case Seq(cond, value) => value.nullable
296-
case Seq(elseValue) => elseValue.nullable
297-
}.reduce(_ || _)
298-
299-
300-
override lazy val resolved = {
301-
lazy val dataTypes = branches.sliding(2, 2).map {
302-
case Seq(cond, value) => value.dataType
303-
case Seq(elseValue) => elseValue.dataType
304-
}.toSeq
305-
lazy val dataTypesEqual =
306-
if (dataTypes.size <= 1) true else dataTypes.drop(1).map(_ == dataTypes(0)).reduce(_ && _)
307-
if (!childrenResolved) false else dataTypesEqual
308-
}
309-
310-
private lazy val branchesArr = branches.toArray
311-
312-
override def eval(input: Row): Any = {
313-
val evaledKey = key.eval(input)
314-
val len = branchesArr.length
315-
var i = 0
316-
// If all branches fail and an elseVal is not provided, the whole statement
317-
// defaults to null, according to Hive's semantics.
318-
var res: Any = null
319-
while (i < len - 1) {
320-
if (branchesArr(i).eval(input) == evaledKey) {
321-
res = branchesArr(i + 1).eval(input)
322-
return res
323-
}
324-
i += 2
325-
}
326-
if (i == len - 1) {
327-
res = branchesArr(i).eval(input)
328-
}
329-
res
330-
}
331-
332-
override def toString = {
333-
val keyString = key.toString
334-
val firstBranch = s"if ($keyString == ${branches(0)}) { ${branches(1)} }"
335-
val otherBranches = branches.sliding(2, 2).drop(1).map {
336-
case Seq(cond, value) => s" else if ($keyString == $cond) { $value }"
337-
case Seq(elseValue) => s" else { $elseValue }"
338-
}.mkString
339-
firstBranch ++ otherBranches
340-
}
341-
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,14 @@ private[hive] object HiveQl {
923923
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
924924
CaseWhen(branches.map(nodeToExpr))
925925
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
926-
CaseKeyWhen(nodeToExpr(branches(0)), branches.drop(1).map(nodeToExpr))
926+
val transformed = branches.drop(1).sliding(2, 2).map {
927+
case Seq(condVal, value) =>
928+
// FIXME?: the key will get evaluated for multiple times in CaseWhen's eval(). Optimize?
929+
Seq(Equals(nodeToExpr(branches(0)), nodeToExpr(condVal)),
930+
nodeToExpr(value))
931+
case Seq(elseVal) => Seq(nodeToExpr(elseVal))
932+
}.toSeq.reduce(_ ++ _)
933+
CaseWhen(transformed)
927934

928935
/* Complex datatype manipulation */
929936
case Token("[", child :: ordinal :: Nil) =>

0 commit comments

Comments
 (0)