diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index bb05c76cfee6d..a4ff09596ad8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -28,10 +28,11 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction -import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedFunction} +import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -101,9 +102,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 2.2.0 */ def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = { - def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr - functionRegistry.createOrReplaceTempFunction(name, builder) - udf + udf match { + case udaf: UserDefinedAggregator[_, _, _] => + def builder(children: Seq[Expression]) = udaf.scalaAggregator(children) + functionRegistry.createOrReplaceTempFunction(name, builder) + udf + case _ => + def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr + functionRegistry.createOrReplaceTempFunction(name, builder) + udf + } } // scalastyle:off line.size.limit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 100486fa9850f..dfae5c07e0373 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql.execution.aggregate +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.internal.Logging -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, _} -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection} +import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ /** @@ -450,3 +454,63 @@ case class ScalaUDAF( override def nodeName: String = udaf.getClass.getSimpleName } + +case class ScalaAggregator[IN, BUF, OUT]( + children: Seq[Expression], + agg: Aggregator[IN, BUF, OUT], + inputEncoderNR: ExpressionEncoder[IN], + nullable: Boolean = true, + isDeterministic: Boolean = true, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[BUF] + with NonSQLExpression + with UserDefinedExpression + with ImplicitCastInputTypes + with Logging { + + private[this] lazy val inputEncoder = inputEncoderNR.resolveAndBind() + private[this] lazy val bufferEncoder = + agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind() + private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]] + + def dataType: DataType = outputEncoder.objSerializer.dataType + + def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType) + + override lazy val deterministic: Boolean = isDeterministic + + def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ScalaAggregator[IN, BUF, OUT] = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ScalaAggregator[IN, BUF, OUT] = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + private[this] lazy val inputProjection = UnsafeProjection.create(children) + + def createAggregationBuffer(): BUF = agg.zero + + def update(buffer: BUF, input: InternalRow): BUF = + agg.reduce(buffer, inputEncoder.fromRow(inputProjection(input))) + + def merge(buffer: BUF, input: BUF): BUF = agg.merge(buffer, input) + + def eval(buffer: BUF): Any = { + val row = outputEncoder.toRow(agg.finish(buffer)) + if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType) + } + + private[this] lazy val bufferRow = new UnsafeRow(bufferEncoder.namedExpressions.length) + + def serialize(agg: BUF): Array[Byte] = + bufferEncoder.toRow(agg).asInstanceOf[UnsafeRow].getBytes() + + def deserialize(storageFormat: Array[Byte]): BUF = { + bufferRow.pointTo(storageFormat, storageFormat.length) + bufferEncoder.fromRow(bufferRow) + } + + override def toString: String = s"""${nodeName}(${children.mkString(",")})""" + + override def nodeName: String = agg.getClass.getSimpleName +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 0c956ecbf936e..85b2cd379ba24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -17,10 +17,15 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.Stable -import org.apache.spark.sql.Column +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.annotation.{Experimental, Stable} +import org.apache.spark.sql.{Column, Encoder} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} +import org.apache.spark.sql.execution.aggregate.ScalaAggregator import org.apache.spark.sql.types.{AnyDataType, DataType} /** @@ -136,3 +141,42 @@ private[sql] case class SparkUserDefinedFunction( } } } + +private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( + aggregator: Aggregator[IN, BUF, OUT], + inputEncoder: Encoder[IN], + name: Option[String] = None, + nullable: Boolean = true, + deterministic: Boolean = true) extends UserDefinedFunction { + + @scala.annotation.varargs + def apply(exprs: Column*): Column = { + Column(AggregateExpression(scalaAggregator(exprs.map(_.expr)), Complete, isDistinct = false)) + } + + // This is also used by udf.register(...) when it detects a UserDefinedAggregator + def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = { + val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]] + ScalaAggregator(exprs, aggregator, iEncoder, nullable, deterministic) + } + + override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = { + copy(name = Option(name)) + } + + override def asNonNullable(): UserDefinedAggregator[IN, BUF, OUT] = { + if (!nullable) { + this + } else { + copy(nullable = false) + } + } + + override def asNondeterministic(): UserDefinedAggregator[IN, BUF, OUT] = { + if (!deterministic) { + this + } else { + copy(deterministic = false) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9911972d0f1ba..0ebbf98a77ec3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -32,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction} +import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils - /** * Commonly used functions available for DataFrame operations. Using functions defined here provides * a little bit more compile-time safety to make sure the function exists. @@ -4231,6 +4230,67 @@ object functions { // Scala UDF functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Obtains a `UserDefinedFunction` that wraps the given `Aggregator` + * so that it may be used with untyped Data Frames. + * {{{ + * val agg = // Aggregator[IN, BUF, OUT] + * + * // declare a UDF based on agg + * val aggUDF = udaf(agg) + * val aggData = df.agg(aggUDF($"colname")) + * + * // register agg as a named function + * spark.udf.register("myAggName", udaf(agg)) + * }}} + * + * @tparam IN the aggregator input type + * @tparam BUF the aggregating buffer type + * @tparam OUT the finalized output type + * + * @param agg the typed Aggregator + * + * @return a UserDefinedFunction that can be used as an aggregating expression. + * + * @note The input encoder is inferred from the input type IN. + */ + def udaf[IN: TypeTag, BUF, OUT](agg: Aggregator[IN, BUF, OUT]): UserDefinedFunction = { + udaf(agg, ExpressionEncoder[IN]()) + } + + /** + * Obtains a `UserDefinedFunction` that wraps the given `Aggregator` + * so that it may be used with untyped Data Frames. + * {{{ + * Aggregator agg = // custom Aggregator + * Encoder enc = // input encoder + * + * // declare a UDF based on agg + * UserDefinedFunction aggUDF = udaf(agg, enc) + * DataFrame aggData = df.agg(aggUDF($"colname")) + * + * // register agg as a named function + * spark.udf.register("myAggName", udaf(agg, enc)) + * }}} + * + * @tparam IN the aggregator input type + * @tparam BUF the aggregating buffer type + * @tparam OUT the finalized output type + * + * @param agg the typed Aggregator + * @param inputEncoder a specific input encoder to use + * + * @return a UserDefinedFunction that can be used as an aggregating expression + * + * @note This overloading takes an explicit input encoder, to support UDAF + * declarations in Java. + */ + def udaf[IN, BUF, OUT]( + agg: Aggregator[IN, BUF, OUT], + inputEncoder: Encoder[IN]): UserDefinedFunction = { + UserDefinedAggregator(agg, inputEncoder) + } + /** * Defines a Scala closure of 0 arguments as user-defined function (UDF). * The data types are automatically inferred based on the Scala closure's diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 696b056a682b3..2e37879ea1658 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.optimizer.TransposeWindow import org.apache.spark.sql.execution.exchange.Exchange -import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} +import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -412,6 +412,42 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSparkSession { Row("b", 2, 4, 8))) } + test("window function with aggregator") { + val agg = udaf(new Aggregator[(Long, Long), Long, Long] { + def zero: Long = 0L + def reduce(b: Long, a: (Long, Long)): Long = b + (a._1 * a._2) + def merge(b1: Long, b2: Long): Long = b1 + b2 + def finish(r: Long): Long = r + def bufferEncoder: Encoder[Long] = Encoders.scalaLong + def outputEncoder: Encoder[Long] = Encoders.scalaLong + }) + + val df = Seq( + ("a", 1, 1), + ("a", 1, 5), + ("a", 2, 10), + ("a", 2, -1), + ("b", 4, 7), + ("b", 3, 8), + ("b", 2, 4)) + .toDF("key", "a", "b") + val window = Window.partitionBy($"key").orderBy($"a").rangeBetween(Long.MinValue, 0L) + checkAnswer( + df.select( + $"key", + $"a", + $"b", + agg($"a", $"b").over(window)), + Seq( + Row("a", 1, 1, 6), + Row("a", 1, 5, 6), + Row("a", 2, 10, 24), + Row("a", 2, -1, 24), + Row("b", 4, 7, 60), + Row("b", 3, 8, 32), + Row("b", 2, 4, 8))) + } + test("null inputs") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)) .toDF("key", "value") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala new file mode 100644 index 0000000000000..e6856a58b0ea9 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala @@ -0,0 +1,417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.lang.{Double => jlDouble, Integer => jlInt, Long => jlLong} + +import scala.collection.JavaConverters._ +import scala.util.Random + +import test.org.apache.spark.sql.MyDoubleAvg +import test.org.apache.spark.sql.MyDoubleSum + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.expressions.{Aggregator} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + +class MyDoubleAvgAggBase extends Aggregator[jlDouble, (Double, Long), jlDouble] { + def zero: (Double, Long) = (0.0, 0L) + def reduce(b: (Double, Long), a: jlDouble): (Double, Long) = { + if (a != null) (b._1 + a, b._2 + 1L) else b + } + def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = + (b1._1 + b2._1, b1._2 + b2._2) + def finish(r: (Double, Long)): jlDouble = + if (r._2 > 0L) 100.0 + (r._1 / r._2.toDouble) else null + def bufferEncoder: Encoder[(Double, Long)] = + Encoders.tuple(Encoders.scalaDouble, Encoders.scalaLong) + def outputEncoder: Encoder[jlDouble] = Encoders.DOUBLE +} + +object MyDoubleAvgAgg extends MyDoubleAvgAggBase +object MyDoubleSumAgg extends MyDoubleAvgAggBase { + override def finish(r: (Double, Long)): jlDouble = if (r._2 > 0L) r._1 else null +} + +object LongProductSumAgg extends Aggregator[(jlLong, jlLong), Long, jlLong] { + def zero: Long = 0L + def reduce(b: Long, a: (jlLong, jlLong)): Long = { + if ((a._1 != null) && (a._2 != null)) b + (a._1 * a._2) else b + } + def merge(b1: Long, b2: Long): Long = b1 + b2 + def finish(r: Long): jlLong = r + def bufferEncoder: Encoder[Long] = Encoders.scalaLong + def outputEncoder: Encoder[jlLong] = Encoders.LONG +} + +@SQLUserDefinedType(udt = classOf[CountSerDeUDT]) +case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Int) + +class CountSerDeUDT extends UserDefinedType[CountSerDeSQL] { + def userClass: Class[CountSerDeSQL] = classOf[CountSerDeSQL] + + override def typeName: String = "count-ser-de" + + private[spark] override def asNullable: CountSerDeUDT = this + + def sqlType: DataType = StructType( + StructField("nSer", IntegerType, false) :: + StructField("nDeSer", IntegerType, false) :: + StructField("sum", IntegerType, false) :: + Nil) + + def serialize(sql: CountSerDeSQL): Any = { + val row = new GenericInternalRow(3) + row.setInt(0, 1 + sql.nSer) + row.setInt(1, sql.nDeSer) + row.setInt(2, sql.sum) + row + } + + def deserialize(any: Any): CountSerDeSQL = any match { + case row: InternalRow if (row.numFields == 3) => + CountSerDeSQL(row.getInt(0), 1 + row.getInt(1), row.getInt(2)) + case u => throw new Exception(s"failed to deserialize: $u") + } + + override def equals(obj: Any): Boolean = { + obj match { + case _: CountSerDeUDT => true + case _ => false + } + } + + override def hashCode(): Int = classOf[CountSerDeUDT].getName.hashCode() +} + +case object CountSerDeUDT extends CountSerDeUDT + +object CountSerDeAgg extends Aggregator[Int, CountSerDeSQL, CountSerDeSQL] { + def zero: CountSerDeSQL = CountSerDeSQL(0, 0, 0) + def reduce(b: CountSerDeSQL, a: Int): CountSerDeSQL = b.copy(sum = b.sum + a) + def merge(b1: CountSerDeSQL, b2: CountSerDeSQL): CountSerDeSQL = + CountSerDeSQL(b1.nSer + b2.nSer, b1.nDeSer + b2.nDeSer, b1.sum + b2.sum) + def finish(r: CountSerDeSQL): CountSerDeSQL = r + def bufferEncoder: Encoder[CountSerDeSQL] = ExpressionEncoder[CountSerDeSQL]() + def outputEncoder: Encoder[CountSerDeSQL] = ExpressionEncoder[CountSerDeSQL]() +} + +abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + val data1 = Seq[(Integer, Integer)]( + (1, 10), + (null, -60), + (1, 20), + (1, 30), + (2, 0), + (null, -10), + (2, -1), + (2, null), + (2, null), + (null, 100), + (3, null), + (null, null), + (3, null)).toDF("key", "value") + data1.write.saveAsTable("agg1") + + val data2 = Seq[(Integer, Integer, Integer)]( + (1, 10, -10), + (null, -60, 60), + (1, 30, -30), + (1, 30, 30), + (2, 1, 1), + (null, -10, 10), + (2, -1, null), + (2, 1, 1), + (2, null, 1), + (null, 100, -10), + (3, null, 3), + (null, null, null), + (3, null, null)).toDF("key", "value1", "value2") + data2.write.saveAsTable("agg2") + + val data3 = Seq[(Seq[Integer], Integer, Integer)]( + (Seq[Integer](1, 1), 10, -10), + (Seq[Integer](null), -60, 60), + (Seq[Integer](1, 1), 30, -30), + (Seq[Integer](1), 30, 30), + (Seq[Integer](2), 1, 1), + (null, -10, 10), + (Seq[Integer](2, 3), -1, null), + (Seq[Integer](2, 3), 1, 1), + (Seq[Integer](2, 3, 4), null, 1), + (Seq[Integer](null), 100, -10), + (Seq[Integer](3), null, 3), + (null, null, null), + (Seq[Integer](3), null, null)).toDF("key", "value1", "value2") + data3.write.saveAsTable("agg3") + + val data4 = Seq[Boolean](true, false, true).toDF("boolvalues") + data4.write.saveAsTable("agg4") + + val emptyDF = spark.createDataFrame( + sparkContext.emptyRDD[Row], + StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) + emptyDF.createOrReplaceTempView("emptyTable") + + // Register UDAs + spark.udf.register("mydoublesum", udaf(MyDoubleSumAgg)) + spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg)) + spark.udf.register("longProductSum", udaf(LongProductSumAgg)) + } + + override def afterAll(): Unit = { + try { + spark.sql("DROP TABLE IF EXISTS agg1") + spark.sql("DROP TABLE IF EXISTS agg2") + spark.sql("DROP TABLE IF EXISTS agg3") + spark.sql("DROP TABLE IF EXISTS agg4") + spark.catalog.dropTempView("emptyTable") + } finally { + super.afterAll() + } + } + + test("aggregators") { + checkAnswer( + spark.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | mydoubleavg(value), + | avg(value - key), + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 64.5, 120.0, 19.0, 55.5, 20.0) :: + Row(2, 5.0, 99.5, -2.5, -7.0, -0.5) :: + Row(3, null, null, null, null, null) :: + Row(null, null, 110.0, null, null, 10.0) :: Nil) + } + + test("non-deterministic children expressions of aggregator") { + val e = intercept[AnalysisException] { + spark.sql( + """ + |SELECT mydoublesum(value + 1.5 * key + rand()) + |FROM agg1 + |GROUP BY key + """.stripMargin) + }.getMessage + assert(Seq("nondeterministic expression", + "should not appear in the arguments of an aggregate function").forall(e.contains)) + } + + test("interpreted aggregate function") { + checkAnswer( + spark.sql( + """ + |SELECT mydoublesum(value), key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) + + checkAnswer( + spark.sql( + """ + |SELECT mydoublesum(value) FROM agg1 + """.stripMargin), + Row(89.0) :: Nil) + + checkAnswer( + spark.sql( + """ + |SELECT mydoublesum(null) + """.stripMargin), + Row(null) :: Nil) + } + + test("interpreted and expression-based aggregation functions") { + checkAnswer( + spark.sql( + """ + |SELECT mydoublesum(value), key, avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1, 20.0) :: + Row(-1.0, 2, -0.5) :: + Row(null, 3, null) :: + Row(30.0, null, 10.0) :: Nil) + + checkAnswer( + spark.sql( + """ + |SELECT + | mydoublesum(value + 1.5 * key), + | avg(value - key), + | key, + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(64.5, 19.0, 1, 55.5, 20.0) :: + Row(5.0, -2.5, 2, -7.0, -0.5) :: + Row(null, null, 3, null, null) :: + Row(null, null, null, null, 10.0) :: Nil) + } + + test("single distinct column set") { + checkAnswer( + spark.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + spark.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + } + + test("multiple distinct multiple columns sets") { + checkAnswer( + spark.sql( + """ + |SELECT + | key, + | count(distinct value1), + | sum(distinct value1), + | count(distinct value2), + | sum(distinct value2), + | count(distinct value1, value2), + | longProductSum(distinct value1, value2), + | count(value1), + | sum(value1), + | count(value2), + | sum(value2), + | longProductSum(value1, value2), + | count(*), + | count(1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) :: + Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) :: + Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) :: + Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) + } + + test("verify aggregator ser/de behavior") { + val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1") + val agg = udaf(CountSerDeAgg) + checkAnswer( + data.agg(agg($"value1")), + Row(CountSerDeSQL(4, 4, 5050)) :: Nil) + } + + test("verify type casting failure") { + assertThrows[org.apache.spark.sql.AnalysisException] { + spark.sql( + """ + |SELECT mydoublesum(boolvalues) FROM agg4 + """.stripMargin) + } + } +} + +class HashUDAQuerySuite extends UDAQuerySuite + +class HashUDAQueryWithControlledFallbackSuite extends UDAQuerySuite { + + override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + super.checkAnswer(actual, expectedAnswer) + Seq("true", "false").foreach { enableTwoLevelMaps => + withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enabled" -> + enableTwoLevelMaps) { + (1 to 3).foreach { fallbackStartsAt => + withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> + s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") { + QueryTest.getErrorMessageInCheckAnswer(actual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using HashAggregate with + |controlled fallback (it falls back to bytes to bytes map once it has processed + |${fallbackStartsAt - 1} input rows and to sort-based aggregation once it has + |processed $fallbackStartsAt input rows). The query is ${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => // Success + } + } + } + } + } + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } +}