|
| 1 | +package org.apache.spark.mllib.evaluation |
| 2 | + |
| 3 | +import org.apache.spark.mllib.util.LocalSparkContext |
| 4 | +import org.apache.spark.rdd.RDD |
| 5 | +import org.scalatest.FunSuite |
| 6 | + |
| 7 | + |
| 8 | +class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { |
| 9 | + test("Multilabel evaluation metrics") { |
| 10 | + /* |
| 11 | + * Documents true labels (5x class0, 3x class1, 4x class2): |
| 12 | + * doc 0 - predict 0, 1 - class 0, 2 |
| 13 | + * doc 1 - predict 0, 2 - class 0, 1 |
| 14 | + * doc 2 - predict none - class 0 |
| 15 | + * doc 3 - predict 2 - class 2 |
| 16 | + * doc 4 - predict 2, 0 - class 2, 0 |
| 17 | + * doc 5 - predict 0, 1, 2 - class 0, 1 |
| 18 | + * doc 6 - predict 1 - class 1, 2 |
| 19 | + * |
| 20 | + * predicted classes |
| 21 | + * class 0 - doc 0, 1, 4, 5 (total 4) |
| 22 | + * class 1 - doc 0, 5, 6 (total 3) |
| 23 | + * class 2 - doc 1, 3, 4, 5 (total 4) |
| 24 | + * |
| 25 | + * true classes |
| 26 | + * class 0 - doc 0, 1, 2, 4, 5 (total 5) |
| 27 | + * class 1 - doc 1, 5, 6 (total 3) |
| 28 | + * class 2 - doc 0, 3, 4, 6 (total 4) |
| 29 | + * |
| 30 | + */ |
| 31 | + val scoreAndLabels:RDD[(Set[Double], Set[Double])] = sc.parallelize( |
| 32 | + Seq((Set(0.0, 1.0), Set(0.0, 2.0)), |
| 33 | + (Set(0.0, 2.0), Set(0.0, 1.0)), |
| 34 | + (Set(), Set(0.0)), |
| 35 | + (Set(2.0), Set(2.0)), |
| 36 | + (Set(2.0, 0.0), Set(2.0, 0.0)), |
| 37 | + (Set(0.0, 1.0, 2.0), Set(0.0, 1.0)), |
| 38 | + (Set(1.0), Set(1.0, 2.0))), 2) |
| 39 | + val metrics = new MultilabelMetrics(scoreAndLabels) |
| 40 | + val delta = 0.00001 |
| 41 | + val precision0 = 4.0 / (4 + 0) |
| 42 | + val precision1 = 2.0 / (2 + 1) |
| 43 | + val precision2 = 2.0 / (2 + 2) |
| 44 | + val recall0 = 4.0 / (4 + 1) |
| 45 | + val recall1 = 2.0 / (2 + 1) |
| 46 | + val recall2 = 2.0 / (2 + 2) |
| 47 | + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) |
| 48 | + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) |
| 49 | + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) |
| 50 | + val microPrecisionClass = (4.0 + 2.0 + 2.0) / (4 + 0 + 2 + 1 + 2 + 2) |
| 51 | + val microRecallClass = (4.0 + 2.0 + 2.0) / (4 + 1 + 2 + 1 + 2 + 2) |
| 52 | + val microF1MeasureClass = 2 * microPrecisionClass * microRecallClass / (microPrecisionClass + microRecallClass) |
| 53 | + |
| 54 | + val macroPrecisionDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0) |
| 55 | + val macroRecallDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2) |
| 56 | + |
| 57 | + println("Ev" + metrics.macroPrecisionDoc) |
| 58 | + println(macroPrecisionDoc) |
| 59 | + println("Ev" + metrics.macroRecallDoc) |
| 60 | + println(macroRecallDoc) |
| 61 | + assert(math.abs(metrics.precisionClass(0.0) - precision0) < delta) |
| 62 | + assert(math.abs(metrics.precisionClass(1.0) - precision1) < delta) |
| 63 | + assert(math.abs(metrics.precisionClass(2.0) - precision2) < delta) |
| 64 | + assert(math.abs(metrics.recallClass(0.0) - recall0) < delta) |
| 65 | + assert(math.abs(metrics.recallClass(1.0) - recall1) < delta) |
| 66 | + assert(math.abs(metrics.recallClass(2.0) - recall2) < delta) |
| 67 | + assert(math.abs(metrics.f1MeasureClass(0.0) - f1measure0) < delta) |
| 68 | + assert(math.abs(metrics.f1MeasureClass(1.0) - f1measure1) < delta) |
| 69 | + assert(math.abs(metrics.f1MeasureClass(2.0) - f1measure2) < delta) |
| 70 | + |
| 71 | + assert(math.abs(metrics.microPrecisionClass - microPrecisionClass) < delta) |
| 72 | + assert(math.abs(metrics.microRecallClass - microRecallClass) < delta) |
| 73 | + assert(math.abs(metrics.microF1MeasureClass - microF1MeasureClass) < delta) |
| 74 | + |
| 75 | + assert(math.abs(metrics.macroPrecisionDoc - macroPrecisionDoc) < delta) |
| 76 | + assert(math.abs(metrics.macroRecallDoc - macroRecallDoc) < delta) |
| 77 | + |
| 78 | + |
| 79 | + } |
| 80 | + |
| 81 | +} |
0 commit comments