@@ -2786,23 +2786,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
27862786 }
27872787
27882788 test(" Non-deterministic aggregate functions should not be deduplicated" ) {
2789- spark.udf.register(" countND" , udaf(new Aggregator [Long , Long , Long ] {
2790- def zero : Long = 0L
2791- def reduce (b : Long , a : Long ): Long = b + a
2792- def merge (b1 : Long , b2 : Long ): Long = b1 + b2
2793- def finish (r : Long ): Long = r
2794- def bufferEncoder : Encoder [Long ] = Encoders .scalaLong
2795- def outputEncoder : Encoder [Long ] = Encoders .scalaLong
2796- }).asNondeterministic())
2797-
2798- val query = " SELECT a, countND(b), countND(b) + 1 FROM testData2 GROUP BY a"
2799- val df = sql(query)
2800- val physical = df.queryExecution.sparkPlan
2801- val aggregateExpressions = physical.collectFirst {
2802- case agg : BaseAggregateExec => agg.aggregateExpressions
2789+ withUserDefinedFunction(" sumND" -> true ) {
2790+ spark.udf.register(" sumND" , udaf(new Aggregator [Long , Long , Long ] {
2791+ def zero : Long = 0L
2792+ def reduce (b : Long , a : Long ): Long = b + a
2793+ def merge (b1 : Long , b2 : Long ): Long = b1 + b2
2794+ def finish (r : Long ): Long = r
2795+ def bufferEncoder : Encoder [Long ] = Encoders .scalaLong
2796+ def outputEncoder : Encoder [Long ] = Encoders .scalaLong
2797+ }).asNondeterministic())
2798+
2799+ val query = " SELECT a, sumND(b), sumND(b) + 1 FROM testData2 GROUP BY a"
2800+ val df = sql(query)
2801+ val physical = df.queryExecution.sparkPlan
2802+ val aggregateExpressions = physical.collectFirst {
2803+ case agg : BaseAggregateExec => agg.aggregateExpressions
2804+ }
2805+ assert(aggregateExpressions.isDefined)
2806+ assert(aggregateExpressions.get.size == 2 )
28032807 }
2804- assert (aggregateExpressions.isDefined)
2805- assert (aggregateExpressions.get.size == 2 )
28062808 }
28072809
28082810 test(" SPARK-22356: overlapped columns between data and partition schema in data source tables" ) {
0 commit comments