Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ package object dsl {
def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = {
val aliasedExprs = aggregateExprs.map {
case ne: NamedExpression => ne
case e => Alias(e, e.toString)()
case e => UnresolvedAlias(e)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems a small bug in test, the name will be an unresolved string if there is no alias specified.

}
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._

/**
* The rule is applied both normal and AQE Optimizer. It optimizes plan using max rows:
* - if the max rows of the child of sort is less than or equal to 1, remove the sort
* - if the max rows per partition of the child of local sort is less than or equal to 1,
* remove the local sort
* - if the max rows of the child of aggregate is less than or equal to 1 and its child and
* it's grouping only(include the rewritten distinct plan), convert aggregate to project
* - if the max rows of the child of aggregate is less than or equal to 1,
* set distinct to false in all aggregate expression
*/
object OptimizeOneRowPlan extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUpWithPruning(_.containsAnyPattern(SORT, AGGREGATE), ruleId) {
case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) => child
case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) => child
case agg @ Aggregate(_, _, child) if agg.groupOnly && child.maxRows.exists(_ <= 1L) =>
Project(agg.aggregateExpressions, child)
case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) =>
agg.transformExpressions {
case aggExpr: AggregateExpression if aggExpr.isDistinct =>
aggExpr.copy(isDistinct = false)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
// PropagateEmptyRelation can change the nullability of an attribute from nullable to
// non-nullable when an empty relation child of a Union is removed
UpdateAttributeNullability) :+
Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan) :+
// The following batch should be executed after batch "Join Reorder" and "LocalRelation".
Batch("Check Cartesian Products", Once,
CheckCartesianProducts) :+
Expand Down Expand Up @@ -1390,15 +1391,14 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
* Removes Sort operations if they don't affect the final output ordering.
* Note that changes in the final output ordering may affect the file size (SPARK-32318).
* This rule handles the following cases:
* 1) if the child maximum number of rows less than or equal to 1
* 2) if the sort order is empty or the sort order does not have any reference
* 3) if the Sort operator is a local sort and the child is already sorted
* 4) if there is another Sort operator separated by 0...n Project, Filter, Repartition or
* 1) if the sort order is empty or the sort order does not have any reference
* 2) if the Sort operator is a local sort and the child is already sorted
* 3) if there is another Sort operator separated by 0...n Project, Filter, Repartition or
* RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators
* 5) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or
* 4) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or
* RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only
* and the Join condition is deterministic
* 6) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or
* 5) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or
* RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only
* and the aggregate function is order irrelevant
*/
Expand All @@ -1407,7 +1407,6 @@ object EliminateSorts extends Rule[LogicalPlan] {
_.containsPattern(SORT))(applyLocally)

private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) => recursiveRemoveSort(child)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, are you sure the new rule can fully cover this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is. The EliminateLimits only run Once , and the added rule run fixedPoint. It's no harmful since we have transformWithPruning

case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
if (newOrders.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeOneRowPlan" ::
"org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeRepartition" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeWindowFunctions" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class AnalysisErrorSuite extends AnalysisTest {
.where(sum($"b") > 0)
.orderBy($"havingCondition".asc),
"MISSING_COLUMN",
Array("havingCondition", "max('b)"))
Array("havingCondition", "max(b)"))

errorTest(
"unresolved star expansion in max",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,14 +422,4 @@ class EliminateSortsSuite extends AnalysisTest {
comparePlans(optimized, correctAnswer)
}
}

