diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 8b3243067a16c..b61c4b8d065f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -197,6 +197,12 @@ package object dsl { Max(e).toAggregateExpression(isDistinct = false, filter = filter) def maxDistinct(e: Expression, filter: Option[Expression] = None): Expression = Max(e).toAggregateExpression(isDistinct = true, filter = filter) + def bitAnd(e: Expression, filter: Option[Expression] = None): Expression = + BitAndAgg(e).toAggregateExpression(isDistinct = false, filter = filter) + def bitOr(e: Expression, filter: Option[Expression] = None): Expression = + BitOrAgg(e).toAggregateExpression(isDistinct = false, filter = filter) + def bitXor(e: Expression, filter: Option[Expression] = None): Expression = + BitXorAgg(e).toAggregateExpression(isDistinct = false, filter = filter) def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def coalesce(args: Expression*): Expression = Coalesce(args) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9216ab1631e7b..b7791cd442694 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1044,7 +1044,7 @@ object EliminateSorts extends Rule[LogicalPlan] { private def isOrderIrrelevantAggs(aggs: Seq[NamedExpression]): Boolean = { def isOrderIrrelevantAggFunction(func: AggregateFunction): Boolean = func match { - case _: Min | _: Max | _: Count => true + case _: Min | _: Max | _: Count | _: BitAggregate => true // Arithmetic operations for floating-point values are order-sensitive // (they are not associative). case _: Sum | _: Average | _: CentralMomentAgg => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index e2b599a7c090c..265f0a9936759 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -197,13 +197,25 @@ class EliminateSortsSuite extends PlanTest { comparePlans(optimizedThrice, correctAnswerThrice) } - test("remove orderBy in groupBy clause with count aggs") { - val projectPlan = testRelation.select('a, 'b) - val unnecessaryOrderByPlan = projectPlan.orderBy('a.asc, 'b.desc) - val groupByPlan = unnecessaryOrderByPlan.groupBy('a)(count(1)) - val optimized = Optimize.execute(groupByPlan.analyze) - val correctAnswer = projectPlan.groupBy('a)(count(1)).analyze - comparePlans(optimized, correctAnswer) + test("remove orderBy in groupBy clause with order irrelevant aggs") { + Seq( + (e : Expression) => min(e), + (e : Expression) => minDistinct(e), + (e : Expression) => max(e), + (e : Expression) => maxDistinct(e), + (e : Expression) => count(e), + (e : Expression) => countDistinct(e), + (e : Expression) => bitAnd(e), + (e : Expression) => bitOr(e), + (e : Expression) => bitXor(e) + ).foreach(agg => { + val projectPlan = testRelation.select('a, 'b) + val unnecessaryOrderByPlan = projectPlan.orderBy('a.asc, 'b.desc) + val groupByPlan = unnecessaryOrderByPlan.groupBy('a)(agg('b)) + val optimized = Optimize.execute(groupByPlan.analyze) + val correctAnswer = projectPlan.groupBy('a)(agg('b)).analyze + comparePlans(optimized, correctAnswer) + }) } test("remove orderBy in groupBy clause with sum aggs") {