Skip to content

Commit 5a6367b

Browse files
committed
[SPARK-34581][SQL] Don't optimize out grouping expressions from aggregate expressions without aggregate function
1 parent 132cbf0 commit 5a6367b

File tree

24 files changed

+239
-138
lines changed

24 files changed

+239
-138
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ object AggregateExpression {
8080
filter,
8181
NamedExpression.newExprId)
8282
}
83+
84+
def containsAggregate(expr: Expression): Boolean = {
85+
expr.find(isAggregate).isDefined
86+
}
87+
88+
def isAggregate(expr: Expression): Boolean = {
89+
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
90+
}
8391
}
8492

8593
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,14 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
21+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2222
import org.apache.spark.sql.catalyst.rules.Rule
2323

2424
/**
2525
* Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions.
2626
*/
2727
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
2828
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
29-
// One place where this optimization is invalid is an aggregation where the select
30-
// list expression is a function of a grouping expression:
31-
//
32-
// SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b)
33-
//
34-
// cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this
35-
// optimization for Aggregates (although this misses some cases where the optimization
36-
// can be made).
37-
case a: Aggregate => a
3829
case p => p.transformExpressionsUp {
3930
// Remove redundant field extraction.
4031
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
148148
EliminateView,
149149
ReplaceExpressions,
150150
RewriteNonCorrelatedExists,
151+
PullOutGroupingExpressions,
151152
ComputeCurrentTime,
152153
GetCurrentDatabaseAndCatalog(catalogManager)) ::
153154
//////////////////////////////////////////////////////////////////////////////////////////
@@ -524,23 +525,19 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
524525
}
525526

526527
private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = {
527-
val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate)
528+
val upperHasNoAggregateExpressions =
529+
!upper.aggregateExpressions.exists(AggregateExpression.containsAggregate)
528530

529531
lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet(
530532
lower
531533
.aggregateExpressions
532534
.filter(_.deterministic)
533-
.filter(!isAggregate(_))
535+
.filterNot(AggregateExpression.containsAggregate)
534536
.map(_.toAttribute)
535537
))
536538

537539
upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg
538540
}
539-
540-
private def isAggregate(expr: Expression): Boolean = {
541-
expr.find(e => e.isInstanceOf[AggregateExpression] ||
542-
PythonUDF.isGroupedAggPandasUDF(e)).isDefined
543-
}
544541
}
545542