test("SPARK-35906: Remove order by if the maximum number of rows less than or equal to 1") {
comparePlans(
Optimize.execute(testRelation.groupBy()(count(1).as("cnt")).orderBy('cnt.asc)).analyze,
testRelation.groupBy()(count(1).as("cnt")).analyze)

comparePlans(
Optimize.execute(testRelation.limit(Literal(1)).orderBy('a.asc).orderBy('a.asc)).analyze,
testRelation.limit(Literal(1)).analyze)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class OptimizeOneRowPlanSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Replace Operators", Once, ReplaceDistinctWithAggregate) ::
Batch("Eliminate Sorts", Once, EliminateSorts) ::
Batch("Optimize One Row Plan", FixedPoint(10), OptimizeOneRowPlan) :: Nil
}

private val t1 = LocalRelation.fromExternalRows(Seq($"a".int), data = Seq(Row(1)))
private val t2 = LocalRelation.fromExternalRows(Seq($"a".int), data = Seq(Row(1), Row(2)))

test("SPARK-35906: Remove order by if the maximum number of rows less than or equal to 1") {
comparePlans(
Optimize.execute(t2.groupBy()(count(1).as("cnt")).orderBy('cnt.asc)).analyze,
t2.groupBy()(count(1).as("cnt")).analyze)

comparePlans(
Optimize.execute(t2.limit(Literal(1)).orderBy('a.asc).orderBy('a.asc)).analyze,
t2.limit(Literal(1)).analyze)
}

test("Remove sort") {
// remove local sort
val plan1 = LocalLimit(0, t1).union(LocalLimit(0, t2)).sortBy($"a".desc).analyze
val expected = LocalLimit(0, t1).union(LocalLimit(0, t2)).analyze
comparePlans(Optimize.execute(plan1), expected)

// do not remove
val plan2 = t2.orderBy($"a".desc).analyze
comparePlans(Optimize.execute(plan2), plan2)

val plan3 = t2.sortBy($"a".desc).analyze
comparePlans(Optimize.execute(plan3), plan3)
}

test("Convert group only aggregate to project") {
val plan1 = t1.groupBy($"a")($"a").analyze
comparePlans(Optimize.execute(plan1), t1.select($"a").analyze)

val plan2 = t1.groupBy($"a" + 1)($"a" + 1).analyze
comparePlans(Optimize.execute(plan2), t1.select($"a" + 1).analyze)

// do not remove
val plan3 = t2.groupBy($"a")($"a").analyze
comparePlans(Optimize.execute(plan3), plan3)

val plan4 = t1.groupBy($"a")(sum($"a")).analyze
comparePlans(Optimize.execute(plan4), plan4)

val plan5 = t1.groupBy()(sum($"a")).analyze
comparePlans(Optimize.execute(plan5), plan5)
}

test("Remove distinct in aggregate expression") {
val plan1 = t1.groupBy($"a")(sumDistinct($"a").as("s")).analyze
val expected1 = t1.groupBy($"a")(sum($"a").as("s")).analyze
comparePlans(Optimize.execute(plan1), expected1)

val plan2 = t1.groupBy()(sumDistinct($"a").as("s")).analyze
val expected2 = t1.groupBy()(sum($"a").as("s")).analyze
comparePlans(Optimize.execute(plan2), expected2)

// do not remove
val plan3 = t2.groupBy($"a")(sumDistinct($"a").as("s")).analyze
comparePlans(Optimize.execute(plan3), plan3)
}

test("Remove in complex case") {
val plan1 = t1.groupBy($"a")($"a").orderBy($"a".asc).analyze
val expected1 = t1.select($"a").analyze
comparePlans(Optimize.execute(plan1), expected1)

val plan2 = t1.groupBy($"a")(sumDistinct($"a").as("s")).orderBy($"s".asc).analyze
val expected2 = t1.groupBy($"a")(sum($"a").as("s")).analyze
comparePlans(Optimize.execute(plan2), expected2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability
import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, EliminateLimits}
import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, EliminateLimits, OptimizeOneRowPlan}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -40,7 +40,8 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] {
ConvertToLocalRelation,
UpdateAttributeNullability),
Batch("Dynamic Join Selection", Once, DynamicJoinSelection),
Batch("Eliminate Limits", fixedPoint, EliminateLimits)
Batch("Eliminate Limits", fixedPoint, EliminateLimits),
Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan)
)

final override protected def batches: Seq[Batch] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
Expand Down Expand Up @@ -126,6 +127,12 @@ class AdaptiveQueryExecSuite
}
}

private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = {
collect(plan) {
case agg: BaseAggregateExec => agg
}
}

private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = {
collect(plan) {
case l: CollectLimitExec => l
Expand Down Expand Up @@ -2484,6 +2491,53 @@ class AdaptiveQueryExecSuite
}
}
}

test("SPARK-38162: Optimize one row plan in AQE Optimizer") {
withTempView("v") {
spark.sparkContext.parallelize(
(1 to 4).map(i => TestData(i, i.toString)), 2)
.toDF("c1", "c2").createOrReplaceTempView("v")

// remove sort
val (origin1, adaptive1) = runAdaptiveAndVerifyResult(
"""
|SELECT * FROM v where c1 = 1 order by c1, c2
|""".stripMargin)
assert(findTopLevelSort(origin1).size == 1)
assert(findTopLevelSort(adaptive1).isEmpty)

// convert group only aggregate to project
val (origin2, adaptive2) = runAdaptiveAndVerifyResult(
"""
|SELECT distinct c1 FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if there is no /*+ repartition(c1) */?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nothing happens, the aggregate node is inside the logical query stage, so we can not optimize it at logical side:

LogicalQueryStage(logicalAgg: Aggregate, physicalAgg: BaseAggregateExec)

And the plan inside physicalAgg:

BaseAggregateExec final
  ShuffleQueryStage
    Exchange
      BaseAggregateExec partial

|""".stripMargin)
assert(findTopLevelAggregate(origin2).size == 2)
assert(findTopLevelAggregate(adaptive2).isEmpty)

// remove distinct in aggregate
val (origin3, adaptive3) = runAdaptiveAndVerifyResult(
"""
|SELECT sum(distinct c1) FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question

|""".stripMargin)
assert(findTopLevelAggregate(origin3).size == 4)
assert(findTopLevelAggregate(adaptive3).size == 2)

// do not optimize if the aggregate is inside query stage
val (origin4, adaptive4) = runAdaptiveAndVerifyResult(
"""
|SELECT distinct c1 FROM v where c1 = 1
|""".stripMargin)
assert(findTopLevelAggregate(origin4).size == 2)
assert(findTopLevelAggregate(adaptive4).size == 2)

val (origin5, adaptive5) = runAdaptiveAndVerifyResult(
"""
|SELECT sum(distinct c1) FROM v where c1 = 1
|""".stripMargin)
assert(findTopLevelAggregate(origin5).size == 4)
assert(findTopLevelAggregate(adaptive5).size == 4)
}
}
}

/**
Expand Down