-
Notifications
You must be signed in to change notification settings - Fork 29.3k
[SPARK-16282][SQL] Implement percentile SQL function. #14136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
15ba3c2
ef14aab
91ddabd
324483f
6bda505
a29d8b3
c7193d4
d21a104
8eebb6a
79a2b97
2ae7b48
8f24a9b
59a61cf
7ad1a35
93f8285
8a08576
7731066
4ace3bc
e01d0b2
b0aabf9
5b8cd4d
3c699ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,292 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.expressions.aggregate | ||
|
|
||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.Countings | ||
| import org.apache.spark.sql.catalyst.util._ | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET | ||
| import org.apache.spark.util.collection.OpenHashMap | ||
|
|
||
|
|
||
| /** | ||
| * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at | ||
| * the given percentage(s) with value range in [0.0, 1.0]. | ||
| * | ||
| * The operator is bound to the slower sort based aggregation path because the number of elements | ||
| * and their partial order cannot be determined in advance. Therefore we have to store all the | ||
| * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory | ||
| * Errors. | ||
| * | ||
| * @param child child expression that produce numeric column value with `child.eval(inputRow)` | ||
| * @param percentageExpression Expression that represents a single percentage value or an array of | ||
| * percentage values. Each percentage value must be in the range | ||
| * [0.0, 1.0]. | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = | ||
| """ | ||
| _FUNC_(col, percentage) - Returns the exact percentile value of numeric column `col` at the | ||
| given percentage. The value of percentage must be between 0.0 and 1.0. | ||
|
|
||
| _FUNC_(col, array(percentage1 [, percentage2]...)) - Returns the exact percentile value array | ||
| of numeric column `col` at the given percentage(s). Each value of the percentage array must | ||
| be between 0.0 and 1.0. | ||
| """) | ||
| case class Percentile( | ||
| child: Expression, | ||
| percentageExpression: Expression, | ||
| mutableAggBufferOffset: Int = 0, | ||
| inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[Countings] { | ||
|
|
||
| def this(child: Expression, percentageExpression: Expression) = { | ||
| this(child, percentageExpression, 0, 0) | ||
| } | ||
|
|
||
| override def prettyName: String = "percentile" | ||
|
|
||
| override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Percentile = | ||
| copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
|
||
| override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Percentile = | ||
| copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
|
||
| // Mark as lazy so that percentageExpression is not evaluated during tree transformation. | ||
| private lazy val (returnPercentileArray: Boolean, percentages: Seq[Number]) = | ||
| evalPercentages(percentageExpression) | ||
|
|
||
| override def children: Seq[Expression] = child :: percentageExpression :: Nil | ||
|
|
||
| // Returns null for empty inputs | ||
| override def nullable: Boolean = true | ||
|
|
||
| override def dataType: DataType = | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. override lazy val dataType: DataType = percentageExpression.dataType match {
case _: ArrayType => ArrayType(DoubleType, false)
case _ => DoubleType
} |
||
| if (returnPercentileArray) ArrayType(DoubleType) else DoubleType | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should return the type of the input. We can always interpolate the value and cast that to the input type. Is this is different from what Hive does?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. HIVE could return double value or array of double values even the column dataType is integer, for example: |
||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = | ||
| Seq(NumericType, TypeCollection(NumericType, ArrayType)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to be the devil's advocate: Why don't we support AtomicTypes here? They are all orderable.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type of the second argument should be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure we should support AtomicTypes because not all of them could be converted to double, and we have to use double values to do interpolation operations.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't specify the type of the array(
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Supporting
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW - you can make the analyzer add casts for you: override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match {
case _: ArrayType => Seq(NumericType, ArrayType(DoubleType, false))
case _ => Seq(NumericType, DoubleType)
}Then you are alway sure you get a double or a double array for the |
||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = | ||
| TypeUtils.checkForNumericExpr(child.dataType, "function percentile") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Call |
||
|
|
||
| override def createAggregationBuffer(): Countings = { | ||
| // Initialize new Countings instance here. | ||
| Countings() | ||
| } | ||
|
|
||
| private def evalPercentages(expr: Expression): (Boolean, Seq[Number]) = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not return doubles? |
||
| val (isArrayType, values) = (expr.dataType, expr.eval()) match { | ||
| case (_, n: Number) => (false, Array(n)) | ||
| case (_, d: Decimal) => (false, Array(d.toDouble.asInstanceOf[Number])) | ||
| case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => | ||
| val numericArray = arrayData.toObjectArray(baseType) | ||
| (true, numericArray.map { x => | ||
| baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]).asInstanceOf[Number] | ||
| }) | ||
| case other => | ||
| throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") | ||
| } | ||
|
|
||
| require(values.forall(value => value.doubleValue() >= 0.0 && value.doubleValue() <= 1.0), | ||
| s"Percentage values must be between 0.0 and 1.0, current values = ${values.mkString(", ")}") | ||
|
|
||
| (isArrayType, values) | ||
| } | ||
|
|
||
| override def update(buffer: Countings, input: InternalRow): Unit = { | ||
| val key = child.eval(input).asInstanceOf[Number] | ||
| buffer.add(key) | ||
| } | ||
|
|
||
| override def merge(buffer: Countings, other: Countings): Unit = { | ||
| buffer.merge(other) | ||
| } | ||
|
|
||
| override def eval(buffer: Countings): Any = { | ||
| generateOutput(buffer.getPercentiles(percentages)) | ||
| } | ||
|
|
||
| private def generateOutput(results: Seq[Double]): Any = { | ||
| if (results.isEmpty) { | ||
| null | ||
| } else if (returnPercentileArray) { | ||
| new GenericArrayData(results) | ||
| } else { | ||
| results.head | ||
| } | ||
| } | ||
|
|
||
| override def serialize(obj: Countings): Array[Byte] = { | ||
| Percentile.serializer.serialize(obj, child.dataType) | ||
| } | ||
|
|
||
| override def deserialize(bytes: Array[Byte]): Countings = { | ||
| Percentile.serializer.deserialize(bytes, child.dataType) | ||
| } | ||
| } | ||
|
|
||
| object Percentile { | ||
| object Countings { | ||
| def apply(): Countings = Countings(new OpenHashMap[Number, Long]) | ||
|
|
||
| def apply(counts: OpenHashMap[Number, Long]): Countings = new Countings(counts) | ||
| } | ||
|
|
||
| /** | ||
| * A class that stores the numbers and their counts, used to support [[Percentile]] function. | ||
| */ | ||
| class Countings(val counts: OpenHashMap[Number, Long]) extends Serializable { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove this class and put its implementation in the Percentile Aggregate.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The class
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could entirely remove the class |
||
| /** | ||
| * Insert a key into countings map. | ||
| */ | ||
| def add(key: Number): Unit = { | ||
| // Null values are ignored in countings. | ||
| if (key != null) { | ||
| counts.changeValue(key, 1L, _ + 1L) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * In place merges in another Countings. | ||
| */ | ||
| def merge(other: Countings): Unit = { | ||
| other.counts.foreach { pair => | ||
| counts.changeValue(pair._1, pair._2, _ + pair._2) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Get the percentile value for every percentile in `percentages`. | ||
| */ | ||
| def getPercentiles(percentages: Seq[Number]): Seq[Double] = { | ||
| if (counts.isEmpty) { | ||
| return Seq.empty | ||
| } | ||
|
|
||
| val sortedCounts = counts.toSeq.sortBy(_._1)(new Ordering[Number]() { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a dumb question: How can we order a sequence of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could cast the ordering? |
||
| override def compare(a: Number, b: Number): Int = | ||
| scala.math.signum(a.doubleValue() - b.doubleValue()).toInt | ||
| }) | ||
| val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just use an imperative loop. |
||
| (k1: (Number, Long), k2: (Number, Long)) => (k2._1, k1._2 + k2._2) | ||
| }.tail | ||
| val maxPosition = aggreCounts.last._2 - 1 | ||
|
|
||
| percentages.map { percentile => | ||
| getPercentile(aggreCounts, maxPosition * percentile.doubleValue()).doubleValue() | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Get the percentile value. | ||
| */ | ||
| private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { | ||
| // We may need to do linear interpolation to get the exact percentile | ||
| val lower = position.floor | ||
| val higher = position.ceil | ||
|
|
||
| // Linear search since this won't take much time from the total execution anyway | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That doesn't make it right :)... Anyway there are enough binarySearch implementations around. So maybe use one of those.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was taken from Hive |
||
| // lower has the range of [0 .. total-1] | ||
| // The first entry with accumulated count (lower+1) corresponds to the lower position. | ||
| var i = 0 | ||
| while (aggreCounts(i)._2 < lower + 1) { | ||
| i += 1 | ||
| } | ||
|
|
||
| val lowerKey = aggreCounts(i)._1 | ||
| if (higher == lower) { | ||
| // no interpolation needed because position does not have a fraction | ||
| return lowerKey | ||
| } | ||
|
|
||
| if (aggreCounts(i)._2 < higher + 1) { | ||
| i += 1 | ||
| } | ||
| val higherKey = aggreCounts(i)._1 | ||
|
|
||
| if (higherKey == lowerKey) { | ||
| // no interpolation needed because lower position and higher position has the same key | ||
| return lowerKey | ||
| } | ||
|
|
||
| // Linear interpolation to get the exact percentile | ||
| return (higher - position) * lowerKey.doubleValue() + | ||
| (position - lower) * higherKey.doubleValue() | ||
| } | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Serializer for class [[Countings]] | ||
| * | ||
| * This class is thread safe. | ||
| */ | ||
| class CountingsSerializer { | ||
|
|
||
| final def serialize(obj: Countings, dataType: DataType): Array[Byte] = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just put this in the Percentile class. |
||
| val counts = obj.counts | ||
|
|
||
| // Write the size of counts map. | ||
| val sizeProjection = UnsafeProjection.create(Array[DataType](IntegerType)) | ||
| val row = InternalRow.apply(counts.size) | ||
| var buffer = sizeProjection.apply(row).getBytes | ||
|
|
||
| // Write the pairs of counts map. | ||
| val projection = UnsafeProjection.create(Array[DataType](dataType, LongType)) | ||
| counts.foreach { pair => | ||
| val row = InternalRow.apply(pair._1, pair._2) | ||
| val unsafeRow = projection.apply(row) | ||
| buffer ++= unsafeRow.getBytes | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is extremely expensive, because you are resizing the buffer for every entry. Please use a |
||
| } | ||
|
|
||
| buffer | ||
| } | ||
|
|
||
| final def deserialize(bytes: Array[Byte], dataType: DataType): Countings = { | ||
| val counts = new OpenHashMap[Number, Long] | ||
| var offset = 0 | ||
|
|
||
| // Read the size of counts map | ||
| val sizeRow = new UnsafeRow(1) | ||
| val rowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(1) | ||
| sizeRow.pointTo(bytes, rowSizeInBytes) | ||
| val size = sizeRow.get(0, IntegerType).asInstanceOf[Integer] | ||
| offset += rowSizeInBytes | ||
|
|
||
| // Read the pairs of counts map | ||
| val row = new UnsafeRow(2) | ||
| val pairRowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(2) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might cause an issue for a |
||
| var i = 0 | ||
| while (i < size) { | ||
| row.pointTo(bytes, offset + BYTE_ARRAY_OFFSET, pairRowSizeInBytes) | ||
| val key = row.get(0, dataType).asInstanceOf[Number] | ||
| val count = row.get(1, LongType).asInstanceOf[Long] | ||
| offset += pairRowSizeInBytes | ||
| counts.update(key, count) | ||
| i += 1 | ||
| } | ||
| Countings(counts) | ||
| } | ||
| } | ||
|
|
||
| val serializer: CountingsSerializer = new CountingsSerializer | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be problematic with serialization. Just put the percentages in a
@transient lazy valand inline the use ofreturnPercentileArray.