1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one or more
3+ * contributor license agreements. See the NOTICE file distributed with
4+ * this work for additional information regarding copyright ownership.
5+ * The ASF licenses this file to You under the Apache License, Version 2.0
6+ * (the "License"); you may not use this file except in compliance with
7+ * the License. You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
17+
118package org .apache .spark .mllib .evaluation
219
320import org .apache .spark .Logging
421import org .apache .spark .rdd .RDD
522import org .apache .spark .SparkContext ._
623
7-
24+ /**
25+ * Evaluator for multilabel classification.
26+ * NB: type Double both for prediction and label is retained
27+ * for compatibility with model.predict that returns Double
28+ * and MLUtils.loadLibSVMFile that loads class labels as Double
29+ *
30+ * @param predictionAndLabels an RDD of pairs (predictions, labels) sets.
31+ */
832class MultilabelMetrics (predictionAndLabels: RDD [(Set [Double ], Set [Double ])]) extends Logging {
933
1034 private lazy val numDocs = predictionAndLabels.count()
1135
36+ /**
37+ * Returns Document-based Precision averaged by the number of documents
38+ * @return macroPrecisionDoc.
39+ */
1240 lazy val macroPrecisionDoc = (predictionAndLabels.map{ case (predictions, labels) =>
13- if (predictions.size > 0 )
14- predictions.intersect(labels).size.toDouble / predictions.size else 0 }.fold(0.0 )(_ + _)) / numDocs
15-
41+ if (predictions.size > 0 )
42+ predictions.intersect(labels).size.toDouble / predictions.size else 0 }.fold(0.0 )(_ + _)) /
43+ numDocs
44+
45+ /**
46+ * Returns Document-based Recall averaged by the number of documents
47+ * @return macroRecallDoc.
48+ */
1649 lazy val macroRecallDoc = (predictionAndLabels.map{ case (predictions, labels) =>
1750 predictions.intersect(labels).size.toDouble / labels.size}.fold(0.0 )(_ + _)) / numDocs
1851
52+ /**
53+ * Returns micro-averaged document-based Precision
54+ * @return microPrecisionDoc.
55+ */
1956 lazy val microPrecisionDoc = {
2057 val (sumTp, sumPredictions) = predictionAndLabels.map{ case (predictions, labels) =>
2158 (predictions.intersect(labels).size, predictions.size)}.
@@ -24,6 +61,10 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
2461 sumTp.toDouble / sumPredictions
2562 }
2663
64+ /**
65+ * Returns micro-averaged document-based Recall
66+ * @return microRecallDoc.
67+ */
2768 lazy val microRecallDoc = {
2869 val (sumTp, sumLabels) = predictionAndLabels.map{ case (predictions, labels) =>
2970 (predictions.intersect(labels).size, labels.size)}.
@@ -41,12 +82,28 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
4182 private lazy val fnPerClass = predictionAndLabels.flatMap{ case (predictions, labels) =>
4283 labels.diff(predictions).map(category => (category, 1 ))}.reduceByKey(_ + _).collectAsMap()
4384
44- def precisionClass (label : Double ) = if ((tpPerClass(label) + fpPerClass.getOrElse(label, 0 )) == 0 ) 0 else
45- tpPerClass(label).toDouble / (tpPerClass(label) + fpPerClass.getOrElse(label, 0 ))
46-
47- def recallClass (label : Double ) = if ((tpPerClass(label) + fnPerClass.getOrElse(label, 0 )) == 0 ) 0 else
85+ /**
86+ * Returns Precision for a given label (category)
87+ * @param label the label.
88+ * @return Precision.
89+ */
90+ def precisionClass (label : Double ) = if ((tpPerClass(label) + fpPerClass.getOrElse(label, 0 )) == 0 )
91+ 0 else tpPerClass(label).toDouble / (tpPerClass(label) + fpPerClass.getOrElse(label, 0 ))
92+
93+ /**
94+ * Returns Recall for a given label (category)
95+ * @param label the label.
96+ * @return Recall.
97+ */
98+ def recallClass (label : Double ) = if ((tpPerClass(label) + fnPerClass.getOrElse(label, 0 )) == 0 )
99+ 0 else
48100 tpPerClass(label).toDouble / (tpPerClass(label) + fnPerClass.getOrElse(label, 0 ))
49101
102+ /**
103+ * Returns F1-measure for a given label (category)
104+ * @param label the label.
105+ * @return F1-measure.
106+ */
50107 def f1MeasureClass (label : Double ) = {
51108 val precision = precisionClass(label)
52109 val recall = recallClass(label)
@@ -55,16 +112,28 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
55112
56113 private lazy val sumTp = tpPerClass.foldLeft(0L ){ case (sumTp, (_, tp)) => sumTp + tp}
57114
115+ /**
116+ * Returns micro-averaged label-based Precision
117+ * @return microPrecisionClass.
118+ */
58119 lazy val microPrecisionClass = {
59120 val sumFp = fpPerClass.foldLeft(0L ){ case (sumFp, (_, fp)) => sumFp + fp}
60121 sumTp.toDouble / (sumTp + sumFp)
61122 }
62123
124+ /**
125+ * Returns micro-averaged label-based Recall
126+ * @return microRecallClass.
127+ */
63128 lazy val microRecallClass = {
64129 val sumFn = fnPerClass.foldLeft(0.0 ){ case (sumFn, (_, fn)) => sumFn + fn}
65130 sumTp.toDouble / (sumTp + sumFn)
66131 }
67132
133+ /**
134+ * Returns micro-averaged label-based F1-measure
135+ * @return microRecallClass.
136+ */
68137 lazy val microF1MeasureClass = {
69138 val precision = microPrecisionClass
70139 val recall = microRecallClass
0 commit comments