Skip to content

Commit dda82d5

Browse files
committed
add confusion matrix
1 parent aa7e278 commit dda82d5

File tree

1 file changed

+142
-38
lines changed

1 file changed

+142
-38
lines changed
Lines changed: 142 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,156 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
118
package org.apache.spark.mllib.evaluation
219

320
import org.apache.spark.rdd.RDD
421
import org.apache.spark.SparkContext._
522

6-
class BinaryClassificationEvaluator(scoreAndLabel: RDD[(Double, Double)]) {
7-
23+
/**
24+
* Binary confusion matrix.
25+
*
26+
* @param count label counter for labels with scores greater than or equal to the current score
27+
* @param total label counter for all labels
28+
*/
29+
case class BinaryConfusionMatrix(
30+
private val count: LabelCounter,
31+
private val total: LabelCounter) extends Serializable {
32+
33+
/** number of true positives */
34+
def tp: Long = count.numPositives
35+
36+
/** number of false positives */
37+
def fp: Long = count.numNegatives
38+
39+
/** number of false negatives */
40+
def fn: Long = total.numPositives - count.numPositives
41+
42+
/** number of true negatives */
43+
def tn: Long = total.numNegatives - count.numNegatives
44+
45+
/** number of positives */
46+
def p: Long = total.numPositives
47+
48+
/** number of negatives */
49+
def n: Long = total.numNegatives
850
}
951

10-
object BinaryClassificationEvaluator {
52+
private trait Metric {
53+
def apply(c: BinaryConfusionMatrix): Double
54+
}
55+
56+
object Precision extends Metric {
57+
override def apply(c: BinaryConfusionMatrix): Double =
58+
c.tp.toDouble / (c.tp + c.fp)
59+
}
1160

12-
def get(rdd: RDD[(Double, Double)]) {
61+
object FalsePositiveRate extends Metric {
62+
override def apply(c: BinaryConfusionMatrix): Double =
63+
c.fp.toDouble / c.n
64+
}
65+
66+
object Recall extends Metric {
67+
override def apply(c: BinaryConfusionMatrix): Double =
68+
c.tp.toDouble / c.p
69+
}
70+
71+
case class FMeasure(beta: Double) extends Metric {
72+
private val beta2 = beta * beta
73+
override def apply(c: BinaryConfusionMatrix): Double = {
74+
val precision = Precision(c)
75+
val recall = Recall(c)
76+
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
77+
}
78+
}
79+
80+
/**
81+
* Evaluator for binary classification.
82+
*
83+
* @param scoreAndlabels an RDD of (score, label) pairs.
84+
*/
85+
class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) extends Serializable {
86+
87+
private lazy val (cumCounts: RDD[(Double, LabelCounter)], totalCount: LabelCounter, scoreAndConfusion: RDD[(Double, BinaryConfusionMatrix)]) = {
1388
// Create a bin for each distinct score value, count positives and negatives within each bin,
1489
// and then sort by score values in descending order.
15-
val counts = rdd.combineByKey(
16-
createCombiner = (label: Double) => new Counter(0L, 0L) += label,
17-
mergeValue = (c: Counter, label: Double) => c += label,
18-
mergeCombiners = (c1: Counter, c2: Counter) => c1 += c2
90+
val counts = scoreAndlabels.combineByKey(
91+
createCombiner = (label: Double) => new LabelCounter(0L, 0L) += label,
92+
mergeValue = (c: LabelCounter, label: Double) => c += label,
93+
mergeCombiners = (c1: LabelCounter, c2: LabelCounter) => c1 += c2
1994
).sortByKey(ascending = false)
20-
println(counts.collect().toList)
21-
val agg = counts.values.mapPartitions((iter: Iterator[Counter]) => {
22-
val agg = new Counter()
95+
val agg = counts.values.mapPartitions({ iter =>
96+
val agg = new LabelCounter()
2397
iter.foreach(agg += _)
2498
Iterator(agg)
2599
}, preservesPartitioning = true).collect()
26-
println(agg.toList)
27-
val cum = agg.scanLeft(new Counter())((agg: Counter, c: Counter) => agg + c)
28-
val total = cum.last
29-
println(total)
30-
println(cum.toList)
31-
val cumCountsRdd = counts.mapPartitionsWithIndex((index: Int, iter: Iterator[(Double, Counter)]) => {
100+
val cum = agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg + c)
101+
val totalCount = cum.last
102+
val cumCounts = counts.mapPartitionsWithIndex((index: Int, iter: Iterator[(Double, LabelCounter)]) => {
32103
val cumCount = cum(index)
33104
iter.map { case (score, c) =>
34105
cumCount += c
35106
(score, cumCount.clone())
36107
}
37108
}, preservesPartitioning = true)
38-
println("cum: " + cumCountsRdd.collect().toList)
39-
val rocAUC = AreaUnderCurve.of(cumCountsRdd.values.map((c: Counter) => {
40-
(1.0 * c.numNegatives / total.numNegatives,
41-
1.0 * c.numPositives / total.numPositives)
42-
}))
43-
println(rocAUC)
44-
val prAUC = AreaUnderCurve.of(cumCountsRdd.values.map((c: Counter) => {
45-
(1.0 * c.numPositives / total.numPositives,
46-
1.0 * c.numPositives / (c.numPositives + c.numNegatives))
47-
}))
48-
println(prAUC)
109+
cumCounts.persist()
110+
val scoreAndConfusion = cumCounts.map { case (score, cumCount) =>
111+
(score, BinaryConfusionMatrix(cumCount, totalCount))
112+
}
113+
(cumCounts, totalCount, scoreAndConfusion)
114+
}
115+
116+
def unpersist() {
117+
cumCounts.unpersist()
49118
}
50119

120+
def rocCurve(): RDD[(Double, Double)] = createCurve(FalsePositiveRate, Recall)
121+
122+
def rocAUC(): Double = AreaUnderCurve.of(rocCurve())
123+
124+
def prCurve(): RDD[(Double, Double)] = createCurve(Recall, Precision)
125+
126+
def prAUC(): Double = AreaUnderCurve.of(prCurve())
127+
128+
def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))
129+
130+
def fMeasureByThreshold() = fMeasureByThreshold(1.0)
131+
132+
private def createCurve(y: Metric): RDD[(Double, Double)] = {
133+
scoreAndConfusion.map { case (s, c) =>
134+
(s, y(c))
135+
}
136+
}
137+
138+
private def createCurve(x: Metric, y: Metric): RDD[(Double, Double)] = {
139+
scoreAndConfusion.map { case (_, c) =>
140+
(x(c), y(c))
141+
}
142+
}
143+
}
144+
145+
class LocalBinaryClassificationEvaluator {
51146
def get(data: Iterable[(Double, Double)]) {
52147
val counts = data.groupBy(_._1).mapValues { s =>
53-
val c = new Counter()
148+
val c = new LabelCounter()
54149
s.foreach(c += _._2)
55150
c
56151
}.toSeq.sortBy(- _._1)
57152
println("counts: " + counts.toList)
58-
val total = new Counter()
153+
val total = new LabelCounter()
59154
val cum = counts.map { s =>
60155
total += s._2
61156
(s._1, total.clone())
@@ -75,35 +170,44 @@ object BinaryClassificationEvaluator {
75170
}
76171
}
77172

78-
class Counter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {
173+
/**
174+
* A counter for positives and negatives.
175+
*
176+
* @param numPositives
177+
* @param numNegatives
178+
*/
179+
private[evaluation]
180+
class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {
79181

80-
def +=(label: Double): Counter = {
182+
/** Process a label. */
183+
def +=(label: Double): LabelCounter = {
81184
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
82185
// -1.0 for negative as well.
83186
if (label > 0.5) numPositives += 1L else numNegatives += 1L
84187
this
85188
}
86189

87-
def +=(other: Counter): Counter = {
190+
/** Merge another counter. */
191+
def +=(other: LabelCounter): LabelCounter = {
88192
numPositives += other.numPositives
89193
numNegatives += other.numNegatives
90194
this
91195
}
92196

93-
def +(label: Double): Counter = {
197+
def +(label: Double): LabelCounter = {
94198
this.clone() += label
95199
}
96200

97-
def +(other: Counter): Counter = {
201+
def +(other: LabelCounter): LabelCounter = {
98202
this.clone() += other
99203
}
100204

101205
def sum: Long = numPositives + numNegatives
102206

103-
override def clone(): Counter = {
104-
new Counter(numPositives, numNegatives)
207+
override def clone: LabelCounter = {
208+
new LabelCounter(numPositives, numNegatives)
105209
}
106210

107-
override def toString(): String = s"[$numPositives,$numNegatives]"
211+
override def toString: String = s"[$numPositives,$numNegatives]"
108212
}
109213

0 commit comments

Comments
 (0)