Skip to content

Commit aac3bc5

Browse files
committed
Add RegressionMetrics in PySpark/MLlib
1 parent 32cdc81 commit aac3bc5

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD
2222
import org.apache.spark.Logging
2323
import org.apache.spark.mllib.linalg.Vectors
2424
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
25+
import org.apache.spark.sql.DataFrame
2526

2627
/**
2728
* :: Experimental ::
@@ -32,6 +33,14 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Multivariate
3233
@Experimental
3334
class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging {
3435

36+
/**
37+
* An auxiliary constructor taking a DataFrame.
38+
* @param predictionAndObservations a DataFrame with two double columns:
39+
* prediction and observation
40+
*/
41+
private[mllib] def this(predictionAndObservations: DataFrame) =
42+
this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1))))
43+
3544
/**
3645
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
3746
*/

python/pyspark/mllib/evaluation.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,73 @@ def unpersist(self):
6767
self.call("unpersist")
6868

6969

70+
class RegressionMetrics(JavaModelWrapper):
71+
"""
72+
Evaluator for regression.
73+
74+
>>> predictionAndObservations = sc.parallelize([
75+
... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
76+
>>> metrics = RegressionMetrics(predictionAndObservations)
77+
>>> metrics.explainedVariance()
78+
0.95...
79+
>>> metrics.meanAbsoluteError()
80+
0.5...
81+
>>> metrics.meanSquaredError()
82+
0.37...
83+
>>> metrics.rootMeanSquaredError()
84+
0.61...
85+
>>> metrics.r2()
86+
0.94...
87+
"""
88+
89+
def __init__(self, predictionAndObservations):
90+
"""
91+
:param predictionAndObservations: an RDD of (prediction, observation) pairs.
92+
"""
93+
sc = predictionAndObservations.ctx
94+
sql_ctx = SQLContext(sc)
95+
df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([
96+
StructField("prediction", DoubleType(), nullable=False),
97+
StructField("observation", DoubleType(), nullable=False)]))
98+
java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics
99+
java_model = java_class(df._jdf)
100+
super(RegressionMetrics, self).__init__(java_model)
101+
102+
def explainedVariance(self):
103+
"""
104+
Returns the explained variance regression score.
105+
explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
106+
"""
107+
return self.call("explainedVariance")
108+
109+
def meanAbsoluteError(self):
110+
"""
111+
Returns the mean absolute error, which is a risk function corresponding to the
112+
expected value of the absolute error loss or l1-norm loss.
113+
"""
114+
return self.call("meanAbsoluteError")
115+
116+
def meanSquaredError(self):
117+
"""
118+
Returns the mean squared error, which is a risk function corresponding to the
119+
expected value of the squared error loss or quadratic loss.
120+
"""
121+
return self.call("meanSquaredError")
122+
123+
def rootMeanSquaredError(self):
124+
"""
125+
Returns the root mean squared error, which is defined as the square root of
126+
the mean squared error.
127+
"""
128+
return self.call("rootMeanSquaredError")
129+
130+
def r2(self):
131+
"""
132+
Returns R^2^, the coefficient of determination.
133+
"""
134+
return self.call("r2")
135+
136+
70137
def _test():
71138
import doctest
72139
from pyspark import SparkContext

0 commit comments

Comments
 (0)