Skip to content

Commit e5e9a04

Browse files
committed
Use withUserDefinedFunction
1 parent 0f4d0af commit e5e9a04

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2786,23 +2786,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
27862786
}
27872787

27882788
test("Non-deterministic aggregate functions should not be deduplicated") {
2789-
spark.udf.register("countND", udaf(new Aggregator[Long, Long, Long] {
2790-
def zero: Long = 0L
2791-
def reduce(b: Long, a: Long): Long = b + a
2792-
def merge(b1: Long, b2: Long): Long = b1 + b2
2793-
def finish(r: Long): Long = r
2794-
def bufferEncoder: Encoder[Long] = Encoders.scalaLong
2795-
def outputEncoder: Encoder[Long] = Encoders.scalaLong
2796-
}).asNondeterministic())
2797-
2798-
val query = "SELECT a, countND(b), countND(b) + 1 FROM testData2 GROUP BY a"
2799-
val df = sql(query)
2800-
val physical = df.queryExecution.sparkPlan
2801-
val aggregateExpressions = physical.collectFirst {
2802-
case agg : BaseAggregateExec => agg.aggregateExpressions
2789+
withUserDefinedFunction("sumND" -> true) {
2790+
spark.udf.register("sumND", udaf(new Aggregator[Long, Long, Long] {
2791+
def zero: Long = 0L
2792+
def reduce(b: Long, a: Long): Long = b + a
2793+
def merge(b1: Long, b2: Long): Long = b1 + b2
2794+
def finish(r: Long): Long = r
2795+
def bufferEncoder: Encoder[Long] = Encoders.scalaLong
2796+
def outputEncoder: Encoder[Long] = Encoders.scalaLong
2797+
}).asNondeterministic())
2798+
2799+
val query = "SELECT a, sumND(b), sumND(b) + 1 FROM testData2 GROUP BY a"
2800+
val df = sql(query)
2801+
val physical = df.queryExecution.sparkPlan
2802+
val aggregateExpressions = physical.collectFirst {
2803+
case agg: BaseAggregateExec => agg.aggregateExpressions
2804+
}
2805+
assert(aggregateExpressions.isDefined)
2806+
assert(aggregateExpressions.get.size == 2)
28032807
}
2804-
assert (aggregateExpressions.isDefined)
2805-
assert (aggregateExpressions.get.size == 2)
28062808
}
28072809

28082810
test("SPARK-22356: overlapped columns between data and partition schema in data source tables") {

0 commit comments

Comments
 (0)