Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1854,7 +1854,7 @@ public void testEvaluateDataFrame_OutlierDetection() throws IOException {
AucRocResult aucRocResult =
evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME);
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocResult.getValue(), closeTo(0.70025, 1e-9));
assertThat(aucRocResult.getValue(), closeTo(0.70, 1e-3));
assertNotNull(aucRocResult.getCurve());
List<AucRocPoint> curve = aucRocResult.getCurve();
AucRocPoint curvePointAtThreshold0 = curve.stream().filter(p -> p.getThreshold() == 0.0).findFirst().get();
Expand Down Expand Up @@ -1989,7 +1989,7 @@ public void testEvaluateDataFrame_Classification() throws IOException {

AucRocResult aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocResult.getValue(), closeTo(0.6425, 1e-9));
assertThat(aucRocResult.getValue(), closeTo(0.619, 1e-3));
assertNotNull(aucRocResult.getCurve());
}
{ // Accuracy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3616,7 +3616,7 @@ public void testEvaluateDataFrame_Classification() throws Exception {
assertThat(otherClassesCount, equalTo(0L));

assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocScore, closeTo(0.6425, 1e-9));
assertThat(aucRocScore, closeTo(0.619, 1e-3));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,76 @@ protected static List<AucRocPoint> buildAucRocCurve(double[] tpPercentiles, doub
assert tpPercentiles.length == fpPercentiles.length;
assert tpPercentiles.length == 99;

List<AucRocPoint> aucRocCurve = new ArrayList<>();
aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0));
aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0));
List<AucRocPoint> points = new ArrayList<>(tpPercentiles.length + fpPercentiles.length);
RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true);
RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false);
aucRocCurve.addAll(tpCurve.scanPoints(fpCurve));
aucRocCurve.addAll(fpCurve.scanPoints(tpCurve));
Collections.sort(aucRocCurve);
points.addAll(tpCurve.scanPoints(fpCurve));
points.addAll(fpCurve.scanPoints(tpCurve));
Collections.sort(points);

// As our auc roc curve is comprised by two sets of points coming from two
// percentiles aggregations, it is possible that we get a non-monotonic result
// because the percentiles aggregation is an approximation. In order to make
// our final curve monotonic, we collapse equal threshold points.
points = collapseEqualThresholdPoints(points);

List<AucRocPoint> aucRocCurve = new ArrayList<>(points.size() + 2);
aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0));
aucRocCurve.addAll(points);
aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0));
return aucRocCurve;
}

/**
* Visible for testing
*
* Expects a sorted list of {@link AucRocPoint} points.
* Collapses points with equal threshold by replacing them
* with a single point that is the average.
*
* @param points A sorted list of {@link AucRocPoint} points
* @return a new list of points where equal threshold points have been collapsed into their average
*/
static List<AucRocPoint> collapseEqualThresholdPoints(List<AucRocPoint> points) {
List<AucRocPoint> collapsed = new ArrayList<>();
List<AucRocPoint> equalThresholdPoints = new ArrayList<>();
for (AucRocPoint point : points) {
if (equalThresholdPoints.isEmpty() == false && equalThresholdPoints.get(0).threshold != point.threshold) {
collapsed.add(calculateAveragePoint(equalThresholdPoints));
equalThresholdPoints = new ArrayList<>();
}
equalThresholdPoints.add(point);
}

if (equalThresholdPoints.isEmpty() == false) {
collapsed.add(calculateAveragePoint(equalThresholdPoints));
}

return collapsed;
}

private static AucRocPoint calculateAveragePoint(List<AucRocPoint> points) {
if (points.isEmpty()) {
throw new IllegalArgumentException("points must not be empty");
}

if (points.size() == 1) {
return points.get(0);
}

double avgTpr = 0.0;
double avgFpr = 0.0;
double avgThreshold = 0.0;
for (AucRocPoint sameThresholdPoint : points) {
avgTpr += sameThresholdPoint.tpr;
avgFpr += sameThresholdPoint.fpr;
avgThreshold += sameThresholdPoint.threshold;
}

int n = points.size();
return new AucRocPoint(avgTpr / n, avgFpr / n, avgThreshold / n);
}

/**
* Visible for testing
*/
Expand Down Expand Up @@ -114,7 +173,10 @@ private double getRate(int index) {
}

private double getThreshold(int index) {
return percentiles[index];
// We subtract the minimum value possible here in order to
// ensure no point has a threshold of 1.0 as we are adding
// that point separately so that fpr = tpr = 0.
return Math.max(0, percentiles[index] - Math.ulp(percentiles[index]));
}

private double interpolateRate(double threshold) {
Expand Down Expand Up @@ -160,9 +222,9 @@ public static final class AucRocPoint implements Comparable<AucRocPoint>, ToXCon
private static final String FPR = "fpr";
private static final String THRESHOLD = "threshold";

private final double tpr;
private final double fpr;
private final double threshold;
final double tpr;
final double fpr;
final double threshold;

public AucRocPoint(double tpr, double fpr, double threshold) {
this.tpr = tpr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.common.AbstractAucRoc.AucRocPoint;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

public class AbstractAucRocTests extends ESTestCase {
Expand All @@ -22,7 +27,7 @@ public void testCalculateAucScore_GivenZeroPercentiles() {
double[] tpPercentiles = zeroPercentiles();
double[] fpPercentiles = zeroPercentiles();

List<AbstractAucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
List<AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);

assertThat(aucRocScore, closeTo(0.5, 0.01));
Expand All @@ -32,7 +37,7 @@ public void testCalculateAucScore_GivenRandomTpPercentilesAndZeroFpPercentiles()
double[] tpPercentiles = randomPercentiles();
double[] fpPercentiles = zeroPercentiles();

List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
List<AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);

assertThat(aucRocScore, closeTo(1.0, 0.1));
Expand All @@ -42,7 +47,7 @@ public void testCalculateAucScore_GivenZeroTpPercentilesAndRandomFpPercentiles()
double[] tpPercentiles = zeroPercentiles();
double[] fpPercentiles = randomPercentiles();

List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
List<AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);

assertThat(aucRocScore, closeTo(0.0, 0.1));
Expand All @@ -53,10 +58,10 @@ public void testCalculateAucScore_GivenRandomPercentiles() {
double[] tpPercentiles = randomPercentiles();
double[] fpPercentiles = randomPercentiles();

List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
List<AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);

List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
List<AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);

