1515 * limitations under the License.
1616 */
1717
18- package org .apache .spark .mllib .evaluation
18+ package org .apache .spark .mllib .evaluation . binary
1919
2020import org .apache .spark .rdd .RDD
2121import org .apache .spark .SparkContext ._
22+ import org .apache .spark .mllib .evaluation .AreaUnderCurve
23+ import org .apache .spark .Logging
2224
2325/**
24- * Binary confusion matrix .
26+ * Implementation of [[ org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix ]] .
2527 *
2628 * @param count label counter for labels with scores greater than or equal to the current score
27- * @param total label counter for all labels
29+ * @param totalCount label counter for all labels
2830 */
29- case class BinaryConfusionMatrix (
31+ private case class BinaryConfusionMatrixImpl (
3032 private val count : LabelCounter ,
31- private val total : LabelCounter ) extends Serializable {
33+ private val totalCount : LabelCounter ) extends BinaryConfusionMatrix with Serializable {
3234
3335 /** number of true positives */
34- def tp : Long = count.numPositives
36+ override def tp : Long = count.numPositives
3537
3638 /** number of false positives */
37- def fp : Long = count.numNegatives
39+ override def fp : Long = count.numNegatives
3840
3941 /** number of false negatives */
40- def fn : Long = total .numPositives - count.numPositives
42+ override def fn : Long = totalCount .numPositives - count.numPositives
4143
4244 /** number of true negatives */
43- def tn : Long = total .numNegatives - count.numNegatives
45+ override def tn : Long = totalCount .numNegatives - count.numNegatives
4446
4547 /** number of positives */
46- def p : Long = total .numPositives
48+ override def p : Long = totalCount .numPositives
4749
4850 /** number of negatives */
49- def n : Long = total.numNegatives
50- }
51-
52- private trait Metric {
53- def apply (c : BinaryConfusionMatrix ): Double
54- }
55-
56- object Precision extends Metric {
57- override def apply (c : BinaryConfusionMatrix ): Double =
58- c.tp.toDouble / (c.tp + c.fp)
59- }
60-
61- object FalsePositiveRate extends Metric {
62- override def apply (c : BinaryConfusionMatrix ): Double =
63- c.fp.toDouble / c.n
64- }
65-
66- object Recall extends Metric {
67- override def apply (c : BinaryConfusionMatrix ): Double =
68- c.tp.toDouble / c.p
69- }
70-
71- case class FMeasure (beta : Double ) extends Metric {
72- private val beta2 = beta * beta
73- override def apply (c : BinaryConfusionMatrix ): Double = {
74- val precision = Precision (c)
75- val recall = Recall (c)
76- (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
77- }
51+ override def n : Long = totalCount.numNegatives
7852}
7953
8054/**
8155 * Evaluator for binary classification.
8256 *
8357 * @param scoreAndlabels an RDD of (score, label) pairs.
8458 */
85- class BinaryClassificationEvaluator (scoreAndlabels : RDD [(Double , Double )]) extends Serializable {
59+ class BinaryClassificationEvaluator (scoreAndlabels : RDD [(Double , Double )]) extends Serializable with Logging {
8660
87- private lazy val (cumCounts : RDD [(Double , LabelCounter )], totalCount : LabelCounter , scoreAndConfusion : RDD [(Double , BinaryConfusionMatrix )]) = {
61+ private lazy val (
62+ cumCounts : RDD [(Double , LabelCounter )],
63+ confusionByThreshold : RDD [(Double , BinaryConfusionMatrix )]) = {
8864 // Create a bin for each distinct score value, count positives and negatives within each bin,
8965 // and then sort by score values in descending order.
9066 val counts = scoreAndlabels.combineByKey(
@@ -99,6 +75,7 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
9975 }, preservesPartitioning = true ).collect()
10076 val cum = agg.scanLeft(new LabelCounter ())((agg : LabelCounter , c : LabelCounter ) => agg + c)
10177 val totalCount = cum.last
78+ logInfo(s " Total counts: totalCount " )
10279 val cumCounts = counts.mapPartitionsWithIndex((index : Int , iter : Iterator [(Double , LabelCounter )]) => {
10380 val cumCount = cum(index)
10481 iter.map { case (score, c) =>
@@ -108,76 +85,71 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
10885 }, preservesPartitioning = true )
10986 cumCounts.persist()
11087 val scoreAndConfusion = cumCounts.map { case (score, cumCount) =>
111- (score, BinaryConfusionMatrix (cumCount, totalCount))
88+ (score, BinaryConfusionMatrixImpl (cumCount, totalCount))
11289 }
11390 (cumCounts, totalCount, scoreAndConfusion)
11491 }
11592
93+ /** Unpersist intermediate RDDs used in the computation. */
11694 def unpersist () {
11795 cumCounts.unpersist()
11896 }
11997
98+ /**
99+ * Returns the receiver operating characteristic (ROC) curve.
100+ * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
101+ */
120102 def rocCurve (): RDD [(Double , Double )] = createCurve(FalsePositiveRate , Recall )
121103
104+ /**
105+ * Computes the area under the receiver operating characteristic (ROC) curve.
106+ */
122107 def rocAUC (): Double = AreaUnderCurve .of(rocCurve())
123108
109+ /**
110+ * Returns the precision-recall curve.
111+ * @see http://en.wikipedia.org/wiki/Precision_and_recall
112+ */
124113 def prCurve (): RDD [(Double , Double )] = createCurve(Recall , Precision )
125114
115+ /**
116+ * Computes the area under the precision-recall curve.
117+ */
126118 def prAUC (): Double = AreaUnderCurve .of(prCurve())
127119
120+ /**
121+ * Returns the (threshold, F-Measure) curve.
122+ * @param beta the beta factor in F-Measure computation.
123+ * @return an RDD of (threshold, F-Measure) pairs.
124+ * @see http://en.wikipedia.org/wiki/F1_score
125+ */
128126 def fMeasureByThreshold (beta : Double ): RDD [(Double , Double )] = createCurve(FMeasure (beta))
129127
128+ /** Returns the (threshold, F-Measure) curve with beta = 1.0. */
130129 def fMeasureByThreshold () = fMeasureByThreshold(1.0 )
131130
132- private def createCurve (y : Metric ): RDD [(Double , Double )] = {
133- scoreAndConfusion.map { case (s, c) =>
131+ /** Creates a curve of (threshold, metric). */
132+ private def createCurve (y : BinaryClassificationMetric ): RDD [(Double , Double )] = {
133+ confusionByThreshold.map { case (s, c) =>
134134 (s, y(c))
135135 }
136136 }
137137
138- private def createCurve (x : Metric , y : Metric ): RDD [(Double , Double )] = {
139- scoreAndConfusion.map { case (_, c) =>
138+ /** Creates a curve of (metricX, metricY). */
139+ private def createCurve (x : BinaryClassificationMetric , y : BinaryClassificationMetric ): RDD [(Double , Double )] = {
140+ confusionByThreshold.map { case (_, c) =>
140141 (x(c), y(c))
141142 }
142143 }
143144}
144145
145- class LocalBinaryClassificationEvaluator {
146- def get (data : Iterable [(Double , Double )]) {
147- val counts = data.groupBy(_._1).mapValues { s =>
148- val c = new LabelCounter ()
149- s.foreach(c += _._2)
150- c
151- }.toSeq.sortBy(- _._1)
152- println(" counts: " + counts.toList)
153- val total = new LabelCounter ()
154- val cum = counts.map { s =>
155- total += s._2
156- (s._1, total.clone())
157- }
158- println(" cum: " + cum.toList)
159- val roc = cum.map { case (s, c) =>
160- (1.0 * c.numNegatives / total.numNegatives, 1.0 * c.numPositives / total.numPositives)
161- }
162- val rocAUC = AreaUnderCurve .of(roc)
163- println(rocAUC)
164- val pr = cum.map { case (s, c) =>
165- (1.0 * c.numPositives / total.numPositives,
166- 1.0 * c.numPositives / (c.numPositives + c.numNegatives))
167- }
168- val prAUC = AreaUnderCurve .of(pr)
169- println(prAUC)
170- }
171- }
172-
173146/**
174147 * A counter for positives and negatives.
175148 *
176- * @param numPositives
177- * @param numNegatives
149+ * @param numPositives number of positive labels
150+ * @param numNegatives number of negative labels
178151 */
179- private [evaluation]
180- class LabelCounter (var numPositives : Long = 0L , var numNegatives : Long = 0L ) extends Serializable {
152+ private class LabelCounter (var numPositives : Long = 0L , var numNegatives : Long = 0L ) extends Serializable {
181153
182154 /** Process a label. */
183155 def += (label : Double ): LabelCounter = {
@@ -208,6 +180,6 @@ class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) ext
208180 new LabelCounter (numPositives, numNegatives)
209181 }
210182
211- override def toString : String = s " [ $numPositives, $numNegatives] "
183+ override def toString : String = s " {numPos: $numPositives, numNeg: $numNegatives} "
212184}
213185
0 commit comments