-
Notifications
You must be signed in to change notification settings - Fork 4
[SPARK-24722][SQL] pivot() with Column type argument #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
889e922
f736ea2
5e68226
c82c397
7d0d226
0fdd11f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -340,36 +340,52 @@ class RelationalGroupedDataset protected[sql]( | |
|
|
||
| /** | ||
| * Pivots a column of the current `DataFrame` and performs the specified aggregation. | ||
| * There are two versions of pivot function: one that requires the caller to specify the list | ||
| * of distinct values to pivot on, and one that does not. The latter is more concise but less | ||
| * efficient, because Spark needs to first compute the list of distinct values internally. | ||
| * | ||
| * {{{ | ||
| * // Compute the sum of earnings for each year by course with each course as a separate column | ||
| * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") | ||
| * | ||
| * // Or without specifying column values (less efficient) | ||
| * df.groupBy("year").pivot("course").sum("earnings") | ||
| * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") | ||
| * }}} | ||
| * | ||
| * @param pivotColumn Name of the column to pivot. | ||
| * @param pivotColumn the column to pivot. | ||
| * @param values List of values that will be translated to columns in the output DataFrame. | ||
| * @since 1.6.0 | ||
| * @since 2.4.0 | ||
| */ | ||
| def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { | ||
| def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably add a Column version API for pivot() signature with no "values" as well.
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please, look at the PR for OSS Spark apache#21699 , it has the API: https://github.com/apache/spark/pull/21699/files#diff-95bb2228c67e3cce4c729e44e2d82422R377 |
||
| groupType match { | ||
| case RelationalGroupedDataset.GroupByType => | ||
| new RelationalGroupedDataset( | ||
| df, | ||
| groupingExprs, | ||
| RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) | ||
| RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that we allow an arbitrary Expression as the pivot column, we might wanna accept both an literal value object or an Expression as the pivot value.
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Are you sure about arbitrary expression for pivot values. Looking at the implementation, we still expect literals. I am not sure that we should give users broader choice in API. |
||
| case _: RelationalGroupedDataset.PivotType => | ||
| throw new UnsupportedOperationException("repeated pivots are not supported") | ||
| case _ => | ||
| throw new UnsupportedOperationException("pivot is only supported after a groupBy") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Pivots a column of the current `DataFrame` and performs the specified aggregation. | ||
| * There are two versions of pivot function: one that requires the caller to specify the list | ||
| * of distinct values to pivot on, and one that does not. The latter is more concise but less | ||
| * efficient, because Spark needs to first compute the list of distinct values internally. | ||
| * | ||
| * {{{ | ||
| * // Compute the sum of earnings for each year by course with each course as a separate column | ||
| * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") | ||
| * | ||
| * // Or without specifying column values (less efficient) | ||
| * df.groupBy("year").pivot("course").sum("earnings") | ||
| * }}} | ||
| * | ||
| * @param pivotColumn Name of the column to pivot. | ||
| * @param values List of values that will be translated to columns in the output DataFrame. | ||
| * @since 1.6.0 | ||
| */ | ||
| def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { | ||
| pivot(Column(pivotColumn), values) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified | ||
| * aggregation. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { | |
| import testImplicits._ | ||
|
|
||
| test("pivot courses") { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need more tests here for Pivot column as:
For 1 and 3, since we only support one or multiple column references as the Pivot column before this change, Pivot node processing in Analyzer does not perform these checks. Now we need to.
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maryannxue I added the tests:
|
||
| val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil | ||
| checkAnswer( | ||
| courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) | ||
| .agg(sum($"earnings")), | ||
| Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil | ||
| ) | ||
| expected) | ||
| checkAnswer( | ||
| courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java")) | ||
| .agg(sum($"earnings")), | ||
| expected) | ||
| } | ||
|
|
||
| test("pivot year") { | ||
| val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil | ||
| checkAnswer( | ||
| courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), | ||
| Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil | ||
| ) | ||
| expected) | ||
| checkAnswer( | ||
| courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)), | ||
| expected) | ||
| } | ||
|
|
||
| test("pivot courses with multiple aggregations") { | ||
| val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: | ||
| Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil | ||
| checkAnswer( | ||
| courseSales.groupBy($"year") | ||
| .pivot("course", Seq("dotNET", "Java")) | ||
| .agg(sum($"earnings"), avg($"earnings")), | ||
| Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: | ||
| Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil | ||
| ) | ||
| expected) | ||
| checkAnswer( | ||
| courseSales.groupBy($"year") | ||
| .pivot($"course", Seq("dotNET", "Java")) | ||
| .agg(sum($"earnings"), avg($"earnings")), | ||
| expected) | ||
| } | ||
|
|
||
| test("pivot year with string values (cast)") { | ||
|
|
@@ -181,10 +193,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { | |
| } | ||
|
|
||
| test("pivot with datatype not supported by PivotFirst") { | ||
| val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil | ||
| checkAnswer( | ||
| complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")), | ||
| Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil | ||
| ) | ||
| expected) | ||
| checkAnswer( | ||
| complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)), | ||
| expected) | ||
| } | ||
|
|
||
| test("pivot with datatype not supported by PivotFirst 2") { | ||
|
|
@@ -246,4 +261,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { | |
| checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone)) | ||
| } | ||
| } | ||
|
|
||
| test("SPARK-24722: pivot trainings - nested columns") { | ||
| val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil | ||
| checkAnswer( | ||
| trainingSales | ||
| .groupBy($"sales.year") | ||
| .pivot($"sales.course", Seq("dotNET", "Java")) | ||
| .agg(sum($"sales.earnings")), | ||
| expected) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slight improvement: would you like to highlight in the doc example that the Column API gives you nested column access easily?