Skip to content
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ object FunctionRegistry {
expression[Max]("max"),
expression[Average]("mean"),
expression[Min]("min"),
expression[Percentile]("percentile"),
expression[Skewness]("skewness"),
expression[ApproximatePercentile]("percentile_approx"),
expression[StddevSamp]("std"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.Countings
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET
import org.apache.spark.util.collection.OpenHashMap


/**
* The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at
* the given percentage(s) with value range in [0.0, 1.0].
*
* The operator is bound to the slower sort based aggregation path because the number of elements
* and their partial order cannot be determined in advance. Therefore we have to store all the
* elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory
* Errors.
*
* @param child child expression that produce numeric column value with `child.eval(inputRow)`
* @param percentageExpression Expression that represents a single percentage value or an array of
* percentage values. Each percentage value must be in the range
* [0.0, 1.0].
*/
@ExpressionDescription(
usage =
"""
_FUNC_(col, percentage) - Returns the exact percentile value of numeric column `col` at the
given percentage. The value of percentage must be between 0.0 and 1.0.

_FUNC_(col, array(percentage1 [, percentage2]...)) - Returns the exact percentile value array
of numeric column `col` at the given percentage(s). Each value of the percentage array must
be between 0.0 and 1.0.
""")
case class Percentile(
child: Expression,
percentageExpression: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[Countings] {

def this(child: Expression, percentageExpression: Expression) = {
this(child, percentageExpression, 0, 0)
}

override def prettyName: String = "percentile"

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Percentile =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Percentile =
copy(inputAggBufferOffset = newInputAggBufferOffset)

// Mark as lazy so that percentageExpression is not evaluated during tree transformation.
private lazy val (returnPercentileArray: Boolean, percentages: Seq[Number]) =

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be problematic with serialization. Just put the percentages in a @transient lazy val and inline the use of returnPercentileArray.

evalPercentages(percentageExpression)

override def children: Seq[Expression] = child :: percentageExpression :: Nil

// Returns null for empty inputs
override def nullable: Boolean = true

override def dataType: DataType =

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override lazy val dataType: DataType = percentageExpression.dataType match {
  case _: ArrayType => ArrayType(DoubleType, false)
  case _ => DoubleType
}

if (returnPercentileArray) ArrayType(DoubleType) else DoubleType

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should return the type of the input. We can always interpolate the value and cast that to the input type. Is this is different from what Hive does?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HIVE could return double value or array of double values even the column dataType is integer, for example:

hive> insert into tbl values(1,2,5,10);
hive> insert into tbl values(1),(2),(5),(10);
hive> select percentile(a, array(0, 0.25, 0.5, 0.75, 1)) from tbl;
[1.0,1.75,3.5,6.25,10.0]


override def inputTypes: Seq[AbstractDataType] =
Seq(NumericType, TypeCollection(NumericType, ArrayType))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be the devil's advocate: Why don't we support AtomicTypes here? They are all orderable.

@hvanhovell hvanhovell Nov 17, 2016

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type of the second argument should be FractionalType or even DoubleType. Can we also specify the type of the array (FractionalType)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should support AtomicTypes because not all of them could be converted to double, and we have to use double values to do interpolation operations.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't specify the type of the array(FractionalType), do we want to support DecimalType in the array?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supporting NumericType does not really make sense for the percentage value. Use FractionalType instead.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW - you can make the analyzer add casts for you:

override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match {
  case _: ArrayType => Seq(NumericType, ArrayType(DoubleType, false))
  case _ => Seq(NumericType, DoubleType)
}

Then you are alway sure you get a double or a double array for the percentageExpression.


override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function percentile")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call super.checkInputDataTypes(), that will validate the inputTypes(). Also check the percentageExpression, that must foldable and the percentage(s) must be in the range [0, 1].


override def createAggregationBuffer(): Countings = {
// Initialize new Countings instance here.
Countings()
}

private def evalPercentages(expr: Expression): (Boolean, Seq[Number]) = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not return doubles?

val (isArrayType, values) = (expr.dataType, expr.eval()) match {
case (_, n: Number) => (false, Array(n))
case (_, d: Decimal) => (false, Array(d.toDouble.asInstanceOf[Number]))
case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
val numericArray = arrayData.toObjectArray(baseType)
(true, numericArray.map { x =>
baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]).asInstanceOf[Number]
})
case other =>
throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage")
}

require(values.forall(value => value.doubleValue() >= 0.0 && value.doubleValue() <= 1.0),
s"Percentage values must be between 0.0 and 1.0, current values = ${values.mkString(", ")}")

(isArrayType, values)
}

override def update(buffer: Countings, input: InternalRow): Unit = {
val key = child.eval(input).asInstanceOf[Number]
buffer.add(key)
}

override def merge(buffer: Countings, other: Countings): Unit = {
buffer.merge(other)
}

override def eval(buffer: Countings): Any = {
generateOutput(buffer.getPercentiles(percentages))
}

private def generateOutput(results: Seq[Double]): Any = {
if (results.isEmpty) {
null
} else if (returnPercentileArray) {
new GenericArrayData(results)
} else {
results.head
}
}

override def serialize(obj: Countings): Array[Byte] = {
Percentile.serializer.serialize(obj, child.dataType)
}

override def deserialize(bytes: Array[Byte]): Countings = {
Percentile.serializer.deserialize(bytes, child.dataType)
}
}

