Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

package org.tribuo.classification.evaluation;

import java.util.logging.Logger;

import org.tribuo.classification.Classifiable;
import org.tribuo.evaluation.metrics.EvaluationMetric.Average;
import org.tribuo.evaluation.metrics.MetricTarget;

import java.util.logging.Logger;

/**
* Static functions for computing classification metrics based on a {@link ConfusionMatrix}.
*/
Expand Down Expand Up @@ -60,7 +60,7 @@ public static <T extends Classifiable<T>> double accuracy(T label, ConfusionMatr
double support = cm.support(label);
// handle div-by-zero
if (support == 0d) {
logger.warning("No predictions: accuracy ill-defined");
logger.warning("No predictions for " + label + ": accuracy ill-defined");
return Double.NaN;
}
return cm.tp(label) / cm.support(label);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@

package org.tribuo.classification.evaluation;

import java.util.Arrays;
import java.util.List;

import org.junit.jupiter.api.Test;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.tribuo.classification.Utils.label;
import static org.tribuo.classification.Utils.mkDomain;
import static org.tribuo.classification.Utils.mkPrediction;
import static org.junit.jupiter.api.Assertions.assertEquals;


public class LabelConfusionMatrixTest {
Expand All @@ -38,7 +38,8 @@ public void testMulticlass() {
mkPrediction("a", "a"),
mkPrediction("c", "b"),
mkPrediction("b", "b"),
mkPrediction("b", "c")
mkPrediction("b", "c"),
mkPrediction("a", "b")
);
ImmutableOutputInfo<Label> domain = mkDomain(predictions);
LabelConfusionMatrix cm = new LabelConfusionMatrix(domain, predictions);
Expand All @@ -54,25 +55,25 @@ public void testMulticlass() {
assertEquals(1, cm.tp(a));
assertEquals(0, cm.fp(a));
assertEquals(3, cm.tn(a));
assertEquals(0, cm.fn(a));
assertEquals(1, cm.support(a));
assertEquals(1, cm.fn(a));
assertEquals(2, cm.support(a));

assertEquals(1, cm.tp(b));
assertEquals(1, cm.fp(b));
assertEquals(2, cm.fp(b));
assertEquals(1, cm.tn(b));
assertEquals(1, cm.fn(b));
assertEquals(2, cm.support(b));

assertEquals(0, cm.tp(c));
assertEquals(1, cm.fp(c));
assertEquals(2, cm.tn(c));
assertEquals(3, cm.tn(c));
assertEquals(1, cm.fn(c));
assertEquals(1, cm.support(c));

assertEquals(4, cm.support());
assertEquals(5, cm.support());
String cmToString = cm.toString();
assertEquals(" a b c\n" +
"a 1 0 0\n" +
"a 1 1 0\n" +
"b 0 1 1\n" +
"c 0 1 0\n", cmToString);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

package org.tribuo.multilabel.evaluation;

import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
Expand All @@ -25,10 +30,6 @@
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.multilabel.MultiLabelFactory;

import java.util.List;
import java.util.Set;
import java.util.function.Function;

/**
* A {@link ConfusionMatrix} which accepts {@link MultiLabel}s.
*
Expand Down Expand Up @@ -158,15 +159,18 @@ public double confusion(MultiLabel predicted, MultiLabel truth) {

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[");
for (int i = 0; i < mcm.length; i++) {
DenseMatrix cm = mcm[i];
sb.append(cm.toString());
sb.append("\n");
}
sb.append("]");
return sb.toString();
return getDomain().getDomain().stream()
.map(multiLabel -> {
final int tp = (int) tp(multiLabel);
final int fn = (int) fn(multiLabel);
final int fp = (int) fp(multiLabel);
final int tn = (int) tn(multiLabel);
return String.join("\n",
multiLabel.toString(),
String.format(" [tn: %,d fn: %,d]", tn, fn),
String.format(" [fp: %,d tp: %,d]", fp, tp));
}
).collect(Collectors.joining("\n"));
}

static ConfusionMatrixTuple tabulate(ImmutableOutputInfo<MultiLabel> domain, List<Prediction<MultiLabel>> predictions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,37 @@

package org.tribuo.multilabel;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.MutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.evaluation.ClassifierEvaluation;
import org.tribuo.classification.evaluation.ConfusionMatrix;
import org.tribuo.classification.evaluation.LabelConfusionMatrix;
import org.tribuo.classification.sgd.linear.LinearSGDTrainer;
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer;
import org.tribuo.impl.ListExample;
import org.tribuo.multilabel.baseline.IndependentMultiLabelTrainer;
import org.tribuo.multilabel.evaluation.MultiLabelEvaluator;
import org.tribuo.multilabel.example.MultiLabelDataGenerator;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.tribuo.test.Helpers;

import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import com.oracle.labs.mlrg.olcut.util.Pair;

import static org.junit.jupiter.api.Assertions.assertEquals;

Expand Down Expand Up @@ -67,4 +81,97 @@ public void testIndependentBinaryPredictions() {
Helpers.testModelSerialization(model,MultiLabel.class);
}

@Test
public void testMultiLabelConfusionMatrixToStrings() {
Dataset<MultiLabel> train = MultiLabelDataGenerator.generateTrainData();
Dataset<MultiLabel> test = MultiLabelDataGenerator.generateTestData();

IndependentMultiLabelTrainer trainer = new IndependentMultiLabelTrainer(
new LogisticRegressionTrainer());
Model<MultiLabel> model = trainer.train(train);

ClassifierEvaluation<MultiLabel> evaluation = new MultiLabelEvaluator()
.evaluate(model, test);

System.out.println(evaluation);

// MultiLabelConfusionMatrix toString() hard to interpret
final ConfusionMatrix<MultiLabel> mcm = evaluation.getConfusionMatrix();

System.out.println("new toString()");
System.out.println(mcm);

System.out.println("\npredictions");
evaluation.getPredictions().forEach(System.out::println);

final List<Prediction<MultiLabel>> predictions = evaluation.getPredictions();
System.out.println("\nsingleLabelConfusionMatrix");
System.out.println(singleLabelConfusionMatrix(predictions));
}

public static LabelConfusionMatrix singleLabelConfusionMatrix(final List<Prediction<MultiLabel>> predictions) {
final List<Prediction<Label>> singleLabelPredictions = mkSingleLabelPredictions(predictions);
ImmutableOutputInfo<Label> domain = mkDomain(singleLabelPredictions);
LabelConfusionMatrix cm = new LabelConfusionMatrix(domain, singleLabelPredictions);
return cm;
}

public static List<Prediction<Label>> mkSingleLabelPredictions(List<Prediction<MultiLabel>> predictions) {
return mkSingleLabelPredictions(predictions, false);
}

public static List<Prediction<Label>> mkSingleLabelPredictions(List<Prediction<MultiLabel>> predictions,
final boolean falseNegativeHeuristic) {
return predictions.stream()
.flatMap(p -> {
final Set<Label> trueLabels = p.getExample().getOutput().getLabelSet();
final Set<Label> predicted = p.getOutput().getLabelSet();
// intersection(trueLabels, predicted) = true positives
// predicted - trueLabels = false positives
// trueLabels - predicted = false negatives
return Stream.concat(predicted.stream().map(pred -> {
if (trueLabels.contains(pred)) {
return mkPrediction(pred.getLabel(), pred.getLabel());
} else if (trueLabels.size() == 1) {
return mkPrediction(trueLabels.iterator().next().getLabel(), pred.getLabel());
} else {
// arbitrarily pick first trueLabel
return mkPrediction(trueLabels.iterator().next().getLabel(), pred.getLabel());
}
}),
!falseNegativeHeuristic ? Stream.of() :
// partially represent false negatives by calling them false positives tied to some predicted label if there is one
trueLabels.stream().filter(t -> !predicted.contains(t)).flatMap(fnTrueLabel -> {
if (predicted.isEmpty()) {
// nothing to pin this on
return Stream.of();
} else if (predicted.size() == 1) {
return Stream.of(mkPrediction(fnTrueLabel.getLabel(), predicted.iterator().next().getLabel()));
} else {
// arbitrarily pick first predicted label
return Stream.of(mkPrediction(fnTrueLabel.getLabel(), predicted.iterator().next().getLabel()));
}
})
);
}).collect(Collectors.toList());
}

// FIXME HACK copied from Classification/Core/src/test/java/org/tribuo/classification/Utils.java

public static Prediction<Label> mkPrediction(String trueVal, String predVal) {
LabelFactory factory = new LabelFactory();
Example<Label> example = new ListExample<>(factory.generateOutput(trueVal));
example.add(new Feature("noop", 1d));
Prediction<Label> prediction = new Prediction<>(factory.generateOutput(predVal), 0, example);
return prediction;
}

public static ImmutableOutputInfo<Label> mkDomain(List<Prediction<Label>> predictions) {
final MutableOutputInfo<Label> info = new LabelFactory().generateInfo();
for (Prediction<Label> p : predictions) {
info.observe(p.getExample().getOutput());
info.observe(p.getOutput()); // TODO? LN added
}
return info.generateImmutableOutputInfo();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@

package org.tribuo.multilabel.evaluation;

import java.util.Arrays;
import java.util.List;

import org.junit.jupiter.api.Test;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.multilabel.MultiLabel;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.tribuo.multilabel.IndependentMultiLabelTest.singleLabelConfusionMatrix;
import static org.tribuo.multilabel.Utils.getUnknown;
import static org.tribuo.multilabel.Utils.label;
import static org.tribuo.multilabel.Utils.mkDomain;
import static org.tribuo.multilabel.Utils.mkPrediction;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class MultiLabelConfusionMatrixTest {

Expand Down Expand Up @@ -158,6 +159,11 @@ public void testSingleLabel() {
assertEquals(1, cm.support(c));

assertEquals(4, cm.support());

System.out.println("new toString()");
System.out.println(cm);
System.out.println("\nsingleLabelConfusionMatrix");
System.out.println(singleLabelConfusionMatrix(predictions));
}

@Test
Expand Down Expand Up @@ -231,6 +237,11 @@ public void testMultiLabel() {
assertEquals(1, cm.support(c));

assertEquals(5, cm.support());

System.out.println("new toString()");
System.out.println(cm);
System.out.println("\nsingleLabelConfusionMatrix");
System.out.println(singleLabelConfusionMatrix(predictions));
}


Expand Down