Skip to content

Commit 96b5a75

Browse files
committed
Make num_top_classes parameter's default value equal to 2
1 parent e5af8bb commit 96b5a75

File tree

3 files changed

+64
-17
lines changed

3 files changed

+64
-17
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
6565
Stream.of(Types.categorical(), Types.discreteNumerical(), Types.bool())
6666
.flatMap(Set::stream)
6767
.collect(Collectors.toUnmodifiableSet());
68+
/**
69+
* As long as we only support binary classification it makes sense to always report both classes with their probabilities.
70+
* This way the user can see if the prediction was made with confidence they need.
71+
*/
72+
private static final int DEFAULT_NUM_TOP_CLASSES = 2;
6873

6974
private final String dependentVariable;
7075
private final BoostedTreeParams boostedTreeParams;
@@ -86,7 +91,7 @@ public Classification(String dependentVariable,
8691
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
8792
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
8893
this.predictionFieldName = predictionFieldName;
89-
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
94+
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
9095
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
9196
}
9297

@@ -106,6 +111,10 @@ public String getDependentVariable() {
106111
return dependentVariable;
107112
}
108113

114+
public int getNumTopClasses() {
115+
return numTopClasses;
116+
}
117+
109118
public double getTrainingPercent() {
110119
return trainingPercent;
111120
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
2121

22+
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);
23+
2224
@Override
2325
protected Classification doParseInstance(XContentParser parser) throws IOException {
2426
return Classification.fromXContent(parser, false);
@@ -43,32 +45,68 @@ protected Writeable.Reader<Classification> instanceReader() {
4345
return Classification::new;
4446
}
4547

46-
public void testConstructor_GivenTrainingPercentIsNull() {
47-
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, null);
48-
assertThat(classification.getTrainingPercent(), equalTo(100.0));
49-
}
50-
51-
public void testConstructor_GivenTrainingPercentIsBoundary() {
52-
Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 1.0);
53-
assertThat(classification.getTrainingPercent(), equalTo(1.0));
54-
classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0);
55-
assertThat(classification.getTrainingPercent(), equalTo(100.0));
56-
}
57-
5848
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
5949
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
60-
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 0.999));
50+
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999));
6151

6252
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
6353
}
6454

6555
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
6656
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
67-
() -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0001));
57+
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001));
6858

6959
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
7060
}
7161

62+
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
63+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
64+
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0));
65+
66+
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
67+
}
68+
69+
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
70+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
71+
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0));
72+
73+
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
74+
}
75+
76+
public void testGetNumTopClasses() {
77+
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0);
78+
assertThat(classification.getNumTopClasses(), equalTo(7));
79+
80+
// Boundary condition: num_top_classes == 0
81+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0);
82+
assertThat(classification.getNumTopClasses(), equalTo(0));
83+
84+
// Boundary condition: num_top_classes == 1000
85+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0);
86+
assertThat(classification.getNumTopClasses(), equalTo(1000));
87+
88+
// num_top_classes == null, default applied
89+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0);
90+
assertThat(classification.getNumTopClasses(), equalTo(2));
91+
}
92+
93+
public void testGetTrainingPercent() {
94+
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
95+
assertThat(classification.getTrainingPercent(), equalTo(50.0));
96+
97+
// Boundary condition: training_percent == 1.0
98+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0);
99+
assertThat(classification.getTrainingPercent(), equalTo(1.0));
100+
101+
// Boundary condition: training_percent == 100.0
102+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0);
103+
assertThat(classification.getTrainingPercent(), equalTo(100.0));
104+
105+
// training_percent == null, default applied
106+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null);
107+
assertThat(classification.getTrainingPercent(), equalTo(100.0));
108+
}
109+
72110
public void testFieldCardinalityLimitsIsNonNull() {
73111
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
74112
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
8383
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
8484
assertThat(resultsObject.containsKey("is_training"), is(true));
8585
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
86-
assertThat(resultsObject.containsKey("top_classes"), is(false));
86+
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
8787
}
8888

8989
assertProgress(jobId, 100, 100, 100, 100);
@@ -120,7 +120,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
120120
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
121121
assertThat(resultsObject.containsKey("is_training"), is(true));
122122
assertThat(resultsObject.get("is_training"), is(true));
123-
assertThat(resultsObject.containsKey("top_classes"), is(false));
123+
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
124124
}
125125

126126
assertProgress(jobId, 100, 100, 100, 100);

0 commit comments

Comments
 (0)