Skip to content

Commit ad62df0

Browse files
committed
Comments and scala style check
1 parent 154164b commit ad62df0

File tree

1 file changed

+77
-8
lines changed

1 file changed

+77
-8
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
118
package org.apache.spark.mllib.evaluation
219

320
import org.apache.spark.Logging
421
import org.apache.spark.rdd.RDD
522
import org.apache.spark.SparkContext._
623

7-
24+
/**
25+
* Evaluator for multilabel classification.
26+
* NB: type Double both for prediction and label is retained
27+
* for compatibility with model.predict that returns Double
28+
* and MLUtils.loadLibSVMFile that loads class labels as Double
29+
*
30+
* @param predictionAndLabels an RDD of pairs (predictions, labels) sets.
31+
*/
832
class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) extends Logging{
933

1034
private lazy val numDocs = predictionAndLabels.count()
1135

36+
/**
37+
* Returns Document-based Precision averaged by the number of documents
38+
* @return macroPrecisionDoc.
39+
*/
1240
lazy val macroPrecisionDoc = (predictionAndLabels.map{ case(predictions, labels) =>
13-
if (predictions.size >0)
14-
predictions.intersect(labels).size.toDouble / predictions.size else 0}.fold(0.0)(_ + _)) / numDocs
15-
41+
if(predictions.size >0)
42+
predictions.intersect(labels).size.toDouble / predictions.size else 0}.fold(0.0)(_ + _)) /
43+
numDocs
44+
45+
/**
46+
* Returns Document-based Recall averaged by the number of documents
47+
* @return macroRecallDoc.
48+
*/
1649
lazy val macroRecallDoc = (predictionAndLabels.map{ case(predictions, labels) =>
1750
predictions.intersect(labels).size.toDouble / labels.size}.fold(0.0)(_ + _)) / numDocs
1851

52+
/**
53+
* Returns micro-averaged document-based Precision
54+
* @return microPrecisionDoc.
55+
*/
1956
lazy val microPrecisionDoc = {
2057
val (sumTp, sumPredictions) = predictionAndLabels.map{ case(predictions, labels) =>
2158
(predictions.intersect(labels).size, predictions.size)}.
@@ -24,6 +61,10 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
2461
sumTp.toDouble / sumPredictions
2562
}
2663

64+
/**
65+
* Returns micro-averaged document-based Recall
66+
* @return microRecallDoc.
67+
*/
2768
lazy val microRecallDoc = {
2869
val (sumTp, sumLabels) = predictionAndLabels.map{ case(predictions, labels) =>
2970
(predictions.intersect(labels).size, labels.size)}.
@@ -41,12 +82,28 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
4182
private lazy val fnPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
4283
labels.diff(predictions).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
4384

44-
def precisionClass(label: Double) = if((tpPerClass(label) + fpPerClass.getOrElse(label, 0)) == 0) 0 else
45-
tpPerClass(label).toDouble / (tpPerClass(label) + fpPerClass.getOrElse(label, 0))
46-
47-
def recallClass(label: Double) = if((tpPerClass(label) + fnPerClass.getOrElse(label, 0)) == 0) 0 else
85+
/**
86+
* Returns Precision for a given label (category)
87+
* @param label the label.
88+
* @return Precision.
89+
*/
90+
def precisionClass(label: Double) = if((tpPerClass(label) + fpPerClass.getOrElse(label, 0)) == 0)
91+
0 else tpPerClass(label).toDouble / (tpPerClass(label) + fpPerClass.getOrElse(label, 0))
92+
93+
/**
94+
* Returns Recall for a given label (category)
95+
* @param label the label.
96+
* @return Recall.
97+
*/
98+
def recallClass(label: Double) = if((tpPerClass(label) + fnPerClass.getOrElse(label, 0)) == 0)
99+
0 else
48100
tpPerClass(label).toDouble / (tpPerClass(label) + fnPerClass.getOrElse(label, 0))
49101

102+
/**
103+
* Returns F1-measure for a given label (category)
104+
* @param label the label.
105+
* @return F1-measure.
106+
*/
50107
def f1MeasureClass(label: Double) = {
51108
val precision = precisionClass(label)
52109
val recall = recallClass(label)
@@ -55,16 +112,28 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
55112

56113
private lazy val sumTp = tpPerClass.foldLeft(0L){ case(sumTp, (_, tp)) => sumTp + tp}
57114

115+
/**
116+
* Returns micro-averaged label-based Precision
117+
* @return microPrecisionClass.
118+
*/
58119
lazy val microPrecisionClass = {
59120
val sumFp = fpPerClass.foldLeft(0L){ case(sumFp, (_, fp)) => sumFp + fp}
60121
sumTp.toDouble / (sumTp + sumFp)
61122
}
62123

124+
/**
125+
* Returns micro-averaged label-based Recall
126+
* @return microRecallClass.
127+
*/
63128
lazy val microRecallClass = {
64129
val sumFn = fnPerClass.foldLeft(0.0){ case(sumFn, (_, fn)) => sumFn + fn}
65130
sumTp.toDouble / (sumTp + sumFn)
66131
}
67132

133+
/**
134+
* Returns micro-averaged label-based F1-measure
135+
* @return microRecallClass.
136+
*/
68137
lazy val microF1MeasureClass = {
69138
val precision = microPrecisionClass
70139
val recall = microRecallClass

0 commit comments

Comments
 (0)