diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index e680f98ebebc0..73c16c71e4a9f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1854,7 +1854,7 @@ public void testEvaluateDataFrame_OutlierDetection() throws IOException { AucRocResult aucRocResult = evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME); assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); - assertThat(aucRocResult.getValue(), closeTo(0.70025, 1e-9)); + assertThat(aucRocResult.getValue(), closeTo(0.70, 1e-3)); assertNotNull(aucRocResult.getCurve()); List curve = aucRocResult.getCurve(); AucRocPoint curvePointAtThreshold0 = curve.stream().filter(p -> p.getThreshold() == 0.0).findFirst().get(); @@ -1989,7 +1989,7 @@ public void testEvaluateDataFrame_Classification() throws IOException { AucRocResult aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME); assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); - assertThat(aucRocResult.getValue(), closeTo(0.6425, 1e-9)); + assertThat(aucRocResult.getValue(), closeTo(0.619, 1e-3)); assertNotNull(aucRocResult.getCurve()); } { // Accuracy diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index d1d1fed3d6f0d..4a623f751ceb3 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -3616,7 +3616,7 @@ public void testEvaluateDataFrame_Classification() throws Exception { assertThat(otherClassesCount, equalTo(0L)); assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); - assertThat(aucRocScore, closeTo(0.6425, 1e-9)); + assertThat(aucRocScore, closeTo(0.619, 1e-3)); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc.java index 38dd90000678d..75e7d2a1777df 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc.java @@ -74,17 +74,76 @@ protected static List buildAucRocCurve(double[] tpPercentiles, doub assert tpPercentiles.length == fpPercentiles.length; assert tpPercentiles.length == 99; - List aucRocCurve = new ArrayList<>(); - aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0)); - aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0)); + List points = new ArrayList<>(tpPercentiles.length + fpPercentiles.length); RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true); RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false); - aucRocCurve.addAll(tpCurve.scanPoints(fpCurve)); - aucRocCurve.addAll(fpCurve.scanPoints(tpCurve)); - Collections.sort(aucRocCurve); + points.addAll(tpCurve.scanPoints(fpCurve)); + points.addAll(fpCurve.scanPoints(tpCurve)); + Collections.sort(points); + + // As our auc roc curve is comprised by two sets of points coming from two + // percentiles aggregations, it is possible that we get a non-monotonic result + // because the percentiles aggregation is an approximation. In order to make + // our final curve monotonic, we collapse equal threshold points. + points = collapseEqualThresholdPoints(points); + + List aucRocCurve = new ArrayList<>(points.size() + 2); + aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0)); + aucRocCurve.addAll(points); + aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0)); return aucRocCurve; } + /** + * Visible for testing + * + * Expects a sorted list of {@link AucRocPoint} points. + * Collapses points with equal threshold by replacing them + * with a single point that is the average. + * + * @param points A sorted list of {@link AucRocPoint} points + * @return a new list of points where equal threshold points have been collapsed into their average + */ + static List collapseEqualThresholdPoints(List points) { + List collapsed = new ArrayList<>(); + List equalThresholdPoints = new ArrayList<>(); + for (AucRocPoint point : points) { + if (equalThresholdPoints.isEmpty() == false && equalThresholdPoints.get(0).threshold != point.threshold) { + collapsed.add(calculateAveragePoint(equalThresholdPoints)); + equalThresholdPoints = new ArrayList<>(); + } + equalThresholdPoints.add(point); + } + + if (equalThresholdPoints.isEmpty() == false) { + collapsed.add(calculateAveragePoint(equalThresholdPoints)); + } + + return collapsed; + } + + private static AucRocPoint calculateAveragePoint(List points) { + if (points.isEmpty()) { + throw new IllegalArgumentException("points must not be empty"); + } + + if (points.size() == 1) { + return points.get(0); + } + + double avgTpr = 0.0; + double avgFpr = 0.0; + double avgThreshold = 0.0; + for (AucRocPoint sameThresholdPoint : points) { + avgTpr += sameThresholdPoint.tpr; + avgFpr += sameThresholdPoint.fpr; + avgThreshold += sameThresholdPoint.threshold; + } + + int n = points.size(); + return new AucRocPoint(avgTpr / n, avgFpr / n, avgThreshold / n); + } + /** * Visible for testing */ @@ -114,7 +173,10 @@ private double getRate(int index) { } private double getThreshold(int index) { - return percentiles[index]; + // We subtract the minimum value possible here in order to + // ensure no point has a threshold of 1.0 as we are adding + // that point separately so that fpr = tpr = 0. + return Math.max(0, percentiles[index] - Math.ulp(percentiles[index])); } private double interpolateRate(double threshold) { @@ -160,9 +222,9 @@ public static final class AucRocPoint implements Comparable, ToXCon private static final String FPR = "fpr"; private static final String THRESHOLD = "threshold"; - private final double tpr; - private final double fpr; - private final double threshold; + final double tpr; + final double fpr; + final double threshold; public AucRocPoint(double tpr, double fpr, double threshold) { this.tpr = tpr; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRocTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRocTests.java index 45dc5fa587287..d7d66930784a7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRocTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRocTests.java @@ -8,12 +8,17 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.common.AbstractAucRoc.AucRocPoint; import java.util.Arrays; +import java.util.Collections; import java.util.List; import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThanOrEqualTo; public class AbstractAucRocTests extends ESTestCase { @@ -22,7 +27,7 @@ public void testCalculateAucScore_GivenZeroPercentiles() { double[] tpPercentiles = zeroPercentiles(); double[] fpPercentiles = zeroPercentiles(); - List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); double aucRocScore = AucRoc.calculateAucScore(curve); assertThat(aucRocScore, closeTo(0.5, 0.01)); @@ -32,7 +37,7 @@ public void testCalculateAucScore_GivenRandomTpPercentilesAndZeroFpPercentiles() double[] tpPercentiles = randomPercentiles(); double[] fpPercentiles = zeroPercentiles(); - List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); double aucRocScore = AucRoc.calculateAucScore(curve); assertThat(aucRocScore, closeTo(1.0, 0.1)); @@ -42,7 +47,7 @@ public void testCalculateAucScore_GivenZeroTpPercentilesAndRandomFpPercentiles() double[] tpPercentiles = zeroPercentiles(); double[] fpPercentiles = randomPercentiles(); - List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); double aucRocScore = AucRoc.calculateAucScore(curve); assertThat(aucRocScore, closeTo(0.0, 0.1)); @@ -53,10 +58,10 @@ public void testCalculateAucScore_GivenRandomPercentiles() { double[] tpPercentiles = randomPercentiles(); double[] fpPercentiles = randomPercentiles(); - List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); double aucRocScore = AucRoc.calculateAucScore(curve); - List inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles); + List inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles); double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve); assertThat(aucRocScore, greaterThanOrEqualTo(0.0)); @@ -80,16 +85,44 @@ public void testCalculateAucScore_GivenPrecalculated() { fpPercentiles[i] = fpSimplified[simplifiedIndex]; } - List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); double aucRocScore = AucRoc.calculateAucScore(curve); - List inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles); + List inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles); double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve); assertThat(aucRocScore, closeTo(0.8, 0.05)); assertThat(inverseAucRocScore, closeTo(0.2, 0.05)); } + public void testCollapseEqualThresholdPoints_GivenEmpty() { + assertThat(AbstractAucRoc.collapseEqualThresholdPoints(Collections.emptyList()), is(empty())); + } + + public void testCollapseEqualThresholdPoints() { + List curve = Arrays.asList( + new AucRocPoint(0.0, 0.0, 1.0), + new AucRocPoint(0.1, 0.9, 0.1), + new AucRocPoint(0.2, 0.8, 0.2), + new AucRocPoint(0.1, 0.9, 0.2), + new AucRocPoint(0.3, 0.6, 0.3), + new AucRocPoint(0.5, 0.5, 0.4), + new AucRocPoint(0.4, 0.6, 0.4), + new AucRocPoint(0.9, 0.1, 0.4), + new AucRocPoint(1.0, 1.0, 0.0) + ); + + List collapsed = AbstractAucRoc.collapseEqualThresholdPoints(curve); + + assertThat(collapsed.size(), equalTo(6)); + assertThat(collapsed.get(0), equalTo(curve.get(0))); + assertThat(collapsed.get(1), equalTo(curve.get(1))); + assertPointCloseTo(collapsed.get(2), 0.15, 0.85, 0.2); + assertThat(collapsed.get(3), equalTo(curve.get(4))); + assertPointCloseTo(collapsed.get(4), 0.6, 0.4, 0.4); + assertThat(collapsed.get(5), equalTo(curve.get(8))); + } + public static double[] zeroPercentiles() { double[] percentiles = new double[99]; Arrays.fill(percentiles, 0.0); @@ -104,4 +137,10 @@ public static double[] randomPercentiles() { Arrays.sort(percentiles); return percentiles; } + + private static void assertPointCloseTo(AucRocPoint point, double expectedTpr, double expectedFpr, double expectedThreshold) { + assertThat(point.tpr, closeTo(expectedTpr, 0.00001)); + assertThat(point.fpr, closeTo(expectedFpr, 0.00001)); + assertThat(point.threshold, closeTo(expectedThreshold, 0.00001)); + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java index e573944d5fb0e..a57a57b70cf90 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java @@ -98,13 +98,13 @@ private AucRoc.Result evaluateAucRoc(String actualField, String predictedField, public void testEvaluate_AucRoc_DoNotIncludeCurve() { AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, false); - assertThat(aucrocResult.getValue(), is(closeTo(1.0, 0.0001))); + assertThat(aucrocResult.getValue(), is(closeTo(0.98, 0.001))); assertThat(aucrocResult.getCurve(), hasSize(0)); } public void testEvaluate_AucRoc_IncludeCurve() { AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, true); - assertThat(aucrocResult.getValue(), is(closeTo(1.0, 0.0001))); + assertThat(aucrocResult.getValue(), is(closeTo(0.98, 0.001))); assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0))); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 83fe922c02492..d83f3255daeb8 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -206,7 +206,7 @@ setup: } } } - - match: { outlier_detection.auc_roc.value: 0.9899 } + - match: { outlier_detection.auc_roc.value: 0.9919403846153846 } - is_false: outlier_detection.auc_roc.curve --- @@ -226,7 +226,7 @@ setup: } } } - - match: { outlier_detection.auc_roc.value: 0.9899 } + - match: { outlier_detection.auc_roc.value: 0.9919403846153846 } - is_false: outlier_detection.auc_roc.curve --- @@ -246,7 +246,7 @@ setup: } } } - - match: { outlier_detection.auc_roc.value: 0.9899 } + - match: { outlier_detection.auc_roc.value: 0.9919403846153846 } - is_true: outlier_detection.auc_roc.curve --- @@ -721,7 +721,7 @@ setup: } } } - - match: { classification.auc_roc.value: 0.8050111095212122 } + - match: { classification.auc_roc.value: 0.7754152761810909 } - is_false: classification.auc_roc.curve --- "Test classification auc_roc with default top_classes_field": @@ -741,7 +741,7 @@ setup: } } } - - match: { classification.auc_roc.value: 0.8050111095212122 } + - match: { classification.auc_roc.value: 0.7754152761810909 } - is_false: classification.auc_roc.curve --- "Test classification accuracy with missing predicted_field":