Skip to content

Commit efcc4d1

Browse files
committed
Implement new analysis type: classification (elastic#46537)
1 parent 31a5e1c commit efcc4d1

File tree

27 files changed

+1833
-427
lines changed

27 files changed

+1833
-427
lines changed
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.dataframe;
20+
21+
import org.elasticsearch.common.Nullable;
22+
import org.elasticsearch.common.ParseField;
23+
import org.elasticsearch.common.Strings;
24+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
25+
import org.elasticsearch.common.xcontent.XContentBuilder;
26+
import org.elasticsearch.common.xcontent.XContentParser;
27+
28+
import java.io.IOException;
29+
import java.util.Objects;
30+
31+
public class Classification implements DataFrameAnalysis {
32+
33+
public static Classification fromXContent(XContentParser parser) {
34+
return PARSER.apply(parser, null);
35+
}
36+
37+
public static Builder builder(String dependentVariable) {
38+
return new Builder(dependentVariable);
39+
}
40+
41+
public static final ParseField NAME = new ParseField("classification");
42+
43+
static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
44+
static final ParseField LAMBDA = new ParseField("lambda");
45+
static final ParseField GAMMA = new ParseField("gamma");
46+
static final ParseField ETA = new ParseField("eta");
47+
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
48+
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
49+
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
50+
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
51+
52+
private static final ConstructingObjectParser<Classification, Void> PARSER =
53+
new ConstructingObjectParser<>(
54+
NAME.getPreferredName(),
55+
true,
56+
a -> new Classification(
57+
(String) a[0],
58+
(Double) a[1],
59+
(Double) a[2],
60+
(Double) a[3],
61+
(Integer) a[4],
62+
(Double) a[5],
63+
(String) a[6],
64+
(Double) a[7]));
65+
66+
static {
67+
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
68+
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
69+
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
70+
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
71+
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
72+
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
73+
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
74+
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
75+
}
76+
77+
private final String dependentVariable;
78+
private final Double lambda;
79+
private final Double gamma;
80+
private final Double eta;
81+
private final Integer maximumNumberTrees;
82+
private final Double featureBagFraction;
83+
private final String predictionFieldName;
84+
private final Double trainingPercent;
85+
86+
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
87+
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
88+
@Nullable Double trainingPercent) {
89+
this.dependentVariable = Objects.requireNonNull(dependentVariable);
90+
this.lambda = lambda;
91+
this.gamma = gamma;
92+
this.eta = eta;
93+
this.maximumNumberTrees = maximumNumberTrees;
94+
this.featureBagFraction = featureBagFraction;
95+
this.predictionFieldName = predictionFieldName;
96+
this.trainingPercent = trainingPercent;
97+
}
98+
99+
@Override
100+
public String getName() {
101+
return NAME.getPreferredName();
102+
}
103+
104+
public String getDependentVariable() {
105+
return dependentVariable;
106+
}
107+
108+
public Double getLambda() {
109+
return lambda;
110+
}
111+
112+
public Double getGamma() {
113+
return gamma;
114+
}
115+
116+
public Double getEta() {
117+
return eta;
118+
}
119+
120+
public Integer getMaximumNumberTrees() {
121+
return maximumNumberTrees;
122+
}
123+
124+
public Double getFeatureBagFraction() {
125+
return featureBagFraction;
126+
}
127+
128+
public String getPredictionFieldName() {
129+
return predictionFieldName;
130+
}
131+
132+
public Double getTrainingPercent() {
133+
return trainingPercent;
134+
}
135+
136+
@Override
137+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
138+
builder.startObject();
139+
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
140+
if (lambda != null) {
141+
builder.field(LAMBDA.getPreferredName(), lambda);
142+
}
143+
if (gamma != null) {
144+
builder.field(GAMMA.getPreferredName(), gamma);
145+
}
146+
if (eta != null) {
147+
builder.field(ETA.getPreferredName(), eta);
148+
}
149+
if (maximumNumberTrees != null) {
150+
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
151+
}
152+
if (featureBagFraction != null) {
153+
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
154+
}
155+
if (predictionFieldName != null) {
156+
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
157+
}
158+
if (trainingPercent != null) {
159+
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
160+
}
161+
builder.endObject();
162+
return builder;
163+
}
164+
165+
@Override
166+
public int hashCode() {
167+
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
168+
trainingPercent);
169+
}
170+
171+
@Override
172+
public boolean equals(Object o) {
173+
if (this == o) return true;
174+
if (o == null || getClass() != o.getClass()) return false;
175+
Classification that = (Classification) o;
176+
return Objects.equals(dependentVariable, that.dependentVariable)
177+
&& Objects.equals(lambda, that.lambda)
178+
&& Objects.equals(gamma, that.gamma)
179+
&& Objects.equals(eta, that.eta)
180+
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
181+
&& Objects.equals(featureBagFraction, that.featureBagFraction)
182+
&& Objects.equals(predictionFieldName, that.predictionFieldName)
183+
&& Objects.equals(trainingPercent, that.trainingPercent);
184+
}
185+
186+
@Override
187+
public String toString() {
188+
return Strings.toString(this);
189+
}
190+
191+
public static class Builder {
192+
private String dependentVariable;
193+
private Double lambda;
194+
private Double gamma;
195+
private Double eta;
196+
private Integer maximumNumberTrees;
197+
private Double featureBagFraction;
198+
private String predictionFieldName;
199+
private Double trainingPercent;
200+
201+
private Builder(String dependentVariable) {
202+
this.dependentVariable = Objects.requireNonNull(dependentVariable);
203+
}
204+
205+
public Builder setLambda(Double lambda) {
206+
this.lambda = lambda;
207+
return this;
208+
}
209+
210+
public Builder setGamma(Double gamma) {
211+
this.gamma = gamma;
212+
return this;
213+
}
214+
215+
public Builder setEta(Double eta) {
216+
this.eta = eta;
217+
return this;
218+
}
219+
220+
public Builder setMaximumNumberTrees(Integer maximumNumberTrees) {
221+
this.maximumNumberTrees = maximumNumberTrees;
222+
return this;
223+
}
224+
225+
public Builder setFeatureBagFraction(Double featureBagFraction) {
226+
this.featureBagFraction = featureBagFraction;
227+
return this;
228+
}
229+
230+
public Builder setPredictionFieldName(String predictionFieldName) {
231+
this.predictionFieldName = predictionFieldName;
232+
return this;
233+
}
234+
235+
public Builder setTrainingPercent(Double trainingPercent) {
236+
this.trainingPercent = trainingPercent;
237+
return this;
238+
}
239+
240+
public Classification build() {
241+
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
242+
trainingPercent);
243+
}
244+
}
245+
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
3636
new NamedXContentRegistry.Entry(
3737
DataFrameAnalysis.class,
3838
Regression.NAME,
39-
(p, c) -> Regression.fromXContent(p)));
39+
(p, c) -> Regression.fromXContent(p)),
40+
new NamedXContentRegistry.Entry(
41+
DataFrameAnalysis.class,
42+
Classification.NAME,
43+
(p, c) -> Classification.fromXContent(p)));
4044
}
4145
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,19 @@ public static Builder builder(String dependentVariable) {
4949
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
5050
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
5151

52-
private static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), true,
53-
a -> new Regression(
54-
(String) a[0],
55-
(Double) a[1],
56-
(Double) a[2],
57-
(Double) a[3],
58-
(Integer) a[4],
59-
(Double) a[5],
60-
(String) a[6],
61-
(Double) a[7]));
52+
private static final ConstructingObjectParser<Regression, Void> PARSER =
53+
new ConstructingObjectParser<>(
54+
NAME.getPreferredName(),
55+
true,
56+
a -> new Regression(
57+
(String) a[0],
58+
(Double) a[1],
59+
(Double) a[2],
60+
(Double) a[3],
61+
(Integer) a[4],
62+
(Double) a[5],
63+
(String) a[6],
64+
(Double) a[7]));
6265

