From 9b094c8c07f904b0e3fb99ddae18182d9667cdcc Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 14 Jun 2017 16:42:48 -0500 Subject: [PATCH 01/12] no message --- .../main/scala/org/apache/spark/sql/Dataset.scala | 15 ++++++++++++--- .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d28ff7888d12..41f7f148cd55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2185,9 +2185,9 @@ class Dataset[T] private[sql]( } /** - * Computes statistics for numeric and string columns, including count, mean, stddev, min, and - * max. If no columns are given, this function computes statistics for all numerical or string - * columns. + * Computes statistics for numeric and string columns, including count, mean, stddev, min, + * 25th, 50th and 75th percentiles and max. If no columns are given, this function computes + * statistics for all numerical or string columns. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting Dataset. If you want to @@ -2202,6 +2202,9 @@ class Dataset[T] private[sql]( * // mean 53.3 178.05 * // stddev 11.6 15.7 * // min 18.0 163.0 + * // 25% 24.0 176.0 + * // 50% 24.0 176.0 + * // 75% 32.0 180.0 * // max 92.0 192.0 * }}} * @@ -2217,6 +2220,12 @@ class Dataset[T] private[sql]( "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), "min" -> ((child: Expression) => Min(child).toAggregateExpression()), + "25%" -> ((child: Expression) => + new ApproximatePercentile(child, Literal(0.25)).toAggregateExpression()), + "50%" -> ((child: Expression) => + new ApproximatePercentile(child, Literal(0.5)).toAggregateExpression()), + "75%" -> ((child: Expression) => + new ApproximatePercentile(child, Literal(0.75)).toAggregateExpression()), "max" -> ((child: Expression) => Max(child).toAggregateExpression())) val outputCols = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9ea9951c24ef..02a31147019f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -675,6 +675,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("mean", null, "33.0", "178.0"), Row("stddev", null, "19.148542155126762", "11.547005383792516"), Row("min", "Alice", "16", "164"), + Row("25%", null, "24.0", "176.0"), + Row("50%", null, "24.0", "176.0"), + Row("75%", null, "32.0", "180.0"), Row("max", "David", "60", "192")) val emptyDescribeResult = Seq( @@ -682,6 +685,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("mean", null, null, null), Row("stddev", null, null, null), Row("min", null, null, null), + Row("25%", null, null, null), + Row("50%", null, null, null), + Row("75%", null, null, null), Row("max", null, null, null)) def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) From 459a6e1d46e2659218a3420ecfd1231db7690e65 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 15 Jun 2017 09:19:33 -0500 Subject: [PATCH 02/12] fix pyspark doctest and documentation --- python/pyspark/sql/dataframe.py | 8 +++++++- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8541403dfe2f..1fd2dfa6857a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -926,7 +926,7 @@ def _sort_cols(self, cols, kwargs): def describe(self, *cols): """Computes statistics for numeric and string columns. - This include count, mean, stddev, min, and max. If no columns are + This include count, mean, stddev, min, approximate quartiles, and max. If no columns are given, this function computes statistics for all numerical or string columns. .. note:: This function is meant for exploratory data analysis, as we make no @@ -940,6 +940,9 @@ def describe(self, *cols): | mean| 3.5| | stddev|2.1213203435596424| | min| 2| + | 25%| 5.0| + | 50%| 5.0| + | 75%| 5.0| | max| 5| +-------+------------------+ >>> df.describe().show() @@ -950,6 +953,9 @@ def describe(self, *cols): | mean| 3.5| null| | stddev|2.1213203435596424| null| | min| 2|Alice| + | 25%| 5.0| null| + | 50%| 5.0| null| + | 75%| 5.0| null| | max| 5| Bob| +-------+------------------+-----+ """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 41f7f148cd55..f482fc590418 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2186,7 +2186,7 @@ class Dataset[T] private[sql]( /** * Computes statistics for numeric and string columns, including count, mean, stddev, min, - * 25th, 50th and 75th percentiles and max. If no columns are given, this function computes + * approximate quartiles, and max. If no columns are given, this function computes * statistics for all numerical or string columns. * * This function is meant for exploratory data analysis, as we make no guarantee about the From cf289f9cb83854f5e9ffb55caac6226eefe6d48f Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 15 Jun 2017 12:27:35 -0500 Subject: [PATCH 03/12] hopefully fix R tests --- R/pkg/tests/fulltests/test_sparkSQL.R | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index af529067f43e..50f0e88243c6 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2485,11 +2485,13 @@ test_that("describe() and summarize() on a DataFrame", { expect_equal(collect(stats)[3, "age"], "7.7781745930520225") stats <- describe(df) expect_equal(collect(stats)[4, "summary"], "min") - expect_equal(collect(stats)[5, "age"], "30") + expect_equal(collect(stats)[5, "summary"], "25%") + expect_equal(collect(stats)[5, "age"], "30.0") + expect_equal(collect(stats)[8, "age"], "30") stats2 <- summary(df) expect_equal(collect(stats2)[4, "summary"], "min") - expect_equal(collect(stats2)[5, "age"], "30") + expect_equal(collect(stats2)[8, "age"], "30") # SPARK-16425: SparkR summary() fails on column of type logical df <- withColumn(df, "boolean", df$age == 30) From 4fe081d9dacf1468dee262ee257f7ff712b8cdf8 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 15 Jun 2017 17:24:05 -0500 Subject: [PATCH 04/12] WIP single agg & selectable percentiles --- .../scala/org/apache/spark/sql/Dataset.scala | 76 ++++++++++++++++--- 1 file changed, 67 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f482fc590418..2a882518002f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2212,7 +2212,40 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def describe(cols: String*): DataFrame = withPlan { + def describe(cols: String*): DataFrame = describe(Array(0.25, 0.5, 0.75), cols: _*) + + /** + * Computes statistics for numeric and string columns, including count, mean, stddev, min, + * approximate quartiles, and max. If no columns are given, this function computes + * statistics for all numerical or string columns. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting Dataset. If you want to + * programmatically compute summary statistics, use the `agg` function instead. + * + * {{{ + * ds.describe("age", "height").show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // 25% 24.0 176.0 + * // 50% 24.0 176.0 + * // 75% 32.0 180.0 + * // max 92.0 192.0 + * }}} + * + * @group action + * @since 1.6.0 + */ + @scala.annotation.varargs + def describe(percentiles: Array[Double], cols: String*): DataFrame = withPlan { + + val hasPercentiles = percentiles.length > 0 + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( @@ -2220,28 +2253,53 @@ class Dataset[T] private[sql]( "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), "min" -> ((child: Expression) => Min(child).toAggregateExpression()), - "25%" -> ((child: Expression) => - new ApproximatePercentile(child, Literal(0.25)).toAggregateExpression()), - "50%" -> ((child: Expression) => - new ApproximatePercentile(child, Literal(0.5)).toAggregateExpression()), - "75%" -> ((child: Expression) => - new ApproximatePercentile(child, Literal(0.75)).toAggregateExpression()), "max" -> ((child: Expression) => Max(child).toAggregateExpression())) + def percentileAgg(child: Expression): Expression = + new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) + .toAggregateExpression() + + val percentileNames = percentiles.map(p => s"${(p * 100.0).round}%") + val outputCols = (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList val ret: Seq[Row] = if (outputCols.nonEmpty) { - val aggExprs = statistics.flatMap { case (_, colToAgg) => + var aggExprs = statistics.flatMap { case (_, colToAgg) => outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) } + if (hasPercentiles) { + aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs + } val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq // Pivot the data so each summary is one row - row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => + val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq + + val basicStats = if (hasPercentiles) grouped.tail else grouped + + val rows = basicStats.zip(statistics).map { case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*) } + + if (hasPercentiles) { + def nullSafeString(x: Any) = if (x == null) null else x.toString + val percentileRows = grouped.head + .map { + case a: Seq[Any] => a + case _ => Seq.fill(percentiles.length)(null: Any) + } + .transpose + .zip(percentileNames) + .map { case (values: Seq[Any], name) => + Row(name :: values.map(nullSafeString).toList: _*) + } + val max :: rest = rows.reverse.toList + rest.reverse ++ percentileRows :+ max + } else { + rows + } } else { // If there are no output columns, just output a single column that contains the stats. statistics.map { case (name, _) => Row(name) } From 0d764fd955e9d72ad729f06d36e109955753d18d Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 16 Jun 2017 10:23:42 -0500 Subject: [PATCH 05/12] fix r tests --- R/pkg/tests/fulltests/test_sparkSQL.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 50f0e88243c6..e3fd6638f898 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2724,15 +2724,15 @@ test_that("attach() on a DataFrame", { expected_age <- data.frame(age = c(NA, 30, 19)) expect_equal(head(age), expected_age) stat <- summary(age) - expect_equal(collect(stat)[5, "age"], "30") + expect_equal(collect(stat)[8, "age"], "30") age <- age$age + 1 expect_is(age, "Column") rm(age) stat2 <- summary(age) - expect_equal(collect(stat2)[5, "age"], "30") + expect_equal(collect(stat2)[8, "age"], "30") detach("df") stat3 <- summary(df[, "age", drop = F]) - expect_equal(collect(stat3)[5, "age"], "30") + expect_equal(collect(stat3)[8, "age"], "30") expect_error(age) }) From 5d73422fa439f25af8681c267264879ffd6c4920 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 28 Jun 2017 16:45:47 -0500 Subject: [PATCH 06/12] Reimplement as describeExtended and describeAdvanced. Add neew unit test. Also revert changes for other languages. --- R/pkg/tests/fulltests/test_sparkSQL.R | 12 +- python/pyspark/sql/dataframe.py | 5 +- .../scala/org/apache/spark/sql/Dataset.scala | 115 ++++++++++++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 85 +++++++++++-- .../apache/spark/sql/test/SQLTestData.scala | 11 ++ 5 files changed, 181 insertions(+), 47 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index e3fd6638f898..af529067f43e 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2485,13 +2485,11 @@ test_that("describe() and summarize() on a DataFrame", { expect_equal(collect(stats)[3, "age"], "7.7781745930520225") stats <- describe(df) expect_equal(collect(stats)[4, "summary"], "min") - expect_equal(collect(stats)[5, "summary"], "25%") - expect_equal(collect(stats)[5, "age"], "30.0") - expect_equal(collect(stats)[8, "age"], "30") + expect_equal(collect(stats)[5, "age"], "30") stats2 <- summary(df) expect_equal(collect(stats2)[4, "summary"], "min") - expect_equal(collect(stats2)[8, "age"], "30") + expect_equal(collect(stats2)[5, "age"], "30") # SPARK-16425: SparkR summary() fails on column of type logical df <- withColumn(df, "boolean", df$age == 30) @@ -2724,15 +2722,15 @@ test_that("attach() on a DataFrame", { expected_age <- data.frame(age = c(NA, 30, 19)) expect_equal(head(age), expected_age) stat <- summary(age) - expect_equal(collect(stat)[8, "age"], "30") + expect_equal(collect(stat)[5, "age"], "30") age <- age$age + 1 expect_is(age, "Column") rm(age) stat2 <- summary(age) - expect_equal(collect(stat2)[8, "age"], "30") + expect_equal(collect(stat2)[5, "age"], "30") detach("df") stat3 <- summary(df[, "age", drop = F]) - expect_equal(collect(stat3)[8, "age"], "30") + expect_equal(collect(stat3)[5, "age"], "30") expect_error(age) }) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1fd2dfa6857a..d4adeb78ba9f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -926,7 +926,7 @@ def _sort_cols(self, cols, kwargs): def describe(self, *cols): """Computes statistics for numeric and string columns. - This include count, mean, stddev, min, approximate quartiles, and max. If no columns are + This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical or string columns. .. note:: This function is meant for exploratory data analysis, as we make no @@ -940,9 +940,6 @@ def describe(self, *cols): | mean| 3.5| | stddev|2.1213203435596424| | min| 2| - | 25%| 5.0| - | 50%| 5.0| - | 75%| 5.0| | max| 5| +-------+------------------+ >>> df.describe().show() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2a882518002f..ef36fc6dac94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.sql.{Date, Timestamp} +import scala.collection.JavaConversions.asJavaCollection import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -39,13 +40,13 @@ import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JSONOptions, JacksonGenerator} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} -import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, usePrettyExpression} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -2185,9 +2186,9 @@ class Dataset[T] private[sql]( } /** - * Computes statistics for numeric and string columns, including count, mean, stddev, min, - * approximate quartiles, and max. If no columns are given, this function computes - * statistics for all numerical or string columns. + * Computes basic statistics for numeric and string columns, including count, mean, stddev, min, + * and max. If no columns are given, this function computes statistics for all numerical or + * string columns. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting Dataset. If you want to @@ -2202,17 +2203,19 @@ class Dataset[T] private[sql]( * // mean 53.3 178.05 * // stddev 11.6 15.7 * // min 18.0 163.0 - * // 25% 24.0 176.0 - * // 50% 24.0 176.0 - * // 75% 32.0 180.0 * // max 92.0 192.0 * }}} * + * See also [[describeExtended]] and [[describeAdvanced]] + * + * @param cols Columns to compute statistics on. + * * @group action * @since 1.6.0 */ @scala.annotation.varargs - def describe(cols: String*): DataFrame = describe(Array(0.25, 0.5, 0.75), cols: _*) + def describe(cols: String*): DataFrame = + describeAdvanced(Array("count", "mean", "stddev", "min", "max"), cols: _*) /** * Computes statistics for numeric and string columns, including count, mean, stddev, min, @@ -2224,7 +2227,7 @@ class Dataset[T] private[sql]( * programmatically compute summary statistics, use the `agg` function instead. * * {{{ - * ds.describe("age", "height").show() + * ds.describeExtended("age", "height").show() * * // output: * // summary age height @@ -2238,34 +2241,95 @@ class Dataset[T] private[sql]( * // max 92.0 192.0 * }}} * + * To specify which statistics or percentiles are desired see [[describeAdvanced]] + * + * @param cols Columns to compute statistics on. + * * @group action - * @since 1.6.0 + * @since 2.3.0 */ @scala.annotation.varargs - def describe(percentiles: Array[Double], cols: String*): DataFrame = withPlan { + def describeExtended(cols: String*): DataFrame = + describeAdvanced(Array("count", "mean", "stddev", "min", "25%", "50%", "75%", "max"), cols: _*) + + /** + * Computes specified statistics for numeric and string columns. Available statistics are: + * + * - count + * - mean + * - stddev + * - min + * - max + * - arbitrary approximate percentiles specifid as a percentage (eg, 75%) + * + * If no columns are given, this function computes statistics for all numerical or string + * columns. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting Dataset. If you want to + * programmatically compute summary statistics, use the `agg` function instead. + * + * {{{ + * ds.describeAdvanced(Array("count", "min", "25%", "75%", "max"), "age", "height").show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // min 18.0 163.0 + * // 25% 24.0 176.0 + * // 75% 32.0 180.0 + * // max 92.0 192.0 + * }}} + * + * @param statistics Statistics from above list to be computed. + * @param cols Columns to compute statistics on. + * + * @group action + * @since 2.3.0 + */ + @scala.annotation.varargs + def describeAdvanced(statistics: Array[String], cols: String*): DataFrame = withPlan { + + val hasPercentiles = statistics.exists(_.endsWith("%")) + val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { + val (pStrings, rest) = statistics.toSeq.partition(a => a.endsWith("%")) + val percentiles = pStrings.map { p => + try { + p.stripSuffix("%").toDouble / 100.0 + } catch { + case e: NumberFormatException => + throw new IllegalArgumentException(s"Unable to parse $p as a double", e) + } + } + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") + (percentiles, pStrings, rest) + } else { + (Seq(), Seq(), statistics.toSeq) + } - val hasPercentiles = percentiles.length > 0 - require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") // The list of summary statistics to compute, in the form of expressions. - val statistics = List[(String, Expression => Expression)]( + val availableStatistics = Map[String, Expression => Expression]( "count" -> ((child: Expression) => Count(child).toAggregateExpression()), "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), "min" -> ((child: Expression) => Min(child).toAggregateExpression()), "max" -> ((child: Expression) => Max(child).toAggregateExpression())) + val statisticFns = remainingAggregates.map { agg => + require(availableStatistics.contains(agg), s"$agg is not a recognised statistic") + agg -> availableStatistics(agg) + } + def percentileAgg(child: Expression): Expression = new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) .toAggregateExpression() - val percentileNames = percentiles.map(p => s"${(p * 100.0).round}%") - val outputCols = (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList - val ret: Seq[Row] = if (outputCols.nonEmpty) { - var aggExprs = statistics.flatMap { case (_, colToAgg) => + val ret: Seq[Row] = if (outputCols.nonEmpty && statistics.nonEmpty) { + var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) } if (hasPercentiles) { @@ -2279,7 +2343,7 @@ class Dataset[T] private[sql]( val basicStats = if (hasPercentiles) grouped.tail else grouped - val rows = basicStats.zip(statistics).map { case (aggregation, (statistic, _)) => + val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*) } @@ -2295,14 +2359,17 @@ class Dataset[T] private[sql]( .map { case (values: Seq[Any], name) => Row(name :: values.map(nullSafeString).toList: _*) } - val max :: rest = rows.reverse.toList - rest.reverse ++ percentileRows :+ max + (rows ++ percentileRows) + .sortWith((left, right) => statistics.indexOf(left(0)) < statistics.indexOf(right(0))) } else { rows } - } else { + } else if (outputCols.isEmpty) { // If there are no output columns, just output a single column that contains the stats. - statistics.map { case (name, _) => Row(name) } + statistics.map(Row(_)) + } else { + // If there are no aggregates, return empty Seq + Seq() } // All columns are string type diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 02a31147019f..a47ce8dd9725 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,8 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} @@ -664,11 +663,52 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("describe") { - val describeTestData = Seq( - ("Bob", 16, 176), - ("Alice", 32, 164), - ("David", 60, 192), - ("Amy", 24, 180)).toDF("name", "age", "height") + val describeTestData = person2 + + val describeResult = Seq( + Row("count", "4", "4", "4"), + Row("mean", null, "33.0", "178.0"), + Row("stddev", null, "19.148542155126762", "11.547005383792516"), + Row("min", "Alice", "16", "164"), + Row("max", "David", "60", "192")) + + val emptyDescribeResult = Seq( + Row("count", "0", "0", "0"), + Row("mean", null, null, null), + Row("stddev", null, null, null), + Row("min", null, null, null), + Row("max", null, null, null)) + + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) + + val describeTwoCols = describeTestData.describe("name", "age", "height") + assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height")) + checkAnswer(describeTwoCols, describeResult) + // All aggregate value should have been cast to string + describeTwoCols.collect().foreach { row => + assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) + assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass) + } + + val describeAllCols = describeTestData.describe() + assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) + checkAnswer(describeAllCols, describeResult) + + val describeOneCol = describeTestData.describe("age") + assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) + checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) + + val describeNoCol = describeTestData.select("name").describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} ) + + val emptyDescription = describeTestData.limit(0).describe() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) + checkAnswer(emptyDescription, emptyDescribeResult) + } + + test("describeExtended") { + val describeTestData = person2 val describeResult = Seq( Row("count", "4", "4", "4"), @@ -692,7 +732,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) - val describeTwoCols = describeTestData.describe("name", "age", "height") + val describeTwoCols = describeTestData.describeExtended("name", "age", "height") assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height")) checkAnswer(describeTwoCols, describeResult) // All aggregate value should have been cast to string @@ -701,23 +741,44 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass) } - val describeAllCols = describeTestData.describe() + val describeAllCols = describeTestData.describeExtended() assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) checkAnswer(describeAllCols, describeResult) - val describeOneCol = describeTestData.describe("age") + val describeOneCol = describeTestData.describeExtended("age") assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) - val describeNoCol = describeTestData.select("name").describe() + val describeNoCol = describeTestData.select("name").describeExtended() assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name")) checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} ) - val emptyDescription = describeTestData.limit(0).describe() + val emptyDescription = describeTestData.limit(0).describeExtended() assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) checkAnswer(emptyDescription, emptyDescribeResult) } + test("describeAdvanced") { + val stats = Array("count", "50.01%", "max", "mean", "min", "25%") + val orderMatters = person2.describeAdvanced(stats) + assert(orderMatters.collect().map(_.getString(0)) === stats) + + val onlyPercentiles = person2.describeAdvanced(Array("0.1%", "99.9%")) + assert(onlyPercentiles.count() === 2) + + val fooE = intercept[IllegalArgumentException] { + person2.describeAdvanced(Array("foo")) + } + assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic") + + val parseE = intercept[IllegalArgumentException] { + person2.describeAdvanced(Array("foo%")) + } + assert(parseE.getMessage === "Unable to parse foo% as a double") + + assert(person2.describeAdvanced(Array()).count() === 0) + } + test("apply on query results (SPARK-5462)") { val df = testData.sparkSession.sql("select key from testData") checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index f9b3ff840582..53b7214e07d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -230,6 +230,16 @@ private[sql] trait SQLTestData { self => df } + protected lazy val person2: DataFrame = { + val df = spark.sparkContext.parallelize( + Person2("Bob", 16, 176) :: + Person2("Alice", 32, 164) :: + Person2("David", 60, 192) :: + Person2("Amy", 24, 180) :: Nil).toDF() + df.createOrReplaceTempView("person2") + df + } + protected lazy val salary: DataFrame = { val df = spark.sparkContext.parallelize( Salary(0, 2000.0) :: @@ -310,6 +320,7 @@ private[sql] object SQLTestData { case class NullStrings(n: Int, s: String) case class TableName(tableName: String) case class Person(id: Int, name: String, age: Int) + case class Person2(name: String, age: Int, height: Int) case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) case class CourseSales(course: String, year: Int, earnings: Double) From 6590f1a87f0e994998d83faca64f5dd8609f5b53 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 28 Jun 2017 16:51:59 -0500 Subject: [PATCH 07/12] missed revert --- python/pyspark/sql/dataframe.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d4adeb78ba9f..8541403dfe2f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -950,9 +950,6 @@ def describe(self, *cols): | mean| 3.5| null| | stddev|2.1213203435596424| null| | min| 2|Alice| - | 25%| 5.0| null| - | 50%| 5.0| null| - | 75%| 5.0| null| | max| 5| Bob| +-------+------------------+-----+ """ From f052665b15b3c3e8f6e7367c65e6bb4f6062340e Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 28 Jun 2017 16:54:27 -0500 Subject: [PATCH 08/12] revert changes to imports --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ef36fc6dac94..41c7b63df995 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.sql.{Date, Timestamp} -import scala.collection.JavaConversions.asJavaCollection import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -40,13 +39,13 @@ import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.json.{JSONOptions, JacksonGenerator} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, usePrettyExpression} +import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation From 32f3b828dccaadd4cd477d541c34006c7e420616 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 30 Jun 2017 14:46:26 -0500 Subject: [PATCH 09/12] summary --- .../scala/org/apache/spark/sql/Dataset.scala | 83 +++++++------------ .../org/apache/spark/sql/DataFrameSuite.scala | 24 +++--- 2 files changed, 43 insertions(+), 64 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 41c7b63df995..7c2184df6dad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2205,7 +2205,7 @@ class Dataset[T] private[sql]( * // max 92.0 192.0 * }}} * - * See also [[describeExtended]] and [[describeAdvanced]] + * See also [[summary]] * * @param cols Columns to compute statistics on. * @@ -2213,20 +2213,30 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def describe(cols: String*): DataFrame = - describeAdvanced(Array("count", "mean", "stddev", "min", "max"), cols: _*) + def describe(cols: String*): DataFrame = { + val selected = if (cols.isEmpty) this else select(cols.head, cols.tail: _*) + selected.summary("count", "mean", "stddev", "min", "max") + } /** - * Computes statistics for numeric and string columns, including count, mean, stddev, min, - * approximate quartiles, and max. If no columns are given, this function computes - * statistics for all numerical or string columns. + * Computes specified statistics for numeric and string columns. Available statistics are: + * + * - count + * - mean + * - stddev + * - min + * - max + * - arbitrary approximate percentiles specified as a percentage (eg, 75%) + * + * If no statistics are given, this function computes count, mean, stddev, min, + * approximate quartiles, and max. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting Dataset. If you want to * programmatically compute summary statistics, use the `agg` function instead. * * {{{ - * ds.describeExtended("age", "height").show() + * ds.summary().show() * * // output: * // summary age height @@ -2240,36 +2250,8 @@ class Dataset[T] private[sql]( * // max 92.0 192.0 * }}} * - * To specify which statistics or percentiles are desired see [[describeAdvanced]] - * - * @param cols Columns to compute statistics on. - * - * @group action - * @since 2.3.0 - */ - @scala.annotation.varargs - def describeExtended(cols: String*): DataFrame = - describeAdvanced(Array("count", "mean", "stddev", "min", "25%", "50%", "75%", "max"), cols: _*) - - /** - * Computes specified statistics for numeric and string columns. Available statistics are: - * - * - count - * - mean - * - stddev - * - min - * - max - * - arbitrary approximate percentiles specifid as a percentage (eg, 75%) - * - * If no columns are given, this function computes statistics for all numerical or string - * columns. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting Dataset. If you want to - * programmatically compute summary statistics, use the `agg` function instead. - * * {{{ - * ds.describeAdvanced(Array("count", "min", "25%", "75%", "max"), "age", "height").show() + * ds.summary("count", "min", "25%", "75%", "max").show() * * // output: * // summary age height @@ -2281,29 +2263,31 @@ class Dataset[T] private[sql]( * }}} * * @param statistics Statistics from above list to be computed. - * @param cols Columns to compute statistics on. * * @group action * @since 2.3.0 */ @scala.annotation.varargs - def describeAdvanced(statistics: Array[String], cols: String*): DataFrame = withPlan { + def summary(statistics: String*): DataFrame = withPlan { + + val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") + val selectedStatistics = if (statistics.nonEmpty) statistics.toSeq else defaultStatistics - val hasPercentiles = statistics.exists(_.endsWith("%")) + val hasPercentiles = selectedStatistics.exists(_.endsWith("%")) val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { - val (pStrings, rest) = statistics.toSeq.partition(a => a.endsWith("%")) + val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%")) val percentiles = pStrings.map { p => try { p.stripSuffix("%").toDouble / 100.0 } catch { case e: NumberFormatException => - throw new IllegalArgumentException(s"Unable to parse $p as a double", e) + throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) } } require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") (percentiles, pStrings, rest) } else { - (Seq(), Seq(), statistics.toSeq) + (Seq(), Seq(), selectedStatistics) } @@ -2324,10 +2308,9 @@ class Dataset[T] private[sql]( new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) .toAggregateExpression() - val outputCols = - (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList + val outputCols = aggregatableColumns.map(usePrettyExpression(_).sql).toList - val ret: Seq[Row] = if (outputCols.nonEmpty && statistics.nonEmpty) { + val ret: Seq[Row] = if (outputCols.nonEmpty) { var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) } @@ -2359,16 +2342,14 @@ class Dataset[T] private[sql]( Row(name :: values.map(nullSafeString).toList: _*) } (rows ++ percentileRows) - .sortWith((left, right) => statistics.indexOf(left(0)) < statistics.indexOf(right(0))) + .sortWith((left, right) => + selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0))) } else { rows } - } else if (outputCols.isEmpty) { - // If there are no output columns, just output a single column that contains the stats. - statistics.map(Row(_)) } else { - // If there are no aggregates, return empty Seq - Seq() + // If there are no output columns, just output a single column that contains the stats. + selectedStatistics.map(Row(_)) } // All columns are string type diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a47ce8dd9725..1b84d98ee681 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -707,7 +707,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(emptyDescription, emptyDescribeResult) } - test("describeExtended") { + test("summary") { val describeTestData = person2 val describeResult = Seq( @@ -732,7 +732,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) - val describeTwoCols = describeTestData.describeExtended("name", "age", "height") + val describeTwoCols = describeTestData.summary() assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height")) checkAnswer(describeTwoCols, describeResult) // All aggregate value should have been cast to string @@ -741,42 +741,40 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass) } - val describeAllCols = describeTestData.describeExtended() + val describeAllCols = describeTestData.summary() assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) checkAnswer(describeAllCols, describeResult) - val describeOneCol = describeTestData.describeExtended("age") + val describeOneCol = describeTestData.select("age").summary() assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) - val describeNoCol = describeTestData.select("name").describeExtended() + val describeNoCol = describeTestData.select("name").summary() assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name")) checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} ) - val emptyDescription = describeTestData.limit(0).describeExtended() + val emptyDescription = describeTestData.limit(0).summary() assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) checkAnswer(emptyDescription, emptyDescribeResult) } - test("describeAdvanced") { + test("summary advanced") { val stats = Array("count", "50.01%", "max", "mean", "min", "25%") - val orderMatters = person2.describeAdvanced(stats) + val orderMatters = person2.summary(stats: _*) assert(orderMatters.collect().map(_.getString(0)) === stats) - val onlyPercentiles = person2.describeAdvanced(Array("0.1%", "99.9%")) + val onlyPercentiles = person2.summary("0.1%", "99.9%") assert(onlyPercentiles.count() === 2) val fooE = intercept[IllegalArgumentException] { - person2.describeAdvanced(Array("foo")) + person2.summary("foo") } assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic") val parseE = intercept[IllegalArgumentException] { - person2.describeAdvanced(Array("foo%")) + person2.summary("foo%") } assert(parseE.getMessage === "Unable to parse foo% as a double") - - assert(person2.describeAdvanced(Array()).count() === 0) } test("apply on query results (SPARK-5462)") { From cba1b0e6c3ac32f7cb327ead54f6e8307aed00ac Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 30 Jun 2017 21:15:47 -0500 Subject: [PATCH 10/12] fix test --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1b84d98ee681..8aff51515baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -774,7 +774,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val parseE = intercept[IllegalArgumentException] { person2.summary("foo%") } - assert(parseE.getMessage === "Unable to parse foo% as a double") + assert(parseE.getMessage === "Unable to parse foo% as a percentile") } test("apply on query results (SPARK-5462)") { From 38ec8ddba96ec2c26c9436571c6581ab489499d4 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 5 Jul 2017 16:13:22 -0500 Subject: [PATCH 11/12] move impl to StatFunctions and add revise doc --- .../scala/org/apache/spark/sql/Dataset.scala | 107 +++--------------- .../sql/execution/stat/StatFunctions.scala | 98 +++++++++++++++- 2 files changed, 109 insertions(+), 96 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7c2184df6dad..61146539decb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -38,18 +38,18 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} -import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -224,7 +224,7 @@ class Dataset[T] private[sql]( } } - private def aggregatableColumns: Seq[Expression] = { + private[sql] def aggregatableColumns: Seq[Expression] = { schema.fields .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) .map { n => @@ -2205,7 +2205,7 @@ class Dataset[T] private[sql]( * // max 92.0 192.0 * }}} * - * See also [[summary]] + * Use [[summary]] for expanded statistics and control over which statistics to compute. * * @param cols Columns to compute statistics on. * @@ -2262,102 +2262,21 @@ class Dataset[T] private[sql]( * // max 92.0 192.0 * }}} * + * To do a summary for specific columns first select them: + * + * {{{ + * ds.select("age", "height").summary().show() + * }}} + * + * See also [[describe]] for basic statistics. + * * @param statistics Statistics from above list to be computed. * * @group action * @since 2.3.0 */ @scala.annotation.varargs - def summary(statistics: String*): DataFrame = withPlan { - - val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") - val selectedStatistics = if (statistics.nonEmpty) statistics.toSeq else defaultStatistics - - val hasPercentiles = selectedStatistics.exists(_.endsWith("%")) - val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { - val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%")) - val percentiles = pStrings.map { p => - try { - p.stripSuffix("%").toDouble / 100.0 - } catch { - case e: NumberFormatException => - throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) - } - } - require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - (percentiles, pStrings, rest) - } else { - (Seq(), Seq(), selectedStatistics) - } - - - // The list of summary statistics to compute, in the form of expressions. - val availableStatistics = Map[String, Expression => Expression]( - "count" -> ((child: Expression) => Count(child).toAggregateExpression()), - "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), - "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), - "min" -> ((child: Expression) => Min(child).toAggregateExpression()), - "max" -> ((child: Expression) => Max(child).toAggregateExpression())) - - val statisticFns = remainingAggregates.map { agg => - require(availableStatistics.contains(agg), s"$agg is not a recognised statistic") - agg -> availableStatistics(agg) - } - - def percentileAgg(child: Expression): Expression = - new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) - .toAggregateExpression() - - val outputCols = aggregatableColumns.map(usePrettyExpression(_).sql).toList - - val ret: Seq[Row] = if (outputCols.nonEmpty) { - var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => - outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) - } - if (hasPercentiles) { - aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs - } - - val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq - - // Pivot the data so each summary is one row - val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq - - val basicStats = if (hasPercentiles) grouped.tail else grouped - - val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) => - Row(statistic :: aggregation.toList: _*) - } - - if (hasPercentiles) { - def nullSafeString(x: Any) = if (x == null) null else x.toString - val percentileRows = grouped.head - .map { - case a: Seq[Any] => a - case _ => Seq.fill(percentiles.length)(null: Any) - } - .transpose - .zip(percentileNames) - .map { case (values: Seq[Any], name) => - Row(name :: values.map(nullSafeString).toList: _*) - } - (rows ++ percentileRows) - .sortWith((left, right) => - selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0))) - } else { - rows - } - } else { - // If there are no output columns, just output a single column that contains the stats. - selectedStatistics.map(Row(_)) - } - - // All columns are string type - val schema = StructType( - StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - // `toArray` forces materialization to make the seq serializable - LocalRelation.fromExternalRows(schema, ret.toArray.toSeq) - } + def summary(statistics: String*): DataFrame = StatFunctions.summary(this, statistics.toSeq) /** * Returns the first `n` rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 1debad03c93f..7f772e1f64f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, GenericInternalRow} +import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.util.QuantileSummaries +import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -220,4 +221,97 @@ object StatFunctions extends Logging { Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } + + /** Calculate selected summary statistics for a dataset */ + def summary[T](ds: Dataset[T], statistics: Seq[String]): DataFrame = { + + val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") + val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics + + val hasPercentiles = selectedStatistics.exists(_.endsWith("%")) + val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { + val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%")) + val percentiles = pStrings.map { p => + try { + p.stripSuffix("%").toDouble / 100.0 + } catch { + case e: NumberFormatException => + throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) + } + } + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") + (percentiles, pStrings, rest) + } else { + (Seq(), Seq(), selectedStatistics) + } + + + // The list of summary statistics to compute, in the form of expressions. + val availableStatistics = Map[String, Expression => Expression]( + "count" -> ((child: Expression) => Count(child).toAggregateExpression()), + "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), + "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), + "min" -> ((child: Expression) => Min(child).toAggregateExpression()), + "max" -> ((child: Expression) => Max(child).toAggregateExpression())) + + val statisticFns = remainingAggregates.map { agg => + require(availableStatistics.contains(agg), s"$agg is not a recognised statistic") + agg -> availableStatistics(agg) + } + + def percentileAgg(child: Expression): Expression = + new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) + .toAggregateExpression() + + val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList + + val ret: Seq[Row] = if (outputCols.nonEmpty) { + var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => + outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) + } + if (hasPercentiles) { + aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs + } + + val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + + // Pivot the data so each summary is one row + val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq + + val basicStats = if (hasPercentiles) grouped.tail else grouped + + val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) => + Row(statistic :: aggregation.toList: _*) + } + + if (hasPercentiles) { + def nullSafeString(x: Any) = if (x == null) null else x.toString + val percentileRows = grouped.head + .map { + case a: Seq[Any] => a + case _ => Seq.fill(percentiles.length)(null: Any) + } + .transpose + .zip(percentileNames) + .map { case (values: Seq[Any], name) => + Row(name :: values.map(nullSafeString).toList: _*) + } + (rows ++ percentileRows) + .sortWith((left, right) => + selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0))) + } else { + rows + } + } else { + // If there are no output columns, just output a single column that contains the stats. + selectedStatistics.map(Row(_)) + } + + // All columns are string type + val schema = StructType( + StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes + // `toArray` forces materialization to make the seq serializable + Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)) + } + } From 3b548cc3d5ad8928785fe644db9ea788dfb8fad2 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 6 Jul 2017 09:45:31 -0500 Subject: [PATCH 12/12] adress pr comments --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../sql/execution/stat/StatFunctions.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 79 ++++++++++--------- .../apache/spark/sql/test/SQLTestData.scala | 11 --- 4 files changed, 42 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 61146539decb..4060b3e923f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2229,7 +2229,7 @@ class Dataset[T] private[sql]( * - arbitrary approximate percentiles specified as a percentage (eg, 75%) * * If no statistics are given, this function computes count, mean, stddev, min, - * approximate quartiles, and max. + * approximate quartiles (percentiles at 25%, 50%, and 75%), and max. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting Dataset. If you want to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 7f772e1f64f1..436e18fdb5ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -223,7 +223,7 @@ object StatFunctions extends Logging { } /** Calculate selected summary statistics for a dataset */ - def summary[T](ds: Dataset[T], statistics: Seq[String]): DataFrame = { + def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = { val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8aff51515baf..2c7051bf431c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -662,9 +662,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } - test("describe") { - val describeTestData = person2 + private lazy val person2: DataFrame = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + test("describe") { val describeResult = Seq( Row("count", "4", "4", "4"), Row("mean", null, "33.0", "178.0"), @@ -681,36 +685,33 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) - val describeTwoCols = describeTestData.describe("name", "age", "height") - assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height")) - checkAnswer(describeTwoCols, describeResult) - // All aggregate value should have been cast to string - describeTwoCols.collect().foreach { row => - assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) - assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass) - } - - val describeAllCols = describeTestData.describe() + val describeAllCols = person2.describe() assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) checkAnswer(describeAllCols, describeResult) + // All aggregate value should have been cast to string + describeAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } + } - val describeOneCol = describeTestData.describe("age") + val describeOneCol = person2.describe("age") assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) - val describeNoCol = describeTestData.select("name").describe() - assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name")) - checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} ) + val describeNoCol = person2.select().describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s)} ) - val emptyDescription = describeTestData.limit(0).describe() + val emptyDescription = person2.limit(0).describe() assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) checkAnswer(emptyDescription, emptyDescribeResult) } test("summary") { - val describeTestData = person2 - - val describeResult = Seq( + val summaryResult = Seq( Row("count", "4", "4", "4"), Row("mean", null, "33.0", "178.0"), Row("stddev", null, "19.148542155126762", "11.547005383792516"), @@ -720,7 +721,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("75%", null, "32.0", "180.0"), Row("max", "David", "60", "192")) - val emptyDescribeResult = Seq( + val emptySummaryResult = Seq( Row("count", "0", "0", "0"), Row("mean", null, null, null), Row("stddev", null, null, null), @@ -732,30 +733,30 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) - val describeTwoCols = describeTestData.summary() - assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height")) - checkAnswer(describeTwoCols, describeResult) + val summaryAllCols = person2.summary() + + assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height")) + checkAnswer(summaryAllCols, summaryResult) // All aggregate value should have been cast to string - describeTwoCols.collect().foreach { row => - assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) - assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass) + summaryAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } } - val describeAllCols = describeTestData.summary() - assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) - checkAnswer(describeAllCols, describeResult) - - val describeOneCol = describeTestData.select("age").summary() - assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) - checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) + val summaryOneCol = person2.select("age").summary() + assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age")) + checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d)} ) - val describeNoCol = describeTestData.select("name").summary() - assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name")) - checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} ) + val summaryNoCol = person2.select().summary() + assert(getSchemaAsSeq(summaryNoCol) === Seq("summary")) + checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s)} ) - val emptyDescription = describeTestData.limit(0).summary() + val emptyDescription = person2.limit(0).summary() assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) - checkAnswer(emptyDescription, emptyDescribeResult) + checkAnswer(emptyDescription, emptySummaryResult) } test("summary advanced") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 53b7214e07d0..f9b3ff840582 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -230,16 +230,6 @@ private[sql] trait SQLTestData { self => df } - protected lazy val person2: DataFrame = { - val df = spark.sparkContext.parallelize( - Person2("Bob", 16, 176) :: - Person2("Alice", 32, 164) :: - Person2("David", 60, 192) :: - Person2("Amy", 24, 180) :: Nil).toDF() - df.createOrReplaceTempView("person2") - df - } - protected lazy val salary: DataFrame = { val df = spark.sparkContext.parallelize( Salary(0, 2000.0) :: @@ -320,7 +310,6 @@ private[sql] object SQLTestData { case class NullStrings(n: Int, s: String) case class TableName(tableName: String) case class Person(id: Int, name: String, age: Int) - case class Person2(name: String, age: Int, height: Int) case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) case class CourseSales(course: String, year: Int, earnings: Double)