Skip to content

Commit d932719

Browse files
Yanbo Liangmengxr
authored andcommitted
SPARK-4111 [MLlib] add regression metrics
Add RegressionMetrics.scala as regression metrics used for evaluation and corresponding test case RegressionMetricsSuite.scala. Author: Yanbo Liang <[email protected]> Author: liangyanbo <[email protected]> Closes #2978 from yanbohappy/regression_metrics and squashes the following commits: 730d0a9 [Yanbo Liang] more clearly annotation 3d0bec1 [Yanbo Liang] rename and keep code style a8ad3e3 [Yanbo Liang] simplify code for keeping style d454909 [Yanbo Liang] rename parameter and function names, delete unused columns, add reference 2e56282 [liangyanbo] rename r2_score() and remove unused column 43bb12b [liangyanbo] add regression metrics
1 parent c7ad085 commit d932719

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.rdd.RDD
22+
import org.apache.spark.Logging
23+
import org.apache.spark.mllib.linalg.Vectors
24+
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
25+
26+
/**
27+
* :: Experimental ::
28+
* Evaluator for regression.
29+
*
30+
* @param predictionAndObservations an RDD of (prediction, observation) pairs.
31+
*/
32+
@Experimental
33+
class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging {
34+
35+
/**
36+
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
37+
*/
38+
private lazy val summary: MultivariateStatisticalSummary = {
39+
val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
40+
case (prediction, observation) => Vectors.dense(observation, observation - prediction)
41+
}.aggregate(new MultivariateOnlineSummarizer())(
42+
(summary, v) => summary.add(v),
43+
(sum1, sum2) => sum1.merge(sum2)
44+
)
45+
summary
46+
}
47+
48+
/**
49+
* Returns the explained variance regression score.
50+
* explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
51+
* Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
52+
*/
53+
def explainedVariance: Double = {
54+
1 - summary.variance(1) / summary.variance(0)
55+
}
56+
57+
/**
58+
* Returns the mean absolute error, which is a risk function corresponding to the
59+
* expected value of the absolute error loss or l1-norm loss.
60+
*/
61+
def meanAbsoluteError: Double = {
62+
summary.normL1(1) / summary.count
63+
}
64+
65+
/**
66+
* Returns the mean squared error, which is a risk function corresponding to the
67+
* expected value of the squared error loss or quadratic loss.
68+
*/
69+
def meanSquaredError: Double = {
70+
val rmse = summary.normL2(1) / math.sqrt(summary.count)
71+
rmse * rmse
72+
}
73+
74+
/**
75+
* Returns the root mean squared error, which is defined as the square root of
76+
* the mean squared error.
77+
*/
78+
def rootMeanSquaredError: Double = {
79+
summary.normL2(1) / math.sqrt(summary.count)
80+
}
81+
82+
/**
83+
* Returns R^2^, the coefficient of determination.
84+
* Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
85+
*/
86+
def r2: Double = {
87+
1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1))
88+
}
89+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.mllib.util.LocalSparkContext
23+
import org.apache.spark.mllib.util.TestingUtils._
24+
25+
class RegressionMetricsSuite extends FunSuite with LocalSparkContext {
26+
27+
test("regression metrics") {
28+
val predictionAndObservations = sc.parallelize(
29+
Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2)
30+
val metrics = new RegressionMetrics(predictionAndObservations)
31+
assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
32+
"explained variance regression score mismatch")
33+
assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
34+
assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
35+
assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
36+
"root mean squared error mismatch")
37+
assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch")
38+
}
39+
40+
test("regression metrics with complete fitting") {
41+
val predictionAndObservations = sc.parallelize(
42+
Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2)
43+
val metrics = new RegressionMetrics(predictionAndObservations)
44+
assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
45+
"explained variance regression score mismatch")
46+
assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
47+
assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
48+
assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5,
49+
"root mean squared error mismatch")
50+
assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch")
51+
}
52+
}

0 commit comments

Comments
 (0)