Skip to content

Commit 58e07e0

Browse files
tanelkcloud-fan
authored andcommitted
[SPARK-32940][SQL] Collect, first and last should be deterministic aggregate functions
### What changes were proposed in this pull request? Collect, first and last have mistakenly been marked as non-deterministic. They are actually deterministic iff their child expression is deterministic. For example collect was marked as non-deterministic in #14749. The reasoning was that its output depends on the actual order of input rows. Although it is correct that these aggregators depend on the order of input rows, it does not make them non-deterministic. In `EliminateSorts` optimizer rule, there is a method `isOrderIrrelevantAggs`, that lists all aggregators that do not depend on their input row order. Collect, first and last are correctly not listed there. An aggregator would be non-deterministic if its output for a group would depend on previous groups it has aggregated - I can't think of any practical examples of this kind of aggregator in Spark. An analogous aggregator to these would be sum on float and double datatype - its result does depend on the order of its inputs, but is deterministic. Another similar aggregates are the `max_by` and `min_by` - deterministic functions, that can return different results when the order of rows changes. ### Why are the changes needed? The optimizer rule `PushPredicateThroughNonJoin` can work in more cases. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #29810 from tanelk/SPARK-32940. Lead-authored-by: [email protected] <[email protected]> Co-authored-by: Tanel Kiis <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 3f3201a commit 58e07e0

File tree

9 files changed

+58
-24
lines changed

9 files changed

