Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.util.Objects

import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -97,7 +99,7 @@ case class Not(child: Expression)
/**
* Evaluates to `true` if `list` contains `value`.
*/
case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback {
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
override def children: Seq[Expression] = value +: list

override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
Expand All @@ -107,21 +109,71 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate with C
val evaluatedValue = value.eval(input)
list.exists(e => e.eval(input) == evaluatedValue)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (list.isEmpty) {
s"""
${ev.primitive} = false;
${ev.isNull} = false;
"""
} else {
val valueGen = value.gen(ctx)
val listGen = list.map(_.gen(ctx))
val listCode = listGen.map(x =>
s"""
if (!${ev.primitive}) {
${x.code}
if (${classOf[Objects].getName}.equals(${valueGen.primitive}, ${x.primitive})) {
${ev.primitive} = true;
}
}
""").foldLeft("")((a, b) => a + "\n" + b)
s"""
${valueGen.code}
boolean ${ev.primitive} = false;
boolean ${ev.isNull} = false;
$listCode
"""
}
}

}

/**
* Helper companion object in order to support code generation.
*/
object InSet {

@transient var hset: Set[Any] = null
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin Is there a better way to expose hset to the codeGen stuff?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't work when you have multiple queries. lemme see ...


def check(o: Any): Boolean = {
hset.contains(o)
}
}

/**
* Optimized version of In clause, when all filter values of In clause are
* static.
*/
case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with Predicate with CodegenFallback {
case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate {

override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"

InSet.hset = hset

override def eval(input: InternalRow): Any = {
hset.contains(child.eval(input))
InSet.check(child.eval(input))
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val childGen = child.gen(ctx)
s"""
${childGen.code}
boolean ${ev.isNull} = false;
boolean ${ev.primitive} =
${classOf[InSet].getName.stripSuffix("$")}.check(${childGen.primitive});
"""
}
}

Expand Down