1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one or more
3+ * contributor license agreements. See the NOTICE file distributed with
4+ * this work for additional information regarding copyright ownership.
5+ * The ASF licenses this file to You under the Apache License, Version 2.0
6+ * (the "License"); you may not use this file except in compliance with
7+ * the License. You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
17+
118package org .apache .spark .mllib .evaluation
219
320import org .apache .spark .rdd .RDD
421import org .apache .spark .SparkContext ._
522
6- class BinaryClassificationEvaluator (scoreAndLabel : RDD [(Double , Double )]) {
7-
23+ /**
24+ * Binary confusion matrix.
25+ *
26+ * @param count label counter for labels with scores greater than or equal to the current score
27+ * @param total label counter for all labels
28+ */
29+ case class BinaryConfusionMatrix (
30+ private val count : LabelCounter ,
31+ private val total : LabelCounter ) extends Serializable {
32+
33+ /** number of true positives */
34+ def tp : Long = count.numPositives
35+
36+ /** number of false positives */
37+ def fp : Long = count.numNegatives
38+
39+ /** number of false negatives */
40+ def fn : Long = total.numPositives - count.numPositives
41+
42+ /** number of true negatives */
43+ def tn : Long = total.numNegatives - count.numNegatives
44+
45+ /** number of positives */
46+ def p : Long = total.numPositives
47+
48+ /** number of negatives */
49+ def n : Long = total.numNegatives
850}
951
10- object BinaryClassificationEvaluator {
52+ private trait Metric {
53+ def apply (c : BinaryConfusionMatrix ): Double
54+ }
55+
56+ object Precision extends Metric {
57+ override def apply (c : BinaryConfusionMatrix ): Double =
58+ c.tp.toDouble / (c.tp + c.fp)
59+ }
1160
12- def get (rdd : RDD [(Double , Double )]) {
61+ object FalsePositiveRate extends Metric {
62+ override def apply (c : BinaryConfusionMatrix ): Double =
63+ c.fp.toDouble / c.n
64+ }
65+
66+ object Recall extends Metric {
67+ override def apply (c : BinaryConfusionMatrix ): Double =
68+ c.tp.toDouble / c.p
69+ }
70+
71+ case class FMeasure (beta : Double ) extends Metric {
72+ private val beta2 = beta * beta
73+ override def apply (c : BinaryConfusionMatrix ): Double = {
74+ val precision = Precision (c)
75+ val recall = Recall (c)
76+ (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
77+ }
78+ }
79+
80+ /**
81+ * Evaluator for binary classification.
82+ *
83+ * @param scoreAndlabels an RDD of (score, label) pairs.
84+ */
85+ class BinaryClassificationEvaluator (scoreAndlabels : RDD [(Double , Double )]) extends Serializable {
86+
87+ private lazy val (cumCounts : RDD [(Double , LabelCounter )], totalCount : LabelCounter , scoreAndConfusion : RDD [(Double , BinaryConfusionMatrix )]) = {
1388 // Create a bin for each distinct score value, count positives and negatives within each bin,
1489 // and then sort by score values in descending order.
15- val counts = rdd .combineByKey(
16- createCombiner = (label : Double ) => new Counter (0L , 0L ) += label,
17- mergeValue = (c : Counter , label : Double ) => c += label,
18- mergeCombiners = (c1 : Counter , c2 : Counter ) => c1 += c2
90+ val counts = scoreAndlabels .combineByKey(
91+ createCombiner = (label : Double ) => new LabelCounter (0L , 0L ) += label,
92+ mergeValue = (c : LabelCounter , label : Double ) => c += label,
93+ mergeCombiners = (c1 : LabelCounter , c2 : LabelCounter ) => c1 += c2
1994 ).sortByKey(ascending = false )
20- println(counts.collect().toList)
21- val agg = counts.values.mapPartitions((iter : Iterator [Counter ]) => {
22- val agg = new Counter ()
95+ val agg = counts.values.mapPartitions({ iter =>
96+ val agg = new LabelCounter ()
2397 iter.foreach(agg += _)
2498 Iterator (agg)
2599 }, preservesPartitioning = true ).collect()
26- println(agg.toList)
27- val cum = agg.scanLeft(new Counter ())((agg : Counter , c : Counter ) => agg + c)
28- val total = cum.last
29- println(total)
30- println(cum.toList)
31- val cumCountsRdd = counts.mapPartitionsWithIndex((index : Int , iter : Iterator [(Double , Counter )]) => {
100+ val cum = agg.scanLeft(new LabelCounter ())((agg : LabelCounter , c : LabelCounter ) => agg + c)
101+ val totalCount = cum.last
102+ val cumCounts = counts.mapPartitionsWithIndex((index : Int , iter : Iterator [(Double , LabelCounter )]) => {
32103 val cumCount = cum(index)
33104 iter.map { case (score, c) =>
34105 cumCount += c
35106 (score, cumCount.clone())
36107 }
37108 }, preservesPartitioning = true )
38- println(" cum: " + cumCountsRdd.collect().toList)
39- val rocAUC = AreaUnderCurve .of(cumCountsRdd.values.map((c : Counter ) => {
40- (1.0 * c.numNegatives / total.numNegatives,
41- 1.0 * c.numPositives / total.numPositives)
42- }))
43- println(rocAUC)
44- val prAUC = AreaUnderCurve .of(cumCountsRdd.values.map((c : Counter ) => {
45- (1.0 * c.numPositives / total.numPositives,
46- 1.0 * c.numPositives / (c.numPositives + c.numNegatives))
47- }))
48- println(prAUC)
109+ cumCounts.persist()
110+ val scoreAndConfusion = cumCounts.map { case (score, cumCount) =>
111+ (score, BinaryConfusionMatrix (cumCount, totalCount))
112+ }
113+ (cumCounts, totalCount, scoreAndConfusion)
114+ }
115+
116+ def unpersist () {
117+ cumCounts.unpersist()
49118 }
50119
120+ def rocCurve (): RDD [(Double , Double )] = createCurve(FalsePositiveRate , Recall )
121+
122+ def rocAUC (): Double = AreaUnderCurve .of(rocCurve())
123+
124+ def prCurve (): RDD [(Double , Double )] = createCurve(Recall , Precision )
125+
126+ def prAUC (): Double = AreaUnderCurve .of(prCurve())
127+
128+ def fMeasureByThreshold (beta : Double ): RDD [(Double , Double )] = createCurve(FMeasure (beta))
129+
130+ def fMeasureByThreshold () = fMeasureByThreshold(1.0 )
131+
132+ private def createCurve (y : Metric ): RDD [(Double , Double )] = {
133+ scoreAndConfusion.map { case (s, c) =>
134+ (s, y(c))
135+ }
136+ }
137+
138+ private def createCurve (x : Metric , y : Metric ): RDD [(Double , Double )] = {
139+ scoreAndConfusion.map { case (_, c) =>
140+ (x(c), y(c))
141+ }
142+ }
143+ }
144+
145+ class LocalBinaryClassificationEvaluator {
51146 def get (data : Iterable [(Double , Double )]) {
52147 val counts = data.groupBy(_._1).mapValues { s =>
53- val c = new Counter ()
148+ val c = new LabelCounter ()
54149 s.foreach(c += _._2)
55150 c
56151 }.toSeq.sortBy(- _._1)
57152 println(" counts: " + counts.toList)
58- val total = new Counter ()
153+ val total = new LabelCounter ()
59154 val cum = counts.map { s =>
60155 total += s._2
61156 (s._1, total.clone())
@@ -75,35 +170,44 @@ object BinaryClassificationEvaluator {
75170 }
76171}
77172
78- class Counter (var numPositives : Long = 0L , var numNegatives : Long = 0L ) extends Serializable {
173+ /**
174+ * A counter for positives and negatives.
175+ *
176+ * @param numPositives
177+ * @param numNegatives
178+ */
179+ private [evaluation]
180+ class LabelCounter (var numPositives : Long = 0L , var numNegatives : Long = 0L ) extends Serializable {
79181
80- def += (label : Double ): Counter = {
182+ /** Process a label. */
183+ def += (label : Double ): LabelCounter = {
81184 // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
82185 // -1.0 for negative as well.
83186 if (label > 0.5 ) numPositives += 1L else numNegatives += 1L
84187 this
85188 }
86189
87- def += (other : Counter ): Counter = {
190+ /** Merge another counter. */
191+ def += (other : LabelCounter ): LabelCounter = {
88192 numPositives += other.numPositives
89193 numNegatives += other.numNegatives
90194 this
91195 }
92196
93- def + (label : Double ): Counter = {
197+ def + (label : Double ): LabelCounter = {
94198 this .clone() += label
95199 }
96200
97- def + (other : Counter ): Counter = {
201+ def + (other : LabelCounter ): LabelCounter = {
98202 this .clone() += other
99203 }
100204
101205 def sum : Long = numPositives + numNegatives
102206
103- override def clone () : Counter = {
104- new Counter (numPositives, numNegatives)
207+ override def clone : LabelCounter = {
208+ new LabelCounter (numPositives, numNegatives)
105209 }
106210
107- override def toString () : String = s " [ $numPositives, $numNegatives] "
211+ override def toString : String = s " [ $numPositives, $numNegatives] "
108212}
109213
0 commit comments