Skip to content

Commit f2bcb9d

Browse files
WIP
1 parent 5906f75 commit f2bcb9d

File tree

5 files changed

+533
-7
lines changed

5 files changed

+533
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
202202
override def toString = s"if ($predicate) $trueValue else $falseValue"
203203
}
204204

205+
// TODO: break this down into two cases to eliminate branching during eval().
206+
205207
/**
206208
* Two types of Case statements: either "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END"
207209
* or "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END", depending on whether `key` is defined.
@@ -212,7 +214,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
212214
* val for the default catch-all case (if provided). Hence, `branches` consist of at least two
213215
* elements, and can have an odd or even length.
214216
*/
215-
case class Case(key: Option[Expression], branches: Seq[Expression]) extends Expression {
217+
case class CaseWhen(key: Option[Expression], branches: Seq[Expression]) extends Expression {
216218
def children = key.toSeq ++ branches
217219

218220
override def nullable = branches
@@ -232,7 +234,7 @@ case class Case(key: Option[Expression], branches: Seq[Expression]) extends Expr
232234
case Seq(elseValue) => elseValue.dataType
233235
}.toSeq
234236
lazy val dataTypesEqual = dataTypes.drop(1).map(_ == dataTypes(0)).reduce(_ && _)
235-
childrenResolved && dataTypesEqual
237+
if (dataTypes.size == 1) true else childrenResolved && dataTypesEqual
236238
}
237239

238240
def dataType = {
@@ -244,6 +246,8 @@ case class Case(key: Option[Expression], branches: Seq[Expression]) extends Expr
244246

245247
type EvaluatedType = Any
246248

249+
// TODO: change eval() to use while, etc.
250+
247251
override def eval(input: Row): Any = {
248252
def slidingCheck(expectedVal: Any): Any = {
249253
branches.sliding(2, 2).foldLeft(None.asInstanceOf[Option[Any]]) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ package object util {
115115
}
116116

117117
/* FIX ME
118-
implicit class debugLogging(a: AnyRef) {
118+
implicit class debugLogging(a: Any) {
119119
def debugLogging() {
120120
org.apache.log4j.Logger.getLogger(a.getClass.getName).setLevel(org.apache.log4j.Level.DEBUG)
121121
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -909,9 +909,9 @@ object HiveQl {
909909

910910
/* Case statements */
911911
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
912-
Case(None, branches.map(nodeToExpr))
912+
CaseWhen(None, branches.map(nodeToExpr))
913913
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
914-
Case(Some(nodeToExpr(branches(0))), branches.drop(1).map(nodeToExpr))
914+
CaseWhen(Some(nodeToExpr(branches(0))), branches.drop(1).map(nodeToExpr))
915915

916916
/* Complex datatype manipulation */
917917
case Token("[", child :: ordinal :: Nil) =>

0 commit comments

Comments
 (0)