Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ object FunctionRegistry {

// aggregate functions
expression[Average]("avg"),
expression[Corr]("corr"),
expression[Count]("count"),
expression[First]("first"),
expression[First]("first_value"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util
import com.clearspring.analytics.hash.MurmurHash

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -524,6 +525,164 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
override val evaluateExpression = Cast(currentSum, resultType)
}

/**
* Compute Pearson correlation between two expressions.
* When applied on empty data (i.e., count is zero), it returns NULL.
*
* Definition of Pearson correlation can be found at
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
*
* @param left one of the expressions to compute correlation with.
* @param right another expression to compute correlation with.
*/
case class Corr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Please add ScalaDoc and document the behavior for null and NaN values.
  • Provide a link to the wikipedia page that contains the update formula.

left: Expression,
right: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends ImperativeAggregate {

def children: Seq[Expression] = Seq(left, right)

def nullable: Boolean = false

def dataType: DataType = DoubleType

override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)

def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

def inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance())

val aggBufferAttributes: Seq[AttributeReference] = Seq(
AttributeReference("xAvg", DoubleType)(),
AttributeReference("yAvg", DoubleType)(),
AttributeReference("Ck", DoubleType)(),
AttributeReference("MkX", DoubleType)(),
AttributeReference("MkY", DoubleType)(),
AttributeReference("count", LongType)())

// Local cache of mutableAggBufferOffset(s) that will be used in update and merge
private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1
private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2
private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3
private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4
private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5

// Local cache of inputAggBufferOffset(s) that will be used in update and merge
private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1
private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2
private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3
private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4
private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def initialize(buffer: MutableRow): Unit = {
buffer.setDouble(mutableAggBufferOffset, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0)
buffer.setLong(mutableAggBufferOffsetPlus5, 0L)
}

override def update(buffer: MutableRow, input: InternalRow): Unit = {
val leftEval = left.eval(input)
val rightEval = right.eval(input)

if (leftEval != null && rightEval != null) {
val x = leftEval.asInstanceOf[Double]
val y = rightEval.asInstanceOf[Double]

var xAvg = buffer.getDouble(mutableAggBufferOffset)
var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1)
var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
var count = buffer.getLong(mutableAggBufferOffsetPlus5)

val deltaX = x - xAvg
val deltaY = y - yAvg
count += 1
xAvg += deltaX / count
yAvg += deltaY / count
Ck += deltaX * (y - yAvg)
MkX += deltaX * (x - xAvg)
MkY += deltaY * (y - yAvg)

buffer.setDouble(mutableAggBufferOffset, xAvg)
buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg)
buffer.setDouble(mutableAggBufferOffsetPlus2, Ck)
buffer.setDouble(mutableAggBufferOffsetPlus3, MkX)
buffer.setDouble(mutableAggBufferOffsetPlus4, MkY)
buffer.setLong(mutableAggBufferOffsetPlus5, count)
}
}

// Merge counters from other partitions. Formula can be found at:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
val count2 = buffer2.getLong(inputAggBufferOffsetPlus5)

// We only go to merge two buffers if there is at least one record aggregated in buffer2.
// We don't need to check count in buffer1 because if count2 is more than zero, totalCount
// is more than zero too, then we won't get a divide by zero exception.
if (count2 > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to assume that the count2 in buffer1 is non zero? There is - currently - no documentation on this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need to consider count in buffer2. I will add document for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. Now it is obvious, I wasn't thinking...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add comment for it?

var xAvg = buffer1.getDouble(mutableAggBufferOffset)
var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1)
var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2)
var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3)
var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4)
var count = buffer1.getLong(mutableAggBufferOffsetPlus5)

val xAvg2 = buffer2.getDouble(inputAggBufferOffset)
val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1)
val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2)
val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3)
val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4)

val totalCount = count + count2
val deltaX = xAvg - xAvg2
val deltaY = yAvg - yAvg2
Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
xAvg = (xAvg * count + xAvg2 * count2) / totalCount
yAvg = (yAvg * count + yAvg2 * count2) / totalCount
MkX += MkX2 + deltaX * deltaX * count / totalCount * count2
MkY += MkY2 + deltaY * deltaY * count / totalCount * count2
count = totalCount

buffer1.setDouble(mutableAggBufferOffset, xAvg)
buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg)
buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck)
buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX)
buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY)
buffer1.setLong(mutableAggBufferOffsetPlus5, count)
}
}

override def eval(buffer: InternalRow): Any = {
val count = buffer.getLong(mutableAggBufferOffsetPlus5)
if (count > 0) {
val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
val corr = Ck / math.sqrt(MkX * MkY)
if (corr.isNaN) {
null
} else {
corr
}
} else {
null
}
}
}

