@@ -29,8 +29,8 @@ import org.apache.spark.Logging
2929 * @param totalCount label counter for all labels
3030 */
3131private case class BinaryConfusionMatrixImpl (
32- private val count : LabelCounter ,
33- private val totalCount : LabelCounter ) extends BinaryConfusionMatrix with Serializable {
32+ count : LabelCounter ,
33+ totalCount : LabelCounter ) extends BinaryConfusionMatrix with Serializable {
3434
3535 /** number of true positives */
3636 override def tp : Long = count.numPositives
@@ -54,16 +54,16 @@ private case class BinaryConfusionMatrixImpl(
5454/**
5555 * Evaluator for binary classification.
5656 *
57- * @param scoreAndlabels an RDD of (score, label) pairs.
57+ * @param scoreAndLabels an RDD of (score, label) pairs.
5858 */
59- class BinaryClassificationEvaluator (scoreAndlabels : RDD [(Double , Double )]) extends Serializable with Logging {
59+ class BinaryClassificationEvaluator (scoreAndLabels : RDD [(Double , Double )]) extends Serializable with Logging {
6060
6161 private lazy val (
6262 cumCounts : RDD [(Double , LabelCounter )],
63- confusionByThreshold : RDD [(Double , BinaryConfusionMatrix )]) = {
63+ confusions : RDD [(Double , BinaryConfusionMatrix )]) = {
6464 // Create a bin for each distinct score value, count positives and negatives within each bin,
6565 // and then sort by score values in descending order.
66- val counts = scoreAndlabels .combineByKey(
66+ val counts = scoreAndLabels .combineByKey(
6767 createCombiner = (label : Double ) => new LabelCounter (0L , 0L ) += label,
6868 mergeValue = (c : LabelCounter , label : Double ) => c += label,
6969 mergeCombiners = (c1 : LabelCounter , c2 : LabelCounter ) => c1 += c2
@@ -73,21 +73,21 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
7373 iter.foreach(agg += _)
7474 Iterator (agg)
7575 }, preservesPartitioning = true ).collect()
76- val cum = agg.scanLeft(new LabelCounter ())((agg : LabelCounter , c : LabelCounter ) => agg + c)
77- val totalCount = cum .last
78- logInfo(s " Total counts: totalCount " )
76+ val partitionwiseCumCounts = agg.scanLeft(new LabelCounter ())((agg : LabelCounter , c : LabelCounter ) => agg + c)
77+ val totalCount = partitionwiseCumCounts .last
78+ logInfo(s " Total counts: $ totalCount" )
7979 val cumCounts = counts.mapPartitionsWithIndex((index : Int , iter : Iterator [(Double , LabelCounter )]) => {
80- val cumCount = cum (index)
80+ val cumCount = partitionwiseCumCounts (index)
8181 iter.map { case (score, c) =>
8282 cumCount += c
8383 (score, cumCount.clone())
8484 }
8585 }, preservesPartitioning = true )
8686 cumCounts.persist()
87- val scoreAndConfusion = cumCounts.map { case (score, cumCount) =>
88- (score, BinaryConfusionMatrixImpl (cumCount, totalCount))
87+ val confusions = cumCounts.map { case (score, cumCount) =>
88+ (score, BinaryConfusionMatrixImpl (cumCount, totalCount). asInstanceOf [ BinaryConfusionMatrix ] )
8989 }
90- (cumCounts, totalCount, scoreAndConfusion )
90+ (cumCounts, confusions )
9191 }
9292
9393 /** Unpersist intermediate RDDs used in the computation. */
@@ -126,18 +126,18 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
126126 def fMeasureByThreshold (beta : Double ): RDD [(Double , Double )] = createCurve(FMeasure (beta))
127127
128128 /** Returns the (threshold, F-Measure) curve with beta = 1.0. */
129- def fMeasureByThreshold () = fMeasureByThreshold(1.0 )
129+ def fMeasureByThreshold (): RDD [( Double , Double )] = fMeasureByThreshold(1.0 )
130130
131131 /** Creates a curve of (threshold, metric). */
132132 private def createCurve (y : BinaryClassificationMetric ): RDD [(Double , Double )] = {
133- confusionByThreshold .map { case (s, c) =>
133+ confusions .map { case (s, c) =>
134134 (s, y(c))
135135 }
136136 }
137137
138138 /** Creates a curve of (metricX, metricY). */
139139 private def createCurve (x : BinaryClassificationMetric , y : BinaryClassificationMetric ): RDD [(Double , Double )] = {
140- confusionByThreshold .map { case (_, c) =>
140+ confusions .map { case (_, c) =>
141141 (x(c), y(c))
142142 }
143143 }
@@ -151,35 +151,29 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
151151 */
152152private class LabelCounter (var numPositives : Long = 0L , var numNegatives : Long = 0L ) extends Serializable {
153153
154- /** Process a label. */
154+ /** Processes a label. */
155155 def += (label : Double ): LabelCounter = {
156156 // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
157157 // -1.0 for negative as well.
158158 if (label > 0.5 ) numPositives += 1L else numNegatives += 1L
159159 this
160160 }
161161
162- /** Merge another counter. */
162+ /** Merges another counter. */
163163 def += (other : LabelCounter ): LabelCounter = {
164164 numPositives += other.numPositives
165165 numNegatives += other.numNegatives
166166 this
167167 }
168168
169- def + (label : Double ): LabelCounter = {
170- this .clone() += label
171- }
172-
169+ /** Sums this counter and another counter and returns the result in a new counter. */
173170 def + (other : LabelCounter ): LabelCounter = {
174171 this .clone() += other
175172 }
176173
177- def sum : Long = numPositives + numNegatives
178-
179174 override def clone : LabelCounter = {
180175 new LabelCounter (numPositives, numNegatives)
181176 }
182177
183178 override def toString : String = s " {numPos: $numPositives, numNeg: $numNegatives} "
184179}
185-
0 commit comments