Skip to content

Commit 94922d7

Browse files
maropuliancheng
authored andcommitted
[SPARK-17289][SQL] Fix a bug to satisfy sort requirements in partial aggregations
## What changes were proposed in this pull request? Partial aggregations are generated in `EnsureRequirements`, but the planner fails to check if partial aggregation satisfies sort requirements. For the following query: ``` val df2 = (0 to 1000).map(x => (x % 2, x.toString)).toDF("a", "b").createOrReplaceTempView("t2") spark.sql("select max(b) from t2 group by a").explain(true) ``` Now, the SortAggregator won't insert Sort operator before partial aggregation, this will break sort-based partial aggregation. ``` == Physical Plan == SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)#17]) +- *Sort [a#5 ASC], false, 0 +- Exchange hashpartitioning(a#5, 200) +- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, max#19]) +- LocalTableScan [a#5, b#6] ``` Actually, a correct plan is: ``` == Physical Plan == SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)#17]) +- *Sort [a#5 ASC], false, 0 +- Exchange hashpartitioning(a#5, 200) +- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, max#19]) +- *Sort [a#5 ASC], false, 0 +- LocalTableScan [a#5, b#6] ``` ## How was this patch tested? Added tests in `PlannerSuite`. Author: Takeshi YAMAMURO <[email protected]> Closes #14865 from maropu/SPARK-17289.
1 parent 8fb445d commit 94922d7

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
164164
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
165165
// aggregation and a shuffle are added as children.
166166
val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
167-
(mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil)
167+
(mergeAgg, createShuffleExchange(
168+
requiredChildDistributions.head, ensureDistributionAndOrdering(mapSideAgg)) :: Nil)
168169
case _ =>
169170
// Ensure that the operator's children satisfy their output distribution requirements:
170171
val childrenWithDist = operator.children.zip(requiredChildDistributions)

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.rdd.RDD
21-
import org.apache.spark.sql.{execution, DataFrame, Row}
21+
import org.apache.spark.sql.{execution, Row}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.Inner
2525
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
2626
import org.apache.spark.sql.catalyst.plans.physical._
27+
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
2728
import org.apache.spark.sql.execution.columnar.InMemoryRelation
2829
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange}
2930
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
@@ -70,6 +71,25 @@ class PlannerSuite extends SharedSQLContext {
7071
s"The plan of query $query does not have partial aggregations.")
7172
}
7273

74+
test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") {
75+
withTempView("testSortBasedPartialAggregation") {
76+
val schema = StructType(
77+
StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil)
78+
val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString)))
79+
spark.createDataFrame(rowRDD, schema)
80+
.createOrReplaceTempView("testSortBasedPartialAggregation")
81+
82+
// This test assumes a query below uses sort-based aggregations
83+
val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key")
84+
.queryExecution.executedPlan
85+
// This line extracts both SortAggregate and Sort operators
86+
val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n }
87+
val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n }
88+
assert(extractedOps.size == 4 && aggOps.size == 2,
89+
s"The plan $planned does not have correct sort-based partial aggregate pairs.")
90+
}
91+
}
92+
7393
test("non-partial aggregation for aggregates") {
7494
withTempView("testNonPartialAggregation") {
7595
val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)

0 commit comments

Comments
 (0)