@@ -353,79 +353,134 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
353353 override def toString : String = s " if ( $predicate) $trueValue else $falseValue"
354354}
355355
356+ trait CaseWhenLike extends Expression {
357+ self : Product =>
358+
359+ type EvaluatedType = Any
360+
361+ // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
362+ // element is the value for the default catch-all case (if provided).
363+ // Hence, `branches` consists of at least two elements, and can have an odd or even length.
364+ def branches : Seq [Expression ]
365+
366+ @ transient lazy val whenList =
367+ branches.sliding(2 , 2 ).collect { case Seq (whenExpr, _) => whenExpr }.toSeq
368+ @ transient lazy val thenList =
369+ branches.sliding(2 , 2 ).collect { case Seq (_, thenExpr) => thenExpr }.toSeq
370+ val elseValue = if (branches.length % 2 == 0 ) None else Option (branches.last)
371+
372+ // both then and else val should be considered.
373+ def valueTypes : Seq [DataType ] = (thenList ++ elseValue).map(_.dataType)
374+ def valueTypesEqual : Boolean = valueTypes.distinct.size <= 1
375+
376+ override def dataType : DataType = {
377+ if (! resolved) {
378+ throw new UnresolvedException (this , " cannot resolve due to differing types in some branches" )
379+ }
380+ valueTypes.head
381+ }
382+
383+ override def nullable : Boolean = {
384+ // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
385+ thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true ))
386+ }
387+ }
388+
356389// scalastyle:off
357390/**
358391 * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
359392 * Refer to this link for the corresponding semantics:
360393 * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
361- *
362- * The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
363- * translated to this form at parsing time. Namely, such a statement gets translated to
364- * "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END".
365- *
366- * Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
367- * element is the value for the default catch-all case (if provided). Hence, `branches` consists of
368- * at least two elements, and can have an odd or even length.
369394 */
370395// scalastyle:on
371- case class CaseWhen (branches : Seq [Expression ]) extends Expression {
372- type EvaluatedType = Any
396+ case class CaseWhen (branches : Seq [Expression ]) extends CaseWhenLike {
397+
398+ // Use private[this] Array to speed up evaluation.
399+ @ transient private [this ] lazy val branchesArr = branches.toArray
373400
374401 override def children : Seq [Expression ] = branches
375402
376- override def dataType : DataType = {
377- if (! resolved) {
378- throw new UnresolvedException (this , " cannot resolve due to differing types in some branches" )
403+ override lazy val resolved : Boolean =
404+ childrenResolved &&
405+ whenList.forall(_.dataType == BooleanType ) &&
406+ valueTypesEqual
407+
408+ /** Written in imperative fashion for performance considerations. */
409+ override def eval (input : Row ): Any = {
410+ val len = branchesArr.length
411+ var i = 0
412+ // If all branches fail and an elseVal is not provided, the whole statement
413+ // defaults to null, according to Hive's semantics.
414+ while (i < len - 1 ) {
415+ if (branchesArr(i).eval(input) == true ) {
416+ return branchesArr(i + 1 ).eval(input)
417+ }
418+ i += 2
419+ }
420+ var res : Any = null
421+ if (i == len - 1 ) {
422+ res = branchesArr(i).eval(input)
379423 }
380- branches( 1 ).dataType
424+ return res
381425 }
382426
427+ override def toString : String = {
428+ " CASE" + branches.sliding(2 , 2 ).map {
429+ case Seq (cond, value) => s " WHEN $cond THEN $value"
430+ case Seq (elseValue) => s " ELSE $elseValue"
431+ }.mkString
432+ }
433+ }
434+
435+ // scalastyle:off
436+ /**
437+ * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
438+ * Refer to this link for the corresponding semantics:
439+ * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
440+ */
441+ // scalastyle:on
442+ case class CaseKeyWhen (key : Expression , branches : Seq [Expression ]) extends CaseWhenLike {
443+
444+ // Use private[this] Array to speed up evaluation.
383445 @ transient private [this ] lazy val branchesArr = branches.toArray
384- @ transient private [this ] lazy val predicates =
385- branches.sliding(2 , 2 ).collect { case Seq (cond, _) => cond }.toSeq
386- @ transient private [this ] lazy val values =
387- branches.sliding(2 , 2 ).collect { case Seq (_, value) => value }.toSeq
388- @ transient private [this ] lazy val elseValue =
389- if (branches.length % 2 == 0 ) None else Option (branches.last)
390446
391- override def nullable : Boolean = {
392- // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
393- values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true ))
394- }
447+ override def children : Seq [Expression ] = key +: branches
395448
396- override lazy val resolved : Boolean = {
397- if (! childrenResolved) {
398- false
399- } else {
400- val allCondBooleans = predicates.forall(_.dataType == BooleanType )
401- // both then and else val should be considered.
402- val dataTypesEqual = (values ++ elseValue).map(_.dataType).distinct.size <= 1
403- allCondBooleans && dataTypesEqual
404- }
405- }
449+ override lazy val resolved : Boolean =
450+ childrenResolved && valueTypesEqual
406451
407452 /** Written in imperative fashion for performance considerations. */
408453 override def eval (input : Row ): Any = {
454+ val evaluatedKey = key.eval(input)
409455 val len = branchesArr.length
410456 var i = 0
411457 // If all branches fail and an elseVal is not provided, the whole statement
412458 // defaults to null, according to Hive's semantics.
413- var res : Any = null
414459 while (i < len - 1 ) {
415- if (branchesArr(i).eval(input) == true ) {
416- res = branchesArr(i + 1 ).eval(input)
417- return res
460+ if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) {
461+ return branchesArr(i + 1 ).eval(input)
418462 }
419463 i += 2
420464 }
465+ var res : Any = null
421466 if (i == len - 1 ) {
422467 res = branchesArr(i).eval(input)
423468 }
424- res
469+ return res
470+ }
471+
472+ private def equalNullSafe (l : Any , r : Any ) = {
473+ if (l == null && r == null ) {
474+ true
475+ } else if (l == null || r == null ) {
476+ false
477+ } else {
478+ l == r
479+ }
425480 }
426481
427482 override def toString : String = {
428- " CASE" + branches.sliding(2 , 2 ).map {
483+ s " CASE $key " + branches.sliding(2 , 2 ).map {
429484 case Seq (cond, value) => s " WHEN $cond THEN $value"
430485 case Seq (elseValue) => s " ELSE $elseValue"
431486 }.mkString
0 commit comments