Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.github.mrpowers.spark.daria.sql.udafs

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}

class ArraySum(elementSchema: DataType,
nullable: Boolean = true) extends UserDefinedAggregateFunction {

private val schema = StructType(List(StructField("value", dataType, nullable)))

override def inputSchema: StructType = schema

override def bufferSchema: StructType = schema

override def dataType: DataType = ArrayType(elementSchema)

override def deterministic: Boolean = true

override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Seq.empty[Any]
}

override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val value = input.getAs[Seq[Any]](0)
if (value != null) {
buffer(0) = buffer.getAs[Seq[Any]](0) ++ value
}
}

override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Seq[Any]](0) ++ buffer2.getAs[Seq[Any]](0)
}

override def evaluate(buffer: Row): Any = {
buffer.getAs[Seq[Any]](0)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.github.mrpowers.spark.daria.sql.udafs

import utest._
import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
import com.github.mrpowers.spark.daria.sql.SparkSessionTestWrapper
import com.github.mrpowers.spark.fast.tests.DataFrameComparer
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._


object ArraySumTest extends TestSuite with DataFrameComparer with SparkSessionTestWrapper {

val tests = Tests {

'arraySum - {
"concatenates rows of arrays" - {

val arraySum = new ArraySum(StringType)
val actualDF = spark
.createDF(
List(
Array("snake", "rat"),
null,
Array("cat", "crazy")
),
List(("array", ArrayType(StringType), true))
).agg(arraySum(col("array")).as("array"))

val expectedDF = spark
.createDF(
List(Array("snake", "rat", "cat", "crazy")),
List(("array", ArrayType(StringType), true))
)

assertSmallDataFrameEquality(
actualDF,
expectedDF
)
}
}
}

}