Skip to content

Commit 7ef284f

Browse files
Refactors: remove redundant passes, improve toString, mark transient.
1 parent f47ae7b commit 7ef284f

File tree

2 files changed

+22
-25
lines changed

2 files changed

+22
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,12 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
209209
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
210210
*
211211
* The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
212-
* translated to this form at parsing time i.e. CASE WHEN a=b THEN c ...).
212+
* translated to this form at parsing time. Namely, such a statement gets translated to
213+
* "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END".
213214
*
214-
* Note that branches are considered in consecutive pairs (cond, val), and the optional last element
215-
* is the val for the default catch-all case (if provided). Hence, `branches` consist of at least
216-
* two elements, and can have an odd or even length.
215+
* Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
216+
* element is the value for the default catch-all case (if provided). Hence, `branches` consists of
217+
* at least two elements, and can have an odd or even length.
217218
*/
218219
// scalastyle:on
219220
case class CaseWhen(branches: Seq[Expression]) extends Expression {
@@ -227,30 +228,27 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
227228
branches(1).dataType
228229
}
229230

230-
override def nullable = branches.sliding(2, 2).map {
231-
case Seq(cond, value) => value.nullable
232-
case Seq(elseValue) => elseValue.nullable
233-
}.reduce(_ || _)
231+
private lazy val branchesArr = branches.toArray
232+
@transient private lazy val predicates = branches
233+
.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq
234+
@transient private lazy val values = branches
235+
.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq
236+
237+
override def nullable = {
238+
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
239+
values.exists(_.nullable) || (values.length % 2 == 0)
240+
}
234241

235242
override lazy val resolved = {
236243
if (!childrenResolved) {
237244
false
238245
} else {
239-
val allCondBooleans = branches.sliding(2, 2).map {
240-
case Seq(cond, value) => cond.dataType == BooleanType
241-
case _ => true
242-
}.reduce(_ && _)
243-
val dataTypes = branches.sliding(2, 2).map {
244-
case Seq(cond, value) => value.dataType
245-
case Seq(elseValue) => elseValue.dataType
246-
}.toSeq
247-
val dataTypesEqual = dataTypes.distinct.size <= 1
246+
val allCondBooleans = predicates.forall(_.dataType == BooleanType)
247+
val dataTypesEqual = values.map(_.dataType).distinct.size <= 1
248248
allCondBooleans && dataTypesEqual
249249
}
250250
}
251251

252-
private lazy val branchesArr = branches.toArray
253-
254252
/** Written in imperative fashion for performance considerations. Same for CaseKeyWhen. */
255253
override def eval(input: Row): Any = {
256254
val len = branchesArr.length
@@ -272,11 +270,9 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
272270
}
273271

274272
override def toString = {
275-
val firstBranch = s"if (${branches(0)} == true) { ${branches(1)} }"
276-
val otherBranches = branches.sliding(2, 2).drop(1).map {
277-
case Seq(cond, value) => s" else if ($cond == true) { $value }"
278-
case Seq(elseValue) => s" else { $elseValue }"
273+
"CASE" + branches.sliding(2, 2).map {
274+
case Seq(cond, value) => s" WHEN $cond THEN $value"
275+
case Seq(elseValue) => s" ELSE $elseValue"
279276
}.mkString
280-
firstBranch ++ otherBranches
281277
}
282278
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ class HiveQuerySuite extends HiveComparisonTest {
192192
assert(actual === expected)
193193
}
194194

195-
// TODO: at what point in the whole pipeline should we throw an exception for this case?
195+
// TODO: adopt this test when Spark SQL has the functionality / framework to report errors.
196+
// See https://github.com/apache/spark/pull/1055#issuecomment-45820167 for a discussion.
196197
ignore("non-boolean conditions in a CaseWhen are illegal") {
197198
intercept[Exception] {
198199
hql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect()

0 commit comments

Comments
 (0)