Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ object TypeCoercion {
i
}

case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
case i @ In(a, b) if b.exists(_.dataType != i.value.dataType) =>
findWiderCommonType(i.children.map(_.dataType)) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
case None => i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ package object dsl {
case c: CreateNamedStruct => InSubquery(c.valExprs, l)
case other => InSubquery(Seq(other), l)
}
case _ => In(expr, list)
case _ => expr match {
case c: CreateNamedStruct => In(c.valExprs, list)
case other => In(Seq(other), list)
}
}

def like(other: Expression): Expression = Like(expr, other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ object Canonicalize {
case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r)

// order the list in the In operator
case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode()))
case In(values, list) if list.length > 1 => In(values, list.sortBy(_.hashCode()))

case _ => e
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import scala.collection.immutable.TreeSet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.expressions.codegen.Block
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -138,21 +140,26 @@ case class Not(child: Expression)
override def sql: String = s"(NOT ${child.sql})"
}

/**
* Evaluates to `true` if `values` are returned in `query`'s result set.
*/
case class InSubquery(values: Seq[Expression], query: ListQuery)
extends Predicate with Unevaluable {
abstract class InBase extends Predicate {
def values: Seq[Expression]

@transient protected lazy val isMultiValued = values.length > 1

@transient private lazy val value: Expression = if (values.length > 1) {
@transient lazy val value: Expression = if (isMultiValued) {
CreateNamedStruct(values.zipWithIndex.flatMap {
case (v: NamedExpression, _) => Seq(Literal(v.name), v)
case (v, idx) => Seq(Literal(s"_$idx"), v)
})
} else {
values.head
}
}

/**
* Evaluates to `true` if `values` are returned in `query`'s result set.
*/
case class InSubquery(values: Seq[Expression], query: ListQuery)
extends InBase with Unevaluable {

override def checkInputDataTypes(): TypeCheckResult = {
if (values.length != query.childOutputs.length) {
Expand Down Expand Up @@ -202,7 +209,11 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.",
usage = """
expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN. Otherwise, if
spark.sql.legacy.inOperator.falseForNullField is false and any of the elements or fields of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any of the elements or fields ...

We should explicitly mention multi-column IN, which is different from a in (b, c, ...) while a is struct type.

the elements is null it returns null, else it returns false.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely remember that multi-line string doesn't work with ExpressionDescription. Can you verify it with DESCRIBE FUNCTION?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the point here is that only a string literal works, so it doesn't work concat and/or interpolation. This just puts the string on different lines, ie. the output is:

scala> sql("DESCRIBE FUNCTION IN").show(false)
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|function_desc                                                                                                                                                                                                                                                             |
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|Function: in                                                                                                                                                                                                                                                              |
|Class: org.apache.spark.sql.catalyst.expressions.In                                                                                                                                                                                                                       |
|Usage: 
    expr1 in(expr2, expr3, ...) - Returns true if `expr` equals to any valN. Otherwise, if
      spark.sql.legacy.inOperator.falseForNullField is false and any of the elements or fields of
      the elements is null it returns null, else it returns false.
  |
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should use

"""
  |xxx
""".stripMargin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that doesn't work. We cannot do stripMargin. We can just put a string.

""",
arguments = """
Arguments:
* expr1, expr2, expr3, ... - the arguments must be same type.
Expand All @@ -219,7 +230,7 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
true
""")
// scalastyle:on line.size.limit
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
case class In(values: Seq[Expression], list: Seq[Expression]) extends InBase {

require(list != null, "list should not be null")

Expand All @@ -234,24 +245,29 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
}

override def children: Seq[Expression] = value +: list
override def children: Seq[Expression] = values ++ list
lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType)

override def nullable: Boolean = children.exists(_.nullable)
override def nullable: Boolean = if (isMultiValued && !SQLConf.get.inFalseForNullField) {
children.exists(_.nullable) ||
children.exists(_.dataType.asInstanceOf[StructType].exists(_.nullable))
} else {
children.exists(_.nullable)
}
override def foldable: Boolean = children.forall(_.foldable)

override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"

override def eval(input: InternalRow): Any = {
val evaluatedValue = value.eval(input)
if (evaluatedValue == null) {
if (checkNullEval(evaluatedValue)) {
null
} else {
var hasNull = false
list.foreach { e =>
val v = e.eval(input)
if (v == null) {
if (checkNullEval(v)) {
hasNull = true
} else if (ordering.equiv(v, evaluatedValue)) {
return true
Expand All @@ -265,6 +281,22 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
}

@transient lazy val checkNullGenCode: (ExprCode) => Block = {
if (isMultiValued && !SQLConf.get.inFalseForNullField) {
e => code"${e.isNull} || ${e.value}.anyNull()"
} else {
e => code"${e.isNull}"
}
}

@transient lazy val checkNullEval: (Any) => Boolean = {
if (isMultiValued && !SQLConf.get.inFalseForNullField) {
input => input == null || input.asInstanceOf[InternalRow].anyNull
} else {
input => input == null
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaDataType = CodeGenerator.javaType(value.dataType)
val valueGen = value.genCode(ctx)
Expand All @@ -283,7 +315,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
val listCode = listGen.map(x =>
s"""
|${x.code}
|if (${x.isNull}) {
|if (${checkNullGenCode(x)}) {
| $tmpResult = $HAS_NULL; // ${ev.isNull} = true;
|} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
| $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true;
Expand Down Expand Up @@ -316,7 +348,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
code"""
|${valueGen.code}
|byte $tmpResult = $HAS_NULL;
|if (!${valueGen.isNull}) {
|if (!(${checkNullGenCode(valueGen)})) {
| $tmpResult = $NOT_MATCHED;
| $javaDataType $valueArg = ${valueGen.value};
| do {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,27 +212,27 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
* 1. Converts the predicate to false when the list is empty and
* the value is not nullable.
* 2. Removes literal repetitions.
* 3. Replaces [[In (value, seq[Literal])]] with optimized version
* 3. Replaces [[In (values, seq[Literal])]] with optimized version
* [[InSet (value, HashSet[Literal])]] which is much faster.
*/
object OptimizeIn extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if list.isEmpty =>
case i @ In(_, list) if list.isEmpty =>
// When v is not nullable, the following expression will be optimized
// to FalseLiteral which is tested in OptimizeInSuite.scala
If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType))
case expr @ In(v, list) if expr.inSetConvertible =>
If(IsNotNull(i.value), FalseLiteral, Literal(null, BooleanType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to look at inFalseForNullField right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rigth, I fixed it and added a test, thanks.

case expr @ In(_, list) if expr.inSetConvertible =>
val newList = ExpressionSet(list).toSeq
if (newList.length == 1
// TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed,
// TODO: we exclude them in this rule.
&& !v.isInstanceOf[CreateNamedStructLike]
&& !expr.value.isInstanceOf[CreateNamedStructLike]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  @transient protected lazy val isMultiValued = values.length > 1
  @transient lazy val value: Expression = if (isMultiValued) {
    CreateNamedStruct(values.zipWithIndex.flatMap {
      case (v: NamedExpression, _) => Seq(Literal(v.name), v)
      case (v, idx) => Seq(Literal(s"_$idx"), v)
    })
  } else {
    values.head
  }
}

According to the implementation, expr.value.isInstanceOf[CreateNamedStructLike] means expr.values.length > 1, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, rigth

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we use expr.values.length == 1 here to make it more clear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really, because expr.value.isInstanceOf[CreateNamedStructLike] means:

  • either expr.values.length == 1;
  • or expr.values.head.isInstanceOf[CreateNamedStructLike];

Basically there are 2 cases: one where we have several attributes in the value before IN; the other when there is a single value before IN but the value is a struct. expr.value.isInstanceOf[CreateNamedStructLike] catches both. I can add a comment explaining these 2 cases if you think is needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, you mean
expr.values.length > 1 => expr.value.isInstanceOf[CreateNamedStructLike]
but expr.value.isInstanceOf[CreateNamedStructLike] can't => expr.values.length > 1

Can you give an example?

Based on my understanding, the code here is trying to optimize a case when it's not a multi-value in and the list has only one element.

Copy link
Contributor Author

@mgaido91 mgaido91 Oct 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I mean that. An example is:

select 1 from (select struct('a', 1, 'b', '2') as a1) t1 where a1 in ((...), ...);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for your case, it's not CreateNamedStructLike, but just a struct type column?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, because of optimizations, it is a CreateNamedStructLike

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, I think for this case we should optimize it.

Anyway it follows the previous behavior, we can change it later.

&& !newList.head.isInstanceOf[CreateNamedStructLike]) {
EqualTo(v, newList.head)
EqualTo(expr.value, newList.head)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we do this only when value is not a CreateNamedStructLike, so we don't go here if there are multi-values

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we update the match here? I think it should be In(Seq(vaue) ...) now

Copy link
Contributor Author

@mgaido91 mgaido91 Oct 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, sorry, we can't do that, otherwise we would skip the other possible optimizations here, eg. converting to InSet, reducing the list of values, etc.etc.

What should be done, instead, is doing the same change to InSet, so that the way nulls are handled is coherent.

} else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) {
val hSet = newList.map(e => e.eval(EmptyRow))
InSet(v, HashSet() ++ hSet)
InSet(expr.value, HashSet() ++ hSet)
} else if (newList.length < list.length) {
expr.copy(list = newList)
} else { // newList.length == list.length && newList.length > 1
Expand Down Expand Up @@ -527,7 +527,7 @@ object NullPropagation extends Rule[LogicalPlan] {
}

// If the value expression is NULL then transform the In expression to null literal.
case In(Literal(null, _), _) => Literal.create(null, BooleanType)
case In(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType)
case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType)

// Non-leaf NullIntolerant expressions will return null, if at least one of its children is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case SqlBaseParser.IN if ctx.query != null =>
invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query))))
case SqlBaseParser.IN =>
invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
invertIfNotDefined(In(getValueExpressions(e), ctx.expression.asScala.map(expression)))
case SqlBaseParser.LIKE =>
invertIfNotDefined(Like(e, expression(ctx.pattern)))
case SqlBaseParser.RLIKE =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) =>
evaluateBinary(LessThanOrEqual(ar, l), ar, l, update)

case In(ar: Attribute, expList)
if expList.forall(e => e.isInstanceOf[Literal]) =>
case In(Seq(ar: Attribute), expList) if expList.forall(e => e.isInstanceOf[Literal]) =>
// Expression [In (value, seq[Literal])] will be replaced with optimized version
// [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10.
// Here we convert In into InSet anyway, because they share the same processing logic.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_IN_FALSE_FOR_NULL_FIELD =
buildConf("spark.sql.legacy.inOperator.falseForNullField")
.internal()
.doc("When set to true (default), the IN operator returns false when comparing literal " +
"structs containing a null field. When set to false, it returns null, instead. This is " +
"important especially when using NOT IN as in the second case, it filters out the rows " +
"when a null is present in a filed; while in the first one, those rows are returned.")
.booleanConf
.createWithDefault(true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we set false as default to follow SQL standard? and be consistent with in-subquery

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I agree, let me switch it, thanks.


val LEGACY_INTEGRALDIVIDE_RETURN_LONG = buildConf("spark.sql.legacy.integralDivide.returnBigint")
.doc("If it is set to true, the div operator returns always a bigint. This behavior was " +
"inherited from Hive. Otherwise, the return type is the data type of the operands.")
Expand Down Expand Up @@ -1978,6 +1988,8 @@ class SQLConf extends Serializable with Logging {

def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED)

def inFalseForNullField: Boolean = getConf(SQLConf.LEGACY_IN_FALSE_FOR_NULL_FIELD)

def integralDivideReturnLong: Boolean = getConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG)

/** ********************** SQLConf functionality methods ************ */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,21 +280,22 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

test("SPARK-8654: invalid CAST in NULL IN(...) expression") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil,
val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(2))), "a")() :: Nil,
LocalRelation()
)
assertAnalysisSuccess(plan)
}

test("SPARK-8654: different types in inlist but can be converted to a common type") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil,
val plan = Project(
Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil,
LocalRelation()
)
assertAnalysisSuccess(plan)
}

test("SPARK-8654: check type compatibility error") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil,
val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(true), Literal(1))), "a")() :: Nil,
LocalRelation()
)
assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1432,16 +1432,16 @@ class TypeCoercionSuite extends AnalysisTest {
// InConversion
val inConversion = TypeCoercion.InConversion(conf)
ruleTest(inConversion,
In(UnresolvedAttribute("a"), Seq(Literal(1))),
In(UnresolvedAttribute("a"), Seq(Literal(1)))
In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))),
In(Seq(UnresolvedAttribute("a")), Seq(Literal(1)))
)
ruleTest(inConversion,
In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))),
In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1)))
In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))),
In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1)))
)
ruleTest(inConversion,
In(Literal("a"), Seq(Literal(1), Literal("b"))),
In(Cast(Literal("a"), StringType),
In(Seq(Literal("a")), Seq(Literal(1), Literal("b"))),
In(Seq(Cast(Literal("a"), StringType)),
Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType)))
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
checkAnswer(tbl2, Seq.empty, Set(part1, part2))
checkAnswer(tbl2, Seq('a.int <= 1), Set(part1))
checkAnswer(tbl2, Seq('a.int === 2), Set.empty)
checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2))
checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2))
checkAnswer(tbl2, Seq(In(Seq('a.int * 10), Seq(30))), Set(part2))
checkAnswer(tbl2, Seq(Not(In(Seq('a.int), Seq(4)))), Set(part1, part2))
checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1))
checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1))
checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@ class CanonicalizeSuite extends SparkFunSuite {
val range = Range(1, 1, 1, 1)
val idAttr = range.output.head

val in1 = In(idAttr, Seq(Literal(1), Literal(2)))
val in2 = In(idAttr, Seq(Literal(2), Literal(1)))
val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3)))
val in1 = In(Seq(idAttr), Seq(Literal(1), Literal(2)))
val in2 = In(Seq(idAttr), Seq(Literal(2), Literal(1)))
val in3 = In(Seq(idAttr), Seq(Literal(1), Literal(2), Literal(3)))

assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash())
assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash())

assert(range.where(in1).sameResult(range.where(in2)))
assert(!range.where(in1).sameResult(range.where(in3)))

val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))),
val arrays1 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))),
CreateArray(Seq(Literal(2), Literal(1)))))
val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))),
val arrays2 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(2), Literal(1))),
CreateArray(Seq(Literal(1), Literal(2)))))
val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))),
val arrays3 = In(Seq(idAttr), Seq(CreateArray(Seq(Literal(1), Literal(2))),
CreateArray(Seq(Literal(3), Literal(1)))))

assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash())
Expand Down
Loading