Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,66 +109,21 @@ object StatFunctions extends Logging {

/** Calculate the Pearson Correlation Coefficient for the given columns */
def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols, "correlation")
counts.Ck / math.sqrt(counts.MkX * counts.MkY)
}

/** Helper class to simplify tracking and merging counts. */
private class CovarianceCounter extends Serializable {
var xAvg = 0.0 // the mean of all examples seen so far in col1
var yAvg = 0.0 // the mean of all examples seen so far in col2
var Ck = 0.0 // the co-moment after k examples
var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
var MkY = 0.0 // sum of squares of differences from the (current) mean for col2
var count = 0L // count of observed examples
// add an example to the calculation
def add(x: Double, y: Double): this.type = {
val deltaX = x - xAvg
val deltaY = y - yAvg
count += 1
xAvg += deltaX / count
yAvg += deltaY / count
Ck += deltaX * (y - yAvg)
MkX += deltaX * (x - xAvg)
MkY += deltaY * (y - yAvg)
this
}
// merge counters from other partitions. Formula can be found at:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
def merge(other: CovarianceCounter): this.type = {
if (other.count > 0) {
val totalCount = count + other.count
val deltaX = xAvg - other.xAvg
val deltaY = yAvg - other.yAvg
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
count = totalCount
}
this
require(cols.length == 2,
"Currently correlation calculation is supported between two columns.")
val Seq(col1, col2) = cols.map { c =>
val dataType = df.resolve(c).dataType
require(dataType.isInstanceOf[NumericType],
"Currently correlation calculation for columns with dataType " +
s"${dataType.catalogString} not supported.")
when(isnull(col(c)), lit(0.0))
.otherwise(col(c).cast(DoubleType))
}
// return the sample covariance for the observed examples
def cov: Double = Ck / (count - 1)
}

private def collectStatisticalData(df: DataFrame, cols: Seq[String],
functionName: String): CovarianceCounter = {
require(cols.length == 2, s"Currently $functionName calculation is supported " +
"between two columns.")
cols.map(name => (name, df.resolve(name))).foreach { case (name, data) =>
require(data.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " +
s"for columns with dataType ${data.dataType.catalogString} not supported.")
}
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
df.select(columns: _*).queryExecution.toRdd.treeAggregate(new CovarianceCounter)(
seqOp = (counter, row) => {
counter.add(row.getDouble(0), row.getDouble(1))
},
combOp = (baseCounter, other) => {
baseCounter.merge(other)
})
val correlation = corr(col1, col2)
df.select(
when(isnull(correlation), lit(Double.NaN))
.otherwise(correlation)
).head.getDouble(0)
}

/**
Expand All @@ -178,8 +133,21 @@ object StatFunctions extends Logging {
* @return the covariance of the two columns.
*/
def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols, "covariance")
counts.cov
require(cols.length == 2,
"Currently covariance calculation is supported between two columns.")
val Seq(col1, col2) = cols.map { c =>
val dataType = df.resolve(c).dataType
require(dataType.isInstanceOf[NumericType],
"Currently covariance calculation for columns with dataType " +
s"${dataType.catalogString} not supported.")
when(isnull(col(c)), lit(0.0))
.otherwise(col(c).cast(DoubleType))
}
val covariance = covar_samp(col1, col2)
df.select(
when(isnull(covariance), lit(0.0))
.otherwise(covariance)
).head.getDouble(0)
}

/** Generate a table of frequencies for the elements of two columns. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.scalatest.matchers.must.Matchers._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.functions.{col, lit, struct}
import org.apache.spark.sql.functions.{col, lit, struct, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType}
Expand Down Expand Up @@ -152,6 +152,18 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession {
}
}

test("SPARK-40933 test cov & corr with null values and empty dataset") {
val df1 = spark.range(0, 10)
.withColumn("value", when(col("id") % 3 === 0, col("id")))
assert(math.abs(df1.stat.cov("id", "value") - 5.0) < 1e-12)
assert(math.abs(df1.stat.corr("id", "value") - 0.5120915564991891) < 1e-12)

// empty dataframe
val df2 = df1.where(col("id") < 0)
assert(df2.stat.cov("id", "value") === 0)
assert(df2.stat.corr("id", "value").isNaN)
}

test("covariance") {
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")

Expand Down