Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1347,11 +1347,17 @@ class Analyzer(
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(e, plans)(Exists(_, _, exprId))
case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved =>
case i @ In(values, Seq(l @ ListQuery(_, _, exprId, _)))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, plans)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output)
})
In(value, Seq(expr))
val subqueryOutputNum = expr.asInstanceOf[ListQuery].childOutputs.length
if (values.length != subqueryOutputNum) {
throw new AnalysisException(s"${i.sql} has ${values.length} values, but the " +
s"subquery has $subqueryOutputNum output values.")
}
In(values, Seq(expr))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,27 +406,16 @@ object TypeCoercion {
* Analysis Exception will be raised at the type checking phase.
*/
case class InConversion(conf: SQLConf) extends TypeCoercionRule {
private def flattenExpr(expr: Expression): Seq[Expression] = {
expr match {
// Multi columns in IN clause is represented as a CreateNamedStruct.
// flatten the named struct to get the list of expressions.
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
}

override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

// Handle type casting required between value expression and subquery output
// in IN subquery.
case i @ In(a, Seq(ListQuery(sub, children, exprId, _)))
if !i.resolved && flattenExpr(a).length == sub.output.length =>
// LHS is the value expression of IN subquery.
val lhs = flattenExpr(a)

// LHS is the value expressions of IN subquery.
case i @ In(lhs, Seq(ListQuery(sub, children, exprId, _)))
if !i.resolved && lhs.length == sub.output.length =>
// RHS is the subquery output.
val rhs = sub.output

Expand All @@ -442,27 +431,26 @@ object TypeCoercion {
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
val castedLhs = lhs.zip(commonTypes).map {
val newLhs = lhs.zip(commonTypes).map {
case (e, dt) if e.dataType != dt => Cast(e, dt)
case (e, _) => e
}

// Before constructing the In expression, wrap the multi values in LHS
// in a CreatedNamedStruct.
val newLhs = castedLhs match {
case Seq(lhs) => lhs
case _ => CreateStruct(castedLhs)
}

val newSub = Project(castedRhs, sub)
In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output)))
} else {
i
}

case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
findWiderCommonType(i.children.map(_.dataType)) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
case i @ In(a, b) if b.exists(_.dataType != i.value.dataType) =>
findWiderCommonType(i.value.dataType +: b.map(_.dataType)) match {
case Some(finalDataType: StructType) if i.values.length > 1 =>
val newValues = a.zip(finalDataType.fields.map(_.dataType)).map {
case (expr, dataType) => Cast(expr, dataType)
}
In(newValues, b.map(Cast(_, finalDataType)))
case Some(finalDataType) =>
In(a.map(Cast(_, finalDataType)), b.map(Cast(_, finalDataType)))
case None => i
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ package object dsl {
def <=> (other: Expression): Predicate = EqualNullSafe(expr, other)
def =!= (other: Expression): Predicate = Not(EqualTo(expr, other))

def in(list: Expression*): Expression = In(expr, list)
def in(list: Expression*): Expression = expr match {
case c: CreateNamedStruct => In(c.valExprs, list)
case other => In(Seq(other), list)
}

def like(other: Expression): Expression = Like(expr, other)
def rlike(other: Expression): Expression = RLike(expr, other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,33 +161,38 @@ case class Not(child: Expression)
true
""")
// scalastyle:on line.size.limit
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
case class In(values: Seq[Expression], list: Seq[Expression]) extends Predicate {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have an analyzer rule to deal with In(CreateStruct(...), ListQuery(...)), to unpack the CreateStruct, or pack the ListQuery? Then we don't need to change In.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, as the value can be replaced later by other rules. So we do need to have a Seq[Expression] here, instead of a single expression. Another possible option which I haven't checked, but I think it may be feasible is to create a new kind of Expression (eg. InValues) we can use only for this specific case. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on InValues. Maybe call it InSubquery

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not a subquery, this is the "left part" of IN, so I don't really agree on InSubquery, but if you have another suggestion I am happy to follow it. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean case class InSubquery(values: Seq[Expression], subquery: ListSubquery), it's not just the left part.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, I think right behavior is the one which both Postgres and Hive have (and it is also the same of Oracle/MySQL, in which we don't have structs). What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should treat (...) specially if it's in front of In, but I'm wondering if we need to do the same thing for =.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. The behavior when comparing structs in not uniform among different DBs. Hive doesn't allow = on structs. Postgres and Presto does, but their behavior with nulls is not consistent and it is different from ours. In particular, comparing a struct containing a null returns null on Postgres and causes an exception in Presto (we return false instead). This is causing another problem which has been reported in another JIRA for which we can return results different from Postgres and Oracle (SPARK-24395).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this specific case, instead, I'll update this PR creating the new ad-hoc expression for the values in front of IN if you agree, as we have to deal not only with the subquery case. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM


require(list != null, "list should not be null")

@transient lazy val value = if (values.length > 1) {
CreateNamedStruct(values.zipWithIndex.flatMap {
case (v: NamedExpression, _) => Seq(Literal(v.name), v)
case (v, idx) => Seq(Literal(s"_$idx"), v)
})
} else {
values.head
}

override def checkInputDataTypes(): TypeCheckResult = {
val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType,
ignoreNullability = true))
if (mismatchOpt.isDefined) {
list match {
case ListQuery(_, _, _, childOutputs) :: Nil =>
val valExprs = value match {
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
if (valExprs.length != childOutputs.length) {
if (values.length != childOutputs.length) {
TypeCheckResult.TypeCheckFailure(
s"""
|The number of columns in the left hand side of an IN subquery does not match the
|number of columns in the output of subquery.
|#columns in left hand side: ${valExprs.length}.
|#columns in left hand side: ${values.length}.
|#columns in right hand side: ${childOutputs.length}.
|Left side columns:
|[${valExprs.map(_.sql).mkString(", ")}].
|[${values.map(_.sql).mkString(", ")}].
|Right side columns:
|[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin)
} else {
val mismatchedColumns = valExprs.zip(childOutputs).flatMap {
val mismatchedColumns = values.zip(childOutputs).flatMap {
case (l, r) if l.dataType != r.dataType =>
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
case _ => None
Expand All @@ -199,7 +204,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|Mismatched columns:
|[${mismatchedColumns.mkString(", ")}]
|Left side:
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
|[${values.map(_.dataType.catalogString).mkString(", ")}].
|Right side:
|[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin)
}
Expand All @@ -212,7 +217,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
}

override def children: Seq[Expression] = value +: list
override def children: Seq[Expression] = values ++: list
lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType)

Expand Down Expand Up @@ -307,9 +312,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}

override def sql: String = {
val childrenSQL = children.map(_.sql)
val valueSQL = childrenSQL.head
val listSQL = childrenSQL.tail.mkString(", ")
val valueSQL = value.sql
val listSQL = list.map(_.sql).mkString(", ")
s"($valueSQL IN ($listSQL))"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,18 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
* 1. Converts the predicate to false when the list is empty and
* the value is not nullable.
* 2. Removes literal repetitions.
* 3. Replaces [[In (value, seq[Literal])]] with optimized version
* 3. Replaces [[In (values, seq[Literal])]] with optimized version
* [[InSet (value, HashSet[Literal])]] which is much faster.
*/
object OptimizeIn extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral
case expr @ In(v, list) if expr.inSetConvertible =>
case i @ In(_, list) if list.isEmpty && !i.value.nullable => FalseLiteral
case expr @ In(_, list) if expr.inSetConvertible =>
val newList = ExpressionSet(list).toSeq
if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) {
val hSet = newList.map(e => e.eval(EmptyRow))
InSet(v, HashSet() ++ hSet)
InSet(expr.value, HashSet() ++ hSet)
} else if (newList.size < list.size) {
expr.copy(list = newList)
} else { // newList.length == list.length
Expand Down Expand Up @@ -493,7 +493,7 @@ object NullPropagation extends Rule[LogicalPlan] {
}

// If the value expression is NULL then transform the In expression to null literal.
case In(Literal(null, _), _) => Literal.create(null, BooleanType)
case In(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType)

// Non-leaf NullIntolerant expressions will return null, if at least one of its children is
// a null literal.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand All @@ -42,13 +43,6 @@ import org.apache.spark.sql.types._
* condition.
*/
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
private def getValueExpression(e: Expression): Seq[Expression] = {
e match {
case cns : CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
}

private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match {
// SPARK-21835: It is possibly that the two sides of the join have conflicting attributes,
// the produced join then becomes unresolved and break structural integrity. We should
Expand Down Expand Up @@ -97,19 +91,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) =>
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
case (p, In(values, Seq(ListQuery(sub, conditions, _, _)))) =>
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) =>
case (p, Not(In(values, Seq(ListQuery(sub, conditions, _, _))))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.

// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
// Expand the NOT IN expression with the NULL-aware semantic
// to its full form. That is from:
Expand Down Expand Up @@ -150,9 +144,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
newPlan = dedupJoin(
Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)))
exists
case In(value, Seq(ListQuery(sub, conditions, _, _))) =>
case In(values, Seq(ListQuery(sub, conditions, _, _))) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
// Deduplicate conflicting attributes if any.
newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case not => Not(e)
}

def getValueExpressions(e: Expression): Seq[Expression] = e match {
case c: CreateNamedStruct => c.valExprs
case other => Seq(other)
}

// Create the predicate.
ctx.kind.getType match {
case SqlBaseParser.BETWEEN =>
Expand All @@ -1094,9 +1099,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
GreaterThanOrEqual(e, expression(ctx.lower)),
LessThanOrEqual(e, expression(ctx.upper))))
case SqlBaseParser.IN if ctx.query != null =>
invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query)))))
invertIfNotDefined(In(getValueExpressions(e), Seq(ListQuery(plan(ctx.query)))))
case SqlBaseParser.IN =>
invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
invertIfNotDefined(In(getValueExpressions(e), ctx.expression.asScala.map(expression)))
case SqlBaseParser.LIKE =>
invertIfNotDefined(Like(e, expression(ctx.pattern)))
case SqlBaseParser.RLIKE =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
case op @ GreaterThanOrEqual(l: Literal, ar: Attribute) =>
evaluateBinary(LessThanOrEqual(ar, l), ar, l, update)

