diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 979c2805e08b..9e83051313f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -218,6 +218,11 @@ package object dsl { BitOrAgg(e).toAggregateExpression(isDistinct = false, filter = filter) def bitXor(e: Expression, filter: Option[Expression] = None): Expression = BitXorAgg(e).toAggregateExpression(isDistinct = false, filter = filter) + def collectList(e: Expression, filter: Option[Expression] = None): Expression = + CollectList(e).toAggregateExpression(isDistinct = false, filter = filter) + def collectSet(e: Expression, filter: Option[Expression] = None): Expression = + CollectSet(e).toAggregateExpression(isDistinct = false, filter = filter) + def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def coalesce(args: Expression*): Expression = Coalesce(args) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 176da5a566f2..4fe00099ddc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -62,9 +62,6 @@ case class First(child: Expression, ignoreNulls: Boolean) override def nullable: Boolean = true - // First is not a deterministic function. - override lazy val deterministic: Boolean = false - // Return data type. override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 0fe6199cd8c3..5840c783cb8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -61,9 +61,6 @@ case class Last(child: Expression, ignoreNulls: Boolean) override def nullable: Boolean = true - // Last is not a deterministic function. - override lazy val deterministic: Boolean = false - // Return data type. override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index a8db8211a9e4..2514461d4c05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -43,10 +43,6 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper override def dataType: DataType = ArrayType(child.dataType, false) - // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the - // actual order of input rows. - override lazy val deterministic: Boolean = false - override def defaultResult: Option[Literal] = Option(Literal.create(Array(), dataType)) protected def convertToBufferElement(value: Any): Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 73be7902b998..298da4fa322e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -423,6 +423,8 @@ object EliminateDistinct extends Rule[LogicalPlan] { case _: BitAndAgg => true case _: BitOrAgg => true case _: CollectSet => true + case _: First => true + case _: Last => true case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala index 0848d5609ff0..08773720d717 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala @@ -39,11 +39,15 @@ class EliminateDistinctSuite extends PlanTest { Min(_), BitAndAgg(_), BitOrAgg(_), + First(_, ignoreNulls = true), + First(_, ignoreNulls = false), + Last(_, ignoreNulls = true), + Last(_, ignoreNulls = false), CollectSet(_: Expression) ).foreach { aggBuilder => val agg = aggBuilder('a) - test(s"Eliminate Distinct in ${agg.prettyName}") { + test(s"Eliminate Distinct in $agg") { val query = testRelation .select(agg.toAggregateExpression(isDistinct = true).as('result)) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index acdbff95422f..dbbb03a7ae2b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -840,6 +840,29 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("SPARK-32940: aggregate: push filters through first, last and collect") { + Seq( + first(_: Expression), + last(_: Expression), + collectList(_: Expression), + collectSet(_: Expression) + ).foreach { agg => + val originalQuery = testRelation + .groupBy('a)(agg('b)) + .where('a > 42) + .analyze + + val optimized = Optimize.execute(originalQuery) + + val correctAnswer = testRelation + .where('a > 42) + .groupBy('a)(agg('b)) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + test("union") { val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index 5acf997e0a00..bdbb51bf31c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils._ -import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ @@ -183,10 +182,10 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp } test("unwrap cast should skip when expression is non-deterministic or foldable") { - Seq(positiveInt, negativeInt).foreach(v => { - val e = Cast(First(f, ignoreNulls = true), IntegerType) <=> v + Seq(positiveLong, negativeLong).foreach (v => { + val e = Cast(Rand(0), LongType) <=> v assertEquivalent(e, e, evaluate = false) - val e2 = Cast(Literal(30.toShort), IntegerType) >= v + val e2 = Cast(Literal(30), LongType) >= v assertEquivalent(e2, e2, evaluate = false) }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 11b7ee65dad5..594c626a755f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, Project, Reparti import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.{CommandResultExec, UnionExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.aggregate._ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.{DataWritingCommandExec, FunctionsCommand} import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, LogicalRelation} @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} +import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -2790,15 +2791,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("Non-deterministic aggregate functions should not be deduplicated") { - val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a" - val df = sql(query) - val physical = df.queryExecution.sparkPlan - val aggregateExpressions = physical.collectFirst { - case agg : HashAggregateExec => agg.aggregateExpressions - case agg : SortAggregateExec => agg.aggregateExpressions + withUserDefinedFunction("sumND" -> true) { + spark.udf.register("sumND", udaf(new Aggregator[Long, Long, Long] { + def zero: Long = 0L + def reduce(b: Long, a: Long): Long = b + a + def merge(b1: Long, b2: Long): Long = b1 + b2 + def finish(r: Long): Long = r + def bufferEncoder: Encoder[Long] = Encoders.scalaLong + def outputEncoder: Encoder[Long] = Encoders.scalaLong + }).asNondeterministic()) + + val query = "SELECT a, sumND(b), sumND(b) + 1 FROM testData2 GROUP BY a" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg: BaseAggregateExec => agg.aggregateExpressions + } + assert(aggregateExpressions.isDefined) + assert(aggregateExpressions.get.size == 2) } - assert (aggregateExpressions.isDefined) - assert (aggregateExpressions.get.size == 2) } test("SPARK-22356: overlapped columns between data and partition schema in data source tables") {