Skip to content

Commit f7d29df

Browse files
committed
Update
1 parent 9938252 commit f7d29df

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
255255
// Setup unique distinct aggregate children.
256256
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
257257
val distinctAggChildAttrMap = distinctAggChildren.map { e =>
258-
ExpressionSet(Seq(e)) -> AttributeReference(e.sql, e.dataType, nullable = true)()
258+
e.canonicalized -> AttributeReference(e.sql, e.dataType, nullable = true)()
259259
}
260260
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
261261
// Setup all the filters in distinct aggregate.
@@ -293,9 +293,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
293293
val naf = if (af.children.forall(_.foldable)) {
294294
af
295295
} else {
296-
patchAggregateFunctionChildren(af) { x1 =>
297-
val es = ExpressionSet(Seq(x1))
298-
distinctAggChildAttrLookup.get(es)
296+
patchAggregateFunctionChildren(af) { x =>
297+
distinctAggChildAttrLookup.get(x.canonicalized)
299298
}
300299
}
301300
val newCondition = if (condition.isDefined) {

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
562562
// [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is disallowed because those two distinct
563563
// aggregates have different column expressions.
564564
val distinctExpressions =
565-
functionsWithDistinct.flatMap(
566-
_.aggregateFunction.children.filterNot(_.foldable)).distinct
565+
functionsWithDistinct.head.aggregateFunction.children.filterNot(_.foldable)
567566
val normalizedNamedDistinctExpressions = distinctExpressions.map { e =>
568567
// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here
569568
// because `distinctExpressions` is not extracted during logical phase.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,17 @@ object AggUtils {
219219
}
220220

221221
// 3. Create an Aggregate operator for partial aggregation (for distinct)
222-
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions, distinctAttributes)
222+
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions.map(_.canonicalized),
223+
distinctAttributes)
223224
val rewrittenDistinctFunctions = functionsWithDistinct.map {
224225
// Children of an AggregateFunction with DISTINCT keyword has already
225226
// been evaluated. At here, we need to replace original children
226227
// to AttributeReferences.
227228
case agg @ AggregateExpression(aggregateFunction, mode, true, _, _) =>
228-
aggregateFunction.transformDown(distinctColumnAttributeLookup)
229-
.asInstanceOf[AggregateFunction]
229+
aggregateFunction.transformDown {
230+
case e: Expression if distinctColumnAttributeLookup.contains(e.canonicalized) =>
231+
distinctColumnAttributeLookup(e.canonicalized)
232+
}.asInstanceOf[AggregateFunction]
230233
case agg =>
231234
throw new IllegalArgumentException(
232235
"Non-distinct aggregate is found in functionsWithDistinct " +

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
9595
// 2 distinct columns with different order
9696
val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i")
9797
assertNoExpand(query3.queryExecution.executedPlan)
98+
99+
// SPARK-40382: 1 distinct expression with cosmetic differences
100+
val query4 = sql("SELECT sum(DISTINCT j), max(DISTINCT J) FROM v GROUP BY i")
101+
assertNoExpand(query4.queryExecution.executedPlan)
98102
}
99103
}
100104

0 commit comments

Comments
 (0)