Skip to content

Commit 365a29b

Browse files
Zhenhua Wanggatorsmile
authored andcommitted
[SPARK-22100][SQL] Make percentile_approx support date/timestamp type and change the output type to be the same as input type
## What changes were proposed in this pull request? The `percentile_approx` function previously accepted numeric type input and output double type results. But since all numeric types, date and timestamp types are represented as numerics internally, `percentile_approx` can support them easily. After this PR, it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. This change is also required when we generate equi-height histograms for these types. ## How was this patch tested? Added a new test and modified some existing tests. Author: Zhenhua Wang <[email protected]> Closes apache#19321 from wzhfy/approx_percentile_support_types.
1 parent 20adf9a commit 365a29b

File tree

7 files changed

+70
-19
lines changed

7 files changed

+70
-19
lines changed

R/pkg/tests/fulltests/test_sparkSQL.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,14 +2538,14 @@ test_that("describe() and summary() on a DataFrame", {
25382538

25392539
stats2 <- summary(df)
25402540
expect_equal(collect(stats2)[5, "summary"], "25%")
2541-
expect_equal(collect(stats2)[5, "age"], "30.0")
2541+
expect_equal(collect(stats2)[5, "age"], "30")
25422542

25432543
stats3 <- summary(df, "min", "max", "55.1%")
25442544

25452545
expect_equal(collect(stats3)[1, "summary"], "min")
25462546
expect_equal(collect(stats3)[2, "summary"], "max")
25472547
expect_equal(collect(stats3)[3, "summary"], "55.1%")
2548-
expect_equal(collect(stats3)[3, "age"], "30.0")
2548+
expect_equal(collect(stats3)[3, "age"], "30")
25492549

25502550
# SPARK-16425: SparkR summary() fails on column of type logical
25512551
df <- withColumn(df, "boolean", df$age == 30)

docs/sql-programming-guide.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,6 +1553,7 @@ options.
15531553
## Upgrading From Spark SQL 2.2 to 2.3
15541554

15551555
- Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`.
1556+
- The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles.
15561557

15571558
## Upgrading From Spark SQL 2.1 to 2.2
15581559

python/pyspark/sql/dataframe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,9 +1038,9 @@ def summary(self, *statistics):
10381038
| mean| 3.5| null|
10391039
| stddev|2.1213203435596424| null|
10401040
| min| 2|Alice|
1041-
| 25%| 5.0| null|
1042-
| 50%| 5.0| null|
1043-
| 75%| 5.0| null|
1041+
| 25%| 5| null|
1042+
| 50%| 5| null|
1043+
| 75%| 5| null|
10441044
| max| 5| Bob|
10451045
+-------+------------------+-----+
10461046
@@ -1050,8 +1050,8 @@ def summary(self, *statistics):
10501050
+-------+---+-----+
10511051
| count| 2| 2|
10521052
| min| 2|Alice|
1053-
| 25%|5.0| null|
1054-
| 75%|5.0| null|
1053+
| 25%| 5| null|
1054+
| 75%| 5| null|
10551055
| max| 5| Bob|
10561056
+-------+---+-----+
10571057

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ case class ApproximatePercentile(
8585
private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int]
8686

8787
override def inputTypes: Seq[AbstractDataType] = {
88-
Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType)
88+
// Support NumericType, DateType and TimestampType since their internal types are all numeric,
89+
// and can be easily cast to double for processing.
90+
Seq(TypeCollection(NumericType, DateType, TimestampType),
91+
TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType)
8992
}
9093

9194
// Mark as lazy so that percentageExpression is not evaluated during tree transformation.
@@ -123,7 +126,15 @@ case class ApproximatePercentile(
123126
val value = child.eval(inputRow)
124127
// Ignore empty rows, for example: percentile_approx(null)
125128
if (value != null) {
126-
buffer.add(value.asInstanceOf[Double])
129+
// Convert the value to a double value
130+
val doubleValue = child.dataType match {
131+
case DateType => value.asInstanceOf[Int].toDouble
132+
case TimestampType => value.asInstanceOf[Long].toDouble
133+
case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType])
134+
case other: DataType =>
135+
throw new UnsupportedOperationException(s"Unexpected data type $other")
136+
}
137+
buffer.add(doubleValue)
127138
}
128139
buffer
129140
}
@@ -134,7 +145,20 @@ case class ApproximatePercentile(
134145
}
135146

136147
override def eval(buffer: PercentileDigest): Any = {
137-
val result = buffer.getPercentiles(percentages)
148+
val doubleResult = buffer.getPercentiles(percentages)
149+
val result = child.dataType match {
150+
case DateType => doubleResult.map(_.toInt)
151+
case TimestampType => doubleResult.map(_.toLong)
152+
case ByteType => doubleResult.map(_.toByte)
153+
case ShortType => doubleResult.map(_.toShort)
154+
case IntegerType => doubleResult.map(_.toInt)
155+
case LongType => doubleResult.map(_.toLong)
156+
case FloatType => doubleResult.map(_.toFloat)
157+
case DoubleType => doubleResult
158+
case _: DecimalType => doubleResult.map(Decimal(_))
159+
case other: DataType =>
160+
throw new UnsupportedOperationException(s"Unexpected data type $other")
161+
}
138162
if (result.length == 0) {
139163
null
140164
} else if (returnPercentileArray) {
@@ -155,8 +179,9 @@ case class ApproximatePercentile(
155179
// Returns null for empty inputs
156180
override def nullable: Boolean = true
157181

182+
// The result type is the same as the input type.
158183
override def dataType: DataType = {
159-
if (returnPercentileArray) ArrayType(DoubleType, false) else DoubleType
184+
if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType
160185
}
161186

162187
override def prettyName: String = "percentile_approx"

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.InternalRow
22-
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute}
2322
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
23+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2424
import org.apache.spark.sql.catalyst.dsl.expressions._
2525
import org.apache.spark.sql.catalyst.dsl.plans._
2626
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericInternalRow, Literal}
@@ -270,7 +270,6 @@ class ApproximatePercentileSuite extends SparkFunSuite {
270270
percentageExpression = percentageExpression,
271271
accuracyExpression = Literal(100))
272272

273-
val result = wrongPercentage.checkInputDataTypes()
274273
assert(
275274
wrongPercentage.checkInputDataTypes() match {
276275
case TypeCheckFailure(msg) if msg.contains("must be between 0.0 and 1.0") => true
@@ -281,7 +280,6 @@ class ApproximatePercentileSuite extends SparkFunSuite {
281280

282281
test("class ApproximatePercentile, automatically add type casting for parameters") {
283282
val testRelation = LocalRelation('a.int)
284-
val analyzer = SimpleAnalyzer
285283

286284
// Compatible accuracy types: Long type and decimal type
287285
val accuracyExpressions = Seq(Literal(1000L), DecimalLiteral(10000), Literal(123.0D))
@@ -299,7 +297,7 @@ class ApproximatePercentileSuite extends SparkFunSuite {
299297
analyzed match {
300298
case Alias(agg: ApproximatePercentile, _) =>
301299
assert(agg.resolved)
302-
assert(agg.child.dataType == DoubleType)
300+
assert(agg.child.dataType == IntegerType)
303301
assert(agg.percentageExpression.dataType == DoubleType ||
304302
agg.percentageExpression.dataType == ArrayType(DoubleType, containsNull = false))
305303
assert(agg.accuracyExpression.dataType == IntegerType)

sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.sql.{Date, Timestamp}
21+
2022
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
2123
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
24+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2225
import org.apache.spark.sql.test.SharedSQLContext
2326

2427
/**
@@ -67,6 +70,30 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext {
6770
}
6871
}
6972

73+
test("percentile_approx, different column types") {
74+
withTempView(table) {
75+
val intSeq = 1 to 1000
76+
val data: Seq[(java.math.BigDecimal, Date, Timestamp)] = intSeq.map { i =>
77+
(new java.math.BigDecimal(i), DateTimeUtils.toJavaDate(i), DateTimeUtils.toJavaTimestamp(i))
78+
}
79+
data.toDF("cdecimal", "cdate", "ctimestamp").createOrReplaceTempView(table)
80+
checkAnswer(
81+
spark.sql(
82+
s"""SELECT
83+
| percentile_approx(cdecimal, array(0.25, 0.5, 0.75D)),
84+
| percentile_approx(cdate, array(0.25, 0.5, 0.75D)),
85+
| percentile_approx(ctimestamp, array(0.25, 0.5, 0.75D))
86+
|FROM $table
87+
""".stripMargin),
88+
Row(
89+
Seq("250.000000000000000000", "500.000000000000000000", "750.000000000000000000")
90+
.map(i => new java.math.BigDecimal(i)),
91+
Seq(250, 500, 750).map(DateTimeUtils.toJavaDate),
92+
Seq(250, 500, 750).map(i => DateTimeUtils.toJavaTimestamp(i.toLong)))
93+
)
94+
}
95+
}
96+
7097
test("percentile_approx, multiple records with the minimum value in a partition") {
7198
withTempView(table) {
7299
spark.sparkContext.makeRDD(Seq(1, 1, 2, 1, 1, 3, 1, 1, 4, 1, 1, 5), 4).toDF("col")
@@ -88,7 +115,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext {
88115
val accuracies = Array(1, 10, 100, 1000, 10000)
89116
val errors = accuracies.map { accuracy =>
90117
val df = spark.sql(s"SELECT percentile_approx(col, 0.25, $accuracy) FROM $table")
91-
val approximatePercentile = df.collect().head.getDouble(0)
118+
val approximatePercentile = df.collect().head.getInt(0)
92119
val error = Math.abs(approximatePercentile - expectedPercentile)
93120
error
94121
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -803,9 +803,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
803803
Row("mean", null, "33.0", "178.0"),
804804
Row("stddev", null, "19.148542155126762", "11.547005383792516"),
805805
Row("min", "Alice", "16", "164"),
806-
Row("25%", null, "24.0", "176.0"),
807-
Row("50%", null, "24.0", "176.0"),
808-
Row("75%", null, "32.0", "180.0"),
806+
Row("25%", null, "24", "176"),
807+
Row("50%", null, "24", "176"),
808+
Row("75%", null, "32", "180"),
809809
Row("max", "David", "60", "192"))
810810

811811
val emptySummaryResult = Seq(

0 commit comments

Comments
 (0)