1717
1818package org .apache .spark .sql .catalyst .expressions
1919
20+ import scala .util .control .Breaks ._
21+
2022import org .apache .spark .sql .catalyst .analysis .UnresolvedException
2123import org .apache .spark .sql .catalyst .plans .logical .LogicalPlan
2224import org .apache .spark .sql .catalyst .types .BooleanType
@@ -202,31 +204,28 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
202204 override def toString = s " if ( $predicate) $trueValue else $falseValue"
203205}
204206
205- // TODO: break this down into two cases to eliminate branching during eval().
206-
207207/**
208- * Two types of Case statements: either "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END"
209- * or "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END", depending on whether `key` is defined.
208+ * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
210209 * Refer to this link for the corresponding semantics:
211210 * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
212211 *
213- * Note that branches are considered in consecutive pairs (cond, val), and the last element is the
214- * val for the default catch-all case (if provided). Hence, `branches` consist of at least two
215- * elements, and can have an odd or even length.
212+ * Note that branches are considered in consecutive pairs (cond, val), and the optional last element
213+ * is the val for the default catch-all case (if provided). Hence, `branches` consist of at least
214+ * two elements, and can have an odd or even length.
216215 */
217- case class CaseWhen (key : Option [Expression ], branches : Seq [Expression ]) extends Expression {
218- def children = key.toSeq ++ branches
216+ case class CaseWhen (branches : Seq [Expression ]) extends Expression {
217+ // TODO: need to check each branch's condition has Boolean type?
218+
219+ def children = branches
219220
220- override def nullable = branches
221- .sliding(2 , 2 )
222- .map {
221+ override def nullable = branches.sliding(2 , 2 ).map {
223222 case Seq (cond, value) => value.nullable
224223 case Seq (elseValue) => elseValue.nullable
225- }
226- .reduce(_ || _)
224+ }.reduce(_ || _)
227225
228226 def references = children.flatMap(_.references).toSet
229227
228+ // TODO: fix resolved for identity function
230229 override lazy val resolved = {
231230 lazy val dataTypes = branches.sliding(2 , 2 )
232231 .map {
@@ -246,46 +245,97 @@ case class CaseWhen(key: Option[Expression], branches: Seq[Expression]) extends
246245
247246 type EvaluatedType = Any
248247
249- // TODO: change eval() to use while, etc.
250-
251248 override def eval (input : Row ): Any = {
252- def slidingCheck (expectedVal : Any ): Any = {
253- branches.sliding(2 , 2 ).foldLeft(None .asInstanceOf [Option [Any ]]) {
254- case (Some (x), _) =>
255- Some (x)
256- case (None , Seq (cond, value)) =>
257- if (cond.eval(input) == expectedVal) Some (value.eval(input)) else None
258- case (None , Seq (elseValue)) =>
259- Some (elseValue.eval(input))
260- }.getOrElse(null )
261- // If all branches fail and an elseVal is not provided, the whole statement
262- // evaluates to null, according to Hive's semantics.
249+ val branchesArr = branches.toArray
250+ val len = branchesArr.length
251+ var i = 0
252+ // If all branches fail and an elseVal is not provided, the whole statement
253+ // defaults to null, according to Hive's semantics.
254+ var res : Any = null
255+ while (i < len - 1 ) {
256+ if (branches(i).eval(input) == true ) {
257+ res = branches(i + 1 ).eval(input)
258+ break
259+ }
260+ i += 2
263261 }
264- // Check if any branch's cond evaluates either to the key (if provided), or to true.
265- if (key.isDefined) {
266- slidingCheck(key.get.eval(input))
267- } else {
268- slidingCheck(true )
262+ if (i == len - 1 ) {
263+ res = branches(i).eval(input)
269264 }
265+ res
270266 }
271267
272268 override def toString = {
273- var firstBranch = " "
274- var otherBranches = " "
275- if (key.isDefined) {
276- val keyString = key.get.toString
277- firstBranch = s " if ( $keyString == ${branches(0 )}) { ${branches(1 )} } "
278- otherBranches = branches.sliding(2 , 2 ).drop(1 ).map {
279- case Seq (cond, value) => s " else if ( $keyString == $cond) { $value } "
280- case Seq (elseValue) => s " else { $elseValue } "
281- }.mkString
282- } else {
283- firstBranch = s " if ( ${branches(0 )}) { ${branches(1 )} } "
284- otherBranches = branches.sliding(2 , 2 ).drop(1 ).map {
285- case Seq (cond, value) => s " else if ( $cond) { $value } "
286- case Seq (elseValue) => s " else { $elseValue } "
287- }.mkString
269+ val firstBranch = s " if ( ${branches(0 )} == true) { ${branches(1 )} } "
270+ val otherBranches = branches.sliding(2 , 2 ).drop(1 ).map {
271+ case Seq (cond, value) => s " else if ( $cond == true) { $value } "
272+ case Seq (elseValue) => s " else { $elseValue } "
273+ }.mkString
274+ firstBranch ++ otherBranches
275+ }
276+ }
277+
278+ /**
279+ * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". This type
280+ * of case statements is separated out from the other type mainly due to performance reason: this
281+ * approach avoids branching (based on whether or not the key is provided) in eval().
282+ */
283+ case class CaseKeyWhen (key : Expression , branches : Seq [Expression ]) extends Expression {
284+ def children = key +: branches
285+
286+ override def nullable = branches.sliding(2 , 2 ).map {
287+ case Seq (cond, value) => value.nullable
288+ case Seq (elseValue) => elseValue.nullable
289+ }.reduce(_ || _)
290+
291+ def references = children.flatMap(_.references).toSet
292+
293+ override lazy val resolved = {
294+ lazy val dataTypes = branches.sliding(2 , 2 ).map {
295+ case Seq (cond, value) => value.dataType
296+ case Seq (elseValue) => elseValue.dataType
297+ }.toSeq
298+ lazy val dataTypesEqual = dataTypes.drop(1 ).map(_ == dataTypes(0 )).reduce(_ && _)
299+ if (dataTypes.size == 1 ) true else childrenResolved && dataTypesEqual
300+ }
301+
302+ def dataType = {
303+ if (! resolved) {
304+ throw new UnresolvedException (this , " cannot resolve due to differing types in some branches" )
288305 }
306+ branches(1 ).dataType
307+ }
308+
309+ type EvaluatedType = Any
310+
311+ override def eval (input : Row ): Any = {
312+ val evaledKey = key.eval(input)
313+ val branchesArr = branches.toArray
314+ val len = branchesArr.length
315+ var i = 0
316+ // If all branches fail and an elseVal is not provided, the whole statement
317+ // defaults to null, according to Hive's semantics.
318+ var res : Any = null
319+ while (i < len - 1 ) {
320+ if (branches(i).eval(input) == evaledKey) {
321+ res = branches(i + 1 ).eval(input)
322+ break
323+ }
324+ i += 2
325+ }
326+ if (i == len - 1 ) {
327+ res = branches(i).eval(input)
328+ }
329+ res
330+ }
331+
332+ override def toString = {
333+ val keyString = key.toString
334+ val firstBranch = s " if ( $keyString == ${branches(0 )}) { ${branches(1 )} } "
335+ val otherBranches = branches.sliding(2 , 2 ).drop(1 ).map {
336+ case Seq (cond, value) => s " else if ( $keyString == $cond) { $value } "
337+ case Seq (elseValue) => s " else { $elseValue } "
338+ }.mkString
289339 firstBranch ++ otherBranches
290340 }
291341}
0 commit comments