@@ -32,14 +32,14 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
3232 labels}.distinct.count
3333
3434 /**
35- * Returns strict Accuracy
35+ * Returns subset accuracy
3636 * (for equal sets of labels)
3737 */
38- lazy val strictAccuracy : Double = predictionAndLabels.filter { case (predictions, labels) =>
38+ lazy val subsetAccuracy : Double = predictionAndLabels.filter { case (predictions, labels) =>
3939 predictions == labels}.count.toDouble / numDocs
4040
4141 /**
42- * Returns Accuracy
42+ * Returns accuracy
4343 */
4444 lazy val accuracy : Double = predictionAndLabels.map { case (predictions, labels) =>
4545 labels.intersect(predictions).size.toDouble / labels.union(predictions).size}.sum / numDocs
@@ -52,43 +52,26 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
5252 sum).toDouble / (numDocs * numLabels)
5353
5454 /**
55- * Returns Document -based Precision averaged by the number of documents
55+ * Returns document -based precision averaged by the number of documents
5656 */
57- lazy val macroPrecisionDoc : Double = (predictionAndLabels.map { case (predictions, labels) =>
57+ lazy val precision : Double = (predictionAndLabels.map { case (predictions, labels) =>
5858 if (predictions.size > 0 ) {
5959 predictions.intersect(labels).size.toDouble / predictions.size
6060 } else 0
6161 }.sum) / numDocs
6262
6363 /**
64- * Returns Document -based Recall averaged by the number of documents
64+ * Returns document -based recall averaged by the number of documents
6565 */
66- lazy val macroRecallDoc : Double = (predictionAndLabels.map { case (predictions, labels) =>
66+ lazy val recall : Double = (predictionAndLabels.map { case (predictions, labels) =>
6767 labels.intersect(predictions).size.toDouble / labels.size}.sum) / numDocs
6868
6969 /**
70- * Returns Document -based F1 -measure averaged by the number of documents
70+ * Returns document -based f1 -measure averaged by the number of documents
7171 */
72- lazy val macroF1MeasureDoc : Double = (predictionAndLabels.map { case (predictions, labels) =>
72+ lazy val f1Measure : Double = (predictionAndLabels.map { case (predictions, labels) =>
7373 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)}.sum) / numDocs
7474
75- /**
76- * Returns micro-averaged document-based Precision
77- * (equals to label-based microPrecision)
78- */
79- lazy val microPrecisionDoc : Double = microPrecisionClass
80-
81- /**
82- * Returns micro-averaged document-based Recall
83- * (equals to label-based microRecall)
84- */
85- lazy val microRecallDoc : Double = microRecallClass
86-
87- /**
88- * Returns micro-averaged document-based F1-measure
89- * (equals to label-based microF1measure)
90- */
91- lazy val microF1MeasureDoc : Double = microF1MeasureClass
9275
9376 private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
9477 predictions.intersect(labels).map(category => (category, 1 ))}.reduceByKey(_ + _).collectAsMap()
@@ -100,57 +83,65 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
10083 labels.diff(predictions).map(category => (category, 1 ))}.reduceByKey(_ + _).collectAsMap()
10184
10285 /**
103- * Returns Precision for a given label (category)
86+ * Returns precision for a given label (category)
10487 * @param label the label.
10588 */
106- def precisionClass (label : Double ) = {
89+ def precision (label : Double ) = {
10790 val tp = tpPerClass(label)
10891 val fp = fpPerClass.getOrElse(label, 0 )
10992 if (tp + fp == 0 ) 0 else tp.toDouble / (tp + fp)
11093 }
11194
11295 /**
113- * Returns Recall for a given label (category)
96+ * Returns recall for a given label (category)
11497 * @param label the label.
11598 */
116- def recallClass (label : Double ) = {
99+ def recall (label : Double ) = {
117100 val tp = tpPerClass(label)
118101 val fn = fnPerClass.getOrElse(label, 0 )
119102 if (tp + fn == 0 ) 0 else tp.toDouble / (tp + fn)
120103 }
121104
122105 /**
123- * Returns F1 -measure for a given label (category)
106+ * Returns f1 -measure for a given label (category)
124107 * @param label the label.
125108 */
126- def f1MeasureClass (label : Double ) = {
127- val precision = precisionClass (label)
128- val recall = recallClass (label)
129- if ((precision + recall ) == 0 ) 0 else 2 * precision * recall / (precision + recall )
109+ def f1Measure (label : Double ) = {
110+ val p = precision (label)
111+ val r = recall (label)
112+ if ((p + r ) == 0 ) 0 else 2 * p * r / (p + r )
130113 }
131114
132115 private lazy val sumTp = tpPerClass.foldLeft(0L ){ case (sum, (_, tp)) => sum + tp}
133116 private lazy val sumFpClass = fpPerClass.foldLeft(0L ){ case (sum, (_, fp)) => sum + fp}
134117 private lazy val sumFnClass = fnPerClass.foldLeft(0L ){ case (sum, (_, fn)) => sum + fn}
135118
136119 /**
137- * Returns micro-averaged label-based Precision
120+ * Returns micro-averaged label-based precision
121+ * (equals to micro-averaged document-based precision)
138122 */
139- lazy val microPrecisionClass = {
123+ lazy val microPrecision = {
140124 val sumFp = fpPerClass.foldLeft(0L ){ case (sumFp, (_, fp)) => sumFp + fp}
141125 sumTp.toDouble / (sumTp + sumFp)
142126 }
143127
144128 /**
145- * Returns micro-averaged label-based Recall
129+ * Returns micro-averaged label-based recall
130+ * (equals to micro-averaged document-based recall)
146131 */
147- lazy val microRecallClass = {
132+ lazy val microRecall = {
148133 val sumFn = fnPerClass.foldLeft(0.0 ){ case (sumFn, (_, fn)) => sumFn + fn}
149134 sumTp.toDouble / (sumTp + sumFn)
150135 }
151136
152137 /**
153- * Returns micro-averaged label-based F1-measure
138+ * Returns micro-averaged label-based f1-measure
139+ * (equals to micro-averaged document-based f1-measure)
140+ */
141+ lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
142+
143+ /**
144+ * Returns the sequence of labels in ascending order
154145 */
155- lazy val microF1MeasureClass = 2.0 * sumTp / ( 2 * sumTp + sumFnClass + sumFpClass)
146+ lazy val labels : Array [ Double ] = tpPerClass.keys.toArray.sorted
156147}
0 commit comments