case In(ar: Attribute, expList)
case In(Seq(ar: Attribute), expList)
if expList.forall(e => e.isInstanceOf[Literal]) =>
// Expression [In (value, seq[Literal])] will be replaced with optimized version
// [InSet (value, HashSet[Literal])] in Optimizer, but only for list.size > 10.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val plan = Project(
Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()),
Seq(a, Alias(In(Seq(a), Seq(ListQuery(LocalRelation(b)))), "c")()),
LocalRelation(a))
assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil)
}
Expand All @@ -530,12 +530,13 @@ class AnalysisErrorSuite extends AnalysisTest {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val c = AttributeReference("c", BooleanType)()
val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType),
val plan1 = Filter(Cast(Not(In(Seq(a), Seq(ListQuery(LocalRelation(b))))), BooleanType),
LocalRelation(a))
assertAnalysisError(plan1,
"Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil)

val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c))
val plan2 = Filter(
Or(Not(In(Seq(a), Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c))
assertAnalysisError(plan2,
"Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,21 +275,21 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

test("SPARK-8654: invalid CAST in NULL IN(...) expression") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil,
val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(2))), "a")() :: Nil,
LocalRelation()
)
assertAnalysisSuccess(plan)
}

test("SPARK-8654: different types in inlist but can be converted to a common type") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil,
LocalRelation()
)
val plan = Project(
Alias(In(Seq(Literal(null)), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil,
LocalRelation())
assertAnalysisSuccess(plan)
}

test("SPARK-8654: check type compatibility error") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil,
val plan = Project(Alias(In(Seq(Literal(null)), Seq(Literal(true), Literal(1))), "a")() :: Nil,
LocalRelation()
)
assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ResolveSubquerySuite extends AnalysisTest {
val t2 = LocalRelation(b)

test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") {
val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1)
val expr = Filter(In(Seq(a), Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1)
val m = intercept[AnalysisException] {
SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr))
}.getMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1240,16 +1240,16 @@ class TypeCoercionSuite extends AnalysisTest {
// InConversion
val inConversion = TypeCoercion.InConversion(conf)
ruleTest(inConversion,
In(UnresolvedAttribute("a"), Seq(Literal(1))),
In(UnresolvedAttribute("a"), Seq(Literal(1)))
In(Seq(UnresolvedAttribute("a")), Seq(Literal(1))),
In(Seq(UnresolvedAttribute("a")), Seq(Literal(1)))
)
ruleTest(inConversion,
In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))),
In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1)))
In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1))),
In(Seq(Literal("test")), Seq(UnresolvedAttribute("a"), Literal(1)))
)
ruleTest(inConversion,
In(Literal("a"), Seq(Literal(1), Literal("b"))),
In(Cast(Literal("a"), StringType),
In(Seq(Literal("a")), Seq(Literal(1), Literal("b"))),
In(Seq(Cast(Literal("a"), StringType)),
Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType)))
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
checkAnswer(tbl2, Seq.empty, Set(part1, part2))
checkAnswer(tbl2, Seq('a.int <= 1), Set(part1))
checkAnswer(tbl2, Seq('a.int === 2), Set.empty)
checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2))
checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2))
checkAnswer(tbl2, Seq(In(Seq('a.int * 10), Seq(30))), Set(part2))
checkAnswer(tbl2, Seq(Not(In(Seq('a.int), Seq(4)))), Set(part1, part2))
checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1))
checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1))
checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty)
Expand Down
Loading