Skip to content

Commit b5f7350

Browse files
[7.x][ML] Ensure auc_roc curve is monotonic (#70628) (#70707)
As we collect points for our auc roc curve from two different percentiles aggregations, the result may contain points with equal threshold that are not monotonic. This is because the percentiles aggregation is an approximation. This commit ensures the fina auc roc curve we calculate is monotonic by collapsing points of equal threshold into a single point that is the average of the equal threshold points it represents. Backport of #70628
1 parent 81da10f commit b5f7350

File tree

6 files changed

+128
-27
lines changed

6 files changed

+128
-27
lines changed

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,7 +1854,7 @@ public void testEvaluateDataFrame_OutlierDetection() throws IOException {
18541854
AucRocResult aucRocResult =
18551855
evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME);
18561856
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
1857-
assertThat(aucRocResult.getValue(), closeTo(0.70025, 1e-9));
1857+
assertThat(aucRocResult.getValue(), closeTo(0.70, 1e-3));
18581858
assertNotNull(aucRocResult.getCurve());
18591859
List<AucRocPoint> curve = aucRocResult.getCurve();
18601860
AucRocPoint curvePointAtThreshold0 = curve.stream().filter(p -> p.getThreshold() == 0.0).findFirst().get();
@@ -1989,7 +1989,7 @@ public void testEvaluateDataFrame_Classification() throws IOException {
19891989

19901990
AucRocResult aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
19911991
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
1992-
assertThat(aucRocResult.getValue(), closeTo(0.6425, 1e-9));
1992+
assertThat(aucRocResult.getValue(), closeTo(0.619, 1e-3));
19931993
assertNotNull(aucRocResult.getCurve());
19941994
}
19951995
{ // Accuracy

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3616,7 +3616,7 @@ public void testEvaluateDataFrame_Classification() throws Exception {
36163616
assertThat(otherClassesCount, equalTo(0L));
36173617

36183618
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
3619-
assertThat(aucRocScore, closeTo(0.6425, 1e-9));
3619+
assertThat(aucRocScore, closeTo(0.619, 1e-3));
36203620
}
36213621
}
36223622

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc.java

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,76 @@ protected static List<AucRocPoint> buildAucRocCurve(double[] tpPercentiles, doub
7474
assert tpPercentiles.length == fpPercentiles.length;
7575
assert tpPercentiles.length == 99;
7676

77-
List<AucRocPoint> aucRocCurve = new ArrayList<>();
78-
aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0));
79-
aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0));
77+
List<AucRocPoint> points = new ArrayList<>(tpPercentiles.length + fpPercentiles.length);
8078
RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true);
8179
RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false);
82-
aucRocCurve.addAll(tpCurve.scanPoints(fpCurve));
83-
aucRocCurve.addAll(fpCurve.scanPoints(tpCurve));
84-
Collections.sort(aucRocCurve);
80+
points.addAll(tpCurve.scanPoints(fpCurve));
81+
points.addAll(fpCurve.scanPoints(tpCurve));
82+
Collections.sort(points);
83+
84+
// As our auc roc curve is comprised by two sets of points coming from two
85+
// percentiles aggregations, it is possible that we get a non-monotonic result
86+
// because the percentiles aggregation is an approximation. In order to make
87+
// our final curve monotonic, we collapse equal threshold points.
88+
points = collapseEqualThresholdPoints(points);
89+
90+
List<AucRocPoint> aucRocCurve = new ArrayList<>(points.size() + 2);
91+
aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0));
92+
aucRocCurve.addAll(points);
93+
aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0));
8594
return aucRocCurve;
8695
}
8796

97+
/**
98+
* Visible for testing
99+
*
100+
* Expects a sorted list of {@link AucRocPoint} points.
101+
* Collapses points with equal threshold by replacing them
102+
* with a single point that is the average.
103+
*
104+
* @param points A sorted list of {@link AucRocPoint} points
105+
* @return a new list of points where equal threshold points have been collapsed into their average
106+
*/
107+
static List<AucRocPoint> collapseEqualThresholdPoints(List<AucRocPoint> points) {
108+
List<AucRocPoint> collapsed = new ArrayList<>();
109+
List<AucRocPoint> equalThresholdPoints = new ArrayList<>();
110+
for (AucRocPoint point : points) {
111+
if (equalThresholdPoints.isEmpty() == false && equalThresholdPoints.get(0).threshold != point.threshold) {
112+
collapsed.add(calculateAveragePoint(equalThresholdPoints));
113+
equalThresholdPoints = new ArrayList<>();
114+
}
115+
equalThresholdPoints.add(point);
116+
}
117+
118+
if (equalThresholdPoints.isEmpty() == false) {
119+
collapsed.add(calculateAveragePoint(equalThresholdPoints));
120+
}
121+
122+
return collapsed;
123+
}
124+
125+
private static AucRocPoint calculateAveragePoint(List<AucRocPoint> points) {
126+
if (points.isEmpty()) {
127+
throw new IllegalArgumentException("points must not be empty");
128+
}
129+
130+
if (points.size() == 1) {
131+
return points.get(0);
132+
}
133+
134+
double avgTpr = 0.0;
135+
double avgFpr = 0.0;
136+
double avgThreshold = 0.0;
137+
for (AucRocPoint sameThresholdPoint : points) {
138+
avgTpr += sameThresholdPoint.tpr;
139+
avgFpr += sameThresholdPoint.fpr;
140+
avgThreshold += sameThresholdPoint.threshold;
141+
}
142+
143+
int n = points.size();
144+
return new AucRocPoint(avgTpr / n, avgFpr / n, avgThreshold / n);
145+
}
146+
88147
/**
89148
* Visible for testing
90149
*/
@@ -114,7 +173,10 @@ private double getRate(int index) {
114173
}
115174

