From 37c4539978f4e92fef9055dfae292b22392a0bf8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 22 Sep 2016 18:45:38 -0700 Subject: [PATCH 1/2] Do not collect null values. --- .../sql/catalyst/expressions/aggregate/collect.scala | 5 ++++- .../apache/spark/sql/DataFrameAggregateSuite.scala | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 896ff61b2309..8de9c9901171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -65,7 +65,10 @@ abstract class Collect extends ImperativeAggregate { } override def update(b: MutableRow, input: InternalRow): Unit = { - buffer += child.eval(input) + val value = child.eval(input) + if (value != null) { + buffer += value + } } override def merge(buffer: MutableRow, input: InternalRow): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0e172bee4f66..7aa4f0026f27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -477,6 +477,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(error.message.contains("collect_set() cannot have map type data")) } + test("SPARK-17641: collect functions should not collect null values") { + val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b") + checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq("1", "1"), Seq(2, 2, 4))) + ) + checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq("1"), Seq(2, 4))) + ) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), From 80b2166c62b449eae4032eb96be39b9d8659ba8a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 27 Sep 2016 11:49:57 -0700 Subject: [PATCH 2/2] Add comment on filtering null semantic --- .../spark/sql/catalyst/expressions/aggregate/collect.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 8de9c9901171..78a388d20630 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -65,6 +65,8 @@ abstract class Collect extends ImperativeAggregate { } override def update(b: MutableRow, input: InternalRow): Unit = { + // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. + // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator val value = child.eval(input) if (value != null) { buffer += value