diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 302aae08d588b..4d23e5e8a65b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -448,10 +448,28 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { resultExpressions, planLater(child)) } else { + // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain + // more than one DISTINCT aggregate function, all of those functions will have the + // same column expressions. For example, it would be valid for functionsWithDistinct + // to be [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but + // [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is disallowed because those two distinct + // aggregates have different column expressions. + val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children + val normalizedNamedDistinctExpressions = distinctExpressions.map { e => + // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here + // because `distinctExpressions` is not extracted during logical phase. + NormalizeFloatingNumbers.normalize(e) match { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + } + AggUtils.planAggregateWithOneDistinct( normalizedGroupingExpressions, functionsWithDistinct, functionsWithoutDistinct, + distinctExpressions, + normalizedNamedDistinctExpressions, resultExpressions, planLater(child)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 56a287d4d0279..761ac20e84744 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -135,20 +135,12 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression], functionsWithoutDistinct: Seq[AggregateExpression], + distinctExpressions: Seq[Expression], + normalizedNamedDistinctExpressions: Seq[NamedExpression], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one - // DISTINCT aggregate function, all of those functions will have the same column expressions. - // For example, it would be valid for functionsWithDistinct to be - // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is - // disallowed because those two distinct aggregates have different column expressions. - val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children - val namedDistinctExpressions = distinctExpressions.map { - case ne: NamedExpression => ne - case other => Alias(other, other.toString)() - } - val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) + val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) // 1. Create an Aggregate Operator for partial aggregations. @@ -159,7 +151,7 @@ object AggUtils { // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. createAggregate( - groupingExpressions = groupingExpressions ++ namedDistinctExpressions, + groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ distinctAttributes ++ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 2293d4ae61aff..f7438f3ffec04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1012,4 +1012,20 @@ class DataFrameAggregateSuite extends QueryTest } } } + + test("SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate") { + withTempView("view") { + val nan1 = java.lang.Float.intBitsToFloat(0x7f800001) + val nan2 = java.lang.Float.intBitsToFloat(0x7fffffff) + + Seq(("mithunr", Float.NaN), + ("mithunr", nan1), + ("mithunr", nan2), + ("abellina", 1.0f), + ("abellina", 2.0f)).toDF("uid", "score").createOrReplaceTempView("view") + + val df = spark.sql("select uid, count(distinct score) from view group by 1 order by 1 asc") + checkAnswer(df, Row("abellina", 2) :: Row("mithunr", 1) :: Nil) + } + } }