object Percentile {
object Countings {
def apply(): Countings = Countings(new OpenHashMap[Number, Long])

def apply(counts: OpenHashMap[Number, Long]): Countings = new Countings(counts)
}

/**
* A class that stores the numbers and their counts, used to support [[Percentile]] function.
*/
class Countings(val counts: OpenHashMap[Number, Long]) extends Serializable {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this class and put its implementation in the Percentile Aggregate.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class TypedImperativeAggregate[T] requires access of this class, so perhaps we should keep it outside of the Percentile.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could entirely remove the class Countings.

/**
* Insert a key into countings map.
*/
def add(key: Number): Unit = {
// Null values are ignored in countings.
if (key != null) {
counts.changeValue(key, 1L, _ + 1L)
}
}

/**
* In place merges in another Countings.
*/
def merge(other: Countings): Unit = {
other.counts.foreach { pair =>
counts.changeValue(pair._1, pair._2, _ + pair._2)
}
}

/**
* Get the percentile value for every percentile in `percentages`.
*/
def getPercentiles(percentages: Seq[Number]): Seq[Double] = {
if (counts.isEmpty) {
return Seq.empty
}

val sortedCounts = counts.toSeq.sortBy(_._1)(new Ordering[Number]() {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use child.asInstanceOf[NumericType].ordering.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a dumb question: How can we order a sequence of Number using the Ordering[NumericType#InternalType] ?

@hvanhovell hvanhovell Nov 26, 2016

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could cast the ordering?

override def compare(a: Number, b: Number): Int =
scala.math.signum(a.doubleValue() - b.doubleValue()).toInt
})
val aggreCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use an imperative loop.

(k1: (Number, Long), k2: (Number, Long)) => (k2._1, k1._2 + k2._2)
}.tail
val maxPosition = aggreCounts.last._2 - 1

percentages.map { percentile =>
getPercentile(aggreCounts, maxPosition * percentile.doubleValue()).doubleValue()
}
}

/**
* Get the percentile value.
*/
private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = {
// We may need to do linear interpolation to get the exact percentile
val lower = position.floor
val higher = position.ceil

// Linear search since this won't take much time from the total execution anyway

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't make it right :)... Anyway there are enough binarySearch implementations around. So maybe use one of those.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was taken from Hive UDAFPercentile. It is fine if you do that, but please acknowledge that you have done so by adding a line of documentation. See this for example: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala#L524

// lower has the range of [0 .. total-1]
// The first entry with accumulated count (lower+1) corresponds to the lower position.
var i = 0
while (aggreCounts(i)._2 < lower + 1) {
i += 1
}

val lowerKey = aggreCounts(i)._1
if (higher == lower) {
// no interpolation needed because position does not have a fraction
return lowerKey
}

if (aggreCounts(i)._2 < higher + 1) {
i += 1
}
val higherKey = aggreCounts(i)._1

if (higherKey == lowerKey) {
// no interpolation needed because lower position and higher position has the same key
return lowerKey
}

// Linear interpolation to get the exact percentile
return (higher - position) * lowerKey.doubleValue() +
(position - lower) * higherKey.doubleValue()
}
}


/**
* Serializer for class [[Countings]]
*
* This class is thread safe.
*/
class CountingsSerializer {

final def serialize(obj: Countings, dataType: DataType): Array[Byte] = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just put this in the Percentile class.

val counts = obj.counts

// Write the size of counts map.
val sizeProjection = UnsafeProjection.create(Array[DataType](IntegerType))
val row = InternalRow.apply(counts.size)
var buffer = sizeProjection.apply(row).getBytes

// Write the pairs of counts map.
val projection = UnsafeProjection.create(Array[DataType](dataType, LongType))
counts.foreach { pair =>
val row = InternalRow.apply(pair._1, pair._2)
val unsafeRow = projection.apply(row)
buffer ++= unsafeRow.getBytes

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is extremely expensive, because you are resizing the buffer for every entry. Please use a ByteArrayOutputStream and a DataOutputStream. See this for an example: https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala#L226-L239

}

buffer
}

final def deserialize(bytes: Array[Byte], dataType: DataType): Countings = {
val counts = new OpenHashMap[Number, Long]
var offset = 0

// Read the size of counts map
val sizeRow = new UnsafeRow(1)
val rowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(1)
sizeRow.pointTo(bytes, rowSizeInBytes)
val size = sizeRow.get(0, IntegerType).asInstanceOf[Integer]
offset += rowSizeInBytes

// Read the pairs of counts map
val row = new UnsafeRow(2)
val pairRowSizeInBytes = UnsafeRow.calculateFixedPortionByteSize(2)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might cause an issue for a DecimalType, a decimal does not have to be fixed. I think we need to write out row sizes or not allow variable length keys. BTW if you only allow fixed length keys, you could get rid of UnsafeRows and projections and directly use a DataOutputStream.

var i = 0
while (i < size) {
row.pointTo(bytes, offset + BYTE_ARRAY_OFFSET, pairRowSizeInBytes)
val key = row.get(0, dataType).asInstanceOf[Number]
val count = row.get(1, LongType).asInstanceOf[Long]
offset += pairRowSizeInBytes
counts.update(key, count)
i += 1
}
Countings(counts)
}
}

val serializer: CountingsSerializer = new CountingsSerializer
}
Loading