-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-12913] [SQL] Improve performance of stat functions #10960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
61edd5e
3c8d737
ae83955
1481bb4
448e0e1
1b95b7c
ae78e81
1086810
383c193
ab32659
7e57a1a
5f98588
fe6fe50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,10 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.expressions.aggregate | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.util.TypeUtils | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
|
|
@@ -44,186 +42,169 @@ import org.apache.spark.sql.types._ | |
| * | ||
| * @param child to compute central moments of. | ||
| */ | ||
| abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { | ||
| abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate { | ||
|
|
||
| /** | ||
| * The central moment order to be computed. | ||
| */ | ||
| protected def momentOrder: Int | ||
|
|
||
| override def children: Seq[Expression] = Seq(child) | ||
|
|
||
| override def nullable: Boolean = true | ||
|
|
||
| override def dataType: DataType = DoubleType | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) | ||
| protected val n = AttributeReference("n", DoubleType, nullable = false)() | ||
| protected val avg = AttributeReference("avg", DoubleType, nullable = false)() | ||
| protected val m2 = AttributeReference("m2", DoubleType, nullable = false)() | ||
| protected val m3 = AttributeReference("m3", DoubleType, nullable = false)() | ||
| protected val m4 = AttributeReference("m4", DoubleType, nullable = false)() | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = | ||
| TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") | ||
| private def trimHigherOrder[T](expressions: Seq[T]) = expressions.take(momentOrder + 1) | ||
|
|
||
| override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) | ||
| override val aggBufferAttributes = trimHigherOrder(Seq(n, avg, m2, m3, m4)) | ||
|
|
||
| /** | ||
| * Size of aggregation buffer. | ||
| */ | ||
| private[this] val bufferSize = 5 | ||
| override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) | ||
|
|
||
| override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => | ||
| AttributeReference(s"M$i", DoubleType)() | ||
| override val updateExpressions: Seq[Expression] = { | ||
| val newN = n + Literal(1.0) | ||
| val delta = child - avg | ||
| val deltaN = delta / newN | ||
| val newAvg = avg + deltaN | ||
| val newM2 = m2 + delta * (delta - deltaN) | ||
|
|
||
| val delta2 = delta * delta | ||
| val deltaN2 = deltaN * deltaN | ||
| val newM3 = if (momentOrder >= 3) { | ||
| m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) | ||
| } else { | ||
| Literal(0.0) | ||
| } | ||
| val newM4 = if (momentOrder >= 4) { | ||
| m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + | ||
| delta * (delta * delta2 - deltaN * deltaN2) | ||
| } else { | ||
| Literal(0.0) | ||
| } | ||
|
|
||
| trimHigherOrder(Seq( | ||
| If(IsNull(child), n, newN), | ||
| If(IsNull(child), avg, newAvg), | ||
| If(IsNull(child), m2, newM2), | ||
| If(IsNull(child), m3, newM3), | ||
| If(IsNull(child), m4, newM4) | ||
| )) | ||
| } | ||
|
|
||
| // Note: although this simply copies aggBufferAttributes, this common code can not be placed | ||
| // in the superclass because that will lead to initialization ordering issues. | ||
| override val inputAggBufferAttributes: Seq[AttributeReference] = | ||
| aggBufferAttributes.map(_.newInstance()) | ||
|
|
||
| // buffer offsets | ||
| private[this] val nOffset = mutableAggBufferOffset | ||
| private[this] val meanOffset = mutableAggBufferOffset + 1 | ||
| private[this] val secondMomentOffset = mutableAggBufferOffset + 2 | ||
| private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 | ||
| private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 | ||
|
|
||
| // frequently used values for online updates | ||
| private[this] var delta = 0.0 | ||
| private[this] var deltaN = 0.0 | ||
| private[this] var delta2 = 0.0 | ||
| private[this] var deltaN2 = 0.0 | ||
| private[this] var n = 0.0 | ||
| private[this] var mean = 0.0 | ||
| private[this] var m2 = 0.0 | ||
| private[this] var m3 = 0.0 | ||
| private[this] var m4 = 0.0 | ||
| override val mergeExpressions: Seq[Expression] = { | ||
|
|
||
| /** | ||
| * Initialize all moments to zero. | ||
| */ | ||
| override def initialize(buffer: MutableRow): Unit = { | ||
| for (aggIndex <- 0 until bufferSize) { | ||
| buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) | ||
| val n1 = n.left | ||
| val n2 = n.right | ||
| val newN = n1 + n2 | ||
| val delta = avg.right - avg.left | ||
| val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN) | ||
| val newAvg = avg.left + deltaN * n2 | ||
|
|
||
| // higher order moments computed according to: | ||
| // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics | ||
| val newM2 = m2.left + m2.right + delta * deltaN * n1 * n2 | ||
| // `m3.right` is not available if momentOrder < 3 | ||
| val newM3 = if (momentOrder >= 3) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These expression require
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then shall we make
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that |
||
| m3.left + m3.right + deltaN * deltaN * delta * n1 * n2 * (n1 - n2) + | ||
| Literal(3.0) * deltaN * (n1 * m2.right - n2 * m2.left) | ||
| } else { | ||
| Literal(0.0) | ||
| } | ||
| // `m4.right` is not available if momentOrder < 4 | ||
| val newM4 = if (momentOrder >= 4) { | ||
| m4.left + m4.right + | ||
| deltaN * deltaN * deltaN * delta * n1 * n2 * (n1 * n1 - n1 * n2 + n2 * n2) + | ||
| Literal(6.0) * deltaN * deltaN * (n1 * n1 * m2.right + n2 * n2 * m2.left) + | ||
| Literal(4.0) * deltaN * (n1 * m3.right - n2 * m3.left) | ||
| } else { | ||
| Literal(0.0) | ||
| } | ||
|
|
||
| trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Update the central moments buffer. | ||
| */ | ||
| override def update(buffer: MutableRow, input: InternalRow): Unit = { | ||
| val v = Cast(child, DoubleType).eval(input) | ||
| if (v != null) { | ||
| val updateValue = v match { | ||
| case d: Double => d | ||
| } | ||
|
|
||
| n = buffer.getDouble(nOffset) | ||
| mean = buffer.getDouble(meanOffset) | ||
|
|
||
| n += 1.0 | ||
| buffer.setDouble(nOffset, n) | ||
| delta = updateValue - mean | ||
| deltaN = delta / n | ||
| mean += deltaN | ||
| buffer.setDouble(meanOffset, mean) | ||
|
|
||
| if (momentOrder >= 2) { | ||
| m2 = buffer.getDouble(secondMomentOffset) | ||
| m2 += delta * (delta - deltaN) | ||
| buffer.setDouble(secondMomentOffset, m2) | ||
| } | ||
|
|
||
| if (momentOrder >= 3) { | ||
| delta2 = delta * delta | ||
| deltaN2 = deltaN * deltaN | ||
| m3 = buffer.getDouble(thirdMomentOffset) | ||
| m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) | ||
| buffer.setDouble(thirdMomentOffset, m3) | ||
| } | ||
|
|
||
| if (momentOrder >= 4) { | ||
| m4 = buffer.getDouble(fourthMomentOffset) | ||
| m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + | ||
| delta * (delta * delta2 - deltaN * deltaN2) | ||
| buffer.setDouble(fourthMomentOffset, m4) | ||
| } | ||
| } | ||
| // Compute the population standard deviation of a column | ||
| case class StddevPop(child: Expression) extends CentralMomentAgg(child) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may still want to have a good name when you call
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. |
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those |
||
| override protected def momentOrder = 2 | ||
|
|
||
| override val evaluateExpression: Expression = { | ||
| If(n === Literal(0.0), Literal.create(null, DoubleType), | ||
| Sqrt(m2 / n)) | ||
| } | ||
|
|
||
| /** | ||
| * Merge two central moment buffers. | ||
| */ | ||
| override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { | ||
| val n1 = buffer1.getDouble(nOffset) | ||
| val n2 = buffer2.getDouble(inputAggBufferOffset) | ||
| val mean1 = buffer1.getDouble(meanOffset) | ||
| val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) | ||
| override def prettyName: String = "stddev_pop" | ||
| } | ||
|
|
||
| // Compute the sample standard deviation of a column | ||
| case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { | ||
|
|
||
| override protected def momentOrder = 2 | ||
|
|
||
| var secondMoment1 = 0.0 | ||
| var secondMoment2 = 0.0 | ||
| override val evaluateExpression: Expression = { | ||
| If(n === Literal(0.0), Literal.create(null, DoubleType), | ||
| If(n === Literal(1.0), Literal(Double.NaN), | ||
| Sqrt(m2 / (n - Literal(1.0))))) | ||
| } | ||
|
|
||
| var thirdMoment1 = 0.0 | ||
| var thirdMoment2 = 0.0 | ||
| override def prettyName: String = "stddev_samp" | ||
| } | ||
|
|
||
| var fourthMoment1 = 0.0 | ||
| var fourthMoment2 = 0.0 | ||
| // Compute the population variance of a column | ||
| case class VariancePop(child: Expression) extends CentralMomentAgg(child) { | ||
|
|
||
| n = n1 + n2 | ||
| buffer1.setDouble(nOffset, n) | ||
| delta = mean2 - mean1 | ||
| deltaN = if (n == 0.0) 0.0 else delta / n | ||
| mean = mean1 + deltaN * n2 | ||
| buffer1.setDouble(mutableAggBufferOffset + 1, mean) | ||
| override protected def momentOrder = 2 | ||
|
|
||
| // higher order moments computed according to: | ||
| // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics | ||
| if (momentOrder >= 2) { | ||
| secondMoment1 = buffer1.getDouble(secondMomentOffset) | ||
| secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) | ||
| m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 | ||
| buffer1.setDouble(secondMomentOffset, m2) | ||
| } | ||
| override val evaluateExpression: Expression = { | ||
| If(n === Literal(0.0), Literal.create(null, DoubleType), | ||
| m2 / n) | ||
| } | ||
|
|
||
| if (momentOrder >= 3) { | ||
| thirdMoment1 = buffer1.getDouble(thirdMomentOffset) | ||
| thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) | ||
| m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * | ||
| (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) | ||
| buffer1.setDouble(thirdMomentOffset, m3) | ||
| } | ||
| override def prettyName: String = "var_pop" | ||
| } | ||
|
|
||
| if (momentOrder >= 4) { | ||
| fourthMoment1 = buffer1.getDouble(fourthMomentOffset) | ||
| fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) | ||
| m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * | ||
| n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * | ||
| (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + | ||
| 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) | ||
| buffer1.setDouble(fourthMomentOffset, m4) | ||
| } | ||
| // Compute the sample variance of a column | ||
| case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { | ||
|
|
||
| override protected def momentOrder = 2 | ||
|
|
||
| override val evaluateExpression: Expression = { | ||
| If(n === Literal(0.0), Literal.create(null, DoubleType), | ||
| If(n === Literal(1.0), Literal(Double.NaN), | ||
| m2 / (n - Literal(1.0)))) | ||
| } | ||
|
|
||
| /** | ||
| * Compute aggregate statistic from sufficient moments. | ||
| * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) | ||
| * needed to compute the aggregate stat. | ||
| */ | ||
| def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any | ||
|
|
||
| override final def eval(buffer: InternalRow): Any = { | ||
| val n = buffer.getDouble(nOffset) | ||
| val mean = buffer.getDouble(meanOffset) | ||
| val moments = Array.ofDim[Double](momentOrder + 1) | ||
| moments(0) = 1.0 | ||
| moments(1) = 0.0 | ||
| if (momentOrder >= 2) { | ||
| moments(2) = buffer.getDouble(secondMomentOffset) | ||
| } | ||
| if (momentOrder >= 3) { | ||
| moments(3) = buffer.getDouble(thirdMomentOffset) | ||
| } | ||
| if (momentOrder >= 4) { | ||
| moments(4) = buffer.getDouble(fourthMomentOffset) | ||
| } | ||
| override def prettyName: String = "var_samp" | ||
| } | ||
|
|
||
| case class Skewness(child: Expression) extends CentralMomentAgg(child) { | ||
|
|
||
| override def prettyName: String = "skewness" | ||
|
|
||
| override protected def momentOrder = 3 | ||
|
|
||
| getStatistic(n, mean, moments) | ||
| override val evaluateExpression: Expression = { | ||
| If(n === Literal(0.0), Literal.create(null, DoubleType), | ||
| If(m2 === Literal(0.0), Literal(Double.NaN), | ||
| Sqrt(n) * m3 / Sqrt(m2 * m2 * m2))) | ||
| } | ||
| } | ||
|
|
||
| case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { | ||
|
|
||
| override protected def momentOrder = 4 | ||
|
|
||
| override val evaluateExpression: Expression = { | ||
| If(n === Literal(0.0), Literal.create(null, DoubleType), | ||
| If(m2 === Literal(0.0), Literal(Double.NaN), | ||
| n * m4 / (m2 * m2) - Literal(3.0))) | ||
| } | ||
|
|
||
| override def prettyName: String = "kurtosis" | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating a Cast() here is very expensive