diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ab7d3afce8f2..c908e8272c7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Objects + import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -97,7 +99,7 @@ case class Not(child: Expression) /** * Evaluates to `true` if `list` contains `value`. */ -case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback { +case class In(value: Expression, list: Seq[Expression]) extends Predicate { override def children: Seq[Expression] = value +: list override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. @@ -107,21 +109,71 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate with C val evaluatedValue = value.eval(input) list.exists(e => e.eval(input) == evaluatedValue) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + if (list.isEmpty) { + s""" + ${ev.primitive} = false; + ${ev.isNull} = false; + """ + } else { + val valueGen = value.gen(ctx) + val listGen = list.map(_.gen(ctx)) + val listCode = listGen.map(x => + s""" + if (!${ev.primitive}) { + ${x.code} + if (${classOf[Objects].getName}.equals(${valueGen.primitive}, ${x.primitive})) { + ${ev.primitive} = true; + } + } + """).foldLeft("")((a, b) => a + "\n" + b) + s""" + ${valueGen.code} + boolean ${ev.primitive} = false; + boolean ${ev.isNull} = false; + $listCode + """ + } + } + } +/** + * Helper companion object in order to support code generation. + */ +object InSet { + + @transient var hset: Set[Any] = null + + def check(o: Any): Boolean = { + hset.contains(o) + } +} /** * Optimized version of In clause, when all filter values of In clause are * static. */ -case class InSet(child: Expression, hset: Set[Any]) - extends UnaryExpression with Predicate with CodegenFallback { +case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate { override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" + InSet.hset = hset + override def eval(input: InternalRow): Any = { - hset.contains(child.eval(input)) + InSet.check(child.eval(input)) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val childGen = child.gen(ctx) + s""" + ${childGen.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = + ${classOf[InSet].getName.stripSuffix("$")}.check(${childGen.primitive}); + """ } }