Skip to content

Commit 2b0cc4e

Browse files
maropuhvanhovell
authored andcommitted
[SPARK-12978][SQL] Skip unnecessary final group-by when input data already clustered with group-by keys
This ticket targets the optimization to skip an unnecessary group-by operation below; Without opt.: ``` == Physical Plan == TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Final,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)], output=[col0#159,sum(col1)#177,avg(col2)#178]) +- TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Partial,isDistinct=false),(avg(col2#161),mode=Partial,isDistinct=false)], output=[col0#159,sum#200,sum#201,count#202L]) +- TungstenExchange hashpartitioning(col0#159,200), None +- InMemoryColumnarTableScan [col0#159,col1#160,col2#161], InMemoryRelation [col0#159,col1#160,col2#161], true, 10000, StorageLevel(true, true, false, true, 1), ConvertToUnsafe, None ``` With opt.: ``` == Physical Plan == TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Complete,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)], output=[col0#159,sum(col1)#177,avg(col2)#178]) +- TungstenExchange hashpartitioning(col0#159,200), None +- InMemoryColumnarTableScan [col0#159,col1#160,col2#161], InMemoryRelation [col0#159,col1#160,col2#161], true, 10000, StorageLevel(true, true, false, true, 1), ConvertToUnsafe, None ``` Author: Takeshi YAMAMURO <[email protected]> Closes #10896 from maropu/SkipGroupbySpike.
1 parent 6b8cb1f commit 2b0cc4e

File tree

8 files changed

+257
-224
lines changed

8 files changed

+257
-224
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,24 +259,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
259259
}
260260

