-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-34079][SQL] Merge non-correlated scalar subqueries #32298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
e0e39d5
0a7e0e2
e35cdc1
0cff7b2
22e833d
42add09
e63111d
c84f0ee
ee8f12a
17fd666
6134fa9
2828345
1f2f75c
a3e84a4
0fe66dc
100cb9c
9d8dd6b
f83f22b
41c0f0a
d10a8be
db34640
d081885
e98754a
bb623cf
ae1d84e
2eb14f1
060e4b7
d86d2c4
0a97c8b
61f2b34
532d05e
c488377
e0a7610
63c3709
dabbea4
4d97de5
cc8690e
83c78ca
3130913
3e8f7fa
252c9b1
fa5e786
e292732
963c423
9efaf2a
96a502d
5b91d61
8bcf515
87ba289
6d5a124
851ca29
96d0cab
a57ed32
0b34d83
4985d43
de9b312
13a2fad
92ce6e5
67ffae6
a32a85c
1bc8a45
224edef
a7fd1c5
a5eb5df
8457148
1ff64e4
13a1cdb
4da3fe6
96ed6fd
dbe81e2
ba299d5
65f3425
dc5e9b9
c64373b
3993eab
f93283d
8c5c9ac
3b7ad2c
c268580
169fd6b
19128ff
1c4d14b
2590edf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,184 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project} | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.catalyst.trees.TreePattern.{MULTI_SCALAR_SUBQUERY, SCALAR_SUBQUERY} | ||
|
|
||
| /** | ||
| * This rule tries to merge multiple non-correlated [[ScalarSubquery]]s into a | ||
| * [[MultiScalarSubquery]] to compute multiple scalar values once. | ||
| * | ||
| * The process is the following: | ||
| * - While traversing through the plan each [[ScalarSubquery]] plan is tried to merge into the cache | ||
| * of already seen subquery plans. If merge is possible then cache is updated with the merged | ||
| * subquery plan, if not then the new subquery plan is added to the cache. | ||
| * - The original [[ScalarSubquery]] expression is replaced to a reference pointing to its cached | ||
| * version in this form: `GetStructField(MultiScalarSubquery(SubqueryReference(...)))`. | ||
| * - A second traversal checks if a [[SubqueryReference]] is pointing to a subquery plan that | ||
| * returns multiple values and either replaces only [[SubqueryReference]] to the cached plan or | ||
| * restores the whole expression to its original [[ScalarSubquery]] form. | ||
| * - [[ReuseSubquery]] rule makes sure that merged subqueries are computed once. | ||
| * | ||
| * Eg. the following query: | ||
| * | ||
| * SELECT | ||
| * (SELECT avg(a) FROM t GROUP BY b), | ||
| * (SELECT sum(b) FROM t GROUP BY b) | ||
| * | ||
| * is optimized from: | ||
| * | ||
| * Project [scalar-subquery#231 [] AS scalarsubquery()#241, | ||
| * scalar-subquery#232 [] AS scalarsubquery()#242L] | ||
| * : :- Aggregate [b#234], [avg(a#233) AS avg(a)#236] | ||
| * : : +- Relation default.t[a#233,b#234] parquet | ||
| * : +- Aggregate [b#240], [sum(b#240) AS sum(b)#238L] | ||
| * : +- Project [b#240] | ||
| * : +- Relation default.t[a#239,b#240] parquet | ||
|
||
| * +- OneRowRelation | ||
| * | ||
| * to: | ||
| * | ||
| * Project [multi-scalar-subquery#231.avg(a) AS scalarsubquery()#241, | ||
| * multi-scalar-subquery#232.sum(b) AS scalarsubquery()#242L] | ||
| * : :- Aggregate [b#234], [avg(a#233) AS avg(a)#236, sum(b#234) AS sum(b)#238L] | ||
| * : : +- Project [a#233, b#234] | ||
| * : : +- Relation default.t[a#233,b#234] parquet | ||
| * : +- Aggregate [b#234], [avg(a#233) AS avg(a)#236, sum(b#234) AS sum(b)#238L] | ||
| * : +- Project [a#233, b#234] | ||
| * : +- Relation default.t[a#233,b#234] parquet | ||
| * +- OneRowRelation | ||
| */ | ||
| object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper { | ||
| def apply(plan: LogicalPlan): LogicalPlan = { | ||
| if (conf.scalarSubqueryMergeEabled && conf.subqueryReuseEnabled) { | ||
| val mergedSubqueries = ArrayBuffer.empty[LogicalPlan] | ||
| removeReferences(mergeAndInsertReferences(plan, mergedSubqueries), mergedSubqueries) | ||
| } else { | ||
| plan | ||
| } | ||
| } | ||
|
|
||
| private def mergeAndInsertReferences( | ||
| plan: LogicalPlan, | ||
| mergedSubqueries: ArrayBuffer[LogicalPlan]): LogicalPlan = { | ||
| plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY), ruleId) { | ||
| case s: ScalarSubquery if s.children.isEmpty => | ||
attilapiros marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| val (mergedPlan, ordinal) = mergeAndGetReference(s.plan, mergedSubqueries) | ||
| GetStructField(MultiScalarSubquery(mergedPlan, s.exprId), ordinal) | ||
| } | ||
| } | ||
|
|
||
| case class SubqueryReference( | ||
| index: Int, | ||
| mergedSubqueries: ArrayBuffer[LogicalPlan]) extends LeafNode { | ||
| override def stringArgs: Iterator[Any] = Iterator(index) | ||
|
|
||
| override def output: Seq[Attribute] = mergedSubqueries(index).output | ||
| } | ||
|
|
||
| private def mergeAndGetReference( | ||
| plan: LogicalPlan, | ||
| mergedSubqueries: ArrayBuffer[LogicalPlan]): (SubqueryReference, Int) = { | ||
| mergedSubqueries.zipWithIndex.collectFirst { | ||
| Function.unlift { case (s, i) => mergePlans(plan, s).map(_ -> i) } | ||
| }.map { case ((mergedPlan, outputMap), i) => | ||
| mergedSubqueries(i) = mergedPlan | ||
| SubqueryReference(i, mergedSubqueries) -> | ||
| mergedPlan.output.indexOf(outputMap(plan.output.head)) | ||
| }.getOrElse { | ||
| mergedSubqueries += plan | ||
| SubqueryReference(mergedSubqueries.length - 1, mergedSubqueries) -> 0 | ||
| } | ||
| } | ||
|
|
||
| private def mergePlans( | ||
peter-toth marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| newPlan: LogicalPlan, | ||
| existingPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = { | ||
| (newPlan, existingPlan) match { | ||
| case (np, ep) if np.canonicalized == ep.canonicalized => | ||
| Some(ep -> AttributeMap(np.output.zip(ep.output))) | ||
| case (np: Project, ep: Project) => | ||
| mergePlans(np.child, ep.child).map { case (mergedChild, outputMap) => | ||
| val newProjectList = replaceAttributes(np.projectList, outputMap) | ||
| val newOutputMap = createOutputMap(np.projectList, newProjectList) | ||
| Project(distinctExpressions(ep.projectList ++ newProjectList), mergedChild) -> | ||
| newOutputMap | ||
| } | ||
| case (np, ep: Project) => | ||
| mergePlans(np, ep.child).map { case (mergedChild, outputMap) => | ||
| Project(distinctExpressions(ep.projectList ++ outputMap.values), mergedChild) -> outputMap | ||
| } | ||
| case (np: Project, ep) => | ||
| mergePlans(np.child, ep).map { case (mergedChild, outputMap) => | ||
| val newProjectList = replaceAttributes(np.projectList, outputMap) | ||
| val newOutputMap = createOutputMap(np.projectList, newProjectList) | ||
| Project(distinctExpressions(ep.output ++ newProjectList), mergedChild) -> newOutputMap | ||
| } | ||
|
||
| case (np: Aggregate, ep: Aggregate) => | ||
|
||
| mergePlans(np.child, ep.child).flatMap { case (mergedChild, outputMap) => | ||
| val newGroupingExpression = replaceAttributes(np.groupingExpressions, outputMap) | ||
| if (ExpressionSet(newGroupingExpression) == ExpressionSet(ep.groupingExpressions)) { | ||
| val newAggregateExpressions = replaceAttributes(np.aggregateExpressions, outputMap) | ||
| val newOutputMap = createOutputMap(np.aggregateExpressions, newAggregateExpressions) | ||
|
||
| Some(Aggregate(ep.groupingExpressions, | ||
| distinctExpressions(ep.aggregateExpressions ++ newAggregateExpressions), | ||
| mergedChild) -> newOutputMap) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
| case _ => | ||
| None | ||
| } | ||
| } | ||
|
|
||
| private def replaceAttributes[T <: Expression]( | ||
| expressions: Seq[T], | ||
| outputMap: AttributeMap[Attribute]) = { | ||
| expressions.map(_.transform { | ||
| case a: Attribute => outputMap.getOrElse(a, a) | ||
| }.asInstanceOf[T]) | ||
| } | ||
|
|
||
| private def createOutputMap(from: Seq[NamedExpression], to: Seq[NamedExpression]) = { | ||
| AttributeMap(from.map(_.toAttribute).zip(to.map(_.toAttribute))) | ||
| } | ||
|
|
||
| private def distinctExpressions(expressions: Seq[NamedExpression]) = { | ||
| ExpressionSet(expressions).toSeq.asInstanceOf[Seq[NamedExpression]] | ||
| } | ||
|
|
||
| private def removeReferences( | ||
| plan: LogicalPlan, | ||
| mergedSubqueries: ArrayBuffer[LogicalPlan]): LogicalPlan = { | ||
| plan.transformAllExpressionsWithPruning(_.containsAnyPattern(MULTI_SCALAR_SUBQUERY), ruleId) { | ||
| case gsf @ GetStructField(mss @ MultiScalarSubquery(sr: SubqueryReference, _), _, _) => | ||
| val dereferencedPlan = removeReferences(mergedSubqueries(sr.index), mergedSubqueries) | ||
| if (dereferencedPlan.outputSet.size > 1) { | ||
| gsf.copy(child = mss.copy(plan = dereferencedPlan)) | ||
| } else { | ||
| ScalarSubquery(dereferencedPlan, exprId = mss.exprId) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1353,6 +1353,14 @@ object SQLConf { | |
| .booleanConf | ||
| .createWithDefault(true) | ||
|
|
||
| val SCALAR_SUBQUERY_MERGE_ENABLED = | ||
| buildConf("spark.sql.scalarSubqueyMerge.enabled") | ||
|
||
| .internal() | ||
| .doc("When true, the planner will try to merge scalar subqueries and re-use them.") | ||
| .version("3.2.0") | ||
| .booleanConf | ||
| .createWithDefault(true) | ||
|
|
||
| val REMOVE_REDUNDANT_PROJECTS_ENABLED = buildConf("spark.sql.execution.removeRedundantProjects") | ||
| .internal() | ||
| .doc("Whether to remove redundant project exec node based on children's output and " + | ||
|
|
@@ -3481,6 +3489,8 @@ class SQLConf extends Serializable with Logging { | |
|
|
||
| def subqueryReuseEnabled: Boolean = getConf(SUBQUERY_REUSE_ENABLED) | ||
|
|
||
| def scalarSubqueryMergeEabled: Boolean = getConf(SCALAR_SUBQUERY_MERGE_ENABLED) | ||
|
|
||
| def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) | ||
|
|
||
| def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.expressions.{GetStructField, MultiScalarSubquery, ScalarSubquery} | ||
| import org.apache.spark.sql.catalyst.plans._ | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.rules._ | ||
|
|
||
| class MergeScalarSubqueriesSuite extends PlanTest { | ||
|
|
||
| private object Optimize extends RuleExecutor[LogicalPlan] { | ||
| val batches = | ||
| Batch("MergeScalarSubqueries", Once, MergeScalarSubqueries) :: Nil | ||
| } | ||
|
|
||
| test("Simple non-correlated scalar subquery merge") { | ||
| val testRelation = LocalRelation('a.int, 'b.int) | ||
|
|
||
| val subquery1 = testRelation | ||
| .groupBy('b)(max('a)) | ||
| val subquery2 = testRelation | ||
| .groupBy('b)(sum('a)) | ||
| val originalQuery = testRelation | ||
| .select(ScalarSubquery(subquery1), ScalarSubquery(subquery2)) | ||
|
|
||
| val multiSubquery = testRelation | ||
| .groupBy('b)(max('a), sum('a)).analyze | ||
| val correctAnswer = testRelation | ||
| .select(GetStructField(MultiScalarSubquery(multiSubquery), 0).as("scalarsubquery()"), | ||
| GetStructField(MultiScalarSubquery(multiSubquery), 1).as("scalarsubquery()")) | ||
|
|
||
| // checkAnalysis is disabled because `Analizer` is not prepared for `MultiScalarSubquery` nodes | ||
| // as only `Optimizer` can insert such a node to the plan | ||
| comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer, false) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to create a new subquery expression? It seems like we can just use
CreateNamedStructinScalarSubquery.planThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropped 'MultiScalarSubquery` in 1f2f75c, will change the docs and the PR description soon.