6366
static {
6467
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,41 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
13151315
assertThat(createdConfig.getDescription(), equalTo("this is a regression"));
13161316
}
13171317

1318+
public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Exception {
1319+
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
1320+
String configId = "test-put-df-analytics-classification";
1321+
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
1322+
.setId(configId)
1323+
.setSource(DataFrameAnalyticsSource.builder()
1324+
.setIndex("put-test-source-index")
1325+
.build())
1326+
.setDest(DataFrameAnalyticsDest.builder()
1327+
.setIndex("put-test-dest-index")
1328+
.build())
1329+
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification
1330+
.builder("my_dependent_variable")
1331+
.setTrainingPercent(80.0)
1332+
.build())
1333+
.setDescription("this is a classification")
1334+
.build();
1335+
1336+
createIndex("put-test-source-index", defaultMappingForTest());
1337+
1338+
PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute(
1339+
new PutDataFrameAnalyticsRequest(config),
1340+
machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync);
1341+
DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig();
1342+
assertThat(createdConfig.getId(), equalTo(config.getId()));
1343+
assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex()));
1344+
assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value
1345+
assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex()));
1346+
assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value
1347+
assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis()));
1348+
assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields()));
1349+
assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value
1350+
assertThat(createdConfig.getDescription(), equalTo("this is a classification"));
1351+
}
1352+
13181353
public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception {
13191354
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
13201355
String configId = "get-test-config";

client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ public void testDefaultNamedXContents() {
684684

685685
public void testProvidedNamedXContents() {
686686
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
687-
assertEquals(44, namedXContents.size());
687+
assertEquals(48, namedXContents.size());
688688
Map<Class<?>, Integer> categories = new HashMap<>();
689689
List<String> names = new ArrayList<>();
690690
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -718,9 +718,10 @@ public void testProvidedNamedXContents() {
718718
assertTrue(names.contains(ShrinkAction.NAME));
719719
assertTrue(names.contains(FreezeAction.NAME));
720720
assertTrue(names.contains(SetPriorityAction.NAME));
721-
assertEquals(Integer.valueOf(2), categories.get(DataFrameAnalysis.class));
721+
assertEquals(Integer.valueOf(3), categories.get(DataFrameAnalysis.class));
722722
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
723723
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName()));
724+
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Classification.NAME.getPreferredName()));
724725
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
725726
assertTrue(names.contains(TimeSyncConfig.NAME));
726727
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));

0 commit comments

Comments
 (0)