@@ -22,24 +22,27 @@ import org.apache.spark.rdd.RDD
2222import org .apache .spark .Logging
2323import org .apache .spark .SparkContext ._
2424
25+ import scala .collection .Map
26+
2527/**
28+ * ::Experimental::
2629 * Evaluator for multiclass classification.
2730 *
2831 * @param predictionsAndLabels an RDD of (prediction, label) pairs.
2932 */
3033@ Experimental
3134class MulticlassMetrics (predictionsAndLabels : RDD [(Double , Double )]) extends Logging {
3235
33- private lazy val labelCountByClass = predictionsAndLabels.values.countByValue()
34- private lazy val labelCount = labelCountByClass.values.sum
35- private lazy val tpByClass = predictionsAndLabels
36- .map{ case (prediction, label) =>
37- (label, if (label == prediction) 1 else 0 )
36+ private lazy val labelCountByClass : Map [ Double , Long ] = predictionsAndLabels.values.countByValue()
37+ private lazy val labelCount : Long = labelCountByClass.values.sum
38+ private lazy val tpByClass : Map [ Double , Int ] = predictionsAndLabels
39+ .map { case (prediction, label) =>
40+ (label, if (label == prediction) 1 else 0 )
3841 }.reduceByKey(_ + _)
3942 .collectAsMap()
40- private lazy val fpByClass = predictionsAndLabels
41- .map{ case (prediction, label) =>
42- (prediction, if (prediction != label) 1 else 0 )
43+ private lazy val fpByClass : Map [ Double , Int ] = predictionsAndLabels
44+ .map { case (prediction, label) =>
45+ (prediction, if (prediction != label) 1 else 0 )
4346 }.reduceByKey(_ + _)
4447 .collectAsMap()
4548
@@ -63,35 +66,41 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
6366 * Returns f-measure for a given label (category)
6467 * @param label the label.
6568 */
66- def fMeasure (label : Double , beta: Double = 1.0 ): Double = {
69+ def fMeasure (label : Double , beta : Double ): Double = {
6770 val p = precision(label)
6871 val r = recall(label)
6972 val betaSqrd = beta * beta
7073 if (p + r == 0 ) 0 else (1 + betaSqrd) * p * r / (betaSqrd * p + r)
7174 }
7275
7376 /**
74- * Returns micro-averaged recall
75- * (equals to microPrecision and microF1measure for multiclass classifier)
77+ * Returns f1-measure for a given label (category)
78+ * @param label the label.
79+ */
80+ def fMeasure (label : Double ): Double = fMeasure(label, 1.0 )
81+
82+ /**
83+ * Returns precision
7684 */
77- lazy val recall : Double =
78- tpByClass.values.sum.toDouble / labelCount
85+ lazy val precision : Double = tpByClass.values.sum.toDouble / labelCount
7986
8087 /**
81- * Returns micro-averaged precision
82- * (equals to microPrecision and microF1measure for multiclass classifier)
88+ * Returns recall
89+ * (equals to precision for multiclass classifier
90+ * because sum of all false positives is equal to sum
91+ * of all false negatives)
8392 */
84- lazy val precision : Double = recall
93+ lazy val recall : Double = precision
8594
8695 /**
87- * Returns micro-averaged f-measure
88- * (equals to microPrecision and microRecall for multiclass classifier )
96+ * Returns f-measure
97+ * (equals to precision and recall because precision equals recall )
8998 */
90- lazy val fMeasure : Double = recall
99+ lazy val fMeasure : Double = precision
91100
92101 /**
93102 * Returns weighted averaged recall
94- * (equals to micro-averaged precision, recall and f-measure)
103+ * (equals to precision, recall and f-measure)
95104 */
96105 lazy val weightedRecall : Double = labelCountByClass.map { case (category, count) =>
97106 recall(category) * count.toDouble / labelCount
@@ -114,6 +123,5 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
114123 /**
115124 * Returns the sequence of labels in ascending order
116125 */
117- lazy val labels = tpByClass.unzip._1.toSeq.sorted
118-
126+ lazy val labels : Array [Double ] = tpByClass.keys.toArray.sorted
119127}
0 commit comments