Skip to content

Commit 43a613e

Browse files
committed
Addressing reviewers comments: change Set to Array
1 parent 1843f73 commit 43a613e

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,55 +22,57 @@ import org.apache.spark.SparkContext._
2222

2323
/**
2424
* Evaluator for multilabel classification.
25-
* @param predictionAndLabels an RDD of (predictions, labels) pairs, both are non-null sets.
25+
* @param predictionAndLabels an RDD of (predictions, labels) pairs,
26+
* both are non-null Arrays, each with unique elements.
2627
*/
27-
class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
28+
class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
2829

29-
private lazy val numDocs: Long = predictionAndLabels.count
30+
private lazy val numDocs: Long = predictionAndLabels.count()
3031

3132
private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
32-
labels}.distinct.count
33+
labels}.distinct().count()
3334

3435
/**
3536
* Returns strict Accuracy
3637
* (for equal sets of labels)
3738
*/
3839
lazy val strictAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
39-
predictions == labels}.count.toDouble / numDocs
40+
predictions.deep == labels.deep }.count().toDouble / numDocs
4041

4142
/**
4243
* Returns Accuracy
4344
*/
4445
lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) =>
45-
labels.intersect(predictions).size.toDouble / labels.union(predictions).size}.sum / numDocs
46+
labels.intersect(predictions).size.toDouble /
47+
(labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs
4648

4749
/**
4850
* Returns Hamming-loss
4951
*/
50-
lazy val hammingLoss: Double = (predictionAndLabels.map { case (predictions, labels) =>
52+
lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) =>
5153
labels.diff(predictions).size + predictions.diff(labels).size}.
52-
sum).toDouble / (numDocs * numLabels)
54+
sum / (numDocs * numLabels)
5355

5456
/**
5557
* Returns Document-based Precision averaged by the number of documents
5658
*/
57-
lazy val macroPrecisionDoc: Double = (predictionAndLabels.map { case (predictions, labels) =>
59+
lazy val macroPrecisionDoc: Double = predictionAndLabels.map { case (predictions, labels) =>
5860
if (predictions.size > 0) {
5961
predictions.intersect(labels).size.toDouble / predictions.size
6062
} else 0
61-
}.sum) / numDocs
63+
}.sum / numDocs
6264

6365
/**
6466
* Returns Document-based Recall averaged by the number of documents
6567
*/
66-
lazy val macroRecallDoc: Double = (predictionAndLabels.map { case (predictions, labels) =>
67-
labels.intersect(predictions).size.toDouble / labels.size}.sum) / numDocs
68+
lazy val macroRecallDoc: Double = predictionAndLabels.map { case (predictions, labels) =>
69+
labels.intersect(predictions).size.toDouble / labels.size}.sum / numDocs
6870

6971
/**
7072
* Returns Document-based F1-measure averaged by the number of documents
7173
*/
72-
lazy val macroF1MeasureDoc: Double = (predictionAndLabels.map { case (predictions, labels) =>
73-
2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)}.sum) / numDocs
74+
lazy val macroF1MeasureDoc: Double = predictionAndLabels.map { case (predictions, labels) =>
75+
2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)}.sum / numDocs
7476

7577
/**
7678
* Returns micro-averaged document-based Precision
@@ -137,15 +139,15 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
137139
* Returns micro-averaged label-based Precision
138140
*/
139141
lazy val microPrecisionClass = {
140-
val sumFp = fpPerClass.foldLeft(0L){ case(sumFp, (_, fp)) => sumFp + fp}
142+
val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
141143
sumTp.toDouble / (sumTp + sumFp)
142144
}
143145

144146
/**
145147
* Returns micro-averaged label-based Recall
146148
*/
147149
lazy val microRecallClass = {
148-
val sumFn = fnPerClass.foldLeft(0.0){ case(sumFn, (_, fn)) => sumFn + fn}
150+
val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
149151
sumTp.toDouble / (sumTp + sumFn)
150152
}
151153

mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
4545
* class 2 - doc 0, 3, 4, 6 (total 4)
4646
*
4747
*/
48-
val scoreAndLabels: RDD[(Set[Double], Set[Double])] = sc.parallelize(
49-
Seq((Set(0.0, 1.0), Set(0.0, 2.0)),
50-
(Set(0.0, 2.0), Set(0.0, 1.0)),
51-
(Set(), Set(0.0)),
52-
(Set(2.0), Set(2.0)),
53-
(Set(2.0, 0.0), Set(2.0, 0.0)),
54-
(Set(0.0, 1.0, 2.0), Set(0.0, 1.0)),
55-
(Set(1.0), Set(1.0, 2.0))), 2)
48+
val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize(
49+
Seq((Array(0.0, 1.0), Array(0.0, 2.0)),
50+
(Array(0.0, 2.0), Array(0.0, 1.0)),
51+
(Array(), Array(0.0)),
52+
(Array(2.0), Array(2.0)),
53+
(Array(2.0, 0.0), Array(2.0, 0.0)),
54+
(Array(0.0, 1.0, 2.0), Array(0.0, 1.0)),
55+
(Array(1.0), Array(1.0, 2.0))), 2)
5656
val metrics = new MultilabelMetrics(scoreAndLabels)
5757
val delta = 0.00001
5858
val precision0 = 4.0 / (4 + 0)

0 commit comments

Comments
 (0)