// scalastyle:off
/**
* HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. This class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ object Utils {
mode = aggregate.Complete,
isDistinct = true)

case expressions.Corr(left, right) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.Corr(left, right),
mode = aggregate.Complete,
isDistinct = false)

case expressions.ApproxCountDistinct(child, rsd) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,24 @@ case class LastFunction(
}
}

/**
* Calculate Pearson Correlation Coefficient for the given columns.
* Only support AggregateExpression2.
*
*/
case class Corr(left: Expression, right: Expression)
extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes {
override def nullable: Boolean = false
override def dataType: DoubleType.type = DoubleType
override def toString: String = s"CORRELATION($left, $right)"
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
override def newInstance(): AggregateFunction1 = {
throw new UnsupportedOperationException(
"Corr only supports the new AggregateExpression2 and can only be used " +
"when spark.sql.useAggregate2 = true")
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be the error message if we call this function when spark.sql.useAggregate2=false? It will be good to provide a meaning error message.


// Compute standard deviation based on online algorithm specified here:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 {
Expand Down
18 changes: 18 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,24 @@ object functions {
*/
def avg(columnName: String): Column = avg(Column(columnName))

/**
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
*
* @group agg_funcs
* @since 1.6.0
*/
def corr(column1: Column, column2: Column): Column =
Corr(column1.expr, column2.expr)

/**
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
*
* @group agg_funcs
* @since 1.6.0
*/
def corr(columnName1: String, columnName2: String): Column =
corr(Column(columnName1), Column(columnName2))

/**
* Aggregate function: returns the number of items in a group.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {

// classpath problems
"compute_stats.*",
"udf_bitmap_.*"
"udf_bitmap_.*",

// The difference between the double numbers generated by Hive and Spark
// can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322)
"udaf_corr"
)

/**
Expand Down Expand Up @@ -858,7 +862,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"type_cast_1",
"type_widening",
"udaf_collect_set",
"udaf_corr",
"udaf_covar_pop",
"udaf_covar_samp",
"udaf_histogram_numeric",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution

import scala.collection.JavaConverters._

import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.aggregate
Expand Down Expand Up @@ -556,6 +557,109 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(0, null, 1, 1, null, 0) :: Nil)
}

test("pearson correlation") {
val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
assert(math.abs(corr1 - 1.0) < 1e-12)
val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
assert(math.abs(corr2 + 1.0) < 1e-12)
// non-trivial example. To reproduce in python, use:
// >>> from scipy.stats import pearsonr
// >>> import numpy as np
// >>> a = np.array(range(20))
// >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
// >>> pearsonr(a, b)
// (0.95723391394758572, 3.8902121417802199e-11)
// In R, use:
// > a <- 0:19
// > b <- mapply(function(x) x * x - 2 * x + 3.5, a)
// > cor(a, b)
// [1] 0.957233913947585835
val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)

val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b")
val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0)
assert(corr4 == Row(null))

val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c")
val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
assert(math.abs(corr5 - 1.0) < 1e-12)
val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
assert(math.abs(corr6 + 1.0) < 1e-12)

// Test for udaf_corr in HiveCompatibilitySuite
// udaf_corr has been blacklisted due to numerical errors
// We test it here:
// SELECT corr(b, c) FROM covar_tab WHERE a < 1; => NULL
// SELECT corr(b, c) FROM covar_tab WHERE a < 3; => NULL
// SELECT corr(b, c) FROM covar_tab WHERE a = 3; => NULL
// SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a; =>
// 1 NULL
// 2 NULL
// 3 NULL
// 4 NULL
// 5 NULL
// 6 NULL
// SELECT corr(b, c) FROM covar_tab; => 0.6633880657639323

val covar_tab = Seq[(Integer, Integer, Integer)](
(1, null, 15),
(2, 3, null),
(3, 7, 12),
(4, 4, 14),
(5, 8, 17),
(6, 2, 11)).toDF("a", "b", "c")

covar_tab.registerTempTable("covar_tab")

checkAnswer(
sqlContext.sql(
"""
|SELECT corr(b, c) FROM covar_tab WHERE a < 1
""".stripMargin),
Row(null) :: Nil)

checkAnswer(
sqlContext.sql(
"""
|SELECT corr(b, c) FROM covar_tab WHERE a < 3
""".stripMargin),
Row(null) :: Nil)

checkAnswer(
sqlContext.sql(
"""
|SELECT corr(b, c) FROM covar_tab WHERE a = 3
""".stripMargin),
Row(null) :: Nil)

checkAnswer(
sqlContext.sql(
"""
|SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a
""".stripMargin),
Row(1, null) ::
Row(2, null) ::
Row(3, null) ::
Row(4, null) ::
Row(5, null) ::
Row(6, null) :: Nil)

val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0)
assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)

withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
val errorMessage = intercept[SparkException] {
val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
}.getMessage
assert(errorMessage.contains("java.lang.UnsupportedOperationException: " +
"Corr only supports the new AggregateExpression2"))
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if the data type of input parameters are not double?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add ImplicitCastInputTypes to case class Corr. So the other NumericType can be automatically casting to double.


test("test Last implemented based on AggregateExpression1") {
// TODO: Remove this test once we remove AggregateExpression1.
import org.apache.spark.sql.functions._
Expand Down