diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala index cd041431..ca997701 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala @@ -420,4 +420,12 @@ object functions { val isLuhnNumber = udf[Option[Boolean], String](isLuhn) + def array_sum(col: Column): Column = { + aggregate( + col, + 0, + (acc, x) -> acc + x + ) + } + } diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala index b8041c9e..360e5a37 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala @@ -816,6 +816,49 @@ object FunctionsTest extends TestSuite with DataFrameComparer with ColumnCompare } + 'array_sum - { + + val df = spark + .createDF( + List( + ( + Array( + 1, + 4, + 9 + ), + true + ), + ( + Array( + 1, + 3, + 5 + ), + false + ) + ), + List( + ( + "nums", + ArrayType( + IntegerType, + true + ), + true + ), + ("expected", BooleanType, false) + ) + ) + .withColumn( + "array_sum", + functions.array_sum(col("nums")) + ) + + df.show() + + } + } }