Skip to content

Commit e0e39d5

Browse files
committed
[SPARK-34079][SQL] Merging non-correlated scalar subqueries to multi-column scalar subqueries for better reuse
1 parent bb5459f commit e0e39d5

File tree

11 files changed

+556
-1386
lines changed

11 files changed

+556
-1386
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
2222
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2323
import org.apache.spark.sql.catalyst.plans.QueryPlan
2424
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
25+
import org.apache.spark.sql.catalyst.trees.LeafLike
2526
import org.apache.spark.sql.types._
2627
import org.apache.spark.util.collection.BitSet
2728

@@ -258,6 +259,28 @@ object ScalarSubquery {
258259
}
259260
}
260261

262+
case class MultiScalarSubquery(
263+
plan: LogicalPlan,
264+
exprId: ExprId = NamedExpression.newExprId)
265+
extends SubqueryExpression(plan, Seq.empty, exprId) with LeafLike[Expression] with Unevaluable {
266+
override def dataType: DataType = {
267+
assert(plan.schema.nonEmpty, "Multi-column scalar subquery should have columns")
268+
plan.schema
269+
}
270+
271+
override def nullable: Boolean = true
272+
273+
override def withNewPlan(plan: LogicalPlan): MultiScalarSubquery = copy(plan = plan)
274+
275+
override def toString: String = s"multi-scalar-subquery#${exprId.id}"
276+
277+
override lazy val canonicalized: Expression = {
278+
MultiScalarSubquery(
279+
plan.canonicalized,
280+
ExprId(0))
281+
}
282+
}
283+
261284
/**
262285
* A [[ListQuery]] expression defines the query which we want to search in an IN subquery
263286
* expression. It should and can only be used in conjunction with an IN expression.
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import scala.collection.mutable.ArrayBuffer
21+
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
26+
/**
27+
* This rule tries to merge non-correlated [[ScalarSubquery]]s into [[MultiScalarSubquery]]s.
28+
* Mergeable [[ScalarSubquery]]s are then replaced to their corresponding [[MultiScalarSubquery]]
29+
* and the [[ReuseSubquery]] rule makes sure that merged subqueries are computed once.
30+
*
31+
* Eg. the following query:
32+
*
33+
* SELECT
34+
* (SELECT avg(a) FROM t GROUP BY b),
35+
* (SELECT sum(b) FROM t GROUP BY b)
36+
*
37+
* is optimized from:
38+
*
39+
* Project [scalar-subquery#231 [] AS scalarsubquery()#241,
40+
* scalar-subquery#232 [] AS scalarsubquery()#242L]
41+
* : :- Aggregate [b#234], [avg(a#233) AS avg(a)#236]
42+
* : : +- Relation default.t[a#233,b#234] parquet
43+
* : +- Aggregate [b#240], [sum(b#240) AS sum(b)#238L]
44+
* : +- Project [b#240]
45+
* : +- Relation default.t[a#239,b#240] parquet
46+
* +- OneRowRelation
47+
*
48+
* to:
49+
*
50+
* Project [multi-scalar-subquery#231.avg(a) AS scalarsubquery()#241,
51+
* multi-scalar-subquery#232.sum(b) AS scalarsubquery()#242L]
52+
* : :- Aggregate [b#234], [avg(a#233) AS avg(a)#236, sum(b#234) AS sum(b)#238L]
53+
* : : +- Project [a#233, b#234]
54+
* : : +- Relation default.t[a#233,b#234] parquet
55+
* : +- Aggregate [b#234], [avg(a#233) AS avg(a)#236, sum(b#234) AS sum(b)#238L]
56+
* : +- Project [a#233, b#234]
57+
* : +- Relation default.t[a#233,b#234] parquet
58+
* +- OneRowRelation
59+
*/
60+
object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
61+
def apply(plan: LogicalPlan): LogicalPlan = {
62+
if (conf.scalarSubqueryMergeEabled && conf.subqueryReuseEnabled) {
63+
val mergedSubqueries = ArrayBuffer.empty[LogicalPlan]
64+
removeReferences(mergeAndInsertReferences(plan, mergedSubqueries), mergedSubqueries)
65+
} else {
66+
plan
67+
}
68+
}
69+
70+
private def mergeAndInsertReferences(
71+
plan: LogicalPlan,
72+
mergedSubqueries: ArrayBuffer[LogicalPlan]): LogicalPlan = {
73+
plan.transformAllExpressions {
74+
case s: ScalarSubquery if s.children.isEmpty =>
75+
val (mergedPlan, ordinal) = mergeAndGetReference(s.plan, mergedSubqueries)
76+
GetStructField(MultiScalarSubquery(mergedPlan, s.exprId), ordinal)
77+
}
78+
}
79+
80+
case class SubqueryReference(
81+
index: Int,
82+
mergedSubqueries: ArrayBuffer[LogicalPlan]) extends LeafNode {
83+
override def stringArgs: Iterator[Any] = Iterator(index)
84+
85+
override def output: Seq[Attribute] = mergedSubqueries(index).output
86+
}
87+
88+
private def mergeAndGetReference(
89+
plan: LogicalPlan,
90+
mergedSubqueries: ArrayBuffer[LogicalPlan]): (SubqueryReference, Int) = {
91+
mergedSubqueries.zipWithIndex.collectFirst {
92+
Function.unlift { case (s, i) => mergePlans(plan, s).map(_ -> i) }
93+
}.map { case ((mergedPlan, outputMap), i) =>
94+
mergedSubqueries(i) = mergedPlan
95+
SubqueryReference(i, mergedSubqueries) ->
96+
mergedPlan.output.indexOf(outputMap(plan.output.head))
97+
}.getOrElse {
98+
mergedSubqueries += plan
99+
SubqueryReference(mergedSubqueries.length - 1, mergedSubqueries) -> 0
100+
}
101+
}
102+
103+
private def mergePlans(
104+
newPlan: LogicalPlan,
105+
existingPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = {
106+
(newPlan, existingPlan) match {
107+
case (np, ep) if np.canonicalized == ep.canonicalized =>
108+
Some(ep -> AttributeMap(np.output.zip(ep.output)))
109+
case (np: Project, ep: Project) =>
110+
mergePlans(np.child, ep.child).map { case (mergedChild, outputMap) =>
111+
val newProjectList = replaceAttributes(np.projectList, outputMap)
112+
val newOutputMap = createOutputMap(np.projectList, newProjectList)
113+
Project(distinctExpressions(ep.projectList ++ newProjectList), mergedChild) ->
114+
newOutputMap
115+
}
116+
case (np, ep: Project) =>
117+
mergePlans(np, ep.child).map { case (mergedChild, outputMap) =>
118+
Project(distinctExpressions(ep.projectList ++ outputMap.values), mergedChild) -> outputMap
119+
}
120+
case (np: Project, ep) =>
121+
mergePlans(np.child, ep).map { case (mergedChild, outputMap) =>
122+
val newProjectList = replaceAttributes(np.projectList, outputMap)
123+
val newOutputMap = createOutputMap(np.projectList, newProjectList)
124+
Project(distinctExpressions(ep.output ++ newProjectList), mergedChild) -> newOutputMap
125+
}
126+
case (np: Aggregate, ep: Aggregate) =>
127+
mergePlans(np.child, ep.child).flatMap { case (mergedChild, outputMap) =>
128+
val newGroupingExpression = replaceAttributes(np.groupingExpressions, outputMap)
129+
if (ExpressionSet(newGroupingExpression) == ExpressionSet(ep.groupingExpressions)) {
130+
val newAggregateExpressions = replaceAttributes(np.aggregateExpressions, outputMap)
131+
val newOutputMap = createOutputMap(np.aggregateExpressions, newAggregateExpressions)
132+
Some(Aggregate(ep.groupingExpressions,
133+
distinctExpressions(ep.aggregateExpressions ++ newAggregateExpressions),
134+
mergedChild) -> newOutputMap)
135+
} else {
136+
None
137+
}
138+
}
139+
case _ =>
140+
None
141+
}
142+
}
143+
144+
private def replaceAttributes[T <: Expression](
145+
expressions: Seq[T],
146+
outputMap: AttributeMap[Attribute]) = {
147+
expressions.map(_.transform {
148+
case a: Attribute => outputMap.getOrElse(a, a)
149+
}.asInstanceOf[T])
150+
}
151+
152+
private def createOutputMap(from: Seq[NamedExpression], to: Seq[NamedExpression]) = {
153+
AttributeMap(from.map(_.toAttribute).zip(to.map(_.toAttribute)))
154+
}
155+
156+
private def distinctExpressions(expressions: Seq[NamedExpression]) = {
157+
ExpressionSet(expressions).toSeq.asInstanceOf[Seq[NamedExpression]]
158+
}
159+
160+
private def removeReferences(
161+
plan: LogicalPlan,
162+
mergedSubqueries: ArrayBuffer[LogicalPlan]): LogicalPlan = {
163+
plan.transformUp {
164+
case other => other.transformExpressionsUp {
165+
case gsf @ GetStructField(mss @ MultiScalarSubquery(sr: SubqueryReference, _), _, _) =>
166+
val dereferencedPlan = removeReferences(mergedSubqueries(sr.index), mergedSubqueries)
167+
if (dereferencedPlan.outputSet.size > 1) {
168+
gsf.copy(child = mss.copy(plan = dereferencedPlan))
169+
} else {
170+
ScalarSubquery(dereferencedPlan, exprId = mss.exprId)
171+
}
172+
case s: SubqueryExpression => s.withNewPlan(removeReferences(s.plan, mergedSubqueries))
173+
}
174+
}
175+
}
176+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
232232
ColumnPruning,
233233
CollapseProject,
234234
RemoveNoopOperators) :+
235+
Batch("MergeScalarSubqueries", Once,
236+
MergeScalarSubqueries) :+
235237
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
236238
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
237239
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,6 +1353,14 @@ object SQLConf {
13531353
.booleanConf
13541354
.createWithDefault(true)
13551355

1356+
val SCALAR_SUBQUERY_MERGE_ENABLED =
1357+
buildConf("spark.sql.scalarSubqueyMerge.enabled")
1358+
.internal()
1359+
.doc("When true, the planner will try to merge scalar subqueries and re-use them.")
1360+
.version("3.2.0")
1361+
.booleanConf
1362+
.createWithDefault(true)
1363+
13561364
val REMOVE_REDUNDANT_PROJECTS_ENABLED = buildConf("spark.sql.execution.removeRedundantProjects")
13571365
.internal()
13581366
.doc("Whether to remove redundant project exec node based on children's output and " +
@@ -3473,6 +3481,8 @@ class SQLConf extends Serializable with Logging {
34733481

34743482
def subqueryReuseEnabled: Boolean = getConf(SUBQUERY_REUSE_ENABLED)
34753483

3484+
def scalarSubqueryMergeEabled: Boolean = getConf(SCALAR_SUBQUERY_MERGE_ENABLED)
3485+
34763486
def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)
34773487

34783488
def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ case class InsertAdaptiveSparkPlan(
121121
val subquery = SubqueryExec.createForScalarSubquery(
122122
s"subquery#${exprId.id}", executedPlan)
123123
subqueryMap.put(exprId.id, subquery)
124+
case expressions.MultiScalarSubquery(p, exprId)
125+
if !subqueryMap.contains(exprId.id) =>
126+
val executedPlan = compileSubquery(p)
127+
verifyAdaptivePlan(executedPlan, p)
128+
val subquery = SubqueryExec.createForScalarSubquery(
129+
s"subquery#${exprId.id}", executedPlan)
130+
subqueryMap.put(exprId.id, subquery)
124131
case expressions.InSubquery(_, ListQuery(query, _, exprId, _))
125132
if !subqueryMap.contains(exprId.id) =>
126133
val executedPlan = compileSubquery(query)

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ case class PlanAdaptiveSubqueries(
3030
plan.transformAllExpressions {
3131
case expressions.ScalarSubquery(_, _, exprId) =>
3232
execution.ScalarSubquery(subqueryMap(exprId.id), exprId)
33+
case expressions.MultiScalarSubquery(_, exprId) =>
34+
execution.MultiScalarSubqueryExec(subqueryMap(exprId.id), exprId)
3335
case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) =>
3436
val expr = if (values.length == 1) {
3537
values.head

sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,50 @@ case class ScalarSubquery(
106106
}
107107
}
108108

109+
case class MultiScalarSubqueryExec(
110+
plan: BaseSubqueryExec,
111+
exprId: ExprId)
112+
extends ExecSubqueryExpression with LeafLike[Expression] {
113+
114+
override def dataType: DataType = plan.schema
115+
override def nullable: Boolean = true
116+
override def toString: String = plan.simpleString(SQLConf.get.maxToStringFields)
117+
override def withNewPlan(query: BaseSubqueryExec): MultiScalarSubqueryExec = copy(plan = query)
118+
119+
override def semanticEquals(other: Expression): Boolean = other match {
120+
case s: MultiScalarSubqueryExec => plan.sameResult(s.plan)
121+
case _ => false
122+
}
123+
124+
// the first column in first row from `query`.
125+
@volatile private var result: Any = _
126+
@volatile private var updated: Boolean = false
127+
128+
def updateResult(): Unit = {
129+
val rows = plan.executeCollect()
130+
if (rows.length > 1) {
131+
sys.error(s"more than one row returned by a subquery used as an expression:\n$plan")
132+
}
133+
if (rows.length == 1) {
134+
result = rows(0)
135+
} else {
136+
// If there is no rows returned, the result should be null.
137+
result = null
138+
}
139+
updated = true
140+
}
141+
142+
override def eval(input: InternalRow): Any = {
143+
require(updated, s"$this has not finished")
144+
result
145+
}
146+
147+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
148+
require(updated, s"$this has not finished")
149+
Literal.create(result, dataType).doGenCode(ctx, ev)
150+
}
151+
}
152+
109153
/**
110154
* The physical node of in-subquery. This is for Dynamic Partition Pruning only, as in-subquery
111155
* coming from the original query will always be converted to joins.
@@ -183,6 +227,12 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
183227
SubqueryExec.createForScalarSubquery(
184228
s"scalar-subquery#${subquery.exprId.id}", executedPlan),
185229
subquery.exprId)
230+
case subquery: expressions.MultiScalarSubquery =>
231+
val executedPlan = QueryExecution.prepareExecutedPlan(sparkSession, subquery.plan)
232+
MultiScalarSubqueryExec(
233+
SubqueryExec.createForScalarSubquery(
234+
s"multi-scalar-subquery#${subquery.exprId.id}", executedPlan),
235+
subquery.exprId)
186236
case expressions.InSubquery(values, ListQuery(query, _, exprId, _)) =>
187237
val expr = if (values.length == 1) {
188238
values.head

0 commit comments

Comments
 (0)