Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def pivot(self, pivot_col, values=None):

>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
>>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
"""
if values is None:
jgd = self._jgd.pivot(pivot_col)
Expand Down Expand Up @@ -296,6 +298,12 @@ def _test():
Row(course="dotNET", year=2012, earnings=5000),
Row(course="dotNET", year=2013, earnings=48000),
Row(course="Java", year=2013, earnings=30000)]).toDF()
globs['df5'] = sc.parallelize([
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)),
Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)),
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)),
Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)),
Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF()

(failure_count, test_count) = doctest.testmod(
pyspark.sql.group, globs=globs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,36 +340,52 @@ class RelationalGroupedDataset protected[sql](

/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings")
* df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight improvement: would you like to highlight in the doc example that the Column API gives you nested column access easily?

* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param pivotColumn the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
* @since 2.4.0
*/
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably add a Column version API for pivot() signature with no "values" as well.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groupType match {
case RelationalGroupedDataset.GroupByType =>
new RelationalGroupedDataset(
df,
groupingExprs,
RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply)))
RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply)))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we allow an arbitrary Expression as the pivot column, we might wanna accept both an literal value object or an Expression as the pivot value.
Another thing we need to do is to check groupbyExprs and pivotColumn do not share any column reference, i.e., a column ref cannot appear in both groupByExprs and pivotColumn.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might wanna accept both an literal value object or an Expression as the pivot value.

Are you sure about arbitrary expression for pivot values. Looking at the implementation, we still expect literals. I am not sure that we should give users broader choice in API.

case _: RelationalGroupedDataset.PivotType =>
throw new UnsupportedOperationException("repeated pivots are not supported")
case _ =>
throw new UnsupportedOperationException("pivot is only supported after a groupBy")
}
}

/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
*/
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
pivot(Column(pivotColumn), values)
}

/**
* (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
* aggregation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
import testImplicits._

test("pivot courses") {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need more tests here for Pivot column as:

  1. a constant value: Although this might not cause any correctness issue, I suggest we add a check for this in our pivot method and throw an exception.
  2. an expression concerning two or more column references (we don't need to cover column list here coz it's been covered elsewhere, so think about other operators).
  3. an aggregate expression: verify that we throw the right exception.

For 1 and 3, since we only support one or multiple column references as the Pivot column before this change, Pivot node processing in Analyzer does not perform these checks. Now we need to.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
expected)
}

test("pivot year") {
val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)),
expected)
}

test("pivot courses with multiple aggregations") {
val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy($"year")
.pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings"), avg($"earnings")),
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"year")
.pivot($"course", Seq("dotNET", "Java"))
.agg(sum($"earnings"), avg($"earnings")),
expected)
}

test("pivot year with string values (cast)") {
Expand Down Expand Up @@ -181,10 +193,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
}

test("pivot with datatype not supported by PivotFirst") {
val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil
checkAnswer(
complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")),
Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil
)
expected)
checkAnswer(
complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)),
expected)
}

test("pivot with datatype not supported by PivotFirst 2") {
Expand Down Expand Up @@ -246,4 +261,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone))
}
}

test("SPARK-24722: pivot trainings - nested columns") {
val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
checkAnswer(
trainingSales
.groupBy($"sales.year")
.pivot($"sales.course", Seq("dotNET", "Java"))
.agg(sum($"sales.earnings")),
expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,17 @@ private[sql] trait SQLTestData { self =>
df
}

protected lazy val trainingSales: DataFrame = {
val df = spark.sparkContext.parallelize(
TrainingSales("Experts", CourseSales("dotNET", 2012, 10000)) ::
TrainingSales("Experts", CourseSales("Java", 2012, 20000)) ::
TrainingSales("Dummies", CourseSales("dotNET", 2012, 5000)) ::
TrainingSales("Experts", CourseSales("dotNET", 2013, 48000)) ::
TrainingSales("Dummies", CourseSales("Java", 2013, 30000)) :: Nil).toDF()
df.createOrReplaceTempView("trainingSales")
df
}

/**
* Initialize all test data such that all temp tables are properly registered.
*/
Expand Down Expand Up @@ -310,4 +321,5 @@ private[sql] object SQLTestData {
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
case class CourseSales(course: String, year: Int, earnings: Double)
case class TrainingSales(training: String, sales: CourseSales)
}