@@ -59,4 +59,60 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
5959 assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
6060 assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
6161 }
62+
63+ test(" binary evaluation metrics for All Positive RDD" ) {
64+ val scoreAndLabels = sc.parallelize(Seq ((0.5 , 1.0 )), 2 )
65+ val metrics : BinaryClassificationMetrics = new BinaryClassificationMetrics (scoreAndLabels)
66+
67+ val threshold = Seq (0.5 )
68+ val precision = Seq (1.0 )
69+ val recall = Seq (1.0 )
70+ val fpr = Seq (0.0 )
71+ val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recall) ++ Seq ((1.0 , 1.0 ))
72+ val pr = recall.zip(precision)
73+ val prCurve = Seq ((0.0 , 1.0 )) ++ pr
74+ val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
75+ val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
76+
77+ assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
78+ assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
79+ assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(rocCurve) absTol 1E-5 )
80+ assert(metrics.pr().collect().zip(prCurve).forall(cond2))
81+ assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(prCurve) absTol 1E-5 )
82+ assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
83+ assert(metrics.fMeasureByThreshold(2.0 ).collect().zip(threshold.zip(f2)).forall(cond2))
84+ assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
85+ assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
86+ }
87+
88+ test(" binary evaluation metrics for All Negative RDD" ) {
89+ val scoreAndLabels = sc.parallelize(Seq ((0.5 , 0.0 )), 2 )
90+ val metrics : BinaryClassificationMetrics = new BinaryClassificationMetrics (scoreAndLabels)
91+
92+ val threshold = Seq (0.5 )
93+ val precision = Seq (0.0 )
94+ val recall = Seq (0.0 )
95+ val fpr = Seq (1.0 )
96+ val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recall) ++ Seq ((1.0 , 1.0 ))
97+ val pr = recall.zip(precision)
98+ val prCurve = Seq ((0.0 , 1.0 )) ++ pr
99+ val f1 = pr.map {
100+ case (0 ,0 ) => 0.0
101+ case (r, p) => 2.0 * (p * r) / (p + r)
102+ }
103+ val f2 = pr.map {
104+ case (0 ,0 ) => 0.0
105+ case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
106+ }
107+
108+ assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
109+ assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
110+ assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(rocCurve) absTol 1E-5 )
111+ assert(metrics.pr().collect().zip(prCurve).forall(cond2))
112+ assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(prCurve) absTol 1E-5 )
113+ assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
114+ assert(metrics.fMeasureByThreshold(2.0 ).collect().zip(threshold.zip(f2)).forall(cond2))
115+ assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
116+ assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
117+ }
62118}
0 commit comments