Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -347,18 +347,12 @@ object HiveTypeCoercion {

case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

/**
Expand All @@ -44,186 +42,169 @@ import org.apache.spark.sql.types._
*
* @param child to compute central moments of.
*/
abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable {
abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate {

/**
* The central moment order to be computed.
*/
protected def momentOrder: Int

override def children: Seq[Expression] = Seq(child)

override def nullable: Boolean = true

override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
protected val n = AttributeReference("n", DoubleType, nullable = false)()
protected val avg = AttributeReference("avg", DoubleType, nullable = false)()
protected val m2 = AttributeReference("m2", DoubleType, nullable = false)()
protected val m3 = AttributeReference("m3", DoubleType, nullable = false)()
protected val m4 = AttributeReference("m4", DoubleType, nullable = false)()

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")
private def trimHigherOrder[T](expressions: Seq[T]) = expressions.take(momentOrder + 1)

override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
override val aggBufferAttributes = trimHigherOrder(Seq(n, avg, m2, m3, m4))

/**
* Size of aggregation buffer.
*/
private[this] val bufferSize = 5
override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0))

override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i =>
AttributeReference(s"M$i", DoubleType)()
override val updateExpressions: Seq[Expression] = {
val newN = n + Literal(1.0)
val delta = child - avg
val deltaN = delta / newN
val newAvg = avg + deltaN
val newM2 = m2 + delta * (delta - deltaN)

val delta2 = delta * delta
val deltaN2 = deltaN * deltaN
val newM3 = if (momentOrder >= 3) {
m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2)
} else {
Literal(0.0)
}
val newM4 = if (momentOrder >= 4) {
m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 +
delta * (delta * delta2 - deltaN * deltaN2)
} else {
Literal(0.0)
}

trimHigherOrder(Seq(
If(IsNull(child), n, newN),
If(IsNull(child), avg, newAvg),
If(IsNull(child), m2, newM2),
If(IsNull(child), m3, newM3),
If(IsNull(child), m4, newM4)
))
}

// Note: although this simply copies aggBufferAttributes, this common code can not be placed
// in the superclass because that will lead to initialization ordering issues.
override val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

// buffer offsets
private[this] val nOffset = mutableAggBufferOffset
private[this] val meanOffset = mutableAggBufferOffset + 1
private[this] val secondMomentOffset = mutableAggBufferOffset + 2
private[this] val thirdMomentOffset = mutableAggBufferOffset + 3
private[this] val fourthMomentOffset = mutableAggBufferOffset + 4

// frequently used values for online updates
private[this] var delta = 0.0
private[this] var deltaN = 0.0
private[this] var delta2 = 0.0
private[this] var deltaN2 = 0.0
private[this] var n = 0.0
private[this] var mean = 0.0
private[this] var m2 = 0.0
private[this] var m3 = 0.0
private[this] var m4 = 0.0
override val mergeExpressions: Seq[Expression] = {

/**
* Initialize all moments to zero.
*/
override def initialize(buffer: MutableRow): Unit = {
for (aggIndex <- 0 until bufferSize) {
buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0)
val n1 = n.left
val n2 = n.right
val newN = n1 + n2
val delta = avg.right - avg.left
val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN)
val newAvg = avg.left + deltaN * n2

// higher order moments computed according to:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
val newM2 = m2.left + m2.right + delta * deltaN * n1 * n2
// `m3.right` is not available if momentOrder < 3
val newM3 = if (momentOrder >= 3) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a Cast() here is very expensive

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if branch is not necessary because we will trim the result at the end. Shall we follow the approach in updateExpression and remove it? It is still useful to leave a comment here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These expression require m3.right, it is not valid if momentOrder < 3, will add a comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then shall we make updateExpression follow the same style?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updateExpression does not need that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that updateExpression is not necessary. This is just to make the code more consistent. For example, I figured it out why this is not necessary in updateExpression, then I was confused by why we use the if branches here. The logic would be clearer if we use if branches in both methods.

m3.left + m3.right + deltaN * deltaN * delta * n1 * n2 * (n1 - n2) +
Literal(3.0) * deltaN * (n1 * m2.right - n2 * m2.left)
} else {
Literal(0.0)
}
// `m4.right` is not available if momentOrder < 4
val newM4 = if (momentOrder >= 4) {
m4.left + m4.right +
deltaN * deltaN * deltaN * delta * n1 * n2 * (n1 * n1 - n1 * n2 + n2 * n2) +
Literal(6.0) * deltaN * deltaN * (n1 * n1 * m2.right + n2 * n2 * m2.left) +
Literal(4.0) * deltaN * (n1 * m3.right - n2 * m3.left)
} else {
Literal(0.0)
}

trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4))
}
}

