Skip to content

Commit fc8175e

Browse files
committed
Merge with previous updates
2 parents 43a613e + 517a594 commit fc8175e

File tree

2 files changed

+74
-74
lines changed

2 files changed

+74
-74
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala

Lines changed: 57 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -33,126 +33,125 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
3333
labels}.distinct().count()
3434

3535
/**
36-
* Returns strict Accuracy
36+
* Returns subset accuracy
3737
* (for equal sets of labels)
3838
*/
39-
lazy val strictAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
40-
predictions.deep == labels.deep }.count().toDouble / numDocs
39+
lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
40+
predictions.deep == labels.deep
41+
}.count().toDouble / numDocs
4142

4243
/**
43-
* Returns Accuracy
44+
* Returns accuracy
4445
*/
4546
lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) =>
4647
labels.intersect(predictions).size.toDouble /
4748
(labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs
4849

50+
4951
/**
5052
* Returns Hamming-loss
5153
*/
5254
lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) =>
53-
labels.diff(predictions).size + predictions.diff(labels).size}.
54-
sum / (numDocs * numLabels)
55+
labels.size + predictions.size - 2 * labels.intersect(predictions).size
56+
}.sum / (numDocs * numLabels)
5557

5658
/**
57-
* Returns Document-based Precision averaged by the number of documents
59+
* Returns document-based precision averaged by the number of documents
5860
*/
59-
lazy val macroPrecisionDoc: Double = predictionAndLabels.map { case (predictions, labels) =>
61+
lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) =>
6062
if (predictions.size > 0) {
6163
predictions.intersect(labels).size.toDouble / predictions.size
62-
} else 0
64+
} else {
65+
0
66+
}
6367
}.sum / numDocs
6468

6569
/**
66-
* Returns Document-based Recall averaged by the number of documents
67-
*/
68-
lazy val macroRecallDoc: Double = predictionAndLabels.map { case (predictions, labels) =>
69-
labels.intersect(predictions).size.toDouble / labels.size}.sum / numDocs
70-
71-
/**
72-
* Returns Document-based F1-measure averaged by the number of documents
70+
* Returns document-based recall averaged by the number of documents
7371
*/
74-
lazy val macroF1MeasureDoc: Double = predictionAndLabels.map { case (predictions, labels) =>
75-
2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)}.sum / numDocs
76-
77-
/**
78-
* Returns micro-averaged document-based Precision
79-
* (equals to label-based microPrecision)
80-
*/
81-
lazy val microPrecisionDoc: Double = microPrecisionClass
82-
83-
/**
84-
* Returns micro-averaged document-based Recall
85-
* (equals to label-based microRecall)
86-
*/
87-
lazy val microRecallDoc: Double = microRecallClass
72+
lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) =>
73+
labels.intersect(predictions).size.toDouble / labels.size
74+
}.sum / numDocs
8875

8976
/**
90-
* Returns micro-averaged document-based F1-measure
91-
* (equals to label-based microF1measure)
77+
* Returns document-based f1-measure averaged by the number of documents
9278
*/
93-
lazy val microF1MeasureDoc: Double = microF1MeasureClass
79+
lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) =>
80+
2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)
81+
}.sum / numDocs
9482

9583
private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
96-
predictions.intersect(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
84+
predictions.intersect(labels)
85+
}.countByValue()
9786

