Skip to content

Commit 62d01d2

Browse files
avulanovmengxr
authored andcommitted
[MLLIB] SPARK-2329 Add multi-label evaluation metrics
Implementation of various multi-label classification measures, including: Hamming-loss, strict and default Accuracy, macro-averaged Precision, Recall and F1-measure based on documents and labels, micro-averaged measures: https://issues.apache.org/jira/browse/SPARK-2329 Multi-class measures are currently in the following pull request: #1155 Author: Alexander Ulanov <[email protected]> Author: avulanov <[email protected]> Closes #1270 from avulanov/multilabelmetrics and squashes the following commits: fc8175e [Alexander Ulanov] Merge with previous updates 43a613e [Alexander Ulanov] Addressing reviewers comments: change Set to Array 517a594 [avulanov] Addressing reviewers comments: Scala style cf4222b [avulanov] Addressing reviewers comments: renaming. Added label method that returns the list of labels 1843f73 [Alexander Ulanov] Scala style fix 79e8476 [Alexander Ulanov] Replacing fold(_ + _) with sum as suggested by srowen ca46765 [Alexander Ulanov] Cosmetic changes: Apache header and parameter explanation 40593f5 [Alexander Ulanov] Multi-label metrics: Hamming-loss, strict and normal accuracy, fix to macro measures, bunch of tests ad62df0 [Alexander Ulanov] Comments and scala style check 154164b [Alexander Ulanov] Multilabel evaluation metics and tests: macro precision and recall averaged by docs, micro and per-class precision and recall averaged by class
1 parent 23f73f5 commit 62d01d2

File tree

2 files changed

+260
-0
lines changed

2 files changed

