Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2425,7 +2425,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, outer)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output)
ListQuery(plan, exprs, exprId, plan.output.length)
})
InSubquery(values, expr.asInstanceOf[ListQuery])
case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,11 @@ abstract class TypeCoercionBase {

// Handle type casting required between value expression and subquery output
// in IN subquery.
case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _, conditions, _))
if !i.resolved && lhs.length == sub.output.length =>
case i @ InSubquery(lhs, l: ListQuery)
if !i.resolved && lhs.length == l.plan.output.length =>
// LHS is the value expressions of IN subquery.
// RHS is the subquery output.
val rhs = sub.output
val rhs = l.plan.output

val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
findWiderTypeForTwo(l.dataType, r.dataType)
Expand All @@ -383,8 +383,7 @@ abstract class TypeCoercionBase {
case (e, _) => e
}

val newSub = Project(castedRhs, sub)
InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output, conditions))
InSubquery(newLhs, l.withNewPlan(Project(castedRhs, l.plan)))
} else {
i
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,12 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
final override val nodePatterns: Seq[TreePattern] = Seq(IN_SUBQUERY)

override def checkInputDataTypes(): TypeCheckResult = {
if (values.length != query.childOutputs.length) {
if (values.length != query.numCols) {
DataTypeMismatch(
errorSubClass = "IN_SUBQUERY_LENGTH_MISMATCH",
messageParameters = Map(
"leftLength" -> values.length.toString,
"rightLength" -> query.childOutputs.length.toString,
"rightLength" -> query.numCols.toString,
"leftColumns" -> values.map(toSQLExpr(_)).mkString(", "),
"rightColumns" -> query.childOutputs.map(toSQLExpr(_)).mkString(", ")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,16 +354,19 @@ case class ListQuery(
plan: LogicalPlan,
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
childOutputs: Seq[Attribute] = Seq.empty,
// The plan of list query may have more columns after de-correlation, and we need to track the
// number of the columns of the original plan, to report the data type properly.
numCols: Int = -1,
joinCond: Seq[Expression] = Seq.empty,
hint: Option[HintInfo] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
override def dataType: DataType = if (childOutputs.length > 1) {
def childOutputs: Seq[Attribute] = plan.output.take(numCols)
override def dataType: DataType = if (numCols > 1) {
childOutputs.toStructType
} else {
childOutputs.head.dataType
plan.output.head.dataType
}
override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty
override lazy val resolved: Boolean = childrenResolved && plan.resolved && numCols != -1
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
override def withNewHint(hint: Option[HintInfo]): ListQuery = copy(hint = hint)
Expand All @@ -373,7 +376,7 @@ case class ListQuery(
plan.canonicalized,
outerAttrs.map(_.canonicalized),
ExprId(0),
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]),
numCols,
joinCond.map(_.canonicalized))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
return filterApplicationSidePlan
}
val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)),
ListQuery(aggregate, childOutputs = aggregate.output))
ListQuery(aggregate, numCols = aggregate.output.length))
Filter(filter, filterApplicationSidePlan)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,10 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint)
case ListQuery(sub, children, exprId, childOutputs, conditions, hint) if children.nonEmpty =>
case ListQuery(sub, children, exprId, numCols, conditions, hint) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
val joinCond = getJoinCondition(newCond, conditions)
ListQuery(newPlan, children, exprId, childOutputs, joinCond, hint)
ListQuery(newPlan, children, exprId, numCols, joinCond, hint)
case LateralSubquery(sub, children, exprId, conditions, hint) if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, plan, handleCountBug = true)
LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1500,4 +1500,28 @@ class AnalysisSuite extends AnalysisTest with Matchers {
assert(refs.map(_.output).distinct.length == 3)
}
}

test("SPARK-43190: ListQuery.childOutput should be consistent with child output") {
val listQuery1 = ListQuery(testRelation2.select($"a"))
val listQuery2 = ListQuery(testRelation2.select($"b"))
val plan = testRelation3.where($"f".in(listQuery1) && $"f".in(listQuery2)).analyze
val resolvedCondition = plan.expressions.head
val finalPlan = testRelation2.join(testRelation3).where(resolvedCondition).analyze
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test uses the resolved ListQuery to build a new plan and resolve it, to trigger the bug. Otherwise the bug is hidden because DeduplicateRelations runs before ResolveSubqueries, and the plan output of ListQuery won't be changed again.

val resolvedListQueries = finalPlan.expressions.flatMap(_.collect {
case l: ListQuery => l
})
assert(resolvedListQueries.length == 2)

def collectLocalRelations(plan: LogicalPlan): Seq[LocalRelation] = plan.collect {
case l: LocalRelation => l
}
val localRelations = resolvedListQueries.flatMap(l => collectLocalRelations(l.plan))
assert(localRelations.length == 2)
// DeduplicateRelations should deduplicate plans in subquery expressions as well.
assert(localRelations.head.output != localRelations.last.output)

resolvedListQueries.foreach { l =>
assert(l.childOutputs == l.plan.output)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,7 @@ trait PlanTestBase extends PredicateHelper with SQLHelper with SQLConfHelper { s
case e: Exists =>
e.copy(plan = normalizeExprIds(e.plan), exprId = ExprId(0))
case l: ListQuery =>
l.copy(
plan = normalizeExprIds(l.plan),
exprId = ExprId(0),
childOutputs = l.childOutputs.map(_.withExprId(ExprId(0))))
l.copy(plan = normalizeExprIds(l.plan), exprId = ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case OuterReference(a: AttributeReference) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp
val alias = Alias(buildKeys(broadcastKeyIndex), buildKeys(broadcastKeyIndex).toString)()
val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan)
DynamicPruningExpression(expressions.InSubquery(
Seq(value), ListQuery(aggregate, childOutputs = aggregate.output)))
Seq(value), ListQuery(aggregate, numCols = aggregate.output.length)))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,6 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla

val buildQuery = Aggregate(buildKeys, buildKeys, matchingRowsPlan)
DynamicPruningExpression(
InSubquery(pruningKeys, ListQuery(buildQuery, childOutputs = buildQuery.output)))
InSubquery(pruningKeys, ListQuery(buildQuery, numCols = buildQuery.output.length)))
}
}