+58
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ package object dsl {
218218
BitOrAgg(e).toAggregateExpression(isDistinct = false, filter = filter)
219219
def bitXor(e: Expression, filter: Option[Expression] = None): Expression =
220220
BitXorAgg(e).toAggregateExpression(isDistinct = false, filter = filter)
221+
def collectList(e: Expression, filter: Option[Expression] = None): Expression =
222+
CollectList(e).toAggregateExpression(isDistinct = false, filter = filter)
223+
def collectSet(e: Expression, filter: Option[Expression] = None): Expression =
224+
CollectSet(e).toAggregateExpression(isDistinct = false, filter = filter)
225+
221226
def upper(e: Expression): Expression = Upper(e)
222227
def lower(e: Expression): Expression = Lower(e)
223228
def coalesce(args: Expression*): Expression = Coalesce(args)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ case class First(child: Expression, ignoreNulls: Boolean)
6262

6363
override def nullable: Boolean = true
6464

65-
// First is not a deterministic function.
66-
override lazy val deterministic: Boolean = false
67-
6865
// Return data type.
6966
override def dataType: DataType = child.dataType
7067

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ case class Last(child: Expression, ignoreNulls: Boolean)
6161

6262
override def nullable: Boolean = true
6363

64-
// Last is not a deterministic function.
65-
override lazy val deterministic: Boolean = false
66-
6764
// Return data type.
6865
override def dataType: DataType = child.dataType
6966

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
4343

4444
override def dataType: DataType = ArrayType(child.dataType, false)
4545

46-
// Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the
47-
// actual order of input rows.
48-
override lazy val deterministic: Boolean = false
49-
5046
override def defaultResult: Option[Literal] = Option(Literal.create(Array(), dataType))
5147

5248
protected def convertToBufferElement(value: Any): Any

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,8 @@ object EliminateDistinct extends Rule[LogicalPlan] {
423423
case _: BitAndAgg => true
424424
case _: BitOrAgg => true
425425
case _: CollectSet => true
426+
case _: First => true
427+
case _: Last => true
426428
case _ => false
427429
}
428430
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,15 @@ class EliminateDistinctSuite extends PlanTest {
3939
Min(_),
4040
BitAndAgg(_),
4141
BitOrAgg(_),
42+
First(_, ignoreNulls = true),
43+
First(_, ignoreNulls = false),
44+
Last(_, ignoreNulls = true),
45+
Last(_, ignoreNulls = false),
4246
CollectSet(_: Expression)
4347
).foreach {
4448
aggBuilder =>
4549
val agg = aggBuilder('a)
46-
test(s"Eliminate Distinct in ${agg.prettyName}") {
50+
test(s"Eliminate Distinct in $agg") {
4751
val query = testRelation
4852
.select(agg.toAggregateExpression(isDistinct = true).as('result))
4953
.analyze

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,29 @@ class FilterPushdownSuite extends PlanTest {
840840
comparePlans(optimized, correctAnswer)
841841
}
842842

843+
test("SPARK-32940: aggregate: push filters through first, last and collect") {
844+
Seq(
845+
first(_: Expression),
846+
last(_: Expression),
847+
collectList(_: Expression),
848+
collectSet(_: Expression)
849+
).foreach { agg =>
850+
val originalQuery = testRelation
851+
.groupBy('a)(agg('b))
852+
.where('a > 42)
853+
.analyze
854+
855+
val optimized = Optimize.execute(originalQuery)
856+
857+
val correctAnswer = testRelation
858+
.where('a > 42)
859+
.groupBy('a)(agg('b))
860+
.analyze
861+
862+
comparePlans(optimized, correctAnswer)
863+
}
864+
}
865+
843866
test("union") {
844867
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
845868

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
2323
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils._
26-
import org.apache.spark.sql.catalyst.expressions.aggregate.First
2726
import org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison._
2827
import org.apache.spark.sql.catalyst.plans.PlanTest
2928
import org.apache.spark.sql.catalyst.plans.logical._
@@ -183,10 +182,10 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
183182
}
184183

185184
test("unwrap cast should skip when expression is non-deterministic or foldable") {
186-
Seq(positiveInt, negativeInt).foreach(v => {
187-
val e = Cast(First(f, ignoreNulls = true), IntegerType) <=> v
185+
Seq(positiveLong, negativeLong).foreach (v => {
186+
val e = Cast(Rand(0), LongType) <=> v
188187
assertEquivalent(e, e, evaluate = false)
189-
val e2 = Cast(Literal(30.toShort), IntegerType) >= v
188+
val e2 = Cast(Literal(30), LongType) >= v
190189
assertEquivalent(e2, e2, evaluate = false)
191190
})
192191
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, Project, Reparti
3535
import org.apache.spark.sql.catalyst.util.StringUtils
3636
import org.apache.spark.sql.execution.{CommandResultExec, UnionExec}
3737
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
38-
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
38+
import org.apache.spark.sql.execution.aggregate._
3939
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
4040
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, FunctionsCommand}
4141
import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, LogicalRelation}
@@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
4444
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
4545
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
4646
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
47+
import org.apache.spark.sql.expressions.Aggregator
4748
import org.apache.spark.sql.functions._
4849
import org.apache.spark.sql.internal.SQLConf
4950
import org.apache.spark.sql.test.SharedSparkSession
@@ -2790,15 +2791,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
27902791
}
27912792

27922793
test("Non-deterministic aggregate functions should not be deduplicated") {
2793-
val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a"
2794-
val df = sql(query)
2795-
val physical = df.queryExecution.sparkPlan
2796-
val aggregateExpressions = physical.collectFirst {
2797-
case agg : HashAggregateExec => agg.aggregateExpressions
2798-
case agg : SortAggregateExec => agg.aggregateExpressions
2794+
withUserDefinedFunction("sumND" -> true) {
2795+
spark.udf.register("sumND", udaf(new Aggregator[Long, Long, Long] {
2796+
def zero: Long = 0L
2797+
def reduce(b: Long, a: Long): Long = b + a
2798+
def merge(b1: Long, b2: Long): Long = b1 + b2
2799+
def finish(r: Long): Long = r
2800+
def bufferEncoder: Encoder[Long] = Encoders.scalaLong
2801+
def outputEncoder: Encoder[Long] = Encoders.scalaLong
2802+
}).asNondeterministic())
2803+
2804+
val query = "SELECT a, sumND(b), sumND(b) + 1 FROM testData2 GROUP BY a"
2805+
val df = sql(query)
2806+
val physical = df.queryExecution.sparkPlan
2807+
val aggregateExpressions = physical.collectFirst {
2808+
case agg: BaseAggregateExec => agg.aggregateExpressions
2809+
}
2810+
assert(aggregateExpressions.isDefined)
2811+
assert(aggregateExpressions.get.size == 2)
27992812
}
2800-
assert (aggregateExpressions.isDefined)
2801-
assert (aggregateExpressions.get.size == 2)
28022813
}
28032814

28042815
test("SPARK-22356: overlapped columns between data and partition schema in data source tables") {

0 commit comments

Comments
 (0)