116175
private double getThreshold(int index) {
117-
return percentiles[index];
176+
// We subtract the minimum value possible here in order to
177+
// ensure no point has a threshold of 1.0 as we are adding
178+
// that point separately so that fpr = tpr = 0.
179+
return Math.max(0, percentiles[index] - Math.ulp(percentiles[index]));
118180
}
119181

120182
private double interpolateRate(double threshold) {
@@ -160,9 +222,9 @@ public static final class AucRocPoint implements Comparable<AucRocPoint>, ToXCon
160222
private static final String FPR = "fpr";
161223
private static final String THRESHOLD = "threshold";
162224

163-
private final double tpr;
164-
private final double fpr;
165-
private final double threshold;
225+
final double tpr;
226+
final double fpr;
227+
final double threshold;
166228

167229
public AucRocPoint(double tpr, double fpr, double threshold) {
168230
this.tpr = tpr;

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRocTests.java

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@
88

99
import org.elasticsearch.test.ESTestCase;
1010
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc;
11+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.common.AbstractAucRoc.AucRocPoint;
1112

1213
import java.util.Arrays;
14+
import java.util.Collections;
1315
import java.util.List;
1416

1517
import static org.hamcrest.Matchers.closeTo;
18+
import static org.hamcrest.Matchers.empty;
19+
import static org.hamcrest.Matchers.equalTo;
1620
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
21+
import static org.hamcrest.Matchers.is;
1722
import static org.hamcrest.Matchers.lessThanOrEqualTo;
1823

1924
public class AbstractAucRocTests extends ESTestCase {
@@ -22,7 +27,7 @@ public void testCalculateAucScore_GivenZeroPercentiles() {
2227
double[] tpPercentiles = zeroPercentiles();
2328
double[] fpPercentiles = zeroPercentiles();
2429

25-
List<AbstractAucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
30+
List<AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
2631
double aucRocScore = AucRoc.calculateAucScore(curve);
2732

2833
assertThat(aucRocScore, closeTo(0.5, 0.01));
@@ -32,7 +37,7 @@ public void testCalculateAucScore_GivenRandomTpPercentilesAndZeroFpPercentiles()
3237
double[] tpPercentiles = randomPercentiles();
3338
double[] fpPercentiles = zeroPercentiles();
3439

35-
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
40+
List<AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
3641
double aucRocScore = AucRoc.calculateAucScore(curve);
3742

3843
assertThat(aucRocScore, closeTo(1.0, 0.1));
@@ -42,7 +47,7 @@ public void testCalculateAucScore_GivenZeroTpPercentilesAndRandomFpPercentiles()
4247
double[] tpPercentiles = zeroPercentiles();
4348
double[] fpPercentiles = randomPercentiles();
4449

45-
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
50+
List<AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
4651
double aucRocScore = AucRoc.calculateAucScore(curve);
4752

4853
assertThat(aucRocScore, closeTo(0.0, 0.1));
@@ -53,10 +58,10 @@ public void testCalculateAucScore_GivenRandomPercentiles() {
5358
double[] tpPercentiles = randomPercentiles();
5459
double[] fpPercentiles = randomPercentiles();
5560

56-
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
61+
List<AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
5762
double aucRocScore = AucRoc.calculateAucScore(curve);
5863

59-
List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
64+
List<AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
6065
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);
6166

6267
assertThat(aucRocScore, greaterThanOrEqualTo(0.0));
@@ -80,16 +85,44 @@ public void testCalculateAucScore_GivenPrecalculated() {
8085
fpPercentiles[i] = fpSimplified[simplifiedIndex];
8186
}
8287

83-
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
88+
List<AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
8489
double aucRocScore = AucRoc.calculateAucScore(curve);
8590

86-
List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
91+
List<AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
8792
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);
8893

8994
assertThat(aucRocScore, closeTo(0.8, 0.05));
9095
assertThat(inverseAucRocScore, closeTo(0.2, 0.05));
9196
}
9297

98+
public void testCollapseEqualThresholdPoints_GivenEmpty() {
99+
assertThat(AbstractAucRoc.collapseEqualThresholdPoints(Collections.emptyList()), is(empty()));
100+
}
101+
102+
public void testCollapseEqualThresholdPoints() {
103+
List<AucRocPoint> curve = Arrays.asList(
104+
new AucRocPoint(0.0, 0.0, 1.0),
105+
new AucRocPoint(0.1, 0.9, 0.1),
106+
new AucRocPoint(0.2, 0.8, 0.2),
107+
new AucRocPoint(0.1, 0.9, 0.2),
108+
new AucRocPoint(0.3, 0.6, 0.3),
109+
new AucRocPoint(0.5, 0.5, 0.4),
110+
new AucRocPoint(0.4, 0.6, 0.4),
111+
new AucRocPoint(0.9, 0.1, 0.4),
112+
new AucRocPoint(1.0, 1.0, 0.0)
113+
);
114+
115+
List<AucRocPoint> collapsed = AbstractAucRoc.collapseEqualThresholdPoints(curve);
116+
117+
assertThat(collapsed.size(), equalTo(6));
118+
assertThat(collapsed.get(0), equalTo(curve.get(0)));
119+
assertThat(collapsed.get(1), equalTo(curve.get(1)));
120+
assertPointCloseTo(collapsed.get(2), 0.15, 0.85, 0.2);
121+
assertThat(collapsed.get(3), equalTo(curve.get(4)));
122+
assertPointCloseTo(collapsed.get(4), 0.6, 0.4, 0.4);
123+
assertThat(collapsed.get(5), equalTo(curve.get(8)));
124+
}
125+
93126
public static double[] zeroPercentiles() {
94127
double[] percentiles = new double[99];
95128
Arrays.fill(percentiles, 0.0);
@@ -104,4 +137,10 @@ public static double[] randomPercentiles() {
104137
Arrays.sort(percentiles);
105138
return percentiles;
106139
}
140+
141+
private static void assertPointCloseTo(AucRocPoint point, double expectedTpr, double expectedFpr, double expectedThreshold) {
142+
assertThat(point.tpr, closeTo(expectedTpr, 0.00001));
143+
assertThat(point.fpr, closeTo(expectedFpr, 0.00001));
144+
assertThat(point.threshold, closeTo(expectedThreshold, 0.00001));
145+
}
107146
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,13 @@ private AucRoc.Result evaluateAucRoc(String actualField, String predictedField,
9898

9999
public void testEvaluate_AucRoc_DoNotIncludeCurve() {
100100
AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, false);
101-
assertThat(aucrocResult.getValue(), is(closeTo(1.0, 0.0001)));
101+
assertThat(aucrocResult.getValue(), is(closeTo(0.98, 0.001)));
102102
assertThat(aucrocResult.getCurve(), hasSize(0));
103103
}
104104

105105
public void testEvaluate_AucRoc_IncludeCurve() {
106106
AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, true);
107-
assertThat(aucrocResult.getValue(), is(closeTo(1.0, 0.0001)));
107+
assertThat(aucrocResult.getValue(), is(closeTo(0.98, 0.001)));
108108
assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0)));
109109
}
110110

