@@ -22,55 +22,57 @@ import org.apache.spark.SparkContext._
2222
2323/**
2424 * Evaluator for multilabel classification.
25- * @param predictionAndLabels an RDD of (predictions, labels) pairs, both are non-null sets.
25+ * @param predictionAndLabels an RDD of (predictions, labels) pairs,
26+ * both are non-null Arrays, each with unique elements.
2627 */
27- class MultilabelMetrics (predictionAndLabels : RDD [(Set [Double ], Set [Double ])]) {
28+ class MultilabelMetrics (predictionAndLabels : RDD [(Array [Double ], Array [Double ])]) {
2829
29- private lazy val numDocs : Long = predictionAndLabels.count
30+ private lazy val numDocs : Long = predictionAndLabels.count()
3031
3132 private lazy val numLabels : Long = predictionAndLabels.flatMap { case (_, labels) =>
32- labels}.distinct.count
33+ labels}.distinct() .count()
3334
3435 /**
3536 * Returns strict Accuracy
3637 * (for equal sets of labels)
3738 */
3839 lazy val strictAccuracy : Double = predictionAndLabels.filter { case (predictions, labels) =>
39- predictions == labels}.count.toDouble / numDocs
40+ predictions.deep == labels.deep }.count() .toDouble / numDocs
4041
4142 /**
4243 * Returns Accuracy
4344 */
4445 lazy val accuracy : Double = predictionAndLabels.map { case (predictions, labels) =>
45- labels.intersect(predictions).size.toDouble / labels.union(predictions).size}.sum / numDocs
46+ labels.intersect(predictions).size.toDouble /
47+ (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs
4648
4749 /**
4850 * Returns Hamming-loss
4951 */
50- lazy val hammingLoss : Double = ( predictionAndLabels.map { case (predictions, labels) =>
52+ lazy val hammingLoss : Double = predictionAndLabels.map { case (predictions, labels) =>
5153 labels.diff(predictions).size + predictions.diff(labels).size}.
52- sum).toDouble / (numDocs * numLabels)
54+ sum / (numDocs * numLabels)
5355
5456 /**
5557 * Returns Document-based Precision averaged by the number of documents
5658 */
57- lazy val macroPrecisionDoc : Double = ( predictionAndLabels.map { case (predictions, labels) =>
59+ lazy val macroPrecisionDoc : Double = predictionAndLabels.map { case (predictions, labels) =>
5860 if (predictions.size > 0 ) {
5961 predictions.intersect(labels).size.toDouble / predictions.size
6062 } else 0
61- }.sum) / numDocs
63+ }.sum / numDocs
6264
6365 /**
6466 * Returns Document-based Recall averaged by the number of documents
6567 */
66- lazy val macroRecallDoc : Double = ( predictionAndLabels.map { case (predictions, labels) =>
67- labels.intersect(predictions).size.toDouble / labels.size}.sum) / numDocs
68+ lazy val macroRecallDoc : Double = predictionAndLabels.map { case (predictions, labels) =>
69+ labels.intersect(predictions).size.toDouble / labels.size}.sum / numDocs
6870
6971 /**
7072 * Returns Document-based F1-measure averaged by the number of documents
7173 */
72- lazy val macroF1MeasureDoc : Double = ( predictionAndLabels.map { case (predictions, labels) =>
73- 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)}.sum) / numDocs
74+ lazy val macroF1MeasureDoc : Double = predictionAndLabels.map { case (predictions, labels) =>
75+ 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)}.sum / numDocs
7476
7577 /**
7678 * Returns micro-averaged document-based Precision
@@ -137,15 +139,15 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
137139 * Returns micro-averaged label-based Precision
138140 */
139141 lazy val microPrecisionClass = {
140- val sumFp = fpPerClass.foldLeft(0L ){ case (sumFp , (_, fp)) => sumFp + fp}
142+ val sumFp = fpPerClass.foldLeft(0L ){ case (cum , (_, fp)) => cum + fp}
141143 sumTp.toDouble / (sumTp + sumFp)
142144 }
143145
144146 /**
145147 * Returns micro-averaged label-based Recall
146148 */
147149 lazy val microRecallClass = {
148- val sumFn = fnPerClass.foldLeft(0.0 ){ case (sumFn , (_, fn)) => sumFn + fn}
150+ val sumFn = fnPerClass.foldLeft(0.0 ){ case (cum , (_, fn)) => cum + fn}
149151 sumTp.toDouble / (sumTp + sumFn)
150152 }
151153
0 commit comments