546543
/**
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
23+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
24+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
25+
import org.apache.spark.sql.catalyst.rules.Rule
26+
27+
/**
28+
* This rule ensures that [[Aggregate]] nodes doesn't contain complex grouping expressions in the
29+
* optimization phase.
30+
*
31+
* Complex grouping expressions are pulled out to a [[Project]] node under [[Aggregate]] and are
32+
* referenced in both grouping expressions and aggregate expressions without aggregate functions.
33+
* These references ensure that optimization rules don't change the aggregate expressions to invalid
34+
* ones that no longer refer to any grouping expressions and also simplify the expression
35+
* transformations on the node (need to transform the expression only once).
36+
*
37+
* For example, in the following query Spark shouldn't optimize the aggregate expression
38+
* `Not(IsNull(c))` to `IsNotNull(c)` as the grouping expression is `IsNull(c)`:
39+
* SELECT not(c IS NULL)
40+
* FROM t
41+
* GROUP BY c IS NULL
42+
* Instead, the aggregate expression references a `_groupingexpression` attribute:
43+
* Aggregate [_groupingexpression#233], [NOT _groupingexpression#233 AS (NOT (c IS NULL))#230]
44+
* +- Project [isnull(c#219) AS _groupingexpression#233]
45+
* +- LocalRelation [c#219]
46+
*/
47+
object PullOutGroupingExpressions extends Rule[LogicalPlan] {
48+
override def apply(plan: LogicalPlan): LogicalPlan = {
49+
plan transform {
50+
case a: Aggregate if a.resolved =>
51+
val complexGroupingExpressionMap = mutable.LinkedHashMap.empty[Expression, NamedExpression]
52+
val newGroupingExpressions = a.groupingExpressions
53+
.filterNot(AggregateExpression.containsAggregate)
54+
.map {
55+
case e if AggregateExpression.isAggregate(e) => e
56+
case e if !e.foldable && e.children.nonEmpty =>
57+
complexGroupingExpressionMap
58+
.getOrElseUpdate(e.canonicalized, Alias(e, s"_groupingexpression")())
59+
.toAttribute
60+
case o => o
61+
}
62+
if (complexGroupingExpressionMap.nonEmpty) {
63+
def replaceComplexGroupingExpressions(e: Expression): Expression = {
64+
e match {
65+
case _ if AggregateExpression.isAggregate(e) => e
66+
case _ if complexGroupingExpressionMap.contains(e.canonicalized) =>
67+
complexGroupingExpressionMap.get(e.canonicalized).map(_.toAttribute).getOrElse(e)
68+
case _ => e.mapChildren(replaceComplexGroupingExpressions)
69+
}
70+
}
71+
72+
val newAggregateExpressions = a.aggregateExpressions
73+
.map(replaceComplexGroupingExpressions(_).asInstanceOf[NamedExpression])
74+
val newChild = Project(a.child.output ++ complexGroupingExpressionMap.values, a.child)
75+
Aggregate(newGroupingExpressions, newAggregateExpressions, newChild)
76+
} else {
77+
a
78+
}
79+
}
80+
}
81+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,9 @@ object PhysicalAggregation {
297297
val aggregateExpressions = resultExpressions.flatMap { expr =>
298298
expr.collect {
299299
// addExpr() always returns false for non-deterministic expressions and do not add them.
300-
case agg: AggregateExpression
301-
if !equivalentAggregateExpressions.addExpr(agg) => agg
302-
case udf: PythonUDF
303-
if PythonUDF.isGroupedAggPandasUDF(udf) &&
304-
!equivalentAggregateExpressions.addExpr(udf) => udf
300+
case a
301+
if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) =>
302+
a
305303
}
306304
}
307305

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
3636

3737
object Optimizer extends RuleExecutor[LogicalPlan] {
3838
val batches =
39+
Batch("Finish Analysis", Once,
40+
PullOutGroupingExpressions) ::
3941
Batch("collapse projections", FixedPoint(10),
4042
CollapseProject) ::
4143
Batch("Constant Folding", FixedPoint(10),
@@ -57,7 +59,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
5759
private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = {
5860
val optimized = Optimizer.execute(originalQuery.analyze)
5961
assert(optimized.resolved, "optimized plans must be still resolvable")
60-
comparePlans(optimized, correctAnswer.analyze)
62+
comparePlans(optimized, PullOutGroupingExpressions(correctAnswer.analyze))
6163
}
6264

6365
test("explicit get from namedStruct") {
@@ -405,14 +407,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
405407
val arrayAggRel = relation.groupBy(
406408
CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0))
407409
checkRule(arrayAggRel, arrayAggRel)
408-
409-
// This could be done if we had a more complex rule that checks that
410-
// the CreateMap does not come from key.
411-
val originalQuery = relation
412-
.groupBy('id)(
413-
GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a"
414-
)
415-
checkRule(originalQuery, originalQuery)
416410
}
417411

418412
test("SPARK-23500: namedStruct and getField in the same Project #1") {

sql/core/src/test/resources/sql-tests/inputs/group-by.sql

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,13 @@ SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(
179179

180180
-- Aggregate with multiple distinct decimal columns
181181
SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col);
182+
183+
-- SPARK-34581: Don't optimize out grouping expressions from aggregate expressions without aggregate function
184+
SELECT not(a IS NULL), count(*) AS c
185+
FROM testData
186+
GROUP BY a IS NULL;
187+
188+
SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
189+
FROM testData
190+
GROUP BY a IS NULL;
191+

sql/core/src/test/resources/sql-tests/results/group-by.sql.out

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 62
2+
-- Number of queries: 64
33

44

55
-- !query
@@ -642,3 +642,25 @@ SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1
642642
struct<avg(DISTINCT decimal_col):decimal(13,4),sum(DISTINCT decimal_col):decimal(19,0)>
643643
-- !query output
644644
1.0000 1
645+
646+
647+
-- !query
648+
SELECT not(a IS NULL), count(*) AS c
649+
FROM testData
650+
GROUP BY a IS NULL
651+
-- !query schema
652+
struct<(NOT (a IS NULL)):boolean,c:bigint>
653+
-- !query output
654+
false 2
655+
true 7
656+
657+
658+
-- !query
659+
SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
660+
FROM testData
661+
GROUP BY a IS NULL
662+
-- !query schema
663+
struct<(IF((NOT (a IS NULL)), rand(0), 1)):double,c:bigint>
664+
-- !query output
665+
0.7604953758285915 7
666+
1.0 2

sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,19 @@ Right keys [1]: [i_item_sk#16]
199199
Join condition: None
200200

201201
(23) Project [codegen id : 8]
202-
Output [3]: [d_date#12, i_item_sk#16, i_item_desc#17]
202+
Output [3]: [d_date#12, i_item_sk#16, substr(i_item_desc#17, 1, 30) AS _groupingexpression#19]
203203
Input [4]: [ss_item_sk#8, d_date#12, i_item_sk#16, i_item_desc#17]
204204

205205
(24) HashAggregate [codegen id : 8]
206-
Input [3]: [d_date#12, i_item_sk#16, i_item_desc#17]
207-
Keys [3]: [substr(i_item_desc#17, 1, 30) AS substr(i_item_desc#17, 1, 30)#19, i_item_sk#16, d_date#12]
206+
Input [3]: [d_date#12, i_item_sk#16, _groupingexpression#19]
207+
Keys [3]: [_groupingexpression#19, i_item_sk#16, d_date#12]
208208
Functions [1]: [partial_count(1)]
209209
Aggregate Attributes [1]: [count#20]
210-
Results [4]: [substr(i_item_desc#17, 1, 30)#19, i_item_sk#16, d_date#12, count#21]
210+
Results [4]: [_groupingexpression#19, i_item_sk#16, d_date#12, count#21]
211211

212212
(25) HashAggregate [codegen id : 8]
213-
Input [4]: [substr(i_item_desc#17, 1, 30)#19, i_item_sk#16, d_date#12, count#21]
214-
Keys [3]: [substr(i_item_desc#17, 1, 30)#19, i_item_sk#16, d_date#12]
213+
Input [4]: [_groupingexpression#19, i_item_sk#16, d_date#12, count#21]
214+
Keys [3]: [_groupingexpression#19, i_item_sk#16, d_date#12]
215215
Functions [1]: [count(1)]
216216
Aggregate Attributes [1]: [count(1)#22]
217217
Results [2]: [i_item_sk#16 AS item_sk#23, count(1)#22 AS count(1)#24]
@@ -406,19 +406,19 @@ Right keys [1]: [i_item_sk#56]
406406
Join condition: None
407407

408408
(69) Project [codegen id : 25]
409-
Output [3]: [d_date#55, i_item_sk#56, i_item_desc#57]
409+
Output [3]: [d_date#55, i_item_sk#56, substr(i_item_desc#57, 1, 30) AS _groupingexpression#58]
410410
Input [4]: [ss_item_sk#54, d_date#55, i_item_sk#56, i_item_desc#57]
411411

412412
(70) HashAggregate [codegen id : 25]
413-
Input [3]: [d_date#55, i_item_sk#56, i_item_desc#57]
414-
Keys [3]: [substr(i_item_desc#57, 1, 30) AS substr(i_item_desc#57, 1, 30)#58, i_item_sk#56, d_date#55]
413+
Input [3]: [d_date#55, i_item_sk#56, _groupingexpression#58]
414+
Keys [3]: [_groupingexpression#58, i_item_sk#56, d_date#55]
415415
Functions [1]: [partial_count(1)]
416416
Aggregate Attributes [1]: [count#59]
417-
Results [4]: [substr(i_item_desc#57, 1, 30)#58, i_item_sk#56, d_date#55, count#60]
417+
Results [4]: [_groupingexpression#58, i_item_sk#56, d_date#55, count#60]
418418

419419
(71) HashAggregate [codegen id : 25]
420-
Input [4]: [substr(i_item_desc#57, 1, 30)#58, i_item_sk#56, d_date#55, count#60]
421-
Keys [3]: [substr(i_item_desc#57, 1, 30)#58, i_item_sk#56, d_date#55]
420+
Input [4]: [_groupingexpression#58, i_item_sk#56, d_date#55, count#60]
421+
Keys [3]: [_groupingexpression#58, i_item_sk#56, d_date#55]
422422
Functions [1]: [count(1)]
423423
Aggregate Attributes [1]: [count(1)#61]
424424
Results [2]: [i_item_sk#56 AS item_sk#23, count(1)#61 AS count(1)#62]

sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/simplified.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ WholeStageCodegen (36)
3434
Sort [item_sk]
3535
Project [item_sk]
3636
Filter [count(1)]
37-
HashAggregate [substr(i_item_desc, 1, 30),i_item_sk,d_date,count] [count(1),item_sk,count(1),count]
38-
HashAggregate [i_item_desc,i_item_sk,d_date] [count,substr(i_item_desc, 1, 30),count]
37+
HashAggregate [_groupingexpression,i_item_sk,d_date,count] [count(1),item_sk,count(1),count]
38+
HashAggregate [_groupingexpression,i_item_sk,d_date] [count,count]
3939
Project [d_date,i_item_sk,i_item_desc]
4040
SortMergeJoin [ss_item_sk,i_item_sk]
4141
InputAdapter
@@ -177,8 +177,8 @@ WholeStageCodegen (36)
177177
Sort [item_sk]
178178
Project [item_sk]
179179
Filter [count(1)]
180-
HashAggregate [substr(i_item_desc, 1, 30),i_item_sk,d_date,count] [count(1),item_sk,count(1),count]
181-
HashAggregate [i_item_desc,i_item_sk,d_date] [count,substr(i_item_desc, 1, 30),count]
180+
HashAggregate [_groupingexpression,i_item_sk,d_date,count] [count(1),item_sk,count(1),count]
181+
HashAggregate [_groupingexpression,i_item_sk,d_date] [count,count]
182182
Project [d_date,i_item_sk,i_item_desc]
183183
SortMergeJoin [ss_item_sk,i_item_sk]
184184
InputAdapter

0 commit comments

Comments
 (0)