Skip to content

Commit a31d782

Browse files
Finish up Case.
1 parent 6cf335d commit a31d782

File tree

3 files changed

+89
-2
lines changed

3 files changed

+89
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,81 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
202202

203203
override def toString = s"if ($predicate) $trueValue else $falseValue"
204204
}
205+
206+
// TODO: is it a good idea to put this class in this file?
207+
// CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END
208+
// When a = true, returns b; when c = true, return d; else return e
209+
case class Case(key: Option[Expression], branches: Seq[Expression]) extends Expression {
210+
// Branches are considered in consecutive pairs (cond, val), and the last element
211+
// is the val for the default catch-all case (w/o a companion condition, that is).
212+
213+
def children = key.toSeq ++ branches
214+
215+
override def nullable = branches
216+
.sliding(2, 2)
217+
.map {
218+
case Seq(cond, value) => value.nullable
219+
case Seq(elseValue) => elseValue.nullable
220+
}
221+
.reduce(_ || _)
222+
223+
def references = children.flatMap(_.references).toSet
224+
225+
override lazy val resolved = {
226+
val allBranchesEqual = branches.sliding(2, 2).map {
227+
case Seq(cond, value) => value.dataType
228+
case Seq(elseValue) => elseValue.dataType
229+
}.reduce(_ == _)
230+
childrenResolved && allBranchesEqual
231+
}
232+
233+
def dataType = {
234+
if (!resolved) {
235+
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
236+
}
237+
branches(1).dataType
238+
}
239+
240+
type EvaluatedType = Any
241+
242+
override def eval(input: Row): Any = {
243+
def slidingCheck(expectedVal: Any): Any = {
244+
branches.sliding(2, 2).foldLeft(None.asInstanceOf[Option[Any]]) {
245+
case (Some(x), _) =>
246+
Some(x)
247+
case (None, Seq(cond, value)) =>
248+
if (cond.eval(input) == true) Some(value.eval(input)) else None
249+
case (None, Seq(elseValue)) =>
250+
Some(elseValue.eval(input))
251+
}.getOrElse(null)
252+
// If all branches fail and an elseVal is not provided, the whole statement
253+
// evaluates to null, according to Hive's semantics.
254+
}
255+
// Check if any branch's cond evaluates either to the key (if provided), or to true.
256+
if (key.isDefined) {
257+
slidingCheck(key.get.eval(input))
258+
} else {
259+
slidingCheck(true)
260+
}
261+
}
262+
263+
override def toString = {
264+
var firstBranch = ""
265+
var otherBranches = ""
266+
if (key.isDefined) {
267+
val keyString = key.get.toString
268+
firstBranch = s"if ($keyString == ${branches(0)}) { ${branches(1)} }"
269+
otherBranches = branches.sliding(2, 2).drop(1).map {
270+
case Seq(cond, value) => s"\nelse if ($keyString == $cond) { $value }"
271+
case Seq(elseValue) => s"\nelse { $elseValue }"
272+
}.mkString
273+
} else {
274+
firstBranch = s"if (${branches(0)}) { ${branches(1)} }"
275+
otherBranches = branches.sliding(2, 2).drop(1).map {
276+
case Seq(cond, value) => s"\nelse if ($cond) { $value }"
277+
case Seq(elseValue) => s"\nelse { $elseValue }"
278+
}.mkString
279+
}
280+
firstBranch ++ otherBranches
281+
}
282+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ExpressionEvaluationSuite extends FunSuite {
3535
/**
3636
* Checks for three-valued-logic. Based on:
3737
* http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29
38-
*
38+
* I.e. in flat cpo "False -> Unknown -> True", OR is lowest upper bound, AND is greatest lower bound.
3939
* p q p OR q p AND q p = q
4040
* True True True True True
4141
* True False True False False

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ private[hive] case class AddJar(jarPath: String) extends Command
4848

4949
private[hive] case class AddFile(filePath: String) extends Command
5050

51+
// FIXME: add back private
5152
/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */
52-
private[hive] object HiveQl {
53+
object HiveQl {
5354
protected val nativeCommands = Seq(
5455
"TOK_DESCFUNCTION",
5556
"TOK_DESCTABLE",
@@ -798,6 +799,8 @@ private[hive] object HiveQl {
798799
val IN = "(?i)IN".r
799800
val DIV = "(?i)DIV".r
800801
val BETWEEN = "(?i)BETWEEN".r
802+
val WHEN = "(?i)WHEN".r
803+
val CASE = "(?i)CASE".r
801804

802805
protected def nodeToExpr(node: Node): Expression = node match {
803806
/* Attribute References */
@@ -904,6 +907,12 @@ private[hive] object HiveQl {
904907
case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right))
905908
case Token(NOT(), child :: Nil) => Not(nodeToExpr(child))
906909

910+
/* Case statements */
911+
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
912+
Case(None, branches.map(nodeToExpr))
913+
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
914+
Case(Some(nodeToExpr(branches(0))), branches.drop(1).map(nodeToExpr))
915+
907916
/* Complex datatype manipulation */
908917
case Token("[", child :: ordinal :: Nil) =>
909918
GetItem(nodeToExpr(child), nodeToExpr(ordinal))

0 commit comments

Comments
 (0)