/**
* Update the central moments buffer.
*/
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val v = Cast(child, DoubleType).eval(input)
if (v != null) {
val updateValue = v match {
case d: Double => d
}

n = buffer.getDouble(nOffset)
mean = buffer.getDouble(meanOffset)

n += 1.0
buffer.setDouble(nOffset, n)
delta = updateValue - mean
deltaN = delta / n
mean += deltaN
buffer.setDouble(meanOffset, mean)

if (momentOrder >= 2) {
m2 = buffer.getDouble(secondMomentOffset)
m2 += delta * (delta - deltaN)
buffer.setDouble(secondMomentOffset, m2)
}

if (momentOrder >= 3) {
delta2 = delta * delta
deltaN2 = deltaN * deltaN
m3 = buffer.getDouble(thirdMomentOffset)
m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2)
buffer.setDouble(thirdMomentOffset, m3)
}

if (momentOrder >= 4) {
m4 = buffer.getDouble(fourthMomentOffset)
m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 +
delta * (delta * delta2 - deltaN * deltaN2)
buffer.setDouble(fourthMomentOffset, m4)
}
}
// Compute the population standard deviation of a column
case class StddevPop(child: Expression) extends CentralMomentAgg(child) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, StddevPop is simply Sqrt(VariancePop) (and StddevSamp = Sqrt(VarianceSamp)). I'm not sure whether it can help simplify the code here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may still want to have a good name when you call explain, I'd like to keep them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those if branches are important to save computation for low-order statistics. Even we won't use CentralMomentAgg for second-order statistics, it is still good to keep them.

override protected def momentOrder = 2

override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
Sqrt(m2 / n))
}

/**
* Merge two central moment buffers.
*/
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
val n1 = buffer1.getDouble(nOffset)
val n2 = buffer2.getDouble(inputAggBufferOffset)
val mean1 = buffer1.getDouble(meanOffset)
val mean2 = buffer2.getDouble(inputAggBufferOffset + 1)
override def prettyName: String = "stddev_pop"
}

// Compute the sample standard deviation of a column
case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {

override protected def momentOrder = 2

var secondMoment1 = 0.0
var secondMoment2 = 0.0
override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
If(n === Literal(1.0), Literal(Double.NaN),
Sqrt(m2 / (n - Literal(1.0)))))
}

var thirdMoment1 = 0.0
var thirdMoment2 = 0.0
override def prettyName: String = "stddev_samp"
}

var fourthMoment1 = 0.0
var fourthMoment2 = 0.0
// Compute the population variance of a column
case class VariancePop(child: Expression) extends CentralMomentAgg(child) {

n = n1 + n2
buffer1.setDouble(nOffset, n)
delta = mean2 - mean1
deltaN = if (n == 0.0) 0.0 else delta / n
mean = mean1 + deltaN * n2
buffer1.setDouble(mutableAggBufferOffset + 1, mean)
override protected def momentOrder = 2

// higher order moments computed according to:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
if (momentOrder >= 2) {
secondMoment1 = buffer1.getDouble(secondMomentOffset)
secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2)
m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2
buffer1.setDouble(secondMomentOffset, m2)
}
override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
m2 / n)
}

if (momentOrder >= 3) {
thirdMoment1 = buffer1.getDouble(thirdMomentOffset)
thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3)
m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 *
(n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1)
buffer1.setDouble(thirdMomentOffset, m3)
}
override def prettyName: String = "var_pop"
}

if (momentOrder >= 4) {
fourthMoment1 = buffer1.getDouble(fourthMomentOffset)
fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4)
m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 *
n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 *
(n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) +
4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1)
buffer1.setDouble(fourthMomentOffset, m4)
}
// Compute the sample variance of a column
case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {

override protected def momentOrder = 2

override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
If(n === Literal(1.0), Literal(Double.NaN),
m2 / (n - Literal(1.0))))
}

/**
* Compute aggregate statistic from sufficient moments.
* @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized)
* needed to compute the aggregate stat.
*/
def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any

override final def eval(buffer: InternalRow): Any = {
val n = buffer.getDouble(nOffset)
val mean = buffer.getDouble(meanOffset)
val moments = Array.ofDim[Double](momentOrder + 1)
moments(0) = 1.0
moments(1) = 0.0
if (momentOrder >= 2) {
moments(2) = buffer.getDouble(secondMomentOffset)
}
if (momentOrder >= 3) {
moments(3) = buffer.getDouble(thirdMomentOffset)
}
if (momentOrder >= 4) {
moments(4) = buffer.getDouble(fourthMomentOffset)
}
override def prettyName: String = "var_samp"
}

case class Skewness(child: Expression) extends CentralMomentAgg(child) {

override def prettyName: String = "skewness"

override protected def momentOrder = 3

getStatistic(n, mean, moments)
override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
If(m2 === Literal(0.0), Literal(Double.NaN),
Sqrt(n) * m3 / Sqrt(m2 * m2 * m2)))
}
}

case class Kurtosis(child: Expression) extends CentralMomentAgg(child) {

override protected def momentOrder = 4

override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
If(m2 === Literal(0.0), Literal(Double.NaN),
n * m4 / (m2 * m2) - Literal(3.0)))
}

override def prettyName: String = "kurtosis"
}
Loading