Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -509,19 +509,21 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
/**
* Split the plan for a scalar subquery into the parts above the innermost query block
* (first part of returned value), the HAVING clause of the innermost query block
* (optional second part) and the parts below the HAVING CLAUSE (third part).
* (optional second part) and the Aggregate below the HAVING CLAUSE (optional third part).
* When the third part is empty, it means the subquery is a non-aggregated single-row subquery.
*/
private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = {
private def splitSubquery(
plan: LogicalPlan): (Seq[LogicalPlan], Option[Filter], Option[Aggregate]) = {
val topPart = ArrayBuffer.empty[LogicalPlan]
var bottomPart: LogicalPlan = plan
while (true) {
bottomPart match {
case havingPart @ Filter(_, aggPart: Aggregate) =>
return (topPart.toSeq, Option(havingPart), aggPart)
return (topPart.toSeq, Option(havingPart), Some(aggPart))

case aggPart: Aggregate =>
// No HAVING clause
return (topPart.toSeq, None, aggPart)
return (topPart.toSeq, None, Some(aggPart))

case p @ Project(_, child) =>
topPart += p
Expand All @@ -531,6 +533,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
topPart += s
bottomPart = child

case p: LogicalPlan if p.maxRows.exists(_ <= 1) =>
// Non-aggregated one row subquery.
return (topPart.toSeq, None, None)

case Filter(_, op) =>
throw QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op, " below filter")

Expand Down Expand Up @@ -561,72 +567,80 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
val origOutput = query.output.head

val resultWithZeroTups = evalSubqueryOnZeroTups(query)
lazy val planWithoutCountBug = Project(
currentChild.output :+ origOutput,
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))

if (resultWithZeroTups.isEmpty) {
// CASE 1: Subquery guaranteed not to have the COUNT bug
Project(
currentChild.output :+ origOutput,
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
planWithoutCountBug
} else {
// Subquery might have the COUNT bug. Add appropriate corrections.
val (topPart, havingNode, aggNode) = splitSubquery(query)

// The next two cases add a leading column to the outer join input to make it
// possible to distinguish between the case when no tuples join and the case
// when the tuple that joins contains null values.
// The leading column always has the value TRUE.
val alwaysTrueExprId = NamedExpression.newExprId
val alwaysTrueExpr = Alias(Literal.TrueLiteral,
ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId)
val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME,
BooleanType)(exprId = alwaysTrueExprId)

val aggValRef = query.output.head

if (havingNode.isEmpty) {
// CASE 2: Subquery with no HAVING clause
val subqueryResultExpr =
Alias(If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)()
subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute))
Project(
currentChild.output :+ subqueryResultExpr,
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))

if (aggNode.isEmpty) {
// SPARK-40862: When the aggregate node is empty, it means the subquery produces
// at most one row and it is not subject to the COUNT bug.
planWithoutCountBug
} else {
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
// Need to modify any operators below the join to pass through all columns
// referenced in the HAVING clause.
var subqueryRoot: UnaryNode = aggNode
val havingInputs: Seq[NamedExpression] = aggNode.output

topPart.reverse.foreach {
case Project(projList, _) =>
subqueryRoot = Project(projList ++ havingInputs, subqueryRoot)
case s @ SubqueryAlias(alias, _) =>
subqueryRoot = SubqueryAlias(alias, subqueryRoot)
case op => throw QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op)
// Subquery might have the COUNT bug. Add appropriate corrections.
val aggregate = aggNode.get

// The next two cases add a leading column to the outer join input to make it
// possible to distinguish between the case when no tuples join and the case
// when the tuple that joins contains null values.
// The leading column always has the value TRUE.
val alwaysTrueExprId = NamedExpression.newExprId
val alwaysTrueExpr = Alias(Literal.TrueLiteral,
ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId)
val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME,
BooleanType)(exprId = alwaysTrueExprId)

val aggValRef = query.output.head

if (havingNode.isEmpty) {
// CASE 2: Subquery with no HAVING clause
val subqueryResultExpr =
Alias(If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)()
subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute))
Project(
currentChild.output :+ subqueryResultExpr,
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))

} else {
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
// Need to modify any operators below the join to pass through all columns
// referenced in the HAVING clause.
var subqueryRoot: UnaryNode = aggregate
val havingInputs: Seq[NamedExpression] = aggregate.output

topPart.reverse.foreach {
case Project(projList, _) =>
subqueryRoot = Project(projList ++ havingInputs, subqueryRoot)
case s@SubqueryAlias(alias, _) =>
subqueryRoot = SubqueryAlias(alias, subqueryRoot)
case op => throw QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op)
}

// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
// ELSE (aggregate value) END AS (original column name)
val caseExpr = Alias(CaseWhen(Seq(
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)()

subqueryAttrMapping += ((origOutput, caseExpr.toAttribute))

Project(
currentChild.output :+ caseExpr,
Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
}

// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
// ELSE (aggregate value) END AS (original column name)
val caseExpr = Alias(CaseWhen(Seq(
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)()

subqueryAttrMapping += ((origOutput, caseExpr.toAttribute))

Project(
currentChild.output :+ caseExpr,
Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))

}
}
}
Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2491,4 +2491,21 @@ class SubquerySuite extends QueryTest
Row("a"))
}
}

test("SPARK-40862: correlated one-row subquery with non-deterministic expressions") {
import org.apache.spark.sql.functions.udf
withTempView("t1") {
sql("CREATE TEMP VIEW t1 AS SELECT ARRAY('a', 'b') a")
val func = udf(() => "a")
spark.udf.register("func", func.asNondeterministic())
checkAnswer(sql(
"""
|SELECT (
| SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] || str AS sorted
| FROM (SELECT MAP('a', 1, 'b', 2) rank, func() AS str)
|) FROM t1
|""".stripMargin),
Row("aa"))
}
}
}