Skip to content

Commit be54bc8

Browse files
Rewrite eval() to a low-level implementation. Separate two CASE stmts.
1 parent f2bcb9d commit be54bc8

File tree

3 files changed

+102
-51
lines changed

3 files changed

+102
-51
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ abstract class Expression extends TreeNode[Expression] {
2828
/** The narrowest possible type that is produced when this expression is evaluated. */
2929
type EvaluatedType <: Any
3030

31-
def dataType: DataType
32-
3331
/**
3432
* Returns true when an expression is a candidate for static evaluation before the query is
3533
* executed.
@@ -59,6 +57,9 @@ abstract class Expression extends TreeNode[Expression] {
5957
*/
6058
lazy val resolved: Boolean = childrenResolved
6159

60+
/** This is invalid to query if `resolved` is false. */
61+
def dataType: DataType
62+
6263
/**
6364
* Returns true if all the children of this expression have been resolved to a specific schema
6465
* and false if any still contains any unresolved placeholders.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 97 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import scala.util.control.Breaks._
21+
2022
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2123
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2224
import org.apache.spark.sql.catalyst.types.BooleanType
@@ -202,31 +204,28 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
202204
override def toString = s"if ($predicate) $trueValue else $falseValue"
203205
}
204206

205-
// TODO: break this down into two cases to eliminate branching during eval().
206-
207207
/**
208-
* Two types of Case statements: either "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END"
209-
* or "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END", depending on whether `key` is defined.
208+
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
210209
* Refer to this link for the corresponding semantics:
211210
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
212211
*
213-
* Note that branches are considered in consecutive pairs (cond, val), and the last element is the
214-
* val for the default catch-all case (if provided). Hence, `branches` consist of at least two
215-
* elements, and can have an odd or even length.
212+
* Note that branches are considered in consecutive pairs (cond, val), and the optional last element
213+
* is the val for the default catch-all case (if provided). Hence, `branches` consist of at least
214+
* two elements, and can have an odd or even length.
216215
*/
217-
case class CaseWhen(key: Option[Expression], branches: Seq[Expression]) extends Expression {
218-
def children = key.toSeq ++ branches
216+
case class CaseWhen(branches: Seq[Expression]) extends Expression {
217+
// TODO: need to check each branch's condition has Boolean type?
218+
219+
def children = branches
219220

220-
override def nullable = branches
221-
.sliding(2, 2)
222-
.map {
221+
override def nullable = branches.sliding(2, 2).map {
223222
case Seq(cond, value) => value.nullable
224223
case Seq(elseValue) => elseValue.nullable
225-
}
226-
.reduce(_ || _)
224+
}.reduce(_ || _)
227225

228226
def references = children.flatMap(_.references).toSet
229227

228+
// TODO: fix resolved for identity function
230229
override lazy val resolved = {
231230
lazy val dataTypes = branches.sliding(2, 2)
232231
.map {
@@ -246,46 +245,97 @@ case class CaseWhen(key: Option[Expression], branches: Seq[Expression]) extends
246245

247246
type EvaluatedType = Any
248247

249-
// TODO: change eval() to use while, etc.
250-
251248
override def eval(input: Row): Any = {
252-
def slidingCheck(expectedVal: Any): Any = {
253-
branches.sliding(2, 2).foldLeft(None.asInstanceOf[Option[Any]]) {
254-
case (Some(x), _) =>
255-
Some(x)
256-
case (None, Seq(cond, value)) =>
257-
if (cond.eval(input) == expectedVal) Some(value.eval(input)) else None
258-
case (None, Seq(elseValue)) =>
259-
Some(elseValue.eval(input))
260-
}.getOrElse(null)
261-
// If all branches fail and an elseVal is not provided, the whole statement
262-
// evaluates to null, according to Hive's semantics.
249+
val branchesArr = branches.toArray
250+
val len = branchesArr.length
251+
var i = 0
252+
// If all branches fail and an elseVal is not provided, the whole statement
253+
// defaults to null, according to Hive's semantics.
254+
var res: Any = null
255+
while (i < len - 1) {
256+
if (branches(i).eval(input) == true) {
257+
res = branches(i + 1).eval(input)
258+
break
259+
}
260+
i += 2
263261
}
264-
// Check if any branch's cond evaluates either to the key (if provided), or to true.
265-
if (key.isDefined) {
266-
slidingCheck(key.get.eval(input))
267-
} else {
268-
slidingCheck(true)
262+
if (i == len - 1) {
263+
res = branches(i).eval(input)
269264
}
265+
res
270266
}
271267

272268
override def toString = {
273-
var firstBranch = ""
274-
var otherBranches = ""
275-
if (key.isDefined) {
276-
val keyString = key.get.toString
277-
firstBranch = s"if ($keyString == ${branches(0)}) { ${branches(1)} }"
278-
otherBranches = branches.sliding(2, 2).drop(1).map {
279-
case Seq(cond, value) => s" else if ($keyString == $cond) { $value }"
280-
case Seq(elseValue) => s" else { $elseValue }"
281-
}.mkString
282-
} else {
283-
firstBranch = s"if (${branches(0)}) { ${branches(1)} }"
284-
otherBranches = branches.sliding(2, 2).drop(1).map {
285-
case Seq(cond, value) => s" else if ($cond) { $value }"
286-
case Seq(elseValue) => s" else { $elseValue }"
287-
}.mkString
269+
val firstBranch = s"if (${branches(0)} == true) { ${branches(1)} }"
270+
val otherBranches = branches.sliding(2, 2).drop(1).map {
271+
case Seq(cond, value) => s" else if ($cond == true) { $value }"
272+
case Seq(elseValue) => s" else { $elseValue }"
273+
}.mkString
274+
firstBranch ++ otherBranches
275+
}
276+
}
277+
278+
/**
279+
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". This type
280+
* of case statements is separated out from the other type mainly due to performance reason: this
281+
* approach avoids branching (based on whether or not the key is provided) in eval().
282+
*/
283+
case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends Expression {
284+
def children = key +: branches
285+
286+
override def nullable = branches.sliding(2, 2).map {
287+
case Seq(cond, value) => value.nullable
288+
case Seq(elseValue) => elseValue.nullable
289+
}.reduce(_ || _)
290+
291+
def references = children.flatMap(_.references).toSet
292+
293+
override lazy val resolved = {
294+
lazy val dataTypes = branches.sliding(2, 2).map {
295+
case Seq(cond, value) => value.dataType
296+
case Seq(elseValue) => elseValue.dataType
297+
}.toSeq
298+
lazy val dataTypesEqual = dataTypes.drop(1).map(_ == dataTypes(0)).reduce(_ && _)
299+
if (dataTypes.size == 1) true else childrenResolved && dataTypesEqual
300+
}
301+
302+
def dataType = {
303+
if (!resolved) {
304+
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
288305
}
306+
branches(1).dataType
307+
}
308+
309+
type EvaluatedType = Any
310+
311+
override def eval(input: Row): Any = {
312+
val evaledKey = key.eval(input)
313+
val branchesArr = branches.toArray
314+
val len = branchesArr.length
315+
var i = 0
316+
// If all branches fail and an elseVal is not provided, the whole statement
317+
// defaults to null, according to Hive's semantics.
318+
var res: Any = null
319+
while (i < len - 1) {
320+
if (branches(i).eval(input) == evaledKey) {
321+
res = branches(i + 1).eval(input)
322+
break
323+
}
324+
i += 2
325+
}
326+
if (i == len - 1) {
327+
res = branches(i).eval(input)
328+
}
329+
res
330+
}
331+
332+
override def toString = {
333+
val keyString = key.toString
334+
val firstBranch = s"if ($keyString == ${branches(0)}) { ${branches(1)} }"
335+
val otherBranches = branches.sliding(2, 2).drop(1).map {
336+
case Seq(cond, value) => s" else if ($keyString == $cond) { $value }"
337+
case Seq(elseValue) => s" else { $elseValue }"
338+
}.mkString
289339
firstBranch ++ otherBranches
290340
}
291341
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -909,9 +909,9 @@ object HiveQl {
909909

910910
/* Case statements */
911911
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
912-
CaseWhen(None, branches.map(nodeToExpr))
912+
CaseWhen(branches.map(nodeToExpr))
913913
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
914-
CaseWhen(Some(nodeToExpr(branches(0))), branches.drop(1).map(nodeToExpr))
914+
CaseKeyWhen(nodeToExpr(branches(0)), branches.drop(1).map(nodeToExpr))
915915

916916
/* Complex datatype manipulation */
917917
case Token("[", child :: ordinal :: Nil) =>

0 commit comments

Comments
 (0)