-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21100][SQL] Add summary method as alternative to describe that gives quartiles similar to Pandas #18307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
9b094c8
459a6e1
cf289f9
4fe081d
0d764fd
5d73422
6590f1a
f052665
32f3b82
cba1b0e
38ec8dd
3b548cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,18 +38,18 @@ import org.apache.spark.sql.catalyst.analysis._ | |
| import org.apache.spark.sql.catalyst.catalog.CatalogRelation | ||
| import org.apache.spark.sql.catalyst.encoders._ | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
| import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} | ||
| import org.apache.spark.sql.catalyst.optimizer.CombineUnions | ||
| import org.apache.spark.sql.catalyst.parser.ParseException | ||
| import org.apache.spark.sql.catalyst.plans._ | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} | ||
| import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} | ||
| import org.apache.spark.sql.catalyst.util.DateTimeUtils | ||
| import org.apache.spark.sql.execution._ | ||
| import org.apache.spark.sql.execution.command._ | ||
| import org.apache.spark.sql.execution.datasources.LogicalRelation | ||
| import org.apache.spark.sql.execution.python.EvaluatePython | ||
| import org.apache.spark.sql.execution.stat.StatFunctions | ||
| import org.apache.spark.sql.streaming.DataStreamWriter | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.storage.StorageLevel | ||
|
|
@@ -224,7 +224,7 @@ class Dataset[T] private[sql]( | |
| } | ||
| } | ||
|
|
||
| private def aggregatableColumns: Seq[Expression] = { | ||
| private[sql] def aggregatableColumns: Seq[Expression] = { | ||
| schema.fields | ||
| .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) | ||
| .map { n => | ||
|
|
@@ -2185,9 +2185,9 @@ class Dataset[T] private[sql]( | |
| } | ||
|
|
||
| /** | ||
| * Computes statistics for numeric and string columns, including count, mean, stddev, min, and | ||
| * max. If no columns are given, this function computes statistics for all numerical or string | ||
| * columns. | ||
| * Computes basic statistics for numeric and string columns, including count, mean, stddev, min, | ||
| * and max. If no columns are given, this function computes statistics for all numerical or | ||
| * string columns. | ||
| * | ||
| * This function is meant for exploratory data analysis, as we make no guarantee about the | ||
| * backward compatibility of the schema of the resulting Dataset. If you want to | ||
|
|
@@ -2205,46 +2205,79 @@ class Dataset[T] private[sql]( | |
| * // max 92.0 192.0 | ||
| * }}} | ||
| * | ||
| * Use [[summary]] for expanded statistics and control over which statistics to compute. | ||
| * | ||
| * @param cols Columns to compute statistics on. | ||
| * | ||
| * @group action | ||
| * @since 1.6.0 | ||
| */ | ||
| @scala.annotation.varargs | ||
| def describe(cols: String*): DataFrame = withPlan { | ||
|
|
||
| // The list of summary statistics to compute, in the form of expressions. | ||
| val statistics = List[(String, Expression => Expression)]( | ||
| "count" -> ((child: Expression) => Count(child).toAggregateExpression()), | ||
| "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), | ||
| "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), | ||
| "min" -> ((child: Expression) => Min(child).toAggregateExpression()), | ||
| "max" -> ((child: Expression) => Max(child).toAggregateExpression())) | ||
|
|
||
| val outputCols = | ||
| (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList | ||
|
|
||
| val ret: Seq[Row] = if (outputCols.nonEmpty) { | ||
| val aggExprs = statistics.flatMap { case (_, colToAgg) => | ||
| outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) | ||
| } | ||
|
|
||
| val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq | ||
|
|
||
| // Pivot the data so each summary is one row | ||
| row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => | ||
| Row(statistic :: aggregation.toList: _*) | ||
| } | ||
| } else { | ||
| // If there are no output columns, just output a single column that contains the stats. | ||
| statistics.map { case (name, _) => Row(name) } | ||
| } | ||
|
|
||
| // All columns are string type | ||
| val schema = StructType( | ||
| StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes | ||
| // `toArray` forces materialization to make the seq serializable | ||
| LocalRelation.fromExternalRows(schema, ret.toArray.toSeq) | ||
| def describe(cols: String*): DataFrame = { | ||
| val selected = if (cols.isEmpty) this else select(cols.head, cols.tail: _*) | ||
| selected.summary("count", "mean", "stddev", "min", "max") | ||
| } | ||
|
|
||
| /** | ||
| * Computes specified statistics for numeric and string columns. Available statistics are: | ||
| * | ||
| * - count | ||
| * - mean | ||
| * - stddev | ||
| * - min | ||
| * - max | ||
| * - arbitrary approximate percentiles specified as a percentage (eg, 75%) | ||
| * | ||
| * If no statistics are given, this function computes count, mean, stddev, min, | ||
| * approximate quartiles, and max. | ||
|
||
| * | ||
| * This function is meant for exploratory data analysis, as we make no guarantee about the | ||
| * backward compatibility of the schema of the resulting Dataset. If you want to | ||
| * programmatically compute summary statistics, use the `agg` function instead. | ||
| * | ||
| * {{{ | ||
| * ds.summary().show() | ||
| * | ||
| * // output: | ||
| * // summary age height | ||
| * // count 10.0 10.0 | ||
| * // mean 53.3 178.05 | ||
| * // stddev 11.6 15.7 | ||
| * // min 18.0 163.0 | ||
| * // 25% 24.0 176.0 | ||
| * // 50% 24.0 176.0 | ||
| * // 75% 32.0 180.0 | ||
| * // max 92.0 192.0 | ||
| * }}} | ||
| * | ||
| * {{{ | ||
| * ds.summary("count", "min", "25%", "75%", "max").show() | ||
| * | ||
| * // output: | ||
| * // summary age height | ||
| * // count 10.0 10.0 | ||
| * // min 18.0 163.0 | ||
| * // 25% 24.0 176.0 | ||
| * // 75% 32.0 180.0 | ||
| * // max 92.0 192.0 | ||
| * }}} | ||
| * | ||
| * To do a summary for specific columns first select them: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a better usage than the previous |
||
| * | ||
| * {{{ | ||
| * ds.select("age", "height").summary().show() | ||
| * }}} | ||
| * | ||
| * See also [[describe]] for basic statistics. | ||
| * | ||
| * @param statistics Statistics from above list to be computed. | ||
| * | ||
| * @group action | ||
| * @since 2.3.0 | ||
| */ | ||
| @scala.annotation.varargs | ||
| def summary(statistics: String*): DataFrame = StatFunctions.summary(this, statistics.toSeq) | ||
|
|
||
| /** | ||
| * Returns the first `n` rows. | ||
| * | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.stat | |
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.catalyst.expressions.{Cast, GenericInternalRow} | ||
| import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal} | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
| import org.apache.spark.sql.catalyst.plans.logical.LocalRelation | ||
| import org.apache.spark.sql.catalyst.util.QuantileSummaries | ||
| import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
|
|
@@ -220,4 +221,97 @@ object StatFunctions extends Logging { | |
|
|
||
| Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0) | ||
| } | ||
|
|
||
| /** Calculate selected summary statistics for a dataset */ | ||
| def summary[T](ds: Dataset[T], statistics: Seq[String]): DataFrame = { | ||
|
||
|
|
||
| val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") | ||
| val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics | ||
|
|
||
| val hasPercentiles = selectedStatistics.exists(_.endsWith("%")) | ||
| val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { | ||
| val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%")) | ||
| val percentiles = pStrings.map { p => | ||
| try { | ||
| p.stripSuffix("%").toDouble / 100.0 | ||
| } catch { | ||
| case e: NumberFormatException => | ||
| throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) | ||
| } | ||
| } | ||
| require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") | ||
| (percentiles, pStrings, rest) | ||
| } else { | ||
| (Seq(), Seq(), selectedStatistics) | ||
| } | ||
|
|
||
|
|
||
| // The list of summary statistics to compute, in the form of expressions. | ||
| val availableStatistics = Map[String, Expression => Expression]( | ||
| "count" -> ((child: Expression) => Count(child).toAggregateExpression()), | ||
| "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), | ||
| "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), | ||
| "min" -> ((child: Expression) => Min(child).toAggregateExpression()), | ||
| "max" -> ((child: Expression) => Max(child).toAggregateExpression())) | ||
|
|
||
| val statisticFns = remainingAggregates.map { agg => | ||
| require(availableStatistics.contains(agg), s"$agg is not a recognised statistic") | ||
| agg -> availableStatistics(agg) | ||
| } | ||
|
|
||
| def percentileAgg(child: Expression): Expression = | ||
| new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) | ||
| .toAggregateExpression() | ||
|
|
||
| val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList | ||
|
|
||
| val ret: Seq[Row] = if (outputCols.nonEmpty) { | ||
| var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => | ||
| outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) | ||
| } | ||
| if (hasPercentiles) { | ||
| aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs | ||
| } | ||
|
|
||
| val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq | ||
|
|
||
| // Pivot the data so each summary is one row | ||
| val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq | ||
|
|
||
| val basicStats = if (hasPercentiles) grouped.tail else grouped | ||
|
|
||
| val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) => | ||
| Row(statistic :: aggregation.toList: _*) | ||
| } | ||
|
|
||
| if (hasPercentiles) { | ||
| def nullSafeString(x: Any) = if (x == null) null else x.toString | ||
| val percentileRows = grouped.head | ||
| .map { | ||
| case a: Seq[Any] => a | ||
| case _ => Seq.fill(percentiles.length)(null: Any) | ||
| } | ||
| .transpose | ||
| .zip(percentileNames) | ||
| .map { case (values: Seq[Any], name) => | ||
| Row(name :: values.map(nullSafeString).toList: _*) | ||
| } | ||
| (rows ++ percentileRows) | ||
| .sortWith((left, right) => | ||
| selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0))) | ||
| } else { | ||
| rows | ||
| } | ||
| } else { | ||
| // If there are no output columns, just output a single column that contains the stats. | ||
| selectedStatistics.map(Row(_)) | ||
| } | ||
|
|
||
| // All columns are string type | ||
| val schema = StructType( | ||
| StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes | ||
| // `toArray` forces materialization to make the seq serializable | ||
| Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)) | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,8 +28,7 @@ import org.scalatest.Matchers._ | |
|
|
||
| import org.apache.spark.SparkException | ||
| import org.apache.spark.sql.catalyst.TableIdentifier | ||
| import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union} | ||
| import org.apache.spark.sql.catalyst.util.DateTimeUtils | ||
| import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} | ||
| import org.apache.spark.sql.execution.{FilterExec, QueryExecution} | ||
| import org.apache.spark.sql.execution.aggregate.HashAggregateExec | ||
| import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} | ||
|
|
@@ -664,11 +663,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { | |
| } | ||
|
|
||
| test("describe") { | ||
| val describeTestData = Seq( | ||
| ("Bob", 16, 176), | ||
| ("Alice", 32, 164), | ||
| ("David", 60, 192), | ||
| ("Amy", 24, 180)).toDF("name", "age", "height") | ||
| val describeTestData = person2 | ||
|
|
||
| val describeResult = Seq( | ||
| Row("count", "4", "4", "4"), | ||
|
|
@@ -712,6 +707,76 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { | |
| checkAnswer(emptyDescription, emptyDescribeResult) | ||
| } | ||
|
|
||
| test("summary") { | ||
| val describeTestData = person2 | ||
|
||
|
|
||
| val describeResult = Seq( | ||
|
||
| Row("count", "4", "4", "4"), | ||
| Row("mean", null, "33.0", "178.0"), | ||
| Row("stddev", null, "19.148542155126762", "11.547005383792516"), | ||
| Row("min", "Alice", "16", "164"), | ||
| Row("25%", null, "24.0", "176.0"), | ||
| Row("50%", null, "24.0", "176.0"), | ||
| Row("75%", null, "32.0", "180.0"), | ||
| Row("max", "David", "60", "192")) | ||
|
|
||
| val emptyDescribeResult = Seq( | ||
| Row("count", "0", "0", "0"), | ||
| Row("mean", null, null, null), | ||
| Row("stddev", null, null, null), | ||
| Row("min", null, null, null), | ||
| Row("25%", null, null, null), | ||
| Row("50%", null, null, null), | ||
| Row("75%", null, null, null), | ||
| Row("max", null, null, null)) | ||
|
|
||
| def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) | ||
|
|
||
| val describeTwoCols = describeTestData.summary() | ||
| assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height")) | ||
| checkAnswer(describeTwoCols, describeResult) | ||
| // All aggregate value should have been cast to string | ||
| describeTwoCols.collect().foreach { row => | ||
| assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) | ||
| assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass) | ||
| } | ||
|
|
||
| val describeAllCols = describeTestData.summary() | ||
| assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) | ||
| checkAnswer(describeAllCols, describeResult) | ||
|
|
||
| val describeOneCol = describeTestData.select("age").summary() | ||
| assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) | ||
| checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) | ||
|
|
||
| val describeNoCol = describeTestData.select("name").summary() | ||
|
||
| assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name")) | ||
| checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} ) | ||
|
|
||
| val emptyDescription = describeTestData.limit(0).summary() | ||
| assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) | ||
| checkAnswer(emptyDescription, emptyDescribeResult) | ||
| } | ||
|
|
||
| test("summary advanced") { | ||
| val stats = Array("count", "50.01%", "max", "mean", "min", "25%") | ||
| val orderMatters = person2.summary(stats: _*) | ||
| assert(orderMatters.collect().map(_.getString(0)) === stats) | ||
|
|
||
| val onlyPercentiles = person2.summary("0.1%", "99.9%") | ||
| assert(onlyPercentiles.count() === 2) | ||
|
|
||
| val fooE = intercept[IllegalArgumentException] { | ||
| person2.summary("foo") | ||
| } | ||
| assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic") | ||
|
|
||
| val parseE = intercept[IllegalArgumentException] { | ||
| person2.summary("foo%") | ||
| } | ||
| assert(parseE.getMessage === "Unable to parse foo% as a percentile") | ||
| } | ||
|
|
||
| test("apply on query results (SPARK-5462)") { | ||
| val df = testData.sparkSession.sql("select key from testData") | ||
| checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -230,6 +230,16 @@ private[sql] trait SQLTestData { self => | |
| df | ||
| } | ||
|
|
||
| protected lazy val person2: DataFrame = { | ||
|
||
| val df = spark.sparkContext.parallelize( | ||
| Person2("Bob", 16, 176) :: | ||
| Person2("Alice", 32, 164) :: | ||
| Person2("David", 60, 192) :: | ||
| Person2("Amy", 24, 180) :: Nil).toDF() | ||
| df.createOrReplaceTempView("person2") | ||
| df | ||
| } | ||
|
|
||
| protected lazy val salary: DataFrame = { | ||
| val df = spark.sparkContext.parallelize( | ||
| Salary(0, 2000.0) :: | ||
|
|
@@ -310,6 +320,7 @@ private[sql] object SQLTestData { | |
| case class NullStrings(n: Int, s: String) | ||
| case class TableName(tableName: String) | ||
| case class Person(id: Int, name: String, age: Int) | ||
| case class Person2(name: String, age: Int, height: Int) | ||
| case class Salary(personId: Int, salary: Double) | ||
| case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) | ||
| case class CourseSales(course: String, year: Int, earnings: Double) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd also give an example of how to compute summary for specific columns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and explain the difference between describe and summary (basically summary seems easier to use).