-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-38162][SQL] Optimize one row plan in normal and AQE Optimizer #35473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.rules._ | ||
| import org.apache.spark.sql.catalyst.trees.TreePattern._ | ||
|
|
||
| /** | ||
| * The rule is applied both normal and AQE Optimizer. It optimizes plan using max rows: | ||
| * - if the max rows of the child of sort is less than or equal to 1, remove the sort | ||
| * - if the max rows per partition of the child of local sort is less than or equal to 1, | ||
| * remove the local sort | ||
| * - if the max rows of the child of aggregate is less than or equal to 1 and its child and | ||
| * it's grouping only(include the rewritten distinct plan), convert aggregate to project | ||
| * - if the max rows of the child of aggregate is less than or equal to 1, | ||
| * set distinct to false in all aggregate expression | ||
| */ | ||
| object OptimizeOneRowPlan extends Rule[LogicalPlan] { | ||
| override def apply(plan: LogicalPlan): LogicalPlan = { | ||
| plan.transformUpWithPruning(_.containsAnyPattern(SORT, AGGREGATE), ruleId) { | ||
| case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) => child | ||
| case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) => child | ||
| case agg @ Aggregate(_, _, child) if agg.groupOnly && child.maxRows.exists(_ <= 1L) => | ||
| Project(agg.aggregateExpressions, child) | ||
| case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) => | ||
| agg.transformExpressions { | ||
| case aggExpr: AggregateExpression if aggExpr.isDistinct => | ||
| aggExpr.copy(isDistinct = false) | ||
| } | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -238,6 +238,7 @@ abstract class Optimizer(catalogManager: CatalogManager) | |
| // PropagateEmptyRelation can change the nullability of an attribute from nullable to | ||
| // non-nullable when an empty relation child of a Union is removed | ||
| UpdateAttributeNullability) :+ | ||
| Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan) :+ | ||
| // The following batch should be executed after batch "Join Reorder" and "LocalRelation". | ||
| Batch("Check Cartesian Products", Once, | ||
| CheckCartesianProducts) :+ | ||
|
|
@@ -1390,15 +1391,14 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { | |
| * Removes Sort operations if they don't affect the final output ordering. | ||
| * Note that changes in the final output ordering may affect the file size (SPARK-32318). | ||
| * This rule handles the following cases: | ||
| * 1) if the child maximum number of rows less than or equal to 1 | ||
| * 2) if the sort order is empty or the sort order does not have any reference | ||
| * 3) if the Sort operator is a local sort and the child is already sorted | ||
| * 4) if there is another Sort operator separated by 0...n Project, Filter, Repartition or | ||
| * 1) if the sort order is empty or the sort order does not have any reference | ||
| * 2) if the Sort operator is a local sort and the child is already sorted | ||
| * 3) if there is another Sort operator separated by 0...n Project, Filter, Repartition or | ||
| * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators | ||
| * 5) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or | ||
| * 4) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or | ||
| * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only | ||
| * and the Join condition is deterministic | ||
| * 6) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or | ||
| * 5) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or | ||
| * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only | ||
| * and the aggregate function is order irrelevant | ||
| */ | ||
|
|
@@ -1407,7 +1407,6 @@ object EliminateSorts extends Rule[LogicalPlan] { | |
| _.containsPattern(SORT))(applyLocally) | ||
|
|
||
| private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { | ||
| case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) => recursiveRemoveSort(child) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, are you sure the new rule can fully cover this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is. The |
||
| case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) => | ||
| val newOrders = orders.filterNot(_.child.foldable) | ||
| if (newOrders.isEmpty) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.expressions.Literal | ||
| import org.apache.spark.sql.catalyst.plans._ | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
|
|
||
| class OptimizeOneRowPlanSuite extends PlanTest { | ||
| object Optimize extends RuleExecutor[LogicalPlan] { | ||
| val batches = | ||
| Batch("Replace Operators", Once, ReplaceDistinctWithAggregate) :: | ||
| Batch("Eliminate Sorts", Once, EliminateSorts) :: | ||
| Batch("Optimize One Row Plan", FixedPoint(10), OptimizeOneRowPlan) :: Nil | ||
| } | ||
|
|
||
| private val t1 = LocalRelation.fromExternalRows(Seq($"a".int), data = Seq(Row(1))) | ||
| private val t2 = LocalRelation.fromExternalRows(Seq($"a".int), data = Seq(Row(1), Row(2))) | ||
|
|
||
| test("SPARK-35906: Remove order by if the maximum number of rows less than or equal to 1") { | ||
| comparePlans( | ||
| Optimize.execute(t2.groupBy()(count(1).as("cnt")).orderBy('cnt.asc)).analyze, | ||
| t2.groupBy()(count(1).as("cnt")).analyze) | ||
|
|
||
| comparePlans( | ||
| Optimize.execute(t2.limit(Literal(1)).orderBy('a.asc).orderBy('a.asc)).analyze, | ||
| t2.limit(Literal(1)).analyze) | ||
| } | ||
|
|
||
| test("Remove sort") { | ||
| // remove local sort | ||
| val plan1 = LocalLimit(0, t1).union(LocalLimit(0, t2)).sortBy($"a".desc).analyze | ||
| val expected = LocalLimit(0, t1).union(LocalLimit(0, t2)).analyze | ||
| comparePlans(Optimize.execute(plan1), expected) | ||
|
|
||
| // do not remove | ||
| val plan2 = t2.orderBy($"a".desc).analyze | ||
| comparePlans(Optimize.execute(plan2), plan2) | ||
|
|
||
| val plan3 = t2.sortBy($"a".desc).analyze | ||
| comparePlans(Optimize.execute(plan3), plan3) | ||
| } | ||
|
|
||
| test("Convert group only aggregate to project") { | ||
| val plan1 = t1.groupBy($"a")($"a").analyze | ||
| comparePlans(Optimize.execute(plan1), t1.select($"a").analyze) | ||
|
|
||
| val plan2 = t1.groupBy($"a" + 1)($"a" + 1).analyze | ||
| comparePlans(Optimize.execute(plan2), t1.select($"a" + 1).analyze) | ||
|
|
||
| // do not remove | ||
| val plan3 = t2.groupBy($"a")($"a").analyze | ||
| comparePlans(Optimize.execute(plan3), plan3) | ||
|
|
||
| val plan4 = t1.groupBy($"a")(sum($"a")).analyze | ||
| comparePlans(Optimize.execute(plan4), plan4) | ||
|
|
||
| val plan5 = t1.groupBy()(sum($"a")).analyze | ||
| comparePlans(Optimize.execute(plan5), plan5) | ||
| } | ||
|
|
||
| test("Remove distinct in aggregate expression") { | ||
| val plan1 = t1.groupBy($"a")(sumDistinct($"a").as("s")).analyze | ||
| val expected1 = t1.groupBy($"a")(sum($"a").as("s")).analyze | ||
| comparePlans(Optimize.execute(plan1), expected1) | ||
|
|
||
| val plan2 = t1.groupBy()(sumDistinct($"a").as("s")).analyze | ||
| val expected2 = t1.groupBy()(sum($"a").as("s")).analyze | ||
| comparePlans(Optimize.execute(plan2), expected2) | ||
|
|
||
| // do not remove | ||
| val plan3 = t2.groupBy($"a")(sumDistinct($"a").as("s")).analyze | ||
| comparePlans(Optimize.execute(plan3), plan3) | ||
| } | ||
|
|
||
| test("Remove in complex case") { | ||
| val plan1 = t1.groupBy($"a")($"a").orderBy($"a".asc).analyze | ||
| val expected1 = t1.select($"a").analyze | ||
| comparePlans(Optimize.execute(plan1), expected1) | ||
|
|
||
| val plan2 = t1.groupBy($"a")(sumDistinct($"a").as("s")).orderBy($"s".asc).analyze | ||
| val expected2 = t1.groupBy($"a")(sum($"a").as("s")).analyze | ||
| comparePlans(Optimize.execute(plan2), expected2) | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} | |
| import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} | ||
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} | ||
| import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec} | ||
| import org.apache.spark.sql.execution.aggregate.BaseAggregateExec | ||
| import org.apache.spark.sql.execution.command.DataWritingCommandExec | ||
| import org.apache.spark.sql.execution.datasources.noop.NoopDataSource | ||
| import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec | ||
|
|
@@ -126,6 +127,12 @@ class AdaptiveQueryExecSuite | |
| } | ||
| } | ||
|
|
||
| private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = { | ||
| collect(plan) { | ||
| case agg: BaseAggregateExec => agg | ||
| } | ||
| } | ||
|
|
||
| private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = { | ||
| collect(plan) { | ||
| case l: CollectLimitExec => l | ||
|
|
@@ -2484,6 +2491,53 @@ class AdaptiveQueryExecSuite | |
| } | ||
| } | ||
| } | ||
|
|
||
| test("SPARK-38162: Optimize one row plan in AQE Optimizer") { | ||
| withTempView("v") { | ||
| spark.sparkContext.parallelize( | ||
| (1 to 4).map(i => TestData(i, i.toString)), 2) | ||
| .toDF("c1", "c2").createOrReplaceTempView("v") | ||
|
|
||
| // remove sort | ||
| val (origin1, adaptive1) = runAdaptiveAndVerifyResult( | ||
| """ | ||
| |SELECT * FROM v where c1 = 1 order by c1, c2 | ||
| |""".stripMargin) | ||
| assert(findTopLevelSort(origin1).size == 1) | ||
| assert(findTopLevelSort(adaptive1).isEmpty) | ||
|
|
||
| // convert group only aggregate to project | ||
| val (origin2, adaptive2) = runAdaptiveAndVerifyResult( | ||
| """ | ||
| |SELECT distinct c1 FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens if there is no
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nothing happens, the aggregate node is inside the logical query stage, so we can not optimize it at logical side:
And the plan inside physicalAgg: BaseAggregateExec final
ShuffleQueryStage
Exchange
BaseAggregateExec partial |
||
| |""".stripMargin) | ||
| assert(findTopLevelAggregate(origin2).size == 2) | ||
| assert(findTopLevelAggregate(adaptive2).isEmpty) | ||
|
|
||
| // remove distinct in aggregate | ||
| val (origin3, adaptive3) = runAdaptiveAndVerifyResult( | ||
| """ | ||
| |SELECT sum(distinct c1) FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same question |
||
| |""".stripMargin) | ||
| assert(findTopLevelAggregate(origin3).size == 4) | ||
| assert(findTopLevelAggregate(adaptive3).size == 2) | ||
|
|
||
| // do not optimize if the aggregate is inside query stage | ||
| val (origin4, adaptive4) = runAdaptiveAndVerifyResult( | ||
| """ | ||
| |SELECT distinct c1 FROM v where c1 = 1 | ||
| |""".stripMargin) | ||
| assert(findTopLevelAggregate(origin4).size == 2) | ||
| assert(findTopLevelAggregate(adaptive4).size == 2) | ||
|
|
||
| val (origin5, adaptive5) = runAdaptiveAndVerifyResult( | ||
| """ | ||
| |SELECT sum(distinct c1) FROM v where c1 = 1 | ||
| |""".stripMargin) | ||
| assert(findTopLevelAggregate(origin5).size == 4) | ||
| assert(findTopLevelAggregate(adaptive5).size == 4) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems a small bug in test, the name will be an unresolved string if there is no alias specified.