@@ -24,39 +24,104 @@ import org.apache.spark.mllib.util.TestingUtils._
2424
2525class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext {
2626
27- def cond1 (x : (Double , Double )): Boolean = x._1 ~= (x._2) absTol 1E-5
27+ private def areWithinEpsilon (x : (Double , Double )): Boolean = x._1 ~= (x._2) absTol 1E-5
2828
29- def cond2 (x : ((Double , Double ), (Double , Double ))): Boolean =
29+ private def pairsWithinEpsilon (x : ((Double , Double ), (Double , Double ))): Boolean =
3030 (x._1._1 ~= x._2._1 absTol 1E-5 ) && (x._1._2 ~= x._2._2 absTol 1E-5 )
3131
32+ private def assertSequencesMatch (left : Seq [Double ], right : Seq [Double ]): Unit = {
33+ assert(left.zip(right).forall(areWithinEpsilon))
34+ }
35+
36+ private def assertTupleSequencesMatch (left : Seq [(Double , Double )],
37+ right : Seq [(Double , Double )]): Unit = {
38+ assert(left.zip(right).forall(pairsWithinEpsilon))
39+ }
40+
41+ private def validateMetrics (metrics : BinaryClassificationMetrics ,
42+ expectedThresholds : Seq [Double ],
43+ expectedROCCurve : Seq [(Double , Double )],
44+ expectedPRCurve : Seq [(Double , Double )],
45+ expectedFMeasures1 : Seq [Double ],
46+ expectedFmeasures2 : Seq [Double ],
47+ expectedPrecisions : Seq [Double ],
48+ expectedRecalls : Seq [Double ]) = {
49+
50+ assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
51+ assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
52+ assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(expectedROCCurve) absTol 1E-5 )
53+ assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve)
54+ assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(expectedPRCurve) absTol 1E-5 )
55+ assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(),
56+ expectedThresholds.zip(expectedFMeasures1))
57+ assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0 ).collect(),
58+ expectedThresholds.zip(expectedFmeasures2))
59+ assertTupleSequencesMatch(metrics.precisionByThreshold().collect(),
60+ expectedThresholds.zip(expectedPrecisions))
61+ assertTupleSequencesMatch(metrics.recallByThreshold().collect(),
62+ expectedThresholds.zip(expectedRecalls))
63+ }
64+
3265 test(" binary evaluation metrics" ) {
3366 val scoreAndLabels = sc.parallelize(
3467 Seq ((0.1 , 0.0 ), (0.1 , 1.0 ), (0.4 , 0.0 ), (0.6 , 0.0 ), (0.6 , 1.0 ), (0.6 , 1.0 ), (0.8 , 1.0 )), 2 )
3568 val metrics = new BinaryClassificationMetrics (scoreAndLabels)
36- val threshold = Seq (0.8 , 0.6 , 0.4 , 0.1 )
69+ val thresholds = Seq (0.8 , 0.6 , 0.4 , 0.1 )
3770 val numTruePositives = Seq (1 , 3 , 3 , 4 )
3871 val numFalsePositives = Seq (0 , 1 , 2 , 3 )
3972 val numPositives = 4
4073 val numNegatives = 3
41- val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
74+ val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
4275 t.toDouble / (t + f)
4376 }
44- val recall = numTruePositives.map(t => t.toDouble / numPositives)
77+ val recalls = numTruePositives.map(t => t.toDouble / numPositives)
4578 val fpr = numFalsePositives.map(f => f.toDouble / numNegatives)
46- val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recall ) ++ Seq ((1.0 , 1.0 ))
47- val pr = recall .zip(precision )
79+ val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recalls ) ++ Seq ((1.0 , 1.0 ))
80+ val pr = recalls .zip(precisions )
4881 val prCurve = Seq ((0.0 , 1.0 )) ++ pr
4982 val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
5083 val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
5184
52- assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
53- assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
54- assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(rocCurve) absTol 1E-5 )
55- assert(metrics.pr().collect().zip(prCurve).forall(cond2))
56- assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(prCurve) absTol 1E-5 )
57- assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
58- assert(metrics.fMeasureByThreshold(2.0 ).collect().zip(threshold.zip(f2)).forall(cond2))
59- assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
60- assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
85+ validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
86+ }
87+
88+ test(" binary evaluation metrics for RDD where all examples have positive label" ) {
89+ val scoreAndLabels = sc.parallelize(Seq ((0.5 , 1.0 ), (0.5 , 1.0 )), 2 )
90+ val metrics = new BinaryClassificationMetrics (scoreAndLabels)
91+
92+ val thresholds = Seq (0.5 )
93+ val precisions = Seq (1.0 )
94+ val recalls = Seq (1.0 )
95+ val fpr = Seq (0.0 )
96+ val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recalls) ++ Seq ((1.0 , 1.0 ))
97+ val pr = recalls.zip(precisions)
98+ val prCurve = Seq ((0.0 , 1.0 )) ++ pr
99+ val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
100+ val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
101+
102+ validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
103+ }
104+
105+ test(" binary evaluation metrics for RDD where all examples have negative label" ) {
106+ val scoreAndLabels = sc.parallelize(Seq ((0.5 , 0.0 ), (0.5 , 0.0 )), 2 )
107+ val metrics = new BinaryClassificationMetrics (scoreAndLabels)
108+
109+ val thresholds = Seq (0.5 )
110+ val precisions = Seq (0.0 )
111+ val recalls = Seq (0.0 )
112+ val fpr = Seq (1.0 )
113+ val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recalls) ++ Seq ((1.0 , 1.0 ))
114+ val pr = recalls.zip(precisions)
115+ val prCurve = Seq ((0.0 , 1.0 )) ++ pr
116+ val f1 = pr.map {
117+ case (0 , 0 ) => 0.0
118+ case (r, p) => 2.0 * (p * r) / (p + r)
119+ }
120+ val f2 = pr.map {
121+ case (0 , 0 ) => 0.0
122+ case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
123+ }
124+
125+ validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
61126 }
62127}
0 commit comments