From 889e9223510c821f359ee3ce5bec6ce2f746a027 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 2 Jul 2018 19:09:19 +0200 Subject: [PATCH 01/18] Adding pivot() which takes Column as its argument --- .../spark/sql/RelationalGroupedDataset.scala | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index c6449cd5a16b..e2f6de92d77e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -340,29 +340,23 @@ class RelationalGroupedDataset protected[sql]( /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. * * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings") + * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") * }}} * - * @param pivotColumn Name of the column to pivot. + * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 */ - def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case RelationalGroupedDataset.GroupByType => new RelationalGroupedDataset( df, groupingExprs, - RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => @@ -370,6 +364,28 @@ class RelationalGroupedDataset protected[sql]( } } + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + pivot(Column(pivotColumn), values) + } + /** * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified * aggregation. From f736ea2bca27fee37281bcadb333e7b6bbcd6124 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 2 Jul 2018 19:21:07 +0200 Subject: [PATCH 02/18] Tests for new function --- .../spark/sql/DataFramePivotSuite.scala | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 6ca9ee57e8f4..c934a93e2e79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("pivot courses") { + val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings")), - Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java")) + .agg(sum($"earnings")), + expected) } test("pivot year") { + val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), - Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)), + expected) } test("pivot courses with multiple aggregations") { + val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: + Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy($"year") .pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings"), avg($"earnings")), - Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: - Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year") + .pivot($"course", Seq("dotNET", "Java")) + .agg(sum($"earnings"), avg($"earnings")), + expected) } test("pivot year with string values (cast)") { @@ -181,10 +193,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { } test("pivot with datatype not supported by PivotFirst") { + val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil checkAnswer( complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")), - Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil - ) + expected) + checkAnswer( + complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)), + expected) } test("pivot with datatype not supported by PivotFirst 2") { From 5e6822650f4781c343d477589bf252c37b8453c4 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 2 Jul 2018 20:24:18 +0200 Subject: [PATCH 03/18] the since tag is updated --- .../scala/org/apache/spark/sql/RelationalGroupedDataset.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index e2f6de92d77e..91e44fe9d566 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -348,7 +348,7 @@ class RelationalGroupedDataset protected[sql]( * * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 + * @since 2.4.0 */ def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { From c82c3979065aba48536a743ebf3384f3c95b570c Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 2 Jul 2018 21:05:48 +0200 Subject: [PATCH 04/18] Test for nested columns --- .../org/apache/spark/sql/DataFramePivotSuite.scala | 10 ++++++++++ .../org/apache/spark/sql/test/SQLTestData.scala | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index c934a93e2e79..1cc9b3dc7e89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -261,4 +261,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone)) } } + + test("pivot trainings - nested columns") { + val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + checkAnswer( + trainingSales + .groupBy($"sales.year") + .pivot($"sales.course", Seq("dotNET", "Java")) + .agg(sum($"sales.earnings")), + expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 0cfe260e5215..85af2ebf9038 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -255,6 +255,17 @@ private[sql] trait SQLTestData { self => df } + protected lazy val trainingSales: DataFrame = { + val df = spark.sparkContext.parallelize( + TrainingSales("Experts", CourseSales("dotNET", 2012, 10000)) :: + TrainingSales("Experts", CourseSales("Java", 2012, 20000)) :: + TrainingSales("Dummies", CourseSales("dotNET", 2012, 5000)) :: + TrainingSales("Experts", CourseSales("dotNET", 2013, 48000)) :: + TrainingSales("Dummies", CourseSales("Java", 2013, 30000)) :: Nil).toDF() + df.createOrReplaceTempView("trainingSales") + df + } + /** * Initialize all test data such that all temp tables are properly registered. */ @@ -310,4 +321,5 @@ private[sql] object SQLTestData { case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) case class CourseSales(course: String, year: Int, earnings: Double) + case class TrainingSales(training: String, sales: CourseSales) } From 7d0d2261cef4c66226cd59635603391faabf0046 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 2 Jul 2018 22:38:02 +0200 Subject: [PATCH 05/18] Python test for nested columns --- python/pyspark/sql/group.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 0906c9c6b329..cc1da8e7c1f7 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -211,6 +211,8 @@ def pivot(self, pivot_col, values=None): >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] + >>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect() + [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ if values is None: jgd = self._jgd.pivot(pivot_col) @@ -296,6 +298,12 @@ def _test(): Row(course="dotNET", year=2012, earnings=5000), Row(course="dotNET", year=2013, earnings=48000), Row(course="Java", year=2013, earnings=30000)]).toDF() + globs['df5'] = sc.parallelize([ + Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)), + Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)), + Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)), + Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)), + Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, From 0fdd11ff26b4f4ca3b79bdd116aaf1c558643698 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 2 Jul 2018 23:02:58 +0200 Subject: [PATCH 06/18] Adding ticket number to test's title --- .../test/scala/org/apache/spark/sql/DataFramePivotSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 1cc9b3dc7e89..8bee2e27221b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -262,7 +262,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { } } - test("pivot trainings - nested columns") { + test("SPARK-24722: pivot trainings - nested columns") { val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil checkAnswer( trainingSales From 74ddcdd9e41fa59476a07f8ac9606d595a1d6cf9 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 3 Jul 2018 13:07:51 +0200 Subject: [PATCH 07/18] Making diff shorter --- .../spark/sql/RelationalGroupedDataset.scala | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 91e44fe9d566..19040bb9d7bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -338,32 +338,6 @@ class RelationalGroupedDataset protected[sql]( pivot(pivotColumn, values) } - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") - * }}} - * - * @param pivotColumn the column to pivot. - * @param values List of values that will be translated to columns in the output DataFrame. - * @since 2.4.0 - */ - def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { - groupType match { - case RelationalGroupedDataset.GroupByType => - new RelationalGroupedDataset( - df, - groupingExprs, - RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) - case _: RelationalGroupedDataset.PivotType => - throw new UnsupportedOperationException("repeated pivots are not supported") - case _ => - throw new UnsupportedOperationException("pivot is only supported after a groupBy") - } - } - /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list @@ -410,6 +384,32 @@ class RelationalGroupedDataset protected[sql]( pivot(pivotColumn, values.asScala) } + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") + * }}} + * + * @param pivotColumn the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 2.4.0 + */ + def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { + groupType match { + case RelationalGroupedDataset.GroupByType => + new RelationalGroupedDataset( + df, + groupingExprs, + RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) + case _: RelationalGroupedDataset.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + } + /** * Applies the given serialized R function `func` to each group of data. For each unique group, * the function will be passed the group key and an iterator that contains all of the elements in From 390d832e11f8293f89d3f635f9fdc4b3926d356f Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 3 Jul 2018 13:18:13 +0200 Subject: [PATCH 08/18] Adding a function which accepts Column argument --- .../spark/sql/RelationalGroupedDataset.scala | 57 ++++++++++++------- .../spark/sql/DataFramePivotSuite.scala | 14 +++-- 2 files changed, 45 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 19040bb9d7bf..43ac97f88db9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -315,28 +315,7 @@ class RelationalGroupedDataset protected[sql]( * @param pivotColumn Name of the column to pivot. * @since 1.6.0 */ - def pivot(pivotColumn: String): RelationalGroupedDataset = { - // This is to prevent unintended OOM errors when the number of distinct values is large - val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues - // Get the distinct values of the column and sort them so its consistent - val values = df.select(pivotColumn) - .distinct() - .limit(maxValues + 1) - .sort(pivotColumn) // ensure that the output columns are in a consistent logical order - .collect() - .map(_.get(0)) - .toSeq - - if (values.length > maxValues) { - throw new AnalysisException( - s"The pivot column $pivotColumn has more than $maxValues distinct values, " + - "this could indicate an error. " + - s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " + - "to at least the number of distinct values of the pivot column.") - } - - pivot(pivotColumn, values) - } + def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn)) /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. @@ -384,6 +363,40 @@ class RelationalGroupedDataset protected[sql]( pivot(pivotColumn, values.asScala) } + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * + * {{{ + * // Or without specifying column values (less efficient) + * df.groupBy($"year").pivot($"course").sum($"earnings"); + * }}} + * + * @param pivotColumn he column to pivot. + * @since 2.4.0 + */ + def pivot(pivotColumn: Column): RelationalGroupedDataset = { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .limit(maxValues + 1) + .sort(pivotColumn) // ensure that the output columns are in a consistent logical order + .collect() + .map(_.get(0)) + .toSeq + + if (values.length > maxValues) { + throw new AnalysisException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " + + "to at least the number of distinct values of the pivot column.") + } + + pivot(pivotColumn, values) + } + /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 8bee2e27221b..8173ca928a1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -79,17 +79,23 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { test("pivot courses with no values") { // Note Java comes before dotNet in sorted order + val expected = Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil checkAnswer( courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), - Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + expected) } test("pivot year with no values") { + val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), - Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + expected) } test("pivot max values enforced") { From d62b7e789f38219b62fb5b010fb2cacc0324fe29 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 3 Jul 2018 13:23:38 +0200 Subject: [PATCH 09/18] Adding Java-specific functions --- .../spark/sql/RelationalGroupedDataset.scala | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 43ac97f88db9..3d3a78a73193 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -360,7 +360,7 @@ class RelationalGroupedDataset protected[sql]( * @since 1.6.0 */ def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { - pivot(pivotColumn, values.asScala) + pivot(Column(pivotColumn), values) } /** @@ -423,6 +423,18 @@ class RelationalGroupedDataset protected[sql]( } } + /** + * (Java-specific) Pivots a column of the current `DataFrame` + * and performs the specified aggregation. + * + * @param pivotColumn the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 2.4.0 + */ + def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { + pivot(pivotColumn, values.asScala) + } + /** * Applies the given serialized R function `func` to each group of data. For each unique group, * the function will be passed the group key and an iterator that contains all of the elements in From fae4fd2f607c0b44adb03827039f26c5ff592d31 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 3 Jul 2018 22:06:56 +0200 Subject: [PATCH 10/18] Tests for column expression --- .../org/apache/spark/sql/DataFramePivotSuite.scala | 12 ++++++------ .../org/apache/spark/sql/test/SQLTestData.scala | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 8173ca928a1b..521fbffa4395 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -270,11 +270,11 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { test("SPARK-24722: pivot trainings - nested columns") { val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil - checkAnswer( - trainingSales - .groupBy($"sales.year") - .pivot($"sales.course", Seq("dotNET", "Java")) - .agg(sum($"sales.earnings")), - expected) + val df = trainingSales + .groupBy($"sales.year") + .pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase)) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 85af2ebf9038..90f5c3b73ae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -258,8 +258,8 @@ private[sql] trait SQLTestData { self => protected lazy val trainingSales: DataFrame = { val df = spark.sparkContext.parallelize( TrainingSales("Experts", CourseSales("dotNET", 2012, 10000)) :: - TrainingSales("Experts", CourseSales("Java", 2012, 20000)) :: - TrainingSales("Dummies", CourseSales("dotNET", 2012, 5000)) :: + TrainingSales("Experts", CourseSales("JAVA", 2012, 20000)) :: + TrainingSales("Dummies", CourseSales("dotNet", 2012, 5000)) :: TrainingSales("Experts", CourseSales("dotNET", 2013, 48000)) :: TrainingSales("Dummies", CourseSales("Java", 2013, 30000)) :: Nil).toDF() df.createOrReplaceTempView("trainingSales") From 8ffdc32834596e49a6bb0160f74d0e7d1107f625 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 8 Jul 2018 17:37:04 +0200 Subject: [PATCH 11/18] Tests for pivot column which refers to multiple other columns --- .../org/apache/spark/sql/DataFramePivotSuite.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 521fbffa4395..77b3a034e228 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -268,7 +268,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-24722: pivot trainings - nested columns") { + test("SPARK-24722: pivoting nested columns") { val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil val df = trainingSales .groupBy($"sales.year") @@ -277,4 +277,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { checkAnswer(df, expected) } + + test("SPARK-24722: references to multiple columns in the pivot column") { + val expected = Row(2012, 10000.0) :: Row(2013, 48000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(concat_ws("-", $"training", $"sales.course"), Seq("Experts-dotNET")) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } } From 57c0f64651e7ce1f2dbf2383fd5325a3e87aa494 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 8 Jul 2018 18:28:41 +0200 Subject: [PATCH 12/18] Test pivoting by constant --- .../org/apache/spark/sql/DataFramePivotSuite.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 77b3a034e228..3df8d46dee63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -287,4 +287,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { checkAnswer(df, expected) } + + test("SPARK-24722: pivoting by a constant") { + val expected = Row(2012, 35000.0) :: Row(2013, 78000.0) :: Nil + val df1 = trainingSales + .groupBy($"sales.year") + .pivot(lit(123), Seq(123)) + .agg(sum($"sales.earnings")) + + checkAnswer(df1, expected) + } } From f32a85bd7d114adb85e7281e2a039b383392a17b Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 8 Jul 2018 18:29:20 +0200 Subject: [PATCH 13/18] Test for pivoting by an aggregate --- .../org/apache/spark/sql/DataFramePivotSuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 3df8d46dee63..b972b9ef93e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -297,4 +297,15 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { checkAnswer(df1, expected) } + + test("SPARK-24722: aggregate as the pivot column") { + val exception = intercept[AnalysisException] { + trainingSales + .groupBy($"sales.year") + .pivot(min($"training"), Seq("Experts")) + .agg(sum($"sales.earnings")) + } + + assert(exception.getMessage.contains("aggregate functions are not allowed")) + } } From e76e7adcca6787cb334b19f8db35f3a4ec61bafc Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 28 Jul 2018 21:34:29 +0200 Subject: [PATCH 14/18] Improving comments by referencing to the overloaded methods --- .../org/apache/spark/sql/RelationalGroupedDataset.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 4206dfb9ad76..c8f9706ac8b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -365,6 +365,7 @@ class RelationalGroupedDataset protected[sql]( /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. * * {{{ * // Or without specifying column values (less efficient) @@ -399,6 +400,7 @@ class RelationalGroupedDataset protected[sql]( /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. * * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column @@ -424,8 +426,9 @@ class RelationalGroupedDataset protected[sql]( } /** - * (Java-specific) Pivots a column of the current `DataFrame` - * and performs the specified aggregation. + * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified + * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of + * the `String` type. * * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. From 34535a9cc5ec7a2ba880f7f525feb7dbbc0b0c37 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 29 Jul 2018 11:35:41 +0200 Subject: [PATCH 15/18] Fix expected error message --- .../test/scala/org/apache/spark/sql/DataFramePivotSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index b972b9ef93e5..f48364616fd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -306,6 +306,6 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { .agg(sum($"sales.earnings")) } - assert(exception.getMessage.contains("aggregate functions are not allowed")) + assert(exception.getMessage.contains("It is not allowed to use an aggregate function")) } } From cf55135f430b2012723f8e09a1aa4651d6c7161b Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 31 Jul 2018 16:55:19 +0200 Subject: [PATCH 16/18] Fix expected message --- .../test/scala/org/apache/spark/sql/DataFramePivotSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index f48364616fd8..b972b9ef93e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -306,6 +306,6 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { .agg(sum($"sales.earnings")) } - assert(exception.getMessage.contains("It is not allowed to use an aggregate function")) + assert(exception.getMessage.contains("aggregate functions are not allowed")) } } From 5da5a2c94a1e99cc3edd920080470b3d17cfc699 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 31 Jul 2018 23:53:05 +0200 Subject: [PATCH 17/18] Support multiple values --- .../spark/sql/RelationalGroupedDataset.scala | 11 +++++-- .../spark/sql/DataFramePivotSuite.scala | 32 +++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index ed39eac5598d..7206a405b51a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -411,12 +412,18 @@ class RelationalGroupedDataset protected[sql]( * @since 2.4.0 */ def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { + import org.apache.spark.sql.functions.struct groupType match { case RelationalGroupedDataset.GroupByType => + val pivotValues = values.map { + case row: GenericRow => struct(row.values.map(lit(_)): _*).expr + case multi: Iterable[Any] => struct(multi.map(lit(_)).toSeq: _*).expr + case single => lit(single).expr + } new RelationalGroupedDataset( df, groupingExprs, - RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) + RelationalGroupedDataset.PivotType(pivotColumn.expr, pivotValues)) case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => @@ -558,5 +565,5 @@ private[sql] object RelationalGroupedDataset { /** * To indicate it's the PIVOT */ - private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Expression]) extends GroupType } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index b972b9ef93e5..f99ba8dfef36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -308,4 +308,36 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { assert(exception.getMessage.contains("aggregate functions are not allowed")) } + + test("SPARK-24722: pivoting column list with values") { + val expected = Row(2012, 10000.0, null) :: Row(2013, 48000.0, 30000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(struct(lower($"sales.course"), $"training"), Seq( + struct(lit("dotnet"), lit("Experts")), + struct(lit("java"), lit("Dummies"))) + ).agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + + val df2 = trainingSales + .groupBy($"sales.year") + .pivot(struct(lower($"sales.course"), $"training"), + Seq(Seq("dotnet", "Experts"), Seq("java", "Dummies")) + ).agg(sum($"sales.earnings")) + + checkAnswer(df2, expected) + } + + test("SPARK-24722: pivoting column list") { + val expected = Seq( + Row(2012, 5000.0, 10000.0, null, 20000.0), + Row(2013, null, 48000.0, 30000.0, null)) + val df = trainingSales + .groupBy($"sales.year") + .pivot(struct(lower($"sales.course"), $"training")) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } } From ca1250b29f4edf8f38eb81c27773e04068e0fdf4 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 3 Aug 2018 09:09:26 +0200 Subject: [PATCH 18/18] Revert "Support multiple values" This reverts commit 5da5a2c94a1e99cc3edd920080470b3d17cfc699. --- .../spark/sql/RelationalGroupedDataset.scala | 11 ++----- .../spark/sql/DataFramePivotSuite.scala | 32 ------------------- 2 files changed, 2 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7206a405b51a..ed39eac5598d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -412,18 +411,12 @@ class RelationalGroupedDataset protected[sql]( * @since 2.4.0 */ def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { - import org.apache.spark.sql.functions.struct groupType match { case RelationalGroupedDataset.GroupByType => - val pivotValues = values.map { - case row: GenericRow => struct(row.values.map(lit(_)): _*).expr - case multi: Iterable[Any] => struct(multi.map(lit(_)).toSeq: _*).expr - case single => lit(single).expr - } new RelationalGroupedDataset( df, groupingExprs, - RelationalGroupedDataset.PivotType(pivotColumn.expr, pivotValues)) + RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => @@ -565,5 +558,5 @@ private[sql] object RelationalGroupedDataset { /** * To indicate it's the PIVOT */ - private[sql] case class PivotType(pivotCol: Expression, values: Seq[Expression]) extends GroupType + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index f99ba8dfef36..b972b9ef93e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -308,36 +308,4 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { assert(exception.getMessage.contains("aggregate functions are not allowed")) } - - test("SPARK-24722: pivoting column list with values") { - val expected = Row(2012, 10000.0, null) :: Row(2013, 48000.0, 30000.0) :: Nil - val df = trainingSales - .groupBy($"sales.year") - .pivot(struct(lower($"sales.course"), $"training"), Seq( - struct(lit("dotnet"), lit("Experts")), - struct(lit("java"), lit("Dummies"))) - ).agg(sum($"sales.earnings")) - - checkAnswer(df, expected) - - val df2 = trainingSales - .groupBy($"sales.year") - .pivot(struct(lower($"sales.course"), $"training"), - Seq(Seq("dotnet", "Experts"), Seq("java", "Dummies")) - ).agg(sum($"sales.earnings")) - - checkAnswer(df2, expected) - } - - test("SPARK-24722: pivoting column list") { - val expected = Seq( - Row(2012, 5000.0, 10000.0, null, 20000.0), - Row(2013, null, 48000.0, 30000.0, null)) - val df = trainingSales - .groupBy($"sales.year") - .pivot(struct(lower($"sales.course"), $"training")) - .agg(sum($"sales.earnings")) - - checkAnswer(df, expected) - } }