98-
private lazy val fpPerClass = predictionAndLabels.flatMap { case(predictions, labels) =>
99-
predictions.diff(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
87+
private lazy val fpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
88+
predictions.diff(labels)
89+
}.countByValue()
10090

101-
private lazy val fnPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
102-
labels.diff(predictions).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
91+
private lazy val fnPerClass = predictionAndLabels.flatMap { case(predictions, labels) =>
92+
labels.diff(predictions)
93+
}.countByValue()
10394

10495
/**
105-
* Returns Precision for a given label (category)
96+
* Returns precision for a given label (category)
10697
* @param label the label.
10798
*/
108-
def precisionClass(label: Double) = {
99+
def precision(label: Double) = {
109100
val tp = tpPerClass(label)
110-
val fp = fpPerClass.getOrElse(label, 0)
101+
val fp = fpPerClass.getOrElse(label, 0L)
111102
if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
112103
}
113104

114105
/**
115-
* Returns Recall for a given label (category)
106+
* Returns recall for a given label (category)
116107
* @param label the label.
117108
*/
118-
def recallClass(label: Double) = {
109+
def recall(label: Double) = {
119110
val tp = tpPerClass(label)
120-
val fn = fnPerClass.getOrElse(label, 0)
111+
val fn = fnPerClass.getOrElse(label, 0L)
121112
if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
122113
}
123114

124115
/**
125-
* Returns F1-measure for a given label (category)
116+
* Returns f1-measure for a given label (category)
126117
* @param label the label.
127118
*/
128-
def f1MeasureClass(label: Double) = {
129-
val precision = precisionClass(label)
130-
val recall = recallClass(label)
131-
if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall)
119+
def f1Measure(label: Double) = {
120+
val p = precision(label)
121+
val r = recall(label)
122+
if((p + r) == 0) 0 else 2 * p * r / (p + r)
132123
}
133124

134-
private lazy val sumTp = tpPerClass.foldLeft(0L){ case (sum, (_, tp)) => sum + tp}
135-
private lazy val sumFpClass = fpPerClass.foldLeft(0L){ case (sum, (_, fp)) => sum + fp}
136-
private lazy val sumFnClass = fnPerClass.foldLeft(0L){ case (sum, (_, fn)) => sum + fn}
125+
private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp }
126+
private lazy val sumFpClass = fpPerClass.foldLeft(0L) { case (sum, (_, fp)) => sum + fp }
127+
private lazy val sumFnClass = fnPerClass.foldLeft(0L) { case (sum, (_, fn)) => sum + fn }
137128

138129
/**
139-
* Returns micro-averaged label-based Precision
130+
* Returns micro-averaged label-based precision
131+
* (equals to micro-averaged document-based precision)
140132
*/
141-
lazy val microPrecisionClass = {
133+
lazy val microPrecision = {
142134
val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
143135
sumTp.toDouble / (sumTp + sumFp)
144136
}
145137

146138
/**
147-
* Returns micro-averaged label-based Recall
139+
* Returns micro-averaged label-based recall
140+
* (equals to micro-averaged document-based recall)
148141
*/
149-
lazy val microRecallClass = {
142+
lazy val microRecall = {
150143
val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
151144
sumTp.toDouble / (sumTp + sumFn)
152145
}
153146

154147
/**
155-
* Returns micro-averaged label-based F1-measure
148+
* Returns micro-averaged label-based f1-measure
149+
* (equals to micro-averaged document-based f1-measure)
150+
*/
151+
lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
152+
153+
/**
154+
* Returns the sequence of labels in ascending order
156155
*/
157-
lazy val microF1MeasureClass = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
156+
lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted
158157
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,24 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
8080
val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1)
8181
val strictAccuracy = 2.0 / 7
8282
val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2)
83-
assert(math.abs(metrics.precisionClass(0.0) - precision0) < delta)
84-
assert(math.abs(metrics.precisionClass(1.0) - precision1) < delta)
85-
assert(math.abs(metrics.precisionClass(2.0) - precision2) < delta)
86-
assert(math.abs(metrics.recallClass(0.0) - recall0) < delta)
87-
assert(math.abs(metrics.recallClass(1.0) - recall1) < delta)
88-
assert(math.abs(metrics.recallClass(2.0) - recall2) < delta)
89-
assert(math.abs(metrics.f1MeasureClass(0.0) - f1measure0) < delta)
90-
assert(math.abs(metrics.f1MeasureClass(1.0) - f1measure1) < delta)
91-
assert(math.abs(metrics.f1MeasureClass(2.0) - f1measure2) < delta)
92-
assert(math.abs(metrics.microPrecisionClass - microPrecisionClass) < delta)
93-
assert(math.abs(metrics.microRecallClass - microRecallClass) < delta)
94-
assert(math.abs(metrics.microF1MeasureClass - microF1MeasureClass) < delta)
95-
assert(math.abs(metrics.macroPrecisionDoc - macroPrecisionDoc) < delta)
96-
assert(math.abs(metrics.macroRecallDoc - macroRecallDoc) < delta)
97-
assert(math.abs(metrics.macroF1MeasureDoc - macroF1MeasureDoc) < delta)
83+
assert(math.abs(metrics.precision(0.0) - precision0) < delta)
84+
assert(math.abs(metrics.precision(1.0) - precision1) < delta)
85+
assert(math.abs(metrics.precision(2.0) - precision2) < delta)
86+
assert(math.abs(metrics.recall(0.0) - recall0) < delta)
87+
assert(math.abs(metrics.recall(1.0) - recall1) < delta)
88+
assert(math.abs(metrics.recall(2.0) - recall2) < delta)
89+
assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
90+
assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
91+
assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
92+
assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta)
93+
assert(math.abs(metrics.microRecall - microRecallClass) < delta)
94+
assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta)
95+
assert(math.abs(metrics.precision - macroPrecisionDoc) < delta)
96+
assert(math.abs(metrics.recall - macroRecallDoc) < delta)
97+
assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta)
9898
assert(math.abs(metrics.hammingLoss - hammingLoss) < delta)
99-
assert(math.abs(metrics.strictAccuracy - strictAccuracy) < delta)
99+
assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta)
100100
assert(math.abs(metrics.accuracy - accuracy) < delta)
101+
assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0)))
101102
}
102103
}

0 commit comments

Comments
 (0)