Skip to content

Commit e303683

Browse files
authored
Default "prediction_field_name" to (dependent_variable + "_prediction") (#48232) (#48282)
1 parent 0744cde commit e303683

File tree

6 files changed

+53
-17
lines changed

6 files changed

+53
-17
lines changed

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@
126126
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
127127
import org.elasticsearch.client.ml.dataframe.QueryConfig;
128128
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
129-
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
130129
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
130+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
131131
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
132132
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
133133
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@@ -1297,6 +1297,7 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
12971297
.setIndex("put-test-dest-index")
12981298
.build())
12991299
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
1300+
.setPredictionFieldName("my_dependent_variable_prediction")
13001301
.setTrainingPercent(80.0)
13011302
.build())
13021303
.setDescription("this is a regression")
@@ -1331,6 +1332,7 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti
13311332
.setIndex("put-test-dest-index")
13321333
.build())
13331334
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
1335+
.setPredictionFieldName("my_dependent_variable_prediction")
13341336
.setTrainingPercent(80.0)
13351337
.setNumTopClasses(1)
13361338
.build())

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ public Classification(String dependentVariable,
9292
}
9393
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
9494
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
95-
this.predictionFieldName = predictionFieldName;
95+
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
9696
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
9797
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
9898
}
@@ -113,6 +113,10 @@ public String getDependentVariable() {
113113
return dependentVariable;
114114
}
115115

116+
public String getPredictionFieldName() {
117+
return predictionFieldName;
118+
}
119+
116120
public int getNumTopClasses() {
117121
return numTopClasses;
118122
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public Regression(String dependentVariable,
7070
}
7171
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
7272
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
73-
this.predictionFieldName = predictionFieldName;
73+
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
7474
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
7575
}
7676

@@ -89,6 +89,10 @@ public String getDependentVariable() {
8989
return dependentVariable;
9090
}
9191

92+
public String getPredictionFieldName() {
93+
return predictionFieldName;
94+
}
95+
9296
public double getTrainingPercent() {
9397
return trainingPercent;
9498
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
7373
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
7474
}
7575

76+
public void testGetPredictionFieldName() {
77+
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
78+
assertThat(classification.getPredictionFieldName(), equalTo("result"));
79+
80+
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0);
81+
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
82+
}
83+
7684
public void testGetNumTopClasses() {
7785
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0);
7886
assertThat(classification.getNumTopClasses(), equalTo(7));

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
2121

22+
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);
23+
2224
@Override
2325
protected Regression doParseInstance(XContentParser parser) throws IOException {
2426
return Regression.fromXContent(parser, false);
@@ -42,32 +44,45 @@ protected Writeable.Reader<Regression> instanceReader() {
4244
return Regression::new;
4345
}
4446

45-
public void testConstructor_GivenTrainingPercentIsNull() {
46-
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", null);
47-
assertThat(regression.getTrainingPercent(), equalTo(100.0));
48-
}
49-
50-
public void testConstructor_GivenTrainingPercentIsBoundary() {
51-
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 1.0);
52-
assertThat(regression.getTrainingPercent(), equalTo(1.0));
53-
regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0);
54-
assertThat(regression.getTrainingPercent(), equalTo(100.0));
55-
}
56-
5747
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
5848
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
59-
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 0.999));
49+
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999));
6050

6151
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
6252
}
6353

6454
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
6555
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
66-
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0001));
56+
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001));
6757

6858
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
6959
}
7060

61+
public void testGetPredictionFieldName() {
62+
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
63+
assertThat(regression.getPredictionFieldName(), equalTo("result"));
64+
65+
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0);
66+
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
67+
}
68+
69+
public void testGetTrainingPercent() {
70+
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
71+
assertThat(regression.getTrainingPercent(), equalTo(50.0));
72+
73+
// Boundary condition: training_percent == 1.0
74+
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0);
75+
assertThat(regression.getTrainingPercent(), equalTo(1.0));
76+
77+
// Boundary condition: training_percent == 100.0
78+
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0);
79+
assertThat(regression.getTrainingPercent(), equalTo(100.0));
80+
81+
// training_percent == null, default applied
82+
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null);
83+
assertThat(regression.getTrainingPercent(), equalTo(100.0));
84+
}
85+
7186
public void testFieldCardinalityLimitsIsNonNull() {
7287
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
7388
}

x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,6 +1470,7 @@ setup:
14701470
"eta": 0.5,
14711471
"maximum_number_trees": 400,
14721472
"feature_bag_fraction": 0.3,
1473+
"prediction_field_name": "foo_prediction",
14731474
"training_percent": 60.3
14741475
}
14751476
}}
@@ -1809,6 +1810,7 @@ setup:
18091810
"eta": 0.5,
18101811
"maximum_number_trees": 400,
18111812
"feature_bag_fraction": 0.3,
1813+
"prediction_field_name": "foo_prediction",
18121814
"training_percent": 60.3,
18131815
"num_top_classes": 2
18141816
}
@@ -1844,6 +1846,7 @@ setup:
18441846
- match: { analysis: {
18451847
"regression":{
18461848
"dependent_variable": "foo",
1849+
"prediction_field_name": "foo_prediction",
18471850
"training_percent": 100.0
18481851
}
18491852
}}

0 commit comments

Comments
 (0)