+260
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.SparkContext._
22+
23+
/**
24+
* Evaluator for multilabel classification.
25+
* @param predictionAndLabels an RDD of (predictions, labels) pairs,
26+
* both are non-null Arrays, each with unique elements.
27+
*/
28+
class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
29+
30+
private lazy val numDocs: Long = predictionAndLabels.count()
31+
32+
private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
33+
labels}.distinct().count()
34+
35+
/**
36+
* Returns subset accuracy
37+
* (for equal sets of labels)
38+
*/
39+
lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
40+
predictions.deep == labels.deep
41+
}.count().toDouble / numDocs
42+
43+
/**
44+
* Returns accuracy
45+
*/
46+
lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) =>
47+
labels.intersect(predictions).size.toDouble /
48+
(labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs
49+
50+
51+
/**
52+
* Returns Hamming-loss
53+
*/
54+
lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) =>
55+
labels.size + predictions.size - 2 * labels.intersect(predictions).size
56+
}.sum / (numDocs * numLabels)
57+
58+
/**
59+
* Returns document-based precision averaged by the number of documents
60+
*/
61+
lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) =>
62+
if (predictions.size > 0) {
63+
predictions.intersect(labels).size.toDouble / predictions.size
64+
} else {
65+
0
66+
}
67+
}.sum / numDocs
68+
69+
/**
70+
* Returns document-based recall averaged by the number of documents
71+
*/
72+
lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) =>
73+
labels.intersect(predictions).size.toDouble / labels.size
74+
}.sum / numDocs
75+
76+
/**
77+
* Returns document-based f1-measure averaged by the number of documents
78+
*/
79+
lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) =>
80+
2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)
81+
}.sum / numDocs
82+
83+
private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
84+
predictions.intersect(labels)
85+
}.countByValue()
86+
87+
private lazy val fpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
88+
predictions.diff(labels)
89+
}.countByValue()
90+
91+
private lazy val fnPerClass = predictionAndLabels.flatMap { case(predictions, labels) =>
92+
labels.diff(predictions)
93+
}.countByValue()
94+
95+
/**
96+
* Returns precision for a given label (category)
97+
* @param label the label.
98+
*/
99+
def precision(label: Double) = {
100+
val tp = tpPerClass(label)
101+
val fp = fpPerClass.getOrElse(label, 0L)
102+
if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
103+
}
104+
105+
/**
106+
* Returns recall for a given label (category)
107+
* @param label the label.
108+
*/
109+
def recall(label: Double) = {
110+
val tp = tpPerClass(label)
111+
val fn = fnPerClass.getOrElse(label, 0L)
112+
if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
113+
}
114+
115+
/**
116+
* Returns f1-measure for a given label (category)
117+
* @param label the label.
118+
*/
119+
def f1Measure(label: Double) = {
120+
val p = precision(label)
121+
val r = recall(label)
122+
if((p + r) == 0) 0 else 2 * p * r / (p + r)
123+
}
124+
125+
private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp }
126+
private lazy val sumFpClass = fpPerClass.foldLeft(0L) { case (sum, (_, fp)) => sum + fp }
127+
private lazy val sumFnClass = fnPerClass.foldLeft(0L) { case (sum, (_, fn)) => sum + fn }
128+
129+
/**
130+
* Returns micro-averaged label-based precision
131+
* (equals to micro-averaged document-based precision)
132+
*/
133+
lazy val microPrecision = {
134+
val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
135+
sumTp.toDouble / (sumTp + sumFp)
136+
}
137+
138+
/**
139+
* Returns micro-averaged label-based recall
140+
* (equals to micro-averaged document-based recall)
141+
*/
142+
lazy val microRecall = {
143+
val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
144+
sumTp.toDouble / (sumTp + sumFn)
145+
}
146+
147+
/**
148+
* Returns micro-averaged label-based f1-measure
149+
* (equals to micro-averaged document-based f1-measure)
150+
*/
151+
lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
152+
153+
/**
154+
* Returns the sequence of labels in ascending order
155+
*/
156+
lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted
157+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.mllib.util.LocalSparkContext
23+
import org.apache.spark.rdd.RDD
24+
25+
class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
26+
test("Multilabel evaluation metrics") {
27+
/*
28+
* Documents true labels (5x class0, 3x class1, 4x class2):
29+
* doc 0 - predict 0, 1 - class 0, 2
30+
* doc 1 - predict 0, 2 - class 0, 1
31+
* doc 2 - predict none - class 0
32+
* doc 3 - predict 2 - class 2
33+
* doc 4 - predict 2, 0 - class 2, 0
34+
* doc 5 - predict 0, 1, 2 - class 0, 1
35+
* doc 6 - predict 1 - class 1, 2
36+
*
37+
* predicted classes
38+
* class 0 - doc 0, 1, 4, 5 (total 4)
39+
* class 1 - doc 0, 5, 6 (total 3)
40+
* class 2 - doc 1, 3, 4, 5 (total 4)
41+
*
42+
* true classes
43+
* class 0 - doc 0, 1, 2, 4, 5 (total 5)
44+
* class 1 - doc 1, 5, 6 (total 3)
45+
* class 2 - doc 0, 3, 4, 6 (total 4)
46+
*
47+
*/
48+
val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize(
49+
Seq((Array(0.0, 1.0), Array(0.0, 2.0)),
50+
(Array(0.0, 2.0), Array(0.0, 1.0)),
51+
(Array(), Array(0.0)),
52+
(Array(2.0), Array(2.0)),
53+
(Array(2.0, 0.0), Array(2.0, 0.0)),
54+
(Array(0.0, 1.0, 2.0), Array(0.0, 1.0)),
55+
(Array(1.0), Array(1.0, 2.0))), 2)
56+
val metrics = new MultilabelMetrics(scoreAndLabels)
57+
val delta = 0.00001
58+
val precision0 = 4.0 / (4 + 0)
59+
val precision1 = 2.0 / (2 + 1)
60+
val precision2 = 2.0 / (2 + 2)
61+
val recall0 = 4.0 / (4 + 1)
62+
val recall1 = 2.0 / (2 + 1)
63+
val recall2 = 2.0 / (2 + 2)
64+
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
65+
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
66+
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
67+
val sumTp = 4 + 2 + 2
68+
assert(sumTp == (1 + 1 + 0 + 1 + 2 + 2 + 1))
69+
val microPrecisionClass = sumTp.toDouble / (4 + 0 + 2 + 1 + 2 + 2)
70+
val microRecallClass = sumTp.toDouble / (4 + 1 + 2 + 1 + 2 + 2)
71+
val microF1MeasureClass = 2.0 * sumTp.toDouble /
72+
(2 * sumTp.toDouble + (1 + 1 + 2) + (0 + 1 + 2))
73+
val macroPrecisionDoc = 1.0 / 7 *
74+
(1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0)
75+
val macroRecallDoc = 1.0 / 7 *
76+
(1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2)
77+
val macroF1MeasureDoc = (1.0 / 7) *
78+
2 * ( 1.0 / (2 + 2) + 1.0 / (2 + 2) + 0 + 1.0 / (1 + 1) +
79+
2.0 / (2 + 2) + 2.0 / (3 + 2) + 1.0 / (1 + 2) )
80+
val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1)
81+
val strictAccuracy = 2.0 / 7
82+
val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2)
83+
assert(math.abs(metrics.precision(0.0) - precision0) < delta)
84+
assert(math.abs(metrics.precision(1.0) - precision1) < delta)
85+
assert(math.abs(metrics.precision(2.0) - precision2) < delta)
86+
assert(math.abs(metrics.recall(0.0) - recall0) < delta)
87+
assert(math.abs(metrics.recall(1.0) - recall1) < delta)
88+
assert(math.abs(metrics.recall(2.0) - recall2) < delta)
89+
assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
90+
assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
91+
assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
92+
assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta)
93+
assert(math.abs(metrics.microRecall - microRecallClass) < delta)
94+
assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta)
95+
assert(math.abs(metrics.precision - macroPrecisionDoc) < delta)
96+
assert(math.abs(metrics.recall - macroRecallDoc) < delta)
97+
assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta)
98+
assert(math.abs(metrics.hammingLoss - hammingLoss) < delta)
99+
assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta)
100+
assert(math.abs(metrics.accuracy - accuracy) < delta)
101+
assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0)))
102+
}
103+
}

0 commit comments

Comments
 (0)