@@ -209,11 +209,12 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
209209 * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
210210 *
211211 * The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
212- * translated to this form at parsing time i.e. CASE WHEN a=b THEN c ...).
212+ * translated to this form at parsing time. Namely, such a statement gets translated to
213+ * "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END".
213214 *
214- * Note that branches are considered in consecutive pairs (cond, val), and the optional last element
215- * is the val for the default catch-all case (if provided). Hence, `branches` consist of at least
216- * two elements, and can have an odd or even length.
215+ * Note that ` branches` are considered in consecutive pairs (cond, val), and the optional last
216+ * element is the value for the default catch-all case (if provided). Hence, `branches` consists of
217+ * at least two elements, and can have an odd or even length.
217218 */
218219// scalastyle:on
219220case class CaseWhen (branches : Seq [Expression ]) extends Expression {
@@ -227,30 +228,27 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
227228 branches(1 ).dataType
228229 }
229230
230- override def nullable = branches.sliding(2 , 2 ).map {
231- case Seq (cond, value) => value.nullable
232- case Seq (elseValue) => elseValue.nullable
233- }.reduce(_ || _)
231+ private lazy val branchesArr = branches.toArray
232+ @ transient private lazy val predicates = branches
233+ .sliding(2 , 2 ).collect { case Seq (cond, _) => cond }.toSeq
234+ @ transient private lazy val values = branches
235+ .sliding(2 , 2 ).collect { case Seq (_, value) => value }.toSeq
236+
237+ override def nullable = {
238+ // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
239+ values.exists(_.nullable) || (values.length % 2 == 0 )
240+ }
234241
235242 override lazy val resolved = {
236243 if (! childrenResolved) {
237244 false
238245 } else {
239- val allCondBooleans = branches.sliding(2 , 2 ).map {
240- case Seq (cond, value) => cond.dataType == BooleanType
241- case _ => true
242- }.reduce(_ && _)
243- val dataTypes = branches.sliding(2 , 2 ).map {
244- case Seq (cond, value) => value.dataType
245- case Seq (elseValue) => elseValue.dataType
246- }.toSeq
247- val dataTypesEqual = dataTypes.distinct.size <= 1
246+ val allCondBooleans = predicates.forall(_.dataType == BooleanType )
247+ val dataTypesEqual = values.map(_.dataType).distinct.size <= 1
248248 allCondBooleans && dataTypesEqual
249249 }
250250 }
251251
252- private lazy val branchesArr = branches.toArray
253-
254252 /** Written in imperative fashion for performance considerations. Same for CaseKeyWhen. */
255253 override def eval (input : Row ): Any = {
256254 val len = branchesArr.length
@@ -272,11 +270,9 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
272270 }
273271
274272 override def toString = {
275- val firstBranch = s " if ( ${branches(0 )} == true) { ${branches(1 )} } "
276- val otherBranches = branches.sliding(2 , 2 ).drop(1 ).map {
277- case Seq (cond, value) => s " else if ( $cond == true) { $value } "
278- case Seq (elseValue) => s " else { $elseValue } "
273+ " CASE" + branches.sliding(2 , 2 ).map {
274+ case Seq (cond, value) => s " WHEN $cond THEN $value"
275+ case Seq (elseValue) => s " ELSE $elseValue"
279276 }.mkString
280- firstBranch ++ otherBranches
281277 }
282278}
0 commit comments