Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
cf6c7e9
Bitwise operations are commutative
tanelk Sep 17, 2020
317f313
Experiment with SQLQueryTestSuite
tanelk Sep 19, 2020
281ed68
Optimizer rules
tanelk Sep 19, 2020
326ec05
Collect, first and last should be deterministic aggregate functions
tanelk Sep 19, 2020
ab2901e
Fix test, that required non-deterministic expression
tanelk Sep 19, 2020
e8badd1
Revert "Optimizer rules"
tanelk Sep 19, 2020
f1e6711
Revert "Experiment with SQLQueryTestSuite"
tanelk Sep 19, 2020
c17f2ef
Revert "Bitwise operations are commutative"
tanelk Sep 19, 2020
3e98b13
Merge branch 'master' into SPARK-32940
tanelk Sep 19, 2020
b9fd2f1
Fix test, that required non-deterministic expression
tanelk Sep 19, 2020
9898c56
Fix test, that required non-deterministic aggregator
tanelk Sep 20, 2020
b0919a2
Improve docstrings
tanelk Sep 20, 2020
a080b53
Merge branch 'master' into SPARK-32940_deterministic_agg
tanelk Dec 30, 2020
dc6e7c0
Fix merge
tanelk Dec 30, 2020
577ec60
Merge branch 'master' into SPARK-32940_deterministic_agg
tanelk Mar 20, 2021
4229251
Revert doc changes
tanelk Mar 20, 2021
89823e4
Merge branch 'master' into SPARK-32940_deterministic_agg
tanelk Mar 21, 2021
13c92f3
Remove distinct from first and last
tanelk Mar 21, 2021
3f945f0
Remove the deterministic flag
tanelk Apr 13, 2021
da196d0
merge
tanelk Apr 13, 2021
0f4d0af
Merge remote-tracking branch 'upstream/master' into SPARK-32940
tanelk Jun 24, 2021
e5e9a04
Use withUserDefinedFunction
tanelk Jun 24, 2021
56fbf15
Merge branch 'master' into SPARK-32940
tanelk Sep 30, 2021
0d40311
Merge remote-tracking branch 'upstream/master' into SPARK-32940
tanelk Oct 19, 2021
e4ed57c
Address comments
tanelk Nov 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ package object dsl {
BitOrAgg(e).toAggregateExpression(isDistinct = false, filter = filter)
def bitXor(e: Expression, filter: Option[Expression] = None): Expression =
BitXorAgg(e).toAggregateExpression(isDistinct = false, filter = filter)
def collectList(e: Expression, filter: Option[Expression] = None): Expression =
CollectList(e).toAggregateExpression(isDistinct = false, filter = filter)
def collectSet(e: Expression, filter: Option[Expression] = None): Expression =
CollectSet(e).toAggregateExpression(isDistinct = false, filter = filter)

def upper(e: Expression): Expression = Upper(e)
def lower(e: Expression): Expression = Lower(e)
def coalesce(args: Expression*): Expression = Coalesce(args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ case class First(child: Expression, ignoreNulls: Boolean)

override def nullable: Boolean = true

// First is not a deterministic function.
override lazy val deterministic: Boolean = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you may need to update the note above and says like "The function can be non-deterministic because its results depend on the order of input rows which are usually non-deterministic after a shuffle." You might need to update functions.py, functions.R and functions.scala


// Return data type.
override def dataType: DataType = child.dataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ case class Last(child: Expression, ignoreNulls: Boolean)

override def nullable: Boolean = true

// Last is not a deterministic function.
override lazy val deterministic: Boolean = false

// Return data type.
override def dataType: DataType = child.dataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper

override def dataType: DataType = ArrayType(child.dataType, false)

// Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the
// actual order of input rows.
override lazy val deterministic: Boolean = false

override def defaultResult: Option[Literal] = Option(Literal.create(Array(), dataType))

protected def convertToBufferElement(value: Any): Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ object EliminateDistinct extends Rule[LogicalPlan] {
case _: BitAndAgg => true
case _: BitOrAgg => true
case _: CollectSet => true
case _: First => true
case _: Last => true
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,15 @@ class EliminateDistinctSuite extends PlanTest {
Min(_),
BitAndAgg(_),
BitOrAgg(_),
First(_, ignoreNulls = true),
First(_, ignoreNulls = false),
Last(_, ignoreNulls = true),
Last(_, ignoreNulls = false),
CollectSet(_: Expression)
).foreach {
aggBuilder =>
val agg = aggBuilder('a)
test(s"Eliminate Distinct in ${agg.prettyName}") {
test(s"Eliminate Distinct in $agg") {
val query = testRelation
.select(agg.toAggregateExpression(isDistinct = true).as('result))
.analyze
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,29 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("SPARK-32940: aggregate: push filters through first, last and collect") {
Seq(
first(_: Expression),
last(_: Expression),
collectList(_: Expression),
collectSet(_: Expression)
).foreach { agg =>
val originalQuery = testRelation
.groupBy('a)(agg('b))
.where('a > 42)
.analyze

val optimized = Optimize.execute(originalQuery)

val correctAnswer = testRelation
.where('a > 42)
.groupBy('a)(agg('b))
.analyze

comparePlans(optimized, correctAnswer)
}
}

test("union") {
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -183,10 +182,10 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
}

test("unwrap cast should skip when expression is non-deterministic or foldable") {
Seq(positiveInt, negativeInt).foreach(v => {
val e = Cast(First(f, ignoreNulls = true), IntegerType) <=> v
Seq(positiveLong, negativeLong).foreach (v => {
val e = Cast(Rand(0), LongType) <=> v
assertEquivalent(e, e, evaluate = false)
val e2 = Cast(Literal(30.toShort), IntegerType) >= v
val e2 = Cast(Literal(30), LongType) >= v
assertEquivalent(e2, e2, evaluate = false)
})
}
Expand Down
29 changes: 20 additions & 9 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, Project, Reparti
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.{CommandResultExec, UnionExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, FunctionsCommand}
import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, LogicalRelation}
Expand All @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -2790,15 +2791,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}

test("Non-deterministic aggregate functions should not be deduplicated") {
val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a"
val df = sql(query)
val physical = df.queryExecution.sparkPlan
val aggregateExpressions = physical.collectFirst {
case agg : HashAggregateExec => agg.aggregateExpressions
case agg : SortAggregateExec => agg.aggregateExpressions
withUserDefinedFunction("sumND" -> true) {
spark.udf.register("sumND", udaf(new Aggregator[Long, Long, Long] {
def zero: Long = 0L
def reduce(b: Long, a: Long): Long = b + a
def merge(b1: Long, b2: Long): Long = b1 + b2
def finish(r: Long): Long = r
def bufferEncoder: Encoder[Long] = Encoders.scalaLong
def outputEncoder: Encoder[Long] = Encoders.scalaLong
}).asNondeterministic())

val query = "SELECT a, sumND(b), sumND(b) + 1 FROM testData2 GROUP BY a"
val df = sql(query)
val physical = df.queryExecution.sparkPlan
val aggregateExpressions = physical.collectFirst {
case agg: BaseAggregateExec => agg.aggregateExpressions
}
assert(aggregateExpressions.isDefined)
assert(aggregateExpressions.get.size == 2)
}
assert (aggregateExpressions.isDefined)
assert (aggregateExpressions.get.size == 2)
}

test("SPARK-22356: overlapped columns between data and partition schema in data source tables") {
Expand Down