@@ -202,3 +202,81 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
202202
203203 override def toString = s " if ( $predicate) $trueValue else $falseValue"
204204}
205+
206+ // TODO: is it a good idea to put this class in this file?
207+ // CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END
208+ // When a = true, returns b; when c = true, return d; else return e
209+ case class Case (key : Option [Expression ], branches : Seq [Expression ]) extends Expression {
210+ // Branches are considered in consecutive pairs (cond, val), and the last element
211+ // is the val for the default catch-all case (w/o a companion condition, that is).
212+
213+ def children = key.toSeq ++ branches
214+
215+ override def nullable = branches
216+ .sliding(2 , 2 )
217+ .map {
218+ case Seq (cond, value) => value.nullable
219+ case Seq (elseValue) => elseValue.nullable
220+ }
221+ .reduce(_ || _)
222+
223+ def references = children.flatMap(_.references).toSet
224+
225+ override lazy val resolved = {
226+ val allBranchesEqual = branches.sliding(2 , 2 ).map {
227+ case Seq (cond, value) => value.dataType
228+ case Seq (elseValue) => elseValue.dataType
229+ }.reduce(_ == _)
230+ childrenResolved && allBranchesEqual
231+ }
232+
233+ def dataType = {
234+ if (! resolved) {
235+ throw new UnresolvedException (this , " cannot resolve due to differing types in some branches" )
236+ }
237+ branches(1 ).dataType
238+ }
239+
240+ type EvaluatedType = Any
241+
242+ override def eval (input : Row ): Any = {
243+ def slidingCheck (expectedVal : Any ): Any = {
244+ branches.sliding(2 , 2 ).foldLeft(None .asInstanceOf [Option [Any ]]) {
245+ case (Some (x), _) =>
246+ Some (x)
247+ case (None , Seq (cond, value)) =>
248+ if (cond.eval(input) == true ) Some (value.eval(input)) else None
249+ case (None , Seq (elseValue)) =>
250+ Some (elseValue.eval(input))
251+ }.getOrElse(null )
252+ // If all branches fail and an elseVal is not provided, the whole statement
253+ // evaluates to null, according to Hive's semantics.
254+ }
255+ // Check if any branch's cond evaluates either to the key (if provided), or to true.
256+ if (key.isDefined) {
257+ slidingCheck(key.get.eval(input))
258+ } else {
259+ slidingCheck(true )
260+ }
261+ }
262+
263+ override def toString = {
264+ var firstBranch = " "
265+ var otherBranches = " "
266+ if (key.isDefined) {
267+ val keyString = key.get.toString
268+ firstBranch = s " if ( $keyString == ${branches(0 )}) { ${branches(1 )} } "
269+ otherBranches = branches.sliding(2 , 2 ).drop(1 ).map {
270+ case Seq (cond, value) => s " \n else if ( $keyString == $cond) { $value } "
271+ case Seq (elseValue) => s " \n else { $elseValue } "
272+ }.mkString
273+ } else {
274+ firstBranch = s " if ( ${branches(0 )}) { ${branches(1 )} } "
275+ otherBranches = branches.sliding(2 , 2 ).drop(1 ).map {
276+ case Seq (cond, value) => s " \n else if ( $cond) { $value } "
277+ case Seq (elseValue) => s " \n else { $elseValue } "
278+ }.mkString
279+ }
280+ firstBranch ++ otherBranches
281+ }
282+ }
0 commit comments