Skip to content

Commit 9dc3518

Browse files
committed
add tests for BinaryClassificationEvaluator
1 parent ca31da5 commit 9dc3518

File tree

6 files changed

+71
-42
lines changed

6 files changed

+71
-42
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._
2323
/**
2424
* Computes the area under the curve (AUC) using the trapezoidal rule.
2525
*/
26-
private[mllib] object AreaUnderCurve {
26+
private[evaluation] object AreaUnderCurve {
2727

2828
/**
2929
* Uses the trapezoidal rule to compute the area under the line connecting the two input points.

mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationEvaluator.scala

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ import org.apache.spark.Logging
2929
* @param totalCount label counter for all labels
3030
*/
3131
private case class BinaryConfusionMatrixImpl(
32-
private val count: LabelCounter,
33-
private val totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {
32+
count: LabelCounter,
33+
totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {
3434

3535
/** number of true positives */
3636
override def tp: Long = count.numPositives
@@ -54,16 +54,16 @@ private case class BinaryConfusionMatrixImpl(
5454
/**
5555
* Evaluator for binary classification.
5656
*
57-
* @param scoreAndlabels an RDD of (score, label) pairs.
57+
* @param scoreAndLabels an RDD of (score, label) pairs.
5858
*/
59-
class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) extends Serializable with Logging {
59+
class BinaryClassificationEvaluator(scoreAndLabels: RDD[(Double, Double)]) extends Serializable with Logging {
6060

6161
private lazy val (
6262
cumCounts: RDD[(Double, LabelCounter)],
63-
confusionByThreshold: RDD[(Double, BinaryConfusionMatrix)]) = {
63+
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
6464
// Create a bin for each distinct score value, count positives and negatives within each bin,
6565
// and then sort by score values in descending order.
66-
val counts = scoreAndlabels.combineByKey(
66+
val counts = scoreAndLabels.combineByKey(
6767
createCombiner = (label: Double) => new LabelCounter(0L, 0L) += label,
6868
mergeValue = (c: LabelCounter, label: Double) => c += label,
6969
mergeCombiners = (c1: LabelCounter, c2: LabelCounter) => c1 += c2
@@ -73,21 +73,21 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
7373
iter.foreach(agg += _)
7474
Iterator(agg)
7575
}, preservesPartitioning = true).collect()
76-
val cum = agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg + c)
77-
val totalCount = cum.last
78-
logInfo(s"Total counts: totalCount")
76+
val partitionwiseCumCounts = agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg + c)
77+
val totalCount = partitionwiseCumCounts.last
78+
logInfo(s"Total counts: $totalCount")
7979
val cumCounts = counts.mapPartitionsWithIndex((index: Int, iter: Iterator[(Double, LabelCounter)]) => {
80-
val cumCount = cum(index)
80+
val cumCount = partitionwiseCumCounts(index)
8181
iter.map { case (score, c) =>
8282
cumCount += c
8383
(score, cumCount.clone())
8484
}
8585
}, preservesPartitioning = true)
8686
cumCounts.persist()
87-
val scoreAndConfusion = cumCounts.map { case (score, cumCount) =>
88-
(score, BinaryConfusionMatrixImpl(cumCount, totalCount))
87+
val confusions = cumCounts.map { case (score, cumCount) =>
88+
(score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
8989
}
90-
(cumCounts, totalCount, scoreAndConfusion)
90+
(cumCounts, confusions)
9191
}
9292

9393
/** Unpersist intermediate RDDs used in the computation. */
@@ -126,18 +126,18 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
126126
def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))
127127

128128
/** Returns the (threshold, F-Measure) curve with beta = 1.0. */
129-
def fMeasureByThreshold() = fMeasureByThreshold(1.0)
129+
def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0)
130130

131131
/** Creates a curve of (threshold, metric). */
132132
private def createCurve(y: BinaryClassificationMetric): RDD[(Double, Double)] = {
133-
confusionByThreshold.map { case (s, c) =>
133+
confusions.map { case (s, c) =>
134134
(s, y(c))
135135
}
136136
}
137137

138138
/** Creates a curve of (metricX, metricY). */
139139
private def createCurve(x: BinaryClassificationMetric, y: BinaryClassificationMetric): RDD[(Double, Double)] = {
140-
confusionByThreshold.map { case (_, c) =>
140+
confusions.map { case (_, c) =>
141141
(x(c), y(c))
142142
}
143143
}
@@ -151,35 +151,29 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
151151
*/
152152
private class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {
153153

154-
/** Process a label. */
154+
/** Processes a label. */
155155
def +=(label: Double): LabelCounter = {
156156
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
157157
// -1.0 for negative as well.
158158
if (label > 0.5) numPositives += 1L else numNegatives += 1L
159159
this
160160
}
161161

162-
/** Merge another counter. */
162+
/** Merges another counter. */
163163
def +=(other: LabelCounter): LabelCounter = {
164164
numPositives += other.numPositives
165165
numNegatives += other.numNegatives
166166
this
167167
}
168168

169-
def +(label: Double): LabelCounter = {
170-
this.clone() += label
171-
}
172-
169+
/** Sums this counter and another counter and returns the result in a new counter. */
173170
def +(other: LabelCounter): LabelCounter = {
174171
this.clone() += other
175172
}
176173

177-
def sum: Long = numPositives + numNegatives
178-
179174
override def clone: LabelCounter = {
180175
new LabelCounter(numPositives, numNegatives)
181176
}
182177

183178
override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
184179
}
185-

mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.evaluation.binary
2020
/**
2121
* Trait for a binary classification evaluation metric.
2222
*/
23-
private[evaluation] trait BinaryClassificationMetric {
23+
private[evaluation] trait BinaryClassificationMetric extends Serializable {
2424
def apply(c: BinaryConfusionMatrix): Double
2525
}
2626

@@ -37,7 +37,7 @@ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetric
3737
}
3838

3939
/** Recall. */
40-
private[evalution] object Recall extends BinaryClassificationMetric {
40+
private[evaluation] object Recall extends BinaryClassificationMetric {
4141
override def apply(c: BinaryConfusionMatrix): Double =
4242
c.tp.toDouble / c.p
4343
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import org.scalatest.FunSuite
2222
import org.apache.spark.mllib.util.LocalSparkContext
2323

2424
class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
25-
2625
test("auc computation") {
2726
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
2827
val auc = 4.0

mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationEvaluationSuite.scala

Lines changed: 0 additions & 13 deletions
This file was deleted.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation.binary
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.mllib.util.LocalSparkContext
23+
import org.apache.spark.mllib.evaluation.AreaUnderCurve
24+
25+
class BinaryClassificationEvaluatorSuite extends FunSuite with LocalSparkContext {
26+
test("binary evaluation metrics") {
27+
val scoreAndLabels = sc.parallelize(
28+
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
29+
val evaluator = new BinaryClassificationEvaluator(scoreAndLabels)
30+
val score = Seq(0.8, 0.6, 0.4, 0.1)
31+
val tp = Seq(1, 3, 3, 4)
32+
val fp = Seq(0, 1, 2, 3)
33+
val p = 4
34+
val n = 3
35+
val precision = tp.zip(fp).map { case (t, f) => t.toDouble / (t + f) }
36+
val recall = tp.map(t => t.toDouble / p)
37+
val fpr = fp.map(f => f.toDouble / n)
38+
val roc = fpr.zip(recall)
39+
val pr = recall.zip(precision)
40+
val f1 = pr.map { case (re, prec) => 2.0 * (prec * re) / (prec + re) }
41+
val f2 = pr.map { case (re, prec) => 5.0 * (prec * re) / (4.0 * prec + re)}
42+
assert(evaluator.rocCurve().collect().toSeq === roc)
43+
assert(evaluator.rocAUC() === AreaUnderCurve.of(roc))
44+
assert(evaluator.prCurve().collect().toSeq === pr)
45+
assert(evaluator.prAUC() === AreaUnderCurve.of(pr))
46+
assert(evaluator.fMeasureByThreshold().collect().toSeq === score.zip(f1))
47+
assert(evaluator.fMeasureByThreshold(2.0).collect().toSeq === score.zip(f2))
48+
}
49+
}

0 commit comments

Comments
 (0)