@@ -22,62 +22,71 @@ import org.apache.spark.rdd.RDD
2222import org .apache .spark .Logging
2323import org .apache .spark .mllib .linalg .Vectors
2424import org .apache .spark .mllib .stat .MultivariateOnlineSummarizer
25- import org .apache .spark .mllib .rdd .RDDFunctions ._
2625
2726/**
2827 * :: Experimental ::
2928 * Evaluator for regression.
3029 *
31- * @param valuesAndPreds an RDD of (value, pred ) pairs.
30+ * @param predictionAndObservations an RDD of (prediction,observation ) pairs.
3231 */
3332@ Experimental
34- class RegressionMetrics (valuesAndPreds : RDD [(Double , Double )]) extends Logging {
33+ class RegressionMetrics (predictionAndObservations : RDD [(Double , Double )]) extends Logging {
3534
3635 /**
3736 * Use MultivariateOnlineSummarizer to calculate mean and variance of different combination.
3837 * MultivariateOnlineSummarizer is a numerically stable algorithm to compute mean and variance
3938 * in a online fashion.
4039 */
4140 private lazy val summarizer : MultivariateOnlineSummarizer = {
42- val summarizer : MultivariateOnlineSummarizer = valuesAndPreds .map{
43- case (value,pred ) => Vectors .dense(
44- Array (value, value - pred, math.abs(value - pred), math.pow(value - pred, 2.0 ) )
41+ val summarizer : MultivariateOnlineSummarizer = predictionAndObservations .map{
42+ case (prediction,observation ) => Vectors .dense(
43+ Array (observation, observation - prediction )
4544 )
46- }.treeAggregate (new MultivariateOnlineSummarizer ())(
45+ }.aggregate (new MultivariateOnlineSummarizer ())(
4746 (summary, v) => summary.add(v),
4847 (sum1,sum2) => sum1.merge(sum2)
4948 )
5049 summarizer
5150 }
5251
5352 /**
54- * Computes the explained variance regression score
53+ * Returns the explained variance regression score.
54+ * explainedVarianceScore = 1 - variance(y - \hat{y}) / variance(y)
55+ * Reference: [[http://en.wikipedia.org/wiki/Explained_variation ]]
5556 */
56- def explainedVarianceScore () : Double = {
57+ def explainedVarianceScore : Double = {
5758 1 - summarizer.variance(1 ) / summarizer.variance(0 )
5859 }
5960
6061 /**
61- * Computes the mean absolute error, which is a risk function corresponding to the
62+ * Returns the mean absolute error, which is a risk function corresponding to the
6263 * expected value of the absolute error loss or l1-norm loss.
6364 */
64- def mae () : Double = {
65- summarizer.mean( 2 )
65+ def meanAbsoluteError : Double = {
66+ summarizer.normL1( 1 ) / summarizer.count
6667 }
6768
6869 /**
69- * Computes the mean square error, which is a risk function corresponding to the
70+ * Returns the mean squared error, which is a risk function corresponding to the
7071 * expected value of the squared error loss or quadratic loss.
7172 */
72- def mse () : Double = {
73- summarizer.mean( 3 )
73+ def meanSquaredError : Double = {
74+ summarizer.normL2( 1 ) * summarizer.normL2( 1 ) / summarizer.count
7475 }
7576
7677 /**
77- * Computes R^2^, the coefficient of determination.
78- * @return
78+ * Returns the root mean squared error, which is defined as the square root of
79+ * the mean squared error.
7980 */
80- def r2_score (): Double = {
81- 1 - summarizer.mean(3 ) * summarizer.count / (summarizer.variance(0 ) * (summarizer.count - 1 ))
81+ def rootMeanSquaredError : Double = {
82+ summarizer.normL2(1 ) / math.sqrt(summarizer.count)
83+ }
84+
85+ /**
86+ * Returns R^2^, the coefficient of determination.
87+ * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination ]]
88+ */
89+ def r2Score : Double = {
90+ 1 - summarizer.normL2(1 ) * summarizer.normL2(1 ) / (summarizer.variance(0 ) * (summarizer.count - 1 ))
8291 }
8392}
0 commit comments