diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 47260cfb59bb1..4226a1b0aea73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -349,11 +349,17 @@ abstract class Optimizer(catalogManager: CatalogManager) */ object EliminateDistinct extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformExpressions { - case ae: AggregateExpression if ae.isDistinct => - ae.aggregateFunction match { - case _: Max | _: Min => ae.copy(isDistinct = false) - case _ => ae - } + case ae: AggregateExpression if ae.isDistinct && isDuplicateAgnostic(ae.aggregateFunction) => + ae.copy(isDistinct = false) + } + + private def isDuplicateAgnostic(af: AggregateFunction): Boolean = af match { + case _: Max => true + case _: Min => true + case _: BitAndAgg => true + case _: BitOrAgg => true + case _: CollectSet => true + case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala index 51c751923e414..0848d5609ff02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -32,25 +34,24 @@ class EliminateDistinctSuite extends PlanTest { val testRelation = LocalRelation('a.int) - test("Eliminate Distinct in Max") { - val query = testRelation - .select(maxDistinct('a).as('result)) - .analyze - val answer = testRelation - .select(max('a).as('result)) - .analyze - assert(query != answer) - comparePlans(Optimize.execute(query), answer) - } - - test("Eliminate Distinct in Min") { - val query = testRelation - .select(minDistinct('a).as('result)) - .analyze - val answer = testRelation - .select(min('a).as('result)) - .analyze - assert(query != answer) - comparePlans(Optimize.execute(query), answer) + Seq( + Max(_), + Min(_), + BitAndAgg(_), + BitOrAgg(_), + CollectSet(_: Expression) + ).foreach { + aggBuilder => + val agg = aggBuilder('a) + test(s"Eliminate Distinct in ${agg.prettyName}") { + val query = testRelation + .select(agg.toAggregateExpression(isDistinct = true).as('result)) + .analyze + val answer = testRelation + .select(agg.toAggregateExpression(isDistinct = false).as('result)) + .analyze + assert(query != answer) + comparePlans(Optimize.execute(query), answer) + } } }