diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index 9eda1aa61010..f5fd725b9ade 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -18,6 +18,7 @@ import unittest from pyspark.rdd import PythonEvalType +from pyspark.sql import Row from pyspark.sql.functions import array, explode, col, lit, mean, sum, \ udf, pandas_udf, PandasUDFType from pyspark.sql.types import * @@ -461,6 +462,18 @@ def test_register_vectorized_udf_basic(self): expected = [1, 5] self.assertEqual(actual, expected) + def test_grouped_with_empty_partition(self): + data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)] + expected = [Row(id=1, sum=5), Row(id=2, x=4)] + num_parts = len(data) + 1 + df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts)) + + f = pandas_udf(lambda x: x.sum(), + 'int', PandasUDFType.GROUPED_AGG) + + result = df.groupBy('id').agg(f(df['x']).alias('sum')).collect() + self.assertEqual(result, expected) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_udf_grouped_agg import * diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index 1d87c636ab34..32d6720b2c12 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -504,6 +504,18 @@ def test_mixed_scalar_udfs_followed_by_grouby_apply(self): self.assertEquals(result.collect()[0]['sum'], 165) + def test_grouped_with_empty_partition(self): + data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)] + expected = [Row(id=1, x=5), Row(id=1, x=5), Row(id=2, x=4)] + num_parts = len(data) + 1 + df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts)) + + f = pandas_udf(lambda pdf: pdf.assign(x=pdf['x'].sum()), + 'id long, x int', PandasUDFType.GROUPED_MAP) + + result = df.groupBy('id').apply(f).collect() + self.assertEqual(result, expected) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_udf_grouped_map import * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 0c78cca086ed..fcbd0b19515b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -105,7 +105,8 @@ case class AggregateInPandasExec( StructField(s"_$i", dt) }) - inputRDD.mapPartitionsInternal { iter => + // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty + inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { val prunedProj = UnsafeProjection.create(allInputs, child.output) val grouped = if (groupingExpressions.isEmpty) { @@ -151,6 +152,6 @@ case class AggregateInPandasExec( val joinedRow = joined(leftRow, aggOutputRow) resultProj(joinedRow) } - } + }} } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 7b0e014f9ca4..267698d1bca5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -125,7 +125,8 @@ case class FlatMapGroupsInPandasExec( val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes val dedupSchema = StructType.fromAttributes(dedupAttributes) - inputRDD.mapPartitionsInternal { iter => + // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty + inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { val grouped = if (groupingAttributes.isEmpty) { Iterator(iter) } else { @@ -156,6 +157,6 @@ case class FlatMapGroupsInPandasExec( flattenedBatch.setNumRows(batch.numRows()) flattenedBatch.rowIterator.asScala }.map(unsafeProj) - } + }} } }