assertThat(aucRocScore, greaterThanOrEqualTo(0.0));
Expand All @@ -80,16 +85,44 @@ public void testCalculateAucScore_GivenPrecalculated() {
fpPercentiles[i] = fpSimplified[simplifiedIndex];
}

List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
List<AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);

List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
List<AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);

assertThat(aucRocScore, closeTo(0.8, 0.05));
assertThat(inverseAucRocScore, closeTo(0.2, 0.05));
}

public void testCollapseEqualThresholdPoints_GivenEmpty() {
assertThat(AbstractAucRoc.collapseEqualThresholdPoints(Collections.emptyList()), is(empty()));
}

public void testCollapseEqualThresholdPoints() {
List<AucRocPoint> curve = Arrays.asList(
new AucRocPoint(0.0, 0.0, 1.0),
new AucRocPoint(0.1, 0.9, 0.1),
new AucRocPoint(0.2, 0.8, 0.2),
new AucRocPoint(0.1, 0.9, 0.2),
new AucRocPoint(0.3, 0.6, 0.3),
new AucRocPoint(0.5, 0.5, 0.4),
new AucRocPoint(0.4, 0.6, 0.4),
new AucRocPoint(0.9, 0.1, 0.4),
new AucRocPoint(1.0, 1.0, 0.0)
);

List<AucRocPoint> collapsed = AbstractAucRoc.collapseEqualThresholdPoints(curve);

assertThat(collapsed.size(), equalTo(6));
assertThat(collapsed.get(0), equalTo(curve.get(0)));
assertThat(collapsed.get(1), equalTo(curve.get(1)));
assertPointCloseTo(collapsed.get(2), 0.15, 0.85, 0.2);
assertThat(collapsed.get(3), equalTo(curve.get(4)));
assertPointCloseTo(collapsed.get(4), 0.6, 0.4, 0.4);
assertThat(collapsed.get(5), equalTo(curve.get(8)));
}

public static double[] zeroPercentiles() {
double[] percentiles = new double[99];
Arrays.fill(percentiles, 0.0);
Expand All @@ -104,4 +137,10 @@ public static double[] randomPercentiles() {
Arrays.sort(percentiles);
return percentiles;
}

private static void assertPointCloseTo(AucRocPoint point, double expectedTpr, double expectedFpr, double expectedThreshold) {
assertThat(point.tpr, closeTo(expectedTpr, 0.00001));
assertThat(point.fpr, closeTo(expectedFpr, 0.00001));
assertThat(point.threshold, closeTo(expectedThreshold, 0.00001));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ private AucRoc.Result evaluateAucRoc(String actualField, String predictedField,

public void testEvaluate_AucRoc_DoNotIncludeCurve() {
AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, false);
assertThat(aucrocResult.getValue(), is(closeTo(1.0, 0.0001)));
assertThat(aucrocResult.getValue(), is(closeTo(0.98, 0.001)));
assertThat(aucrocResult.getCurve(), hasSize(0));
}

public void testEvaluate_AucRoc_IncludeCurve() {
AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, true);
assertThat(aucrocResult.getValue(), is(closeTo(1.0, 0.0001)));
assertThat(aucrocResult.getValue(), is(closeTo(0.98, 0.001)));
assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ setup:
}
}
}
- match: { outlier_detection.auc_roc.value: 0.9899 }
- match: { outlier_detection.auc_roc.value: 0.9919403846153846 }
- is_false: outlier_detection.auc_roc.curve

---
Expand All @@ -226,7 +226,7 @@ setup:
}
}
}
- match: { outlier_detection.auc_roc.value: 0.9899 }
- match: { outlier_detection.auc_roc.value: 0.9919403846153846 }
- is_false: outlier_detection.auc_roc.curve

---
Expand All @@ -246,7 +246,7 @@ setup:
}
}
}
- match: { outlier_detection.auc_roc.value: 0.9899 }
- match: { outlier_detection.auc_roc.value: 0.9919403846153846 }
- is_true: outlier_detection.auc_roc.curve

---
Expand Down Expand Up @@ -721,7 +721,7 @@ setup:
}
}
}
- match: { classification.auc_roc.value: 0.8050111095212122 }
- match: { classification.auc_roc.value: 0.7754152761810909 }
- is_false: classification.auc_roc.curve
---
"Test classification auc_roc with default top_classes_field":
Expand All @@ -741,7 +741,7 @@ setup:
}
}
}
- match: { classification.auc_roc.value: 0.8050111095212122 }
- match: { classification.auc_roc.value: 0.7754152761810909 }
- is_false: classification.auc_roc.curve
---
"Test classification accuracy with missing predicted_field":
Expand Down