Skip to content

Commit cf4222b

Browse files
committed
Addressing reviewers comments: renaming. Added label method that returns the list of labels
1 parent 1843f73 commit cf4222b

File tree

2 files changed

+49
-57
lines changed

2 files changed

+49
-57
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
3232
labels}.distinct.count
3333

3434
/**
35-
* Returns strict Accuracy
35+
* Returns subset accuracy
3636
* (for equal sets of labels)
3737
*/
38-
lazy val strictAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
38+
lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
3939
predictions == labels}.count.toDouble / numDocs
4040

4141
/**
42-
* Returns Accuracy
42+
* Returns accuracy
4343
*/
4444
lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) =>
4545
labels.intersect(predictions).size.toDouble / labels.union(predictions).size}.sum / numDocs
@@ -52,43 +52,26 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
5252
sum).toDouble / (numDocs * numLabels)
5353

5454
/**
55-
* Returns Document-based Precision averaged by the number of documents
55+
* Returns document-based precision averaged by the number of documents
5656
*/
57-
lazy val macroPrecisionDoc: Double = (predictionAndLabels.map { case (predictions, labels) =>
57+
lazy val precision: Double = (predictionAndLabels.map { case (predictions, labels) =>
5858
if (predictions.size > 0) {
5959
predictions.intersect(labels).size.toDouble / predictions.size
6060
} else 0
6161
}.sum) / numDocs
6262

6363
/**
64-
* Returns Document-based Recall averaged by the number of documents
64+
* Returns document-based recall averaged by the number of documents
6565
*/
66-
lazy val macroRecallDoc: Double = (predictionAndLabels.map { case (predictions, labels) =>
66+
lazy val recall: Double = (predictionAndLabels.map { case (predictions, labels) =>
6767
labels.intersect(predictions).size.toDouble / labels.size}.sum) / numDocs
6868

6969
/**
70-
* Returns Document-based F1-measure averaged by the number of documents
70+
* Returns document-based f1-measure averaged by the number of documents
7171
*/
72-
lazy val macroF1MeasureDoc: Double = (predictionAndLabels.map { case (predictions, labels) =>
72+
lazy val f1Measure: Double = (predictionAndLabels.map { case (predictions, labels) =>
7373
2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)}.sum) / numDocs
7474

75-
/**
76-
* Returns micro-averaged document-based Precision
77-
* (equals to label-based microPrecision)
78-
*/
79-
lazy val microPrecisionDoc: Double = microPrecisionClass
80-
81-
/**
82-
* Returns micro-averaged document-based Recall
83-
* (equals to label-based microRecall)
84-
*/
85-
lazy val microRecallDoc: Double = microRecallClass
86-
87-
/**
88-
* Returns micro-averaged document-based F1-measure
89-
* (equals to label-based microF1measure)
90-
*/
91-
lazy val microF1MeasureDoc: Double = microF1MeasureClass
9275

9376
private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
9477
predictions.intersect(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
@@ -100,57 +83,65 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
10083
labels.diff(predictions).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
10184

10285
/**
103-
* Returns Precision for a given label (category)
86+
* Returns precision for a given label (category)
10487
* @param label the label.
10588
*/
106-
def precisionClass(label: Double) = {
89+
def precision(label: Double) = {
10790
val tp = tpPerClass(label)
10891
val fp = fpPerClass.getOrElse(label, 0)
10992
if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
11093
}
11194

11295
/**
113-
* Returns Recall for a given label (category)
96+
* Returns recall for a given label (category)
11497
* @param label the label.
11598
*/
116-
def recallClass(label: Double) = {
99+
def recall(label: Double) = {
117100
val tp = tpPerClass(label)
118101
val fn = fnPerClass.getOrElse(label, 0)
119102
if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
120103
}
121104

122105
/**
123-
* Returns F1-measure for a given label (category)
106+
* Returns f1-measure for a given label (category)
124107
* @param label the label.
125108
*/
126-
def f1MeasureClass(label: Double) = {
127-
val precision = precisionClass(label)
128-
val recall = recallClass(label)
129-
if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall)
109+
def f1Measure(label: Double) = {
110+
val p = precision(label)
111+
val r = recall(label)
112+
if((p + r) == 0) 0 else 2 * p * r / (p + r)
130113
}
131114

132115
private lazy val sumTp = tpPerClass.foldLeft(0L){ case (sum, (_, tp)) => sum + tp}
133116
private lazy val sumFpClass = fpPerClass.foldLeft(0L){ case (sum, (_, fp)) => sum + fp}
134117
private lazy val sumFnClass = fnPerClass.foldLeft(0L){ case (sum, (_, fn)) => sum + fn}
135118

136119
/**
137-
* Returns micro-averaged label-based Precision
120+
* Returns micro-averaged label-based precision
121+
* (equals to micro-averaged document-based precision)
138122
*/
139-
lazy val microPrecisionClass = {
123+
lazy val microPrecision = {
140124
val sumFp = fpPerClass.foldLeft(0L){ case(sumFp, (_, fp)) => sumFp + fp}
141125
sumTp.toDouble / (sumTp + sumFp)
142126
}
143127

144128
/**
145-
* Returns micro-averaged label-based Recall
129+
* Returns micro-averaged label-based recall
130+
* (equals to micro-averaged document-based recall)
146131
*/
147-
lazy val microRecallClass = {
132+
lazy val microRecall = {
148133
val sumFn = fnPerClass.foldLeft(0.0){ case(sumFn, (_, fn)) => sumFn + fn}
149134
sumTp.toDouble / (sumTp + sumFn)
150135
}
151136

152137
/**
153-
* Returns micro-averaged label-based F1-measure
138+
* Returns micro-averaged label-based f1-measure
139+
* (equals to micro-averaged document-based f1-measure)
140+
*/
141+
lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
142+
143+
/**
144+
* Returns the sequence of labels in ascending order
154145
*/
155-
lazy val microF1MeasureClass = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
146+
lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted
156147
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,24 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
8080
val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1)
8181
val strictAccuracy = 2.0 / 7
8282
val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2)
83-
assert(math.abs(metrics.precisionClass(0.0) - precision0) < delta)
84-
assert(math.abs(metrics.precisionClass(1.0) - precision1) < delta)
85-
assert(math.abs(metrics.precisionClass(2.0) - precision2) < delta)
86-
assert(math.abs(metrics.recallClass(0.0) - recall0) < delta)
87-
assert(math.abs(metrics.recallClass(1.0) - recall1) < delta)
88-
assert(math.abs(metrics.recallClass(2.0) - recall2) < delta)
89-
assert(math.abs(metrics.f1MeasureClass(0.0) - f1measure0) < delta)
90-
assert(math.abs(metrics.f1MeasureClass(1.0) - f1measure1) < delta)
91-
assert(math.abs(metrics.f1MeasureClass(2.0) - f1measure2) < delta)
92-
assert(math.abs(metrics.microPrecisionClass - microPrecisionClass) < delta)
93-
assert(math.abs(metrics.microRecallClass - microRecallClass) < delta)
94-
assert(math.abs(metrics.microF1MeasureClass - microF1MeasureClass) < delta)
95-
assert(math.abs(metrics.macroPrecisionDoc - macroPrecisionDoc) < delta)
96-
assert(math.abs(metrics.macroRecallDoc - macroRecallDoc) < delta)
97-
assert(math.abs(metrics.macroF1MeasureDoc - macroF1MeasureDoc) < delta)
83+
assert(math.abs(metrics.precision(0.0) - precision0) < delta)
84+
assert(math.abs(metrics.precision(1.0) - precision1) < delta)
85+
assert(math.abs(metrics.precision(2.0) - precision2) < delta)
86+
assert(math.abs(metrics.recall(0.0) - recall0) < delta)
87+
assert(math.abs(metrics.recall(1.0) - recall1) < delta)
88+
assert(math.abs(metrics.recall(2.0) - recall2) < delta)
89+
assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
90+
assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
91+
assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
92+
assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta)
93+
assert(math.abs(metrics.microRecall - microRecallClass) < delta)
94+
assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta)
95+
assert(math.abs(metrics.precision - macroPrecisionDoc) < delta)
96+
assert(math.abs(metrics.recall - macroRecallDoc) < delta)
97+
assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta)
9898
assert(math.abs(metrics.hammingLoss - hammingLoss) < delta)
99-
assert(math.abs(metrics.strictAccuracy - strictAccuracy) < delta)
99+
assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta)
100100
assert(math.abs(metrics.accuracy - accuracy) < delta)
101+
assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0)))
101102
}
102103
}

0 commit comments

Comments
 (0)