diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py index 6981601cb129..6c84bd740171 100644 --- a/python/pyspark/sql/tests/test_group.py +++ b/python/pyspark/sql/tests/test_group.py @@ -22,6 +22,29 @@ class GroupTestsMixin: + def test_agg_func(self): + data = [Row(key=1, value=10), Row(key=1, value=20), Row(key=1, value=30)] + df = self.spark.createDataFrame(data) + g = df.groupBy("key") + self.assertEqual(g.max("value").collect(), [Row(**{"key": 1, "max(value)": 30})]) + self.assertEqual(g.min("value").collect(), [Row(**{"key": 1, "min(value)": 10})]) + self.assertEqual(g.sum("value").collect(), [Row(**{"key": 1, "sum(value)": 60})]) + self.assertEqual(g.count().collect(), [Row(key=1, count=3)]) + self.assertEqual(g.mean("value").collect(), [Row(**{"key": 1, "avg(value)": 20.0})]) + + data = [ + Row(electronic="Smartphone", year=2018, sales=150000), + Row(electronic="Tablet", year=2018, sales=120000), + Row(electronic="Smartphone", year=2019, sales=180000), + Row(electronic="Tablet", year=2019, sales=50000), + ] + + df_pivot = self.spark.createDataFrame(data) + assertDataFrameEqual( + df_pivot.groupBy("year").pivot("electronic", ["Smartphone", "Tablet"]).sum("sales"), + df_pivot.groupBy("year").pivot("electronic").sum("sales"), + ) + def test_aggregator(self): df = self.df g = df.groupBy()