Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -1286,8 +1286,10 @@ class Analyzer(
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(e, plans)(Exists(_, _, exprId))
case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved =>
val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId))
case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !sub.resolved =>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya If we modified to

case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved

would we still require the following case statement ? The following case looks a little
strange as we are in the resolveSubqueries routine and check for sub.resolved == true.

@viirya viirya Aug 24, 2017

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought to change resolveSubQuery to avoid re-analysis on a resolved plan. But since it is just once, maybe not a big deal. So finally I leave it untouched.

val expr = resolveSubQuery(l, plans)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output)
})
In(value, Seq(expr))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ object TypeCoercion {

// Handle type casting required between value expression and subquery output
// in IN subquery.
case i @ In(a, Seq(ListQuery(sub, children, exprId)))
case i @ In(a, Seq(ListQuery(sub, children, exprId, _)))
if !i.resolved && flattenExpr(a).length == sub.output.length =>
// LHS is the value expression of IN subquery.
val lhs = flattenExpr(a)
Expand Down Expand Up @@ -434,7 +434,8 @@ object TypeCoercion {
case _ => CreateStruct(castedLhs)
}

In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
val newSub = Project(castedRhs, sub)
In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output)))
} else {
i
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,55 +138,14 @@ case class Not(child: Expression)
case class In(value: Expression, list: Seq[Expression]) extends Predicate {

require(list != null, "list should not be null")

override def checkInputDataTypes(): TypeCheckResult = {
list match {
case ListQuery(sub, _, _) :: Nil =>
val valExprs = value match {
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
if (valExprs.length != sub.output.length) {
TypeCheckResult.TypeCheckFailure(
s"""
|The number of columns in the left hand side of an IN subquery does not match the
|number of columns in the output of subquery.
|#columns in left hand side: ${valExprs.length}.
|#columns in right hand side: ${sub.output.length}.
|Left side columns:
|[${valExprs.map(_.sql).mkString(", ")}].
|Right side columns:
|[${sub.output.map(_.sql).mkString(", ")}].
""".stripMargin)
} else {
val mismatchedColumns = valExprs.zip(sub.output).flatMap {
case (l, r) if l.dataType != r.dataType =>
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
case _ => None
}
if (mismatchedColumns.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
s"""
|The data type of one or more elements in the left hand side of an IN subquery
|is not compatible with the data type of the output of the subquery
|Mismatched columns:
|[${mismatchedColumns.mkString(", ")}]
|Left side:
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
|Right side:
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
""".stripMargin)
} else {
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}
case _ =>
val mismatchOpt = list.find(l => l.dataType != value.dataType)
if (mismatchOpt.isDefined) {
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${value.dataType} != ${mismatchOpt.get.dataType}")
} else {
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
val mismatchOpt = list.find(l => l.dataType != value.dataType)
if (mismatchOpt.isDefined) {
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${value.dataType} != ${mismatchOpt.get.dataType}")
} else {
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,17 +274,23 @@ object ScalarSubquery {
case class ListQuery(
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
exprId: ExprId = NamedExpression.newExprId,
childOutputs: Seq[Attribute] = Seq.empty)
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
override def dataType: DataType = plan.schema.fields.head.dataType
override def dataType: DataType = if (childOutputs.length > 1) {
childOutputs.toStructType
} else {
childOutputs.head.dataType
}
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
override def toString: String = s"list#${exprId.id} $conditionString"
override lazy val canonicalized: Expression = {
ListQuery(
plan.canonicalized,
children.map(_.canonicalized),
ExprId(0))
ExprId(0),
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
case (p, Not(Exists(sub, conditions, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftAnti, joinCond)
case (p, In(value, Seq(ListQuery(sub, conditions, _)))) =>
case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) =>
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) =>
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.
Expand Down Expand Up @@ -116,7 +116,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
val exists = AttributeReference("exists", BooleanType, nullable = false)()
newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
exists
case In(value, Seq(ListQuery(sub, conditions, _))) =>
case In(value, Seq(ListQuery(sub, conditions, _, _))) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
Expand Down Expand Up @@ -227,9 +227,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
case Exists(sub, children, exprId) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
Exists(newPlan, newCond, exprId)
case ListQuery(sub, _, exprId) =>
case ListQuery(sub, _, exprId, childOutputs) =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
ListQuery(newPlan, newCond, exprId)
ListQuery(newPlan, newCond, exprId, childOutputs)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class PullupCorrelatedPredicatesSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("PullupCorrelatedPredicates", Once,
PullupCorrelatedPredicates) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.double)
val testRelation2 = LocalRelation('c.int, 'd.double)

test("PullupCorrelatedPredicates should not produce unresolved plan") {
val correlatedSubquery =
testRelation2
.where('b < 'd)
.select('c)
val outerQuery =
testRelation
.where(In('a, Seq(ListQuery(correlatedSubquery))))
.select('a).analyze
assert(outerQuery.resolved)

val optimized = Optimize.execute(outerQuery)
assert(optimized.resolved)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,7 @@ t1a IN (SELECT t2a, t2b
struct<>
-- !query 5 output
org.apache.spark.sql.AnalysisException
cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch:
The number of columns in the left hand side of an IN subquery does not match the
number of columns in the output of subquery.
#columns in left hand side: 1.
#columns in right hand side: 2.
Left side columns:
[t1.`t1a`].
Right side columns:
[t2.`t2a`, t2.`t2b`].
;
cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: Arguments must be same type but were: IntegerType != StructType(StructField(t2a,IntegerType,false), StructField(t2b,IntegerType,false));

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new message is confusing when users using the In Subquery.



-- !query 6
Expand All @@ -94,13 +85,4 @@ WHERE
struct<>
-- !query 6 output
org.apache.spark.sql.AnalysisException
cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch:
The number of columns in the left hand side of an IN subquery does not match the
number of columns in the output of subquery.
#columns in left hand side: 2.
#columns in right hand side: 1.
Left side columns:
[t1.`t1a`, t1.`t1b`].
Right side columns:
[t2.`t2a`].
;
cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: Arguments must be same type but were: StructType(StructField(t1a,IntegerType,false), StructField(t1b,IntegerType,false)) != IntegerType;