Skip to content

Commit d454909

Browse files
author
Yanbo Liang
committed
rename parameter and function names, delete unused columns, add reference
1 parent 2e56282 commit d454909

File tree

2 files changed

+48
-32
lines changed

2 files changed

+48
-32
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,62 +22,71 @@ import org.apache.spark.rdd.RDD
2222
import org.apache.spark.Logging
2323
import org.apache.spark.mllib.linalg.Vectors
2424
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
25-
import org.apache.spark.mllib.rdd.RDDFunctions._
2625

2726
/**
2827
* :: Experimental ::
2928
* Evaluator for regression.
3029
*
31-
* @param valuesAndPreds an RDD of (value, pred) pairs.
30+
* @param predictionAndObservations an RDD of (prediction,observation) pairs.
3231
*/
3332
@Experimental
34-
class RegressionMetrics(valuesAndPreds: RDD[(Double, Double)]) extends Logging {
33+
class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging {
3534

3635
/**
3736
* Use MultivariateOnlineSummarizer to calculate mean and variance of different combination.
3837
* MultivariateOnlineSummarizer is a numerically stable algorithm to compute mean and variance
3938
* in a online fashion.
4039
*/
4140
private lazy val summarizer: MultivariateOnlineSummarizer = {
42-
val summarizer: MultivariateOnlineSummarizer = valuesAndPreds.map{
43-
case (value,pred) => Vectors.dense(
44-
Array(value, value - pred, math.abs(value - pred), math.pow(value - pred, 2.0))
41+
val summarizer: MultivariateOnlineSummarizer = predictionAndObservations.map{
42+
case (prediction,observation) => Vectors.dense(
43+
Array(observation, observation - prediction)
4544
)
46-
}.treeAggregate(new MultivariateOnlineSummarizer())(
45+
}.aggregate(new MultivariateOnlineSummarizer())(
4746
(summary, v) => summary.add(v),
4847
(sum1,sum2) => sum1.merge(sum2)
4948
)
5049
summarizer
5150
}
5251

5352
/**
54-
* Computes the explained variance regression score
53+
* Returns the explained variance regression score.
54+
* explainedVarianceScore = 1 - variance(y - \hat{y}) / variance(y)
55+
* Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
5556
*/
56-
def explainedVarianceScore(): Double = {
57+
def explainedVarianceScore: Double = {
5758
1 - summarizer.variance(1) / summarizer.variance(0)
5859
}
5960

6061
/**
61-
* Computes the mean absolute error, which is a risk function corresponding to the
62+
* Returns the mean absolute error, which is a risk function corresponding to the
6263
* expected value of the absolute error loss or l1-norm loss.
6364
*/
64-
def mae(): Double = {
65-
summarizer.mean(2)
65+
def meanAbsoluteError: Double = {
66+
summarizer.normL1(1) / summarizer.count
6667
}
6768

6869
/**
69-
* Computes the mean square error, which is a risk function corresponding to the
70+
* Returns the mean squared error, which is a risk function corresponding to the
7071
* expected value of the squared error loss or quadratic loss.
7172
*/
72-
def mse(): Double = {
73-
summarizer.mean(3)
73+
def meanSquaredError: Double = {
74+
summarizer.normL2(1) * summarizer.normL2(1) / summarizer.count
7475
}
7576

7677
/**
77-
* Computes R^2^, the coefficient of determination.
78-
* @return
78+
* Returns the root mean squared error, which is defined as the square root of
79+
* the mean squared error.
7980
*/
80-
def r2_score(): Double = {
81-
1 - summarizer.mean(3) * summarizer.count / (summarizer.variance(0) * (summarizer.count - 1))
81+
def rootMeanSquaredError: Double = {
82+
summarizer.normL2(1) / math.sqrt(summarizer.count)
83+
}
84+
85+
/**
86+
* Returns R^2^, the coefficient of determination.
87+
* Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
88+
*/
89+
def r2Score: Double = {
90+
1 - summarizer.normL2(1) * summarizer.normL2(1) / (summarizer.variance(0) * (summarizer.count - 1))
8291
}
8392
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,35 @@
1818
package org.apache.spark.mllib.evaluation
1919

2020
import org.scalatest.FunSuite
21+
2122
import org.apache.spark.mllib.util.LocalSparkContext
2223
import org.apache.spark.mllib.util.TestingUtils._
2324

2425
class RegressionMetricsSuite extends FunSuite with LocalSparkContext {
2526

2627
test("regression metrics") {
27-
val valuesAndPreds = sc.parallelize(
28-
Seq((3.0,2.5),(-0.5,0.0),(2.0,2.0),(7.0,8.0)),2)
29-
val metrics = new RegressionMetrics(valuesAndPreds)
30-
assert(metrics.explainedVarianceScore() ~== 0.95717 absTol 1E-5,"explained variance regression score mismatch")
31-
assert(metrics.mae() ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
32-
assert(metrics.mse() ~== 0.375 absTol 1E-5, "mean square error mismatch")
33-
assert(metrics.r2_score() ~== 0.94861 absTol 1E-5, "r2 score mismatch")
28+
val predictionAndObservations = sc.parallelize(
29+
Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)),2)
30+
val metrics = new RegressionMetrics(predictionAndObservations)
31+
assert(metrics.explainedVarianceScore ~== 0.95717 absTol 1E-5,
32+
"explained variance regression score mismatch")
33+
assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
34+
assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
35+
assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
36+
"root mean squared error mismatch")
37+
assert(metrics.r2Score ~== 0.94861 absTol 1E-5, "r2 score mismatch")
3438
}
3539

3640
test("regression metrics with complete fitting") {
37-
val valuesAndPreds = sc.parallelize(
41+
val predictionAndObservations = sc.parallelize(
3842
Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)),2)
39-
val metrics = new RegressionMetrics(valuesAndPreds)
40-
assert(metrics.explainedVarianceScore() ~== 1.0 absTol 1E-5,"explained variance regression score mismatch")
41-
assert(metrics.mae() ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
42-
assert(metrics.mse() ~== 0.0 absTol 1E-5, "mean square error mismatch")
43-
assert(metrics.r2_score() ~== 1.0 absTol 1E-5, "r2 score mismatch")
43+
val metrics = new RegressionMetrics(predictionAndObservations)
44+
assert(metrics.explainedVarianceScore ~== 1.0 absTol 1E-5,
45+
"explained variance regression score mismatch")
46+
assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
47+
assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
48+
assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5,
49+
"root mean squared error mismatch")
50+
assert(metrics.r2Score ~== 1.0 absTol 1E-5, "r2 score mismatch")
4451
}
4552
}

0 commit comments

Comments
 (0)