Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
Expand Down Expand Up @@ -1267,6 +1267,7 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
.setIndex("put-test-dest-index")
.build())
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
.setPredictionFieldName("my_dependent_variable_prediction")
.setTrainingPercent(80.0)
.build())
.setDescription("this is a regression")
Expand Down Expand Up @@ -1301,6 +1302,7 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti
.setIndex("put-test-dest-index")
.build())
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
.setPredictionFieldName("my_dependent_variable_prediction")
.setTrainingPercent(80.0)
.setNumTopClasses(1)
.build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public Classification(String dependentVariable,
}
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
this.predictionFieldName = predictionFieldName;
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
}
Expand All @@ -111,6 +111,10 @@ public String getDependentVariable() {
return dependentVariable;
}

public String getPredictionFieldName() {
return predictionFieldName;
}

public int getNumTopClasses() {
return numTopClasses;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public Regression(String dependentVariable,
}
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
this.predictionFieldName = predictionFieldName;
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
}

Expand All @@ -89,6 +89,10 @@ public String getDependentVariable() {
return dependentVariable;
}

public String getPredictionFieldName() {
return predictionFieldName;
}

public double getTrainingPercent() {
return trainingPercent;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
}

public void testGetPredictionFieldName() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
assertThat(classification.getPredictionFieldName(), equalTo("result"));

classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0);
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
}

public void testGetNumTopClasses() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0);
assertThat(classification.getNumTopClasses(), equalTo(7));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

public class RegressionTests extends AbstractSerializingTestCase<Regression> {

private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);

@Override
protected Regression doParseInstance(XContentParser parser) throws IOException {
return Regression.fromXContent(parser, false);
Expand All @@ -42,32 +44,45 @@ protected Writeable.Reader<Regression> instanceReader() {
return Regression::new;
}

public void testConstructor_GivenTrainingPercentIsNull() {
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", null);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}

public void testConstructor_GivenTrainingPercentIsBoundary() {
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 1.0);
assertThat(regression.getTrainingPercent(), equalTo(1.0));
regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}

public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 0.999));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999));

assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}

public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0001));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001));

assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}

public void testGetPredictionFieldName() {
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
assertThat(regression.getPredictionFieldName(), equalTo("result"));

regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0);
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
}

public void testGetTrainingPercent() {
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
assertThat(regression.getTrainingPercent(), equalTo(50.0));

// Boundary condition: training_percent == 1.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0);
assertThat(regression.getTrainingPercent(), equalTo(1.0));

// Boundary condition: training_percent == 100.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0);
assertThat(regression.getTrainingPercent(), equalTo(100.0));

// training_percent == null, default applied
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}

public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,7 @@ setup:
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3,
"prediction_field_name": "foo_prediction",
"training_percent": 60.3
}
}}
Expand Down Expand Up @@ -1809,6 +1810,7 @@ setup:
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3,
"prediction_field_name": "foo_prediction",
"training_percent": 60.3,
"num_top_classes": 2
}
Expand Down Expand Up @@ -1844,6 +1846,7 @@ setup:
- match: { analysis: {
"regression":{
"dependent_variable": "foo",
"prediction_field_name": "foo_prediction",
"training_percent": 100.0
}
}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
- match: { data_frame_analytics.0.source.index: ["bwc_ml_regression_job_source"] }
- match: { data_frame_analytics.0.source.query: {"term": { "user": "Kimchy" }} }
- match: { data_frame_analytics.0.dest.index: "old_cluster_regression_job_results" }
- match: { data_frame_analytics.0.analysis: {"regression":{ "dependent_variable": "foo", "training_percent": 100.0 }} }
- match: { data_frame_analytics.0.analysis.regression.dependent_variable: "foo" }
- match: { data_frame_analytics.0.analysis.regression.training_percent: 100.0 }

---
"Get old regression job stats":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
- match: { data_frame_analytics.0.source.index: ["bwc_ml_regression_job_source"] }
- match: { data_frame_analytics.0.source.query: {"term": { "user": "Kimchy" }} }
- match: { data_frame_analytics.0.dest.index: "old_cluster_regression_job_results" }
- match: { data_frame_analytics.0.analysis: {"regression":{ "dependent_variable": "foo", "training_percent": 100.0 }} }
- match: { data_frame_analytics.0.analysis.regression.dependent_variable: "foo" }
- match: { data_frame_analytics.0.analysis.regression.training_percent: 100.0 }

---
"Get old cluster regression job stats":
Expand Down