Skip to content
134 changes: 124 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2185,9 +2185,9 @@ class Dataset[T] private[sql](
}

/**
* Computes statistics for numeric and string columns, including count, mean, stddev, min, and
* max. If no columns are given, this function computes statistics for all numerical or string
* columns.
* Computes basic statistics for numeric and string columns, including count, mean, stddev, min,
* and max. If no columns are given, this function computes statistics for all numerical or
* string columns.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting Dataset. If you want to
Expand All @@ -2205,37 +2205,151 @@ class Dataset[T] private[sql](
* // max 92.0 192.0
* }}}
*
* See also [[summary]]
*
* @param cols Columns to compute statistics on.
*
* @group action
* @since 1.6.0
*/
@scala.annotation.varargs
def describe(cols: String*): DataFrame = withPlan {
def describe(cols: String*): DataFrame = {
val selected = if (cols.isEmpty) this else select(cols.head, cols.tail: _*)
selected.summary("count", "mean", "stddev", "min", "max")
}

/**
* Computes specified statistics for numeric and string columns. Available statistics are:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also give an example of how to compute summary for specific columns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and explain the difference between describe and summary (basically summary seems easier to use).

*
* - count
* - mean
* - stddev
* - min
* - max
* - arbitrary approximate percentiles specified as a percentage (eg, 75%)
*
* If no statistics are given, this function computes count, mean, stddev, min,
* approximate quartiles, and max.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approximate quartiles at 25%, 50% and 75%?

*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting Dataset. If you want to
* programmatically compute summary statistics, use the `agg` function instead.
*
* {{{
* ds.summary().show()
*
* // output:
* // summary age height
* // count 10.0 10.0
* // mean 53.3 178.05
* // stddev 11.6 15.7
* // min 18.0 163.0
* // 25% 24.0 176.0
* // 50% 24.0 176.0
* // 75% 32.0 180.0
* // max 92.0 192.0
* }}}
*
* {{{
* ds.summary("count", "min", "25%", "75%", "max").show()
*
* // output:
* // summary age height
* // count 10.0 10.0
* // min 18.0 163.0
* // 25% 24.0 176.0
* // 75% 32.0 180.0
* // max 92.0 192.0
* }}}
*
* @param statistics Statistics from above list to be computed.
*
* @group action
* @since 2.3.0
*/
@scala.annotation.varargs
def summary(statistics: String*): DataFrame = withPlan {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move the implementation into org.apache.spark.sql.execution.stat.StatFunctions? I worry Dataset is getting too long. It should probably be mostly an interface / delegation and most of the implementations are elsewhere.


val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
val selectedStatistics = if (statistics.nonEmpty) statistics.toSeq else defaultStatistics

val hasPercentiles = selectedStatistics.exists(_.endsWith("%"))
val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) {
val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%"))
val percentiles = pStrings.map { p =>
try {
p.stripSuffix("%").toDouble / 100.0
} catch {
case e: NumberFormatException =>
throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
}
}
require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
(percentiles, pStrings, rest)
} else {
(Seq(), Seq(), selectedStatistics)
}


// The list of summary statistics to compute, in the form of expressions.
val statistics = List[(String, Expression => Expression)](
val availableStatistics = Map[String, Expression => Expression](
"count" -> ((child: Expression) => Count(child).toAggregateExpression()),
"mean" -> ((child: Expression) => Average(child).toAggregateExpression()),
"stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()),
"min" -> ((child: Expression) => Min(child).toAggregateExpression()),
"max" -> ((child: Expression) => Max(child).toAggregateExpression()))

val outputCols =
(if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList
val statisticFns = remainingAggregates.map { agg =>
require(availableStatistics.contains(agg), s"$agg is not a recognised statistic")
agg -> availableStatistics(agg)
}

def percentileAgg(child: Expression): Expression =
new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_))))
.toAggregateExpression()

val outputCols = aggregatableColumns.map(usePrettyExpression(_).sql).toList

val ret: Seq[Row] = if (outputCols.nonEmpty) {
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) =>
outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
}
if (hasPercentiles) {
aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs
}

val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq

// Pivot the data so each summary is one row
row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq

val basicStats = if (hasPercentiles) grouped.tail else grouped

val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) =>
Row(statistic :: aggregation.toList: _*)
}