261261
val aggregateOperator =
262-
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
263-
if (functionsWithDistinct.nonEmpty) {
264-
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
265-
"aggregate functions which don't support partial aggregation.")
266-
} else {
267-
aggregate.AggUtils.planAggregateWithoutPartial(
268-
groupingExpressions,
269-
aggregateExpressions,
270-
resultExpressions,
271-
planLater(child))
272-
}
273-
} else if (functionsWithDistinct.isEmpty) {
262+
if (functionsWithDistinct.isEmpty) {
274263
aggregate.AggUtils.planAggregateWithoutDistinct(
275264
groupingExpressions,
276265
aggregateExpressions,
277266
resultExpressions,
278267
planLater(child))
279268
} else {
269+
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
270+
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
271+
"aggregate functions which don't support partial aggregation.")
272+
}
280273
aggregate.AggUtils.planAggregateWithOneDistinct(
281274
groupingExpressions,
282275
functionsWithDistinct,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 115 additions & 135 deletions
Large diffs are not rendered by default.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.aggregate
19+
20+
import org.apache.spark.sql.catalyst.expressions._
21+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
22+
import org.apache.spark.sql.catalyst.plans.physical._
23+
import org.apache.spark.sql.execution.SparkPlan
24+
import org.apache.spark.sql.execution.UnaryExecNode
25+
26+
/**
27+
* A base class for aggregate implementation.
28+
*/
29+
abstract class AggregateExec extends UnaryExecNode {
30+
31+
def requiredChildDistributionExpressions: Option[Seq[Expression]]
32+
def groupingExpressions: Seq[NamedExpression]
33+
def aggregateExpressions: Seq[AggregateExpression]
34+
def aggregateAttributes: Seq[Attribute]
35+
def initialInputBufferOffset: Int
36+
def resultExpressions: Seq[NamedExpression]
37+
38+
protected[this] val aggregateBufferAttributes = {
39+
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
40+
}
41+
42+
override def producedAttributes: AttributeSet =
43+
AttributeSet(aggregateAttributes) ++
44+
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
45+
AttributeSet(aggregateBufferAttributes)
46+
47+
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
48+
49+
override def requiredChildDistribution: List[Distribution] = {
50+
requiredChildDistributionExpressions match {
51+
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
52+
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
53+
case None => UnspecifiedDistribution :: Nil
54+
}
55+
}
56+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.errors._
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate._
2626
import org.apache.spark.sql.catalyst.expressions.codegen._
27-
import org.apache.spark.sql.catalyst.plans.physical._
2827
import org.apache.spark.sql.execution._
2928
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
3029
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
@@ -42,11 +41,7 @@ case class HashAggregateExec(
4241
initialInputBufferOffset: Int,
4342
resultExpressions: Seq[NamedExpression],
4443
child: SparkPlan)
45-
extends UnaryExecNode with CodegenSupport {
46-
47-
private[this] val aggregateBufferAttributes = {
48-
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
49-
}
44+
extends AggregateExec with CodegenSupport {
5045

5146
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
5247

@@ -60,21 +55,6 @@ case class HashAggregateExec(
6055
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
6156
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))
6257

63-
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
64-
65-
override def producedAttributes: AttributeSet =
66-
AttributeSet(aggregateAttributes) ++
67-
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
68-
AttributeSet(aggregateBufferAttributes)
69-
70-
override def requiredChildDistribution: List[Distribution] = {
71-
requiredChildDistributionExpressions match {
72-
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
73-
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
74-
case None => UnspecifiedDistribution :: Nil
75-
}
76-
}
77-
7858
// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
7959
// map and/or the sort-based aggregation once it has processed a given number of input rows.
8060
private val testFallbackStartsAt: Option[(Int, Int)] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors._
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate._
25-
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
26-
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
25+
import org.apache.spark.sql.execution.SparkPlan
2726
import org.apache.spark.sql.execution.metric.SQLMetrics
2827
import org.apache.spark.util.Utils
2928

@@ -38,30 +37,11 @@ case class SortAggregateExec(
3837
initialInputBufferOffset: Int,
3938
resultExpressions: Seq[NamedExpression],
4039
child: SparkPlan)
41-
extends UnaryExecNode {
42-
43-
private[this] val aggregateBufferAttributes = {
44-
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
45-
}
46-
47-
override def producedAttributes: AttributeSet =
48-
AttributeSet(aggregateAttributes) ++
49-
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
50-
AttributeSet(aggregateBufferAttributes)
40+
extends AggregateExec {
5141

5242
override lazy val metrics = Map(
5343
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
5444

55-
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
56-
57-
override def requiredChildDistribution: List[Distribution] = {
58-
requiredChildDistributionExpressions match {
59-
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
60-
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
61-
case None => UnspecifiedDistribution :: Nil
62-
}
63-
}
64-
6545
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
6646
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
6747
}

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.plans.physical._
2222
import org.apache.spark.sql.catalyst.rules.Rule
2323
import org.apache.spark.sql.execution._
24+
import org.apache.spark.sql.execution.aggregate.AggUtils
25+
import org.apache.spark.sql.execution.aggregate.PartialAggregate
2426
import org.apache.spark.sql.internal.SQLConf
2527

2628
/**
@@ -151,18 +153,30 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
151153
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
152154
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
153155
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
154-
var children: Seq[SparkPlan] = operator.children
155-
assert(requiredChildDistributions.length == children.length)
156-
assert(requiredChildOrderings.length == children.length)
156+
assert(requiredChildDistributions.length == operator.children.length)
157+
assert(requiredChildOrderings.length == operator.children.length)
157158

158-
// Ensure that the operator's children satisfy their output distribution requirements:
159-
children = children.zip(requiredChildDistributions).map {
160-
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
161-
child
162-
case (child, BroadcastDistribution(mode)) =>
163-
BroadcastExchangeExec(mode, child)
164-
case (child, distribution) =>
165-
ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
159+
def createShuffleExchange(dist: Distribution, child: SparkPlan) =
160+
ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child)
161+
162+
var (parent, children) = operator match {
163+
case PartialAggregate(childDist) if !operator.outputPartitioning.satisfies(childDist) =>
164+
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
165+
// aggregation and a shuffle are added as children.
166+
val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
167+
(mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil)
168+
case _ =>
169+
// Ensure that the operator's children satisfy their output distribution requirements:
170+
val childrenWithDist = operator.children.zip(requiredChildDistributions)
171+
val newChildren = childrenWithDist.map {
172+
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
173+
child
174+
case (child, BroadcastDistribution(mode)) =>
175+
BroadcastExchangeExec(mode, child)
176+
case (child, distribution) =>
177+
createShuffleExchange(distribution, child)
178+
}
179+
(operator, newChildren)
166180
}
167181

168182
// If the operator has multiple children and specifies child output distributions (e.g. join),
@@ -246,7 +260,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
246260
}
247261
}
248262

249-
operator.withNewChildren(children)
263+
parent.withNewChildren(children)
250264
}
251265

252266
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
12481248
}
12491249

12501250
/**
1251-
* Verifies that there is no Exchange between the Aggregations for `df`
1251+
* Verifies that there is a single Aggregation for `df`
12521252
*/
1253-
private def verifyNonExchangingAgg(df: DataFrame) = {
1253+
private def verifyNonExchangingSingleAgg(df: DataFrame) = {
12541254
var atFirstAgg: Boolean = false
12551255
df.queryExecution.executedPlan.foreach {
12561256
case agg: HashAggregateExec =>
1257-
atFirstAgg = !atFirstAgg
1258-
case _ =>
12591257
if (atFirstAgg) {
1260-
fail("Should not have operators between the two aggregations")
1258+
fail("Should not have back to back Aggregates")
12611259
}
1260+
atFirstAgg = true
1261+
case _ =>
12621262
}
12631263
}
12641264

@@ -1292,9 +1292,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
12921292
// Group by the column we are distributed by. This should generate a plan with no exchange
12931293
// between the aggregates
12941294
val df3 = testData.repartition($"key").groupBy("key").count()
1295-
verifyNonExchangingAgg(df3)
1296-
verifyNonExchangingAgg(testData.repartition($"key", $"value")
1295+
verifyNonExchangingSingleAgg(df3)
1296+
verifyNonExchangingSingleAgg(testData.repartition($"key", $"value")
12971297
.groupBy("key", "value").count())
1298+
verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count())
12981299

12991300
// Grouping by just the first distributeBy expr, need to exchange.
13001301
verifyExchangingAgg(testData.repartition($"key", $"value")

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.rdd.RDD
21-
import org.apache.spark.sql.{execution, Row}
21+
import org.apache.spark.sql.{execution, DataFrame, Row}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
2424
import org.apache.spark.sql.catalyst.plans.Inner
@@ -37,36 +37,65 @@ class PlannerSuite extends SharedSQLContext {
3737

3838
setupTestData()
3939

40-
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
40+
private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = {
4141
val planner = spark.sessionState.planner
4242
import planner._
43-
val plannedOption = Aggregation(query).headOption
44-
val planned =
45-
plannedOption.getOrElse(
46-
fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
47-
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
48-
49-
// For the new aggregation code path, there will be four aggregate operator for
50-
// distinct aggregations.
51-
assert(
52-
aggregations.size == 2 || aggregations.size == 4,
53-
s"The plan of query $query does not have partial aggregations.")
43+
val ensureRequirements = EnsureRequirements(spark.sessionState.conf)
44+
val planned = Aggregation(query).headOption.map(ensureRequirements(_))
45+
.getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
46+
planned.collect { case n if n.nodeName contains "Aggregate" => n }
5447
}
5548

5649
test("count is partially aggregated") {
5750
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
58-
testPartialAggregationPlan(query)
51+
assert(testPartialAggregationPlan(query).size == 2,
52+
s"The plan of query $query does not have partial aggregations.")
5953
}
6054

6155
test("count distinct is partially aggregated") {
6256
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
6357
testPartialAggregationPlan(query)
58+
// For the new aggregation code path, there will be four aggregate operator for distinct
59+
// aggregations.
60+
assert(testPartialAggregationPlan(query).size == 4,
61+
s"The plan of query $query does not have partial aggregations.")
6462
}
6563

6664
test("mixed aggregates are partially aggregated") {
6765
val query =
6866
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
69-
testPartialAggregationPlan(query)
67+
// For the new aggregation code path, there will be four aggregate operator for distinct
68+
// aggregations.
69+
assert(testPartialAggregationPlan(query).size == 4,
70+
s"The plan of query $query does not have partial aggregations.")
71+
}
72+
73+
test("non-partial aggregation for aggregates") {
74+
withTempView("testNonPartialAggregation") {
75+
val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
76+
val row = Row.fromSeq(Seq.fill(1)(null))
77+
val rowRDD = sparkContext.parallelize(row :: Nil)
78+
spark.createDataFrame(rowRDD, schema).repartition($"value")
79+
.createOrReplaceTempView("testNonPartialAggregation")
80+
81+
val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value")
82+
.queryExecution.executedPlan
83+
84+
// If input data are already partitioned and the same columns are used in grouping keys and
85+
// aggregation values, no partial aggregation exist in query plans.
86+
val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
87+
assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.")
88+
89+
val planned2 = sql(
90+
"""
91+
|SELECT t.value, SUM(DISTINCT t.value)
92+
|FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
93+
|GROUP BY t.value
94+
""".stripMargin).queryExecution.executedPlan
95+
96+
val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
97+
assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.")
98+
}
7099
}
71100

72101
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {

0 commit comments

Comments
 (0)