Skip to content

Commit aaf632b

Browse files
committed
revert PR#10896 and PR#14865
## What changes were proposed in this pull request? according to the discussion in the original PR #10896 and the new approach PR #14876 , we decided to revert these 2 PRs and go with the new approach. ## How was this patch tested? N/A Author: Wenchen Fan <[email protected]> Closes #14909 from cloud-fan/revert.
1 parent 7a5000f commit aaf632b

File tree

8 files changed

+223
-277
lines changed

8 files changed

+223
-277
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,17 +259,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
259259
}
260260

261261
val aggregateOperator =
262-
if (functionsWithDistinct.isEmpty) {
262+
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
263+
if (functionsWithDistinct.nonEmpty) {
264+
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
265+
"aggregate functions which don't support partial aggregation.")
266+
} else {
267+
aggregate.AggUtils.planAggregateWithoutPartial(
268+
groupingExpressions,
269+
aggregateExpressions,
270+
resultExpressions,
271+
planLater(child))
272+
}
273+
} else if (functionsWithDistinct.isEmpty) {
263274
aggregate.AggUtils.planAggregateWithoutDistinct(
264275
groupingExpressions,
265276
aggregateExpressions,
266277
resultExpressions,
267278
planLater(child))
268279
} else {
269-
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
270-
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
271-
"aggregate functions which don't support partial aggregation.")
272-
}
273280
aggregate.AggUtils.planAggregateWithOneDistinct(
274281
groupingExpressions,
275282
functionsWithDistinct,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 135 additions & 115 deletions
Large diffs are not rendered by default.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala

Lines changed: 0 additions & 56 deletions
This file was deleted.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.errors._
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate._
2626
import org.apache.spark.sql.catalyst.expressions.codegen._
27+
import org.apache.spark.sql.catalyst.plans.physical._
2728
import org.apache.spark.sql.execution._
2829
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
2930
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
@@ -41,7 +42,11 @@ case class HashAggregateExec(
4142
initialInputBufferOffset: Int,
4243
resultExpressions: Seq[NamedExpression],
4344
child: SparkPlan)
44-
extends AggregateExec with CodegenSupport {
45+
extends UnaryExecNode with CodegenSupport {
46+
47+
private[this] val aggregateBufferAttributes = {
48+
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
49+
}
4550

4651
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
4752

@@ -55,6 +60,21 @@ case class HashAggregateExec(
5560
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
5661
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))
5762

63+
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
64+
65+
override def producedAttributes: AttributeSet =
66+
AttributeSet(aggregateAttributes) ++
67+
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
68+
AttributeSet(aggregateBufferAttributes)
69+
70+
override def requiredChildDistribution: List[Distribution] = {
71+
requiredChildDistributionExpressions match {
72+
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
73+
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
74+
case None => UnspecifiedDistribution :: Nil
75+
}
76+
}
77+
5878
// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
5979
// map and/or the sort-based aggregation once it has processed a given number of input rows.
6080
private val testFallbackStartsAt: Option[(Int, Int)] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.errors._
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate._
25-
import org.apache.spark.sql.execution.SparkPlan
25+
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
26+
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
2627
import org.apache.spark.sql.execution.metric.SQLMetrics
2728
import org.apache.spark.util.Utils
2829

@@ -37,11 +38,30 @@ case class SortAggregateExec(
3738
initialInputBufferOffset: Int,
3839
resultExpressions: Seq[NamedExpression],
3940
child: SparkPlan)
40-
extends AggregateExec {
41+
extends UnaryExecNode {
42+
43+
private[this] val aggregateBufferAttributes = {
44+
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
45+
}
46+
47+
override def producedAttributes: AttributeSet =
48+
AttributeSet(aggregateAttributes) ++
49+
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
50+
AttributeSet(aggregateBufferAttributes)
4151

4252
override lazy val metrics = Map(
4353
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
4454

55+
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
56+
57+
override def requiredChildDistribution: List[Distribution] = {
58+
requiredChildDistributionExpressions match {
59+
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
60+
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
61+
case None => UnspecifiedDistribution :: Nil
62+
}
63+
}
64+
4565
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
4666
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
4767
}

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.plans.physical._
2222
import org.apache.spark.sql.catalyst.rules.Rule
2323
import org.apache.spark.sql.execution._
24-
import org.apache.spark.sql.execution.aggregate.AggUtils
25-
import org.apache.spark.sql.execution.aggregate.PartialAggregate
2624
import org.apache.spark.sql.internal.SQLConf
2725

2826
/**
@@ -153,31 +151,18 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
153151
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
154152
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
155153
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
156-
assert(requiredChildDistributions.length == operator.children.length)
157-
assert(requiredChildOrderings.length == operator.children.length)
154+
var children: Seq[SparkPlan] = operator.children
155+
assert(requiredChildDistributions.length == children.length)
156+
assert(requiredChildOrderings.length == children.length)
158157

159-
def createShuffleExchange(dist: Distribution, child: SparkPlan) =
160-
ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child)
161-
162-
var (parent, children) = operator match {
163-
case PartialAggregate(childDist) if !operator.outputPartitioning.satisfies(childDist) =>
164-
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
165-
// aggregation and a shuffle are added as children.
166-
val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
167-
(mergeAgg, createShuffleExchange(
168-
requiredChildDistributions.head, ensureDistributionAndOrdering(mapSideAgg)) :: Nil)
169-
case _ =>
170-
// Ensure that the operator's children satisfy their output distribution requirements:
171-
val childrenWithDist = operator.children.zip(requiredChildDistributions)
172-
val newChildren = childrenWithDist.map {
173-
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
174-
child
175-
case (child, BroadcastDistribution(mode)) =>
176-
BroadcastExchangeExec(mode, child)
177-
case (child, distribution) =>
178-
createShuffleExchange(distribution, child)
179-
}
180-
(operator, newChildren)
158+
// Ensure that the operator's children satisfy their output distribution requirements:
159+
children = children.zip(requiredChildDistributions).map {
160+
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
161+
child
162+
case (child, BroadcastDistribution(mode)) =>
163+
BroadcastExchangeExec(mode, child)
164+
case (child, distribution) =>
165+
ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
181166
}
182167

183168
// If the operator has multiple children and specifies child output distributions (e.g. join),
@@ -270,7 +255,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
270255
}
271256
}
272257

273-
parent.withNewChildren(children)
258+
operator.withNewChildren(children)
274259
}
275260

276261
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
12481248
}
12491249

12501250
/**
1251-
* Verifies that there is a single Aggregation for `df`
1251+
* Verifies that there is no Exchange between the Aggregations for `df`
12521252
*/
1253-
private def verifyNonExchangingSingleAgg(df: DataFrame) = {
1253+
private def verifyNonExchangingAgg(df: DataFrame) = {
12541254
var atFirstAgg: Boolean = false
12551255
df.queryExecution.executedPlan.foreach {
12561256
case agg: HashAggregateExec =>
1257+
atFirstAgg = !atFirstAgg
1258+
case _ =>
12571259
if (atFirstAgg) {
1258-
fail("Should not have back to back Aggregates")
1260+
fail("Should not have operators between the two aggregations")
12591261
}
1260-
atFirstAgg = true
1261-
case _ =>
12621262
}
12631263
}
12641264

@@ -1292,10 +1292,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
12921292
// Group by the column we are distributed by. This should generate a plan with no exchange
12931293
// between the aggregates
12941294
val df3 = testData.repartition($"key").groupBy("key").count()
1295-
verifyNonExchangingSingleAgg(df3)
1296-
verifyNonExchangingSingleAgg(testData.repartition($"key", $"value")
1295+
verifyNonExchangingAgg(df3)
1296+
verifyNonExchangingAgg(testData.repartition($"key", $"value")
12971297
.groupBy("key", "value").count())
1298-
verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count())
12991298

13001299
// Grouping by just the first distributeBy expr, need to exchange.
13011300
verifyExchangingAgg(testData.repartition($"key", $"value")

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 14 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.Inner
2525
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
2626
import org.apache.spark.sql.catalyst.plans.physical._
27-
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
2827
import org.apache.spark.sql.execution.columnar.InMemoryRelation
2928
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange}
3029
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
@@ -38,84 +37,36 @@ class PlannerSuite extends SharedSQLContext {
3837

3938
setupTestData()
4039

41-
private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = {
40+
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
4241
val planner = spark.sessionState.planner
4342
import planner._
44-
val ensureRequirements = EnsureRequirements(spark.sessionState.conf)
45-
val planned = Aggregation(query).headOption.map(ensureRequirements(_))
46-
.getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
47-
planned.collect { case n if n.nodeName contains "Aggregate" => n }
43+
val plannedOption = Aggregation(query).headOption
44+
val planned =
45+
plannedOption.getOrElse(
46+
fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
47+
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
48+
49+
// For the new aggregation code path, there will be four aggregate operator for
50+
// distinct aggregations.
51+
assert(
52+
aggregations.size == 2 || aggregations.size == 4,
53+
s"The plan of query $query does not have partial aggregations.")
4854
}
4955

5056
test("count is partially aggregated") {
5157
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
52-
assert(testPartialAggregationPlan(query).size == 2,
53-
s"The plan of query $query does not have partial aggregations.")
58+
testPartialAggregationPlan(query)
5459
}
5560

5661
test("count distinct is partially aggregated") {
5762
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
5863
testPartialAggregationPlan(query)
59-
// For the new aggregation code path, there will be four aggregate operator for distinct
60-
// aggregations.
61-
assert(testPartialAggregationPlan(query).size == 4,
62-
s"The plan of query $query does not have partial aggregations.")
6364
}
6465

6566
test("mixed aggregates are partially aggregated") {
6667
val query =
6768
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
68-
// For the new aggregation code path, there will be four aggregate operator for distinct
69-
// aggregations.
70-
assert(testPartialAggregationPlan(query).size == 4,
71-
s"The plan of query $query does not have partial aggregations.")
72-
}
73-
74-
test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") {
75-
withTempView("testSortBasedPartialAggregation") {
76-
val schema = StructType(
77-
StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil)
78-
val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString)))
79-
spark.createDataFrame(rowRDD, schema)
80-
.createOrReplaceTempView("testSortBasedPartialAggregation")
81-
82-
// This test assumes a query below uses sort-based aggregations
83-
val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key")
84-
.queryExecution.executedPlan
85-
// This line extracts both SortAggregate and Sort operators
86-
val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n }
87-
val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n }
88-
assert(extractedOps.size == 4 && aggOps.size == 2,
89-
s"The plan $planned does not have correct sort-based partial aggregate pairs.")
90-
}
91-
}
92-
93-
test("non-partial aggregation for aggregates") {
94-
withTempView("testNonPartialAggregation") {
95-
val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
96-
val row = Row.fromSeq(Seq.fill(1)(null))
97-
val rowRDD = sparkContext.parallelize(row :: Nil)
98-
spark.createDataFrame(rowRDD, schema).repartition($"value")
99-
.createOrReplaceTempView("testNonPartialAggregation")
100-
101-
val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value")
102-
.queryExecution.executedPlan
103-
104-
// If input data are already partitioned and the same columns are used in grouping keys and
105-
// aggregation values, no partial aggregation exist in query plans.
106-
val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
107-
assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.")
108-
109-
val planned2 = sql(
110-
"""
111-
|SELECT t.value, SUM(DISTINCT t.value)
112-
|FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
113-
|GROUP BY t.value
114-
""".stripMargin).queryExecution.executedPlan
115-
116-
val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
117-
assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.")
118-
}
69+
testPartialAggregationPlan(query)
11970
}
12071

12172
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {

0 commit comments

Comments
 (0)