@@ -67,6 +67,73 @@ def unpersist(self):
6767 self .call ("unpersist" )
6868
6969
70+ class RegressionMetrics (JavaModelWrapper ):
71+ """
72+ Evaluator for regression.
73+
74+ >>> predictionAndObservations = sc.parallelize([
75+ ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
76+ >>> metrics = RegressionMetrics(predictionAndObservations)
77+ >>> metrics.explainedVariance()
78+ 0.95...
79+ >>> metrics.meanAbsoluteError()
80+ 0.5...
81+ >>> metrics.meanSquaredError()
82+ 0.37...
83+ >>> metrics.rootMeanSquaredError()
84+ 0.61...
85+ >>> metrics.r2()
86+ 0.94...
87+ """
88+
89+ def __init__ (self , predictionAndObservations ):
90+ """
91+ :param predictionAndObservations: an RDD of (prediction, observation) pairs.
92+ """
93+ sc = predictionAndObservations .ctx
94+ sql_ctx = SQLContext (sc )
95+ df = sql_ctx .createDataFrame (predictionAndObservations , schema = StructType ([
96+ StructField ("prediction" , DoubleType (), nullable = False ),
97+ StructField ("observation" , DoubleType (), nullable = False )]))
98+ java_class = sc ._jvm .org .apache .spark .mllib .evaluation .RegressionMetrics
99+ java_model = java_class (df ._jdf )
100+ super (RegressionMetrics , self ).__init__ (java_model )
101+
102+ def explainedVariance (self ):
103+ """
104+ Returns the explained variance regression score.
105+ explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
106+ """
107+ return self .call ("explainedVariance" )
108+
109+ def meanAbsoluteError (self ):
110+ """
111+ Returns the mean absolute error, which is a risk function corresponding to the
112+ expected value of the absolute error loss or l1-norm loss.
113+ """
114+ return self .call ("meanAbsoluteError" )
115+
116+ def meanSquaredError (self ):
117+ """
118+ Returns the mean squared error, which is a risk function corresponding to the
119+ expected value of the squared error loss or quadratic loss.
120+ """
121+ return self .call ("meanSquaredError" )
122+
123+ def rootMeanSquaredError (self ):
124+ """
125+ Returns the root mean squared error, which is defined as the square root of
126+ the mean squared error.
127+ """
128+ return self .call ("rootMeanSquaredError" )
129+
130+ def r2 (self ):
131+ """
132+ Returns R^2^, the coefficient of determination.
133+ """
134+ return self .call ("r2" )
135+
136+
70137def _test ():
71138 import doctest
72139 from pyspark import SparkContext
0 commit comments