Skip to content

Commit 154164b

Browse files
committed
Multilabel evaluation metics and tests: macro precision and recall averaged by docs, micro and per-class precision and recall averaged by class
1 parent 67fca18 commit 154164b

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package org.apache.spark.mllib.evaluation
2+
3+
import org.apache.spark.Logging
4+
import org.apache.spark.rdd.RDD
5+
import org.apache.spark.SparkContext._
6+
7+
8+
class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) extends Logging{
9+
10+
private lazy val numDocs = predictionAndLabels.count()
11+
12+
lazy val macroPrecisionDoc = (predictionAndLabels.map{ case(predictions, labels) =>
13+
if (predictions.size >0)
14+
predictions.intersect(labels).size.toDouble / predictions.size else 0}.fold(0.0)(_ + _)) / numDocs
15+
16+
lazy val macroRecallDoc = (predictionAndLabels.map{ case(predictions, labels) =>
17+
predictions.intersect(labels).size.toDouble / labels.size}.fold(0.0)(_ + _)) / numDocs
18+
19+
lazy val microPrecisionDoc = {
20+
val (sumTp, sumPredictions) = predictionAndLabels.map{ case(predictions, labels) =>
21+
(predictions.intersect(labels).size, predictions.size)}.
22+
fold((0, 0)){ case((tp1, predictions1), (tp2, predictions2)) =>
23+
(tp1 + tp2, predictions1 + predictions2)}
24+
sumTp.toDouble / sumPredictions
25+
}
26+
27+
lazy val microRecallDoc = {
28+
val (sumTp, sumLabels) = predictionAndLabels.map{ case(predictions, labels) =>
29+
(predictions.intersect(labels).size, labels.size)}.
30+
fold((0, 0)){ case((tp1, labels1), (tp2, labels2)) =>
31+
(tp1 + tp2, labels1 + labels2)}
32+
sumTp.toDouble / sumLabels
33+
}
34+
35+
private lazy val tpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
36+
predictions.intersect(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
37+
38+
private lazy val fpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
39+
predictions.diff(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
40+
41+
private lazy val fnPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
42+
labels.diff(predictions).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
43+
44+
def precisionClass(label: Double) = if((tpPerClass(label) + fpPerClass.getOrElse(label, 0)) == 0) 0 else
45+
tpPerClass(label).toDouble / (tpPerClass(label) + fpPerClass.getOrElse(label, 0))
46+
47+
def recallClass(label: Double) = if((tpPerClass(label) + fnPerClass.getOrElse(label, 0)) == 0) 0 else
48+
tpPerClass(label).toDouble / (tpPerClass(label) + fnPerClass.getOrElse(label, 0))
49+
50+
def f1MeasureClass(label: Double) = {
51+
val precision = precisionClass(label)
52+
val recall = recallClass(label)
53+
if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall)
54+
}
55+
56+
private lazy val sumTp = tpPerClass.foldLeft(0L){ case(sumTp, (_, tp)) => sumTp + tp}
57+
58+
lazy val microPrecisionClass = {
59+
val sumFp = fpPerClass.foldLeft(0L){ case(sumFp, (_, fp)) => sumFp + fp}
60+
sumTp.toDouble / (sumTp + sumFp)
61+
}
62+
63+
lazy val microRecallClass = {
64+
val sumFn = fnPerClass.foldLeft(0.0){ case(sumFn, (_, fn)) => sumFn + fn}
65+
sumTp.toDouble / (sumTp + sumFn)
66+
}
67+
68+
lazy val microF1MeasureClass = {
69+
val precision = microPrecisionClass
70+
val recall = microRecallClass
71+
if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall)
72+
}
73+
74+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package org.apache.spark.mllib.evaluation
2+
3+
import org.apache.spark.mllib.util.LocalSparkContext
4+
import org.apache.spark.rdd.RDD
5+
import org.scalatest.FunSuite
6+
7+
8+
class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
9+
test("Multilabel evaluation metrics") {
10+
/*
11+
* Documents true labels (5x class0, 3x class1, 4x class2):
12+
* doc 0 - predict 0, 1 - class 0, 2
13+
* doc 1 - predict 0, 2 - class 0, 1
14+
* doc 2 - predict none - class 0
15+
* doc 3 - predict 2 - class 2
16+
* doc 4 - predict 2, 0 - class 2, 0
17+
* doc 5 - predict 0, 1, 2 - class 0, 1
18+
* doc 6 - predict 1 - class 1, 2
19+
*
20+
* predicted classes
21+
* class 0 - doc 0, 1, 4, 5 (total 4)
22+
* class 1 - doc 0, 5, 6 (total 3)
23+
* class 2 - doc 1, 3, 4, 5 (total 4)
24+
*
25+
* true classes
26+
* class 0 - doc 0, 1, 2, 4, 5 (total 5)
27+
* class 1 - doc 1, 5, 6 (total 3)
28+
* class 2 - doc 0, 3, 4, 6 (total 4)
29+
*
30+
*/
31+
val scoreAndLabels:RDD[(Set[Double], Set[Double])] = sc.parallelize(
32+
Seq((Set(0.0, 1.0), Set(0.0, 2.0)),
33+
(Set(0.0, 2.0), Set(0.0, 1.0)),
34+
(Set(), Set(0.0)),
35+
(Set(2.0), Set(2.0)),
36+
(Set(2.0, 0.0), Set(2.0, 0.0)),
37+
(Set(0.0, 1.0, 2.0), Set(0.0, 1.0)),
38+
(Set(1.0), Set(1.0, 2.0))), 2)
39+
val metrics = new MultilabelMetrics(scoreAndLabels)
40+
val delta = 0.00001
41+
val precision0 = 4.0 / (4 + 0)
42+
val precision1 = 2.0 / (2 + 1)
43+
val precision2 = 2.0 / (2 + 2)
44+
val recall0 = 4.0 / (4 + 1)
45+
val recall1 = 2.0 / (2 + 1)
46+
val recall2 = 2.0 / (2 + 2)
47+
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
48+
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
49+
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
50+
val microPrecisionClass = (4.0 + 2.0 + 2.0) / (4 + 0 + 2 + 1 + 2 + 2)
51+
val microRecallClass = (4.0 + 2.0 + 2.0) / (4 + 1 + 2 + 1 + 2 + 2)
52+
val microF1MeasureClass = 2 * microPrecisionClass * microRecallClass / (microPrecisionClass + microRecallClass)
53+
54+
val macroPrecisionDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0)
55+
val macroRecallDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2)
56+
57+
println("Ev" + metrics.macroPrecisionDoc)
58+
println(macroPrecisionDoc)
59+
println("Ev" + metrics.macroRecallDoc)
60+
println(macroRecallDoc)
61+
assert(math.abs(metrics.precisionClass(0.0) - precision0) < delta)
62+
assert(math.abs(metrics.precisionClass(1.0) - precision1) < delta)
63+
assert(math.abs(metrics.precisionClass(2.0) - precision2) < delta)
64+
assert(math.abs(metrics.recallClass(0.0) - recall0) < delta)
65+
assert(math.abs(metrics.recallClass(1.0) - recall1) < delta)
66+
assert(math.abs(metrics.recallClass(2.0) - recall2) < delta)
67+
assert(math.abs(metrics.f1MeasureClass(0.0) - f1measure0) < delta)
68+
assert(math.abs(metrics.f1MeasureClass(1.0) - f1measure1) < delta)
69+
assert(math.abs(metrics.f1MeasureClass(2.0) - f1measure2) < delta)
70+
71+
assert(math.abs(metrics.microPrecisionClass - microPrecisionClass) < delta)
72+
assert(math.abs(metrics.microRecallClass - microRecallClass) < delta)
73+
assert(math.abs(metrics.microF1MeasureClass - microF1MeasureClass) < delta)
74+
75+
assert(math.abs(metrics.macroPrecisionDoc - macroPrecisionDoc) < delta)
76+
assert(math.abs(metrics.macroRecallDoc - macroRecallDoc) < delta)
77+
78+
79+
}
80+
81+
}

0 commit comments

Comments
 (0)