x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ setup:
206206
}
207207
}
208208
}
209-
- match: { outlier_detection.auc_roc.value: 0.9899 }
209+
- match: { outlier_detection.auc_roc.value: 0.9919403846153846 }
210210
- is_false: outlier_detection.auc_roc.curve
211211

212212
---
@@ -226,7 +226,7 @@ setup:
226226
}
227227
}
228228
}
229-
- match: { outlier_detection.auc_roc.value: 0.9899 }
229+
- match: { outlier_detection.auc_roc.value: 0.9919403846153846 }
230230
- is_false: outlier_detection.auc_roc.curve
231231

232232
---
@@ -246,7 +246,7 @@ setup:
246246
}
247247
}
248248
}
249-
- match: { outlier_detection.auc_roc.value: 0.9899 }
249+
- match: { outlier_detection.auc_roc.value: 0.9919403846153846 }
250250
- is_true: outlier_detection.auc_roc.curve
251251

252252
---
@@ -721,7 +721,7 @@ setup:
721721
}
722722
}
723723
}
724-
- match: { classification.auc_roc.value: 0.8050111095212122 }
724+
- match: { classification.auc_roc.value: 0.7754152761810909 }
725725
- is_false: classification.auc_roc.curve
726726
---
727727
"Test classification auc_roc with default top_classes_field":
@@ -741,7 +741,7 @@ setup:
741741
}
742742
}
743743
}
744-
- match: { classification.auc_roc.value: 0.8050111095212122 }
744+
- match: { classification.auc_roc.value: 0.7754152761810909 }
745745
- is_false: classification.auc_roc.curve
746746
---
747747
"Test classification accuracy with missing predicted_field":

0 commit comments

Comments
 (0)