diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/udafs/ArraySum.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/udafs/ArraySum.scala new file mode 100644 index 00000000..d28586db --- /dev/null +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/udafs/ArraySum.scala @@ -0,0 +1,38 @@ +package com.github.mrpowers.spark.daria.sql.udafs + +import org.apache.spark.sql.Row +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} + +class ArraySum(elementSchema: DataType, + nullable: Boolean = true) extends UserDefinedAggregateFunction { + + private val schema = StructType(List(StructField("value", dataType, nullable))) + + override def inputSchema: StructType = schema + + override def bufferSchema: StructType = schema + + override def dataType: DataType = ArrayType(elementSchema) + + override def deterministic: Boolean = true + + override def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = Seq.empty[Any] + } + + override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + val value = input.getAs[Seq[Any]](0) + if (value != null) { + buffer(0) = buffer.getAs[Seq[Any]](0) ++ value + } + } + + override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getAs[Seq[Any]](0) ++ buffer2.getAs[Seq[Any]](0) + } + + override def evaluate(buffer: Row): Any = { + buffer.getAs[Seq[Any]](0) + } +} diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/udafs/ArraySumTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/udafs/ArraySumTest.scala new file mode 100644 index 00000000..f523650a --- /dev/null +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/udafs/ArraySumTest.scala @@ -0,0 +1,44 @@ +package com.github.mrpowers.spark.daria.sql.udafs + +import utest._ +import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ +import com.github.mrpowers.spark.daria.sql.SparkSessionTestWrapper +import com.github.mrpowers.spark.fast.tests.DataFrameComparer +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + + +object ArraySumTest extends TestSuite with DataFrameComparer with SparkSessionTestWrapper { + + val tests = Tests { + + 'arraySum - { + "concatenates rows of arrays" - { + + val arraySum = new ArraySum(StringType) + val actualDF = spark + .createDF( + List( + Array("snake", "rat"), + null, + Array("cat", "crazy") + ), + List(("array", ArrayType(StringType), true)) + ).agg(arraySum(col("array")).as("array")) + + val expectedDF = spark + .createDF( + List(Array("snake", "rat", "cat", "crazy")), + List(("array", ArrayType(StringType), true)) + ) + + assertSmallDataFrameEquality( + actualDF, + expectedDF + ) + } + } + } + +} +