|
17 | 17 |
|
18 | 18 | package org.apache.spark.mllib.evaluation.binary |
19 | 19 |
|
20 | | -import org.apache.spark.rdd.RDD |
| 20 | +import org.apache.spark.rdd.{UnionRDD, RDD} |
21 | 21 | import org.apache.spark.SparkContext._ |
22 | 22 | import org.apache.spark.mllib.evaluation.AreaUnderCurve |
23 | 23 | import org.apache.spark.Logging |
@@ -103,22 +103,34 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) |
103 | 103 |
|
104 | 104 | /** |
105 | 105 | * Returns the receiver operating characteristic (ROC) curve, |
106 | | - * which is an RDD of (false positive rate, true positive rate). |
| 106 | + * which is an RDD of (false positive rate, true positive rate) |
| 107 | + * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. |
107 | 108 | * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic |
108 | 109 | */ |
109 | | - def roc(): RDD[(Double, Double)] = createCurve(FalsePositiveRate, Recall) |
| 110 | + def roc(): RDD[(Double, Double)] = { |
| 111 | + val rocCurve = createCurve(FalsePositiveRate, Recall) |
| 112 | + val sc = confusions.context |
| 113 | + val first = sc.makeRDD(Seq((0.0, 0.0)), 1) |
| 114 | + val last = sc.makeRDD(Seq((1.0, 1.0)), 1) |
| 115 | + new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last)) |
| 116 | + } |
110 | 117 |
|
111 | 118 | /** |
112 | 119 | * Computes the area under the receiver operating characteristic (ROC) curve. |
113 | 120 | */ |
114 | 121 | def areaUnderROC(): Double = AreaUnderCurve.of(roc()) |
115 | 122 |
|
116 | 123 | /** |
117 | | - * Returns the precision-recall curve, |
118 | | - * which is an RDD of (recall, precision), NOT (precision, recall). |
| 124 | + * Returns the precision-recall curve, which is an RDD of (recall, precision), |
| 125 | + * NOT (precision, recall), with (0.0, 1.0) prepended to it. |
119 | 126 | * @see http://en.wikipedia.org/wiki/Precision_and_recall |
120 | 127 | */ |
121 | | - def pr(): RDD[(Double, Double)] = createCurve(Recall, Precision) |
| 128 | + def pr(): RDD[(Double, Double)] = { |
| 129 | + val prCurve = createCurve(Recall, Precision) |
| 130 | + val sc = confusions.context |
| 131 | + val first = sc.makeRDD(Seq((0.0, 1.0)), 1) |
| 132 | + first.union(prCurve) |
| 133 | + } |
122 | 134 |
|
123 | 135 | /** |
124 | 136 | * Computes the area under the precision-recall curve. |
|
0 commit comments