if (hasPercentiles) {
def nullSafeString(x: Any) = if (x == null) null else x.toString
val percentileRows = grouped.head
.map {
case a: Seq[Any] => a
case _ => Seq.fill(percentiles.length)(null: Any)
}
.transpose
.zip(percentileNames)
.map { case (values: Seq[Any], name) =>
Row(name :: values.map(nullSafeString).toList: _*)
}
(rows ++ percentileRows)
.sortWith((left, right) =>
selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0)))
} else {
rows
}
} else {
// If there are no output columns, just output a single column that contains the stats.
statistics.map { case (name, _) => Row(name) }
selectedStatistics.map(Row(_))
}

// All columns are string type
Expand Down
79 changes: 72 additions & 7 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import org.scalatest.Matchers._

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
Expand Down Expand Up @@ -664,11 +663,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}

test("describe") {
val describeTestData = Seq(
("Bob", 16, 176),
("Alice", 32, 164),
("David", 60, 192),
("Amy", 24, 180)).toDF("name", "age", "height")
val describeTestData = person2

val describeResult = Seq(
Row("count", "4", "4", "4"),
Expand Down Expand Up @@ -712,6 +707,76 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(emptyDescription, emptyDescribeResult)
}

test("summary") {
val describeTestData = person2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

summaryTestData? Actually can we just use person2?


val describeResult = Seq(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

summaryResult?

Row("count", "4", "4", "4"),
Row("mean", null, "33.0", "178.0"),
Row("stddev", null, "19.148542155126762", "11.547005383792516"),
Row("min", "Alice", "16", "164"),
Row("25%", null, "24.0", "176.0"),
Row("50%", null, "24.0", "176.0"),
Row("75%", null, "32.0", "180.0"),
Row("max", "David", "60", "192"))

val emptyDescribeResult = Seq(
Row("count", "0", "0", "0"),
Row("mean", null, null, null),
Row("stddev", null, null, null),
Row("min", null, null, null),
Row("25%", null, null, null),
Row("50%", null, null, null),
Row("75%", null, null, null),
Row("max", null, null, null))

def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)

val describeTwoCols = describeTestData.summary()
assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height"))
checkAnswer(describeTwoCols, describeResult)
// All aggregate value should have been cast to string
describeTwoCols.collect().foreach { row =>
assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass)
assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass)
}

val describeAllCols = describeTestData.summary()
assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height"))
checkAnswer(describeAllCols, describeResult)

val describeOneCol = describeTestData.select("age").summary()
assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} )

val describeNoCol = describeTestData.select("name").summary()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

describeNoCol?

assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name"))
checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} )

val emptyDescription = describeTestData.limit(0).summary()
assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
checkAnswer(emptyDescription, emptyDescribeResult)
}

test("summary advanced") {
val stats = Array("count", "50.01%", "max", "mean", "min", "25%")
val orderMatters = person2.summary(stats: _*)
assert(orderMatters.collect().map(_.getString(0)) === stats)

val onlyPercentiles = person2.summary("0.1%", "99.9%")
assert(onlyPercentiles.count() === 2)

val fooE = intercept[IllegalArgumentException] {
person2.summary("foo")
}
assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic")

val parseE = intercept[IllegalArgumentException] {
person2.summary("foo%")
}
assert(parseE.getMessage === "Unable to parse foo% as a double")
}

test("apply on query results (SPARK-5462)") {
val df = testData.sparkSession.sql("select key from testData")
checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,16 @@ private[sql] trait SQLTestData { self =>
df
}

protected lazy val person2: DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's only used in DataFrameSuite, can we put this in DataFrameSuite?

val df = spark.sparkContext.parallelize(
Person2("Bob", 16, 176) ::
Person2("Alice", 32, 164) ::
Person2("David", 60, 192) ::
Person2("Amy", 24, 180) :: Nil).toDF()
df.createOrReplaceTempView("person2")
df
}

protected lazy val salary: DataFrame = {
val df = spark.sparkContext.parallelize(
Salary(0, 2000.0) ::
Expand Down Expand Up @@ -310,6 +320,7 @@ private[sql] object SQLTestData {
case class NullStrings(n: Int, s: String)
case class TableName(tableName: String)
case class Person(id: Int, name: String, age: Int)
case class Person2(name: String, age: Int, height: Int)
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
case class CourseSales(course: String, year: Int, earnings: Double)
Expand Down